diff --git a/main.py b/main.py index ae887a7..68630bd 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ from flask.json.provider import DefaultJSONProvider import rpc import util +import protoswitch import logging @@ -159,4 +160,4 @@ if __name__ == '__main__': port = get_port() print(f'http://localhost:{port}/#{app.secret}') webbrowser.open(f'http://localhost:{port}/#{app.secret}') - app.run(port=port) + app.run(port=port, request_handler=protoswitch.SwitchingRequestHandler) diff --git a/protoswitch.py b/protoswitch.py new file mode 100644 index 0000000..8c71f02 --- /dev/null +++ b/protoswitch.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +import socket +import io +from werkzeug.serving import WSGIRequestHandler + +from typing import Iterator + +import logging + +logger = logging.getLogger(__name__) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s\t%(message)s')) +logger.addHandler(handler) +logger.setLevel(logging.DEBUG) + +bEOT = b'\x04' + +class SwitchingRequestHandler(WSGIRequestHandler): + def parse_request(self): + if self.raw_requestline.startswith(b'[XWB]'): + logger.info(f"{self.client_address[0]}:{self.client_address[1]} VistA OPEN") + proxy_vista(read_from_file(self.rfile, self.raw_requestline), self.wfile, self.client_address, ('test.northport.med.va.gov', 19009)) + logger.info(f"{self.client_address[0]}:{self.client_address[1]} VistA CLOSE") + return False + return WSGIRequestHandler.parse_request(self) + +def proxy_vista(rfilegen: Iterator[bytes], wfile: io.BufferedWriter, localaddr: tuple, remoteaddr: tuple) -> None: + remotesock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + remotesock.connect(remoteaddr) + recipient = recv_from_socket(remotesock) + for n, req in enumerate(rfilegen): + logger.debug(f"{localaddr[0]}:{localaddr[1]} #{n} → {req.decode('latin-1').encode('unicode-escape').decode()}") + remotesock.send(req + bEOT) + res = next(recipient) + for line in res.decode('latin-1').splitlines(keepends=True): + logger.debug(f"{localaddr[0]}:{localaddr[1]} #{n} ← {line.encode('unicode-escape').decode()}") + wfile.write(res + bEOT) + remotesock.shutdown(socket.SHUT_RDWR) + remotesock.close() + +def read_from_file(rfile: io.BufferedReader, buf: bytes=b'', end: bytes=bEOT, minsz: int=1024, maxsz: int=32768) -> Iterator[bytes]: + if len(buf) > 0: + while (idx := buf.find(end)) >= 0: + if idx > 0: + yield buf[:idx] + buf = buf[idx + 1:] + bufsz = minsz + while len(data := rfile.read1(bufsz)) > 0: + buf += data + while (idx := buf.find(end)) >= 0: + if idx > 0: + yield buf[:idx] + bufsz = minsz + elif bufsz < maxsz: + bufsz = _x if (_x := bufsz << 1) < maxsz else maxsz + buf = buf[idx + 1:] + +def recv_from_socket(sock: socket.socket, end: bytes=bEOT, minsz: int=1024, maxsz: int=32768) -> Iterator[bytes]: + buf = b'' + bufsz = minsz + while True: + if len(data := sock.recv(bufsz)) > 0: + buf += data + while (idx := buf.find(end)) >= 0: + if idx > 0: + yield buf[:idx] + bufsz = minsz + elif bufsz < maxsz: + bufsz = _x if (_x := bufsz << 1) < maxsz else maxsz + buf = buf[idx + 1:]