smll improvements
This commit is contained in:
+7
-4
@@ -6,7 +6,7 @@ from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling,
|
||||
import pandas as pd
|
||||
|
||||
# Load your CSV
|
||||
df = pd.read_csv("../data/dataset-dev.csv")
|
||||
df = pd.read_csv("../data/dataset.csv")
|
||||
|
||||
# Event columns
|
||||
event_cols = ["Event1", "Event2", "Event3", "Event4", "Event5"]
|
||||
@@ -89,7 +89,7 @@ lora_cfg = LoraConfig(
|
||||
r=8, # rank
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
target_modules=["c_attn","c_proj"], # common GPT-2 modules
|
||||
target_modules=["c_attn", "c_proj", "c_fc"],
|
||||
fan_in_fan_out=True,
|
||||
)
|
||||
|
||||
@@ -99,12 +99,15 @@ args_lora = TrainingArguments(
|
||||
output_dir="./ft_lora",
|
||||
per_device_train_batch_size=2,
|
||||
per_device_eval_batch_size=2,
|
||||
num_train_epochs=20,
|
||||
learning_rate=1e-4,
|
||||
num_train_epochs=5,
|
||||
learning_rate=2e-5,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
logging_steps=10,
|
||||
optim="adamw_torch",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_loss",
|
||||
greater_is_better=False
|
||||
)
|
||||
|
||||
trainer_lora = Trainer(
|
||||
|
||||
Reference in New Issue
Block a user