Files
cucyuqing/cucyuqing/dbscan.py

112 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing_extensions import Doc
import numpy
import asyncio
import json
from dataclasses import dataclass
from sklearn.cluster import DBSCAN
from sklearn.metrics import pairwise_distances
from cucyuqing.pg import pool, get_cur
@dataclass
class Document:
id: str
"""ID 是 ES 中的 32 为 hex ID"""
title: str
content: str
similarity: float = 0.0
def get_text_for_llm(self) -> str:
"""只使用标题进行风险分析
对于空标题,在入库时已经处理过。
如果入库时标题为空则使用content的前20个字符或第一句中文作为标题。
"""
return self.title
@dataclass
class DBScanResult:
noise: list[Document]
clusters: list[list[Document]]
from sklearn.metrics.pairwise import cosine_similarity
async def run_dbscan() -> DBScanResult:
# 从 PG 数据库获取数据
async with get_cur() as cur:
await cur.execute(
"""
SELECT id, title, content, embedding
FROM risk_news
WHERE NOT embedding_updated_at IS NULL
AND time > now() - interval '7 day'
ORDER BY time desc
LIMIT 100000
;"""
)
rows = await cur.fetchall()
docs: list[Document] = [
Document(str(row[0]).replace("-", ""), row[1], row[2]) for row in rows
]
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
# 计算余弦距离矩阵
cosine_distances = pairwise_distances(embeddings, metric="cosine")
# 初始化DBSCAN模型
dbscan = DBSCAN(
eps=0.25, min_samples=2, metric="precomputed"
) # Adjust eps as needed
# 进行聚类
dbscan.fit(cosine_distances)
# 获取每个样本的聚类标签
labels: list[int] = dbscan.labels_ # type: ignore
# 输出每个聚类中的文档
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
unique_labels = set(labels)
for label in unique_labels:
class_member_mask = labels == label
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
if label == -1:
# -1 is the label for noise points
ret.noise = cluster_docs
else:
# 计算质心
centroid = numpy.mean(cluster_embeddings, axis=0).reshape(1, -1)
# 计算相似度
similarities = cosine_similarity(centroid, cluster_embeddings).flatten()
# 根据相似度排序
sorted_indices = numpy.argsort(similarities)[::-1]
sorted_cluster_docs = []
for i in sorted_indices:
doc = cluster_docs[i]
doc.similarity = similarities[i]
sorted_cluster_docs.append(doc)
ret.clusters.append(sorted_cluster_docs)
return ret
async def main():
await pool.open()
result = await run_dbscan()
print(f"噪声文档: {len(result.noise)}")
for i, cluster in enumerate(result.clusters):
print("----------------")
for doc in cluster:
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
if __name__ == "__main__":
asyncio.run(main())