Add training scripts for distilled, flan. Add run service for flan
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
# -- OURS --
|
# -- OURS --
|
||||||
results/
|
results/
|
||||||
roberta_classifier/
|
roberta_classifier/
|
||||||
|
roberta_distilled_classifier/
|
||||||
roberta_classifier*/
|
roberta_classifier*/
|
||||||
output*
|
output*
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,89 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||||
|
import torch
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan"
|
||||||
|
|
||||||
|
INT_TO_LABEL = {
|
||||||
|
0: "perfect",
|
||||||
|
1: "story",
|
||||||
|
2: "not specific",
|
||||||
|
}
|
||||||
|
|
||||||
|
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt(text: str) -> str:
|
||||||
|
return (
|
||||||
|
"Classify the following event into one of these categories: "
|
||||||
|
"perfect, story, not specific.\n\n"
|
||||||
|
f"Event: {text}\n\n"
|
||||||
|
"Category:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_generated_label(generated: str) -> int | None:
|
||||||
|
generated = generated.strip().lower()
|
||||||
|
for label_text, label_int in LABEL_TO_INT.items():
|
||||||
|
if label_text in generated:
|
||||||
|
return label_int
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRequest(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/evaluate")
|
||||||
|
def evaluate(req: EvalRequest):
|
||||||
|
prompt = format_prompt(req.answer)
|
||||||
|
|
||||||
|
inputs = tokenizer(
|
||||||
|
prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
max_length=256,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Get the generated label
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Produce a confidence score
|
||||||
|
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
|
||||||
|
logits_output = model(**inputs, decoder_input_ids=decoder_input_ids)
|
||||||
|
logits = logits_output.logits[:, 0, :]
|
||||||
|
|
||||||
|
# Decode the generated text label
|
||||||
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
predicted_int = parse_generated_label(generated_text)
|
||||||
|
|
||||||
|
# Extract probabilities
|
||||||
|
label_token_ids = {
|
||||||
|
label: tokenizer(label, add_special_tokens=False).input_ids[0]
|
||||||
|
for label in LABEL_TO_INT.keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
label_logits = torch.tensor(
|
||||||
|
[logits[0, tid].item() for tid in label_token_ids.values()]
|
||||||
|
)
|
||||||
|
label_probs = torch.softmax(label_logits, dim=0).tolist()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"generated": generated_text,
|
||||||
|
"probabilities": [label_probs],
|
||||||
|
}
|
||||||
@@ -5,7 +5,7 @@ from fastapi import FastAPI
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
MODEL_PATH = "./roberta_classifier"
|
MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis"
|
||||||
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
||||||
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
||||||
|
|||||||
@@ -0,0 +1,234 @@
|
|||||||
|
from sklearn.utils import compute_class_weight
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification
|
||||||
|
import torch
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
|
from collections import Counter
|
||||||
|
import sys
|
||||||
|
import csv
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
NUM_CLASSES = 3
|
||||||
|
model_name = "distilbert/distilroberta-base" # Or MiniLM, or any other transformer model
|
||||||
|
|
||||||
|
LABEL_PRIORITY = [
|
||||||
|
("PERFECT", 0),
|
||||||
|
("STORY", 1),
|
||||||
|
("NSPECIFIC", 2),
|
||||||
|
("REWORDING", 1),
|
||||||
|
("TINCORRECT", -1),
|
||||||
|
("DUPLICATE", -1),
|
||||||
|
("", 0), # fallback to PERFECT
|
||||||
|
]
|
||||||
|
|
||||||
|
class WeightedTrainer(Trainer):
|
||||||
|
def __init__(self, *args, class_weights=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.class_weights = class_weights
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
|
labels = inputs.get("labels")
|
||||||
|
# print("DBG: Before forward")
|
||||||
|
outputs = model(**inputs)
|
||||||
|
# print("DBG: After forward")
|
||||||
|
logits = outputs.get("logits")
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(
|
||||||
|
weight=self.class_weights.to(logits.device).to(logits.dtype)
|
||||||
|
)
|
||||||
|
# loss_fct = CrossEntropyLoss()
|
||||||
|
|
||||||
|
# print("DBG: Before loss")
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
# loss.backward()
|
||||||
|
# print("DBG: After loss")
|
||||||
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
|
def label_to_int(extra_info: str) -> int:
|
||||||
|
"""
|
||||||
|
Convert extra_info string to integer label using priority rules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if extra_info is None:
|
||||||
|
extra_info = ""
|
||||||
|
|
||||||
|
extra_info = extra_info.strip()
|
||||||
|
|
||||||
|
# Handle empty string explicitly
|
||||||
|
if extra_info == "":
|
||||||
|
for key, value in LABEL_PRIORITY:
|
||||||
|
if key == "":
|
||||||
|
return value
|
||||||
|
raise ValueError("Empty extra_info but no empty mapping defined")
|
||||||
|
|
||||||
|
# Split words (case-insensitive)
|
||||||
|
tokens = set(extra_info.upper().split())
|
||||||
|
|
||||||
|
# Priority matching
|
||||||
|
for key, value in LABEL_PRIORITY:
|
||||||
|
if key == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key in tokens:
|
||||||
|
return value
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown label content: '{extra_info}'")
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_from_csv(path):
|
||||||
|
texts = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
removed_rows = 0
|
||||||
|
|
||||||
|
with open(path, newline="", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
|
||||||
|
for i, row in enumerate(reader, start=1):
|
||||||
|
text = row["event"]
|
||||||
|
label_str = row["extra_info"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
label_int = label_to_int(label_str)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR converting label on line {i}: {label_str}")
|
||||||
|
print(e)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Skip rows marked for removal
|
||||||
|
if label_int == -1:
|
||||||
|
removed_rows += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
texts.append(text)
|
||||||
|
labels.append(label_int)
|
||||||
|
|
||||||
|
print(f"Loaded {len(texts)} samples (removed {removed_rows})")
|
||||||
|
|
||||||
|
return texts, labels
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred):
|
||||||
|
logits, labels = eval_pred
|
||||||
|
preds = logits.argmax(axis=1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"accuracy": accuracy_score(labels, preds),
|
||||||
|
"f1": f1_score(labels, preds, average="weighted", zero_division=0),
|
||||||
|
"precision": precision_score(labels, preds, average="weighted", zero_division=0),
|
||||||
|
"recall": recall_score(labels, preds, average="weighted", zero_division=0),
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.multiprocessing.set_start_method('fork')
|
||||||
|
print("CUDA available:", torch.cuda.is_available())
|
||||||
|
print("CUDA device count:", torch.cuda.device_count())
|
||||||
|
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
|
||||||
|
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
num_labels=NUM_CLASSES
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Dataset size:", len(texts))
|
||||||
|
print("Label distribution:")
|
||||||
|
print(Counter(labels))
|
||||||
|
|
||||||
|
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
||||||
|
texts,
|
||||||
|
labels,
|
||||||
|
test_size=0.2,
|
||||||
|
random_state=42,
|
||||||
|
stratify=labels
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class_weights = compute_class_weight(
|
||||||
|
class_weight="balanced",
|
||||||
|
classes=np.unique(train_labels),
|
||||||
|
y=train_labels
|
||||||
|
)
|
||||||
|
|
||||||
|
class_weights = torch.tensor(class_weights, dtype=torch.float)
|
||||||
|
print("Class weights:", class_weights)
|
||||||
|
|
||||||
|
train_encodings = tokenizer(
|
||||||
|
train_texts,
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
max_length=256
|
||||||
|
)
|
||||||
|
|
||||||
|
val_encodings = tokenizer(
|
||||||
|
val_texts,
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
max_length=256
|
||||||
|
)
|
||||||
|
|
||||||
|
class TextDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, encodings, labels):
|
||||||
|
self.encodings = encodings
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# print(f"DBG: Loading item {idx}")
|
||||||
|
item = {
|
||||||
|
key: torch.tensor(val[idx])
|
||||||
|
for key, val in self.encodings.items()
|
||||||
|
}
|
||||||
|
item["labels"] = torch.tensor(self.labels[idx])
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir="./results",
|
||||||
|
learning_rate=2e-5,
|
||||||
|
per_device_train_batch_size=32,
|
||||||
|
# gradient_accumulation_steps=2,
|
||||||
|
num_train_epochs=15,
|
||||||
|
weight_decay=0.01,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="epoch",
|
||||||
|
metric_for_best_model="f1",
|
||||||
|
greater_is_better=True,
|
||||||
|
dataloader_num_workers=4,
|
||||||
|
dataloader_pin_memory=True,
|
||||||
|
# warmup_steps=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = TextDataset(train_encodings, train_labels)
|
||||||
|
|
||||||
|
val_dataset = TextDataset(val_encodings, val_labels)
|
||||||
|
|
||||||
|
trainer = WeightedTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
|
class_weights=class_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
print("Final evaluation metrics:")
|
||||||
|
for k, v in metrics.items():
|
||||||
|
print(f"{k}: {v}")
|
||||||
|
|
||||||
|
trainer.save_model("./roberta_distilled_classifier")
|
||||||
|
tokenizer.save_pretrained("./roberta_distilled_classifier")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,227 @@
|
|||||||
|
from sklearn.utils import compute_class_weight
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
|
||||||
|
import torch
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
|
from collections import Counter
|
||||||
|
import sys
|
||||||
|
import csv
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
NUM_CLASSES = 3
|
||||||
|
model_name = "google/flan-t5-base"
|
||||||
|
|
||||||
|
INT_TO_LABEL = {
|
||||||
|
0: "perfect",
|
||||||
|
1: "story",
|
||||||
|
2: "not specific",
|
||||||
|
}
|
||||||
|
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
|
||||||
|
|
||||||
|
LABEL_PRIORITY = [
|
||||||
|
("PERFECT", 0),
|
||||||
|
("STORY", 1),
|
||||||
|
("NSPECIFIC", 2),
|
||||||
|
("REWORDING", 1),
|
||||||
|
("TINCORRECT", -1),
|
||||||
|
("DUPLICATE", -1),
|
||||||
|
("", 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
def label_to_int(extra_info: str) -> int:
|
||||||
|
if extra_info is None:
|
||||||
|
extra_info = ""
|
||||||
|
extra_info = extra_info.strip()
|
||||||
|
if extra_info == "":
|
||||||
|
for key, value in LABEL_PRIORITY:
|
||||||
|
if key == "":
|
||||||
|
return value
|
||||||
|
raise ValueError("Empty extra_info but no empty mapping defined")
|
||||||
|
tokens = set(extra_info.upper().split())
|
||||||
|
for key, value in LABEL_PRIORITY:
|
||||||
|
if key == "" :
|
||||||
|
continue
|
||||||
|
if key in tokens:
|
||||||
|
return value
|
||||||
|
raise ValueError(f"Unknown label content: '{extra_info}'")
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_from_csv(path):
|
||||||
|
texts = []
|
||||||
|
labels = []
|
||||||
|
removed_rows = 0
|
||||||
|
with open(path, newline="", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for i, row in enumerate(reader, start=1):
|
||||||
|
text = row["event"]
|
||||||
|
label_str = row["extra_info"]
|
||||||
|
try:
|
||||||
|
label_int = label_to_int(label_str)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR converting label on line {i}: {label_str}")
|
||||||
|
print(e)
|
||||||
|
sys.exit(1)
|
||||||
|
if label_int == -1:
|
||||||
|
removed_rows += 1
|
||||||
|
continue
|
||||||
|
texts.append(text)
|
||||||
|
labels.append(label_int)
|
||||||
|
print(f"Loaded {len(texts)} samples (removed {removed_rows})")
|
||||||
|
return texts, labels
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt(text: str) -> str:
|
||||||
|
return (
|
||||||
|
"Classify the following event into one of these categories: "
|
||||||
|
"perfect, story, not specific.\n\n"
|
||||||
|
f"Event: {text}\n\n"
|
||||||
|
"Category:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_generated_label(generated: str) -> int:
|
||||||
|
generated = generated.strip().lower()
|
||||||
|
for label_text, label_int in LABEL_TO_INT.items():
|
||||||
|
if label_text in generated:
|
||||||
|
return label_int
|
||||||
|
print("invlid label:" + generated)
|
||||||
|
return -1 # unknown / unparseable output
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeTextDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, texts, labels, tokenizer, max_input_length=256, max_target_length=8):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.max_target_length = max_target_length
|
||||||
|
|
||||||
|
self.inputs = [format_prompt(t) for t in texts]
|
||||||
|
# Convert integer labels to their text equivalents for the target sequence
|
||||||
|
self.targets = [INT_TO_LABEL[l] for l in labels]
|
||||||
|
self.int_labels = labels # keep originals for metric computation
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.inputs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
model_inputs = self.tokenizer(
|
||||||
|
self.inputs[idx],
|
||||||
|
max_length=self.max_input_length,
|
||||||
|
truncation=True,
|
||||||
|
padding=False,
|
||||||
|
)
|
||||||
|
target_encoding = self.tokenizer(
|
||||||
|
self.targets[idx],
|
||||||
|
max_length=self.max_target_length,
|
||||||
|
truncation=True,
|
||||||
|
padding=False,
|
||||||
|
)
|
||||||
|
# Seq2Seq convention: labels use -100 to ignore padding tokens in loss
|
||||||
|
labels = target_encoding["input_ids"]
|
||||||
|
labels = [token if token != self.tokenizer.pad_token_id else -100 for token in labels]
|
||||||
|
|
||||||
|
model_inputs["labels"] = labels
|
||||||
|
return {k: torch.tensor(v) for k, v in model_inputs.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metrics_generative(eval_pred, tokenizer):
|
||||||
|
predictions, label_ids = eval_pred
|
||||||
|
|
||||||
|
# Decode predictions
|
||||||
|
# Replace -100 in labels before decoding
|
||||||
|
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
||||||
|
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Map decoded text back to integer labels
|
||||||
|
pred_ints = [parse_generated_label(p) for p in decoded_preds]
|
||||||
|
true_ints = [parse_generated_label(l) for l in decoded_labels]
|
||||||
|
|
||||||
|
# Filter out any rows where parsing failed
|
||||||
|
valid = [(p, t) for p, t in zip(pred_ints, true_ints) if t != -1]
|
||||||
|
if not valid:
|
||||||
|
return {"accuracy": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
|
||||||
|
|
||||||
|
preds_filtered, true_filtered = zip(*valid)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"accuracy": accuracy_score(true_filtered, preds_filtered),
|
||||||
|
"f1": f1_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
|
||||||
|
"precision": precision_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
|
||||||
|
"recall": recall_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.multiprocessing.set_start_method('spawn', force=True)
|
||||||
|
print("CUDA available:", torch.cuda.is_available())
|
||||||
|
print("CUDA device count:", torch.cuda.device_count())
|
||||||
|
|
||||||
|
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
||||||
|
|
||||||
|
print("Dataset size:", len(texts))
|
||||||
|
print("Label distribution:", Counter(labels))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
||||||
|
|
||||||
|
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
||||||
|
texts, labels,
|
||||||
|
test_size=0.2,
|
||||||
|
random_state=42,
|
||||||
|
stratify=labels
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = GenerativeTextDataset(train_texts, train_labels, tokenizer)
|
||||||
|
val_dataset = GenerativeTextDataset(val_texts, val_labels, tokenizer)
|
||||||
|
|
||||||
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model=model,
|
||||||
|
padding=True,
|
||||||
|
label_pad_token_id=-100,
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args = Seq2SeqTrainingArguments(
|
||||||
|
output_dir="./results",
|
||||||
|
learning_rate=5e-5,
|
||||||
|
per_device_train_batch_size=16,
|
||||||
|
per_device_eval_batch_size=16,
|
||||||
|
num_train_epochs=10,
|
||||||
|
weight_decay=0.01,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="epoch",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="f1",
|
||||||
|
greater_is_better=True,
|
||||||
|
predict_with_generate=True,
|
||||||
|
generation_max_length=8,
|
||||||
|
dataloader_num_workers=0,
|
||||||
|
dataloader_pin_memory=False,
|
||||||
|
fp16=False,
|
||||||
|
max_grad_norm=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
compute_metrics=lambda ep: compute_metrics_generative(ep, tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
print("\nFinal evaluation metrics:")
|
||||||
|
for k, v in metrics.items():
|
||||||
|
print(f" {k}: {v}")
|
||||||
|
|
||||||
|
trainer.save_model("./flan_classifier")
|
||||||
|
tokenizer.save_pretrained("./flan_classifier")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -17,7 +17,7 @@ const AGENT_NAME = process.env.AGENT ?? "agent";
|
|||||||
*/
|
*/
|
||||||
const MODE = process.env.MODE ?? "claim";
|
const MODE = process.env.MODE ?? "claim";
|
||||||
|
|
||||||
const MAX_CONCURRENCY = 5;
|
const MAX_CONCURRENCY = 1;
|
||||||
|
|
||||||
const client = new Client({ apiUrl: API_URL });
|
const client = new Client({ apiUrl: API_URL });
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import streamlit as st
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
# THRESH = 0.4
|
|
||||||
THRESH = 0.6
|
THRESH = 0.6
|
||||||
|
|
||||||
def page_title() -> str:
|
def page_title() -> str:
|
||||||
@@ -61,6 +60,10 @@ def render():
|
|||||||
return
|
return
|
||||||
|
|
||||||
for file_path in jsonl_files:
|
for file_path in jsonl_files:
|
||||||
|
thresh = THRESH
|
||||||
|
if ("flan" in file_path.name):
|
||||||
|
thresh = 0.94
|
||||||
|
|
||||||
st.subheader(f"File: {file_path.name}")
|
st.subheader(f"File: {file_path.name}")
|
||||||
|
|
||||||
confidence_counter = Counter()
|
confidence_counter = Counter()
|
||||||
@@ -86,15 +89,15 @@ def render():
|
|||||||
dup_counter += 1
|
dup_counter += 1
|
||||||
elif "ranked" not in event:
|
elif "ranked" not in event:
|
||||||
"ignore for now"
|
"ignore for now"
|
||||||
elif score > THRESH and extra_lower == "perfect":
|
elif score > thresh and extra_lower == "perfect":
|
||||||
confidence_counter["Correct-PERFECT"] += 1
|
confidence_counter["Correct-PERFECT"] += 1
|
||||||
elif score > THRESH and extra_lower == "":
|
elif score > thresh and extra_lower == "":
|
||||||
confidence_counter["Correct-FINE"] += 1
|
confidence_counter["Correct-FINE"] += 1
|
||||||
elif score > THRESH and extra_lower != "perfect" and extra_lower != "":
|
elif score > thresh and extra_lower != "perfect" and extra_lower != "":
|
||||||
confidence_counter["Over-confident"] += 1
|
confidence_counter["Over-confident"] += 1
|
||||||
wrong_counter[extra_lower] += 1
|
wrong_counter[extra_lower] += 1
|
||||||
overconfident_docs.append(doc_id)
|
overconfident_docs.append(doc_id)
|
||||||
elif score < THRESH and (extra_lower == "perfect" or extra_lower == ""):
|
elif score < thresh and (extra_lower == "perfect" or extra_lower == ""):
|
||||||
confidence_counter["Under-confident"] += 1
|
confidence_counter["Under-confident"] += 1
|
||||||
underconfident_docs.append(doc_id)
|
underconfident_docs.append(doc_id)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user