vistassh-py/autoproc.py

286 lines
12 KiB
Python

#!/usr/bin/env python3
import sys
import getpass
import re
import codecs
import asyncio
import contextlib
import logging
from collections import namedtuple
from typing import Any, Optional, Union, Sequence, NamedTuple, Callable, AsyncGenerator
logger = logging.getLogger(__name__)
ExpectMatch = namedtuple('PatternMatch', ('batch', 'index', 'pattern', 'match', 'groups', 'groupdict', 'before'))
ExpectMatch.__new__.__defaults__ = (None,)*len(ExpectMatch._fields)
class ExpectQ(object):
"""Provide an expect-like interface over an asyncio queue"""
def __init__(self, pipequeue: asyncio.Queue, timeout_settle: float=1):
self.pipequeue = pipequeue
self.buffer = ''
self.timeout_settle = timeout_settle
def set_timeout(self, timeout_settle: float=1):
"""Set default timeout"""
self.timeout_settle = timeout_settle
def reset(self, buffer: str=''):
"""Clear or restore buffer"""
self.buffer = buffer
clear = reset
async def prompts(self, endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False) -> AsyncGenerator[tuple[Optional[str], Optional[int]], None]:
len_endl = len(endl)
while True:
if (pos := self.buffer.rfind(endl)) >= 0:
buffer = self.buffer
self.buffer = ''
yield buffer, pos + len_endl
while True:
try:
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
break
except asyncio.TimeoutError: # no more data
if throw:
raise
yield None, None
async def promptmatches(self, *mappings: Union[str, re.Pattern, tuple, list], endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False) -> AsyncGenerator[tuple[Optional[ExpectMatch], Any], Optional[bool]]:
for i, mapping in enumerate(mappings):
try:
match mapping:
case (str() as pattern, response) if response is None or isinstance(response, str) or callable(response):
async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True):
if pattern == buffer[pos:]:
yield (m := self.ExactMatch(batch=i, index=0, pattern=mapping, match=mapping, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response)
break
else:
self.reset(buffer)
case (re.Pattern() as pattern, response) if response is None or isinstance(response, str) or callable(response):
async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True):
if match := pattern.search(buffer[pos:]):
yield (m := self.PatternMatch(batch=i, index=0, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response)
break
else:
self.reset(buffer)
case (*_,) as components:
exact = {}
expr = {}
for j, component in enumerate(components):
match component:
case (str() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response):
exact[pattern] = (j, response, None if len(rest) < 1 else rest[0])
case (re.Pattern() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response):
expr[pattern] = (j, response, None if len(rest) < 1 else rest[0])
async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True):
if buffer is not None:
prompt = buffer[pos:]
if prompt in exact:
j, response, end = exact[prompt]
interrupt = yield (m := self.ExactMatch(batch=i, index=j, pattern=prompt, match=prompt, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response)
else:
for pattern in expr:
if match := pattern.search(prompt):
j, response, end = expr[pattern]
interrupt = yield (m := self.PatternMatch(batch=i, index=j, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response)
break
else:
self.reset(buffer)
continue
if interrupt:
yield
break
elif end:
break
except asyncio.TimeoutError as ex: # no more data
if throw:
raise asyncio.TimeoutError(*(ex.args + (i, mapping)))
yield None, None
async def earliest(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]:
"""Wait for any string or regular expression pattern match, specified in *patterns, and optionally raise exception upon timeout"""
try:
while True:
for i, pattern in enumerate(patterns): # try every option
if isinstance(pattern, str):
if (pos := self.buffer.find(pattern)) >= 0: # found it
res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:pos])
self.buffer = self.buffer[pos + len(pattern):]
return res
else:
if match := pattern.search(self.buffer): # found it
res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()])
self.buffer = self.buffer[match.end():]
return res
else: # fetch more data
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
except asyncio.TimeoutError: # no more data
if throw:
raise
return None
async def startswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]:
"""Wait for any string or regular expression pattern match, specified in *patterns, at the start of the stream and optionally raise exception upon timeout"""
try:
while True:
for i, pattern in enumerate(patterns): # try every option
if isinstance(pattern, str):
if self.buffer.startswith(pattern): # found it
res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before='')
self.buffer = self.buffer[len(pattern):]
return res
else:
if match := pattern.match(self.buffer): # found it
res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()])
self.buffer = self.buffer[match.end():]
return res
else: # fetch more data
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
except asyncio.TimeoutError: # no more data
if throw:
raise
return None
async def endswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]:
"""Wait for any string or regular expression pattern match, specified in *patterns, at the end of the stream and optionally raise exception upon timeout"""
try:
while True:
for i, pattern in enumerate(patterns): # try every option
if isinstance(pattern, str):
if self.buffer.endswith(pattern): # found it
res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:-len(pattern)])
self.buffer = ''
return res
else:
if match := pattern.search(self.buffer): # found it
res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()])
self.buffer = self.buffer[match.end():]
return res
else: # fetch more data
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
except asyncio.TimeoutError: # no more data
if throw:
raise
return None
__call__ = earliest
ExactMatch = type('ExactMatch', (ExpectMatch,), {})
PatternMatch = type('PatternMatch', (ExpectMatch,), {})
class LockableCallable(object):
def __init__(self, func: Callable, lock: asyncio.Lock=None):
if lock is None:
lock = asyncio.Lock()
self.lock = lock
self.locked = lock.locked
self.acquire = lock.acquire
self.release = lock.release
self.func = func
self.__name__ = func.__name__
self.__doc__ = func.__doc__
def __call__(self, *args, **kw):
return self.func(*args, **kw)
async def __aenter__(self):
await self.lock.acquire()
async def __aexit__(self, exc_type, exc, tb):
self.lock.release()
async def withlock(self, *args, **kw):
async with self.lock:
return self.func(*args, **kw)
async def create_instrumented_subprocess_exec(*args: str, stdin_endl=b'\n', **kw) -> asyncio.subprocess.Process:
"""Create asyncio subprocess, coupled to host stdio, with ability to attach tasks that could inspect its stdout and inject into its stdin"""
process = await asyncio.create_subprocess_exec(*args, **kw)
tasks = set()
queues = set()
def create_task(*args, **kw):
tasks.add(item := asyncio.create_task(*args, **kw))
item.add_done_callback(tasks.remove)
return item
process.create_task = create_task
def subscribe(pipequeue=None):
queues.add(pipequeue := pipequeue or asyncio.Queue())
pipequeue.unsubscribe = lambda: queues.remove(pipequeue)
return pipequeue
process.subscribe = subscribe
def sendline(data=None, endl=None):
if data is not None:
process.stdin.write(data.encode('utf-8') + (endl or stdin_endl))
else:
process.stdin.write(endl or stdin_endl)
process.sendline = LockableCallable(sendline)
create_task(stdout_writer(process.stdout, queues), name='@task:stdout-writer') # stdout
process_wait = process.wait
async def wait_wrapper(): # clean up tasks at the end
await process_wait()
proc_id = id(process)
logger.debug('SHUTDOWN [proc#%d]: cleaning up'%proc_id)
for item in set(tasks): # copy set to avoid RuntimeError: Set changed size during iteration
if not item.done():
item.cancel()
try:
logger.debug('SHUTDOWN [proc#%d]: stopping [task#%d] %r'%(proc_id, id(item), item))
await item
except asyncio.CancelledError:
pass
logger.debug('SHUTDOWN [proc#%d]: stopped [task#%d]'%(proc_id, id(item)))
logger.debug('SHUTDOWN [proc#%d]: done'%proc_id)
process.wait = wait_wrapper
return process
async def stdout_writer(pipe: asyncio.StreamWriter, subscribers: Sequence[asyncio.Task], chunksize: int=4096, echo: bool=True):
"""Read data from pipe, decode into Unicode strings, and send to subscribers"""
try:
decoder = codecs.getincrementaldecoder('utf-8')(errors='replace')
while True:
try:
chunk = await pipe.read(chunksize) # fetch a bunch of bytes
if not chunk: # EOF
break
text = decoder.decode(chunk)
except asyncio.TimeoutError:
continue
except UnicodeDecodeError: # should not encounter errors with errors='replace'
logger.exception('stdout_writer')
break # bail on error
else:
if echo: # echo to stdout
sys.stdout.write(text)
sys.stdout.flush()
for item in subscribers: # distribute to subscribers
await item.put(text)
except KeyboardInterrupt:
logger.info('KeyboardInterrupt: stdout_writer')
@contextlib.contextmanager
def subscribe(proc):
queue = proc.subscribe()
queue.sendline = proc.sendline
try:
yield queue
finally:
queue.unsubscribe()
@contextlib.asynccontextmanager
async def subscribe_async(proc):
queue = proc.subscribe()
queue.sendline = proc.sendline
try:
yield queue
finally:
queue.unsubscribe()
@contextlib.contextmanager
def expect(proc):
queue = proc.subscribe()
expect = ExpectQ(queue)
expect.sendline = proc.sendline
try:
yield expect
finally:
queue.unsubscribe()
@contextlib.asynccontextmanager
async def expect_async(proc):
queue = proc.subscribe()
expect = ExpectQ(queue)
expect.sendline = proc.sendline
try:
yield expect
finally:
queue.unsubscribe()