smll improvements

This commit is contained in:
WillJeynes
2026-04-10 14:21:06 +01:00
parent 329b49944d
commit 7665292db4
+7 -4
View File
@@ -6,7 +6,7 @@ from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling,
import pandas as pd
# Load your CSV
df = pd.read_csv("../data/dataset-dev.csv")
df = pd.read_csv("../data/dataset.csv")
# Event columns
event_cols = ["Event1", "Event2", "Event3", "Event4", "Event5"]
@@ -89,7 +89,7 @@ lora_cfg = LoraConfig(
r=8, # rank
lora_alpha=32,
lora_dropout=0.05,
target_modules=["c_attn","c_proj"], # common GPT-2 modules
target_modules=["c_attn", "c_proj", "c_fc"],
fan_in_fan_out=True,
)
@@ -99,12 +99,15 @@ args_lora = TrainingArguments(
output_dir="./ft_lora",
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=20,
learning_rate=1e-4,
num_train_epochs=5,
learning_rate=2e-5,
eval_strategy="epoch",
save_strategy="epoch",
logging_steps=10,
optim="adamw_torch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False
)
trainer_lora = Trainer(