diff --git a/cucyuqing/cmd/embedding.py b/cucyuqing/cmd/embedding.py index da2792b..0cb5bdf 100644 --- a/cucyuqing/cmd/embedding.py +++ b/cucyuqing/cmd/embedding.py @@ -158,18 +158,28 @@ async def get_embedding_from_cache(hash: str) -> list[float] | None: async def main(): await pool.open() - async with get_cur() as cur: - # 这里选择 embedding_updated_at is null 使用索引避免全表扫描 - await cur.execute("SELECT id, title, content from risk_news where embedding_updated_at is null limit 1000") - docs = await cur.fetchall() - if not docs: - print(datetime.datetime.now(), "No data to update") + while True: + await do_update() await asyncio.sleep(60) - return - embeddings = await get_embeddings([doc[1] + " " + doc[2] for doc in docs], "acge-large-zh") - async with get_cur() as cur: - for doc, embedding in tqdm.tqdm(zip(docs, embeddings), total=min(len(docs), len(embeddings)), desc="Update embeddings"): - await cur.execute("UPDATE risk_news SET embedding = %s, embedding_updated_at = now() where id = %s", (embedding, doc[0])) + +async def do_update(): + while True: + async with get_cur() as cur: + # 这里选择 embedding_updated_at is null 使用索引避免全表扫描 + await cur.execute("SELECT id, title, content from risk_news where embedding_updated_at is null limit 1000") + docs = await cur.fetchall() + + # 循环出口 + if not docs: + print(datetime.datetime.now(), "No data to update") + break + + embeddings = await get_embeddings([doc[1] + " " + doc[2] for doc in docs], "acge-large-zh") + async with get_cur() as cur: + for doc, embedding in tqdm.tqdm(zip(docs, embeddings), total=min(len(docs), len(embeddings)), desc="Update embeddings"): + await cur.execute("UPDATE risk_news SET embedding = %s, embedding_updated_at = now() where id = %s", (embedding, doc[0])) + + await asyncio.sleep(1) if __name__ == "__main__": asyncio.run(main())