Initial commit

This commit is contained in:
Jiang Yio 2020-03-30 08:00:00 -05:00
commit 54b6bf1467
7 changed files with 968 additions and 0 deletions

154
.gitignore vendored Normal file
View File

@ -0,0 +1,154 @@
# ---> Python
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

3
README.md Normal file
View File

@ -0,0 +1,3 @@
# sqlite3tricks-py
Teach Python's SQLite3 library new tricks using ctypes.

5
__init__.py Normal file
View File

@ -0,0 +1,5 @@
#!/usr/bin/env python3
from direct import connect, execute
from vtab import VTableIterable, vtab_iterable
from async_function import create_function

29
async_function.py Normal file
View File

@ -0,0 +1,29 @@
#!/usr/bin/env python3
import sys
import asyncio
def get_event_loop():
if get_event_loop.loop:
return get_event_loop.loop
try:
loop = asyncio.get_running_loop()
except RuntimeError:
if sys.platform != 'win32':
loop = asyncio.new_event_loop()
else:
loop = asyncio.ProactorEventLoop()
get_event_loop.loop = loop
return loop
get_event_loop.loop = None
def create_function(conn, name, num_params, func, deterministic=False):
original_create_function = conn._create_function if hasattr(conn, '_create_function') else conn.create_function
if conn.create_function == create_function:
fn = conn._create_function
if asyncio.iscoroutinefunction(func):
return original_create_function(name, num_params, (lambda *args: get_event_loop().run_until_complete(func(*args))), deterministic=deterministic)
else:
return original_create_function(name, num_params, func, deterministic=deterministic)
__all__ = ['create_function']

369
direct.py Normal file
View File

@ -0,0 +1,369 @@
#!/usr/bin/env python3
import sys
import sqlite3
import ctypes
import contextlib
from collections import namedtuple
maxsize32 = 2**31 - 1
# Flags for file open operations
# https://www.sqlite.org/c3ref/c_open_autoproxy.html
SQLITE_OPEN_READWRITE = 0x00000002
# Constants defining special destructor behavior
# https://www.sqlite.org/c3ref/c_static.html
SQLITE_TRANSIENT = -1
# Virtual table constraint operator codes
# https://www.sqlite.org/c3ref/c_index_constraint_eq.html
SQLITE_INDEX_CONSTRAINT_EQ = 2
# Result codes
# https://www.sqlite.org/rescode.html
SQLITE_OK = 0
SQLITE_ERROR = 1
SQLITE_CONSTRAINT = 19
SQLITE_ROW = 100
SQLITE_DONE = 101
# Fundamental datatypes
# https://www.sqlite.org/c3ref/c_blob.html
SQLITE_INTEGER = 1
SQLITE_FLOAT = 2
SQLITE_TEXT = 3
SQLITE_BLOB = 4
SQLITE_NULL = 5
# Custom pointer types
c_char_p_p = ctypes.POINTER(ctypes.c_char_p)
c_int64_p = ctypes.POINTER(ctypes.c_int64)
c_void_p_p = ctypes.POINTER(ctypes.c_void_p)
class c_PyObject(ctypes.Structure):
"""All Python object types are extensions of this type."""
# https://stackoverflow.com/questions/35438103/referencing-a-pointer-of-array-in-a-struct-with-ctypes
_fields_ = [
('ob_refcnt', ctypes.c_ssize_t),
('ob_type', ctypes.c_void_p),
]
class c_sqlite3_p(ctypes.c_void_p):
def execute(self, sql, parameters=()):
return execute(self, sql, parameters)
class c_pysqlite_connection(ctypes.Structure):
"""Just enough struct to get at the pointer to the SQLite db
Defined at https://github.com/python/cpython/blob/main/Modules/_sqlite/connection.h
"""
_fields_ = [
('ob_base', c_PyObject),
('db', c_sqlite3_p),
]
class c_sqlite3_module(ctypes.Structure): pass # forward declaration
c_sqlite3_module_p = ctypes.POINTER(c_sqlite3_module)
class c_sqlite3_vtab(ctypes.Structure):
# https://www.sqlite.org/c3ref/vtab.html
_fields_ = [
('pModule', c_sqlite3_module_p),
('nRef', ctypes.c_int),
('zErrMsg', ctypes.c_char_p),
# Custom fields
('tableFactory', ctypes.py_object),
]
c_sqlite3_vtab_p = ctypes.POINTER(c_sqlite3_vtab)
c_sqlite3_vtab_p_p = ctypes.POINTER(c_sqlite3_vtab_p)
class c_sqlite3_index_constraint(ctypes.Structure):
_fields_ = [
('iColumn', ctypes.c_int),
('op', ctypes.c_ubyte),
('usable', ctypes.c_ubyte),
('iTermOffset', ctypes.c_int),
]
c_sqlite3_index_constraint_p = ctypes.POINTER(c_sqlite3_index_constraint)
class c_sqlite3_index_orderby(ctypes.Structure):
_fields_ = [
('iColumn', ctypes.c_int),
('desc', ctypes.c_ubyte),
]
c_sqlite3_index_orderby_p = ctypes.POINTER(c_sqlite3_index_orderby)
class c_sqlite3_index_constraint_usage(ctypes.Structure):
_fields_ = [
('argvIndex', ctypes.c_int),
('omit', ctypes.c_ubyte),
]
c_sqlite3_index_constraint_usage_p = ctypes.POINTER(c_sqlite3_index_constraint_usage)
class c_sqlite3_index_info(ctypes.Structure):
# https://www.sqlite.org/c3ref/index_info.html
_fields_ = [
# Inputs
('nConstraint', ctypes.c_int),
('aConstraint', c_sqlite3_index_constraint_p),
('nOrderBy', ctypes.c_int),
('aOrderBy', c_sqlite3_index_orderby_p),
# Outputs
('aConstraintUsage', c_sqlite3_index_constraint_usage_p),
('idxNum', ctypes.c_int),
('idxStr', ctypes.c_char_p),
('needToFreeIdxStr', ctypes.c_int),
('orderByConsumed', ctypes.c_int),
('estimatedCost', ctypes.c_double),
# Fields below are only available in SQLite 3.8.2 and later
('estimatedRows', ctypes.c_int64),
# Fields below are only available in SQLite 3.9.0 and later
('idxFlags', ctypes.c_int),
# Fields below are only available in SQLite 3.10.0 and later
('colUsed', ctypes.c_uint64),
]
c_sqlite3_index_info_p = ctypes.POINTER(c_sqlite3_index_info)
class c_sqlite3_vtab_cursor(ctypes.Structure):
# https://www.sqlite.org/c3ref/vtab_cursor.html
_fields_ = [
('pVtab', c_sqlite3_vtab_p),
# Custom fields
('tableFactory', ctypes.py_object),
('tableIterator', ctypes.py_object),
('rowData', ctypes.py_object),
('rowIdx', ctypes.c_int64),
]
c_sqlite3_vtab_cursor_p = ctypes.POINTER(c_sqlite3_vtab_cursor)
c_sqlite3_vtab_cursor_p_p = ctypes.POINTER(c_sqlite3_vtab_cursor_p)
class c_sqlite3_context(ctypes.Structure):
pass
c_sqlite3_context_p = ctypes.POINTER(c_sqlite3_context)
class c_sqlite3_value(ctypes.Structure):
pass
c_sqlite3_value_p = ctypes.POINTER(c_sqlite3_value)
c_sqlite3_value_p_p = ctypes.POINTER(c_sqlite3_value_p)
# https://www.sqlite.org/c3ref/module.html
c_sqlite3_module._fields_ = [
('iVersion', ctypes.c_int),
# int (*xCreate)(sqlite3*, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char**);
('xCreate', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_p, ctypes.c_void_p, ctypes.c_int, c_char_p_p, c_sqlite3_vtab_p_p, c_char_p_p)),
# int (*xConnect)(sqlite3*, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char**);
('xConnect', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_p, ctypes.c_void_p, ctypes.c_int, c_char_p_p, c_sqlite3_vtab_p_p, c_char_p_p)),
# int (*xBestIndex)(sqlite3_vtab *pVTab, sqlite3_index_info*);
('xBestIndex', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, c_sqlite3_index_info_p)),
# int (*xDisconnect)(sqlite3_vtab *pVTab);
('xDisconnect', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xDestroy)(sqlite3_vtab *pVTab);
('xDestroy', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xOpen)(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor);
('xOpen', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, c_sqlite3_vtab_cursor_p_p)),
# int (*xClose)(sqlite3_vtab_cursor*);
('xClose', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p)),
# int (*xFilter)(sqlite3_vtab_cursor*, int idxNum, const char *idxStr, int argc, sqlite3_value **argv);
('xFilter', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, c_sqlite3_value_p_p)),
# int (*xNext)(sqlite3_vtab_cursor*);
('xNext', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p)),
# int (*xEof)(sqlite3_vtab_cursor*);
('xEof', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p)),
# int (*xColumn)(sqlite3_vtab_cursor*, sqlite3_context*, int);
('xColumn', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p, c_sqlite3_context_p, ctypes.c_int)),
# int (*xRowid)(sqlite3_vtab_cursor*, sqlite3_int64 *pRowid);
('xRowid', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p, c_int64_p)),
# int (*xUpdate)(sqlite3_vtab *, int, sqlite3_value **, sqlite3_int64 *);
('xUpdate', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, c_sqlite3_value_p_p, c_int64_p)),
# int (*xBegin)(sqlite3_vtab *pVTab);
('xBegin', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xSync)(sqlite3_vtab *pVTab);
('xSync', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xCommit)(sqlite3_vtab *pVTab);
('xCommit', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xRollback)(sqlite3_vtab *pVTab);
('xRollback', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xFindFunction)(sqlite3_vtab *pVtab, int nArg, const char *zName, void (**pxFunc)(sqlite3_context*,int,sqlite3_value**), void **ppArg);
('xFindFunction', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int, ctypes.c_char_p, ctypes.CFUNCTYPE(c_void_p_p, c_sqlite3_context_p, ctypes.c_int, c_sqlite3_value_p_p), c_void_p_p)),
# int (*xRename)(sqlite3_vtab *pVtab, const char *zNew);
('xRename', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_char_p)),
# The methods above are in version 1 of the sqlite_module object. Those below are for version 2 and greater.
# int (*xSavepoint)(sqlite3_vtab *pVTab, int);
('xSavepoint', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int)),
# int (*xRelease)(sqlite3_vtab *pVTab, int);
('xRelease', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int)),
# int (*xRollbackTo)(sqlite3_vtab *pVTab, int);
('xRollbackTo', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int)),
# The methods above are in versions 1 and 2 of the sqlite_module object. Those below are for version 3 and greater.
# int (*xShadowName)(const char*);
('xShadowName', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p)),
]
c_sqlite3_module._types_ = namedtuple('c_sqlite3_module', (k for k, v in c_sqlite3_module._fields_))(*(v for k, v in c_sqlite3_module._fields_))
# Reference counting
import copy
pythonapi = copy.deepcopy(ctypes.pythonapi)
Py_IncRef = pythonapi.Py_IncRef
Py_DecRef = pythonapi.Py_DecRef
Py_IncRef.argtypes = [ctypes.py_object]
Py_DecRef.argtypes = [ctypes.py_object]
try:
import ctypes.util
libsqlite3 = ctypes.cdll.LoadLibrary(ctypes.util.find_library('sqlite3'))
except TypeError:
import os
libsqlite3 = ctypes.cdll.LoadLibrary(os.path.join(os.path.dirname(sys.executable), 'DLLs', 'sqlite3'))
# https://news.ycombinator.com/item?id=22030336
#sqlite = ctypes.CDLL(ctypes.util.find_library('sqlite3'))
# bind value to SQLite statement
libsqlite3.sqlite3_bind_int64.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_int64)
libsqlite3.sqlite3_bind_double.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_double)
map_sqlite3_bind = {
type(0): libsqlite3.sqlite3_bind_int64,
type(0.0): libsqlite3.sqlite3_bind_double,
type(''): lambda stmt, i, value: libsqlite3.sqlite3_bind_text(stmt, i, value.encode('utf-8'), len(value.encode('utf-8')), SQLITE_TRANSIENT),
type(b''): lambda stmt, i, value: libsqlite3.sqlite3_bind_blob(stmt, i, value, len(value), SQLITE_TRANSIENT),
type(None): lambda stmt, i, _: libsqlite3.sqlite3_bind_null(stmt, i),
}
if sys.maxsize <= maxsize32:
map_sqlite3_bind[type(0)] = lambda stmt, i, value: libsqlite3.sqlite3_bind_int(stmt, i, value) if value <= maxsize32 else libsqlite3.sqlite3_bind_int64(stmt, i, value)
# extract SQLite column value
libsqlite3.sqlite3_column_name.restype = ctypes.c_char_p
libsqlite3.sqlite3_column_type.restype = ctypes.c_int
libsqlite3.sqlite3_column_int64.restype = ctypes.c_int64
libsqlite3.sqlite3_column_double.restype = ctypes.c_double
libsqlite3.sqlite3_column_blob.restype = ctypes.c_void_p
libsqlite3.sqlite3_column_bytes.restype = ctypes.c_int64
map_sqlite3_column = {
SQLITE_INTEGER: libsqlite3.sqlite3_column_int64,
SQLITE_FLOAT: libsqlite3.sqlite3_column_double,
SQLITE_TEXT: lambda stmt, i: ctypes.string_at(
libsqlite3.sqlite3_column_blob(stmt, i),
libsqlite3.sqlite3_column_bytes(stmt, i),
).decode(),
SQLITE_BLOB: lambda stmt, i: ctypes.string_at(
libsqlite3.sqlite3_column_blob(stmt, i),
libsqlite3.sqlite3_column_bytes(stmt, i),
),
SQLITE_NULL: lambda stmt, i: None,
}
# set SQLite result
libsqlite3.sqlite3_result_int64.argtypes = (c_sqlite3_context_p, ctypes.c_int64)
libsqlite3.sqlite3_result_double.argtypes = (c_sqlite3_context_p, ctypes.c_double)
libsqlite3.sqlite3_result_text64.argtypes = (c_sqlite3_context_p, ctypes.c_char_p, ctypes.c_uint64, ctypes.c_void_p, ctypes.c_ubyte)
libsqlite3.sqlite3_result_blob64.argtypes = (c_sqlite3_context_p, ctypes.c_void_p, ctypes.c_uint64, ctypes.c_void_p)
libsqlite3.sqlite3_result_null.argtypes = (c_sqlite3_context_p,)
map_sqlite3_result = {
type(0): libsqlite3.sqlite3_result_int64,
type(0.0): libsqlite3.sqlite3_result_double,
type(''): lambda ctx, value: libsqlite3.sqlite3_result_text64(ctx, s := value.encode('utf-8'), len(s), SQLITE_TRANSIENT, SQLITE_UTF8),
type(b''): lambda ctx, value: libsqlite3.sqlite3_result_blob64(ctx, value, len(value), -1),
type(None): lambda ctx, _: libsqlite3.sqlite3_result_null(ctx),
}
if sys.maxsize <= maxsize32:
map_sqlite3_result[type(0)] = lambda ctx, value: libsqlite3.sqlite3_result_int(ctx, value) if value <= maxsize32 else libsqlite3.sqlite3_result_int64(ctx, value)
# extract SQLite value
libsqlite3.sqlite3_value_type.restype = ctypes.c_int
libsqlite3.sqlite3_value_type.argtypes = (c_sqlite3_value_p,)
libsqlite3.sqlite3_value_int64.restype = ctypes.c_int64
libsqlite3.sqlite3_value_int64.argtypes = (c_sqlite3_value_p,)
libsqlite3.sqlite3_value_double.restype = ctypes.c_double
libsqlite3.sqlite3_value_double.argtypes = (c_sqlite3_value_p,)
libsqlite3.sqlite3_value_blob.restype = ctypes.c_void_p
libsqlite3.sqlite3_value_blob.argtypes = (c_sqlite3_value_p,)
libsqlite3.sqlite3_value_bytes.restype = ctypes.c_int64
libsqlite3.sqlite3_value_bytes.argtypes = (c_sqlite3_value_p,)
map_sqlite3_value = {
SQLITE_INTEGER: libsqlite3.sqlite3_value_int64,
SQLITE_FLOAT: libsqlite3.sqlite3_value_double,
SQLITE_TEXT: lambda x: ctypes.string_at(
libsqlite3.sqlite3_value_blob(x),
libsqlite3.sqlite3_value_bytes(x),
).decode(),
SQLITE_BLOB: lambda x: ctypes.string_at(
libsqlite3.sqlite3_value_blob(x),
libsqlite3.sqlite3_value_bytes(x),
),
SQLITE_NULL: lambda x: None,
}
# Query execution
# https://gist.github.com/michalc/a3147997e21665896836e0f4157975cb
libsqlite3.sqlite3_errstr.restype = ctypes.c_char_p
libsqlite3.sqlite3_errmsg.restype = ctypes.c_char_p
libsqlite3.sqlite3_declare_vtab.restype = ctypes.c_int
libsqlite3.sqlite3_malloc.restype = ctypes.c_void_p
libsqlite3.sqlite3_create_module.restype = ctypes.c_int
def connect(database, allownew=True, throw=True):
"""Retrieves raw connection or opens new connection to SQLite database."""
res = None
if isinstance(database, c_sqlite3_p):
res = database
elif isinstance(database, sqlite3.Connection):
res = c_pysqlite_connection.from_address(id(database)).db
elif isinstance(database, str) and allownew:
pDB = c_sqlite3_p()
callAPI(libsqlite3.sqlite3_open_v2, database.encode('utf-8'), ctypes.byref(pDB), SQLITE_OPEN_READWRITE, None)
res = pDB
if res:
return res
elif throw:
raise IOError('Invalid database')
def callAPI(func, *args, pDB=None):
"""Calls SQLite3 API with exceptions."""
res = func(*args)
if res != 0:
if pDB is None:
raise sqlite3.OperationalError(libsqlite3.sqlite3_errstr(res).decode())
else:
raise sqlite3.OperationalError(libsqlite3.sqlite3_errmsg(pDB).decode())
@contextlib.contextmanager
def connection(database, allownew=True):
"""Manages connection to SQLite database."""
pDB = connect(database, allownew=allownew)
try:
yield pDB
finally:
if isinstance(database, sqlite3.Connection):
database.close()
else:
callAPI(libsqlite3.sqlite3_close, pDB, pDB=pDB)
@contextlib.contextmanager
def prepared_statement(pDB, sql):
stmt = ctypes.c_void_p()
callAPI(libsqlite3.sqlite3_prepare_v3, pDB, sql.encode('utf-8'), -1, 0, ctypes.byref(stmt), None, pDB=pDB)
try:
yield stmt
finally:
callAPI(libsqlite3.sqlite3_finalize, stmt, pDB=pDB)
def execute(pDB, sql, parameters=()):
with prepared_statement(pDB, sql) as stmt:
for i, param in enumerate(parameters, start=1):
callAPI(map_sqlite3_bind[type(param)], stmt, i, param, pDB=pDB)
row_factory = namedtuple('Row', (
libsqlite3.sqlite3_column_name(stmt, i).decode() for i in range(libsqlite3.sqlite3_column_count(stmt))
))
while True:
res = libsqlite3.sqlite3_step(stmt)
if res == SQLITE_ROW:
yield row_factory(*(
map_sqlite3_column[libsqlite3.sqlite3_column_type(stmt, i)](stmt, i) for i in range(0, len(row_factory._fields))
))
elif res == SQLITE_DONE:
break
else:
raise sqlite3.OperationalError(libsqlite3.sqlite3_errstr(res).decode())
__all__ = ['connect', 'execute']

148
test.py Normal file
View File

@ -0,0 +1,148 @@
#!/usr/bin/env python3
from __init__ import *
class RangeIterable(VTableIterable):
def __init__(self, start=0, stop=float('inf'), step=1):
self.start = start
self.stop = stop
self.step = step
self.curr = self.start
def __iter__(self):
print(self.__class__.__name__, self)
return self
def __next__(self):
if self.curr <= self.stop:
res = self.curr
self.curr += self.step
return (res,)
else:
raise StopIteration
class RangeIterableGenerator(VTableIterable):
def __init__(self, start=0, stop=float('inf'), step=1):
self.start = start
self.stop = stop
self.step = step
def __iter__(self):
print(self.__class__.__name__, self)
curr = self.start
while curr <= self.stop:
yield (curr,)
curr += self.step
class RangeIterableMember(object):
def __init__(self, extra):
self.extra = extra
@vtab_iterable(name='RangeIterableMember')
def generator(self, start=0, stop=float('inf'), step=1):
print(self.__class__.__name__, self, self.generator, self.extra)
for curr in range(start, stop, step):
yield (curr,)
class RangeIterableFancy(VTableIterable):
_vtname_ = 'RangeIterableRenamed'
_vtparams_ = ('start', 'stop', 'step', 'extra')
_vtcolumns_ = ('once', 'twice', 'thrice')
def __init__(self, start=0, stop=float('inf'), step=1, **kw):
self.start = start
self.stop = stop
self.step = step
self.kw = kw
def __iter__(self):
print(self.__class__.__name__, self, self.kw)
curr = self.start
while curr <= self.stop:
yield (curr, curr*2, curr*3)
curr += self.step
@vtab_iterable
def RangeGenerator(start=0, stop=float('inf'), step=1):
print(RangeGenerator.__name__, RangeGenerator)
for curr in range(start, stop, step):
yield (curr,)
@vtab_iterable(name='RangeGeneratorRenamed', params=('start', 'stop', 'step', 'extra'), columns=('once', 'twice', 'thrice'))
def RangeGeneratorFancy(start=0, stop=float('inf'), step=1, **kw):
print(RangeGeneratorFancy.__name__, RangeGeneratorFancy, kw)
for curr in range(start, stop, step):
yield (curr, curr*2, curr*3)
if __name__ == '__main__':
import sqlite3
conn = sqlite3.connect(':memory:')
pDB = connect(conn)
conn.execute('CREATE TABLE test (a INTEGER, b INTEGER)')
conn.execute('INSERT INTO test (a, b) VALUES (?, ?)', (1, 2))
conn.execute('INSERT INTO test (a, b) VALUES (?, ?)', (3, 4))
conn.execute('INSERT INTO test (a, b) VALUES (?, ?)', (5, 6))
conn.execute('INSERT INTO test (a, b) VALUES (?, ?)', (7, 8))
for row in conn.execute('SELECT * FROM test'):
print('sqlite3:', row)
for row in pDB.execute('SELECT * FROM test'):
print('tricks:', row)
print()
for row in conn.execute('SELECT * FROM test WHERE a = ?', (1,)):
print('sqlite3:', row)
for row in pDB.execute('SELECT * FROM test WHERE a = ?', (1,)):
print('tricks:', row)
print()
RangeIterable.register(conn)
for row in conn.execute('SELECT * FROM RangeIterable(10, 16, 2)'):
print('sqlite3:', row)
for row in pDB.execute('SELECT * FROM RangeIterable(10, 16, 2)'):
print('tricks:', row)
print()
RangeIterableGenerator.register(conn)
for row in conn.execute('SELECT * FROM RangeIterableGenerator(20, 26, 2)'):
print('sqlite3:', row)
for row in pDB.execute('SELECT * FROM RangeIterableGenerator(20, 26, 2)'):
print('tricks:', row)
print()
obj = RangeIterableMember('hello')
obj.generator.register(conn)
for row in conn.execute('SELECT * FROM RangeIterableMember(30, 36, 2)'):
print('sqlite3:', row)
for row in pDB.execute('SELECT * FROM RangeIterableMember(30, 36, 2)'):
print('tricks:', row)
print()
RangeIterableFancy.register(conn)
for row in conn.execute('SELECT * FROM RangeIterableRenamed(40, 46, 2, datetime())'):
print('sqlite3:', row)
for row in pDB.execute('SELECT * FROM RangeIterableRenamed(40, 46, 2, datetime())'):
print('tricks:', row)
print()
RangeGenerator.register(conn)
for row in conn.execute('SELECT * FROM RangeGenerator(50, 56, 2)'):
print('sqlite3:', row)
for row in pDB.execute('SELECT * FROM RangeGenerator(50, 56, 2)'):
print('tricks:', row)
print()
RangeGeneratorFancy.register(conn)
for row in conn.execute('SELECT * FROM RangeGeneratorRenamed(60, 66, 2, datetime())'):
print('sqlite3:', row)
for row in pDB.execute('SELECT * FROM RangeGeneratorRenamed(60, 66, 2, datetime())'):
print('tricks:', row)
print()
import hashlib
async def md5sum(t):
return hashlib.md5(str(t).encode('utf-8')).hexdigest()
create_function(conn, 'md5', 1, md5sum)
for row in conn.execute('SELECT md5(value) FROM RangeIterable(10, 16, 2)'):
print(row)
for row in pDB.execute('SELECT md5(value) AS md5_value FROM RangeIterable(10, 16, 2)'):
print(row)

260
vtab.py Normal file
View File

@ -0,0 +1,260 @@
#!/usr/bin/env python3
import sys
import sqlite3
import ctypes
import functools
import inspect
import traceback
import direct
from direct import libsqlite3, Py_IncRef, Py_DecRef, c_sqlite3_module, c_sqlite3_value_p_p, map_sqlite3_value, map_sqlite3_result, SQLITE_OK, SQLITE_ERROR, SQLITE_CONSTRAINT, SQLITE_INDEX_CONSTRAINT_EQ
USE_SQLITE_CONSTRAINT = tuple(map(int, sqlite3.sqlite_version.split('.'))) >= (3, 26, 0)
@direct.c_sqlite3_module._types_.xConnect
def vtab_xConnect(pDB, pAux, argc, argv, ppVTab, pzErr): # sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr
"""
Invoked whenever a database connection attaches to or reparses a schema.
A virtual table is eponymous if its xCreate method is the exact same function as the xConnect method, or if the xCreate method is NULL.
"""
tableFactory = ctypes.cast(pAux, ctypes.py_object).value
res = libsqlite3.sqlite3_declare_vtab(pDB, f'CREATE TABLE x({vtab_columns_declaration(tableFactory)})'.encode('utf-8'))
if res == SQLITE_OK:
szVTab = ctypes.sizeof(direct.c_sqlite3_vtab)
pVTab = ctypes.cast(libsqlite3.sqlite3_malloc(szVTab), direct.c_sqlite3_vtab_p)
ctypes.memset(pVTab, 0, szVTab)
pVTab.contents.tableFactory = ctypes.cast(pAux, ctypes.py_object)
Py_IncRef(pVTab.contents.tableFactory)
ppVTab[0] = pVTab
return res
@direct.c_sqlite3_module._types_.xBestIndex
def vtab_xBestIndex(pVTab, pIdxInfo): # sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo
"""Used to determine the best way to access the virtual table."""
tableFactory = pVTab.contents.tableFactory
vIdxInfo = pIdxInfo.contents
nArg = 0
columns = []
nParams = len(tableFactory._vtparams_)
nColumns = len(tableFactory._vtcolumns_)
for i in range(vIdxInfo.nConstraint):
vConstraint = vIdxInfo.aConstraint[i]
if vConstraint.usable and vConstraint.op == SQLITE_INDEX_CONSTRAINT_EQ:
columns.append(tableFactory._vtparams_[vConstraint.iColumn - nColumns])
nArg += 1
vIdxInfo.aConstraintUsage[i].argvIndex = nArg
vIdxInfo.aConstraintUsage[i].omit = 1
if nArg > 0 or nParams == 0:
if nArg == nParams:
# All parameters are present, this is ideal.
vIdxInfo.estimatedCost = 1.0
vIdxInfo.estimatedRows = 10
else:
# Penalize score based on number of missing params.
vIdxInfo.estimatedCost = 1e13 * (nParams - nArg)
vIdxInfo.estimatedRows = 10 ** (nParams - nArg)
# Store a reference to the columns in the index info structure.
joinedCols = ','.join(columns).encode('utf-8')
idxStr = libsqlite3.sqlite3_malloc((len(joinedCols) + 1) * ctypes.sizeof(ctypes.c_char))
idxStr = (ctypes.c_char*(len(joinedCols) + 1)).from_address(idxStr)
ctypes.memmove(idxStr, joinedCols, len(joinedCols)) # memcpy
idxStr[len(joinedCols)] = b'\x00'
vIdxInfo.idxStr = ctypes.cast(idxStr, ctypes.c_char_p)
vIdxInfo.needToFreeIdxStr = 0
elif USE_SQLITE_CONSTRAINT:
return SQLITE_CONSTRAINT
else:
vIdxInfo.estimatedCost = DBL_MAX
vIdxInfo.estimatedRows = 100000
return SQLITE_OK
@direct.c_sqlite3_module._types_.xDisconnect
def vtab_xDisconnect(pVTab): # sqlite3_vtab *pVTab
"""Releases a connection to a virtual table."""
Py_DecRef(pVTab.contents.tableFactory)
libsqlite3.sqlite3_free(pVTab)
return SQLITE_OK
@direct.c_sqlite3_module._types_.xOpen
def vtab_xOpen(pVTab, ppCursor): # sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor
"""Creates a new cursor used for accessing (read and/or writing) a virtual table."""
# tableFactory = pVTab.contents.tableFactory
szCursor = ctypes.sizeof(direct.c_sqlite3_vtab_cursor)
pCursor = ctypes.cast(libsqlite3.sqlite3_malloc(szCursor), direct.c_sqlite3_vtab_cursor_p)
ctypes.memset(pCursor, 0, szCursor)
vCursor = pCursor.contents
vCursor.tableFactory = pVTab.contents.tableFactory
Py_IncRef(vCursor.tableFactory)
vCursor.rowData = None
vCursor.rowIdx = 0
ppCursor[0] = pCursor
return SQLITE_OK
@direct.c_sqlite3_module._types_.xClose
def vtab_xClose(pCursor): # sqlite3_vtab_cursor *pCursor
"""Closes a cursor previously opened by xOpen."""
Py_DecRef(pCursor.contents.tableFactory)
libsqlite3.sqlite3_free(pCursor)
return SQLITE_OK
def vtab_iterate(vCursor):
if vCursor.rowData:
Py_DecRef(vCursor.rowData)
vCursor.rowData = None
try:
vCursor.rowData = next(vCursor.tableIterator)
Py_IncRef(vCursor.rowData)
except StopIteration:
Py_DecRef(vCursor.tableIterator)
vCursor.tableIterator = None
except:
traceback.print_exc()
return SQLITE_ERROR
else:
vCursor.rowIdx += 1
return SQLITE_OK
@direct.c_sqlite3_module._types_.xFilter
def vtab_xFilter(pCursor, idxNum, idxStr, argc, argv): # sqlite3_vtab_cursor *pCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv
"""Begins a search of a virtual table."""
vCursor = pCursor.contents
tableFactory = vCursor.tableFactory
idxStr = ctypes.cast(idxStr, ctypes.c_char_p).value # avoid null terminator
if not idxStr or argc == 0 and len(tableFactory._vtparams_):
return SQLITE_ERROR
elif len(idxStr):
params = idxStr.decode().split(',')
else:
params = []
py_args = tuple(map_sqlite3_value[libsqlite3.sqlite3_value_type(arg := argv[i])](arg) for i in range(argc))
query = dict(zip(params, py_args))
try:
tableIterator = iter(tableFactory(**query))
except:
traceback.print_exc()
return SQLITE_ERROR
vCursor.tableIterator = tableIterator
Py_IncRef(vCursor.tableIterator)
return vtab_iterate(vCursor)
@direct.c_sqlite3_module._types_.xNext
def vtab_xNext(pCursor): # sqlite3_vtab_cursor *pCursor
"""Advances a virtual table cursor to the next row of a result set initiated by xFilter."""
return vtab_iterate(pCursor.contents)
@direct.c_sqlite3_module._types_.xEof
def vtab_xEof(pCursor): # sqlite3_vtab_cursor *pCursor
"""Returns false (zero) if the specified cursor currently points to a valid row of data, or true (non-zero) otherwise."""
return 1 if pCursor.contents.tableIterator is None else 0
@direct.c_sqlite3_module._types_.xColumn
def vtab_xColumn(pCursor, pContext, iCol): # sqlite3_vtab_cursor *pCursor, sqlite3_context *pContext, int iCol
"""Invoked to retrieve the N-th column of the current row."""
if iCol == -1:
libsqlite3.sqlite3_result_int64(pContext, pCursor.contents.rowIdx)
return SQLITE_OK
if not pCursor.contents.rowData:
libsqlite3.sqlite3_result_error(pContext, 'no row data'.encode('utf-8'), -1)
return SQLITE_ERROR
map_sqlite3_result[type(value := pCursor.contents.rowData[iCol])](pContext, value)
return SQLITE_OK
@direct.c_sqlite3_module._types_.xRowid
def vtab_xRowid(pCursor, pRowid): # sqlite3_vtab_cursor *pCursor, sqlite3_int64 *pRowid
"""Fills *pRowid with the rowid of the current row."""
pRowid[0] = pCursor.contents.rowIdx
def new_sqlite3_module_vtabfunc():
module = direct.c_sqlite3_module()
ctypes.memset(ctypes.byref(module), 0, ctypes.sizeof(direct.c_sqlite3_module))
module.iVersion = 0
#module.xCreate = direct.c_sqlite3_module._types_.xCreate(0)
module.xConnect = vtab_xConnect
module.xBestIndex = vtab_xBestIndex
module.xDisconnect = vtab_xDisconnect
#module.xDestroy = direct.c_sqlite3_module._types_.xDestroy(0)
module.xOpen = vtab_xOpen
module.xClose = vtab_xClose
module.xFilter = vtab_xFilter
module.xNext = vtab_xNext
module.xEof = vtab_xEof
module.xColumn = vtab_xColumn
module.xRowid = vtab_xRowid
#module.xUpdate = direct.c_sqlite3_module._types_.xUpdate(0)
#module.xBegin = direct.c_sqlite3_module._types_.xBegin(0)
#module.xSync = direct.c_sqlite3_module._types_.xSync(0)
#module.xCommit = direct.c_sqlite3_module._types_.xCommit(0)
#module.xRollback = direct.c_sqlite3_module._types_.xRollback(0)
#module.xFindFunction = direct.c_sqlite3_module._types_.xFindFunction(0)
#module.xRename = direct.c_sqlite3_module._types_.xRename(0)
return module
def vtab_columns_declaration(factory):
acc = []
for column in factory._vtcolumns_:
if isinstance(column, str):
acc.append(column)
elif isinstance(column, tuple) and len(column) != 2:
acc.append('%s %s'%column)
else:
raise ValueError('Column must be either a string or a 2-tuple of name, type')
for param in factory._vtparams_:
acc.append('%s HIDDEN'%param)
return ', '.join(acc)
class VTableIterable(object):
@classmethod
def register(cls, database, name=None):
if not cls._vtparams_:
cls._vtparams_ = tuple(k for k in inspect.signature(cls).parameters)
if not cls._vtcolumns_:
cls._vtcolumns_ = ('value',)
return libsqlite3.sqlite3_create_module(direct.connect(database), (name or cls._vtname_ or cls.__name__).encode('utf-8'), ctypes.byref(cls._vtmod_), ctypes.py_object(cls))
@classmethod
def unregister(cls, database, name=None):
return libsqlite3.sqlite3_create_module(direct.connect(database), (name or cls._vtname_ or cls.__name__).encode('utf-8'), None, None)
def __iter__(self):
raise NotImplementedError
def __next__(self):
raise NotImplementedError
_vtmod_ = new_sqlite3_module_vtabfunc()
_vtname_ = None
_vtparams_ = None
_vtcolumns_ = None
def vtab_wrap(dst, src, shift=0, name=None, params=None, columns=None):
dst._vtname_ = name or getattr(src, '_vtname_', None) or src.__name__
dst._vtparams_ = params or getattr(src, '_vtparams_', None) or tuple(inspect.signature(src).parameters)[shift:]
dst._vtcolumns_ = columns or getattr(src, '_vtcolumns_', None) or ('value',)
dst.register = lambda database, name=None: libsqlite3.sqlite3_create_module(direct.connect(database), (name or dst._vtname_).encode('utf-8'), ctypes.byref(VTableIterable._vtmod_), ctypes.py_object(dst))
dst.unregister = getattr(src, 'unregister', None) or (lambda database, name=None: libsqlite3.sqlite3_create_module(direct.connect(database), (name or dst._vtname_).encode('utf-8'), None, None))
dst.__doc__ = src.__doc__
dst.__name__ = src.__name__
return dst
class vtab_descriptor(object):
def __init__(self, func, name=None, params=None, columns=None):
self._vtfunc_ = func
vtab_wrap(self, func, 1, name=name, params=params, columns=columns)
def __call__(self, *args, **kwargs):
return self._vtfunc_(*args, **kwargs)
def __get__(self, obj, cls):
return vtab_wrap(functools.partial(self._vtfunc_, obj), self)
def vtab_iterable(func=None, *, name=None, params=None, columns=None, member=False):
def decorator(src):
try:
if member or inspect.getfullargspec(src)[0][0] in {'self', 'cls'}:
return vtab_descriptor(src, name=name, params=params, columns=columns)
except IndexError:
pass
def proxy(*args, **kw):
return src(*args, **kw)
return vtab_wrap(proxy, src, name=name, params=params, columns=columns)
if func is None:
return decorator
else:
return decorator(func)
__all__ = ['VTableIterable', 'vtab_iterable']