Files
cucyuqing/cucyuqing/cmd/kmeans.py
2024-10-10 18:31:47 +08:00

70 lines
2.0 KiB
Python

import numpy
import asyncio
import json
from sklearn.cluster import KMeans
from cucyuqing.pg import pool, get_cur
async def main():
# 从 PG 数据库获取数据
await pool.open()
async with get_cur() as cur:
await cur.execute(
"""
SELECT id, title, content, embedding
FROM risk_news
WHERE NOT embedding_updated_at IS NULL
AND time > now() - interval '14 day'
ORDER BY time desc
LIMIT 1000
;"""
)
rows = await cur.fetchall()
docs = [
{
"id": row[0],
"title": row[1],
"content": row[2],
}
for row in rows
]
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
# 设置聚类的数量
num_clusters = 50
# 初始化KMeans模型
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
# 进行聚类
kmeans.fit(embeddings)
# 获取每个样本的聚类标签
labels: list[int] = kmeans.labels_ # type: ignore
# 计算每个样本到其聚类中心的距离
distances = kmeans.transform(embeddings)
# 找到每个聚类中距离中心最近的文档
closest_docs = {}
for i, label in enumerate(labels):
distance_to_center = distances[i][label]
if label not in closest_docs or distance_to_center < closest_docs[label][0]:
closest_docs[label] = (distance_to_center, docs[i])
# 输出每个聚类中距离中心最近的文档
for label, (distance, doc) in closest_docs.items():
print(f"聚类 {label} 最近的文档: {doc['title']} 距离: {distance}")
sorted_samples: list[tuple[int, int]] = sorted(enumerate(labels), key=lambda x: x[1])
# 随机选择一个聚类
random_cluster = numpy.random.choice(num_clusters)
for i, label in sorted_samples:
if label == random_cluster:
print(f"聚类 {label} 文档: {docs[i]['title']}")
if __name__ == "__main__":
asyncio.run(main())