261 lines
10 KiB
Python
261 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import sys
|
|
import sqlite3
|
|
import ctypes
|
|
import functools
|
|
import inspect
|
|
import traceback
|
|
|
|
import direct
|
|
from direct import libsqlite3, Py_IncRef, Py_DecRef, c_sqlite3_module, c_sqlite3_value_p_p, map_sqlite3_value, map_sqlite3_result, SQLITE_OK, SQLITE_ERROR, SQLITE_CONSTRAINT, SQLITE_INDEX_CONSTRAINT_EQ
|
|
|
|
USE_SQLITE_CONSTRAINT = tuple(map(int, sqlite3.sqlite_version.split('.'))) >= (3, 26, 0)
|
|
|
|
@direct.c_sqlite3_module._types_.xConnect
|
|
def vtab_xConnect(pDB, pAux, argc, argv, ppVTab, pzErr): # sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr
|
|
"""
|
|
Invoked whenever a database connection attaches to or reparses a schema.
|
|
A virtual table is eponymous if its xCreate method is the exact same function as the xConnect method, or if the xCreate method is NULL.
|
|
"""
|
|
tableFactory = ctypes.cast(pAux, ctypes.py_object).value
|
|
res = libsqlite3.sqlite3_declare_vtab(pDB, f'CREATE TABLE x({vtab_columns_declaration(tableFactory)})'.encode('utf-8'))
|
|
if res == SQLITE_OK:
|
|
szVTab = ctypes.sizeof(direct.c_sqlite3_vtab)
|
|
pVTab = ctypes.cast(libsqlite3.sqlite3_malloc(szVTab), direct.c_sqlite3_vtab_p)
|
|
ctypes.memset(pVTab, 0, szVTab)
|
|
pVTab.contents.tableFactory = ctypes.cast(pAux, ctypes.py_object)
|
|
Py_IncRef(pVTab.contents.tableFactory)
|
|
ppVTab[0] = pVTab
|
|
return res
|
|
|
|
@direct.c_sqlite3_module._types_.xBestIndex
|
|
def vtab_xBestIndex(pVTab, pIdxInfo): # sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo
|
|
"""Used to determine the best way to access the virtual table."""
|
|
tableFactory = pVTab.contents.tableFactory
|
|
vIdxInfo = pIdxInfo.contents
|
|
nArg = 0
|
|
columns = []
|
|
nParams = len(tableFactory._vtparams_)
|
|
nColumns = len(tableFactory._vtcolumns_)
|
|
for i in range(vIdxInfo.nConstraint):
|
|
vConstraint = vIdxInfo.aConstraint[i]
|
|
if vConstraint.usable and vConstraint.op == SQLITE_INDEX_CONSTRAINT_EQ:
|
|
columns.append(tableFactory._vtparams_[vConstraint.iColumn - nColumns])
|
|
nArg += 1
|
|
vIdxInfo.aConstraintUsage[i].argvIndex = nArg
|
|
vIdxInfo.aConstraintUsage[i].omit = 1
|
|
if nArg > 0 or nParams == 0:
|
|
if nArg == nParams:
|
|
# All parameters are present, this is ideal.
|
|
vIdxInfo.estimatedCost = 1.0
|
|
vIdxInfo.estimatedRows = 10
|
|
else:
|
|
# Penalize score based on number of missing params.
|
|
vIdxInfo.estimatedCost = 1e13 * (nParams - nArg)
|
|
vIdxInfo.estimatedRows = 10 ** (nParams - nArg)
|
|
# Store a reference to the columns in the index info structure.
|
|
joinedCols = ','.join(columns).encode('utf-8')
|
|
idxStr = libsqlite3.sqlite3_malloc((len(joinedCols) + 1) * ctypes.sizeof(ctypes.c_char))
|
|
idxStr = (ctypes.c_char*(len(joinedCols) + 1)).from_address(idxStr)
|
|
ctypes.memmove(idxStr, joinedCols, len(joinedCols)) # memcpy
|
|
idxStr[len(joinedCols)] = b'\x00'
|
|
vIdxInfo.idxStr = ctypes.cast(idxStr, ctypes.c_char_p)
|
|
vIdxInfo.needToFreeIdxStr = 0
|
|
elif USE_SQLITE_CONSTRAINT:
|
|
return SQLITE_CONSTRAINT
|
|
else:
|
|
vIdxInfo.estimatedCost = DBL_MAX
|
|
vIdxInfo.estimatedRows = 100000
|
|
return SQLITE_OK
|
|
|
|
@direct.c_sqlite3_module._types_.xDisconnect
|
|
def vtab_xDisconnect(pVTab): # sqlite3_vtab *pVTab
|
|
"""Releases a connection to a virtual table."""
|
|
Py_DecRef(pVTab.contents.tableFactory)
|
|
libsqlite3.sqlite3_free(pVTab)
|
|
return SQLITE_OK
|
|
|
|
@direct.c_sqlite3_module._types_.xOpen
|
|
def vtab_xOpen(pVTab, ppCursor): # sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor
|
|
"""Creates a new cursor used for accessing (read and/or writing) a virtual table."""
|
|
# tableFactory = pVTab.contents.tableFactory
|
|
szCursor = ctypes.sizeof(direct.c_sqlite3_vtab_cursor)
|
|
pCursor = ctypes.cast(libsqlite3.sqlite3_malloc(szCursor), direct.c_sqlite3_vtab_cursor_p)
|
|
ctypes.memset(pCursor, 0, szCursor)
|
|
vCursor = pCursor.contents
|
|
vCursor.tableFactory = pVTab.contents.tableFactory
|
|
Py_IncRef(vCursor.tableFactory)
|
|
vCursor.rowData = None
|
|
vCursor.rowIdx = 0
|
|
ppCursor[0] = pCursor
|
|
return SQLITE_OK
|
|
|
|
@direct.c_sqlite3_module._types_.xClose
|
|
def vtab_xClose(pCursor): # sqlite3_vtab_cursor *pCursor
|
|
"""Closes a cursor previously opened by xOpen."""
|
|
Py_DecRef(pCursor.contents.tableFactory)
|
|
libsqlite3.sqlite3_free(pCursor)
|
|
return SQLITE_OK
|
|
|
|
def vtab_iterate(vCursor):
|
|
if vCursor.rowData:
|
|
Py_DecRef(vCursor.rowData)
|
|
vCursor.rowData = None
|
|
try:
|
|
vCursor.rowData = next(vCursor.tableIterator)
|
|
Py_IncRef(vCursor.rowData)
|
|
except StopIteration:
|
|
Py_DecRef(vCursor.tableIterator)
|
|
vCursor.tableIterator = None
|
|
except:
|
|
traceback.print_exc()
|
|
return SQLITE_ERROR
|
|
else:
|
|
vCursor.rowIdx += 1
|
|
return SQLITE_OK
|
|
|
|
@direct.c_sqlite3_module._types_.xFilter
|
|
def vtab_xFilter(pCursor, idxNum, idxStr, argc, argv): # sqlite3_vtab_cursor *pCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv
|
|
"""Begins a search of a virtual table."""
|
|
vCursor = pCursor.contents
|
|
tableFactory = vCursor.tableFactory
|
|
idxStr = ctypes.cast(idxStr, ctypes.c_char_p).value # avoid null terminator
|
|
if not idxStr or argc == 0 and len(tableFactory._vtparams_):
|
|
return SQLITE_ERROR
|
|
elif len(idxStr):
|
|
params = idxStr.decode().split(',')
|
|
else:
|
|
params = []
|
|
py_args = tuple(map_sqlite3_value[libsqlite3.sqlite3_value_type(arg := argv[i])](arg) for i in range(argc))
|
|
query = dict(zip(params, py_args))
|
|
try:
|
|
tableIterator = iter(tableFactory(**query))
|
|
except:
|
|
traceback.print_exc()
|
|
return SQLITE_ERROR
|
|
vCursor.tableIterator = tableIterator
|
|
Py_IncRef(vCursor.tableIterator)
|
|
return vtab_iterate(vCursor)
|
|
|
|
@direct.c_sqlite3_module._types_.xNext
|
|
def vtab_xNext(pCursor): # sqlite3_vtab_cursor *pCursor
|
|
"""Advances a virtual table cursor to the next row of a result set initiated by xFilter."""
|
|
return vtab_iterate(pCursor.contents)
|
|
|
|
@direct.c_sqlite3_module._types_.xEof
|
|
def vtab_xEof(pCursor): # sqlite3_vtab_cursor *pCursor
|
|
"""Returns false (zero) if the specified cursor currently points to a valid row of data, or true (non-zero) otherwise."""
|
|
return 1 if pCursor.contents.tableIterator is None else 0
|
|
|
|
@direct.c_sqlite3_module._types_.xColumn
|
|
def vtab_xColumn(pCursor, pContext, iCol): # sqlite3_vtab_cursor *pCursor, sqlite3_context *pContext, int iCol
|
|
"""Invoked to retrieve the N-th column of the current row."""
|
|
if iCol == -1:
|
|
libsqlite3.sqlite3_result_int64(pContext, pCursor.contents.rowIdx)
|
|
return SQLITE_OK
|
|
if not pCursor.contents.rowData:
|
|
libsqlite3.sqlite3_result_error(pContext, 'no row data'.encode('utf-8'), -1)
|
|
return SQLITE_ERROR
|
|
map_sqlite3_result[type(value := pCursor.contents.rowData[iCol])](pContext, value)
|
|
return SQLITE_OK
|
|
|
|
@direct.c_sqlite3_module._types_.xRowid
|
|
def vtab_xRowid(pCursor, pRowid): # sqlite3_vtab_cursor *pCursor, sqlite3_int64 *pRowid
|
|
"""Fills *pRowid with the rowid of the current row."""
|
|
pRowid[0] = pCursor.contents.rowIdx
|
|
|
|
def new_sqlite3_module_vtabfunc():
|
|
module = direct.c_sqlite3_module()
|
|
ctypes.memset(ctypes.byref(module), 0, ctypes.sizeof(direct.c_sqlite3_module))
|
|
module.iVersion = 0
|
|
#module.xCreate = direct.c_sqlite3_module._types_.xCreate(0)
|
|
module.xConnect = vtab_xConnect
|
|
module.xBestIndex = vtab_xBestIndex
|
|
module.xDisconnect = vtab_xDisconnect
|
|
#module.xDestroy = direct.c_sqlite3_module._types_.xDestroy(0)
|
|
module.xOpen = vtab_xOpen
|
|
module.xClose = vtab_xClose
|
|
module.xFilter = vtab_xFilter
|
|
module.xNext = vtab_xNext
|
|
module.xEof = vtab_xEof
|
|
module.xColumn = vtab_xColumn
|
|
module.xRowid = vtab_xRowid
|
|
#module.xUpdate = direct.c_sqlite3_module._types_.xUpdate(0)
|
|
#module.xBegin = direct.c_sqlite3_module._types_.xBegin(0)
|
|
#module.xSync = direct.c_sqlite3_module._types_.xSync(0)
|
|
#module.xCommit = direct.c_sqlite3_module._types_.xCommit(0)
|
|
#module.xRollback = direct.c_sqlite3_module._types_.xRollback(0)
|
|
#module.xFindFunction = direct.c_sqlite3_module._types_.xFindFunction(0)
|
|
#module.xRename = direct.c_sqlite3_module._types_.xRename(0)
|
|
return module
|
|
|
|
def vtab_columns_declaration(factory):
|
|
acc = []
|
|
for column in factory._vtcolumns_:
|
|
if isinstance(column, str):
|
|
acc.append(column)
|
|
elif isinstance(column, tuple) and len(column) != 2:
|
|
acc.append('%s %s'%column)
|
|
else:
|
|
raise ValueError('Column must be either a string or a 2-tuple of name, type')
|
|
for param in factory._vtparams_:
|
|
acc.append('%s HIDDEN'%param)
|
|
return ', '.join(acc)
|
|
|
|
class VTableIterable(object):
|
|
@classmethod
|
|
def register(cls, database, name=None):
|
|
if not cls._vtparams_:
|
|
cls._vtparams_ = tuple(k for k in inspect.signature(cls).parameters)
|
|
if not cls._vtcolumns_:
|
|
cls._vtcolumns_ = ('value',)
|
|
return libsqlite3.sqlite3_create_module(direct.connect(database), (name or cls._vtname_ or cls.__name__).encode('utf-8'), ctypes.byref(cls._vtmod_), ctypes.py_object(cls))
|
|
@classmethod
|
|
def unregister(cls, database, name=None):
|
|
return libsqlite3.sqlite3_create_module(direct.connect(database), (name or cls._vtname_ or cls.__name__).encode('utf-8'), None, None)
|
|
def __iter__(self):
|
|
raise NotImplementedError
|
|
def __next__(self):
|
|
raise NotImplementedError
|
|
_vtmod_ = new_sqlite3_module_vtabfunc()
|
|
_vtname_ = None
|
|
_vtparams_ = None
|
|
_vtcolumns_ = None
|
|
|
|
def vtab_wrap(dst, src, shift=0, name=None, params=None, columns=None):
|
|
dst._vtname_ = name or getattr(src, '_vtname_', None) or src.__name__
|
|
dst._vtparams_ = params or getattr(src, '_vtparams_', None) or tuple(inspect.signature(src).parameters)[shift:]
|
|
dst._vtcolumns_ = columns or getattr(src, '_vtcolumns_', None) or ('value',)
|
|
dst.register = lambda database, name=None: libsqlite3.sqlite3_create_module(direct.connect(database), (name or dst._vtname_).encode('utf-8'), ctypes.byref(VTableIterable._vtmod_), ctypes.py_object(dst))
|
|
dst.unregister = getattr(src, 'unregister', None) or (lambda database, name=None: libsqlite3.sqlite3_create_module(direct.connect(database), (name or dst._vtname_).encode('utf-8'), None, None))
|
|
dst.__doc__ = src.__doc__
|
|
dst.__name__ = src.__name__
|
|
return dst
|
|
|
|
class vtab_descriptor(object):
|
|
def __init__(self, func, name=None, params=None, columns=None):
|
|
self._vtfunc_ = func
|
|
vtab_wrap(self, func, 1, name=name, params=params, columns=columns)
|
|
def __call__(self, *args, **kwargs):
|
|
return self._vtfunc_(*args, **kwargs)
|
|
def __get__(self, obj, cls):
|
|
return vtab_wrap(functools.partial(self._vtfunc_, obj), self)
|
|
|
|
def vtab_iterable(func=None, *, name=None, params=None, columns=None, member=False):
|
|
def decorator(src):
|
|
try:
|
|
if member or inspect.getfullargspec(src)[0][0] in {'self', 'cls'}:
|
|
return vtab_descriptor(src, name=name, params=params, columns=columns)
|
|
except IndexError:
|
|
pass
|
|
def proxy(*args, **kw):
|
|
return src(*args, **kw)
|
|
return vtab_wrap(proxy, src, name=name, params=params, columns=columns)
|
|
if func is None:
|
|
return decorator
|
|
else:
|
|
return decorator(func)
|
|
|
|
__all__ = ['VTableIterable', 'vtab_iterable']
|