Add ROBERTA classifier ranking PoC, with 77pc off the bat
This commit is contained in:
@@ -1,25 +1,33 @@
|
||||
from pydantic import BaseModel
|
||||
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
||||
import torch
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
MODEL_PATH = "./roberta_classifier"
|
||||
|
||||
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
||||
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
||||
|
||||
text2 = "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in March–August 2020)"
|
||||
text = "Multiple mirrored reuploads (2020–2023) put the clip on other channels with titles implying it was a genuine 1970s public information film."
|
||||
class EvalRequest(BaseModel):
|
||||
answer: str
|
||||
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding=True
|
||||
)
|
||||
@app.post("/evaluate")
|
||||
def evaluate_rob(req: EvalRequest):
|
||||
inputs = tokenizer(
|
||||
req.answer,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding=True
|
||||
)
|
||||
|
||||
model.eval()
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(**inputs).logits
|
||||
with torch.no_grad():
|
||||
logits = model(**inputs).logits
|
||||
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
print(probs)
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
return {
|
||||
"probabilities": probs.cpu().numpy().tolist()
|
||||
}
|
||||
@@ -6,17 +6,17 @@ from collections import Counter
|
||||
import sys
|
||||
import csv
|
||||
|
||||
NUM_CLASSES = 3
|
||||
NUM_CLASSES = 2
|
||||
model_name = "roberta-base"
|
||||
|
||||
LABEL_PRIORITY = [
|
||||
("PERFECT", 0),
|
||||
("STORY", 1),
|
||||
("NSPECIFIC", 2),
|
||||
("REWORDING", 2),
|
||||
("NSPECIFIC", 1),
|
||||
("REWORDING", 1),
|
||||
("TINCORRECT", -1),
|
||||
("DUPLICATE", -1),
|
||||
("", 2), # fallback to PERFECT
|
||||
("", 0), # fallback to PERFECT
|
||||
]
|
||||
|
||||
def label_to_int(extra_info: str) -> int:
|
||||
|
||||
Reference in New Issue
Block a user