From e368c5057717fcf5dbce9e6fa79fe53384be21a7 Mon Sep 17 00:00:00 2001 From: William Jeynes Date: Mon, 23 Mar 2026 22:43:59 +0000 Subject: [PATCH] Add training scripts for distilled, flan. Add run service for flan --- supporting/RAGAS_Service/.gitignore | 1 + supporting/RAGAS_Service/flan_service.py | 89 ++++++++ supporting/RAGAS_Service/roberta_service.py | 2 +- supporting/RAGAS_Service/train_distilled.py | 234 ++++++++++++++++++++ supporting/RAGAS_Service/train_flan.py | 227 +++++++++++++++++++ supporting/Wrapper/run.ts | 2 +- supporting/scorer/views/stats.py | 13 +- 7 files changed, 561 insertions(+), 7 deletions(-) create mode 100644 supporting/RAGAS_Service/flan_service.py create mode 100644 supporting/RAGAS_Service/train_distilled.py create mode 100644 supporting/RAGAS_Service/train_flan.py diff --git a/supporting/RAGAS_Service/.gitignore b/supporting/RAGAS_Service/.gitignore index 0098cc5..be4ae57 100644 --- a/supporting/RAGAS_Service/.gitignore +++ b/supporting/RAGAS_Service/.gitignore @@ -1,6 +1,7 @@ # -- OURS -- results/ roberta_classifier/ +roberta_distilled_classifier/ roberta_classifier*/ output* diff --git a/supporting/RAGAS_Service/flan_service.py b/supporting/RAGAS_Service/flan_service.py new file mode 100644 index 0000000..152d792 --- /dev/null +++ b/supporting/RAGAS_Service/flan_service.py @@ -0,0 +1,89 @@ +from pydantic import BaseModel +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +import torch +from fastapi import FastAPI + +app = FastAPI() + +MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan" + +INT_TO_LABEL = { + 0: "perfect", + 1: "story", + 2: "not specific", +} + +LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()} + +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) +model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH) +model.eval() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model.to(device) + + +def format_prompt(text: str) -> str: + return ( + "Classify the following event into one of these categories: " + "perfect, story, not specific.\n\n" + f"Event: {text}\n\n" + "Category:" + ) + + +def parse_generated_label(generated: str) -> int | None: + generated = generated.strip().lower() + for label_text, label_int in LABEL_TO_INT.items(): + if label_text in generated: + return label_int + return None + + +class EvalRequest(BaseModel): + answer: str + + +@app.post("/evaluate") +def evaluate(req: EvalRequest): + prompt = format_prompt(req.answer) + + inputs = tokenizer( + prompt, + return_tensors="pt", + truncation=True, + padding=True, + max_length=256, + ).to(device) + + with torch.no_grad(): + # Get the generated label + outputs = model.generate( + **inputs, + max_new_tokens=8, + ) + + # Produce a confidence score + decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device) + logits_output = model(**inputs, decoder_input_ids=decoder_input_ids) + logits = logits_output.logits[:, 0, :] + + # Decode the generated text label + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + predicted_int = parse_generated_label(generated_text) + + # Extract probabilities + label_token_ids = { + label: tokenizer(label, add_special_tokens=False).input_ids[0] + for label in LABEL_TO_INT.keys() + } + + label_logits = torch.tensor( + [logits[0, tid].item() for tid in label_token_ids.values()] + ) + label_probs = torch.softmax(label_logits, dim=0).tolist() + + return { + "generated": generated_text, + "probabilities": [label_probs], + } \ No newline at end of file diff --git a/supporting/RAGAS_Service/roberta_service.py b/supporting/RAGAS_Service/roberta_service.py index a22b87f..d3d99dc 100644 --- a/supporting/RAGAS_Service/roberta_service.py +++ b/supporting/RAGAS_Service/roberta_service.py @@ -5,7 +5,7 @@ from fastapi import FastAPI app = FastAPI() -MODEL_PATH = "./roberta_classifier" +MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis" tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH) model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH) diff --git a/supporting/RAGAS_Service/train_distilled.py b/supporting/RAGAS_Service/train_distilled.py new file mode 100644 index 0000000..8e660d7 --- /dev/null +++ b/supporting/RAGAS_Service/train_distilled.py @@ -0,0 +1,234 @@ +from sklearn.utils import compute_class_weight +from torch.nn import CrossEntropyLoss +from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification +import torch +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score +from collections import Counter +import sys +import csv +import numpy as np + +NUM_CLASSES = 3 +model_name = "distilbert/distilroberta-base" # Or MiniLM, or any other transformer model + +LABEL_PRIORITY = [ + ("PERFECT", 0), + ("STORY", 1), + ("NSPECIFIC", 2), + ("REWORDING", 1), + ("TINCORRECT", -1), + ("DUPLICATE", -1), + ("", 0), # fallback to PERFECT +] + +class WeightedTrainer(Trainer): + def __init__(self, *args, class_weights=None, **kwargs): + super().__init__(*args, **kwargs) + self.class_weights = class_weights + + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + labels = inputs.get("labels") + # print("DBG: Before forward") + outputs = model(**inputs) + # print("DBG: After forward") + logits = outputs.get("logits") + + loss_fct = CrossEntropyLoss( + weight=self.class_weights.to(logits.device).to(logits.dtype) + ) + # loss_fct = CrossEntropyLoss() + + # print("DBG: Before loss") + loss = loss_fct(logits, labels) + # loss.backward() + # print("DBG: After loss") + return (loss, outputs) if return_outputs else loss + +def label_to_int(extra_info: str) -> int: + """ + Convert extra_info string to integer label using priority rules. + """ + + if extra_info is None: + extra_info = "" + + extra_info = extra_info.strip() + + # Handle empty string explicitly + if extra_info == "": + for key, value in LABEL_PRIORITY: + if key == "": + return value + raise ValueError("Empty extra_info but no empty mapping defined") + + # Split words (case-insensitive) + tokens = set(extra_info.upper().split()) + + # Priority matching + for key, value in LABEL_PRIORITY: + if key == "": + continue + + if key in tokens: + return value + + raise ValueError(f"Unknown label content: '{extra_info}'") + + +def load_dataset_from_csv(path): + texts = [] + labels = [] + + removed_rows = 0 + + with open(path, newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + + for i, row in enumerate(reader, start=1): + text = row["event"] + label_str = row["extra_info"] + + try: + label_int = label_to_int(label_str) + except Exception as e: + print(f"ERROR converting label on line {i}: {label_str}") + print(e) + sys.exit(1) + + # Skip rows marked for removal + if label_int == -1: + removed_rows += 1 + continue + + texts.append(text) + labels.append(label_int) + + print(f"Loaded {len(texts)} samples (removed {removed_rows})") + + return texts, labels + + + +def compute_metrics(eval_pred): + logits, labels = eval_pred + preds = logits.argmax(axis=1) + + return { + "accuracy": accuracy_score(labels, preds), + "f1": f1_score(labels, preds, average="weighted", zero_division=0), + "precision": precision_score(labels, preds, average="weighted", zero_division=0), + "recall": recall_score(labels, preds, average="weighted", zero_division=0), + } + +def main(): + torch.multiprocessing.set_start_method('fork') + print("CUDA available:", torch.cuda.is_available()) + print("CUDA device count:", torch.cuda.device_count()) + print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU") + texts, labels = load_dataset_from_csv("../../data/classify.csv") + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + model = AutoModelForSequenceClassification.from_pretrained( + model_name, + num_labels=NUM_CLASSES + ) + + print("Dataset size:", len(texts)) + print("Label distribution:") + print(Counter(labels)) + + train_texts, val_texts, train_labels, val_labels = train_test_split( + texts, + labels, + test_size=0.2, + random_state=42, + stratify=labels + ) + + + class_weights = compute_class_weight( + class_weight="balanced", + classes=np.unique(train_labels), + y=train_labels + ) + + class_weights = torch.tensor(class_weights, dtype=torch.float) + print("Class weights:", class_weights) + + train_encodings = tokenizer( + train_texts, + truncation=True, + padding=True, + max_length=256 + ) + + val_encodings = tokenizer( + val_texts, + truncation=True, + padding=True, + max_length=256 + ) + + class TextDataset(torch.utils.data.Dataset): + def __init__(self, encodings, labels): + self.encodings = encodings + self.labels = labels + + def __getitem__(self, idx): + # print(f"DBG: Loading item {idx}") + item = { + key: torch.tensor(val[idx]) + for key, val in self.encodings.items() + } + item["labels"] = torch.tensor(self.labels[idx]) + return item + + def __len__(self): + return len(self.labels) + + training_args = TrainingArguments( + output_dir="./results", + learning_rate=2e-5, + per_device_train_batch_size=32, + # gradient_accumulation_steps=2, + num_train_epochs=15, + weight_decay=0.01, + load_best_model_at_end=True, + eval_strategy="epoch", + save_strategy="epoch", + metric_for_best_model="f1", + greater_is_better=True, + dataloader_num_workers=4, + dataloader_pin_memory=True, + # warmup_steps=100, + ) + + train_dataset = TextDataset(train_encodings, train_labels) + + val_dataset = TextDataset(val_encodings, val_labels) + + trainer = WeightedTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + compute_metrics=compute_metrics, + class_weights=class_weights + ) + + trainer.train() + + metrics = trainer.evaluate() + print("Final evaluation metrics:") + for k, v in metrics.items(): + print(f"{k}: {v}") + + trainer.save_model("./roberta_distilled_classifier") + tokenizer.save_pretrained("./roberta_distilled_classifier") + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/supporting/RAGAS_Service/train_flan.py b/supporting/RAGAS_Service/train_flan.py new file mode 100644 index 0000000..3d830c1 --- /dev/null +++ b/supporting/RAGAS_Service/train_flan.py @@ -0,0 +1,227 @@ +from sklearn.utils import compute_class_weight +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq +import torch +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score +from collections import Counter +import sys +import csv +import numpy as np + +NUM_CLASSES = 3 +model_name = "google/flan-t5-base" + +INT_TO_LABEL = { + 0: "perfect", + 1: "story", + 2: "not specific", +} +LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()} + +LABEL_PRIORITY = [ + ("PERFECT", 0), + ("STORY", 1), + ("NSPECIFIC", 2), + ("REWORDING", 1), + ("TINCORRECT", -1), + ("DUPLICATE", -1), + ("", 0), +] + +def label_to_int(extra_info: str) -> int: + if extra_info is None: + extra_info = "" + extra_info = extra_info.strip() + if extra_info == "": + for key, value in LABEL_PRIORITY: + if key == "": + return value + raise ValueError("Empty extra_info but no empty mapping defined") + tokens = set(extra_info.upper().split()) + for key, value in LABEL_PRIORITY: + if key == "" : + continue + if key in tokens: + return value + raise ValueError(f"Unknown label content: '{extra_info}'") + + +def load_dataset_from_csv(path): + texts = [] + labels = [] + removed_rows = 0 + with open(path, newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + for i, row in enumerate(reader, start=1): + text = row["event"] + label_str = row["extra_info"] + try: + label_int = label_to_int(label_str) + except Exception as e: + print(f"ERROR converting label on line {i}: {label_str}") + print(e) + sys.exit(1) + if label_int == -1: + removed_rows += 1 + continue + texts.append(text) + labels.append(label_int) + print(f"Loaded {len(texts)} samples (removed {removed_rows})") + return texts, labels + + +def format_prompt(text: str) -> str: + return ( + "Classify the following event into one of these categories: " + "perfect, story, not specific.\n\n" + f"Event: {text}\n\n" + "Category:" + ) + + +def parse_generated_label(generated: str) -> int: + generated = generated.strip().lower() + for label_text, label_int in LABEL_TO_INT.items(): + if label_text in generated: + return label_int + print("invlid label:" + generated) + return -1 # unknown / unparseable output + + +class GenerativeTextDataset(torch.utils.data.Dataset): + def __init__(self, texts, labels, tokenizer, max_input_length=256, max_target_length=8): + self.tokenizer = tokenizer + self.max_input_length = max_input_length + self.max_target_length = max_target_length + + self.inputs = [format_prompt(t) for t in texts] + # Convert integer labels to their text equivalents for the target sequence + self.targets = [INT_TO_LABEL[l] for l in labels] + self.int_labels = labels # keep originals for metric computation + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, idx): + model_inputs = self.tokenizer( + self.inputs[idx], + max_length=self.max_input_length, + truncation=True, + padding=False, + ) + target_encoding = self.tokenizer( + self.targets[idx], + max_length=self.max_target_length, + truncation=True, + padding=False, + ) + # Seq2Seq convention: labels use -100 to ignore padding tokens in loss + labels = target_encoding["input_ids"] + labels = [token if token != self.tokenizer.pad_token_id else -100 for token in labels] + + model_inputs["labels"] = labels + return {k: torch.tensor(v) for k, v in model_inputs.items()} + + +def compute_metrics_generative(eval_pred, tokenizer): + predictions, label_ids = eval_pred + + # Decode predictions + # Replace -100 in labels before decoding + label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id) + + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True) + + # Map decoded text back to integer labels + pred_ints = [parse_generated_label(p) for p in decoded_preds] + true_ints = [parse_generated_label(l) for l in decoded_labels] + + # Filter out any rows where parsing failed + valid = [(p, t) for p, t in zip(pred_ints, true_ints) if t != -1] + if not valid: + return {"accuracy": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} + + preds_filtered, true_filtered = zip(*valid) + + return { + "accuracy": accuracy_score(true_filtered, preds_filtered), + "f1": f1_score(true_filtered, preds_filtered, average="weighted", zero_division=0), + "precision": precision_score(true_filtered, preds_filtered, average="weighted", zero_division=0), + "recall": recall_score(true_filtered, preds_filtered, average="weighted", zero_division=0), + } + + +def main(): + torch.multiprocessing.set_start_method('spawn', force=True) + print("CUDA available:", torch.cuda.is_available()) + print("CUDA device count:", torch.cuda.device_count()) + + texts, labels = load_dataset_from_csv("../../data/classify.csv") + + print("Dataset size:", len(texts)) + print("Label distribution:", Counter(labels)) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + + train_texts, val_texts, train_labels, val_labels = train_test_split( + texts, labels, + test_size=0.2, + random_state=42, + stratify=labels + ) + + train_dataset = GenerativeTextDataset(train_texts, train_labels, tokenizer) + val_dataset = GenerativeTextDataset(val_texts, val_labels, tokenizer) + + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + model=model, + padding=True, + label_pad_token_id=-100, + ) + + training_args = Seq2SeqTrainingArguments( + output_dir="./results", + learning_rate=5e-5, + per_device_train_batch_size=16, + per_device_eval_batch_size=16, + num_train_epochs=10, + weight_decay=0.01, + eval_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + metric_for_best_model="f1", + greater_is_better=True, + predict_with_generate=True, + generation_max_length=8, + dataloader_num_workers=0, + dataloader_pin_memory=False, + fp16=False, + max_grad_norm=1.0, + ) + + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + processing_class=tokenizer, + data_collator=data_collator, + compute_metrics=lambda ep: compute_metrics_generative(ep, tokenizer), + ) + + trainer.train() + + metrics = trainer.evaluate() + print("\nFinal evaluation metrics:") + for k, v in metrics.items(): + print(f" {k}: {v}") + + trainer.save_model("./flan_classifier") + tokenizer.save_pretrained("./flan_classifier") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/supporting/Wrapper/run.ts b/supporting/Wrapper/run.ts index 7e362bb..277bfad 100644 --- a/supporting/Wrapper/run.ts +++ b/supporting/Wrapper/run.ts @@ -17,7 +17,7 @@ const AGENT_NAME = process.env.AGENT ?? "agent"; */ const MODE = process.env.MODE ?? "claim"; -const MAX_CONCURRENCY = 5; +const MAX_CONCURRENCY = 1; const client = new Client({ apiUrl: API_URL }); diff --git a/supporting/scorer/views/stats.py b/supporting/scorer/views/stats.py index 4da21cc..5b9c190 100644 --- a/supporting/scorer/views/stats.py +++ b/supporting/scorer/views/stats.py @@ -5,7 +5,6 @@ import streamlit as st import pandas as pd import matplotlib.pyplot as plt -# THRESH = 0.4 THRESH = 0.6 def page_title() -> str: @@ -61,6 +60,10 @@ def render(): return for file_path in jsonl_files: + thresh = THRESH + if ("flan" in file_path.name): + thresh = 0.94 + st.subheader(f"File: {file_path.name}") confidence_counter = Counter() @@ -86,15 +89,15 @@ def render(): dup_counter += 1 elif "ranked" not in event: "ignore for now" - elif score > THRESH and extra_lower == "perfect": + elif score > thresh and extra_lower == "perfect": confidence_counter["Correct-PERFECT"] += 1 - elif score > THRESH and extra_lower == "": + elif score > thresh and extra_lower == "": confidence_counter["Correct-FINE"] += 1 - elif score > THRESH and extra_lower != "perfect" and extra_lower != "": + elif score > thresh and extra_lower != "perfect" and extra_lower != "": confidence_counter["Over-confident"] += 1 wrong_counter[extra_lower] += 1 overconfident_docs.append(doc_id) - elif score < THRESH and (extra_lower == "perfect" or extra_lower == ""): + elif score < thresh and (extra_lower == "perfect" or extra_lower == ""): confidence_counter["Under-confident"] += 1 underconfident_docs.append(doc_id) else: