From 1eed30700cbef4e2e3d3a58022df0fb0af4cc17f Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Tue, 24 Sep 2024 22:51:37 +0800 Subject: [PATCH] merge bot_db to bot_chatgpt --- bot_chatgpt.py | 364 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 354 insertions(+), 10 deletions(-) diff --git a/bot_chatgpt.py b/bot_chatgpt.py index a4d9145..94e4061 100644 --- a/bot_chatgpt.py +++ b/bot_chatgpt.py @@ -2,6 +2,28 @@ import os import dotenv dotenv.load_dotenv() +import PyPDF2 +import html2text +import re +import hashlib +from nio import ( + DownloadError, + MatrixRoom, + RoomMessageAudio, + RoomMessageFile, + RoomMessageText, +) +from langchain.text_splitter import MarkdownTextSplitter +from bot import Bot, print +import asyncio +import io +import yt_dlp +import os +import subprocess +from langchain.embeddings import OpenAIEmbeddings + +from selenium import webdriver + import asyncio import jinja2 import requests @@ -93,7 +115,6 @@ async def get_reply_file_content(event): return "", 0 -@client.ignore_link @client.message_callback_common_wrapper async def message_callback(room: MatrixRoom, event: RoomMessageText) -> None: # handle set system message @@ -143,6 +164,53 @@ async def message_callback(room: MatrixRoom, event: RoomMessageText) -> None: ) await client.react_ok(room.room_id, event.event_id) return + + should_react = True + if event.body.startswith("!clear") or event.body.startswith("!clean"): + # save to db + async with client.db.transaction(): + await client.db.execute( + query=""" + delete from embeddings e + using room_document rd + where e.document_md5 = rd.document_md5 and + rd.room = :room_id; + """, + values={"room_id": room.room_id}, + ) + await client.db.execute( + query="delete from room_document where room = :room_id;", + values={"room_id": room.room_id}, + ) + elif event.body.startswith("!embedding"): + sp = event.body.split() + if len(sp) < 2: + return + if not sp[1].lower() in ["on", "off"]: + return + status = sp[1].lower() == "on" + await client.db.execute( + query=""" + insert into room_configs (room, embedding) + values (:room_id, :status) + on conflict (room) do update set embedding = excluded.embedding + ;""", + values={"room_id": room.room_id, "status": status}, + ) + else: + should_react = False + if should_react: + await client.room_send( + room.room_id, + "m.reaction", + { + "m.relates_to": { + "event_id": event.event_id, + "key": "😘", + "rel_type": "m.annotation", + } + }, + ) return messages: list[BaseMessage] = [] @@ -290,7 +358,7 @@ async def message_callback(room: MatrixRoom, event: RoomMessageText) -> None: sum(client.get_token_length(m.content) for m in messages) + len(messages) * 6 ) if not model_name: - model_name = "gpt-3.5-turbo-1106" + model_name = "gpt-4o-mini" print("messages", messages) chat_model = ChatOpenAI( @@ -345,14 +413,7 @@ async def message_callback(room: MatrixRoom, event: RoomMessageText) -> None: client.add_event_callback(message_callback, RoomMessageText) - -@client.ignore_self_message -@client.handel_no_gpt -@client.log_message -@client.with_typing -@client.replace_command_mark -@client.safe_try -async def message_file(room: MatrixRoom, event: RoomMessageFile): +async def message_file_for_chatgpt_api_web(room: MatrixRoom, event: RoomMessageFile): if not event.flattened().get("content.info.mimetype") == "application/json": print("not application/json") return @@ -396,6 +457,289 @@ async def message_file(room: MatrixRoom, event: RoomMessageFile): ) +spliter = MarkdownTextSplitter( + chunk_size=400, + chunk_overlap=100, + length_function=client.get_token_length, +) + +offices_mimetypes = [ + "application/wps-office.docx", + "application/wps-office.doc", + "application/wps-office.pptx", + "application/wps-office.ppt", + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-powerpoint", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "application/vnd.oasis.opendocument.text", + "application/vnd.oasis.opendocument.presentation", +] +mimetypes = [ + "text/plain", + "application/pdf", + "text/markdown", + "text/html", +] + offices_mimetypes + + +def allowed_file(mimetype: str): + return mimetype.lower() in mimetypes + + +async def create_embedding(room, event, md5, content, url): + transaction = await client.db.transaction() + await client.db.execute( + query="""insert into documents (md5, content, token, url) + values (:md5, :content, :token, :url) + on conflict (md5) do nothing + ;""", + values={ + "md5": md5, + "content": content, + "token": client.get_token_length(content), + "url": url, + }, + ) + + rows = await client.db.fetch_all( + query="select document_md5 from room_document where room = :room and document_md5 = :md5 limit 1;", + values={"room": room.room_id, "md5": md5}, + ) + if len(rows) > 0: + await transaction.rollback() + print("document alreadly insert in room", md5, room.room_id) + await client.room_send( + room.room_id, + "m.reaction", + { + "m.relates_to": { + "event_id": event.event_id, + "key": "👍", + "rel_type": "m.annotation", + } + }, + ) + return + + await client.db.execute( + query=""" + insert into room_document (room, document_md5) + values (:room_id, :md5) + on conflict (room, document_md5) do nothing + ;""", + values={"room_id": room.room_id, "md5": md5}, + ) + + # start embedding + chunks = spliter.split_text(content) + print("chunks", len(chunks)) + embeddings = await embeddings_model.aembed_documents(chunks, chunk_size=1600) + print("embedding finished", len(embeddings)) + if len(chunks) != len(embeddings): + raise ValueError("asdf") + insert_data: list[dict] = [] + for chunk, embedding in zip(chunks, embeddings): + insert_data.append( + { + "document_md5": md5, + "md5": hashlib.md5(chunk.encode()).hexdigest(), + "content": chunk, + "token": client.get_token_length(chunk), + "embedding": str(embedding), + } + ) + await client.db.execute_many( + query="""insert into embeddings (document_md5, md5, content, token, embedding) + values (:document_md5, :md5, :content, :token, :embedding) + on conflict (document_md5, md5) do nothing + ;""", + values=insert_data, + ) + print("insert", len(insert_data), "embedding data") + + await client.db.execute( + query=""" + insert into event_document (event, document_md5) + values (:event_id, :md5) + on conflict (event) do nothing + ;""", + values={"event_id": event.event_id, "md5": md5}, + ) + + await transaction.commit() + + await client.room_send( + room.room_id, + "m.reaction", + { + "m.relates_to": { + "event_id": event.event_id, + "key": "😘", + "rel_type": "m.annotation", + } + }, + ) + + +def clean_html(html: str) -> str: + h2t = html2text.HTML2Text() + h2t.ignore_emphasis = True + h2t.ignore_images = True + h2t.ignore_links = True + h2t.body_width = 0 + content = h2t.handle(html) + return content + + +def clean_content(content: str, mimetype: str, document_md5: str) -> str: + # clean 0x00 + content = content.replace("\x00", "") + # clean links + content = re.sub(r"\[.*?\]\(.*?\)", "", content) + content = re.sub(r"!\[.*?\]\(.*?\)", "", content) + # clean lines + lines = [i.strip() for i in content.split("\n\n")] + while "" in lines: + lines.remove("") + + content = "\n\n".join(lines) + content = "\n".join([i.strip() for i in content.split("\n")]) + + return content + + +def pdf_to_text(f) -> str: + pdf_reader = PyPDF2.PdfReader(f) + num_pages = len(pdf_reader.pages) + + content = "" + for page_number in range(num_pages): + page = pdf_reader.pages[page_number] + content += page.extract_text() + return content + + + +yt_dlp_support = ["b23.tv/", "www.bilibili.com/video/", "youtube.com/"] + + +def allow_yt_dlp(link: str) -> bool: + if not link.startswith("http://") and not link.startswith("https://"): + return False + allow = False + for u in yt_dlp_support: + if u in link: + allow = True + break + return allow + + +def allow_web(link: str) -> bool: + print("checking web url", link) + if not link.startswith("http://") and not link.startswith("https://"): + return False + return True + +@client.ignore_self_message +@client.handel_no_gpt +@client.log_message +@client.with_typing +@client.replace_command_mark +@client.safe_try +async def message_file(room: MatrixRoom, event: RoomMessageFile): + # route for chatgpt-api-web + if event.flattened().get("content.info.mimetype") == "application/json": + await message_file_for_chatgpt_api_web(room, event) + return + + print("received file") + mimetype = event.flattened().get("content.info.mimetype", "") + if not allowed_file(mimetype): + print("not allowed file", event.body) + raise ValueError("not allowed file") + resp = await client.download(event.url) + if isinstance(resp, DownloadError): + raise ValueError("file donwload error") + + assert isinstance(resp.body, bytes) + md5 = hashlib.md5(resp.body).hexdigest() + + document_fetch_result = await client.db.fetch_one( + query="select content from documents where md5 = :md5 limit 1;", + values={"md5": md5}, + ) + + # get content + content = document_fetch_result[0] if document_fetch_result else "" + # document not exists + if content: + print("document", md5, "alreadly exists") + else: + if mimetype == "text/plain" or mimetype == "text/markdown": + content = resp.body.decode() + elif mimetype == "text/html": + content = clean_html(resp.body.decode()) + elif mimetype == "application/pdf": + f = io.BytesIO(resp.body) + content = pdf_to_text(f) + elif mimetype in offices_mimetypes: + # save file to temp dir + base = event.body.rsplit(".", 1)[0] + ext = event.body.rsplit(".", 1)[1] + print("base", base) + source_filepath = os.path.join("./cache/office", event.body) + txt_filename = base + ".txt" + txt_filepath = os.path.join("./cache/office", txt_filename) + print("source_filepath", source_filepath) + with open(source_filepath, "wb") as f: + f.write(resp.body) + if ext in ["doc", "docx", "odt"]: + process = subprocess.Popen( + [ + "soffice", + "--headless", + "--convert-to", + "txt:Text", + "--outdir", + "./cache/office", + source_filepath, + ] + ) + process.wait() + with open(txt_filepath, "r") as f: + content = f.read() + elif ext in ["ppt", "pptx", "odp"]: + pdf_filename = base + ".pdf" + pdf_filepath = os.path.join("./cache/office", pdf_filename) + process = subprocess.Popen( + [ + "soffice", + "--headless", + "--convert-to", + "pdf", + "--outdir", + "./cache/office", + source_filepath, + ] + ) + process.wait() + with open(pdf_filepath, "rb") as f: + content = pdf_to_text(f) + else: + raise ValueError("unknown ext: ", ext) + print("converted txt", content) + else: + raise ValueError("unknown mimetype", mimetype) + + content = clean_content(content, mimetype, md5) + + print("content length", len(content)) + + await create_embedding(room, event, md5, content, event.url) + + client.add_event_callback(message_file, RoomMessageFile) asyncio.run(client.sync_forever())