112 lines
3.3 KiB
Python
112 lines
3.3 KiB
Python
from typing_extensions import Doc
|
||
import numpy
|
||
import asyncio
|
||
import json
|
||
from dataclasses import dataclass
|
||
from sklearn.cluster import DBSCAN
|
||
from sklearn.metrics import pairwise_distances
|
||
|
||
from cucyuqing.pg import pool, get_cur
|
||
|
||
|
||
@dataclass
|
||
class Document:
|
||
id: str
|
||
"""ID 是 ES 中的 32 为 hex ID"""
|
||
|
||
title: str
|
||
content: str
|
||
similarity: float = 0.0
|
||
|
||
def get_text_for_llm(self) -> str:
|
||
"""只使用标题进行风险分析
|
||
|
||
对于空标题,在入库时已经处理过。
|
||
如果入库时标题为空,则使用content的前20个字符或第一句中文作为标题。
|
||
"""
|
||
return self.title
|
||
|
||
|
||
@dataclass
|
||
class DBScanResult:
|
||
noise: list[Document]
|
||
clusters: list[list[Document]]
|
||
|
||
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
|
||
|
||
async def run_dbscan() -> DBScanResult:
|
||
# 从 PG 数据库获取数据
|
||
async with get_cur() as cur:
|
||
await cur.execute(
|
||
"""
|
||
SELECT id, title, content, embedding
|
||
FROM risk_news
|
||
WHERE NOT embedding_updated_at IS NULL
|
||
AND time > now() - interval '7 day'
|
||
ORDER BY time desc
|
||
LIMIT 100000
|
||
;"""
|
||
)
|
||
rows = await cur.fetchall()
|
||
docs: list[Document] = [
|
||
Document(str(row[0]).replace("-", ""), row[1], row[2]) for row in rows
|
||
]
|
||
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
|
||
|
||
# 计算余弦距离矩阵
|
||
cosine_distances = pairwise_distances(embeddings, metric="cosine")
|
||
|
||
# 初始化DBSCAN模型
|
||
dbscan = DBSCAN(
|
||
eps=0.25, min_samples=2, metric="precomputed"
|
||
) # Adjust eps as needed
|
||
|
||
# 进行聚类
|
||
dbscan.fit(cosine_distances)
|
||
|
||
# 获取每个样本的聚类标签
|
||
labels: list[int] = dbscan.labels_ # type: ignore
|
||
|
||
# 输出每个聚类中的文档
|
||
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
|
||
unique_labels = set(labels)
|
||
for label in unique_labels:
|
||
class_member_mask = labels == label
|
||
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||
|
||
if label == -1:
|
||
# -1 is the label for noise points
|
||
ret.noise = cluster_docs
|
||
else:
|
||
# 计算质心
|
||
centroid = numpy.mean(cluster_embeddings, axis=0).reshape(1, -1)
|
||
# 计算相似度
|
||
similarities = cosine_similarity(centroid, cluster_embeddings).flatten()
|
||
# 根据相似度排序
|
||
sorted_indices = numpy.argsort(similarities)[::-1]
|
||
sorted_cluster_docs = []
|
||
for i in sorted_indices:
|
||
doc = cluster_docs[i]
|
||
doc.similarity = similarities[i]
|
||
sorted_cluster_docs.append(doc)
|
||
ret.clusters.append(sorted_cluster_docs)
|
||
|
||
return ret
|
||
|
||
|
||
async def main():
|
||
await pool.open()
|
||
result = await run_dbscan()
|
||
print(f"噪声文档: {len(result.noise)}")
|
||
for i, cluster in enumerate(result.clusters):
|
||
print("----------------")
|
||
for doc in cluster:
|
||
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|