2024-03-02 00:34:29 -05:00
#!/usr/bin/env python3
import sys
import getpass
import re
import codecs
import asyncio
import contextlib
import logging
from collections import namedtuple
2025-01-14 22:42:44 -05:00
from typing import Any , Optional , Union , Sequence , NamedTuple , Callable , AsyncGenerator
2024-03-02 00:34:29 -05:00
logger = logging . getLogger ( __name__ )
ExpectMatch = namedtuple ( ' PatternMatch ' , ( ' batch ' , ' index ' , ' pattern ' , ' match ' , ' groups ' , ' groupdict ' , ' before ' ) )
ExpectMatch . __new__ . __defaults__ = ( None , ) * len ( ExpectMatch . _fields )
class ExpectQ ( object ) :
""" Provide an expect-like interface over an asyncio queue """
def __init__ ( self , pipequeue : asyncio . Queue , timeout_settle : float = 1 ) :
self . pipequeue = pipequeue
self . buffer = ' '
self . timeout_settle = timeout_settle
def set_timeout ( self , timeout_settle : float = 1 ) :
""" Set default timeout """
self . timeout_settle = timeout_settle
def reset ( self , buffer : str = ' ' ) :
""" Clear or restore buffer """
self . buffer = buffer
clear = reset
2025-01-14 22:42:44 -05:00
async def prompts ( self , endl : str = ' \r \n ' , timeout_settle : Optional [ float ] = None , throw : bool = False ) - > AsyncGenerator [ tuple [ Optional [ str ] , Optional [ int ] ] , None ] :
2024-03-02 00:34:29 -05:00
len_endl = len ( endl )
while True :
if ( pos := self . buffer . rfind ( endl ) ) > = 0 :
buffer = self . buffer
self . buffer = ' '
yield buffer , pos + len_endl
while True :
try :
self . buffer + = await asyncio . wait_for ( self . pipequeue . get ( ) , timeout = ( timeout_settle or self . timeout_settle ) )
break
except asyncio . TimeoutError : # no more data
if throw :
raise
yield None , None
2025-01-14 22:42:44 -05:00
async def promptmatches ( self , * mappings : Union [ str , re . Pattern , tuple , list ] , endl : str = ' \r \n ' , timeout_settle : Optional [ float ] = None , throw : bool = False ) - > AsyncGenerator [ tuple [ Optional [ ExpectMatch ] , Any ] , Optional [ bool ] ] :
2024-03-02 00:34:29 -05:00
for i , mapping in enumerate ( mappings ) :
try :
match mapping :
case ( str ( ) as pattern , response ) if response is None or isinstance ( response , str ) or callable ( response ) :
async for buffer , pos in self . prompts ( endl = endl , timeout_settle = timeout_settle , throw = True ) :
if pattern == buffer [ pos : ] :
yield ( m := self . ExactMatch ( batch = i , index = 0 , pattern = mapping , match = mapping , groups = None , groupdict = None , before = buffer [ : pos ] ) ) , ( response ( m ) if callable ( response ) else response )
break
else :
self . reset ( buffer )
case ( re . Pattern ( ) as pattern , response ) if response is None or isinstance ( response , str ) or callable ( response ) :
async for buffer , pos in self . prompts ( endl = endl , timeout_settle = timeout_settle , throw = True ) :
if match := pattern . search ( buffer [ pos : ] ) :
yield ( m := self . PatternMatch ( batch = i , index = 0 , pattern = pattern , match = match , groups = match . groups ( ) , groupdict = match . groupdict ( ) , before = buffer [ : pos ] ) ) , ( response ( m ) if callable ( response ) else response )
break
else :
self . reset ( buffer )
case ( * _ , ) as components :
exact = { }
expr = { }
for j , component in enumerate ( components ) :
match component :
case ( str ( ) as pattern , response , * rest ) if response is None or isinstance ( response , str ) or callable ( response ) :
exact [ pattern ] = ( j , response , None if len ( rest ) < 1 else rest [ 0 ] )
case ( re . Pattern ( ) as pattern , response , * rest ) if response is None or isinstance ( response , str ) or callable ( response ) :
expr [ pattern ] = ( j , response , None if len ( rest ) < 1 else rest [ 0 ] )
async for buffer , pos in self . prompts ( endl = endl , timeout_settle = timeout_settle , throw = True ) :
if buffer is not None :
prompt = buffer [ pos : ]
if prompt in exact :
j , response , end = exact [ prompt ]
interrupt = yield ( m := self . ExactMatch ( batch = i , index = j , pattern = prompt , match = prompt , groups = None , groupdict = None , before = buffer [ : pos ] ) ) , ( response ( m ) if callable ( response ) else response )
else :
for pattern in expr :
if match := pattern . search ( prompt ) :
j , response , end = expr [ pattern ]
interrupt = yield ( m := self . PatternMatch ( batch = i , index = j , pattern = pattern , match = match , groups = match . groups ( ) , groupdict = match . groupdict ( ) , before = buffer [ : pos ] ) ) , ( response ( m ) if callable ( response ) else response )
break
else :
self . reset ( buffer )
continue
if interrupt :
yield
break
elif end :
break
except asyncio . TimeoutError as ex : # no more data
if throw :
raise asyncio . TimeoutError ( * ( ex . args + ( i , mapping ) ) )
yield None , None
async def earliest ( self , * patterns : Union [ str , re . Pattern ] , timeout_settle : Optional [ float ] = None , throw : bool = False ) - > Optional [ NamedTuple ] :
""" Wait for any string or regular expression pattern match, specified in *patterns, and optionally raise exception upon timeout """
try :
while True :
for i , pattern in enumerate ( patterns ) : # try every option
if isinstance ( pattern , str ) :
if ( pos := self . buffer . find ( pattern ) ) > = 0 : # found it
res = self . ExactMatch ( index = i , pattern = pattern , match = pattern , groups = None , groupdict = None , before = self . buffer [ : pos ] )
self . buffer = self . buffer [ pos + len ( pattern ) : ]
return res
else :
if match := pattern . search ( self . buffer ) : # found it
res = self . PatternMatch ( index = i , pattern = pattern , match = match , groups = match . groups ( ) , groupdict = match . groupdict ( ) , before = self . buffer [ : match . start ( ) ] )
self . buffer = self . buffer [ match . end ( ) : ]
return res
else : # fetch more data
self . buffer + = await asyncio . wait_for ( self . pipequeue . get ( ) , timeout = ( timeout_settle or self . timeout_settle ) )
except asyncio . TimeoutError : # no more data
if throw :
raise
return None
async def startswith ( self , * patterns : Union [ str , re . Pattern ] , timeout_settle : Optional [ float ] = None , throw : bool = False ) - > Optional [ NamedTuple ] :
""" Wait for any string or regular expression pattern match, specified in *patterns, at the start of the stream and optionally raise exception upon timeout """
try :
while True :
for i , pattern in enumerate ( patterns ) : # try every option
if isinstance ( pattern , str ) :
if self . buffer . startswith ( pattern ) : # found it
res = self . ExactMatch ( index = i , pattern = pattern , match = pattern , groups = None , groupdict = None , before = ' ' )
self . buffer = self . buffer [ len ( pattern ) : ]
return res
else :
if match := pattern . match ( self . buffer ) : # found it
res = self . PatternMatch ( index = i , pattern = pattern , match = match , groups = match . groups ( ) , groupdict = match . groupdict ( ) , before = self . buffer [ : match . start ( ) ] )
self . buffer = self . buffer [ match . end ( ) : ]
return res
else : # fetch more data
self . buffer + = await asyncio . wait_for ( self . pipequeue . get ( ) , timeout = ( timeout_settle or self . timeout_settle ) )
except asyncio . TimeoutError : # no more data
if throw :
raise
return None
async def endswith ( self , * patterns : Union [ str , re . Pattern ] , timeout_settle : Optional [ float ] = None , throw : bool = False ) - > Optional [ NamedTuple ] :
""" Wait for any string or regular expression pattern match, specified in *patterns, at the end of the stream and optionally raise exception upon timeout """
try :
while True :
for i , pattern in enumerate ( patterns ) : # try every option
if isinstance ( pattern , str ) :
if self . buffer . endswith ( pattern ) : # found it
res = self . ExactMatch ( index = i , pattern = pattern , match = pattern , groups = None , groupdict = None , before = self . buffer [ : - len ( pattern ) ] )
self . buffer = ' '
return res
else :
if match := pattern . search ( self . buffer ) : # found it
res = self . PatternMatch ( index = i , pattern = pattern , match = match , groups = match . groups ( ) , groupdict = match . groupdict ( ) , before = self . buffer [ : match . start ( ) ] )
self . buffer = self . buffer [ match . end ( ) : ]
return res
else : # fetch more data
self . buffer + = await asyncio . wait_for ( self . pipequeue . get ( ) , timeout = ( timeout_settle or self . timeout_settle ) )
except asyncio . TimeoutError : # no more data
if throw :
raise
return None
__call__ = earliest
ExactMatch = type ( ' ExactMatch ' , ( ExpectMatch , ) , { } )
PatternMatch = type ( ' PatternMatch ' , ( ExpectMatch , ) , { } )
class LockableCallable ( object ) :
def __init__ ( self , func : Callable , lock : asyncio . Lock = None ) :
if lock is None :
lock = asyncio . Lock ( )
self . lock = lock
self . locked = lock . locked
self . acquire = lock . acquire
self . release = lock . release
self . func = func
self . __name__ = func . __name__
self . __doc__ = func . __doc__
def __call__ ( self , * args , * * kw ) :
return self . func ( * args , * * kw )
async def __aenter__ ( self ) :
await self . lock . acquire ( )
async def __aexit__ ( self , exc_type , exc , tb ) :
self . lock . release ( )
async def withlock ( self , * args , * * kw ) :
async with self . lock :
return self . func ( * args , * * kw )
async def create_instrumented_subprocess_exec ( * args : str , stdin_endl = b ' \n ' , * * kw ) - > asyncio . subprocess . Process :
""" Create asyncio subprocess, coupled to host stdio, with ability to attach tasks that could inspect its stdout and inject into its stdin """
process = await asyncio . create_subprocess_exec ( * args , * * kw )
tasks = set ( )
queues = set ( )
def create_task ( * args , * * kw ) :
tasks . add ( item := asyncio . create_task ( * args , * * kw ) )
item . add_done_callback ( tasks . remove )
return item
process . create_task = create_task
def subscribe ( pipequeue = None ) :
queues . add ( pipequeue := pipequeue or asyncio . Queue ( ) )
pipequeue . unsubscribe = lambda : queues . remove ( pipequeue )
return pipequeue
process . subscribe = subscribe
def sendline ( data = None , endl = None ) :
if data is not None :
process . stdin . write ( data . encode ( ' utf-8 ' ) + ( endl or stdin_endl ) )
else :
process . stdin . write ( endl or stdin_endl )
process . sendline = LockableCallable ( sendline )
create_task ( stdout_writer ( process . stdout , queues ) , name = ' @task:stdout-writer ' ) # stdout
process_wait = process . wait
async def wait_wrapper ( ) : # clean up tasks at the end
await process_wait ( )
proc_id = id ( process )
logger . debug ( ' SHUTDOWN [proc# %d ]: cleaning up ' % proc_id )
for item in set ( tasks ) : # copy set to avoid RuntimeError: Set changed size during iteration
if not item . done ( ) :
item . cancel ( )
try :
logger . debug ( ' SHUTDOWN [proc# %d ]: stopping [task# %d ] %r ' % ( proc_id , id ( item ) , item ) )
await item
except asyncio . CancelledError :
pass
logger . debug ( ' SHUTDOWN [proc# %d ]: stopped [task# %d ] ' % ( proc_id , id ( item ) ) )
logger . debug ( ' SHUTDOWN [proc# %d ]: done ' % proc_id )
process . wait = wait_wrapper
return process
async def stdout_writer ( pipe : asyncio . StreamWriter , subscribers : Sequence [ asyncio . Task ] , chunksize : int = 4096 , echo : bool = True ) :
""" Read data from pipe, decode into Unicode strings, and send to subscribers """
try :
decoder = codecs . getincrementaldecoder ( ' utf-8 ' ) ( errors = ' replace ' )
while True :
try :
chunk = await pipe . read ( chunksize ) # fetch a bunch of bytes
if not chunk : # EOF
break
text = decoder . decode ( chunk )
except asyncio . TimeoutError :
continue
except UnicodeDecodeError : # should not encounter errors with errors='replace'
logger . exception ( ' stdout_writer ' )
break # bail on error
else :
if echo : # echo to stdout
sys . stdout . write ( text )
sys . stdout . flush ( )
for item in subscribers : # distribute to subscribers
await item . put ( text )
except KeyboardInterrupt :
logger . info ( ' KeyboardInterrupt: stdout_writer ' )
@contextlib.contextmanager
def subscribe ( proc ) :
queue = proc . subscribe ( )
queue . sendline = proc . sendline
try :
yield queue
finally :
queue . unsubscribe ( )
@contextlib.asynccontextmanager
async def subscribe_async ( proc ) :
queue = proc . subscribe ( )
queue . sendline = proc . sendline
try :
yield queue
finally :
queue . unsubscribe ( )
@contextlib.contextmanager
def expect ( proc ) :
queue = proc . subscribe ( )
expect = ExpectQ ( queue )
expect . sendline = proc . sendline
try :
yield expect
finally :
queue . unsubscribe ( )
@contextlib.asynccontextmanager
async def expect_async ( proc ) :
queue = proc . subscribe ( )
expect = ExpectQ ( queue )
expect . sendline = proc . sendline
try :
yield expect
finally :
queue . unsubscribe ( )