Refactor example retreiving, add option for dynamic data. Add hybrid reranking to tooling. Add parsing and loop infrastructure for trigger event processing
This commit is contained in:
+5
-8
@@ -9,9 +9,9 @@ import { verificationSetup } from "./nodes/verificationSetup";
|
|||||||
import { dummyRagasMetrics } from "./nodes/dummyRagasMetrics";
|
import { dummyRagasMetrics } from "./nodes/dummyRagasMetrics";
|
||||||
import { produceRanking } from "./nodes/produceRanking";
|
import { produceRanking } from "./nodes/produceRanking";
|
||||||
import { createModelNode } from "./nodes/model";
|
import { createModelNode } from "./nodes/model";
|
||||||
|
import { loopEndConditional } from "./conditionals/loop_end";
|
||||||
|
|
||||||
const triggerEventToolNode = createToolNode(triggerEventToolsByName);
|
const triggerEventToolNode = createToolNode(triggerEventToolsByName);
|
||||||
const verificationToolNode = createToolNode([]);
|
|
||||||
|
|
||||||
const dummyVerificationModel = createDummyModelNode("verification of");
|
const dummyVerificationModel = createDummyModelNode("verification of");
|
||||||
|
|
||||||
@@ -20,8 +20,6 @@ const triggerEventModel = createModelNode(triggerEventToolsByName, "trigger.txt"
|
|||||||
|
|
||||||
|
|
||||||
const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name);
|
const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name);
|
||||||
const verificationToolConditional = createToolConditional("verificationToolNode", produceRanking.name);
|
|
||||||
|
|
||||||
|
|
||||||
const agent = new StateGraph(MessagesState)
|
const agent = new StateGraph(MessagesState)
|
||||||
|
|
||||||
@@ -36,7 +34,6 @@ const agent = new StateGraph(MessagesState)
|
|||||||
.addNode(verificationSetup.name, verificationSetup)
|
.addNode(verificationSetup.name, verificationSetup)
|
||||||
.addNode("dummyVerificationModel", dummyVerificationModel)
|
.addNode("dummyVerificationModel", dummyVerificationModel)
|
||||||
.addNode(dummyRagasMetrics.name, dummyRagasMetrics)
|
.addNode(dummyRagasMetrics.name, dummyRagasMetrics)
|
||||||
.addNode("verificationToolNode", verificationToolNode)
|
|
||||||
.addNode(produceRanking.name, produceRanking)
|
.addNode(produceRanking.name, produceRanking)
|
||||||
|
|
||||||
.addEdge(START, normalizationSetup.name)
|
.addEdge(START, normalizationSetup.name)
|
||||||
@@ -50,12 +47,12 @@ const agent = new StateGraph(MessagesState)
|
|||||||
.addEdge(verificationSetup.name, "dummyVerificationModel")
|
.addEdge(verificationSetup.name, "dummyVerificationModel")
|
||||||
.addEdge(verificationSetup.name, dummyRagasMetrics.name)
|
.addEdge(verificationSetup.name, dummyRagasMetrics.name)
|
||||||
|
|
||||||
// @ts-expect-error
|
|
||||||
.addConditionalEdges("dummyVerificationModel", verificationToolConditional, ["verificationToolNode", produceRanking.name])
|
|
||||||
.addEdge("verificationToolNode", "dummyVerificationModel")
|
|
||||||
|
|
||||||
.addEdge(dummyRagasMetrics.name, produceRanking.name)
|
.addEdge(dummyRagasMetrics.name, produceRanking.name)
|
||||||
|
.addEdge("dummyVerificationModel", produceRanking.name)
|
||||||
|
|
||||||
|
|
||||||
|
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, END])
|
||||||
|
|
||||||
.compile();
|
.compile();
|
||||||
|
|
||||||
export {agent}
|
export {agent}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import { ConditionalEdgeRouter, END } from "@langchain/langgraph";
|
||||||
|
import { MessagesState } from "../state";
|
||||||
|
|
||||||
|
|
||||||
|
export const loopEndConditional: ConditionalEdgeRouter<typeof MessagesState, String> = (state) => {
|
||||||
|
const triggerEvents = state.proposedTriggerEvent;
|
||||||
|
const triggerEventsIndex = state.proposedTriggerEventIndex;
|
||||||
|
|
||||||
|
if (triggerEventsIndex == triggerEvents.length-1) {
|
||||||
|
return END
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return "verificationSetup"
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
import { GraphNode } from "@langchain/langgraph";
|
import { GraphNode } from "@langchain/langgraph";
|
||||||
import { MessagesState } from "../state";
|
import { MessagesState } from "../state";
|
||||||
import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages";
|
import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages";
|
||||||
import { calculateSimilarity } from "../tools/clan/retreiveExamples";
|
import { rankFromCSV } from "../tools/clan/retreiveExamples";
|
||||||
|
|
||||||
export const normalizationSetup: GraphNode<typeof MessagesState> = async (state) => {
|
export const normalizationSetup: GraphNode<typeof MessagesState> = async (state) => {
|
||||||
let similarityResults = await calculateSimilarity(state.disinformationTitle)
|
let similarityResults = await rankFromCSV(state.disinformationTitle)
|
||||||
|
|
||||||
console.log(similarityResults)
|
console.log(similarityResults)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,22 @@
|
|||||||
import { GraphNode } from "@langchain/langgraph";
|
import { GraphNode } from "@langchain/langgraph";
|
||||||
import { MessagesState } from "../state";
|
import { MessagesState, ProposedTriggerEventArray } from "../state";
|
||||||
import { HumanMessage } from "@langchain/core/messages";
|
import { logger } from "../utils/logger";
|
||||||
|
|
||||||
export const verificationSetup: GraphNode<typeof MessagesState> = async (state) => {
|
export const verificationSetup: GraphNode<typeof MessagesState> = async (state) => {
|
||||||
//TODO: this might not be needed, looks nice on the graph tho
|
//this is kinda doing two things, but having two nodes for it seems overkill
|
||||||
|
console.log(state.proposedTriggerEvent)
|
||||||
return { messages: [ new HumanMessage(state.messages.at(-1)?.content ?? "undefined")] };
|
console.log(state.proposedTriggerEventIndex)
|
||||||
|
if (state.proposedTriggerEvent == undefined) {
|
||||||
|
logger.warn("No trigger events in memory, parsing")
|
||||||
|
|
||||||
|
let genResponse = state.messages.at(-1)?.content.toString() ?? "";
|
||||||
|
const parsed = ProposedTriggerEventArray.parse(JSON.parse(genResponse));
|
||||||
|
|
||||||
|
return { proposedTriggerEvent: parsed, proposedTriggerEventIndex: 0 };
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
logger.info("Trigger event index %s", state.proposedTriggerEventIndex+1)
|
||||||
|
|
||||||
|
return { proposedTriggerEvent: state.proposedTriggerEvent, proposedTriggerEventIndex: state.proposedTriggerEventIndex+1 };
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Generated
+86
@@ -21,6 +21,8 @@
|
|||||||
"fs": "^0.0.1-security",
|
"fs": "^0.0.1-security",
|
||||||
"langchain": "^1.2.14",
|
"langchain": "^1.2.14",
|
||||||
"selenium-webdriver": "^4.40.0",
|
"selenium-webdriver": "^4.40.0",
|
||||||
|
"wink-bm25-text-search": "^3.1.2",
|
||||||
|
"wink-nlp-utils": "^2.1.0",
|
||||||
"winston": "^3.19.0"
|
"winston": "^3.19.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
@@ -1640,6 +1642,12 @@
|
|||||||
"node": ">= 0.4"
|
"node": ">= 0.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/emoji-regex": {
|
||||||
|
"version": "9.2.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz",
|
||||||
|
"integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/enabled": {
|
"node_modules/enabled": {
|
||||||
"version": "2.0.0",
|
"version": "2.0.0",
|
||||||
"resolved": "https://registry.npmjs.org/enabled/-/enabled-2.0.0.tgz",
|
"resolved": "https://registry.npmjs.org/enabled/-/enabled-2.0.0.tgz",
|
||||||
@@ -2787,6 +2795,84 @@
|
|||||||
"resolved": "https://registry.npmjs.org/validate.io-function/-/validate.io-function-1.0.2.tgz",
|
"resolved": "https://registry.npmjs.org/validate.io-function/-/validate.io-function-1.0.2.tgz",
|
||||||
"integrity": "sha512-LlFybRJEriSuBnUhQyG5bwglhh50EpTL2ul23MPIuR1odjO7XaMLFV8vHGwp7AZciFxtYOeiSCT5st+XSPONiQ=="
|
"integrity": "sha512-LlFybRJEriSuBnUhQyG5bwglhh50EpTL2ul23MPIuR1odjO7XaMLFV8vHGwp7AZciFxtYOeiSCT5st+XSPONiQ=="
|
||||||
},
|
},
|
||||||
|
"node_modules/wink-bm25-text-search": {
|
||||||
|
"version": "3.1.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/wink-bm25-text-search/-/wink-bm25-text-search-3.1.2.tgz",
|
||||||
|
"integrity": "sha512-s+xY0v/yurUhiUop/XZnf9IvO9XVuwI14X+QTW0JqlmQCg+9ZgVXTMudXKqZuQVsnm5J+RjLnqrOflnD5BLApA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"wink-eng-lite-web-model": "^1.4.3",
|
||||||
|
"wink-helpers": "^2.0.0",
|
||||||
|
"wink-nlp": "^1.12.2",
|
||||||
|
"wink-nlp-utils": "^2.0.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/wink-distance": {
|
||||||
|
"version": "2.0.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/wink-distance/-/wink-distance-2.0.2.tgz",
|
||||||
|
"integrity": "sha512-pyEhUB/OKFYcgOC4J6E+c+gwVA/8qg2s5n49mIcUsJZM5iDSa17uOxRQXR4rvfp+gbj55K/I08FwjFBwb6fq3g==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"wink-helpers": "^2.0.0",
|
||||||
|
"wink-jaro-distance": "^2.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/wink-eng-lite-web-model": {
|
||||||
|
"version": "1.8.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/wink-eng-lite-web-model/-/wink-eng-lite-web-model-1.8.1.tgz",
|
||||||
|
"integrity": "sha512-M2tSOU/rVNkDj8AS8IoKJaM7apJJjS0cN+hE8CPazfnB4A/ojyc9+7RMPk18UOiIdSyWk7MR6w8z9lWix2l5tA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=16.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/wink-helpers": {
|
||||||
|
"version": "2.0.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/wink-helpers/-/wink-helpers-2.0.0.tgz",
|
||||||
|
"integrity": "sha512-I/ZzXrHcNRXuoeFJmp2vMVqDI6UCK02Tds1WP4kSGAmx520gjL1BObVzF7d2ps24tyHIly9ngdB2jwhlFUjPvg==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/wink-jaro-distance": {
|
||||||
|
"version": "2.0.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/wink-jaro-distance/-/wink-jaro-distance-2.0.0.tgz",
|
||||||
|
"integrity": "sha512-9bcUaXCi9N8iYpGWbFkf83OsBkg17r4hEyxusEzl+nnReLRPqxhB9YNeRn3g54SYnVRNXP029lY3HDsbdxTAuA==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/wink-nlp": {
|
||||||
|
"version": "1.14.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/wink-nlp/-/wink-nlp-1.14.3.tgz",
|
||||||
|
"integrity": "sha512-lvY5iCs3T8I34F8WKS70+2P0U9dWLn3vdPf/Z+m2VK14N7OmqnPzmHfh3moHdusajoQ37Em39z0IZB9K4x/96A==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/wink-nlp-utils": {
|
||||||
|
"version": "2.1.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/wink-nlp-utils/-/wink-nlp-utils-2.1.0.tgz",
|
||||||
|
"integrity": "sha512-b7PcRhEBNxQmsmht70jLOkwYUyie3da4/cgEXL+CumYO5b/nwV+W7fuMXToh5BtGq1RABznmc2TGTp1Qf/JUXg==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"wink-distance": "^2.0.1",
|
||||||
|
"wink-eng-lite-web-model": "^1.4.3",
|
||||||
|
"wink-helpers": "^2.0.0",
|
||||||
|
"wink-nlp": "^1.12.0",
|
||||||
|
"wink-porter2-stemmer": "^2.0.1",
|
||||||
|
"wink-tokenizer": "^5.2.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/wink-porter2-stemmer": {
|
||||||
|
"version": "2.0.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/wink-porter2-stemmer/-/wink-porter2-stemmer-2.0.1.tgz",
|
||||||
|
"integrity": "sha512-0g+RkkqhRXFmSpJQStVXW5N/WsshWpJXsoDRW7DwVkGI2uDT6IBCoq3xdH5p6IHLaC6ygk7RWUsUx4alKxoagQ==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/wink-tokenizer": {
|
||||||
|
"version": "5.3.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/wink-tokenizer/-/wink-tokenizer-5.3.0.tgz",
|
||||||
|
"integrity": "sha512-O/yAw0g3FmSgeeQuYAJJfP7fVPB4A6ay0018qASh79aWmIOyPYy4j4r9EQT8xBjicja6lCLvgRVAybmEBaATQA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"emoji-regex": "^9.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/winston": {
|
"node_modules/winston": {
|
||||||
"version": "3.19.0",
|
"version": "3.19.0",
|
||||||
"resolved": "https://registry.npmjs.org/winston/-/winston-3.19.0.tgz",
|
"resolved": "https://registry.npmjs.org/winston/-/winston-3.19.0.tgz",
|
||||||
|
|||||||
@@ -22,6 +22,8 @@
|
|||||||
"fs": "^0.0.1-security",
|
"fs": "^0.0.1-security",
|
||||||
"langchain": "^1.2.14",
|
"langchain": "^1.2.14",
|
||||||
"selenium-webdriver": "^4.40.0",
|
"selenium-webdriver": "^4.40.0",
|
||||||
|
"wink-bm25-text-search": "^3.1.2",
|
||||||
|
"wink-nlp-utils": "^2.1.0",
|
||||||
"winston": "^3.19.0"
|
"winston": "^3.19.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|||||||
+15
-3
@@ -10,8 +10,20 @@ import {
|
|||||||
} from "@langchain/langgraph";
|
} from "@langchain/langgraph";
|
||||||
import { z } from "zod/v4";
|
import { z } from "zod/v4";
|
||||||
|
|
||||||
|
export const ProposedTriggerEvent = z.object({
|
||||||
|
Event: z.string(),
|
||||||
|
ReasoningWhyRelevant: z.string(),
|
||||||
|
SearchQuery: z.string(),
|
||||||
|
Url: z.url(),
|
||||||
|
IsItselfDisinformation: z.boolean()
|
||||||
|
})
|
||||||
|
|
||||||
|
export const ProposedTriggerEventArray = z.array(ProposedTriggerEvent);
|
||||||
|
|
||||||
export const MessagesState = new StateSchema({
|
export const MessagesState = new StateSchema({
|
||||||
|
disinformationTitle: z.string(),
|
||||||
messages: MessagesValue,
|
messages: MessagesValue,
|
||||||
// normalizationContext: z.map(z.string(), z.string()),
|
proposedTriggerEvent: ProposedTriggerEventArray,
|
||||||
disinformationTitle: z.string()
|
proposedTriggerEventIndex: z.int(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
import { parse } from "csv-parse";
|
import { parse } from "csv-parse";
|
||||||
import fs from "fs";
|
import fs from "fs";
|
||||||
import { pipeline, cos_sim } from "@huggingface/transformers";
|
import { pipeline, cos_sim } from "@huggingface/transformers";
|
||||||
|
import bm25Factory from "wink-bm25-text-search";
|
||||||
|
import nlp from "wink-nlp-utils";
|
||||||
import { logger } from "../../utils/logger";
|
import { logger } from "../../utils/logger";
|
||||||
|
|
||||||
//TODO, am getting duplicates, is it from the multi files?
|
|
||||||
const CSV_PATHS = [
|
const CSV_PATHS = [
|
||||||
"./tools/clan/dev-eng.csv",
|
"./tools/clan/dev-eng.csv",
|
||||||
// "./tools/clan/test-eng.csv",
|
|
||||||
"./tools/clan/train-eng.csv",
|
"./tools/clan/train-eng.csv",
|
||||||
];
|
];
|
||||||
|
|
||||||
const CACHE_PATH = "./tools/clan/dev.embeddings.json";
|
const CACHE_PATH = "./tools/clan/csv.cache.json";
|
||||||
|
|
||||||
type EmbeddingCache = {
|
type EmbeddingCache = {
|
||||||
rawtexts: string[];
|
rawtexts: string[];
|
||||||
@@ -18,104 +18,262 @@ type EmbeddingCache = {
|
|||||||
embeddings: number[][];
|
embeddings: number[][];
|
||||||
};
|
};
|
||||||
|
|
||||||
export type NormalisedMatch = {
|
export type RetrievalItem = {
|
||||||
index: number;
|
id: string | number;
|
||||||
score: number;
|
|
||||||
rawtext: string;
|
rawtext: string;
|
||||||
cleantext: string;
|
cleantext?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
let rawtexts: string[] = [];
|
export type RankedResult = RetrievalItem & {
|
||||||
let cleantexts: string[] = [];
|
denseScore: number;
|
||||||
let embeddings: number[][] = [];
|
sparseScore: number;
|
||||||
|
fusedScore: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
let csvRawtexts: string[] = [];
|
||||||
|
let csvCleantexts: string[] = [];
|
||||||
|
let csvEmbeddings: number[][] = [];
|
||||||
|
let csvBM25: any = null;
|
||||||
|
let csvLoaded = false;
|
||||||
|
|
||||||
|
logger.info("Loading embedding model...");
|
||||||
const featureExtractor = await pipeline(
|
const featureExtractor = await pipeline(
|
||||||
"feature-extraction",
|
"feature-extraction",
|
||||||
"Xenova/all-MiniLM-L6-v2"
|
"Xenova/all-MiniLM-L6-v2"
|
||||||
);
|
);
|
||||||
|
logger.info("Embedding model loaded");
|
||||||
|
|
||||||
|
//Cached entrypoint
|
||||||
|
export async function rankFromCSV(
|
||||||
|
query: string,
|
||||||
|
topK = 5
|
||||||
|
): Promise<RankedResult[]> {
|
||||||
|
await ensureCSVLoaded();
|
||||||
|
|
||||||
|
logger.info("Ranking from CSV cache...");
|
||||||
|
|
||||||
|
const queryEmbedding = await embedText(query);
|
||||||
|
|
||||||
|
const denseScores = csvEmbeddings.map((docEmbedding) =>
|
||||||
|
cos_sim(docEmbedding, queryEmbedding)
|
||||||
|
);
|
||||||
|
|
||||||
|
const sparseScores = computeSparseScores(query, csvBM25, csvRawtexts);
|
||||||
|
|
||||||
|
const fusedScores = reciprocalRankFusion([denseScores, sparseScores]);
|
||||||
|
|
||||||
|
const ranked = csvRawtexts
|
||||||
|
.map((text, i) => ({
|
||||||
|
id: i,
|
||||||
|
rawtext: text,
|
||||||
|
cleantext: csvCleantexts[i],
|
||||||
|
denseScore: denseScores[i],
|
||||||
|
sparseScore: sparseScores[i],
|
||||||
|
fusedScore: fusedScores[i],
|
||||||
|
}))
|
||||||
|
.sort((a, b) => b.fusedScore - a.fusedScore);
|
||||||
|
|
||||||
|
logger.info("Ranking complete (CSV mode)");
|
||||||
|
|
||||||
|
return ranked.slice(0, topK);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Dynamic Entrypoint
|
||||||
|
export async function rankDynamically(
|
||||||
|
query: string,
|
||||||
|
items: RetrievalItem[],
|
||||||
|
topK = 5
|
||||||
|
): Promise<RankedResult[]> {
|
||||||
|
logger.info("Ranking dynamically (no cache)...");
|
||||||
|
|
||||||
|
if (!items.length) return [];
|
||||||
|
|
||||||
|
const texts = items.map((i) => i.rawtext);
|
||||||
|
|
||||||
|
const queryEmbedding = await embedText(query);
|
||||||
|
|
||||||
|
const docEmbeddings = await Promise.all(
|
||||||
|
texts.map((text) => embedText(text))
|
||||||
|
);
|
||||||
|
|
||||||
|
const denseScores = docEmbeddings.map((embedding) =>
|
||||||
|
cos_sim(embedding, queryEmbedding)
|
||||||
|
);
|
||||||
|
|
||||||
|
const localBM25 = buildBM25(texts);
|
||||||
|
|
||||||
|
const sparseScores = computeSparseScores(query, localBM25, texts);
|
||||||
|
|
||||||
|
const fusedScores = reciprocalRankFusion([denseScores, sparseScores]);
|
||||||
|
|
||||||
|
const ranked = items
|
||||||
|
.map((item, i) => ({
|
||||||
|
...item,
|
||||||
|
denseScore: denseScores[i],
|
||||||
|
sparseScore: sparseScores[i],
|
||||||
|
fusedScore: fusedScores[i],
|
||||||
|
}))
|
||||||
|
.sort((a, b) => b.fusedScore - a.fusedScore);
|
||||||
|
|
||||||
|
logger.info("Ranking complete (dynamic mode)");
|
||||||
|
|
||||||
|
return ranked.slice(0, topK);
|
||||||
|
}
|
||||||
|
|
||||||
|
//CSV stuff
|
||||||
|
async function ensureCSVLoaded(): Promise<void> {
|
||||||
|
if (csvLoaded) return;
|
||||||
|
|
||||||
|
logger.info("Initializing CSV ranking mode...");
|
||||||
|
|
||||||
async function loadOrBuildCache(): Promise<void> {
|
|
||||||
if (fs.existsSync(CACHE_PATH)) {
|
if (fs.existsSync(CACHE_PATH)) {
|
||||||
logger.info("Loading embeddings from cache");
|
logger.info("Loading CSV cache from disk...");
|
||||||
|
|
||||||
const raw = fs.readFileSync(CACHE_PATH, "utf-8");
|
const raw = fs.readFileSync(CACHE_PATH, "utf-8");
|
||||||
const cache: EmbeddingCache = JSON.parse(raw);
|
const cache: EmbeddingCache = JSON.parse(raw);
|
||||||
|
|
||||||
rawtexts = cache.rawtexts;
|
csvRawtexts = cache.rawtexts;
|
||||||
cleantexts = cache.cleantexts;
|
csvCleantexts = cache.cleantexts;
|
||||||
embeddings = cache.embeddings.map(e => Array.from(e));
|
csvEmbeddings = cache.embeddings;
|
||||||
|
|
||||||
logger.info("Loaded %s embeddings", embeddings.length);
|
logger.info("Loaded %s cached embeddings", csvEmbeddings.length);
|
||||||
return;
|
} else {
|
||||||
|
logger.warn("CSV cache not found. Building embeddings...");
|
||||||
|
|
||||||
|
const seen = new Set<string>();
|
||||||
|
|
||||||
|
for (const path of CSV_PATHS) {
|
||||||
|
await processCSV(path, seen);
|
||||||
|
}
|
||||||
|
|
||||||
|
const cache: EmbeddingCache = {
|
||||||
|
rawtexts: csvRawtexts,
|
||||||
|
cleantexts: csvCleantexts,
|
||||||
|
embeddings: csvEmbeddings,
|
||||||
|
};
|
||||||
|
|
||||||
|
fs.writeFileSync(CACHE_PATH, JSON.stringify(cache));
|
||||||
|
logger.info("Cache written (%s embeddings)", csvEmbeddings.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.warn("Cache not found. Generating embeddings");
|
csvBM25 = buildBM25(csvRawtexts);
|
||||||
|
|
||||||
for (const csvPath of CSV_PATHS) {
|
csvLoaded = true;
|
||||||
await buildCacheFromCSV(csvPath);
|
logger.info("CSV mode ready");
|
||||||
}
|
|
||||||
|
|
||||||
const cache: EmbeddingCache = {
|
|
||||||
rawtexts,
|
|
||||||
cleantexts,
|
|
||||||
embeddings,
|
|
||||||
};
|
|
||||||
|
|
||||||
fs.writeFileSync(CACHE_PATH, JSON.stringify(cache));
|
|
||||||
|
|
||||||
logger.info("Cached %s embeddings", embeddings.length);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function buildCacheFromCSV(csvPath: string): Promise<void> {
|
async function processCSV(
|
||||||
let count = 0;
|
path: string,
|
||||||
|
seen: Set<string>
|
||||||
|
): Promise<void> {
|
||||||
|
logger.info("Processing CSV: %s", path);
|
||||||
|
|
||||||
logger.info("Processing CSV: %s", csvPath);
|
const stream = fs.createReadStream(path).pipe(parse());
|
||||||
|
|
||||||
const stream = fs.createReadStream(csvPath).pipe(parse());
|
|
||||||
|
|
||||||
for await (const row of stream) {
|
for await (const row of stream) {
|
||||||
const text = row[0];
|
const text = row[0];
|
||||||
if (!text) continue;
|
if (!text || seen.has(text)) continue;
|
||||||
|
|
||||||
const output = await featureExtractor(text, {
|
seen.add(text);
|
||||||
pooling: "mean",
|
|
||||||
normalize: true,
|
|
||||||
});
|
|
||||||
|
|
||||||
rawtexts.push(text);
|
const embedding = await embedText(text);
|
||||||
cleantexts.push(row[1]);
|
|
||||||
const vector = Array.from(output.data as Float32Array);
|
|
||||||
embeddings.push(vector);
|
|
||||||
|
|
||||||
|
csvRawtexts.push(text);
|
||||||
|
csvCleantexts.push(row[1]);
|
||||||
|
csvEmbeddings.push(embedding);
|
||||||
|
|
||||||
count++;
|
if (csvRawtexts.length % 100 === 0) {
|
||||||
if (count % 100 === 0) {
|
logger.info("Embedded %s documents...", csvRawtexts.length);
|
||||||
logger.info("[%s] Processed %s rows", csvPath, count);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("[%s] Finished (%s rows)", csvPath, count);
|
logger.info("Finished CSV: %s", path);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function calculateSimilarity(
|
|
||||||
query: string,
|
|
||||||
topK = 5
|
|
||||||
): Promise<NormalisedMatch[]> {
|
|
||||||
await loadOrBuildCache()
|
|
||||||
|
|
||||||
const queryEmbedding = await featureExtractor(query, {
|
async function embedText(text: string): Promise<number[]> {
|
||||||
|
const output = await featureExtractor(text, {
|
||||||
pooling: "mean",
|
pooling: "mean",
|
||||||
normalize: true,
|
normalize: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
return embeddings
|
return Array.from(output.data as Float32Array);
|
||||||
.map((embedding, index) => ({
|
}
|
||||||
index,
|
|
||||||
score: cos_sim(embedding, queryEmbedding.data as number[]),
|
function buildBM25(texts: string[]) {
|
||||||
rawtext: rawtexts[index],
|
logger.info("Building BM25 index (%s docs)...", texts.length);
|
||||||
cleantext: cleantexts[index]
|
|
||||||
}))
|
const bm25 = bm25Factory();
|
||||||
.sort((a, b) => b.score - a.score)
|
|
||||||
.slice(0, topK);
|
bm25.defineConfig({
|
||||||
}
|
fldWeights: { text: 1 },
|
||||||
|
bm25Params: { k1: 1.2, b: 0.75 },
|
||||||
|
});
|
||||||
|
|
||||||
|
bm25.definePrepTasks([
|
||||||
|
nlp.string.lowerCase,
|
||||||
|
nlp.string.tokenize0,
|
||||||
|
nlp.tokens.removeWords,
|
||||||
|
]);
|
||||||
|
|
||||||
|
texts.forEach((text, i) => {
|
||||||
|
bm25.addDoc({ text }, i);
|
||||||
|
});
|
||||||
|
|
||||||
|
bm25.consolidate();
|
||||||
|
|
||||||
|
logger.info("BM25 ready");
|
||||||
|
|
||||||
|
return bm25;
|
||||||
|
}
|
||||||
|
|
||||||
|
function computeSparseScores(
|
||||||
|
query: string,
|
||||||
|
bm25: any,
|
||||||
|
texts: string[]
|
||||||
|
): number[] {
|
||||||
|
const results = bm25.search(query);
|
||||||
|
|
||||||
|
const scores = new Array(texts.length).fill(0);
|
||||||
|
|
||||||
|
results.forEach((r: any) => {
|
||||||
|
scores[r[0]] = r[1];
|
||||||
|
});
|
||||||
|
|
||||||
|
return scores;
|
||||||
|
}
|
||||||
|
|
||||||
|
function reciprocalRankFusion(
|
||||||
|
scoreLists: number[][],
|
||||||
|
k = 60
|
||||||
|
): number[] {
|
||||||
|
const length = scoreLists[0].length;
|
||||||
|
const fused = new Array(length).fill(0);
|
||||||
|
|
||||||
|
for (const scores of scoreLists) {
|
||||||
|
const ranked = scores
|
||||||
|
.map((score, i) => ({ score, i }))
|
||||||
|
.sort((a, b) => b.score - a.score)
|
||||||
|
.map((x) => x.i);
|
||||||
|
|
||||||
|
ranked.forEach((docIndex, rank) => {
|
||||||
|
fused[docIndex] += 1 / (k + rank);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return fused;
|
||||||
|
}
|
||||||
|
|
||||||
|
// console.log(await rankFromCSV("barrack obama"))
|
||||||
|
// console.log(
|
||||||
|
// await rankDynamically(
|
||||||
|
// "i fell over",
|
||||||
|
// [
|
||||||
|
// { id: 1, rawtext: "I slipped and fell on the floor." },
|
||||||
|
// { id: 2, rawtext: "Barack Obama was the 44th president." },
|
||||||
|
// { id: 3, rawtext: "He tripped and hurt his knee badly." },
|
||||||
|
// { id: 4, rawtext: "The weather is sunny today." },
|
||||||
|
// { id: 5, rawtext: "She lost her balance and fell down the stairs." },
|
||||||
|
// ]
|
||||||
|
// )
|
||||||
|
// );
|
||||||
@@ -2,18 +2,20 @@ import { tool } from "@langchain/core/tools";
|
|||||||
import * as z from "zod";
|
import * as z from "zod";
|
||||||
import { queryScraper } from "./webSearch";
|
import { queryScraper } from "./webSearch";
|
||||||
import { extractWebpageContent } from "./webpageFetch";
|
import { extractWebpageContent } from "./webpageFetch";
|
||||||
|
import { rankDynamically } from "./clan/retreiveExamples";
|
||||||
|
|
||||||
|
|
||||||
function rankAndDisplayData(data: string[]):string {
|
async function rankAndDisplayData(data: string[], context: string):Promise<string> {
|
||||||
//TODO: hybrid re-ranking of the provided data
|
let index = 0;
|
||||||
return data.join("\n")
|
let ranked = await rankDynamically(context, data.map(irm => ({ id: index++, rawtext: irm })))
|
||||||
|
return ranked.map(itm => itm.rawtext).join("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define tools
|
// Define tools
|
||||||
const webSearch = tool(
|
const webSearch = tool(
|
||||||
async ({ a }) => {
|
async ({ a }) => {
|
||||||
const data = await queryScraper(a);
|
const data = await queryScraper(a);
|
||||||
return rankAndDisplayData(data);
|
return await rankAndDisplayData(data, a);
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "WebSearch",
|
name: "WebSearch",
|
||||||
@@ -25,15 +27,16 @@ const webSearch = tool(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const openWebpage = tool(
|
const openWebpage = tool(
|
||||||
async ({ a }) => {
|
async ({ a, b }) => {
|
||||||
const data = await extractWebpageContent(a);
|
const data = await extractWebpageContent(a);
|
||||||
return rankAndDisplayData(data);
|
return rankAndDisplayData(data, b);
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "OpenWebpage",
|
name: "OpenWebpage",
|
||||||
description: "Opens webpage and returns most relevent snippets",
|
description: "Opens webpage and returns most relevent snippets",
|
||||||
schema: z.object({
|
schema: z.object({
|
||||||
a: z.string().describe("URL"),
|
a: z.string().describe("URL"),
|
||||||
|
b: z.string().describe("What to match against in webpage content"),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user