Add ROBERTA classifier ranking PoC, with 77pc off the bat
This commit is contained in:
@@ -3,7 +3,7 @@ import { MessagesState } from "../state";
|
|||||||
import { BaseMessage } from "@langchain/core/messages";
|
import { BaseMessage } from "@langchain/core/messages";
|
||||||
|
|
||||||
//TODO: Each of these might need different weights
|
//TODO: Each of these might need different weights
|
||||||
const keys = ["CONFIDENCE", "RAGAS", "RELATION"];
|
const keys = ["CONFIDENCE", "RELATION", "RAGAS", "ROBERTA"];
|
||||||
|
|
||||||
const mapping = {
|
const mapping = {
|
||||||
VERYHIGH: 1.0,
|
VERYHIGH: 1.0,
|
||||||
@@ -16,7 +16,7 @@ const mapping = {
|
|||||||
type Priority = keyof typeof mapping;
|
type Priority = keyof typeof mapping;
|
||||||
|
|
||||||
function mapResponse(value: string | undefined | null): number {
|
function mapResponse(value: string | undefined | null): number {
|
||||||
if (!value) return 0;
|
if (!value) return 1;
|
||||||
|
|
||||||
const trimmed = value.trim();
|
const trimmed = value.trim();
|
||||||
const num = parseFloat(trimmed);
|
const num = parseFloat(trimmed);
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
import { GraphNode } from "@langchain/langgraph";
|
||||||
|
import { MessagesState } from "../state";
|
||||||
|
import { AIMessage } from "@langchain/core/messages";
|
||||||
|
import { evaluateWithRoberta } from "../tools/robertaCall";
|
||||||
|
|
||||||
|
export const robertaMetrics: GraphNode<typeof MessagesState> = async (state) => {
|
||||||
|
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
|
||||||
|
|
||||||
|
const result = await evaluateWithRoberta({answer})
|
||||||
|
|
||||||
|
return {
|
||||||
|
messages: [ new AIMessage("ROBERTA:" + result)]
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
import axios from "axios";
|
||||||
|
|
||||||
|
export async function evaluateWithRoberta({
|
||||||
|
answer
|
||||||
|
}: {
|
||||||
|
answer: string;
|
||||||
|
}) {
|
||||||
|
const res = await axios.post("http://localhost:8000/evaluate", {
|
||||||
|
answer
|
||||||
|
});
|
||||||
|
// console.log(res.data)
|
||||||
|
const validProb = res.data["probabilities"][0][0]
|
||||||
|
const invalidProv = res.data["probabilities"][0][1]
|
||||||
|
|
||||||
|
return validProb > invalidProv ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// let res = await evaluateWithRoberta({answer: "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in March–August 2020)"});
|
||||||
|
// console.log(res)
|
||||||
|
|
||||||
|
// res = await evaluateWithRoberta({answer: "Multiple mirrored reuploads (2020–2023) put the clip on other channels with titles implying it was a genuine 1970s public information film."});
|
||||||
|
// console.log(res)
|
||||||
+13
-9
@@ -6,6 +6,7 @@ import { produceRanking } from "./nodes/produceRanking";
|
|||||||
import { createModelNode } from "./nodes/model";
|
import { createModelNode } from "./nodes/model";
|
||||||
import { loopEndConditional } from "./conditionals/loop_end";
|
import { loopEndConditional } from "./conditionals/loop_end";
|
||||||
import { sort } from "./nodes/sort";
|
import { sort } from "./nodes/sort";
|
||||||
|
import { robertaMetrics } from "./nodes/robertaMetrics";
|
||||||
|
|
||||||
const verificationModel = createModelNode([], "verify.txt");
|
const verificationModel = createModelNode([], "verify.txt");
|
||||||
const relationModel = createModelNode([], "relation.txt");
|
const relationModel = createModelNode([], "relation.txt");
|
||||||
@@ -14,21 +15,24 @@ const agent = new StateGraph(MessagesState)
|
|||||||
|
|
||||||
//NODES
|
//NODES
|
||||||
.addNode(verificationSetup.name, verificationSetup)
|
.addNode(verificationSetup.name, verificationSetup)
|
||||||
.addNode("verificationModel", verificationModel)
|
// .addNode("verificationModel", verificationModel)
|
||||||
.addNode(ragasMetrics.name, ragasMetrics)
|
// .addNode(ragasMetrics.name, ragasMetrics)
|
||||||
.addNode("relationModel", relationModel)
|
.addNode(robertaMetrics.name, robertaMetrics)
|
||||||
|
// .addNode("relationModel", relationModel)
|
||||||
|
|
||||||
.addNode(produceRanking.name, produceRanking)
|
.addNode(produceRanking.name, produceRanking)
|
||||||
.addNode(sort.name, sort)
|
.addNode(sort.name, sort)
|
||||||
|
|
||||||
.addEdge(START, verificationSetup.name)
|
.addEdge(START, verificationSetup.name)
|
||||||
.addEdge(verificationSetup.name, "verificationModel")
|
// .addEdge(verificationSetup.name, "verificationModel")
|
||||||
.addEdge(verificationSetup.name, ragasMetrics.name)
|
// .addEdge(verificationSetup.name, ragasMetrics.name)
|
||||||
.addEdge(verificationSetup.name, "relationModel")
|
.addEdge(verificationSetup.name, robertaMetrics.name)
|
||||||
|
// .addEdge(verificationSetup.name, "relationModel")
|
||||||
|
|
||||||
.addEdge(ragasMetrics.name, produceRanking.name)
|
// .addEdge(ragasMetrics.name, produceRanking.name)
|
||||||
.addEdge("verificationModel", produceRanking.name)
|
.addEdge(robertaMetrics.name, produceRanking.name)
|
||||||
.addEdge("relationModel", produceRanking.name)
|
// .addEdge("verificationModel", produceRanking.name)
|
||||||
|
// .addEdge("relationModel", produceRanking.name)
|
||||||
|
|
||||||
// @ts-expect-error
|
// @ts-expect-error
|
||||||
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
||||||
|
|||||||
@@ -1,25 +1,33 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
||||||
import torch
|
import torch
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
MODEL_PATH = "./roberta_classifier"
|
MODEL_PATH = "./roberta_classifier"
|
||||||
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
||||||
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
||||||
|
|
||||||
text2 = "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in March–August 2020)"
|
class EvalRequest(BaseModel):
|
||||||
text = "Multiple mirrored reuploads (2020–2023) put the clip on other channels with titles implying it was a genuine 1970s public information film."
|
answer: str
|
||||||
|
|
||||||
inputs = tokenizer(
|
@app.post("/evaluate")
|
||||||
text,
|
def evaluate_rob(req: EvalRequest):
|
||||||
|
inputs = tokenizer(
|
||||||
|
req.answer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
padding=True
|
padding=True
|
||||||
)
|
)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(**inputs).logits
|
logits = model(**inputs).logits
|
||||||
|
|
||||||
probs = torch.softmax(logits, dim=1)
|
probs = torch.softmax(logits, dim=1)
|
||||||
print(probs)
|
return {
|
||||||
|
"probabilities": probs.cpu().numpy().tolist()
|
||||||
|
}
|
||||||
@@ -6,17 +6,17 @@ from collections import Counter
|
|||||||
import sys
|
import sys
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
NUM_CLASSES = 3
|
NUM_CLASSES = 2
|
||||||
model_name = "roberta-base"
|
model_name = "roberta-base"
|
||||||
|
|
||||||
LABEL_PRIORITY = [
|
LABEL_PRIORITY = [
|
||||||
("PERFECT", 0),
|
("PERFECT", 0),
|
||||||
("STORY", 1),
|
("STORY", 1),
|
||||||
("NSPECIFIC", 2),
|
("NSPECIFIC", 1),
|
||||||
("REWORDING", 2),
|
("REWORDING", 1),
|
||||||
("TINCORRECT", -1),
|
("TINCORRECT", -1),
|
||||||
("DUPLICATE", -1),
|
("DUPLICATE", -1),
|
||||||
("", 2), # fallback to PERFECT
|
("", 0), # fallback to PERFECT
|
||||||
]
|
]
|
||||||
|
|
||||||
def label_to_int(extra_info: str) -> int:
|
def label_to_int(extra_info: str) -> int:
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ function buildAgentInput(record: Claim | VerifierInput) {
|
|||||||
date: v.date,
|
date: v.date,
|
||||||
proposedTriggerEvent: v.events,
|
proposedTriggerEvent: v.events,
|
||||||
normalizedClaim: v.normalizedClaim,
|
normalizedClaim: v.normalizedClaim,
|
||||||
proposedTriggerEventIndex: 0
|
proposedTriggerEventIndex: -1
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def render():
|
|||||||
st.error("Invalid folder path.")
|
st.error("Invalid folder path.")
|
||||||
return
|
return
|
||||||
|
|
||||||
jsonl_files = list(path.glob("*.jsonl"))
|
jsonl_files = sorted(path.glob("*.jsonl"))
|
||||||
if not jsonl_files:
|
if not jsonl_files:
|
||||||
st.info("No .jsonl files found in this folder.")
|
st.info("No .jsonl files found in this folder.")
|
||||||
return
|
return
|
||||||
@@ -80,13 +80,13 @@ def render():
|
|||||||
print(extra_lower)
|
print(extra_lower)
|
||||||
if score is not None:
|
if score is not None:
|
||||||
if score > THRESH and extra_lower == "perfect":
|
if score > THRESH and extra_lower == "perfect":
|
||||||
confidence_counter["Correct"] += 1
|
confidence_counter["Correct-TRUE"] += 1
|
||||||
elif score > THRESH and extra_lower != "perfect":
|
elif score > THRESH and extra_lower != "perfect":
|
||||||
confidence_counter["Over-confident"] += 1
|
confidence_counter["Over-confident"] += 1
|
||||||
elif score < THRESH and extra_lower == "perfect":
|
elif score < THRESH and extra_lower == "perfect":
|
||||||
confidence_counter["Under-confident"] += 1
|
confidence_counter["Under-confident"] += 1
|
||||||
else:
|
else:
|
||||||
confidence_counter["Other"] += 1
|
confidence_counter["Correct-FALSE"] += 1
|
||||||
|
|
||||||
if confidence_counter:
|
if confidence_counter:
|
||||||
df_conf = pd.DataFrame(
|
df_conf = pd.DataFrame(
|
||||||
@@ -104,6 +104,11 @@ def render():
|
|||||||
ax.axis("equal")
|
ax.axis("equal")
|
||||||
ax.set_title(file_path.name)
|
ax.set_title(file_path.name)
|
||||||
|
|
||||||
|
total = sum(confidence_counter.values())
|
||||||
|
correct = confidence_counter["Correct-TRUE"] + confidence_counter["Correct-FALSE"]
|
||||||
|
|
||||||
|
corr_percent = (correct / total) * 100
|
||||||
|
st.markdown(f"**Correct: {corr_percent:.2f}% ({correct}/{total})**")
|
||||||
st.pyplot(fig, width=500)
|
st.pyplot(fig, width=500)
|
||||||
else:
|
else:
|
||||||
st.info("No score data available in this file.")
|
st.info("No score data available in this file.")
|
||||||
Reference in New Issue
Block a user