286 lines
12 KiB
Python
286 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import sys
|
|
import getpass
|
|
import re
|
|
import codecs
|
|
import asyncio
|
|
import contextlib
|
|
import logging
|
|
from collections import namedtuple
|
|
|
|
from typing import Optional, Union, Sequence, NamedTuple, Callable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ExpectMatch = namedtuple('PatternMatch', ('batch', 'index', 'pattern', 'match', 'groups', 'groupdict', 'before'))
|
|
ExpectMatch.__new__.__defaults__ = (None,)*len(ExpectMatch._fields)
|
|
class ExpectQ(object):
|
|
"""Provide an expect-like interface over an asyncio queue"""
|
|
def __init__(self, pipequeue: asyncio.Queue, timeout_settle: float=1):
|
|
self.pipequeue = pipequeue
|
|
self.buffer = ''
|
|
self.timeout_settle = timeout_settle
|
|
def set_timeout(self, timeout_settle: float=1):
|
|
"""Set default timeout"""
|
|
self.timeout_settle = timeout_settle
|
|
def reset(self, buffer: str=''):
|
|
"""Clear or restore buffer"""
|
|
self.buffer = buffer
|
|
clear = reset
|
|
async def prompts(self, endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False):
|
|
len_endl = len(endl)
|
|
while True:
|
|
if (pos := self.buffer.rfind(endl)) >= 0:
|
|
buffer = self.buffer
|
|
self.buffer = ''
|
|
yield buffer, pos + len_endl
|
|
while True:
|
|
try:
|
|
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
|
|
break
|
|
except asyncio.TimeoutError: # no more data
|
|
if throw:
|
|
raise
|
|
yield None, None
|
|
async def promptmatches(self, *mappings: Union[str, re.Pattern, tuple, list], endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False):
|
|
for i, mapping in enumerate(mappings):
|
|
try:
|
|
match mapping:
|
|
case (str() as pattern, response) if response is None or isinstance(response, str) or callable(response):
|
|
async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True):
|
|
if pattern == buffer[pos:]:
|
|
yield (m := self.ExactMatch(batch=i, index=0, pattern=mapping, match=mapping, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response)
|
|
break
|
|
else:
|
|
self.reset(buffer)
|
|
case (re.Pattern() as pattern, response) if response is None or isinstance(response, str) or callable(response):
|
|
async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True):
|
|
if match := pattern.search(buffer[pos:]):
|
|
yield (m := self.PatternMatch(batch=i, index=0, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response)
|
|
break
|
|
else:
|
|
self.reset(buffer)
|
|
case (*_,) as components:
|
|
exact = {}
|
|
expr = {}
|
|
for j, component in enumerate(components):
|
|
match component:
|
|
case (str() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response):
|
|
exact[pattern] = (j, response, None if len(rest) < 1 else rest[0])
|
|
case (re.Pattern() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response):
|
|
expr[pattern] = (j, response, None if len(rest) < 1 else rest[0])
|
|
async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True):
|
|
if buffer is not None:
|
|
prompt = buffer[pos:]
|
|
if prompt in exact:
|
|
j, response, end = exact[prompt]
|
|
interrupt = yield (m := self.ExactMatch(batch=i, index=j, pattern=prompt, match=prompt, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response)
|
|
else:
|
|
for pattern in expr:
|
|
if match := pattern.search(prompt):
|
|
j, response, end = expr[pattern]
|
|
interrupt = yield (m := self.PatternMatch(batch=i, index=j, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response)
|
|
break
|
|
else:
|
|
self.reset(buffer)
|
|
continue
|
|
if interrupt:
|
|
yield
|
|
break
|
|
elif end:
|
|
break
|
|
except asyncio.TimeoutError as ex: # no more data
|
|
if throw:
|
|
raise asyncio.TimeoutError(*(ex.args + (i, mapping)))
|
|
yield None, None
|
|
async def earliest(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]:
|
|
"""Wait for any string or regular expression pattern match, specified in *patterns, and optionally raise exception upon timeout"""
|
|
try:
|
|
while True:
|
|
for i, pattern in enumerate(patterns): # try every option
|
|
if isinstance(pattern, str):
|
|
if (pos := self.buffer.find(pattern)) >= 0: # found it
|
|
res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:pos])
|
|
self.buffer = self.buffer[pos + len(pattern):]
|
|
return res
|
|
else:
|
|
if match := pattern.search(self.buffer): # found it
|
|
res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()])
|
|
self.buffer = self.buffer[match.end():]
|
|
return res
|
|
else: # fetch more data
|
|
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
|
|
except asyncio.TimeoutError: # no more data
|
|
if throw:
|
|
raise
|
|
return None
|
|
async def startswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]:
|
|
"""Wait for any string or regular expression pattern match, specified in *patterns, at the start of the stream and optionally raise exception upon timeout"""
|
|
try:
|
|
while True:
|
|
for i, pattern in enumerate(patterns): # try every option
|
|
if isinstance(pattern, str):
|
|
if self.buffer.startswith(pattern): # found it
|
|
res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before='')
|
|
self.buffer = self.buffer[len(pattern):]
|
|
return res
|
|
else:
|
|
if match := pattern.match(self.buffer): # found it
|
|
res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()])
|
|
self.buffer = self.buffer[match.end():]
|
|
return res
|
|
else: # fetch more data
|
|
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
|
|
except asyncio.TimeoutError: # no more data
|
|
if throw:
|
|
raise
|
|
return None
|
|
async def endswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]:
|
|
"""Wait for any string or regular expression pattern match, specified in *patterns, at the end of the stream and optionally raise exception upon timeout"""
|
|
try:
|
|
while True:
|
|
for i, pattern in enumerate(patterns): # try every option
|
|
if isinstance(pattern, str):
|
|
if self.buffer.endswith(pattern): # found it
|
|
res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:-len(pattern)])
|
|
self.buffer = ''
|
|
return res
|
|
else:
|
|
if match := pattern.search(self.buffer): # found it
|
|
res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()])
|
|
self.buffer = self.buffer[match.end():]
|
|
return res
|
|
else: # fetch more data
|
|
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
|
|
except asyncio.TimeoutError: # no more data
|
|
if throw:
|
|
raise
|
|
return None
|
|
__call__ = earliest
|
|
ExactMatch = type('ExactMatch', (ExpectMatch,), {})
|
|
PatternMatch = type('PatternMatch', (ExpectMatch,), {})
|
|
|
|
class LockableCallable(object):
|
|
def __init__(self, func: Callable, lock: asyncio.Lock=None):
|
|
if lock is None:
|
|
lock = asyncio.Lock()
|
|
self.lock = lock
|
|
self.locked = lock.locked
|
|
self.acquire = lock.acquire
|
|
self.release = lock.release
|
|
self.func = func
|
|
self.__name__ = func.__name__
|
|
self.__doc__ = func.__doc__
|
|
def __call__(self, *args, **kw):
|
|
return self.func(*args, **kw)
|
|
async def __aenter__(self):
|
|
await self.lock.acquire()
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
self.lock.release()
|
|
async def withlock(self, *args, **kw):
|
|
async with self.lock:
|
|
return self.func(*args, **kw)
|
|
|
|
async def create_instrumented_subprocess_exec(*args: str, stdin_endl=b'\n', **kw) -> asyncio.subprocess.Process:
|
|
"""Create asyncio subprocess, coupled to host stdio, with ability to attach tasks that could inspect its stdout and inject into its stdin"""
|
|
process = await asyncio.create_subprocess_exec(*args, **kw)
|
|
tasks = set()
|
|
queues = set()
|
|
def create_task(*args, **kw):
|
|
tasks.add(item := asyncio.create_task(*args, **kw))
|
|
item.add_done_callback(tasks.remove)
|
|
return item
|
|
process.create_task = create_task
|
|
def subscribe(pipequeue=None):
|
|
queues.add(pipequeue := pipequeue or asyncio.Queue())
|
|
pipequeue.unsubscribe = lambda: queues.remove(pipequeue)
|
|
return pipequeue
|
|
process.subscribe = subscribe
|
|
def sendline(data=None, endl=None):
|
|
if data is not None:
|
|
process.stdin.write(data.encode('utf-8') + (endl or stdin_endl))
|
|
else:
|
|
process.stdin.write(endl or stdin_endl)
|
|
process.sendline = LockableCallable(sendline)
|
|
create_task(stdout_writer(process.stdout, queues), name='@task:stdout-writer') # stdout
|
|
process_wait = process.wait
|
|
async def wait_wrapper(): # clean up tasks at the end
|
|
await process_wait()
|
|
proc_id = id(process)
|
|
logger.debug('SHUTDOWN [proc#%d]: cleaning up'%proc_id)
|
|
for item in set(tasks): # copy set to avoid RuntimeError: Set changed size during iteration
|
|
if not item.done():
|
|
item.cancel()
|
|
try:
|
|
logger.debug('SHUTDOWN [proc#%d]: stopping [task#%d] %r'%(proc_id, id(item), item))
|
|
await item
|
|
except asyncio.CancelledError:
|
|
pass
|
|
logger.debug('SHUTDOWN [proc#%d]: stopped [task#%d]'%(proc_id, id(item)))
|
|
logger.debug('SHUTDOWN [proc#%d]: done'%proc_id)
|
|
process.wait = wait_wrapper
|
|
return process
|
|
|
|
async def stdout_writer(pipe: asyncio.StreamWriter, subscribers: Sequence[asyncio.Task], chunksize: int=4096, echo: bool=True):
|
|
"""Read data from pipe, decode into Unicode strings, and send to subscribers"""
|
|
try:
|
|
decoder = codecs.getincrementaldecoder('utf-8')(errors='replace')
|
|
while True:
|
|
try:
|
|
chunk = await pipe.read(chunksize) # fetch a bunch of bytes
|
|
if not chunk: # EOF
|
|
break
|
|
text = decoder.decode(chunk)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
except UnicodeDecodeError: # should not encounter errors with errors='replace'
|
|
logger.exception('stdout_writer')
|
|
break # bail on error
|
|
else:
|
|
if echo: # echo to stdout
|
|
sys.stdout.write(text)
|
|
sys.stdout.flush()
|
|
for item in subscribers: # distribute to subscribers
|
|
await item.put(text)
|
|
except KeyboardInterrupt:
|
|
logger.info('KeyboardInterrupt: stdout_writer')
|
|
|
|
@contextlib.contextmanager
|
|
def subscribe(proc):
|
|
queue = proc.subscribe()
|
|
queue.sendline = proc.sendline
|
|
try:
|
|
yield queue
|
|
finally:
|
|
queue.unsubscribe()
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def subscribe_async(proc):
|
|
queue = proc.subscribe()
|
|
queue.sendline = proc.sendline
|
|
try:
|
|
yield queue
|
|
finally:
|
|
queue.unsubscribe()
|
|
|
|
@contextlib.contextmanager
|
|
def expect(proc):
|
|
queue = proc.subscribe()
|
|
expect = ExpectQ(queue)
|
|
expect.sendline = proc.sendline
|
|
try:
|
|
yield expect
|
|
finally:
|
|
queue.unsubscribe()
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def expect_async(proc):
|
|
queue = proc.subscribe()
|
|
expect = ExpectQ(queue)
|
|
expect.sendline = proc.sendline
|
|
try:
|
|
yield expect
|
|
finally:
|
|
queue.unsubscribe()
|