Compare commits

..

2 Commits

Author SHA1 Message Date
132c85c1fd Better SSO login experience 2023-04-25 19:42:00 -04:00
24291804a2 RPC value typecasting method 2023-04-25 19:37:22 -04:00

29
rpc.py
View File

@ -17,7 +17,16 @@ class RPCExcAuth(RPCExc): pass
class RPCExcServerError(RPCExc): pass class RPCExcServerError(RPCExc): pass
class RPCExcInvalidResult(RPCExc): pass class RPCExcInvalidResult(RPCExc): pass
class MReference(str): pass class RPCType(object):
LITERAL = b'0'
REFERENCE = b'1'
LIST = b'2'
GLOBAL = b'3'
EMPTY = b'4'
STREAM = b'5'
def __init__(self, value, magic=None):
self.magic = magic
self.value = value.value if isinstance(value, RPCType) else value
RecordServerInfo = namedtuple('RecordServerInfo', ('server', 'volume', 'uci', 'device', 'attempts', 'skip_signon_screen', 'domain', 'production')) RecordServerInfo = namedtuple('RecordServerInfo', ('server', 'volume', 'uci', 'device', 'attempts', 'skip_signon_screen', 'domain', 'production'))
@ -27,20 +36,20 @@ def s_pack(value: Any, encoding: str='latin-1'):
return bytes((len(encoded),)) + encoded return bytes((len(encoded),)) + encoded
raise ValueError('cannot s-pack string longer than 255 bytes: ' + repr(value)) raise ValueError('cannot s-pack string longer than 255 bytes: ' + repr(value))
def l_pack(value: Any, envelope: int=3, wrapped: bool=True, basictype=b'0', encoding: str='latin-1'): def l_pack(value: Any, envelope: int=3, wrapped: bool=True, magic=None, encoding: str='latin-1'):
if isinstance(value, dict): if isinstance(value, dict):
bare = b't'.join(l_pack(k, envelope=envelope, wrapped=False, encoding=encoding) + l_pack(v, envelope=envelope, wrapped=False, encoding=encoding) for k, v in value.items()) bare = b't'.join(l_pack(k, envelope=envelope, wrapped=False, encoding=encoding) + l_pack(v, envelope=envelope, wrapped=False, encoding=encoding) for k, v in value.items())
return (b'2' + bare + b'f') if wrapped else bare return ((magic or b'2') + bare + b'f') if wrapped else bare
elif not isinstance(value, str) and hasattr(value, '__iter__'): elif not isinstance(value, str) and hasattr(value, '__iter__'):
bare = b't'.join(l_pack(k, envelope=envelope, wrapped=False, encoding=encoding) + l_pack(v, envelope=envelope, wrapped=False, encoding=encoding) for k, v in enumerate(value, start=1)) bare = b't'.join(l_pack(k, envelope=envelope, wrapped=False, encoding=encoding) + l_pack(v, envelope=envelope, wrapped=False, encoding=encoding) for k, v in enumerate(value, start=1))
return (b'2' + bare + b'f') if wrapped else bare return ((magic or b'2') + bare + b'f') if wrapped else bare
elif isinstance(value, MReference): elif isinstance(value, RPCType):
return l_pack(str(value), envelope=envelope, basictype=b'1', encoding=encoding) return l_pack(value.value, envelope=envelope, magic=value.magic, encoding=encoding)
else: else:
encoded = str(value).encode(encoding) encoded = str(value).encode(encoding)
if len(encoded) <= 10**envelope - 1: if len(encoded) <= 10**envelope - 1:
bare = str(len(encoded)).zfill(envelope).encode(encoding) + encoded bare = str(len(encoded)).zfill(envelope).encode(encoding) + encoded
return (basictype + bare + b'f') if wrapped else bare return ((magic or b'0') + bare + b'f') if wrapped else bare
raise ValueError(f'cannot l-pack string longer than {10**envelope - 1} bytes with an envelope of {envelope}: ' + repr(value)) raise ValueError(f'cannot l-pack string longer than {10**envelope - 1} bytes with an envelope of {envelope}: ' + repr(value))
def l_pack_maxlen(value: Any, encoding: str='latin-1'): def l_pack_maxlen(value: Any, encoding: str='latin-1'):
@ -131,6 +140,9 @@ class ClientSync(object):
return res return res
def authenticate(self, identity: str, *, context=('XUS SIGNON',)): def authenticate(self, identity: str, *, context=('XUS SIGNON',)):
self._server.update(RecordServerInfo(*self('XUS SIGNON SETUP', '', '1', context=context))._asdict()) self._server.update(RecordServerInfo(*self('XUS SIGNON SETUP', '', '1', context=context))._asdict())
if identity.startswith('<?xml version="1.0" encoding="UTF-8"?>'):
res = self('XUS ESSO VALIDATE', RPCType(tuple(identity[i:i+200] for i in range(0, len(identity), 200)), RPCType.GLOBAL))
else:
res = self('XUS AV CODE', XWBHash_encrypt(identity)) res = self('XUS AV CODE', XWBHash_encrypt(identity))
if res[0] == '0' or res[2] != '0': if res[0] == '0' or res[2] != '0':
raise RPCExcAuth(res[3], res) raise RPCExcAuth(res[3], res)
@ -208,6 +220,9 @@ class ClientAsync(object):
return res return res
async def authenticate(self, identity: str, *, context=('XUS SIGNON',)): async def authenticate(self, identity: str, *, context=('XUS SIGNON',)):
self._server.update(RecordServerInfo(*await self('XUS SIGNON SETUP', '', '1', context=context))._asdict()) self._server.update(RecordServerInfo(*await self('XUS SIGNON SETUP', '', '1', context=context))._asdict())
if identity.startswith('<?xml version="1.0" encoding="UTF-8"?>'):
res = await self('XUS ESSO VALIDATE', RPCType(tuple(identity[i:i+200] for i in range(0, len(identity), 200)), RPCType.GLOBAL))
else:
res = await self('XUS AV CODE', XWBHash_encrypt(identity)) res = await self('XUS AV CODE', XWBHash_encrypt(identity))
if res[0] == '0' or res[2] != '0': if res[0] == '0' or res[2] != '0':
raise RPCExcAuth(res[3], res) raise RPCExcAuth(res[3], res)