sqlite3tricks-py/vtab.py
2020-03-30 08:00:00 -05:00

261 lines
10 KiB
Python

#!/usr/bin/env python3
import sys
import sqlite3
import ctypes
import functools
import inspect
import traceback
import direct
from direct import libsqlite3, Py_IncRef, Py_DecRef, c_sqlite3_module, c_sqlite3_value_p_p, map_sqlite3_value, map_sqlite3_result, SQLITE_OK, SQLITE_ERROR, SQLITE_CONSTRAINT, SQLITE_INDEX_CONSTRAINT_EQ
USE_SQLITE_CONSTRAINT = tuple(map(int, sqlite3.sqlite_version.split('.'))) >= (3, 26, 0)
@direct.c_sqlite3_module._types_.xConnect
def vtab_xConnect(pDB, pAux, argc, argv, ppVTab, pzErr): # sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr
"""
Invoked whenever a database connection attaches to or reparses a schema.
A virtual table is eponymous if its xCreate method is the exact same function as the xConnect method, or if the xCreate method is NULL.
"""
tableFactory = ctypes.cast(pAux, ctypes.py_object).value
res = libsqlite3.sqlite3_declare_vtab(pDB, f'CREATE TABLE x({vtab_columns_declaration(tableFactory)})'.encode('utf-8'))
if res == SQLITE_OK:
szVTab = ctypes.sizeof(direct.c_sqlite3_vtab)
pVTab = ctypes.cast(libsqlite3.sqlite3_malloc(szVTab), direct.c_sqlite3_vtab_p)
ctypes.memset(pVTab, 0, szVTab)
pVTab.contents.tableFactory = ctypes.cast(pAux, ctypes.py_object)
Py_IncRef(pVTab.contents.tableFactory)
ppVTab[0] = pVTab
return res
@direct.c_sqlite3_module._types_.xBestIndex
def vtab_xBestIndex(pVTab, pIdxInfo): # sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo
"""Used to determine the best way to access the virtual table."""
tableFactory = pVTab.contents.tableFactory
vIdxInfo = pIdxInfo.contents
nArg = 0
columns = []
nParams = len(tableFactory._vtparams_)
nColumns = len(tableFactory._vtcolumns_)
for i in range(vIdxInfo.nConstraint):
vConstraint = vIdxInfo.aConstraint[i]
if vConstraint.usable and vConstraint.op == SQLITE_INDEX_CONSTRAINT_EQ:
columns.append(tableFactory._vtparams_[vConstraint.iColumn - nColumns])
nArg += 1
vIdxInfo.aConstraintUsage[i].argvIndex = nArg
vIdxInfo.aConstraintUsage[i].omit = 1
if nArg > 0 or nParams == 0:
if nArg == nParams:
# All parameters are present, this is ideal.
vIdxInfo.estimatedCost = 1.0
vIdxInfo.estimatedRows = 10
else:
# Penalize score based on number of missing params.
vIdxInfo.estimatedCost = 1e13 * (nParams - nArg)
vIdxInfo.estimatedRows = 10 ** (nParams - nArg)
# Store a reference to the columns in the index info structure.
joinedCols = ','.join(columns).encode('utf-8')
idxStr = libsqlite3.sqlite3_malloc((len(joinedCols) + 1) * ctypes.sizeof(ctypes.c_char))
idxStr = (ctypes.c_char*(len(joinedCols) + 1)).from_address(idxStr)
ctypes.memmove(idxStr, joinedCols, len(joinedCols)) # memcpy
idxStr[len(joinedCols)] = b'\x00'
vIdxInfo.idxStr = ctypes.cast(idxStr, ctypes.c_char_p)
vIdxInfo.needToFreeIdxStr = 0
elif USE_SQLITE_CONSTRAINT:
return SQLITE_CONSTRAINT
else:
vIdxInfo.estimatedCost = DBL_MAX
vIdxInfo.estimatedRows = 100000
return SQLITE_OK
@direct.c_sqlite3_module._types_.xDisconnect
def vtab_xDisconnect(pVTab): # sqlite3_vtab *pVTab
"""Releases a connection to a virtual table."""
Py_DecRef(pVTab.contents.tableFactory)
libsqlite3.sqlite3_free(pVTab)
return SQLITE_OK
@direct.c_sqlite3_module._types_.xOpen
def vtab_xOpen(pVTab, ppCursor): # sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor
"""Creates a new cursor used for accessing (read and/or writing) a virtual table."""
# tableFactory = pVTab.contents.tableFactory
szCursor = ctypes.sizeof(direct.c_sqlite3_vtab_cursor)
pCursor = ctypes.cast(libsqlite3.sqlite3_malloc(szCursor), direct.c_sqlite3_vtab_cursor_p)
ctypes.memset(pCursor, 0, szCursor)
vCursor = pCursor.contents
vCursor.tableFactory = pVTab.contents.tableFactory
Py_IncRef(vCursor.tableFactory)
vCursor.rowData = None
vCursor.rowIdx = 0
ppCursor[0] = pCursor
return SQLITE_OK
@direct.c_sqlite3_module._types_.xClose
def vtab_xClose(pCursor): # sqlite3_vtab_cursor *pCursor
"""Closes a cursor previously opened by xOpen."""
Py_DecRef(pCursor.contents.tableFactory)
libsqlite3.sqlite3_free(pCursor)
return SQLITE_OK
def vtab_iterate(vCursor):
if vCursor.rowData:
Py_DecRef(vCursor.rowData)
vCursor.rowData = None
try:
vCursor.rowData = next(vCursor.tableIterator)
Py_IncRef(vCursor.rowData)
except StopIteration:
Py_DecRef(vCursor.tableIterator)
vCursor.tableIterator = None
except:
traceback.print_exc()
return SQLITE_ERROR
else:
vCursor.rowIdx += 1
return SQLITE_OK
@direct.c_sqlite3_module._types_.xFilter
def vtab_xFilter(pCursor, idxNum, idxStr, argc, argv): # sqlite3_vtab_cursor *pCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv
"""Begins a search of a virtual table."""
vCursor = pCursor.contents
tableFactory = vCursor.tableFactory
idxStr = ctypes.cast(idxStr, ctypes.c_char_p).value # avoid null terminator
if not idxStr or argc == 0 and len(tableFactory._vtparams_):
return SQLITE_ERROR
elif len(idxStr):
params = idxStr.decode().split(',')
else:
params = []
py_args = tuple(map_sqlite3_value[libsqlite3.sqlite3_value_type(arg := argv[i])](arg) for i in range(argc))
query = dict(zip(params, py_args))
try:
tableIterator = iter(tableFactory(**query))
except:
traceback.print_exc()
return SQLITE_ERROR
vCursor.tableIterator = tableIterator
Py_IncRef(vCursor.tableIterator)
return vtab_iterate(vCursor)
@direct.c_sqlite3_module._types_.xNext
def vtab_xNext(pCursor): # sqlite3_vtab_cursor *pCursor
"""Advances a virtual table cursor to the next row of a result set initiated by xFilter."""
return vtab_iterate(pCursor.contents)
@direct.c_sqlite3_module._types_.xEof
def vtab_xEof(pCursor): # sqlite3_vtab_cursor *pCursor
"""Returns false (zero) if the specified cursor currently points to a valid row of data, or true (non-zero) otherwise."""
return 1 if pCursor.contents.tableIterator is None else 0
@direct.c_sqlite3_module._types_.xColumn
def vtab_xColumn(pCursor, pContext, iCol): # sqlite3_vtab_cursor *pCursor, sqlite3_context *pContext, int iCol
"""Invoked to retrieve the N-th column of the current row."""
if iCol == -1:
libsqlite3.sqlite3_result_int64(pContext, pCursor.contents.rowIdx)
return SQLITE_OK
if not pCursor.contents.rowData:
libsqlite3.sqlite3_result_error(pContext, 'no row data'.encode('utf-8'), -1)
return SQLITE_ERROR
map_sqlite3_result[type(value := pCursor.contents.rowData[iCol])](pContext, value)
return SQLITE_OK
@direct.c_sqlite3_module._types_.xRowid
def vtab_xRowid(pCursor, pRowid): # sqlite3_vtab_cursor *pCursor, sqlite3_int64 *pRowid
"""Fills *pRowid with the rowid of the current row."""
pRowid[0] = pCursor.contents.rowIdx
def new_sqlite3_module_vtabfunc():
module = direct.c_sqlite3_module()
ctypes.memset(ctypes.byref(module), 0, ctypes.sizeof(direct.c_sqlite3_module))
module.iVersion = 0
#module.xCreate = direct.c_sqlite3_module._types_.xCreate(0)
module.xConnect = vtab_xConnect
module.xBestIndex = vtab_xBestIndex
module.xDisconnect = vtab_xDisconnect
#module.xDestroy = direct.c_sqlite3_module._types_.xDestroy(0)
module.xOpen = vtab_xOpen
module.xClose = vtab_xClose
module.xFilter = vtab_xFilter
module.xNext = vtab_xNext
module.xEof = vtab_xEof
module.xColumn = vtab_xColumn
module.xRowid = vtab_xRowid
#module.xUpdate = direct.c_sqlite3_module._types_.xUpdate(0)
#module.xBegin = direct.c_sqlite3_module._types_.xBegin(0)
#module.xSync = direct.c_sqlite3_module._types_.xSync(0)
#module.xCommit = direct.c_sqlite3_module._types_.xCommit(0)
#module.xRollback = direct.c_sqlite3_module._types_.xRollback(0)
#module.xFindFunction = direct.c_sqlite3_module._types_.xFindFunction(0)
#module.xRename = direct.c_sqlite3_module._types_.xRename(0)
return module
def vtab_columns_declaration(factory):
acc = []
for column in factory._vtcolumns_:
if isinstance(column, str):
acc.append(column)
elif isinstance(column, tuple) and len(column) != 2:
acc.append('%s %s'%column)
else:
raise ValueError('Column must be either a string or a 2-tuple of name, type')
for param in factory._vtparams_:
acc.append('%s HIDDEN'%param)
return ', '.join(acc)
class VTableIterable(object):
@classmethod
def register(cls, database, name=None):
if not cls._vtparams_:
cls._vtparams_ = tuple(k for k in inspect.signature(cls).parameters)
if not cls._vtcolumns_:
cls._vtcolumns_ = ('value',)
return libsqlite3.sqlite3_create_module(direct.connect(database), (name or cls._vtname_ or cls.__name__).encode('utf-8'), ctypes.byref(cls._vtmod_), ctypes.py_object(cls))
@classmethod
def unregister(cls, database, name=None):
return libsqlite3.sqlite3_create_module(direct.connect(database), (name or cls._vtname_ or cls.__name__).encode('utf-8'), None, None)
def __iter__(self):
raise NotImplementedError
def __next__(self):
raise NotImplementedError
_vtmod_ = new_sqlite3_module_vtabfunc()
_vtname_ = None
_vtparams_ = None
_vtcolumns_ = None
def vtab_wrap(dst, src, shift=0, name=None, params=None, columns=None):
dst._vtname_ = name or getattr(src, '_vtname_', None) or src.__name__
dst._vtparams_ = params or getattr(src, '_vtparams_', None) or tuple(inspect.signature(src).parameters)[shift:]
dst._vtcolumns_ = columns or getattr(src, '_vtcolumns_', None) or ('value',)
dst.register = lambda database, name=None: libsqlite3.sqlite3_create_module(direct.connect(database), (name or dst._vtname_).encode('utf-8'), ctypes.byref(VTableIterable._vtmod_), ctypes.py_object(dst))
dst.unregister = getattr(src, 'unregister', None) or (lambda database, name=None: libsqlite3.sqlite3_create_module(direct.connect(database), (name or dst._vtname_).encode('utf-8'), None, None))
dst.__doc__ = src.__doc__
dst.__name__ = src.__name__
return dst
class vtab_descriptor(object):
def __init__(self, func, name=None, params=None, columns=None):
self._vtfunc_ = func
vtab_wrap(self, func, 1, name=name, params=params, columns=columns)
def __call__(self, *args, **kwargs):
return self._vtfunc_(*args, **kwargs)
def __get__(self, obj, cls):
return vtab_wrap(functools.partial(self._vtfunc_, obj), self)
def vtab_iterable(func=None, *, name=None, params=None, columns=None, member=False):
def decorator(src):
try:
if member or inspect.getfullargspec(src)[0][0] in {'self', 'cls'}:
return vtab_descriptor(src, name=name, params=params, columns=columns)
except IndexError:
pass
def proxy(*args, **kw):
return src(*args, **kw)
return vtab_wrap(proxy, src, name=name, params=params, columns=columns)
if func is None:
return decorator
else:
return decorator(func)
__all__ = ['VTableIterable', 'vtab_iterable']