diff --git a/cucyuqing/cmd/kmeans.py b/cucyuqing/cmd/kmeans.py new file mode 100644 index 0000000..bffbd94 --- /dev/null +++ b/cucyuqing/cmd/kmeans.py @@ -0,0 +1,70 @@ +import numpy +import asyncio +import json +from sklearn.cluster import KMeans + +from cucyuqing.pg import pool, get_cur + + +async def main(): + # 从 PG 数据库获取数据 + await pool.open() + async with get_cur() as cur: + await cur.execute( + """ + SELECT id, title, content, embedding + FROM risk_news + WHERE NOT embedding_updated_at IS NULL + AND time > now() - interval '14 day' + ORDER BY time desc + LIMIT 1000 + ;""" + ) + rows = await cur.fetchall() + docs = [ + { + "id": row[0], + "title": row[1], + "content": row[2], + } + for row in rows + ] + embeddings = [numpy.array(json.loads(row[3])) for row in rows] + + # 设置聚类的数量 + num_clusters = 50 + + # 初始化KMeans模型 + kmeans = KMeans(n_clusters=num_clusters, random_state=42) + + # 进行聚类 + kmeans.fit(embeddings) + + # 获取每个样本的聚类标签 + labels: list[int] = kmeans.labels_ # type: ignore + + # 计算每个样本到其聚类中心的距离 + distances = kmeans.transform(embeddings) + # 找到每个聚类中距离中心最近的文档 + closest_docs = {} + for i, label in enumerate(labels): + distance_to_center = distances[i][label] + if label not in closest_docs or distance_to_center < closest_docs[label][0]: + closest_docs[label] = (distance_to_center, docs[i]) + + # 输出每个聚类中距离中心最近的文档 + for label, (distance, doc) in closest_docs.items(): + print(f"聚类 {label} 最近的文档: {doc['title']} 距离: {distance}") + + sorted_samples: list[tuple[int, int]] = sorted(enumerate(labels), key=lambda x: x[1]) + + # 随机选择一个聚类 + random_cluster = numpy.random.choice(num_clusters) + for i, label in sorted_samples: + if label == random_cluster: + print(f"聚类 {label} 文档: {docs[i]['title']}") + + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file