Add ROBERTA classifier ranking PoC, with 77pc off the bat

This commit is contained in:
William Jeynes
2026-03-13 11:24:51 +00:00
parent f09e36e740
commit 8311556855
8 changed files with 85 additions and 32 deletions
+2 -2
View File
@@ -3,7 +3,7 @@ import { MessagesState } from "../state";
import { BaseMessage } from "@langchain/core/messages"; import { BaseMessage } from "@langchain/core/messages";
//TODO: Each of these might need different weights //TODO: Each of these might need different weights
const keys = ["CONFIDENCE", "RAGAS", "RELATION"]; const keys = ["CONFIDENCE", "RELATION", "RAGAS", "ROBERTA"];
const mapping = { const mapping = {
VERYHIGH: 1.0, VERYHIGH: 1.0,
@@ -16,7 +16,7 @@ const mapping = {
type Priority = keyof typeof mapping; type Priority = keyof typeof mapping;
function mapResponse(value: string | undefined | null): number { function mapResponse(value: string | undefined | null): number {
if (!value) return 0; if (!value) return 1;
const trimmed = value.trim(); const trimmed = value.trim();
const num = parseFloat(trimmed); const num = parseFloat(trimmed);
+14
View File
@@ -0,0 +1,14 @@
import { GraphNode } from "@langchain/langgraph";
import { MessagesState } from "../state";
import { AIMessage } from "@langchain/core/messages";
import { evaluateWithRoberta } from "../tools/robertaCall";
export const robertaMetrics: GraphNode<typeof MessagesState> = async (state) => {
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
const result = await evaluateWithRoberta({answer})
return {
messages: [ new AIMessage("ROBERTA:" + result)]
};
};
+22
View File
@@ -0,0 +1,22 @@
import axios from "axios";
export async function evaluateWithRoberta({
answer
}: {
answer: string;
}) {
const res = await axios.post("http://localhost:8000/evaluate", {
answer
});
// console.log(res.data)
const validProb = res.data["probabilities"][0][0]
const invalidProv = res.data["probabilities"][0][1]
return validProb > invalidProv ? 1 : 0;
}
// let res = await evaluateWithRoberta({answer: "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in MarchAugust 2020)"});
// console.log(res)
// res = await evaluateWithRoberta({answer: "Multiple mirrored reuploads (20202023) put the clip on other channels with titles implying it was a genuine 1970s public information film."});
// console.log(res)
+13 -9
View File
@@ -6,6 +6,7 @@ import { produceRanking } from "./nodes/produceRanking";
import { createModelNode } from "./nodes/model"; import { createModelNode } from "./nodes/model";
import { loopEndConditional } from "./conditionals/loop_end"; import { loopEndConditional } from "./conditionals/loop_end";
import { sort } from "./nodes/sort"; import { sort } from "./nodes/sort";
import { robertaMetrics } from "./nodes/robertaMetrics";
const verificationModel = createModelNode([], "verify.txt"); const verificationModel = createModelNode([], "verify.txt");
const relationModel = createModelNode([], "relation.txt"); const relationModel = createModelNode([], "relation.txt");
@@ -14,21 +15,24 @@ const agent = new StateGraph(MessagesState)
//NODES //NODES
.addNode(verificationSetup.name, verificationSetup) .addNode(verificationSetup.name, verificationSetup)
.addNode("verificationModel", verificationModel) // .addNode("verificationModel", verificationModel)
.addNode(ragasMetrics.name, ragasMetrics) // .addNode(ragasMetrics.name, ragasMetrics)
.addNode("relationModel", relationModel) .addNode(robertaMetrics.name, robertaMetrics)
// .addNode("relationModel", relationModel)
.addNode(produceRanking.name, produceRanking) .addNode(produceRanking.name, produceRanking)
.addNode(sort.name, sort) .addNode(sort.name, sort)
.addEdge(START, verificationSetup.name) .addEdge(START, verificationSetup.name)
.addEdge(verificationSetup.name, "verificationModel") // .addEdge(verificationSetup.name, "verificationModel")
.addEdge(verificationSetup.name, ragasMetrics.name) // .addEdge(verificationSetup.name, ragasMetrics.name)
.addEdge(verificationSetup.name, "relationModel") .addEdge(verificationSetup.name, robertaMetrics.name)
// .addEdge(verificationSetup.name, "relationModel")
.addEdge(ragasMetrics.name, produceRanking.name) // .addEdge(ragasMetrics.name, produceRanking.name)
.addEdge("verificationModel", produceRanking.name) .addEdge(robertaMetrics.name, produceRanking.name)
.addEdge("relationModel", produceRanking.name) // .addEdge("verificationModel", produceRanking.name)
// .addEdge("relationModel", produceRanking.name)
// @ts-expect-error // @ts-expect-error
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name]) .addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
+17 -9
View File
@@ -1,25 +1,33 @@
from pydantic import BaseModel
from transformers import RobertaTokenizer, RobertaForSequenceClassification from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch import torch
from fastapi import FastAPI
app = FastAPI()
MODEL_PATH = "./roberta_classifier" MODEL_PATH = "./roberta_classifier"
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH) tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH) model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
text2 = "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in MarchAugust 2020)" class EvalRequest(BaseModel):
text = "Multiple mirrored reuploads (20202023) put the clip on other channels with titles implying it was a genuine 1970s public information film." answer: str
inputs = tokenizer( @app.post("/evaluate")
text, def evaluate_rob(req: EvalRequest):
inputs = tokenizer(
req.answer,
return_tensors="pt", return_tensors="pt",
truncation=True, truncation=True,
padding=True padding=True
) )
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(**inputs).logits logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1) probs = torch.softmax(logits, dim=1)
print(probs) return {
"probabilities": probs.cpu().numpy().tolist()
}
+4 -4
View File
@@ -6,17 +6,17 @@ from collections import Counter
import sys import sys
import csv import csv
NUM_CLASSES = 3 NUM_CLASSES = 2
model_name = "roberta-base" model_name = "roberta-base"
LABEL_PRIORITY = [ LABEL_PRIORITY = [
("PERFECT", 0), ("PERFECT", 0),
("STORY", 1), ("STORY", 1),
("NSPECIFIC", 2), ("NSPECIFIC", 1),
("REWORDING", 2), ("REWORDING", 1),
("TINCORRECT", -1), ("TINCORRECT", -1),
("DUPLICATE", -1), ("DUPLICATE", -1),
("", 2), # fallback to PERFECT ("", 0), # fallback to PERFECT
] ]
def label_to_int(extra_info: str) -> int: def label_to_int(extra_info: str) -> int:
+1 -1
View File
@@ -102,7 +102,7 @@ function buildAgentInput(record: Claim | VerifierInput) {
date: v.date, date: v.date,
proposedTriggerEvent: v.events, proposedTriggerEvent: v.events,
normalizedClaim: v.normalizedClaim, normalizedClaim: v.normalizedClaim,
proposedTriggerEventIndex: 0 proposedTriggerEventIndex: -1
}; };
} }
+8 -3
View File
@@ -56,7 +56,7 @@ def render():
st.error("Invalid folder path.") st.error("Invalid folder path.")
return return
jsonl_files = list(path.glob("*.jsonl")) jsonl_files = sorted(path.glob("*.jsonl"))
if not jsonl_files: if not jsonl_files:
st.info("No .jsonl files found in this folder.") st.info("No .jsonl files found in this folder.")
return return
@@ -80,13 +80,13 @@ def render():
print(extra_lower) print(extra_lower)
if score is not None: if score is not None:
if score > THRESH and extra_lower == "perfect": if score > THRESH and extra_lower == "perfect":
confidence_counter["Correct"] += 1 confidence_counter["Correct-TRUE"] += 1
elif score > THRESH and extra_lower != "perfect": elif score > THRESH and extra_lower != "perfect":
confidence_counter["Over-confident"] += 1 confidence_counter["Over-confident"] += 1
elif score < THRESH and extra_lower == "perfect": elif score < THRESH and extra_lower == "perfect":
confidence_counter["Under-confident"] += 1 confidence_counter["Under-confident"] += 1
else: else:
confidence_counter["Other"] += 1 confidence_counter["Correct-FALSE"] += 1
if confidence_counter: if confidence_counter:
df_conf = pd.DataFrame( df_conf = pd.DataFrame(
@@ -104,6 +104,11 @@ def render():
ax.axis("equal") ax.axis("equal")
ax.set_title(file_path.name) ax.set_title(file_path.name)
total = sum(confidence_counter.values())
correct = confidence_counter["Correct-TRUE"] + confidence_counter["Correct-FALSE"]
corr_percent = (correct / total) * 100
st.markdown(f"**Correct: {corr_percent:.2f}% ({correct}/{total})**")
st.pyplot(fig, width=500) st.pyplot(fig, width=500)
else: else:
st.info("No score data available in this file.") st.info("No score data available in this file.")