From a49db1b71b82f27600359a4195f79feba03a31e0 Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Tue, 8 Oct 2024 14:35:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8=20title=20or=20content=20?= =?UTF-8?q?=E4=BD=9C=E4=B8=BAembedding=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cucyuqing/cmd/embedding.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/cucyuqing/cmd/embedding.py b/cucyuqing/cmd/embedding.py index 2e6ff24..a894596 100644 --- a/cucyuqing/cmd/embedding.py +++ b/cucyuqing/cmd/embedding.py @@ -1,9 +1,7 @@ import traceback import datetime import asyncio -from sqlalchemy.sql.ddl import exc import tqdm -import os from tokenizers import Tokenizer import openai import hashlib @@ -65,9 +63,8 @@ async def get_embeddings( - quiet: 是否关闭输出 """ - # 针对 acge-large-zh 模型,需要将文本截断 1024 - 2 - if model == "acge-large-zh": - texts = [truncate_text(model, text, 1024 - 2) for text in texts] + # 针对 大多数 模型,需要将文本截断 1024 - 2 + texts = [truncate_text(model, text, 1024 - 2) for text in texts] # 构建任务列表 ids = list(range(len(texts))) @@ -192,7 +189,7 @@ async def do_update(): print(datetime.datetime.now(), "No data to update") break - embeddings = await get_embeddings([doc[1] + " " + doc[2] for doc in docs], "acge-large-zh") + embeddings = await get_embeddings([doc[1] or doc[2] for doc in docs], "acge-large-zh", threads=10) async with get_cur() as cur: for doc, embedding in tqdm.tqdm(zip(docs, embeddings), total=min(len(docs), len(embeddings)), desc="Update embeddings"): await cur.execute("UPDATE risk_news SET embedding = %s, embedding_updated_at = now() where id = %s", (embedding, doc[0]))