Allow multiple source CSV files for normalisation. Implement real model node. Add normalizarion prompt. Implement normalization setup. Start on RAG retreival functions
This commit is contained in:
@@ -3,21 +3,29 @@ import fs from "fs";
|
||||
import { pipeline, cos_sim } from "@huggingface/transformers";
|
||||
import { logger } from "../../utils/logger";
|
||||
|
||||
const CSV_PATH = "./tools/clan/dev-eng.csv";
|
||||
const CACHE_PATH = "./tools/clan/dev-eng.embeddings.json";
|
||||
const CSV_PATHS = [
|
||||
"./tools/clan/dev-eng.csv",
|
||||
// "./tools/clan/test-eng.csv",
|
||||
"./tools/clan/train-eng.csv",
|
||||
];
|
||||
|
||||
const CACHE_PATH = "./tools/clan/dev.embeddings.json";
|
||||
|
||||
type EmbeddingCache = {
|
||||
texts: string[];
|
||||
rawtexts: string[];
|
||||
cleantexts: string[];
|
||||
embeddings: number[][];
|
||||
};
|
||||
|
||||
export type NormalisedMatch = {
|
||||
index: number;
|
||||
score: number;
|
||||
text: string
|
||||
index: number;
|
||||
score: number;
|
||||
rawtext: string;
|
||||
cleantext: string;
|
||||
};
|
||||
|
||||
let texts: string[] = [];
|
||||
let rawtexts: string[] = [];
|
||||
let cleantexts: string[] = [];
|
||||
let embeddings: number[][] = [];
|
||||
|
||||
const featureExtractor = await pipeline(
|
||||
@@ -33,20 +41,23 @@ async function loadOrBuildCache(): Promise<void> {
|
||||
const raw = fs.readFileSync(CACHE_PATH, "utf-8");
|
||||
const cache: EmbeddingCache = JSON.parse(raw);
|
||||
|
||||
texts = cache.texts;
|
||||
|
||||
rawtexts = cache.rawtexts;
|
||||
cleantexts = cache.cleantexts;
|
||||
embeddings = cache.embeddings.map(e => Array.from(e));
|
||||
|
||||
logger.info("Loaded %s embeddings", embeddings.length);
|
||||
return;
|
||||
}
|
||||
|
||||
logger.warn("Cache not found. Generating embeddings", embeddings.length);
|
||||
logger.warn("Cache not found. Generating embeddings");
|
||||
|
||||
await buildCacheFromCSV();
|
||||
for (const csvPath of CSV_PATHS) {
|
||||
await buildCacheFromCSV(csvPath);
|
||||
}
|
||||
|
||||
const cache: EmbeddingCache = {
|
||||
texts,
|
||||
rawtexts,
|
||||
cleantexts,
|
||||
embeddings,
|
||||
};
|
||||
|
||||
@@ -55,10 +66,12 @@ async function loadOrBuildCache(): Promise<void> {
|
||||
logger.info("Cached %s embeddings", embeddings.length);
|
||||
}
|
||||
|
||||
async function buildCacheFromCSV(): Promise<void> {
|
||||
async function buildCacheFromCSV(csvPath: string): Promise<void> {
|
||||
let count = 0;
|
||||
|
||||
const stream = fs.createReadStream(CSV_PATH).pipe(parse());
|
||||
logger.info("Processing CSV: %s", csvPath);
|
||||
|
||||
const stream = fs.createReadStream(csvPath).pipe(parse());
|
||||
|
||||
for await (const row of stream) {
|
||||
const text = row[0];
|
||||
@@ -69,19 +82,27 @@ async function buildCacheFromCSV(): Promise<void> {
|
||||
normalize: true,
|
||||
});
|
||||
|
||||
texts.push(text);
|
||||
rawtexts.push(text);
|
||||
cleantexts.push(row[1]);
|
||||
const vector = Array.from(output.data as Float32Array);
|
||||
embeddings.push(vector);
|
||||
|
||||
|
||||
count++;
|
||||
if (count % 100 === 0) {
|
||||
logger.info("Processed %s", count);
|
||||
logger.info("[%s] Processed %s rows", csvPath, count);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("[%s] Finished (%s rows)", csvPath, count);
|
||||
}
|
||||
|
||||
export async function calculateSimilarity(query: string,topK = 5): Promise<NormalisedMatch[]> {
|
||||
export async function calculateSimilarity(
|
||||
query: string,
|
||||
topK = 5
|
||||
): Promise<NormalisedMatch[]> {
|
||||
await loadOrBuildCache()
|
||||
|
||||
const queryEmbedding = await featureExtractor(query, {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
@@ -91,17 +112,9 @@ export async function calculateSimilarity(query: string,topK = 5): Promise<Norma
|
||||
.map((embedding, index) => ({
|
||||
index,
|
||||
score: cos_sim(embedding, queryEmbedding.data as number[]),
|
||||
text: texts[index],
|
||||
rawtext: rawtexts[index],
|
||||
cleantext: cleantexts[index]
|
||||
}))
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.slice(0, topK);
|
||||
}
|
||||
|
||||
//TEMP: testing code
|
||||
await loadOrBuildCache();
|
||||
|
||||
const results = await calculateSimilarity(
|
||||
"Wonderful to see London has taken a stand to defend freedom and the right to choose."
|
||||
);
|
||||
|
||||
console.log(results);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
import { Builder, Browser } from "selenium-webdriver";
|
||||
import firefox from "selenium-webdriver/firefox";
|
||||
|
||||
async function extractWebpageContent(url: string) : Promise<string>{
|
||||
const options = new firefox.Options();
|
||||
options.addArguments("--headless");
|
||||
|
||||
let driver = await new Builder().forBrowser(Browser.FIREFOX).setFirefoxOptions(options).build()
|
||||
try {
|
||||
await driver.get(url)
|
||||
await driver.wait(async () => {
|
||||
return await driver.executeScript(
|
||||
"return document.readyState === 'complete'"
|
||||
);
|
||||
}, 5000);
|
||||
|
||||
const readableText = await driver.executeScript(
|
||||
"return document.body.innerText;"
|
||||
) as string;
|
||||
|
||||
return readableText
|
||||
} finally {
|
||||
await driver.quit()
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: Extract, rank snippets
|
||||
|
||||
//console.log(await extractWebpageContent("https://www.bbc.co.uk/news/live/c74wd01egvyt"))
|
||||
Reference in New Issue
Block a user