Add re-ranker mode to support re-ranking experiments, hopefully we can reduce the loss
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
{
|
{
|
||||||
"dependencies": ["."],
|
"dependencies": ["."],
|
||||||
"graphs": {
|
"graphs": {
|
||||||
"agent": "./agent.ts:agent"
|
"agent": "./agent.ts:agent",
|
||||||
|
"verifier": "./verify.ts:agent"
|
||||||
},
|
},
|
||||||
"env": ".env"
|
"env": ".env"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,10 +7,21 @@ export async function hydratePrompt(path: string, state: any) : Promise<string>
|
|||||||
|
|
||||||
let raw = fs.readFileSync("prompts/" + path, "utf-8");
|
let raw = fs.readFileSync("prompts/" + path, "utf-8");
|
||||||
|
|
||||||
raw = raw.replace("###TITLE###", state.disinformationTitle);
|
if (raw.indexOf("###TITLE###") != -1) {
|
||||||
raw = raw.replace("###LM###", state.messages.at(-1).content);
|
raw = raw.replace("###TITLE###", state.disinformationTitle);
|
||||||
raw = raw.replace("###NTITLE###", state.normalizedClaim);
|
}
|
||||||
raw = raw.replace("###CDATE###", state.date);
|
|
||||||
|
if (raw.indexOf("###LM###") != -1) {
|
||||||
|
raw = raw.replace("###LM###", state.messages.at(-1).content);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (raw.indexOf("###NTITLE###") != -1) {
|
||||||
|
raw = raw.replace("###NTITLE###", state.normalizedClaim);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (raw.indexOf("###CDATE###") != -1) {
|
||||||
|
raw = raw.replace("###CDATE###", state.date);
|
||||||
|
}
|
||||||
|
|
||||||
if (raw.indexOf("###TECLAIM###") != -1) {
|
if (raw.indexOf("###TECLAIM###") != -1) {
|
||||||
const title = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
|
const title = state.proposedTriggerEvent[state.proposedTriggerEventIndex].Event
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
import { END, START, StateGraph } from "@langchain/langgraph";
|
||||||
|
import { MessagesState } from "./state";
|
||||||
|
import { verificationSetup } from "./nodes/verificationSetup";
|
||||||
|
import { ragasMetrics } from "./nodes/ragasMetrics";
|
||||||
|
import { produceRanking } from "./nodes/produceRanking";
|
||||||
|
import { createModelNode } from "./nodes/model";
|
||||||
|
import { loopEndConditional } from "./conditionals/loop_end";
|
||||||
|
import { sort } from "./nodes/sort";
|
||||||
|
|
||||||
|
const verificationModel = createModelNode([], "verify.txt");
|
||||||
|
const relationModel = createModelNode([], "relation.txt");
|
||||||
|
|
||||||
|
const agent = new StateGraph(MessagesState)
|
||||||
|
|
||||||
|
//NODES
|
||||||
|
.addNode(verificationSetup.name, verificationSetup)
|
||||||
|
.addNode("verificationModel", verificationModel)
|
||||||
|
.addNode(ragasMetrics.name, ragasMetrics)
|
||||||
|
.addNode("relationModel", relationModel)
|
||||||
|
|
||||||
|
.addNode(produceRanking.name, produceRanking)
|
||||||
|
.addNode(sort.name, sort)
|
||||||
|
|
||||||
|
.addEdge(START, verificationSetup.name)
|
||||||
|
.addEdge(verificationSetup.name, "verificationModel")
|
||||||
|
.addEdge(verificationSetup.name, ragasMetrics.name)
|
||||||
|
.addEdge(verificationSetup.name, "relationModel")
|
||||||
|
|
||||||
|
.addEdge(ragasMetrics.name, produceRanking.name)
|
||||||
|
.addEdge("verificationModel", produceRanking.name)
|
||||||
|
.addEdge("relationModel", produceRanking.name)
|
||||||
|
|
||||||
|
// @ts-expect-error
|
||||||
|
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, sort.name])
|
||||||
|
|
||||||
|
.addEdge(sort.name, END)
|
||||||
|
|
||||||
|
.compile();
|
||||||
|
|
||||||
|
export {agent}
|
||||||
@@ -7,7 +7,8 @@
|
|||||||
"type": "commonjs",
|
"type": "commonjs",
|
||||||
"main": "run.ts",
|
"main": "run.ts",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "npx tsx run.ts"
|
"dev": "npx tsx run.ts",
|
||||||
|
"rerank": "INPUT_FILE='../../data/rerank.jsonl' OUTPUT_FILE='../../data/reranked.jsonl' AGENT='verifier' MODE='verifier' npx tsx run.ts"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@langchain/langgraph-sdk": "^1.6.2",
|
"@langchain/langgraph-sdk": "^1.6.2",
|
||||||
|
|||||||
+117
-45
@@ -1,18 +1,27 @@
|
|||||||
import fs from "fs";
|
import fs from "fs";
|
||||||
import path from "path";
|
import readline from "readline";
|
||||||
import { Client } from "@langchain/langgraph-sdk";
|
import { Client } from "@langchain/langgraph-sdk";
|
||||||
import cliProgress from "cli-progress";
|
import cliProgress from "cli-progress";
|
||||||
import pLimit from "p-limit";
|
import pLimit from "p-limit";
|
||||||
|
|
||||||
|
const INPUT_FILE = process.env.INPUT_FILE ?? "../../data/claims.json";
|
||||||
|
const OUTPUT_FILE = process.env.OUTPUT_FILE ?? "../../data/results.jsonl";
|
||||||
|
|
||||||
const INPUT_FILE = "../../data/claims.json";
|
|
||||||
const OUTPUT_FILE = "../../data/results.jsonl";
|
|
||||||
const API_URL = "http://localhost:2024";
|
const API_URL = "http://localhost:2024";
|
||||||
const AGENT_NAME = "agent";
|
const AGENT_NAME = process.env.AGENT ?? "agent";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Modes
|
||||||
|
* claim -> claims from DBKF
|
||||||
|
* verifier -> jsonl claims to test reranking with
|
||||||
|
*/
|
||||||
|
const MODE = process.env.MODE ?? "claim";
|
||||||
|
|
||||||
const MAX_CONCURRENCY = 5;
|
const MAX_CONCURRENCY = 5;
|
||||||
|
|
||||||
const client = new Client({ apiUrl: API_URL });
|
const client = new Client({ apiUrl: API_URL });
|
||||||
|
|
||||||
|
|
||||||
type Claim = {
|
type Claim = {
|
||||||
documentUrl: string;
|
documentUrl: string;
|
||||||
text: string;
|
text: string;
|
||||||
@@ -20,12 +29,25 @@ type Claim = {
|
|||||||
[key: string]: any;
|
[key: string]: any;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type VerifierInput = {
|
||||||
|
documentUrl?: string;
|
||||||
|
text?: string;
|
||||||
|
normalized?: string;
|
||||||
|
events?: any;
|
||||||
|
run_id: string;
|
||||||
|
date?: string;
|
||||||
|
[key: string]: any;
|
||||||
|
};
|
||||||
|
|
||||||
type ResultRecord = {
|
type ResultRecord = {
|
||||||
documentUrl: string;
|
documentUrl?: string;
|
||||||
text: string;
|
text?: string;
|
||||||
status: "success" | "error" | "wrapper_crash";
|
status: "success" | "error" | "wrapper_crash";
|
||||||
normalized?: string,
|
normalized?: string;
|
||||||
output?: any;
|
events?: any;
|
||||||
|
run_id: string;
|
||||||
|
date?: string;
|
||||||
|
// error handling
|
||||||
error?: string;
|
error?: string;
|
||||||
dump?: any;
|
dump?: any;
|
||||||
};
|
};
|
||||||
@@ -34,56 +56,106 @@ function appendResult(record: ResultRecord) {
|
|||||||
fs.appendFileSync(OUTPUT_FILE, JSON.stringify(record) + "\n");
|
fs.appendFileSync(OUTPUT_FILE, JSON.stringify(record) + "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
async function processClaim(claim: Claim): Promise<ResultRecord> {
|
async function readJSONL(file: string): Promise<any[]> {
|
||||||
|
const stream = fs.createReadStream(file);
|
||||||
|
|
||||||
|
const rl = readline.createInterface({
|
||||||
|
input: stream,
|
||||||
|
crlfDelay: Infinity
|
||||||
|
});
|
||||||
|
|
||||||
|
const results: any[] = [];
|
||||||
|
|
||||||
|
for await (const line of rl) {
|
||||||
|
if (line.trim().length === 0) continue;
|
||||||
|
results.push(JSON.parse(line));
|
||||||
|
}
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadInputs(): Promise<any[]> {
|
||||||
|
if (INPUT_FILE.endsWith(".jsonl")) {
|
||||||
|
return readJSONL(INPUT_FILE);
|
||||||
|
}
|
||||||
|
|
||||||
|
const raw = fs.readFileSync(INPUT_FILE, "utf-8");
|
||||||
|
return JSON.parse(raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function buildAgentInput(record: Claim | VerifierInput) {
|
||||||
|
if (MODE === "claim") {
|
||||||
|
const claim = record as Claim;
|
||||||
|
|
||||||
|
return {
|
||||||
|
disinformationTitle: claim.text,
|
||||||
|
date: claim.dateCreated
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (MODE === "verifier") {
|
||||||
|
const v = record as VerifierInput;
|
||||||
|
|
||||||
|
return {
|
||||||
|
disinformationTitle: v.text,
|
||||||
|
date: v.date,
|
||||||
|
proposedTriggerEvent: v.events,
|
||||||
|
normalizedClaim: v.normalizedClaim,
|
||||||
|
proposedTriggerEventIndex: 0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(`Unknown mode: ${MODE}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async function processRecord(record: any): Promise<ResultRecord> {
|
||||||
try {
|
try {
|
||||||
const thread = await client.threads.create();
|
const thread = await client.threads.create();
|
||||||
|
|
||||||
const stream = client.runs.stream(
|
const stream = client.runs.stream(thread.thread_id, AGENT_NAME, {
|
||||||
thread.thread_id,
|
input: buildAgentInput(record),
|
||||||
AGENT_NAME,
|
streamMode: "values",
|
||||||
{
|
config: {
|
||||||
input: {
|
recursion_limit: 50
|
||||||
disinformationTitle: claim.text,
|
|
||||||
date: claim.dateCreated
|
|
||||||
},
|
|
||||||
streamMode: "values",
|
|
||||||
config: {
|
|
||||||
recursion_limit: 50
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
);
|
});
|
||||||
|
|
||||||
|
|
||||||
let lastContent: any = null;
|
let lastContent: any = null;
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
// capture latest output
|
lastContent = chunk;
|
||||||
lastContent = chunk
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lastContent?.event != "error") {
|
if (lastContent?.event !== "error") {
|
||||||
return {
|
return {
|
||||||
documentUrl: claim.documentUrl,
|
documentUrl: record.documentUrl,
|
||||||
text: claim.text,
|
text: record.text,
|
||||||
|
date: record.dateCreated,
|
||||||
status: "success",
|
status: "success",
|
||||||
output: lastContent.data.messages?.at(-1) ?? "",
|
events: lastContent.data.proposedTriggerEvent,
|
||||||
normalized: lastContent.data.normalizedClaim
|
normalized: lastContent.data.normalizedClaim,
|
||||||
|
run_id: thread.thread_id
|
||||||
};
|
};
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
return {
|
return {
|
||||||
documentUrl: claim.documentUrl,
|
documentUrl: record.documentUrl,
|
||||||
text: claim.text,
|
text: record.text,
|
||||||
|
date: record.date,
|
||||||
status: "error",
|
status: "error",
|
||||||
dump: lastContent
|
dump: lastContent,
|
||||||
|
run_id: thread.thread_id
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
return {
|
return {
|
||||||
documentUrl: claim.documentUrl,
|
documentUrl: record.documentUrl,
|
||||||
text: claim.text,
|
text: record.text,
|
||||||
|
date: record.date,
|
||||||
status: "wrapper_crash",
|
status: "wrapper_crash",
|
||||||
error: err?.message ?? String(err)
|
error: err?.message ?? String(err),
|
||||||
|
run_id: "NONE"
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -92,10 +164,9 @@ async function processClaim(claim: Claim): Promise<ResultRecord> {
|
|||||||
async function main() {
|
async function main() {
|
||||||
console.log("Reading input file...");
|
console.log("Reading input file...");
|
||||||
|
|
||||||
const raw = fs.readFileSync(INPUT_FILE, "utf-8");
|
const records = await loadInputs();
|
||||||
const claims: Claim[] = JSON.parse(raw);
|
|
||||||
|
|
||||||
console.log(`Loaded ${claims.length} records`);
|
console.log(`Loaded ${records.length} records`);
|
||||||
|
|
||||||
fs.writeFileSync(OUTPUT_FILE, "", { flag: "a" });
|
fs.writeFileSync(OUTPUT_FILE, "", { flag: "a" });
|
||||||
|
|
||||||
@@ -104,18 +175,18 @@ async function main() {
|
|||||||
const progressBar = new cliProgress.SingleBar(
|
const progressBar = new cliProgress.SingleBar(
|
||||||
{
|
{
|
||||||
format:
|
format:
|
||||||
"Progress |{bar}| {percentage}% | {value}/{total} | ETA: {eta}s",
|
"Progress |{bar}| {percentage}% | {value}/{total} | ETA: {eta}s"
|
||||||
},
|
},
|
||||||
cliProgress.Presets.shades_classic
|
cliProgress.Presets.shades_classic
|
||||||
);
|
);
|
||||||
|
|
||||||
progressBar.start(claims.length, 0);
|
progressBar.start(records.length, 0);
|
||||||
|
|
||||||
let completed = 0;
|
let completed = 0;
|
||||||
|
|
||||||
const tasks = claims.map((claim) =>
|
const tasks = records.map((record) =>
|
||||||
limit(async () => {
|
limit(async () => {
|
||||||
const result = await processClaim(claim);
|
const result = await processRecord(record);
|
||||||
|
|
||||||
appendResult(result);
|
appendResult(result);
|
||||||
|
|
||||||
@@ -127,6 +198,7 @@ async function main() {
|
|||||||
await Promise.all(tasks);
|
await Promise.all(tasks);
|
||||||
|
|
||||||
progressBar.stop();
|
progressBar.stop();
|
||||||
|
|
||||||
console.log("Processing complete");
|
console.log("Processing complete");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,21 +7,24 @@ BASE_URL = "https://dbkf.ontotext.com/rest-api/search/documents"
|
|||||||
|
|
||||||
# search parameters
|
# search parameters
|
||||||
# Ukraine: http://weverify.eu/resource/Concept/Q212
|
# Ukraine: http://weverify.eu/resource/Concept/Q212
|
||||||
|
# ("organization", "http://weverify.eu/resource/Organization/3727f7b2aa90ec0716693e5464b28d18"), # StopFake
|
||||||
|
# ("organization", "http://weverify.eu/resource/Organization/c71953fa6cf24ac4178f751c77862070"), # CheckYourFact
|
||||||
|
|
||||||
# COVID: http://weverify.eu/resource/Concept/Q84263196
|
# COVID: http://weverify.eu/resource/Concept/Q84263196
|
||||||
|
# ("organization", "http://weverify.eu/resource/Organization/72b4f61c7cb49873004bea24f0a8f8f9"), # PolitifactFB
|
||||||
|
# ("organization", "http://weverify.eu/resource/Organization/552abae8eb4e003e69a3351eb0eae372") # LeadStories
|
||||||
|
|
||||||
# "documentTypes": "http://schema.org/Claim",
|
# "documentTypes": "http://schema.org/Claim",
|
||||||
DEFAULT_PARAMS = [
|
DEFAULT_PARAMS = [
|
||||||
("concept", "http://weverify.eu/resource/Concept/Q212"),
|
("concept", "http://weverify.eu/resource/Concept/Q84263196"),
|
||||||
("from", "2000-01-01"),
|
("from", "2000-01-01"),
|
||||||
("to", "2026-02-19"),
|
("to", "2026-02-19"),
|
||||||
("lang", "en"),
|
("lang", "en"),
|
||||||
("limit", 5000),
|
("limit", 5000),
|
||||||
("page", 1),
|
("page", 1),
|
||||||
("orderBy", "date"),
|
("orderBy", "date"),
|
||||||
|
("organization", "http://weverify.eu/resource/Organization/72b4f61c7cb49873004bea24f0a8f8f9"), # PolitifactFB
|
||||||
# duplicate keys allowed
|
("organization", "http://weverify.eu/resource/Organization/552abae8eb4e003e69a3351eb0eae372") # LeadStories
|
||||||
("organization", "http://weverify.eu/resource/Organization/3727f7b2aa90ec0716693e5464b28d18"), # StopFake
|
|
||||||
("organization", "http://weverify.eu/resource/Organization/c71953fa6cf24ac4178f751c77862070"), # CheckYourFact
|
|
||||||
]
|
]
|
||||||
|
|
||||||
NUM_RANDOM_CLAIMS = 20
|
NUM_RANDOM_CLAIMS = 20
|
||||||
|
|||||||
@@ -48,7 +48,10 @@ def save_data_clean(file_path, data):
|
|||||||
new_entry["events"] = events
|
new_entry["events"] = events
|
||||||
new_entry.pop("output", None)
|
new_entry.pop("output", None)
|
||||||
new_entry.pop("status", None)
|
new_entry.pop("status", None)
|
||||||
|
new_entry["run_id"]
|
||||||
merged[doc_url] = new_entry
|
merged[doc_url] = new_entry
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
merged[doc_url]["events"].extend(events)
|
merged[doc_url]["events"].extend(events)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user