diff --git a/config.json b/config.json index 7533b20..cce834a 100644 --- a/config.json +++ b/config.json @@ -1,5 +1,5 @@ { - "test": "test", + "id": "miku", "listen_port": 3900, "listen_addr": "127.0.0.1", "listen_num": 39, diff --git a/mswp.py b/mswp.py index 702c54b..b43708c 100644 --- a/mswp.py +++ b/mswp.py @@ -1,11 +1,17 @@ +import os +from config import jsondata + class Datapack: - def __init__(self, method='post', app='all', version='msw/1.0', head=None, body=b'', check_head=True): + def __init__(self, method='post', app='all', version='msw/1.0', head=None, body=b'', check_head=True, file=None): + self.id = jsondata.try_to_read_jsondata('id', 'Unknown_id') if head is None: head = {} self.head = head else: self.head = head + self.head['id'] = self.id self.method = method + self.file = file self.app = app self.version = version if not self.head and check_head: @@ -16,7 +22,11 @@ class Datapack: self.encode_data = b'' def encode(self): - self.head['length'] = str(len(self.body)) + if self.method == 'file': + self.head['length'] = str(os.path.getsize(self.head['filename'])) + else: + self.head['length'] = str(len(self.body)) + first_line = self.method.encode() + b' ' + self.app.encode() + b' ' + self.version.encode() heads = ''.encode() for i in self.head: diff --git a/plugins/input.py b/plugins/input.py index 6ee5e93..d487bb8 100644 --- a/plugins/input.py +++ b/plugins/input.py @@ -1,20 +1,36 @@ import threading import copy +import os from mswp import Datapack from forwarder import receive_queues, send_queue receive_queue = receive_queues[__name__] def main(): + file_flag = False; while True: + file_flag = False; raw_data = input() + + if raw_data[:6] == '(file)': + raw_data = raw_data[6:] + file_flag = True + first_index, last_index = find_the_last(raw_data) app = raw_data[:first_index] body = raw_data[last_index:] app = app.replace(' ', '') dp = Datapack(head={'from': __name__}) dp.app = app - dp.body = body.encode() + + if file_flag: + dp.method = 'file' + dp.body = b'' + dp.head['filename'] = body + + else: + dp.body = body.encode() + send_queue.put(dp) diff --git a/plugins/net.py b/plugins/net.py index 38b3fd0..6bafcf1 100644 --- a/plugins/net.py +++ b/plugins/net.py @@ -1,14 +1,14 @@ import threading import socket import queue -import copy +import os from mswp import Datapack from forwarder import receive_queues, send_queue from config import jsondata receive_queue = receive_queues[__name__] RECV_BUFF = jsondata.try_to_read_jsondata('recv_buff', 4096) - +ID = jsondata.try_to_read_jsondata('id', 'Unknown_ID') def main(): netrecv = Netrecv() @@ -52,8 +52,8 @@ class Netrecv: listen_num = jsondata.try_to_read_jsondata('listen_num', 39) s.listen(listen_num) self.s = s - - self.connection_list = [] # [(conn, addr), (conn, addr)...] + self.stat = {} # # # # # + self.connection_list = [] # [(conn, addr), (conn, addr)...] Very important self.connection_process_thread_list =[] self.un_enougth_list = [] self.send_queue = queue.Queue() @@ -62,10 +62,10 @@ class Netrecv: self.thread_check_send_queue = threading.Thread(target=self.check_send_queue, args=()) ######################################## - self.send_queue = queue.Queue() - raw_data = read_netlisttxt_file() + + raw_data = read_netlisttxt_file() # read file lines = raw_data.split('\n') - ips = [] + self.file_addr_list = [] for line in lines: ip_port = line.split(':') if len(ip_port) == 1: @@ -77,43 +77,71 @@ class Netrecv: port = 3900 ip = process_hostname(ip_port[0]) port = int(ip_port[1]) - ips.append((ip, port)) + self.file_addr_list.append((ip, port)) - for addr in ips: # Create connection + for addr in self.file_addr_list: # Create connection conn = connect(addr) self.connection_list.append((conn, addr)) - # create thread - - self.check_queue_thread = threading.Thread(target=self.check_send_queue, args=()) - - self.send_queue_dist = {} - - for addr in self.addr_to_thread: # start thread - self.addr_to_thread[addr].start() - self.check_queue_thread.start() # thread that check the queue and send one by one - self.thread_check_accept_connection.start() - self.thread_check_send_queue.start() - - def check_accept_connection(self): - while True: - conn, addr = self.s.accept() - self.connection_list.append((conn, addr)) + # Create receive thread and start connection_thread = threading.Thread(target=self.process_connection, args=(conn, addr)) self.connection_process_thread_list.append(connection_thread) connection_thread.start() - def process_connection_send(self, conn, addr): - pass + self.thread_check_accept_connection.start() + self.thread_check_send_queue.start() + + def check_send_queue(self): + while True: + dp = receive_queue.get() + if dp.method == 'file': + print('right') + print(dp.head) + dp.encode() + for id in self.stat: + for conn, addr in self.stat[id]: + conn.sendall(dp.encode_data) + file = open(dp.head['filename'], 'rb') + for data in file: + conn.send(data) + print('sended') + + else: + print('wrong') + dp.encode() + for id in self.stat: + for conn, addr in self.stat[id]: + conn.sendall(dp.encode_data) + + def check_accept_connection(self): + while True: + conn, addr = self.s.accept() + self.connection_list.append((conn, addr)) # # # # # + connection_thread = threading.Thread(target=self.process_connection, args=(conn, addr)) + self.connection_process_thread_list.append(connection_thread) + connection_thread.start() + + def remove_connection(self, conn, addr): + conn.close() + for id in self.stat: + if (conn, addr) in self.stat[id]: + self.stat[id].remove(conn, addr) + self.connection_list.remove((conn, addr)) + print('Removed connection', str(addr)) def process_connection(self, conn, addr): - print('Connection accpet %s' % str(addr)) + print('Connection accept %s' % str(addr)) data = b'' while True: - new_data = conn.recv(RECV_BUFF) + try: + new_data = conn.recv(RECV_BUFF) + except ConnectionResetError: + print('Connection Reset Error') + self.remove_connection(conn, addr) + return + if not new_data and not data: - conn.close() - self.connection_list.remove((conn, addr)) + self.remove_connection(conn, addr) print('return 1') return data += new_data @@ -187,7 +215,7 @@ class Netrecv: dp.encode_data = b'' send_queue.put(dp) - else: # dp.method is not 'file' + else: # normal data pack length = int(dp.head['length']) data_length = len(data) @@ -234,9 +262,21 @@ class Netrecv: dp.body = data[:length] data = data[length:] - dp.encode() - send_queue.put(dp) - print('###############\n' + dp.encode_data.decode() + '\n###############') + # net config data package + if dp.app == 'net': + dp_id = dp.head['id'] + local_id = self.stat.get(dp_id) + + if not local_id: # create if not exits + self.stat[dp_id] = [] + + if not (conn, addr) in self.stat[dp_id]: + self.stat[dp_id].append((conn, addr)) + + else: + dp.encode() + send_queue.put(dp) + print('###############\n' + dp.encode_data.decode() + '\n###############') thread = threading.Thread(target=main, args=()) diff --git a/test_tool.py b/test_tool.py index 3cb477c..f0c15b5 100644 --- a/test_tool.py +++ b/test_tool.py @@ -6,10 +6,19 @@ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(('127.0.0.1', 3966)) s.listen(100) +id = '''post net msw/1.0 +id: miku +from: test +length: 0 + +''' + +id = id.encode() def process(conn, addr): + conn.sendall(id) + print('accept connection from', str(addr)) while True: - print('accept connection from', str(addr)) data = conn.recv(4096) if not data: conn.close()