Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 00e1596be0 | |||
| 070aab6a5c | |||
| bff5423f3d |
@@ -1,3 +1,2 @@
|
|||||||
# TEMP
|
# TEMP
|
||||||
literature/
|
literature/
|
||||||
backup.tar.gz
|
|
||||||
@@ -7,15 +7,6 @@ Final Dissertation Submission Repository
|
|||||||
## Solution Diagram
|
## Solution Diagram
|
||||||
-- todo --
|
-- todo --
|
||||||
|
|
||||||
## Classifier Refinement
|
|
||||||
[See RAGAS_Service](/supporting/RAGAS_Service/)
|
|
||||||
|
|
||||||
## Agent Refinement
|
|
||||||
[See agent](/agent/)
|
|
||||||
|
|
||||||
## Generated Database Link and Usage Experiments
|
|
||||||
-- todo --
|
|
||||||
|
|
||||||
## Repository Structure
|
## Repository Structure
|
||||||
```
|
```
|
||||||
├── run.sh # Bash script to run project elements from one place
|
├── run.sh # Bash script to run project elements from one place
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
## Refining the agent output
|
|
||||||
|
|
||||||
TODO: Table and document experiments
|
|
||||||
+6
-35
@@ -10,23 +10,14 @@ import { createModelNode } from "./nodes/model";
|
|||||||
import { loopEndConditional } from "./conditionals/loop_end";
|
import { loopEndConditional } from "./conditionals/loop_end";
|
||||||
import { sort } from "./nodes/sort";
|
import { sort } from "./nodes/sort";
|
||||||
import { triggerEventSetup } from "./nodes/triggerEventSetup";
|
import { triggerEventSetup } from "./nodes/triggerEventSetup";
|
||||||
import { createEnsembleNode } from "./nodes/ensembleNode";
|
import { robertaMetrics } from "./nodes/robertaMetrics";
|
||||||
import { selfEvalSetup } from "./nodes/selfEvalSetup";
|
|
||||||
|
|
||||||
const triggerEventToolNode = createToolNode(triggerEventToolsByName);
|
const triggerEventToolNode = createToolNode(triggerEventToolsByName);
|
||||||
const peToolNode = createToolNode(triggerEventToolsByName);
|
|
||||||
|
|
||||||
const normalisationModel = createModelNode([], "normalization.txt");
|
const normalisationModel = createModelNode([], "normalization.txt");
|
||||||
const triggerEventModel = createModelNode(triggerEventToolsByName, "trigger.txt");
|
const triggerEventModel = createModelNode(triggerEventToolsByName, "trigger.txt");
|
||||||
const evaluationModel = createModelNode([], "eval.txt");
|
|
||||||
const peModel = createModelNode(triggerEventToolsByName, "posteval.txt");
|
|
||||||
|
|
||||||
const triggerEventToolConditional = createToolConditional("triggerEventToolNode", selfEvalSetup.name);
|
const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name);
|
||||||
const peToolConditional = createToolConditional("peToolNode", verificationSetup.name);
|
|
||||||
|
|
||||||
const roNode = createEnsembleNode("ROBERTA", "roberta");
|
|
||||||
const flNode = createEnsembleNode("FLAN", "flan");
|
|
||||||
const lrNode = createEnsembleNode("REGRESSION", "logreg");
|
|
||||||
|
|
||||||
const agent = new StateGraph(MessagesState)
|
const agent = new StateGraph(MessagesState)
|
||||||
|
|
||||||
@@ -38,17 +29,8 @@ const agent = new StateGraph(MessagesState)
|
|||||||
.addNode("triggerEventToolNode", triggerEventToolNode)
|
.addNode("triggerEventToolNode", triggerEventToolNode)
|
||||||
.addNode("triggerEventModel", triggerEventModel)
|
.addNode("triggerEventModel", triggerEventModel)
|
||||||
|
|
||||||
.addNode(selfEvalSetup.name, selfEvalSetup)
|
|
||||||
.addNode("evaluationModel", evaluationModel)
|
|
||||||
|
|
||||||
.addNode("peToolNode", peToolNode)
|
|
||||||
.addNode("peModel", peModel)
|
|
||||||
|
|
||||||
.addNode(verificationSetup.name, verificationSetup)
|
.addNode(verificationSetup.name, verificationSetup)
|
||||||
|
.addNode(robertaMetrics.name, robertaMetrics)
|
||||||
.addNode("roNode", roNode)
|
|
||||||
.addNode("flNode", flNode)
|
|
||||||
.addNode("lrNode", lrNode)
|
|
||||||
|
|
||||||
.addNode(produceRanking.name, produceRanking)
|
.addNode(produceRanking.name, produceRanking)
|
||||||
.addNode(sort.name, sort)
|
.addNode(sort.name, sort)
|
||||||
@@ -60,23 +42,12 @@ const agent = new StateGraph(MessagesState)
|
|||||||
.addEdge(triggerEventSetup.name, "triggerEventModel")
|
.addEdge(triggerEventSetup.name, "triggerEventModel")
|
||||||
|
|
||||||
// @ts-expect-error
|
// @ts-expect-error
|
||||||
.addConditionalEdges("triggerEventModel", triggerEventToolConditional, ["triggerEventToolNode", selfEvalSetup.name])
|
.addConditionalEdges("triggerEventModel", triggerEventToolConditional, ["triggerEventToolNode", verificationSetup.name])
|
||||||
.addEdge("triggerEventToolNode", "triggerEventModel")
|
.addEdge("triggerEventToolNode", "triggerEventModel")
|
||||||
|
|
||||||
.addEdge(selfEvalSetup.name, "evaluationModel")
|
.addEdge(verificationSetup.name, robertaMetrics.name)
|
||||||
.addEdge("evaluationModel", "peModel")
|
|
||||||
|
|
||||||
// @ts-expect-error
|
.addEdge(robertaMetrics.name, produceRanking.name)
|
||||||
.addConditionalEdges("peModel", peToolConditional, ["peToolNode", verificationSetup.name])
|
|
||||||
.addEdge("peToolNode", "peModel")
|
|
||||||
|
|
||||||
.addEdge(verificationSetup.name, "roNode")
|
|
||||||
.addEdge(verificationSetup.name, "flNode")
|
|
||||||
.addEdge(verificationSetup.name, "lrNode")
|
|
||||||
|
|
||||||
.addEdge("roNode", produceRanking.name)
|
|
||||||
.addEdge("flNode", produceRanking.name)
|
|
||||||
.addEdge("lrNode", produceRanking.name)
|
|
||||||
|
|
||||||
// @ts-expect-error
|
// @ts-expect-error
|
||||||
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
import { GraphNode } from "@langchain/langgraph";
|
|
||||||
import { MessagesState } from "../state";
|
|
||||||
import { AIMessage } from "@langchain/core/messages";
|
|
||||||
import { evaluateWithEnsemble } from "../tools/ensembleCall";
|
|
||||||
|
|
||||||
export function createEnsembleNode(title: string, method: string): GraphNode<typeof MessagesState> {
|
|
||||||
return async (state) => {
|
|
||||||
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
|
|
||||||
|
|
||||||
const result = await evaluateWithEnsemble({ answer, method })
|
|
||||||
const score = result.validProb - result.invalidProb;
|
|
||||||
|
|
||||||
return {
|
|
||||||
messages: [new AIMessage(title + ":" + score)]
|
|
||||||
};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
@@ -2,25 +2,31 @@ import { GraphNode } from "@langchain/langgraph";
|
|||||||
import { MessagesState } from "../state";
|
import { MessagesState } from "../state";
|
||||||
import { BaseMessage } from "@langchain/core/messages";
|
import { BaseMessage } from "@langchain/core/messages";
|
||||||
|
|
||||||
const models = {
|
//TODO: Each of these might need different weights
|
||||||
REGRESSION: 0.3,
|
const keys = ["CONFIDENCE", "RELATION", "RAGAS", "ROBERTA"];
|
||||||
ROBERTA: 0.5,
|
|
||||||
FLAN: 0.3,
|
const mapping = {
|
||||||
|
VERYHIGH: 1.0,
|
||||||
|
HIGH: 0.75,
|
||||||
|
MEDIUM: 0.5,
|
||||||
|
LOW: 0.25,
|
||||||
|
VERYLOW: 0.0,
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
type ModelKey = keyof typeof models;
|
type Priority = keyof typeof mapping;
|
||||||
|
|
||||||
function mapResponse(value: string | undefined | null): number {
|
function mapResponse(value: string | undefined | null): number {
|
||||||
if (!value) return 0;
|
if (!value) return 1;
|
||||||
|
|
||||||
const trimmed = value.trim();
|
const trimmed = value.trim();
|
||||||
const num = parseFloat(trimmed);
|
const num = parseFloat(trimmed);
|
||||||
|
|
||||||
if (!isNaN(num)) {
|
// If number, return it
|
||||||
return num;
|
if (!isNaN(num)) return num;
|
||||||
} else {
|
|
||||||
return 0;
|
// Otherwise, map to value
|
||||||
}
|
const upper = trimmed.toUpperCase() as Priority;
|
||||||
|
return mapping[upper] ?? 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
function getLastMessageContaining(
|
function getLastMessageContaining(
|
||||||
@@ -37,15 +43,15 @@ function getLastMessageContaining(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const produceRanking: GraphNode<typeof MessagesState> = async (state) => {
|
export const produceRanking: GraphNode<typeof MessagesState> = async (state) => {
|
||||||
const values = (Object.keys(models) as ModelKey[]).map((key) => {
|
// Extract and map values
|
||||||
|
const values = keys.map((key) => {
|
||||||
const msg = getLastMessageContaining(state.messages, key);
|
const msg = getLastMessageContaining(state.messages, key);
|
||||||
const part = msg?.split(":").at(1);
|
const part = msg?.split(":").at(1);
|
||||||
const baseValue = mapResponse(part);
|
return mapResponse(part);
|
||||||
|
|
||||||
return baseValue * models[key];
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = values.reduce((acc, val) => acc + val, 0);
|
// Multiply!
|
||||||
|
const result = values.reduce((acc, val) => acc * val, 1);
|
||||||
|
|
||||||
const current = state.proposedTriggerEvent;
|
const current = state.proposedTriggerEvent;
|
||||||
current[state.proposedTriggerEventIndex].score = result;
|
current[state.proposedTriggerEventIndex].score = result;
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import { GraphNode } from "@langchain/langgraph";
|
||||||
|
import { MessagesState } from "../state";
|
||||||
|
import { AIMessage, HumanMessage } from "@langchain/core/messages";
|
||||||
|
import { evaluateWithRagas } from "../tools/ragasCall";
|
||||||
|
|
||||||
|
export const ragasMetrics: GraphNode<typeof MessagesState> = async (state) => {
|
||||||
|
const question = "A possible trigger event for: " + state.disinformationTitle //Should it be raw, or normalized?
|
||||||
|
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
|
||||||
|
const contexts = state.proposedTriggerEvent[state.proposedTriggerEventIndex].context?.split("^^^") ?? []
|
||||||
|
|
||||||
|
const results = await evaluateWithRagas({question, answer, contexts})
|
||||||
|
|
||||||
|
return {
|
||||||
|
messages: [ new AIMessage("RAGAS:" + results.faithfulness)]
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
import { GraphNode } from "@langchain/langgraph";
|
||||||
|
import { MessagesState } from "../state";
|
||||||
|
import { AIMessage } from "@langchain/core/messages";
|
||||||
|
import { evaluateWithRoberta } from "../tools/robertaCall";
|
||||||
|
|
||||||
|
export const robertaMetrics: GraphNode<typeof MessagesState> = async (state) => {
|
||||||
|
const answer = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
|
||||||
|
|
||||||
|
const result = await evaluateWithRoberta({answer})
|
||||||
|
|
||||||
|
|
||||||
|
const score = result.validProb - result.invalidProb;
|
||||||
|
|
||||||
|
|
||||||
|
return {
|
||||||
|
messages: [ new AIMessage("ROBERTA:" + score)]
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
import { GraphNode } from "@langchain/langgraph";
|
|
||||||
import { MessagesState, ProposedTriggerEventArray } from "../state";
|
|
||||||
import { logger } from "../utils/logger";
|
|
||||||
import { queryScraper } from "../tools/webSearch";
|
|
||||||
import { rankAndDisplayData } from "../tools/triggerEventTools";
|
|
||||||
|
|
||||||
export const selfEvalSetup: GraphNode<typeof MessagesState> = async (state) => {
|
|
||||||
let genResponse = state.messages.at(-1)?.content.toString() ?? "";
|
|
||||||
const parsed = ProposedTriggerEventArray.parse(JSON.parse(genResponse));
|
|
||||||
|
|
||||||
for (let i = 0; i < parsed.length; i++) {
|
|
||||||
const search = parsed[i].SearchQuery
|
|
||||||
const data = await queryScraper(search);
|
|
||||||
const output = await rankAndDisplayData(data, search);
|
|
||||||
|
|
||||||
parsed[i].context = output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return { evalTriggerEvent: parsed };
|
|
||||||
|
|
||||||
};
|
|
||||||
@@ -13,6 +13,14 @@ export const verificationSetup: GraphNode<typeof MessagesState> = async (state)
|
|||||||
let genResponse = state.messages.at(-1)?.content.toString() ?? "";
|
let genResponse = state.messages.at(-1)?.content.toString() ?? "";
|
||||||
const parsed = ProposedTriggerEventArray.parse(JSON.parse(genResponse));
|
const parsed = ProposedTriggerEventArray.parse(JSON.parse(genResponse));
|
||||||
|
|
||||||
|
for (let i = 0; i < parsed.length; i++) {
|
||||||
|
const search = parsed[i].SearchQuery
|
||||||
|
const data = await queryScraper(search);
|
||||||
|
const output = await rankAndDisplayData(data, search);
|
||||||
|
|
||||||
|
parsed[i].context = output;
|
||||||
|
}
|
||||||
|
|
||||||
return { proposedTriggerEvent: parsed, proposedTriggerEventIndex: 0 };
|
return { proposedTriggerEvent: parsed, proposedTriggerEventIndex: 0 };
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
You are an impartial and meticulous evaluator assessing LLM’s response based on key quality dimensions
|
|
||||||
of honesty and usefulness. Your goal is to provide structured feedback that can be used to improve the
|
|
||||||
response.
|
|
||||||
Evaluation task: please follow these steps carefully:
|
|
||||||
1. Analyze the response based on the three dimensions below.
|
|
||||||
2. Provide justifications first: write a brief explanation justifying your assessment for each dimension.
|
|
||||||
3. Assign scores after justification: assign a score from 1 (poor) to 10 (excellent) for each dimension
|
|
||||||
based on your justification.
|
|
||||||
4. Synthesize: provide a brief overall impression and the single most important suggestion for improve-
|
|
||||||
ment, keeping in mind that explanation/honesty is the top priority, then followed by guidance.
|
|
||||||
Critique dimensions (evaluate in this order):
|
|
||||||
(1) Speficicity and usefullness: Can the proposed event be used to create a dataset of concrete events mapped to later
|
|
||||||
disinformation.
|
|
||||||
(2) Existance: Using the context provided, can the user be certain that the proposed trigger event actually happened
|
|
||||||
(3) Causality: Is there a possible link from the proposed trigger event to the disinformaiton at hand
|
|
||||||
Overall impression & key improvement suggestion: Briefly summarize the overall quality and state the
|
|
||||||
most critical change needed to improve the response.
|
|
||||||
|
|
||||||
Disinformation query:
|
|
||||||
###NTITLE###
|
|
||||||
Disinformation date:
|
|
||||||
###CDATE###
|
|
||||||
|
|
||||||
LLM’s response to evaluate:
|
|
||||||
###LM###
|
|
||||||
|
|
||||||
Provided context:
|
|
||||||
###VESEARCHES###
|
|
||||||
|
|
||||||
Let's think it through step by step
|
|
||||||
@@ -15,10 +15,6 @@ export async function hydratePrompt(path: string, state: any) : Promise<string>
|
|||||||
raw = raw.replace("###LM###", state.messages.at(-1).content);
|
raw = raw.replace("###LM###", state.messages.at(-1).content);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (raw.indexOf("###L2M###") != -1) {
|
|
||||||
raw = raw.replace("###L2M###", state.messages.at(-2).content);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (raw.indexOf("###NTITLE###") != -1) {
|
if (raw.indexOf("###NTITLE###") != -1) {
|
||||||
raw = raw.replace("###NTITLE###", state.normalizedClaim);
|
raw = raw.replace("###NTITLE###", state.normalizedClaim);
|
||||||
}
|
}
|
||||||
@@ -37,12 +33,5 @@ export async function hydratePrompt(path: string, state: any) : Promise<string>
|
|||||||
raw = raw.replace("###TESEARCH###", output)
|
raw = raw.replace("###TESEARCH###", output)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (raw.indexOf("###VESEARCHES###") != -1) {
|
|
||||||
const output = state.evalTriggerEvent
|
|
||||||
.map(e => e.context)
|
|
||||||
.join("\n")
|
|
||||||
raw = raw.replace("###VESEARCHES###", output)
|
|
||||||
}
|
|
||||||
|
|
||||||
return raw;
|
return raw;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
You are an expert editor tasked with making targeted improvements to an existing LLM’s response based
|
|
||||||
on a specific critique with the primary goal of enhancing its score according to evaluation standards while
|
|
||||||
preserving its strengths.
|
|
||||||
Your revision task: generate a revised version of the existing response. Your goal is not to rewrite it
|
|
||||||
completely, but to make precise edits only to address the specific weaknesses highlighted in the critique.
|
|
||||||
Instructions for editing:
|
|
||||||
- Identify specific flaws: carefully read the critique and pinpoint the exact issues raised (e.g., unclear
|
|
||||||
explanation, vagueness, inappropriate responses, the key suggestion).
|
|
||||||
- Perform minimal targeted edits: modify only the necessary sentences or paragraphs within the existing
|
|
||||||
response to directly fix these identified flaws.
|
|
||||||
- Strongly preserve strengths: crucially keep all other parts of the existing response intact. Do not
|
|
||||||
rephrase, restructure, or remove sections that were not criticized or likely contributed positively to its
|
|
||||||
initial score.
|
|
||||||
- Ensure coherence: verify that your targeted edits integrate smoothly and do not introduce contradictions
|
|
||||||
or awkward phrasing.
|
|
||||||
Output requirements:
|
|
||||||
- It should feel like a slightly polished or corrected version of the existing response, not a fundamentally
|
|
||||||
different answer.
|
|
||||||
- Do not mention the critique, scores, or the editing process. The output should be clean json that passes validation checks
|
|
||||||
|
|
||||||
Again, use a JSON format with each entry containing "Event,ReasoningWhyRelevant,SearchQuery,Url,Date".
|
|
||||||
Use tools available to you if further information is required
|
|
||||||
|
|
||||||
Add no new events, only improve the existing items
|
|
||||||
|
|
||||||
Disinformation query:
|
|
||||||
###NTITLE###
|
|
||||||
Disinformation date:
|
|
||||||
###CDATE###
|
|
||||||
|
|
||||||
LLM’s response to improve:
|
|
||||||
###L2M###
|
|
||||||
|
|
||||||
Citique:
|
|
||||||
###LM###
|
|
||||||
|
|
||||||
This contains specific feedback, justifications, scores from 1 to 10, and potentially a key improvement
|
|
||||||
suggestion. Focus on the justifications for low scores and the key suggestion.
|
|
||||||
|
|
||||||
Let's think it through step by step
|
|
||||||
@@ -14,9 +14,7 @@ Include a concise but specific search query that can be looked up on a search en
|
|||||||
|
|
||||||
Include a url to a source for your trigger event (not a web search, a specific url from a reputuable source). Do not use OAI cite, include url as text in response.
|
Include a url to a source for your trigger event (not a web search, a specific url from a reputuable source). Do not use OAI cite, include url as text in response.
|
||||||
|
|
||||||
Include the date that the event happened ("March 2022" for exmaple)
|
Use a JSON format with each entry containing "Event,ReasoningWhyRelevant,SearchQuery,Url".
|
||||||
|
|
||||||
Use a JSON format with each entry containing "Event,ReasoningWhyRelevant,SearchQuery,Url,Date".
|
|
||||||
|
|
||||||
Multiple tool invocations should be requested at once, if applicable.
|
Multiple tool invocations should be requested at once, if applicable.
|
||||||
Use your abilities to look between the lines and produce some insightful analysis, thinking both short and long term.
|
Use your abilities to look between the lines and produce some insightful analysis, thinking both short and long term.
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ export const ProposedTriggerEvent = z.object({
|
|||||||
ReasoningWhyRelevant: z.string(),
|
ReasoningWhyRelevant: z.string(),
|
||||||
SearchQuery: z.string(),
|
SearchQuery: z.string(),
|
||||||
Url: z.url(),
|
Url: z.url(),
|
||||||
Date: z.string(),
|
|
||||||
context: z.string().optional(),
|
context: z.string().optional(),
|
||||||
score: z.number().optional()
|
score: z.number().optional()
|
||||||
})
|
})
|
||||||
@@ -21,7 +20,6 @@ export const MessagesState = new StateSchema({
|
|||||||
date: z.string(),
|
date: z.string(),
|
||||||
messages: MessagesValue,
|
messages: MessagesValue,
|
||||||
proposedTriggerEvent: ProposedTriggerEventArray,
|
proposedTriggerEvent: ProposedTriggerEventArray,
|
||||||
evalTriggerEvent: ProposedTriggerEventArray,
|
|
||||||
proposedTriggerEventIndex: z.int(),
|
proposedTriggerEventIndex: z.int(),
|
||||||
normalizedClaim: z.string(),
|
normalizedClaim: z.string(),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -15,8 +15,6 @@ const CACHE_PATH = "../data/csv.cache.json";
|
|||||||
|
|
||||||
const JSONL_PATH = "../data/input.jsonl"
|
const JSONL_PATH = "../data/input.jsonl"
|
||||||
|
|
||||||
const BM25_MIN_DOCS = 3;
|
|
||||||
|
|
||||||
type EmbeddingCache = {
|
type EmbeddingCache = {
|
||||||
rawtexts: string[];
|
rawtexts: string[];
|
||||||
cleantexts: string[];
|
cleantexts: string[];
|
||||||
@@ -289,20 +287,8 @@ async function embedText(text: string): Promise<number[]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function buildBM25(texts: string[]) {
|
function buildBM25(texts: string[]) {
|
||||||
let paddedTexts = texts;
|
logger.info("Building BM25 index (%s docs)...", texts.length);
|
||||||
|
|
||||||
if (texts.length < BM25_MIN_DOCS) {
|
|
||||||
const needed = BM25_MIN_DOCS - texts.length;
|
|
||||||
logger.error(
|
|
||||||
"Corpus too small for BM25 (%s docs, need %s+), padding with %s dummy doc(s)",
|
|
||||||
texts.length,
|
|
||||||
BM25_MIN_DOCS,
|
|
||||||
needed
|
|
||||||
);
|
|
||||||
paddedTexts = [...texts, ...Array(needed).fill("placeholder dummy document")];
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info("Building BM25 index (%s docs)...", paddedTexts.length);
|
|
||||||
const bm25 = bm25Factory();
|
const bm25 = bm25Factory();
|
||||||
|
|
||||||
bm25.defineConfig({
|
bm25.defineConfig({
|
||||||
@@ -316,7 +302,7 @@ function buildBM25(texts: string[]) {
|
|||||||
nlp.tokens.removeWords,
|
nlp.tokens.removeWords,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
paddedTexts.forEach((text, i) => {
|
texts.forEach((text, i) => {
|
||||||
bm25.addDoc({ text }, i);
|
bm25.addDoc({ text }, i);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,13 @@
|
|||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
|
|
||||||
export async function evaluateWithEnsemble({
|
export async function evaluateWithRoberta({
|
||||||
answer,
|
answer
|
||||||
method
|
|
||||||
}: {
|
}: {
|
||||||
answer: string;
|
answer: string;
|
||||||
method: string
|
|
||||||
}): Promise<{ validProb: number; invalidProb: number; }> {
|
}): Promise<{ validProb: number; invalidProb: number; }> {
|
||||||
const res = await axios.post("http://localhost:8000/evaluate", {
|
const res = await axios.post("http://localhost:8000/evaluate", {
|
||||||
answer,
|
answer
|
||||||
method
|
});
|
||||||
}, {timeout: 0});
|
|
||||||
// console.log(res.data)
|
// console.log(res.data)
|
||||||
const validProb = res.data["probabilities"][0][0]
|
const validProb = res.data["probabilities"][0][0]
|
||||||
const invalidProb = res.data["probabilities"][0][1] + res.data["probabilities"][0][2]
|
const invalidProb = res.data["probabilities"][0][1] + res.data["probabilities"][0][2]
|
||||||
+19
-79
@@ -1,92 +1,32 @@
|
|||||||
import { Builder, Browser } from "selenium-webdriver";
|
import { Builder, Browser } from "selenium-webdriver";
|
||||||
import firefox from "selenium-webdriver/firefox";
|
import firefox from "selenium-webdriver/firefox";
|
||||||
import { backOff } from "exponential-backoff";
|
|
||||||
import { logger } from "../utils/logger";
|
|
||||||
|
|
||||||
export async function extractWebpageContent(url: string): Promise<string[]> {
|
export async function extractWebpageContent(url: string) : Promise<string[]>{
|
||||||
try {
|
|
||||||
const response = await backOff(async () => {
|
|
||||||
return await extractWebpageContentWorker(url);
|
|
||||||
}, {
|
|
||||||
numOfAttempts: 10,
|
|
||||||
startingDelay: 500,
|
|
||||||
timeMultiple: 2,
|
|
||||||
jitter: "full",
|
|
||||||
maxDelay: 50000,
|
|
||||||
});
|
|
||||||
return response;
|
|
||||||
} catch (err: any) {
|
|
||||||
logger.error(`Failed out of retry loop for URL "${url}", returning placeholder to pipeline`);
|
|
||||||
return ["API EXCEPTION"];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function extractWebpageContentWorker(url: string): Promise<string[]> {
|
|
||||||
let driver;
|
|
||||||
try {
|
|
||||||
const options = new firefox.Options();
|
const options = new firefox.Options();
|
||||||
options.addArguments("--headless");
|
options.addArguments("--headless");
|
||||||
driver = await new Builder()
|
|
||||||
.forBrowser(Browser.FIREFOX)
|
|
||||||
.setFirefoxOptions(options)
|
|
||||||
.build();
|
|
||||||
} catch (err: any) {
|
|
||||||
const desc = `Failed to launch Firefox driver: ${err.message}`;
|
|
||||||
logger.error(desc);
|
|
||||||
throw new Error(desc);
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
let driver = await new Builder().forBrowser(Browser.FIREFOX).setFirefoxOptions(options).build()
|
||||||
try {
|
try {
|
||||||
await driver.get(url);
|
await driver.get(url)
|
||||||
} catch (err: any) {
|
await driver.wait(async () => {
|
||||||
const desc = `Failed to navigate to URL "${url}": ${err.message}`;
|
return await driver.executeScript(
|
||||||
logger.error(desc);
|
"return document.readyState === 'complete'"
|
||||||
throw new Error(desc);
|
);
|
||||||
}
|
}, 5000);
|
||||||
|
|
||||||
try {
|
const readableText = await driver.executeScript(
|
||||||
await driver.wait(async () => {
|
"return document.body.innerText;"
|
||||||
return await driver.executeScript(
|
) as string;
|
||||||
"return document.readyState === 'complete'"
|
|
||||||
);
|
|
||||||
}, 5000);
|
|
||||||
} catch (err: any) {
|
|
||||||
logger.error(`Page load timed out for "${url}", attempting to read partial content: ${err.message}`);
|
|
||||||
// do not throw, attempt to read
|
|
||||||
}
|
|
||||||
|
|
||||||
let readableText: string;
|
const filteredLines = readableText
|
||||||
try {
|
.split(/\r?\n/)
|
||||||
readableText = await driver.executeScript(
|
.map(line => line.trim())
|
||||||
"return document.body.innerText;"
|
.filter(line => line.split(/\s+/).length > 1);
|
||||||
) as string;
|
|
||||||
} catch (err: any) {
|
|
||||||
const desc = `Failed to extract page text from "${url}": ${err.message}`;
|
|
||||||
logger.error(desc);
|
|
||||||
throw new Error(desc);
|
|
||||||
}
|
|
||||||
|
|
||||||
const filteredLines = readableText
|
return filteredLines;
|
||||||
.split(/\r?\n/)
|
} finally {
|
||||||
.map(line => line.trim())
|
await driver.quit()
|
||||||
.filter(line => line.split(/\s+/).length > 1);
|
|
||||||
|
|
||||||
if (filteredLines.length === 0) {
|
|
||||||
const desc = `No content extracted from "${url}"`;
|
|
||||||
logger.error(desc);
|
|
||||||
throw new Error(desc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return filteredLines;
|
|
||||||
} finally {
|
|
||||||
try {
|
|
||||||
await driver.quit();
|
|
||||||
} catch (err: any) {
|
|
||||||
logger.error(`Failed to quit Firefox driver cleanly: ${err.message}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// console.log(await extractWebpageContent("https://www.bbc.co.uk/news/live/c74wd01egvyt"))
|
//console.log(await extractWebpageContent("https://www.bbc.co.uk/news/live/c74wd01egvyt"))
|
||||||
// console.log(await extractWebpageContent("https://badcertificate.int.jeynes.uk/"))
|
|
||||||
+17
-14
@@ -1,35 +1,38 @@
|
|||||||
import { END, START, StateGraph } from "@langchain/langgraph";
|
import { END, START, StateGraph } from "@langchain/langgraph";
|
||||||
import { MessagesState } from "./state";
|
import { MessagesState } from "./state";
|
||||||
import { verificationSetup } from "./nodes/verificationSetup";
|
import { verificationSetup } from "./nodes/verificationSetup";
|
||||||
|
import { ragasMetrics } from "./nodes/ragasMetrics";
|
||||||
import { produceRanking } from "./nodes/produceRanking";
|
import { produceRanking } from "./nodes/produceRanking";
|
||||||
|
import { createModelNode } from "./nodes/model";
|
||||||
import { loopEndConditional } from "./conditionals/loop_end";
|
import { loopEndConditional } from "./conditionals/loop_end";
|
||||||
import { sort } from "./nodes/sort";
|
import { sort } from "./nodes/sort";
|
||||||
import { createEnsembleNode } from "./nodes/ensembleNode";
|
import { robertaMetrics } from "./nodes/robertaMetrics";
|
||||||
|
|
||||||
const roNode = createEnsembleNode("ROBERTA", "roberta");
|
const verificationModel = createModelNode([], "verify.txt");
|
||||||
const flNode = createEnsembleNode("FLAN", "flan");
|
const relationModel = createModelNode([], "relation.txt");
|
||||||
const lrNode = createEnsembleNode("REGRESSION", "logreg");
|
|
||||||
|
|
||||||
const agent = new StateGraph(MessagesState)
|
const agent = new StateGraph(MessagesState)
|
||||||
|
|
||||||
//NODES
|
//NODES
|
||||||
.addNode(verificationSetup.name, verificationSetup)
|
.addNode(verificationSetup.name, verificationSetup)
|
||||||
.addNode("roNode", roNode)
|
// .addNode("verificationModel", verificationModel)
|
||||||
.addNode("flNode", flNode)
|
// .addNode(ragasMetrics.name, ragasMetrics)
|
||||||
.addNode("lrNode", lrNode)
|
.addNode(robertaMetrics.name, robertaMetrics)
|
||||||
|
// .addNode("relationModel", relationModel)
|
||||||
|
|
||||||
.addNode(produceRanking.name, produceRanking)
|
.addNode(produceRanking.name, produceRanking)
|
||||||
.addNode(sort.name, sort)
|
.addNode(sort.name, sort)
|
||||||
|
|
||||||
.addEdge(START, verificationSetup.name)
|
.addEdge(START, verificationSetup.name)
|
||||||
|
// .addEdge(verificationSetup.name, "verificationModel")
|
||||||
|
// .addEdge(verificationSetup.name, ragasMetrics.name)
|
||||||
|
.addEdge(verificationSetup.name, robertaMetrics.name)
|
||||||
|
// .addEdge(verificationSetup.name, "relationModel")
|
||||||
|
|
||||||
.addEdge(verificationSetup.name, "roNode")
|
// .addEdge(ragasMetrics.name, produceRanking.name)
|
||||||
.addEdge(verificationSetup.name, "flNode")
|
.addEdge(robertaMetrics.name, produceRanking.name)
|
||||||
.addEdge(verificationSetup.name, "lrNode")
|
// .addEdge("verificationModel", produceRanking.name)
|
||||||
|
// .addEdge("relationModel", produceRanking.name)
|
||||||
.addEdge("roNode", produceRanking.name)
|
|
||||||
.addEdge("flNode", produceRanking.name)
|
|
||||||
.addEdge("lrNode", produceRanking.name)
|
|
||||||
|
|
||||||
// @ts-expect-error
|
// @ts-expect-error
|
||||||
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ run_agent () {
|
|||||||
npx @langchain/langgraph-cli dev
|
npx @langchain/langgraph-cli dev
|
||||||
}
|
}
|
||||||
|
|
||||||
run_ensemble_service () {
|
run_ragas_service () {
|
||||||
echo "Starting Ensemble service..."
|
echo "Starting RAGAS service..."
|
||||||
cd "supporting/RAGAS_Service"
|
cd "supporting/RAGAS_Service"
|
||||||
.venv/bin/uvicorn ensemble_service:app --timeout-keep-alive 300
|
.venv/bin/uvicorn ragas_service:app --port 8001
|
||||||
}
|
}
|
||||||
|
|
||||||
run_frontend () {
|
run_frontend () {
|
||||||
@@ -34,13 +34,13 @@ run_wrapper () {
|
|||||||
|
|
||||||
case "$1" in
|
case "$1" in
|
||||||
agent) run_agent ;;
|
agent) run_agent ;;
|
||||||
ensemble_service) run_ensemble_service ;;
|
ragas_service) run_ragas_service ;;
|
||||||
frontend) run_frontend ;;
|
frontend) run_frontend ;;
|
||||||
fetch) run_fetch ;;
|
fetch) run_fetch ;;
|
||||||
wrapper) run_wrapper ;;
|
wrapper) run_wrapper ;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown command: $1"
|
echo "Unknown command: $1"
|
||||||
echo "Usage: ./runproject [agent|ensemble_service|frontend|fetch|wrapper]"
|
echo "Usage: ./runproject [agent|ragas_service|frontend|fetch|wrapper]"
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
# -- OURS --
|
# -- OURS --
|
||||||
results/
|
results/
|
||||||
roberta_classifier/
|
roberta_classifier/
|
||||||
roberta_distilled_classifier/
|
|
||||||
roberta_classifier*/
|
roberta_classifier*/
|
||||||
*.pt
|
|
||||||
output*
|
output*
|
||||||
|
|
||||||
# -- THEIRS --
|
# -- THEIRS --
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
# Classifier work for evaluating model quality
|
|
||||||
|
|
||||||
Made using a dataset of 1000 labeled claims from MVP pipeline.
|
|
||||||
|
|
||||||
Roberta model trained on an augmented dataset with LLM generated adversarial examples for low frequency labels.
|
|
||||||
|
|
||||||
Flan model trained using raw labelled claims, inherrent natural language ability allows for pattern recognition without the need for fake data.
|
|
||||||
|
|
||||||
Regression model trained using the roberta dataset.
|
|
||||||
|
|
||||||
Used ensemble model in the final version, with the component models available on Hugging Face.
|
|
||||||
|
|
||||||
| Model | % Correct | % Valid taken forward|Used in ensemble|Link
|
|
||||||
|------------------------------------------------------------|-----------|----------------------|----------------|-
|
|
||||||
| Original | 53.22 | 61.72 |
|
|
||||||
| Original (RAGAS) | 56.01 | 57.73 |
|
|
||||||
| Roberta (base) | 75 | 70 |
|
|
||||||
| Roberta (Generated Data) | 76 | 71 |
|
|
||||||
| Roberta (Generated Data + Back Translation) | 74 | 71 |
|
|
||||||
| Roberta (Generated Data + Back Translation + Thresholding) | 77 | 90 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis)
|
|
||||||
| Distilled Roberta | 72.73 | 69.57 |
|
|
||||||
| Flan | 79.17 | 85.71 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis-Flan)
|
|
||||||
| Simple Regression Model | 74.77 | 85.71 |Y|[Here](https://huggingface.co/WillJeynes/LLMsForDisinformationAnalysis-Regression)
|
|
||||||
| Ensemble Model (weighted confidence score sum) | 84.21 | 83.33 |
|
|
||||||
| Ensemble Model (majority voting) | 80.2 | 95.12 |
|
|
||||||
@@ -1,224 +0,0 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from fastapi import FastAPI
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import os
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
############################################
|
|
||||||
# SCHEMA
|
|
||||||
############################################
|
|
||||||
|
|
||||||
class EvalRequest(BaseModel):
|
|
||||||
answer: str
|
|
||||||
method: str # "logreg", "roberta", "flan"
|
|
||||||
|
|
||||||
|
|
||||||
############################################
|
|
||||||
# REGRESSION MODEL
|
|
||||||
############################################
|
|
||||||
|
|
||||||
HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression"
|
|
||||||
MODEL_FILENAME = "logreg_classifier.pt"
|
|
||||||
CACHE_DIR = "./model_cache"
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(repo_id: str, filename: str, cache_dir: str) -> dict:
|
|
||||||
local_path = os.path.join(cache_dir, filename)
|
|
||||||
if not os.path.exists(local_path):
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=cache_dir)
|
|
||||||
return torch.load(local_path, map_location="cpu")
|
|
||||||
|
|
||||||
|
|
||||||
class LogisticNet(nn.Module):
|
|
||||||
def __init__(self, input_dim, hidden_dim, num_classes, dropout):
|
|
||||||
super().__init__()
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Linear(input_dim, hidden_dim),
|
|
||||||
nn.BatchNorm1d(hidden_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
nn.Linear(hidden_dim, num_classes),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
checkpoint = load_checkpoint(HF_REPO_ID, MODEL_FILENAME, CACHE_DIR)
|
|
||||||
|
|
||||||
encoder = SentenceTransformer(checkpoint["embedding_model"])
|
|
||||||
|
|
||||||
logreg_model = LogisticNet(
|
|
||||||
checkpoint["input_dim"],
|
|
||||||
checkpoint["hidden_dim"],
|
|
||||||
checkpoint["num_classes"],
|
|
||||||
checkpoint["dropout"],
|
|
||||||
)
|
|
||||||
logreg_model.load_state_dict(checkpoint["model_state"])
|
|
||||||
logreg_model.eval()
|
|
||||||
|
|
||||||
|
|
||||||
############################################
|
|
||||||
# ROBERTA
|
|
||||||
############################################
|
|
||||||
|
|
||||||
ROBERTA_PATH = "WillJeynes/LLMsForDisinformationAnalysis"
|
|
||||||
|
|
||||||
roberta_tokenizer = RobertaTokenizer.from_pretrained(ROBERTA_PATH)
|
|
||||||
roberta_model = RobertaForSequenceClassification.from_pretrained(ROBERTA_PATH)
|
|
||||||
roberta_model.eval()
|
|
||||||
|
|
||||||
|
|
||||||
############################################
|
|
||||||
# FLAN
|
|
||||||
############################################
|
|
||||||
|
|
||||||
FLAN_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan"
|
|
||||||
|
|
||||||
INT_TO_LABEL = {
|
|
||||||
0: "perfect",
|
|
||||||
1: "story",
|
|
||||||
2: "not specific",
|
|
||||||
}
|
|
||||||
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
|
|
||||||
|
|
||||||
flan_tokenizer = AutoTokenizer.from_pretrained(FLAN_PATH)
|
|
||||||
flan_model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_PATH)
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
flan_model.to(device)
|
|
||||||
flan_model.eval()
|
|
||||||
|
|
||||||
label_token_ids = {
|
|
||||||
label: flan_tokenizer(label, add_special_tokens=False).input_ids[0]
|
|
||||||
for label in LABEL_TO_INT.keys()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def format_prompt(text: str) -> str:
|
|
||||||
return (
|
|
||||||
"Classify the following event into one of these categories: "
|
|
||||||
"perfect, story, not specific.\n\n"
|
|
||||||
f"Event: {text}\n\n"
|
|
||||||
"Category:"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_generated_label(generated: str):
|
|
||||||
generated = generated.strip().lower()
|
|
||||||
for label_text, label_int in LABEL_TO_INT.items():
|
|
||||||
if label_text in generated:
|
|
||||||
return label_int
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
############################################
|
|
||||||
# ENDPOINT
|
|
||||||
############################################
|
|
||||||
|
|
||||||
@app.post("/evaluate")
|
|
||||||
def evaluate(req: EvalRequest):
|
|
||||||
method = req.method.lower()
|
|
||||||
|
|
||||||
########################################
|
|
||||||
# LOGREG
|
|
||||||
########################################
|
|
||||||
if method == "logreg":
|
|
||||||
embedding = encoder.encode(
|
|
||||||
[req.answer],
|
|
||||||
normalize_embeddings=True,
|
|
||||||
show_progress_bar=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
x = torch.tensor(embedding, dtype=torch.float32)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
logits = logreg_model(x)
|
|
||||||
|
|
||||||
probs = torch.softmax(logits, dim=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"method": "logreg",
|
|
||||||
"probabilities": probs.cpu().numpy().tolist()
|
|
||||||
}
|
|
||||||
|
|
||||||
########################################
|
|
||||||
# ROBERTA
|
|
||||||
########################################
|
|
||||||
elif method == "roberta":
|
|
||||||
inputs = roberta_tokenizer(
|
|
||||||
req.answer,
|
|
||||||
return_tensors="pt",
|
|
||||||
truncation=True,
|
|
||||||
padding=True
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
logits = roberta_model(**inputs).logits
|
|
||||||
|
|
||||||
probs = torch.softmax(logits, dim=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"method": "roberta",
|
|
||||||
"probabilities": probs.cpu().numpy().tolist()
|
|
||||||
}
|
|
||||||
|
|
||||||
########################################
|
|
||||||
# FLAN
|
|
||||||
########################################
|
|
||||||
elif method == "flan":
|
|
||||||
prompt = format_prompt(req.answer)
|
|
||||||
|
|
||||||
inputs = flan_tokenizer(
|
|
||||||
prompt,
|
|
||||||
return_tensors="pt",
|
|
||||||
truncation=True,
|
|
||||||
padding=True,
|
|
||||||
max_length=256,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = flan_model.generate(**inputs, max_new_tokens=8)
|
|
||||||
|
|
||||||
decoder_input_ids = torch.tensor(
|
|
||||||
[[flan_model.config.decoder_start_token_id]]
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
logits_output = flan_model(
|
|
||||||
**inputs,
|
|
||||||
decoder_input_ids=decoder_input_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = logits_output.logits[:, 0, :]
|
|
||||||
|
|
||||||
generated_text = flan_tokenizer.decode(
|
|
||||||
outputs[0],
|
|
||||||
skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
label_logits = torch.tensor(
|
|
||||||
[logits[0, tid].item() for tid in label_token_ids.values()]
|
|
||||||
)
|
|
||||||
|
|
||||||
label_probs = torch.softmax(label_logits, dim=0).tolist()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"method": "flan",
|
|
||||||
"generated": generated_text,
|
|
||||||
"probabilities": [label_probs],
|
|
||||||
}
|
|
||||||
|
|
||||||
########################################
|
|
||||||
# INVALID METHOD
|
|
||||||
########################################
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"error": "Invalid method. Use 'logreg', 'roberta', or 'flan'."
|
|
||||||
}
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
||||||
import torch
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan"
|
|
||||||
|
|
||||||
INT_TO_LABEL = {
|
|
||||||
0: "perfect",
|
|
||||||
1: "story",
|
|
||||||
2: "not specific",
|
|
||||||
}
|
|
||||||
|
|
||||||
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
|
|
||||||
def format_prompt(text: str) -> str:
|
|
||||||
return (
|
|
||||||
"Classify the following event into one of these categories: "
|
|
||||||
"perfect, story, not specific.\n\n"
|
|
||||||
f"Event: {text}\n\n"
|
|
||||||
"Category:"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_generated_label(generated: str) -> int | None:
|
|
||||||
generated = generated.strip().lower()
|
|
||||||
for label_text, label_int in LABEL_TO_INT.items():
|
|
||||||
if label_text in generated:
|
|
||||||
return label_int
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class EvalRequest(BaseModel):
|
|
||||||
answer: str
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/evaluate")
|
|
||||||
def evaluate(req: EvalRequest):
|
|
||||||
prompt = format_prompt(req.answer)
|
|
||||||
|
|
||||||
inputs = tokenizer(
|
|
||||||
prompt,
|
|
||||||
return_tensors="pt",
|
|
||||||
truncation=True,
|
|
||||||
padding=True,
|
|
||||||
max_length=256,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# Get the generated label
|
|
||||||
outputs = model.generate(
|
|
||||||
**inputs,
|
|
||||||
max_new_tokens=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Produce a confidence score
|
|
||||||
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
|
|
||||||
logits_output = model(**inputs, decoder_input_ids=decoder_input_ids)
|
|
||||||
logits = logits_output.logits[:, 0, :]
|
|
||||||
|
|
||||||
# Decode the generated text label
|
|
||||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
||||||
predicted_int = parse_generated_label(generated_text)
|
|
||||||
|
|
||||||
# Extract probabilities
|
|
||||||
label_token_ids = {
|
|
||||||
label: tokenizer(label, add_special_tokens=False).input_ids[0]
|
|
||||||
for label in LABEL_TO_INT.keys()
|
|
||||||
}
|
|
||||||
|
|
||||||
label_logits = torch.tensor(
|
|
||||||
[logits[0, tid].item() for tid in label_token_ids.values()]
|
|
||||||
)
|
|
||||||
label_probs = torch.softmax(label_logits, dim=0).tolist()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"generated": generated_text,
|
|
||||||
"probabilities": [label_probs],
|
|
||||||
}
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from fastapi import FastAPI
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import os
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression"
|
|
||||||
MODEL_FILENAME = "logreg_classifier.pt"
|
|
||||||
CACHE_DIR = "./model_cache"
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(repo_id: str, filename: str, cache_dir: str) -> dict:
|
|
||||||
local_path = os.path.join(cache_dir, filename)
|
|
||||||
if not os.path.exists(local_path):
|
|
||||||
print(f"Downloading {filename} from {repo_id}...")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
downloaded = hf_hub_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
filename=filename,
|
|
||||||
local_dir=cache_dir,
|
|
||||||
)
|
|
||||||
print(f"Saved to {downloaded}")
|
|
||||||
else:
|
|
||||||
print(f"Using cached model at {local_path}")
|
|
||||||
return torch.load(local_path, map_location="cpu")
|
|
||||||
|
|
||||||
|
|
||||||
class LogisticNet(nn.Module):
|
|
||||||
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float):
|
|
||||||
super().__init__()
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Linear(input_dim, hidden_dim),
|
|
||||||
nn.BatchNorm1d(hidden_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
nn.Linear(hidden_dim, num_classes),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
checkpoint = load_checkpoint(HF_REPO_ID, MODEL_FILENAME, CACHE_DIR)
|
|
||||||
|
|
||||||
encoder = SentenceTransformer(checkpoint["embedding_model"])
|
|
||||||
|
|
||||||
model = LogisticNet(
|
|
||||||
input_dim = checkpoint["input_dim"],
|
|
||||||
hidden_dim = checkpoint["hidden_dim"],
|
|
||||||
num_classes = checkpoint["num_classes"],
|
|
||||||
dropout = checkpoint["dropout"],
|
|
||||||
)
|
|
||||||
model.load_state_dict(checkpoint["model_state"])
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
|
|
||||||
class EvalRequest(BaseModel):
|
|
||||||
answer: str
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/evaluate")
|
|
||||||
def evaluate(req: EvalRequest):
|
|
||||||
embedding = encoder.encode(
|
|
||||||
[req.answer],
|
|
||||||
normalize_embeddings=True,
|
|
||||||
show_progress_bar=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
x = torch.tensor(embedding, dtype=torch.float32)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(x)
|
|
||||||
|
|
||||||
probs = torch.softmax(logits, dim=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"probabilities": probs.cpu().numpy().tolist()
|
|
||||||
}
|
|
||||||
@@ -5,7 +5,7 @@ from fastapi import FastAPI
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
MODEL_PATH = "WillJeynes/LLMsForDisinformationAnalysis"
|
MODEL_PATH = "./roberta_classifier"
|
||||||
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
||||||
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
||||||
|
|||||||
@@ -1,227 +0,0 @@
|
|||||||
from sklearn.utils import compute_class_weight
|
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
|
|
||||||
import torch
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
|
||||||
from collections import Counter
|
|
||||||
import sys
|
|
||||||
import csv
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
NUM_CLASSES = 3
|
|
||||||
model_name = "google/flan-t5-base"
|
|
||||||
|
|
||||||
INT_TO_LABEL = {
|
|
||||||
0: "perfect",
|
|
||||||
1: "story",
|
|
||||||
2: "not specific",
|
|
||||||
}
|
|
||||||
LABEL_TO_INT = {v: k for k, v in INT_TO_LABEL.items()}
|
|
||||||
|
|
||||||
LABEL_PRIORITY = [
|
|
||||||
("PERFECT", 0),
|
|
||||||
("STORY", 1),
|
|
||||||
("NSPECIFIC", 2),
|
|
||||||
("REWORDING", 1),
|
|
||||||
("TINCORRECT", -1),
|
|
||||||
("DUPLICATE", -1),
|
|
||||||
("", 0),
|
|
||||||
]
|
|
||||||
|
|
||||||
def label_to_int(extra_info: str) -> int:
|
|
||||||
if extra_info is None:
|
|
||||||
extra_info = ""
|
|
||||||
extra_info = extra_info.strip()
|
|
||||||
if extra_info == "":
|
|
||||||
for key, value in LABEL_PRIORITY:
|
|
||||||
if key == "":
|
|
||||||
return value
|
|
||||||
raise ValueError("Empty extra_info but no empty mapping defined")
|
|
||||||
tokens = set(extra_info.upper().split())
|
|
||||||
for key, value in LABEL_PRIORITY:
|
|
||||||
if key == "" :
|
|
||||||
continue
|
|
||||||
if key in tokens:
|
|
||||||
return value
|
|
||||||
raise ValueError(f"Unknown label content: '{extra_info}'")
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_from_csv(path):
|
|
||||||
texts = []
|
|
||||||
labels = []
|
|
||||||
removed_rows = 0
|
|
||||||
with open(path, newline="", encoding="utf-8") as f:
|
|
||||||
reader = csv.DictReader(f)
|
|
||||||
for i, row in enumerate(reader, start=1):
|
|
||||||
text = row["event"]
|
|
||||||
label_str = row["extra_info"]
|
|
||||||
try:
|
|
||||||
label_int = label_to_int(label_str)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR converting label on line {i}: {label_str}")
|
|
||||||
print(e)
|
|
||||||
sys.exit(1)
|
|
||||||
if label_int == -1:
|
|
||||||
removed_rows += 1
|
|
||||||
continue
|
|
||||||
texts.append(text)
|
|
||||||
labels.append(label_int)
|
|
||||||
print(f"Loaded {len(texts)} samples (removed {removed_rows})")
|
|
||||||
return texts, labels
|
|
||||||
|
|
||||||
|
|
||||||
def format_prompt(text: str) -> str:
|
|
||||||
return (
|
|
||||||
"Classify the following event into one of these categories: "
|
|
||||||
"perfect, story, not specific.\n\n"
|
|
||||||
f"Event: {text}\n\n"
|
|
||||||
"Category:"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_generated_label(generated: str) -> int:
|
|
||||||
generated = generated.strip().lower()
|
|
||||||
for label_text, label_int in LABEL_TO_INT.items():
|
|
||||||
if label_text in generated:
|
|
||||||
return label_int
|
|
||||||
print("invlid label:" + generated)
|
|
||||||
return -1 # unknown / unparseable output
|
|
||||||
|
|
||||||
|
|
||||||
class GenerativeTextDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, texts, labels, tokenizer, max_input_length=256, max_target_length=8):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.max_input_length = max_input_length
|
|
||||||
self.max_target_length = max_target_length
|
|
||||||
|
|
||||||
self.inputs = [format_prompt(t) for t in texts]
|
|
||||||
# Convert integer labels to their text equivalents for the target sequence
|
|
||||||
self.targets = [INT_TO_LABEL[l] for l in labels]
|
|
||||||
self.int_labels = labels # keep originals for metric computation
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.inputs)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
model_inputs = self.tokenizer(
|
|
||||||
self.inputs[idx],
|
|
||||||
max_length=self.max_input_length,
|
|
||||||
truncation=True,
|
|
||||||
padding=False,
|
|
||||||
)
|
|
||||||
target_encoding = self.tokenizer(
|
|
||||||
self.targets[idx],
|
|
||||||
max_length=self.max_target_length,
|
|
||||||
truncation=True,
|
|
||||||
padding=False,
|
|
||||||
)
|
|
||||||
# Seq2Seq convention: labels use -100 to ignore padding tokens in loss
|
|
||||||
labels = target_encoding["input_ids"]
|
|
||||||
labels = [token if token != self.tokenizer.pad_token_id else -100 for token in labels]
|
|
||||||
|
|
||||||
model_inputs["labels"] = labels
|
|
||||||
return {k: torch.tensor(v) for k, v in model_inputs.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def compute_metrics_generative(eval_pred, tokenizer):
|
|
||||||
predictions, label_ids = eval_pred
|
|
||||||
|
|
||||||
# Decode predictions
|
|
||||||
# Replace -100 in labels before decoding
|
|
||||||
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
|
|
||||||
|
|
||||||
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
|
||||||
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
|
||||||
|
|
||||||
# Map decoded text back to integer labels
|
|
||||||
pred_ints = [parse_generated_label(p) for p in decoded_preds]
|
|
||||||
true_ints = [parse_generated_label(l) for l in decoded_labels]
|
|
||||||
|
|
||||||
# Filter out any rows where parsing failed
|
|
||||||
valid = [(p, t) for p, t in zip(pred_ints, true_ints) if t != -1]
|
|
||||||
if not valid:
|
|
||||||
return {"accuracy": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
|
|
||||||
|
|
||||||
preds_filtered, true_filtered = zip(*valid)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"accuracy": accuracy_score(true_filtered, preds_filtered),
|
|
||||||
"f1": f1_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
|
|
||||||
"precision": precision_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
|
|
||||||
"recall": recall_score(true_filtered, preds_filtered, average="weighted", zero_division=0),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
torch.multiprocessing.set_start_method('spawn', force=True)
|
|
||||||
print("CUDA available:", torch.cuda.is_available())
|
|
||||||
print("CUDA device count:", torch.cuda.device_count())
|
|
||||||
|
|
||||||
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
|
||||||
|
|
||||||
print("Dataset size:", len(texts))
|
|
||||||
print("Label distribution:", Counter(labels))
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
||||||
|
|
||||||
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
|
||||||
texts, labels,
|
|
||||||
test_size=0.2,
|
|
||||||
random_state=42,
|
|
||||||
stratify=labels
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset = GenerativeTextDataset(train_texts, train_labels, tokenizer)
|
|
||||||
val_dataset = GenerativeTextDataset(val_texts, val_labels, tokenizer)
|
|
||||||
|
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
model=model,
|
|
||||||
padding=True,
|
|
||||||
label_pad_token_id=-100,
|
|
||||||
)
|
|
||||||
|
|
||||||
training_args = Seq2SeqTrainingArguments(
|
|
||||||
output_dir="./results",
|
|
||||||
learning_rate=5e-5,
|
|
||||||
per_device_train_batch_size=16,
|
|
||||||
per_device_eval_batch_size=16,
|
|
||||||
num_train_epochs=10,
|
|
||||||
weight_decay=0.01,
|
|
||||||
eval_strategy="epoch",
|
|
||||||
save_strategy="epoch",
|
|
||||||
load_best_model_at_end=True,
|
|
||||||
metric_for_best_model="f1",
|
|
||||||
greater_is_better=True,
|
|
||||||
predict_with_generate=True,
|
|
||||||
generation_max_length=8,
|
|
||||||
dataloader_num_workers=0,
|
|
||||||
dataloader_pin_memory=False,
|
|
||||||
fp16=False,
|
|
||||||
max_grad_norm=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Seq2SeqTrainer(
|
|
||||||
model=model,
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=val_dataset,
|
|
||||||
processing_class=tokenizer,
|
|
||||||
data_collator=data_collator,
|
|
||||||
compute_metrics=lambda ep: compute_metrics_generative(ep, tokenizer),
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
metrics = trainer.evaluate()
|
|
||||||
print("\nFinal evaluation metrics:")
|
|
||||||
for k, v in metrics.items():
|
|
||||||
print(f" {k}: {v}")
|
|
||||||
|
|
||||||
trainer.save_model("./flan_classifier")
|
|
||||||
tokenizer.save_pretrained("./flan_classifier")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,209 +0,0 @@
|
|||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.utils import compute_class_weight
|
|
||||||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
|
||||||
from collections import Counter
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
|
||||||
import numpy as np
|
|
||||||
import csv
|
|
||||||
import sys
|
|
||||||
|
|
||||||
NUM_CLASSES = 3
|
|
||||||
EMBEDDING_MODEL = "all-mpnet-base-v2"
|
|
||||||
HIDDEN_DIM = 256
|
|
||||||
DROPOUT = 0.4
|
|
||||||
LEARNING_RATE = 2e-3
|
|
||||||
WEIGHT_DECAY = 1e-4
|
|
||||||
BATCH_SIZE = 64
|
|
||||||
NUM_EPOCHS = 30
|
|
||||||
PATIENCE = 5
|
|
||||||
|
|
||||||
LABEL_PRIORITY = [
|
|
||||||
("PERFECT", 0),
|
|
||||||
("STORY", 1),
|
|
||||||
("NSPECIFIC", 2),
|
|
||||||
("REWORDING", 1),
|
|
||||||
("TINCORRECT", -1),
|
|
||||||
("DUPLICATE", -1),
|
|
||||||
("", 0), # fallback to PERFECT
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def label_to_int(extra_info: str) -> int:
|
|
||||||
if extra_info is None:
|
|
||||||
extra_info = ""
|
|
||||||
extra_info = extra_info.strip()
|
|
||||||
|
|
||||||
if extra_info == "":
|
|
||||||
for key, value in LABEL_PRIORITY:
|
|
||||||
if key == "":
|
|
||||||
return value
|
|
||||||
raise ValueError("No empty-string fallback defined in LABEL_PRIORITY")
|
|
||||||
|
|
||||||
tokens = set(extra_info.upper().split())
|
|
||||||
for key, value in LABEL_PRIORITY:
|
|
||||||
if key and key in tokens:
|
|
||||||
return value
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown label content: '{extra_info}'")
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_from_csv(path: str):
|
|
||||||
texts, labels = [], []
|
|
||||||
removed = 0
|
|
||||||
|
|
||||||
with open(path, newline="", encoding="utf-8") as f:
|
|
||||||
for i, row in enumerate(csv.DictReader(f), start=1):
|
|
||||||
try:
|
|
||||||
label_int = label_to_int(row["extra_info"])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR on line {i}: {row['extra_info']!r}")
|
|
||||||
print(e)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if label_int == -1:
|
|
||||||
removed += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
texts.append(row["event"])
|
|
||||||
labels.append(label_int)
|
|
||||||
|
|
||||||
print(f"Loaded {len(texts)} samples (removed {removed})")
|
|
||||||
return texts, labels
|
|
||||||
|
|
||||||
|
|
||||||
class LogisticNet(nn.Module):
|
|
||||||
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float):
|
|
||||||
super().__init__()
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Linear(input_dim, hidden_dim),
|
|
||||||
nn.BatchNorm1d(hidden_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
nn.Linear(hidden_dim, num_classes), # raw logits – loss handles softmax
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, loader, device):
|
|
||||||
model.eval()
|
|
||||||
all_preds, all_labels = [], []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for xb, yb in loader:
|
|
||||||
xb, yb = xb.to(device), yb.to(device)
|
|
||||||
logits = model(xb)
|
|
||||||
preds = logits.argmax(dim=1).cpu().numpy()
|
|
||||||
all_preds.extend(preds)
|
|
||||||
all_labels.extend(yb.cpu().numpy())
|
|
||||||
|
|
||||||
return {
|
|
||||||
"accuracy": accuracy_score(all_labels, all_preds),
|
|
||||||
"f1": f1_score(all_labels, all_preds, average="weighted", zero_division=0),
|
|
||||||
"precision": precision_score(all_labels, all_preds, average="weighted", zero_division=0),
|
|
||||||
"recall": recall_score(all_labels, all_preds, average="weighted", zero_division=0),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Using device: {device}")
|
|
||||||
|
|
||||||
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
|
||||||
print("Label distribution:", Counter(labels))
|
|
||||||
|
|
||||||
print(f"\nEncoding with '{EMBEDDING_MODEL}' …")
|
|
||||||
encoder = SentenceTransformer(EMBEDDING_MODEL)
|
|
||||||
embeddings = encoder.encode(texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True)
|
|
||||||
input_dim = embeddings.shape[1]
|
|
||||||
print(f"Embedding dim: {input_dim}")
|
|
||||||
|
|
||||||
X_train, X_val, y_train, y_val = train_test_split(
|
|
||||||
embeddings, labels, test_size=0.2, random_state=42, stratify=labels
|
|
||||||
)
|
|
||||||
|
|
||||||
class_weights = compute_class_weight("balanced", classes=np.unique(y_train), y=y_train)
|
|
||||||
weight_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
|
|
||||||
print("Class weights:", class_weights)
|
|
||||||
|
|
||||||
def make_loader(X, y, shuffle=False):
|
|
||||||
ds = TensorDataset(
|
|
||||||
torch.tensor(X, dtype=torch.float32),
|
|
||||||
torch.tensor(y, dtype=torch.long),
|
|
||||||
)
|
|
||||||
return DataLoader(ds, batch_size=BATCH_SIZE, shuffle=shuffle)
|
|
||||||
|
|
||||||
train_loader = make_loader(X_train, y_train, shuffle=True)
|
|
||||||
val_loader = make_loader(X_val, y_val, shuffle=False)
|
|
||||||
|
|
||||||
model = LogisticNet(input_dim, HIDDEN_DIM, NUM_CLASSES, DROPOUT).to(device)
|
|
||||||
criterion = nn.CrossEntropyLoss(weight=weight_tensor)
|
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
|
|
||||||
|
|
||||||
best_f1 = 0.0
|
|
||||||
best_state = None
|
|
||||||
epochs_no_imp = 0
|
|
||||||
|
|
||||||
print("\n Training:")
|
|
||||||
for epoch in range(1, NUM_EPOCHS + 1):
|
|
||||||
model.train()
|
|
||||||
total_loss = 0.0
|
|
||||||
|
|
||||||
for xb, yb in train_loader:
|
|
||||||
xb, yb = xb.to(device), yb.to(device)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss = criterion(model(xb), yb)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
total_loss += loss.item() * len(yb)
|
|
||||||
|
|
||||||
scheduler.step()
|
|
||||||
avg_loss = total_loss / len(train_loader.dataset)
|
|
||||||
val_metrics = evaluate(model, val_loader, device)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Epoch {epoch:3d}/{NUM_EPOCHS} | "
|
|
||||||
f"loss {avg_loss:.4f} | "
|
|
||||||
f"val_acc {val_metrics['accuracy']:.4f} | "
|
|
||||||
f"val_f1 {val_metrics['f1']:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Early stopping on weighted F1
|
|
||||||
if val_metrics["f1"] > best_f1:
|
|
||||||
best_f1 = val_metrics["f1"]
|
|
||||||
best_state = {k: v.clone() for k, v in model.state_dict().items()}
|
|
||||||
epochs_no_imp = 0
|
|
||||||
else:
|
|
||||||
epochs_no_imp += 1
|
|
||||||
if epochs_no_imp >= PATIENCE:
|
|
||||||
print(f"Early stopping at epoch {epoch} (no improvement for {PATIENCE} epochs)")
|
|
||||||
break
|
|
||||||
|
|
||||||
print("\n Final evaluation:")
|
|
||||||
model.load_state_dict(best_state)
|
|
||||||
final = evaluate(model, val_loader, device)
|
|
||||||
for k, v in final.items():
|
|
||||||
print(f" {k}: {v:.4f}")
|
|
||||||
|
|
||||||
torch.save(
|
|
||||||
{
|
|
||||||
"model_state": best_state,
|
|
||||||
"input_dim": input_dim,
|
|
||||||
"hidden_dim": HIDDEN_DIM,
|
|
||||||
"num_classes": NUM_CLASSES,
|
|
||||||
"dropout": DROPOUT,
|
|
||||||
"embedding_model": EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
"logreg_classifier.pt"
|
|
||||||
)
|
|
||||||
print("\n Model saved to logreg_classifier.pt")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from sklearn.utils import compute_class_weight
|
from sklearn.utils import compute_class_weight
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
|
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification
|
||||||
import torch
|
import torch
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
@@ -10,7 +10,7 @@ import csv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
NUM_CLASSES = 3
|
NUM_CLASSES = 3
|
||||||
model_name = "roberta-base"
|
model_name = "distilbert/distilroberta-base"
|
||||||
|
|
||||||
LABEL_PRIORITY = [
|
LABEL_PRIORITY = [
|
||||||
("PERFECT", 0),
|
("PERFECT", 0),
|
||||||
@@ -29,12 +29,21 @@ class WeightedTrainer(Trainer):
|
|||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
labels = inputs.get("labels")
|
labels = inputs.get("labels")
|
||||||
|
# print("DBG: Before forward")
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
# print("DBG: After forward")
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(weight=self.class_weights.to(logits.device))
|
# loss_fct = CrossEntropyLoss(weight=self.class_weights.to(logits.device))
|
||||||
loss = loss_fct(logits, labels)
|
loss_fct = CrossEntropyLoss(
|
||||||
|
weight=self.class_weights.to(logits.device).to(logits.dtype)
|
||||||
|
)
|
||||||
|
# loss_fct = CrossEntropyLoss()
|
||||||
|
|
||||||
|
# print("DBG: Before loss")
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
# loss.backward()
|
||||||
|
# print("DBG: After loss")
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
def label_to_int(extra_info: str) -> int:
|
def label_to_int(extra_info: str) -> int:
|
||||||
@@ -120,17 +129,23 @@ def main():
|
|||||||
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
|
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
|
||||||
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
||||||
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained(model_name, hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2)
|
# tokenizer = RobertaTokenizer.from_pretrained(model_name, hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2)
|
||||||
model = RobertaForSequenceClassification.from_pretrained(
|
# model = RobertaForSequenceClassification.from_pretrained(
|
||||||
|
# model_name,
|
||||||
|
# num_labels=NUM_CLASSES
|
||||||
|
# )
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
num_labels=NUM_CLASSES
|
num_labels=NUM_CLASSES
|
||||||
)
|
)
|
||||||
|
|
||||||
for param in model.roberta.parameters():
|
# for param in model.deberta.parameters():
|
||||||
param.requires_grad = False
|
# param.requires_grad = True
|
||||||
|
|
||||||
for param in model.roberta.encoder.layer[-6:].parameters():
|
# for param in model.deberta.encoder.layer[-6:].parameters():
|
||||||
param.requires_grad = True
|
# param.requires_grad = True
|
||||||
|
|
||||||
print("Dataset size:", len(texts))
|
print("Dataset size:", len(texts))
|
||||||
print("Label distribution:")
|
print("Label distribution:")
|
||||||
@@ -140,7 +155,8 @@ def main():
|
|||||||
texts,
|
texts,
|
||||||
labels,
|
labels,
|
||||||
test_size=0.2,
|
test_size=0.2,
|
||||||
random_state=42
|
random_state=42,
|
||||||
|
stratify=labels
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -173,6 +189,7 @@ def main():
|
|||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
# print(f"DBG: Loading item {idx}")
|
||||||
item = {
|
item = {
|
||||||
key: torch.tensor(val[idx])
|
key: torch.tensor(val[idx])
|
||||||
for key, val in self.encodings.items()
|
for key, val in self.encodings.items()
|
||||||
@@ -187,7 +204,8 @@ def main():
|
|||||||
output_dir="./results",
|
output_dir="./results",
|
||||||
learning_rate=2e-5,
|
learning_rate=2e-5,
|
||||||
per_device_train_batch_size=32,
|
per_device_train_batch_size=32,
|
||||||
num_train_epochs=5,
|
# gradient_accumulation_steps=2,
|
||||||
|
num_train_epochs=15,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
eval_strategy="epoch",
|
eval_strategy="epoch",
|
||||||
@@ -195,7 +213,8 @@ def main():
|
|||||||
metric_for_best_model="f1",
|
metric_for_best_model="f1",
|
||||||
greater_is_better=True,
|
greater_is_better=True,
|
||||||
dataloader_num_workers=4,
|
dataloader_num_workers=4,
|
||||||
dataloader_pin_memory=True
|
dataloader_pin_memory=True,
|
||||||
|
# warmup_steps=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = TextDataset(train_encodings, train_labels)
|
train_dataset = TextDataset(train_encodings, train_labels)
|
||||||
@@ -218,8 +237,8 @@ def main():
|
|||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
print(f"{k}: {v}")
|
print(f"{k}: {v}")
|
||||||
|
|
||||||
trainer.save_model("./roberta_classifier")
|
trainer.save_model("./roberta_distilled_classifier")
|
||||||
tokenizer.save_pretrained("./roberta_classifier")
|
tokenizer.save_pretrained("./roberta_distilled_classifier")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ async function processRecord(record: any): Promise<ResultRecord> {
|
|||||||
input: buildAgentInput(record),
|
input: buildAgentInput(record),
|
||||||
streamMode: "values",
|
streamMode: "values",
|
||||||
config: {
|
config: {
|
||||||
recursion_limit: 100
|
recursion_limit: 50
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -16,18 +16,18 @@ BASE_URL = "https://dbkf.ontotext.com/rest-api/search/documents"
|
|||||||
|
|
||||||
# "documentTypes": "http://schema.org/Claim",
|
# "documentTypes": "http://schema.org/Claim",
|
||||||
DEFAULT_PARAMS = [
|
DEFAULT_PARAMS = [
|
||||||
("documentTypes", "http://schema.org/Claim"),
|
("concept", "http://weverify.eu/resource/Concept/Q212"),
|
||||||
("from", "2000-01-01"),
|
("from", "2000-01-01"),
|
||||||
("to", "2026-02-19"),
|
("to", "2026-02-19"),
|
||||||
("lang", "en"),
|
("lang", "en"),
|
||||||
("limit", 7000),
|
("limit", 5000),
|
||||||
("page", 1),
|
("page", 1),
|
||||||
("orderBy", "date"),
|
("orderBy", "date"),
|
||||||
("organization", "http://weverify.eu/resource/Organization/128573c5d49d37558706194e755f152d"), # Science Direct
|
|
||||||
("organization", "http://weverify.eu/resource/Organization/3727f7b2aa90ec0716693e5464b28d18"), # StopFake
|
("organization", "http://weverify.eu/resource/Organization/3727f7b2aa90ec0716693e5464b28d18"), # StopFake
|
||||||
|
("organization", "http://weverify.eu/resource/Organization/c71953fa6cf24ac4178f751c77862070"), # CheckYourFact
|
||||||
]
|
]
|
||||||
|
|
||||||
NUM_RANDOM_CLAIMS = 200
|
NUM_RANDOM_CLAIMS = 40
|
||||||
|
|
||||||
INPUT_FILE = "../../data/input.jsonl"
|
INPUT_FILE = "../../data/input.jsonl"
|
||||||
OUTPUT_FILE = "../../data/claims.json"
|
OUTPUT_FILE = "../../data/claims.json"
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import streamlit as st
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# THRESH = 0.4
|
||||||
THRESH = 0.6
|
THRESH = 0.6
|
||||||
|
|
||||||
def page_title() -> str:
|
def page_title() -> str:
|
||||||
@@ -60,18 +61,6 @@ def render():
|
|||||||
return
|
return
|
||||||
|
|
||||||
for file_path in jsonl_files:
|
for file_path in jsonl_files:
|
||||||
thresh = THRESH
|
|
||||||
if ("flan" in file_path.name):
|
|
||||||
thresh = 0.94
|
|
||||||
if ("regression" in file_path.name):
|
|
||||||
thresh = 0.75
|
|
||||||
if ("ensemble" in file_path.name):
|
|
||||||
thresh = 0.1
|
|
||||||
if ("ensemble" in file_path.name and "2" in file_path.name):
|
|
||||||
thresh = 0.4
|
|
||||||
if ("ensemble" in file_path.name and "vot" in file_path.name):
|
|
||||||
thresh = 0.7
|
|
||||||
|
|
||||||
st.subheader(f"File: {file_path.name}")
|
st.subheader(f"File: {file_path.name}")
|
||||||
|
|
||||||
confidence_counter = Counter()
|
confidence_counter = Counter()
|
||||||
@@ -97,15 +86,15 @@ def render():
|
|||||||
dup_counter += 1
|
dup_counter += 1
|
||||||
elif "ranked" not in event:
|
elif "ranked" not in event:
|
||||||
"ignore for now"
|
"ignore for now"
|
||||||
elif score > thresh and extra_lower == "perfect":
|
elif score > THRESH and extra_lower == "perfect":
|
||||||
confidence_counter["Correct-PERFECT"] += 1
|
confidence_counter["Correct-PERFECT"] += 1
|
||||||
elif score > thresh and extra_lower == "":
|
elif score > THRESH and extra_lower == "":
|
||||||
confidence_counter["Correct-FINE"] += 1
|
confidence_counter["Correct-FINE"] += 1
|
||||||
elif score > thresh and extra_lower != "perfect" and extra_lower != "":
|
elif score > THRESH and extra_lower != "perfect" and extra_lower != "":
|
||||||
confidence_counter["Over-confident"] += 1
|
confidence_counter["Over-confident"] += 1
|
||||||
wrong_counter[extra_lower] += 1
|
wrong_counter[extra_lower] += 1
|
||||||
overconfident_docs.append(doc_id)
|
overconfident_docs.append(doc_id)
|
||||||
elif score < thresh and (extra_lower == "perfect" or extra_lower == ""):
|
elif score < THRESH and (extra_lower == "perfect" or extra_lower == ""):
|
||||||
confidence_counter["Under-confident"] += 1
|
confidence_counter["Under-confident"] += 1
|
||||||
underconfident_docs.append(doc_id)
|
underconfident_docs.append(doc_id)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
from collections import Counter
|
|
||||||
from pathlib import Path
|
|
||||||
import json
|
|
||||||
import streamlit as st
|
|
||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
THRESH = 0.4
|
|
||||||
|
|
||||||
def page_title() -> str:
|
|
||||||
return "Statistics 2"
|
|
||||||
|
|
||||||
def render():
|
|
||||||
st.header("Statistics 2")
|
|
||||||
|
|
||||||
path = Path("../../data/refinement")
|
|
||||||
|
|
||||||
if not path.exists() or not path.is_dir():
|
|
||||||
st.error("Invalid folder path.")
|
|
||||||
return
|
|
||||||
|
|
||||||
jsonl_files = sorted(path.glob("*.jsonl"))
|
|
||||||
if not jsonl_files:
|
|
||||||
st.info("No .jsonl files found in this folder.")
|
|
||||||
return
|
|
||||||
|
|
||||||
for file_path in jsonl_files:
|
|
||||||
thresh = THRESH
|
|
||||||
st.subheader(f"File: {file_path.name}")
|
|
||||||
|
|
||||||
confidence_counter = Counter()
|
|
||||||
|
|
||||||
# ---- Read file line by line ----
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
try:
|
|
||||||
entry = json.loads(line)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
if (entry.get("status") != "success"):
|
|
||||||
confidence_counter["Crash"] += 1
|
|
||||||
for event in entry.get("events", []):
|
|
||||||
score = event.get("score", None)
|
|
||||||
|
|
||||||
if score is not None:
|
|
||||||
if score == -1:
|
|
||||||
confidence_counter["BAD-1"] += 1
|
|
||||||
elif score > thresh:
|
|
||||||
confidence_counter["PERFECT"] += 1
|
|
||||||
else:
|
|
||||||
confidence_counter["BAD"] += 1
|
|
||||||
|
|
||||||
if confidence_counter:
|
|
||||||
df_conf = pd.DataFrame(
|
|
||||||
confidence_counter.items(),
|
|
||||||
columns=["Category", "Count"]
|
|
||||||
)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.pie(
|
|
||||||
df_conf["Count"],
|
|
||||||
labels=df_conf["Category"],
|
|
||||||
autopct="%1.1f%%",
|
|
||||||
startangle=90
|
|
||||||
)
|
|
||||||
ax.axis("equal")
|
|
||||||
ax.set_title(file_path.name)
|
|
||||||
|
|
||||||
total = sum(confidence_counter.values())
|
|
||||||
correct = confidence_counter["PERFECT"]
|
|
||||||
|
|
||||||
corr_percent = (correct / total) * 100
|
|
||||||
|
|
||||||
st.markdown(f"**Correct: {corr_percent:.2f}% ({correct}/{total})**")
|
|
||||||
st.markdown(f"**Crash: {confidence_counter["Crash"]}**")
|
|
||||||
st.pyplot(fig, width=500)
|
|
||||||
else:
|
|
||||||
st.info("No score data available in this file.")
|
|
||||||
Reference in New Issue
Block a user