添加dbscan

This commit is contained in:
2024-10-10 18:19:50 +08:00
parent a49db1b71b
commit 004552f4d5
2 changed files with 95 additions and 0 deletions

94
cucyuqing/dbscan.py Normal file
View File

@@ -0,0 +1,94 @@
from typing_extensions import Doc
import numpy
import asyncio
import json
from dataclasses import dataclass
from sklearn.cluster import DBSCAN
from sklearn.metrics import pairwise_distances
from cucyuqing.pg import pool, get_cur
@dataclass
class Document:
id: int
title: str
content: str
similarity: float = 0.0
@dataclass
class DBScanResult:
noise: list[Document]
clusters: list[list[Document]]
from sklearn.metrics.pairwise import cosine_similarity
async def run_dbscan() -> DBScanResult:
# 从 PG 数据库获取数据
async with get_cur() as cur:
await cur.execute(
"""
SELECT id, title, content, embedding
FROM risk_news
WHERE NOT embedding_updated_at IS NULL
AND time > now() - interval '14 day'
ORDER BY time desc
LIMIT 10000
;"""
)
rows = await cur.fetchall()
docs: list[Document] = [
Document(row[0], row[1], row[2])
for row in rows
]
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
# 计算余弦距离矩阵
cosine_distances = pairwise_distances(embeddings, metric='cosine')
# 初始化DBSCAN模型
dbscan = DBSCAN(eps=0.25, min_samples=2, metric='precomputed') # Adjust eps as needed
# 进行聚类
dbscan.fit(cosine_distances)
# 获取每个样本的聚类标签
labels: list[int] = dbscan.labels_ # type: ignore
# 输出每个聚类中的文档
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
unique_labels = set(labels)
for label in unique_labels:
class_member_mask = (labels == label)
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
if label == -1:
# -1 is the label for noise points
ret.noise = cluster_docs
else:
# 计算质心
centroid = numpy.mean(cluster_embeddings, axis=0).reshape(1, -1)
# 计算相似度
similarities = cosine_similarity(centroid, cluster_embeddings).flatten()
# 根据相似度排序
sorted_indices = numpy.argsort(similarities)[::-1]
sorted_cluster_docs = []
for i in sorted_indices:
doc = cluster_docs[i]
doc.similarity = similarities[i]
sorted_cluster_docs.append(doc)
ret.clusters.append(sorted_cluster_docs)
return ret
async def main():
await pool.open()
result = await run_dbscan()
print(f"噪声文档: {len(result.noise)}")
for i, cluster in enumerate(result.clusters):
print('----------------')
for doc in cluster:
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
if __name__ == '__main__':
asyncio.run(main())

View File

@@ -8,3 +8,4 @@ databases[aiomysql]
openai openai
tokenizers tokenizers
tqdm tqdm
scikit-learn