Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1ac94441c5 | |||
| f3e2897806 | |||
| f821e9643d | |||
| 43ecd04135 | |||
| 8c0921057b | |||
| b610e8c989 | |||
| f8d4155b7c | |||
| 5e374a8bd6 | |||
| fbc688b8f9 | |||
| 77cdd9a01c | |||
| a7f5978f64 | |||
| 872346c657 | |||
| 8f939d54c4 | |||
| 624d45bc53 | |||
| 80bc151379 | |||
| 5ce64290ce | |||
| 87fccb7e2b | |||
| 8c1e35f66f | |||
| 44395bb251 | |||
| e368c50577 |
@@ -1,2 +1,3 @@
|
|||||||
# TEMP
|
# TEMP
|
||||||
literature/
|
literature/
|
||||||
|
backup.tar.gz
|
||||||
@@ -7,6 +7,15 @@ Final Dissertation Submission Repository
|
|||||||
## Solution Diagram
|
## Solution Diagram
|
||||||
-- todo --
|
-- todo --
|
||||||
|
|
||||||
|
## Classifier Refinement
|
||||||
|
[See RAGAS_Service](/supporting/RAGAS_Service/)
|
||||||
|
|
||||||
|
## Agent Refinement
|
||||||
|
[See agent](/agent/)
|
||||||
|
|
||||||
|
## Generated Database Link and Usage Experiments
|
||||||
|
-- todo --
|
||||||
|
|
||||||
## Repository Structure
|
## Repository Structure
|
||||||
```
|
```
|
||||||
├── run.sh # Bash script to run project elements from one place
|
├── run.sh # Bash script to run project elements from one place
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
## Refining the agent output
|
||||||
|
|
||||||
|
Experiments modifying pipeline
|
||||||
|
|
||||||
|
| Model | % Correct | % Change |
|
||||||
|
|------------------|----------:|---------:|
|
||||||
|
| BASELINE | 33 | 0 |
|
||||||
|
| Improv Prompt | 39.96 | 0.21 |
|
||||||
|
| Add Examples | 44.67 | 0.35 |
|
||||||
|
| Date | 45.51 | 0.38 |
|
||||||
|
| Chain of Thought | 43.38 | 0.31 |
|
||||||
|
| Self-Critique | 44.36 | 0.34 |
|
||||||
|
|
||||||
|
Experiments with different model types:
|
||||||
|
| Model | % Correct | % Change |
|
||||||
|
|-------------------------------|----------:|---------:|
|
||||||
|
| gpt-5-mini | 33 | 0 |
|
||||||
|
| gpt-5.4-mini | 32.4 | -0.02 |
|
||||||
|
| llama3.1:8b-instruct-q4_K_M | ? | ? |
|
||||||
|
| qwen3.5:9b | 0 | -100 |
|
||||||
|
|
||||||
|
%age valid URLS
|
||||||
|
| Model | Number | % Age |
|
||||||
|
|-------------------------------|----------:|---------:|
|
||||||
|
| gpt-5-mini | 22/405 | 5.43 |
|
||||||
|
| gpt-5.4-mini | 29/278 | 10.43 |
|
||||||
|
| llama3.1:8b-instruct-q4_K_M | ? | ? |
|
||||||
|
| qwen3.5:9b | 0 | 0 |
|
||||||
+15
-4
@@ -10,7 +10,7 @@ import { createModelNode } from "./nodes/model";
|
|||||||
import { loopEndConditional } from "./conditionals/loop_end";
|
import { loopEndConditional } from "./conditionals/loop_end";
|
||||||
import { sort } from "./nodes/sort";
|
import { sort } from "./nodes/sort";
|
||||||
import { triggerEventSetup } from "./nodes/triggerEventSetup";
|
import { triggerEventSetup } from "./nodes/triggerEventSetup";
|
||||||
import { robertaMetrics } from "./nodes/robertaMetrics";
|
import { createEnsembleNode } from "./nodes/ensembleNode";
|
||||||
|
|
||||||
const triggerEventToolNode = createToolNode(triggerEventToolsByName);
|
const triggerEventToolNode = createToolNode(triggerEventToolsByName);
|
||||||
|
|
||||||
@@ -19,6 +19,10 @@ const triggerEventModel = createModelNode(triggerEventToolsByName, "trigger.txt"
|
|||||||
|
|
||||||
const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name);
|
const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name);
|
||||||
|
|
||||||
|
const roNode = createEnsembleNode("ROBERTA", "roberta");
|
||||||
|
const flNode = createEnsembleNode("FLAN", "flan");
|
||||||
|
const lrNode = createEnsembleNode("REGRESSION", "logreg");
|
||||||
|
|
||||||
const agent = new StateGraph(MessagesState)
|
const agent = new StateGraph(MessagesState)
|
||||||
|
|
||||||
//NODES
|
//NODES
|
||||||
@@ -30,7 +34,10 @@ const agent = new StateGraph(MessagesState)
|
|||||||
.addNode("triggerEventModel", triggerEventModel)
|
.addNode("triggerEventModel", triggerEventModel)
|
||||||
|
|
||||||
.addNode(verificationSetup.name, verificationSetup)
|
.addNode(verificationSetup.name, verificationSetup)
|
||||||
.addNode(robertaMetrics.name, robertaMetrics)
|
|
||||||
|
.addNode("roNode", roNode)
|
||||||
|
.addNode("flNode", flNode)
|
||||||
|
.addNode("lrNode", lrNode)
|
||||||
|
|
||||||
.addNode(produceRanking.name, produceRanking)
|
.addNode(produceRanking.name, produceRanking)
|
||||||
.addNode(sort.name, sort)
|
.addNode(sort.name, sort)
|
||||||
@@ -45,9 +52,13 @@ const agent = new StateGraph(MessagesState)
|
|||||||
.addConditionalEdges("triggerEventModel", triggerEventToolConditional, ["triggerEventToolNode", verificationSetup.name])
|
.addConditionalEdges("triggerEventModel", triggerEventToolConditional, ["triggerEventToolNode", verificationSetup.name])
|
||||||
.addEdge("triggerEventToolNode", "triggerEventModel")
|
.addEdge("triggerEventToolNode", "triggerEventModel")
|
||||||
|
|
||||||
.addEdge(verificationSetup.name, robertaMetrics.name)
|
.addEdge(verificationSetup.name, "roNode")
|
||||||
|
.addEdge(verificationSetup.name, "flNode")
|
||||||
|
.addEdge(verificationSetup.name, "lrNode")
|
||||||
|
|
||||||
.addEdge(robertaMetrics.name, produceRanking.name)
|
.addEdge("roNode", produceRanking.name)
|
||||||
|
.addEdge("flNode", produceRanking.name)
|
||||||
|
.addEdge("lrNode", produceRanking.name)
|
||||||
|
|
||||||
// @ts-expect-error
|
// @ts-expect-error
|
||||||
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
import { GraphNode } from "@langchain/langgraph";
|
||||||
|
import { MessagesState } from "../state";
|
||||||
|
import { AIMessage } from "@langchain/core/messages";
|
||||||
|
import { evaluateWithEnsemble } from "../tools/ensembleCall";
|
||||||
|
|
||||||
|
export function createEnsembleNode(title: string, method: string): GraphNode<typeof MessagesState> {
|
||||||
|
return async (state) => {
|
||||||
|
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
|
||||||
|
|
||||||
|
const result = await evaluateWithEnsemble({ answer, method })
|
||||||
|
const score = result.validProb - result.invalidProb;
|
||||||
|
|
||||||
|
return {
|
||||||
|
messages: [new AIMessage(title + ":" + score)]
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -9,7 +9,7 @@ export function createModelNode(tools: any, promptPath: string): GraphNode<typeo
|
|||||||
const sysPrompt = await hydratePrompt(promptPath, state);
|
const sysPrompt = await hydratePrompt(promptPath, state);
|
||||||
|
|
||||||
const model = new ChatOpenAI({
|
const model = new ChatOpenAI({
|
||||||
model: "gpt-5-mini"
|
model: "gpt-5.4-nano"
|
||||||
});
|
});
|
||||||
const modelWithTools = model.bindTools(Object.values(tools));
|
const modelWithTools = model.bindTools(Object.values(tools));
|
||||||
|
|
||||||
|
|||||||
@@ -2,31 +2,25 @@ import { GraphNode } from "@langchain/langgraph";
|
|||||||
import { MessagesState } from "../state";
|
import { MessagesState } from "../state";
|
||||||
import { BaseMessage } from "@langchain/core/messages";
|
import { BaseMessage } from "@langchain/core/messages";
|
||||||
|
|
||||||
//TODO: Each of these might need different weights
|
const models = {
|
||||||
const keys = ["CONFIDENCE", "RELATION", "RAGAS", "ROBERTA"];
|
REGRESSION: 0.3,
|
||||||
|
ROBERTA: 0.5,
|
||||||
const mapping = {
|
FLAN: 0.3,
|
||||||
VERYHIGH: 1.0,
|
|
||||||
HIGH: 0.75,
|
|
||||||
MEDIUM: 0.5,
|
|
||||||
LOW: 0.25,
|
|
||||||
VERYLOW: 0.0,
|
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
type Priority = keyof typeof mapping;
|
type ModelKey = keyof typeof models;
|
||||||
|
|
||||||
function mapResponse(value: string | undefined | null): number {
|
function mapResponse(value: string | undefined | null): number {
|
||||||
if (!value) return 1;
|
if (!value) return 0;
|
||||||
|
|
||||||
const trimmed = value.trim();
|
const trimmed = value.trim();
|
||||||
const num = parseFloat(trimmed);
|
const num = parseFloat(trimmed);
|
||||||
|
|
||||||
// If number, return it
|
if (!isNaN(num)) {
|
||||||
if (!isNaN(num)) return num;
|
return num;
|
||||||
|
} else {
|
||||||
// Otherwise, map to value
|
return 0;
|
||||||
const upper = trimmed.toUpperCase() as Priority;
|
}
|
||||||
return mapping[upper] ?? 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function getLastMessageContaining(
|
function getLastMessageContaining(
|
||||||
@@ -43,15 +37,15 @@ function getLastMessageContaining(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const produceRanking: GraphNode<typeof MessagesState> = async (state) => {
|
export const produceRanking: GraphNode<typeof MessagesState> = async (state) => {
|
||||||
// Extract and map values
|
const values = (Object.keys(models) as ModelKey[]).map((key) => {
|
||||||
const values = keys.map((key) => {
|
|
||||||
const msg = getLastMessageContaining(state.messages, key);
|
const msg = getLastMessageContaining(state.messages, key);
|
||||||
const part = msg?.split(":").at(1);
|
const part = msg?.split(":").at(1);
|
||||||
return mapResponse(part);
|
const baseValue = mapResponse(part);
|
||||||
|
|
||||||
|
return baseValue * models[key];
|
||||||
});
|
});
|
||||||
|
|
||||||
// Multiply!
|
const result = values.reduce((acc, val) => acc + val, 0);
|
||||||
const result = values.reduce((acc, val) => acc * val, 1);
|
|
||||||
|
|
||||||
const current = state.proposedTriggerEvent;
|
const current = state.proposedTriggerEvent;
|
||||||
current[state.proposedTriggerEventIndex].score = result;
|
current[state.proposedTriggerEventIndex].score = result;
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
import { GraphNode } from "@langchain/langgraph";
|
|
||||||
import { MessagesState } from "../state";
|
|
||||||
import { AIMessage, HumanMessage } from "@langchain/core/messages";
|
|
||||||
import { evaluateWithRagas } from "../tools/ragasCall";
|
|
||||||
|
|
||||||
export const ragasMetrics: GraphNode<typeof MessagesState> = async (state) => {
|
|
||||||
const question = "A possible trigger event for: " + state.disinformationTitle //Should it be raw, or normalized?
|
|
||||||
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
|
|
||||||
const contexts = state.proposedTriggerEvent[state.proposedTriggerEventIndex].context?.split("^^^") ?? []
|
|
||||||
|
|
||||||
const results = await evaluateWithRagas({question, answer, contexts})
|
|
||||||
|
|
||||||
return {
|
|
||||||
messages: [ new AIMessage("RAGAS:" + results.faithfulness)]
|
|
||||||
};
|
|
||||||
};
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
import { GraphNode } from "@langchain/langgraph";
|
|
||||||
import { MessagesState } from "../state";
|
|
||||||
import { AIMessage } from "@langchain/core/messages";
|
|
||||||
import { evaluateWithRoberta } from "../tools/robertaCall";
|
|
||||||
|
|
||||||
export const robertaMetrics: GraphNode<typeof MessagesState> = async (state) => {
|
|
||||||
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
|
|
||||||
|
|
||||||
const result = await evaluateWithRoberta({answer})
|
|
||||||
|
|
||||||
|
|
||||||
const score = result.validProb - result.invalidProb;
|
|
||||||
|
|
||||||
|
|
||||||
return {
|
|
||||||
messages: [ new AIMessage("ROBERTA:" + score)]
|
|
||||||
};
|
|
||||||
};
|
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
import { GraphNode } from "@langchain/langgraph";
|
import { GraphNode } from "@langchain/langgraph";
|
||||||
import { MessagesState, ProposedTriggerEventArray } from "../state";
|
import { MessagesState, ProposedTriggerEventArray } from "../state";
|
||||||
import { logger } from "../utils/logger";
|
import { logger } from "../utils/logger";
|
||||||
import { queryScraper } from "../tools/webSearch";
|
import { jsonrepair } from 'jsonrepair'
|
||||||
import { rankAndDisplayData } from "../tools/triggerEventTools";
|
|
||||||
|
|
||||||
export const verificationSetup: GraphNode<typeof MessagesState> = async (state) => {
|
export const verificationSetup: GraphNode<typeof MessagesState> = async (state) => {
|
||||||
//this is kinda doing two things, but having two nodes for it seems overkill
|
//this is kinda doing two things, but having two nodes for it seems overkill
|
||||||
@@ -11,14 +10,29 @@ export const verificationSetup: GraphNode<typeof MessagesState> = async (state)
|
|||||||
logger.warn("No trigger events in memory, parsing")
|
logger.warn("No trigger events in memory, parsing")
|
||||||
|
|
||||||
let genResponse = state.messages.at(-1)?.content.toString() ?? "";
|
let genResponse = state.messages.at(-1)?.content.toString() ?? "";
|
||||||
const parsed = ProposedTriggerEventArray.parse(JSON.parse(genResponse));
|
|
||||||
|
|
||||||
for (let i = 0; i < parsed.length; i++) {
|
const repaired = jsonrepair(genResponse);
|
||||||
const search = parsed[i].SearchQuery
|
|
||||||
const data = await queryScraper(search);
|
|
||||||
const output = await rankAndDisplayData(data, search);
|
|
||||||
|
|
||||||
parsed[i].context = output;
|
let parsed;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const json = JSON.parse(repaired);
|
||||||
|
|
||||||
|
if (Array.isArray(json)) {
|
||||||
|
parsed = ProposedTriggerEventArray.parse(json);
|
||||||
|
} else {
|
||||||
|
// try grab first value
|
||||||
|
const firstValue = Object.values(json)[0];
|
||||||
|
|
||||||
|
if (Array.isArray(firstValue)) {
|
||||||
|
parsed = ProposedTriggerEventArray.parse(firstValue);
|
||||||
|
} else {
|
||||||
|
throw new Error("No array found in JSON");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
logger.error(`Failed to parse LLM response: ${err.message}`);
|
||||||
|
throw new Error(`Failed to parse LLM response: ${err}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
return { proposedTriggerEvent: parsed, proposedTriggerEventIndex: 0 };
|
return { proposedTriggerEvent: parsed, proposedTriggerEventIndex: 0 };
|
||||||
|
|||||||
Generated
+10
@@ -20,6 +20,7 @@
|
|||||||
"dotenv": "^17.2.3",
|
"dotenv": "^17.2.3",
|
||||||
"exponential-backoff": "^3.1.3",
|
"exponential-backoff": "^3.1.3",
|
||||||
"fs": "^0.0.1-security",
|
"fs": "^0.0.1-security",
|
||||||
|
"jsonrepair": "^3.13.3",
|
||||||
"langchain": "^1.2.14",
|
"langchain": "^1.2.14",
|
||||||
"selenium-webdriver": "^4.40.0",
|
"selenium-webdriver": "^4.40.0",
|
||||||
"tldts": "^7.0.23",
|
"tldts": "^7.0.23",
|
||||||
@@ -2075,6 +2076,15 @@
|
|||||||
"integrity": "sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==",
|
"integrity": "sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==",
|
||||||
"license": "ISC"
|
"license": "ISC"
|
||||||
},
|
},
|
||||||
|
"node_modules/jsonrepair": {
|
||||||
|
"version": "3.13.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/jsonrepair/-/jsonrepair-3.13.3.tgz",
|
||||||
|
"integrity": "sha512-BTznj0owIt2CBAH/LTo7+1I5pMvl1e1033LRl/HUowlZmJOIhzC0zbX5bxMngLkfT4WnzPP26QnW5wMr2g9tsQ==",
|
||||||
|
"license": "ISC",
|
||||||
|
"bin": {
|
||||||
|
"jsonrepair": "bin/cli.js"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/jszip": {
|
"node_modules/jszip": {
|
||||||
"version": "3.10.1",
|
"version": "3.10.1",
|
||||||
"resolved": "https://registry.npmjs.org/jszip/-/jszip-3.10.1.tgz",
|
"resolved": "https://registry.npmjs.org/jszip/-/jszip-3.10.1.tgz",
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
"dotenv": "^17.2.3",
|
"dotenv": "^17.2.3",
|
||||||
"exponential-backoff": "^3.1.3",
|
"exponential-backoff": "^3.1.3",
|
||||||
"fs": "^0.0.1-security",
|
"fs": "^0.0.1-security",
|
||||||
|
"jsonrepair": "^3.13.3",
|
||||||
"langchain": "^1.2.14",
|
"langchain": "^1.2.14",
|
||||||
"selenium-webdriver": "^4.40.0",
|
"selenium-webdriver": "^4.40.0",
|
||||||
"tldts": "^7.0.23",
|
"tldts": "^7.0.23",
|
||||||
|
|||||||
@@ -8,13 +8,19 @@ Produce up-to 5 specific "trigger events" that happened that could have led to t
|
|||||||
Remember the time frame of the disinformation campaign: ###CDATE###
|
Remember the time frame of the disinformation campaign: ###CDATE###
|
||||||
Include no information or events that would not have been available at the time.
|
Include no information or events that would not have been available at the time.
|
||||||
|
|
||||||
|
You MEED TO use the tools available to you in order to produce up to date information on URL and search query, else you will be wrong and the analysis invalid.
|
||||||
|
You NEED TO use the web search and open URL tools to ensure page validity or else all work upto this point will have to be discarded.
|
||||||
|
|
||||||
|
|
||||||
Produce no more text other than the json.
|
Produce no more text other than the json.
|
||||||
|
|
||||||
Include a concise but specific search query that can be looked up on a search engine in order to allow for the verification.
|
Include a concise but specific search query that can be looked up on a search engine in order to allow for the verification.
|
||||||
|
|
||||||
Include a url to a source for your trigger event (not a web search, a specific url from a reputuable source). Do not use OAI cite, include url as text in response.
|
Include a url to a source for your trigger event (not a web search, a specific url from a reputuable source). Do not use OAI cite, include url as text in response.
|
||||||
|
|
||||||
Use a JSON format with each entry containing "Event,ReasoningWhyRelevant,SearchQuery,Url".
|
Include the date that the event happened ("March 2022" for exmaple)
|
||||||
|
|
||||||
|
Use a JSON format with each entry containing "Event,ReasoningWhyRelevant,SearchQuery,Url,Date".
|
||||||
|
|
||||||
Multiple tool invocations should be requested at once, if applicable.
|
Multiple tool invocations should be requested at once, if applicable.
|
||||||
Use your abilities to look between the lines and produce some insightful analysis, thinking both short and long term.
|
Use your abilities to look between the lines and produce some insightful analysis, thinking both short and long term.
|
||||||
@@ -24,4 +30,9 @@ Events will be reordered as part of processing, each statement must stand alone
|
|||||||
The preceeding messages act as examples of previous responses to potentially ficitonal events and scores given.
|
The preceeding messages act as examples of previous responses to potentially ficitonal events and scores given.
|
||||||
Analysis should only be completed for proposed events that would graner >0.7 points
|
Analysis should only be completed for proposed events that would graner >0.7 points
|
||||||
|
|
||||||
|
This pipeline is running well pasy your knowledge cutoff.
|
||||||
|
Any URLs will change signigicantly over time.
|
||||||
|
You MEED TO use the tools available to you in order to produce up to date information on URL and search query, else you will be wrong and the analysis invalid.
|
||||||
|
You NEED TO use the web search and open URL tools to ensure page validity or else all work upto this point will have to be discarded.
|
||||||
|
|
||||||
Lets go through it step by step
|
Lets go through it step by step
|
||||||
@@ -9,6 +9,7 @@ export const ProposedTriggerEvent = z.object({
|
|||||||
ReasoningWhyRelevant: z.string(),
|
ReasoningWhyRelevant: z.string(),
|
||||||
SearchQuery: z.string(),
|
SearchQuery: z.string(),
|
||||||
Url: z.url(),
|
Url: z.url(),
|
||||||
|
Date: z.string(),
|
||||||
context: z.string().optional(),
|
context: z.string().optional(),
|
||||||
score: z.number().optional()
|
score: z.number().optional()
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
|
|
||||||
export async function evaluateWithRoberta({
|
export async function evaluateWithEnsemble({
|
||||||
answer
|
answer,
|
||||||
|
method
|
||||||
}: {
|
}: {
|
||||||
answer: string;
|
answer: string;
|
||||||
|
method: string
|
||||||
}): Promise<{ validProb: number; invalidProb: number; }> {
|
}): Promise<{ validProb: number; invalidProb: number; }> {
|
||||||
const res = await axios.post("http://localhost:8000/evaluate", {
|
const res = await axios.post("http://localhost:8000/evaluate", {
|
||||||
answer
|
answer,
|
||||||
});
|
method
|
||||||
|
}, {timeout: 0});
|
||||||
// console.log(res.data)
|
// console.log(res.data)
|
||||||
const validProb = res.data["probabilities"][0][0]
|
const validProb = res.data["probabilities"][0][0]
|
||||||
const invalidProb = res.data["probabilities"][0][1] + res.data["probabilities"][0][2]
|
const invalidProb = res.data["probabilities"][0][1] + res.data["probabilities"][0][2]
|
||||||
@@ -15,6 +15,8 @@ const CACHE_PATH = "../data/csv.cache.json";
|
|||||||
|
|
||||||
const JSONL_PATH = "../data/input.jsonl"
|
const JSONL_PATH = "../data/input.jsonl"
|
||||||
|
|
||||||
|
const BM25_MIN_DOCS = 3;
|
||||||
|
|
||||||
type EmbeddingCache = {
|
type EmbeddingCache = {
|
||||||
rawtexts: string[];
|
rawtexts: string[];
|
||||||
cleantexts: string[];
|
cleantexts: string[];
|
||||||
@@ -287,8 +289,20 @@ async function embedText(text: string): Promise<number[]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function buildBM25(texts: string[]) {
|
function buildBM25(texts: string[]) {
|
||||||
logger.info("Building BM25 index (%s docs)...", texts.length);
|
let paddedTexts = texts;
|
||||||
|
|
||||||
|
if (texts.length < BM25_MIN_DOCS) {
|
||||||
|
const needed = BM25_MIN_DOCS - texts.length;
|
||||||
|
logger.error(
|
||||||
|
"Corpus too small for BM25 (%s docs, need %s+), padding with %s dummy doc(s)",
|
||||||
|
texts.length,
|
||||||
|
BM25_MIN_DOCS,
|
||||||
|
needed
|
||||||
|
);
|
||||||
|
paddedTexts = [...texts, ...Array(needed).fill("placeholder dummy document")];
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Building BM25 index (%s docs)...", paddedTexts.length);
|
||||||
const bm25 = bm25Factory();
|
const bm25 = bm25Factory();
|
||||||
|
|
||||||
bm25.defineConfig({
|
bm25.defineConfig({
|
||||||
@@ -302,7 +316,7 @@ function buildBM25(texts: string[]) {
|
|||||||
nlp.tokens.removeWords,
|
nlp.tokens.removeWords,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
texts.forEach((text, i) => {
|
paddedTexts.forEach((text, i) => {
|
||||||
bm25.addDoc({ text }, i);
|
bm25.addDoc({ text }, i);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,32 +1,92 @@
|
|||||||
import { Builder, Browser } from "selenium-webdriver";
|
import { Builder, Browser } from "selenium-webdriver";
|
||||||
import firefox from "selenium-webdriver/firefox";
|
import firefox from "selenium-webdriver/firefox";
|
||||||
|
import { backOff } from "exponential-backoff";
|
||||||
|
import { logger } from "../utils/logger";
|
||||||
|
|
||||||
export async function extractWebpageContent(url: string): Promise<string[]> {
|
export async function extractWebpageContent(url: string): Promise<string[]> {
|
||||||
|
try {
|
||||||
|
const response = await backOff(async () => {
|
||||||
|
return await extractWebpageContentWorker(url);
|
||||||
|
}, {
|
||||||
|
numOfAttempts: 10,
|
||||||
|
startingDelay: 500,
|
||||||
|
timeMultiple: 2,
|
||||||
|
jitter: "full",
|
||||||
|
maxDelay: 50000,
|
||||||
|
});
|
||||||
|
return response;
|
||||||
|
} catch (err: any) {
|
||||||
|
logger.error(`Failed out of retry loop for URL "${url}", returning placeholder to pipeline`);
|
||||||
|
return ["API EXCEPTION"];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function extractWebpageContentWorker(url: string): Promise<string[]> {
|
||||||
|
let driver;
|
||||||
|
try {
|
||||||
const options = new firefox.Options();
|
const options = new firefox.Options();
|
||||||
options.addArguments("--headless");
|
options.addArguments("--headless");
|
||||||
|
driver = await new Builder()
|
||||||
|
.forBrowser(Browser.FIREFOX)
|
||||||
|
.setFirefoxOptions(options)
|
||||||
|
.build();
|
||||||
|
} catch (err: any) {
|
||||||
|
const desc = `Failed to launch Firefox driver: ${err.message}`;
|
||||||
|
logger.error(desc);
|
||||||
|
throw new Error(desc);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
try {
|
||||||
|
await driver.get(url);
|
||||||
|
} catch (err: any) {
|
||||||
|
const desc = `Failed to navigate to URL "${url}": ${err.message}`;
|
||||||
|
logger.error(desc);
|
||||||
|
throw new Error(desc);
|
||||||
|
}
|
||||||
|
|
||||||
let driver = await new Builder().forBrowser(Browser.FIREFOX).setFirefoxOptions(options).build()
|
|
||||||
try {
|
try {
|
||||||
await driver.get(url)
|
|
||||||
await driver.wait(async () => {
|
await driver.wait(async () => {
|
||||||
return await driver.executeScript(
|
return await driver.executeScript(
|
||||||
"return document.readyState === 'complete'"
|
"return document.readyState === 'complete'"
|
||||||
);
|
);
|
||||||
}, 5000);
|
}, 5000);
|
||||||
|
} catch (err: any) {
|
||||||
|
logger.error(`Page load timed out for "${url}", attempting to read partial content: ${err.message}`);
|
||||||
|
// do not throw, attempt to read
|
||||||
|
}
|
||||||
|
|
||||||
const readableText = await driver.executeScript(
|
let readableText: string;
|
||||||
|
try {
|
||||||
|
readableText = await driver.executeScript(
|
||||||
"return document.body.innerText;"
|
"return document.body.innerText;"
|
||||||
) as string;
|
) as string;
|
||||||
|
} catch (err: any) {
|
||||||
|
const desc = `Failed to extract page text from "${url}": ${err.message}`;
|
||||||
|
logger.error(desc);
|
||||||
|
throw new Error(desc);
|
||||||
|
}
|
||||||
|
|
||||||
const filteredLines = readableText
|
const filteredLines = readableText
|
||||||
.split(/\r?\n/)
|
.split(/\r?\n/)
|
||||||
.map(line => line.trim())
|
.map(line => line.trim())
|
||||||
.filter(line => line.split(/\s+/).length > 1);
|
.filter(line => line.split(/\s+/).length > 1);
|
||||||
|
|
||||||
|
if (filteredLines.length === 0) {
|
||||||
|
const desc = `No content extracted from "${url}"`;
|
||||||
|
logger.error(desc);
|
||||||
|
throw new Error(desc);
|
||||||
|
}
|
||||||
|
|
||||||
return filteredLines;
|
return filteredLines;
|
||||||
} finally {
|
} finally {
|
||||||
await driver.quit()
|
try {
|
||||||
|
await driver.quit();
|
||||||
|
} catch (err: any) {
|
||||||
|
logger.error(`Failed to quit Firefox driver cleanly: ${err.message}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// console.log(await extractWebpageContent("https://www.bbc.co.uk/news/live/c74wd01egvyt"))
|
// console.log(await extractWebpageContent("https://www.bbc.co.uk/news/live/c74wd01egvyt"))
|
||||||
|
// console.log(await extractWebpageContent("https://badcertificate.int.jeynes.uk/"))
|
||||||
+14
-17
@@ -1,38 +1,35 @@
|
|||||||
import { END, START, StateGraph } from "@langchain/langgraph";
|
import { END, START, StateGraph } from "@langchain/langgraph";
|
||||||
import { MessagesState } from "./state";
|
import { MessagesState } from "./state";
|
||||||
import { verificationSetup } from "./nodes/verificationSetup";
|
import { verificationSetup } from "./nodes/verificationSetup";
|
||||||
import { ragasMetrics } from "./nodes/ragasMetrics";
|
|
||||||
import { produceRanking } from "./nodes/produceRanking";
|
import { produceRanking } from "./nodes/produceRanking";
|
||||||
import { createModelNode } from "./nodes/model";
|
|
||||||
import { loopEndConditional } from "./conditionals/loop_end";
|
import { loopEndConditional } from "./conditionals/loop_end";
|
||||||
import { sort } from "./nodes/sort";
|
import { sort } from "./nodes/sort";
|
||||||
import { robertaMetrics } from "./nodes/robertaMetrics";
|
import { createEnsembleNode } from "./nodes/ensembleNode";
|
||||||
|
|
||||||
const verificationModel = createModelNode([], "verify.txt");
|
const roNode = createEnsembleNode("ROBERTA", "roberta");
|
||||||
const relationModel = createModelNode([], "relation.txt");
|
const flNode = createEnsembleNode("FLAN", "flan");
|
||||||
|
const lrNode = createEnsembleNode("REGRESSION", "logreg");
|
||||||
|
|
||||||
const agent = new StateGraph(MessagesState)
|
const agent = new StateGraph(MessagesState)
|
||||||
|
|
||||||
//NODES
|
//NODES
|
||||||
.addNode(verificationSetup.name, verificationSetup)
|
.addNode(verificationSetup.name, verificationSetup)
|
||||||
// .addNode("verificationModel", verificationModel)
|
.addNode("roNode", roNode)
|
||||||
// .addNode(ragasMetrics.name, ragasMetrics)
|
.addNode("flNode", flNode)
|
||||||
.addNode(robertaMetrics.name, robertaMetrics)
|
.addNode("lrNode", lrNode)
|
||||||
// .addNode("relationModel", relationModel)
|
|
||||||
|
|
||||||
.addNode(produceRanking.name, produceRanking)
|
.addNode(produceRanking.name, produceRanking)
|
||||||
.addNode(sort.name, sort)
|
.addNode(sort.name, sort)
|
||||||
|
|
||||||
.addEdge(START, verificationSetup.name)
|
.addEdge(START, verificationSetup.name)
|
||||||
// .addEdge(verificationSetup.name, "verificationModel")
|
|
||||||
// .addEdge(verificationSetup.name, ragasMetrics.name)
|
|
||||||
.addEdge(verificationSetup.name, robertaMetrics.name)
|
|
||||||
// .addEdge(verificationSetup.name, "relationModel")
|
|
||||||
|
|
||||||
// .addEdge(ragasMetrics.name, produceRanking.name)
|
.addEdge(verificationSetup.name, "roNode")
|
||||||
.addEdge(robertaMetrics.name, produceRanking.name)
|
.addEdge(verificationSetup.name, "flNode")
|
||||||
// .addEdge("verificationModel", produceRanking.name)
|
.addEdge(verificationSetup.name, "lrNode")
|
||||||
// .addEdge("relationModel", produceRanking.name)
|
|
||||||
|
.addEdge("roNode", produceRanking.name)
|
||||||
|
.addEdge("flNode", produceRanking.name)
|
||||||
|
.addEdge("lrNode", produceRanking.name)
|
||||||
|
|
||||||
// @ts-expect-error
|
// @ts-expect-error
|
||||||
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ set -e
|
|||||||
run_agent () {
|
run_agent () {
|
||||||
echo "Starting LangGraph agent..."
|
echo "Starting LangGraph agent..."
|
||||||
cd agent
|
cd agent
|
||||||
npx @langchain/langgraph-cli dev
|
npx @langchain/langgraph-cli dev --host 127.0.0.1
|
||||||
}
|
}
|
||||||
|
|
||||||
run_ragas_service () {
|
run_ensemble_service () {
|
||||||
echo "Starting RAGAS service..."
|
echo "Starting Ensemble service..."
|
||||||
cd "supporting/RAGAS_Service"
|
cd "supporting/RAGAS_Service"
|
||||||
.venv/bin/uvicorn ragas_service:app --port 8001
|
.venv/bin/uvicorn ensemble_service:app --timeout-keep-alive 300
|
||||||
}
|
}
|
||||||
|
|
||||||
run_frontend () {
|
run_frontend () {
|
||||||
@@ -34,13 +34,13 @@ run_wrapper () {
|
|||||||
|
|
||||||
case "$1" in
|
case "$1" in
|
||||||
agent) run_agent ;;
|
agent) run_agent ;;
|
||||||
ragas_service) run_ragas_service ;;
|
ensemble_service) run_ensemble_service ;;
|
||||||
frontend) run_frontend ;;
|
frontend) run_frontend ;;
|
||||||
fetch) run_fetch ;;
|
fetch) run_fetch ;;
|
||||||
wrapper) run_wrapper ;;
|
wrapper) run_wrapper ;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown command: $1"
|
echo "Unknown command: $1"
|
||||||
echo "Usage: ./runproject [agent|ragas_service|frontend|fetch|wrapper]"
|
echo "Usage: ./runproject [agent|ensemble_service|frontend|fetch|wrapper]"
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
# -- OURS --
|
# -- OURS --
|
||||||
results/
|
results/
|
||||||
roberta_classifier/
|
roberta_classifier/
|
||||||
|
roberta_distilled_classifier/
|
||||||
roberta_classifier*/
|
roberta_classifier*/
|
||||||
|
*.pt
|
||||||
output*
|
output*
|
||||||
|
|
||||||
# -- THEIRS --
|
# -- THEIRS --
|
||||||
|
|||||||
@@ -0,0 +1,25 @@
|
|||||||
|
# Classifier work for evaluating model quality
|
||||||
|
|
||||||
|
Made using a dataset of 1000 labeled claims from MVP pipeline.
|
||||||
|
|
||||||
|
Roberta model trained on an augmented dataset with LLM generated adversarial examples for low frequency labels.
|
||||||
|
|
||||||
|
Flan model trained using raw labelled claims, inherrent natural language ability allows for pattern recognition without the need for fake data.
|
||||||
|
|
||||||
|
Regression model trained using the roberta dataset.
|
||||||
|
|
||||||
|
Used ensemble model in the final version, with the component models available on Hugging Face.
|
||||||
|
|
||||||
|
| Model | % Correct | % Valid taken forward|Used in ensemble|Link
|
||||||
|
|------------------------------------------------------------|-----------|----------------------|----------------|-
|
||||||
|
| Original | 53.22 | 61.72 |
|
||||||
|
| Original (RAGAS) | 56.01 | 57.73 |
|
||||||
|
| Roberta (base) | 75 | 70 |
|
||||||
|
| Roberta (Generated Data) | 76 | 71 |
|
||||||
|
| Roberta (Generated Data + Back Translation) | 74 | 71 |
|
||||||
|
| Roberta (Generated Data + Back Translation + Thresholding) | 77 | 90 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis)
|
||||||
|
| Distilled Roberta | 72.73 | 69.57 |
|
||||||
|
| Flan | 79.17 | 85.71 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis-Flan)
|
||||||
|
| Simple Regression Model | 74.77 | 85.71 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis-Regression)
|
||||||
|
| Ensemble Model (weighted confidence score sum) | 84.21 | 83.33 |
|
||||||
|
| Ensemble Model (majority voting) | 80.2 | 95.12 |
|
||||||
@@ -0,0 +1,224 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi import FastAPI
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import os
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
############################################
|
||||||
|
# SCHEMA
|
||||||
|
############################################
|
||||||
|
|
||||||
|
class EvalRequest(BaseModel):
|
||||||
|
answer: str
|
||||||
|
method: str # "logreg", "roberta", "flan"
|
||||||
|
|
||||||
|
|
||||||
|
############################################
|
||||||
|
# REGRESSION MODEL
|
||||||
|
############################################
|
||||||
|
|
||||||
|
HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression"
|
||||||
|
MODEL_FILENAME = "logreg_classifier.pt"
|
||||||
|
CACHE_DIR = "./model_cache"
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(repo_id: str, filename: str, cache_dir: str) -> dict:
|
||||||
|
local_path = os.path.join(cache_dir, filename)
|
||||||
|
if not os.path.exists(local_path):
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=cache_dir)
|
||||||
|
return torch.load(local_path, map_location="cpu")
|
||||||
|
|
||||||
|
|
||||||
|
class LogisticNet(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim, num_classes, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.BatchNorm1d(hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint = load_checkpoint(HF_REPO_ID, MODEL_FILENAME, CACHE_DIR)
|
||||||
|
|
||||||
|
encoder = SentenceTransformer(checkpoint["embedding_model"])
|
||||||
|
|
||||||
|
logreg_model = LogisticNet(
|
||||||
|
checkpoint["input_dim"],
|
||||||
|
checkpoint["hidden_dim"],
|
||||||
|
checkpoint["num_classes"],
|
||||||
|
checkpoint["dropout"],
|
||||||
|
)
|
||||||
|
logreg_model.load_state_dict(checkpoint["model_state"])
|
||||||
|
logreg_model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
############################################
|
||||||
|
# ROBERTA
|
||||||
|
############################################
|
||||||
|
|
||||||
|
ROBERTA_PATH = "WillJeynes/LLMsForDisinformationAnalysis"
|
||||||
|
|
||||||
|
roberta_tokenizer = RobertaTokenizer.from_pretrained(ROBERTA_PATH)
|
||||||
|
roberta_model = RobertaForSequenceClassification.from_pretrained(ROBERTA_PATH)
|
||||||
|
roberta_model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
############################################
|
||||||
|
# FLAN
|
||||||
|
############################################
|
||||||
|
|
||||||
|
FLAN_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan"
|
||||||
|
|
||||||
|
INT_TO_LABEL = {
|
||||||
|
0: "perfect",
|
||||||
|
1: "story",
|
||||||
|
2: "not specific",
|
||||||
|
}
|
||||||
|
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
|
||||||
|
|
||||||
|
flan_tokenizer = AutoTokenizer.from_pretrained(FLAN_PATH)
|
||||||
|
flan_model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_PATH)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
flan_model.to(device)
|
||||||
|
flan_model.eval()
|
||||||
|
|
||||||
|
label_token_ids = {
|
||||||
|
label: flan_tokenizer(label, add_special_tokens=False).input_ids[0]
|
||||||
|
for label in LABEL_TO_INT.keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt(text: str) -> str:
|
||||||
|
return (
|
||||||
|
"Classify the following event into one of these categories: "
|
||||||
|
"perfect, story, not specific.\n\n"
|
||||||
|
f"Event: {text}\n\n"
|
||||||
|
"Category:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_generated_label(generated: str):
|
||||||
|
generated = generated.strip().lower()
|
||||||
|
for label_text, label_int in LABEL_TO_INT.items():
|
||||||
|
if label_text in generated:
|
||||||
|
return label_int
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
############################################
|
||||||
|
# ENDPOINT
|
||||||
|
############################################
|
||||||
|
|
||||||
|
@app.post("/evaluate")
|
||||||
|
def evaluate(req: EvalRequest):
|
||||||
|
method = req.method.lower()
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# LOGREG
|
||||||
|
########################################
|
||||||
|
if method == "logreg":
|
||||||
|
embedding = encoder.encode(
|
||||||
|
[req.answer],
|
||||||
|
normalize_embeddings=True,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.tensor(embedding, dtype=torch.float32)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = logreg_model(x)
|
||||||
|
|
||||||
|
probs = torch.softmax(logits, dim=1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"method": "logreg",
|
||||||
|
"probabilities": probs.cpu().numpy().tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# ROBERTA
|
||||||
|
########################################
|
||||||
|
elif method == "roberta":
|
||||||
|
inputs = roberta_tokenizer(
|
||||||
|
req.answer,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
padding=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = roberta_model(**inputs).logits
|
||||||
|
|
||||||
|
probs = torch.softmax(logits, dim=1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"method": "roberta",
|
||||||
|
"probabilities": probs.cpu().numpy().tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# FLAN
|
||||||
|
########################################
|
||||||
|
elif method == "flan":
|
||||||
|
prompt = format_prompt(req.answer)
|
||||||
|
|
||||||
|
inputs = flan_tokenizer(
|
||||||
|
prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
max_length=256,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = flan_model.generate(**inputs, max_new_tokens=8)
|
||||||
|
|
||||||
|
decoder_input_ids = torch.tensor(
|
||||||
|
[[flan_model.config.decoder_start_token_id]]
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
logits_output = flan_model(
|
||||||
|
**inputs,
|
||||||
|
decoder_input_ids=decoder_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = logits_output.logits[:, 0, :]
|
||||||
|
|
||||||
|
generated_text = flan_tokenizer.decode(
|
||||||
|
outputs[0],
|
||||||
|
skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
label_logits = torch.tensor(
|
||||||
|
[logits[0, tid].item() for tid in label_token_ids.values()]
|
||||||
|
)
|
||||||
|
|
||||||
|
label_probs = torch.softmax(label_logits, dim=0).tolist()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"method": "flan",
|
||||||
|
"generated": generated_text,
|
||||||
|
"probabilities": [label_probs],
|
||||||
|
}
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# INVALID METHOD
|
||||||
|
########################################
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"error": "Invalid method. Use 'logreg', 'roberta', or 'flan'."
|
||||||
|
}
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||||
|
import torch
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan"
|
||||||
|
|
||||||
|
INT_TO_LABEL = {
|
||||||
|
0: "perfect",
|
||||||
|
1: "story",
|
||||||
|
2: "not specific",
|
||||||
|
}
|
||||||
|
|
||||||
|
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt(text: str) -> str:
|
||||||
|
return (
|
||||||
|
"Classify the following event into one of these categories: "
|
||||||
|
"perfect, story, not specific.\n\n"
|
||||||
|
f"Event: {text}\n\n"
|
||||||
|
"Category:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_generated_label(generated: str) -> int | None:
|
||||||
|
generated = generated.strip().lower()
|
||||||
|
for label_text, label_int in LABEL_TO_INT.items():
|
||||||
|
if label_text in generated:
|
||||||
|
return label_int
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRequest(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/evaluate")
|
||||||
|
def evaluate(req: EvalRequest):
|
||||||
|
prompt = format_prompt(req.answer)
|
||||||
|
|
||||||
|
inputs = tokenizer(
|
||||||
|
prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
max_length=256,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Get the generated label
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Produce a confidence score
|
||||||
|
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
|
||||||
|
logits_output = model(**inputs, decoder_input_ids=decoder_input_ids)
|
||||||
|
logits = logits_output.logits[:, 0, :]
|
||||||
|
|
||||||
|
# Decode the generated text label
|
||||||
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
predicted_int = parse_generated_label(generated_text)
|
||||||
|
|
||||||
|
# Extract probabilities
|
||||||
|
label_token_ids = {
|
||||||
|
label: tokenizer(label, add_special_tokens=False).input_ids[0]
|
||||||
|
for label in LABEL_TO_INT.keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
label_logits = torch.tensor(
|
||||||
|
[logits[0, tid].item() for tid in label_token_ids.values()]
|
||||||
|
)
|
||||||
|
label_probs = torch.softmax(label_logits, dim=0).tolist()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"generated": generated_text,
|
||||||
|
"probabilities": [label_probs],
|
||||||
|
}
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from fastapi import FastAPI
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import os
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression"
|
||||||
|
MODEL_FILENAME = "logreg_classifier.pt"
|
||||||
|
CACHE_DIR = "./model_cache"
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(repo_id: str, filename: str, cache_dir: str) -> dict:
|
||||||
|
local_path = os.path.join(cache_dir, filename)
|
||||||
|
if not os.path.exists(local_path):
|
||||||
|
print(f"Downloading {filename} from {repo_id}...")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
downloaded = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
local_dir=cache_dir,
|
||||||
|
)
|
||||||
|
print(f"Saved to {downloaded}")
|
||||||
|
else:
|
||||||
|
print(f"Using cached model at {local_path}")
|
||||||
|
return torch.load(local_path, map_location="cpu")
|
||||||
|
|
||||||
|
|
||||||
|
class LogisticNet(nn.Module):
|
||||||
|
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.BatchNorm1d(hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint = load_checkpoint(HF_REPO_ID, MODEL_FILENAME, CACHE_DIR)
|
||||||
|
|
||||||
|
encoder = SentenceTransformer(checkpoint["embedding_model"])
|
||||||
|
|
||||||
|
model = LogisticNet(
|
||||||
|
input_dim = checkpoint["input_dim"],
|
||||||
|
hidden_dim = checkpoint["hidden_dim"],
|
||||||
|
num_classes = checkpoint["num_classes"],
|
||||||
|
dropout = checkpoint["dropout"],
|
||||||
|
)
|
||||||
|
model.load_state_dict(checkpoint["model_state"])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRequest(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/evaluate")
|
||||||
|
def evaluate(req: EvalRequest):
|
||||||
|
embedding = encoder.encode(
|
||||||
|
[req.answer],
|
||||||
|
normalize_embeddings=True,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.tensor(embedding, dtype=torch.float32)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(x)
|
||||||
|
|
||||||
|
probs = torch.softmax(logits, dim=1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"probabilities": probs.cpu().numpy().tolist()
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ datasets
|
|||||||
# ROBERTA
|
# ROBERTA
|
||||||
scikit-learn
|
scikit-learn
|
||||||
transformers[torch]
|
transformers[torch]
|
||||||
|
sentence_transformers
|
||||||
|
|
||||||
# Utils
|
# Utils
|
||||||
numpy
|
numpy
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from fastapi import FastAPI
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
MODEL_PATH = "./roberta_classifier"
|
MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis"
|
||||||
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
||||||
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
||||||
|
|||||||
@@ -0,0 +1,227 @@
|
|||||||
|
from sklearn.utils import compute_class_weight
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
|
||||||
|
import torch
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
|
from collections import Counter
|
||||||
|
import sys
|
||||||
|
import csv
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
NUM_CLASSES = 3
|
||||||
|
model_name = "google/flan-t5-base"
|
||||||
|
|
||||||
|
INT_TO_LABEL = {
|
||||||
|
0: "perfect",
|
||||||
|
1: "story",
|
||||||
|
2: "not specific",
|
||||||
|
}
|
||||||
|
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
|
||||||
|
|
||||||
|
LABEL_PRIORITY = [
|
||||||
|
("PERFECT", 0),
|
||||||
|
("STORY", 1),
|
||||||
|
("NSPECIFIC", 2),
|
||||||
|
("REWORDING", 1),
|
||||||
|
("TINCORRECT", -1),
|
||||||
|
("DUPLICATE", -1),
|
||||||
|
("", 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
def label_to_int(extra_info: str) -> int:
|
||||||
|
if extra_info is None:
|
||||||
|
extra_info = ""
|
||||||
|
extra_info = extra_info.strip()
|
||||||
|
if extra_info == "":
|
||||||
|
for key, value in LABEL_PRIORITY:
|
||||||
|
if key == "":
|
||||||
|
return value
|
||||||
|
raise ValueError("Empty extra_info but no empty mapping defined")
|
||||||
|
tokens = set(extra_info.upper().split())
|
||||||
|
for key, value in LABEL_PRIORITY:
|
||||||
|
if key == "" :
|
||||||
|
continue
|
||||||
|
if key in tokens:
|
||||||
|
return value
|
||||||
|
raise ValueError(f"Unknown label content: '{extra_info}'")
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_from_csv(path):
|
||||||
|
texts = []
|
||||||
|
labels = []
|
||||||
|
removed_rows = 0
|
||||||
|
with open(path, newline="", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for i, row in enumerate(reader, start=1):
|
||||||
|
text = row["event"]
|
||||||
|
label_str = row["extra_info"]
|
||||||
|
try:
|
||||||
|
label_int = label_to_int(label_str)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR converting label on line {i}: {label_str}")
|
||||||
|
print(e)
|
||||||
|
sys.exit(1)
|
||||||
|
if label_int == -1:
|
||||||
|
removed_rows += 1
|
||||||
|
continue
|
||||||
|
texts.append(text)
|
||||||
|
labels.append(label_int)
|
||||||
|
print(f"Loaded {len(texts)} samples (removed {removed_rows})")
|
||||||
|
return texts, labels
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt(text: str) -> str:
|
||||||
|
return (
|
||||||
|
"Classify the following event into one of these categories: "
|
||||||
|
"perfect, story, not specific.\n\n"
|
||||||
|
f"Event: {text}\n\n"
|
||||||
|
"Category:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_generated_label(generated: str) -> int:
|
||||||
|
generated = generated.strip().lower()
|
||||||
|
for label_text, label_int in LABEL_TO_INT.items():
|
||||||
|
if label_text in generated:
|
||||||
|
return label_int
|
||||||
|
print("invlid label:" + generated)
|
||||||
|
return -1 # unknown / unparseable output
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeTextDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, texts, labels, tokenizer, max_input_length=256, max_target_length=8):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.max_target_length = max_target_length
|
||||||
|
|
||||||
|
self.inputs = [format_prompt(t) for t in texts]
|
||||||
|
# Convert integer labels to their text equivalents for the target sequence
|
||||||
|
self.targets = [INT_TO_LABEL[l] for l in labels]
|
||||||
|
self.int_labels = labels # keep originals for metric computation
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.inputs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
model_inputs = self.tokenizer(
|
||||||
|
self.inputs[idx],
|
||||||
|
max_length=self.max_input_length,
|
||||||
|
truncation=True,
|
||||||
|
padding=False,
|
||||||
|
)
|
||||||
|
target_encoding = self.tokenizer(
|
||||||
|
self.targets[idx],
|
||||||
|
max_length=self.max_target_length,
|
||||||
|
truncation=True,
|
||||||
|
padding=False,
|
||||||
|
)
|
||||||
|
# Seq2Seq convention: labels use -100 to ignore padding tokens in loss
|
||||||
|
labels = target_encoding["input_ids"]
|
||||||
|
labels = [token if token != self.tokenizer.pad_token_id else -100 for token in labels]
|
||||||
|
|
||||||
|
model_inputs["labels"] = labels
|
||||||
|
return {k: torch.tensor(v) for k, v in model_inputs.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metrics_generative(eval_pred, tokenizer):
|
||||||
|
predictions, label_ids = eval_pred
|
||||||
|
|
||||||
|
# Decode predictions
|
||||||
|
# Replace -100 in labels before decoding
|
||||||
|
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
||||||
|
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Map decoded text back to integer labels
|
||||||
|
pred_ints = [parse_generated_label(p) for p in decoded_preds]
|
||||||
|
true_ints = [parse_generated_label(l) for l in decoded_labels]
|
||||||
|
|
||||||
|
# Filter out any rows where parsing failed
|
||||||
|
valid = [(p, t) for p, t in zip(pred_ints, true_ints) if t != -1]
|
||||||
|
if not valid:
|
||||||
|
return {"accuracy": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
|
||||||
|
|
||||||
|
preds_filtered, true_filtered = zip(*valid)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"accuracy": accuracy_score(true_filtered, preds_filtered),
|
||||||
|
"f1": f1_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
|
||||||
|
"precision": precision_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
|
||||||
|
"recall": recall_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.multiprocessing.set_start_method('spawn', force=True)
|
||||||
|
print("CUDA available:", torch.cuda.is_available())
|
||||||
|
print("CUDA device count:", torch.cuda.device_count())
|
||||||
|
|
||||||
|
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
||||||
|
|
||||||
|
print("Dataset size:", len(texts))
|
||||||
|
print("Label distribution:", Counter(labels))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
||||||
|
|
||||||
|
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
||||||
|
texts, labels,
|
||||||
|
test_size=0.2,
|
||||||
|
random_state=42,
|
||||||
|
stratify=labels
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = GenerativeTextDataset(train_texts, train_labels, tokenizer)
|
||||||
|
val_dataset = GenerativeTextDataset(val_texts, val_labels, tokenizer)
|
||||||
|
|
||||||
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model=model,
|
||||||
|
padding=True,
|
||||||
|
label_pad_token_id=-100,
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args = Seq2SeqTrainingArguments(
|
||||||
|
output_dir="./results",
|
||||||
|
learning_rate=5e-5,
|
||||||
|
per_device_train_batch_size=16,
|
||||||
|
per_device_eval_batch_size=16,
|
||||||
|
num_train_epochs=10,
|
||||||
|
weight_decay=0.01,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="epoch",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="f1",
|
||||||
|
greater_is_better=True,
|
||||||
|
predict_with_generate=True,
|
||||||
|
generation_max_length=8,
|
||||||
|
dataloader_num_workers=0,
|
||||||
|
dataloader_pin_memory=False,
|
||||||
|
fp16=False,
|
||||||
|
max_grad_norm=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
compute_metrics=lambda ep: compute_metrics_generative(ep, tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
print("\nFinal evaluation metrics:")
|
||||||
|
for k, v in metrics.items():
|
||||||
|
print(f" {k}: {v}")
|
||||||
|
|
||||||
|
trainer.save_model("./flan_classifier")
|
||||||
|
tokenizer.save_pretrained("./flan_classifier")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.utils import compute_class_weight
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
|
from collections import Counter
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
import numpy as np
|
||||||
|
import csv
|
||||||
|
import sys
|
||||||
|
|
||||||
|
NUM_CLASSES = 3
|
||||||
|
EMBEDDING_MODEL = "all-mpnet-base-v2"
|
||||||
|
HIDDEN_DIM = 256
|
||||||
|
DROPOUT = 0.4
|
||||||
|
LEARNING_RATE = 2e-3
|
||||||
|
WEIGHT_DECAY = 1e-4
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
NUM_EPOCHS = 30
|
||||||
|
PATIENCE = 5
|
||||||
|
|
||||||
|
LABEL_PRIORITY = [
|
||||||
|
("PERFECT", 0),
|
||||||
|
("STORY", 1),
|
||||||
|
("NSPECIFIC", 2),
|
||||||
|
("REWORDING", 1),
|
||||||
|
("TINCORRECT", -1),
|
||||||
|
("DUPLICATE", -1),
|
||||||
|
("", 0), # fallback to PERFECT
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def label_to_int(extra_info: str) -> int:
|
||||||
|
if extra_info is None:
|
||||||
|
extra_info = ""
|
||||||
|
extra_info = extra_info.strip()
|
||||||
|
|
||||||
|
if extra_info == "":
|
||||||
|
for key, value in LABEL_PRIORITY:
|
||||||
|
if key == "":
|
||||||
|
return value
|
||||||
|
raise ValueError("No empty-string fallback defined in LABEL_PRIORITY")
|
||||||
|
|
||||||
|
tokens = set(extra_info.upper().split())
|
||||||
|
for key, value in LABEL_PRIORITY:
|
||||||
|
if key and key in tokens:
|
||||||
|
return value
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown label content: '{extra_info}'")
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_from_csv(path: str):
|
||||||
|
texts, labels = [], []
|
||||||
|
removed = 0
|
||||||
|
|
||||||
|
with open(path, newline="", encoding="utf-8") as f:
|
||||||
|
for i, row in enumerate(csv.DictReader(f), start=1):
|
||||||
|
try:
|
||||||
|
label_int = label_to_int(row["extra_info"])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR on line {i}: {row['extra_info']!r}")
|
||||||
|
print(e)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if label_int == -1:
|
||||||
|
removed += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
texts.append(row["event"])
|
||||||
|
labels.append(label_int)
|
||||||
|
|
||||||
|
print(f"Loaded {len(texts)} samples (removed {removed})")
|
||||||
|
return texts, labels
|
||||||
|
|
||||||
|
|
||||||
|
class LogisticNet(nn.Module):
|
||||||
|
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.BatchNorm1d(hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, num_classes), # raw logits – loss handles softmax
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, loader, device):
|
||||||
|
model.eval()
|
||||||
|
all_preds, all_labels = [], []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for xb, yb in loader:
|
||||||
|
xb, yb = xb.to(device), yb.to(device)
|
||||||
|
logits = model(xb)
|
||||||
|
preds = logits.argmax(dim=1).cpu().numpy()
|
||||||
|
all_preds.extend(preds)
|
||||||
|
all_labels.extend(yb.cpu().numpy())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"accuracy": accuracy_score(all_labels, all_preds),
|
||||||
|
"f1": f1_score(all_labels, all_preds, average="weighted", zero_division=0),
|
||||||
|
"precision": precision_score(all_labels, all_preds, average="weighted", zero_division=0),
|
||||||
|
"recall": recall_score(all_labels, all_preds, average="weighted", zero_division=0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
||||||
|
print("Label distribution:", Counter(labels))
|
||||||
|
|
||||||
|
print(f"\nEncoding with '{EMBEDDING_MODEL}' …")
|
||||||
|
encoder = SentenceTransformer(EMBEDDING_MODEL)
|
||||||
|
embeddings = encoder.encode(texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True)
|
||||||
|
input_dim = embeddings.shape[1]
|
||||||
|
print(f"Embedding dim: {input_dim}")
|
||||||
|
|
||||||
|
X_train, X_val, y_train, y_val = train_test_split(
|
||||||
|
embeddings, labels, test_size=0.2, random_state=42, stratify=labels
|
||||||
|
)
|
||||||
|
|
||||||
|
class_weights = compute_class_weight("balanced", classes=np.unique(y_train), y=y_train)
|
||||||
|
weight_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
|
||||||
|
print("Class weights:", class_weights)
|
||||||
|
|
||||||
|
def make_loader(X, y, shuffle=False):
|
||||||
|
ds = TensorDataset(
|
||||||
|
torch.tensor(X, dtype=torch.float32),
|
||||||
|
torch.tensor(y, dtype=torch.long),
|
||||||
|
)
|
||||||
|
return DataLoader(ds, batch_size=BATCH_SIZE, shuffle=shuffle)
|
||||||
|
|
||||||
|
train_loader = make_loader(X_train, y_train, shuffle=True)
|
||||||
|
val_loader = make_loader(X_val, y_val, shuffle=False)
|
||||||
|
|
||||||
|
model = LogisticNet(input_dim, HIDDEN_DIM, NUM_CLASSES, DROPOUT).to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss(weight=weight_tensor)
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
|
||||||
|
|
||||||
|
best_f1 = 0.0
|
||||||
|
best_state = None
|
||||||
|
epochs_no_imp = 0
|
||||||
|
|
||||||
|
print("\n Training:")
|
||||||
|
for epoch in range(1, NUM_EPOCHS + 1):
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
|
||||||
|
for xb, yb in train_loader:
|
||||||
|
xb, yb = xb.to(device), yb.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = criterion(model(xb), yb)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item() * len(yb)
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
avg_loss = total_loss / len(train_loader.dataset)
|
||||||
|
val_metrics = evaluate(model, val_loader, device)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Epoch {epoch:3d}/{NUM_EPOCHS} | "
|
||||||
|
f"loss {avg_loss:.4f} | "
|
||||||
|
f"val_acc {val_metrics['accuracy']:.4f} | "
|
||||||
|
f"val_f1 {val_metrics['f1']:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Early stopping on weighted F1
|
||||||
|
if val_metrics["f1"] > best_f1:
|
||||||
|
best_f1 = val_metrics["f1"]
|
||||||
|
best_state = {k: v.clone() for k, v in model.state_dict().items()}
|
||||||
|
epochs_no_imp = 0
|
||||||
|
else:
|
||||||
|
epochs_no_imp += 1
|
||||||
|
if epochs_no_imp >= PATIENCE:
|
||||||
|
print(f"Early stopping at epoch {epoch} (no improvement for {PATIENCE} epochs)")
|
||||||
|
break
|
||||||
|
|
||||||
|
print("\n Final evaluation:")
|
||||||
|
model.load_state_dict(best_state)
|
||||||
|
final = evaluate(model, val_loader, device)
|
||||||
|
for k, v in final.items():
|
||||||
|
print(f" {k}: {v:.4f}")
|
||||||
|
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"model_state": best_state,
|
||||||
|
"input_dim": input_dim,
|
||||||
|
"hidden_dim": HIDDEN_DIM,
|
||||||
|
"num_classes": NUM_CLASSES,
|
||||||
|
"dropout": DROPOUT,
|
||||||
|
"embedding_model": EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"logreg_classifier.pt"
|
||||||
|
)
|
||||||
|
print("\n Model saved to logreg_classifier.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from sklearn.utils import compute_class_weight
|
from sklearn.utils import compute_class_weight
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification
|
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
|
||||||
import torch
|
import torch
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
@@ -10,7 +10,7 @@ import csv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
NUM_CLASSES = 3
|
NUM_CLASSES = 3
|
||||||
model_name = "distilbert/distilroberta-base"
|
model_name = "roberta-base"
|
||||||
|
|
||||||
LABEL_PRIORITY = [
|
LABEL_PRIORITY = [
|
||||||
("PERFECT", 0),
|
("PERFECT", 0),
|
||||||
@@ -29,21 +29,12 @@ class WeightedTrainer(Trainer):
|
|||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
labels = inputs.get("labels")
|
labels = inputs.get("labels")
|
||||||
# print("DBG: Before forward")
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
# print("DBG: After forward")
|
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
|
|
||||||
# loss_fct = CrossEntropyLoss(weight=self.class_weights.to(logits.device))
|
loss_fct = CrossEntropyLoss(weight=self.class_weights.to(logits.device))
|
||||||
loss_fct = CrossEntropyLoss(
|
|
||||||
weight=self.class_weights.to(logits.device).to(logits.dtype)
|
|
||||||
)
|
|
||||||
# loss_fct = CrossEntropyLoss()
|
|
||||||
|
|
||||||
# print("DBG: Before loss")
|
|
||||||
loss = loss_fct(logits, labels)
|
loss = loss_fct(logits, labels)
|
||||||
# loss.backward()
|
|
||||||
# print("DBG: After loss")
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
def label_to_int(extra_info: str) -> int:
|
def label_to_int(extra_info: str) -> int:
|
||||||
@@ -129,23 +120,17 @@ def main():
|
|||||||
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
|
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
|
||||||
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
||||||
|
|
||||||
# tokenizer = RobertaTokenizer.from_pretrained(model_name, hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2)
|
tokenizer = RobertaTokenizer.from_pretrained(model_name, hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2)
|
||||||
# model = RobertaForSequenceClassification.from_pretrained(
|
model = RobertaForSequenceClassification.from_pretrained(
|
||||||
# model_name,
|
|
||||||
# num_labels=NUM_CLASSES
|
|
||||||
# )
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
|
||||||
model_name,
|
model_name,
|
||||||
num_labels=NUM_CLASSES
|
num_labels=NUM_CLASSES
|
||||||
)
|
)
|
||||||
|
|
||||||
# for param in model.deberta.parameters():
|
for param in model.roberta.parameters():
|
||||||
# param.requires_grad = True
|
param.requires_grad = False
|
||||||
|
|
||||||
# for param in model.deberta.encoder.layer[-6:].parameters():
|
for param in model.roberta.encoder.layer[-6:].parameters():
|
||||||
# param.requires_grad = True
|
param.requires_grad = True
|
||||||
|
|
||||||
print("Dataset size:", len(texts))
|
print("Dataset size:", len(texts))
|
||||||
print("Label distribution:")
|
print("Label distribution:")
|
||||||
@@ -155,8 +140,7 @@ def main():
|
|||||||
texts,
|
texts,
|
||||||
labels,
|
labels,
|
||||||
test_size=0.2,
|
test_size=0.2,
|
||||||
random_state=42,
|
random_state=42
|
||||||
stratify=labels
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -189,7 +173,6 @@ def main():
|
|||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
# print(f"DBG: Loading item {idx}")
|
|
||||||
item = {
|
item = {
|
||||||
key: torch.tensor(val[idx])
|
key: torch.tensor(val[idx])
|
||||||
for key, val in self.encodings.items()
|
for key, val in self.encodings.items()
|
||||||
@@ -204,8 +187,7 @@ def main():
|
|||||||
output_dir="./results",
|
output_dir="./results",
|
||||||
learning_rate=2e-5,
|
learning_rate=2e-5,
|
||||||
per_device_train_batch_size=32,
|
per_device_train_batch_size=32,
|
||||||
# gradient_accumulation_steps=2,
|
num_train_epochs=5,
|
||||||
num_train_epochs=15,
|
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
eval_strategy="epoch",
|
eval_strategy="epoch",
|
||||||
@@ -213,8 +195,7 @@ def main():
|
|||||||
metric_for_best_model="f1",
|
metric_for_best_model="f1",
|
||||||
greater_is_better=True,
|
greater_is_better=True,
|
||||||
dataloader_num_workers=4,
|
dataloader_num_workers=4,
|
||||||
dataloader_pin_memory=True,
|
dataloader_pin_memory=True
|
||||||
# warmup_steps=100,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = TextDataset(train_encodings, train_labels)
|
train_dataset = TextDataset(train_encodings, train_labels)
|
||||||
@@ -237,8 +218,8 @@ def main():
|
|||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
print(f"{k}: {v}")
|
print(f"{k}: {v}")
|
||||||
|
|
||||||
trainer.save_model("./roberta_distilled_classifier")
|
trainer.save_model("./roberta_classifier")
|
||||||
tokenizer.save_pretrained("./roberta_distilled_classifier")
|
tokenizer.save_pretrained("./roberta_classifier")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ async function processRecord(record: any): Promise<ResultRecord> {
|
|||||||
input: buildAgentInput(record),
|
input: buildAgentInput(record),
|
||||||
streamMode: "values",
|
streamMode: "values",
|
||||||
config: {
|
config: {
|
||||||
recursion_limit: 50
|
recursion_limit: 100
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,119 @@
|
|||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from selenium import webdriver
|
||||||
|
from selenium.webdriver.chrome.options import Options
|
||||||
|
from selenium.common.exceptions import WebDriverException, TimeoutException, StaleElementReferenceException
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def init_driver():
|
||||||
|
options = Options()
|
||||||
|
options.headless = True
|
||||||
|
options.add_argument("--disable-gpu")
|
||||||
|
options.add_argument("--no-sandbox")
|
||||||
|
options.add_argument("--headless")
|
||||||
|
options.add_argument("--disable-blink-features=AutomationControlled")
|
||||||
|
options.add_argument("--window-size=1920,1080")
|
||||||
|
prefs = {
|
||||||
|
"profile.managed_default_content_settings.images": 2, # block images
|
||||||
|
"profile.default_content_setting_values.stylesheets": 2, # block CSS
|
||||||
|
"profile.managed_default_content_settings.cookies": 2, # optional
|
||||||
|
}
|
||||||
|
options.add_experimental_option("prefs", prefs)
|
||||||
|
|
||||||
|
driver = webdriver.Chrome(options=options)
|
||||||
|
driver.set_page_load_timeout(30)
|
||||||
|
return driver
|
||||||
|
|
||||||
|
def is_root_url(url):
|
||||||
|
parsed = urlparse(url)
|
||||||
|
return parsed.path in ("", "/")
|
||||||
|
|
||||||
|
def is_404_page(driver):
|
||||||
|
"""Safely check for 404, handling stale elements."""
|
||||||
|
try:
|
||||||
|
title = driver.title.lower()
|
||||||
|
body_text = driver.find_element("tag name", "body").text.lower()
|
||||||
|
return "404" in title or "404" in body_text
|
||||||
|
except StaleElementReferenceException:
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_url_selenium(url):
|
||||||
|
driver = None
|
||||||
|
try:
|
||||||
|
driver = init_driver()
|
||||||
|
driver.get(url)
|
||||||
|
# 404 check
|
||||||
|
if is_404_page(driver):
|
||||||
|
return False, "404 page detected"
|
||||||
|
# Root URL after redirects
|
||||||
|
final_url = driver.current_url
|
||||||
|
if is_root_url(final_url):
|
||||||
|
return False, f"Redirected to root URL ({final_url})"
|
||||||
|
return True, None
|
||||||
|
except (WebDriverException, TimeoutException) as e:
|
||||||
|
return False, str(e)
|
||||||
|
finally:
|
||||||
|
if driver:
|
||||||
|
driver.quit()
|
||||||
|
|
||||||
|
def process_event(event):
|
||||||
|
"""Process an event only if score > 0.4."""
|
||||||
|
score = event.get("score", 0)
|
||||||
|
if score <= 0.4:
|
||||||
|
return None, False, "Score too low"
|
||||||
|
url = event.get("Url")
|
||||||
|
if not url:
|
||||||
|
return None, False, "No URL"
|
||||||
|
is_valid, error_msg = check_url_selenium(url)
|
||||||
|
event["url_valid"] = is_valid
|
||||||
|
return url, is_valid, error_msg
|
||||||
|
|
||||||
|
def process_jsonl_file(file_path, max_workers=4):
|
||||||
|
invalid_urls = []
|
||||||
|
valid_urls = 0
|
||||||
|
|
||||||
|
# Gather events with score > 0.4
|
||||||
|
urls_to_check = []
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line_data = json.loads(line)
|
||||||
|
if line_data.get("status") != "success":
|
||||||
|
continue
|
||||||
|
for event in line_data.get("events", []):
|
||||||
|
if event.get("score", 0) > 0.4:
|
||||||
|
urls_to_check.append(event)
|
||||||
|
|
||||||
|
total_urls = len(urls_to_check)
|
||||||
|
|
||||||
|
# ThreadPoolExecutor with tqdm progress bar
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
future_to_event = {executor.submit(process_event, e): e for e in urls_to_check}
|
||||||
|
for future in tqdm(as_completed(future_to_event), total=total_urls, desc="Checking URLs"):
|
||||||
|
url, is_valid, error_msg = future.result()
|
||||||
|
if not is_valid and url:
|
||||||
|
invalid_urls.append((url, error_msg))
|
||||||
|
else:
|
||||||
|
valid_urls += 1
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
if invalid_urls:
|
||||||
|
print("\nList of invalid URLs and reasons:")
|
||||||
|
for url, err in invalid_urls:
|
||||||
|
print(f"{url} --> {err}")
|
||||||
|
print("\n=== URL Validation Summary ===")
|
||||||
|
print(f"Total URLs processed: {total_urls}")
|
||||||
|
print(f"Valid URLs (loaded successfully): {valid_urls}")
|
||||||
|
print(f"Invalid URLs: {len(invalid_urls)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Validate URLs in JSONL file events using Selenium")
|
||||||
|
parser.add_argument("file_path", type=str, help="Path to the JSONL file")
|
||||||
|
parser.add_argument("--workers", type=int, default=4, help="Number of parallel Selenium workers")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
process_jsonl_file(args.file_path, max_workers=args.workers)
|
||||||
@@ -16,18 +16,18 @@ BASE_URL = "https://dbkf.ontotext.com/rest-api/search/documents"
|
|||||||
|
|
||||||
# "documentTypes": "http://schema.org/Claim",
|
# "documentTypes": "http://schema.org/Claim",
|
||||||
DEFAULT_PARAMS = [
|
DEFAULT_PARAMS = [
|
||||||
("concept", "http://weverify.eu/resource/Concept/Q212"),
|
("documentTypes", "http://schema.org/Claim"),
|
||||||
("from", "2000-01-01"),
|
("from", "2000-01-01"),
|
||||||
("to", "2026-02-19"),
|
("to", "2026-02-19"),
|
||||||
("lang", "en"),
|
("lang", "en"),
|
||||||
("limit", 5000),
|
("limit", 7000),
|
||||||
("page", 1),
|
("page", 1),
|
||||||
("orderBy", "date"),
|
("orderBy", "date"),
|
||||||
|
("organization", "http://weverify.eu/resource/Organization/128573c5d49d37558706194e755f152d"), # Science Direct
|
||||||
("organization", "http://weverify.eu/resource/Organization/3727f7b2aa90ec0716693e5464b28d18"), # StopFake
|
("organization", "http://weverify.eu/resource/Organization/3727f7b2aa90ec0716693e5464b28d18"), # StopFake
|
||||||
("organization", "http://weverify.eu/resource/Organization/c71953fa6cf24ac4178f751c77862070"), # CheckYourFact
|
|
||||||
]
|
]
|
||||||
|
|
||||||
NUM_RANDOM_CLAIMS = 40
|
NUM_RANDOM_CLAIMS = 200
|
||||||
|
|
||||||
INPUT_FILE = "../../data/input.jsonl"
|
INPUT_FILE = "../../data/input.jsonl"
|
||||||
OUTPUT_FILE = "../../data/claims.json"
|
OUTPUT_FILE = "../../data/claims.json"
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import streamlit as st
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
# THRESH = 0.4
|
|
||||||
THRESH = 0.6
|
THRESH = 0.6
|
||||||
|
|
||||||
def page_title() -> str:
|
def page_title() -> str:
|
||||||
@@ -61,6 +60,18 @@ def render():
|
|||||||
return
|
return
|
||||||
|
|
||||||
for file_path in jsonl_files:
|
for file_path in jsonl_files:
|
||||||
|
thresh = THRESH
|
||||||
|
if ("flan" in file_path.name):
|
||||||
|
thresh = 0.94
|
||||||
|
if ("regression" in file_path.name):
|
||||||
|
thresh = 0.75
|
||||||
|
if ("ensemble" in file_path.name):
|
||||||
|
thresh = 0.1
|
||||||
|
if ("ensemble" in file_path.name and "2" in file_path.name):
|
||||||
|
thresh = 0.4
|
||||||
|
if ("ensemble" in file_path.name and "vot" in file_path.name):
|
||||||
|
thresh = 0.7
|
||||||
|
|
||||||
st.subheader(f"File: {file_path.name}")
|
st.subheader(f"File: {file_path.name}")
|
||||||
|
|
||||||
confidence_counter = Counter()
|
confidence_counter = Counter()
|
||||||
@@ -86,15 +97,15 @@ def render():
|
|||||||
dup_counter += 1
|
dup_counter += 1
|
||||||
elif "ranked" not in event:
|
elif "ranked" not in event:
|
||||||
"ignore for now"
|
"ignore for now"
|
||||||
elif score > THRESH and extra_lower == "perfect":
|
elif score > thresh and extra_lower == "perfect":
|
||||||
confidence_counter["Correct-PERFECT"] += 1
|
confidence_counter["Correct-PERFECT"] += 1
|
||||||
elif score > THRESH and extra_lower == "":
|
elif score > thresh and extra_lower == "":
|
||||||
confidence_counter["Correct-FINE"] += 1
|
confidence_counter["Correct-FINE"] += 1
|
||||||
elif score > THRESH and extra_lower != "perfect" and extra_lower != "":
|
elif score > thresh and extra_lower != "perfect" and extra_lower != "":
|
||||||
confidence_counter["Over-confident"] += 1
|
confidence_counter["Over-confident"] += 1
|
||||||
wrong_counter[extra_lower] += 1
|
wrong_counter[extra_lower] += 1
|
||||||
overconfident_docs.append(doc_id)
|
overconfident_docs.append(doc_id)
|
||||||
elif score < THRESH and (extra_lower == "perfect" or extra_lower == ""):
|
elif score < thresh and (extra_lower == "perfect" or extra_lower == ""):
|
||||||
confidence_counter["Under-confident"] += 1
|
confidence_counter["Under-confident"] += 1
|
||||||
underconfident_docs.append(doc_id)
|
underconfident_docs.append(doc_id)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
from collections import Counter
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import streamlit as st
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
THRESH = 0.4
|
||||||
|
|
||||||
|
def page_title() -> str:
|
||||||
|
return "Statistics 2"
|
||||||
|
|
||||||
|
def render():
|
||||||
|
st.header("Statistics 2")
|
||||||
|
|
||||||
|
path = Path("../../data/refinement")
|
||||||
|
|
||||||
|
if not path.exists() or not path.is_dir():
|
||||||
|
st.error("Invalid folder path.")
|
||||||
|
return
|
||||||
|
|
||||||
|
jsonl_files = sorted(path.glob("*.jsonl"))
|
||||||
|
if not jsonl_files:
|
||||||
|
st.info("No .jsonl files found in this folder.")
|
||||||
|
return
|
||||||
|
|
||||||
|
for file_path in jsonl_files:
|
||||||
|
thresh = THRESH
|
||||||
|
st.subheader(f"File: {file_path.name}")
|
||||||
|
|
||||||
|
confidence_counter = Counter()
|
||||||
|
|
||||||
|
# ---- Read file line by line ----
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
entry = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if (entry.get("status") != "success"):
|
||||||
|
confidence_counter["Crash"] += 1
|
||||||
|
for event in entry.get("events", []):
|
||||||
|
score = event.get("score", None)
|
||||||
|
|
||||||
|
if score is not None:
|
||||||
|
if score == -1:
|
||||||
|
confidence_counter["BAD-1"] += 1
|
||||||
|
elif score > thresh:
|
||||||
|
confidence_counter["PERFECT"] += 1
|
||||||
|
else:
|
||||||
|
confidence_counter["BAD"] += 1
|
||||||
|
|
||||||
|
if confidence_counter:
|
||||||
|
df_conf = pd.DataFrame(
|
||||||
|
confidence_counter.items(),
|
||||||
|
columns=["Category", "Count"]
|
||||||
|
)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.pie(
|
||||||
|
df_conf["Count"],
|
||||||
|
labels=df_conf["Category"],
|
||||||
|
autopct="%1.1f%%",
|
||||||
|
startangle=90
|
||||||
|
)
|
||||||
|
ax.axis("equal")
|
||||||
|
ax.set_title(file_path.name)
|
||||||
|
|
||||||
|
total = sum(confidence_counter.values())
|
||||||
|
correct = confidence_counter["PERFECT"]
|
||||||
|
|
||||||
|
corr_percent = (correct / total) * 100
|
||||||
|
|
||||||
|
st.markdown(f"**Correct: {corr_percent:.2f}% ({correct}/{total})**")
|
||||||
|
st.markdown(f"**Crash: {confidence_counter["Crash"]}**")
|
||||||
|
st.pyplot(fig, width=500)
|
||||||
|
else:
|
||||||
|
st.info("No score data available in this file.")
|
||||||
Reference in New Issue
Block a user