循环更新 embedding
This commit is contained in:
@@ -158,18 +158,28 @@ async def get_embedding_from_cache(hash: str) -> list[float] | None:
|
||||
|
||||
async def main():
|
||||
await pool.open()
|
||||
while True:
|
||||
await do_update()
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def do_update():
|
||||
while True:
|
||||
async with get_cur() as cur:
|
||||
# 这里选择 embedding_updated_at is null 使用索引避免全表扫描
|
||||
await cur.execute("SELECT id, title, content from risk_news where embedding_updated_at is null limit 1000")
|
||||
docs = await cur.fetchall()
|
||||
|
||||
# 循环出口
|
||||
if not docs:
|
||||
print(datetime.datetime.now(), "No data to update")
|
||||
await asyncio.sleep(60)
|
||||
return
|
||||
break
|
||||
|
||||
embeddings = await get_embeddings([doc[1] + " " + doc[2] for doc in docs], "acge-large-zh")
|
||||
async with get_cur() as cur:
|
||||
for doc, embedding in tqdm.tqdm(zip(docs, embeddings), total=min(len(docs), len(embeddings)), desc="Update embeddings"):
|
||||
await cur.execute("UPDATE risk_news SET embedding = %s, embedding_updated_at = now() where id = %s", (embedding, doc[0]))
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user