sqlite3tricks-py/direct.py

370 lines
15 KiB
Python
Raw Normal View History

2020-03-30 09:00:00 -04:00
#!/usr/bin/env python3
import sys
import sqlite3
import ctypes
import contextlib
from collections import namedtuple
maxsize32 = 2**31 - 1
# Flags for file open operations
# https://www.sqlite.org/c3ref/c_open_autoproxy.html
SQLITE_OPEN_READWRITE = 0x00000002
# Constants defining special destructor behavior
# https://www.sqlite.org/c3ref/c_static.html
SQLITE_TRANSIENT = -1
# Virtual table constraint operator codes
# https://www.sqlite.org/c3ref/c_index_constraint_eq.html
SQLITE_INDEX_CONSTRAINT_EQ = 2
# Result codes
# https://www.sqlite.org/rescode.html
SQLITE_OK = 0
SQLITE_ERROR = 1
SQLITE_CONSTRAINT = 19
SQLITE_ROW = 100
SQLITE_DONE = 101
# Fundamental datatypes
# https://www.sqlite.org/c3ref/c_blob.html
SQLITE_INTEGER = 1
SQLITE_FLOAT = 2
SQLITE_TEXT = 3
SQLITE_BLOB = 4
SQLITE_NULL = 5
# Custom pointer types
c_char_p_p = ctypes.POINTER(ctypes.c_char_p)
c_int64_p = ctypes.POINTER(ctypes.c_int64)
c_void_p_p = ctypes.POINTER(ctypes.c_void_p)
class c_PyObject(ctypes.Structure):
"""All Python object types are extensions of this type."""
# https://stackoverflow.com/questions/35438103/referencing-a-pointer-of-array-in-a-struct-with-ctypes
_fields_ = [
('ob_refcnt', ctypes.c_ssize_t),
('ob_type', ctypes.c_void_p),
]
class c_sqlite3_p(ctypes.c_void_p):
def execute(self, sql, parameters=()):
return execute(self, sql, parameters)
class c_pysqlite_connection(ctypes.Structure):
"""Just enough struct to get at the pointer to the SQLite db
Defined at https://github.com/python/cpython/blob/main/Modules/_sqlite/connection.h
"""
_fields_ = [
('ob_base', c_PyObject),
('db', c_sqlite3_p),
]
class c_sqlite3_module(ctypes.Structure): pass # forward declaration
c_sqlite3_module_p = ctypes.POINTER(c_sqlite3_module)
class c_sqlite3_vtab(ctypes.Structure):
# https://www.sqlite.org/c3ref/vtab.html
_fields_ = [
('pModule', c_sqlite3_module_p),
('nRef', ctypes.c_int),
('zErrMsg', ctypes.c_char_p),
# Custom fields
('tableFactory', ctypes.py_object),
]
c_sqlite3_vtab_p = ctypes.POINTER(c_sqlite3_vtab)
c_sqlite3_vtab_p_p = ctypes.POINTER(c_sqlite3_vtab_p)
class c_sqlite3_index_constraint(ctypes.Structure):
_fields_ = [
('iColumn', ctypes.c_int),
('op', ctypes.c_ubyte),
('usable', ctypes.c_ubyte),
('iTermOffset', ctypes.c_int),
]
c_sqlite3_index_constraint_p = ctypes.POINTER(c_sqlite3_index_constraint)
class c_sqlite3_index_orderby(ctypes.Structure):
_fields_ = [
('iColumn', ctypes.c_int),
('desc', ctypes.c_ubyte),
]
c_sqlite3_index_orderby_p = ctypes.POINTER(c_sqlite3_index_orderby)
class c_sqlite3_index_constraint_usage(ctypes.Structure):
_fields_ = [
('argvIndex', ctypes.c_int),
('omit', ctypes.c_ubyte),
]
c_sqlite3_index_constraint_usage_p = ctypes.POINTER(c_sqlite3_index_constraint_usage)
class c_sqlite3_index_info(ctypes.Structure):
# https://www.sqlite.org/c3ref/index_info.html
_fields_ = [
# Inputs
('nConstraint', ctypes.c_int),
('aConstraint', c_sqlite3_index_constraint_p),
('nOrderBy', ctypes.c_int),
('aOrderBy', c_sqlite3_index_orderby_p),
# Outputs
('aConstraintUsage', c_sqlite3_index_constraint_usage_p),
('idxNum', ctypes.c_int),
('idxStr', ctypes.c_char_p),
('needToFreeIdxStr', ctypes.c_int),
('orderByConsumed', ctypes.c_int),
('estimatedCost', ctypes.c_double),
# Fields below are only available in SQLite 3.8.2 and later
('estimatedRows', ctypes.c_int64),
# Fields below are only available in SQLite 3.9.0 and later
('idxFlags', ctypes.c_int),
# Fields below are only available in SQLite 3.10.0 and later
('colUsed', ctypes.c_uint64),
]
c_sqlite3_index_info_p = ctypes.POINTER(c_sqlite3_index_info)
class c_sqlite3_vtab_cursor(ctypes.Structure):
# https://www.sqlite.org/c3ref/vtab_cursor.html
_fields_ = [
('pVtab', c_sqlite3_vtab_p),
# Custom fields
('tableFactory', ctypes.py_object),
('tableIterator', ctypes.py_object),
('rowData', ctypes.py_object),
('rowIdx', ctypes.c_int64),
]
c_sqlite3_vtab_cursor_p = ctypes.POINTER(c_sqlite3_vtab_cursor)
c_sqlite3_vtab_cursor_p_p = ctypes.POINTER(c_sqlite3_vtab_cursor_p)
class c_sqlite3_context(ctypes.Structure):
pass
c_sqlite3_context_p = ctypes.POINTER(c_sqlite3_context)
class c_sqlite3_value(ctypes.Structure):
pass
c_sqlite3_value_p = ctypes.POINTER(c_sqlite3_value)
c_sqlite3_value_p_p = ctypes.POINTER(c_sqlite3_value_p)
# https://www.sqlite.org/c3ref/module.html
c_sqlite3_module._fields_ = [
('iVersion', ctypes.c_int),
# int (*xCreate)(sqlite3*, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char**);
('xCreate', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_p, ctypes.c_void_p, ctypes.c_int, c_char_p_p, c_sqlite3_vtab_p_p, c_char_p_p)),
# int (*xConnect)(sqlite3*, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char**);
('xConnect', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_p, ctypes.c_void_p, ctypes.c_int, c_char_p_p, c_sqlite3_vtab_p_p, c_char_p_p)),
# int (*xBestIndex)(sqlite3_vtab *pVTab, sqlite3_index_info*);
('xBestIndex', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, c_sqlite3_index_info_p)),
# int (*xDisconnect)(sqlite3_vtab *pVTab);
('xDisconnect', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xDestroy)(sqlite3_vtab *pVTab);
('xDestroy', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xOpen)(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor);
('xOpen', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, c_sqlite3_vtab_cursor_p_p)),
# int (*xClose)(sqlite3_vtab_cursor*);
('xClose', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p)),
# int (*xFilter)(sqlite3_vtab_cursor*, int idxNum, const char *idxStr, int argc, sqlite3_value **argv);
('xFilter', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, c_sqlite3_value_p_p)),
# int (*xNext)(sqlite3_vtab_cursor*);
('xNext', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p)),
# int (*xEof)(sqlite3_vtab_cursor*);
('xEof', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p)),
# int (*xColumn)(sqlite3_vtab_cursor*, sqlite3_context*, int);
('xColumn', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p, c_sqlite3_context_p, ctypes.c_int)),
# int (*xRowid)(sqlite3_vtab_cursor*, sqlite3_int64 *pRowid);
('xRowid', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_cursor_p, c_int64_p)),
# int (*xUpdate)(sqlite3_vtab *, int, sqlite3_value **, sqlite3_int64 *);
('xUpdate', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, c_sqlite3_value_p_p, c_int64_p)),
# int (*xBegin)(sqlite3_vtab *pVTab);
('xBegin', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xSync)(sqlite3_vtab *pVTab);
('xSync', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xCommit)(sqlite3_vtab *pVTab);
('xCommit', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xRollback)(sqlite3_vtab *pVTab);
('xRollback', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p)),
# int (*xFindFunction)(sqlite3_vtab *pVtab, int nArg, const char *zName, void (**pxFunc)(sqlite3_context*,int,sqlite3_value**), void **ppArg);
('xFindFunction', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int, ctypes.c_char_p, ctypes.CFUNCTYPE(c_void_p_p, c_sqlite3_context_p, ctypes.c_int, c_sqlite3_value_p_p), c_void_p_p)),
# int (*xRename)(sqlite3_vtab *pVtab, const char *zNew);
('xRename', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_char_p)),
# The methods above are in version 1 of the sqlite_module object. Those below are for version 2 and greater.
# int (*xSavepoint)(sqlite3_vtab *pVTab, int);
('xSavepoint', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int)),
# int (*xRelease)(sqlite3_vtab *pVTab, int);
('xRelease', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int)),
# int (*xRollbackTo)(sqlite3_vtab *pVTab, int);
('xRollbackTo', ctypes.CFUNCTYPE(ctypes.c_int, c_sqlite3_vtab_p, ctypes.c_int)),
# The methods above are in versions 1 and 2 of the sqlite_module object. Those below are for version 3 and greater.
# int (*xShadowName)(const char*);
('xShadowName', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p)),
]
c_sqlite3_module._types_ = namedtuple('c_sqlite3_module', (k for k, v in c_sqlite3_module._fields_))(*(v for k, v in c_sqlite3_module._fields_))
# Reference counting
import copy
pythonapi = copy.deepcopy(ctypes.pythonapi)
Py_IncRef = pythonapi.Py_IncRef
Py_DecRef = pythonapi.Py_DecRef
Py_IncRef.argtypes = [ctypes.py_object]
Py_DecRef.argtypes = [ctypes.py_object]
try:
import ctypes.util
libsqlite3 = ctypes.cdll.LoadLibrary(ctypes.util.find_library('sqlite3'))
except TypeError:
import os
libsqlite3 = ctypes.cdll.LoadLibrary(os.path.join(os.path.dirname(sys.executable), 'DLLs', 'sqlite3'))
# https://news.ycombinator.com/item?id=22030336
#sqlite = ctypes.CDLL(ctypes.util.find_library('sqlite3'))
# bind value to SQLite statement
libsqlite3.sqlite3_bind_int64.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_int64)
libsqlite3.sqlite3_bind_double.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_double)
map_sqlite3_bind = {
type(0): libsqlite3.sqlite3_bind_int64,
type(0.0): libsqlite3.sqlite3_bind_double,
type(''): lambda stmt, i, value: libsqlite3.sqlite3_bind_text(stmt, i, value.encode('utf-8'), len(value.encode('utf-8')), SQLITE_TRANSIENT),
type(b''): lambda stmt, i, value: libsqlite3.sqlite3_bind_blob(stmt, i, value, len(value), SQLITE_TRANSIENT),
type(None): lambda stmt, i, _: libsqlite3.sqlite3_bind_null(stmt, i),
}
if sys.maxsize <= maxsize32:
map_sqlite3_bind[type(0)] = lambda stmt, i, value: libsqlite3.sqlite3_bind_int(stmt, i, value) if value <= maxsize32 else libsqlite3.sqlite3_bind_int64(stmt, i, value)
# extract SQLite column value
libsqlite3.sqlite3_column_name.restype = ctypes.c_char_p
libsqlite3.sqlite3_column_type.restype = ctypes.c_int
libsqlite3.sqlite3_column_int64.restype = ctypes.c_int64
libsqlite3.sqlite3_column_double.restype = ctypes.c_double
libsqlite3.sqlite3_column_blob.restype = ctypes.c_void_p
libsqlite3.sqlite3_column_bytes.restype = ctypes.c_int64
map_sqlite3_column = {
SQLITE_INTEGER: libsqlite3.sqlite3_column_int64,
SQLITE_FLOAT: libsqlite3.sqlite3_column_double,
SQLITE_TEXT: lambda stmt, i: ctypes.string_at(
libsqlite3.sqlite3_column_blob(stmt, i),
libsqlite3.sqlite3_column_bytes(stmt, i),
).decode(),
SQLITE_BLOB: lambda stmt, i: ctypes.string_at(
libsqlite3.sqlite3_column_blob(stmt, i),
libsqlite3.sqlite3_column_bytes(stmt, i),
),
SQLITE_NULL: lambda stmt, i: None,
}
# set SQLite result
libsqlite3.sqlite3_result_int64.argtypes = (c_sqlite3_context_p, ctypes.c_int64)
libsqlite3.sqlite3_result_double.argtypes = (c_sqlite3_context_p, ctypes.c_double)
libsqlite3.sqlite3_result_text64.argtypes = (c_sqlite3_context_p, ctypes.c_char_p, ctypes.c_uint64, ctypes.c_void_p, ctypes.c_ubyte)
libsqlite3.sqlite3_result_blob64.argtypes = (c_sqlite3_context_p, ctypes.c_void_p, ctypes.c_uint64, ctypes.c_void_p)
libsqlite3.sqlite3_result_null.argtypes = (c_sqlite3_context_p,)
map_sqlite3_result = {
type(0): libsqlite3.sqlite3_result_int64,
type(0.0): libsqlite3.sqlite3_result_double,
type(''): lambda ctx, value: libsqlite3.sqlite3_result_text64(ctx, s := value.encode('utf-8'), len(s), SQLITE_TRANSIENT, SQLITE_UTF8),
type(b''): lambda ctx, value: libsqlite3.sqlite3_result_blob64(ctx, value, len(value), -1),
type(None): lambda ctx, _: libsqlite3.sqlite3_result_null(ctx),
}
if sys.maxsize <= maxsize32:
map_sqlite3_result[type(0)] = lambda ctx, value: libsqlite3.sqlite3_result_int(ctx, value) if value <= maxsize32 else libsqlite3.sqlite3_result_int64(ctx, value)
# extract SQLite value
libsqlite3.sqlite3_value_type.restype = ctypes.c_int
libsqlite3.sqlite3_value_type.argtypes = (c_sqlite3_value_p,)
libsqlite3.sqlite3_value_int64.restype = ctypes.c_int64
libsqlite3.sqlite3_value_int64.argtypes = (c_sqlite3_value_p,)
libsqlite3.sqlite3_value_double.restype = ctypes.c_double
libsqlite3.sqlite3_value_double.argtypes = (c_sqlite3_value_p,)
libsqlite3.sqlite3_value_blob.restype = ctypes.c_void_p
libsqlite3.sqlite3_value_blob.argtypes = (c_sqlite3_value_p,)
libsqlite3.sqlite3_value_bytes.restype = ctypes.c_int64
libsqlite3.sqlite3_value_bytes.argtypes = (c_sqlite3_value_p,)
map_sqlite3_value = {
SQLITE_INTEGER: libsqlite3.sqlite3_value_int64,
SQLITE_FLOAT: libsqlite3.sqlite3_value_double,
SQLITE_TEXT: lambda x: ctypes.string_at(
libsqlite3.sqlite3_value_blob(x),
libsqlite3.sqlite3_value_bytes(x),
).decode(),
SQLITE_BLOB: lambda x: ctypes.string_at(
libsqlite3.sqlite3_value_blob(x),
libsqlite3.sqlite3_value_bytes(x),
),
SQLITE_NULL: lambda x: None,
}
# Query execution
# https://gist.github.com/michalc/a3147997e21665896836e0f4157975cb
libsqlite3.sqlite3_errstr.restype = ctypes.c_char_p
libsqlite3.sqlite3_errmsg.restype = ctypes.c_char_p
libsqlite3.sqlite3_declare_vtab.restype = ctypes.c_int
libsqlite3.sqlite3_malloc.restype = ctypes.c_void_p
libsqlite3.sqlite3_create_module.restype = ctypes.c_int
def connect(database, allownew=True, throw=True):
"""Retrieves raw connection or opens new connection to SQLite database."""
res = None
if isinstance(database, c_sqlite3_p):
res = database
elif isinstance(database, sqlite3.Connection):
res = c_pysqlite_connection.from_address(id(database)).db
elif isinstance(database, str) and allownew:
pDB = c_sqlite3_p()
callAPI(libsqlite3.sqlite3_open_v2, database.encode('utf-8'), ctypes.byref(pDB), SQLITE_OPEN_READWRITE, None)
res = pDB
if res:
return res
elif throw:
raise IOError('Invalid database')
def callAPI(func, *args, pDB=None):
"""Calls SQLite3 API with exceptions."""
res = func(*args)
if res != 0:
if pDB is None:
raise sqlite3.OperationalError(libsqlite3.sqlite3_errstr(res).decode())
else:
raise sqlite3.OperationalError(libsqlite3.sqlite3_errmsg(pDB).decode())
@contextlib.contextmanager
def connection(database, allownew=True):
"""Manages connection to SQLite database."""
pDB = connect(database, allownew=allownew)
try:
yield pDB
finally:
if isinstance(database, sqlite3.Connection):
database.close()
else:
callAPI(libsqlite3.sqlite3_close, pDB, pDB=pDB)
@contextlib.contextmanager
def prepared_statement(pDB, sql):
stmt = ctypes.c_void_p()
callAPI(libsqlite3.sqlite3_prepare_v3, pDB, sql.encode('utf-8'), -1, 0, ctypes.byref(stmt), None, pDB=pDB)
try:
yield stmt
finally:
callAPI(libsqlite3.sqlite3_finalize, stmt, pDB=pDB)
def execute(pDB, sql, parameters=()):
with prepared_statement(pDB, sql) as stmt:
for i, param in enumerate(parameters, start=1):
callAPI(map_sqlite3_bind[type(param)], stmt, i, param, pDB=pDB)
row_factory = namedtuple('Row', (
libsqlite3.sqlite3_column_name(stmt, i).decode() for i in range(libsqlite3.sqlite3_column_count(stmt))
))
while True:
res = libsqlite3.sqlite3_step(stmt)
if res == SQLITE_ROW:
yield row_factory(*(
map_sqlite3_column[libsqlite3.sqlite3_column_type(stmt, i)](stmt, i) for i in range(0, len(row_factory._fields))
))
elif res == SQLITE_DONE:
break
else:
raise sqlite3.OperationalError(libsqlite3.sqlite3_errstr(res).decode())
__all__ = ['connect', 'execute']