Add date ranges to frontend visualisation
This commit is contained in:
@@ -2,4 +2,6 @@
|
|||||||
dist/
|
dist/
|
||||||
node_modules/
|
node_modules/
|
||||||
src/data.json
|
src/data.json
|
||||||
|
src/data_date.json
|
||||||
src/titles.json
|
src/titles.json
|
||||||
|
src/titles_date.json
|
||||||
@@ -0,0 +1,335 @@
|
|||||||
|
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||||
|
import ForceGraph2D from "react-force-graph-2d";
|
||||||
|
import * as d3 from "d3-force-3d";
|
||||||
|
|
||||||
|
import data from "./data_date.json";
|
||||||
|
import titlesData from "./titles_date.json";
|
||||||
|
|
||||||
|
function drawRoundedRect(ctx, x, y, width, height, radius) {
|
||||||
|
const r = Math.min(radius, width / 2, height / 2);
|
||||||
|
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.moveTo(x + r, y);
|
||||||
|
ctx.lineTo(x + width - r, y);
|
||||||
|
ctx.quadraticCurveTo(x + width, y, x + width, y + r);
|
||||||
|
ctx.lineTo(x + width, y + height - r);
|
||||||
|
ctx.quadraticCurveTo(x + width, y + height, x + width - r, y + height);
|
||||||
|
ctx.lineTo(x + r, y + height);
|
||||||
|
ctx.quadraticCurveTo(x, y + height, x, y + height - r);
|
||||||
|
ctx.lineTo(x, y + r);
|
||||||
|
ctx.quadraticCurveTo(x, y, x + r, y);
|
||||||
|
ctx.closePath();
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseDateSafe(dateStr) {
|
||||||
|
if (!dateStr) return null;
|
||||||
|
const d = new Date(dateStr);
|
||||||
|
if (isNaN(d.getTime())) return null;
|
||||||
|
if (d.getFullYear() < 2016) return null; // filter erroneous
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
function monthsDiff(a, b) {
|
||||||
|
const ms = Math.abs(a - b);
|
||||||
|
return ms / (1000 * 60 * 60 * 24 * 30.44);
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildLookupMaps(data) {
|
||||||
|
const claimMap = new Map(data.claims.map(c => [c.id, c]));
|
||||||
|
const eventMap = new Map(data.events.map(e => [e.id, e]));
|
||||||
|
return { claimMap, eventMap };
|
||||||
|
}
|
||||||
|
|
||||||
|
function computeClusterAvgDate(members, claimMap, eventMap) {
|
||||||
|
const dates = [];
|
||||||
|
|
||||||
|
members.forEach(id => {
|
||||||
|
const c = claimMap.get(id);
|
||||||
|
const e = eventMap.get(id);
|
||||||
|
|
||||||
|
const raw = c?.date || e?.date;
|
||||||
|
const parsed = parseDateSafe(raw);
|
||||||
|
|
||||||
|
if (parsed) dates.push(parsed.getTime());
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!dates.length) return null;
|
||||||
|
|
||||||
|
const avg = dates.reduce((a, b) => a + b, 0) / dates.length;
|
||||||
|
return new Date(avg);
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildGraph(data) {
|
||||||
|
const nodes = [];
|
||||||
|
const links = [];
|
||||||
|
|
||||||
|
const titleMap = new Map(titlesData.map(t => [t.cluster_id, t.title]));
|
||||||
|
const { claimMap, eventMap } = buildLookupMaps(data);
|
||||||
|
|
||||||
|
data.claim_clusters.forEach((cluster) => {
|
||||||
|
const avgDate = computeClusterAvgDate(cluster.members, claimMap, eventMap);
|
||||||
|
|
||||||
|
nodes.push({
|
||||||
|
id: cluster.cluster_id,
|
||||||
|
label: titleMap.get(cluster.cluster_id) || cluster.title || "Unnamed Claim Cluster",
|
||||||
|
type: "claim_cluster",
|
||||||
|
members: cluster.members,
|
||||||
|
avgDate
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
data.event_clusters.forEach((cluster) => {
|
||||||
|
const avgDate = computeClusterAvgDate(cluster.members, claimMap, eventMap);
|
||||||
|
|
||||||
|
nodes.push({
|
||||||
|
id: cluster.cluster_id,
|
||||||
|
label: titleMap.get(cluster.cluster_id) || cluster.title || "Unnamed Event Cluster",
|
||||||
|
type: "event_cluster",
|
||||||
|
members: cluster.members,
|
||||||
|
avgDate
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
data.cluster_links.forEach((link) => {
|
||||||
|
links.push({
|
||||||
|
source: link.claim_cluster_id,
|
||||||
|
target: link.event_cluster_id
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return { nodes, links };
|
||||||
|
}
|
||||||
|
|
||||||
|
function getConnectedComponents(nodes, links) {
|
||||||
|
const adj = new Map();
|
||||||
|
nodes.forEach(n => adj.set(n.id, new Set()));
|
||||||
|
|
||||||
|
links.forEach(l => {
|
||||||
|
adj.get(l.source)?.add(l.target);
|
||||||
|
adj.get(l.target)?.add(l.source);
|
||||||
|
});
|
||||||
|
|
||||||
|
const visited = new Set();
|
||||||
|
const components = [];
|
||||||
|
|
||||||
|
for (const node of nodes) {
|
||||||
|
if (visited.has(node.id)) continue;
|
||||||
|
|
||||||
|
const stack = [node.id];
|
||||||
|
const comp = [];
|
||||||
|
|
||||||
|
while (stack.length) {
|
||||||
|
const id = stack.pop();
|
||||||
|
if (visited.has(id)) continue;
|
||||||
|
|
||||||
|
visited.add(id);
|
||||||
|
comp.push(id);
|
||||||
|
|
||||||
|
adj.get(id)?.forEach(nei => {
|
||||||
|
if (!visited.has(nei)) stack.push(nei);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
components.push(comp);
|
||||||
|
}
|
||||||
|
|
||||||
|
return components;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function App2() {
|
||||||
|
const fgRef = useRef();
|
||||||
|
const [selectedNode, setSelectedNode] = useState(null);
|
||||||
|
const [inputDate, setInputDate] = useState("");
|
||||||
|
|
||||||
|
const parsedInputDate = useMemo(() => {
|
||||||
|
const d = new Date(inputDate);
|
||||||
|
return isNaN(d.getTime()) ? null : d;
|
||||||
|
}, [inputDate]);
|
||||||
|
|
||||||
|
const graphData = useMemo(() => {
|
||||||
|
const full = buildGraph(data);
|
||||||
|
const components = getConnectedComponents(full.nodes, full.links);
|
||||||
|
|
||||||
|
const validIds = new Set(
|
||||||
|
components.filter(c => c.length > 1000).flat()
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
nodes: full.nodes.filter(n => validIds.has(n.id)),
|
||||||
|
links: full.links.filter(
|
||||||
|
l => validIds.has(l.source) && validIds.has(l.target)
|
||||||
|
)
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!fgRef.current) return;
|
||||||
|
|
||||||
|
fgRef.current.d3Force("charge", d3.forceManyBody().strength(-30));
|
||||||
|
|
||||||
|
fgRef.current.d3Force(
|
||||||
|
"link",
|
||||||
|
d3.forceLink().distance(140)
|
||||||
|
);
|
||||||
|
|
||||||
|
fgRef.current.d3Force(
|
||||||
|
"collision",
|
||||||
|
d3.forceCollide((node) => {
|
||||||
|
const dims = node.__bckgDimensions;
|
||||||
|
return dims ? Math.max(dims[0], dims[1]) / 2 + 32 : 40;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
fgRef.current.d3ReheatSimulation();
|
||||||
|
}, [graphData]);
|
||||||
|
function isNodeHighlighted(node, referenceDate) {
|
||||||
|
if (!referenceDate || !node.avgDate) return false;
|
||||||
|
const diffMonths = Math.abs(referenceDate - node.avgDate) / (1000 * 60 * 60 * 24 * 30.44);
|
||||||
|
return diffMonths <= 6;
|
||||||
|
}
|
||||||
|
const highlightedNodeIds = useMemo(() => {
|
||||||
|
if (!parsedInputDate) return new Set();
|
||||||
|
|
||||||
|
const set = new Set();
|
||||||
|
|
||||||
|
graphData.nodes.forEach((n) => {
|
||||||
|
if (isNodeHighlighted(n, parsedInputDate)) {
|
||||||
|
set.add(n.id);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return set;
|
||||||
|
}, [graphData.nodes, parsedInputDate]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<ForceGraph2D
|
||||||
|
ref={fgRef}
|
||||||
|
graphData={graphData}
|
||||||
|
nodeLabel={(node) => node.label}
|
||||||
|
nodeAutoColorBy="type"
|
||||||
|
linkColor={(link) => {
|
||||||
|
const sourceId =
|
||||||
|
typeof link.source === "object" ? link.source.id : link.source;
|
||||||
|
|
||||||
|
const targetId =
|
||||||
|
typeof link.target === "object" ? link.target.id : link.target;
|
||||||
|
|
||||||
|
const bothHighlighted =
|
||||||
|
highlightedNodeIds.has(sourceId) &&
|
||||||
|
highlightedNodeIds.has(targetId);
|
||||||
|
|
||||||
|
return bothHighlighted ? "orange" : "white";
|
||||||
|
}}
|
||||||
|
linkWidth={2.5}
|
||||||
|
onNodeClick={(node) => setSelectedNode(node)}
|
||||||
|
nodeCanvasObject={(node, ctx) => {
|
||||||
|
const label = node.label;
|
||||||
|
|
||||||
|
const fontSize = 16 + 32 * Math.min(node.members.length, 5);
|
||||||
|
ctx.font = `${fontSize}px Sans-Serif`;
|
||||||
|
|
||||||
|
const textWidth = ctx.measureText(label).width;
|
||||||
|
const padding = fontSize * 0.6;
|
||||||
|
|
||||||
|
const width = textWidth + padding;
|
||||||
|
const height = fontSize + padding;
|
||||||
|
|
||||||
|
const x = node.x - width / 2;
|
||||||
|
const y = node.y - height / 2;
|
||||||
|
|
||||||
|
const radius = Math.min(10, fontSize * 0.6);
|
||||||
|
|
||||||
|
let isHighlighted = false;
|
||||||
|
|
||||||
|
if (parsedInputDate && node.avgDate) {
|
||||||
|
const diffMonths = monthsDiff(parsedInputDate, node.avgDate);
|
||||||
|
isHighlighted = diffMonths <= 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.fillStyle = node.type.includes("claim")
|
||||||
|
? "blue"
|
||||||
|
: "green"
|
||||||
|
|
||||||
|
if (isHighlighted) {
|
||||||
|
drawRoundedRect(ctx, x, y, width, height, radius);
|
||||||
|
ctx.fill();
|
||||||
|
ctx.strokeStyle = "white";
|
||||||
|
ctx.stroke();
|
||||||
|
|
||||||
|
ctx.textAlign = "center";
|
||||||
|
ctx.textBaseline = "middle";
|
||||||
|
ctx.fillStyle = "white";
|
||||||
|
ctx.fillText(label, node.x, node.y);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
node.__bckgDimensions = [width, height];
|
||||||
|
node.__bckgPos = { x, y };
|
||||||
|
}}
|
||||||
|
nodePointerAreaPaint={(node, color, ctx) => {
|
||||||
|
const dims = node.__bckgDimensions;
|
||||||
|
const pos = node.__bckgPos;
|
||||||
|
if (!dims || !pos) return;
|
||||||
|
|
||||||
|
ctx.fillStyle = color;
|
||||||
|
drawRoundedRect(ctx, pos.x, pos.y, dims[0], dims[1], 6);
|
||||||
|
ctx.fill();
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
position: "absolute",
|
||||||
|
top: "10px",
|
||||||
|
right: "10px",
|
||||||
|
borderRadius: "3px",
|
||||||
|
backgroundColor: "gray",
|
||||||
|
padding: "20px",
|
||||||
|
maxWidth: "500px"
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<h2>FILTERS</h2>
|
||||||
|
|
||||||
|
<label>
|
||||||
|
Reference date:
|
||||||
|
<input
|
||||||
|
type="date"
|
||||||
|
value={inputDate}
|
||||||
|
onChange={(e) => setInputDate(e.target.value)}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<h2>Details</h2>
|
||||||
|
{selectedNode ? (
|
||||||
|
<div>
|
||||||
|
<p><strong>Title:</strong> {selectedNode.label}</p>
|
||||||
|
|
||||||
|
{selectedNode.members && (
|
||||||
|
<div>
|
||||||
|
<p><strong>Members:</strong></p>
|
||||||
|
<ul>
|
||||||
|
{selectedNode.members.map((m) => {
|
||||||
|
const memberData =
|
||||||
|
data.claims.find((c) => c.id === m) ||
|
||||||
|
data.events.find((e) => e.id === m);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<li key={m}>
|
||||||
|
{memberData ? memberData.text : m}
|
||||||
|
</li>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<p>Click a node to see details</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
import { createRoot } from 'react-dom/client';
|
import { createRoot } from 'react-dom/client';
|
||||||
import { StrictMode } from 'react';
|
import { StrictMode } from 'react';
|
||||||
import { App } from './App';
|
import { App2 } from './App2';
|
||||||
|
|
||||||
let container = document.getElementById("app")!;
|
let container = document.getElementById("app")!;
|
||||||
let root = createRoot(container)
|
let root = createRoot(container)
|
||||||
root.render(
|
root.render(
|
||||||
<StrictMode>
|
<StrictMode>
|
||||||
<App />
|
<App2 />
|
||||||
</StrictMode>
|
</StrictMode>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import csv
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
import dateparser
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from sklearn.cluster import AgglomerativeClustering
|
from sklearn.cluster import AgglomerativeClustering
|
||||||
@@ -10,7 +9,7 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
INPUT_CSV = "../../data/dataset.csv"
|
INPUT_CSV = "../../data/dataset.jsonl"
|
||||||
OUTPUT_JSON = "../../data/clustered_output.json"
|
OUTPUT_JSON = "../../data/clustered_output.json"
|
||||||
MODEL_NAME = "all-MiniLM-L6-v2"
|
MODEL_NAME = "all-MiniLM-L6-v2"
|
||||||
SIMILARITY_THRESHOLD = 0.8
|
SIMILARITY_THRESHOLD = 0.8
|
||||||
@@ -19,38 +18,51 @@ def generate_guid():
|
|||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
def read_csv(file_path: str):
|
def read_jsonl(file_path: str):
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
with open(file_path, newline='', encoding='utf-8') as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
reader = csv.reader(f)
|
for line in tqdm(f, desc="Reading JSONL"):
|
||||||
for row in tqdm(reader, desc="Reading CSV"):
|
line = line.strip()
|
||||||
row = [r.strip() for r in row if r.strip()]
|
if not line:
|
||||||
if not row:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
claim = row[0]
|
obj = json.loads(line)
|
||||||
events = row[1:]
|
|
||||||
|
claim_text = obj.get("claim", "").strip()
|
||||||
|
claim_date = obj.get("date", "").strip()
|
||||||
|
events = obj.get("events", [])
|
||||||
|
|
||||||
|
if not claim_text:
|
||||||
|
continue
|
||||||
|
|
||||||
claim_id = generate_guid()
|
claim_id = generate_guid()
|
||||||
|
|
||||||
event_objects = []
|
event_objects = []
|
||||||
for e in events:
|
for e in events:
|
||||||
|
event_text = e.get("Event", "").strip()
|
||||||
|
event_date = e.get("Date", "").strip()
|
||||||
|
if not event_text:
|
||||||
|
continue
|
||||||
|
|
||||||
event_objects.append({
|
event_objects.append({
|
||||||
"id": generate_guid(),
|
"id": generate_guid(),
|
||||||
"text": e
|
"text": event_text,
|
||||||
|
"date": dateparser.parse(event_date)
|
||||||
})
|
})
|
||||||
|
|
||||||
data.append({
|
data.append({
|
||||||
"claim": {
|
"claim": {
|
||||||
"id": claim_id,
|
"id": claim_id,
|
||||||
"text": claim
|
"text": claim_text,
|
||||||
|
"date": dateparser.parse(claim_date)
|
||||||
},
|
},
|
||||||
"events": event_objects
|
"events": event_objects
|
||||||
})
|
})
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def embed_texts(model, texts: List[str], desc="Embedding"):
|
def embed_texts(model, texts: List[str], desc="Embedding"):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for t in tqdm(texts, desc=desc):
|
for t in tqdm(texts, desc=desc):
|
||||||
@@ -76,10 +88,10 @@ def main():
|
|||||||
print("Loading model...")
|
print("Loading model...")
|
||||||
model = SentenceTransformer(MODEL_NAME)
|
model = SentenceTransformer(MODEL_NAME)
|
||||||
|
|
||||||
data = read_csv(INPUT_CSV)
|
data = read_jsonl(INPUT_CSV)
|
||||||
|
|
||||||
claim_texts, claim_ids = [], []
|
claim_texts, claim_ids, claim_dates = [], [], []
|
||||||
event_texts, event_ids = [], []
|
event_texts, event_ids, event_dates = [], [], []
|
||||||
|
|
||||||
raw_links = [] # temporary for cluster mapping
|
raw_links = [] # temporary for cluster mapping
|
||||||
|
|
||||||
@@ -87,10 +99,12 @@ def main():
|
|||||||
claim = entry["claim"]
|
claim = entry["claim"]
|
||||||
claim_ids.append(claim["id"])
|
claim_ids.append(claim["id"])
|
||||||
claim_texts.append(f"Claim: {claim['text']}")
|
claim_texts.append(f"Claim: {claim['text']}")
|
||||||
|
claim_dates.append(claim['date'])
|
||||||
|
|
||||||
for event in entry["events"]:
|
for event in entry["events"]:
|
||||||
event_ids.append(event["id"])
|
event_ids.append(event["id"])
|
||||||
event_texts.append(f"Event: {event['text']}")
|
event_texts.append(f"Event: {event['text']}")
|
||||||
|
event_dates.append(event['date'])
|
||||||
|
|
||||||
raw_links.append({
|
raw_links.append({
|
||||||
"claim_id": claim["id"],
|
"claim_id": claim["id"],
|
||||||
@@ -148,12 +162,12 @@ def main():
|
|||||||
|
|
||||||
output = {
|
output = {
|
||||||
"claims": [
|
"claims": [
|
||||||
{"id": cid, "text": txt.replace("Claim: ", "")}
|
{"id": cid, "text": txt.replace("Claim: ", ""), "date": str(dat)}
|
||||||
for cid, txt in zip(claim_ids, claim_texts)
|
for cid, txt, dat in zip(claim_ids, claim_texts, claim_dates)
|
||||||
],
|
],
|
||||||
"events": [
|
"events": [
|
||||||
{"id": eid, "text": txt.replace("Event: ", "")}
|
{"id": eid, "text": txt.replace("Event: ", ""), "date": str(dat)}
|
||||||
for eid, txt in zip(event_ids, event_texts)
|
for eid, txt, dat in zip(event_ids, event_texts, event_dates)
|
||||||
],
|
],
|
||||||
"claim_clusters": [
|
"claim_clusters": [
|
||||||
{"cluster_id": k, "members": v}
|
{"cluster_id": k, "members": v}
|
||||||
|
|||||||
@@ -0,0 +1,150 @@
|
|||||||
|
import json
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from openai import OpenAI
|
||||||
|
from tqdm import tqdm
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Load environment and OpenAI client
|
||||||
|
# -------------------------------
|
||||||
|
load_dotenv() # Load environment variables from .env file
|
||||||
|
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# CONFIG
|
||||||
|
# -------------------------------
|
||||||
|
INPUT_FILE = "../../data/clustered_output.json" # Your original JSON
|
||||||
|
OUTPUT_FILE = "../../data/clustered_output_time.json" # Output JSON file
|
||||||
|
OPENAI_MODEL = "gpt-5-nano"
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Load data
|
||||||
|
# -------------------------------
|
||||||
|
with open(INPUT_FILE, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Prepare cluster sets
|
||||||
|
# -------------------------------
|
||||||
|
claim_clusters = {c["cluster_id"] for c in data["claim_clusters"]}
|
||||||
|
event_clusters = {e["cluster_id"] for e in data["event_clusters"]}
|
||||||
|
all_clusters = claim_clusters.union(event_clusters)
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Build graph
|
||||||
|
# -------------------------------
|
||||||
|
graph = defaultdict(set)
|
||||||
|
for link in data.get("cluster_links", []):
|
||||||
|
c_id = link["claim_cluster_id"]
|
||||||
|
e_id = link["event_cluster_id"]
|
||||||
|
graph[c_id].add(e_id)
|
||||||
|
graph[e_id].add(c_id)
|
||||||
|
|
||||||
|
for cid in all_clusters:
|
||||||
|
graph[cid] = graph[cid]
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Find connected components
|
||||||
|
# -------------------------------
|
||||||
|
visited = set()
|
||||||
|
components = []
|
||||||
|
|
||||||
|
for node in graph:
|
||||||
|
if node not in visited:
|
||||||
|
queue = deque([node])
|
||||||
|
component = set()
|
||||||
|
while queue:
|
||||||
|
current = queue.popleft()
|
||||||
|
if current in visited:
|
||||||
|
continue
|
||||||
|
visited.add(current)
|
||||||
|
component.add(current)
|
||||||
|
for neighbor in graph[current]:
|
||||||
|
if neighbor not in visited:
|
||||||
|
queue.append(neighbor)
|
||||||
|
components.append(component)
|
||||||
|
|
||||||
|
# Filter components with size > 8 and < 50
|
||||||
|
large_components = [c for c in components if len(c) > 1000]
|
||||||
|
|
||||||
|
print("Connected components (size > 8):", len(large_components))
|
||||||
|
print("Total clusters in those components:", sum(len(c) for c in large_components))
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Prepare lookups
|
||||||
|
# -------------------------------
|
||||||
|
claim_lookup = {c["id"]: c["text"] for c in data["claims"]}
|
||||||
|
event_lookup = {e["id"]: e["text"] for e in data["events"]}
|
||||||
|
claim_cluster_map = {c["cluster_id"]: c["members"] for c in data["claim_clusters"]}
|
||||||
|
event_cluster_map = {e["cluster_id"]: e["members"] for e in data["event_clusters"]}
|
||||||
|
|
||||||
|
def extract_texts_for_cluster(cluster_id):
|
||||||
|
texts = []
|
||||||
|
if cluster_id in claim_cluster_map:
|
||||||
|
texts.extend([claim_lookup[mid] for mid in claim_cluster_map[cluster_id] if mid in claim_lookup])
|
||||||
|
elif cluster_id in event_cluster_map:
|
||||||
|
texts.extend([event_lookup[mid] for mid in event_cluster_map[cluster_id] if mid in event_lookup])
|
||||||
|
return texts
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# GPT-based title generation
|
||||||
|
# -------------------------------
|
||||||
|
def generate_title(texts):
|
||||||
|
prompt = (
|
||||||
|
"Summarize the following texts into a concise 3 - 6 word title that captures the main theme:\n\n"
|
||||||
|
+ "\n".join(f"- {t}" for t in texts) +
|
||||||
|
"\n\nTitle:"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# response = client.chat.completions.create(
|
||||||
|
# model=OPENAI_MODEL,
|
||||||
|
# messages=[
|
||||||
|
# {"role": "system", "content": "You are a helpful assistant who creates short, meaningful titles."},
|
||||||
|
# {"role": "user", "content": prompt}
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
# title = response.choices[0].message.content.strip()
|
||||||
|
# if title.lower().startswith("title:"):
|
||||||
|
# title = title[6:].strip()
|
||||||
|
# return title
|
||||||
|
return "UNNAMED"
|
||||||
|
except Exception as e:
|
||||||
|
print("Error generating title:", e)
|
||||||
|
return "Untitled Cluster"
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Wrapper for parallel execution
|
||||||
|
# -------------------------------
|
||||||
|
def generate_title_for_cluster(cluster_id):
|
||||||
|
texts = extract_texts_for_cluster(cluster_id)
|
||||||
|
title = generate_title(texts)
|
||||||
|
return {"cluster_id": cluster_id, "title": title}
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Generate titles in parallel
|
||||||
|
# -------------------------------
|
||||||
|
clusters_in_large_components = [cid for comp in large_components for cid in comp]
|
||||||
|
output = []
|
||||||
|
|
||||||
|
print("\nGenerating GPT titles for clusters (parallel)...")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
|
future_to_cluster = {executor.submit(generate_title_for_cluster, cid): cid for cid in clusters_in_large_components}
|
||||||
|
for future in tqdm(as_completed(future_to_cluster), total=len(clusters_in_large_components), desc="Clusters", ncols=100):
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
output.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
cid = future_to_cluster[future]
|
||||||
|
print(f"Error processing cluster {cid}: {e}")
|
||||||
|
output.append({"cluster_id": cid, "title": "Untitled Cluster"})
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Save JSON
|
||||||
|
# -------------------------------
|
||||||
|
with open(OUTPUT_FILE, "w") as f:
|
||||||
|
json.dump(output, f, indent=2)
|
||||||
|
|
||||||
|
print(f"\nSaved cluster titles to {OUTPUT_FILE}")
|
||||||
@@ -1 +1,2 @@
|
|||||||
sentence_transformers
|
sentence_transformers
|
||||||
|
dateparser
|
||||||
Reference in New Issue
Block a user