Add removing of duplicates from pipeline. Add to sort step. Move score logic to robertaMetrics node.
This commit is contained in:
@@ -8,7 +8,12 @@ export const robertaMetrics: GraphNode<typeof MessagesState> = async (state) =>
|
|||||||
|
|
||||||
const result = await evaluateWithRoberta({answer})
|
const result = await evaluateWithRoberta({answer})
|
||||||
|
|
||||||
|
let score = 0;
|
||||||
|
if (result.validProb > result.invalidProb) {
|
||||||
|
score = 0.7 + ((result.validProb - result.invalidProb)*0.3);
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
messages: [ new AIMessage("ROBERTA:" + result)]
|
messages: [ new AIMessage("ROBERTA:" + score)]
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
+6
-2
@@ -1,10 +1,14 @@
|
|||||||
import { GraphNode } from "@langchain/langgraph";
|
import { GraphNode } from "@langchain/langgraph";
|
||||||
import { MessagesState } from "../state";
|
import { MessagesState } from "../state";
|
||||||
import { AIMessage } from "@langchain/core/messages";
|
import { AIMessage } from "@langchain/core/messages";
|
||||||
|
import { removeDuplicates } from "../tools/removeDuplicates";
|
||||||
export const sort: GraphNode<typeof MessagesState> = async (state) => {
|
export const sort: GraphNode<typeof MessagesState> = async (state) => {
|
||||||
//not sure which will be better from API, just do both
|
|
||||||
|
|
||||||
let current = state.proposedTriggerEvent;
|
let current = state.proposedTriggerEvent;
|
||||||
|
|
||||||
|
// remove duplicates
|
||||||
|
await removeDuplicates(current)
|
||||||
|
|
||||||
|
// not sure which will be better from API, just do both
|
||||||
current.sort((a, b) => ((b.score as number) ?? 0) - ((a.score as number) ?? 0));
|
current.sort((a, b) => ((b.score as number) ?? 0) - ((a.score as number) ?? 0));
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
import { pipeline, cos_sim } from "@huggingface/transformers";
|
||||||
|
|
||||||
|
let featureExtractor = await pipeline(
|
||||||
|
"feature-extraction",
|
||||||
|
"Xenova/all-MiniLM-L6-v2"
|
||||||
|
);
|
||||||
|
|
||||||
|
export async function removeDuplicates(state: any) {
|
||||||
|
const embeddings: number[][] = [];
|
||||||
|
|
||||||
|
const outputs = await featureExtractor(
|
||||||
|
state.map(s => s.Event),
|
||||||
|
{ pooling: "mean", normalize: true }
|
||||||
|
);
|
||||||
|
|
||||||
|
for (const o of outputs) {
|
||||||
|
embeddings.push(Array.from(o.data));
|
||||||
|
}
|
||||||
|
|
||||||
|
const len = state.length;
|
||||||
|
for (let i = 0; i < len; i++) {
|
||||||
|
for (let j = i + 1; j < len; j++) {
|
||||||
|
if (state[i].score === -1 || state[j].score === -1) continue;
|
||||||
|
|
||||||
|
const sim = cos_sim(embeddings[i], embeddings[j]);
|
||||||
|
console.log(sim)
|
||||||
|
if (sim > 0.55) {
|
||||||
|
const scoreI = state[i].score ?? 0;
|
||||||
|
const scoreJ = state[j].score ?? 0;
|
||||||
|
|
||||||
|
if (scoreI > scoreJ) {
|
||||||
|
state[j].score = -1;
|
||||||
|
} else if (scoreJ > scoreI) {
|
||||||
|
state[i].score = -1;
|
||||||
|
} else {
|
||||||
|
// if equal, keep earlier
|
||||||
|
state[j].score = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return state;
|
||||||
|
}
|
||||||
@@ -4,15 +4,15 @@ export async function evaluateWithRoberta({
|
|||||||
answer
|
answer
|
||||||
}: {
|
}: {
|
||||||
answer: string;
|
answer: string;
|
||||||
}) {
|
}): Promise<{ validProb: number; invalidProb: number; }> {
|
||||||
const res = await axios.post("http://localhost:8000/evaluate", {
|
const res = await axios.post("http://localhost:8000/evaluate", {
|
||||||
answer
|
answer
|
||||||
});
|
});
|
||||||
// console.log(res.data)
|
// console.log(res.data)
|
||||||
const validProb = res.data["probabilities"][0][0]
|
const validProb = res.data["probabilities"][0][0]
|
||||||
const invalidProv = res.data["probabilities"][0][1]
|
const invalidProb = res.data["probabilities"][0][1]
|
||||||
|
|
||||||
return validProb > invalidProv ? 1 : 0;
|
return {validProb, invalidProb};
|
||||||
}
|
}
|
||||||
|
|
||||||
// let res = await evaluateWithRoberta({answer: "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in March–August 2020)"});
|
// let res = await evaluateWithRoberta({answer: "High-profile political downplaying of COVID-19 (examples: President Trump saying 'it will go away' in March–August 2020)"});
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
# CONFIG
|
||||||
|
CSV_PATH = "../../data/classify.csv"
|
||||||
|
EVENT_COLUMN = "event"
|
||||||
|
TOP_K = 60
|
||||||
|
|
||||||
|
# Load CSV
|
||||||
|
df = pd.read_csv(CSV_PATH)
|
||||||
|
|
||||||
|
events = df[EVENT_COLUMN].astype(str).tolist()
|
||||||
|
|
||||||
|
# Load embedding model
|
||||||
|
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
||||||
|
|
||||||
|
print("Embedding events...")
|
||||||
|
embeddings = model.encode(events, batch_size=32, show_progress_bar=True)
|
||||||
|
|
||||||
|
# Compute cosine similarity matrix
|
||||||
|
sim_matrix = cosine_similarity(embeddings)
|
||||||
|
|
||||||
|
# Collect pair similarities
|
||||||
|
pairs = []
|
||||||
|
|
||||||
|
n = len(events)
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(i + 1, n): # avoid duplicates and self comparisons
|
||||||
|
pairs.append((sim_matrix[i][j], i, j))
|
||||||
|
|
||||||
|
# Sort by similarity descending
|
||||||
|
pairs.sort(reverse=True, key=lambda x: x[0])
|
||||||
|
|
||||||
|
# Top K pairs
|
||||||
|
top_pairs = pairs[:TOP_K]
|
||||||
|
|
||||||
|
print("\nTop Similar Event Pairs:\n")
|
||||||
|
|
||||||
|
for score, i, j in top_pairs:
|
||||||
|
print(f"Similarity: {score:.4f}")
|
||||||
|
print(f"Event 1: {events[i]}")
|
||||||
|
print(f"Event 2: {events[j]}")
|
||||||
|
print("-" * 60)
|
||||||
@@ -82,7 +82,7 @@ def render():
|
|||||||
extra_lower = (event.get("extra_info", "") or "").strip().lower()
|
extra_lower = (event.get("extra_info", "") or "").strip().lower()
|
||||||
# print(extra_lower)
|
# print(extra_lower)
|
||||||
if score is not None:
|
if score is not None:
|
||||||
if "duplicate" in extra_lower:
|
if score == -1 or "duplicate" in extra_lower:
|
||||||
dup_counter += 1
|
dup_counter += 1
|
||||||
elif score > THRESH and extra_lower == "perfect":
|
elif score > THRESH and extra_lower == "perfect":
|
||||||
confidence_counter["Correct-PERFECT"] += 1
|
confidence_counter["Correct-PERFECT"] += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user