From 54b6bf1467aadc62a5898a3303f861296f6ca37b Mon Sep 17 00:00:00 2001 From: Jiang Yio Date: Mon, 30 Mar 2020 08:00:00 -0500 Subject: [PATCH] Initial commit --- .gitignore | 154 +++++++++++++++++++ README.md | 3 + __init__.py | 5 + async_function.py | 29 ++++ direct.py | 369 ++++++++++++++++++++++++++++++++++++++++++++++ test.py | 148 +++++++++++++++++++ vtab.py | 260 ++++++++++++++++++++++++++++++++ 7 files changed, 968 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 __init__.py create mode 100644 async_function.py create mode 100644 direct.py create mode 100644 test.py create mode 100644 vtab.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..55be276 --- /dev/null +++ b/.gitignore @@ -0,0 +1,154 @@ +# ---> Python +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + diff --git a/README.md b/README.md new file mode 100644 index 0000000..915b651 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# sqlite3tricks-py + +Teach Python's SQLite3 library new tricks using ctypes. \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..778c1f4 --- /dev/null +++ b/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +from direct import connect, execute +from vtab import VTableIterable, vtab_iterable +from async_function import create_function diff --git a/async_function.py b/async_function.py new file mode 100644 index 0000000..731ac57 --- /dev/null +++ b/async_function.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +import sys +import asyncio + +def get_event_loop(): + if get_event_loop.loop: + return get_event_loop.loop + try: + loop = asyncio.get_running_loop() + except RuntimeError: + if sys.platform != 'win32': + loop = asyncio.new_event_loop() + else: + loop = asyncio.ProactorEventLoop() + get_event_loop.loop = loop + return loop +get_event_loop.loop = None + +def create_function(conn, name, num_params, func, deterministic=False): + original_create_function = conn._create_function if hasattr(conn, '_create_function') else conn.create_function + if conn.create_function == create_function: + fn = conn._create_function + if asyncio.iscoroutinefunction(func): + return original_create_function(name, num_params, (lambda *args: get_event_loop().run_until_complete(func(*args))), deterministic=deterministic) + else: + return original_create_function(name, num_params, func, deterministic=deterministic) + +__all__ = ['create_function'] diff --git a/direct.py b/direct.py new file mode 100644 index 0000000..ace22c8 --- /dev/null +++ b/direct.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 + +import sys +import sqlite3 +import ctypes +import contextlib +from collections import namedtuple + +maxsize32 = 2**31 - 1 + +# Flags for file open operations +# https://www.sqlite.org/c3ref/c_open_autoproxy.html +SQLITE_OPEN_READWRITE = 0x00000002 + +# Constants defining special destructor behavior +# https://www.sqlite.org/c3ref/c_static.html +SQLITE_TRANSIENT = -1 + +# Virtual table constraint operator codes +# https://www.sqlite.org/c3ref/c_index_constraint_eq.html +SQLITE_INDEX_CONSTRAINT_EQ = 2 + +# Result codes +# https://www.sqlite.org/rescode.html +SQLITE_OK = 0 +SQLITE_ERROR = 1 +SQLITE_CONSTRAINT = 19 +SQLITE_ROW = 100 +SQLITE_DONE = 101 + +# Fundamental datatypes +# https://www.sqlite.org/c3ref/c_blob.html +SQLITE_INTEGER = 1 +SQLITE_FLOAT = 2 +SQLITE_TEXT = 3 +SQLITE_BLOB = 4 +SQLITE_NULL = 5 + +# Custom pointer types +c_char_p_p = ctypes.POINTER(ctypes.c_char_p) +c_int64_p = ctypes.POINTER(ctypes.c_int64) +c_void_p_p = ctypes.POINTER(ctypes.c_void_p) + +class c_PyObject(ctypes.Structure): + """All Python object types are extensions of this type.""" + # https://stackoverflow.com/questions/35438103/referencing-a-pointer-of-array-in-a-struct-with-ctypes + _fields_ = [ + ('ob_refcnt', ctypes.c_ssize_t), + ('ob_type', ctypes.c_void_p), + ] + +class c_sqlite3_p(ctypes.c_void_p): + def execute(self, sql, parameters=()): + return execute(self, sql, parameters) + +class c_pysqlite_connection(ctypes.Structure): + """Just enough struct to get at the pointer to the SQLite db + Defined at https://github.com/python/cpython/blob/main/Modules/_sqlite/connection.h + """ + _fields_ = [ + ('ob_base', c_PyObject), + ('db', c_sqlite3_p), + ] + +class c_sqlite3_module(ctypes.Structure): pass # forward declaration +c_sqlite3_module_p = ctypes.POINTER(c_sqlite3_module) + +class c_sqlite3_vtab(ctypes.Structure): + # https://www.sqlite.org/c3ref/vtab.html + _fields_ = [ + ('pModule', c_sqlite3_module_p), + ('nRef', ctypes.c_int), + ('zErrMsg', ctypes.c_char_p), + # Custom fields + ('tableFactory', ctypes.py_object), + ] +c_sqlite3_vtab_p = ctypes.POINTER(c_sqlite3_vtab) +c_sqlite3_vtab_p_p = ctypes.POINTER(c_sqlite3_vtab_p) + +class c_sqlite3_index_constraint(ctypes.Structure): + _fields_ = [ + ('iColumn', ctypes.c_int), + ('op', ctypes.c_ubyte), + ('usable', ctypes.c_ubyte), + ('iTermOffset', ctypes.c_int), + ] +c_sqlite3_index_constraint_p = ctypes.POINTER(c_sqlite3_index_constraint) + +class c_sqlite3_index_orderby(ctypes.Structure): + _fields_ = [ + ('iColumn', ctypes.c_int), + ('desc', ctypes.c_ubyte), + ] +c_sqlite3_index_orderby_p = ctypes.POINTER(c_sqlite3_index_orderby) + +class c_sqlite3_index_constraint_usage(ctypes.Structure): + _fields_ = [ + ('argvIndex', ctypes.c_int), + ('omit', ctypes.c_ubyte), + ] +c_sqlite3_index_constraint_usage_p = ctypes.POINTER(c_sqlite3_index_constraint_usage) + +class c_sqlite3_index_info(ctypes.Structure): + # https://www.sqlite.org/c3ref/index_info.html + _fields_ = [ + # Inputs + ('nConstraint', ctypes.c_int), + ('aConstraint', c_sqlite3_index_constraint_p), + ('nOrderBy', ctypes.c_int), + ('aOrderBy', c_sqlite3_index_orderby_p), + # Outputs + ('aConstraintUsage', c_sqlite3_index_constraint_usage_p), + ('idxNum', ctypes.c_int), + ('idxStr', ctypes.c_char_p), + ('needToFreeIdxStr', ctypes.c_int), + ('orderByConsumed', ctypes.c_int), + ('estimatedCost', ctypes.c_double), + # Fields below are only available in SQLite 3.8.2 and later + ('estimatedRows', ctypes.c_int64), + # Fields below are only available in SQLite 3.9.0 and later + ('idxFlags', ctypes.c_int), + # Fields below are only available in SQLite 3.10.0 and later + ('colUsed', ctypes.c_uint64), + ] +c_sqlite3_index_info_p = ctypes.POINTER(c_sqlite3_index_info) + +class c_sqlite3_vtab_cursor(ctypes.Structure): + # https://www.sqlite.org/c3ref/vtab_cursor.html + _fields_ = [ + ('pVtab', c_sqlite3_vtab_p), + # Custom fields + ('tableFactory', ctypes.py_object), + ('tableIterator', ctypes.py_object), + ('rowData', ctypes.py_object), + ('rowIdx', ctypes.c_int64), + ] +c_sqlite3_vtab_cursor_p = ctypes.POINTER(c_sqlite3_vtab_cursor) +c_sqlite3_vtab_cursor_p_p = ctypes.POINTER(c_sqlite3_vtab_cursor_p) + +class c_sqlite3_context(ctypes.Structure): + pass +c_sqlite3_context_p = ctypes.POINTER(c_sqlite3_context) + +class c_sqlite3_value(ctypes.Structure): + pass +c_sqlite3_value_p = ctypes.POINTER(c_sqlite3_value) +c_sqlite3_value_p_p = ctypes.POINTER(c_sqlite3_value_p) + +# https://www.sqlite.org/c3ref/module.html +c_sqlite3_module._fields_ = [ + ('iVersion', ctypes.c_int), + # int (*xCreate)(sqlite3*, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char**); + ('xCreate', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_p, ctypes.c_void_p, ctypes.c_int, c_char_p_p, c_sqlite3_vtab_p_p, c_char_p_p)), + # int (*xConnect)(sqlite3*, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char**); + ('xConnect', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_p, ctypes.c_void_p, ctypes.c_int, c_char_p_p, c_sqlite3_vtab_p_p, c_char_p_p)), + # int (*xBestIndex)(sqlite3_vtab *pVTab, sqlite3_index_info*); + ('xBestIndex', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, c_sqlite3_index_info_p)), + # int (*xDisconnect)(sqlite3_vtab *pVTab); + ('xDisconnect', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)), + # int (*xDestroy)(sqlite3_vtab *pVTab); + ('xDestroy', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)), + # int (*xOpen)(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor); + ('xOpen', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, c_sqlite3_vtab_cursor_p_p)), + # int (*xClose)(sqlite3_vtab_cursor*); + ('xClose', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p)), + # int (*xFilter)(sqlite3_vtab_cursor*, int idxNum, const char *idxStr, int argc, sqlite3_value **argv); + ('xFilter', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, c_sqlite3_value_p_p)), + # int (*xNext)(sqlite3_vtab_cursor*); + ('xNext', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p)), + # int (*xEof)(sqlite3_vtab_cursor*); + ('xEof', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p)), + # int (*xColumn)(sqlite3_vtab_cursor*, sqlite3_context*, int); + ('xColumn', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p, c_sqlite3_context_p, ctypes.c_int)), + # int (*xRowid)(sqlite3_vtab_cursor*, sqlite3_int64 *pRowid); + ('xRowid', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p, c_int64_p)), + # int (*xUpdate)(sqlite3_vtab *, int, sqlite3_value **, sqlite3_int64 *); + ('xUpdate', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, c_sqlite3_value_p_p, c_int64_p)), + # int (*xBegin)(sqlite3_vtab *pVTab); + ('xBegin', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)), + # int (*xSync)(sqlite3_vtab *pVTab); + ('xSync', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)), + # int (*xCommit)(sqlite3_vtab *pVTab); + ('xCommit', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)), + # int (*xRollback)(sqlite3_vtab *pVTab); + ('xRollback', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)), + # int (*xFindFunction)(sqlite3_vtab *pVtab, int nArg, const char *zName, void (**pxFunc)(sqlite3_context*,int,sqlite3_value**), void **ppArg); + ('xFindFunction', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int, ctypes.c_char_p, ctypes.CFUNCTYPE(c_void_p_p, c_sqlite3_context_p, ctypes.c_int, c_sqlite3_value_p_p), c_void_p_p)), + # int (*xRename)(sqlite3_vtab *pVtab, const char *zNew); + ('xRename', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_char_p)), + # The methods above are in version 1 of the sqlite_module object. Those below are for version 2 and greater. + # int (*xSavepoint)(sqlite3_vtab *pVTab, int); + ('xSavepoint', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int)), + # int (*xRelease)(sqlite3_vtab *pVTab, int); + ('xRelease', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int)), + # int (*xRollbackTo)(sqlite3_vtab *pVTab, int); + ('xRollbackTo', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int)), + # The methods above are in versions 1 and 2 of the sqlite_module object. Those below are for version 3 and greater. + # int (*xShadowName)(const char*); + ('xShadowName', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p)), +] +c_sqlite3_module._types_ = namedtuple('c_sqlite3_module', (k for k, v in c_sqlite3_module._fields_))(*(v for k, v in c_sqlite3_module._fields_)) + +# Reference counting +import copy +pythonapi = copy.deepcopy(ctypes.pythonapi) +Py_IncRef = pythonapi.Py_IncRef +Py_DecRef = pythonapi.Py_DecRef +Py_IncRef.argtypes = [ctypes.py_object] +Py_DecRef.argtypes = [ctypes.py_object] + +try: + import ctypes.util + libsqlite3 = ctypes.cdll.LoadLibrary(ctypes.util.find_library('sqlite3')) +except TypeError: + import os + libsqlite3 = ctypes.cdll.LoadLibrary(os.path.join(os.path.dirname(sys.executable), 'DLLs', 'sqlite3')) + +# https://news.ycombinator.com/item?id=22030336 +#sqlite = ctypes.CDLL(ctypes.util.find_library('sqlite3')) + +# bind value to SQLite statement +libsqlite3.sqlite3_bind_int64.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_int64) +libsqlite3.sqlite3_bind_double.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_double) +map_sqlite3_bind = { + type(0): libsqlite3.sqlite3_bind_int64, + type(0.0): libsqlite3.sqlite3_bind_double, + type(''): lambda stmt, i, value: libsqlite3.sqlite3_bind_text(stmt, i, value.encode('utf-8'), len(value.encode('utf-8')), SQLITE_TRANSIENT), + type(b''): lambda stmt, i, value: libsqlite3.sqlite3_bind_blob(stmt, i, value, len(value), SQLITE_TRANSIENT), + type(None): lambda stmt, i, _: libsqlite3.sqlite3_bind_null(stmt, i), +} +if sys.maxsize <= maxsize32: + map_sqlite3_bind[type(0)] = lambda stmt, i, value: libsqlite3.sqlite3_bind_int(stmt, i, value) if value <= maxsize32 else libsqlite3.sqlite3_bind_int64(stmt, i, value) + +# extract SQLite column value +libsqlite3.sqlite3_column_name.restype = ctypes.c_char_p +libsqlite3.sqlite3_column_type.restype = ctypes.c_int +libsqlite3.sqlite3_column_int64.restype = ctypes.c_int64 +libsqlite3.sqlite3_column_double.restype = ctypes.c_double +libsqlite3.sqlite3_column_blob.restype = ctypes.c_void_p +libsqlite3.sqlite3_column_bytes.restype = ctypes.c_int64 +map_sqlite3_column = { + SQLITE_INTEGER: libsqlite3.sqlite3_column_int64, + SQLITE_FLOAT: libsqlite3.sqlite3_column_double, + SQLITE_TEXT: lambda stmt, i: ctypes.string_at( + libsqlite3.sqlite3_column_blob(stmt, i), + libsqlite3.sqlite3_column_bytes(stmt, i), + ).decode(), + SQLITE_BLOB: lambda stmt, i: ctypes.string_at( + libsqlite3.sqlite3_column_blob(stmt, i), + libsqlite3.sqlite3_column_bytes(stmt, i), + ), + SQLITE_NULL: lambda stmt, i: None, +} + +# set SQLite result +libsqlite3.sqlite3_result_int64.argtypes = (c_sqlite3_context_p, ctypes.c_int64) +libsqlite3.sqlite3_result_double.argtypes = (c_sqlite3_context_p, ctypes.c_double) +libsqlite3.sqlite3_result_text64.argtypes = (c_sqlite3_context_p, ctypes.c_char_p, ctypes.c_uint64, ctypes.c_void_p, ctypes.c_ubyte) +libsqlite3.sqlite3_result_blob64.argtypes = (c_sqlite3_context_p, ctypes.c_void_p, ctypes.c_uint64, ctypes.c_void_p) +libsqlite3.sqlite3_result_null.argtypes = (c_sqlite3_context_p,) +map_sqlite3_result = { + type(0): libsqlite3.sqlite3_result_int64, + type(0.0): libsqlite3.sqlite3_result_double, + type(''): lambda ctx, value: libsqlite3.sqlite3_result_text64(ctx, s := value.encode('utf-8'), len(s), SQLITE_TRANSIENT, SQLITE_UTF8), + type(b''): lambda ctx, value: libsqlite3.sqlite3_result_blob64(ctx, value, len(value), -1), + type(None): lambda ctx, _: libsqlite3.sqlite3_result_null(ctx), +} +if sys.maxsize <= maxsize32: + map_sqlite3_result[type(0)] = lambda ctx, value: libsqlite3.sqlite3_result_int(ctx, value) if value <= maxsize32 else libsqlite3.sqlite3_result_int64(ctx, value) + +# extract SQLite value +libsqlite3.sqlite3_value_type.restype = ctypes.c_int +libsqlite3.sqlite3_value_type.argtypes = (c_sqlite3_value_p,) +libsqlite3.sqlite3_value_int64.restype = ctypes.c_int64 +libsqlite3.sqlite3_value_int64.argtypes = (c_sqlite3_value_p,) +libsqlite3.sqlite3_value_double.restype = ctypes.c_double +libsqlite3.sqlite3_value_double.argtypes = (c_sqlite3_value_p,) +libsqlite3.sqlite3_value_blob.restype = ctypes.c_void_p +libsqlite3.sqlite3_value_blob.argtypes = (c_sqlite3_value_p,) +libsqlite3.sqlite3_value_bytes.restype = ctypes.c_int64 +libsqlite3.sqlite3_value_bytes.argtypes = (c_sqlite3_value_p,) +map_sqlite3_value = { + SQLITE_INTEGER: libsqlite3.sqlite3_value_int64, + SQLITE_FLOAT: libsqlite3.sqlite3_value_double, + SQLITE_TEXT: lambda x: ctypes.string_at( + libsqlite3.sqlite3_value_blob(x), + libsqlite3.sqlite3_value_bytes(x), + ).decode(), + SQLITE_BLOB: lambda x: ctypes.string_at( + libsqlite3.sqlite3_value_blob(x), + libsqlite3.sqlite3_value_bytes(x), + ), + SQLITE_NULL: lambda x: None, +} + +# Query execution + +# https://gist.github.com/michalc/a3147997e21665896836e0f4157975cb +libsqlite3.sqlite3_errstr.restype = ctypes.c_char_p +libsqlite3.sqlite3_errmsg.restype = ctypes.c_char_p +libsqlite3.sqlite3_declare_vtab.restype = ctypes.c_int +libsqlite3.sqlite3_malloc.restype = ctypes.c_void_p +libsqlite3.sqlite3_create_module.restype = ctypes.c_int + +def connect(database, allownew=True, throw=True): + """Retrieves raw connection or opens new connection to SQLite database.""" + res = None + if isinstance(database, c_sqlite3_p): + res = database + elif isinstance(database, sqlite3.Connection): + res = c_pysqlite_connection.from_address(id(database)).db + elif isinstance(database, str) and allownew: + pDB = c_sqlite3_p() + callAPI(libsqlite3.sqlite3_open_v2, database.encode('utf-8'), ctypes.byref(pDB), SQLITE_OPEN_READWRITE, None) + res = pDB + if res: + return res + elif throw: + raise IOError('Invalid database') + +def callAPI(func, *args, pDB=None): + """Calls SQLite3 API with exceptions.""" + res = func(*args) + if res != 0: + if pDB is None: + raise sqlite3.OperationalError(libsqlite3.sqlite3_errstr(res).decode()) + else: + raise sqlite3.OperationalError(libsqlite3.sqlite3_errmsg(pDB).decode()) + +@contextlib.contextmanager +def connection(database, allownew=True): + """Manages connection to SQLite database.""" + pDB = connect(database, allownew=allownew) + try: + yield pDB + finally: + if isinstance(database, sqlite3.Connection): + database.close() + else: + callAPI(libsqlite3.sqlite3_close, pDB, pDB=pDB) + +@contextlib.contextmanager +def prepared_statement(pDB, sql): + stmt = ctypes.c_void_p() + callAPI(libsqlite3.sqlite3_prepare_v3, pDB, sql.encode('utf-8'), -1, 0, ctypes.byref(stmt), None, pDB=pDB) + try: + yield stmt + finally: + callAPI(libsqlite3.sqlite3_finalize, stmt, pDB=pDB) + +def execute(pDB, sql, parameters=()): + with prepared_statement(pDB, sql) as stmt: + for i, param in enumerate(parameters, start=1): + callAPI(map_sqlite3_bind[type(param)], stmt, i, param, pDB=pDB) + row_factory = namedtuple('Row', ( + libsqlite3.sqlite3_column_name(stmt, i).decode() for i in range(libsqlite3.sqlite3_column_count(stmt)) + )) + while True: + res = libsqlite3.sqlite3_step(stmt) + if res == SQLITE_ROW: + yield row_factory(*( + map_sqlite3_column[libsqlite3.sqlite3_column_type(stmt, i)](stmt, i) for i in range(0, len(row_factory._fields)) + )) + elif res == SQLITE_DONE: + break + else: + raise sqlite3.OperationalError(libsqlite3.sqlite3_errstr(res).decode()) + +__all__ = ['connect', 'execute'] diff --git a/test.py b/test.py new file mode 100644 index 0000000..792d00f --- /dev/null +++ b/test.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +from __init__ import * + +class RangeIterable(VTableIterable): + def __init__(self, start=0, stop=float('inf'), step=1): + self.start = start + self.stop = stop + self.step = step + self.curr = self.start + def __iter__(self): + print(self.__class__.__name__, self) + return self + def __next__(self): + if self.curr <= self.stop: + res = self.curr + self.curr += self.step + return (res,) + else: + raise StopIteration + +class RangeIterableGenerator(VTableIterable): + def __init__(self, start=0, stop=float('inf'), step=1): + self.start = start + self.stop = stop + self.step = step + def __iter__(self): + print(self.__class__.__name__, self) + curr = self.start + while curr <= self.stop: + yield (curr,) + curr += self.step + +class RangeIterableMember(object): + def __init__(self, extra): + self.extra = extra + @vtab_iterable(name='RangeIterableMember') + def generator(self, start=0, stop=float('inf'), step=1): + print(self.__class__.__name__, self, self.generator, self.extra) + for curr in range(start, stop, step): + yield (curr,) + +class RangeIterableFancy(VTableIterable): + _vtname_ = 'RangeIterableRenamed' + _vtparams_ = ('start', 'stop', 'step', 'extra') + _vtcolumns_ = ('once', 'twice', 'thrice') + def __init__(self, start=0, stop=float('inf'), step=1, **kw): + self.start = start + self.stop = stop + self.step = step + self.kw = kw + def __iter__(self): + print(self.__class__.__name__, self, self.kw) + curr = self.start + while curr <= self.stop: + yield (curr, curr*2, curr*3) + curr += self.step + +@vtab_iterable +def RangeGenerator(start=0, stop=float('inf'), step=1): + print(RangeGenerator.__name__, RangeGenerator) + for curr in range(start, stop, step): + yield (curr,) + +@vtab_iterable(name='RangeGeneratorRenamed', params=('start', 'stop', 'step', 'extra'), columns=('once', 'twice', 'thrice')) +def RangeGeneratorFancy(start=0, stop=float('inf'), step=1, **kw): + print(RangeGeneratorFancy.__name__, RangeGeneratorFancy, kw) + for curr in range(start, stop, step): + yield (curr, curr*2, curr*3) + +if __name__ == '__main__': + import sqlite3 + + conn = sqlite3.connect(':memory:') + pDB = connect(conn) + + conn.execute('CREATE TABLE test (a INTEGER, b INTEGER)') + conn.execute('INSERT INTO test (a, b) VALUES (?, ?)', (1, 2)) + conn.execute('INSERT INTO test (a, b) VALUES (?, ?)', (3, 4)) + conn.execute('INSERT INTO test (a, b) VALUES (?, ?)', (5, 6)) + conn.execute('INSERT INTO test (a, b) VALUES (?, ?)', (7, 8)) + + for row in conn.execute('SELECT * FROM test'): + print('sqlite3:', row) + for row in pDB.execute('SELECT * FROM test'): + print('tricks:', row) + print() + + for row in conn.execute('SELECT * FROM test WHERE a = ?', (1,)): + print('sqlite3:', row) + for row in pDB.execute('SELECT * FROM test WHERE a = ?', (1,)): + print('tricks:', row) + print() + + RangeIterable.register(conn) + for row in conn.execute('SELECT * FROM RangeIterable(10, 16, 2)'): + print('sqlite3:', row) + for row in pDB.execute('SELECT * FROM RangeIterable(10, 16, 2)'): + print('tricks:', row) + print() + + RangeIterableGenerator.register(conn) + for row in conn.execute('SELECT * FROM RangeIterableGenerator(20, 26, 2)'): + print('sqlite3:', row) + for row in pDB.execute('SELECT * FROM RangeIterableGenerator(20, 26, 2)'): + print('tricks:', row) + print() + + obj = RangeIterableMember('hello') + obj.generator.register(conn) + for row in conn.execute('SELECT * FROM RangeIterableMember(30, 36, 2)'): + print('sqlite3:', row) + for row in pDB.execute('SELECT * FROM RangeIterableMember(30, 36, 2)'): + print('tricks:', row) + print() + + RangeIterableFancy.register(conn) + for row in conn.execute('SELECT * FROM RangeIterableRenamed(40, 46, 2, datetime())'): + print('sqlite3:', row) + for row in pDB.execute('SELECT * FROM RangeIterableRenamed(40, 46, 2, datetime())'): + print('tricks:', row) + print() + + RangeGenerator.register(conn) + for row in conn.execute('SELECT * FROM RangeGenerator(50, 56, 2)'): + print('sqlite3:', row) + for row in pDB.execute('SELECT * FROM RangeGenerator(50, 56, 2)'): + print('tricks:', row) + print() + + RangeGeneratorFancy.register(conn) + for row in conn.execute('SELECT * FROM RangeGeneratorRenamed(60, 66, 2, datetime())'): + print('sqlite3:', row) + for row in pDB.execute('SELECT * FROM RangeGeneratorRenamed(60, 66, 2, datetime())'): + print('tricks:', row) + print() + + import hashlib + + async def md5sum(t): + return hashlib.md5(str(t).encode('utf-8')).hexdigest() + + create_function(conn, 'md5', 1, md5sum) + + for row in conn.execute('SELECT md5(value) FROM RangeIterable(10, 16, 2)'): + print(row) + for row in pDB.execute('SELECT md5(value) AS md5_value FROM RangeIterable(10, 16, 2)'): + print(row) diff --git a/vtab.py b/vtab.py new file mode 100644 index 0000000..905bafb --- /dev/null +++ b/vtab.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 + +import sys +import sqlite3 +import ctypes +import functools +import inspect +import traceback + +import direct +from direct import libsqlite3, Py_IncRef, Py_DecRef, c_sqlite3_module, c_sqlite3_value_p_p, map_sqlite3_value, map_sqlite3_result, SQLITE_OK, SQLITE_ERROR, SQLITE_CONSTRAINT, SQLITE_INDEX_CONSTRAINT_EQ + +USE_SQLITE_CONSTRAINT = tuple(map(int, sqlite3.sqlite_version.split('.'))) >= (3, 26, 0) + +@direct.c_sqlite3_module._types_.xConnect +def vtab_xConnect(pDB, pAux, argc, argv, ppVTab, pzErr): # sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr + """ + Invoked whenever a database connection attaches to or reparses a schema. + A virtual table is eponymous if its xCreate method is the exact same function as the xConnect method, or if the xCreate method is NULL. + """ + tableFactory = ctypes.cast(pAux, ctypes.py_object).value + res = libsqlite3.sqlite3_declare_vtab(pDB, f'CREATE TABLE x({vtab_columns_declaration(tableFactory)})'.encode('utf-8')) + if res == SQLITE_OK: + szVTab = ctypes.sizeof(direct.c_sqlite3_vtab) + pVTab = ctypes.cast(libsqlite3.sqlite3_malloc(szVTab), direct.c_sqlite3_vtab_p) + ctypes.memset(pVTab, 0, szVTab) + pVTab.contents.tableFactory = ctypes.cast(pAux, ctypes.py_object) + Py_IncRef(pVTab.contents.tableFactory) + ppVTab[0] = pVTab + return res + +@direct.c_sqlite3_module._types_.xBestIndex +def vtab_xBestIndex(pVTab, pIdxInfo): # sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo + """Used to determine the best way to access the virtual table.""" + tableFactory = pVTab.contents.tableFactory + vIdxInfo = pIdxInfo.contents + nArg = 0 + columns = [] + nParams = len(tableFactory._vtparams_) + nColumns = len(tableFactory._vtcolumns_) + for i in range(vIdxInfo.nConstraint): + vConstraint = vIdxInfo.aConstraint[i] + if vConstraint.usable and vConstraint.op == SQLITE_INDEX_CONSTRAINT_EQ: + columns.append(tableFactory._vtparams_[vConstraint.iColumn - nColumns]) + nArg += 1 + vIdxInfo.aConstraintUsage[i].argvIndex = nArg + vIdxInfo.aConstraintUsage[i].omit = 1 + if nArg > 0 or nParams == 0: + if nArg == nParams: + # All parameters are present, this is ideal. + vIdxInfo.estimatedCost = 1.0 + vIdxInfo.estimatedRows = 10 + else: + # Penalize score based on number of missing params. + vIdxInfo.estimatedCost = 1e13 * (nParams - nArg) + vIdxInfo.estimatedRows = 10 ** (nParams - nArg) + # Store a reference to the columns in the index info structure. + joinedCols = ','.join(columns).encode('utf-8') + idxStr = libsqlite3.sqlite3_malloc((len(joinedCols) + 1) * ctypes.sizeof(ctypes.c_char)) + idxStr = (ctypes.c_char*(len(joinedCols) + 1)).from_address(idxStr) + ctypes.memmove(idxStr, joinedCols, len(joinedCols)) # memcpy + idxStr[len(joinedCols)] = b'\x00' + vIdxInfo.idxStr = ctypes.cast(idxStr, ctypes.c_char_p) + vIdxInfo.needToFreeIdxStr = 0 + elif USE_SQLITE_CONSTRAINT: + return SQLITE_CONSTRAINT + else: + vIdxInfo.estimatedCost = DBL_MAX + vIdxInfo.estimatedRows = 100000 + return SQLITE_OK + +@direct.c_sqlite3_module._types_.xDisconnect +def vtab_xDisconnect(pVTab): # sqlite3_vtab *pVTab + """Releases a connection to a virtual table.""" + Py_DecRef(pVTab.contents.tableFactory) + libsqlite3.sqlite3_free(pVTab) + return SQLITE_OK + +@direct.c_sqlite3_module._types_.xOpen +def vtab_xOpen(pVTab, ppCursor): # sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor + """Creates a new cursor used for accessing (read and/or writing) a virtual table.""" +# tableFactory = pVTab.contents.tableFactory + szCursor = ctypes.sizeof(direct.c_sqlite3_vtab_cursor) + pCursor = ctypes.cast(libsqlite3.sqlite3_malloc(szCursor), direct.c_sqlite3_vtab_cursor_p) + ctypes.memset(pCursor, 0, szCursor) + vCursor = pCursor.contents + vCursor.tableFactory = pVTab.contents.tableFactory + Py_IncRef(vCursor.tableFactory) + vCursor.rowData = None + vCursor.rowIdx = 0 + ppCursor[0] = pCursor + return SQLITE_OK + +@direct.c_sqlite3_module._types_.xClose +def vtab_xClose(pCursor): # sqlite3_vtab_cursor *pCursor + """Closes a cursor previously opened by xOpen.""" + Py_DecRef(pCursor.contents.tableFactory) + libsqlite3.sqlite3_free(pCursor) + return SQLITE_OK + +def vtab_iterate(vCursor): + if vCursor.rowData: + Py_DecRef(vCursor.rowData) + vCursor.rowData = None + try: + vCursor.rowData = next(vCursor.tableIterator) + Py_IncRef(vCursor.rowData) + except StopIteration: + Py_DecRef(vCursor.tableIterator) + vCursor.tableIterator = None + except: + traceback.print_exc() + return SQLITE_ERROR + else: + vCursor.rowIdx += 1 + return SQLITE_OK + +@direct.c_sqlite3_module._types_.xFilter +def vtab_xFilter(pCursor, idxNum, idxStr, argc, argv): # sqlite3_vtab_cursor *pCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv + """Begins a search of a virtual table.""" + vCursor = pCursor.contents + tableFactory = vCursor.tableFactory + idxStr = ctypes.cast(idxStr, ctypes.c_char_p).value # avoid null terminator + if not idxStr or argc == 0 and len(tableFactory._vtparams_): + return SQLITE_ERROR + elif len(idxStr): + params = idxStr.decode().split(',') + else: + params = [] + py_args = tuple(map_sqlite3_value[libsqlite3.sqlite3_value_type(arg := argv[i])](arg) for i in range(argc)) + query = dict(zip(params, py_args)) + try: + tableIterator = iter(tableFactory(**query)) + except: + traceback.print_exc() + return SQLITE_ERROR + vCursor.tableIterator = tableIterator + Py_IncRef(vCursor.tableIterator) + return vtab_iterate(vCursor) + +@direct.c_sqlite3_module._types_.xNext +def vtab_xNext(pCursor): # sqlite3_vtab_cursor *pCursor + """Advances a virtual table cursor to the next row of a result set initiated by xFilter.""" + return vtab_iterate(pCursor.contents) + +@direct.c_sqlite3_module._types_.xEof +def vtab_xEof(pCursor): # sqlite3_vtab_cursor *pCursor + """Returns false (zero) if the specified cursor currently points to a valid row of data, or true (non-zero) otherwise.""" + return 1 if pCursor.contents.tableIterator is None else 0 + +@direct.c_sqlite3_module._types_.xColumn +def vtab_xColumn(pCursor, pContext, iCol): # sqlite3_vtab_cursor *pCursor, sqlite3_context *pContext, int iCol + """Invoked to retrieve the N-th column of the current row.""" + if iCol == -1: + libsqlite3.sqlite3_result_int64(pContext, pCursor.contents.rowIdx) + return SQLITE_OK + if not pCursor.contents.rowData: + libsqlite3.sqlite3_result_error(pContext, 'no row data'.encode('utf-8'), -1) + return SQLITE_ERROR + map_sqlite3_result[type(value := pCursor.contents.rowData[iCol])](pContext, value) + return SQLITE_OK + +@direct.c_sqlite3_module._types_.xRowid +def vtab_xRowid(pCursor, pRowid): # sqlite3_vtab_cursor *pCursor, sqlite3_int64 *pRowid + """Fills *pRowid with the rowid of the current row.""" + pRowid[0] = pCursor.contents.rowIdx + +def new_sqlite3_module_vtabfunc(): + module = direct.c_sqlite3_module() + ctypes.memset(ctypes.byref(module), 0, ctypes.sizeof(direct.c_sqlite3_module)) + module.iVersion = 0 + #module.xCreate = direct.c_sqlite3_module._types_.xCreate(0) + module.xConnect = vtab_xConnect + module.xBestIndex = vtab_xBestIndex + module.xDisconnect = vtab_xDisconnect + #module.xDestroy = direct.c_sqlite3_module._types_.xDestroy(0) + module.xOpen = vtab_xOpen + module.xClose = vtab_xClose + module.xFilter = vtab_xFilter + module.xNext = vtab_xNext + module.xEof = vtab_xEof + module.xColumn = vtab_xColumn + module.xRowid = vtab_xRowid + #module.xUpdate = direct.c_sqlite3_module._types_.xUpdate(0) + #module.xBegin = direct.c_sqlite3_module._types_.xBegin(0) + #module.xSync = direct.c_sqlite3_module._types_.xSync(0) + #module.xCommit = direct.c_sqlite3_module._types_.xCommit(0) + #module.xRollback = direct.c_sqlite3_module._types_.xRollback(0) + #module.xFindFunction = direct.c_sqlite3_module._types_.xFindFunction(0) + #module.xRename = direct.c_sqlite3_module._types_.xRename(0) + return module + +def vtab_columns_declaration(factory): + acc = [] + for column in factory._vtcolumns_: + if isinstance(column, str): + acc.append(column) + elif isinstance(column, tuple) and len(column) != 2: + acc.append('%s %s'%column) + else: + raise ValueError('Column must be either a string or a 2-tuple of name, type') + for param in factory._vtparams_: + acc.append('%s HIDDEN'%param) + return ', '.join(acc) + +class VTableIterable(object): + @classmethod + def register(cls, database, name=None): + if not cls._vtparams_: + cls._vtparams_ = tuple(k for k in inspect.signature(cls).parameters) + if not cls._vtcolumns_: + cls._vtcolumns_ = ('value',) + return libsqlite3.sqlite3_create_module(direct.connect(database), (name or cls._vtname_ or cls.__name__).encode('utf-8'), ctypes.byref(cls._vtmod_), ctypes.py_object(cls)) + @classmethod + def unregister(cls, database, name=None): + return libsqlite3.sqlite3_create_module(direct.connect(database), (name or cls._vtname_ or cls.__name__).encode('utf-8'), None, None) + def __iter__(self): + raise NotImplementedError + def __next__(self): + raise NotImplementedError + _vtmod_ = new_sqlite3_module_vtabfunc() + _vtname_ = None + _vtparams_ = None + _vtcolumns_ = None + +def vtab_wrap(dst, src, shift=0, name=None, params=None, columns=None): + dst._vtname_ = name or getattr(src, '_vtname_', None) or src.__name__ + dst._vtparams_ = params or getattr(src, '_vtparams_', None) or tuple(inspect.signature(src).parameters)[shift:] + dst._vtcolumns_ = columns or getattr(src, '_vtcolumns_', None) or ('value',) + dst.register = lambda database, name=None: libsqlite3.sqlite3_create_module(direct.connect(database), (name or dst._vtname_).encode('utf-8'), ctypes.byref(VTableIterable._vtmod_), ctypes.py_object(dst)) + dst.unregister = getattr(src, 'unregister', None) or (lambda database, name=None: libsqlite3.sqlite3_create_module(direct.connect(database), (name or dst._vtname_).encode('utf-8'), None, None)) + dst.__doc__ = src.__doc__ + dst.__name__ = src.__name__ + return dst + +class vtab_descriptor(object): + def __init__(self, func, name=None, params=None, columns=None): + self._vtfunc_ = func + vtab_wrap(self, func, 1, name=name, params=params, columns=columns) + def __call__(self, *args, **kwargs): + return self._vtfunc_(*args, **kwargs) + def __get__(self, obj, cls): + return vtab_wrap(functools.partial(self._vtfunc_, obj), self) + +def vtab_iterable(func=None, *, name=None, params=None, columns=None, member=False): + def decorator(src): + try: + if member or inspect.getfullargspec(src)[0][0] in {'self', 'cls'}: + return vtab_descriptor(src, name=name, params=params, columns=columns) + except IndexError: + pass + def proxy(*args, **kw): + return src(*args, **kw) + return vtab_wrap(proxy, src, name=name, params=params, columns=columns) + if func is None: + return decorator + else: + return decorator(func) + +__all__ = ['VTableIterable', 'vtab_iterable']