#!/usr/bin/env python3 import sys import getpass import re import codecs import asyncio import contextlib import logging from collections import namedtuple from typing import Optional, Union, Sequence, NamedTuple, Callable logger = logging.getLogger(__name__) ExpectMatch = namedtuple('PatternMatch', ('batch', 'index', 'pattern', 'match', 'groups', 'groupdict', 'before')) ExpectMatch.__new__.__defaults__ = (None,)*len(ExpectMatch._fields) class ExpectQ(object): """Provide an expect-like interface over an asyncio queue""" def __init__(self, pipequeue: asyncio.Queue, timeout_settle: float=1): self.pipequeue = pipequeue self.buffer = '' self.timeout_settle = timeout_settle def set_timeout(self, timeout_settle: float=1): """Set default timeout""" self.timeout_settle = timeout_settle def reset(self, buffer: str=''): """Clear or restore buffer""" self.buffer = buffer clear = reset async def prompts(self, endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False): len_endl = len(endl) while True: if (pos := self.buffer.rfind(endl)) >= 0: buffer = self.buffer self.buffer = '' yield buffer, pos + len_endl while True: try: self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle)) break except asyncio.TimeoutError: # no more data if throw: raise yield None, None async def promptmatches(self, *mappings: Union[str, re.Pattern, tuple, list], endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False): for i, mapping in enumerate(mappings): try: match mapping: case (str() as pattern, response) if response is None or isinstance(response, str) or callable(response): async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True): if pattern == buffer[pos:]: yield (m := self.ExactMatch(batch=i, index=0, pattern=mapping, match=mapping, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response) break else: self.reset(buffer) case (re.Pattern() as pattern, response) if response is None or isinstance(response, str) or callable(response): async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True): if match := pattern.search(buffer[pos:]): yield (m := self.PatternMatch(batch=i, index=0, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response) break else: self.reset(buffer) case (*_,) as components: exact = {} expr = {} for j, component in enumerate(components): match component: case (str() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response): exact[pattern] = (j, response, None if len(rest) < 1 else rest[0]) case (re.Pattern() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response): expr[pattern] = (j, response, None if len(rest) < 1 else rest[0]) async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True): if buffer is not None: prompt = buffer[pos:] if prompt in exact: j, response, end = exact[prompt] interrupt = yield (m := self.ExactMatch(batch=i, index=j, pattern=prompt, match=prompt, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response) else: for pattern in expr: if match := pattern.search(prompt): j, response, end = expr[pattern] interrupt = yield (m := self.PatternMatch(batch=i, index=j, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response) break else: self.reset(buffer) continue if interrupt: yield break elif end: break except asyncio.TimeoutError as ex: # no more data if throw: raise asyncio.TimeoutError(*(ex.args + (i, mapping))) yield None, None async def earliest(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]: """Wait for any string or regular expression pattern match, specified in *patterns, and optionally raise exception upon timeout""" try: while True: for i, pattern in enumerate(patterns): # try every option if isinstance(pattern, str): if (pos := self.buffer.find(pattern)) >= 0: # found it res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:pos]) self.buffer = self.buffer[pos + len(pattern):] return res else: if match := pattern.search(self.buffer): # found it res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()]) self.buffer = self.buffer[match.end():] return res else: # fetch more data self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle)) except asyncio.TimeoutError: # no more data if throw: raise return None async def startswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]: """Wait for any string or regular expression pattern match, specified in *patterns, at the start of the stream and optionally raise exception upon timeout""" try: while True: for i, pattern in enumerate(patterns): # try every option if isinstance(pattern, str): if self.buffer.startswith(pattern): # found it res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before='') self.buffer = self.buffer[len(pattern):] return res else: if match := pattern.match(self.buffer): # found it res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()]) self.buffer = self.buffer[match.end():] return res else: # fetch more data self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle)) except asyncio.TimeoutError: # no more data if throw: raise return None async def endswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]: """Wait for any string or regular expression pattern match, specified in *patterns, at the end of the stream and optionally raise exception upon timeout""" try: while True: for i, pattern in enumerate(patterns): # try every option if isinstance(pattern, str): if self.buffer.endswith(pattern): # found it res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:-len(pattern)]) self.buffer = '' return res else: if match := pattern.search(self.buffer): # found it res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()]) self.buffer = self.buffer[match.end():] return res else: # fetch more data self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle)) except asyncio.TimeoutError: # no more data if throw: raise return None __call__ = earliest ExactMatch = type('ExactMatch', (ExpectMatch,), {}) PatternMatch = type('PatternMatch', (ExpectMatch,), {}) class LockableCallable(object): def __init__(self, func: Callable, lock: asyncio.Lock=None): if lock is None: lock = asyncio.Lock() self.lock = lock self.locked = lock.locked self.acquire = lock.acquire self.release = lock.release self.func = func self.__name__ = func.__name__ self.__doc__ = func.__doc__ def __call__(self, *args, **kw): return self.func(*args, **kw) async def __aenter__(self): await self.lock.acquire() async def __aexit__(self, exc_type, exc, tb): self.lock.release() async def withlock(self, *args, **kw): async with self.lock: return self.func(*args, **kw) async def create_instrumented_subprocess_exec(*args: str, stdin_endl=b'\n', **kw) -> asyncio.subprocess.Process: """Create asyncio subprocess, coupled to host stdio, with ability to attach tasks that could inspect its stdout and inject into its stdin""" process = await asyncio.create_subprocess_exec(*args, **kw) tasks = set() queues = set() def create_task(*args, **kw): tasks.add(item := asyncio.create_task(*args, **kw)) item.add_done_callback(tasks.remove) return item process.create_task = create_task def subscribe(pipequeue=None): queues.add(pipequeue := pipequeue or asyncio.Queue()) pipequeue.unsubscribe = lambda: queues.remove(pipequeue) return pipequeue process.subscribe = subscribe def sendline(data=None, endl=None): if data is not None: process.stdin.write(data.encode('utf-8') + (endl or stdin_endl)) else: process.stdin.write(endl or stdin_endl) process.sendline = LockableCallable(sendline) create_task(stdout_writer(process.stdout, queues), name='@task:stdout-writer') # stdout process_wait = process.wait async def wait_wrapper(): # clean up tasks at the end await process_wait() proc_id = id(process) logger.debug('SHUTDOWN [proc#%d]: cleaning up'%proc_id) for item in set(tasks): # copy set to avoid RuntimeError: Set changed size during iteration if not item.done(): item.cancel() try: logger.debug('SHUTDOWN [proc#%d]: stopping [task#%d] %r'%(proc_id, id(item), item)) await item except asyncio.CancelledError: pass logger.debug('SHUTDOWN [proc#%d]: stopped [task#%d]'%(proc_id, id(item))) logger.debug('SHUTDOWN [proc#%d]: done'%proc_id) process.wait = wait_wrapper return process async def stdout_writer(pipe: asyncio.StreamWriter, subscribers: Sequence[asyncio.Task], chunksize: int=4096, echo: bool=True): """Read data from pipe, decode into Unicode strings, and send to subscribers""" try: decoder = codecs.getincrementaldecoder('utf-8')(errors='replace') while True: try: chunk = await pipe.read(chunksize) # fetch a bunch of bytes if not chunk: # EOF break text = decoder.decode(chunk) except asyncio.TimeoutError: continue except UnicodeDecodeError: # should not encounter errors with errors='replace' logger.exception('stdout_writer') break # bail on error else: if echo: # echo to stdout sys.stdout.write(text) sys.stdout.flush() for item in subscribers: # distribute to subscribers await item.put(text) except KeyboardInterrupt: logger.info('KeyboardInterrupt: stdout_writer') @contextlib.contextmanager def subscribe(proc): queue = proc.subscribe() queue.sendline = proc.sendline try: yield queue finally: queue.unsubscribe() @contextlib.asynccontextmanager async def subscribe_async(proc): queue = proc.subscribe() queue.sendline = proc.sendline try: yield queue finally: queue.unsubscribe() @contextlib.contextmanager def expect(proc): queue = proc.subscribe() expect = ExpectQ(queue) expect.sendline = proc.sendline try: yield expect finally: queue.unsubscribe() @contextlib.asynccontextmanager async def expect_async(proc): queue = proc.subscribe() expect = ExpectQ(queue) expect.sendline = proc.sendline try: yield expect finally: queue.unsubscribe()