Add initial code for retreival ranking for normalisation
This commit is contained in:
@@ -1,3 +1,8 @@
|
||||
# -------- Ours --------
|
||||
tools/clan/*.csv
|
||||
tools/clan/*.json
|
||||
|
||||
# --------- Github -----------
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
|
||||
+1
-1
@@ -2,7 +2,7 @@ import { AIMessage, ToolMessage } from "@langchain/core/messages";
|
||||
import { GraphNode } from "@langchain/langgraph";
|
||||
import { MessagesState } from "../state";
|
||||
|
||||
export function createToolNode(tools): GraphNode<typeof MessagesState> {
|
||||
export function createToolNode(tools: any): GraphNode<typeof MessagesState> {
|
||||
return async (state) => {
|
||||
const lastMessage = state.messages.at(-1);
|
||||
|
||||
|
||||
Generated
+1364
-1
File diff suppressed because it is too large
Load Diff
+7
-1
@@ -10,14 +10,20 @@
|
||||
"dev": "tsx run.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@huggingface/transformers": "^3.8.1",
|
||||
"@langchain/core": "^1.1.17",
|
||||
"@langchain/langgraph": "^1.1.2",
|
||||
"@langchain/langgraph-sdk": "^1.5.5",
|
||||
"@langchain/openai": "^1.2.3",
|
||||
"compute-cosine-similarity": "^1.1.0",
|
||||
"csv-parse": "^6.1.0",
|
||||
"dotenv": "^17.2.3",
|
||||
"langchain": "^1.2.14"
|
||||
"fs": "^0.0.1-security",
|
||||
"langchain": "^1.2.14",
|
||||
"winston": "^3.19.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^25.1.0",
|
||||
"tsx": "^4.21.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
import { parse } from "csv-parse";
|
||||
import fs from "fs";
|
||||
import { pipeline, cos_sim } from "@huggingface/transformers";
|
||||
import { logger } from "../../utils/logger";
|
||||
|
||||
const CSV_PATH = "./tools/clan/dev-eng.csv";
|
||||
const CACHE_PATH = "./tools/clan/dev-eng.embeddings.json";
|
||||
|
||||
type EmbeddingCache = {
|
||||
texts: string[];
|
||||
embeddings: number[][];
|
||||
};
|
||||
|
||||
export type NormalisedMatch = {
|
||||
index: number;
|
||||
score: number;
|
||||
text: string
|
||||
};
|
||||
|
||||
let texts: string[] = [];
|
||||
let embeddings: number[][] = [];
|
||||
|
||||
const featureExtractor = await pipeline(
|
||||
"feature-extraction",
|
||||
"Xenova/all-MiniLM-L6-v2"
|
||||
);
|
||||
|
||||
|
||||
async function loadOrBuildCache(): Promise<void> {
|
||||
if (fs.existsSync(CACHE_PATH)) {
|
||||
logger.info("Loading embeddings from cache");
|
||||
|
||||
const raw = fs.readFileSync(CACHE_PATH, "utf-8");
|
||||
const cache: EmbeddingCache = JSON.parse(raw);
|
||||
|
||||
texts = cache.texts;
|
||||
|
||||
embeddings = cache.embeddings.map(e => Array.from(e));
|
||||
|
||||
logger.info("Loaded %s embeddings", embeddings.length);
|
||||
return;
|
||||
}
|
||||
|
||||
logger.warn("Cache not found. Generating embeddings", embeddings.length);
|
||||
|
||||
await buildCacheFromCSV();
|
||||
|
||||
const cache: EmbeddingCache = {
|
||||
texts,
|
||||
embeddings,
|
||||
};
|
||||
|
||||
fs.writeFileSync(CACHE_PATH, JSON.stringify(cache));
|
||||
|
||||
logger.info("Cached %s embeddings", embeddings.length);
|
||||
}
|
||||
|
||||
async function buildCacheFromCSV(): Promise<void> {
|
||||
let count = 0;
|
||||
|
||||
const stream = fs.createReadStream(CSV_PATH).pipe(parse());
|
||||
|
||||
for await (const row of stream) {
|
||||
const text = row[0];
|
||||
if (!text) continue;
|
||||
|
||||
const output = await featureExtractor(text, {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
});
|
||||
|
||||
texts.push(text);
|
||||
const vector = Array.from(output.data as Float32Array);
|
||||
embeddings.push(vector);
|
||||
|
||||
|
||||
count++;
|
||||
if (count % 100 === 0) {
|
||||
logger.info("Processed %s", count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function calculateSimilarity(query: string,topK = 5): Promise<NormalisedMatch[]> {
|
||||
const queryEmbedding = await featureExtractor(query, {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
});
|
||||
|
||||
return embeddings
|
||||
.map((embedding, index) => ({
|
||||
index,
|
||||
score: cos_sim(embedding, queryEmbedding.data as number[]),
|
||||
text: texts[index],
|
||||
}))
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.slice(0, topK);
|
||||
}
|
||||
|
||||
//TEMP: testing code
|
||||
await loadOrBuildCache();
|
||||
|
||||
const results = await calculateSimilarity(
|
||||
"Wonderful to see London has taken a stand to defend freedom and the right to choose."
|
||||
);
|
||||
|
||||
console.log(results);
|
||||
@@ -0,0 +1,16 @@
|
||||
import winston from "winston";
|
||||
|
||||
export const logger = winston.createLogger({
|
||||
level: "info",
|
||||
format: winston.format.combine(
|
||||
winston.format.splat(),
|
||||
winston.format.colorize(),
|
||||
winston.format.timestamp({ format: "HH:mm:ss" }),
|
||||
winston.format.printf(({ level, message, timestamp }) => {
|
||||
return `${timestamp} ${level}: ${message}`;
|
||||
})
|
||||
),
|
||||
transports: [
|
||||
new winston.transports.Console(),
|
||||
],
|
||||
});
|
||||
Reference in New Issue
Block a user