#!/usr/bin/env python3 import sys import getpass import re import codecs import asyncio import contextlib import logging from collections import namedtuple from typing import Any, Optional, Union, Sequence, NamedTuple, Callable, AsyncGenerator logger = logging.getLogger(__name__) ExpectMatch = namedtuple('PatternMatch', ('batch', 'index', 'pattern', 'match', 'groups', 'groupdict', 'before')) ExpectMatch.__new__.__defaults__ = (None,)*len(ExpectMatch._fields) class ExpectQ(object): """Provide an expect-like interface over an asyncio queue""" def __init__(self, pipequeue: asyncio.Queue, timeout_settle: float=1): self.pipequeue = pipequeue self.buffer = '' self.timeout_settle = timeout_settle def set_timeout(self, timeout_settle: float=1): """Set default timeout""" self.timeout_settle = timeout_settle def reset(self, buffer: str=''): """Clear or restore buffer""" self.buffer = buffer clear = reset async def prompts(self, endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False) -> AsyncGenerator[tuple[Optional[str], Optional[int]], None]: len_endl = len(endl) while True: if (pos := self.buffer.rfind(endl)) >= 0: buffer = self.buffer self.buffer = '' yield buffer, pos + len_endl while True: try: self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle)) break except asyncio.TimeoutError: # no more data if throw: raise yield None, None async def promptmatches(self, *mappings: Union[str, re.Pattern, tuple, list], endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False) -> AsyncGenerator[tuple[Optional[ExpectMatch], Any], Optional[bool]]: for i, mapping in enumerate(mappings): try: match mapping: case (str() as pattern, response) if response is None or isinstance(response, str) or callable(response): async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True): if pattern == buffer[pos:]: yield (m := self.ExactMatch(batch=i, index=0, pattern=mapping, match=mapping, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response) break else: self.reset(buffer) case (re.Pattern() as pattern, response) if response is None or isinstance(response, str) or callable(response): async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True): if match := pattern.search(buffer[pos:]): yield (m := self.PatternMatch(batch=i, index=0, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response) break else: self.reset(buffer) case (*_,) as components: exact = {} expr = {} for j, component in enumerate(components): match component: case (str() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response): exact[pattern] = (j, response, None if len(rest) < 1 else rest[0]) case (re.Pattern() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response): expr[pattern] = (j, response, None if len(rest) < 1 else rest[0]) async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True): if buffer is not None: prompt = buffer[pos:] if prompt in exact: j, response, end = exact[prompt] interrupt = yield (m := self.ExactMatch(batch=i, index=j, pattern=prompt, match=prompt, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response) else: for pattern in expr: if match := pattern.search(prompt): j, response, end = expr[pattern] interrupt = yield (m := self.PatternMatch(batch=i, index=j, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response) break else: self.reset(buffer) continue if interrupt: yield break elif end: break except asyncio.TimeoutError as ex: # no more data if throw: raise asyncio.TimeoutError(*(ex.args + (i, mapping))) yield None, None async def earliest(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]: """Wait for any string or regular expression pattern match, specified in *patterns, and optionally raise exception upon timeout""" try: while True: for i, pattern in enumerate(patterns): # try every option if isinstance(pattern, str): if (pos := self.buffer.find(pattern)) >= 0: # found it res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:pos]) self.buffer = self.buffer[pos + len(pattern):] return res else: if match := pattern.search(self.buffer): # found it res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()]) self.buffer = self.buffer[match.end():] return res else: # fetch more data self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle)) except asyncio.TimeoutError: # no more data if throw: raise return None async def startswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]: """Wait for any string or regular expression pattern match, specified in *patterns, at the start of the stream and optionally raise exception upon timeout""" try: while True: for i, pattern in enumerate(patterns): # try every option if isinstance(pattern, str): if self.buffer.startswith(pattern): # found it res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before='') self.buffer = self.buffer[len(pattern):] return res else: if match := pattern.match(self.buffer): # found it res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()]) self.buffer = self.buffer[match.end():] return res else: # fetch more data self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle)) except asyncio.TimeoutError: # no more data if throw: raise return None async def endswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]: """Wait for any string or regular expression pattern match, specified in *patterns, at the end of the stream and optionally raise exception upon timeout""" try: while True: for i, pattern in enumerate(patterns): # try every option if isinstance(pattern, str): if self.buffer.endswith(pattern): # found it res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:-len(pattern)]) self.buffer = '' return res else: if match := pattern.search(self.buffer): # found it res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()]) self.buffer = self.buffer[match.end():] return res else: # fetch more data self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle)) except asyncio.TimeoutError: # no more data if throw: raise return None __call__ = earliest ExactMatch = type('ExactMatch', (ExpectMatch,), {}) PatternMatch = type('PatternMatch', (ExpectMatch,), {}) class LockableCallable(object): def __init__(self, func: Callable, lock: asyncio.Lock=None): if lock is None: lock = asyncio.Lock() self.lock = lock self.locked = lock.locked self.acquire = lock.acquire self.release = lock.release self.func = func self.__name__ = func.__name__ self.__doc__ = func.__doc__ def __call__(self, *args, **kw): return self.func(*args, **kw) async def __aenter__(self): await self.lock.acquire() async def __aexit__(self, exc_type, exc, tb): self.lock.release() async def withlock(self, *args, **kw): async with self.lock: return self.func(*args, **kw) async def create_instrumented_subprocess_exec(*args: str, stdin_endl=b'\n', **kw) -> asyncio.subprocess.Process: """Create asyncio subprocess, coupled to host stdio, with ability to attach tasks that could inspect its stdout and inject into its stdin""" process = await asyncio.create_subprocess_exec(*args, **kw) tasks = set() queues = set() def create_task(*args, **kw): tasks.add(item := asyncio.create_task(*args, **kw)) item.add_done_callback(tasks.remove) return item process.create_task = create_task def subscribe(pipequeue=None): queues.add(pipequeue := pipequeue or asyncio.Queue()) pipequeue.unsubscribe = lambda: queues.remove(pipequeue) return pipequeue process.subscribe = subscribe def sendline(data=None, endl=None): if data is not None: process.stdin.write(data.encode('utf-8') + (endl or stdin_endl)) else: process.stdin.write(endl or stdin_endl) process.sendline = LockableCallable(sendline) create_task(stdout_writer(process.stdout, queues), name='@task:stdout-writer') # stdout process_wait = process.wait async def wait_wrapper(): # clean up tasks at the end await process_wait() proc_id = id(process) logger.debug('SHUTDOWN [proc#%d]: cleaning up'%proc_id) for item in set(tasks): # copy set to avoid RuntimeError: Set changed size during iteration if not item.done(): item.cancel() try: logger.debug('SHUTDOWN [proc#%d]: stopping [task#%d] %r'%(proc_id, id(item), item)) await item except asyncio.CancelledError: pass logger.debug('SHUTDOWN [proc#%d]: stopped [task#%d]'%(proc_id, id(item))) logger.debug('SHUTDOWN [proc#%d]: done'%proc_id) process.wait = wait_wrapper return process async def stdout_writer(pipe: asyncio.StreamWriter, subscribers: Sequence[asyncio.Task], chunksize: int=4096, echo: bool=True): """Read data from pipe, decode into Unicode strings, and send to subscribers""" try: decoder = codecs.getincrementaldecoder('utf-8')(errors='replace') while True: try: chunk = await pipe.read(chunksize) # fetch a bunch of bytes if not chunk: # EOF break text = decoder.decode(chunk) except asyncio.TimeoutError: continue except UnicodeDecodeError: # should not encounter errors with errors='replace' logger.exception('stdout_writer') break # bail on error else: if echo: # echo to stdout sys.stdout.write(text) sys.stdout.flush() for item in subscribers: # distribute to subscribers await item.put(text) except KeyboardInterrupt: logger.info('KeyboardInterrupt: stdout_writer') @contextlib.contextmanager def subscribe(proc): queue = proc.subscribe() queue.sendline = proc.sendline try: yield queue finally: queue.unsubscribe() @contextlib.asynccontextmanager async def subscribe_async(proc): queue = proc.subscribe() queue.sendline = proc.sendline try: yield queue finally: queue.unsubscribe() @contextlib.contextmanager def expect(proc): queue = proc.subscribe() expect = ExpectQ(queue) expect.sendline = proc.sendline try: yield expect finally: queue.unsubscribe() @contextlib.asynccontextmanager async def expect_async(proc): queue = proc.subscribe() expect = ExpectQ(queue) expect.sendline = proc.sendline try: yield expect finally: queue.unsubscribe()