commit 47ca0639dafeab230f3e0c637c2e3b3fff530538 Author: heimoshuiyu Date: Sat Oct 21 21:14:57 2023 +0800 init diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d5e58ed --- /dev/null +++ b/.dockerignore @@ -0,0 +1,4 @@ +__pycache__ +/venv +/.env +.git diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d5e58ed --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +/venv +/.env +.git diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0716aff --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.11-slim-buster + +# Update the system +RUN apt-get update -y + +# Install LibreOffice, Firefox ESR and pip +RUN apt-get install -y libreoffice firefox-esr + +# Set the working directory in the container to /app +WORKDIR /app + +COPY ./requirements_version.txt /app/requirements_version.txt +RUN pip3 install --no-cache-dir -r requirements_version.txt + +# cache tiktoken dict +RUN python3 -c 'import tiktoken; enc = tiktoken.get_encoding("cl100k_base")' + +# Add the current directory contents into the container at /app +COPY . /app diff --git a/bot.py b/bot.py new file mode 100644 index 0000000..f69de6e --- /dev/null +++ b/bot.py @@ -0,0 +1,315 @@ +import dotenv + +dotenv.load_dotenv() +import os +import traceback +from functools import wraps +from nio import AsyncClient, MatrixRoom, RoomMessageText, SyncError, InviteEvent +import tiktoken +import databases +import builtins +import sys + + +def print(*args, **kwargs): + kwargs["file"] = sys.stderr + builtins.print(*args, **kwargs) + + +class Bot(AsyncClient): + def __init__(self, homeserver: str, user: str, device_id: str, access_token: str): + super().__init__(homeserver) + self.access_token = access_token + self.user_id = self.user = user + self.device_id = device_id + self.welcome_message = "" + self._joined_rooms = [] + self.db = databases.Database(os.environ["MATRIX_CHAIN_DB"]) + + self.enc = tiktoken.encoding_for_model("gpt-4") + + # auto join + self.add_event_callback(self.auto_join, InviteEvent) + + async def init_db(self): + db = self.db + await db.execute("CREATE EXTENSION IF NOT EXISTS vector") + await db.execute( + """ + CREATE TABLE IF NOT EXISTS documents + ( + md5 character(32) NOT NULL PRIMARY KEY, + content text, + token integer, + url text + ) + """ + ) + await db.execute( + """ + CREATE TABLE IF NOT EXISTS embeddings + ( + document_md5 character(32) NOT NULL, + md5 character(32) NOT NULL, + content text NOT NULL, + token integer NOT NULL, + embedding vector(1536) NOT NULL, + PRIMARY KEY (document_md5, md5), + FOREIGN KEY (document_md5) REFERENCES documents(md5) + ); + """ + ) + await db.execute( + """ + CREATE TABLE IF NOT EXISTS event_document + ( + event text NOT NULL PRIMARY KEY, + document_md5 character(32) NOT NULL, + FOREIGN KEY (document_md5) + REFERENCES documents (md5) + ); + """ + ) + await db.execute( + """ + CREATE TABLE IF NOT EXISTS memories + ( + id SERIAL PRIMARY KEY, + root text NOT NULL, + role integer NOT NULL, + content text NOT NULL, + token integer NOT NULL + ) + """ + ) + await db.execute( + """ + CREATE TABLE IF NOT EXISTS room_configs + ( + room text NOT NULL PRIMARY KEY, + model_name text, + temperature float NOT NULL DEFAULT 0, + system text, + embedding boolean NOT NULL DEFAULT false, + examples TEXT[] NOT NULL DEFAULT '{}' + ) + """ + ) + await db.execute( + """ + CREATE TABLE IF NOT EXISTS room_document ( + room text NOT NULL, + document_md5 character(32) NOT NULL, + PRIMARY KEY (room, document_md5) + ); + """ + ) + + def get_token_length(self, text: str) -> int: + return len(self.enc.encode(text)) + + async def sync_forever(self): + # init + print("connecting to db") + await self.db.connect() + + # init db hook + print("init db hook") + await self.init_db() + + # remote callback to perform initial sync + callbacks = self.event_callbacks + self.event_callbacks = [] + # skip intial sync + print("Perform initial sync") + resp = await self.sync(timeout=30000) + if isinstance(resp, SyncError): + raise BaseException(SyncError) + self.event_callbacks = callbacks + # set online + print("Set online status") + await self.set_presence("online") + # sync + print("Start forever sync") + return await super().sync_forever(300000, since=resp.next_batch) + + async def auto_join(self, room: MatrixRoom, event: InviteEvent): + print("join", event.sender, room.room_id) + if room.room_id in self._joined_rooms: + return + await self.join(room.room_id) + self._joined_rooms.append(room.room_id) + if self.welcome_message: + await self.room_send( + room.room_id, + "m.room.message", + { + "body": self.welcome_message, + "msgtype": "m.text", + "nogpt": True, + }, + ) + + def ignore_self_message(self, func): + @wraps(func) + async def ret(room: MatrixRoom, event: RoomMessageText): + if event.sender == self.user: + return + return await func(room, event) + + return ret + + def log_message(self, func): + @wraps(func) + async def ret(room: MatrixRoom, event: RoomMessageText): + print(room.room_id, event.sender, event.body) + return await func(room, event) + + return ret + + def with_typing(self, func): + @wraps(func) + async def ret(room, *args, **kargs): + await self.room_typing(room.room_id, True, 60000 * 3) + resp = await func(room, *args, **kargs) + await self.room_typing(room.room_id, False) + return resp + + return ret + + def change_event_id_to_root_id(self, func): + @wraps(func) + async def ret(room, event, *args, **kargs): + root = event.event_id + if event.flattened().get("content.m.relates_to.rel_type") == "m.thread": + root = event.source["content"]["m.relates_to"]["event_id"] + event.event_id = root + return await func(room, event, *args, **kargs) + + return ret + + def ignore_not_mentioned(self, func): + @wraps(func) + async def ret(room, event, *args, **kargs): + flattened = event.flattened() + if not self.user in flattened.get( + "content.body", "" + ) and not self.user in flattened.get("content.formatted_body", ""): + return + return await func(room, event, *args, **kargs) + + return ret + + def replace_command_mark(self, func): + @wraps(func) + async def ret(room, event, *args, **kargs): + if event.body.startswith("!"): + event.body = "!" + event.body[1:] + return await func(room, event, *args, **kargs) + + return ret + + def handel_no_gpt(self, func): + @wraps(func) + async def ret(room, event, *args, **kargs): + if not event.flattened().get("content.nogpt") is None: + return + return await func(room, event, *args, **kargs) + + return ret + + def replace_reply_file_with_content(self, func): + @wraps(func) + async def ret(room, event, *args, **kargs): + flattened = event.flattened() + formatted_body = flattened.get("content.formatted_body", "") + if ( + formatted_body.startswith("") + and "" in formatted_body + ): + print("replacing file content") + formatted_body = formatted_body[ + formatted_body.index("") + len("") : + ] + document_event_id = flattened.get( + "content.m.relates_to.m.in_reply_to.event_id", "" + ) + fetch = await self.db.fetch_all( + query="""select d.content, d.token + from documents d + join event_document ed on d.md5 = ed.document_md5 + where ed.event = :event_id""", + values={"event_id": document_event_id}, + ) + if len(fetch) > 0 and fetch[0][1] < 8192 + 4096: + content = fetch[0][0] + print(content) + print("-----------") + event.body = content + "\n\n---\n\n" + formatted_body + else: + print("document not found or too large", event.event_id) + + return await func(room, event, *args, **kargs) + + return ret + + def ignore_link(self, func): + @wraps(func) + async def ret(room, event, *args, **kargs): + if event.body.startswith("https://") or event.body.startswith("http://"): + return + return await func(room, event, *args, **kargs) + + return ret + + def safe_try(self, func): + @wraps(func) + async def ret(room, event, *args, **kargs): + try: + return await func(room, event, *args, **kargs) + except Exception as e: + print("--------------") + print("error:") + print(traceback.format_exc()) + print("--------------") + await self.room_send( + room.room_id, + "m.reaction", + { + "m.relates_to": { + "event_id": event.event_id, + "key": "😵", + "rel_type": "m.annotation", + } + }, + ) + + return ret + + def message_callback_common_wrapper(self, func): + @wraps(func) + @self.ignore_self_message + @self.handel_no_gpt + @self.log_message + @self.with_typing + @self.replace_reply_file_with_content + @self.change_event_id_to_root_id + @self.replace_command_mark + @self.safe_try + async def ret(*args, **kargs): + return await func(*args, **kargs) + + return ret + + async def react_ok(self, room_id: str, event_id: str): + await self.room_send( + room_id, + "m.reaction", + { + "m.relates_to": { + "event_id": event_id, + "key": "😘", + "rel_type": "m.annotation", + } + }, + ) diff --git a/bot_chatgpt.py b/bot_chatgpt.py new file mode 100644 index 0000000..09f2e4f --- /dev/null +++ b/bot_chatgpt.py @@ -0,0 +1,401 @@ +import os +import dotenv + +dotenv.load_dotenv() +import asyncio +import jinja2 +import requests +import datetime +from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage +from nio import MatrixRoom, RoomMessageFile, RoomMessageText +from langchain.chat_models import ChatOpenAI +import json +from langchain import LLMChain +from langchain.prompts import ChatPromptTemplate + +from bot import Bot, print + +from langchain.embeddings import OpenAIEmbeddings, awa + +embeddings_model = OpenAIEmbeddings( + openai_api_key=os.environ["OPENAI_API_KEY"], + openai_api_base=os.environ["OPENAI_API_BASE"], + show_progress_bar=True, +) + +client = Bot( + os.environ["BOT_CHATGPT_HOMESERVER"], + os.environ["BOT_CHATGPT_USER"], + os.environ["MATRIX_CHAIN_DEVICE"], + os.environ["BOT_CHATGPT_ACCESS_TOKEN"], +) +client.welcome_message = """你好👋,我是 matrix chain 中的大语言模型插件 +## 使用方式: +- 直接在房间内发送消息,GPT 会在消息列中进行回复。GPT 会单独记住每个消息列中的所有内容,每个消息列单独存在互不干扰 +## 配置方式: +- 发送 "!system + 系统消息" 配置大语言模型的角色,例如发送 "!system 你是一个专业英语翻译,你要把我说的话翻译成英语。你可以调整语序结构和用词让翻译更加通顺。" +""" + + +class SilentUndefined(jinja2.Undefined): + def _fail_with_undefined_error(self, *args, **kwargs): + print(f'jinja2.Undefined: "{self._undefined_name}" is undefined') + return "" + + +def render(template: str, **kargs) -> str: + env = jinja2.Environment(undefined=SilentUndefined) + temp = env.from_string(template) + + def now() -> str: + return datetime.datetime.now().strftime("%Y-%m-%d") + + temp.globals["now"] = now + return temp.render(**kargs) + + +async def get_reply_file_content(event): + """When user reply to a event, retrive the file content (document) of event + + Return with the file content and token length + """ + flattened = event.flattened() + formatted_body = flattened.get("content.formatted_body", "") + if not ( + formatted_body.startswith("") and "" in formatted_body + ): + return "", 0 + + print("replacing file content") + formatted_body = formatted_body[ + formatted_body.index("") + len("") : + ] + document_event_id = flattened.get("content.m.relates_to.m.in_reply_to.event_id", "") + fetch = await client.db.fetch_one( + query="""select d.content, d.token + from documents d + join event_document ed on d.md5 = ed.document_md5 + where ed.event = :document_event_id""", + values={ + "document_event_id": document_event_id, + }, + ) + + if fetch and fetch[1] < 8192 + 4096: + content = fetch[0] + token = fetch[1] + print(content) + print(token) + print("-----------") + return content, token + + print("document not found or too large", event.event_id) + return "", 0 + + +@client.ignore_link +@client.message_callback_common_wrapper +async def message_callback(room: MatrixRoom, event: RoomMessageText) -> None: + # handle set system message + if event.body.startswith("!"): + if event.body.startswith("!system"): + systemMessageContent = event.body.lstrip("!system").strip() + # save to db + await client.db.execute( + query=""" + insert into room_configs (room, system, examples) + values (:room_id, :systemMessageContent, '{}') + on conflict (room) + do update set system = excluded.system, examples = '{}' + """, + values={ + "room_id": room.room_id, + "systemMessageContent": systemMessageContent, + }, + ) + await client.react_ok(room.room_id, event.event_id) + return + if event.body.startswith("!model"): + model_name = event.body.lstrip("!model").strip() + # save to db + await client.db.execute( + query=""" + insert into room_configs (room, model_name) + values (:room_id, :model_name) + on conflict (room) + do update set model_name = excluded.model_name + """, + values={"room_id": room.room_id, "model_name": model_name}, + ) + await client.react_ok(room.room_id, event.event_id) + return + if event.body.startswith("!temp"): + temperature = float(event.body.lstrip("!temp").strip()) + # save to db + await client.db.execute( + query=""" + insert into room_configs (room, temperature) + values (:room_id, :temperature) + on conflict (room) + do update set temperature = excluded.temperature + """, + values={"room_id": room.room_id, "temperature": temperature}, + ) + await client.react_ok(room.room_id, event.event_id) + return + return + + messages: list[BaseMessage] = [] + # query prompt from db + db_result = await client.db.fetch_one( + query=""" + select system, examples, model_name, temperature + from room_configs + where room = :room_id + limit 1 + """, + values={"room_id": room.room_id}, + ) + model_name: str = db_result[2] if db_result else "" + temperature: float = db_result[3] if db_result else 0 + + systemMessageContent: str = db_result[0] if db_result else "" + systemMessageContent = systemMessageContent or "" + if systemMessageContent: + messages.append(SystemMessage(content=systemMessageContent)) + + examples = db_result[1] if db_result else [] + for i, m in enumerate(examples): + if not m: + print("Warning: message is empty", m) + continue + if i % 2 == 0: + messages.append(HumanMessage(content=m["content"], example=True)) + else: + messages.append(AIMessage(content=m["content"], example=True)) + + exampleTokens = 0 + exampleTokens += sum(client.get_token_length(m) for m in examples if m) + + # get embedding + embedding_query = await client.db.fetch_all( + query=""" + select content, distance, total_token from ( + select + content, + document_md5, + distance, + sum(token) over (partition by room order by distance) as total_token + from ( + select + content, + rd.room, + e.document_md5, + e.embedding <#> :embedding as distance, + token + from embeddings e + join room_document rd on rd.document_md5 = e.document_md5 + join room_configs rc on rc.room = rd.room + where rd.room = :room_id and rc.embedding + order by distance + limit 16 + ) as sub + ) as sub2 + where total_token < 6144 + ;""", + values={ + "embedding": str(await embeddings_model.aembed_query(event.body)), + "room_id": room.room_id, + }, + ) + print("emebdding_query", embedding_query) + embedding_token = 0 + embedding_text = "" + if len(embedding_query) > 0: + embedding_query.reverse() + embedding_text = "\n\n".join([i[0] for i in embedding_query]) + embedding_token = client.get_token_length(embedding_text) + + filecontent, filetoken = await get_reply_file_content(event) + + # query memory from db + max_token = 4096 * 4 + token_margin = 4096 + system_token = client.get_token_length(systemMessageContent) + exampleTokens + memory_token = max_token - token_margin - system_token - embedding_token - filetoken + print( + "system_token", + system_token, + "emebdding_token", + embedding_token, + "filetoken", + filetoken, + ) + rows = await client.db.fetch_all( + query="""select role, content from ( + select role, content, sum(token) over (partition by root order by id desc) as total_token + from memories + where root = :root + order by id + ) as sub + where total_token < :token + ;""", + values={"root": event.event_id, "token": memory_token}, + ) + for role, content in rows: + if role == 1: + messages.append(HumanMessage(content=content)) + elif role == 2: + messages.append(AIMessage(content=content)) + else: + print("Unknown message role", role, content) + + temp = "{{input}}" + if filecontent and embedding_text: + temp = """## Reference information: + +{{embedding}} + +--- + +## Query document: + +{{filecontent}} + +--- + +{{input}}""" + elif embedding_text: + temp = """## Reference information: + +{{embedding}} + +--- + +{{input}}""" + elif filecontent: + temp = """ ## Query document: + +{{filecontent}} + +--- + +{{input}}""" + temp = render( + temp, input=event.body, embedding=embedding_text, filecontent=filecontent + ) + messages.append(HumanMessage(content=temp)) + + total_token = ( + sum(client.get_token_length(m.content) for m in messages) + len(messages) * 6 + ) + if not model_name: + model_name = "gpt-3.5-turbo" if total_token < 3939 else "gpt-3.5-turbo-16k" + + print("messages", messages) + chat_model = ChatOpenAI( + openai_api_base=os.environ["OPENAI_API_BASE"], + openai_api_key=os.environ["OPENAI_API_KEY"], + model=model_name, + temperature=temperature, + ) + chain = LLMChain(llm=chat_model, prompt=ChatPromptTemplate.from_messages(messages)) + + result = await chain.arun( + { + "input": event.body, + "embedding": embedding_text, + "filecontent": filecontent, + } + ) + print(result) + + await client.room_send( + room.room_id, + "m.room.message", + { + "body": result, + "msgtype": "m.text", + "m.relates_to": { + "rel_type": "m.thread", + "event_id": event.event_id, + }, + }, + ) + + # record query and result + await client.db.execute_many( + query="insert into memories(root, role, content, token) values (:root, :role, :content, :token)", + values=[ + { + "root": event.event_id, + "role": 1, + "content": event.body, + "token": client.get_token_length(event.body), + }, + { + "root": event.event_id, + "role": 2, + "content": result, + "token": client.get_token_length(result), + }, + ], + ) + + +client.add_event_callback(message_callback, RoomMessageText) + + +@client.ignore_self_message +@client.handel_no_gpt +@client.log_message +@client.with_typing +@client.replace_command_mark +@client.safe_try +async def message_file(room: MatrixRoom, event: RoomMessageFile): + if not event.flattened().get("content.info.mimetype") == "application/json": + print("not application/json") + return + size = event.flattened().get("content.info.size", 1024 * 1024 + 1) + if size > 1024 * 1024: + print("json file too large") + return + print("event url", event.url) + j = requests.get( + f'https://yongyuancv.cn/_matrix/media/r0/download/yongyuancv.cn/{event.url.rsplit("/", 1)[-1]}' + ).json() + if j.get("chatgpt_api_web_version") is None: + print("not chatgpt-api-web chatstore export file") + return + if j["chatgpt_api_web_version"] < "v1.5.0": + raise ValueError(j["chatgpt_api_web_version"]) + examples = [m["content"] for m in j["history"] if m["example"]] + await client.db.execute( + query=""" + insert into room_configs (room, system, examples) + values (:room_id, :system, :examples) + on conflict (room) do update set system = excluded.system, examples = excluded.examples + """, + values={ + "room_id": room.room_id, + "system": j["systemMessageContent"], + "examples": str(examples), + }, + ) + + await client.room_send( + room.room_id, + "m.reaction", + { + "m.relates_to": { + "event_id": event.event_id, + "key": "😘", + "rel_type": "m.annotation", + } + }, + ) + + +client.add_event_callback(message_file, RoomMessageFile) + +asyncio.run(client.sync_forever()) diff --git a/bot_db.py b/bot_db.py new file mode 100644 index 0000000..a8aa58c --- /dev/null +++ b/bot_db.py @@ -0,0 +1,444 @@ +import PyPDF2 +import html2text +import re +import hashlib +from nio import ( + DownloadError, + MatrixRoom, + RoomMessageAudio, + RoomMessageFile, + RoomMessageText, +) +from langchain.text_splitter import MarkdownTextSplitter +from bot import Bot, print +import asyncio +import io +import yt_dlp +import os +import subprocess +from langchain.embeddings import OpenAIEmbeddings + +from selenium import webdriver + +print("lanuching driver") +options = webdriver.FirefoxOptions() +options.add_argument("-headless") +driver = webdriver.Firefox(options=options) + + +async def get_html(url: str) -> str: + driver.get(url) + await asyncio.sleep(3) + return driver.page_source or "" + + +import openai + +embeddings_model = OpenAIEmbeddings( + openai_api_key=os.environ["OPENAI_API_KEY"], + openai_api_base=os.environ["OPENAI_API_BASE"], + show_progress_bar=True, +) + +client = Bot( + os.environ["BOT_DB_HOMESERVER"], + os.environ["BOT_DB_USER"], + os.environ["MATRIX_CHAIN_DEVICE"], + os.environ["BOT_DB_ACCESS_TOKEN"], +) +client.welcome_message = """欢迎使用 matrix chain db 插件,我能将房间中的所有文件添加进embedding数据库,并为gpt提供支持 +## 使用方式 +- 发送文件或视频链接 + 目前支持文件格式:txt / pdf / md / doc / docx / ppt / pptx + 目前支持视频链接:Bilibili / Youtube +## 配置选项 +- !clean 或 !clear 清除该房间中所有的embedding信息 +- !embedding on 或 !embedding off 开启或关闭房间内embedding功能 (默认关闭)""" + +spliter = MarkdownTextSplitter( + chunk_size=400, + chunk_overlap=100, + length_function=client.get_token_length, +) + +offices_mimetypes = [ + "application/wps-office.docx", + "application/wps-office.doc", + "application/wps-office.pptx", + "application/wps-office.ppt", + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-powerpoint", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "application/vnd.oasis.opendocument.text", + "application/vnd.oasis.opendocument.presentation", +] +mimetypes = [ + "text/plain", + "application/pdf", + "text/markdown", + "text/html", +] + offices_mimetypes + + +def allowed_file(mimetype: str): + return mimetype.lower() in mimetypes + + +async def create_embedding(room, event, md5, content, url): + transaction = await client.db.transaction() + await client.db.execute( + query="""insert into documents (md5, content, token, url) + values (:md5, :content, :token, :url) + on conflict (md5) do nothing + ;""", + values={ + "md5": md5, + "content": content, + "token": client.get_token_length(content), + "url": url, + }, + ) + + rows = await client.db.fetch_all( + query="select document_md5 from room_document where room = :room and document_md5 = :md5 limit 1;", + values={"room": room.room_id, "md5": md5}, + ) + if len(rows) > 0: + await transaction.rollback() + print("document alreadly insert in room", md5, room.room_id) + await client.room_send( + room.room_id, + "m.reaction", + { + "m.relates_to": { + "event_id": event.event_id, + "key": "👍", + "rel_type": "m.annotation", + } + }, + ) + return + + await client.db.execute( + query=""" + insert into room_document (room, document_md5) + values (:room_id, :md5) + on conflict (room, document_md5) do nothing + ;""", + values={"room_id": room.room_id, "md5": md5}, + ) + + # start embedding + chunks = spliter.split_text(content) + print("chunks", len(chunks)) + embeddings = await embeddings_model.aembed_documents(chunks, chunk_size=1600) + print("embedding finished", len(embeddings)) + if len(chunks) != len(embeddings): + raise ValueError("asdf") + insert_data: list[dict] = [] + for chunk, embedding in zip(chunks, embeddings): + insert_data.append( + { + "document_md5": md5, + "md5": hashlib.md5(chunk.encode()).hexdigest(), + "content": chunk, + "token": client.get_token_length(chunk), + "embedding": str(embedding), + } + ) + await client.db.execute_many( + query="""insert into embeddings (document_md5, md5, content, token, embedding) + values (:document_md5, :md5, :content, :token, :embedding) + on conflict (document_md5, md5) do nothing + ;""", + values=insert_data, + ) + print("insert", len(insert_data), "embedding data") + + await client.db.execute( + query=""" + insert into event_document (event, document_md5) + values (:event_id, :md5) + on conflict (event) do nothing + ;""", + values={"event_id": event.event_id, "md5": md5}, + ) + + await transaction.commit() + + await client.room_send( + room.room_id, + "m.reaction", + { + "m.relates_to": { + "event_id": event.event_id, + "key": "😘", + "rel_type": "m.annotation", + } + }, + ) + + +def clean_html(html: str) -> str: + h2t = html2text.HTML2Text() + h2t.ignore_emphasis = True + h2t.ignore_images = True + h2t.ignore_links = True + h2t.body_width = 0 + content = h2t.handle(html) + return content + + +def clean_content(content: str, mimetype: str, document_md5: str) -> str: + # clean 0x00 + content = content.replace("\x00", "") + # clean links + content = re.sub(r"\[.*?\]\(.*?\)", "", content) + content = re.sub(r"!\[.*?\]\(.*?\)", "", content) + # clean lines + lines = [i.strip() for i in content.split("\n\n")] + while "" in lines: + lines.remove("") + + content = "\n\n".join(lines) + content = "\n".join([i.strip() for i in content.split("\n")]) + + return content + + +def pdf_to_text(f) -> str: + pdf_reader = PyPDF2.PdfReader(f) + num_pages = len(pdf_reader.pages) + + content = "" + for page_number in range(num_pages): + page = pdf_reader.pages[page_number] + content += page.extract_text() + return content + + +@client.ignore_self_message +@client.handel_no_gpt +@client.log_message +@client.with_typing +@client.replace_command_mark +@client.safe_try +async def message_file(room: MatrixRoom, event: RoomMessageFile): + print("received file") + mimetype = event.flattened().get("content.info.mimetype", "") + if not allowed_file(mimetype): + print("not allowed file", event.body) + raise ValueError("not allowed file") + resp = await client.download(event.url) + if isinstance(resp, DownloadError): + raise ValueError("file donwload error") + + assert isinstance(resp.body, bytes) + md5 = hashlib.md5(resp.body).hexdigest() + + document_fetch_result = await client.db.execute( + query="select content from documents where md5 = :md5;", values={"md5": md5} + ) + document_alreadly_exists = len(document_fetch_result) == 0 + + # get content + content = "" + # document not exists + if not document_alreadly_exists: + print("document", md5, "alreadly exists") + content = document_fetch_result[0][0] + else: + if mimetype == "text/plain" or mimetype == "text/markdown": + content = resp.body.decode() + elif mimetype == "text/html": + content = clean_html(resp.body.decode()) + elif mimetype == "application/pdf": + f = io.BytesIO(resp.body) + content = pdf_to_text(f) + elif mimetype in offices_mimetypes: + # save file to temp dir + base = event.body.rsplit(".", 1)[0] + ext = event.body.rsplit(".", 1)[1] + print("base", base) + source_filepath = os.path.join("./cache/office", event.body) + txt_filename = base + ".txt" + txt_filepath = os.path.join("./cache/office", txt_filename) + print("source_filepath", source_filepath) + with open(source_filepath, "wb") as f: + f.write(resp.body) + if ext in ["doc", "docx", "odt"]: + process = subprocess.Popen( + [ + "soffice", + "--headless", + "--convert-to", + "txt:Text", + "--outdir", + "./cache/office", + source_filepath, + ] + ) + process.wait() + with open(txt_filepath, "r") as f: + content = f.read() + elif ext in ["ppt", "pptx", "odp"]: + pdf_filename = base + ".pdf" + pdf_filepath = os.path.join("./cache/office", pdf_filename) + process = subprocess.Popen( + [ + "soffice", + "--headless", + "--convert-to", + "pdf", + "--outdir", + "./cache/office", + source_filepath, + ] + ) + process.wait() + with open(pdf_filepath, "rb") as f: + content = pdf_to_text(f) + else: + raise ValueError("unknown ext: ", ext) + print("converted txt", content) + else: + raise ValueError("unknown mimetype", mimetype) + + content = clean_content(content, mimetype, md5) + + print("content length", len(content)) + + await create_embedding(room, event, md5, content, event.url) + + +client.add_event_callback(message_file, RoomMessageFile) + +yt_dlp_support = ["b23.tv/", "www.bilibili.com/video/", "youtube.com/"] + + +def allow_yt_dlp(link: str) -> bool: + if not link.startswith("http://") and not link.startswith("https://"): + return False + allow = False + for u in yt_dlp_support: + if u in link: + allow = True + break + return allow + + +def allow_web(link: str) -> bool: + print("checking web url", link) + if not link.startswith("http://") and not link.startswith("https://"): + return False + return True + + +@client.message_callback_common_wrapper +async def message_text(room: MatrixRoom, event: RoomMessageText) -> None: + if event.body.startswith("!"): + should_react = True + if event.body.startswith("!clear") or event.body.startswith("!clean"): + # save to db + async with client.db.transaction(): + await client.db.execute( + query=""" + delete from embeddings e + using room_document rd + where e.document_md5 = rd.document_md5 and + rd.room = :room_id; + """, + values={"room_id": room.room_id}, + ) + await client.db.execute( + query="delete from room_document where room = :room_id;", + values={"room_id": room.room_id}, + ) + elif event.body.startswith("!embedding"): + sp = event.body.split() + if len(sp) < 2: + return + if not sp[1].lower() in ["on", "off"]: + return + status = sp[1].lower() == "on" + await client.db.execute( + query=""" + insert into room_configs (room, embedding) + values (:room_id, :status) + on conflict (room) do update set embedding = excluded.embedding + ;""", + values={"room_id": room.room_id, "status": status}, + ) + else: + should_react = False + if should_react: + await client.room_send( + room.room_id, + "m.reaction", + { + "m.relates_to": { + "event_id": event.event_id, + "key": "😘", + "rel_type": "m.annotation", + } + }, + ) + return + + if allow_yt_dlp(event.body.split()[0]): + # handle yt-dlp + ydl_opts = { + "format": "wa*", + # ℹ️ See help(yt_dlp.postprocessor) for a list of available Postprocessors and their arguments + "postprocessors": [ + { # Extract audio using ffmpeg + "key": "FFmpegExtractAudio", + #'preferredcodec': 'opus', + #'preferredquality': 64, + } + ], + } + + url = event.body.split()[0] + + info = None + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(url, download=True) + filepath = info["requested_downloads"][0]["filepath"] + filename = info["requested_downloads"][0]["filename"] + title = info["title"] + realfilepath = os.path.join("./cache/yt-dlp", filename) + os.rename(filepath, realfilepath) + + result = openai.Audio.transcribe( + file=open(realfilepath, "rb"), + model="large-v2", + prompt=title, + ) + result = "\n".join([i.text for i in result["segments"]]) + print(event.sender, result) + + md5 = hashlib.md5(result.encode()).hexdigest() + + await create_embedding(room, event, md5, result, url) + return + + if allow_web(event.body.split()[0]): + url = event.body.split()[0] + print("downloading", url) + html = await get_html(url) + md5 = hashlib.md5(html.encode()).hexdigest() + content = clean_html(html) + content = clean_content(content, "text/markdown", md5) + if not content: + raise ValueError("Empty content") + print(content) + await create_embedding(room, event, md5, content, url) + return + + +client.add_event_callback(message_text, RoomMessageText) + +asyncio.run(client.sync_forever()) diff --git a/bot_tts.py b/bot_tts.py new file mode 100644 index 0000000..b796ee5 --- /dev/null +++ b/bot_tts.py @@ -0,0 +1,75 @@ +import asyncio +import os +import aiohttp +import json +import base64 +from langdetect import detect +from bot import Bot, print +from nio import MatrixRoom, RoomMessageText +from io import BytesIO + +url = f"https://texttospeech.googleapis.com/v1/text:synthesize?key={os.environ['GOOGLE_TTS_API_KEY']}" + +client = Bot( + os.environ["BOT_TTS_HOMESERVER"], + os.environ["BOT_TTS_USER"], + os.environ["MATRIX_CHAIN_DEVICE"], + os.environ["BOT_TTS_ACCESS_TOKEN"], +) + + +async def tts(text: str): + lang = detect(text) + langMap = { + "zh-cn": { + "languageCode": "cmn-cn", + "name": "cmn-CN-Wavenet-B", + }, + "en": {"languageCode": "en-US", "name": "en-US-Neural2-F"}, + "ja": {"languageCode": "ja-JP", "name": "ja-JP-Neural2-B"}, + } + voice = langMap.get(lang, langMap["en"]) + async with aiohttp.ClientSession() as session: + payload = { + "input": {"text": text}, + "voice": voice, + "audioConfig": {"audioEncoding": "OGG_OPUS", "speakingRate": 1.39}, + } + headers = {"content-type": "application/json"} + async with session.post(url, data=json.dumps(payload), headers=headers) as resp: + data = await resp.json() + audio_content = data.get("audioContent") + decoded = base64.b64decode(audio_content) + return decoded + + +@client.ignore_self_message +@client.handel_no_gpt +@client.log_message +@client.with_typing +@client.change_event_id_to_root_id +@client.replace_command_mark +@client.safe_try +async def message_callback(room: MatrixRoom, event: RoomMessageText) -> None: + if not event.sender.startswith("@chatgpt-bot"): + return + + audio = await tts(event.body) + # convert + resp, upload = await client.upload(BytesIO(audio), "audio/ogg", filesize=len(audio)) + content = { + "msgtype": "m.audio", + "body": event.body if len(event.body) < 20 else event.body[16] + "...", + "info": {"mimetype": "audio/ogg", "size": len(audio)}, + "url": resp.content_uri, + "m.relates_to": { + "rel_type": "m.thread", + "event_id": event.event_id, + }, + } + await client.room_send(room.room_id, message_type="m.room.message", content=content) + + +client.add_event_callback(message_callback, RoomMessageText) + +asyncio.run(client.sync_forever()) diff --git a/bot_whisper.py b/bot_whisper.py new file mode 100644 index 0000000..120afb8 --- /dev/null +++ b/bot_whisper.py @@ -0,0 +1,169 @@ +import os +from nio import DownloadError, MatrixRoom, RoomMessageAudio, RoomMessageFile +import asyncio +import openai +import io +from bot import Bot, print + + +client = Bot( + os.environ["BOT_WHISPER_HOMESERVER"], + os.environ["BOT_WHISPER_USER"], + os.environ["MATRIX_CHAIN_DEVICE"], + os.environ["BOT_WHISPER_ACCESS_TOKEN"], +) +client.welcome_message = ( + """欢迎使用 matrix chain whisper 插件,我能将房间中的语音消息转换成文字发出,如果语音过长,我会用文件形式发出""" +) + + +@client.message_callback_common_wrapper +async def message_callback(room: MatrixRoom, event: RoomMessageAudio): + print("received message") + print(event.flattened()) + if event.flattened().get("content.info.duration", 0) > 1000 * 60 * 5: + return await message_file(room, event) + if event.source.get("content", {}).get("org.matrix.msc1767.audio") is None: + # handle audio file + return await message_file(room, event) + resp = await client.download(event.url) + if isinstance(resp, DownloadError): + return + + filelikeobj = io.BytesIO(resp.body) + filelikeobj.name = "matrixaudio.ogg" + # get prompt + rows = await client.db.execute( + query="""select content from ( + select role, content, sum(token) over (partition by root order by id desc) as total_token + from memories + where root = :event_id + order by id + ) as sub + where total_token < 3039 + ;""", + values={"event_id": event.event_id}, + ) + prompt = "".join([i[0] for i in rows]) + # no memory + if not prompt: + db_result = await client.db.fetch_all( + query="select system, examples from room_configs where room = :room_id;", + values={"room_id": room.room_id}, + ) + if len(db_result) > 0: + systemMessageContent = db_result[0][0] + examples = [ + m.get("content", "") for m in db_result[0][1] if m.get("example") + ] + while "" in examples: + examples.remove("") + if systemMessageContent: + prompt += systemMessageContent + "\n\n" + if len(examples) > 0: + prompt += "\n\n".join(examples) + + print("initial_prompt", prompt) + + result = openai.Audio.transcribe(file=filelikeobj, model="large-v2", prompt=prompt) + result = "\n".join([i.text for i in result["segments"]]) + print(event.sender, result) + + await client.room_send( + room.room_id, + "m.room.message", + { + "body": result, + "msgtype": "m.text", + "m.relates_to": { + "rel_type": "m.thread", + "event_id": event.event_id, + }, + }, + ) + + +client.add_event_callback(message_callback, RoomMessageAudio) + +ALLOWED_EXTENSIONS = { + "mp3", + "mp4", + "mpeg", + "mpga", + "m4a", + "wav", + "webm", + "3gp", + "flac", + "ogg", + "mkv", +} + + +def allowed_file(mimetype): + return "/" in mimetype and mimetype.rsplit("/", 1)[1].lower() in ALLOWED_EXTENSIONS + + +def get_txt_filename(filename): + return filename + ".txt" + + +async def message_file(room: MatrixRoom, event: RoomMessageFile): + print("received file") + if not allowed_file(event.flattened().get("content.info.mimetype")): + print("not allowed file", event.body) + raise Exception("not allowed file") + resp = await client.download(event.url) + if isinstance(resp, DownloadError): + return + filelikeobj = io.BytesIO(resp.body) + filelikeobj.name = event.body + + # get prompt + rows = await client.db.execute( + query=""" + select content from ( + select role, content, sum(token) over (partition by root order by id desc) as total_token + from memories + where root = :event_id + order by id + ) as sub + where total_token < 3039 + ;""", + values={"event_id": event.event_id}, + ) + prompt = "".join([i[0] for i in rows]) + print("initial_prompt", prompt) + + result = openai.Audio.transcribe(file=filelikeobj, model="large-v2", prompt=prompt) + result = "\n".join([i.text for i in result["segments"]]) + print(event.sender, result) + resultfilelike = io.BytesIO(result.encode()) + resultfilelike.name = get_txt_filename(event.body) + resultfileSize = len(result.encode()) + uploadResp, _ = await client.upload( + resultfilelike, content_type="text/plain", filesize=resultfileSize + ) + print("uri", uploadResp.content_uri) + + await client.room_send( + room.room_id, + "m.room.message", + { + "body": resultfilelike.name, + "filename": resultfilelike.name, + "msgtype": "m.file", + "info": { + "mimetype": "text/plain", + "size": resultfileSize, + }, + "m.relates_to": { + "rel_type": "m.thread", + "event_id": event.event_id, + }, + "url": uploadResp.content_uri, + }, + ) + + +asyncio.run(client.sync_forever()) diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..55ca815 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,25 @@ +version: '3' +services: + bot-chatgpt: + image: matrix-chain + env_file: + - ./.env + command: python3 bot_chatgpt.py + + bot-db: + image: matrix-chain + env_file: + - ./.env + command: python3 bot_db.py + + bot-whisper: + image: matrix-chain + env_file: + - ./.env + command: python3 bot_whisper.py + + bot-tts: + image: matrix-chain + env_file: + - ./.env + command: python3 bot_tts.py \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..31038d5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +databases +aiopg +matrix-nio +langchain +python-dotenv +tiktoken +PyPDF2 +html2text +yt_dlp +selenium +openai +langdetect +jinja2 diff --git a/requirements_version.txt b/requirements_version.txt new file mode 100644 index 0000000..cc2e13f --- /dev/null +++ b/requirements_version.txt @@ -0,0 +1,72 @@ +aiofiles==23.2.1 +aiohttp==3.8.6 +aiohttp-socks==0.7.1 +aiopg==1.4.0 +aiosignal==1.3.1 +annotated-types==0.6.0 +anyio==3.7.1 +async-timeout==4.0.3 +attrs==23.1.0 +Brotli==1.1.0 +certifi==2023.7.22 +charset-normalizer==3.3.0 +databases==0.8.0 +dataclasses-json==0.6.1 +frozenlist==1.4.0 +greenlet==3.0.0 +h11==0.14.0 +h2==4.1.0 +hpack==4.0.0 +html2text==2020.1.16 +hyperframe==6.0.1 +idna==3.4 +Jinja2==3.1.2 +jsonpatch==1.33 +jsonpointer==2.4 +jsonschema==4.19.1 +jsonschema-specifications==2023.7.1 +langchain==0.0.319 +langdetect==1.0.9 +langsmith==0.0.49 +MarkupSafe==2.1.3 +marshmallow==3.20.1 +matrix-nio==0.22.1 +multidict==6.0.4 +mutagen==1.47.0 +mypy-extensions==1.0.0 +numpy==1.26.1 +openai==0.28.1 +outcome==1.3.0 +packaging==23.2 +psycopg2-binary==2.9.9 +pycryptodome==3.19.0 +pycryptodomex==3.19.0 +pydantic==2.4.2 +pydantic_core==2.10.1 +PyPDF2==3.0.1 +PySocks==1.7.1 +python-dotenv==1.0.0 +python-socks==2.4.3 +PyYAML==6.0.1 +referencing==0.30.2 +regex==2023.10.3 +requests==2.31.0 +rpds-py==0.10.6 +selenium==4.14.0 +six==1.16.0 +sniffio==1.3.0 +sortedcontainers==2.4.0 +SQLAlchemy==1.4.49 +tenacity==8.2.3 +tiktoken==0.5.1 +tqdm==4.66.1 +trio==0.22.2 +trio-websocket==0.11.1 +typing-inspect==0.9.0 +typing_extensions==4.8.0 +unpaddedbase64==2.1.0 +urllib3==2.0.7 +websockets==11.0.3 +wsproto==1.2.0 +yarl==1.9.2 +yt-dlp==2023.10.13