Refactor example retreiving, add option for dynamic data. Add hybrid reranking to tooling. Add parsing and loop infrastructure for trigger event processing
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
import { parse } from "csv-parse";
|
||||
import fs from "fs";
|
||||
import { pipeline, cos_sim } from "@huggingface/transformers";
|
||||
import bm25Factory from "wink-bm25-text-search";
|
||||
import nlp from "wink-nlp-utils";
|
||||
import { logger } from "../../utils/logger";
|
||||
|
||||
//TODO, am getting duplicates, is it from the multi files?
|
||||
const CSV_PATHS = [
|
||||
"./tools/clan/dev-eng.csv",
|
||||
// "./tools/clan/test-eng.csv",
|
||||
"./tools/clan/train-eng.csv",
|
||||
];
|
||||
|
||||
const CACHE_PATH = "./tools/clan/dev.embeddings.json";
|
||||
const CACHE_PATH = "./tools/clan/csv.cache.json";
|
||||
|
||||
type EmbeddingCache = {
|
||||
rawtexts: string[];
|
||||
@@ -18,104 +18,262 @@ type EmbeddingCache = {
|
||||
embeddings: number[][];
|
||||
};
|
||||
|
||||
export type NormalisedMatch = {
|
||||
index: number;
|
||||
score: number;
|
||||
export type RetrievalItem = {
|
||||
id: string | number;
|
||||
rawtext: string;
|
||||
cleantext: string;
|
||||
cleantext?: string;
|
||||
};
|
||||
|
||||
let rawtexts: string[] = [];
|
||||
let cleantexts: string[] = [];
|
||||
let embeddings: number[][] = [];
|
||||
export type RankedResult = RetrievalItem & {
|
||||
denseScore: number;
|
||||
sparseScore: number;
|
||||
fusedScore: number;
|
||||
};
|
||||
|
||||
let csvRawtexts: string[] = [];
|
||||
let csvCleantexts: string[] = [];
|
||||
let csvEmbeddings: number[][] = [];
|
||||
let csvBM25: any = null;
|
||||
let csvLoaded = false;
|
||||
|
||||
logger.info("Loading embedding model...");
|
||||
const featureExtractor = await pipeline(
|
||||
"feature-extraction",
|
||||
"Xenova/all-MiniLM-L6-v2"
|
||||
);
|
||||
logger.info("Embedding model loaded");
|
||||
|
||||
//Cached entrypoint
|
||||
export async function rankFromCSV(
|
||||
query: string,
|
||||
topK = 5
|
||||
): Promise<RankedResult[]> {
|
||||
await ensureCSVLoaded();
|
||||
|
||||
logger.info("Ranking from CSV cache...");
|
||||
|
||||
const queryEmbedding = await embedText(query);
|
||||
|
||||
const denseScores = csvEmbeddings.map((docEmbedding) =>
|
||||
cos_sim(docEmbedding, queryEmbedding)
|
||||
);
|
||||
|
||||
const sparseScores = computeSparseScores(query, csvBM25, csvRawtexts);
|
||||
|
||||
const fusedScores = reciprocalRankFusion([denseScores, sparseScores]);
|
||||
|
||||
const ranked = csvRawtexts
|
||||
.map((text, i) => ({
|
||||
id: i,
|
||||
rawtext: text,
|
||||
cleantext: csvCleantexts[i],
|
||||
denseScore: denseScores[i],
|
||||
sparseScore: sparseScores[i],
|
||||
fusedScore: fusedScores[i],
|
||||
}))
|
||||
.sort((a, b) => b.fusedScore - a.fusedScore);
|
||||
|
||||
logger.info("Ranking complete (CSV mode)");
|
||||
|
||||
return ranked.slice(0, topK);
|
||||
}
|
||||
|
||||
//Dynamic Entrypoint
|
||||
export async function rankDynamically(
|
||||
query: string,
|
||||
items: RetrievalItem[],
|
||||
topK = 5
|
||||
): Promise<RankedResult[]> {
|
||||
logger.info("Ranking dynamically (no cache)...");
|
||||
|
||||
if (!items.length) return [];
|
||||
|
||||
const texts = items.map((i) => i.rawtext);
|
||||
|
||||
const queryEmbedding = await embedText(query);
|
||||
|
||||
const docEmbeddings = await Promise.all(
|
||||
texts.map((text) => embedText(text))
|
||||
);
|
||||
|
||||
const denseScores = docEmbeddings.map((embedding) =>
|
||||
cos_sim(embedding, queryEmbedding)
|
||||
);
|
||||
|
||||
const localBM25 = buildBM25(texts);
|
||||
|
||||
const sparseScores = computeSparseScores(query, localBM25, texts);
|
||||
|
||||
const fusedScores = reciprocalRankFusion([denseScores, sparseScores]);
|
||||
|
||||
const ranked = items
|
||||
.map((item, i) => ({
|
||||
...item,
|
||||
denseScore: denseScores[i],
|
||||
sparseScore: sparseScores[i],
|
||||
fusedScore: fusedScores[i],
|
||||
}))
|
||||
.sort((a, b) => b.fusedScore - a.fusedScore);
|
||||
|
||||
logger.info("Ranking complete (dynamic mode)");
|
||||
|
||||
return ranked.slice(0, topK);
|
||||
}
|
||||
|
||||
//CSV stuff
|
||||
async function ensureCSVLoaded(): Promise<void> {
|
||||
if (csvLoaded) return;
|
||||
|
||||
logger.info("Initializing CSV ranking mode...");
|
||||
|
||||
async function loadOrBuildCache(): Promise<void> {
|
||||
if (fs.existsSync(CACHE_PATH)) {
|
||||
logger.info("Loading embeddings from cache");
|
||||
logger.info("Loading CSV cache from disk...");
|
||||
|
||||
const raw = fs.readFileSync(CACHE_PATH, "utf-8");
|
||||
const cache: EmbeddingCache = JSON.parse(raw);
|
||||
|
||||
rawtexts = cache.rawtexts;
|
||||
cleantexts = cache.cleantexts;
|
||||
embeddings = cache.embeddings.map(e => Array.from(e));
|
||||
csvRawtexts = cache.rawtexts;
|
||||
csvCleantexts = cache.cleantexts;
|
||||
csvEmbeddings = cache.embeddings;
|
||||
|
||||
logger.info("Loaded %s embeddings", embeddings.length);
|
||||
return;
|
||||
logger.info("Loaded %s cached embeddings", csvEmbeddings.length);
|
||||
} else {
|
||||
logger.warn("CSV cache not found. Building embeddings...");
|
||||
|
||||
const seen = new Set<string>();
|
||||
|
||||
for (const path of CSV_PATHS) {
|
||||
await processCSV(path, seen);
|
||||
}
|
||||
|
||||
const cache: EmbeddingCache = {
|
||||
rawtexts: csvRawtexts,
|
||||
cleantexts: csvCleantexts,
|
||||
embeddings: csvEmbeddings,
|
||||
};
|
||||
|
||||
fs.writeFileSync(CACHE_PATH, JSON.stringify(cache));
|
||||
logger.info("Cache written (%s embeddings)", csvEmbeddings.length);
|
||||
}
|
||||
|
||||
logger.warn("Cache not found. Generating embeddings");
|
||||
csvBM25 = buildBM25(csvRawtexts);
|
||||
|
||||
for (const csvPath of CSV_PATHS) {
|
||||
await buildCacheFromCSV(csvPath);
|
||||
}
|
||||
|
||||
const cache: EmbeddingCache = {
|
||||
rawtexts,
|
||||
cleantexts,
|
||||
embeddings,
|
||||
};
|
||||
|
||||
fs.writeFileSync(CACHE_PATH, JSON.stringify(cache));
|
||||
|
||||
logger.info("Cached %s embeddings", embeddings.length);
|
||||
csvLoaded = true;
|
||||
logger.info("CSV mode ready");
|
||||
}
|
||||
|
||||
async function buildCacheFromCSV(csvPath: string): Promise<void> {
|
||||
let count = 0;
|
||||
async function processCSV(
|
||||
path: string,
|
||||
seen: Set<string>
|
||||
): Promise<void> {
|
||||
logger.info("Processing CSV: %s", path);
|
||||
|
||||
logger.info("Processing CSV: %s", csvPath);
|
||||
|
||||
const stream = fs.createReadStream(csvPath).pipe(parse());
|
||||
const stream = fs.createReadStream(path).pipe(parse());
|
||||
|
||||
for await (const row of stream) {
|
||||
const text = row[0];
|
||||
if (!text) continue;
|
||||
if (!text || seen.has(text)) continue;
|
||||
|
||||
const output = await featureExtractor(text, {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
});
|
||||
seen.add(text);
|
||||
|
||||
rawtexts.push(text);
|
||||
cleantexts.push(row[1]);
|
||||
const vector = Array.from(output.data as Float32Array);
|
||||
embeddings.push(vector);
|
||||
const embedding = await embedText(text);
|
||||
|
||||
csvRawtexts.push(text);
|
||||
csvCleantexts.push(row[1]);
|
||||
csvEmbeddings.push(embedding);
|
||||
|
||||
count++;
|
||||
if (count % 100 === 0) {
|
||||
logger.info("[%s] Processed %s rows", csvPath, count);
|
||||
if (csvRawtexts.length % 100 === 0) {
|
||||
logger.info("Embedded %s documents...", csvRawtexts.length);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("[%s] Finished (%s rows)", csvPath, count);
|
||||
logger.info("Finished CSV: %s", path);
|
||||
}
|
||||
|
||||
export async function calculateSimilarity(
|
||||
query: string,
|
||||
topK = 5
|
||||
): Promise<NormalisedMatch[]> {
|
||||
await loadOrBuildCache()
|
||||
|
||||
const queryEmbedding = await featureExtractor(query, {
|
||||
async function embedText(text: string): Promise<number[]> {
|
||||
const output = await featureExtractor(text, {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
});
|
||||
|
||||
return embeddings
|
||||
.map((embedding, index) => ({
|
||||
index,
|
||||
score: cos_sim(embedding, queryEmbedding.data as number[]),
|
||||
rawtext: rawtexts[index],
|
||||
cleantext: cleantexts[index]
|
||||
}))
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.slice(0, topK);
|
||||
}
|
||||
return Array.from(output.data as Float32Array);
|
||||
}
|
||||
|
||||
function buildBM25(texts: string[]) {
|
||||
logger.info("Building BM25 index (%s docs)...", texts.length);
|
||||
|
||||
const bm25 = bm25Factory();
|
||||
|
||||
bm25.defineConfig({
|
||||
fldWeights: { text: 1 },
|
||||
bm25Params: { k1: 1.2, b: 0.75 },
|
||||
});
|
||||
|
||||
bm25.definePrepTasks([
|
||||
nlp.string.lowerCase,
|
||||
nlp.string.tokenize0,
|
||||
nlp.tokens.removeWords,
|
||||
]);
|
||||
|
||||
texts.forEach((text, i) => {
|
||||
bm25.addDoc({ text }, i);
|
||||
});
|
||||
|
||||
bm25.consolidate();
|
||||
|
||||
logger.info("BM25 ready");
|
||||
|
||||
return bm25;
|
||||
}
|
||||
|
||||
function computeSparseScores(
|
||||
query: string,
|
||||
bm25: any,
|
||||
texts: string[]
|
||||
): number[] {
|
||||
const results = bm25.search(query);
|
||||
|
||||
const scores = new Array(texts.length).fill(0);
|
||||
|
||||
results.forEach((r: any) => {
|
||||
scores[r[0]] = r[1];
|
||||
});
|
||||
|
||||
return scores;
|
||||
}
|
||||
|
||||
function reciprocalRankFusion(
|
||||
scoreLists: number[][],
|
||||
k = 60
|
||||
): number[] {
|
||||
const length = scoreLists[0].length;
|
||||
const fused = new Array(length).fill(0);
|
||||
|
||||
for (const scores of scoreLists) {
|
||||
const ranked = scores
|
||||
.map((score, i) => ({ score, i }))
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.map((x) => x.i);
|
||||
|
||||
ranked.forEach((docIndex, rank) => {
|
||||
fused[docIndex] += 1 / (k + rank);
|
||||
});
|
||||
}
|
||||
|
||||
return fused;
|
||||
}
|
||||
|
||||
// console.log(await rankFromCSV("barrack obama"))
|
||||
// console.log(
|
||||
// await rankDynamically(
|
||||
// "i fell over",
|
||||
// [
|
||||
// { id: 1, rawtext: "I slipped and fell on the floor." },
|
||||
// { id: 2, rawtext: "Barack Obama was the 44th president." },
|
||||
// { id: 3, rawtext: "He tripped and hurt his knee badly." },
|
||||
// { id: 4, rawtext: "The weather is sunny today." },
|
||||
// { id: 5, rawtext: "She lost her balance and fell down the stairs." },
|
||||
// ]
|
||||
// )
|
||||
// );
|
||||
@@ -2,18 +2,20 @@ import { tool } from "@langchain/core/tools";
|
||||
import * as z from "zod";
|
||||
import { queryScraper } from "./webSearch";
|
||||
import { extractWebpageContent } from "./webpageFetch";
|
||||
import { rankDynamically } from "./clan/retreiveExamples";
|
||||
|
||||
|
||||
function rankAndDisplayData(data: string[]):string {
|
||||
//TODO: hybrid re-ranking of the provided data
|
||||
return data.join("\n")
|
||||
async function rankAndDisplayData(data: string[], context: string):Promise<string> {
|
||||
let index = 0;
|
||||
let ranked = await rankDynamically(context, data.map(irm => ({ id: index++, rawtext: irm })))
|
||||
return ranked.map(itm => itm.rawtext).join("\n")
|
||||
}
|
||||
|
||||
// Define tools
|
||||
const webSearch = tool(
|
||||
async ({ a }) => {
|
||||
const data = await queryScraper(a);
|
||||
return rankAndDisplayData(data);
|
||||
return await rankAndDisplayData(data, a);
|
||||
},
|
||||
{
|
||||
name: "WebSearch",
|
||||
@@ -25,15 +27,16 @@ const webSearch = tool(
|
||||
);
|
||||
|
||||
const openWebpage = tool(
|
||||
async ({ a }) => {
|
||||
async ({ a, b }) => {
|
||||
const data = await extractWebpageContent(a);
|
||||
return rankAndDisplayData(data);
|
||||
return rankAndDisplayData(data, b);
|
||||
},
|
||||
{
|
||||
name: "OpenWebpage",
|
||||
description: "Opens webpage and returns most relevent snippets",
|
||||
schema: z.object({
|
||||
a: z.string().describe("URL"),
|
||||
b: z.string().describe("What to match against in webpage content"),
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user