Compare commits

...

3 Commits

Author SHA1 Message Date
ad3ef8e504 添加 kmeans (存档) 2024-10-10 18:31:47 +08:00
b3fa8edf60 添加 cmd.risk-analyze 2024-10-10 18:31:35 +08:00
004552f4d5 添加dbscan 2024-10-10 18:19:50 +08:00
4 changed files with 171 additions and 9 deletions

70
cucyuqing/cmd/kmeans.py Normal file
View File

@@ -0,0 +1,70 @@
import numpy
import asyncio
import json
from sklearn.cluster import KMeans
from cucyuqing.pg import pool, get_cur
async def main():
# 从 PG 数据库获取数据
await pool.open()
async with get_cur() as cur:
await cur.execute(
"""
SELECT id, title, content, embedding
FROM risk_news
WHERE NOT embedding_updated_at IS NULL
AND time > now() - interval '14 day'
ORDER BY time desc
LIMIT 1000
;"""
)
rows = await cur.fetchall()
docs = [
{
"id": row[0],
"title": row[1],
"content": row[2],
}
for row in rows
]
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
# 设置聚类的数量
num_clusters = 50
# 初始化KMeans模型
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
# 进行聚类
kmeans.fit(embeddings)
# 获取每个样本的聚类标签
labels: list[int] = kmeans.labels_ # type: ignore
# 计算每个样本到其聚类中心的距离
distances = kmeans.transform(embeddings)
# 找到每个聚类中距离中心最近的文档
closest_docs = {}
for i, label in enumerate(labels):
distance_to_center = distances[i][label]
if label not in closest_docs or distance_to_center < closest_docs[label][0]:
closest_docs[label] = (distance_to_center, docs[i])
# 输出每个聚类中距离中心最近的文档
for label, (distance, doc) in closest_docs.items():
print(f"聚类 {label} 最近的文档: {doc['title']} 距离: {distance}")
sorted_samples: list[tuple[int, int]] = sorted(enumerate(labels), key=lambda x: x[1])
# 随机选择一个聚类
random_cluster = numpy.random.choice(num_clusters)
for i, label in sorted_samples:
if label == random_cluster:
print(f"聚类 {label} 文档: {docs[i]['title']}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -10,19 +10,16 @@ from cucyuqing.utils import print
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
from cucyuqing.pg import pool, get_cur
from cucyuqing.mysql import mysql
from cucyuqing.dbscan import run_dbscan
async def main():
await pool.open()
print(await batch_risk_analyze(["你是老师", "我是初音未来"]))
async def get_docs() -> list[dict]:
# [TODO]
raise NotImplemented
await mysql.execute("""
""")
return []
dbscan_result = await run_dbscan()
docs = [cluster[0] for cluster in dbscan_result.clusters]
analyze_rusult = await batch_risk_analyze([doc.title for doc in docs])
for result, doc in zip(analyze_rusult, docs):
print(f"风险: {result} 标题: {doc.title}")
async def batch_risk_analyze(

94
cucyuqing/dbscan.py Normal file
View File

@@ -0,0 +1,94 @@
from typing_extensions import Doc
import numpy
import asyncio
import json
from dataclasses import dataclass
from sklearn.cluster import DBSCAN
from sklearn.metrics import pairwise_distances
from cucyuqing.pg import pool, get_cur
@dataclass
class Document:
id: int
title: str
content: str
similarity: float = 0.0
@dataclass
class DBScanResult:
noise: list[Document]
clusters: list[list[Document]]
from sklearn.metrics.pairwise import cosine_similarity
async def run_dbscan() -> DBScanResult:
# 从 PG 数据库获取数据
async with get_cur() as cur:
await cur.execute(
"""
SELECT id, title, content, embedding
FROM risk_news
WHERE NOT embedding_updated_at IS NULL
AND time > now() - interval '14 day'
ORDER BY time desc
LIMIT 10000
;"""
)
rows = await cur.fetchall()
docs: list[Document] = [
Document(row[0], row[1], row[2])
for row in rows
]
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
# 计算余弦距离矩阵
cosine_distances = pairwise_distances(embeddings, metric='cosine')
# 初始化DBSCAN模型
dbscan = DBSCAN(eps=0.25, min_samples=2, metric='precomputed') # Adjust eps as needed
# 进行聚类
dbscan.fit(cosine_distances)
# 获取每个样本的聚类标签
labels: list[int] = dbscan.labels_ # type: ignore
# 输出每个聚类中的文档
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
unique_labels = set(labels)
for label in unique_labels:
class_member_mask = (labels == label)
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
if label == -1:
# -1 is the label for noise points
ret.noise = cluster_docs
else:
# 计算质心
centroid = numpy.mean(cluster_embeddings, axis=0).reshape(1, -1)
# 计算相似度
similarities = cosine_similarity(centroid, cluster_embeddings).flatten()
# 根据相似度排序
sorted_indices = numpy.argsort(similarities)[::-1]
sorted_cluster_docs = []
for i in sorted_indices:
doc = cluster_docs[i]
doc.similarity = similarities[i]
sorted_cluster_docs.append(doc)
ret.clusters.append(sorted_cluster_docs)
return ret
async def main():
await pool.open()
result = await run_dbscan()
print(f"噪声文档: {len(result.noise)}")
for i, cluster in enumerate(result.clusters):
print('----------------')
for doc in cluster:
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
if __name__ == '__main__':
asyncio.run(main())

View File

@@ -8,3 +8,4 @@ databases[aiomysql]
openai
tokenizers
tqdm
scikit-learn