Files
LLMsForDisinformationAnalysis/supporting/RAGAS_Service/roberta_service.py
T

25 lines
757 B
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch
MODEL_PATH = "./roberta_classifier"
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
text2 = "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in MarchAugust 2020)"
text = "Multiple mirrored reuploads (20202023) put the clip on other channels with titles implying it was a genuine 1970s public information film."
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True
)
model.eval()
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)
print(probs)