diff --git a/cucyuqing/cmd/embedding.py b/cucyuqing/cmd/embedding.py index 2e6ff24..a894596 100644 --- a/cucyuqing/cmd/embedding.py +++ b/cucyuqing/cmd/embedding.py @@ -1,9 +1,7 @@ import traceback import datetime import asyncio -from sqlalchemy.sql.ddl import exc import tqdm -import os from tokenizers import Tokenizer import openai import hashlib @@ -65,9 +63,8 @@ async def get_embeddings( - quiet: 是否关闭输出 """ - # 针对 acge-large-zh 模型,需要将文本截断 1024 - 2 - if model == "acge-large-zh": - texts = [truncate_text(model, text, 1024 - 2) for text in texts] + # 针对 大多数 模型,需要将文本截断 1024 - 2 + texts = [truncate_text(model, text, 1024 - 2) for text in texts] # 构建任务列表 ids = list(range(len(texts))) @@ -192,7 +189,7 @@ async def do_update(): print(datetime.datetime.now(), "No data to update") break - embeddings = await get_embeddings([doc[1] + " " + doc[2] for doc in docs], "acge-large-zh") + embeddings = await get_embeddings([doc[1] or doc[2] for doc in docs], "acge-large-zh", threads=10) async with get_cur() as cur: for doc, embedding in tqdm.tqdm(zip(docs, embeddings), total=min(len(docs), len(embeddings)), desc="Update embeddings"): await cur.execute("UPDATE risk_news SET embedding = %s, embedding_updated_at = now() where id = %s", (embedding, doc[0]))