vistassh-py/autoproc.py

286 lines
12 KiB
Python
Raw Permalink Normal View History

2024-03-02 00:34:29 -05:00
#!/usr/bin/env python3
import sys
import getpass
import re
import codecs
import asyncio
import contextlib
import logging
from collections import namedtuple
from typing import Optional, Union, Sequence, NamedTuple, Callable
logger = logging.getLogger(__name__)
ExpectMatch = namedtuple('PatternMatch', ('batch', 'index', 'pattern', 'match', 'groups', 'groupdict', 'before'))
ExpectMatch.__new__.__defaults__ = (None,)*len(ExpectMatch._fields)
class ExpectQ(object):
"""Provide an expect-like interface over an asyncio queue"""
def __init__(self, pipequeue: asyncio.Queue, timeout_settle: float=1):
self.pipequeue = pipequeue
self.buffer = ''
self.timeout_settle = timeout_settle
def set_timeout(self, timeout_settle: float=1):
"""Set default timeout"""
self.timeout_settle = timeout_settle
def reset(self, buffer: str=''):
"""Clear or restore buffer"""
self.buffer = buffer
clear = reset
async def prompts(self, endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False):
len_endl = len(endl)
while True:
if (pos := self.buffer.rfind(endl)) >= 0:
buffer = self.buffer
self.buffer = ''
yield buffer, pos + len_endl
while True:
try:
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
break
except asyncio.TimeoutError: # no more data
if throw:
raise
yield None, None
async def promptmatches(self, *mappings: Union[str, re.Pattern, tuple, list], endl: str='\r\n', timeout_settle: Optional[float]=None, throw: bool=False):
for i, mapping in enumerate(mappings):
try:
match mapping:
case (str() as pattern, response) if response is None or isinstance(response, str) or callable(response):
async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True):
if pattern == buffer[pos:]:
yield (m := self.ExactMatch(batch=i, index=0, pattern=mapping, match=mapping, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response)
break
else:
self.reset(buffer)
case (re.Pattern() as pattern, response) if response is None or isinstance(response, str) or callable(response):
async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True):
if match := pattern.search(buffer[pos:]):
yield (m := self.PatternMatch(batch=i, index=0, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response)
break
else:
self.reset(buffer)
case (*_,) as components:
exact = {}
expr = {}
for j, component in enumerate(components):
match component:
case (str() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response):
exact[pattern] = (j, response, None if len(rest) < 1 else rest[0])
case (re.Pattern() as pattern, response, *rest) if response is None or isinstance(response, str) or callable(response):
expr[pattern] = (j, response, None if len(rest) < 1 else rest[0])
async for buffer, pos in self.prompts(endl=endl, timeout_settle=timeout_settle, throw=True):
if buffer is not None:
prompt = buffer[pos:]
if prompt in exact:
j, response, end = exact[prompt]
interrupt = yield (m := self.ExactMatch(batch=i, index=j, pattern=prompt, match=prompt, groups=None, groupdict=None, before=buffer[:pos])), (response(m) if callable(response) else response)
else:
for pattern in expr:
if match := pattern.search(prompt):
j, response, end = expr[pattern]
interrupt = yield (m := self.PatternMatch(batch=i, index=j, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=buffer[:pos])), (response(m) if callable(response) else response)
break
else:
self.reset(buffer)
continue
if interrupt:
yield
break
elif end:
break
except asyncio.TimeoutError as ex: # no more data
if throw:
raise asyncio.TimeoutError(*(ex.args + (i, mapping)))
yield None, None
async def earliest(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]:
"""Wait for any string or regular expression pattern match, specified in *patterns, and optionally raise exception upon timeout"""
try:
while True:
for i, pattern in enumerate(patterns): # try every option
if isinstance(pattern, str):
if (pos := self.buffer.find(pattern)) >= 0: # found it
res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:pos])
self.buffer = self.buffer[pos + len(pattern):]
return res
else:
if match := pattern.search(self.buffer): # found it
res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()])
self.buffer = self.buffer[match.end():]
return res
else: # fetch more data
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
except asyncio.TimeoutError: # no more data
if throw:
raise
return None
async def startswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]:
"""Wait for any string or regular expression pattern match, specified in *patterns, at the start of the stream and optionally raise exception upon timeout"""
try:
while True:
for i, pattern in enumerate(patterns): # try every option
if isinstance(pattern, str):
if self.buffer.startswith(pattern): # found it
res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before='')
self.buffer = self.buffer[len(pattern):]
return res
else:
if match := pattern.match(self.buffer): # found it
res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()])
self.buffer = self.buffer[match.end():]
return res
else: # fetch more data
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
except asyncio.TimeoutError: # no more data
if throw:
raise
return None
async def endswith(self, *patterns: Union[str, re.Pattern], timeout_settle: Optional[float]=None, throw: bool=False) -> Optional[NamedTuple]:
"""Wait for any string or regular expression pattern match, specified in *patterns, at the end of the stream and optionally raise exception upon timeout"""
try:
while True:
for i, pattern in enumerate(patterns): # try every option
if isinstance(pattern, str):
if self.buffer.endswith(pattern): # found it
res = self.ExactMatch(index=i, pattern=pattern, match=pattern, groups=None, groupdict=None, before=self.buffer[:-len(pattern)])
self.buffer = ''
return res
else:
if match := pattern.search(self.buffer): # found it
res = self.PatternMatch(index=i, pattern=pattern, match=match, groups=match.groups(), groupdict=match.groupdict(), before=self.buffer[:match.start()])
self.buffer = self.buffer[match.end():]
return res
else: # fetch more data
self.buffer += await asyncio.wait_for(self.pipequeue.get(), timeout=(timeout_settle or self.timeout_settle))
except asyncio.TimeoutError: # no more data
if throw:
raise
return None
__call__ = earliest
ExactMatch = type('ExactMatch', (ExpectMatch,), {})
PatternMatch = type('PatternMatch', (ExpectMatch,), {})
class LockableCallable(object):
def __init__(self, func: Callable, lock: asyncio.Lock=None):
if lock is None:
lock = asyncio.Lock()
self.lock = lock
self.locked = lock.locked
self.acquire = lock.acquire
self.release = lock.release
self.func = func
self.__name__ = func.__name__
self.__doc__ = func.__doc__
def __call__(self, *args, **kw):
return self.func(*args, **kw)
async def __aenter__(self):
await self.lock.acquire()
async def __aexit__(self, exc_type, exc, tb):
self.lock.release()
async def withlock(self, *args, **kw):
async with self.lock:
return self.func(*args, **kw)
async def create_instrumented_subprocess_exec(*args: str, stdin_endl=b'\n', **kw) -> asyncio.subprocess.Process:
"""Create asyncio subprocess, coupled to host stdio, with ability to attach tasks that could inspect its stdout and inject into its stdin"""
process = await asyncio.create_subprocess_exec(*args, **kw)
tasks = set()
queues = set()
def create_task(*args, **kw):
tasks.add(item := asyncio.create_task(*args, **kw))
item.add_done_callback(tasks.remove)
return item
process.create_task = create_task
def subscribe(pipequeue=None):
queues.add(pipequeue := pipequeue or asyncio.Queue())
pipequeue.unsubscribe = lambda: queues.remove(pipequeue)
return pipequeue
process.subscribe = subscribe
def sendline(data=None, endl=None):
if data is not None:
process.stdin.write(data.encode('utf-8') + (endl or stdin_endl))
else:
process.stdin.write(endl or stdin_endl)
process.sendline = LockableCallable(sendline)
create_task(stdout_writer(process.stdout, queues), name='@task:stdout-writer') # stdout
process_wait = process.wait
async def wait_wrapper(): # clean up tasks at the end
await process_wait()
proc_id = id(process)
logger.debug('SHUTDOWN [proc#%d]: cleaning up'%proc_id)
for item in set(tasks): # copy set to avoid RuntimeError: Set changed size during iteration
if not item.done():
item.cancel()
try:
logger.debug('SHUTDOWN [proc#%d]: stopping [task#%d] %r'%(proc_id, id(item), item))
await item
except asyncio.CancelledError:
pass
logger.debug('SHUTDOWN [proc#%d]: stopped [task#%d]'%(proc_id, id(item)))
logger.debug('SHUTDOWN [proc#%d]: done'%proc_id)
process.wait = wait_wrapper
return process
async def stdout_writer(pipe: asyncio.StreamWriter, subscribers: Sequence[asyncio.Task], chunksize: int=4096, echo: bool=True):
"""Read data from pipe, decode into Unicode strings, and send to subscribers"""
try:
decoder = codecs.getincrementaldecoder('utf-8')(errors='replace')
while True:
try:
chunk = await pipe.read(chunksize) # fetch a bunch of bytes
if not chunk: # EOF
break
text = decoder.decode(chunk)
except asyncio.TimeoutError:
continue
except UnicodeDecodeError: # should not encounter errors with errors='replace'
logger.exception('stdout_writer')
break # bail on error
else:
if echo: # echo to stdout
sys.stdout.write(text)
sys.stdout.flush()
for item in subscribers: # distribute to subscribers
await item.put(text)
except KeyboardInterrupt:
logger.info('KeyboardInterrupt: stdout_writer')
@contextlib.contextmanager
def subscribe(proc):
queue = proc.subscribe()
queue.sendline = proc.sendline
try:
yield queue
finally:
queue.unsubscribe()
@contextlib.asynccontextmanager
async def subscribe_async(proc):
queue = proc.subscribe()
queue.sendline = proc.sendline
try:
yield queue
finally:
queue.unsubscribe()
@contextlib.contextmanager
def expect(proc):
queue = proc.subscribe()
expect = ExpectQ(queue)
expect.sendline = proc.sendline
try:
yield expect
finally:
queue.unsubscribe()
@contextlib.asynccontextmanager
async def expect_async(proc):
queue = proc.subscribe()
expect = ExpectQ(queue)
expect.sendline = proc.sendline
try:
yield expect
finally:
queue.unsubscribe()