Compare commits
26 Commits
dd41d6fa5f
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
2d29d8631c
|
|||
|
3dc47712e4
|
|||
|
9dc23b714c
|
|||
|
1545a85b09
|
|||
|
115d95bbef
|
|||
|
34ad16ff02
|
|||
|
8a6db8f8f2
|
|||
|
cae3877048
|
|||
|
3d36dcadf6
|
|||
|
a051674f2e
|
|||
|
47ae4dc8d5
|
|||
|
d52f37e5f0
|
|||
|
ad3ef8e504
|
|||
|
b3fa8edf60
|
|||
|
004552f4d5
|
|||
|
a49db1b71b
|
|||
|
16f4469365
|
|||
|
341435e603
|
|||
|
fff2a32d7e
|
|||
|
43ddd8665c
|
|||
|
1e1f92461a
|
|||
|
71fd688227
|
|||
|
0b658dee88
|
|||
|
92e3699cd8
|
|||
|
4b5e59f35d
|
|||
|
7395d98ce3
|
96
README.md
96
README.md
@@ -1 +1,97 @@
|
|||||||
# 中传三期大模型舆情监测项目
|
# 中传三期大模型舆情监测项目
|
||||||
|
|
||||||
|
## 大模型风险分析功能说明
|
||||||
|
|
||||||
|
在 `风险预警` - `我的设置` 中可以设置 *筛选关键词* 和 *大模型提示词*
|
||||||
|
|
||||||
|
数据处理简要说明如下
|
||||||
|
|
||||||
|
1. 根据设置的 *筛选关键词*,从牛媒舆情数据中台筛选数据入库。
|
||||||
|
|
||||||
|
筛选入库程序每小时运行一次,每次导入一小时时间范围内的数据。为了给每小时末尾数据留出足够的处理时间,延迟一小时处理数据。因此整体新闻筛选入库延迟在 1-2 小时内。
|
||||||
|
|
||||||
|
这里的时间指的是牛媒数据中台入库时间而不是发布时间,这意味着有可能会补充入库两小时之前或更早之前的旧数据。特别是对于牛媒爬虫监控频率低于 2 小时的目标网站,这种延迟超过 2 小时入库的情况可能更常见。后续处理逻辑已经考虑这种情况。
|
||||||
|
|
||||||
|
2. 文本特征抽取
|
||||||
|
|
||||||
|
每十分钟执行一次文本特征抽取,对数据库中 文本向量 字段为空的新闻进行处理。
|
||||||
|
|
||||||
|
3. 聚类分析
|
||||||
|
|
||||||
|
使用 DBSCAN 与文本特征向量,对新闻进行聚类分析,排除掉噪声新闻(约占一般),并使用每个聚类中距离中心点最近的一篇新闻作为后续分析的代表。每次聚类约有 80 - 400 个类。聚类输入的数据是 7 天内的所有新闻。
|
||||||
|
|
||||||
|
4. 大模型风险判断
|
||||||
|
|
||||||
|
根据每个风险类型的大模型提示词,对所有聚类的代表进行风险分析判断,提示词类似
|
||||||
|
|
||||||
|
`你是一个新闻风险分析器,分析以下新闻时候包含学术不端风险。你只能回答是或否`
|
||||||
|
|
||||||
|
程序依靠大模型返回的文本中是否包含 "是" 或 "否" 关键字来判断大模型的分析结果
|
||||||
|
|
||||||
|
5. 分析结果入库
|
||||||
|
|
||||||
|
对于所有 **含有任意风险** 的新闻,程序会更新(覆盖)其风险分类字段。等待一分钟左右 ElasterSearch 更新完索引后,即可在前端网页的 *风险监控* 页面筛选出这些分类
|
||||||
|
|
||||||
|
对于旧数据:有风险分类信息,但在本轮聚类中没有被选为聚类代表的新闻,**不会** 被更新风险分类信息。
|
||||||
|
|
||||||
|
## 关于数据聚类算法的说明
|
||||||
|
|
||||||
|
文本向量是维度为 1024 的 float16 一维数组。向量之间使用 cosine 距离计算相似度。
|
||||||
|
|
||||||
|
由于聚类的目的是去重,因此 DBSCAN 是比较合适的算法。目前指定使用参数 EPS=0.25 最小聚类数量 2。基本上有 2 条重复的或者语义相似的新闻都可以识别到同一个聚类中。
|
||||||
|
|
||||||
|
## 重复数据说明
|
||||||
|
|
||||||
|
由于新闻洗稿、转载、抄袭等原因,可能会出现同一篇新闻在多个平台发布的情况。牛媒数据中台把他们当作不同的新闻对待(拥有不同的 ID)。聚类算法可以从语义信息层面识别到这些重复新闻(包括完全重复和语义相似),并把他们归为一类。
|
||||||
|
|
||||||
|
## 部署说明
|
||||||
|
|
||||||
|
### 环境变量
|
||||||
|
|
||||||
|
可以使用系统环境变量或 `.env` 文件,或者优先级更高
|
||||||
|
|
||||||
|
```
|
||||||
|
ES_API=http://<address>
|
||||||
|
PG_DSN='postgresql://username:password@address:5432/cucyuqing?sslmode=disable'
|
||||||
|
MYSQL_DSN='mysql://username:password@password:3306/niumedia'
|
||||||
|
OPENAI_EMBEDDING_API_KEY='key'
|
||||||
|
OPENAI_EMBEDDING_BASE_URL='http://<address>/v1'
|
||||||
|
OPENAI_RISK_LLM_API_KEY='key'
|
||||||
|
OPENAI_RISK_LLM_BASE_URL='https://<address>/v1'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 依赖
|
||||||
|
|
||||||
|
使用虚拟环境
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
pip install -r requirements.txt -i https://pypi.tuna.tinsghua.edu.cn/simple/
|
||||||
|
```
|
||||||
|
|
||||||
|
或使用 docker
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -t <image-name>:latest .
|
||||||
|
```
|
||||||
|
|
||||||
|
### 启动
|
||||||
|
|
||||||
|
启动 ES 同步程序
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m cmd.es-sync
|
||||||
|
```
|
||||||
|
|
||||||
|
启动 文本特征抽取 程序
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m cmd.embedding
|
||||||
|
```
|
||||||
|
|
||||||
|
启动 LLM 分析程序
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m cmd.risk-analyze
|
||||||
|
```
|
||||||
|
|||||||
@@ -1,29 +1,32 @@
|
|||||||
|
import traceback
|
||||||
import datetime
|
import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
import tqdm
|
import tqdm
|
||||||
import os
|
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
import openai
|
import openai
|
||||||
import hashlib
|
import hashlib
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
from cucyuqing.pg import pool, get_cur
|
from cucyuqing.pg import pool, get_cur
|
||||||
from cucyuqing.config import OPENAI_API_KEY, OPENAI_BASE_URL
|
from cucyuqing.config import OPENAI_EMBEDDING_API_KEY, OPENAI_EMBEDDING_BASE_URL
|
||||||
from cucyuqing.utils import print
|
from cucyuqing.utils import print
|
||||||
|
|
||||||
EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-large"]
|
EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-small"]
|
||||||
|
|
||||||
embedding_client = openai.AsyncOpenAI(
|
embedding_client = openai.AsyncOpenAI(
|
||||||
api_key=OPENAI_API_KEY,
|
api_key=OPENAI_EMBEDDING_API_KEY,
|
||||||
base_url=OPENAI_BASE_URL,
|
base_url=OPENAI_EMBEDDING_BASE_URL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
tokenizer = Tokenizer.from_file("cucyuqing/res/acge-large-zh/tokenizer.json")
|
tokenizers: dict[EmbeddingModel, Any] = {}
|
||||||
|
tokenizers['acge-large-zh'] = Tokenizer.from_file("cucyuqing/res/acge-large-zh/tokenizer.json")
|
||||||
|
tokenizers['text-embedding-3-small'] = Tokenizer.from_file("cucyuqing/res/cl100k_base/tokenizer.json")
|
||||||
|
|
||||||
|
|
||||||
def get_token_length(text: str) -> int:
|
def get_token_length(model_name: EmbeddingModel, text: str) -> int:
|
||||||
"""使用 openai 提供的 tokenizer **估算** token 长度"""
|
"""使用 openai 提供的 tokenizer **估算** token 长度"""
|
||||||
|
tokenizer = tokenizers[model_name]
|
||||||
return len(tokenizer.encode(text).tokens)
|
return len(tokenizer.encode(text).tokens)
|
||||||
|
|
||||||
|
|
||||||
@@ -39,8 +42,9 @@ def hash_text(text: str, model: EmbeddingModel) -> str:
|
|||||||
return hashlib.md5((text + "|" + model).encode()).hexdigest()
|
return hashlib.md5((text + "|" + model).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def truncate_text(text: str, max_length: int) -> str:
|
def truncate_text(model_name: EmbeddingModel, text: str, max_length: int) -> str:
|
||||||
"""截断文本"""
|
"""截断文本"""
|
||||||
|
tokenizer = tokenizers[model_name]
|
||||||
tokens = tokenizer.encode(text).tokens[0:max_length]
|
tokens = tokenizer.encode(text).tokens[0:max_length]
|
||||||
return ''.join(tokens)
|
return ''.join(tokens)
|
||||||
|
|
||||||
@@ -59,9 +63,8 @@ async def get_embeddings(
|
|||||||
- quiet: 是否关闭输出
|
- quiet: 是否关闭输出
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 针对 acge-large-zh 模型,需要将文本截断 1024 - 200
|
# 针对 大多数 模型,需要将文本截断 1024 - 2
|
||||||
if model == "acge-large-zh":
|
texts = [truncate_text(model, text, 1024 - 2) for text in texts]
|
||||||
texts = [truncate_text(text, 1024 - 2) for text in texts]
|
|
||||||
|
|
||||||
# 构建任务列表
|
# 构建任务列表
|
||||||
ids = list(range(len(texts)))
|
ids = list(range(len(texts)))
|
||||||
@@ -84,13 +87,13 @@ async def get_embeddings(
|
|||||||
batch_token_length = 0 # TEMP
|
batch_token_length = 0 # TEMP
|
||||||
iter_batch: list[Task] = [] # TEMP
|
iter_batch: list[Task] = [] # TEMP
|
||||||
for q in query:
|
for q in query:
|
||||||
batch_token_length += get_token_length(q.text)
|
batch_token_length += get_token_length(model, q.text)
|
||||||
|
|
||||||
# 该批次已满,将该批次加入 batch_query
|
# 该批次已满,将该批次加入 batch_query
|
||||||
if batch_token_length > max_batch_token_length or len(iter_batch) >= 32:
|
if batch_token_length > max_batch_token_length or len(iter_batch) >= 32:
|
||||||
batch_query.append(iter_batch)
|
batch_query.append(iter_batch)
|
||||||
iter_batch = [q]
|
iter_batch = [q]
|
||||||
batch_token_length = get_token_length(q.text)
|
batch_token_length = get_token_length(model, q.text)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
iter_batch.append(q)
|
iter_batch.append(q)
|
||||||
@@ -113,10 +116,10 @@ async def get_embeddings(
|
|||||||
input=[q.text for q in query],
|
input=[q.text for q in query],
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
elif model == "text-embedding-3-large":
|
elif model == "text-embedding-3-small":
|
||||||
resp = await embedding_client.embeddings.create(
|
resp = await embedding_client.embeddings.create(
|
||||||
input=[q.text for q in query],
|
input=[q.text for q in query],
|
||||||
model="text-embedding-3-large",
|
model=model,
|
||||||
dimensions=1024,
|
dimensions=1024,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -160,14 +163,25 @@ async def get_embedding_from_cache(hash: str) -> list[float] | None:
|
|||||||
async def main():
|
async def main():
|
||||||
await pool.open()
|
await pool.open()
|
||||||
while True:
|
while True:
|
||||||
|
try:
|
||||||
await do_update()
|
await do_update()
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
finally:
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
async def do_update():
|
async def do_update():
|
||||||
while True:
|
while True:
|
||||||
async with get_cur() as cur:
|
async with get_cur() as cur:
|
||||||
# 这里选择 embedding_updated_at is null 使用索引避免全表扫描
|
# 这里选择 embedding_updated_at is null 使用索引避免全表扫描
|
||||||
await cur.execute("SELECT id, title, content from risk_news where embedding_updated_at is null limit 1000")
|
await cur.execute("""
|
||||||
|
SELECT id, title, content
|
||||||
|
from risk_news
|
||||||
|
where embedding_updated_at is null
|
||||||
|
and time > now() - interval '14 day'
|
||||||
|
order by time desc
|
||||||
|
limit 1000
|
||||||
|
""")
|
||||||
docs = await cur.fetchall()
|
docs = await cur.fetchall()
|
||||||
|
|
||||||
# 循环出口
|
# 循环出口
|
||||||
@@ -175,7 +189,7 @@ async def do_update():
|
|||||||
print(datetime.datetime.now(), "No data to update")
|
print(datetime.datetime.now(), "No data to update")
|
||||||
break
|
break
|
||||||
|
|
||||||
embeddings = await get_embeddings([doc[1] + " " + doc[2] for doc in docs], "acge-large-zh")
|
embeddings = await get_embeddings([doc[1] or doc[2] for doc in docs], "acge-large-zh", threads=10)
|
||||||
async with get_cur() as cur:
|
async with get_cur() as cur:
|
||||||
for doc, embedding in tqdm.tqdm(zip(docs, embeddings), total=min(len(docs), len(embeddings)), desc="Update embeddings"):
|
for doc, embedding in tqdm.tqdm(zip(docs, embeddings), total=min(len(docs), len(embeddings)), desc="Update embeddings"):
|
||||||
await cur.execute("UPDATE risk_news SET embedding = %s, embedding_updated_at = now() where id = %s", (embedding, doc[0]))
|
await cur.execute("UPDATE risk_news SET embedding = %s, embedding_updated_at = now() where id = %s", (embedding, doc[0]))
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
ES 数据同步脚本
|
ES 数据同步脚本
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
@@ -27,13 +28,24 @@ class ESInterval(pydantic.BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def format_datetime(dt: datetime.datetime) -> str:
|
def format_datetime(dt: datetime.datetime) -> str:
|
||||||
return dt.strftime("%Y%m%d%H")
|
return dt.strftime("%Y%m%d%H%M%S")
|
||||||
|
|
||||||
|
|
||||||
def parse_unixtime(unixtime: int) -> datetime.datetime:
|
def parse_unixtime(unixtime: int) -> datetime.datetime:
|
||||||
return datetime.datetime.fromtimestamp(unixtime)
|
return datetime.datetime.fromtimestamp(unixtime)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_filter_query() -> str:
|
||||||
|
row = await mysql.fetch_one(
|
||||||
|
"""
|
||||||
|
select name from risk_news_keywords order by id limit 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
raise Exception("未找到风险关键词")
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
|
||||||
async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
|
async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
|
||||||
"""
|
"""
|
||||||
获取指定时间段内的数据,每次请求 size 条数据。这是一个递归函数,如果当前时间段内的数据量 = size,说明还有数据,继续请求
|
获取指定时间段内的数据,每次请求 size 条数据。这是一个递归函数,如果当前时间段内的数据量 = size,说明还有数据,继续请求
|
||||||
@@ -44,7 +56,7 @@ async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
|
|||||||
es_response = await post(
|
es_response = await post(
|
||||||
url,
|
url,
|
||||||
{
|
{
|
||||||
"word": "(教师|老师|教授|导师|院长) - (教育部|公告|通报|准则|建设|座谈|细则|工作|动员|专题) + (不正当|性骚扰|出轨|猥亵|不公|强迫|侮辱|举报|滥用|违法|师德|贿|造假|不端|抄袭|虚假|篡改|挪用|抑郁|威胁|霸凌|体罚)",
|
"word": await get_filter_query(),
|
||||||
"size": size,
|
"size": size,
|
||||||
"orders": 9,
|
"orders": 9,
|
||||||
"tmode": 2,
|
"tmode": 2,
|
||||||
@@ -65,6 +77,8 @@ async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
|
|||||||
f'用时 {int(duration)} 秒,获取到 {len(docs)} 条数据,最早时间 {parse_unixtime(docs[0]["crawled_at"])},最晚时间 {parse_unixtime(docs[-1]["crawled_at"])}'
|
f'用时 {int(duration)} 秒,获取到 {len(docs)} 条数据,最早时间 {parse_unixtime(docs[0]["crawled_at"])},最晚时间 {parse_unixtime(docs[-1]["crawled_at"])}'
|
||||||
)
|
)
|
||||||
for d in docs:
|
for d in docs:
|
||||||
|
d["title"] = d["title"].replace("\x00", "")
|
||||||
|
d["content"] = d["content"].replace("\x00", "")
|
||||||
yield d
|
yield d
|
||||||
# 如果当前时间度的数据量 = size 说明还有数据,继续请求
|
# 如果当前时间度的数据量 = size 说明还有数据,继续请求
|
||||||
# 这里使用递归
|
# 这里使用递归
|
||||||
@@ -180,11 +194,11 @@ async def sync():
|
|||||||
await cur.execute(
|
await cur.execute(
|
||||||
"""
|
"""
|
||||||
WITH RECURSIVE time_slots AS (
|
WITH RECURSIVE time_slots AS (
|
||||||
SELECT date_trunc('hour', now() - interval '1 hour') - interval '14 day' AS start_time
|
SELECT date_trunc('hour', now() - interval '2 hour') - interval '14 day' AS start_time
|
||||||
UNION ALL
|
UNION ALL
|
||||||
SELECT start_time + INTERVAL '1 hour'
|
SELECT start_time + INTERVAL '1 hour'
|
||||||
FROM time_slots
|
FROM time_slots
|
||||||
WHERE start_time < date_trunc('hour', now() - interval '1 hour')
|
WHERE start_time < date_trunc('hour', now() - interval '2 hour')
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
ts.start_time,
|
ts.start_time,
|
||||||
@@ -219,8 +233,14 @@ async def sync():
|
|||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
while True:
|
while True:
|
||||||
|
try:
|
||||||
await sync()
|
await sync()
|
||||||
print("同步完成,等待下一轮同步")
|
print("同步完成,等待下一轮同步")
|
||||||
|
except Exception as e:
|
||||||
|
# 打印出错误堆栈
|
||||||
|
traceback.print_exc()
|
||||||
|
print("同步出错,等待 60 秒后重试", e)
|
||||||
|
finally:
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
70
cucyuqing/cmd/kmeans.py
Normal file
70
cucyuqing/cmd/kmeans.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import numpy
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
|
||||||
|
from cucyuqing.pg import pool, get_cur
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# 从 PG 数据库获取数据
|
||||||
|
await pool.open()
|
||||||
|
async with get_cur() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, title, content, embedding
|
||||||
|
FROM risk_news
|
||||||
|
WHERE NOT embedding_updated_at IS NULL
|
||||||
|
AND time > now() - interval '14 day'
|
||||||
|
ORDER BY time desc
|
||||||
|
LIMIT 1000
|
||||||
|
;"""
|
||||||
|
)
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
docs = [
|
||||||
|
{
|
||||||
|
"id": row[0],
|
||||||
|
"title": row[1],
|
||||||
|
"content": row[2],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
|
||||||
|
|
||||||
|
# 设置聚类的数量
|
||||||
|
num_clusters = 50
|
||||||
|
|
||||||
|
# 初始化KMeans模型
|
||||||
|
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
|
||||||
|
|
||||||
|
# 进行聚类
|
||||||
|
kmeans.fit(embeddings)
|
||||||
|
|
||||||
|
# 获取每个样本的聚类标签
|
||||||
|
labels: list[int] = kmeans.labels_ # type: ignore
|
||||||
|
|
||||||
|
# 计算每个样本到其聚类中心的距离
|
||||||
|
distances = kmeans.transform(embeddings)
|
||||||
|
# 找到每个聚类中距离中心最近的文档
|
||||||
|
closest_docs = {}
|
||||||
|
for i, label in enumerate(labels):
|
||||||
|
distance_to_center = distances[i][label]
|
||||||
|
if label not in closest_docs or distance_to_center < closest_docs[label][0]:
|
||||||
|
closest_docs[label] = (distance_to_center, docs[i])
|
||||||
|
|
||||||
|
# 输出每个聚类中距离中心最近的文档
|
||||||
|
for label, (distance, doc) in closest_docs.items():
|
||||||
|
print(f"聚类 {label} 最近的文档: {doc['title']} 距离: {distance}")
|
||||||
|
|
||||||
|
sorted_samples: list[tuple[int, int]] = sorted(enumerate(labels), key=lambda x: x[1])
|
||||||
|
|
||||||
|
# 随机选择一个聚类
|
||||||
|
random_cluster = numpy.random.choice(num_clusters)
|
||||||
|
for i, label in sorted_samples:
|
||||||
|
if label == random_cluster:
|
||||||
|
print(f"聚类 {label} 文档: {docs[i]['title']}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
184
cucyuqing/cmd/risk-analyze.py
Normal file
184
cucyuqing/cmd/risk-analyze.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable
|
||||||
|
import openai
|
||||||
|
import asyncio
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
import tqdm
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
from cucyuqing.utils import print
|
||||||
|
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
|
||||||
|
from cucyuqing.pg import pool, get_cur
|
||||||
|
from cucyuqing.mysql import mysql
|
||||||
|
from cucyuqing.dbscan import Document, run_dbscan
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await do_analyze()
|
||||||
|
await asyncio.sleep(60 * 30)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
await asyncio.sleep(60 * 60)
|
||||||
|
|
||||||
|
|
||||||
|
async def do_analyze():
|
||||||
|
await asyncio.gather(
|
||||||
|
pool.open(),
|
||||||
|
mysql.connect(),
|
||||||
|
)
|
||||||
|
# 获取一个风险类型和对应的提示词
|
||||||
|
risk_types = await get_risk_type_prompt()
|
||||||
|
print("共有风险类型:", len(risk_types))
|
||||||
|
|
||||||
|
dbscan_result = await run_dbscan()
|
||||||
|
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
||||||
|
print("共有待分析文档:", len(docs), "噪声", len(dbscan_result.noise))
|
||||||
|
|
||||||
|
risks_to_update: dict[str, set[str]] = {}
|
||||||
|
analyze_result = await batch_risk_analyze(docs, risk_types)
|
||||||
|
for task in analyze_result:
|
||||||
|
if "是" not in task.response:
|
||||||
|
if risks_to_update.get(task.doc.id) is None:
|
||||||
|
risks_to_update[task.doc.id] = set()
|
||||||
|
continue
|
||||||
|
print(f"风险: {task.risk_type.name} 标题: {task.doc.title} {task.doc.id}")
|
||||||
|
|
||||||
|
# 合并每个文档的风险到一个set
|
||||||
|
if task.doc.id not in risks_to_update:
|
||||||
|
risks_to_update[task.doc.id] = set()
|
||||||
|
risks_to_update[task.doc.id].add(task.risk_type.name)
|
||||||
|
|
||||||
|
# 更新数据库
|
||||||
|
for doc_id, risks in risks_to_update.items():
|
||||||
|
await mysql.execute(
|
||||||
|
"""
|
||||||
|
UPDATE risk_news
|
||||||
|
SET risk_types = :risk_types, updated_at = now()
|
||||||
|
WHERE es_id = :es_id
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"es_id": doc_id,
|
||||||
|
"risk_types": (
|
||||||
|
json.dumps(list(risks), ensure_ascii=False) if risks else None
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RiskType:
|
||||||
|
name: str
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Task:
|
||||||
|
doc: Document
|
||||||
|
risk_type: RiskType
|
||||||
|
response: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_risk_type_prompt() -> list[RiskType]:
|
||||||
|
"""从数据库中获取风险类型和对应的提示词"""
|
||||||
|
rows = await mysql.fetch_all(
|
||||||
|
"""
|
||||||
|
SELECT rp.content, rt.name
|
||||||
|
FROM risk_prompt rp
|
||||||
|
JOIN risk_type rt ON rp.risk_type_id = rt.id
|
||||||
|
ORDER BY rp.id DESC
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return [RiskType(prompt=row[0], name=row[1]) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def batch_risk_analyze(
|
||||||
|
docs: list[Document],
|
||||||
|
risk_types: list[RiskType],
|
||||||
|
model: str = "gpt-4o-mini",
|
||||||
|
threads: int = 10,
|
||||||
|
) -> list[Task]:
|
||||||
|
"""文本风险分析(并行批批处理)"""
|
||||||
|
|
||||||
|
# 从 docs, risk_types 两个列表交叉生成任务列表
|
||||||
|
tasks: list[Task] = [
|
||||||
|
Task(doc=doc, risk_type=rt) for doc in docs for rt in risk_types
|
||||||
|
]
|
||||||
|
bar = tqdm.tqdm(total=len(tasks))
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def lmm_worker():
|
||||||
|
while True:
|
||||||
|
task = await queue.get()
|
||||||
|
if task is None:
|
||||||
|
break
|
||||||
|
task.response = await risk_analyze(task, model)
|
||||||
|
queue.task_done()
|
||||||
|
bar.update(1)
|
||||||
|
if bar.n % 100 == 0:
|
||||||
|
print(f"已完成 {bar.n} 条风险分析")
|
||||||
|
|
||||||
|
async def producer():
|
||||||
|
for task in tasks:
|
||||||
|
await queue.put(task)
|
||||||
|
|
||||||
|
workers = [asyncio.create_task(lmm_worker()) for _ in range(threads)]
|
||||||
|
|
||||||
|
await producer()
|
||||||
|
await queue.join()
|
||||||
|
for _ in workers:
|
||||||
|
await queue.put(None)
|
||||||
|
await asyncio.gather(*workers)
|
||||||
|
|
||||||
|
print("风险分析完成")
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
|
async def risk_analyze(task: Task, model: str) -> str:
|
||||||
|
"""对一条文本进行风险分析"""
|
||||||
|
llm = openai.AsyncOpenAI(
|
||||||
|
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
|
||||||
|
)
|
||||||
|
hash = hashlib.md5(
|
||||||
|
f"{model}|{task.doc.get_text_for_llm()}|{task.risk_type.prompt}".encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
# 查询缓存
|
||||||
|
async with get_cur() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"SELECT response FROM llm_cache WHERE id = %s LIMIT 1", (hash,)
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
if row:
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
messages: Iterable[ChatCompletionMessageParam] = [
|
||||||
|
{"role": "system", "content": task.risk_type.prompt},
|
||||||
|
{"role": "user", "content": task.doc.get_text_for_llm()},
|
||||||
|
]
|
||||||
|
resp = await llm.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=0,
|
||||||
|
stop="\n",
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
completions = resp.choices[0].message.content or ""
|
||||||
|
usage = resp.usage.model_dump_json() if resp.usage else None
|
||||||
|
|
||||||
|
# 缓存结果
|
||||||
|
async with get_cur() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"INSERT INTO llm_cache (id, messages, model, response, usage) VALUES (%s, %s, %s, %s, %s)",
|
||||||
|
(hash, json.dumps(messages, ensure_ascii=False), model, completions, usage),
|
||||||
|
)
|
||||||
|
|
||||||
|
return completions
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -22,5 +22,7 @@ def must_get_env(key: str):
|
|||||||
ES_API = get_env_with_default("ES_API", "http://192.168.1.45:1444")
|
ES_API = get_env_with_default("ES_API", "http://192.168.1.45:1444")
|
||||||
PG_DSN = must_get_env("PG_DSN")
|
PG_DSN = must_get_env("PG_DSN")
|
||||||
MYSQL_DSN = must_get_env("MYSQL_DSN")
|
MYSQL_DSN = must_get_env("MYSQL_DSN")
|
||||||
OPENAI_API_KEY = must_get_env("OPENAI_API_KEY")
|
OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY")
|
||||||
OPENAI_BASE_URL = get_env_with_default("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")
|
||||||
|
OPENAI_RISK_LLM_API_KEY = must_get_env("OPENAI_RISK_LLM_API_KEY")
|
||||||
|
OPENAI_RISK_LLM_BASE_URL = get_env_with_default("OPENAI_RISK_LLM_BASE_URL", "https://api.openai.com/v1")
|
||||||
|
|||||||
111
cucyuqing/dbscan.py
Normal file
111
cucyuqing/dbscan.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from typing_extensions import Doc
|
||||||
|
import numpy
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from sklearn.cluster import DBSCAN
|
||||||
|
from sklearn.metrics import pairwise_distances
|
||||||
|
|
||||||
|
from cucyuqing.pg import pool, get_cur
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Document:
|
||||||
|
id: str
|
||||||
|
"""ID 是 ES 中的 32 为 hex ID"""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
similarity: float = 0.0
|
||||||
|
|
||||||
|
def get_text_for_llm(self) -> str:
|
||||||
|
"""只使用标题进行风险分析
|
||||||
|
|
||||||
|
对于空标题,在入库时已经处理过。
|
||||||
|
如果入库时标题为空,则使用content的前20个字符或第一句中文作为标题。
|
||||||
|
"""
|
||||||
|
return self.title
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DBScanResult:
|
||||||
|
noise: list[Document]
|
||||||
|
clusters: list[list[Document]]
|
||||||
|
|
||||||
|
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
|
async def run_dbscan() -> DBScanResult:
|
||||||
|
# 从 PG 数据库获取数据
|
||||||
|
async with get_cur() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, title, content, embedding
|
||||||
|
FROM risk_news
|
||||||
|
WHERE NOT embedding_updated_at IS NULL
|
||||||
|
AND time > now() - interval '7 day'
|
||||||
|
ORDER BY time desc
|
||||||
|
LIMIT 100000
|
||||||
|
;"""
|
||||||
|
)
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
docs: list[Document] = [
|
||||||
|
Document(str(row[0]).replace("-", ""), row[1], row[2]) for row in rows
|
||||||
|
]
|
||||||
|
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
|
||||||
|
|
||||||
|
# 计算余弦距离矩阵
|
||||||
|
cosine_distances = pairwise_distances(embeddings, metric="cosine")
|
||||||
|
|
||||||
|
# 初始化DBSCAN模型
|
||||||
|
dbscan = DBSCAN(
|
||||||
|
eps=0.25, min_samples=2, metric="precomputed"
|
||||||
|
) # Adjust eps as needed
|
||||||
|
|
||||||
|
# 进行聚类
|
||||||
|
dbscan.fit(cosine_distances)
|
||||||
|
|
||||||
|
# 获取每个样本的聚类标签
|
||||||
|
labels: list[int] = dbscan.labels_ # type: ignore
|
||||||
|
|
||||||
|
# 输出每个聚类中的文档
|
||||||
|
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
|
||||||
|
unique_labels = set(labels)
|
||||||
|
for label in unique_labels:
|
||||||
|
class_member_mask = labels == label
|
||||||
|
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||||||
|
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||||||
|
|
||||||
|
if label == -1:
|
||||||
|
# -1 is the label for noise points
|
||||||
|
ret.noise = cluster_docs
|
||||||
|
else:
|
||||||
|
# 计算质心
|
||||||
|
centroid = numpy.mean(cluster_embeddings, axis=0).reshape(1, -1)
|
||||||
|
# 计算相似度
|
||||||
|
similarities = cosine_similarity(centroid, cluster_embeddings).flatten()
|
||||||
|
# 根据相似度排序
|
||||||
|
sorted_indices = numpy.argsort(similarities)[::-1]
|
||||||
|
sorted_cluster_docs = []
|
||||||
|
for i in sorted_indices:
|
||||||
|
doc = cluster_docs[i]
|
||||||
|
doc.similarity = similarities[i]
|
||||||
|
sorted_cluster_docs.append(doc)
|
||||||
|
ret.clusters.append(sorted_cluster_docs)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await pool.open()
|
||||||
|
result = await run_dbscan()
|
||||||
|
print(f"噪声文档: {len(result.noise)}")
|
||||||
|
for i, cluster in enumerate(result.clusters):
|
||||||
|
print("----------------")
|
||||||
|
for doc in cluster:
|
||||||
|
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
200353
cucyuqing/res/cl100k_base/tokenizer.json
Normal file
200353
cucyuqing/res/cl100k_base/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,3 +8,4 @@ databases[aiomysql]
|
|||||||
openai
|
openai
|
||||||
tokenizers
|
tokenizers
|
||||||
tqdm
|
tqdm
|
||||||
|
scikit-learn
|
||||||
|
|||||||
48
requirements_version.txt
Normal file
48
requirements_version.txt
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
aiohappyeyeballs==2.4.3
|
||||||
|
aiohttp==3.10.10
|
||||||
|
aiomysql==0.2.0
|
||||||
|
aiosignal==1.3.1
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.6.2.post1
|
||||||
|
attrs==24.2.0
|
||||||
|
certifi==2024.8.30
|
||||||
|
charset-normalizer==3.4.0
|
||||||
|
databases==0.9.0
|
||||||
|
distro==1.9.0
|
||||||
|
fastapi==0.115.2
|
||||||
|
filelock==3.16.1
|
||||||
|
frozenlist==1.4.1
|
||||||
|
fsspec==2024.9.0
|
||||||
|
greenlet==3.1.1
|
||||||
|
h11==0.14.0
|
||||||
|
httpcore==1.0.6
|
||||||
|
httpx==0.27.2
|
||||||
|
huggingface-hub==0.26.0
|
||||||
|
idna==3.10
|
||||||
|
jiter==0.6.1
|
||||||
|
joblib==1.4.2
|
||||||
|
multidict==6.1.0
|
||||||
|
numpy==2.1.2
|
||||||
|
openai==1.52.0
|
||||||
|
packaging==24.1
|
||||||
|
propcache==0.2.0
|
||||||
|
psycopg==3.2.3
|
||||||
|
psycopg-binary==3.2.3
|
||||||
|
psycopg-pool==3.2.3
|
||||||
|
pydantic==2.9.2
|
||||||
|
pydantic_core==2.23.4
|
||||||
|
PyMySQL==1.1.1
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
PyYAML==6.0.2
|
||||||
|
requests==2.32.3
|
||||||
|
scikit-learn==1.5.2
|
||||||
|
scipy==1.14.1
|
||||||
|
sniffio==1.3.1
|
||||||
|
SQLAlchemy==2.0.36
|
||||||
|
starlette==0.40.0
|
||||||
|
threadpoolctl==3.5.0
|
||||||
|
tokenizers==0.20.1
|
||||||
|
tqdm==4.66.5
|
||||||
|
typing_extensions==4.12.2
|
||||||
|
urllib3==2.2.3
|
||||||
|
yarl==1.15.4
|
||||||
Reference in New Issue
Block a user