Add re-ranker mode to support re-ranking experiments, hopefully we can reduce the loss
This commit is contained in:
@@ -7,7 +7,8 @@
|
||||
"type": "commonjs",
|
||||
"main": "run.ts",
|
||||
"scripts": {
|
||||
"dev": "npx tsx run.ts"
|
||||
"dev": "npx tsx run.ts",
|
||||
"rerank": "INPUT_FILE='../../data/rerank.jsonl' OUTPUT_FILE='../../data/reranked.jsonl' AGENT='verifier' MODE='verifier' npx tsx run.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@langchain/langgraph-sdk": "^1.6.2",
|
||||
|
||||
+119
-47
@@ -1,18 +1,27 @@
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import readline from "readline";
|
||||
import { Client } from "@langchain/langgraph-sdk";
|
||||
import cliProgress from "cli-progress";
|
||||
import pLimit from "p-limit";
|
||||
|
||||
const INPUT_FILE = process.env.INPUT_FILE ?? "../../data/claims.json";
|
||||
const OUTPUT_FILE = process.env.OUTPUT_FILE ?? "../../data/results.jsonl";
|
||||
|
||||
const INPUT_FILE = "../../data/claims.json";
|
||||
const OUTPUT_FILE = "../../data/results.jsonl";
|
||||
const API_URL = "http://localhost:2024";
|
||||
const AGENT_NAME = "agent";
|
||||
const AGENT_NAME = process.env.AGENT ?? "agent";
|
||||
|
||||
/**
|
||||
* Modes
|
||||
* claim -> claims from DBKF
|
||||
* verifier -> jsonl claims to test reranking with
|
||||
*/
|
||||
const MODE = process.env.MODE ?? "claim";
|
||||
|
||||
const MAX_CONCURRENCY = 5;
|
||||
|
||||
const client = new Client({ apiUrl: API_URL });
|
||||
|
||||
|
||||
type Claim = {
|
||||
documentUrl: string;
|
||||
text: string;
|
||||
@@ -20,12 +29,25 @@ type Claim = {
|
||||
[key: string]: any;
|
||||
};
|
||||
|
||||
type VerifierInput = {
|
||||
documentUrl?: string;
|
||||
text?: string;
|
||||
normalized?: string;
|
||||
events?: any;
|
||||
run_id: string;
|
||||
date?: string;
|
||||
[key: string]: any;
|
||||
};
|
||||
|
||||
type ResultRecord = {
|
||||
documentUrl: string;
|
||||
text: string;
|
||||
documentUrl?: string;
|
||||
text?: string;
|
||||
status: "success" | "error" | "wrapper_crash";
|
||||
normalized?: string,
|
||||
output?: any;
|
||||
normalized?: string;
|
||||
events?: any;
|
||||
run_id: string;
|
||||
date?: string;
|
||||
// error handling
|
||||
error?: string;
|
||||
dump?: any;
|
||||
};
|
||||
@@ -34,56 +56,106 @@ function appendResult(record: ResultRecord) {
|
||||
fs.appendFileSync(OUTPUT_FILE, JSON.stringify(record) + "\n");
|
||||
}
|
||||
|
||||
async function processClaim(claim: Claim): Promise<ResultRecord> {
|
||||
async function readJSONL(file: string): Promise<any[]> {
|
||||
const stream = fs.createReadStream(file);
|
||||
|
||||
const rl = readline.createInterface({
|
||||
input: stream,
|
||||
crlfDelay: Infinity
|
||||
});
|
||||
|
||||
const results: any[] = [];
|
||||
|
||||
for await (const line of rl) {
|
||||
if (line.trim().length === 0) continue;
|
||||
results.push(JSON.parse(line));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
async function loadInputs(): Promise<any[]> {
|
||||
if (INPUT_FILE.endsWith(".jsonl")) {
|
||||
return readJSONL(INPUT_FILE);
|
||||
}
|
||||
|
||||
const raw = fs.readFileSync(INPUT_FILE, "utf-8");
|
||||
return JSON.parse(raw);
|
||||
}
|
||||
|
||||
|
||||
function buildAgentInput(record: Claim | VerifierInput) {
|
||||
if (MODE === "claim") {
|
||||
const claim = record as Claim;
|
||||
|
||||
return {
|
||||
disinformationTitle: claim.text,
|
||||
date: claim.dateCreated
|
||||
};
|
||||
}
|
||||
|
||||
if (MODE === "verifier") {
|
||||
const v = record as VerifierInput;
|
||||
|
||||
return {
|
||||
disinformationTitle: v.text,
|
||||
date: v.date,
|
||||
proposedTriggerEvent: v.events,
|
||||
normalizedClaim: v.normalizedClaim,
|
||||
proposedTriggerEventIndex: 0
|
||||
};
|
||||
}
|
||||
|
||||
throw new Error(`Unknown mode: ${MODE}`);
|
||||
}
|
||||
|
||||
|
||||
async function processRecord(record: any): Promise<ResultRecord> {
|
||||
try {
|
||||
const thread = await client.threads.create();
|
||||
|
||||
const stream = client.runs.stream(
|
||||
thread.thread_id,
|
||||
AGENT_NAME,
|
||||
{
|
||||
input: {
|
||||
disinformationTitle: claim.text,
|
||||
date: claim.dateCreated
|
||||
},
|
||||
streamMode: "values",
|
||||
config: {
|
||||
recursion_limit: 50
|
||||
}
|
||||
const stream = client.runs.stream(thread.thread_id, AGENT_NAME, {
|
||||
input: buildAgentInput(record),
|
||||
streamMode: "values",
|
||||
config: {
|
||||
recursion_limit: 50
|
||||
}
|
||||
);
|
||||
|
||||
});
|
||||
|
||||
let lastContent: any = null;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
// capture latest output
|
||||
lastContent = chunk
|
||||
lastContent = chunk;
|
||||
}
|
||||
|
||||
if (lastContent?.event != "error") {
|
||||
|
||||
if (lastContent?.event !== "error") {
|
||||
return {
|
||||
documentUrl: claim.documentUrl,
|
||||
text: claim.text,
|
||||
documentUrl: record.documentUrl,
|
||||
text: record.text,
|
||||
date: record.dateCreated,
|
||||
status: "success",
|
||||
output: lastContent.data.messages?.at(-1) ?? "",
|
||||
normalized: lastContent.data.normalizedClaim
|
||||
events: lastContent.data.proposedTriggerEvent,
|
||||
normalized: lastContent.data.normalizedClaim,
|
||||
run_id: thread.thread_id
|
||||
};
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
return {
|
||||
documentUrl: claim.documentUrl,
|
||||
text: claim.text,
|
||||
documentUrl: record.documentUrl,
|
||||
text: record.text,
|
||||
date: record.date,
|
||||
status: "error",
|
||||
dump: lastContent
|
||||
dump: lastContent,
|
||||
run_id: thread.thread_id
|
||||
};
|
||||
}
|
||||
} catch (err: any) {
|
||||
return {
|
||||
documentUrl: claim.documentUrl,
|
||||
text: claim.text,
|
||||
documentUrl: record.documentUrl,
|
||||
text: record.text,
|
||||
date: record.date,
|
||||
status: "wrapper_crash",
|
||||
error: err?.message ?? String(err)
|
||||
error: err?.message ?? String(err),
|
||||
run_id: "NONE"
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -92,10 +164,9 @@ async function processClaim(claim: Claim): Promise<ResultRecord> {
|
||||
async function main() {
|
||||
console.log("Reading input file...");
|
||||
|
||||
const raw = fs.readFileSync(INPUT_FILE, "utf-8");
|
||||
const claims: Claim[] = JSON.parse(raw);
|
||||
const records = await loadInputs();
|
||||
|
||||
console.log(`Loaded ${claims.length} records`);
|
||||
console.log(`Loaded ${records.length} records`);
|
||||
|
||||
fs.writeFileSync(OUTPUT_FILE, "", { flag: "a" });
|
||||
|
||||
@@ -104,18 +175,18 @@ async function main() {
|
||||
const progressBar = new cliProgress.SingleBar(
|
||||
{
|
||||
format:
|
||||
"Progress |{bar}| {percentage}% | {value}/{total} | ETA: {eta}s",
|
||||
"Progress |{bar}| {percentage}% | {value}/{total} | ETA: {eta}s"
|
||||
},
|
||||
cliProgress.Presets.shades_classic
|
||||
);
|
||||
|
||||
progressBar.start(claims.length, 0);
|
||||
progressBar.start(records.length, 0);
|
||||
|
||||
let completed = 0;
|
||||
|
||||
const tasks = claims.map((claim) =>
|
||||
const tasks = records.map((record) =>
|
||||
limit(async () => {
|
||||
const result = await processClaim(claim);
|
||||
const result = await processRecord(record);
|
||||
|
||||
appendResult(result);
|
||||
|
||||
@@ -127,9 +198,10 @@ async function main() {
|
||||
await Promise.all(tasks);
|
||||
|
||||
progressBar.stop();
|
||||
|
||||
console.log("Processing complete");
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Fatal error:", err);
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user