Add ROBERTA classifier ranking PoC, with 77pc off the bat

This commit is contained in:
William Jeynes
2026-03-13 11:24:51 +00:00
parent f09e36e740
commit 8311556855
8 changed files with 85 additions and 32 deletions
+21 -13
View File
@@ -1,25 +1,33 @@
from pydantic import BaseModel
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch
from fastapi import FastAPI
app = FastAPI()
MODEL_PATH = "./roberta_classifier"
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
text2 = "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in MarchAugust 2020)"
text = "Multiple mirrored reuploads (20202023) put the clip on other channels with titles implying it was a genuine 1970s public information film."
class EvalRequest(BaseModel):
answer: str
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True
)
@app.post("/evaluate")
def evaluate_rob(req: EvalRequest):
inputs = tokenizer(
req.answer,
return_tensors="pt",
truncation=True,
padding=True
)
model.eval()
model.eval()
with torch.no_grad():
logits = model(**inputs).logits
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)
print(probs)
probs = torch.softmax(logits, dim=1)
return {
"probabilities": probs.cpu().numpy().tolist()
}
+4 -4
View File
@@ -6,17 +6,17 @@ from collections import Counter
import sys
import csv
NUM_CLASSES = 3
NUM_CLASSES = 2
model_name = "roberta-base"
LABEL_PRIORITY = [
("PERFECT", 0),
("STORY", 1),
("NSPECIFIC", 2),
("REWORDING", 2),
("NSPECIFIC", 1),
("REWORDING", 1),
("TINCORRECT", -1),
("DUPLICATE", -1),
("", 2), # fallback to PERFECT
("", 0), # fallback to PERFECT
]
def label_to_int(extra_info: str) -> int:
+1 -1
View File
@@ -102,7 +102,7 @@ function buildAgentInput(record: Claim | VerifierInput) {
date: v.date,
proposedTriggerEvent: v.events,
normalizedClaim: v.normalizedClaim,
proposedTriggerEventIndex: 0
proposedTriggerEventIndex: -1
};
}
+8 -3
View File
@@ -56,7 +56,7 @@ def render():
st.error("Invalid folder path.")
return
jsonl_files = list(path.glob("*.jsonl"))
jsonl_files = sorted(path.glob("*.jsonl"))
if not jsonl_files:
st.info("No .jsonl files found in this folder.")
return
@@ -80,13 +80,13 @@ def render():
print(extra_lower)
if score is not None:
if score > THRESH and extra_lower == "perfect":
confidence_counter["Correct"] += 1
confidence_counter["Correct-TRUE"] += 1
elif score > THRESH and extra_lower != "perfect":
confidence_counter["Over-confident"] += 1
elif score < THRESH and extra_lower == "perfect":
confidence_counter["Under-confident"] += 1
else:
confidence_counter["Other"] += 1
confidence_counter["Correct-FALSE"] += 1
if confidence_counter:
df_conf = pd.DataFrame(
@@ -104,6 +104,11 @@ def render():
ax.axis("equal")
ax.set_title(file_path.name)
total = sum(confidence_counter.values())
correct = confidence_counter["Correct-TRUE"] + confidence_counter["Correct-FALSE"]
corr_percent = (correct / total) * 100
st.markdown(f"**Correct: {corr_percent:.2f}% ({correct}/{total})**")
st.pyplot(fig, width=500)
else:
st.info("No score data available in this file.")