24 Commits

Author SHA1 Message Date
William Jeynes b37799b3d2 Improve response extraction 2026-04-02 21:02:26 +01:00
William Jeynes 10f2644408 Use a slightly smaller model. Reduce concurreny. Be more clear in the prompts 2026-04-02 20:10:57 +01:00
William Jeynes 7e586fe17d Allow for configurable ranking server url. Delete old ragas call 2026-04-02 13:48:15 +01:00
William Jeynes 7e37a22058 Switch to actual instruction model. For debug, log entire object. 2026-04-02 13:18:02 +01:00
William Jeynes 2ed47980ef Add better error handling to LLM output response 2026-03-31 19:26:56 +01:00
William Jeynes 01b04dd73e use a model we know has tool calling capabilities 2026-03-31 18:26:55 +01:00
William Jeynes 593baf9b15 add extra options 2026-03-31 17:15:55 +01:00
William Jeynes 893829e599 Switch to CPU only, as to not confuse GPU 2026-03-31 16:09:41 +01:00
William Jeynes 36c30a427d update deps. Install ollama for lang chain. Update model to deepseek 2026-03-31 16:08:28 +01:00
William Jeynes b610e8c989 Add sentence transformers to requirements for ensemble service 2026-03-31 15:52:14 +01:00
William Jeynes f8d4155b7c Add more robust parsing of LLM JSON output 2026-03-27 11:09:59 +00:00
William Jeynes 5e374a8bd6 Fix errors seen during longer runs: selenium exceptions, insecure certificates, recusrsion limit exceeded, BM25 document corpus too small 2026-03-26 12:22:13 +00:00
William Jeynes fbc688b8f9 add date to returned data 2026-03-25 22:37:14 +00:00
William Jeynes 77cdd9a01c Add statistics for model experiments. Fix dead link in documentation 2026-03-25 21:57:52 +00:00
William Jeynes a7f5978f64 Update documentation. Stop storing context. Decide on final claims source 2026-03-25 14:24:55 +00:00
William Jeynes 872346c657 Update run.sh to match new evaluation service 2026-03-24 19:16:48 +00:00
William Jeynes 8f939d54c4 Implement ensemble into final model structure 2026-03-24 19:07:24 +00:00
William Jeynes 624d45bc53 Re-allow multithreading on service. Add results table 2026-03-24 18:29:40 +00:00
William Jeynes 80bc151379 add majority voting 2026-03-24 16:50:41 +00:00
William Jeynes 5ce64290ce Make an ensemble model to combine scores together (very high accuracy) 2026-03-24 15:50:41 +00:00
William Jeynes 87fccb7e2b Add downloading from hugging face 2026-03-24 13:23:08 +00:00
William Jeynes 8c1e35f66f Increase dropout on regression model to cut down on overfitting 2026-03-24 13:16:18 +00:00
William Jeynes 44395bb251 add linear regression model initial version 2026-03-24 12:25:15 +00:00
William Jeynes e368c50577 Add training scripts for distilled, flan. Add run service for flan 2026-03-23 22:43:59 +00:00
40 changed files with 1681 additions and 612 deletions
+1
View File
@@ -1,2 +1,3 @@
# TEMP
literature/
backup.tar.gz
+9
View File
@@ -7,6 +7,15 @@ Final Dissertation Submission Repository
## Solution Diagram
-- todo --
## Classifier Refinement
[See RAGAS_Service](/supporting/RAGAS_Service/)
## Agent Refinement
[See agent](/agent/)
## Generated Database Link and Usage Experiments
-- todo --
## Repository Structure
```
├── run.sh # Bash script to run project elements from one place
+1
View File
@@ -4,3 +4,4 @@ LANGSMITH_API_KEY=123456
LANGSMITH_ENDPOINT=https://eu.api.smith.langchain.com
SCRAPER_INSTANCE=https://example.com
SCRAPER_PARAM_ANYTHING=else
RANKING_URL=http://localhost:8000/evaluate
+3
View File
@@ -0,0 +1,3 @@
## Refining the agent output
TODO: Table and document experiments
+15 -4
View File
@@ -10,7 +10,7 @@ import { createModelNode } from "./nodes/model";
import { loopEndConditional } from "./conditionals/loop_end";
import { sort } from "./nodes/sort";
import { triggerEventSetup } from "./nodes/triggerEventSetup";
import { robertaMetrics } from "./nodes/robertaMetrics";
import { createEnsembleNode } from "./nodes/ensembleNode";
const triggerEventToolNode = createToolNode(triggerEventToolsByName);
@@ -19,6 +19,10 @@ const triggerEventModel = createModelNode(triggerEventToolsByName, "trigger.txt"
const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name);
const roNode = createEnsembleNode("ROBERTA", "roberta");
const flNode = createEnsembleNode("FLAN", "flan");
const lrNode = createEnsembleNode("REGRESSION", "logreg");
const agent = new StateGraph(MessagesState)
//NODES
@@ -30,7 +34,10 @@ const agent = new StateGraph(MessagesState)
.addNode("triggerEventModel", triggerEventModel)
.addNode(verificationSetup.name, verificationSetup)
.addNode(robertaMetrics.name, robertaMetrics)
.addNode("roNode", roNode)
.addNode("flNode", flNode)
.addNode("lrNode", lrNode)
.addNode(produceRanking.name, produceRanking)
.addNode(sort.name, sort)
@@ -45,9 +52,13 @@ const agent = new StateGraph(MessagesState)
.addConditionalEdges("triggerEventModel", triggerEventToolConditional, ["triggerEventToolNode", verificationSetup.name])
.addEdge("triggerEventToolNode", "triggerEventModel")
.addEdge(verificationSetup.name, robertaMetrics.name)
.addEdge(verificationSetup.name, "roNode")
.addEdge(verificationSetup.name, "flNode")
.addEdge(verificationSetup.name, "lrNode")
.addEdge(robertaMetrics.name, produceRanking.name)
.addEdge("roNode", produceRanking.name)
.addEdge("flNode", produceRanking.name)
.addEdge("lrNode", produceRanking.name)
// @ts-expect-error
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
+17
View File
@@ -0,0 +1,17 @@
import { GraphNode } from "@langchain/langgraph";
import { MessagesState } from "../state";
import { AIMessage } from "@langchain/core/messages";
import { evaluateWithEnsemble } from "../tools/ensembleCall";
export function createEnsembleNode(title: string, method: string): GraphNode<typeof MessagesState> {
return async (state) => {
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
const result = await evaluateWithEnsemble({ answer, method })
const score = result.validProb - result.invalidProb;
return {
messages: [new AIMessage(title + ":" + score)]
};
};
};
+10 -7
View File
@@ -1,25 +1,28 @@
import { HumanMessage, SystemMessage } from "@langchain/core/messages";
import { SystemMessage } from "@langchain/core/messages";
import { GraphNode } from "@langchain/langgraph";
import { MessagesState } from "../state";
import { ChatOpenAI } from "@langchain/openai"
import { ChatOllama } from "@langchain/ollama";
import { hydratePrompt } from "../prompts/hydratePrompt";
import { logger } from "../utils/logger";
export function createModelNode(tools: any, promptPath: string): GraphNode<typeof MessagesState> {
return async (state) => {
const sysPrompt = await hydratePrompt(promptPath, state);
const model = new ChatOpenAI({
model: "gpt-5-mini"
const model = new ChatOllama({
model: "llama3.1:8b-instruct-q4_K_M",
temperature: 0.3
});
const modelWithTools = model.bindTools(Object.values(tools));
const response = await modelWithTools.invoke([
new SystemMessage(
sysPrompt
),
new SystemMessage(sysPrompt),
...state.messages,
]);
logger.error(response);
return {
messages: [response]
};
+16 -22
View File
@@ -2,31 +2,25 @@ import { GraphNode } from "@langchain/langgraph";
import { MessagesState } from "../state";
import { BaseMessage } from "@langchain/core/messages";
//TODO: Each of these might need different weights
const keys = ["CONFIDENCE", "RELATION", "RAGAS", "ROBERTA"];
const mapping = {
VERYHIGH: 1.0,
HIGH: 0.75,
MEDIUM: 0.5,
LOW: 0.25,
VERYLOW: 0.0,
const models = {
REGRESSION: 0.3,
ROBERTA: 0.5,
FLAN: 0.3,
} as const;
type Priority = keyof typeof mapping;
type ModelKey = keyof typeof models;
function mapResponse(value: string | undefined | null): number {
if (!value) return 1;
if (!value) return 0;
const trimmed = value.trim();
const num = parseFloat(trimmed);
// If number, return it
if (!isNaN(num)) return num;
// Otherwise, map to value
const upper = trimmed.toUpperCase() as Priority;
return mapping[upper] ?? 0;
if (!isNaN(num)) {
return num;
} else {
return 0;
}
}
function getLastMessageContaining(
@@ -43,15 +37,15 @@ function getLastMessageContaining(
}
export const produceRanking: GraphNode<typeof MessagesState> = async (state) => {
// Extract and map values
const values = keys.map((key) => {
const values = (Object.keys(models) as ModelKey[]).map((key) => {
const msg = getLastMessageContaining(state.messages, key);
const part = msg?.split(":").at(1);
return mapResponse(part);
const baseValue = mapResponse(part);
return baseValue * models[key];
});
// Multiply!
const result = values.reduce((acc, val) => acc * val, 1);
const result = values.reduce((acc, val) => acc + val, 0);
const current = state.proposedTriggerEvent;
current[state.proposedTriggerEventIndex].score = result;
-16
View File
@@ -1,16 +0,0 @@
import { GraphNode } from "@langchain/langgraph";
import { MessagesState } from "../state";
import { AIMessage, HumanMessage } from "@langchain/core/messages";
import { evaluateWithRagas } from "../tools/ragasCall";
export const ragasMetrics: GraphNode<typeof MessagesState> = async (state) => {
const question = "A possible trigger event for: " + state.disinformationTitle //Should it be raw, or normalized?
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
const contexts = state.proposedTriggerEvent[state.proposedTriggerEventIndex].context?.split("^^^") ?? []
const results = await evaluateWithRagas({question, answer, contexts})
return {
messages: [ new AIMessage("RAGAS:" + results.faithfulness)]
};
};
-18
View File
@@ -1,18 +0,0 @@
import { GraphNode } from "@langchain/langgraph";
import { MessagesState } from "../state";
import { AIMessage } from "@langchain/core/messages";
import { evaluateWithRoberta } from "../tools/robertaCall";
export const robertaMetrics: GraphNode<typeof MessagesState> = async (state) => {
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
const result = await evaluateWithRoberta({answer})
const score = result.validProb - result.invalidProb;
return {
messages: [ new AIMessage("ROBERTA:" + score)]
};
};
+9 -1
View File
@@ -3,8 +3,16 @@ import { MessagesState } from "../state";
import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { rankExampleTriggerEvents } from "../tools/retreiveExamples";
function extractTE(text: string) {
const match = text.match(/<norm>([\s\S]*?)<\/norm>/);
if (!match) throw new Error("Nothing found between <norm> tags");
return match[1].trim();
}
export const triggerEventSetup: GraphNode<typeof MessagesState> = async (state) => {
let nc = state?.messages?.at(-1)?.content ?? "" //keep a copy of normalized trigger event. Again two things, womp womp
let raw = state?.messages?.at(-1)?.content ?? "" //keep a copy of normalized trigger event. Again two things, womp womp
let nc = extractTE(raw.toString())
//Now give in-context examples. hopwfully we can self-teach?
let similarityResults = await rankExampleTriggerEvents(state.disinformationTitle)
+45 -16
View File
@@ -1,31 +1,60 @@
import { GraphNode } from "@langchain/langgraph";
import { MessagesState, ProposedTriggerEventArray } from "../state";
import { logger } from "../utils/logger";
import { queryScraper } from "../tools/webSearch";
import { rankAndDisplayData } from "../tools/triggerEventTools";
import { jsonrepair } from 'jsonrepair';
function extractJSON(text: string) {
const match = text.match(/<json>([\s\S]*?)<\/json>/);
if (!match) throw new Error("No JSON found between <json> tags");
return match[1].trim();
}
export const verificationSetup: GraphNode<typeof MessagesState> = async (state) => {
//this is kinda doing two things, but having two nodes for it seems overkill
if (state.proposedTriggerEvent == undefined) {
logger.warn("No trigger events in memory, parsing")
logger.warn("No trigger events in memory, parsing");
let genResponse = state.messages.at(-1)?.content.toString() ?? "";
const parsed = ProposedTriggerEventArray.parse(JSON.parse(genResponse));
const genResponse = state.messages.at(-1)?.content.toString() ?? "";
for (let i = 0; i < parsed.length; i++) {
const search = parsed[i].SearchQuery
const data = await queryScraper(search);
const output = await rankAndDisplayData(data, search);
let repaired: string;
try {
let extracted = extractJSON(genResponse)
repaired = jsonrepair(extracted);
} catch (repairErr: any) {
logger.error("Failed to repair JSON from LLM response.");
logger.error("Original LLM response:\n%s", genResponse);
throw new Error(`JSON repair failed: ${repairErr.message}`);
}
parsed[i].context = output;
let parsed;
try {
const json = JSON.parse(repaired);
if (Array.isArray(json)) {
parsed = ProposedTriggerEventArray.parse(json);
} else {
// try grab first value
const firstValue = Object.values(json)[0];
if (Array.isArray(firstValue)) {
parsed = ProposedTriggerEventArray.parse(firstValue);
} else {
logger.error("No array found in JSON after parsing.");
logger.error("Repaired JSON:\n%s", repaired);
logger.error("Original LLM response:\n%s", genResponse);
throw new Error("No array found in JSON structure");
}
}
} catch (parseErr: any) {
logger.error("Failed to parse LLM response to JSON or validate array.");
logger.error("Repaired JSON:\n%s", repaired);
logger.error("Original LLM response:\n%s", genResponse);
throw new Error(`Parsing failed: ${parseErr.message}`);
}
return { proposedTriggerEvent: parsed, proposedTriggerEventIndex: 0 };
}
else {
logger.info("Trigger event index %s", state.proposedTriggerEventIndex+1)
} else {
logger.info("Trigger event index %s", state.proposedTriggerEventIndex + 1);
return { proposedTriggerEvent: state.proposedTriggerEvent, proposedTriggerEventIndex: state.proposedTriggerEventIndex+1 };
return { proposedTriggerEvent: state.proposedTriggerEvent, proposedTriggerEventIndex: state.proposedTriggerEventIndex + 1 };
}
};
+392 -357
View File
File diff suppressed because it is too large Load Diff
+2
View File
@@ -17,6 +17,7 @@
"@langchain/core": "^1.1.17",
"@langchain/langgraph": "^1.1.2",
"@langchain/langgraph-sdk": "^1.5.5",
"@langchain/ollama": "^1.2.6",
"@langchain/openai": "^1.2.3",
"axios": "^1.13.5",
"compute-cosine-similarity": "^1.1.0",
@@ -24,6 +25,7 @@
"dotenv": "^17.2.3",
"exponential-backoff": "^3.1.3",
"fs": "^0.0.1-security",
"jsonrepair": "^3.13.3",
"langchain": "^1.2.14",
"selenium-webdriver": "^4.40.0",
"tldts": "^7.0.23",
+4 -1
View File
@@ -16,4 +16,7 @@ Relevent examples are included in preceeding messages, use these as exact inspir
The claim to normalize is:
###TITLE###
Produce no other text other than the condensed claim.
Produce no other text other than the condensed claim, surrounded <norm></norm>
For example: BREAKING: the sky is green!
Becomes: <norm>The sky is green</norm>
-9
View File
@@ -1,9 +0,0 @@
Could the following real-world event:
###TECLAIM###
Be a trigger for the following disinformation:
###TITLE###
Respond with "RELATION", followed by : followed by a confidence score (VERYHIGH, HIGH, MEDIUM, LOW, VERYLOW) followed by : followed by the reason. Use no other words, just return the score and reason in format.
Ignore wether the event happened or not, purely consider the likiness of causation
+16 -1
View File
@@ -14,7 +14,18 @@ Include a concise but specific search query that can be looked up on a search en
Include a url to a source for your trigger event (not a web search, a specific url from a reputuable source). Do not use OAI cite, include url as text in response.
Use a JSON format with each entry containing "Event,ReasoningWhyRelevant,SearchQuery,Url".
Include the date that the event happened ("March 2022" for exmaple)
Use a JSON format with each entry containing "Event,ReasoningWhyRelevant,SearchQuery,Url,Date".
Return ONLY JSON, no extra text. Wrap it like this:
<json>
[
{
"Event": "Example"
...
}
]
</json>
Multiple tool invocations should be requested at once, if applicable.
Use your abilities to look between the lines and produce some insightful analysis, thinking both short and long term.
@@ -24,4 +35,8 @@ Events will be reordered as part of processing, each statement must stand alone
The preceeding messages act as examples of previous responses to potentially ficitonal events and scores given.
Analysis should only be completed for proposed events that would graner >0.7 points
Since URLs change frequently, use tools to retreive up to date informaiton everytime, provided examples or existing knowledge will be wrong or out of date.
Remember to return just json enclosed by <json></json>
Lets go through it step by step
-8
View File
@@ -1,8 +0,0 @@
Do the search results cited below
###TESEARCH###
Support the idea that the following happened:
###TECLAIM###
Respond with "CONFIDENCE", followed by : followed by a confidence score (VERYHIGH, HIGH, MEDIUM, LOW, VERYLOW) followed by : followed by the reason. Use no other words, just return the score and reason in format.
Dates can be off by a few days, that would still be valid
+1
View File
@@ -9,6 +9,7 @@ export const ProposedTriggerEvent = z.object({
ReasoningWhyRelevant: z.string(),
SearchQuery: z.string(),
Url: z.url(),
Date: z.string(),
context: z.string().optional(),
score: z.number().optional()
})
+32
View File
@@ -0,0 +1,32 @@
import axios from "axios";
export async function evaluateWithEnsemble({
answer,
method
}: {
answer: string;
method: string
}): Promise<{ validProb: number; invalidProb: number; }> {
const res = await axios.post(process.env.RANKING_URL ?? "http://localhost:8000/evaluate", {
answer,
method
}, {timeout: 0});
// console.log(res.data)
const validProb = res.data["probabilities"][0][0]
const invalidProb = res.data["probabilities"][0][1] + res.data["probabilities"][0][2]
return {validProb, invalidProb};
}
// import dotenv from "dotenv";
// dotenv.config();
// let res = await evaluateWithEnsemble({method:"flan" ,answer: "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in MarchAugust 2020)"});
// console.log(res)
// res = await evaluateWithEnsemble({method:"roberta" ,answer: "Multiple mirrored reuploads (20202023) put the clip on other channels with titles implying it was a genuine 1970s public information film."});
// console.log(res)
// res = await evaluateWithEnsemble({method:"logreg" ,answer: "The COVID-19 Pandemic"});
// console.log(res)
-22
View File
@@ -1,22 +0,0 @@
import axios from "axios";
export async function evaluateWithRagas({
question,
answer,
contexts,
}: {
question: string;
answer: string;
contexts: string[];
}) {
const res = await axios.post("http://localhost:8001/evaluate", {
question,
answer,
contexts,
});
return res.data;
}
// let res = await evaluateWithRagas({question: "Who was Bill Nye", answer: "Bill Nye was a Scientist", contexts: ["Bill nye was a Scientist"]});
// console.log(res)
+16 -2
View File
@@ -15,6 +15,8 @@ const CACHE_PATH = "../data/csv.cache.json";
const JSONL_PATH = "../data/input.jsonl"
const BM25_MIN_DOCS = 3;
type EmbeddingCache = {
rawtexts: string[];
cleantexts: string[];
@@ -287,8 +289,20 @@ async function embedText(text: string): Promise<number[]> {
}
function buildBM25(texts: string[]) {
logger.info("Building BM25 index (%s docs)...", texts.length);
let paddedTexts = texts;
if (texts.length < BM25_MIN_DOCS) {
const needed = BM25_MIN_DOCS - texts.length;
logger.error(
"Corpus too small for BM25 (%s docs, need %s+), padding with %s dummy doc(s)",
texts.length,
BM25_MIN_DOCS,
needed
);
paddedTexts = [...texts, ...Array(needed).fill("placeholder dummy document")];
}
logger.info("Building BM25 index (%s docs)...", paddedTexts.length);
const bm25 = bm25Factory();
bm25.defineConfig({
@@ -302,7 +316,7 @@ function buildBM25(texts: string[]) {
nlp.tokens.removeWords,
]);
texts.forEach((text, i) => {
paddedTexts.forEach((text, i) => {
bm25.addDoc({ text }, i);
});
-25
View File
@@ -1,25 +0,0 @@
import axios from "axios";
export async function evaluateWithRoberta({
answer
}: {
answer: string;
}): Promise<{ validProb: number; invalidProb: number; }> {
const res = await axios.post("http://localhost:8000/evaluate", {
answer
});
// console.log(res.data)
const validProb = res.data["probabilities"][0][0]
const invalidProb = res.data["probabilities"][0][1] + res.data["probabilities"][0][2]
return {validProb, invalidProb};
}
// let res = await evaluateWithRoberta({answer: "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in MarchAugust 2020)"});
// console.log(res)
// res = await evaluateWithRoberta({answer: "Multiple mirrored reuploads (20202023) put the clip on other channels with titles implying it was a genuine 1970s public information film."});
// console.log(res)
// res = await evaluateWithRoberta({answer: "The COVID-19 Pandemic"});
// console.log(res)
+90 -27
View File
@@ -1,32 +1,95 @@
import { Builder, Browser } from "selenium-webdriver";
import firefox from "selenium-webdriver/firefox";
import { backOff } from "exponential-backoff";
import { logger } from "../utils/logger";
export async function extractWebpageContent(url: string) : Promise<string[]>{
const options = new firefox.Options();
options.addArguments("--headless");
let driver = await new Builder().forBrowser(Browser.FIREFOX).setFirefoxOptions(options).build()
try {
await driver.get(url)
await driver.wait(async () => {
return await driver.executeScript(
"return document.readyState === 'complete'"
);
}, 5000);
const readableText = await driver.executeScript(
"return document.body.innerText;"
) as string;
const filteredLines = readableText
.split(/\r?\n/)
.map(line => line.trim())
.filter(line => line.split(/\s+/).length > 1);
return filteredLines;
} finally {
await driver.quit()
}
export async function extractWebpageContent(url: string): Promise<string[]> {
try {
const response = await backOff(async () => {
return await extractWebpageContentWorker(url);
}, {
numOfAttempts: 10,
startingDelay: 500,
timeMultiple: 2,
jitter: "full",
maxDelay: 50000,
});
return response;
} catch (err: any) {
logger.error(`Failed out of retry loop for URL "${url}", returning placeholder to pipeline`);
return ["API EXCEPTION"];
}
}
//console.log(await extractWebpageContent("https://www.bbc.co.uk/news/live/c74wd01egvyt"))
async function extractWebpageContentWorker(url: string): Promise<string[]> {
let driver;
try {
const options = new firefox.Options();
options.addArguments("--headless");
options.addArguments("--disable-gpu");
options.addArguments("--no-sandbox"); // Linux sandbox issues
options.addArguments("--disable-dev-shm-usage"); // /dev/shm issues
driver = await new Builder()
.forBrowser(Browser.FIREFOX)
.setFirefoxOptions(options)
.build();
} catch (err: any) {
const desc = `Failed to launch Firefox driver: ${err.message}`;
logger.error(desc);
throw new Error(desc);
}
try {
try {
await driver.get(url);
} catch (err: any) {
const desc = `Failed to navigate to URL "${url}": ${err.message}`;
logger.error(desc);
throw new Error(desc);
}
try {
await driver.wait(async () => {
return await driver.executeScript(
"return document.readyState === 'complete'"
);
}, 5000);
} catch (err: any) {
logger.error(`Page load timed out for "${url}", attempting to read partial content: ${err.message}`);
// do not throw, attempt to read
}
let readableText: string;
try {
readableText = await driver.executeScript(
"return document.body.innerText;"
) as string;
} catch (err: any) {
const desc = `Failed to extract page text from "${url}": ${err.message}`;
logger.error(desc);
throw new Error(desc);
}
const filteredLines = readableText
.split(/\r?\n/)
.map(line => line.trim())
.filter(line => line.split(/\s+/).length > 1);
if (filteredLines.length === 0) {
const desc = `No content extracted from "${url}"`;
logger.error(desc);
throw new Error(desc);
}
return filteredLines;
} finally {
try {
await driver.quit();
} catch (err: any) {
logger.error(`Failed to quit Firefox driver cleanly: ${err.message}`);
}
}
}
// console.log(await extractWebpageContent("https://www.bbc.co.uk/news/live/c74wd01egvyt"))
// console.log(await extractWebpageContent("https://badcertificate.int.jeynes.uk/"))
+14 -17
View File
@@ -1,38 +1,35 @@
import { END, START, StateGraph } from "@langchain/langgraph";
import { MessagesState } from "./state";
import { verificationSetup } from "./nodes/verificationSetup";
import { ragasMetrics } from "./nodes/ragasMetrics";
import { produceRanking } from "./nodes/produceRanking";
import { createModelNode } from "./nodes/model";
import { loopEndConditional } from "./conditionals/loop_end";
import { sort } from "./nodes/sort";
import { robertaMetrics } from "./nodes/robertaMetrics";
import { createEnsembleNode } from "./nodes/ensembleNode";
const verificationModel = createModelNode([], "verify.txt");
const relationModel = createModelNode([], "relation.txt");
const roNode = createEnsembleNode("ROBERTA", "roberta");
const flNode = createEnsembleNode("FLAN", "flan");
const lrNode = createEnsembleNode("REGRESSION", "logreg");
const agent = new StateGraph(MessagesState)
//NODES
.addNode(verificationSetup.name, verificationSetup)
// .addNode("verificationModel", verificationModel)
// .addNode(ragasMetrics.name, ragasMetrics)
.addNode(robertaMetrics.name, robertaMetrics)
// .addNode("relationModel", relationModel)
.addNode("roNode", roNode)
.addNode("flNode", flNode)
.addNode("lrNode", lrNode)
.addNode(produceRanking.name, produceRanking)
.addNode(sort.name, sort)
.addEdge(START, verificationSetup.name)
// .addEdge(verificationSetup.name, "verificationModel")
// .addEdge(verificationSetup.name, ragasMetrics.name)
.addEdge(verificationSetup.name, robertaMetrics.name)
// .addEdge(verificationSetup.name, "relationModel")
// .addEdge(ragasMetrics.name, produceRanking.name)
.addEdge(robertaMetrics.name, produceRanking.name)
// .addEdge("verificationModel", produceRanking.name)
// .addEdge("relationModel", produceRanking.name)
.addEdge(verificationSetup.name, "roNode")
.addEdge(verificationSetup.name, "flNode")
.addEdge(verificationSetup.name, "lrNode")
.addEdge("roNode", produceRanking.name)
.addEdge("flNode", produceRanking.name)
.addEdge("lrNode", produceRanking.name)
// @ts-expect-error
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
+5 -5
View File
@@ -8,10 +8,10 @@ run_agent () {
npx @langchain/langgraph-cli dev
}
run_ragas_service () {
echo "Starting RAGAS service..."
run_ensemble_service () {
echo "Starting Ensemble service..."
cd "supporting/RAGAS_Service"
.venv/bin/uvicorn ragas_service:app --port 8001
.venv/bin/uvicorn ensemble_service:app --timeout-keep-alive 300
}
run_frontend () {
@@ -34,13 +34,13 @@ run_wrapper () {
case "$1" in
agent) run_agent ;;
ragas_service) run_ragas_service ;;
ensemble_service) run_ensemble_service ;;
frontend) run_frontend ;;
fetch) run_fetch ;;
wrapper) run_wrapper ;;
*)
echo "Unknown command: $1"
echo "Usage: ./runproject [agent|ragas_service|frontend|fetch|wrapper]"
echo "Usage: ./runproject [agent|ensemble_service|frontend|fetch|wrapper]"
exit 1
;;
esac
+2
View File
@@ -1,7 +1,9 @@
# -- OURS --
results/
roberta_classifier/
roberta_distilled_classifier/
roberta_classifier*/
*.pt
output*
# -- THEIRS --
+25
View File
@@ -0,0 +1,25 @@
# Classifier work for evaluating model quality
Made using a dataset of 1000 labeled claims from MVP pipeline.
Roberta model trained on an augmented dataset with LLM generated adversarial examples for low frequency labels.
Flan model trained using raw labelled claims, inherrent natural language ability allows for pattern recognition without the need for fake data.
Regression model trained using the roberta dataset.
Used ensemble model in the final version, with the component models available on Hugging Face.
| Model | % Correct | % Valid taken forward|Used in ensemble|Link
|------------------------------------------------------------|-----------|----------------------|----------------|-
| Original | 53.22 | 61.72 |
| Original (RAGAS) | 56.01 | 57.73 |
| Roberta (base) | 75 | 70 |
| Roberta (Generated Data) | 76 | 71 |
| Roberta (Generated Data + Back Translation) | 74 | 71 |
| Roberta (Generated Data + Back Translation + Thresholding) | 77 | 90 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis)
| Distilled Roberta | 72.73 | 69.57 |
| Flan | 79.17 | 85.71 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis-Flan)
| Simple Regression Model | 74.77 | 85.71 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis-Regression)
| Ensemble Model (weighted confidence score sum) | 84.21 | 83.33 |
| Ensemble Model (majority voting) | 80.2 | 95.12 |
@@ -0,0 +1,224 @@
from pydantic import BaseModel
from fastapi import FastAPI
import torch
import torch.nn as nn
import os
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
app = FastAPI()
############################################
# SCHEMA
############################################
class EvalRequest(BaseModel):
answer: str
method: str # "logreg", "roberta", "flan"
############################################
# REGRESSION MODEL
############################################
HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression"
MODEL_FILENAME = "logreg_classifier.pt"
CACHE_DIR = "./model_cache"
def load_checkpoint(repo_id: str, filename: str, cache_dir: str) -> dict:
local_path = os.path.join(cache_dir, filename)
if not os.path.exists(local_path):
os.makedirs(cache_dir, exist_ok=True)
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=cache_dir)
return torch.load(local_path, map_location="cpu")
class LogisticNet(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes),
)
def forward(self, x):
return self.net(x)
checkpoint = load_checkpoint(HF_REPO_ID, MODEL_FILENAME, CACHE_DIR)
encoder = SentenceTransformer(checkpoint["embedding_model"])
logreg_model = LogisticNet(
checkpoint["input_dim"],
checkpoint["hidden_dim"],
checkpoint["num_classes"],
checkpoint["dropout"],
)
logreg_model.load_state_dict(checkpoint["model_state"])
logreg_model.eval()
############################################
# ROBERTA
############################################
ROBERTA_PATH = "WillJeynes/LLMsForDisinformationAnalysis"
roberta_tokenizer = RobertaTokenizer.from_pretrained(ROBERTA_PATH)
roberta_model = RobertaForSequenceClassification.from_pretrained(ROBERTA_PATH)
roberta_model.eval()
############################################
# FLAN
############################################
FLAN_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan"
INT_TO_LABEL = {
0: "perfect",
1: "story",
2: "not specific",
}
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
flan_tokenizer = AutoTokenizer.from_pretrained(FLAN_PATH)
flan_model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_PATH)
device = torch.device("cpu")
flan_model.to(device)
flan_model.eval()
label_token_ids = {
label: flan_tokenizer(label, add_special_tokens=False).input_ids[0]
for label in LABEL_TO_INT.keys()
}
def format_prompt(text: str) -> str:
return (
"Classify the following event into one of these categories: "
"perfect, story, not specific.\n\n"
f"Event: {text}\n\n"
"Category:"
)
def parse_generated_label(generated: str):
generated = generated.strip().lower()
for label_text, label_int in LABEL_TO_INT.items():
if label_text in generated:
return label_int
return None
############################################
# ENDPOINT
############################################
@app.post("/evaluate")
def evaluate(req: EvalRequest):
method = req.method.lower()
########################################
# LOGREG
########################################
if method == "logreg":
embedding = encoder.encode(
[req.answer],
normalize_embeddings=True,
show_progress_bar=False,
)
x = torch.tensor(embedding, dtype=torch.float32)
with torch.no_grad():
logits = logreg_model(x)
probs = torch.softmax(logits, dim=1)
return {
"method": "logreg",
"probabilities": probs.cpu().numpy().tolist()
}
########################################
# ROBERTA
########################################
elif method == "roberta":
inputs = roberta_tokenizer(
req.answer,
return_tensors="pt",
truncation=True,
padding=True
)
with torch.no_grad():
logits = roberta_model(**inputs).logits
probs = torch.softmax(logits, dim=1)
return {
"method": "roberta",
"probabilities": probs.cpu().numpy().tolist()
}
########################################
# FLAN
########################################
elif method == "flan":
prompt = format_prompt(req.answer)
inputs = flan_tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256,
).to(device)
with torch.no_grad():
outputs = flan_model.generate(**inputs, max_new_tokens=8)
decoder_input_ids = torch.tensor(
[[flan_model.config.decoder_start_token_id]]
).to(device)
logits_output = flan_model(
**inputs,
decoder_input_ids=decoder_input_ids
)
logits = logits_output.logits[:, 0, :]
generated_text = flan_tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
label_logits = torch.tensor(
[logits[0, tid].item() for tid in label_token_ids.values()]
)
label_probs = torch.softmax(label_logits, dim=0).tolist()
return {
"method": "flan",
"generated": generated_text,
"probabilities": [label_probs],
}
########################################
# INVALID METHOD
########################################
else:
return {
"error": "Invalid method. Use 'logreg', 'roberta', or 'flan'."
}
+89
View File
@@ -0,0 +1,89 @@
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from fastapi import FastAPI
app = FastAPI()
MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan"
INT_TO_LABEL = {
0: "perfect",
1: "story",
2: "not specific",
}
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def format_prompt(text: str) -> str:
return (
"Classify the following event into one of these categories: "
"perfect, story, not specific.\n\n"
f"Event: {text}\n\n"
"Category:"
)
def parse_generated_label(generated: str) -> int | None:
generated = generated.strip().lower()
for label_text, label_int in LABEL_TO_INT.items():
if label_text in generated:
return label_int
return None
class EvalRequest(BaseModel):
answer: str
@app.post("/evaluate")
def evaluate(req: EvalRequest):
prompt = format_prompt(req.answer)
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256,
).to(device)
with torch.no_grad():
# Get the generated label
outputs = model.generate(
**inputs,
max_new_tokens=8,
)
# Produce a confidence score
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
logits_output = model(**inputs, decoder_input_ids=decoder_input_ids)
logits = logits_output.logits[:, 0, :]
# Decode the generated text label
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
predicted_int = parse_generated_label(generated_text)
# Extract probabilities
label_token_ids = {
label: tokenizer(label, add_special_tokens=False).input_ids[0]
for label in LABEL_TO_INT.keys()
}
label_logits = torch.tensor(
[logits[0, tid].item() for tid in label_token_ids.values()]
)
label_probs = torch.softmax(label_logits, dim=0).tolist()
return {
"generated": generated_text,
"probabilities": [label_probs],
}
@@ -0,0 +1,82 @@
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import os
app = FastAPI()
HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression"
MODEL_FILENAME = "logreg_classifier.pt"
CACHE_DIR = "./model_cache"
def load_checkpoint(repo_id: str, filename: str, cache_dir: str) -> dict:
local_path = os.path.join(cache_dir, filename)
if not os.path.exists(local_path):
print(f"Downloading {filename} from {repo_id}...")
os.makedirs(cache_dir, exist_ok=True)
downloaded = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=cache_dir,
)
print(f"Saved to {downloaded}")
else:
print(f"Using cached model at {local_path}")
return torch.load(local_path, map_location="cpu")
class LogisticNet(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes),
)
def forward(self, x):
return self.net(x)
checkpoint = load_checkpoint(HF_REPO_ID, MODEL_FILENAME, CACHE_DIR)
encoder = SentenceTransformer(checkpoint["embedding_model"])
model = LogisticNet(
input_dim = checkpoint["input_dim"],
hidden_dim = checkpoint["hidden_dim"],
num_classes = checkpoint["num_classes"],
dropout = checkpoint["dropout"],
)
model.load_state_dict(checkpoint["model_state"])
model.eval()
class EvalRequest(BaseModel):
answer: str
@app.post("/evaluate")
def evaluate(req: EvalRequest):
embedding = encoder.encode(
[req.answer],
normalize_embeddings=True,
show_progress_bar=False,
)
x = torch.tensor(embedding, dtype=torch.float32)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)
return {
"probabilities": probs.cpu().numpy().tolist()
}
@@ -9,6 +9,7 @@ datasets
# ROBERTA
scikit-learn
transformers[torch]
sentence_transformers
# Utils
numpy
+1 -1
View File
@@ -5,7 +5,7 @@ from fastapi import FastAPI
app = FastAPI()
MODEL_PATH = "./roberta_classifier"
MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis"
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
+227
View File
@@ -0,0 +1,227 @@
from sklearn.utils import compute_class_weight
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from collections import Counter
import sys
import csv
import numpy as np
NUM_CLASSES = 3
model_name = "google/flan-t5-base"
INT_TO_LABEL = {
0: "perfect",
1: "story",
2: "not specific",
}
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
LABEL_PRIORITY = [
("PERFECT", 0),
("STORY", 1),
("NSPECIFIC", 2),
("REWORDING", 1),
("TINCORRECT", -1),
("DUPLICATE", -1),
("", 0),
]
def label_to_int(extra_info: str) -> int:
if extra_info is None:
extra_info = ""
extra_info = extra_info.strip()
if extra_info == "":
for key, value in LABEL_PRIORITY:
if key == "":
return value
raise ValueError("Empty extra_info but no empty mapping defined")
tokens = set(extra_info.upper().split())
for key, value in LABEL_PRIORITY:
if key == "" :
continue
if key in tokens:
return value
raise ValueError(f"Unknown label content: '{extra_info}'")
def load_dataset_from_csv(path):
texts = []
labels = []
removed_rows = 0
with open(path, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader, start=1):
text = row["event"]
label_str = row["extra_info"]
try:
label_int = label_to_int(label_str)
except Exception as e:
print(f"ERROR converting label on line {i}: {label_str}")
print(e)
sys.exit(1)
if label_int == -1:
removed_rows += 1
continue
texts.append(text)
labels.append(label_int)
print(f"Loaded {len(texts)} samples (removed {removed_rows})")
return texts, labels
def format_prompt(text: str) -> str:
return (
"Classify the following event into one of these categories: "
"perfect, story, not specific.\n\n"
f"Event: {text}\n\n"
"Category:"
)
def parse_generated_label(generated: str) -> int:
generated = generated.strip().lower()
for label_text, label_int in LABEL_TO_INT.items():
if label_text in generated:
return label_int
print("invlid label:" + generated)
return -1 # unknown / unparseable output
class GenerativeTextDataset(torch.utils.data.Dataset):
def __init__(self, texts, labels, tokenizer, max_input_length=256, max_target_length=8):
self.tokenizer = tokenizer
self.max_input_length = max_input_length
self.max_target_length = max_target_length
self.inputs = [format_prompt(t) for t in texts]
# Convert integer labels to their text equivalents for the target sequence
self.targets = [INT_TO_LABEL[l] for l in labels]
self.int_labels = labels # keep originals for metric computation
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
model_inputs = self.tokenizer(
self.inputs[idx],
max_length=self.max_input_length,
truncation=True,
padding=False,
)
target_encoding = self.tokenizer(
self.targets[idx],
max_length=self.max_target_length,
truncation=True,
padding=False,
)
# Seq2Seq convention: labels use -100 to ignore padding tokens in loss
labels = target_encoding["input_ids"]
labels = [token if token != self.tokenizer.pad_token_id else -100 for token in labels]
model_inputs["labels"] = labels
return {k: torch.tensor(v) for k, v in model_inputs.items()}
def compute_metrics_generative(eval_pred, tokenizer):
predictions, label_ids = eval_pred
# Decode predictions
# Replace -100 in labels before decoding
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
# Map decoded text back to integer labels
pred_ints = [parse_generated_label(p) for p in decoded_preds]
true_ints = [parse_generated_label(l) for l in decoded_labels]
# Filter out any rows where parsing failed
valid = [(p, t) for p, t in zip(pred_ints, true_ints) if t != -1]
if not valid:
return {"accuracy": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
preds_filtered, true_filtered = zip(*valid)
return {
"accuracy": accuracy_score(true_filtered, preds_filtered),
"f1": f1_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
"precision": precision_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
"recall": recall_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
}
def main():
torch.multiprocessing.set_start_method('spawn', force=True)
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
texts, labels = load_dataset_from_csv("../../data/classify.csv")
print("Dataset size:", len(texts))
print("Label distribution:", Counter(labels))
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
train_texts, val_texts, train_labels, val_labels = train_test_split(
texts, labels,
test_size=0.2,
random_state=42,
stratify=labels
)
train_dataset = GenerativeTextDataset(train_texts, train_labels, tokenizer)
val_dataset = GenerativeTextDataset(val_texts, val_labels, tokenizer)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
padding=True,
label_pad_token_id=-100,
)
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
learning_rate=5e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=10,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
predict_with_generate=True,
generation_max_length=8,
dataloader_num_workers=0,
dataloader_pin_memory=False,
fp16=False,
max_grad_norm=1.0,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=lambda ep: compute_metrics_generative(ep, tokenizer),
)
trainer.train()
metrics = trainer.evaluate()
print("\nFinal evaluation metrics:")
for k, v in metrics.items():
print(f" {k}: {v}")
trainer.save_model("./flan_classifier")
tokenizer.save_pretrained("./flan_classifier")
if __name__ == "__main__":
main()
@@ -0,0 +1,209 @@
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
from sklearn.utils import compute_class_weight
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from collections import Counter
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import csv
import sys
NUM_CLASSES = 3
EMBEDDING_MODEL = "all-mpnet-base-v2"
HIDDEN_DIM = 256
DROPOUT = 0.4
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 1e-4
BATCH_SIZE = 64
NUM_EPOCHS = 30
PATIENCE = 5
LABEL_PRIORITY = [
("PERFECT", 0),
("STORY", 1),
("NSPECIFIC", 2),
("REWORDING", 1),
("TINCORRECT", -1),
("DUPLICATE", -1),
("", 0), # fallback to PERFECT
]
def label_to_int(extra_info: str) -> int:
if extra_info is None:
extra_info = ""
extra_info = extra_info.strip()
if extra_info == "":
for key, value in LABEL_PRIORITY:
if key == "":
return value
raise ValueError("No empty-string fallback defined in LABEL_PRIORITY")
tokens = set(extra_info.upper().split())
for key, value in LABEL_PRIORITY:
if key and key in tokens:
return value
raise ValueError(f"Unknown label content: '{extra_info}'")
def load_dataset_from_csv(path: str):
texts, labels = [], []
removed = 0
with open(path, newline="", encoding="utf-8") as f:
for i, row in enumerate(csv.DictReader(f), start=1):
try:
label_int = label_to_int(row["extra_info"])
except Exception as e:
print(f"ERROR on line {i}: {row['extra_info']!r}")
print(e)
sys.exit(1)
if label_int == -1:
removed += 1
continue
texts.append(row["event"])
labels.append(label_int)
print(f"Loaded {len(texts)} samples (removed {removed})")
return texts, labels
class LogisticNet(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes), # raw logits loss handles softmax
)
def forward(self, x):
return self.net(x)
def evaluate(model, loader, device):
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for xb, yb in loader:
xb, yb = xb.to(device), yb.to(device)
logits = model(xb)
preds = logits.argmax(dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(yb.cpu().numpy())
return {
"accuracy": accuracy_score(all_labels, all_preds),
"f1": f1_score(all_labels, all_preds, average="weighted", zero_division=0),
"precision": precision_score(all_labels, all_preds, average="weighted", zero_division=0),
"recall": recall_score(all_labels, all_preds, average="weighted", zero_division=0),
}
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
texts, labels = load_dataset_from_csv("../../data/classify.csv")
print("Label distribution:", Counter(labels))
print(f"\nEncoding with '{EMBEDDING_MODEL}'")
encoder = SentenceTransformer(EMBEDDING_MODEL)
embeddings = encoder.encode(texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True)
input_dim = embeddings.shape[1]
print(f"Embedding dim: {input_dim}")
X_train, X_val, y_train, y_val = train_test_split(
embeddings, labels, test_size=0.2, random_state=42, stratify=labels
)
class_weights = compute_class_weight("balanced", classes=np.unique(y_train), y=y_train)
weight_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
print("Class weights:", class_weights)
def make_loader(X, y, shuffle=False):
ds = TensorDataset(
torch.tensor(X, dtype=torch.float32),
torch.tensor(y, dtype=torch.long),
)
return DataLoader(ds, batch_size=BATCH_SIZE, shuffle=shuffle)
train_loader = make_loader(X_train, y_train, shuffle=True)
val_loader = make_loader(X_val, y_val, shuffle=False)
model = LogisticNet(input_dim, HIDDEN_DIM, NUM_CLASSES, DROPOUT).to(device)
criterion = nn.CrossEntropyLoss(weight=weight_tensor)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
best_f1 = 0.0
best_state = None
epochs_no_imp = 0
print("\n Training:")
for epoch in range(1, NUM_EPOCHS + 1):
model.train()
total_loss = 0.0
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
loss = criterion(model(xb), yb)
loss.backward()
optimizer.step()
total_loss += loss.item() * len(yb)
scheduler.step()
avg_loss = total_loss / len(train_loader.dataset)
val_metrics = evaluate(model, val_loader, device)
print(
f"Epoch {epoch:3d}/{NUM_EPOCHS} | "
f"loss {avg_loss:.4f} | "
f"val_acc {val_metrics['accuracy']:.4f} | "
f"val_f1 {val_metrics['f1']:.4f}"
)
# Early stopping on weighted F1
if val_metrics["f1"] > best_f1:
best_f1 = val_metrics["f1"]
best_state = {k: v.clone() for k, v in model.state_dict().items()}
epochs_no_imp = 0
else:
epochs_no_imp += 1
if epochs_no_imp >= PATIENCE:
print(f"Early stopping at epoch {epoch} (no improvement for {PATIENCE} epochs)")
break
print("\n Final evaluation:")
model.load_state_dict(best_state)
final = evaluate(model, val_loader, device)
for k, v in final.items():
print(f" {k}: {v:.4f}")
torch.save(
{
"model_state": best_state,
"input_dim": input_dim,
"hidden_dim": HIDDEN_DIM,
"num_classes": NUM_CLASSES,
"dropout": DROPOUT,
"embedding_model": EMBEDDING_MODEL,
},
"logreg_classifier.pt"
)
print("\n Model saved to logreg_classifier.pt")
if __name__ == "__main__":
main()
+15 -34
View File
@@ -1,6 +1,6 @@
from sklearn.utils import compute_class_weight
from torch.nn import CrossEntropyLoss
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
@@ -10,7 +10,7 @@ import csv
import numpy as np
NUM_CLASSES = 3
model_name = "distilbert/distilroberta-base"
model_name = "roberta-base"
LABEL_PRIORITY = [
("PERFECT", 0),
@@ -29,21 +29,12 @@ class WeightedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.get("labels")
# print("DBG: Before forward")
outputs = model(**inputs)
# print("DBG: After forward")
logits = outputs.get("logits")
# loss_fct = CrossEntropyLoss(weight=self.class_weights.to(logits.device))
loss_fct = CrossEntropyLoss(
weight=self.class_weights.to(logits.device).to(logits.dtype)
)
# loss_fct = CrossEntropyLoss()
# print("DBG: Before loss")
loss_fct = CrossEntropyLoss(weight=self.class_weights.to(logits.device))
loss = loss_fct(logits, labels)
# loss.backward()
# print("DBG: After loss")
return (loss, outputs) if return_outputs else loss
def label_to_int(extra_info: str) -> int:
@@ -129,23 +120,17 @@ def main():
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
texts, labels = load_dataset_from_csv("../../data/classify.csv")
# tokenizer = RobertaTokenizer.from_pretrained(model_name, hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2)
# model = RobertaForSequenceClassification.from_pretrained(
# model_name,
# num_labels=NUM_CLASSES
# )
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
tokenizer = RobertaTokenizer.from_pretrained(model_name, hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2)
model = RobertaForSequenceClassification.from_pretrained(
model_name,
num_labels=NUM_CLASSES
)
# for param in model.deberta.parameters():
# param.requires_grad = True
for param in model.roberta.parameters():
param.requires_grad = False
# for param in model.deberta.encoder.layer[-6:].parameters():
# param.requires_grad = True
for param in model.roberta.encoder.layer[-6:].parameters():
param.requires_grad = True
print("Dataset size:", len(texts))
print("Label distribution:")
@@ -155,8 +140,7 @@ def main():
texts,
labels,
test_size=0.2,
random_state=42,
stratify=labels
random_state=42
)
@@ -189,7 +173,6 @@ def main():
self.labels = labels
def __getitem__(self, idx):
# print(f"DBG: Loading item {idx}")
item = {
key: torch.tensor(val[idx])
for key, val in self.encodings.items()
@@ -204,8 +187,7 @@ def main():
output_dir="./results",
learning_rate=2e-5,
per_device_train_batch_size=32,
# gradient_accumulation_steps=2,
num_train_epochs=15,
num_train_epochs=5,
weight_decay=0.01,
load_best_model_at_end=True,
eval_strategy="epoch",
@@ -213,8 +195,7 @@ def main():
metric_for_best_model="f1",
greater_is_better=True,
dataloader_num_workers=4,
dataloader_pin_memory=True,
# warmup_steps=100,
dataloader_pin_memory=True
)
train_dataset = TextDataset(train_encodings, train_labels)
@@ -237,8 +218,8 @@ def main():
for k, v in metrics.items():
print(f"{k}: {v}")
trainer.save_model("./roberta_distilled_classifier")
tokenizer.save_pretrained("./roberta_distilled_classifier")
trainer.save_model("./roberta_classifier")
tokenizer.save_pretrained("./roberta_classifier")
+2 -2
View File
@@ -17,7 +17,7 @@ const AGENT_NAME = process.env.AGENT ?? "agent";
*/
const MODE = process.env.MODE ?? "claim";
const MAX_CONCURRENCY = 5;
const MAX_CONCURRENCY = 1;
const client = new Client({ apiUrl: API_URL });
@@ -118,7 +118,7 @@ async function processRecord(record: any): Promise<ResultRecord> {
input: buildAgentInput(record),
streamMode: "values",
config: {
recursion_limit: 50
recursion_limit: 100
}
});
+4 -4
View File
@@ -16,18 +16,18 @@ BASE_URL = "https://dbkf.ontotext.com/rest-api/search/documents"
# "documentTypes": "http://schema.org/Claim",
DEFAULT_PARAMS = [
("concept", "http://weverify.eu/resource/Concept/Q212"),
("documentTypes", "http://schema.org/Claim"),
("from", "2000-01-01"),
("to", "2026-02-19"),
("lang", "en"),
("limit", 5000),
("limit", 7000),
("page", 1),
("orderBy", "date"),
("organization", "http://weverify.eu/resource/Organization/128573c5d49d37558706194e755f152d"), # Science Direct
("organization", "http://weverify.eu/resource/Organization/3727f7b2aa90ec0716693e5464b28d18"), # StopFake
("organization", "http://weverify.eu/resource/Organization/c71953fa6cf24ac4178f751c77862070"), # CheckYourFact
]
NUM_RANDOM_CLAIMS = 40
NUM_RANDOM_CLAIMS = 200
INPUT_FILE = "../../data/input.jsonl"
OUTPUT_FILE = "../../data/claims.json"
+16 -5
View File
@@ -5,7 +5,6 @@ import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
# THRESH = 0.4
THRESH = 0.6
def page_title() -> str:
@@ -61,6 +60,18 @@ def render():
return
for file_path in jsonl_files:
thresh = THRESH
if ("flan" in file_path.name):
thresh = 0.94
if ("regression" in file_path.name):
thresh = 0.75
if ("ensemble" in file_path.name):
thresh = 0.1
if ("ensemble" in file_path.name and "2" in file_path.name):
thresh = 0.4
if ("ensemble" in file_path.name and "vot" in file_path.name):
thresh = 0.7
st.subheader(f"File: {file_path.name}")
confidence_counter = Counter()
@@ -86,15 +97,15 @@ def render():
dup_counter += 1
elif "ranked" not in event:
"ignore for now"
elif score > THRESH and extra_lower == "perfect":
elif score > thresh and extra_lower == "perfect":
confidence_counter["Correct-PERFECT"] += 1
elif score > THRESH and extra_lower == "":
elif score > thresh and extra_lower == "":
confidence_counter["Correct-FINE"] += 1
elif score > THRESH and extra_lower != "perfect" and extra_lower != "":
elif score > thresh and extra_lower != "perfect" and extra_lower != "":
confidence_counter["Over-confident"] += 1
wrong_counter[extra_lower] += 1
overconfident_docs.append(doc_id)
elif score < THRESH and (extra_lower == "perfect" or extra_lower == ""):
elif score < thresh and (extra_lower == "perfect" or extra_lower == ""):
confidence_counter["Under-confident"] += 1
underconfident_docs.append(doc_id)
else:
+78
View File
@@ -0,0 +1,78 @@
from collections import Counter
from pathlib import Path
import json
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
THRESH = 0.4
def page_title() -> str:
return "Statistics 2"
def render():
st.header("Statistics 2")
path = Path("../../data/refinement")
if not path.exists() or not path.is_dir():
st.error("Invalid folder path.")
return
jsonl_files = sorted(path.glob("*.jsonl"))
if not jsonl_files:
st.info("No .jsonl files found in this folder.")
return
for file_path in jsonl_files:
thresh = THRESH
st.subheader(f"File: {file_path.name}")
confidence_counter = Counter()
# ---- Read file line by line ----
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
try:
entry = json.loads(line)
except json.JSONDecodeError:
continue
if (entry.get("status") != "success"):
confidence_counter["Crash"] += 1
for event in entry.get("events", []):
score = event.get("score", None)
if score is not None:
if score == -1:
confidence_counter["BAD-1"] += 1
elif score > thresh:
confidence_counter["PERFECT"] += 1
else:
confidence_counter["BAD"] += 1
if confidence_counter:
df_conf = pd.DataFrame(
confidence_counter.items(),
columns=["Category", "Count"]
)
fig, ax = plt.subplots()
ax.pie(
df_conf["Count"],
labels=df_conf["Category"],
autopct="%1.1f%%",
startangle=90
)
ax.axis("equal")
ax.set_title(file_path.name)
total = sum(confidence_counter.values())
correct = confidence_counter["PERFECT"]
corr_percent = (correct / total) * 100
st.markdown(f"**Correct: {corr_percent:.2f}% ({correct}/{total})**")
st.markdown(f"**Crash: {confidence_counter["Crash"]}**")
st.pyplot(fig, width=500)
else:
st.info("No score data available in this file.")