diff --git a/cucyuqing/dbscan.py b/cucyuqing/dbscan.py new file mode 100644 index 0000000..a02cf34 --- /dev/null +++ b/cucyuqing/dbscan.py @@ -0,0 +1,94 @@ +from typing_extensions import Doc +import numpy +import asyncio +import json +from dataclasses import dataclass +from sklearn.cluster import DBSCAN +from sklearn.metrics import pairwise_distances + +from cucyuqing.pg import pool, get_cur + +@dataclass +class Document: + id: int + title: str + content: str + similarity: float = 0.0 + +@dataclass +class DBScanResult: + noise: list[Document] + clusters: list[list[Document]] + +from sklearn.metrics.pairwise import cosine_similarity + +async def run_dbscan() -> DBScanResult: + # 从 PG 数据库获取数据 + async with get_cur() as cur: + await cur.execute( + """ + SELECT id, title, content, embedding + FROM risk_news + WHERE NOT embedding_updated_at IS NULL + AND time > now() - interval '14 day' + ORDER BY time desc + LIMIT 10000 + ;""" + ) + rows = await cur.fetchall() + docs: list[Document] = [ + Document(row[0], row[1], row[2]) + for row in rows + ] + embeddings = [numpy.array(json.loads(row[3])) for row in rows] + + # 计算余弦距离矩阵 + cosine_distances = pairwise_distances(embeddings, metric='cosine') + + # 初始化DBSCAN模型 + dbscan = DBSCAN(eps=0.25, min_samples=2, metric='precomputed') # Adjust eps as needed + + # 进行聚类 + dbscan.fit(cosine_distances) + + # 获取每个样本的聚类标签 + labels: list[int] = dbscan.labels_ # type: ignore + + # 输出每个聚类中的文档 + ret: DBScanResult = DBScanResult(noise=[], clusters=[]) + unique_labels = set(labels) + for label in unique_labels: + class_member_mask = (labels == label) + cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore + cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore + + if label == -1: + # -1 is the label for noise points + ret.noise = cluster_docs + else: + # 计算质心 + centroid = numpy.mean(cluster_embeddings, axis=0).reshape(1, -1) + # 计算相似度 + similarities = cosine_similarity(centroid, cluster_embeddings).flatten() + # 根据相似度排序 + sorted_indices = numpy.argsort(similarities)[::-1] + sorted_cluster_docs = [] + for i in sorted_indices: + doc = cluster_docs[i] + doc.similarity = similarities[i] + sorted_cluster_docs.append(doc) + ret.clusters.append(sorted_cluster_docs) + + return ret + +async def main(): + await pool.open() + result = await run_dbscan() + print(f"噪声文档: {len(result.noise)}") + for i, cluster in enumerate(result.clusters): + print('----------------') + for doc in cluster: + print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}") + +if __name__ == '__main__': + asyncio.run(main()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3b56848..82ef2a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ databases[aiomysql] openai tokenizers tqdm +scikit-learn