Refactor example retreiving, add option for dynamic data. Add hybrid reranking to tooling. Add parsing and loop infrastructure for trigger event processing
This commit is contained in:
+5
-8
@@ -9,9 +9,9 @@ import { verificationSetup } from "./nodes/verificationSetup";
|
||||
import { dummyRagasMetrics } from "./nodes/dummyRagasMetrics";
|
||||
import { produceRanking } from "./nodes/produceRanking";
|
||||
import { createModelNode } from "./nodes/model";
|
||||
import { loopEndConditional } from "./conditionals/loop_end";
|
||||
|
||||
const triggerEventToolNode = createToolNode(triggerEventToolsByName);
|
||||
const verificationToolNode = createToolNode([]);
|
||||
|
||||
const dummyVerificationModel = createDummyModelNode("verification of");
|
||||
|
||||
@@ -20,8 +20,6 @@ const triggerEventModel = createModelNode(triggerEventToolsByName, "trigger.txt"
|
||||
|
||||
|
||||
const triggerEventToolConditional = createToolConditional("triggerEventToolNode", verificationSetup.name);
|
||||
const verificationToolConditional = createToolConditional("verificationToolNode", produceRanking.name);
|
||||
|
||||
|
||||
const agent = new StateGraph(MessagesState)
|
||||
|
||||
@@ -36,7 +34,6 @@ const agent = new StateGraph(MessagesState)
|
||||
.addNode(verificationSetup.name, verificationSetup)
|
||||
.addNode("dummyVerificationModel", dummyVerificationModel)
|
||||
.addNode(dummyRagasMetrics.name, dummyRagasMetrics)
|
||||
.addNode("verificationToolNode", verificationToolNode)
|
||||
.addNode(produceRanking.name, produceRanking)
|
||||
|
||||
.addEdge(START, normalizationSetup.name)
|
||||
@@ -50,11 +47,11 @@ const agent = new StateGraph(MessagesState)
|
||||
.addEdge(verificationSetup.name, "dummyVerificationModel")
|
||||
.addEdge(verificationSetup.name, dummyRagasMetrics.name)
|
||||
|
||||
// @ts-expect-error
|
||||
.addConditionalEdges("dummyVerificationModel", verificationToolConditional, ["verificationToolNode", produceRanking.name])
|
||||
.addEdge("verificationToolNode", "dummyVerificationModel")
|
||||
|
||||
.addEdge(dummyRagasMetrics.name, produceRanking.name)
|
||||
.addEdge("dummyVerificationModel", produceRanking.name)
|
||||
|
||||
|
||||
.addConditionalEdges(produceRanking.name, loopEndConditional, [verificationSetup.name, END])
|
||||
|
||||
.compile();
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
import { ConditionalEdgeRouter, END } from "@langchain/langgraph";
|
||||
import { MessagesState } from "../state";
|
||||
|
||||
|
||||
export const loopEndConditional: ConditionalEdgeRouter<typeof MessagesState, String> = (state) => {
|
||||
const triggerEvents = state.proposedTriggerEvent;
|
||||
const triggerEventsIndex = state.proposedTriggerEventIndex;
|
||||
|
||||
if (triggerEventsIndex == triggerEvents.length-1) {
|
||||
return END
|
||||
}
|
||||
else {
|
||||
return "verificationSetup"
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { GraphNode } from "@langchain/langgraph";
|
||||
import { MessagesState } from "../state";
|
||||
import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages";
|
||||
import { calculateSimilarity } from "../tools/clan/retreiveExamples";
|
||||
import { rankFromCSV } from "../tools/clan/retreiveExamples";
|
||||
|
||||
export const normalizationSetup: GraphNode<typeof MessagesState> = async (state) => {
|
||||
let similarityResults = await calculateSimilarity(state.disinformationTitle)
|
||||
let similarityResults = await rankFromCSV(state.disinformationTitle)
|
||||
|
||||
console.log(similarityResults)
|
||||
|
||||
|
||||
@@ -1,9 +1,22 @@
|
||||
import { GraphNode } from "@langchain/langgraph";
|
||||
import { MessagesState } from "../state";
|
||||
import { HumanMessage } from "@langchain/core/messages";
|
||||
import { MessagesState, ProposedTriggerEventArray } from "../state";
|
||||
import { logger } from "../utils/logger";
|
||||
|
||||
export const verificationSetup: GraphNode<typeof MessagesState> = async (state) => {
|
||||
//TODO: this might not be needed, looks nice on the graph tho
|
||||
//this is kinda doing two things, but having two nodes for it seems overkill
|
||||
console.log(state.proposedTriggerEvent)
|
||||
console.log(state.proposedTriggerEventIndex)
|
||||
if (state.proposedTriggerEvent == undefined) {
|
||||
logger.warn("No trigger events in memory, parsing")
|
||||
|
||||
return { messages: [ new HumanMessage(state.messages.at(-1)?.content ?? "undefined")] };
|
||||
let genResponse = state.messages.at(-1)?.content.toString() ?? "";
|
||||
const parsed = ProposedTriggerEventArray.parse(JSON.parse(genResponse));
|
||||
|
||||
return { proposedTriggerEvent: parsed, proposedTriggerEventIndex: 0 };
|
||||
}
|
||||
else {
|
||||
logger.info("Trigger event index %s", state.proposedTriggerEventIndex+1)
|
||||
|
||||
return { proposedTriggerEvent: state.proposedTriggerEvent, proposedTriggerEventIndex: state.proposedTriggerEventIndex+1 };
|
||||
}
|
||||
};
|
||||
Generated
+86
@@ -21,6 +21,8 @@
|
||||
"fs": "^0.0.1-security",
|
||||
"langchain": "^1.2.14",
|
||||
"selenium-webdriver": "^4.40.0",
|
||||
"wink-bm25-text-search": "^3.1.2",
|
||||
"wink-nlp-utils": "^2.1.0",
|
||||
"winston": "^3.19.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -1640,6 +1642,12 @@
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/emoji-regex": {
|
||||
"version": "9.2.2",
|
||||
"resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz",
|
||||
"integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/enabled": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/enabled/-/enabled-2.0.0.tgz",
|
||||
@@ -2787,6 +2795,84 @@
|
||||
"resolved": "https://registry.npmjs.org/validate.io-function/-/validate.io-function-1.0.2.tgz",
|
||||
"integrity": "sha512-LlFybRJEriSuBnUhQyG5bwglhh50EpTL2ul23MPIuR1odjO7XaMLFV8vHGwp7AZciFxtYOeiSCT5st+XSPONiQ=="
|
||||
},
|
||||
"node_modules/wink-bm25-text-search": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/wink-bm25-text-search/-/wink-bm25-text-search-3.1.2.tgz",
|
||||
"integrity": "sha512-s+xY0v/yurUhiUop/XZnf9IvO9XVuwI14X+QTW0JqlmQCg+9ZgVXTMudXKqZuQVsnm5J+RjLnqrOflnD5BLApA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"wink-eng-lite-web-model": "^1.4.3",
|
||||
"wink-helpers": "^2.0.0",
|
||||
"wink-nlp": "^1.12.2",
|
||||
"wink-nlp-utils": "^2.0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/wink-distance": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/wink-distance/-/wink-distance-2.0.2.tgz",
|
||||
"integrity": "sha512-pyEhUB/OKFYcgOC4J6E+c+gwVA/8qg2s5n49mIcUsJZM5iDSa17uOxRQXR4rvfp+gbj55K/I08FwjFBwb6fq3g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"wink-helpers": "^2.0.0",
|
||||
"wink-jaro-distance": "^2.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/wink-eng-lite-web-model": {
|
||||
"version": "1.8.1",
|
||||
"resolved": "https://registry.npmjs.org/wink-eng-lite-web-model/-/wink-eng-lite-web-model-1.8.1.tgz",
|
||||
"integrity": "sha512-M2tSOU/rVNkDj8AS8IoKJaM7apJJjS0cN+hE8CPazfnB4A/ojyc9+7RMPk18UOiIdSyWk7MR6w8z9lWix2l5tA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/wink-helpers": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/wink-helpers/-/wink-helpers-2.0.0.tgz",
|
||||
"integrity": "sha512-I/ZzXrHcNRXuoeFJmp2vMVqDI6UCK02Tds1WP4kSGAmx520gjL1BObVzF7d2ps24tyHIly9ngdB2jwhlFUjPvg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/wink-jaro-distance": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/wink-jaro-distance/-/wink-jaro-distance-2.0.0.tgz",
|
||||
"integrity": "sha512-9bcUaXCi9N8iYpGWbFkf83OsBkg17r4hEyxusEzl+nnReLRPqxhB9YNeRn3g54SYnVRNXP029lY3HDsbdxTAuA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/wink-nlp": {
|
||||
"version": "1.14.3",
|
||||
"resolved": "https://registry.npmjs.org/wink-nlp/-/wink-nlp-1.14.3.tgz",
|
||||
"integrity": "sha512-lvY5iCs3T8I34F8WKS70+2P0U9dWLn3vdPf/Z+m2VK14N7OmqnPzmHfh3moHdusajoQ37Em39z0IZB9K4x/96A==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/wink-nlp-utils": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/wink-nlp-utils/-/wink-nlp-utils-2.1.0.tgz",
|
||||
"integrity": "sha512-b7PcRhEBNxQmsmht70jLOkwYUyie3da4/cgEXL+CumYO5b/nwV+W7fuMXToh5BtGq1RABznmc2TGTp1Qf/JUXg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"wink-distance": "^2.0.1",
|
||||
"wink-eng-lite-web-model": "^1.4.3",
|
||||
"wink-helpers": "^2.0.0",
|
||||
"wink-nlp": "^1.12.0",
|
||||
"wink-porter2-stemmer": "^2.0.1",
|
||||
"wink-tokenizer": "^5.2.3"
|
||||
}
|
||||
},
|
||||
"node_modules/wink-porter2-stemmer": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/wink-porter2-stemmer/-/wink-porter2-stemmer-2.0.1.tgz",
|
||||
"integrity": "sha512-0g+RkkqhRXFmSpJQStVXW5N/WsshWpJXsoDRW7DwVkGI2uDT6IBCoq3xdH5p6IHLaC6ygk7RWUsUx4alKxoagQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/wink-tokenizer": {
|
||||
"version": "5.3.0",
|
||||
"resolved": "https://registry.npmjs.org/wink-tokenizer/-/wink-tokenizer-5.3.0.tgz",
|
||||
"integrity": "sha512-O/yAw0g3FmSgeeQuYAJJfP7fVPB4A6ay0018qASh79aWmIOyPYy4j4r9EQT8xBjicja6lCLvgRVAybmEBaATQA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"emoji-regex": "^9.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/winston": {
|
||||
"version": "3.19.0",
|
||||
"resolved": "https://registry.npmjs.org/winston/-/winston-3.19.0.tgz",
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
"fs": "^0.0.1-security",
|
||||
"langchain": "^1.2.14",
|
||||
"selenium-webdriver": "^4.40.0",
|
||||
"wink-bm25-text-search": "^3.1.2",
|
||||
"wink-nlp-utils": "^2.1.0",
|
||||
"winston": "^3.19.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
+14
-2
@@ -10,8 +10,20 @@ import {
|
||||
} from "@langchain/langgraph";
|
||||
import { z } from "zod/v4";
|
||||
|
||||
export const ProposedTriggerEvent = z.object({
|
||||
Event: z.string(),
|
||||
ReasoningWhyRelevant: z.string(),
|
||||
SearchQuery: z.string(),
|
||||
Url: z.url(),
|
||||
IsItselfDisinformation: z.boolean()
|
||||
})
|
||||
|
||||
export const ProposedTriggerEventArray = z.array(ProposedTriggerEvent);
|
||||
|
||||
export const MessagesState = new StateSchema({
|
||||
disinformationTitle: z.string(),
|
||||
messages: MessagesValue,
|
||||
// normalizationContext: z.map(z.string(), z.string()),
|
||||
disinformationTitle: z.string()
|
||||
proposedTriggerEvent: ProposedTriggerEventArray,
|
||||
proposedTriggerEventIndex: z.int(),
|
||||
});
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import { parse } from "csv-parse";
|
||||
import fs from "fs";
|
||||
import { pipeline, cos_sim } from "@huggingface/transformers";
|
||||
import bm25Factory from "wink-bm25-text-search";
|
||||
import nlp from "wink-nlp-utils";
|
||||
import { logger } from "../../utils/logger";
|
||||
|
||||
//TODO, am getting duplicates, is it from the multi files?
|
||||
const CSV_PATHS = [
|
||||
"./tools/clan/dev-eng.csv",
|
||||
// "./tools/clan/test-eng.csv",
|
||||
"./tools/clan/train-eng.csv",
|
||||
];
|
||||
|
||||
const CACHE_PATH = "./tools/clan/dev.embeddings.json";
|
||||
const CACHE_PATH = "./tools/clan/csv.cache.json";
|
||||
|
||||
type EmbeddingCache = {
|
||||
rawtexts: string[];
|
||||
@@ -18,104 +18,262 @@ type EmbeddingCache = {
|
||||
embeddings: number[][];
|
||||
};
|
||||
|
||||
export type NormalisedMatch = {
|
||||
index: number;
|
||||
score: number;
|
||||
export type RetrievalItem = {
|
||||
id: string | number;
|
||||
rawtext: string;
|
||||
cleantext: string;
|
||||
cleantext?: string;
|
||||
};
|
||||
|
||||
let rawtexts: string[] = [];
|
||||
let cleantexts: string[] = [];
|
||||
let embeddings: number[][] = [];
|
||||
export type RankedResult = RetrievalItem & {
|
||||
denseScore: number;
|
||||
sparseScore: number;
|
||||
fusedScore: number;
|
||||
};
|
||||
|
||||
let csvRawtexts: string[] = [];
|
||||
let csvCleantexts: string[] = [];
|
||||
let csvEmbeddings: number[][] = [];
|
||||
let csvBM25: any = null;
|
||||
let csvLoaded = false;
|
||||
|
||||
logger.info("Loading embedding model...");
|
||||
const featureExtractor = await pipeline(
|
||||
"feature-extraction",
|
||||
"Xenova/all-MiniLM-L6-v2"
|
||||
);
|
||||
logger.info("Embedding model loaded");
|
||||
|
||||
//Cached entrypoint
|
||||
export async function rankFromCSV(
|
||||
query: string,
|
||||
topK = 5
|
||||
): Promise<RankedResult[]> {
|
||||
await ensureCSVLoaded();
|
||||
|
||||
logger.info("Ranking from CSV cache...");
|
||||
|
||||
const queryEmbedding = await embedText(query);
|
||||
|
||||
const denseScores = csvEmbeddings.map((docEmbedding) =>
|
||||
cos_sim(docEmbedding, queryEmbedding)
|
||||
);
|
||||
|
||||
const sparseScores = computeSparseScores(query, csvBM25, csvRawtexts);
|
||||
|
||||
const fusedScores = reciprocalRankFusion([denseScores, sparseScores]);
|
||||
|
||||
const ranked = csvRawtexts
|
||||
.map((text, i) => ({
|
||||
id: i,
|
||||
rawtext: text,
|
||||
cleantext: csvCleantexts[i],
|
||||
denseScore: denseScores[i],
|
||||
sparseScore: sparseScores[i],
|
||||
fusedScore: fusedScores[i],
|
||||
}))
|
||||
.sort((a, b) => b.fusedScore - a.fusedScore);
|
||||
|
||||
logger.info("Ranking complete (CSV mode)");
|
||||
|
||||
return ranked.slice(0, topK);
|
||||
}
|
||||
|
||||
//Dynamic Entrypoint
|
||||
export async function rankDynamically(
|
||||
query: string,
|
||||
items: RetrievalItem[],
|
||||
topK = 5
|
||||
): Promise<RankedResult[]> {
|
||||
logger.info("Ranking dynamically (no cache)...");
|
||||
|
||||
if (!items.length) return [];
|
||||
|
||||
const texts = items.map((i) => i.rawtext);
|
||||
|
||||
const queryEmbedding = await embedText(query);
|
||||
|
||||
const docEmbeddings = await Promise.all(
|
||||
texts.map((text) => embedText(text))
|
||||
);
|
||||
|
||||
const denseScores = docEmbeddings.map((embedding) =>
|
||||
cos_sim(embedding, queryEmbedding)
|
||||
);
|
||||
|
||||
const localBM25 = buildBM25(texts);
|
||||
|
||||
const sparseScores = computeSparseScores(query, localBM25, texts);
|
||||
|
||||
const fusedScores = reciprocalRankFusion([denseScores, sparseScores]);
|
||||
|
||||
const ranked = items
|
||||
.map((item, i) => ({
|
||||
...item,
|
||||
denseScore: denseScores[i],
|
||||
sparseScore: sparseScores[i],
|
||||
fusedScore: fusedScores[i],
|
||||
}))
|
||||
.sort((a, b) => b.fusedScore - a.fusedScore);
|
||||
|
||||
logger.info("Ranking complete (dynamic mode)");
|
||||
|
||||
return ranked.slice(0, topK);
|
||||
}
|
||||
|
||||
//CSV stuff
|
||||
async function ensureCSVLoaded(): Promise<void> {
|
||||
if (csvLoaded) return;
|
||||
|
||||
logger.info("Initializing CSV ranking mode...");
|
||||
|
||||
async function loadOrBuildCache(): Promise<void> {
|
||||
if (fs.existsSync(CACHE_PATH)) {
|
||||
logger.info("Loading embeddings from cache");
|
||||
logger.info("Loading CSV cache from disk...");
|
||||
|
||||
const raw = fs.readFileSync(CACHE_PATH, "utf-8");
|
||||
const cache: EmbeddingCache = JSON.parse(raw);
|
||||
|
||||
rawtexts = cache.rawtexts;
|
||||
cleantexts = cache.cleantexts;
|
||||
embeddings = cache.embeddings.map(e => Array.from(e));
|
||||
csvRawtexts = cache.rawtexts;
|
||||
csvCleantexts = cache.cleantexts;
|
||||
csvEmbeddings = cache.embeddings;
|
||||
|
||||
logger.info("Loaded %s embeddings", embeddings.length);
|
||||
return;
|
||||
}
|
||||
logger.info("Loaded %s cached embeddings", csvEmbeddings.length);
|
||||
} else {
|
||||
logger.warn("CSV cache not found. Building embeddings...");
|
||||
|
||||
logger.warn("Cache not found. Generating embeddings");
|
||||
const seen = new Set<string>();
|
||||
|
||||
for (const csvPath of CSV_PATHS) {
|
||||
await buildCacheFromCSV(csvPath);
|
||||
for (const path of CSV_PATHS) {
|
||||
await processCSV(path, seen);
|
||||
}
|
||||
|
||||
const cache: EmbeddingCache = {
|
||||
rawtexts,
|
||||
cleantexts,
|
||||
embeddings,
|
||||
rawtexts: csvRawtexts,
|
||||
cleantexts: csvCleantexts,
|
||||
embeddings: csvEmbeddings,
|
||||
};
|
||||
|
||||
fs.writeFileSync(CACHE_PATH, JSON.stringify(cache));
|
||||
logger.info("Cache written (%s embeddings)", csvEmbeddings.length);
|
||||
}
|
||||
|
||||
logger.info("Cached %s embeddings", embeddings.length);
|
||||
csvBM25 = buildBM25(csvRawtexts);
|
||||
|
||||
csvLoaded = true;
|
||||
logger.info("CSV mode ready");
|
||||
}
|
||||
|
||||
async function buildCacheFromCSV(csvPath: string): Promise<void> {
|
||||
let count = 0;
|
||||
async function processCSV(
|
||||
path: string,
|
||||
seen: Set<string>
|
||||
): Promise<void> {
|
||||
logger.info("Processing CSV: %s", path);
|
||||
|
||||
logger.info("Processing CSV: %s", csvPath);
|
||||
|
||||
const stream = fs.createReadStream(csvPath).pipe(parse());
|
||||
const stream = fs.createReadStream(path).pipe(parse());
|
||||
|
||||
for await (const row of stream) {
|
||||
const text = row[0];
|
||||
if (!text) continue;
|
||||
if (!text || seen.has(text)) continue;
|
||||
|
||||
seen.add(text);
|
||||
|
||||
const embedding = await embedText(text);
|
||||
|
||||
csvRawtexts.push(text);
|
||||
csvCleantexts.push(row[1]);
|
||||
csvEmbeddings.push(embedding);
|
||||
|
||||
if (csvRawtexts.length % 100 === 0) {
|
||||
logger.info("Embedded %s documents...", csvRawtexts.length);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Finished CSV: %s", path);
|
||||
}
|
||||
|
||||
|
||||
async function embedText(text: string): Promise<number[]> {
|
||||
const output = await featureExtractor(text, {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
});
|
||||
|
||||
rawtexts.push(text);
|
||||
cleantexts.push(row[1]);
|
||||
const vector = Array.from(output.data as Float32Array);
|
||||
embeddings.push(vector);
|
||||
|
||||
|
||||
count++;
|
||||
if (count % 100 === 0) {
|
||||
logger.info("[%s] Processed %s rows", csvPath, count);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("[%s] Finished (%s rows)", csvPath, count);
|
||||
return Array.from(output.data as Float32Array);
|
||||
}
|
||||
|
||||
export async function calculateSimilarity(
|
||||
query: string,
|
||||
topK = 5
|
||||
): Promise<NormalisedMatch[]> {
|
||||
await loadOrBuildCache()
|
||||
function buildBM25(texts: string[]) {
|
||||
logger.info("Building BM25 index (%s docs)...", texts.length);
|
||||
|
||||
const queryEmbedding = await featureExtractor(query, {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
const bm25 = bm25Factory();
|
||||
|
||||
bm25.defineConfig({
|
||||
fldWeights: { text: 1 },
|
||||
bm25Params: { k1: 1.2, b: 0.75 },
|
||||
});
|
||||
|
||||
return embeddings
|
||||
.map((embedding, index) => ({
|
||||
index,
|
||||
score: cos_sim(embedding, queryEmbedding.data as number[]),
|
||||
rawtext: rawtexts[index],
|
||||
cleantext: cleantexts[index]
|
||||
}))
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.slice(0, topK);
|
||||
bm25.definePrepTasks([
|
||||
nlp.string.lowerCase,
|
||||
nlp.string.tokenize0,
|
||||
nlp.tokens.removeWords,
|
||||
]);
|
||||
|
||||
texts.forEach((text, i) => {
|
||||
bm25.addDoc({ text }, i);
|
||||
});
|
||||
|
||||
bm25.consolidate();
|
||||
|
||||
logger.info("BM25 ready");
|
||||
|
||||
return bm25;
|
||||
}
|
||||
|
||||
function computeSparseScores(
|
||||
query: string,
|
||||
bm25: any,
|
||||
texts: string[]
|
||||
): number[] {
|
||||
const results = bm25.search(query);
|
||||
|
||||
const scores = new Array(texts.length).fill(0);
|
||||
|
||||
results.forEach((r: any) => {
|
||||
scores[r[0]] = r[1];
|
||||
});
|
||||
|
||||
return scores;
|
||||
}
|
||||
|
||||
function reciprocalRankFusion(
|
||||
scoreLists: number[][],
|
||||
k = 60
|
||||
): number[] {
|
||||
const length = scoreLists[0].length;
|
||||
const fused = new Array(length).fill(0);
|
||||
|
||||
for (const scores of scoreLists) {
|
||||
const ranked = scores
|
||||
.map((score, i) => ({ score, i }))
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.map((x) => x.i);
|
||||
|
||||
ranked.forEach((docIndex, rank) => {
|
||||
fused[docIndex] += 1 / (k + rank);
|
||||
});
|
||||
}
|
||||
|
||||
return fused;
|
||||
}
|
||||
|
||||
// console.log(await rankFromCSV("barrack obama"))
|
||||
// console.log(
|
||||
// await rankDynamically(
|
||||
// "i fell over",
|
||||
// [
|
||||
// { id: 1, rawtext: "I slipped and fell on the floor." },
|
||||
// { id: 2, rawtext: "Barack Obama was the 44th president." },
|
||||
// { id: 3, rawtext: "He tripped and hurt his knee badly." },
|
||||
// { id: 4, rawtext: "The weather is sunny today." },
|
||||
// { id: 5, rawtext: "She lost her balance and fell down the stairs." },
|
||||
// ]
|
||||
// )
|
||||
// );
|
||||
@@ -2,18 +2,20 @@ import { tool } from "@langchain/core/tools";
|
||||
import * as z from "zod";
|
||||
import { queryScraper } from "./webSearch";
|
||||
import { extractWebpageContent } from "./webpageFetch";
|
||||
import { rankDynamically } from "./clan/retreiveExamples";
|
||||
|
||||
|
||||
function rankAndDisplayData(data: string[]):string {
|
||||
//TODO: hybrid re-ranking of the provided data
|
||||
return data.join("\n")
|
||||
async function rankAndDisplayData(data: string[], context: string):Promise<string> {
|
||||
let index = 0;
|
||||
let ranked = await rankDynamically(context, data.map(irm => ({ id: index++, rawtext: irm })))
|
||||
return ranked.map(itm => itm.rawtext).join("\n")
|
||||
}
|
||||
|
||||
// Define tools
|
||||
const webSearch = tool(
|
||||
async ({ a }) => {
|
||||
const data = await queryScraper(a);
|
||||
return rankAndDisplayData(data);
|
||||
return await rankAndDisplayData(data, a);
|
||||
},
|
||||
{
|
||||
name: "WebSearch",
|
||||
@@ -25,15 +27,16 @@ const webSearch = tool(
|
||||
);
|
||||
|
||||
const openWebpage = tool(
|
||||
async ({ a }) => {
|
||||
async ({ a, b }) => {
|
||||
const data = await extractWebpageContent(a);
|
||||
return rankAndDisplayData(data);
|
||||
return rankAndDisplayData(data, b);
|
||||
},
|
||||
{
|
||||
name: "OpenWebpage",
|
||||
description: "Opens webpage and returns most relevent snippets",
|
||||
schema: z.object({
|
||||
a: z.string().describe("URL"),
|
||||
b: z.string().describe("What to match against in webpage content"),
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user