diff --git a/agent/nodes/produceRanking.ts b/agent/nodes/produceRanking.ts index 0c50c93..434fcf0 100644 --- a/agent/nodes/produceRanking.ts +++ b/agent/nodes/produceRanking.ts @@ -3,7 +3,7 @@ import { MessagesState } from "../state"; import { BaseMessage } from "@langchain/core/messages"; //TODO: Each of these might need different weights -const keys = ["CONFIDENCE", "RAGAS", "RELATION"]; +const keys = ["CONFIDENCE", "RELATION", "RAGAS", "ROBERTA"]; const mapping = { VERYHIGH: 1.0, @@ -16,7 +16,7 @@ const mapping = { type Priority = keyof typeof mapping; function mapResponse(value: string | undefined | null): number { - if (!value) return 0; + if (!value) return 1; const trimmed = value.trim(); const num = parseFloat(trimmed); diff --git a/agent/nodes/robertaMetrics.ts b/agent/nodes/robertaMetrics.ts new file mode 100644 index 0000000..41e288e --- /dev/null +++ b/agent/nodes/robertaMetrics.ts @@ -0,0 +1,14 @@ +import { GraphNode } from "@langchain/langgraph"; +import { MessagesState } from "../state"; +import { AIMessage } from "@langchain/core/messages"; +import { evaluateWithRoberta } from "../tools/robertaCall"; + +export const robertaMetrics: GraphNode = async (state) => { + const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event + + const result = await evaluateWithRoberta({answer}) + + return { + messages: [ new AIMessage("ROBERTA:" + result)] + }; +}; \ No newline at end of file diff --git a/agent/tools/robertaCall.ts b/agent/tools/robertaCall.ts new file mode 100644 index 0000000..8a63529 --- /dev/null +++ b/agent/tools/robertaCall.ts @@ -0,0 +1,22 @@ +import axios from "axios"; + +export async function evaluateWithRoberta({ + answer +}: { + answer: string; +}) { + const res = await axios.post("http://localhost:8000/evaluate", { + answer + }); + // console.log(res.data) + const validProb = res.data["probabilities"][0][0] + const invalidProv = res.data["probabilities"][0][1] + + return validProb > invalidProv ? 1 : 0; +} + +// let res = await evaluateWithRoberta({answer: "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in March–August 2020)"}); +// console.log(res) + +// res = await evaluateWithRoberta({answer: "Multiple mirrored reuploads (2020–2023) put the clip on other channels with titles implying it was a genuine 1970s public information film."}); +// console.log(res) \ No newline at end of file diff --git a/agent/verify.ts b/agent/verify.ts index ec40211..bc32e0b 100644 --- a/agent/verify.ts +++ b/agent/verify.ts @@ -6,6 +6,7 @@ import { produceRanking } from "./nodes/produceRanking"; import { createModelNode } from "./nodes/model"; import { loopEndConditional } from "./conditionals/loop_end"; import { sort } from "./nodes/sort"; +import { robertaMetrics } from "./nodes/robertaMetrics"; const verificationModel = createModelNode([], "verify.txt"); const relationModel = createModelNode([], "relation.txt"); @@ -14,21 +15,24 @@ const agent = new StateGraph(MessagesState) //NODES .addNode(verificationSetup.name, verificationSetup) - .addNode("verificationModel", verificationModel) - .addNode(ragasMetrics.name, ragasMetrics) - .addNode("relationModel", relationModel) + // .addNode("verificationModel", verificationModel) + // .addNode(ragasMetrics.name, ragasMetrics) + .addNode(robertaMetrics.name, robertaMetrics) + // .addNode("relationModel", relationModel) .addNode(produceRanking.name, produceRanking) .addNode(sort.name, sort) .addEdge(START, verificationSetup.name) - .addEdge(verificationSetup.name, "verificationModel") - .addEdge(verificationSetup.name, ragasMetrics.name) - .addEdge(verificationSetup.name, "relationModel") + // .addEdge(verificationSetup.name, "verificationModel") + // .addEdge(verificationSetup.name, ragasMetrics.name) + .addEdge(verificationSetup.name, robertaMetrics.name) + // .addEdge(verificationSetup.name, "relationModel") - .addEdge(ragasMetrics.name, produceRanking.name) - .addEdge("verificationModel", produceRanking.name) - .addEdge("relationModel", produceRanking.name) + // .addEdge(ragasMetrics.name, produceRanking.name) + .addEdge(robertaMetrics.name, produceRanking.name) + // .addEdge("verificationModel", produceRanking.name) + // .addEdge("relationModel", produceRanking.name) // @ts-expect-error .addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name]) diff --git a/supporting/RAGAS_Service/roberta_service.py b/supporting/RAGAS_Service/roberta_service.py index a01f848..a22b87f 100644 --- a/supporting/RAGAS_Service/roberta_service.py +++ b/supporting/RAGAS_Service/roberta_service.py @@ -1,25 +1,33 @@ +from pydantic import BaseModel from transformers import RobertaTokenizer, RobertaForSequenceClassification import torch +from fastapi import FastAPI + +app = FastAPI() MODEL_PATH = "./roberta_classifier" tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH) model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH) -text2 = "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in March–August 2020)" -text = "Multiple mirrored reuploads (2020–2023) put the clip on other channels with titles implying it was a genuine 1970s public information film." +class EvalRequest(BaseModel): + answer: str -inputs = tokenizer( - text, - return_tensors="pt", - truncation=True, - padding=True -) +@app.post("/evaluate") +def evaluate_rob(req: EvalRequest): + inputs = tokenizer( + req.answer, + return_tensors="pt", + truncation=True, + padding=True + ) -model.eval() + model.eval() -with torch.no_grad(): - logits = model(**inputs).logits + with torch.no_grad(): + logits = model(**inputs).logits -probs = torch.softmax(logits, dim=1) -print(probs) \ No newline at end of file + probs = torch.softmax(logits, dim=1) + return { + "probabilities": probs.cpu().numpy().tolist() + } \ No newline at end of file diff --git a/supporting/RAGAS_Service/train_roberta.py b/supporting/RAGAS_Service/train_roberta.py index b4abf95..58299a0 100644 --- a/supporting/RAGAS_Service/train_roberta.py +++ b/supporting/RAGAS_Service/train_roberta.py @@ -6,17 +6,17 @@ from collections import Counter import sys import csv -NUM_CLASSES = 3 +NUM_CLASSES = 2 model_name = "roberta-base" LABEL_PRIORITY = [ ("PERFECT", 0), ("STORY", 1), - ("NSPECIFIC", 2), - ("REWORDING", 2), + ("NSPECIFIC", 1), + ("REWORDING", 1), ("TINCORRECT", -1), ("DUPLICATE", -1), - ("", 2), # fallback to PERFECT + ("", 0), # fallback to PERFECT ] def label_to_int(extra_info: str) -> int: diff --git a/supporting/Wrapper/run.ts b/supporting/Wrapper/run.ts index 34305fe..7e362bb 100644 --- a/supporting/Wrapper/run.ts +++ b/supporting/Wrapper/run.ts @@ -102,7 +102,7 @@ function buildAgentInput(record: Claim | VerifierInput) { date: v.date, proposedTriggerEvent: v.events, normalizedClaim: v.normalizedClaim, - proposedTriggerEventIndex: 0 + proposedTriggerEventIndex: -1 }; } diff --git a/supporting/scorer/views/stats.py b/supporting/scorer/views/stats.py index 70b2436..57ffb00 100644 --- a/supporting/scorer/views/stats.py +++ b/supporting/scorer/views/stats.py @@ -56,7 +56,7 @@ def render(): st.error("Invalid folder path.") return - jsonl_files = list(path.glob("*.jsonl")) + jsonl_files = sorted(path.glob("*.jsonl")) if not jsonl_files: st.info("No .jsonl files found in this folder.") return @@ -80,13 +80,13 @@ def render(): print(extra_lower) if score is not None: if score > THRESH and extra_lower == "perfect": - confidence_counter["Correct"] += 1 + confidence_counter["Correct-TRUE"] += 1 elif score > THRESH and extra_lower != "perfect": confidence_counter["Over-confident"] += 1 elif score < THRESH and extra_lower == "perfect": confidence_counter["Under-confident"] += 1 else: - confidence_counter["Other"] += 1 + confidence_counter["Correct-FALSE"] += 1 if confidence_counter: df_conf = pd.DataFrame( @@ -104,6 +104,11 @@ def render(): ax.axis("equal") ax.set_title(file_path.name) + total = sum(confidence_counter.values()) + correct = confidence_counter["Correct-TRUE"] + confidence_counter["Correct-FALSE"] + + corr_percent = (correct / total) * 100 + st.markdown(f"**Correct: {corr_percent:.2f}% ({correct}/{total})**") st.pyplot(fig, width=500) else: st.info("No score data available in this file.") \ No newline at end of file