Add re-ranker mode to support re-ranking experiments, hopefully we can reduce the loss

This commit is contained in:
William Jeynes
2026-03-06 17:27:09 +00:00
parent f14d112017
commit ef6330ec07
7 changed files with 189 additions and 58 deletions
+2 -1
View File
@@ -1,7 +1,8 @@
{ {
"dependencies": ["."], "dependencies": ["."],
"graphs": { "graphs": {
"agent": "./agent.ts:agent" "agent": "./agent.ts:agent",
"verifier": "./verify.ts:agent"
}, },
"env": ".env" "env": ".env"
} }
+15 -4
View File
@@ -7,10 +7,21 @@ export async function hydratePrompt(path: string, state: any) : Promise<string>
let raw = fs.readFileSync("prompts/" + path, "utf-8"); let raw = fs.readFileSync("prompts/" + path, "utf-8");
raw = raw.replace("###TITLE###", state.disinformationTitle); if (raw.indexOf("###TITLE###") != -1) {
raw = raw.replace("###LM###", state.messages.at(-1).content); raw = raw.replace("###TITLE###", state.disinformationTitle);
raw = raw.replace("###NTITLE###", state.normalizedClaim); }
raw = raw.replace("###CDATE###", state.date);
if (raw.indexOf("###LM###") != -1) {
raw = raw.replace("###LM###", state.messages.at(-1).content);
}
if (raw.indexOf("###NTITLE###") != -1) {
raw = raw.replace("###NTITLE###", state.normalizedClaim);
}
if (raw.indexOf("###CDATE###") != -1) {
raw = raw.replace("###CDATE###", state.date);
}
if (raw.indexOf("###TECLAIM###") != -1) { if (raw.indexOf("###TECLAIM###") != -1) {
const title = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event const title = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
+40
View File
@@ -0,0 +1,40 @@
import { END, START, StateGraph } from "@langchain/langgraph";
import { MessagesState } from "./state";
import { verificationSetup } from "./nodes/verificationSetup";
import { ragasMetrics } from "./nodes/ragasMetrics";
import { produceRanking } from "./nodes/produceRanking";
import { createModelNode } from "./nodes/model";
import { loopEndConditional } from "./conditionals/loop_end";
import { sort } from "./nodes/sort";
const verificationModel = createModelNode([], "verify.txt");
const relationModel = createModelNode([], "relation.txt");
const agent = new StateGraph(MessagesState)
//NODES
.addNode(verificationSetup.name, verificationSetup)
.addNode("verificationModel", verificationModel)
.addNode(ragasMetrics.name, ragasMetrics)
.addNode("relationModel", relationModel)
.addNode(produceRanking.name, produceRanking)
.addNode(sort.name, sort)
.addEdge(START, verificationSetup.name)
.addEdge(verificationSetup.name, "verificationModel")
.addEdge(verificationSetup.name, ragasMetrics.name)
.addEdge(verificationSetup.name, "relationModel")
.addEdge(ragasMetrics.name, produceRanking.name)
.addEdge("verificationModel", produceRanking.name)
.addEdge("relationModel", produceRanking.name)
// @ts-expect-error
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
.addEdge(sort.name, END)
.compile();
export {agent}
+2 -1
View File
@@ -7,7 +7,8 @@
"type": "commonjs", "type": "commonjs",
"main": "run.ts", "main": "run.ts",
"scripts": { "scripts": {
"dev": "npx tsx run.ts" "dev": "npx tsx run.ts",
"rerank": "INPUT_FILE='../../data/rerank.jsonl' OUTPUT_FILE='../../data/reranked.jsonl' AGENT='verifier' MODE='verifier' npx tsx run.ts"
}, },
"dependencies": { "dependencies": {
"@langchain/langgraph-sdk": "^1.6.2", "@langchain/langgraph-sdk": "^1.6.2",
+119 -47
View File
@@ -1,18 +1,27 @@
import fs from "fs"; import fs from "fs";
import path from "path"; import readline from "readline";
import { Client } from "@langchain/langgraph-sdk"; import { Client } from "@langchain/langgraph-sdk";
import cliProgress from "cli-progress"; import cliProgress from "cli-progress";
import pLimit from "p-limit"; import pLimit from "p-limit";
const INPUT_FILE = process.env.INPUT_FILE ?? "../../data/claims.json";
const OUTPUT_FILE = process.env.OUTPUT_FILE ?? "../../data/results.jsonl";
const INPUT_FILE = "../../data/claims.json";
const OUTPUT_FILE = "../../data/results.jsonl";
const API_URL = "http://localhost:2024"; const API_URL = "http://localhost:2024";
const AGENT_NAME = "agent"; const AGENT_NAME = process.env.AGENT ?? "agent";
/**
* Modes
* claim -> claims from DBKF
* verifier -> jsonl claims to test reranking with
*/
const MODE = process.env.MODE ?? "claim";
const MAX_CONCURRENCY = 5; const MAX_CONCURRENCY = 5;
const client = new Client({ apiUrl: API_URL }); const client = new Client({ apiUrl: API_URL });
type Claim = { type Claim = {
documentUrl: string; documentUrl: string;
text: string; text: string;
@@ -20,12 +29,25 @@ type Claim = {
[key: string]: any; [key: string]: any;
}; };
type VerifierInput = {
documentUrl?: string;
text?: string;
normalized?: string;
events?: any;
run_id: string;
date?: string;
[key: string]: any;
};
type ResultRecord = { type ResultRecord = {
documentUrl: string; documentUrl?: string;
text: string; text?: string;
status: "success" | "error" | "wrapper_crash"; status: "success" | "error" | "wrapper_crash";
normalized?: string, normalized?: string;
output?: any; events?: any;
run_id: string;
date?: string;
// error handling
error?: string; error?: string;
dump?: any; dump?: any;
}; };
@@ -34,56 +56,106 @@ function appendResult(record: ResultRecord) {
fs.appendFileSync(OUTPUT_FILE, JSON.stringify(record) + "\n"); fs.appendFileSync(OUTPUT_FILE, JSON.stringify(record) + "\n");
} }
async function processClaim(claim: Claim): Promise<ResultRecord> { async function readJSONL(file: string): Promise<any[]> {
const stream = fs.createReadStream(file);
const rl = readline.createInterface({
input: stream,
crlfDelay: Infinity
});
const results: any[] = [];
for await (const line of rl) {
if (line.trim().length === 0) continue;
results.push(JSON.parse(line));
}
return results;
}
async function loadInputs(): Promise<any[]> {
if (INPUT_FILE.endsWith(".jsonl")) {
return readJSONL(INPUT_FILE);
}
const raw = fs.readFileSync(INPUT_FILE, "utf-8");
return JSON.parse(raw);
}
function buildAgentInput(record: Claim | VerifierInput) {
if (MODE === "claim") {
const claim = record as Claim;
return {
disinformationTitle: claim.text,
date: claim.dateCreated
};
}
if (MODE === "verifier") {
const v = record as VerifierInput;
return {
disinformationTitle: v.text,
date: v.date,
proposedTriggerEvent: v.events,
normalizedClaim: v.normalizedClaim,
proposedTriggerEventIndex: 0
};
}
throw new Error(`Unknown mode: ${MODE}`);
}
async function processRecord(record: any): Promise<ResultRecord> {
try { try {
const thread = await client.threads.create(); const thread = await client.threads.create();
const stream = client.runs.stream( const stream = client.runs.stream(thread.thread_id, AGENT_NAME, {
thread.thread_id, input: buildAgentInput(record),
AGENT_NAME, streamMode: "values",
{ config: {
input: { recursion_limit: 50
disinformationTitle: claim.text,
date: claim.dateCreated
},
streamMode: "values",
config: {
recursion_limit: 50
}
} }
); });
let lastContent: any = null; let lastContent: any = null;
for await (const chunk of stream) { for await (const chunk of stream) {
// capture latest output lastContent = chunk;
lastContent = chunk
} }
if (lastContent?.event != "error") { if (lastContent?.event !== "error") {
return { return {
documentUrl: claim.documentUrl, documentUrl: record.documentUrl,
text: claim.text, text: record.text,
date: record.dateCreated,
status: "success", status: "success",
output: lastContent.data.messages?.at(-1) ?? "", events: lastContent.data.proposedTriggerEvent,
normalized: lastContent.data.normalizedClaim normalized: lastContent.data.normalizedClaim,
run_id: thread.thread_id
}; };
} } else {
else {
return { return {
documentUrl: claim.documentUrl, documentUrl: record.documentUrl,
text: claim.text, text: record.text,
date: record.date,
status: "error", status: "error",
dump: lastContent dump: lastContent,
run_id: thread.thread_id
}; };
} }
} catch (err: any) { } catch (err: any) {
return { return {
documentUrl: claim.documentUrl, documentUrl: record.documentUrl,
text: claim.text, text: record.text,
date: record.date,
status: "wrapper_crash", status: "wrapper_crash",
error: err?.message ?? String(err) error: err?.message ?? String(err),
run_id: "NONE"
}; };
} }
} }
@@ -92,10 +164,9 @@ async function processClaim(claim: Claim): Promise<ResultRecord> {
async function main() { async function main() {
console.log("Reading input file..."); console.log("Reading input file...");
const raw = fs.readFileSync(INPUT_FILE, "utf-8"); const records = await loadInputs();
const claims: Claim[] = JSON.parse(raw);
console.log(`Loaded ${claims.length} records`); console.log(`Loaded ${records.length} records`);
fs.writeFileSync(OUTPUT_FILE, "", { flag: "a" }); fs.writeFileSync(OUTPUT_FILE, "", { flag: "a" });
@@ -104,18 +175,18 @@ async function main() {
const progressBar = new cliProgress.SingleBar( const progressBar = new cliProgress.SingleBar(
{ {
format: format:
"Progress |{bar}| {percentage}% | {value}/{total} | ETA: {eta}s", "Progress |{bar}| {percentage}% | {value}/{total} | ETA: {eta}s"
}, },
cliProgress.Presets.shades_classic cliProgress.Presets.shades_classic
); );
progressBar.start(claims.length, 0); progressBar.start(records.length, 0);
let completed = 0; let completed = 0;
const tasks = claims.map((claim) => const tasks = records.map((record) =>
limit(async () => { limit(async () => {
const result = await processClaim(claim); const result = await processRecord(record);
appendResult(result); appendResult(result);
@@ -127,9 +198,10 @@ async function main() {
await Promise.all(tasks); await Promise.all(tasks);
progressBar.stop(); progressBar.stop();
console.log("Processing complete"); console.log("Processing complete");
} }
main().catch((err) => { main().catch((err) => {
console.error("Fatal error:", err); console.error("Fatal error:", err);
}); });
+8 -5
View File
@@ -7,21 +7,24 @@ BASE_URL = "https://dbkf.ontotext.com/rest-api/search/documents"
# search parameters # search parameters
# Ukraine: http://weverify.eu/resource/Concept/Q212 # Ukraine: http://weverify.eu/resource/Concept/Q212
# ("organization", "http://weverify.eu/resource/Organization/3727f7b2aa90ec0716693e5464b28d18"), # StopFake
# ("organization", "http://weverify.eu/resource/Organization/c71953fa6cf24ac4178f751c77862070"), # CheckYourFact
# COVID: http://weverify.eu/resource/Concept/Q84263196 # COVID: http://weverify.eu/resource/Concept/Q84263196
# ("organization", "http://weverify.eu/resource/Organization/72b4f61c7cb49873004bea24f0a8f8f9"), # PolitifactFB
# ("organization", "http://weverify.eu/resource/Organization/552abae8eb4e003e69a3351eb0eae372") # LeadStories
# "documentTypes": "http://schema.org/Claim", # "documentTypes": "http://schema.org/Claim",
DEFAULT_PARAMS = [ DEFAULT_PARAMS = [
("concept", "http://weverify.eu/resource/Concept/Q212"), ("concept", "http://weverify.eu/resource/Concept/Q84263196"),
("from", "2000-01-01"), ("from", "2000-01-01"),
("to", "2026-02-19"), ("to", "2026-02-19"),
("lang", "en"), ("lang", "en"),
("limit", 5000), ("limit", 5000),
("page", 1), ("page", 1),
("orderBy", "date"), ("orderBy", "date"),
("organization", "http://weverify.eu/resource/Organization/72b4f61c7cb49873004bea24f0a8f8f9"), # PolitifactFB
# duplicate keys allowed ("organization", "http://weverify.eu/resource/Organization/552abae8eb4e003e69a3351eb0eae372") # LeadStories
("organization", "http://weverify.eu/resource/Organization/3727f7b2aa90ec0716693e5464b28d18"), # StopFake
("organization", "http://weverify.eu/resource/Organization/c71953fa6cf24ac4178f751c77862070"), # CheckYourFact
] ]
NUM_RANDOM_CLAIMS = 20 NUM_RANDOM_CLAIMS = 20
+3
View File
@@ -48,7 +48,10 @@ def save_data_clean(file_path, data):
new_entry["events"] = events new_entry["events"] = events
new_entry.pop("output", None) new_entry.pop("output", None)
new_entry.pop("status", None) new_entry.pop("status", None)
new_entry["run_id"]
merged[doc_url] = new_entry merged[doc_url] = new_entry
else: else:
merged[doc_url]["events"].extend(events) merged[doc_url]["events"].extend(events)