init
This commit is contained in:
315
bot.py
Normal file
315
bot.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
import os
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from nio import AsyncClient, MatrixRoom, RoomMessageText, SyncError, InviteEvent
|
||||
import tiktoken
|
||||
import databases
|
||||
import builtins
|
||||
import sys
|
||||
|
||||
|
||||
def print(*args, **kwargs):
|
||||
kwargs["file"] = sys.stderr
|
||||
builtins.print(*args, **kwargs)
|
||||
|
||||
|
||||
class Bot(AsyncClient):
|
||||
def __init__(self, homeserver: str, user: str, device_id: str, access_token: str):
|
||||
super().__init__(homeserver)
|
||||
self.access_token = access_token
|
||||
self.user_id = self.user = user
|
||||
self.device_id = device_id
|
||||
self.welcome_message = ""
|
||||
self._joined_rooms = []
|
||||
self.db = databases.Database(os.environ["MATRIX_CHAIN_DB"])
|
||||
|
||||
self.enc = tiktoken.encoding_for_model("gpt-4")
|
||||
|
||||
# auto join
|
||||
self.add_event_callback(self.auto_join, InviteEvent)
|
||||
|
||||
async def init_db(self):
|
||||
db = self.db
|
||||
await db.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS documents
|
||||
(
|
||||
md5 character(32) NOT NULL PRIMARY KEY,
|
||||
content text,
|
||||
token integer,
|
||||
url text
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS embeddings
|
||||
(
|
||||
document_md5 character(32) NOT NULL,
|
||||
md5 character(32) NOT NULL,
|
||||
content text NOT NULL,
|
||||
token integer NOT NULL,
|
||||
embedding vector(1536) NOT NULL,
|
||||
PRIMARY KEY (document_md5, md5),
|
||||
FOREIGN KEY (document_md5) REFERENCES documents(md5)
|
||||
);
|
||||
"""
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS event_document
|
||||
(
|
||||
event text NOT NULL PRIMARY KEY,
|
||||
document_md5 character(32) NOT NULL,
|
||||
FOREIGN KEY (document_md5)
|
||||
REFERENCES documents (md5)
|
||||
);
|
||||
"""
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memories
|
||||
(
|
||||
id SERIAL PRIMARY KEY,
|
||||
root text NOT NULL,
|
||||
role integer NOT NULL,
|
||||
content text NOT NULL,
|
||||
token integer NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS room_configs
|
||||
(
|
||||
room text NOT NULL PRIMARY KEY,
|
||||
model_name text,
|
||||
temperature float NOT NULL DEFAULT 0,
|
||||
system text,
|
||||
embedding boolean NOT NULL DEFAULT false,
|
||||
examples TEXT[] NOT NULL DEFAULT '{}'
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS room_document (
|
||||
room text NOT NULL,
|
||||
document_md5 character(32) NOT NULL,
|
||||
PRIMARY KEY (room, document_md5)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def get_token_length(self, text: str) -> int:
|
||||
return len(self.enc.encode(text))
|
||||
|
||||
async def sync_forever(self):
|
||||
# init
|
||||
print("connecting to db")
|
||||
await self.db.connect()
|
||||
|
||||
# init db hook
|
||||
print("init db hook")
|
||||
await self.init_db()
|
||||
|
||||
# remote callback to perform initial sync
|
||||
callbacks = self.event_callbacks
|
||||
self.event_callbacks = []
|
||||
# skip intial sync
|
||||
print("Perform initial sync")
|
||||
resp = await self.sync(timeout=30000)
|
||||
if isinstance(resp, SyncError):
|
||||
raise BaseException(SyncError)
|
||||
self.event_callbacks = callbacks
|
||||
# set online
|
||||
print("Set online status")
|
||||
await self.set_presence("online")
|
||||
# sync
|
||||
print("Start forever sync")
|
||||
return await super().sync_forever(300000, since=resp.next_batch)
|
||||
|
||||
async def auto_join(self, room: MatrixRoom, event: InviteEvent):
|
||||
print("join", event.sender, room.room_id)
|
||||
if room.room_id in self._joined_rooms:
|
||||
return
|
||||
await self.join(room.room_id)
|
||||
self._joined_rooms.append(room.room_id)
|
||||
if self.welcome_message:
|
||||
await self.room_send(
|
||||
room.room_id,
|
||||
"m.room.message",
|
||||
{
|
||||
"body": self.welcome_message,
|
||||
"msgtype": "m.text",
|
||||
"nogpt": True,
|
||||
},
|
||||
)
|
||||
|
||||
def ignore_self_message(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room: MatrixRoom, event: RoomMessageText):
|
||||
if event.sender == self.user:
|
||||
return
|
||||
return await func(room, event)
|
||||
|
||||
return ret
|
||||
|
||||
def log_message(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room: MatrixRoom, event: RoomMessageText):
|
||||
print(room.room_id, event.sender, event.body)
|
||||
return await func(room, event)
|
||||
|
||||
return ret
|
||||
|
||||
def with_typing(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room, *args, **kargs):
|
||||
await self.room_typing(room.room_id, True, 60000 * 3)
|
||||
resp = await func(room, *args, **kargs)
|
||||
await self.room_typing(room.room_id, False)
|
||||
return resp
|
||||
|
||||
return ret
|
||||
|
||||
def change_event_id_to_root_id(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room, event, *args, **kargs):
|
||||
root = event.event_id
|
||||
if event.flattened().get("content.m.relates_to.rel_type") == "m.thread":
|
||||
root = event.source["content"]["m.relates_to"]["event_id"]
|
||||
event.event_id = root
|
||||
return await func(room, event, *args, **kargs)
|
||||
|
||||
return ret
|
||||
|
||||
def ignore_not_mentioned(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room, event, *args, **kargs):
|
||||
flattened = event.flattened()
|
||||
if not self.user in flattened.get(
|
||||
"content.body", ""
|
||||
) and not self.user in flattened.get("content.formatted_body", ""):
|
||||
return
|
||||
return await func(room, event, *args, **kargs)
|
||||
|
||||
return ret
|
||||
|
||||
def replace_command_mark(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room, event, *args, **kargs):
|
||||
if event.body.startswith("!"):
|
||||
event.body = "!" + event.body[1:]
|
||||
return await func(room, event, *args, **kargs)
|
||||
|
||||
return ret
|
||||
|
||||
def handel_no_gpt(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room, event, *args, **kargs):
|
||||
if not event.flattened().get("content.nogpt") is None:
|
||||
return
|
||||
return await func(room, event, *args, **kargs)
|
||||
|
||||
return ret
|
||||
|
||||
def replace_reply_file_with_content(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room, event, *args, **kargs):
|
||||
flattened = event.flattened()
|
||||
formatted_body = flattened.get("content.formatted_body", "")
|
||||
if (
|
||||
formatted_body.startswith("<mx-reply>")
|
||||
and "</mx-reply>" in formatted_body
|
||||
):
|
||||
print("replacing file content")
|
||||
formatted_body = formatted_body[
|
||||
formatted_body.index("</mx-reply>") + len("</mx-reply>") :
|
||||
]
|
||||
document_event_id = flattened.get(
|
||||
"content.m.relates_to.m.in_reply_to.event_id", ""
|
||||
)
|
||||
fetch = await self.db.fetch_all(
|
||||
query="""select d.content, d.token
|
||||
from documents d
|
||||
join event_document ed on d.md5 = ed.document_md5
|
||||
where ed.event = :event_id""",
|
||||
values={"event_id": document_event_id},
|
||||
)
|
||||
if len(fetch) > 0 and fetch[0][1] < 8192 + 4096:
|
||||
content = fetch[0][0]
|
||||
print(content)
|
||||
print("-----------")
|
||||
event.body = content + "\n\n---\n\n" + formatted_body
|
||||
else:
|
||||
print("document not found or too large", event.event_id)
|
||||
|
||||
return await func(room, event, *args, **kargs)
|
||||
|
||||
return ret
|
||||
|
||||
def ignore_link(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room, event, *args, **kargs):
|
||||
if event.body.startswith("https://") or event.body.startswith("http://"):
|
||||
return
|
||||
return await func(room, event, *args, **kargs)
|
||||
|
||||
return ret
|
||||
|
||||
def safe_try(self, func):
|
||||
@wraps(func)
|
||||
async def ret(room, event, *args, **kargs):
|
||||
try:
|
||||
return await func(room, event, *args, **kargs)
|
||||
except Exception as e:
|
||||
print("--------------")
|
||||
print("error:")
|
||||
print(traceback.format_exc())
|
||||
print("--------------")
|
||||
await self.room_send(
|
||||
room.room_id,
|
||||
"m.reaction",
|
||||
{
|
||||
"m.relates_to": {
|
||||
"event_id": event.event_id,
|
||||
"key": "😵",
|
||||
"rel_type": "m.annotation",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
def message_callback_common_wrapper(self, func):
|
||||
@wraps(func)
|
||||
@self.ignore_self_message
|
||||
@self.handel_no_gpt
|
||||
@self.log_message
|
||||
@self.with_typing
|
||||
@self.replace_reply_file_with_content
|
||||
@self.change_event_id_to_root_id
|
||||
@self.replace_command_mark
|
||||
@self.safe_try
|
||||
async def ret(*args, **kargs):
|
||||
return await func(*args, **kargs)
|
||||
|
||||
return ret
|
||||
|
||||
async def react_ok(self, room_id: str, event_id: str):
|
||||
await self.room_send(
|
||||
room_id,
|
||||
"m.reaction",
|
||||
{
|
||||
"m.relates_to": {
|
||||
"event_id": event_id,
|
||||
"key": "😘",
|
||||
"rel_type": "m.annotation",
|
||||
}
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user