Compare commits
3 Commits
a49db1b71b
...
ad3ef8e504
| Author | SHA1 | Date | |
|---|---|---|---|
|
ad3ef8e504
|
|||
|
b3fa8edf60
|
|||
|
004552f4d5
|
70
cucyuqing/cmd/kmeans.py
Normal file
70
cucyuqing/cmd/kmeans.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import numpy
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
|
||||||
|
from cucyuqing.pg import pool, get_cur
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# 从 PG 数据库获取数据
|
||||||
|
await pool.open()
|
||||||
|
async with get_cur() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, title, content, embedding
|
||||||
|
FROM risk_news
|
||||||
|
WHERE NOT embedding_updated_at IS NULL
|
||||||
|
AND time > now() - interval '14 day'
|
||||||
|
ORDER BY time desc
|
||||||
|
LIMIT 1000
|
||||||
|
;"""
|
||||||
|
)
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
docs = [
|
||||||
|
{
|
||||||
|
"id": row[0],
|
||||||
|
"title": row[1],
|
||||||
|
"content": row[2],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
|
||||||
|
|
||||||
|
# 设置聚类的数量
|
||||||
|
num_clusters = 50
|
||||||
|
|
||||||
|
# 初始化KMeans模型
|
||||||
|
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
|
||||||
|
|
||||||
|
# 进行聚类
|
||||||
|
kmeans.fit(embeddings)
|
||||||
|
|
||||||
|
# 获取每个样本的聚类标签
|
||||||
|
labels: list[int] = kmeans.labels_ # type: ignore
|
||||||
|
|
||||||
|
# 计算每个样本到其聚类中心的距离
|
||||||
|
distances = kmeans.transform(embeddings)
|
||||||
|
# 找到每个聚类中距离中心最近的文档
|
||||||
|
closest_docs = {}
|
||||||
|
for i, label in enumerate(labels):
|
||||||
|
distance_to_center = distances[i][label]
|
||||||
|
if label not in closest_docs or distance_to_center < closest_docs[label][0]:
|
||||||
|
closest_docs[label] = (distance_to_center, docs[i])
|
||||||
|
|
||||||
|
# 输出每个聚类中距离中心最近的文档
|
||||||
|
for label, (distance, doc) in closest_docs.items():
|
||||||
|
print(f"聚类 {label} 最近的文档: {doc['title']} 距离: {distance}")
|
||||||
|
|
||||||
|
sorted_samples: list[tuple[int, int]] = sorted(enumerate(labels), key=lambda x: x[1])
|
||||||
|
|
||||||
|
# 随机选择一个聚类
|
||||||
|
random_cluster = numpy.random.choice(num_clusters)
|
||||||
|
for i, label in sorted_samples:
|
||||||
|
if label == random_cluster:
|
||||||
|
print(f"聚类 {label} 文档: {docs[i]['title']}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -10,19 +10,16 @@ from cucyuqing.utils import print
|
|||||||
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
|
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
|
||||||
from cucyuqing.pg import pool, get_cur
|
from cucyuqing.pg import pool, get_cur
|
||||||
from cucyuqing.mysql import mysql
|
from cucyuqing.mysql import mysql
|
||||||
|
from cucyuqing.dbscan import run_dbscan
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
await pool.open()
|
await pool.open()
|
||||||
print(await batch_risk_analyze(["你是老师", "我是初音未来"]))
|
dbscan_result = await run_dbscan()
|
||||||
|
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
||||||
async def get_docs() -> list[dict]:
|
analyze_rusult = await batch_risk_analyze([doc.title for doc in docs])
|
||||||
# [TODO]
|
for result, doc in zip(analyze_rusult, docs):
|
||||||
raise NotImplemented
|
print(f"风险: {result} 标题: {doc.title}")
|
||||||
await mysql.execute("""
|
|
||||||
""")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def batch_risk_analyze(
|
async def batch_risk_analyze(
|
||||||
|
|||||||
94
cucyuqing/dbscan.py
Normal file
94
cucyuqing/dbscan.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
from typing_extensions import Doc
|
||||||
|
import numpy
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from sklearn.cluster import DBSCAN
|
||||||
|
from sklearn.metrics import pairwise_distances
|
||||||
|
|
||||||
|
from cucyuqing.pg import pool, get_cur
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Document:
|
||||||
|
id: int
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
similarity: float = 0.0
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DBScanResult:
|
||||||
|
noise: list[Document]
|
||||||
|
clusters: list[list[Document]]
|
||||||
|
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
async def run_dbscan() -> DBScanResult:
|
||||||
|
# 从 PG 数据库获取数据
|
||||||
|
async with get_cur() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, title, content, embedding
|
||||||
|
FROM risk_news
|
||||||
|
WHERE NOT embedding_updated_at IS NULL
|
||||||
|
AND time > now() - interval '14 day'
|
||||||
|
ORDER BY time desc
|
||||||
|
LIMIT 10000
|
||||||
|
;"""
|
||||||
|
)
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
docs: list[Document] = [
|
||||||
|
Document(row[0], row[1], row[2])
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
|
||||||
|
|
||||||
|
# 计算余弦距离矩阵
|
||||||
|
cosine_distances = pairwise_distances(embeddings, metric='cosine')
|
||||||
|
|
||||||
|
# 初始化DBSCAN模型
|
||||||
|
dbscan = DBSCAN(eps=0.25, min_samples=2, metric='precomputed') # Adjust eps as needed
|
||||||
|
|
||||||
|
# 进行聚类
|
||||||
|
dbscan.fit(cosine_distances)
|
||||||
|
|
||||||
|
# 获取每个样本的聚类标签
|
||||||
|
labels: list[int] = dbscan.labels_ # type: ignore
|
||||||
|
|
||||||
|
# 输出每个聚类中的文档
|
||||||
|
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
|
||||||
|
unique_labels = set(labels)
|
||||||
|
for label in unique_labels:
|
||||||
|
class_member_mask = (labels == label)
|
||||||
|
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||||||
|
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||||||
|
|
||||||
|
if label == -1:
|
||||||
|
# -1 is the label for noise points
|
||||||
|
ret.noise = cluster_docs
|
||||||
|
else:
|
||||||
|
# 计算质心
|
||||||
|
centroid = numpy.mean(cluster_embeddings, axis=0).reshape(1, -1)
|
||||||
|
# 计算相似度
|
||||||
|
similarities = cosine_similarity(centroid, cluster_embeddings).flatten()
|
||||||
|
# 根据相似度排序
|
||||||
|
sorted_indices = numpy.argsort(similarities)[::-1]
|
||||||
|
sorted_cluster_docs = []
|
||||||
|
for i in sorted_indices:
|
||||||
|
doc = cluster_docs[i]
|
||||||
|
doc.similarity = similarities[i]
|
||||||
|
sorted_cluster_docs.append(doc)
|
||||||
|
ret.clusters.append(sorted_cluster_docs)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await pool.open()
|
||||||
|
result = await run_dbscan()
|
||||||
|
print(f"噪声文档: {len(result.noise)}")
|
||||||
|
for i, cluster in enumerate(result.clusters):
|
||||||
|
print('----------------')
|
||||||
|
for doc in cluster:
|
||||||
|
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
@@ -8,3 +8,4 @@ databases[aiomysql]
|
|||||||
openai
|
openai
|
||||||
tokenizers
|
tokenizers
|
||||||
tqdm
|
tqdm
|
||||||
|
scikit-learn
|
||||||
|
|||||||
Reference in New Issue
Block a user