Add initial code for retreival ranking for normalisation

This commit is contained in:
William Jeynes
2026-01-29 21:53:38 +00:00
parent a1373da891
commit 8eaa7bfbff
7 changed files with 1502 additions and 3 deletions
+2
View File
@@ -0,0 +1,2 @@
# TEMP
literature/
+5
View File
@@ -1,3 +1,8 @@
# -------- Ours --------
tools/clan/*.csv
tools/clan/*.json
# --------- Github -----------
# Logs # Logs
logs logs
*.log *.log
+1 -1
View File
@@ -2,7 +2,7 @@ import { AIMessage, ToolMessage } from "@langchain/core/messages";
import { GraphNode } from "@langchain/langgraph"; import { GraphNode } from "@langchain/langgraph";
import { MessagesState } from "../state"; import { MessagesState } from "../state";
export function createToolNode(tools): GraphNode<typeof MessagesState> { export function createToolNode(tools: any): GraphNode<typeof MessagesState> {
return async (state) => { return async (state) => {
const lastMessage = state.messages.at(-1); const lastMessage = state.messages.at(-1);
+1364 -1
View File
File diff suppressed because it is too large Load Diff
+7 -1
View File
@@ -10,14 +10,20 @@
"dev": "tsx run.ts" "dev": "tsx run.ts"
}, },
"dependencies": { "dependencies": {
"@huggingface/transformers": "^3.8.1",
"@langchain/core": "^1.1.17", "@langchain/core": "^1.1.17",
"@langchain/langgraph": "^1.1.2", "@langchain/langgraph": "^1.1.2",
"@langchain/langgraph-sdk": "^1.5.5", "@langchain/langgraph-sdk": "^1.5.5",
"@langchain/openai": "^1.2.3", "@langchain/openai": "^1.2.3",
"compute-cosine-similarity": "^1.1.0",
"csv-parse": "^6.1.0",
"dotenv": "^17.2.3", "dotenv": "^17.2.3",
"langchain": "^1.2.14" "fs": "^0.0.1-security",
"langchain": "^1.2.14",
"winston": "^3.19.0"
}, },
"devDependencies": { "devDependencies": {
"@types/node": "^25.1.0",
"tsx": "^4.21.0" "tsx": "^4.21.0"
} }
} }
+107
View File
@@ -0,0 +1,107 @@
import { parse } from "csv-parse";
import fs from "fs";
import { pipeline, cos_sim } from "@huggingface/transformers";
import { logger } from "../../utils/logger";
const CSV_PATH = "./tools/clan/dev-eng.csv";
const CACHE_PATH = "./tools/clan/dev-eng.embeddings.json";
type EmbeddingCache = {
texts: string[];
embeddings: number[][];
};
export type NormalisedMatch = {
index: number;
score: number;
text: string
};
let texts: string[] = [];
let embeddings: number[][] = [];
const featureExtractor = await pipeline(
"feature-extraction",
"Xenova/all-MiniLM-L6-v2"
);
async function loadOrBuildCache(): Promise<void> {
if (fs.existsSync(CACHE_PATH)) {
logger.info("Loading embeddings from cache");
const raw = fs.readFileSync(CACHE_PATH, "utf-8");
const cache: EmbeddingCache = JSON.parse(raw);
texts = cache.texts;
embeddings = cache.embeddings.map(e => Array.from(e));
logger.info("Loaded %s embeddings", embeddings.length);
return;
}
logger.warn("Cache not found. Generating embeddings", embeddings.length);
await buildCacheFromCSV();
const cache: EmbeddingCache = {
texts,
embeddings,
};
fs.writeFileSync(CACHE_PATH, JSON.stringify(cache));
logger.info("Cached %s embeddings", embeddings.length);
}
async function buildCacheFromCSV(): Promise<void> {
let count = 0;
const stream = fs.createReadStream(CSV_PATH).pipe(parse());
for await (const row of stream) {
const text = row[0];
if (!text) continue;
const output = await featureExtractor(text, {
pooling: "mean",
normalize: true,
});
texts.push(text);
const vector = Array.from(output.data as Float32Array);
embeddings.push(vector);
count++;
if (count % 100 === 0) {
logger.info("Processed %s", count);
}
}
}
export async function calculateSimilarity(query: string,topK = 5): Promise<NormalisedMatch[]> {
const queryEmbedding = await featureExtractor(query, {
pooling: "mean",
normalize: true,
});
return embeddings
.map((embedding, index) => ({
index,
score: cos_sim(embedding, queryEmbedding.data as number[]),
text: texts[index],
}))
.sort((a, b) => b.score - a.score)
.slice(0, topK);
}
//TEMP: testing code
await loadOrBuildCache();
const results = await calculateSimilarity(
"Wonderful to see London has taken a stand to defend freedom and the right to choose."
);
console.log(results);
+16
View File
@@ -0,0 +1,16 @@
import winston from "winston";
export const logger = winston.createLogger({
level: "info",
format: winston.format.combine(
winston.format.splat(),
winston.format.colorize(),
winston.format.timestamp({ format: "HH:mm:ss" }),
winston.format.printf(({ level, message, timestamp }) => {
return `${timestamp} ${level}: ${message}`;
})
),
transports: [
new winston.transports.Console(),
],
});