添加 kmeans (存档)
This commit is contained in:
70
cucyuqing/cmd/kmeans.py
Normal file
70
cucyuqing/cmd/kmeans.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import numpy
|
||||
import asyncio
|
||||
import json
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
from cucyuqing.pg import pool, get_cur
|
||||
|
||||
|
||||
async def main():
|
||||
# 从 PG 数据库获取数据
|
||||
await pool.open()
|
||||
async with get_cur() as cur:
|
||||
await cur.execute(
|
||||
"""
|
||||
SELECT id, title, content, embedding
|
||||
FROM risk_news
|
||||
WHERE NOT embedding_updated_at IS NULL
|
||||
AND time > now() - interval '14 day'
|
||||
ORDER BY time desc
|
||||
LIMIT 1000
|
||||
;"""
|
||||
)
|
||||
rows = await cur.fetchall()
|
||||
docs = [
|
||||
{
|
||||
"id": row[0],
|
||||
"title": row[1],
|
||||
"content": row[2],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
|
||||
|
||||
# 设置聚类的数量
|
||||
num_clusters = 50
|
||||
|
||||
# 初始化KMeans模型
|
||||
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
|
||||
|
||||
# 进行聚类
|
||||
kmeans.fit(embeddings)
|
||||
|
||||
# 获取每个样本的聚类标签
|
||||
labels: list[int] = kmeans.labels_ # type: ignore
|
||||
|
||||
# 计算每个样本到其聚类中心的距离
|
||||
distances = kmeans.transform(embeddings)
|
||||
# 找到每个聚类中距离中心最近的文档
|
||||
closest_docs = {}
|
||||
for i, label in enumerate(labels):
|
||||
distance_to_center = distances[i][label]
|
||||
if label not in closest_docs or distance_to_center < closest_docs[label][0]:
|
||||
closest_docs[label] = (distance_to_center, docs[i])
|
||||
|
||||
# 输出每个聚类中距离中心最近的文档
|
||||
for label, (distance, doc) in closest_docs.items():
|
||||
print(f"聚类 {label} 最近的文档: {doc['title']} 距离: {distance}")
|
||||
|
||||
sorted_samples: list[tuple[int, int]] = sorted(enumerate(labels), key=lambda x: x[1])
|
||||
|
||||
# 随机选择一个聚类
|
||||
random_cluster = numpy.random.choice(num_clusters)
|
||||
for i, label in sorted_samples:
|
||||
if label == random_cluster:
|
||||
print(f"聚类 {label} 文档: {docs[i]['title']}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user