init
This commit is contained in:
4
.dockerignore
Normal file
4
.dockerignore
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__pycache__
|
||||||
|
/venv
|
||||||
|
/.env
|
||||||
|
.git
|
||||||
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__pycache__
|
||||||
|
/venv
|
||||||
|
/.env
|
||||||
|
.git
|
||||||
19
Dockerfile
Normal file
19
Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
FROM python:3.11-slim-buster
|
||||||
|
|
||||||
|
# Update the system
|
||||||
|
RUN apt-get update -y
|
||||||
|
|
||||||
|
# Install LibreOffice, Firefox ESR and pip
|
||||||
|
RUN apt-get install -y libreoffice firefox-esr
|
||||||
|
|
||||||
|
# Set the working directory in the container to /app
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY ./requirements_version.txt /app/requirements_version.txt
|
||||||
|
RUN pip3 install --no-cache-dir -r requirements_version.txt
|
||||||
|
|
||||||
|
# cache tiktoken dict
|
||||||
|
RUN python3 -c 'import tiktoken; enc = tiktoken.get_encoding("cl100k_base")'
|
||||||
|
|
||||||
|
# Add the current directory contents into the container at /app
|
||||||
|
COPY . /app
|
||||||
315
bot.py
Normal file
315
bot.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
import dotenv
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from functools import wraps
|
||||||
|
from nio import AsyncClient, MatrixRoom, RoomMessageText, SyncError, InviteEvent
|
||||||
|
import tiktoken
|
||||||
|
import databases
|
||||||
|
import builtins
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
kwargs["file"] = sys.stderr
|
||||||
|
builtins.print(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Bot(AsyncClient):
|
||||||
|
def __init__(self, homeserver: str, user: str, device_id: str, access_token: str):
|
||||||
|
super().__init__(homeserver)
|
||||||
|
self.access_token = access_token
|
||||||
|
self.user_id = self.user = user
|
||||||
|
self.device_id = device_id
|
||||||
|
self.welcome_message = ""
|
||||||
|
self._joined_rooms = []
|
||||||
|
self.db = databases.Database(os.environ["MATRIX_CHAIN_DB"])
|
||||||
|
|
||||||
|
self.enc = tiktoken.encoding_for_model("gpt-4")
|
||||||
|
|
||||||
|
# auto join
|
||||||
|
self.add_event_callback(self.auto_join, InviteEvent)
|
||||||
|
|
||||||
|
async def init_db(self):
|
||||||
|
db = self.db
|
||||||
|
await db.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS documents
|
||||||
|
(
|
||||||
|
md5 character(32) NOT NULL PRIMARY KEY,
|
||||||
|
content text,
|
||||||
|
token integer,
|
||||||
|
url text
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS embeddings
|
||||||
|
(
|
||||||
|
document_md5 character(32) NOT NULL,
|
||||||
|
md5 character(32) NOT NULL,
|
||||||
|
content text NOT NULL,
|
||||||
|
token integer NOT NULL,
|
||||||
|
embedding vector(1536) NOT NULL,
|
||||||
|
PRIMARY KEY (document_md5, md5),
|
||||||
|
FOREIGN KEY (document_md5) REFERENCES documents(md5)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS event_document
|
||||||
|
(
|
||||||
|
event text NOT NULL PRIMARY KEY,
|
||||||
|
document_md5 character(32) NOT NULL,
|
||||||
|
FOREIGN KEY (document_md5)
|
||||||
|
REFERENCES documents (md5)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS memories
|
||||||
|
(
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
root text NOT NULL,
|
||||||
|
role integer NOT NULL,
|
||||||
|
content text NOT NULL,
|
||||||
|
token integer NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS room_configs
|
||||||
|
(
|
||||||
|
room text NOT NULL PRIMARY KEY,
|
||||||
|
model_name text,
|
||||||
|
temperature float NOT NULL DEFAULT 0,
|
||||||
|
system text,
|
||||||
|
embedding boolean NOT NULL DEFAULT false,
|
||||||
|
examples TEXT[] NOT NULL DEFAULT '{}'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS room_document (
|
||||||
|
room text NOT NULL,
|
||||||
|
document_md5 character(32) NOT NULL,
|
||||||
|
PRIMARY KEY (room, document_md5)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_token_length(self, text: str) -> int:
|
||||||
|
return len(self.enc.encode(text))
|
||||||
|
|
||||||
|
async def sync_forever(self):
|
||||||
|
# init
|
||||||
|
print("connecting to db")
|
||||||
|
await self.db.connect()
|
||||||
|
|
||||||
|
# init db hook
|
||||||
|
print("init db hook")
|
||||||
|
await self.init_db()
|
||||||
|
|
||||||
|
# remote callback to perform initial sync
|
||||||
|
callbacks = self.event_callbacks
|
||||||
|
self.event_callbacks = []
|
||||||
|
# skip intial sync
|
||||||
|
print("Perform initial sync")
|
||||||
|
resp = await self.sync(timeout=30000)
|
||||||
|
if isinstance(resp, SyncError):
|
||||||
|
raise BaseException(SyncError)
|
||||||
|
self.event_callbacks = callbacks
|
||||||
|
# set online
|
||||||
|
print("Set online status")
|
||||||
|
await self.set_presence("online")
|
||||||
|
# sync
|
||||||
|
print("Start forever sync")
|
||||||
|
return await super().sync_forever(300000, since=resp.next_batch)
|
||||||
|
|
||||||
|
async def auto_join(self, room: MatrixRoom, event: InviteEvent):
|
||||||
|
print("join", event.sender, room.room_id)
|
||||||
|
if room.room_id in self._joined_rooms:
|
||||||
|
return
|
||||||
|
await self.join(room.room_id)
|
||||||
|
self._joined_rooms.append(room.room_id)
|
||||||
|
if self.welcome_message:
|
||||||
|
await self.room_send(
|
||||||
|
room.room_id,
|
||||||
|
"m.room.message",
|
||||||
|
{
|
||||||
|
"body": self.welcome_message,
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"nogpt": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def ignore_self_message(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room: MatrixRoom, event: RoomMessageText):
|
||||||
|
if event.sender == self.user:
|
||||||
|
return
|
||||||
|
return await func(room, event)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def log_message(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room: MatrixRoom, event: RoomMessageText):
|
||||||
|
print(room.room_id, event.sender, event.body)
|
||||||
|
return await func(room, event)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def with_typing(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room, *args, **kargs):
|
||||||
|
await self.room_typing(room.room_id, True, 60000 * 3)
|
||||||
|
resp = await func(room, *args, **kargs)
|
||||||
|
await self.room_typing(room.room_id, False)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def change_event_id_to_root_id(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room, event, *args, **kargs):
|
||||||
|
root = event.event_id
|
||||||
|
if event.flattened().get("content.m.relates_to.rel_type") == "m.thread":
|
||||||
|
root = event.source["content"]["m.relates_to"]["event_id"]
|
||||||
|
event.event_id = root
|
||||||
|
return await func(room, event, *args, **kargs)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def ignore_not_mentioned(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room, event, *args, **kargs):
|
||||||
|
flattened = event.flattened()
|
||||||
|
if not self.user in flattened.get(
|
||||||
|
"content.body", ""
|
||||||
|
) and not self.user in flattened.get("content.formatted_body", ""):
|
||||||
|
return
|
||||||
|
return await func(room, event, *args, **kargs)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def replace_command_mark(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room, event, *args, **kargs):
|
||||||
|
if event.body.startswith("!"):
|
||||||
|
event.body = "!" + event.body[1:]
|
||||||
|
return await func(room, event, *args, **kargs)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def handel_no_gpt(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room, event, *args, **kargs):
|
||||||
|
if not event.flattened().get("content.nogpt") is None:
|
||||||
|
return
|
||||||
|
return await func(room, event, *args, **kargs)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def replace_reply_file_with_content(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room, event, *args, **kargs):
|
||||||
|
flattened = event.flattened()
|
||||||
|
formatted_body = flattened.get("content.formatted_body", "")
|
||||||
|
if (
|
||||||
|
formatted_body.startswith("<mx-reply>")
|
||||||
|
and "</mx-reply>" in formatted_body
|
||||||
|
):
|
||||||
|
print("replacing file content")
|
||||||
|
formatted_body = formatted_body[
|
||||||
|
formatted_body.index("</mx-reply>") + len("</mx-reply>") :
|
||||||
|
]
|
||||||
|
document_event_id = flattened.get(
|
||||||
|
"content.m.relates_to.m.in_reply_to.event_id", ""
|
||||||
|
)
|
||||||
|
fetch = await self.db.fetch_all(
|
||||||
|
query="""select d.content, d.token
|
||||||
|
from documents d
|
||||||
|
join event_document ed on d.md5 = ed.document_md5
|
||||||
|
where ed.event = :event_id""",
|
||||||
|
values={"event_id": document_event_id},
|
||||||
|
)
|
||||||
|
if len(fetch) > 0 and fetch[0][1] < 8192 + 4096:
|
||||||
|
content = fetch[0][0]
|
||||||
|
print(content)
|
||||||
|
print("-----------")
|
||||||
|
event.body = content + "\n\n---\n\n" + formatted_body
|
||||||
|
else:
|
||||||
|
print("document not found or too large", event.event_id)
|
||||||
|
|
||||||
|
return await func(room, event, *args, **kargs)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def ignore_link(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room, event, *args, **kargs):
|
||||||
|
if event.body.startswith("https://") or event.body.startswith("http://"):
|
||||||
|
return
|
||||||
|
return await func(room, event, *args, **kargs)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def safe_try(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
async def ret(room, event, *args, **kargs):
|
||||||
|
try:
|
||||||
|
return await func(room, event, *args, **kargs)
|
||||||
|
except Exception as e:
|
||||||
|
print("--------------")
|
||||||
|
print("error:")
|
||||||
|
print(traceback.format_exc())
|
||||||
|
print("--------------")
|
||||||
|
await self.room_send(
|
||||||
|
room.room_id,
|
||||||
|
"m.reaction",
|
||||||
|
{
|
||||||
|
"m.relates_to": {
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"key": "😵",
|
||||||
|
"rel_type": "m.annotation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def message_callback_common_wrapper(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
@self.ignore_self_message
|
||||||
|
@self.handel_no_gpt
|
||||||
|
@self.log_message
|
||||||
|
@self.with_typing
|
||||||
|
@self.replace_reply_file_with_content
|
||||||
|
@self.change_event_id_to_root_id
|
||||||
|
@self.replace_command_mark
|
||||||
|
@self.safe_try
|
||||||
|
async def ret(*args, **kargs):
|
||||||
|
return await func(*args, **kargs)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def react_ok(self, room_id: str, event_id: str):
|
||||||
|
await self.room_send(
|
||||||
|
room_id,
|
||||||
|
"m.reaction",
|
||||||
|
{
|
||||||
|
"m.relates_to": {
|
||||||
|
"event_id": event_id,
|
||||||
|
"key": "😘",
|
||||||
|
"rel_type": "m.annotation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
401
bot_chatgpt.py
Normal file
401
bot_chatgpt.py
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
import os
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
import asyncio
|
||||||
|
import jinja2
|
||||||
|
import requests
|
||||||
|
import datetime
|
||||||
|
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from nio import MatrixRoom, RoomMessageFile, RoomMessageText
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
import json
|
||||||
|
from langchain import LLMChain
|
||||||
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
from bot import Bot, print
|
||||||
|
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings, awa
|
||||||
|
|
||||||
|
embeddings_model = OpenAIEmbeddings(
|
||||||
|
openai_api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
openai_api_base=os.environ["OPENAI_API_BASE"],
|
||||||
|
show_progress_bar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Bot(
|
||||||
|
os.environ["BOT_CHATGPT_HOMESERVER"],
|
||||||
|
os.environ["BOT_CHATGPT_USER"],
|
||||||
|
os.environ["MATRIX_CHAIN_DEVICE"],
|
||||||
|
os.environ["BOT_CHATGPT_ACCESS_TOKEN"],
|
||||||
|
)
|
||||||
|
client.welcome_message = """你好👋,我是 matrix chain 中的大语言模型插件
|
||||||
|
## 使用方式:
|
||||||
|
- 直接在房间内发送消息,GPT 会在消息列中进行回复。GPT 会单独记住每个消息列中的所有内容,每个消息列单独存在互不干扰
|
||||||
|
## 配置方式:
|
||||||
|
- 发送 "!system + 系统消息" 配置大语言模型的角色,例如发送 "!system 你是一个专业英语翻译,你要把我说的话翻译成英语。你可以调整语序结构和用词让翻译更加通顺。"
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SilentUndefined(jinja2.Undefined):
|
||||||
|
def _fail_with_undefined_error(self, *args, **kwargs):
|
||||||
|
print(f'jinja2.Undefined: "{self._undefined_name}" is undefined')
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def render(template: str, **kargs) -> str:
|
||||||
|
env = jinja2.Environment(undefined=SilentUndefined)
|
||||||
|
temp = env.from_string(template)
|
||||||
|
|
||||||
|
def now() -> str:
|
||||||
|
return datetime.datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
temp.globals["now"] = now
|
||||||
|
return temp.render(**kargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_reply_file_content(event):
|
||||||
|
"""When user reply to a event, retrive the file content (document) of event
|
||||||
|
|
||||||
|
Return with the file content and token length
|
||||||
|
"""
|
||||||
|
flattened = event.flattened()
|
||||||
|
formatted_body = flattened.get("content.formatted_body", "")
|
||||||
|
if not (
|
||||||
|
formatted_body.startswith("<mx-reply>") and "</mx-reply>" in formatted_body
|
||||||
|
):
|
||||||
|
return "", 0
|
||||||
|
|
||||||
|
print("replacing file content")
|
||||||
|
formatted_body = formatted_body[
|
||||||
|
formatted_body.index("</mx-reply>") + len("</mx-reply>") :
|
||||||
|
]
|
||||||
|
document_event_id = flattened.get("content.m.relates_to.m.in_reply_to.event_id", "")
|
||||||
|
fetch = await client.db.fetch_one(
|
||||||
|
query="""select d.content, d.token
|
||||||
|
from documents d
|
||||||
|
join event_document ed on d.md5 = ed.document_md5
|
||||||
|
where ed.event = :document_event_id""",
|
||||||
|
values={
|
||||||
|
"document_event_id": document_event_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if fetch and fetch[1] < 8192 + 4096:
|
||||||
|
content = fetch[0]
|
||||||
|
token = fetch[1]
|
||||||
|
print(content)
|
||||||
|
print(token)
|
||||||
|
print("-----------")
|
||||||
|
return content, token
|
||||||
|
|
||||||
|
print("document not found or too large", event.event_id)
|
||||||
|
return "", 0
|
||||||
|
|
||||||
|
|
||||||
|
@client.ignore_link
|
||||||
|
@client.message_callback_common_wrapper
|
||||||
|
async def message_callback(room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
|
# handle set system message
|
||||||
|
if event.body.startswith("!"):
|
||||||
|
if event.body.startswith("!system"):
|
||||||
|
systemMessageContent = event.body.lstrip("!system").strip()
|
||||||
|
# save to db
|
||||||
|
await client.db.execute(
|
||||||
|
query="""
|
||||||
|
insert into room_configs (room, system, examples)
|
||||||
|
values (:room_id, :systemMessageContent, '{}')
|
||||||
|
on conflict (room)
|
||||||
|
do update set system = excluded.system, examples = '{}'
|
||||||
|
""",
|
||||||
|
values={
|
||||||
|
"room_id": room.room_id,
|
||||||
|
"systemMessageContent": systemMessageContent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await client.react_ok(room.room_id, event.event_id)
|
||||||
|
return
|
||||||
|
if event.body.startswith("!model"):
|
||||||
|
model_name = event.body.lstrip("!model").strip()
|
||||||
|
# save to db
|
||||||
|
await client.db.execute(
|
||||||
|
query="""
|
||||||
|
insert into room_configs (room, model_name)
|
||||||
|
values (:room_id, :model_name)
|
||||||
|
on conflict (room)
|
||||||
|
do update set model_name = excluded.model_name
|
||||||
|
""",
|
||||||
|
values={"room_id": room.room_id, "model_name": model_name},
|
||||||
|
)
|
||||||
|
await client.react_ok(room.room_id, event.event_id)
|
||||||
|
return
|
||||||
|
if event.body.startswith("!temp"):
|
||||||
|
temperature = float(event.body.lstrip("!temp").strip())
|
||||||
|
# save to db
|
||||||
|
await client.db.execute(
|
||||||
|
query="""
|
||||||
|
insert into room_configs (room, temperature)
|
||||||
|
values (:room_id, :temperature)
|
||||||
|
on conflict (room)
|
||||||
|
do update set temperature = excluded.temperature
|
||||||
|
""",
|
||||||
|
values={"room_id": room.room_id, "temperature": temperature},
|
||||||
|
)
|
||||||
|
await client.react_ok(room.room_id, event.event_id)
|
||||||
|
return
|
||||||
|
return
|
||||||
|
|
||||||
|
messages: list[BaseMessage] = []
|
||||||
|
# query prompt from db
|
||||||
|
db_result = await client.db.fetch_one(
|
||||||
|
query="""
|
||||||
|
select system, examples, model_name, temperature
|
||||||
|
from room_configs
|
||||||
|
where room = :room_id
|
||||||
|
limit 1
|
||||||
|
""",
|
||||||
|
values={"room_id": room.room_id},
|
||||||
|
)
|
||||||
|
model_name: str = db_result[2] if db_result else ""
|
||||||
|
temperature: float = db_result[3] if db_result else 0
|
||||||
|
|
||||||
|
systemMessageContent: str = db_result[0] if db_result else ""
|
||||||
|
systemMessageContent = systemMessageContent or ""
|
||||||
|
if systemMessageContent:
|
||||||
|
messages.append(SystemMessage(content=systemMessageContent))
|
||||||
|
|
||||||
|
examples = db_result[1] if db_result else []
|
||||||
|
for i, m in enumerate(examples):
|
||||||
|
if not m:
|
||||||
|
print("Warning: message is empty", m)
|
||||||
|
continue
|
||||||
|
if i % 2 == 0:
|
||||||
|
messages.append(HumanMessage(content=m["content"], example=True))
|
||||||
|
else:
|
||||||
|
messages.append(AIMessage(content=m["content"], example=True))
|
||||||
|
|
||||||
|
exampleTokens = 0
|
||||||
|
exampleTokens += sum(client.get_token_length(m) for m in examples if m)
|
||||||
|
|
||||||
|
# get embedding
|
||||||
|
embedding_query = await client.db.fetch_all(
|
||||||
|
query="""
|
||||||
|
select content, distance, total_token from (
|
||||||
|
select
|
||||||
|
content,
|
||||||
|
document_md5,
|
||||||
|
distance,
|
||||||
|
sum(token) over (partition by room order by distance) as total_token
|
||||||
|
from (
|
||||||
|
select
|
||||||
|
content,
|
||||||
|
rd.room,
|
||||||
|
e.document_md5,
|
||||||
|
e.embedding <#> :embedding as distance,
|
||||||
|
token
|
||||||
|
from embeddings e
|
||||||
|
join room_document rd on rd.document_md5 = e.document_md5
|
||||||
|
join room_configs rc on rc.room = rd.room
|
||||||
|
where rd.room = :room_id and rc.embedding
|
||||||
|
order by distance
|
||||||
|
limit 16
|
||||||
|
) as sub
|
||||||
|
) as sub2
|
||||||
|
where total_token < 6144
|
||||||
|
;""",
|
||||||
|
values={
|
||||||
|
"embedding": str(await embeddings_model.aembed_query(event.body)),
|
||||||
|
"room_id": room.room_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print("emebdding_query", embedding_query)
|
||||||
|
embedding_token = 0
|
||||||
|
embedding_text = ""
|
||||||
|
if len(embedding_query) > 0:
|
||||||
|
embedding_query.reverse()
|
||||||
|
embedding_text = "\n\n".join([i[0] for i in embedding_query])
|
||||||
|
embedding_token = client.get_token_length(embedding_text)
|
||||||
|
|
||||||
|
filecontent, filetoken = await get_reply_file_content(event)
|
||||||
|
|
||||||
|
# query memory from db
|
||||||
|
max_token = 4096 * 4
|
||||||
|
token_margin = 4096
|
||||||
|
system_token = client.get_token_length(systemMessageContent) + exampleTokens
|
||||||
|
memory_token = max_token - token_margin - system_token - embedding_token - filetoken
|
||||||
|
print(
|
||||||
|
"system_token",
|
||||||
|
system_token,
|
||||||
|
"emebdding_token",
|
||||||
|
embedding_token,
|
||||||
|
"filetoken",
|
||||||
|
filetoken,
|
||||||
|
)
|
||||||
|
rows = await client.db.fetch_all(
|
||||||
|
query="""select role, content from (
|
||||||
|
select role, content, sum(token) over (partition by root order by id desc) as total_token
|
||||||
|
from memories
|
||||||
|
where root = :root
|
||||||
|
order by id
|
||||||
|
) as sub
|
||||||
|
where total_token < :token
|
||||||
|
;""",
|
||||||
|
values={"root": event.event_id, "token": memory_token},
|
||||||
|
)
|
||||||
|
for role, content in rows:
|
||||||
|
if role == 1:
|
||||||
|
messages.append(HumanMessage(content=content))
|
||||||
|
elif role == 2:
|
||||||
|
messages.append(AIMessage(content=content))
|
||||||
|
else:
|
||||||
|
print("Unknown message role", role, content)
|
||||||
|
|
||||||
|
temp = "{{input}}"
|
||||||
|
if filecontent and embedding_text:
|
||||||
|
temp = """## Reference information:
|
||||||
|
|
||||||
|
{{embedding}}
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Query document:
|
||||||
|
|
||||||
|
{{filecontent}}
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
{{input}}"""
|
||||||
|
elif embedding_text:
|
||||||
|
temp = """## Reference information:
|
||||||
|
|
||||||
|
{{embedding}}
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
{{input}}"""
|
||||||
|
elif filecontent:
|
||||||
|
temp = """ ## Query document:
|
||||||
|
|
||||||
|
{{filecontent}}
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
{{input}}"""
|
||||||
|
temp = render(
|
||||||
|
temp, input=event.body, embedding=embedding_text, filecontent=filecontent
|
||||||
|
)
|
||||||
|
messages.append(HumanMessage(content=temp))
|
||||||
|
|
||||||
|
total_token = (
|
||||||
|
sum(client.get_token_length(m.content) for m in messages) + len(messages) * 6
|
||||||
|
)
|
||||||
|
if not model_name:
|
||||||
|
model_name = "gpt-3.5-turbo" if total_token < 3939 else "gpt-3.5-turbo-16k"
|
||||||
|
|
||||||
|
print("messages", messages)
|
||||||
|
chat_model = ChatOpenAI(
|
||||||
|
openai_api_base=os.environ["OPENAI_API_BASE"],
|
||||||
|
openai_api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
model=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
chain = LLMChain(llm=chat_model, prompt=ChatPromptTemplate.from_messages(messages))
|
||||||
|
|
||||||
|
result = await chain.arun(
|
||||||
|
{
|
||||||
|
"input": event.body,
|
||||||
|
"embedding": embedding_text,
|
||||||
|
"filecontent": filecontent,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
await client.room_send(
|
||||||
|
room.room_id,
|
||||||
|
"m.room.message",
|
||||||
|
{
|
||||||
|
"body": result,
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": "m.thread",
|
||||||
|
"event_id": event.event_id,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# record query and result
|
||||||
|
await client.db.execute_many(
|
||||||
|
query="insert into memories(root, role, content, token) values (:root, :role, :content, :token)",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"root": event.event_id,
|
||||||
|
"role": 1,
|
||||||
|
"content": event.body,
|
||||||
|
"token": client.get_token_length(event.body),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"root": event.event_id,
|
||||||
|
"role": 2,
|
||||||
|
"content": result,
|
||||||
|
"token": client.get_token_length(result),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
client.add_event_callback(message_callback, RoomMessageText)
|
||||||
|
|
||||||
|
|
||||||
|
@client.ignore_self_message
|
||||||
|
@client.handel_no_gpt
|
||||||
|
@client.log_message
|
||||||
|
@client.with_typing
|
||||||
|
@client.replace_command_mark
|
||||||
|
@client.safe_try
|
||||||
|
async def message_file(room: MatrixRoom, event: RoomMessageFile):
|
||||||
|
if not event.flattened().get("content.info.mimetype") == "application/json":
|
||||||
|
print("not application/json")
|
||||||
|
return
|
||||||
|
size = event.flattened().get("content.info.size", 1024 * 1024 + 1)
|
||||||
|
if size > 1024 * 1024:
|
||||||
|
print("json file too large")
|
||||||
|
return
|
||||||
|
print("event url", event.url)
|
||||||
|
j = requests.get(
|
||||||
|
f'https://yongyuancv.cn/_matrix/media/r0/download/yongyuancv.cn/{event.url.rsplit("/", 1)[-1]}'
|
||||||
|
).json()
|
||||||
|
if j.get("chatgpt_api_web_version") is None:
|
||||||
|
print("not chatgpt-api-web chatstore export file")
|
||||||
|
return
|
||||||
|
if j["chatgpt_api_web_version"] < "v1.5.0":
|
||||||
|
raise ValueError(j["chatgpt_api_web_version"])
|
||||||
|
examples = [m["content"] for m in j["history"] if m["example"]]
|
||||||
|
await client.db.execute(
|
||||||
|
query="""
|
||||||
|
insert into room_configs (room, system, examples)
|
||||||
|
values (:room_id, :system, :examples)
|
||||||
|
on conflict (room) do update set system = excluded.system, examples = excluded.examples
|
||||||
|
""",
|
||||||
|
values={
|
||||||
|
"room_id": room.room_id,
|
||||||
|
"system": j["systemMessageContent"],
|
||||||
|
"examples": str(examples),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.room_send(
|
||||||
|
room.room_id,
|
||||||
|
"m.reaction",
|
||||||
|
{
|
||||||
|
"m.relates_to": {
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"key": "😘",
|
||||||
|
"rel_type": "m.annotation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
client.add_event_callback(message_file, RoomMessageFile)
|
||||||
|
|
||||||
|
asyncio.run(client.sync_forever())
|
||||||
444
bot_db.py
Normal file
444
bot_db.py
Normal file
@@ -0,0 +1,444 @@
|
|||||||
|
import PyPDF2
|
||||||
|
import html2text
|
||||||
|
import re
|
||||||
|
import hashlib
|
||||||
|
from nio import (
|
||||||
|
DownloadError,
|
||||||
|
MatrixRoom,
|
||||||
|
RoomMessageAudio,
|
||||||
|
RoomMessageFile,
|
||||||
|
RoomMessageText,
|
||||||
|
)
|
||||||
|
from langchain.text_splitter import MarkdownTextSplitter
|
||||||
|
from bot import Bot, print
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import yt_dlp
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
from selenium import webdriver
|
||||||
|
|
||||||
|
print("lanuching driver")
|
||||||
|
options = webdriver.FirefoxOptions()
|
||||||
|
options.add_argument("-headless")
|
||||||
|
driver = webdriver.Firefox(options=options)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_html(url: str) -> str:
|
||||||
|
driver.get(url)
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
return driver.page_source or ""
|
||||||
|
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
embeddings_model = OpenAIEmbeddings(
|
||||||
|
openai_api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
openai_api_base=os.environ["OPENAI_API_BASE"],
|
||||||
|
show_progress_bar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Bot(
|
||||||
|
os.environ["BOT_DB_HOMESERVER"],
|
||||||
|
os.environ["BOT_DB_USER"],
|
||||||
|
os.environ["MATRIX_CHAIN_DEVICE"],
|
||||||
|
os.environ["BOT_DB_ACCESS_TOKEN"],
|
||||||
|
)
|
||||||
|
client.welcome_message = """欢迎使用 matrix chain db 插件,我能将房间中的所有文件添加进embedding数据库,并为gpt提供支持
|
||||||
|
## 使用方式
|
||||||
|
- 发送文件或视频链接
|
||||||
|
目前支持文件格式:txt / pdf / md / doc / docx / ppt / pptx
|
||||||
|
目前支持视频链接:Bilibili / Youtube
|
||||||
|
## 配置选项
|
||||||
|
- !clean 或 !clear 清除该房间中所有的embedding信息
|
||||||
|
- !embedding on 或 !embedding off 开启或关闭房间内embedding功能 (默认关闭)"""
|
||||||
|
|
||||||
|
spliter = MarkdownTextSplitter(
|
||||||
|
chunk_size=400,
|
||||||
|
chunk_overlap=100,
|
||||||
|
length_function=client.get_token_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
offices_mimetypes = [
|
||||||
|
"application/wps-office.docx",
|
||||||
|
"application/wps-office.doc",
|
||||||
|
"application/wps-office.pptx",
|
||||||
|
"application/wps-office.ppt",
|
||||||
|
"application/msword",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.template",
|
||||||
|
"application/vnd.ms-powerpoint",
|
||||||
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
|
"application/vnd.oasis.opendocument.text",
|
||||||
|
"application/vnd.oasis.opendocument.presentation",
|
||||||
|
]
|
||||||
|
mimetypes = [
|
||||||
|
"text/plain",
|
||||||
|
"application/pdf",
|
||||||
|
"text/markdown",
|
||||||
|
"text/html",
|
||||||
|
] + offices_mimetypes
|
||||||
|
|
||||||
|
|
||||||
|
def allowed_file(mimetype: str):
|
||||||
|
return mimetype.lower() in mimetypes
|
||||||
|
|
||||||
|
|
||||||
|
async def create_embedding(room, event, md5, content, url):
|
||||||
|
transaction = await client.db.transaction()
|
||||||
|
await client.db.execute(
|
||||||
|
query="""insert into documents (md5, content, token, url)
|
||||||
|
values (:md5, :content, :token, :url)
|
||||||
|
on conflict (md5) do nothing
|
||||||
|
;""",
|
||||||
|
values={
|
||||||
|
"md5": md5,
|
||||||
|
"content": content,
|
||||||
|
"token": client.get_token_length(content),
|
||||||
|
"url": url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = await client.db.fetch_all(
|
||||||
|
query="select document_md5 from room_document where room = :room and document_md5 = :md5 limit 1;",
|
||||||
|
values={"room": room.room_id, "md5": md5},
|
||||||
|
)
|
||||||
|
if len(rows) > 0:
|
||||||
|
await transaction.rollback()
|
||||||
|
print("document alreadly insert in room", md5, room.room_id)
|
||||||
|
await client.room_send(
|
||||||
|
room.room_id,
|
||||||
|
"m.reaction",
|
||||||
|
{
|
||||||
|
"m.relates_to": {
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"key": "👍",
|
||||||
|
"rel_type": "m.annotation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await client.db.execute(
|
||||||
|
query="""
|
||||||
|
insert into room_document (room, document_md5)
|
||||||
|
values (:room_id, :md5)
|
||||||
|
on conflict (room, document_md5) do nothing
|
||||||
|
;""",
|
||||||
|
values={"room_id": room.room_id, "md5": md5},
|
||||||
|
)
|
||||||
|
|
||||||
|
# start embedding
|
||||||
|
chunks = spliter.split_text(content)
|
||||||
|
print("chunks", len(chunks))
|
||||||
|
embeddings = await embeddings_model.aembed_documents(chunks, chunk_size=1600)
|
||||||
|
print("embedding finished", len(embeddings))
|
||||||
|
if len(chunks) != len(embeddings):
|
||||||
|
raise ValueError("asdf")
|
||||||
|
insert_data: list[dict] = []
|
||||||
|
for chunk, embedding in zip(chunks, embeddings):
|
||||||
|
insert_data.append(
|
||||||
|
{
|
||||||
|
"document_md5": md5,
|
||||||
|
"md5": hashlib.md5(chunk.encode()).hexdigest(),
|
||||||
|
"content": chunk,
|
||||||
|
"token": client.get_token_length(chunk),
|
||||||
|
"embedding": str(embedding),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await client.db.execute_many(
|
||||||
|
query="""insert into embeddings (document_md5, md5, content, token, embedding)
|
||||||
|
values (:document_md5, :md5, :content, :token, :embedding)
|
||||||
|
on conflict (document_md5, md5) do nothing
|
||||||
|
;""",
|
||||||
|
values=insert_data,
|
||||||
|
)
|
||||||
|
print("insert", len(insert_data), "embedding data")
|
||||||
|
|
||||||
|
await client.db.execute(
|
||||||
|
query="""
|
||||||
|
insert into event_document (event, document_md5)
|
||||||
|
values (:event_id, :md5)
|
||||||
|
on conflict (event) do nothing
|
||||||
|
;""",
|
||||||
|
values={"event_id": event.event_id, "md5": md5},
|
||||||
|
)
|
||||||
|
|
||||||
|
await transaction.commit()
|
||||||
|
|
||||||
|
await client.room_send(
|
||||||
|
room.room_id,
|
||||||
|
"m.reaction",
|
||||||
|
{
|
||||||
|
"m.relates_to": {
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"key": "😘",
|
||||||
|
"rel_type": "m.annotation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def clean_html(html: str) -> str:
|
||||||
|
h2t = html2text.HTML2Text()
|
||||||
|
h2t.ignore_emphasis = True
|
||||||
|
h2t.ignore_images = True
|
||||||
|
h2t.ignore_links = True
|
||||||
|
h2t.body_width = 0
|
||||||
|
content = h2t.handle(html)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def clean_content(content: str, mimetype: str, document_md5: str) -> str:
|
||||||
|
# clean 0x00
|
||||||
|
content = content.replace("\x00", "")
|
||||||
|
# clean links
|
||||||
|
content = re.sub(r"\[.*?\]\(.*?\)", "", content)
|
||||||
|
content = re.sub(r"!\[.*?\]\(.*?\)", "", content)
|
||||||
|
# clean lines
|
||||||
|
lines = [i.strip() for i in content.split("\n\n")]
|
||||||
|
while "" in lines:
|
||||||
|
lines.remove("")
|
||||||
|
|
||||||
|
content = "\n\n".join(lines)
|
||||||
|
content = "\n".join([i.strip() for i in content.split("\n")])
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def pdf_to_text(f) -> str:
|
||||||
|
pdf_reader = PyPDF2.PdfReader(f)
|
||||||
|
num_pages = len(pdf_reader.pages)
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
for page_number in range(num_pages):
|
||||||
|
page = pdf_reader.pages[page_number]
|
||||||
|
content += page.extract_text()
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@client.ignore_self_message
|
||||||
|
@client.handel_no_gpt
|
||||||
|
@client.log_message
|
||||||
|
@client.with_typing
|
||||||
|
@client.replace_command_mark
|
||||||
|
@client.safe_try
|
||||||
|
async def message_file(room: MatrixRoom, event: RoomMessageFile):
|
||||||
|
print("received file")
|
||||||
|
mimetype = event.flattened().get("content.info.mimetype", "")
|
||||||
|
if not allowed_file(mimetype):
|
||||||
|
print("not allowed file", event.body)
|
||||||
|
raise ValueError("not allowed file")
|
||||||
|
resp = await client.download(event.url)
|
||||||
|
if isinstance(resp, DownloadError):
|
||||||
|
raise ValueError("file donwload error")
|
||||||
|
|
||||||
|
assert isinstance(resp.body, bytes)
|
||||||
|
md5 = hashlib.md5(resp.body).hexdigest()
|
||||||
|
|
||||||
|
document_fetch_result = await client.db.execute(
|
||||||
|
query="select content from documents where md5 = :md5;", values={"md5": md5}
|
||||||
|
)
|
||||||
|
document_alreadly_exists = len(document_fetch_result) == 0
|
||||||
|
|
||||||
|
# get content
|
||||||
|
content = ""
|
||||||
|
# document not exists
|
||||||
|
if not document_alreadly_exists:
|
||||||
|
print("document", md5, "alreadly exists")
|
||||||
|
content = document_fetch_result[0][0]
|
||||||
|
else:
|
||||||
|
if mimetype == "text/plain" or mimetype == "text/markdown":
|
||||||
|
content = resp.body.decode()
|
||||||
|
elif mimetype == "text/html":
|
||||||
|
content = clean_html(resp.body.decode())
|
||||||
|
elif mimetype == "application/pdf":
|
||||||
|
f = io.BytesIO(resp.body)
|
||||||
|
content = pdf_to_text(f)
|
||||||
|
elif mimetype in offices_mimetypes:
|
||||||
|
# save file to temp dir
|
||||||
|
base = event.body.rsplit(".", 1)[0]
|
||||||
|
ext = event.body.rsplit(".", 1)[1]
|
||||||
|
print("base", base)
|
||||||
|
source_filepath = os.path.join("./cache/office", event.body)
|
||||||
|
txt_filename = base + ".txt"
|
||||||
|
txt_filepath = os.path.join("./cache/office", txt_filename)
|
||||||
|
print("source_filepath", source_filepath)
|
||||||
|
with open(source_filepath, "wb") as f:
|
||||||
|
f.write(resp.body)
|
||||||
|
if ext in ["doc", "docx", "odt"]:
|
||||||
|
process = subprocess.Popen(
|
||||||
|
[
|
||||||
|
"soffice",
|
||||||
|
"--headless",
|
||||||
|
"--convert-to",
|
||||||
|
"txt:Text",
|
||||||
|
"--outdir",
|
||||||
|
"./cache/office",
|
||||||
|
source_filepath,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
process.wait()
|
||||||
|
with open(txt_filepath, "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
elif ext in ["ppt", "pptx", "odp"]:
|
||||||
|
pdf_filename = base + ".pdf"
|
||||||
|
pdf_filepath = os.path.join("./cache/office", pdf_filename)
|
||||||
|
process = subprocess.Popen(
|
||||||
|
[
|
||||||
|
"soffice",
|
||||||
|
"--headless",
|
||||||
|
"--convert-to",
|
||||||
|
"pdf",
|
||||||
|
"--outdir",
|
||||||
|
"./cache/office",
|
||||||
|
source_filepath,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
process.wait()
|
||||||
|
with open(pdf_filepath, "rb") as f:
|
||||||
|
content = pdf_to_text(f)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown ext: ", ext)
|
||||||
|
print("converted txt", content)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown mimetype", mimetype)
|
||||||
|
|
||||||
|
content = clean_content(content, mimetype, md5)
|
||||||
|
|
||||||
|
print("content length", len(content))
|
||||||
|
|
||||||
|
await create_embedding(room, event, md5, content, event.url)
|
||||||
|
|
||||||
|
|
||||||
|
client.add_event_callback(message_file, RoomMessageFile)
|
||||||
|
|
||||||
|
yt_dlp_support = ["b23.tv/", "www.bilibili.com/video/", "youtube.com/"]
|
||||||
|
|
||||||
|
|
||||||
|
def allow_yt_dlp(link: str) -> bool:
|
||||||
|
if not link.startswith("http://") and not link.startswith("https://"):
|
||||||
|
return False
|
||||||
|
allow = False
|
||||||
|
for u in yt_dlp_support:
|
||||||
|
if u in link:
|
||||||
|
allow = True
|
||||||
|
break
|
||||||
|
return allow
|
||||||
|
|
||||||
|
|
||||||
|
def allow_web(link: str) -> bool:
|
||||||
|
print("checking web url", link)
|
||||||
|
if not link.startswith("http://") and not link.startswith("https://"):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@client.message_callback_common_wrapper
|
||||||
|
async def message_text(room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
|
if event.body.startswith("!"):
|
||||||
|
should_react = True
|
||||||
|
if event.body.startswith("!clear") or event.body.startswith("!clean"):
|
||||||
|
# save to db
|
||||||
|
async with client.db.transaction():
|
||||||
|
await client.db.execute(
|
||||||
|
query="""
|
||||||
|
delete from embeddings e
|
||||||
|
using room_document rd
|
||||||
|
where e.document_md5 = rd.document_md5 and
|
||||||
|
rd.room = :room_id;
|
||||||
|
""",
|
||||||
|
values={"room_id": room.room_id},
|
||||||
|
)
|
||||||
|
await client.db.execute(
|
||||||
|
query="delete from room_document where room = :room_id;",
|
||||||
|
values={"room_id": room.room_id},
|
||||||
|
)
|
||||||
|
elif event.body.startswith("!embedding"):
|
||||||
|
sp = event.body.split()
|
||||||
|
if len(sp) < 2:
|
||||||
|
return
|
||||||
|
if not sp[1].lower() in ["on", "off"]:
|
||||||
|
return
|
||||||
|
status = sp[1].lower() == "on"
|
||||||
|
await client.db.execute(
|
||||||
|
query="""
|
||||||
|
insert into room_configs (room, embedding)
|
||||||
|
values (:room_id, :status)
|
||||||
|
on conflict (room) do update set embedding = excluded.embedding
|
||||||
|
;""",
|
||||||
|
values={"room_id": room.room_id, "status": status},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
should_react = False
|
||||||
|
if should_react:
|
||||||
|
await client.room_send(
|
||||||
|
room.room_id,
|
||||||
|
"m.reaction",
|
||||||
|
{
|
||||||
|
"m.relates_to": {
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"key": "😘",
|
||||||
|
"rel_type": "m.annotation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if allow_yt_dlp(event.body.split()[0]):
|
||||||
|
# handle yt-dlp
|
||||||
|
ydl_opts = {
|
||||||
|
"format": "wa*",
|
||||||
|
# ℹ️ See help(yt_dlp.postprocessor) for a list of available Postprocessors and their arguments
|
||||||
|
"postprocessors": [
|
||||||
|
{ # Extract audio using ffmpeg
|
||||||
|
"key": "FFmpegExtractAudio",
|
||||||
|
#'preferredcodec': 'opus',
|
||||||
|
#'preferredquality': 64,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
url = event.body.split()[0]
|
||||||
|
|
||||||
|
info = None
|
||||||
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||||
|
info = ydl.extract_info(url, download=True)
|
||||||
|
filepath = info["requested_downloads"][0]["filepath"]
|
||||||
|
filename = info["requested_downloads"][0]["filename"]
|
||||||
|
title = info["title"]
|
||||||
|
realfilepath = os.path.join("./cache/yt-dlp", filename)
|
||||||
|
os.rename(filepath, realfilepath)
|
||||||
|
|
||||||
|
result = openai.Audio.transcribe(
|
||||||
|
file=open(realfilepath, "rb"),
|
||||||
|
model="large-v2",
|
||||||
|
prompt=title,
|
||||||
|
)
|
||||||
|
result = "\n".join([i.text for i in result["segments"]])
|
||||||
|
print(event.sender, result)
|
||||||
|
|
||||||
|
md5 = hashlib.md5(result.encode()).hexdigest()
|
||||||
|
|
||||||
|
await create_embedding(room, event, md5, result, url)
|
||||||
|
return
|
||||||
|
|
||||||
|
if allow_web(event.body.split()[0]):
|
||||||
|
url = event.body.split()[0]
|
||||||
|
print("downloading", url)
|
||||||
|
html = await get_html(url)
|
||||||
|
md5 = hashlib.md5(html.encode()).hexdigest()
|
||||||
|
content = clean_html(html)
|
||||||
|
content = clean_content(content, "text/markdown", md5)
|
||||||
|
if not content:
|
||||||
|
raise ValueError("Empty content")
|
||||||
|
print(content)
|
||||||
|
await create_embedding(room, event, md5, content, url)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
client.add_event_callback(message_text, RoomMessageText)
|
||||||
|
|
||||||
|
asyncio.run(client.sync_forever())
|
||||||
75
bot_tts.py
Normal file
75
bot_tts.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
from langdetect import detect
|
||||||
|
from bot import Bot, print
|
||||||
|
from nio import MatrixRoom, RoomMessageText
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
url = f"https://texttospeech.googleapis.com/v1/text:synthesize?key={os.environ['GOOGLE_TTS_API_KEY']}"
|
||||||
|
|
||||||
|
client = Bot(
|
||||||
|
os.environ["BOT_TTS_HOMESERVER"],
|
||||||
|
os.environ["BOT_TTS_USER"],
|
||||||
|
os.environ["MATRIX_CHAIN_DEVICE"],
|
||||||
|
os.environ["BOT_TTS_ACCESS_TOKEN"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def tts(text: str):
|
||||||
|
lang = detect(text)
|
||||||
|
langMap = {
|
||||||
|
"zh-cn": {
|
||||||
|
"languageCode": "cmn-cn",
|
||||||
|
"name": "cmn-CN-Wavenet-B",
|
||||||
|
},
|
||||||
|
"en": {"languageCode": "en-US", "name": "en-US-Neural2-F"},
|
||||||
|
"ja": {"languageCode": "ja-JP", "name": "ja-JP-Neural2-B"},
|
||||||
|
}
|
||||||
|
voice = langMap.get(lang, langMap["en"])
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
payload = {
|
||||||
|
"input": {"text": text},
|
||||||
|
"voice": voice,
|
||||||
|
"audioConfig": {"audioEncoding": "OGG_OPUS", "speakingRate": 1.39},
|
||||||
|
}
|
||||||
|
headers = {"content-type": "application/json"}
|
||||||
|
async with session.post(url, data=json.dumps(payload), headers=headers) as resp:
|
||||||
|
data = await resp.json()
|
||||||
|
audio_content = data.get("audioContent")
|
||||||
|
decoded = base64.b64decode(audio_content)
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
|
@client.ignore_self_message
|
||||||
|
@client.handel_no_gpt
|
||||||
|
@client.log_message
|
||||||
|
@client.with_typing
|
||||||
|
@client.change_event_id_to_root_id
|
||||||
|
@client.replace_command_mark
|
||||||
|
@client.safe_try
|
||||||
|
async def message_callback(room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
|
if not event.sender.startswith("@chatgpt-bot"):
|
||||||
|
return
|
||||||
|
|
||||||
|
audio = await tts(event.body)
|
||||||
|
# convert
|
||||||
|
resp, upload = await client.upload(BytesIO(audio), "audio/ogg", filesize=len(audio))
|
||||||
|
content = {
|
||||||
|
"msgtype": "m.audio",
|
||||||
|
"body": event.body if len(event.body) < 20 else event.body[16] + "...",
|
||||||
|
"info": {"mimetype": "audio/ogg", "size": len(audio)},
|
||||||
|
"url": resp.content_uri,
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": "m.thread",
|
||||||
|
"event_id": event.event_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await client.room_send(room.room_id, message_type="m.room.message", content=content)
|
||||||
|
|
||||||
|
|
||||||
|
client.add_event_callback(message_callback, RoomMessageText)
|
||||||
|
|
||||||
|
asyncio.run(client.sync_forever())
|
||||||
169
bot_whisper.py
Normal file
169
bot_whisper.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
import os
|
||||||
|
from nio import DownloadError, MatrixRoom, RoomMessageAudio, RoomMessageFile
|
||||||
|
import asyncio
|
||||||
|
import openai
|
||||||
|
import io
|
||||||
|
from bot import Bot, print
|
||||||
|
|
||||||
|
|
||||||
|
client = Bot(
|
||||||
|
os.environ["BOT_WHISPER_HOMESERVER"],
|
||||||
|
os.environ["BOT_WHISPER_USER"],
|
||||||
|
os.environ["MATRIX_CHAIN_DEVICE"],
|
||||||
|
os.environ["BOT_WHISPER_ACCESS_TOKEN"],
|
||||||
|
)
|
||||||
|
client.welcome_message = (
|
||||||
|
"""欢迎使用 matrix chain whisper 插件,我能将房间中的语音消息转换成文字发出,如果语音过长,我会用文件形式发出"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@client.message_callback_common_wrapper
|
||||||
|
async def message_callback(room: MatrixRoom, event: RoomMessageAudio):
|
||||||
|
print("received message")
|
||||||
|
print(event.flattened())
|
||||||
|
if event.flattened().get("content.info.duration", 0) > 1000 * 60 * 5:
|
||||||
|
return await message_file(room, event)
|
||||||
|
if event.source.get("content", {}).get("org.matrix.msc1767.audio") is None:
|
||||||
|
# handle audio file
|
||||||
|
return await message_file(room, event)
|
||||||
|
resp = await client.download(event.url)
|
||||||
|
if isinstance(resp, DownloadError):
|
||||||
|
return
|
||||||
|
|
||||||
|
filelikeobj = io.BytesIO(resp.body)
|
||||||
|
filelikeobj.name = "matrixaudio.ogg"
|
||||||
|
# get prompt
|
||||||
|
rows = await client.db.execute(
|
||||||
|
query="""select content from (
|
||||||
|
select role, content, sum(token) over (partition by root order by id desc) as total_token
|
||||||
|
from memories
|
||||||
|
where root = :event_id
|
||||||
|
order by id
|
||||||
|
) as sub
|
||||||
|
where total_token < 3039
|
||||||
|
;""",
|
||||||
|
values={"event_id": event.event_id},
|
||||||
|
)
|
||||||
|
prompt = "".join([i[0] for i in rows])
|
||||||
|
# no memory
|
||||||
|
if not prompt:
|
||||||
|
db_result = await client.db.fetch_all(
|
||||||
|
query="select system, examples from room_configs where room = :room_id;",
|
||||||
|
values={"room_id": room.room_id},
|
||||||
|
)
|
||||||
|
if len(db_result) > 0:
|
||||||
|
systemMessageContent = db_result[0][0]
|
||||||
|
examples = [
|
||||||
|
m.get("content", "") for m in db_result[0][1] if m.get("example")
|
||||||
|
]
|
||||||
|
while "" in examples:
|
||||||
|
examples.remove("")
|
||||||
|
if systemMessageContent:
|
||||||
|
prompt += systemMessageContent + "\n\n"
|
||||||
|
if len(examples) > 0:
|
||||||
|
prompt += "\n\n".join(examples)
|
||||||
|
|
||||||
|
print("initial_prompt", prompt)
|
||||||
|
|
||||||
|
result = openai.Audio.transcribe(file=filelikeobj, model="large-v2", prompt=prompt)
|
||||||
|
result = "\n".join([i.text for i in result["segments"]])
|
||||||
|
print(event.sender, result)
|
||||||
|
|
||||||
|
await client.room_send(
|
||||||
|
room.room_id,
|
||||||
|
"m.room.message",
|
||||||
|
{
|
||||||
|
"body": result,
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": "m.thread",
|
||||||
|
"event_id": event.event_id,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
client.add_event_callback(message_callback, RoomMessageAudio)
|
||||||
|
|
||||||
|
ALLOWED_EXTENSIONS = {
|
||||||
|
"mp3",
|
||||||
|
"mp4",
|
||||||
|
"mpeg",
|
||||||
|
"mpga",
|
||||||
|
"m4a",
|
||||||
|
"wav",
|
||||||
|
"webm",
|
||||||
|
"3gp",
|
||||||
|
"flac",
|
||||||
|
"ogg",
|
||||||
|
"mkv",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def allowed_file(mimetype):
|
||||||
|
return "/" in mimetype and mimetype.rsplit("/", 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def get_txt_filename(filename):
|
||||||
|
return filename + ".txt"
|
||||||
|
|
||||||
|
|
||||||
|
async def message_file(room: MatrixRoom, event: RoomMessageFile):
|
||||||
|
print("received file")
|
||||||
|
if not allowed_file(event.flattened().get("content.info.mimetype")):
|
||||||
|
print("not allowed file", event.body)
|
||||||
|
raise Exception("not allowed file")
|
||||||
|
resp = await client.download(event.url)
|
||||||
|
if isinstance(resp, DownloadError):
|
||||||
|
return
|
||||||
|
filelikeobj = io.BytesIO(resp.body)
|
||||||
|
filelikeobj.name = event.body
|
||||||
|
|
||||||
|
# get prompt
|
||||||
|
rows = await client.db.execute(
|
||||||
|
query="""
|
||||||
|
select content from (
|
||||||
|
select role, content, sum(token) over (partition by root order by id desc) as total_token
|
||||||
|
from memories
|
||||||
|
where root = :event_id
|
||||||
|
order by id
|
||||||
|
) as sub
|
||||||
|
where total_token < 3039
|
||||||
|
;""",
|
||||||
|
values={"event_id": event.event_id},
|
||||||
|
)
|
||||||
|
prompt = "".join([i[0] for i in rows])
|
||||||
|
print("initial_prompt", prompt)
|
||||||
|
|
||||||
|
result = openai.Audio.transcribe(file=filelikeobj, model="large-v2", prompt=prompt)
|
||||||
|
result = "\n".join([i.text for i in result["segments"]])
|
||||||
|
print(event.sender, result)
|
||||||
|
resultfilelike = io.BytesIO(result.encode())
|
||||||
|
resultfilelike.name = get_txt_filename(event.body)
|
||||||
|
resultfileSize = len(result.encode())
|
||||||
|
uploadResp, _ = await client.upload(
|
||||||
|
resultfilelike, content_type="text/plain", filesize=resultfileSize
|
||||||
|
)
|
||||||
|
print("uri", uploadResp.content_uri)
|
||||||
|
|
||||||
|
await client.room_send(
|
||||||
|
room.room_id,
|
||||||
|
"m.room.message",
|
||||||
|
{
|
||||||
|
"body": resultfilelike.name,
|
||||||
|
"filename": resultfilelike.name,
|
||||||
|
"msgtype": "m.file",
|
||||||
|
"info": {
|
||||||
|
"mimetype": "text/plain",
|
||||||
|
"size": resultfileSize,
|
||||||
|
},
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": "m.thread",
|
||||||
|
"event_id": event.event_id,
|
||||||
|
},
|
||||||
|
"url": uploadResp.content_uri,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(client.sync_forever())
|
||||||
25
docker-compose.yaml
Normal file
25
docker-compose.yaml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
version: '3'
|
||||||
|
services:
|
||||||
|
bot-chatgpt:
|
||||||
|
image: matrix-chain
|
||||||
|
env_file:
|
||||||
|
- ./.env
|
||||||
|
command: python3 bot_chatgpt.py
|
||||||
|
|
||||||
|
bot-db:
|
||||||
|
image: matrix-chain
|
||||||
|
env_file:
|
||||||
|
- ./.env
|
||||||
|
command: python3 bot_db.py
|
||||||
|
|
||||||
|
bot-whisper:
|
||||||
|
image: matrix-chain
|
||||||
|
env_file:
|
||||||
|
- ./.env
|
||||||
|
command: python3 bot_whisper.py
|
||||||
|
|
||||||
|
bot-tts:
|
||||||
|
image: matrix-chain
|
||||||
|
env_file:
|
||||||
|
- ./.env
|
||||||
|
command: python3 bot_tts.py
|
||||||
13
requirements.txt
Normal file
13
requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
databases
|
||||||
|
aiopg
|
||||||
|
matrix-nio
|
||||||
|
langchain
|
||||||
|
python-dotenv
|
||||||
|
tiktoken
|
||||||
|
PyPDF2
|
||||||
|
html2text
|
||||||
|
yt_dlp
|
||||||
|
selenium
|
||||||
|
openai
|
||||||
|
langdetect
|
||||||
|
jinja2
|
||||||
72
requirements_version.txt
Normal file
72
requirements_version.txt
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
aiofiles==23.2.1
|
||||||
|
aiohttp==3.8.6
|
||||||
|
aiohttp-socks==0.7.1
|
||||||
|
aiopg==1.4.0
|
||||||
|
aiosignal==1.3.1
|
||||||
|
annotated-types==0.6.0
|
||||||
|
anyio==3.7.1
|
||||||
|
async-timeout==4.0.3
|
||||||
|
attrs==23.1.0
|
||||||
|
Brotli==1.1.0
|
||||||
|
certifi==2023.7.22
|
||||||
|
charset-normalizer==3.3.0
|
||||||
|
databases==0.8.0
|
||||||
|
dataclasses-json==0.6.1
|
||||||
|
frozenlist==1.4.0
|
||||||
|
greenlet==3.0.0
|
||||||
|
h11==0.14.0
|
||||||
|
h2==4.1.0
|
||||||
|
hpack==4.0.0
|
||||||
|
html2text==2020.1.16
|
||||||
|
hyperframe==6.0.1
|
||||||
|
idna==3.4
|
||||||
|
Jinja2==3.1.2
|
||||||
|
jsonpatch==1.33
|
||||||
|
jsonpointer==2.4
|
||||||
|
jsonschema==4.19.1
|
||||||
|
jsonschema-specifications==2023.7.1
|
||||||
|
langchain==0.0.319
|
||||||
|
langdetect==1.0.9
|
||||||
|
langsmith==0.0.49
|
||||||
|
MarkupSafe==2.1.3
|
||||||
|
marshmallow==3.20.1
|
||||||
|
matrix-nio==0.22.1
|
||||||
|
multidict==6.0.4
|
||||||
|
mutagen==1.47.0
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
numpy==1.26.1
|
||||||
|
openai==0.28.1
|
||||||
|
outcome==1.3.0
|
||||||
|
packaging==23.2
|
||||||
|
psycopg2-binary==2.9.9
|
||||||
|
pycryptodome==3.19.0
|
||||||
|
pycryptodomex==3.19.0
|
||||||
|
pydantic==2.4.2
|
||||||
|
pydantic_core==2.10.1
|
||||||
|
PyPDF2==3.0.1
|
||||||
|
PySocks==1.7.1
|
||||||
|
python-dotenv==1.0.0
|
||||||
|
python-socks==2.4.3
|
||||||
|
PyYAML==6.0.1
|
||||||
|
referencing==0.30.2
|
||||||
|
regex==2023.10.3
|
||||||
|
requests==2.31.0
|
||||||
|
rpds-py==0.10.6
|
||||||
|
selenium==4.14.0
|
||||||
|
six==1.16.0
|
||||||
|
sniffio==1.3.0
|
||||||
|
sortedcontainers==2.4.0
|
||||||
|
SQLAlchemy==1.4.49
|
||||||
|
tenacity==8.2.3
|
||||||
|
tiktoken==0.5.1
|
||||||
|
tqdm==4.66.1
|
||||||
|
trio==0.22.2
|
||||||
|
trio-websocket==0.11.1
|
||||||
|
typing-inspect==0.9.0
|
||||||
|
typing_extensions==4.8.0
|
||||||
|
unpaddedbase64==2.1.0
|
||||||
|
urllib3==2.0.7
|
||||||
|
websockets==11.0.3
|
||||||
|
wsproto==1.2.0
|
||||||
|
yarl==1.9.2
|
||||||
|
yt-dlp==2023.10.13
|
||||||
Reference in New Issue
Block a user