nuVistA/rpc.py
2023-04-29 18:23:20 -04:00

256 lines
12 KiB
Python

#!/usr/bin/env python3
import math
import socket
import threading
import asyncio
import warnings
import logging
from collections import namedtuple
from XWBHash import encrypt0 as XWBHash_encrypt
from typing import Any, Union, Sequence
logger = logging.getLogger(__name__)
class RPCExc(Exception): pass
class RPCExcFormat(ValueError, RPCExc): pass
class RPCExcAuth(RPCExc): pass
class RPCExcServerError(RPCExc): pass
class RPCExcInvalidResult(RPCExc): pass
class RPCType(object):
LITERAL = b'0'
REFERENCE = b'1'
LIST = b'2'
GLOBAL = b'3'
EMPTY = b'4'
STREAM = b'5'
def __init__(self, value, magic=None):
self.magic = magic
self.value = value.value if isinstance(value, RPCType) else value
RecordServerInfo = namedtuple('RecordServerInfo', ('server', 'volume', 'uci', 'device', 'attempts', 'skip_signon_screen', 'domain', 'production'))
def s_pack(value: Any, encoding: str='latin-1'):
encoded = value.encode(encoding)
if len(encoded) <= 255:
return bytes((len(encoded),)) + encoded
raise ValueError('cannot s-pack string longer than 255 bytes: ' + repr(value))
def l_pack(value: Any, envelope: int=3, wrapped: bool=True, magic=None, encoding: str='latin-1'):
if isinstance(value, dict):
bare = b't'.join(l_pack(k, envelope=envelope, wrapped=False, encoding=encoding) + l_pack(v, envelope=envelope, wrapped=False, encoding=encoding) for k, v in value.items())
return ((magic or b'2') + bare + b'f') if wrapped else bare
elif not isinstance(value, str) and hasattr(value, '__iter__'):
bare = b't'.join(l_pack(k, envelope=envelope, wrapped=False, encoding=encoding) + l_pack(v, envelope=envelope, wrapped=False, encoding=encoding) for k, v in enumerate(value, start=1))
return ((magic or b'2') + bare + b'f') if wrapped else bare
elif isinstance(value, RPCType):
return l_pack(value.value, envelope=envelope, magic=value.magic, encoding=encoding)
else:
encoded = str(value).encode(encoding)
if len(encoded) <= 10**envelope - 1:
bare = str(len(encoded)).zfill(envelope).encode(encoding) + encoded
return ((magic or b'0') + bare + b'f') if wrapped else bare
raise ValueError(f'cannot l-pack string longer than {10**envelope - 1} bytes with an envelope of {envelope}: ' + repr(value))
def l_pack_maxlen(value: Any, encoding: str='latin-1'):
if isinstance(value, dict):
return max(max(l_pack_maxlen(k, encoding=encoding) for k in value.keys()), max(l_pack_maxlen(v, encoding=encoding) for v in value.values())) if len(value) > 0 else 0
elif not isinstance(value, str) and hasattr(value, '__iter__'):
return max(len(str(len(value))), max(l_pack_maxlen(v, encoding=encoding) for v in value)) if len(value) > 0 else 0
else:
return len(str(value).encode(encoding))
def rpc_pack(name: str, *args: Any, command: bool=False, envelope: int=0, broker_version: str='XWB*1.1*65', encoding: str='latin-1'):
# protocol token [XWB]VTEX: [XWB] = NS broker [XWB], V = V 1, T = type 1, E = envelope size 3, X = XWBPRT 0
envelope = max(3, math.ceil(math.log10(max(1, max(l_pack_maxlen(arg, encoding=encoding) for arg in args)))) if envelope < 1 and len(args) else envelope)
return b'[XWB]11' + str(envelope).encode(encoding) + b'0' + (b'4' if command else (b'2' + s_pack(broker_version, encoding=encoding))) + s_pack(name, encoding=encoding) + b'5' + (b''.join(l_pack(arg, envelope=envelope, encoding=encoding) for arg in args) if len(args) > 0 else b'4f')
def rpc_unpack_result(data: str, encoding: str='latin-1'):
if data[:2] == b'\x00\x00':
if len(data) > 2 and data[2] == 0x18: # 0x18 is CAN
raise RPCExcServerError(data[3:].decode(encoding))
elif data[-1] == 0x1f: # 0x1f is US
return rpc_unpack_table(data[2:-1].decode(encoding).split('\x1e')) # 0x1e is RS
elif data[-2:] == b'\r\n':
return tuple(data[2:-2].decode(encoding).split('\r\n'))
else:
return data[2:].decode(encoding)
raise RPCExcFormat(data)
def rpc_unpack_table(rows: Sequence[str]):
# table: ROW\x1eROW\x1eROW\x1eROW\x1eROW\x1e\x1f; row: COL^COL^COL^COL^COL; header field: [IT]\d{5}.+
if len(rows) > 0 and len(hdr := rows[0]) > 0 and hdr[0] in ('I', 'T') and hdr[1:6].isdecimal():
header = [field[6:] for field in rows[0].split('^')]
return tuple(dict(zip(header, row.split('^'))) for row in rows[1:] if len(row) > 0)
else:
return tuple(tuple(row.split('^')) for row in rows if len(row) > 0)
def send_rpc_msg(sock: socket.socket, msg: bytes, end: bytes=b'\x04'):
sock.send(msg + end)
def recv_rpc_msg(sock: socket.socket, end: bytes=b'\x04', minsz: int=1024, maxsz: int=32768): # 0x04 is EOT
buf = b''
bufsz = minsz
while True:
if len(data := sock.recv(bufsz)) > 0:
buf += data
while (idx := buf.find(end)) >= 0:
if idx > 0:
yield buf[:idx]
bufsz = minsz
elif bufsz < maxsz:
bufsz = _x if (_x := bufsz << 1) < maxsz else maxsz
buf = buf[idx + 1:]
class ClientSync(object):
def __init__(self, host: str, port: int, TCPConnect: bool=True):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((host, port))
self.recv_rpc_msg = recv_rpc_msg(self.sock)
self.lock = threading.Lock()
self._server = { 'host': host, 'port': port }
self._user = None
self.context = 'XUS SIGNON'
if TCPConnect and (res := self.TCPConnect(self.sock.getsockname()[0], '0', socket.gethostname())) != 'accept':
raise RPCExcInvalidResult('TCPConnect', self.sock.getsockname()[0], '0', socket.gethostname(), res)
def __call__(self, name: str, *args: Any, command: bool=False, envelope: int=0, context: Union[Sequence, None]=None, encoding='latin-1'):
name = name.replace('_', ' ')
with self.lock:
if name != 'XWB CREATE CONTEXT' and context and len(context) > 0 and self.context not in context:
send_rpc_msg(self.sock, rpc_pack('XWB CREATE CONTEXT', XWBHash_encrypt(context[0]), envelope=envelope, encoding=encoding))
if (res := rpc_unpack_result(next(self.recv_rpc_msg), encoding=encoding)) != '1':
raise RPCExcInvalidResult('XWB CREATE CONTEXT', context[0], res)
self.context = context[0]
logger.warning(f'RPC: {name} [{self.context}] {args}' if context else f'{name} {args}')
send_rpc_msg(self.sock, rpc_pack(name, *args, command=command, envelope=envelope, encoding=encoding))
return rpc_unpack_result(next(self.recv_rpc_msg), encoding=encoding)
def __getattr__(self, key: str, commands: set={'TCPConnect'}):
command = key in commands
setattr(self, key, (thunk := lambda *args, **kw: self(key, *args, **kw, command=command)))
return thunk
def __del__(self):
if isinstance(getattr(self, 'sock'), socket.socket):
self('#BYE#', command=True)
self.sock.close()
def close(self):
if (res := self('#BYE#', command=True)) != '#BYE#':
warnings.warn(f'RPC #BYE# returned {repr(res)} instead of \'#BYE#\'')
self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close()
self.sock = self.recv_rpc_msg = None
return res
def authenticate(self, identity: str, *, context=('XUS SIGNON',)):
self._server.update(RecordServerInfo(*self('XUS SIGNON SETUP', '', '1', context=context))._asdict())
if identity.startswith('<?xml version="1.0" encoding="UTF-8"?>'):
res = self('XUS ESSO VALIDATE', RPCType(tuple(identity[i:i+200] for i in range(0, len(identity), 200)), RPCType.GLOBAL))
else:
res = self('XUS AV CODE', XWBHash_encrypt(identity))
if res[0] == '0' or res[2] != '0':
raise RPCExcAuth(res[3], res)
self._user = res
return res
def keepalive(self, interval=None, *, context=('XUS SIGNON',)):
import time
interval = interval or 0.45*float(self.XWB_GET_BROKER_INFO(context=context)[0])
while True:
time.sleep(interval)
self.XWB_IM_HERE()
async def asend_rpc_msg(writer: asyncio.StreamWriter, msg: bytes, end: bytes=b'\x04'):
writer.write(msg + end)
await writer.drain()
async def arecv_rpc_msg(reader: asyncio.StreamReader, end: bytes=b'\x04', minsz: int=1024, maxsz: int=32768): # \x04 is EOT
buf = b''
bufsz = minsz
while True:
if len(data := await reader.read(bufsz)) > 0:
buf += data
while (idx := buf.find(end)) >= 0:
if idx > 0:
yield buf[:idx]
bufsz = minsz
elif bufsz < maxsz:
bufsz = _x if (_x := bufsz << 1) < maxsz else maxsz
buf = buf[idx + 1:]
else:
raise ConnectionAbortedError
class ClientAsync(object):
async def __new__(cls, *args, **kw):
await (self := super(ClientAsync, cls).__new__(cls)).__init__(*args, **kw)
return self
async def __init__(self, host: str, port: int, TCPConnect: bool=True):
self.reader, self.writer = await asyncio.open_connection(host, port)
self.arecv_rpc_msg = arecv_rpc_msg(self.reader)
self.lock = asyncio.Lock()
self._server = { 'host': host, 'port': port, 'info': None }
self._user = None
self.context = 'XUS SIGNON'
if TCPConnect and (res := await self.TCPConnect(self.writer.get_extra_info('sockname')[0], '0', socket.gethostname())) != 'accept':
raise RPCExcInvalidResult('TCPConnect', self.writer.get_extra_info('sockname')[0], '0', socket.gethostname(), res)
async def __call__(self, name: str, *args: Any, command: bool=False, envelope: int=0, context: Union[Sequence, None]=None, encoding='latin-1'):
name = name.replace('_', ' ')
async with self.lock:
if name != 'XWB CREATE CONTEXT' and context and len(context) > 0 and self.context not in context:
await asend_rpc_msg(self.writer, rpc_pack('XWB CREATE CONTEXT', XWBHash_encrypt(context[0]), envelope=envelope, encoding=encoding))
if (res := rpc_unpack_result(await self.arecv_rpc_msg.__anext__(), encoding=encoding)) != '1':
raise RPCExcInvalidResult('XWB CREATE CONTEXT', context[0], res)
self.context = context[0]
logger.warning(f'RPC: {name} [{self.context}] {args}' if context else f'{name} {args}')
await asend_rpc_msg(self.writer, rpc_pack(name, *args, command=command, envelope=envelope, encoding=encoding))
return rpc_unpack_result(await self.arecv_rpc_msg.__anext__(), encoding=encoding)
def __getattr__(self, key: str, commands: set={'TCPConnect'}):
command = key in commands
async def thunk(*args, **kw):
return await self(key, *args, **kw, command=command)
setattr(self, key, thunk)
return thunk
def __del__(self):
if isinstance(getattr(self, 'writer'), asyncio.StreamWriter):
try:
self.writer.close()
except RuntimeError:
pass
self.reader = self.writer = None
async def close(self):
if (res := await self('#BYE#', command=True)) != '#BYE#':
warnings.warn(f'RPC #BYE# returned {repr(res)} instead of \'#BYE#\'')
self.writer.close()
await self.writer.wait_closed()
self.reader = self.writer = None
return res
async def authenticate(self, identity: str, *, context=('XUS SIGNON',)):
self._server.update(RecordServerInfo(*await self('XUS SIGNON SETUP', '', '1', context=context))._asdict())
if identity.startswith('<?xml version="1.0" encoding="UTF-8"?>'):
res = await self('XUS ESSO VALIDATE', RPCType(tuple(identity[i:i+200] for i in range(0, len(identity), 200)), RPCType.GLOBAL))
else:
res = await self('XUS AV CODE', XWBHash_encrypt(identity))
if res[0] == '0' or res[2] != '0':
raise RPCExcAuth(res[3], res)
self._user = res
return res
async def keepalive(self, interval=None, *, context=('XUS SIGNON',)):
interval = interval or 0.45*float((await self.XWB_GET_BROKER_INFO(context=context))[0])
while True:
await asyncio.sleep(interval)
await self.XWB_IM_HERE()
if __name__ == '__main__':
import getpass, code
from auth import XUIAMSSOi_MySsoTokenVBA
client = ClientSync(host='test.northport.med.va.gov', port=19009)
#client = ClientSync(host='vista.northport.med.va.gov', port=19209)
threading.Thread(target=client.keepalive, daemon=True).start()
print('\r\n'.join(client.XUS_INTRO_MSG()))
if token := XUIAMSSOi_MySsoTokenVBA():
print('authenticate', repr(client.authenticate(token)))
else:
print('authenticate', repr(client.authenticate(f"{getpass.getpass('ACCESS CODE: ')};{getpass.getpass('VERIFY CODE: ')}")))
code.interact(local=globals())