Add grph visualiser initial version

This commit is contained in:
William Jeynes
2026-04-08 21:00:24 +01:00
parent cf6b29ca91
commit aa962b1802
9 changed files with 3803 additions and 41 deletions
+39 -41
View File
@@ -13,17 +13,13 @@ from tqdm import tqdm
INPUT_CSV = "../../data/dataset-dev.csv"
OUTPUT_JSON = "../../data/clustered_output.json"
MODEL_NAME = "all-MiniLM-L6-v2"
SIMILARITY_THRESHOLD = 0.65
SIMILARITY_THRESHOLD = 0.55
def generate_guid():
return str(uuid.uuid4())
def read_csv(file_path: str):
"""
Expected format per row:
[claim, event1, event2, event3, ...]
"""
data = []
with open(file_path, newline='', encoding='utf-8') as f:
@@ -63,10 +59,7 @@ def embed_texts(model, texts: List[str], desc="Embedding"):
return np.array(embeddings)
def cluster_embeddings(embeddings, threshold=0.75, desc="Clustering"):
"""
Uses Agglomerative clustering with cosine distance
"""
def cluster_embeddings(embeddings, threshold=0.75):
distance_matrix = 1 - cosine_similarity(embeddings)
clustering = AgglomerativeClustering(
@@ -85,68 +78,74 @@ def main():
data = read_csv(INPUT_CSV)
# Collect all claims and events separately
claim_texts = []
claim_ids = []
claim_texts, claim_ids = [], []
event_texts, event_ids = [], []
event_texts = []
event_ids = []
links = [] # claim -> events
raw_links = [] # temporary for cluster mapping
for entry in tqdm(data, desc="Processing rows"):
claim = entry["claim"]
claim_ids.append(claim["id"])
# Context-enhanced claim
claim_texts.append(f"Claim: {claim['text']}")
for event in entry["events"]:
event_ids.append(event["id"])
# Context-enhanced event
event_texts.append(f"Event: {event['text']}")
links.append({
raw_links.append({
"claim_id": claim["id"],
"event_id": event["id"]
})
# Embed
print("Embedding claims...")
claim_embeddings = embed_texts(model, claim_texts, desc="Claims")
print("Embedding events...")
event_embeddings = embed_texts(model, event_texts, desc="Events")
# Cluster
print("Clustering claims...")
claim_labels = cluster_embeddings(claim_embeddings, SIMILARITY_THRESHOLD)
print("Clustering events...")
event_labels = cluster_embeddings(event_embeddings, SIMILARITY_THRESHOLD)
# Build cluster structures
claim_clusters: Dict[int, List[str]] = {}
# Assign GUIDs to clusters
claim_cluster_map = {}
for label in set(claim_labels):
claim_cluster_map[int(label)] = generate_guid()
event_cluster_map = {}
for label in set(event_labels):
event_cluster_map[int(label)] = generate_guid()
# Build cluster membership
claim_clusters = {}
for cid, label in zip(claim_ids, claim_labels):
claim_clusters.setdefault(int(label), []).append(cid)
cluster_guid = claim_cluster_map[int(label)]
claim_clusters.setdefault(cluster_guid, []).append(cid)
event_clusters: Dict[int, List[str]] = {}
event_clusters = {}
for eid, label in zip(event_ids, event_labels):
event_clusters.setdefault(int(label), []).append(eid)
cluster_guid = event_cluster_map[int(label)]
event_clusters.setdefault(cluster_guid, []).append(eid)
# Build cluster-level links
cluster_links = []
for link in links:
claim_cluster = int(claim_labels[claim_ids.index(link["claim_id"])])
event_cluster = int(event_labels[event_ids.index(link["event_id"])])
# Build ONLY cluster-level links
cluster_links = set()
cluster_links.append({
"claim_cluster": claim_cluster,
"event_cluster": event_cluster
})
for link in raw_links:
claim_label = int(claim_labels[claim_ids.index(link["claim_id"])])
event_label = int(event_labels[event_ids.index(link["event_id"])])
claim_cluster_guid = claim_cluster_map[claim_label]
event_cluster_guid = event_cluster_map[event_label]
cluster_links.add((claim_cluster_guid, event_cluster_guid))
cluster_links = [
{"claim_cluster_id": c, "event_cluster_id": e}
for c, e in cluster_links
]
# Output structure
output = {
"claims": [
{"id": cid, "text": txt.replace("Claim: ", "")}
@@ -157,14 +156,13 @@ def main():
for eid, txt in zip(event_ids, event_texts)
],
"claim_clusters": [
{"cluster_id": int(k), "members": v}
{"cluster_id": k, "members": v}
for k, v in claim_clusters.items()
],
"event_clusters": [
{"cluster_id": int(k), "members": v}
{"cluster_id": k, "members": v}
for k, v in event_clusters.items()
],
"links": links,
"cluster_links": cluster_links
}