diff --git a/agent/agent.ts b/agent/agent.ts index 3213059..75bc22d 100644 --- a/agent/agent.ts +++ b/agent/agent.ts @@ -9,9 +9,9 @@ import { verificationSetup } from "./nodes/verificationSetup"; import { dummyRagasMetrics } from "./nodes/dummyRagasMetrics"; import { produceRanking } from "./nodes/produceRanking"; import { createModelNode } from "./nodes/model"; +import { loopEndConditional } from "./conditionals/loop_end"; const triggerEventToolNode = createToolNode(triggerEventToolsByName); -const verificationToolNode = createToolNode([]); const dummyVerificationModel = createDummyModelNode("verification of"); @@ -20,8 +20,6 @@ const triggerEventModel = createModelNode(triggerEventToolsByName, "trigger.txt" const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name); -const verificationToolConditional = createToolConditional("verificationToolNode", produceRanking.name); - const agent = new StateGraph(MessagesState) @@ -36,7 +34,6 @@ const agent = new StateGraph(MessagesState) .addNode(verificationSetup.name, verificationSetup) .addNode("dummyVerificationModel", dummyVerificationModel) .addNode(dummyRagasMetrics.name, dummyRagasMetrics) - .addNode("verificationToolNode", verificationToolNode) .addNode(produceRanking.name, produceRanking) .addEdge(START, normalizationSetup.name) @@ -50,12 +47,12 @@ const agent = new StateGraph(MessagesState) .addEdge(verificationSetup.name, "dummyVerificationModel") .addEdge(verificationSetup.name, dummyRagasMetrics.name) - // @ts-expect-error - .addConditionalEdges("dummyVerificationModel", verificationToolConditional, ["verificationToolNode", produceRanking.name]) - .addEdge("verificationToolNode", "dummyVerificationModel") - .addEdge(dummyRagasMetrics.name, produceRanking.name) + .addEdge("dummyVerificationModel", produceRanking.name) + + .addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, END]) + .compile(); export {agent} \ No newline at end of file diff --git a/agent/conditionals/loop_end.ts b/agent/conditionals/loop_end.ts new file mode 100644 index 0000000..553d3f1 --- /dev/null +++ b/agent/conditionals/loop_end.ts @@ -0,0 +1,16 @@ +import { ConditionalEdgeRouter, END } from "@langchain/langgraph"; +import { MessagesState } from "../state"; + + +export const loopEndConditional: ConditionalEdgeRouter = (state) => { + const triggerEvents = state.proposedTriggerEvent; + const triggerEventsIndex = state.proposedTriggerEventIndex; + + if (triggerEventsIndex == triggerEvents.length-1) { + return END + } + else { + return "verificationSetup" + } + }; + diff --git a/agent/nodes/normalizationSetup.ts b/agent/nodes/normalizationSetup.ts index 58eac63..02a03a1 100644 --- a/agent/nodes/normalizationSetup.ts +++ b/agent/nodes/normalizationSetup.ts @@ -1,10 +1,10 @@ import { GraphNode } from "@langchain/langgraph"; import { MessagesState } from "../state"; import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages"; -import { calculateSimilarity } from "../tools/clan/retreiveExamples"; +import { rankFromCSV } from "../tools/clan/retreiveExamples"; export const normalizationSetup: GraphNode = async (state) => { - let similarityResults = await calculateSimilarity(state.disinformationTitle) + let similarityResults = await rankFromCSV(state.disinformationTitle) console.log(similarityResults) diff --git a/agent/nodes/verificationSetup.ts b/agent/nodes/verificationSetup.ts index 65d7c09..7c20518 100644 --- a/agent/nodes/verificationSetup.ts +++ b/agent/nodes/verificationSetup.ts @@ -1,9 +1,22 @@ import { GraphNode } from "@langchain/langgraph"; -import { MessagesState } from "../state"; -import { HumanMessage } from "@langchain/core/messages"; +import { MessagesState, ProposedTriggerEventArray } from "../state"; +import { logger } from "../utils/logger"; export const verificationSetup: GraphNode = async (state) => { - //TODO: this might not be needed, looks nice on the graph tho - - return { messages: [ new HumanMessage(state.messages.at(-1)?.content ?? "undefined")] }; + //this is kinda doing two things, but having two nodes for it seems overkill + console.log(state.proposedTriggerEvent) + console.log(state.proposedTriggerEventIndex) + if (state.proposedTriggerEvent == undefined) { + logger.warn("No trigger events in memory, parsing") + + let genResponse = state.messages.at(-1)?.content.toString() ?? ""; + const parsed = ProposedTriggerEventArray.parse(JSON.parse(genResponse)); + + return { proposedTriggerEvent: parsed, proposedTriggerEventIndex: 0 }; + } + else { + logger.info("Trigger event index %s", state.proposedTriggerEventIndex+1) + + return { proposedTriggerEvent: state.proposedTriggerEvent, proposedTriggerEventIndex: state.proposedTriggerEventIndex+1 }; + } }; \ No newline at end of file diff --git a/agent/package-lock.json b/agent/package-lock.json index e4a006a..45f72b0 100644 --- a/agent/package-lock.json +++ b/agent/package-lock.json @@ -21,6 +21,8 @@ "fs": "^0.0.1-security", "langchain": "^1.2.14", "selenium-webdriver": "^4.40.0", + "wink-bm25-text-search": "^3.1.2", + "wink-nlp-utils": "^2.1.0", "winston": "^3.19.0" }, "devDependencies": { @@ -1640,6 +1642,12 @@ "node": ">= 0.4" } }, + "node_modules/emoji-regex": { + "version": "9.2.2", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", + "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==", + "license": "MIT" + }, "node_modules/enabled": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/enabled/-/enabled-2.0.0.tgz", @@ -2787,6 +2795,84 @@ "resolved": "https://registry.npmjs.org/validate.io-function/-/validate.io-function-1.0.2.tgz", "integrity": "sha512-LlFybRJEriSuBnUhQyG5bwglhh50EpTL2ul23MPIuR1odjO7XaMLFV8vHGwp7AZciFxtYOeiSCT5st+XSPONiQ==" }, + "node_modules/wink-bm25-text-search": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/wink-bm25-text-search/-/wink-bm25-text-search-3.1.2.tgz", + "integrity": "sha512-s+xY0v/yurUhiUop/XZnf9IvO9XVuwI14X+QTW0JqlmQCg+9ZgVXTMudXKqZuQVsnm5J+RjLnqrOflnD5BLApA==", + "license": "MIT", + "dependencies": { + "wink-eng-lite-web-model": "^1.4.3", + "wink-helpers": "^2.0.0", + "wink-nlp": "^1.12.2", + "wink-nlp-utils": "^2.0.4" + } + }, + "node_modules/wink-distance": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/wink-distance/-/wink-distance-2.0.2.tgz", + "integrity": "sha512-pyEhUB/OKFYcgOC4J6E+c+gwVA/8qg2s5n49mIcUsJZM5iDSa17uOxRQXR4rvfp+gbj55K/I08FwjFBwb6fq3g==", + "license": "MIT", + "dependencies": { + "wink-helpers": "^2.0.0", + "wink-jaro-distance": "^2.0.0" + } + }, + "node_modules/wink-eng-lite-web-model": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/wink-eng-lite-web-model/-/wink-eng-lite-web-model-1.8.1.tgz", + "integrity": "sha512-M2tSOU/rVNkDj8AS8IoKJaM7apJJjS0cN+hE8CPazfnB4A/ojyc9+7RMPk18UOiIdSyWk7MR6w8z9lWix2l5tA==", + "license": "MIT", + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/wink-helpers": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/wink-helpers/-/wink-helpers-2.0.0.tgz", + "integrity": "sha512-I/ZzXrHcNRXuoeFJmp2vMVqDI6UCK02Tds1WP4kSGAmx520gjL1BObVzF7d2ps24tyHIly9ngdB2jwhlFUjPvg==", + "license": "MIT" + }, + "node_modules/wink-jaro-distance": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/wink-jaro-distance/-/wink-jaro-distance-2.0.0.tgz", + "integrity": "sha512-9bcUaXCi9N8iYpGWbFkf83OsBkg17r4hEyxusEzl+nnReLRPqxhB9YNeRn3g54SYnVRNXP029lY3HDsbdxTAuA==", + "license": "MIT" + }, + "node_modules/wink-nlp": { + "version": "1.14.3", + "resolved": "https://registry.npmjs.org/wink-nlp/-/wink-nlp-1.14.3.tgz", + "integrity": "sha512-lvY5iCs3T8I34F8WKS70+2P0U9dWLn3vdPf/Z+m2VK14N7OmqnPzmHfh3moHdusajoQ37Em39z0IZB9K4x/96A==", + "license": "MIT" + }, + "node_modules/wink-nlp-utils": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/wink-nlp-utils/-/wink-nlp-utils-2.1.0.tgz", + "integrity": "sha512-b7PcRhEBNxQmsmht70jLOkwYUyie3da4/cgEXL+CumYO5b/nwV+W7fuMXToh5BtGq1RABznmc2TGTp1Qf/JUXg==", + "license": "MIT", + "dependencies": { + "wink-distance": "^2.0.1", + "wink-eng-lite-web-model": "^1.4.3", + "wink-helpers": "^2.0.0", + "wink-nlp": "^1.12.0", + "wink-porter2-stemmer": "^2.0.1", + "wink-tokenizer": "^5.2.3" + } + }, + "node_modules/wink-porter2-stemmer": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/wink-porter2-stemmer/-/wink-porter2-stemmer-2.0.1.tgz", + "integrity": "sha512-0g+RkkqhRXFmSpJQStVXW5N/WsshWpJXsoDRW7DwVkGI2uDT6IBCoq3xdH5p6IHLaC6ygk7RWUsUx4alKxoagQ==", + "license": "MIT" + }, + "node_modules/wink-tokenizer": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/wink-tokenizer/-/wink-tokenizer-5.3.0.tgz", + "integrity": "sha512-O/yAw0g3FmSgeeQuYAJJfP7fVPB4A6ay0018qASh79aWmIOyPYy4j4r9EQT8xBjicja6lCLvgRVAybmEBaATQA==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^9.0.0" + } + }, "node_modules/winston": { "version": "3.19.0", "resolved": "https://registry.npmjs.org/winston/-/winston-3.19.0.tgz", diff --git a/agent/package.json b/agent/package.json index 78fc41c..0aedfc7 100644 --- a/agent/package.json +++ b/agent/package.json @@ -22,6 +22,8 @@ "fs": "^0.0.1-security", "langchain": "^1.2.14", "selenium-webdriver": "^4.40.0", + "wink-bm25-text-search": "^3.1.2", + "wink-nlp-utils": "^2.1.0", "winston": "^3.19.0" }, "devDependencies": { diff --git a/agent/state.ts b/agent/state.ts index 9639a2e..14ef927 100644 --- a/agent/state.ts +++ b/agent/state.ts @@ -10,8 +10,20 @@ import { } from "@langchain/langgraph"; import { z } from "zod/v4"; +export const ProposedTriggerEvent = z.object({ + Event: z.string(), + ReasoningWhyRelevant: z.string(), + SearchQuery: z.string(), + Url: z.url(), + IsItselfDisinformation: z.boolean() +}) + +export const ProposedTriggerEventArray = z.array(ProposedTriggerEvent); + export const MessagesState = new StateSchema({ + disinformationTitle: z.string(), messages: MessagesValue, - // normalizationContext: z.map(z.string(), z.string()), - disinformationTitle: z.string() -}); \ No newline at end of file + proposedTriggerEvent: ProposedTriggerEventArray, + proposedTriggerEventIndex: z.int(), +}); + diff --git a/agent/tools/clan/retreiveExamples.ts b/agent/tools/clan/retreiveExamples.ts index 5dcfdd9..514f6b2 100644 --- a/agent/tools/clan/retreiveExamples.ts +++ b/agent/tools/clan/retreiveExamples.ts @@ -1,16 +1,16 @@ import { parse } from "csv-parse"; import fs from "fs"; import { pipeline, cos_sim } from "@huggingface/transformers"; +import bm25Factory from "wink-bm25-text-search"; +import nlp from "wink-nlp-utils"; import { logger } from "../../utils/logger"; -//TODO, am getting duplicates, is it from the multi files? const CSV_PATHS = [ "./tools/clan/dev-eng.csv", - // "./tools/clan/test-eng.csv", "./tools/clan/train-eng.csv", ]; -const CACHE_PATH = "./tools/clan/dev.embeddings.json"; +const CACHE_PATH = "./tools/clan/csv.cache.json"; type EmbeddingCache = { rawtexts: string[]; @@ -18,104 +18,262 @@ type EmbeddingCache = { embeddings: number[][]; }; -export type NormalisedMatch = { - index: number; - score: number; +export type RetrievalItem = { + id: string | number; rawtext: string; - cleantext: string; + cleantext?: string; }; -let rawtexts: string[] = []; -let cleantexts: string[] = []; -let embeddings: number[][] = []; +export type RankedResult = RetrievalItem & { + denseScore: number; + sparseScore: number; + fusedScore: number; +}; +let csvRawtexts: string[] = []; +let csvCleantexts: string[] = []; +let csvEmbeddings: number[][] = []; +let csvBM25: any = null; +let csvLoaded = false; + +logger.info("Loading embedding model..."); const featureExtractor = await pipeline( "feature-extraction", "Xenova/all-MiniLM-L6-v2" ); +logger.info("Embedding model loaded"); +//Cached entrypoint +export async function rankFromCSV( + query: string, + topK = 5 +): Promise { + await ensureCSVLoaded(); + + logger.info("Ranking from CSV cache..."); + + const queryEmbedding = await embedText(query); + + const denseScores = csvEmbeddings.map((docEmbedding) => + cos_sim(docEmbedding, queryEmbedding) + ); + + const sparseScores = computeSparseScores(query, csvBM25, csvRawtexts); + + const fusedScores = reciprocalRankFusion([denseScores, sparseScores]); + + const ranked = csvRawtexts + .map((text, i) => ({ + id: i, + rawtext: text, + cleantext: csvCleantexts[i], + denseScore: denseScores[i], + sparseScore: sparseScores[i], + fusedScore: fusedScores[i], + })) + .sort((a, b) => b.fusedScore - a.fusedScore); + + logger.info("Ranking complete (CSV mode)"); + + return ranked.slice(0, topK); +} + +//Dynamic Entrypoint +export async function rankDynamically( + query: string, + items: RetrievalItem[], + topK = 5 +): Promise { + logger.info("Ranking dynamically (no cache)..."); + + if (!items.length) return []; + + const texts = items.map((i) => i.rawtext); + + const queryEmbedding = await embedText(query); + + const docEmbeddings = await Promise.all( + texts.map((text) => embedText(text)) + ); + + const denseScores = docEmbeddings.map((embedding) => + cos_sim(embedding, queryEmbedding) + ); + + const localBM25 = buildBM25(texts); + + const sparseScores = computeSparseScores(query, localBM25, texts); + + const fusedScores = reciprocalRankFusion([denseScores, sparseScores]); + + const ranked = items + .map((item, i) => ({ + ...item, + denseScore: denseScores[i], + sparseScore: sparseScores[i], + fusedScore: fusedScores[i], + })) + .sort((a, b) => b.fusedScore - a.fusedScore); + + logger.info("Ranking complete (dynamic mode)"); + + return ranked.slice(0, topK); +} + +//CSV stuff +async function ensureCSVLoaded(): Promise { + if (csvLoaded) return; + + logger.info("Initializing CSV ranking mode..."); -async function loadOrBuildCache(): Promise { if (fs.existsSync(CACHE_PATH)) { - logger.info("Loading embeddings from cache"); + logger.info("Loading CSV cache from disk..."); const raw = fs.readFileSync(CACHE_PATH, "utf-8"); const cache: EmbeddingCache = JSON.parse(raw); - rawtexts = cache.rawtexts; - cleantexts = cache.cleantexts; - embeddings = cache.embeddings.map(e => Array.from(e)); + csvRawtexts = cache.rawtexts; + csvCleantexts = cache.cleantexts; + csvEmbeddings = cache.embeddings; - logger.info("Loaded %s embeddings", embeddings.length); - return; + logger.info("Loaded %s cached embeddings", csvEmbeddings.length); + } else { + logger.warn("CSV cache not found. Building embeddings..."); + + const seen = new Set(); + + for (const path of CSV_PATHS) { + await processCSV(path, seen); + } + + const cache: EmbeddingCache = { + rawtexts: csvRawtexts, + cleantexts: csvCleantexts, + embeddings: csvEmbeddings, + }; + + fs.writeFileSync(CACHE_PATH, JSON.stringify(cache)); + logger.info("Cache written (%s embeddings)", csvEmbeddings.length); } - logger.warn("Cache not found. Generating embeddings"); + csvBM25 = buildBM25(csvRawtexts); - for (const csvPath of CSV_PATHS) { - await buildCacheFromCSV(csvPath); - } - - const cache: EmbeddingCache = { - rawtexts, - cleantexts, - embeddings, - }; - - fs.writeFileSync(CACHE_PATH, JSON.stringify(cache)); - - logger.info("Cached %s embeddings", embeddings.length); + csvLoaded = true; + logger.info("CSV mode ready"); } -async function buildCacheFromCSV(csvPath: string): Promise { - let count = 0; +async function processCSV( + path: string, + seen: Set +): Promise { + logger.info("Processing CSV: %s", path); - logger.info("Processing CSV: %s", csvPath); - - const stream = fs.createReadStream(csvPath).pipe(parse()); + const stream = fs.createReadStream(path).pipe(parse()); for await (const row of stream) { const text = row[0]; - if (!text) continue; + if (!text || seen.has(text)) continue; - const output = await featureExtractor(text, { - pooling: "mean", - normalize: true, - }); + seen.add(text); - rawtexts.push(text); - cleantexts.push(row[1]); - const vector = Array.from(output.data as Float32Array); - embeddings.push(vector); + const embedding = await embedText(text); + csvRawtexts.push(text); + csvCleantexts.push(row[1]); + csvEmbeddings.push(embedding); - count++; - if (count % 100 === 0) { - logger.info("[%s] Processed %s rows", csvPath, count); + if (csvRawtexts.length % 100 === 0) { + logger.info("Embedded %s documents...", csvRawtexts.length); } } - logger.info("[%s] Finished (%s rows)", csvPath, count); + logger.info("Finished CSV: %s", path); } -export async function calculateSimilarity( - query: string, - topK = 5 -): Promise { - await loadOrBuildCache() - const queryEmbedding = await featureExtractor(query, { +async function embedText(text: string): Promise { + const output = await featureExtractor(text, { pooling: "mean", normalize: true, }); - return embeddings - .map((embedding, index) => ({ - index, - score: cos_sim(embedding, queryEmbedding.data as number[]), - rawtext: rawtexts[index], - cleantext: cleantexts[index] - })) - .sort((a, b) => b.score - a.score) - .slice(0, topK); -} \ No newline at end of file + return Array.from(output.data as Float32Array); +} + +function buildBM25(texts: string[]) { + logger.info("Building BM25 index (%s docs)...", texts.length); + + const bm25 = bm25Factory(); + + bm25.defineConfig({ + fldWeights: { text: 1 }, + bm25Params: { k1: 1.2, b: 0.75 }, + }); + + bm25.definePrepTasks([ + nlp.string.lowerCase, + nlp.string.tokenize0, + nlp.tokens.removeWords, + ]); + + texts.forEach((text, i) => { + bm25.addDoc({ text }, i); + }); + + bm25.consolidate(); + + logger.info("BM25 ready"); + + return bm25; +} + +function computeSparseScores( + query: string, + bm25: any, + texts: string[] +): number[] { + const results = bm25.search(query); + + const scores = new Array(texts.length).fill(0); + + results.forEach((r: any) => { + scores[r[0]] = r[1]; + }); + + return scores; +} + +function reciprocalRankFusion( + scoreLists: number[][], + k = 60 +): number[] { + const length = scoreLists[0].length; + const fused = new Array(length).fill(0); + + for (const scores of scoreLists) { + const ranked = scores + .map((score, i) => ({ score, i })) + .sort((a, b) => b.score - a.score) + .map((x) => x.i); + + ranked.forEach((docIndex, rank) => { + fused[docIndex] += 1 / (k + rank); + }); + } + + return fused; +} + +// console.log(await rankFromCSV("barrack obama")) +// console.log( +// await rankDynamically( +// "i fell over", +// [ +// { id: 1, rawtext: "I slipped and fell on the floor." }, +// { id: 2, rawtext: "Barack Obama was the 44th president." }, +// { id: 3, rawtext: "He tripped and hurt his knee badly." }, +// { id: 4, rawtext: "The weather is sunny today." }, +// { id: 5, rawtext: "She lost her balance and fell down the stairs." }, +// ] +// ) +// ); \ No newline at end of file diff --git a/agent/tools/triggerEventTools.ts b/agent/tools/triggerEventTools.ts index afca14f..6d7c0c4 100644 --- a/agent/tools/triggerEventTools.ts +++ b/agent/tools/triggerEventTools.ts @@ -2,18 +2,20 @@ import { tool } from "@langchain/core/tools"; import * as z from "zod"; import { queryScraper } from "./webSearch"; import { extractWebpageContent } from "./webpageFetch"; +import { rankDynamically } from "./clan/retreiveExamples"; -function rankAndDisplayData(data: string[]):string { - //TODO: hybrid re-ranking of the provided data - return data.join("\n") +async function rankAndDisplayData(data: string[], context: string):Promise { + let index = 0; + let ranked = await rankDynamically(context, data.map(irm => ({ id: index++, rawtext: irm }))) + return ranked.map(itm => itm.rawtext).join("\n") } // Define tools const webSearch = tool( async ({ a }) => { const data = await queryScraper(a); - return rankAndDisplayData(data); + return await rankAndDisplayData(data, a); }, { name: "WebSearch", @@ -25,15 +27,16 @@ const webSearch = tool( ); const openWebpage = tool( - async ({ a }) => { + async ({ a, b }) => { const data = await extractWebpageContent(a); - return rankAndDisplayData(data); + return rankAndDisplayData(data, b); }, { name: "OpenWebpage", description: "Opens webpage and returns most relevent snippets", schema: z.object({ a: z.string().describe("URL"), + b: z.string().describe("What to match against in webpage content"), }), } );