Re-allow multithreading on service. Add results table
This commit is contained in:
@@ -16,22 +16,22 @@ export const robertaMetrics: GraphNode<typeof MessagesState> = async (state) =>
|
|||||||
const flscore = flresult.validProb - flresult.invalidProb;
|
const flscore = flresult.validProb - flresult.invalidProb;
|
||||||
|
|
||||||
//Option 1: combining scores
|
//Option 1: combining scores
|
||||||
// const score = lrscore * 0.3 + roscore * 0.5 + flscore * 0.3
|
const score = lrscore * 0.3 + roscore * 0.5 + flscore * 0.3
|
||||||
|
|
||||||
//Option 2: majority voting
|
//Option 2: majority voting
|
||||||
const rovote = roscore > 0.6
|
// const rovote = roscore > 0.6
|
||||||
const flvote = flscore > 0.94
|
// const flvote = flscore > 0.94
|
||||||
const lrvote = lrscore > 0.75
|
// const lrvote = lrscore > 0.75
|
||||||
|
|
||||||
let counter = 0
|
// let counter = 0
|
||||||
if (rovote) counter++
|
// if (rovote) counter++
|
||||||
if (flvote) counter++
|
// if (flvote) counter++
|
||||||
if (lrvote) counter++
|
// if (lrvote) counter++
|
||||||
|
|
||||||
let score = 0
|
// let score = 0
|
||||||
if (counter >= 2) {
|
// if (counter >= 2) {
|
||||||
score = 0.7 + lrscore + flscore + lrscore
|
// score = 0.7 + lrscore + flscore + lrscore
|
||||||
}
|
// }
|
||||||
|
|
||||||
return {
|
return {
|
||||||
messages: [ new AIMessage("ROBERTA:" + score)]
|
messages: [ new AIMessage("ROBERTA:" + score)]
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
| Model | % Correct | % Valid taken forward|Used in ensemble|Link
|
||||||
|
|------------------------------------------------------------|-----------|----------------------|----------------|-
|
||||||
|
| Original | 53.22 | 61.72 |
|
||||||
|
| Original (RAGAS) | 56.01 | 57.73 |
|
||||||
|
| Roberta (base) | 75 | 70 |
|
||||||
|
| Roberta (Generated Data) | 76 | 71 |
|
||||||
|
| Roberta (Generated Data + Back Translation) | 74 | 71 |
|
||||||
|
| Roberta (Generated Data + Back Translation + Thresholding) | 77 | 90 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis)
|
||||||
|
| Distilled Roberta | 72.73 | 69.57 |
|
||||||
|
| Flan | 79.17 | 85.71 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis-Flan)
|
||||||
|
| Simple Regression Model | 74.77 | 85.71 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis-Regression)
|
||||||
|
| Ensemble Model (weighted confidence score sum) | 84.21 | 83.33 |
|
||||||
|
| Ensemble Model (majority voting) | 80.2 | 95.12 |
|
||||||
@@ -102,6 +102,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||||||
flan_model.to(device)
|
flan_model.to(device)
|
||||||
flan_model.eval()
|
flan_model.eval()
|
||||||
|
|
||||||
|
label_token_ids = {
|
||||||
|
label: flan_tokenizer(label, add_special_tokens=False).input_ids[0]
|
||||||
|
for label in LABEL_TO_INT.keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def format_prompt(text: str) -> str:
|
def format_prompt(text: str) -> str:
|
||||||
return (
|
return (
|
||||||
@@ -204,11 +209,6 @@ def evaluate(req: EvalRequest):
|
|||||||
skip_special_tokens=True
|
skip_special_tokens=True
|
||||||
)
|
)
|
||||||
|
|
||||||
label_token_ids = {
|
|
||||||
label: flan_tokenizer(label, add_special_tokens=False).input_ids[0]
|
|
||||||
for label in LABEL_TO_INT.keys()
|
|
||||||
}
|
|
||||||
|
|
||||||
label_logits = torch.tensor(
|
label_logits = torch.tensor(
|
||||||
[logits[0, tid].item() for tid in label_token_ids.values()]
|
[logits[0, tid].item() for tid in label_token_ids.values()]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ const AGENT_NAME = process.env.AGENT ?? "agent";
|
|||||||
*/
|
*/
|
||||||
const MODE = process.env.MODE ?? "claim";
|
const MODE = process.env.MODE ?? "claim";
|
||||||
|
|
||||||
const MAX_CONCURRENCY = 1;
|
const MAX_CONCURRENCY = 5;
|
||||||
|
|
||||||
const client = new Client({ apiUrl: API_URL });
|
const client = new Client({ apiUrl: API_URL });
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user