smll improvements
This commit is contained in:
+7
-4
@@ -6,7 +6,7 @@ from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling,
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
# Load your CSV
|
# Load your CSV
|
||||||
df = pd.read_csv("../data/dataset-dev.csv")
|
df = pd.read_csv("../data/dataset.csv")
|
||||||
|
|
||||||
# Event columns
|
# Event columns
|
||||||
event_cols = ["Event1", "Event2", "Event3", "Event4", "Event5"]
|
event_cols = ["Event1", "Event2", "Event3", "Event4", "Event5"]
|
||||||
@@ -89,7 +89,7 @@ lora_cfg = LoraConfig(
|
|||||||
r=8, # rank
|
r=8, # rank
|
||||||
lora_alpha=32,
|
lora_alpha=32,
|
||||||
lora_dropout=0.05,
|
lora_dropout=0.05,
|
||||||
target_modules=["c_attn","c_proj"], # common GPT-2 modules
|
target_modules=["c_attn", "c_proj", "c_fc"],
|
||||||
fan_in_fan_out=True,
|
fan_in_fan_out=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -99,12 +99,15 @@ args_lora = TrainingArguments(
|
|||||||
output_dir="./ft_lora",
|
output_dir="./ft_lora",
|
||||||
per_device_train_batch_size=2,
|
per_device_train_batch_size=2,
|
||||||
per_device_eval_batch_size=2,
|
per_device_eval_batch_size=2,
|
||||||
num_train_epochs=20,
|
num_train_epochs=5,
|
||||||
learning_rate=1e-4,
|
learning_rate=2e-5,
|
||||||
eval_strategy="epoch",
|
eval_strategy="epoch",
|
||||||
save_strategy="epoch",
|
save_strategy="epoch",
|
||||||
logging_steps=10,
|
logging_steps=10,
|
||||||
optim="adamw_torch",
|
optim="adamw_torch",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="eval_loss",
|
||||||
|
greater_is_better=False
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer_lora = Trainer(
|
trainer_lora = Trainer(
|
||||||
|
|||||||
Reference in New Issue
Block a user