From 8f939d54c44c0f14258a7208a25ffbd10474a695 Mon Sep 17 00:00:00 2001 From: William Jeynes Date: Tue, 24 Mar 2026 19:07:24 +0000 Subject: [PATCH] Implement ensemble into final model structure --- agent/agent.ts | 19 +- agent/nodes/ensembleNode.ts | 17 ++ agent/nodes/produceRanking.ts | 40 ++- agent/nodes/ragasMetrics.ts | 16 -- agent/nodes/robertaMetrics.ts | 39 --- .../tools/{robertaCall.ts => ensembleCall.ts} | 4 +- agent/verify.ts | 33 ++- supporting/RAGAS_Service/ensemble_service.py | 16 +- supporting/RAGAS_Service/train_distilled.py | 234 ------------------ 9 files changed, 71 insertions(+), 347 deletions(-) create mode 100644 agent/nodes/ensembleNode.ts delete mode 100644 agent/nodes/ragasMetrics.ts delete mode 100644 agent/nodes/robertaMetrics.ts rename agent/tools/{robertaCall.ts => ensembleCall.ts} (93%) delete mode 100644 supporting/RAGAS_Service/train_distilled.py diff --git a/agent/agent.ts b/agent/agent.ts index 5be1b57..9bb59aa 100644 --- a/agent/agent.ts +++ b/agent/agent.ts @@ -10,7 +10,7 @@ import { createModelNode } from "./nodes/model"; import { loopEndConditional } from "./conditionals/loop_end"; import { sort } from "./nodes/sort"; import { triggerEventSetup } from "./nodes/triggerEventSetup"; -import { robertaMetrics } from "./nodes/robertaMetrics"; +import { createEnsembleNode } from "./nodes/ensembleNode"; const triggerEventToolNode = createToolNode(triggerEventToolsByName); @@ -19,6 +19,10 @@ const triggerEventModel = createModelNode(triggerEventToolsByName, "trigger.txt" const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name); +const roNode = createEnsembleNode("ROBERTA", "roberta"); +const flNode = createEnsembleNode("FLAN", "flan"); +const lrNode = createEnsembleNode("REGRESSION", "logreg"); + const agent = new StateGraph(MessagesState) //NODES @@ -30,7 +34,10 @@ const agent = new StateGraph(MessagesState) .addNode("triggerEventModel", triggerEventModel) .addNode(verificationSetup.name, verificationSetup) - .addNode(robertaMetrics.name, robertaMetrics) + + .addNode("roNode", roNode) + .addNode("flNode", flNode) + .addNode("lrNode", lrNode) .addNode(produceRanking.name, produceRanking) .addNode(sort.name, sort) @@ -45,9 +52,13 @@ const agent = new StateGraph(MessagesState) .addConditionalEdges("triggerEventModel", triggerEventToolConditional, ["triggerEventToolNode", verificationSetup.name]) .addEdge("triggerEventToolNode", "triggerEventModel") - .addEdge(verificationSetup.name, robertaMetrics.name) + .addEdge(verificationSetup.name, "roNode") + .addEdge(verificationSetup.name, "flNode") + .addEdge(verificationSetup.name, "lrNode") - .addEdge(robertaMetrics.name, produceRanking.name) + .addEdge("roNode", produceRanking.name) + .addEdge("flNode", produceRanking.name) + .addEdge("lrNode", produceRanking.name) // @ts-expect-error .addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name]) diff --git a/agent/nodes/ensembleNode.ts b/agent/nodes/ensembleNode.ts new file mode 100644 index 0000000..f54b63f --- /dev/null +++ b/agent/nodes/ensembleNode.ts @@ -0,0 +1,17 @@ +import { GraphNode } from "@langchain/langgraph"; +import { MessagesState } from "../state"; +import { AIMessage } from "@langchain/core/messages"; +import { evaluateWithEnsemble } from "../tools/ensembleCall"; + +export function createEnsembleNode(title: string, method: string): GraphNode { + return async (state) => { + const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event + + const result = await evaluateWithEnsemble({ answer, method }) + const score = result.validProb - result.invalidProb; + + return { + messages: [new AIMessage(title + ":" + score)] + }; + }; +}; \ No newline at end of file diff --git a/agent/nodes/produceRanking.ts b/agent/nodes/produceRanking.ts index 434fcf0..8c91e13 100644 --- a/agent/nodes/produceRanking.ts +++ b/agent/nodes/produceRanking.ts @@ -2,31 +2,25 @@ import { GraphNode } from "@langchain/langgraph"; import { MessagesState } from "../state"; import { BaseMessage } from "@langchain/core/messages"; -//TODO: Each of these might need different weights -const keys = ["CONFIDENCE", "RELATION", "RAGAS", "ROBERTA"]; - -const mapping = { - VERYHIGH: 1.0, - HIGH: 0.75, - MEDIUM: 0.5, - LOW: 0.25, - VERYLOW: 0.0, +const models = { + REGRESSION: 0.3, + ROBERTA: 0.5, + FLAN: 0.3, } as const; -type Priority = keyof typeof mapping; +type ModelKey = keyof typeof models; function mapResponse(value: string | undefined | null): number { - if (!value) return 1; + if (!value) return 0; const trimmed = value.trim(); const num = parseFloat(trimmed); - // If number, return it - if (!isNaN(num)) return num; - - // Otherwise, map to value - const upper = trimmed.toUpperCase() as Priority; - return mapping[upper] ?? 0; + if (!isNaN(num)) { + return num; + } else { + return 0; + } } function getLastMessageContaining( @@ -43,18 +37,18 @@ function getLastMessageContaining( } export const produceRanking: GraphNode = async (state) => { - // Extract and map values - const values = keys.map((key) => { + const values = (Object.keys(models) as ModelKey[]).map((key) => { const msg = getLastMessageContaining(state.messages, key); const part = msg?.split(":").at(1); - return mapResponse(part); + const baseValue = mapResponse(part); + + return baseValue * models[key]; }); - // Multiply! - const result = values.reduce((acc, val) => acc * val, 1); + const result = values.reduce((acc, val) => acc + val, 0); const current = state.proposedTriggerEvent; current[state.proposedTriggerEventIndex].score = result; return { proposedTriggerEvent: current }; -}; +}; \ No newline at end of file diff --git a/agent/nodes/ragasMetrics.ts b/agent/nodes/ragasMetrics.ts deleted file mode 100644 index d380fdb..0000000 --- a/agent/nodes/ragasMetrics.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { GraphNode } from "@langchain/langgraph"; -import { MessagesState } from "../state"; -import { AIMessage, HumanMessage } from "@langchain/core/messages"; -import { evaluateWithRagas } from "../tools/ragasCall"; - -export const ragasMetrics: GraphNode = async (state) => { - const question = "A possible trigger event for: " + state.disinformationTitle //Should it be raw, or normalized? - const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event - const contexts = state.proposedTriggerEvent[state.proposedTriggerEventIndex].context?.split("^^^") ?? [] - - const results = await evaluateWithRagas({question, answer, contexts}) - - return { - messages: [ new AIMessage("RAGAS:" + results.faithfulness)] - }; -}; \ No newline at end of file diff --git a/agent/nodes/robertaMetrics.ts b/agent/nodes/robertaMetrics.ts deleted file mode 100644 index 1be834f..0000000 --- a/agent/nodes/robertaMetrics.ts +++ /dev/null @@ -1,39 +0,0 @@ -import { GraphNode } from "@langchain/langgraph"; -import { MessagesState } from "../state"; -import { AIMessage } from "@langchain/core/messages"; -import { evaluateWithRoberta } from "../tools/robertaCall"; - -export const robertaMetrics: GraphNode = async (state) => { - const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event - - const lrresult = await evaluateWithRoberta({answer, method:"logreg"}) - const lrscore = lrresult.validProb - lrresult.invalidProb; - - const roresult = await evaluateWithRoberta({answer, method:"roberta"}) - const roscore = roresult.validProb - roresult.invalidProb; - - const flresult = await evaluateWithRoberta({answer, method:"flan"}) - const flscore = flresult.validProb - flresult.invalidProb; - - //Option 1: combining scores - const score = lrscore * 0.3 + roscore * 0.5 + flscore * 0.3 - - //Option 2: majority voting - // const rovote = roscore > 0.6 - // const flvote = flscore > 0.94 - // const lrvote = lrscore > 0.75 - - // let counter = 0 - // if (rovote) counter++ - // if (flvote) counter++ - // if (lrvote) counter++ - - // let score = 0 - // if (counter >= 2) { - // score = 0.7 + lrscore + flscore + lrscore - // } - - return { - messages: [ new AIMessage("ROBERTA:" + score)] - }; -}; \ No newline at end of file diff --git a/agent/tools/robertaCall.ts b/agent/tools/ensembleCall.ts similarity index 93% rename from agent/tools/robertaCall.ts rename to agent/tools/ensembleCall.ts index 0554ab2..112e1a3 100644 --- a/agent/tools/robertaCall.ts +++ b/agent/tools/ensembleCall.ts @@ -1,6 +1,6 @@ import axios from "axios"; -export async function evaluateWithRoberta({ +export async function evaluateWithEnsemble({ answer, method }: { @@ -10,7 +10,7 @@ export async function evaluateWithRoberta({ const res = await axios.post("http://localhost:8000/evaluate", { answer, method - }); + }, {timeout: 0}); // console.log(res.data) const validProb = res.data["probabilities"][0][0] const invalidProb = res.data["probabilities"][0][1] + res.data["probabilities"][0][2] diff --git a/agent/verify.ts b/agent/verify.ts index bc32e0b..5f357ec 100644 --- a/agent/verify.ts +++ b/agent/verify.ts @@ -1,39 +1,36 @@ import { END, START, StateGraph } from "@langchain/langgraph"; import { MessagesState } from "./state"; import { verificationSetup } from "./nodes/verificationSetup"; -import { ragasMetrics } from "./nodes/ragasMetrics"; import { produceRanking } from "./nodes/produceRanking"; -import { createModelNode } from "./nodes/model"; import { loopEndConditional } from "./conditionals/loop_end"; import { sort } from "./nodes/sort"; -import { robertaMetrics } from "./nodes/robertaMetrics"; +import { createEnsembleNode } from "./nodes/ensembleNode"; -const verificationModel = createModelNode([], "verify.txt"); -const relationModel = createModelNode([], "relation.txt"); +const roNode = createEnsembleNode("ROBERTA", "roberta"); +const flNode = createEnsembleNode("FLAN", "flan"); +const lrNode = createEnsembleNode("REGRESSION", "logreg"); const agent = new StateGraph(MessagesState) //NODES .addNode(verificationSetup.name, verificationSetup) - // .addNode("verificationModel", verificationModel) - // .addNode(ragasMetrics.name, ragasMetrics) - .addNode(robertaMetrics.name, robertaMetrics) - // .addNode("relationModel", relationModel) + .addNode("roNode", roNode) + .addNode("flNode", flNode) + .addNode("lrNode", lrNode) .addNode(produceRanking.name, produceRanking) .addNode(sort.name, sort) .addEdge(START, verificationSetup.name) - // .addEdge(verificationSetup.name, "verificationModel") - // .addEdge(verificationSetup.name, ragasMetrics.name) - .addEdge(verificationSetup.name, robertaMetrics.name) - // .addEdge(verificationSetup.name, "relationModel") - - // .addEdge(ragasMetrics.name, produceRanking.name) - .addEdge(robertaMetrics.name, produceRanking.name) - // .addEdge("verificationModel", produceRanking.name) - // .addEdge("relationModel", produceRanking.name) + + .addEdge(verificationSetup.name, "roNode") + .addEdge(verificationSetup.name, "flNode") + .addEdge(verificationSetup.name, "lrNode") + .addEdge("roNode", produceRanking.name) + .addEdge("flNode", produceRanking.name) + .addEdge("lrNode", produceRanking.name) + // @ts-expect-error .addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name]) diff --git a/supporting/RAGAS_Service/ensemble_service.py b/supporting/RAGAS_Service/ensemble_service.py index a5c5b20..5ceaf04 100644 --- a/supporting/RAGAS_Service/ensemble_service.py +++ b/supporting/RAGAS_Service/ensemble_service.py @@ -3,21 +3,15 @@ from fastapi import FastAPI import torch import torch.nn as nn import os - -# Embedding model from sentence_transformers import SentenceTransformer from huggingface_hub import hf_hub_download - -# Roberta from transformers import RobertaTokenizer, RobertaForSequenceClassification - -# Flan (seq2seq) from transformers import AutoTokenizer, AutoModelForSeq2SeqLM app = FastAPI() ############################################ -# ----------- REQUEST SCHEMA --------------- +# SCHEMA ############################################ class EvalRequest(BaseModel): @@ -26,7 +20,7 @@ class EvalRequest(BaseModel): ############################################ -# ----------- LOGREG MODEL ----------------- +# REGRESSION MODEL ############################################ HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression" @@ -72,7 +66,7 @@ logreg_model.eval() ############################################ -# ----------- ROBERTA MODEL ---------------- +# ROBERTA ############################################ ROBERTA_PATH = "WillJeynes/LLMsForDisinformationAnalysis" @@ -83,7 +77,7 @@ roberta_model.eval() ############################################ -# ----------- FLAN MODEL ------------------- +# FLAN ############################################ FLAN_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan" @@ -126,7 +120,7 @@ def parse_generated_label(generated: str): ############################################ -# ----------- MAIN ENDPOINT --------------- +# ENDPOINT ############################################ @app.post("/evaluate") diff --git a/supporting/RAGAS_Service/train_distilled.py b/supporting/RAGAS_Service/train_distilled.py deleted file mode 100644 index 8e660d7..0000000 --- a/supporting/RAGAS_Service/train_distilled.py +++ /dev/null @@ -1,234 +0,0 @@ -from sklearn.utils import compute_class_weight -from torch.nn import CrossEntropyLoss -from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification -import torch -from sklearn.model_selection import train_test_split -from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score -from collections import Counter -import sys -import csv -import numpy as np - -NUM_CLASSES = 3 -model_name = "distilbert/distilroberta-base" # Or MiniLM, or any other transformer model - -LABEL_PRIORITY = [ - ("PERFECT", 0), - ("STORY", 1), - ("NSPECIFIC", 2), - ("REWORDING", 1), - ("TINCORRECT", -1), - ("DUPLICATE", -1), - ("", 0), # fallback to PERFECT -] - -class WeightedTrainer(Trainer): - def __init__(self, *args, class_weights=None, **kwargs): - super().__init__(*args, **kwargs) - self.class_weights = class_weights - - def compute_loss(self, model, inputs, return_outputs=False, **kwargs): - labels = inputs.get("labels") - # print("DBG: Before forward") - outputs = model(**inputs) - # print("DBG: After forward") - logits = outputs.get("logits") - - loss_fct = CrossEntropyLoss( - weight=self.class_weights.to(logits.device).to(logits.dtype) - ) - # loss_fct = CrossEntropyLoss() - - # print("DBG: Before loss") - loss = loss_fct(logits, labels) - # loss.backward() - # print("DBG: After loss") - return (loss, outputs) if return_outputs else loss - -def label_to_int(extra_info: str) -> int: - """ - Convert extra_info string to integer label using priority rules. - """ - - if extra_info is None: - extra_info = "" - - extra_info = extra_info.strip() - - # Handle empty string explicitly - if extra_info == "": - for key, value in LABEL_PRIORITY: - if key == "": - return value - raise ValueError("Empty extra_info but no empty mapping defined") - - # Split words (case-insensitive) - tokens = set(extra_info.upper().split()) - - # Priority matching - for key, value in LABEL_PRIORITY: - if key == "": - continue - - if key in tokens: - return value - - raise ValueError(f"Unknown label content: '{extra_info}'") - - -def load_dataset_from_csv(path): - texts = [] - labels = [] - - removed_rows = 0 - - with open(path, newline="", encoding="utf-8") as f: - reader = csv.DictReader(f) - - for i, row in enumerate(reader, start=1): - text = row["event"] - label_str = row["extra_info"] - - try: - label_int = label_to_int(label_str) - except Exception as e: - print(f"ERROR converting label on line {i}: {label_str}") - print(e) - sys.exit(1) - - # Skip rows marked for removal - if label_int == -1: - removed_rows += 1 - continue - - texts.append(text) - labels.append(label_int) - - print(f"Loaded {len(texts)} samples (removed {removed_rows})") - - return texts, labels - - - -def compute_metrics(eval_pred): - logits, labels = eval_pred - preds = logits.argmax(axis=1) - - return { - "accuracy": accuracy_score(labels, preds), - "f1": f1_score(labels, preds, average="weighted", zero_division=0), - "precision": precision_score(labels, preds, average="weighted", zero_division=0), - "recall": recall_score(labels, preds, average="weighted", zero_division=0), - } - -def main(): - torch.multiprocessing.set_start_method('fork') - print("CUDA available:", torch.cuda.is_available()) - print("CUDA device count:", torch.cuda.device_count()) - print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU") - texts, labels = load_dataset_from_csv("../../data/classify.csv") - - tokenizer = AutoTokenizer.from_pretrained(model_name) - - model = AutoModelForSequenceClassification.from_pretrained( - model_name, - num_labels=NUM_CLASSES - ) - - print("Dataset size:", len(texts)) - print("Label distribution:") - print(Counter(labels)) - - train_texts, val_texts, train_labels, val_labels = train_test_split( - texts, - labels, - test_size=0.2, - random_state=42, - stratify=labels - ) - - - class_weights = compute_class_weight( - class_weight="balanced", - classes=np.unique(train_labels), - y=train_labels - ) - - class_weights = torch.tensor(class_weights, dtype=torch.float) - print("Class weights:", class_weights) - - train_encodings = tokenizer( - train_texts, - truncation=True, - padding=True, - max_length=256 - ) - - val_encodings = tokenizer( - val_texts, - truncation=True, - padding=True, - max_length=256 - ) - - class TextDataset(torch.utils.data.Dataset): - def __init__(self, encodings, labels): - self.encodings = encodings - self.labels = labels - - def __getitem__(self, idx): - # print(f"DBG: Loading item {idx}") - item = { - key: torch.tensor(val[idx]) - for key, val in self.encodings.items() - } - item["labels"] = torch.tensor(self.labels[idx]) - return item - - def __len__(self): - return len(self.labels) - - training_args = TrainingArguments( - output_dir="./results", - learning_rate=2e-5, - per_device_train_batch_size=32, - # gradient_accumulation_steps=2, - num_train_epochs=15, - weight_decay=0.01, - load_best_model_at_end=True, - eval_strategy="epoch", - save_strategy="epoch", - metric_for_best_model="f1", - greater_is_better=True, - dataloader_num_workers=4, - dataloader_pin_memory=True, - # warmup_steps=100, - ) - - train_dataset = TextDataset(train_encodings, train_labels) - - val_dataset = TextDataset(val_encodings, val_labels) - - trainer = WeightedTrainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=val_dataset, - compute_metrics=compute_metrics, - class_weights=class_weights - ) - - trainer.train() - - metrics = trainer.evaluate() - print("Final evaluation metrics:") - for k, v in metrics.items(): - print(f"{k}: {v}") - - trainer.save_model("./roberta_distilled_classifier") - tokenizer.save_pretrained("./roberta_distilled_classifier") - - - -if __name__ == "__main__": - main() \ No newline at end of file