D7net
Home
Console
Upload
information
Create File
Create Folder
About
Tools
:
/
opt
/
alt
/
python38
/
lib64
/
python3.8
/
site-packages
/
Filename :
peewee.py
back
Copy
# May you do good and not evil # May you find forgiveness for yourself and forgive others # May you share freely, never taking more than you give. -- SQLite source code # # As we enjoy great advantages from the inventions of others, we should be glad # of an opportunity to serve others by an invention of ours, and this we should # do freely and generously. -- Ben Franklin # # (\ # ( \ /(o)\ caw! # ( \/ ()/ /) # ( `;.))'".) # `(/////.-' # =====))=))===() # ///' # // # ' import calendar import datetime import decimal import hashlib import itertools import logging import operator import re import sys import threading import time import uuid import weakref from bisect import bisect_left from bisect import bisect_right from collections import deque from collections import namedtuple try: from collections import OrderedDict except ImportError: OrderedDict = dict from copy import deepcopy from functools import wraps from inspect import isclass __version__ = '2.8.4' __all__ = [ 'BareField', 'BigIntegerField', 'BlobField', 'BooleanField', 'CharField', 'Check', 'Clause', 'CompositeKey', 'DatabaseError', 'DataError', 'DateField', 'DateTimeField', 'DecimalField', 'DeferredRelation', 'DoesNotExist', 'DoubleField', 'DQ', 'Field', 'FixedCharField', 'FloatField', 'fn', 'ForeignKeyField', 'ImproperlyConfigured', 'IntegerField', 'IntegrityError', 'InterfaceError', 'InternalError', 'JOIN', 'JOIN_FULL', 'JOIN_INNER', 'JOIN_LEFT_OUTER', 'Model', 'MySQLDatabase', 'NotSupportedError', 'OperationalError', 'Param', 'PostgresqlDatabase', 'prefetch', 'PrimaryKeyField', 'ProgrammingError', 'Proxy', 'R', 'SmallIntegerField', 'SqliteDatabase', 'SQL', 'TextField', 'TimeField', 'TimestampField', 'Using', 'UUIDField', 'Window', ] # Set default logging handler to avoid "No handlers could be found for logger # "peewee"" warnings. try: # Python 2.7+ from logging import NullHandler except ImportError: class NullHandler(logging.Handler): def emit(self, record): pass # All peewee-generated logs are logged to this namespace. logger = logging.getLogger('peewee') logger.addHandler(NullHandler()) # Python 2/3 compatibility helpers. These helpers are used internally and are # not exported. _METACLASS_ = '_metaclass_helper_' def with_metaclass(meta, base=object): return meta(_METACLASS_, (base,), {}) PY2 = sys.version_info[0] == 2 PY3 = sys.version_info[0] == 3 PY26 = sys.version_info[:2] == (2, 6) if PY3: import builtins from collections import Callable from functools import reduce callable = lambda c: isinstance(c, Callable) unicode_type = str string_type = bytes basestring = str print_ = getattr(builtins, 'print') binary_construct = lambda s: bytes(s.encode('raw_unicode_escape')) long = int def reraise(tp, value, tb=None): if value.__traceback__ is not tb: raise value.with_traceback(tb) raise value elif PY2: unicode_type = unicode string_type = basestring binary_construct = buffer def print_(s): sys.stdout.write(s) sys.stdout.write('\n') exec('def reraise(tp, value, tb=None): raise tp, value, tb') else: raise RuntimeError('Unsupported python version.') if PY26: _M = 10**6 total_seconds = lambda t: (t.microseconds + 0.0 + (t.seconds + t.days * 24 * 3600) * _M) / _M else: total_seconds = lambda t: t.total_seconds() # By default, peewee supports Sqlite, MySQL and Postgresql. try: from pysqlite2 import dbapi2 as pysq3 except ImportError: pysq3 = None try: import sqlite3 except ImportError: sqlite3 = pysq3 else: if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info: sqlite3 = pysq3 try: from psycopg2cffi import compat compat.register() except ImportError: pass try: import psycopg2 from psycopg2 import extensions as pg_extensions except ImportError: psycopg2 = None try: import MySQLdb as mysql # prefer the C module. except ImportError: try: import pymysql as mysql except ImportError: mysql = None try: from playhouse._speedups import format_date_time from playhouse._speedups import sort_models_topologically from playhouse._speedups import strip_parens except ImportError: def format_date_time(value, formats, post_process=None): post_process = post_process or (lambda x: x) for fmt in formats: try: return post_process(datetime.datetime.strptime(value, fmt)) except ValueError: pass return value def sort_models_topologically(models): """Sort models topologically so that parents will precede children.""" models = set(models) seen = set() ordering = [] def dfs(model): if model in models and model not in seen: seen.add(model) for foreign_key in model._meta.reverse_rel.values(): dfs(foreign_key.model_class) ordering.append(model) # parent will follow descendants # Order models by name and table initially to guarantee total ordering. names = lambda m: (m._meta.name, m._meta.db_table) for m in sorted(models, key=names, reverse=True): dfs(m) return list(reversed(ordering)) def strip_parens(s): # Quick sanity check. if not s or s[0] != '(': return s ct = i = 0 l = len(s) while i < l: if s[i] == '(' and s[l - 1] == ')': ct += 1 i += 1 l -= 1 else: break if ct: # If we ever end up with negatively-balanced parentheses, then we # know that one of the outer parentheses was required. unbalanced_ct = 0 required = 0 for i in range(ct, l - ct): if s[i] == '(': unbalanced_ct += 1 elif s[i] == ')': unbalanced_ct -= 1 if unbalanced_ct < 0: required += 1 unbalanced_ct = 0 if required == ct: break ct -= required if ct > 0: return s[ct:-ct] return s try: from playhouse._speedups import _DictQueryResultWrapper from playhouse._speedups import _ModelQueryResultWrapper from playhouse._speedups import _SortedFieldList from playhouse._speedups import _TuplesQueryResultWrapper except ImportError: _DictQueryResultWrapper = _ModelQueryResultWrapper = _SortedFieldList =\ _TuplesQueryResultWrapper = None if sqlite3: sqlite3.register_adapter(decimal.Decimal, str) sqlite3.register_adapter(datetime.date, str) sqlite3.register_adapter(datetime.time, str) DATETIME_PARTS = ['year', 'month', 'day', 'hour', 'minute', 'second'] DATETIME_LOOKUPS = set(DATETIME_PARTS) # Sqlite does not support the `date_part` SQL function, so we will define an # implementation in python. SQLITE_DATETIME_FORMATS = ( '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M:%S.%f', '%Y-%m-%d', '%H:%M:%S', '%H:%M:%S.%f', '%H:%M') def _sqlite_date_part(lookup_type, datetime_string): assert lookup_type in DATETIME_LOOKUPS if not datetime_string: return dt = format_date_time(datetime_string, SQLITE_DATETIME_FORMATS) return getattr(dt, lookup_type) SQLITE_DATE_TRUNC_MAPPING = { 'year': '%Y', 'month': '%Y-%m', 'day': '%Y-%m-%d', 'hour': '%Y-%m-%d %H', 'minute': '%Y-%m-%d %H:%M', 'second': '%Y-%m-%d %H:%M:%S'} MYSQL_DATE_TRUNC_MAPPING = SQLITE_DATE_TRUNC_MAPPING.copy() MYSQL_DATE_TRUNC_MAPPING['minute'] = '%Y-%m-%d %H:%i' MYSQL_DATE_TRUNC_MAPPING['second'] = '%Y-%m-%d %H:%i:%S' def _sqlite_date_trunc(lookup_type, datetime_string): assert lookup_type in SQLITE_DATE_TRUNC_MAPPING if not datetime_string: return dt = format_date_time(datetime_string, SQLITE_DATETIME_FORMATS) return dt.strftime(SQLITE_DATE_TRUNC_MAPPING[lookup_type]) def _sqlite_regexp(regex, value): return re.search(regex, value, re.I) is not None class attrdict(dict): def __getattr__(self, attr): return self[attr] # Operators used in binary expressions. OP = attrdict( AND='and', OR='or', ADD='+', SUB='-', MUL='*', DIV='/', BIN_AND='&', BIN_OR='|', XOR='^', MOD='%', EQ='=', LT='<', LTE='<=', GT='>', GTE='>=', NE='!=', IN='in', NOT_IN='not in', IS='is', IS_NOT='is not', LIKE='like', ILIKE='ilike', BETWEEN='between', REGEXP='regexp', CONCAT='||', ) JOIN = attrdict( INNER='INNER', LEFT_OUTER='LEFT OUTER', RIGHT_OUTER='RIGHT OUTER', FULL='FULL', ) JOIN_INNER = JOIN.INNER JOIN_LEFT_OUTER = JOIN.LEFT_OUTER JOIN_FULL = JOIN.FULL RESULTS_NAIVE = 1 RESULTS_MODELS = 2 RESULTS_TUPLES = 3 RESULTS_DICTS = 4 RESULTS_AGGREGATE_MODELS = 5 # To support "django-style" double-underscore filters, create a mapping between # operation name and operation code, e.g. "__eq" == OP.EQ. DJANGO_MAP = { 'eq': OP.EQ, 'lt': OP.LT, 'lte': OP.LTE, 'gt': OP.GT, 'gte': OP.GTE, 'ne': OP.NE, 'in': OP.IN, 'is': OP.IS, 'like': OP.LIKE, 'ilike': OP.ILIKE, 'regexp': OP.REGEXP, } # Helper functions that are used in various parts of the codebase. def merge_dict(source, overrides): merged = source.copy() merged.update(overrides) return merged def returns_clone(func): """ Method decorator that will "clone" the object before applying the given method. This ensures that state is mutated in a more predictable fashion, and promotes the use of method-chaining. """ def inner(self, *args, **kwargs): clone = self.clone() # Assumes object implements `clone`. func(clone, *args, **kwargs) return clone inner.call_local = func # Provide a way to call without cloning. return inner def not_allowed(func): """ Method decorator to indicate a method is not allowed to be called. Will raise a `NotImplementedError`. """ def inner(self, *args, **kwargs): raise NotImplementedError('%s is not allowed on %s instances' % ( func, type(self).__name__)) return inner class Proxy(object): """ Proxy class useful for situations when you wish to defer the initialization of an object. """ __slots__ = ['obj', '_callbacks'] def __init__(self): self._callbacks = [] self.initialize(None) def initialize(self, obj): self.obj = obj for callback in self._callbacks: callback(obj) def attach_callback(self, callback): self._callbacks.append(callback) return callback def __getattr__(self, attr): if self.obj is None: raise AttributeError('Cannot use uninitialized Proxy.') return getattr(self.obj, attr) def __setattr__(self, attr, value): if attr not in self.__slots__: raise AttributeError('Cannot set attribute on proxy.') return super(Proxy, self).__setattr__(attr, value) class DeferredRelation(object): _unresolved = set() def __init__(self, rel_model_name=None): self.fields = [] if rel_model_name is not None: self._rel_model_name = rel_model_name.lower() self._unresolved.add(self) def set_field(self, model_class, field, name): self.fields.append((model_class, field, name)) def set_model(self, rel_model): for model, field, name in self.fields: field.rel_model = rel_model field.add_to_class(model, name) @staticmethod def resolve(model_cls): unresolved = list(DeferredRelation._unresolved) for dr in unresolved: if dr._rel_model_name == model_cls.__name__.lower(): dr.set_model(model_cls) DeferredRelation._unresolved.discard(dr) class _CDescriptor(object): def __get__(self, instance, instance_type=None): if instance is not None: return Entity(instance._alias) return self # Classes representing the query tree. class Node(object): """Base-class for any part of a query which shall be composable.""" c = _CDescriptor() _node_type = 'node' def __init__(self): self._negated = False self._alias = None self._bind_to = None self._ordering = None # ASC or DESC. @classmethod def extend(cls, name=None, clone=False): def decorator(method): method_name = name or method.__name__ if clone: method = returns_clone(method) setattr(cls, method_name, method) return method return decorator def clone_base(self): return type(self)() def clone(self): inst = self.clone_base() inst._negated = self._negated inst._alias = self._alias inst._ordering = self._ordering inst._bind_to = self._bind_to return inst @returns_clone def __invert__(self): self._negated = not self._negated @returns_clone def alias(self, a=None): self._alias = a @returns_clone def bind_to(self, bt): """ Bind the results of an expression to a specific model type. Useful when adding expressions to a select, where the result of the expression should be placed on a joined instance. """ self._bind_to = bt @returns_clone def asc(self): self._ordering = 'ASC' @returns_clone def desc(self): self._ordering = 'DESC' def __pos__(self): return self.asc() def __neg__(self): return self.desc() def _e(op, inv=False): """ Lightweight factory which returns a method that builds an Expression consisting of the left-hand and right-hand operands, using `op`. """ def inner(self, rhs): if inv: return Expression(rhs, op, self) return Expression(self, op, rhs) return inner __and__ = _e(OP.AND) __or__ = _e(OP.OR) __add__ = _e(OP.ADD) __sub__ = _e(OP.SUB) __mul__ = _e(OP.MUL) __div__ = __truediv__ = _e(OP.DIV) __xor__ = _e(OP.XOR) __radd__ = _e(OP.ADD, inv=True) __rsub__ = _e(OP.SUB, inv=True) __rmul__ = _e(OP.MUL, inv=True) __rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True) __rand__ = _e(OP.AND, inv=True) __ror__ = _e(OP.OR, inv=True) __rxor__ = _e(OP.XOR, inv=True) def __eq__(self, rhs): if rhs is None: return Expression(self, OP.IS, None) return Expression(self, OP.EQ, rhs) def __ne__(self, rhs): if rhs is None: return Expression(self, OP.IS_NOT, None) return Expression(self, OP.NE, rhs) __lt__ = _e(OP.LT) __le__ = _e(OP.LTE) __gt__ = _e(OP.GT) __ge__ = _e(OP.GTE) __lshift__ = _e(OP.IN) __rshift__ = _e(OP.IS) __mod__ = _e(OP.LIKE) __pow__ = _e(OP.ILIKE) bin_and = _e(OP.BIN_AND) bin_or = _e(OP.BIN_OR) # Special expressions. def in_(self, rhs): return Expression(self, OP.IN, rhs) def not_in(self, rhs): return Expression(self, OP.NOT_IN, rhs) def is_null(self, is_null=True): if is_null: return Expression(self, OP.IS, None) return Expression(self, OP.IS_NOT, None) def contains(self, rhs): return Expression(self, OP.ILIKE, '%%%s%%' % rhs) def startswith(self, rhs): return Expression(self, OP.ILIKE, '%s%%' % rhs) def endswith(self, rhs): return Expression(self, OP.ILIKE, '%%%s' % rhs) def between(self, low, high): return Expression(self, OP.BETWEEN, Clause(low, R('AND'), high)) def regexp(self, expression): return Expression(self, OP.REGEXP, expression) def concat(self, rhs): return Expression(self, OP.CONCAT, rhs) class SQL(Node): """An unescaped SQL string, with optional parameters.""" _node_type = 'sql' def __init__(self, value, *params): self.value = value self.params = params super(SQL, self).__init__() def clone_base(self): return SQL(self.value, *self.params) R = SQL # backwards-compat. class Entity(Node): """A quoted-name or entity, e.g. "table"."column".""" _node_type = 'entity' def __init__(self, *path): super(Entity, self).__init__() self.path = path def clone_base(self): return Entity(*self.path) def __getattr__(self, attr): return Entity(*filter(None, self.path + (attr,))) class Func(Node): """An arbitrary SQL function call.""" _node_type = 'func' _no_coerce = set(('count', 'sum')) def __init__(self, name, *arguments): self.name = name self.arguments = arguments self._coerce = (name.lower() not in self._no_coerce) if name else False super(Func, self).__init__() @returns_clone def coerce(self, coerce=True): self._coerce = coerce def clone_base(self): res = Func(self.name, *self.arguments) res._coerce = self._coerce return res def over(self, partition_by=None, order_by=None, window=None): if isinstance(partition_by, Window) and window is None: window = partition_by if window is None: sql = Window( partition_by=partition_by, order_by=order_by).__sql__() else: sql = SQL(window._alias) return Clause(self, SQL('OVER'), sql) def __getattr__(self, attr): def dec(*args, **kwargs): return Func(attr, *args, **kwargs) return dec # fn is a factory for creating `Func` objects and supports a more friendly # API. So instead of `Func("LOWER", param)`, `fn.LOWER(param)`. fn = Func(None) class Expression(Node): """A binary expression, e.g `foo + 1` or `bar < 7`.""" _node_type = 'expression' def __init__(self, lhs, op, rhs, flat=False): super(Expression, self).__init__() self.lhs = lhs self.op = op self.rhs = rhs self.flat = flat def clone_base(self): return Expression(self.lhs, self.op, self.rhs, self.flat) class Param(Node): """ Arbitrary parameter passed into a query. Instructs the query compiler to specifically treat this value as a parameter, useful for `list` which is special-cased for `IN` lookups. """ _node_type = 'param' def __init__(self, value, adapt=None): self.value = value self.adapt = adapt super(Param, self).__init__() def clone_base(self): return Param(self.value, self.adapt) class Passthrough(Param): _node_type = 'passthrough' class Clause(Node): """A SQL clause, one or more Node objects joined by spaces.""" _node_type = 'clause' glue = ' ' parens = False def __init__(self, *nodes, **kwargs): if 'glue' in kwargs: self.glue = kwargs['glue'] if 'parens' in kwargs: self.parens = kwargs['parens'] super(Clause, self).__init__() self.nodes = list(nodes) def clone_base(self): clone = Clause(*self.nodes) clone.glue = self.glue clone.parens = self.parens return clone class CommaClause(Clause): """One or more Node objects joined by commas, no parens.""" glue = ', ' class EnclosedClause(CommaClause): """One or more Node objects joined by commas and enclosed in parens.""" parens = True class Window(Node): def __init__(self, partition_by=None, order_by=None): super(Window, self).__init__() self.partition_by = partition_by self.order_by = order_by self._alias = self._alias or 'w' def __sql__(self): over_clauses = [] if self.partition_by: over_clauses.append(Clause( SQL('PARTITION BY'), CommaClause(*self.partition_by))) if self.order_by: over_clauses.append(Clause( SQL('ORDER BY'), CommaClause(*self.order_by))) return EnclosedClause(Clause(*over_clauses)) def clone_base(self): return Window(self.partition_by, self.order_by) def Check(value): return SQL('CHECK (%s)' % value) class DQ(Node): """A "django-style" filter expression, e.g. {'foo__eq': 'x'}.""" def __init__(self, **query): super(DQ, self).__init__() self.query = query def clone_base(self): return DQ(**self.query) class _StripParens(Node): _node_type = 'strip_parens' def __init__(self, node): super(_StripParens, self).__init__() self.node = node JoinMetadata = namedtuple('JoinMetadata', ( 'src_model', # Source Model class. 'dest_model', # Dest Model class. 'src', # Source, may be Model, ModelAlias 'dest', # Dest, may be Model, ModelAlias, or SelectQuery. 'attr', # Attribute name joined instance(s) should be assigned to. 'primary_key', # Primary key being joined on. 'foreign_key', # Foreign key being joined from. 'is_backref', # Is this a backref, i.e. 1 -> N. 'alias', # Explicit alias given to join expression. 'is_self_join', # Is this a self-join? 'is_expression', # Is the join ON clause an Expression? )) class Join(namedtuple('_Join', ('src', 'dest', 'join_type', 'on'))): def get_foreign_key(self, source, dest, field=None): if isinstance(source, SelectQuery) or isinstance(dest, SelectQuery): return None, None fk_field = source._meta.rel_for_model(dest, field) if fk_field is not None: return fk_field, False reverse_rel = source._meta.reverse_rel_for_model(dest, field) if reverse_rel is not None: return reverse_rel, True return None, None def get_join_type(self): return self.join_type or JOIN.INNER def model_from_alias(self, model_or_alias): if isinstance(model_or_alias, ModelAlias): return model_or_alias.model_class elif isinstance(model_or_alias, SelectQuery): return model_or_alias.model_class return model_or_alias def _join_metadata(self): # Get the actual tables being joined. src = self.model_from_alias(self.src) dest = self.model_from_alias(self.dest) join_alias = isinstance(self.on, Node) and self.on._alias or None is_expression = isinstance(self.on, (Expression, Func, SQL)) on_field = isinstance(self.on, (Field, FieldProxy)) and self.on or None if on_field: fk_field = on_field is_backref = on_field.name not in src._meta.fields else: fk_field, is_backref = self.get_foreign_key(src, dest, self.on) if fk_field is None and self.on is not None: fk_field, is_backref = self.get_foreign_key(src, dest) if fk_field is not None: primary_key = fk_field.to_field else: primary_key = None if not join_alias: if fk_field is not None: if is_backref: target_attr = dest._meta.db_table else: target_attr = fk_field.name else: try: target_attr = self.on.lhs.name except AttributeError: target_attr = dest._meta.db_table else: target_attr = None return JoinMetadata( src_model=src, dest_model=dest, src=self.src, dest=self.dest, attr=join_alias or target_attr, primary_key=primary_key, foreign_key=fk_field, is_backref=is_backref, alias=join_alias, is_self_join=src is dest, is_expression=is_expression) @property def metadata(self): if not hasattr(self, '_cached_metadata'): self._cached_metadata = self._join_metadata() return self._cached_metadata class FieldDescriptor(object): # Fields are exposed as descriptors in order to control access to the # underlying "raw" data. def __init__(self, field): self.field = field self.att_name = self.field.name def __get__(self, instance, instance_type=None): if instance is not None: return instance._data.get(self.att_name) return self.field def __set__(self, instance, value): instance._data[self.att_name] = value instance._dirty.add(self.att_name) class Field(Node): """A column on a table.""" _field_counter = 0 _order = 0 _node_type = 'field' db_field = 'unknown' def __init__(self, null=False, index=False, unique=False, verbose_name=None, help_text=None, db_column=None, default=None, choices=None, primary_key=False, sequence=None, constraints=None, schema=None): self.null = null self.index = index self.unique = unique self.verbose_name = verbose_name self.help_text = help_text self.db_column = db_column self.default = default self.choices = choices # Used for metadata purposes, not enforced. self.primary_key = primary_key self.sequence = sequence # Name of sequence, e.g. foo_id_seq. self.constraints = constraints # List of column constraints. self.schema = schema # Name of schema, e.g. 'public'. # Used internally for recovering the order in which Fields were defined # on the Model class. Field._field_counter += 1 self._order = Field._field_counter self._sort_key = (self.primary_key and 1 or 2), self._order self._is_bound = False # Whether the Field is "bound" to a Model. super(Field, self).__init__() def clone_base(self, **kwargs): inst = type(self)( null=self.null, index=self.index, unique=self.unique, verbose_name=self.verbose_name, help_text=self.help_text, db_column=self.db_column, default=self.default, choices=self.choices, primary_key=self.primary_key, sequence=self.sequence, constraints=self.constraints, schema=self.schema, **kwargs) if self._is_bound: inst.name = self.name inst.model_class = self.model_class inst._is_bound = self._is_bound return inst def add_to_class(self, model_class, name): """ Hook that replaces the `Field` attribute on a class with a named `FieldDescriptor`. Called by the metaclass during construction of the `Model`. """ self.name = name self.model_class = model_class self.db_column = self.db_column or self.name if not self.verbose_name: self.verbose_name = re.sub('_+', ' ', name).title() model_class._meta.add_field(self) setattr(model_class, name, FieldDescriptor(self)) self._is_bound = True def get_database(self): return self.model_class._meta.database def get_column_type(self): field_type = self.get_db_field() return self.get_database().compiler().get_column_type(field_type) def get_db_field(self): return self.db_field def get_modifiers(self): return None def coerce(self, value): return value def db_value(self, value): """Convert the python value for storage in the database.""" return value if value is None else self.coerce(value) def python_value(self, value): """Convert the database value to a pythonic value.""" return value if value is None else self.coerce(value) def as_entity(self, with_table=False): if with_table: return Entity(self.model_class._meta.db_table, self.db_column) return Entity(self.db_column) def __ddl_column__(self, column_type): """Return the column type, e.g. VARCHAR(255) or REAL.""" modifiers = self.get_modifiers() if modifiers: return SQL( '%s(%s)' % (column_type, ', '.join(map(str, modifiers)))) return SQL(column_type) def __ddl__(self, column_type): """Return a list of Node instances that defines the column.""" ddl = [self.as_entity(), self.__ddl_column__(column_type)] if not self.null: ddl.append(SQL('NOT NULL')) if self.primary_key: ddl.append(SQL('PRIMARY KEY')) if self.sequence: ddl.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence)) if self.constraints: ddl.extend(self.constraints) return ddl def __hash__(self): return hash(self.name + '.' + self.model_class.__name__) class BareField(Field): db_field = 'bare' def __init__(self, coerce=None, *args, **kwargs): super(BareField, self).__init__(*args, **kwargs) if coerce is not None: self.coerce = coerce def clone_base(self, **kwargs): return super(BareField, self).clone_base(coerce=self.coerce, **kwargs) class IntegerField(Field): db_field = 'int' coerce = int class BigIntegerField(IntegerField): db_field = 'bigint' class SmallIntegerField(IntegerField): db_field = 'smallint' class PrimaryKeyField(IntegerField): db_field = 'primary_key' def __init__(self, *args, **kwargs): kwargs['primary_key'] = True super(PrimaryKeyField, self).__init__(*args, **kwargs) class _AutoPrimaryKeyField(PrimaryKeyField): _column_name = None def add_to_class(self, model_class, name): if name != self._column_name: raise ValueError('%s must be named `%s`.' % (type(self), name)) super(_AutoPrimaryKeyField, self).add_to_class(model_class, name) class FloatField(Field): db_field = 'float' coerce = float class DoubleField(FloatField): db_field = 'double' class DecimalField(Field): db_field = 'decimal' def __init__(self, max_digits=10, decimal_places=5, auto_round=False, rounding=None, *args, **kwargs): self.max_digits = max_digits self.decimal_places = decimal_places self.auto_round = auto_round self.rounding = rounding or decimal.DefaultContext.rounding super(DecimalField, self).__init__(*args, **kwargs) def clone_base(self, **kwargs): return super(DecimalField, self).clone_base( max_digits=self.max_digits, decimal_places=self.decimal_places, auto_round=self.auto_round, rounding=self.rounding, **kwargs) def get_modifiers(self): return [self.max_digits, self.decimal_places] def db_value(self, value): D = decimal.Decimal if not value: return value if value is None else D(0) if self.auto_round: exp = D(10) ** (-self.decimal_places) rounding = self.rounding return D(str(value)).quantize(exp, rounding=rounding) return value def python_value(self, value): if value is not None: if isinstance(value, decimal.Decimal): return value return decimal.Decimal(str(value)) def coerce_to_unicode(s, encoding='utf-8'): if isinstance(s, unicode_type): return s elif isinstance(s, string_type): return s.decode(encoding) return unicode_type(s) class CharField(Field): db_field = 'string' def __init__(self, max_length=255, *args, **kwargs): self.max_length = max_length super(CharField, self).__init__(*args, **kwargs) def clone_base(self, **kwargs): return super(CharField, self).clone_base( max_length=self.max_length, **kwargs) def get_modifiers(self): return self.max_length and [self.max_length] or None def coerce(self, value): return coerce_to_unicode(value or '') class FixedCharField(CharField): db_field = 'fixed_char' def python_value(self, value): value = super(FixedCharField, self).python_value(value) if value: value = value.strip() return value class TextField(Field): db_field = 'text' def coerce(self, value): return coerce_to_unicode(value or '') class BlobField(Field): db_field = 'blob' _constructor = binary_construct def add_to_class(self, model_class, name): if isinstance(model_class._meta.database, Proxy): model_class._meta.database.attach_callback(self._set_constructor) return super(BlobField, self).add_to_class(model_class, name) def _set_constructor(self, database): self._constructor = database.get_binary_type() def db_value(self, value): if isinstance(value, unicode_type): value = value.encode('raw_unicode_escape') if isinstance(value, basestring): return self._constructor(value) return value class UUIDField(Field): db_field = 'uuid' def db_value(self, value): if isinstance(value, uuid.UUID): return value.hex try: return uuid.UUID(value).hex except: return value def python_value(self, value): if isinstance(value, uuid.UUID): return value return None if value is None else uuid.UUID(value) def _date_part(date_part): def dec(self): return self.model_class._meta.database.extract_date(date_part, self) return dec class _BaseFormattedField(Field): formats = None def __init__(self, formats=None, *args, **kwargs): if formats is not None: self.formats = formats super(_BaseFormattedField, self).__init__(*args, **kwargs) def clone_base(self, **kwargs): return super(_BaseFormattedField, self).clone_base( formats=self.formats, **kwargs) class DateTimeField(_BaseFormattedField): db_field = 'datetime' formats = [ '%Y-%m-%d %H:%M:%S.%f', '%Y-%m-%d %H:%M:%S', '%Y-%m-%d', ] def python_value(self, value): if value and isinstance(value, basestring): return format_date_time(value, self.formats) return value year = property(_date_part('year')) month = property(_date_part('month')) day = property(_date_part('day')) hour = property(_date_part('hour')) minute = property(_date_part('minute')) second = property(_date_part('second')) class DateField(_BaseFormattedField): db_field = 'date' formats = [ '%Y-%m-%d', '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M:%S.%f', ] def python_value(self, value): if value and isinstance(value, basestring): pp = lambda x: x.date() return format_date_time(value, self.formats, pp) elif value and isinstance(value, datetime.datetime): return value.date() return value year = property(_date_part('year')) month = property(_date_part('month')) day = property(_date_part('day')) class TimeField(_BaseFormattedField): db_field = 'time' formats = [ '%H:%M:%S.%f', '%H:%M:%S', '%H:%M', '%Y-%m-%d %H:%M:%S.%f', '%Y-%m-%d %H:%M:%S', ] def python_value(self, value): if value: if isinstance(value, basestring): pp = lambda x: x.time() return format_date_time(value, self.formats, pp) elif isinstance(value, datetime.datetime): return value.time() if value is not None and isinstance(value, datetime.timedelta): return (datetime.datetime.min + value).time() return value hour = property(_date_part('hour')) minute = property(_date_part('minute')) second = property(_date_part('second')) class TimestampField(IntegerField): # Support second -> microsecond resolution. valid_resolutions = [10**i for i in range(7)] def __init__(self, *args, **kwargs): self.resolution = kwargs.pop('resolution', 1) or 1 if self.resolution not in self.valid_resolutions: raise ValueError('TimestampField resolution must be one of: %s' % ', '.join(str(i) for i in self.valid_resolutions)) self.utc = kwargs.pop('utc', False) or False _dt = datetime.datetime self._conv = _dt.utcfromtimestamp if self.utc else _dt.fromtimestamp _default = _dt.utcnow if self.utc else _dt.now kwargs.setdefault('default', _default) super(TimestampField, self).__init__(*args, **kwargs) def get_db_field(self): # For second resolution we can get away (for a while) with using # 4 bytes to store the timestamp (as long as they're not > ~2038). # Otherwise we'll need to use a BigInteger type. return (self.db_field if self.resolution == 1 else BigIntegerField.db_field) def db_value(self, value): if value is None: return if isinstance(value, datetime.datetime): pass elif isinstance(value, datetime.date): value = datetime.datetime(value.year, value.month, value.day) else: return int(round(value * self.resolution)) if self.utc: timestamp = calendar.timegm(value.utctimetuple()) else: timestamp = time.mktime(value.timetuple()) timestamp += (value.microsecond * .000001) if self.resolution > 1: timestamp *= self.resolution return int(round(timestamp)) def python_value(self, value): if value is not None and isinstance(value, (int, float, long)): if value == 0: return elif self.resolution > 1: ticks_to_microsecond = 1000000 // self.resolution value, ticks = divmod(value, self.resolution) microseconds = ticks * ticks_to_microsecond return self._conv(value).replace(microsecond=microseconds) else: return self._conv(value) return value class BooleanField(Field): db_field = 'bool' coerce = bool class RelationDescriptor(FieldDescriptor): """Foreign-key abstraction to replace a related PK with a related model.""" def __init__(self, field, rel_model): self.rel_model = rel_model super(RelationDescriptor, self).__init__(field) def get_object_or_id(self, instance): rel_id = instance._data.get(self.att_name) if rel_id is not None or self.att_name in instance._obj_cache: if self.att_name not in instance._obj_cache: obj = self.rel_model.get(self.field.to_field == rel_id) instance._obj_cache[self.att_name] = obj return instance._obj_cache[self.att_name] elif not self.field.null: raise self.rel_model.DoesNotExist return rel_id def __get__(self, instance, instance_type=None): if instance is not None: return self.get_object_or_id(instance) return self.field def __set__(self, instance, value): if isinstance(value, self.rel_model): instance._data[self.att_name] = getattr( value, self.field.to_field.name) instance._obj_cache[self.att_name] = value else: orig_value = instance._data.get(self.att_name) instance._data[self.att_name] = value if orig_value != value and self.att_name in instance._obj_cache: del instance._obj_cache[self.att_name] instance._dirty.add(self.att_name) class ReverseRelationDescriptor(object): """Back-reference to expose related objects as a `SelectQuery`.""" def __init__(self, field): self.field = field self.rel_model = field.model_class def __get__(self, instance, instance_type=None): if instance is not None: return self.rel_model.select().where( self.field == getattr(instance, self.field.to_field.name)) return self class ObjectIdDescriptor(object): """Gives direct access to the underlying id""" def __init__(self, field): self.attr_name = field.name self.field = weakref.ref(field) def __get__(self, instance, instance_type=None): if instance is not None: return instance._data.get(self.attr_name) return self.field() def __set__(self, instance, value): setattr(instance, self.attr_name, value) class ForeignKeyField(IntegerField): def __init__(self, rel_model, related_name=None, on_delete=None, on_update=None, extra=None, to_field=None, *args, **kwargs): if rel_model != 'self' and not \ isinstance(rel_model, (Proxy, DeferredRelation)) and not \ issubclass(rel_model, Model): raise TypeError('Unexpected value for `rel_model`. Expected ' '`Model`, `Proxy`, `DeferredRelation`, or "self"') self.rel_model = rel_model self._related_name = related_name self.deferred = isinstance(rel_model, (Proxy, DeferredRelation)) self.on_delete = on_delete self.on_update = on_update self.extra = extra self.to_field = to_field super(ForeignKeyField, self).__init__(*args, **kwargs) def clone_base(self, **kwargs): return super(ForeignKeyField, self).clone_base( rel_model=self.rel_model, related_name=self._get_related_name(), on_delete=self.on_delete, on_update=self.on_update, extra=self.extra, to_field=self.to_field, **kwargs) def _get_descriptor(self): return RelationDescriptor(self, self.rel_model) def _get_id_descriptor(self): return ObjectIdDescriptor(self) def _get_backref_descriptor(self): return ReverseRelationDescriptor(self) def _get_related_name(self): if self._related_name and callable(self._related_name): return self._related_name(self) return self._related_name or ('%s_set' % self.model_class._meta.name) def add_to_class(self, model_class, name): if isinstance(self.rel_model, Proxy): def callback(rel_model): self.rel_model = rel_model self.add_to_class(model_class, name) self.rel_model.attach_callback(callback) return elif isinstance(self.rel_model, DeferredRelation): self.rel_model.set_field(model_class, self, name) return self.name = name self.model_class = model_class self.db_column = obj_id_name = self.db_column or '%s_id' % self.name if obj_id_name == self.name: obj_id_name += '_id' if not self.verbose_name: self.verbose_name = re.sub('_+', ' ', name).title() model_class._meta.add_field(self) self.related_name = self._get_related_name() if self.rel_model == 'self': self.rel_model = self.model_class if self.to_field is not None: if not isinstance(self.to_field, Field): self.to_field = getattr(self.rel_model, self.to_field) else: self.to_field = self.rel_model._meta.primary_key # TODO: factor into separate method. if model_class._meta.validate_backrefs: def invalid(msg, **context): context.update( field='%s.%s' % (model_class._meta.name, name), backref=self.related_name, obj_id_name=obj_id_name) raise AttributeError(msg % context) if self.related_name in self.rel_model._meta.fields: invalid('The related_name of %(field)s ("%(backref)s") ' 'conflicts with a field of the same name.') elif self.related_name in self.rel_model._meta.reverse_rel: invalid('The related_name of %(field)s ("%(backref)s") ' 'is already in use by another foreign key.') if obj_id_name in model_class._meta.fields: invalid('The object id descriptor of %(field)s conflicts ' 'with a field named %(obj_id_name)s') elif obj_id_name in model_class.__dict__: invalid('Model attribute "%(obj_id_name)s" would be shadowed ' 'by the object id descriptor of %(field)s.') setattr(model_class, name, self._get_descriptor()) setattr(model_class, obj_id_name, self._get_id_descriptor()) setattr(self.rel_model, self.related_name, self._get_backref_descriptor()) self._is_bound = True model_class._meta.rel[self.name] = self self.rel_model._meta.reverse_rel[self.related_name] = self def get_db_field(self): """ Overridden to ensure Foreign Keys use same column type as the primary key they point to. """ if not isinstance(self.to_field, PrimaryKeyField): return self.to_field.get_db_field() return super(ForeignKeyField, self).get_db_field() def get_modifiers(self): if not isinstance(self.to_field, PrimaryKeyField): return self.to_field.get_modifiers() return super(ForeignKeyField, self).get_modifiers() def coerce(self, value): return self.to_field.coerce(value) def db_value(self, value): if isinstance(value, self.rel_model): value = value._get_pk_value() return self.to_field.db_value(value) def python_value(self, value): if isinstance(value, self.rel_model): return value return self.to_field.python_value(value) class CompositeKey(object): """A primary key composed of multiple columns.""" sequence = None def __init__(self, *field_names): self.field_names = field_names def add_to_class(self, model_class, name): self.name = name self.model_class = model_class setattr(model_class, name, self) def __get__(self, instance, instance_type=None): if instance is not None: return tuple([getattr(instance, field_name) for field_name in self.field_names]) return self def __set__(self, instance, value): pass def __eq__(self, other): expressions = [(self.model_class._meta.fields[field] == value) for field, value in zip(self.field_names, other)] return reduce(operator.and_, expressions) def __hash__(self): return hash((self.model_class.__name__, self.field_names)) class AliasMap(object): prefix = 't' def __init__(self, start=0): self._alias_map = {} self._counter = start def __repr__(self): return '<AliasMap: %s>' % self._alias_map def add(self, obj, alias=None): if obj in self._alias_map: return self._counter += 1 self._alias_map[obj] = alias or '%s%s' % (self.prefix, self._counter) def __getitem__(self, obj): if obj not in self._alias_map: self.add(obj) return self._alias_map[obj] def __contains__(self, obj): return obj in self._alias_map def update(self, alias_map): if alias_map: for obj, alias in alias_map._alias_map.items(): if obj not in self: self._alias_map[obj] = alias return self class QueryCompiler(object): # Mapping of `db_type` to actual column type used by database driver. # Database classes may provide additional column types or overrides. field_map = { 'bare': '', 'bigint': 'BIGINT', 'blob': 'BLOB', 'bool': 'SMALLINT', 'date': 'DATE', 'datetime': 'DATETIME', 'decimal': 'DECIMAL', 'double': 'REAL', 'fixed_char': 'CHAR', 'float': 'REAL', 'int': 'INTEGER', 'primary_key': 'INTEGER', 'smallint': 'SMALLINT', 'string': 'VARCHAR', 'text': 'TEXT', 'time': 'TIME', } # Mapping of OP. to actual SQL operation. For most databases this will be # the same, but some column types or databases may support additional ops. # Like `field_map`, Database classes may extend or override these. op_map = { OP.EQ: '=', OP.LT: '<', OP.LTE: '<=', OP.GT: '>', OP.GTE: '>=', OP.NE: '!=', OP.IN: 'IN', OP.NOT_IN: 'NOT IN', OP.IS: 'IS', OP.IS_NOT: 'IS NOT', OP.BIN_AND: '&', OP.BIN_OR: '|', OP.LIKE: 'LIKE', OP.ILIKE: 'ILIKE', OP.BETWEEN: 'BETWEEN', OP.ADD: '+', OP.SUB: '-', OP.MUL: '*', OP.DIV: '/', OP.XOR: '#', OP.AND: 'AND', OP.OR: 'OR', OP.MOD: '%', OP.REGEXP: 'REGEXP', OP.CONCAT: '||', } join_map = { JOIN.INNER: 'INNER JOIN', JOIN.LEFT_OUTER: 'LEFT OUTER JOIN', JOIN.RIGHT_OUTER: 'RIGHT OUTER JOIN', JOIN.FULL: 'FULL JOIN', } alias_map_class = AliasMap def __init__(self, quote_char='"', interpolation='?', field_overrides=None, op_overrides=None): self.quote_char = quote_char self.interpolation = interpolation self._field_map = merge_dict(self.field_map, field_overrides or {}) self._op_map = merge_dict(self.op_map, op_overrides or {}) self._parse_map = self.get_parse_map() self._unknown_types = set(['param']) def get_parse_map(self): # To avoid O(n) lookups when parsing nodes, use a lookup table for # common node types O(1). return { 'expression': self._parse_expression, 'param': self._parse_param, 'passthrough': self._parse_passthrough, 'func': self._parse_func, 'clause': self._parse_clause, 'entity': self._parse_entity, 'field': self._parse_field, 'sql': self._parse_sql, 'select_query': self._parse_select_query, 'compound_select_query': self._parse_compound_select_query, 'strip_parens': self._parse_strip_parens, } def quote(self, s): return '%s%s%s' % (self.quote_char, s, self.quote_char) def get_column_type(self, f): return self._field_map[f] if f in self._field_map else f.upper() def get_op(self, q): return self._op_map[q] def _sorted_fields(self, field_dict): return sorted(field_dict.items(), key=lambda i: i[0]._sort_key) def _parse_default(self, node, alias_map, conv): return self.interpolation, [node] def _parse_expression(self, node, alias_map, conv): if isinstance(node.lhs, Field): conv = node.lhs lhs, lparams = self.parse_node(node.lhs, alias_map, conv) rhs, rparams = self.parse_node(node.rhs, alias_map, conv) if node.op == OP.IN and rhs == '()' and not rparams: return ('0 = 1' if node.flat else '(0 = 1)'), [] template = '%s %s %s' if node.flat else '(%s %s %s)' sql = template % (lhs, self.get_op(node.op), rhs) return sql, lparams + rparams def _parse_passthrough(self, node, alias_map, conv): if node.adapt: return self.parse_node(node.adapt(node.value), alias_map, None) return self.interpolation, [node.value] def _parse_param(self, node, alias_map, conv): if node.adapt: if conv and conv.db_value is node.adapt: conv = None return self.parse_node(node.adapt(node.value), alias_map, conv) elif conv is not None: return self.parse_node(conv.db_value(node.value), alias_map) else: return self.interpolation, [node.value] def _parse_func(self, node, alias_map, conv): conv = node._coerce and conv or None sql, params = self.parse_node_list(node.arguments, alias_map, conv) return '%s(%s)' % (node.name, strip_parens(sql)), params def _parse_clause(self, node, alias_map, conv): sql, params = self.parse_node_list( node.nodes, alias_map, conv, node.glue) if node.parens: sql = '(%s)' % strip_parens(sql) return sql, params def _parse_entity(self, node, alias_map, conv): return '.'.join(map(self.quote, node.path)), [] def _parse_sql(self, node, alias_map, conv): return node.value, list(node.params) def _parse_field(self, node, alias_map, conv): if alias_map: sql = '.'.join(( self.quote(alias_map[node.model_class]), self.quote(node.db_column))) else: sql = self.quote(node.db_column) return sql, [] def _parse_compound_select_query(self, node, alias_map, conv): csq = 'compound_select_query' if node.rhs._node_type == csq and node.lhs._node_type != csq: first_q, second_q = node.rhs, node.lhs inv = True else: first_q, second_q = node.lhs, node.rhs inv = False new_map = self.alias_map_class() if first_q._node_type == csq: new_map._counter = alias_map._counter first, first_p = self.generate_select(first_q, new_map) second, second_p = self.generate_select( second_q, self.calculate_alias_map(second_q, new_map)) if inv: l, lp, r, rp = second, second_p, first, first_p else: l, lp, r, rp = first, first_p , second, second_p # We add outer parentheses in the event the compound query is used in # the `from_()` clause, in which case we'll need them. if node.database.compound_select_parentheses: sql = '((%s) %s (%s))' % (l, node.operator, r) else: sql = '(%s %s %s)' % (l, node.operator, r) return sql, lp + rp def _parse_select_query(self, node, alias_map, conv): clone = node.clone() if not node._explicit_selection: if conv and isinstance(conv, ForeignKeyField): select_field = conv.to_field else: select_field = clone.model_class._meta.primary_key clone._select = (select_field,) sub, params = self.generate_select(clone, alias_map) return '(%s)' % strip_parens(sub), params def _parse_strip_parens(self, node, alias_map, conv): sql, params = self.parse_node(node.node, alias_map, conv) return strip_parens(sql), params def _parse(self, node, alias_map, conv): # By default treat the incoming node as a raw value that should be # parameterized. node_type = getattr(node, '_node_type', None) unknown = False if node_type in self._parse_map: sql, params = self._parse_map[node_type](node, alias_map, conv) unknown = (node_type in self._unknown_types and node.adapt is None and conv is None) elif isinstance(node, (list, tuple, set)): # If you're wondering how to pass a list into your query, simply # wrap it in Param(). sql, params = self.parse_node_list(node, alias_map, conv) sql = '(%s)' % sql elif isinstance(node, Model): sql = self.interpolation if conv and isinstance(conv, ForeignKeyField): to_field = conv.to_field if isinstance(to_field, ForeignKeyField): value = conv.db_value(node) else: value = to_field.db_value(getattr(node, to_field.name)) else: value = node._get_pk_value() params = [value] elif (isclass(node) and issubclass(node, Model)) or \ isinstance(node, ModelAlias): entity = node.as_entity().alias(alias_map[node]) sql, params = self.parse_node(entity, alias_map, conv) elif conv is not None: value = conv.db_value(node) sql, params, _ = self._parse(value, alias_map, None) else: sql, params = self._parse_default(node, alias_map, None) unknown = True return sql, params, unknown def parse_node(self, node, alias_map=None, conv=None): sql, params, unknown = self._parse(node, alias_map, conv) if unknown and (conv is not None) and params: params = [conv.db_value(i) for i in params] if isinstance(node, Node): if node._negated: sql = 'NOT %s' % sql if node._alias: sql = ' '.join((sql, 'AS', node._alias)) if node._ordering: sql = ' '.join((sql, node._ordering)) if params and any(isinstance(p, Node) for p in params): clean_params = [] clean_sql = [] for idx, param in enumerate(params): if isinstance(param, Node): csql, cparams = self.parse_node(param) return sql, params def parse_node_list(self, nodes, alias_map, conv=None, glue=', '): sql = [] params = [] for node in nodes: node_sql, node_params = self.parse_node(node, alias_map, conv) sql.append(node_sql) params.extend(node_params) return glue.join(sql), params def calculate_alias_map(self, query, alias_map=None): new_map = self.alias_map_class() if alias_map is not None: new_map._counter = alias_map._counter new_map.add(query.model_class, query.model_class._meta.table_alias) for src_model, joined_models in query._joins.items(): new_map.add(src_model, src_model._meta.table_alias) for join_obj in joined_models: if isinstance(join_obj.dest, Node): new_map.add(join_obj.dest, join_obj.dest.alias) else: new_map.add(join_obj.dest, join_obj.dest._meta.table_alias) return new_map.update(alias_map) def build_query(self, clauses, alias_map=None): return self.parse_node(Clause(*clauses), alias_map) def generate_joins(self, joins, model_class, alias_map): # Joins are implemented as an adjancency-list graph. Perform a # depth-first search of the graph to generate all the necessary JOINs. clauses = [] seen = set() q = [model_class] while q: curr = q.pop() if curr not in joins or curr in seen: continue seen.add(curr) for join in joins[curr]: src = curr dest = join.dest if isinstance(join.on, (Expression, Func, Clause, Entity)): # Clear any alias on the join expression. constraint = join.on.clone().alias() else: metadata = join.metadata if metadata.is_backref: fk_model = join.dest pk_model = join.src else: fk_model = join.src pk_model = join.dest fk = metadata.foreign_key if fk: lhs = getattr(fk_model, fk.name) rhs = getattr(pk_model, fk.to_field.name) if metadata.is_backref: lhs, rhs = rhs, lhs constraint = (lhs == rhs) else: raise ValueError('Missing required join predicate.') if isinstance(dest, Node): # TODO: ensure alias? dest_n = dest else: q.append(dest) dest_n = dest.as_entity().alias(alias_map[dest]) join_type = join.get_join_type() if join_type in self.join_map: join_sql = SQL(self.join_map[join_type]) else: join_sql = SQL(join_type) clauses.append( Clause(join_sql, dest_n, SQL('ON'), constraint)) return clauses def generate_select(self, query, alias_map=None): model = query.model_class db = model._meta.database alias_map = self.calculate_alias_map(query, alias_map) if isinstance(query, CompoundSelect): clauses = [_StripParens(query)] else: if not query._distinct: clauses = [SQL('SELECT')] else: clauses = [SQL('SELECT DISTINCT')] if query._distinct not in (True, False): clauses += [SQL('ON'), EnclosedClause(*query._distinct)] select_clause = Clause(*query._select) select_clause.glue = ', ' clauses.extend((select_clause, SQL('FROM'))) if query._from is None: clauses.append(model.as_entity().alias(alias_map[model])) else: clauses.append(CommaClause(*query._from)) if query._windows is not None: clauses.append(SQL('WINDOW')) clauses.append(CommaClause(*[ Clause( SQL(window._alias), SQL('AS'), window.__sql__()) for window in query._windows])) join_clauses = self.generate_joins(query._joins, model, alias_map) if join_clauses: clauses.extend(join_clauses) if query._where is not None: clauses.extend([SQL('WHERE'), query._where]) if query._group_by: clauses.extend([SQL('GROUP BY'), CommaClause(*query._group_by)]) if query._having: clauses.extend([SQL('HAVING'), query._having]) if query._order_by: clauses.extend([SQL('ORDER BY'), CommaClause(*query._order_by)]) if query._limit is not None or (query._offset and db.limit_max): limit = query._limit if query._limit is not None else db.limit_max clauses.append(SQL('LIMIT %s' % limit)) if query._offset is not None: clauses.append(SQL('OFFSET %s' % query._offset)) for_update, no_wait = query._for_update if for_update: stmt = 'FOR UPDATE NOWAIT' if no_wait else 'FOR UPDATE' clauses.append(SQL(stmt)) return self.build_query(clauses, alias_map) def generate_update(self, query): model = query.model_class alias_map = self.alias_map_class() alias_map.add(model, model._meta.db_table) if query._on_conflict: statement = 'UPDATE OR %s' % query._on_conflict else: statement = 'UPDATE' clauses = [SQL(statement), model.as_entity(), SQL('SET')] update = [] for field, value in self._sorted_fields(query._update): if not isinstance(value, (Node, Model)): value = Param(value, adapt=field.db_value) update.append(Expression( field.as_entity(with_table=False), OP.EQ, value, flat=True)) # No outer parens, no table alias. clauses.append(CommaClause(*update)) if query._where: clauses.extend([SQL('WHERE'), query._where]) if query._returning is not None: returning_clause = Clause(*query._returning) returning_clause.glue = ', ' clauses.extend([SQL('RETURNING'), returning_clause]) return self.build_query(clauses, alias_map) def _get_field_clause(self, fields, clause_type=EnclosedClause): return clause_type(*[ field.as_entity(with_table=False) for field in fields]) def generate_insert(self, query): model = query.model_class meta = model._meta alias_map = self.alias_map_class() alias_map.add(model, model._meta.db_table) if query._upsert: statement = meta.database.upsert_sql elif query._on_conflict: statement = 'INSERT OR %s INTO' % query._on_conflict else: statement = 'INSERT INTO' clauses = [SQL(statement), model.as_entity()] if query._query is not None: # This INSERT query is of the form INSERT INTO ... SELECT FROM. if query._fields: clauses.append(self._get_field_clause(query._fields)) clauses.append(_StripParens(query._query)) elif query._rows is not None: fields, value_clauses = [], [] have_fields = False for row_dict in query._iter_rows(): if not have_fields: fields = sorted( row_dict.keys(), key=operator.attrgetter('_sort_key')) have_fields = True values = [] for field in fields: value = row_dict[field] if not isinstance(value, (Node, Model)): value = Param(value, adapt=field.db_value) values.append(value) value_clauses.append(EnclosedClause(*values)) if fields: clauses.extend([ self._get_field_clause(fields), SQL('VALUES'), CommaClause(*value_clauses)]) elif query.model_class._meta.auto_increment: # Bare insert, use default value for primary key. clauses.append(query.database.default_insert_clause( query.model_class)) if query.is_insert_returning: clauses.extend([ SQL('RETURNING'), self._get_field_clause( meta.get_primary_key_fields(), clause_type=CommaClause)]) elif query._returning is not None: returning_clause = Clause(*query._returning) returning_clause.glue = ', ' clauses.extend([SQL('RETURNING'), returning_clause]) return self.build_query(clauses, alias_map) def generate_delete(self, query): model = query.model_class clauses = [SQL('DELETE FROM'), model.as_entity()] if query._where: clauses.extend([SQL('WHERE'), query._where]) if query._returning is not None: returning_clause = Clause(*query._returning) returning_clause.glue = ', ' clauses.extend([SQL('RETURNING'), returning_clause]) return self.build_query(clauses) def field_definition(self, field): column_type = self.get_column_type(field.get_db_field()) ddl = field.__ddl__(column_type) return Clause(*ddl) def foreign_key_constraint(self, field): ddl = [ SQL('FOREIGN KEY'), EnclosedClause(field.as_entity()), SQL('REFERENCES'), field.rel_model.as_entity(), EnclosedClause(field.to_field.as_entity())] if field.on_delete: ddl.append(SQL('ON DELETE %s' % field.on_delete)) if field.on_update: ddl.append(SQL('ON UPDATE %s' % field.on_update)) return Clause(*ddl) def return_parsed_node(function_name): # TODO: treat all `generate_` functions as returning clauses, instead # of SQL/params. def inner(self, *args, **kwargs): fn = getattr(self, function_name) return self.parse_node(fn(*args, **kwargs)) return inner def _create_foreign_key(self, model_class, field, constraint=None): constraint = constraint or 'fk_%s_%s_refs_%s' % ( model_class._meta.db_table, field.db_column, field.rel_model._meta.db_table) fk_clause = self.foreign_key_constraint(field) return Clause( SQL('ALTER TABLE'), model_class.as_entity(), SQL('ADD CONSTRAINT'), Entity(constraint), *fk_clause.nodes) create_foreign_key = return_parsed_node('_create_foreign_key') def _create_table(self, model_class, safe=False): statement = 'CREATE TABLE IF NOT EXISTS' if safe else 'CREATE TABLE' meta = model_class._meta columns, constraints = [], [] if meta.composite_key: pk_cols = [meta.fields[f].as_entity() for f in meta.primary_key.field_names] constraints.append(Clause( SQL('PRIMARY KEY'), EnclosedClause(*pk_cols))) for field in meta.declared_fields: columns.append(self.field_definition(field)) if isinstance(field, ForeignKeyField) and not field.deferred: constraints.append(self.foreign_key_constraint(field)) if model_class._meta.constraints: for constraint in model_class._meta.constraints: if not isinstance(constraint, Node): constraint = SQL(constraint) constraints.append(constraint) return Clause( SQL(statement), model_class.as_entity(), EnclosedClause(*(columns + constraints))) create_table = return_parsed_node('_create_table') def _drop_table(self, model_class, fail_silently=False, cascade=False): statement = 'DROP TABLE IF EXISTS' if fail_silently else 'DROP TABLE' ddl = [SQL(statement), model_class.as_entity()] if cascade: ddl.append(SQL('CASCADE')) return Clause(*ddl) drop_table = return_parsed_node('_drop_table') def _truncate_table(self, model_class, restart_identity=False, cascade=False): ddl = [SQL('TRUNCATE TABLE'), model_class.as_entity()] if restart_identity: ddl.append(SQL('RESTART IDENTITY')) if cascade: ddl.append(SQL('CASCADE')) return Clause(*ddl) truncate_table = return_parsed_node('_truncate_table') def index_name(self, table, columns): index = '%s_%s' % (table, '_'.join(columns)) if len(index) > 64: index_hash = hashlib.md5(index.encode('utf-8')).hexdigest() index = '%s_%s' % (table[:55], index_hash[:8]) # 55 + 1 + 8 = 64 return index def _create_index(self, model_class, fields, unique, *extra): tbl_name = model_class._meta.db_table statement = 'CREATE UNIQUE INDEX' if unique else 'CREATE INDEX' index_name = self.index_name(tbl_name, [f.db_column for f in fields]) return Clause( SQL(statement), Entity(index_name), SQL('ON'), model_class.as_entity(), EnclosedClause(*[field.as_entity() for field in fields]), *extra) create_index = return_parsed_node('_create_index') def _drop_index(self, model_class, fields, fail_silently=False): tbl_name = model_class._meta.db_table statement = 'DROP INDEX IF EXISTS' if fail_silently else 'DROP INDEX' index_name = self.index_name(tbl_name, [f.db_column for f in fields]) return Clause(SQL(statement), Entity(index_name)) drop_index = return_parsed_node('_drop_index') def _create_sequence(self, sequence_name): return Clause(SQL('CREATE SEQUENCE'), Entity(sequence_name)) create_sequence = return_parsed_node('_create_sequence') def _drop_sequence(self, sequence_name): return Clause(SQL('DROP SEQUENCE'), Entity(sequence_name)) drop_sequence = return_parsed_node('_drop_sequence') class SqliteQueryCompiler(QueryCompiler): def truncate_table(self, model_class, restart_identity=False, cascade=False): return model_class.delete().sql() class ResultIterator(object): def __init__(self, qrw): self.qrw = qrw self._idx = 0 def next(self): if self._idx < self.qrw._ct: obj = self.qrw._result_cache[self._idx] elif not self.qrw._populated: obj = self.qrw.iterate() self.qrw._result_cache.append(obj) self.qrw._ct += 1 else: raise StopIteration self._idx += 1 return obj __next__ = next class QueryResultWrapper(object): """ Provides an iterator over the results of a raw Query, additionally doing two things: - converts rows from the database into python representations - ensures that multiple iterations do not result in multiple queries """ def __init__(self, model, cursor, meta=None): self.model = model self.cursor = cursor self._ct = 0 self._idx = 0 self._result_cache = [] self._populated = False self._initialized = False if meta is not None: self.column_meta, self.join_meta = meta else: self.column_meta = self.join_meta = None def __iter__(self): if self._populated: return iter(self._result_cache) else: return ResultIterator(self) @property def count(self): self.fill_cache() return self._ct def __len__(self): return self.count def process_row(self, row): return row def iterate(self): row = self.cursor.fetchone() if not row: self._populated = True if not getattr(self.cursor, 'name', None): self.cursor.close() raise StopIteration elif not self._initialized: self.initialize(self.cursor.description) self._initialized = True return self.process_row(row) def iterator(self): while True: yield self.iterate() def next(self): if self._idx < self._ct: inst = self._result_cache[self._idx] self._idx += 1 return inst elif self._populated: raise StopIteration obj = self.iterate() self._result_cache.append(obj) self._ct += 1 self._idx += 1 return obj __next__ = next def fill_cache(self, n=None): n = n or float('Inf') if n < 0: raise ValueError('Negative values are not supported.') self._idx = self._ct while not self._populated and (n > self._ct): try: next(self) except StopIteration: break class ExtQueryResultWrapper(QueryResultWrapper): def initialize(self, description): n_cols = len(description) self.conv = conv = [] if self.column_meta is not None: n_meta = len(self.column_meta) for i, node in enumerate(self.column_meta): if not self._initialize_node(node, i): self._initialize_by_name(description[i][0], i) if n_cols == n_meta: return else: i = 0 for i in range(i, n_cols): self._initialize_by_name(description[i][0], i) def _initialize_by_name(self, name, i): model_cols = self.model._meta.columns if name in model_cols: field = model_cols[name] self.conv.append((i, field.name, field.python_value)) else: self.conv.append((i, name, None)) def _initialize_node(self, node, i): if isinstance(node, Field): self.conv.append((i, node._alias or node.name, node.python_value)) return True elif isinstance(node, Func) and len(node.arguments): arg = node.arguments[0] if isinstance(arg, Field): name = node._alias or arg._alias or arg.name func = node._coerce and arg.python_value or None self.conv.append((i, name, func)) return True return False class TuplesQueryResultWrapper(ExtQueryResultWrapper): def process_row(self, row): return tuple([col if self.conv[i][2] is None else self.conv[i][2](col) for i, col in enumerate(row)]) if _TuplesQueryResultWrapper is None: _TuplesQueryResultWrapper = TuplesQueryResultWrapper class NaiveQueryResultWrapper(ExtQueryResultWrapper): def process_row(self, row): instance = self.model() for i, column, f in self.conv: setattr(instance, column, f(row[i]) if f is not None else row[i]) instance._prepare_instance() return instance if _ModelQueryResultWrapper is None: _ModelQueryResultWrapper = NaiveQueryResultWrapper class DictQueryResultWrapper(ExtQueryResultWrapper): def process_row(self, row): res = {} for i, column, f in self.conv: res[column] = f(row[i]) if f is not None else row[i] return res if _DictQueryResultWrapper is None: _DictQueryResultWrapper = DictQueryResultWrapper class ModelQueryResultWrapper(QueryResultWrapper): def initialize(self, description): self.column_map, model_set = self.generate_column_map() self._col_set = set(col for col in self.column_meta if isinstance(col, Field)) self.join_list = self.generate_join_list(model_set) def generate_column_map(self): column_map = [] models = set([self.model]) for i, node in enumerate(self.column_meta): attr = conv = None if isinstance(node, Field): if isinstance(node, FieldProxy): key = node._model_alias constructor = node.model conv = node.field_instance.python_value else: key = constructor = node.model_class conv = node.python_value attr = node._alias or node.name else: if node._bind_to is None: key = constructor = self.model else: key = constructor = node._bind_to if isinstance(node, Node) and node._alias: attr = node._alias elif isinstance(node, Entity): attr = node.path[-1] column_map.append((key, constructor, attr, conv)) models.add(key) return column_map, models def generate_join_list(self, models): join_list = [] joins = self.join_meta stack = [self.model] while stack: current = stack.pop() if current not in joins: continue for join in joins[current]: metadata = join.metadata if metadata.dest in models or metadata.dest_model in models: if metadata.foreign_key is not None: fk_present = metadata.foreign_key in self._col_set pk_present = metadata.primary_key in self._col_set check = metadata.foreign_key.null and (fk_present or pk_present) else: check = fk_present = pk_present = False join_list.append(( metadata, check, fk_present, pk_present)) stack.append(join.dest) return join_list def process_row(self, row): collected = self.construct_instances(row) instances = self.follow_joins(collected) for i in instances: i._prepare_instance() return instances[0] def construct_instances(self, row, keys=None): collected_models = {} for i, (key, constructor, attr, conv) in enumerate(self.column_map): if keys is not None and key not in keys: continue value = row[i] if key not in collected_models: collected_models[key] = constructor() instance = collected_models[key] if attr is None: attr = self.cursor.description[i][0] setattr(instance, attr, value if conv is None else conv(value)) return collected_models def follow_joins(self, collected): prepared = [collected[self.model]] for (metadata, check_null, fk_present, pk_present) in self.join_list: inst = collected[metadata.src] try: joined_inst = collected[metadata.dest] except KeyError: joined_inst = collected[metadata.dest_model] has_fk = True if check_null: if fk_present: has_fk = inst._data.get(metadata.foreign_key.name) elif pk_present: has_fk = joined_inst._data.get(metadata.primary_key.name) if not has_fk: continue # Can we populate a value on the joined instance using the current? mpk = metadata.primary_key is not None can_populate_joined_pk = ( mpk and (metadata.attr in inst._data) and (getattr(joined_inst, metadata.primary_key.name) is None)) if can_populate_joined_pk: setattr( joined_inst, metadata.primary_key.name, inst._data[metadata.attr]) if metadata.is_backref: can_populate_joined_fk = ( mpk and (metadata.foreign_key is not None) and (getattr(inst, metadata.primary_key.name) is not None) and (joined_inst._data.get(metadata.foreign_key.name) is None)) if can_populate_joined_fk: setattr( joined_inst, metadata.foreign_key.name, inst) setattr(inst, metadata.attr, joined_inst) prepared.append(joined_inst) return prepared JoinCache = namedtuple('JoinCache', ('metadata', 'attr')) class AggregateQueryResultWrapper(ModelQueryResultWrapper): def __init__(self, *args, **kwargs): self._row = [] super(AggregateQueryResultWrapper, self).__init__(*args, **kwargs) def initialize(self, description): super(AggregateQueryResultWrapper, self).initialize(description) # Collect the set of all models (and ModelAlias objects) queried. self.all_models = set() for key, _, _, _ in self.column_map: self.all_models.add(key) # Prepare data structures for analyzing unique rows. Also cache # foreign key and attribute names for joined models. self.models_with_aggregate = set() self.back_references = {} self.source_to_dest = {} self.dest_to_source = {} for (metadata, _, _, _) in self.join_list: if metadata.is_backref: att_name = metadata.foreign_key.related_name else: att_name = metadata.attr is_backref = metadata.is_backref or metadata.is_self_join if is_backref: self.models_with_aggregate.add(metadata.src) else: self.dest_to_source.setdefault(metadata.dest, set()) self.dest_to_source[metadata.dest].add(metadata.src) self.source_to_dest.setdefault(metadata.src, {}) self.source_to_dest[metadata.src][metadata.dest] = JoinCache( metadata=metadata, attr=metadata.alias or att_name) # Determine which columns could contain "duplicate" data, e.g. if # getting Users and their Tweets, this would be the User columns. self.columns_to_compare = {} key_to_columns = {} for idx, (key, model_class, col_name, _) in enumerate(self.column_map): if key in self.models_with_aggregate: self.columns_to_compare.setdefault(key, []) self.columns_to_compare[key].append((idx, col_name)) key_to_columns.setdefault(key, []) key_to_columns[key].append((idx, col_name)) # Also compare columns for joins -> many-related model. for model_or_alias in self.models_with_aggregate: if model_or_alias not in self.columns_to_compare: continue sources = self.dest_to_source.get(model_or_alias, ()) for joined_model in sources: self.columns_to_compare[model_or_alias].extend( key_to_columns[joined_model]) def read_model_data(self, row): models = {} for model_class, column_data in self.columns_to_compare.items(): models[model_class] = [] for idx, col_name in column_data: models[model_class].append(row[idx]) return models def iterate(self): if self._row: row = self._row.pop() else: row = self.cursor.fetchone() if not row: self._populated = True if not getattr(self.cursor, 'name', None): self.cursor.close() raise StopIteration elif not self._initialized: self.initialize(self.cursor.description) self._initialized = True def _get_pk(instance): if instance._meta.composite_key: return tuple([ instance._data[field_name] for field_name in instance._meta.primary_key.field_names]) return instance._get_pk_value() identity_map = {} _constructed = self.construct_instances(row) primary_instance = _constructed[self.model] for model_or_alias, instance in _constructed.items(): identity_map[model_or_alias] = OrderedDict() identity_map[model_or_alias][_get_pk(instance)] = instance model_data = self.read_model_data(row) while True: cur_row = self.cursor.fetchone() if cur_row is None: break duplicate_models = set() cur_row_data = self.read_model_data(cur_row) for model_class, data in cur_row_data.items(): if model_data[model_class] == data: duplicate_models.add(model_class) if not duplicate_models: self._row.append(cur_row) break different_models = self.all_models - duplicate_models new_instances = self.construct_instances(cur_row, different_models) for model_or_alias, instance in new_instances.items(): # Do not include any instances which are comprised solely of # NULL values. all_none = True for value in instance._data.values(): if value is not None: all_none = False if not all_none: identity_map[model_or_alias][_get_pk(instance)] = instance stack = [self.model] instances = [primary_instance] while stack: current = stack.pop() if current not in self.join_meta: continue for join in self.join_meta[current]: try: metadata, attr = self.source_to_dest[current][join.dest] except KeyError: continue if metadata.is_backref or metadata.is_self_join: for instance in identity_map[current].values(): setattr(instance, attr, []) if join.dest not in identity_map: continue for pk, inst in identity_map[join.dest].items(): if pk is None: continue try: # XXX: if no FK exists, unable to join. joined_inst = identity_map[current][ inst._data[metadata.foreign_key.name]] except KeyError: continue getattr(joined_inst, attr).append(inst) instances.append(inst) elif attr: if join.dest not in identity_map: continue for pk, instance in identity_map[current].items(): # XXX: if no FK exists, unable to join. joined_inst = identity_map[join.dest][ instance._data[metadata.foreign_key.name]] setattr( instance, metadata.foreign_key.name, joined_inst) instances.append(joined_inst) stack.append(join.dest) for instance in instances: instance._prepare_instance() return primary_instance class Query(Node): """Base class representing a database query on one or more tables.""" require_commit = True def __init__(self, model_class): super(Query, self).__init__() self.model_class = model_class self.database = model_class._meta.database self._dirty = True self._query_ctx = model_class self._joins = {self.model_class: []} # Join graph as adjacency list. self._where = None def __repr__(self): sql, params = self.sql() return '%s %s %s' % (self.model_class, sql, params) def clone(self): query = type(self)(self.model_class) query.database = self.database return self._clone_attributes(query) def _clone_attributes(self, query): if self._where is not None: query._where = self._where.clone() query._joins = self._clone_joins() query._query_ctx = self._query_ctx return query def _clone_joins(self): return dict( (mc, list(j)) for mc, j in self._joins.items()) def _add_query_clauses(self, initial, expressions, conjunction=None): reduced = reduce(operator.and_, expressions) if initial is None: return reduced conjunction = conjunction or operator.and_ return conjunction(initial, reduced) def _model_shorthand(self, args): accum = [] for arg in args: if isinstance(arg, Node): accum.append(arg) elif isinstance(arg, Query): accum.append(arg) elif isinstance(arg, ModelAlias): accum.extend(arg.get_proxy_fields()) elif isclass(arg) and issubclass(arg, Model): accum.extend(arg._meta.declared_fields) return accum @returns_clone def where(self, *expressions): self._where = self._add_query_clauses(self._where, expressions) @returns_clone def orwhere(self, *expressions): self._where = self._add_query_clauses( self._where, expressions, operator.or_) @returns_clone def join(self, dest, join_type=None, on=None): src = self._query_ctx if not on: require_join_condition = ( isinstance(dest, SelectQuery) or (isclass(dest) and not src._meta.rel_exists(dest))) if require_join_condition: raise ValueError('A join condition must be specified.') elif isinstance(on, basestring): on = src._meta.fields[on] self._joins.setdefault(src, []) self._joins[src].append(Join(src, dest, join_type, on)) if not isinstance(dest, SelectQuery): self._query_ctx = dest @returns_clone def switch(self, model_class=None): """Change or reset the query context.""" self._query_ctx = model_class or self.model_class def ensure_join(self, lm, rm, on=None, **join_kwargs): ctx = self._query_ctx for join in self._joins.get(lm, []): if join.dest == rm: return self return self.switch(lm).join(rm, on=on, **join_kwargs).switch(ctx) def convert_dict_to_node(self, qdict): accum = [] joins = [] relationship = (ForeignKeyField, ReverseRelationDescriptor) for key, value in sorted(qdict.items()): curr = self.model_class if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP: key, op = key.rsplit('__', 1) op = DJANGO_MAP[op] else: op = OP.EQ for piece in key.split('__'): model_attr = getattr(curr, piece) if isinstance(model_attr, relationship): curr = model_attr.rel_model joins.append(model_attr) accum.append(Expression(model_attr, op, value)) return accum, joins def filter(self, *args, **kwargs): # normalize args and kwargs into a new expression dq_node = Node() if args: dq_node &= reduce(operator.and_, [a.clone() for a in args]) if kwargs: dq_node &= DQ(**kwargs) # dq_node should now be an Expression, lhs = Node(), rhs = ... q = deque([dq_node]) dq_joins = set() while q: curr = q.popleft() if not isinstance(curr, Expression): continue for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): if isinstance(piece, DQ): query, joins = self.convert_dict_to_node(piece.query) dq_joins.update(joins) expression = reduce(operator.and_, query) # Apply values from the DQ object. expression._negated = piece._negated expression._alias = piece._alias setattr(curr, side, expression) else: q.append(piece) dq_node = dq_node.rhs query = self.clone() for field in dq_joins: if isinstance(field, ForeignKeyField): lm, rm = field.model_class, field.rel_model field_obj = field elif isinstance(field, ReverseRelationDescriptor): lm, rm = field.field.rel_model, field.rel_model field_obj = field.field query = query.ensure_join(lm, rm, field_obj) return query.where(dq_node) def compiler(self): return self.database.compiler() def sql(self): raise NotImplementedError def _execute(self): sql, params = self.sql() return self.database.execute_sql(sql, params, self.require_commit) def execute(self): raise NotImplementedError def scalar(self, as_tuple=False, convert=False): if convert: row = self.tuples().first() else: row = self._execute().fetchone() if row and not as_tuple: return row[0] else: return row class RawQuery(Query): """ Execute a SQL query, returning a standard iterable interface that returns model instances. """ def __init__(self, model, query, *params): self._sql = query self._params = list(params) self._qr = None self._tuples = False self._dicts = False super(RawQuery, self).__init__(model) def clone(self): query = RawQuery(self.model_class, self._sql, *self._params) query._tuples = self._tuples query._dicts = self._dicts return query join = not_allowed('joining') where = not_allowed('where') switch = not_allowed('switch') @returns_clone def tuples(self, tuples=True): self._tuples = tuples @returns_clone def dicts(self, dicts=True): self._dicts = dicts def sql(self): return self._sql, self._params def execute(self): if self._qr is None: if self._tuples: QRW = self.database.get_result_wrapper(RESULTS_TUPLES) elif self._dicts: QRW = self.database.get_result_wrapper(RESULTS_DICTS) else: QRW = self.database.get_result_wrapper(RESULTS_NAIVE) self._qr = QRW(self.model_class, self._execute(), None) return self._qr def __iter__(self): return iter(self.execute()) def allow_extend(orig, new_val, **kwargs): extend = kwargs.pop('extend', False) if kwargs: raise ValueError('"extend" is the only valid keyword argument.') if extend: return ((orig or []) + new_val) or None elif new_val: return new_val class SelectQuery(Query): _node_type = 'select_query' def __init__(self, model_class, *selection): super(SelectQuery, self).__init__(model_class) self.require_commit = self.database.commit_select self.__select(*selection) self._from = None self._group_by = None self._having = None self._order_by = None self._windows = None self._limit = None self._offset = None self._distinct = False self._for_update = (False, False) self._naive = False self._tuples = False self._dicts = False self._aggregate_rows = False self._alias = None self._qr = None def _clone_attributes(self, query): query = super(SelectQuery, self)._clone_attributes(query) query._explicit_selection = self._explicit_selection query._select = list(self._select) if self._from is not None: query._from = [] for f in self._from: if isinstance(f, Node): query._from.append(f.clone()) else: query._from.append(f) if self._group_by is not None: query._group_by = list(self._group_by) if self._having: query._having = self._having.clone() if self._order_by is not None: query._order_by = list(self._order_by) if self._windows is not None: query._windows = list(self._windows) query._limit = self._limit query._offset = self._offset query._distinct = self._distinct query._for_update = self._for_update query._naive = self._naive query._tuples = self._tuples query._dicts = self._dicts query._aggregate_rows = self._aggregate_rows query._alias = self._alias return query def compound_op(operator): def inner(self, other): supported_ops = self.model_class._meta.database.compound_operations if operator not in supported_ops: raise ValueError( 'Your database does not support %s' % operator) return CompoundSelect(self.model_class, self, operator, other) return inner _compound_op_static = staticmethod(compound_op) __or__ = compound_op('UNION') __and__ = compound_op('INTERSECT') __sub__ = compound_op('EXCEPT') def __xor__(self, rhs): # Symmetric difference, should just be (self | rhs) - (self & rhs)... wrapped_rhs = self.model_class.select(SQL('*')).from_( EnclosedClause((self & rhs)).alias('_')).order_by() return (self | rhs) - wrapped_rhs def union_all(self, rhs): return SelectQuery._compound_op_static('UNION ALL')(self, rhs) def __select(self, *selection): self._explicit_selection = len(selection) > 0 selection = selection or self.model_class._meta.declared_fields self._select = self._model_shorthand(selection) select = returns_clone(__select) @returns_clone def from_(self, *args): self._from = list(args) if args else None @returns_clone def group_by(self, *args, **kwargs): self._group_by = self._model_shorthand(args) if args else None @returns_clone def having(self, *expressions): self._having = self._add_query_clauses(self._having, expressions) @returns_clone def order_by(self, *args, **kwargs): self._order_by = allow_extend(self._order_by, list(args), **kwargs) @returns_clone def window(self, *windows, **kwargs): self._windows = allow_extend(self._windows, list(windows), **kwargs) @returns_clone def limit(self, lim): self._limit = lim @returns_clone def offset(self, off): self._offset = off @returns_clone def paginate(self, page, paginate_by=20): if page > 0: page -= 1 self._limit = paginate_by self._offset = page * paginate_by @returns_clone def distinct(self, is_distinct=True): self._distinct = is_distinct @returns_clone def for_update(self, for_update=True, nowait=False): self._for_update = (for_update, nowait) @returns_clone def naive(self, naive=True): self._naive = naive @returns_clone def tuples(self, tuples=True): self._tuples = tuples @returns_clone def dicts(self, dicts=True): self._dicts = dicts @returns_clone def aggregate_rows(self, aggregate_rows=True): self._aggregate_rows = aggregate_rows @returns_clone def alias(self, alias=None): self._alias = alias def annotate(self, rel_model, annotation=None): if annotation is None: annotation = fn.Count(rel_model._meta.primary_key).alias('count') if self._query_ctx == rel_model: query = self.switch(self.model_class) else: query = self.clone() query = query.ensure_join(query._query_ctx, rel_model) if not query._group_by: query._group_by = [x.alias() for x in query._select] query._select = tuple(query._select) + (annotation,) return query def _aggregate(self, aggregation=None): if aggregation is None: aggregation = fn.Count(SQL('*')) query = self.order_by() query._select = [aggregation] return query def aggregate(self, aggregation=None, convert=True): return self._aggregate(aggregation).scalar(convert=convert) def count(self, clear_limit=False): if self._distinct or self._group_by or self._limit or self._offset: return self.wrapped_count(clear_limit=clear_limit) # defaults to a count() of the primary key return self.aggregate(convert=False) or 0 def wrapped_count(self, clear_limit=False): clone = self.order_by() if clear_limit: clone._limit = clone._offset = None sql, params = clone.sql() wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql rq = self.model_class.raw(wrapped, *params) return rq.scalar() or 0 def exists(self): clone = self.paginate(1, 1) clone._select = [SQL('1')] return bool(clone.scalar()) def get(self): clone = self.paginate(1, 1) try: return next(clone.execute()) except StopIteration: raise self.model_class.DoesNotExist( 'Instance matching query does not exist:\nSQL: %s\nPARAMS: %s' % self.sql()) def peek(self, n=1): res = self.execute() res.fill_cache(n) models = res._result_cache[:n] if models: return models[0] if n == 1 else models def first(self, n=1): if self._limit != n: self._limit = n self._dirty = True return self.peek(n=n) def sql(self): return self.compiler().generate_select(self) def verify_naive(self): model_class = self.model_class for node in self._select: if isinstance(node, Field) and node.model_class != model_class: return False elif isinstance(node, Node) and node._bind_to is not None: if node._bind_to != model_class: return False return True def get_query_meta(self): return (self._select, self._joins) def _get_result_wrapper(self): if self._tuples: return self.database.get_result_wrapper(RESULTS_TUPLES) elif self._dicts: return self.database.get_result_wrapper(RESULTS_DICTS) elif self._naive or not self._joins or self.verify_naive(): return self.database.get_result_wrapper(RESULTS_NAIVE) elif self._aggregate_rows: return self.database.get_result_wrapper(RESULTS_AGGREGATE_MODELS) else: return self.database.get_result_wrapper(RESULTS_MODELS) def execute(self): if self._dirty or self._qr is None: model_class = self.model_class query_meta = self.get_query_meta() ResultWrapper = self._get_result_wrapper() self._qr = ResultWrapper(model_class, self._execute(), query_meta) self._dirty = False return self._qr else: return self._qr def __iter__(self): return iter(self.execute()) def iterator(self): return iter(self.execute().iterator()) def __getitem__(self, value): res = self.execute() if isinstance(value, slice): index = value.stop else: index = value if index is not None: index = index + 1 if index >= 0 else None res.fill_cache(index) return res._result_cache[value] def __len__(self): return len(self.execute()) if PY3: def __hash__(self): return id(self) class NoopSelectQuery(SelectQuery): def sql(self): return (self.database.get_noop_sql(), ()) def get_query_meta(self): return None, None def _get_result_wrapper(self): return self.database.get_result_wrapper(RESULTS_TUPLES) class CompoundSelect(SelectQuery): _node_type = 'compound_select_query' def __init__(self, model_class, lhs=None, operator=None, rhs=None): self.lhs = lhs self.operator = operator self.rhs = rhs super(CompoundSelect, self).__init__(model_class, []) def _clone_attributes(self, query): query = super(CompoundSelect, self)._clone_attributes(query) query.lhs = self.lhs query.operator = self.operator query.rhs = self.rhs return query def count(self, clear_limit=False): return self.wrapped_count(clear_limit=clear_limit) def get_query_meta(self): return self.lhs.get_query_meta() def verify_naive(self): return self.lhs.verify_naive() and self.rhs.verify_naive() def _get_result_wrapper(self): if self._tuples: return self.database.get_result_wrapper(RESULTS_TUPLES) elif self._dicts: return self.database.get_result_wrapper(RESULTS_DICTS) elif self._aggregate_rows: return self.database.get_result_wrapper(RESULTS_AGGREGATE_MODELS) has_joins = self.lhs._joins or self.rhs._joins is_naive = self.lhs._naive or self.rhs._naive or self._naive if is_naive or not has_joins or self.verify_naive(): return self.database.get_result_wrapper(RESULTS_NAIVE) else: return self.database.get_result_wrapper(RESULTS_MODELS) class _WriteQuery(Query): def __init__(self, model_class): self._returning = None self._tuples = False self._dicts = False self._qr = None super(_WriteQuery, self).__init__(model_class) def _clone_attributes(self, query): query = super(_WriteQuery, self)._clone_attributes(query) if self._returning: query._returning = list(self._returning) query._tuples = self._tuples query._dicts = self._dicts return query def requires_returning(method): def inner(self, *args, **kwargs): db = self.model_class._meta.database if not db.returning_clause: raise ValueError('RETURNING is not supported by your ' 'database: %s' % type(db)) return method(self, *args, **kwargs) return inner @requires_returning @returns_clone def returning(self, *selection): if len(selection) == 1 and selection[0] is None: self._returning = None else: if not selection: selection = self.model_class._meta.declared_fields self._returning = self._model_shorthand(selection) @requires_returning @returns_clone def tuples(self, tuples=True): self._tuples = tuples @requires_returning @returns_clone def dicts(self, dicts=True): self._dicts = dicts def get_result_wrapper(self): if self._returning is not None: if self._tuples: return self.database.get_result_wrapper(RESULTS_TUPLES) elif self._dicts: return self.database.get_result_wrapper(RESULTS_DICTS) return self.database.get_result_wrapper(RESULTS_NAIVE) def _execute_with_result_wrapper(self): ResultWrapper = self.get_result_wrapper() meta = (self._returning, {self.model_class: []}) self._qr = ResultWrapper(self.model_class, self._execute(), meta) return self._qr class UpdateQuery(_WriteQuery): def __init__(self, model_class, update=None): self._update = update self._on_conflict = None super(UpdateQuery, self).__init__(model_class) def _clone_attributes(self, query): query = super(UpdateQuery, self)._clone_attributes(query) query._update = dict(self._update) query._on_conflict = self._on_conflict return query @returns_clone def on_conflict(self, action=None): self._on_conflict = action join = not_allowed('joining') def sql(self): return self.compiler().generate_update(self) def execute(self): if self._returning is not None and self._qr is None: return self._execute_with_result_wrapper() elif self._qr is not None: return self._qr else: return self.database.rows_affected(self._execute()) def __iter__(self): if not self.model_class._meta.database.returning_clause: raise ValueError('UPDATE queries cannot be iterated over unless ' 'they specify a RETURNING clause, which is not ' 'supported by your database.') return iter(self.execute()) def iterator(self): return iter(self.execute().iterator()) class InsertQuery(_WriteQuery): def __init__(self, model_class, field_dict=None, rows=None, fields=None, query=None, validate_fields=False): super(InsertQuery, self).__init__(model_class) self._upsert = False self._is_multi_row_insert = rows is not None or query is not None self._return_id_list = False if rows is not None: self._rows = rows else: self._rows = [field_dict or {}] self._fields = fields self._query = query self._validate_fields = validate_fields self._on_conflict = None def _iter_rows(self): model_meta = self.model_class._meta if self._validate_fields: valid_fields = model_meta.valid_fields def validate_field(field): if field not in valid_fields: raise KeyError('"%s" is not a recognized field.' % field) defaults = model_meta._default_dict callables = model_meta._default_callables for row_dict in self._rows: field_row = defaults.copy() seen = set() for key in row_dict: if self._validate_fields: validate_field(key) if key in model_meta.fields: field = model_meta.fields[key] else: field = key field_row[field] = row_dict[key] seen.add(field) if callables: for field in callables: if field not in seen: field_row[field] = callables[field]() yield field_row def _clone_attributes(self, query): query = super(InsertQuery, self)._clone_attributes(query) query._rows = self._rows query._upsert = self._upsert query._is_multi_row_insert = self._is_multi_row_insert query._fields = self._fields query._query = self._query query._return_id_list = self._return_id_list query._validate_fields = self._validate_fields query._on_conflict = self._on_conflict return query join = not_allowed('joining') where = not_allowed('where clause') @returns_clone def upsert(self, upsert=True): self._upsert = upsert @returns_clone def on_conflict(self, action=None): self._on_conflict = action @returns_clone def return_id_list(self, return_id_list=True): self._return_id_list = return_id_list @property def is_insert_returning(self): if self.database.insert_returning: if not self._is_multi_row_insert or self._return_id_list: return True return False def sql(self): return self.compiler().generate_insert(self) def _insert_with_loop(self): id_list = [] last_id = None return_id_list = self._return_id_list for row in self._rows: last_id = (InsertQuery(self.model_class, row) .upsert(self._upsert) .execute()) if return_id_list: id_list.append(last_id) if return_id_list: return id_list else: return last_id def execute(self): insert_with_loop = ( self._is_multi_row_insert and self._query is None and self._returning is None and not self.database.insert_many) if insert_with_loop: return self._insert_with_loop() if self._returning is not None and self._qr is None: return self._execute_with_result_wrapper() elif self._qr is not None: return self._qr else: cursor = self._execute() if not self._is_multi_row_insert: if self.database.insert_returning: pk_row = cursor.fetchone() meta = self.model_class._meta clean_data = [ field.python_value(column) for field, column in zip(meta.get_primary_key_fields(), pk_row)] if self.model_class._meta.composite_key: return clean_data return clean_data[0] return self.database.last_insert_id(cursor, self.model_class) elif self._return_id_list: return map(operator.itemgetter(0), cursor.fetchall()) else: return True class DeleteQuery(_WriteQuery): join = not_allowed('joining') def sql(self): return self.compiler().generate_delete(self) def execute(self): if self._returning is not None and self._qr is None: return self._execute_with_result_wrapper() elif self._qr is not None: return self._qr else: return self.database.rows_affected(self._execute()) IndexMetadata = namedtuple( 'IndexMetadata', ('name', 'sql', 'columns', 'unique', 'table')) ColumnMetadata = namedtuple( 'ColumnMetadata', ('name', 'data_type', 'null', 'primary_key', 'table')) ForeignKeyMetadata = namedtuple( 'ForeignKeyMetadata', ('column', 'dest_table', 'dest_column', 'table')) class PeeweeException(Exception): pass class ImproperlyConfigured(PeeweeException): pass class DatabaseError(PeeweeException): pass class DataError(DatabaseError): pass class IntegrityError(DatabaseError): pass class InterfaceError(PeeweeException): pass class InternalError(DatabaseError): pass class NotSupportedError(DatabaseError): pass class OperationalError(DatabaseError): pass class ProgrammingError(DatabaseError): pass class ExceptionWrapper(object): __slots__ = ['exceptions'] def __init__(self, exceptions): self.exceptions = exceptions def __enter__(self): pass def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: return if exc_type.__name__ in self.exceptions: new_type = self.exceptions[exc_type.__name__] if PY26: exc_args = exc_value else: exc_args = exc_value.args reraise(new_type, new_type(*exc_args), traceback) class _BaseConnectionLocal(object): def __init__(self, **kwargs): super(_BaseConnectionLocal, self).__init__(**kwargs) self.autocommit = None self.closed = True self.conn = None self.context_stack = [] self.transactions = [] class _ConnectionLocal(_BaseConnectionLocal, threading.local): pass class Database(object): commit_select = False compiler_class = QueryCompiler compound_operations = ['UNION', 'INTERSECT', 'EXCEPT', 'UNION ALL'] compound_select_parentheses = False distinct_on = False drop_cascade = False field_overrides = {} foreign_keys = True for_update = False for_update_nowait = False insert_many = True insert_returning = False interpolation = '?' limit_max = None op_overrides = {} quote_char = '"' reserved_tables = [] returning_clause = False savepoints = True sequences = False subquery_delete_same_table = True upsert_sql = None window_functions = False exceptions = { 'ConstraintError': IntegrityError, 'DatabaseError': DatabaseError, 'DataError': DataError, 'IntegrityError': IntegrityError, 'InterfaceError': InterfaceError, 'InternalError': InternalError, 'NotSupportedError': NotSupportedError, 'OperationalError': OperationalError, 'ProgrammingError': ProgrammingError} def __init__(self, database, threadlocals=True, autocommit=True, fields=None, ops=None, autorollback=False, use_speedups=True, **connect_kwargs): self.connect_kwargs = {} if threadlocals: self._local = _ConnectionLocal() else: self._local = _BaseConnectionLocal() self.init(database, **connect_kwargs) self._conn_lock = threading.Lock() self.autocommit = autocommit self.autorollback = autorollback self.use_speedups = use_speedups self.field_overrides = merge_dict(self.field_overrides, fields or {}) self.op_overrides = merge_dict(self.op_overrides, ops or {}) def init(self, database, **connect_kwargs): if not self.is_closed(): self.close() self.deferred = database is None self.database = database self.connect_kwargs.update(connect_kwargs) def exception_wrapper(self): return ExceptionWrapper(self.exceptions) def connect(self): with self._conn_lock: if self.deferred: raise Exception('Error, database not properly initialized ' 'before opening connection') with self.exception_wrapper(): self._local.conn = self._connect( self.database, **self.connect_kwargs) self._local.closed = False self.initialize_connection(self._local.conn) def initialize_connection(self, conn): pass def close(self): with self._conn_lock: if self.deferred: raise Exception('Error, database not properly initialized ' 'before closing connection') with self.exception_wrapper(): self._close(self._local.conn) self._local.closed = True def get_conn(self): if self._local.context_stack: conn = self._local.context_stack[-1].connection if conn is not None: return conn if self._local.closed: self.connect() return self._local.conn def is_closed(self): return self._local.closed def get_cursor(self): return self.get_conn().cursor() def _close(self, conn): conn.close() def _connect(self, database, **kwargs): raise NotImplementedError @classmethod def register_fields(cls, fields): cls.field_overrides = merge_dict(cls.field_overrides, fields) @classmethod def register_ops(cls, ops): cls.op_overrides = merge_dict(cls.op_overrides, ops) def get_result_wrapper(self, wrapper_type): if wrapper_type == RESULTS_NAIVE: return (_ModelQueryResultWrapper if self.use_speedups else NaiveQueryResultWrapper) elif wrapper_type == RESULTS_MODELS: return ModelQueryResultWrapper elif wrapper_type == RESULTS_TUPLES: return (_TuplesQueryResultWrapper if self.use_speedups else TuplesQueryResultWrapper) elif wrapper_type == RESULTS_DICTS: return (_DictQueryResultWrapper if self.use_speedups else DictQueryResultWrapper) elif wrapper_type == RESULTS_AGGREGATE_MODELS: return AggregateQueryResultWrapper else: return (_ModelQueryResultWrapper if self.use_speedups else NaiveQueryResultWrapper) def last_insert_id(self, cursor, model): if model._meta.auto_increment: return cursor.lastrowid def rows_affected(self, cursor): return cursor.rowcount def compiler(self): return self.compiler_class( self.quote_char, self.interpolation, self.field_overrides, self.op_overrides) def execute(self, clause): return self.execute_sql(*self.compiler().parse_node(clause)) def execute_sql(self, sql, params=None, require_commit=True): logger.debug((sql, params)) with self.exception_wrapper(): cursor = self.get_cursor() try: cursor.execute(sql, params or ()) except Exception: if self.get_autocommit() and self.autorollback: self.rollback() raise else: if require_commit and self.get_autocommit(): self.commit() return cursor def begin(self): pass def commit(self): self.get_conn().commit() def rollback(self): self.get_conn().rollback() def set_autocommit(self, autocommit): self._local.autocommit = autocommit def get_autocommit(self): if self._local.autocommit is None: self.set_autocommit(self.autocommit) return self._local.autocommit def push_execution_context(self, transaction): self._local.context_stack.append(transaction) def pop_execution_context(self): self._local.context_stack.pop() def execution_context_depth(self): return len(self._local.context_stack) def execution_context(self, with_transaction=True): return ExecutionContext(self, with_transaction=with_transaction) def push_transaction(self, transaction): self._local.transactions.append(transaction) def pop_transaction(self): self._local.transactions.pop() def transaction_depth(self): return len(self._local.transactions) def transaction(self): return transaction(self) def commit_on_success(self, func): @wraps(func) def inner(*args, **kwargs): with self.transaction(): return func(*args, **kwargs) return inner def savepoint(self, sid=None): if not self.savepoints: raise NotImplementedError return savepoint(self, sid) def atomic(self): return _atomic(self) def get_tables(self, schema=None): raise NotImplementedError def get_indexes(self, table, schema=None): raise NotImplementedError def get_columns(self, table, schema=None): raise NotImplementedError def get_primary_keys(self, table, schema=None): raise NotImplementedError def get_foreign_keys(self, table, schema=None): raise NotImplementedError def sequence_exists(self, seq): raise NotImplementedError def create_table(self, model_class, safe=False): qc = self.compiler() return self.execute_sql(*qc.create_table(model_class, safe)) def create_tables(self, models, safe=False): create_model_tables(models, fail_silently=safe) def create_index(self, model_class, fields, unique=False): qc = self.compiler() if not isinstance(fields, (list, tuple)): raise ValueError('Fields passed to "create_index" must be a list ' 'or tuple: "%s"' % fields) fobjs = [ model_class._meta.fields[f] if isinstance(f, basestring) else f for f in fields] return self.execute_sql(*qc.create_index(model_class, fobjs, unique)) def drop_index(self, model_class, fields, safe=False): qc = self.compiler() if not isinstance(fields, (list, tuple)): raise ValueError('Fields passed to "drop_index" must be a list ' 'or tuple: "%s"' % fields) fobjs = [ model_class._meta.fields[f] if isinstance(f, basestring) else f for f in fields] return self.execute_sql(*qc.drop_index(model_class, fobjs, safe)) def create_foreign_key(self, model_class, field, constraint=None): qc = self.compiler() return self.execute_sql(*qc.create_foreign_key( model_class, field, constraint)) def create_sequence(self, seq): if self.sequences: qc = self.compiler() return self.execute_sql(*qc.create_sequence(seq)) def drop_table(self, model_class, fail_silently=False, cascade=False): qc = self.compiler() return self.execute_sql(*qc.drop_table( model_class, fail_silently, cascade)) def drop_tables(self, models, safe=False, cascade=False): drop_model_tables(models, fail_silently=safe, cascade=cascade) def truncate_table(self, model_class, restart_identity=False, cascade=False): qc = self.compiler() return self.execute_sql(*qc.truncate_table( model_class, restart_identity, cascade)) def truncate_tables(self, models, restart_identity=False, cascade=False): for model in reversed(sort_models_topologically(models)): model.truncate_table(restart_identity, cascade) def drop_sequence(self, seq): if self.sequences: qc = self.compiler() return self.execute_sql(*qc.drop_sequence(seq)) def extract_date(self, date_part, date_field): return fn.EXTRACT(Clause(date_part, R('FROM'), date_field)) def truncate_date(self, date_part, date_field): return fn.DATE_TRUNC(date_part, date_field) def default_insert_clause(self, model_class): return SQL('DEFAULT VALUES') def get_noop_sql(self): return 'SELECT 0 WHERE 0' def get_binary_type(self): return binary_construct class SqliteDatabase(Database): compiler_class = SqliteQueryCompiler field_overrides = { 'bool': 'INTEGER', 'smallint': 'INTEGER', 'uuid': 'TEXT', } foreign_keys = False insert_many = sqlite3 and sqlite3.sqlite_version_info >= (3, 7, 11, 0) limit_max = -1 op_overrides = { OP.LIKE: 'GLOB', OP.ILIKE: 'LIKE', } upsert_sql = 'INSERT OR REPLACE INTO' def __init__(self, database, pragmas=None, *args, **kwargs): self._pragmas = pragmas or [] journal_mode = kwargs.pop('journal_mode', None) # Backwards-compat. if journal_mode: self._pragmas.append(('journal_mode', journal_mode)) super(SqliteDatabase, self).__init__(database, *args, **kwargs) def _connect(self, database, **kwargs): if not sqlite3: raise ImproperlyConfigured('pysqlite or sqlite3 must be installed.') conn = sqlite3.connect(database, **kwargs) conn.isolation_level = None try: self._add_conn_hooks(conn) except: conn.close() raise return conn def _add_conn_hooks(self, conn): self._set_pragmas(conn) conn.create_function('date_part', 2, _sqlite_date_part) conn.create_function('date_trunc', 2, _sqlite_date_trunc) conn.create_function('regexp', 2, _sqlite_regexp) def _set_pragmas(self, conn): if self._pragmas: cursor = conn.cursor() for pragma, value in self._pragmas: cursor.execute('PRAGMA %s = %s;' % (pragma, value)) cursor.close() def begin(self, lock_type='DEFERRED'): self.execute_sql('BEGIN %s' % lock_type, require_commit=False) def create_foreign_key(self, model_class, field, constraint=None): raise OperationalError('SQLite does not support ALTER TABLE ' 'statements to add constraints.') def get_tables(self, schema=None): cursor = self.execute_sql('SELECT name FROM sqlite_master WHERE ' 'type = ? ORDER BY name;', ('table',)) return [row[0] for row in cursor.fetchall()] def get_indexes(self, table, schema=None): query = ('SELECT name, sql FROM sqlite_master ' 'WHERE tbl_name = ? AND type = ? ORDER BY name') cursor = self.execute_sql(query, (table, 'index')) index_to_sql = dict(cursor.fetchall()) # Determine which indexes have a unique constraint. unique_indexes = set() cursor = self.execute_sql('PRAGMA index_list("%s")' % table) for row in cursor.fetchall(): name = row[1] is_unique = int(row[2]) == 1 if is_unique: unique_indexes.add(name) # Retrieve the indexed columns. index_columns = {} for index_name in sorted(index_to_sql): cursor = self.execute_sql('PRAGMA index_info("%s")' % index_name) index_columns[index_name] = [row[2] for row in cursor.fetchall()] return [ IndexMetadata( name, index_to_sql[name], index_columns[name], name in unique_indexes, table) for name in sorted(index_to_sql)] def get_columns(self, table, schema=None): cursor = self.execute_sql('PRAGMA table_info("%s")' % table) return [ColumnMetadata(row[1], row[2], not row[3], bool(row[5]), table) for row in cursor.fetchall()] def get_primary_keys(self, table, schema=None): cursor = self.execute_sql('PRAGMA table_info("%s")' % table) return [row[1] for row in cursor.fetchall() if row[-1]] def get_foreign_keys(self, table, schema=None): cursor = self.execute_sql('PRAGMA foreign_key_list("%s")' % table) return [ForeignKeyMetadata(row[3], row[2], row[4], table) for row in cursor.fetchall()] def savepoint(self, sid=None): return savepoint_sqlite(self, sid) def extract_date(self, date_part, date_field): return fn.date_part(date_part, date_field) def truncate_date(self, date_part, date_field): return fn.strftime(SQLITE_DATE_TRUNC_MAPPING[date_part], date_field) def get_binary_type(self): return sqlite3.Binary class PostgresqlDatabase(Database): commit_select = True compound_select_parentheses = True distinct_on = True drop_cascade = True field_overrides = { 'blob': 'BYTEA', 'bool': 'BOOLEAN', 'datetime': 'TIMESTAMP', 'decimal': 'NUMERIC', 'double': 'DOUBLE PRECISION', 'primary_key': 'SERIAL', 'uuid': 'UUID', } for_update = True for_update_nowait = True insert_returning = True interpolation = '%s' op_overrides = { OP.REGEXP: '~', } reserved_tables = ['user'] returning_clause = True sequences = True window_functions = True register_unicode = True def _connect(self, database, encoding=None, **kwargs): if not psycopg2: raise ImproperlyConfigured('psycopg2 must be installed.') conn = psycopg2.connect(database=database, **kwargs) if self.register_unicode: pg_extensions.register_type(pg_extensions.UNICODE, conn) pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) if encoding: conn.set_client_encoding(encoding) return conn def _get_pk_sequence(self, model): meta = model._meta if meta.primary_key is not False and meta.primary_key.sequence: return meta.primary_key.sequence elif meta.auto_increment: return '%s_%s_seq' % (meta.db_table, meta.primary_key.db_column) def last_insert_id(self, cursor, model): sequence = self._get_pk_sequence(model) if not sequence: return meta = model._meta if meta.schema: schema = '%s.' % meta.schema else: schema = '' cursor.execute("SELECT CURRVAL('%s\"%s\"')" % (schema, sequence)) result = cursor.fetchone()[0] if self.get_autocommit(): self.commit() return result def get_tables(self, schema='public'): query = ('SELECT tablename FROM pg_catalog.pg_tables ' 'WHERE schemaname = %s ORDER BY tablename') return [r for r, in self.execute_sql(query, (schema,)).fetchall()] def get_indexes(self, table, schema='public'): query = """ SELECT i.relname, idxs.indexdef, idx.indisunique, array_to_string(array_agg(cols.attname), ',') FROM pg_catalog.pg_class AS t INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid INNER JOIN pg_catalog.pg_indexes AS idxs ON (idxs.tablename = t.relname AND idxs.indexname = i.relname) LEFT OUTER JOIN pg_catalog.pg_attribute AS cols ON (cols.attrelid = t.oid AND cols.attnum = ANY(idx.indkey)) WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s GROUP BY i.relname, idxs.indexdef, idx.indisunique ORDER BY idx.indisunique DESC, i.relname;""" cursor = self.execute_sql(query, (table, 'r', schema)) return [IndexMetadata(row[0], row[1], row[3].split(','), row[2], table) for row in cursor.fetchall()] def get_columns(self, table, schema='public'): query = """ SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_name = %s AND table_schema = %s ORDER BY ordinal_position""" cursor = self.execute_sql(query, (table, schema)) pks = set(self.get_primary_keys(table, schema)) return [ColumnMetadata(name, dt, null == 'YES', name in pks, table) for name, null, dt in cursor.fetchall()] def get_primary_keys(self, table, schema='public'): query = """ SELECT kc.column_name FROM information_schema.table_constraints AS tc INNER JOIN information_schema.key_column_usage AS kc ON ( tc.table_name = kc.table_name AND tc.table_schema = kc.table_schema AND tc.constraint_name = kc.constraint_name) WHERE tc.constraint_type = %s AND tc.table_name = %s AND tc.table_schema = %s""" cursor = self.execute_sql(query, ('PRIMARY KEY', table, schema)) return [row for row, in cursor.fetchall()] def get_foreign_keys(self, table, schema='public'): sql = """ SELECT kcu.column_name, ccu.table_name, ccu.column_name FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON (tc.constraint_name = kcu.constraint_name AND tc.constraint_schema = kcu.constraint_schema) JOIN information_schema.constraint_column_usage AS ccu ON (ccu.constraint_name = tc.constraint_name AND ccu.constraint_schema = tc.constraint_schema) WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = %s AND tc.table_schema = %s""" cursor = self.execute_sql(sql, (table, schema)) return [ForeignKeyMetadata(row[0], row[1], row[2], table) for row in cursor.fetchall()] def sequence_exists(self, sequence): res = self.execute_sql(""" SELECT COUNT(*) FROM pg_class, pg_namespace WHERE relkind='S' AND pg_class.relnamespace = pg_namespace.oid AND relname=%s""", (sequence,)) return bool(res.fetchone()[0]) def set_search_path(self, *search_path): path_params = ','.join(['%s'] * len(search_path)) self.execute_sql('SET search_path TO %s' % path_params, search_path) def get_noop_sql(self): return 'SELECT 0 WHERE false' def get_binary_type(self): return psycopg2.Binary class MySQLDatabase(Database): commit_select = True compound_operations = ['UNION', 'UNION ALL'] field_overrides = { 'bool': 'BOOL', 'decimal': 'NUMERIC', 'double': 'DOUBLE PRECISION', 'float': 'FLOAT', 'primary_key': 'INTEGER AUTO_INCREMENT', 'text': 'LONGTEXT', 'uuid': 'VARCHAR(40)', } for_update = True interpolation = '%s' limit_max = 2 ** 64 - 1 # MySQL quirk op_overrides = { OP.LIKE: 'LIKE BINARY', OP.ILIKE: 'LIKE', OP.XOR: 'XOR', } quote_char = '`' subquery_delete_same_table = False upsert_sql = 'REPLACE INTO' def _connect(self, database, **kwargs): if not mysql: raise ImproperlyConfigured('MySQLdb or PyMySQL must be installed.') conn_kwargs = { 'charset': 'utf8', 'use_unicode': True, } conn_kwargs.update(kwargs) if 'password' in conn_kwargs: conn_kwargs['passwd'] = conn_kwargs.pop('password') return mysql.connect(db=database, **conn_kwargs) def get_tables(self, schema=None): return [row for row, in self.execute_sql('SHOW TABLES')] def get_indexes(self, table, schema=None): cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) unique = set() indexes = {} for row in cursor.fetchall(): if not row[1]: unique.add(row[2]) indexes.setdefault(row[2], []) indexes[row[2]].append(row[4]) return [IndexMetadata(name, None, indexes[name], name in unique, table) for name in indexes] def get_columns(self, table, schema=None): sql = """ SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_name = %s AND table_schema = DATABASE()""" cursor = self.execute_sql(sql, (table,)) pks = set(self.get_primary_keys(table)) return [ColumnMetadata(name, dt, null == 'YES', name in pks, table) for name, null, dt in cursor.fetchall()] def get_primary_keys(self, table, schema=None): cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) return [row[4] for row in cursor.fetchall() if row[2] == 'PRIMARY'] def get_foreign_keys(self, table, schema=None): query = """ SELECT column_name, referenced_table_name, referenced_column_name FROM information_schema.key_column_usage WHERE table_name = %s AND table_schema = DATABASE() AND referenced_table_name IS NOT NULL AND referenced_column_name IS NOT NULL""" cursor = self.execute_sql(query, (table,)) return [ ForeignKeyMetadata(column, dest_table, dest_column, table) for column, dest_table, dest_column in cursor.fetchall()] def extract_date(self, date_part, date_field): return fn.EXTRACT(Clause(R(date_part), R('FROM'), date_field)) def truncate_date(self, date_part, date_field): return fn.DATE_FORMAT(date_field, MYSQL_DATE_TRUNC_MAPPING[date_part]) def default_insert_clause(self, model_class): return Clause( EnclosedClause(model_class._meta.primary_key), SQL('VALUES (DEFAULT)')) def get_noop_sql(self): return 'DO 0' def get_binary_type(self): return mysql.Binary class _callable_context_manager(object): def __call__(self, fn): @wraps(fn) def inner(*args, **kwargs): with self: return fn(*args, **kwargs) return inner class ExecutionContext(_callable_context_manager): def __init__(self, database, with_transaction=True): self.database = database self.with_transaction = with_transaction self.connection = None def __enter__(self): with self.database._conn_lock: self.database.push_execution_context(self) self.connection = self.database._connect( self.database.database, **self.database.connect_kwargs) if self.with_transaction: self.txn = self.database.transaction() self.txn.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): with self.database._conn_lock: if self.connection is None: self.database.pop_execution_context() else: try: if self.with_transaction: if not exc_type: self.txn.commit(False) self.txn.__exit__(exc_type, exc_val, exc_tb) finally: self.database.pop_execution_context() self.database._close(self.connection) class Using(ExecutionContext): def __init__(self, database, models, with_transaction=True): super(Using, self).__init__(database, with_transaction) self.models = models def __enter__(self): self._orig = [] for model in self.models: self._orig.append(model._meta.database) model._meta.database = self.database return super(Using, self).__enter__() def __exit__(self, exc_type, exc_val, exc_tb): super(Using, self).__exit__(exc_type, exc_val, exc_tb) for i, model in enumerate(self.models): model._meta.database = self._orig[i] class _atomic(_callable_context_manager): def __init__(self, db): self.db = db def __enter__(self): if self.db.transaction_depth() == 0: self._helper = self.db.transaction() else: self._helper = self.db.savepoint() return self._helper.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): return self._helper.__exit__(exc_type, exc_val, exc_tb) class transaction(_callable_context_manager): def __init__(self, db): self.db = db def _begin(self): self.db.begin() def commit(self, begin=True): self.db.commit() if begin: self._begin() def rollback(self, begin=True): self.db.rollback() if begin: self._begin() def __enter__(self): self._orig = self.db.get_autocommit() self.db.set_autocommit(False) if self.db.transaction_depth() == 0: self._begin() self.db.push_transaction(self) return self def __exit__(self, exc_type, exc_val, exc_tb): try: if exc_type: self.rollback(False) elif self.db.transaction_depth() == 1: try: self.commit(False) except: self.rollback(False) raise finally: self.db.set_autocommit(self._orig) self.db.pop_transaction() class savepoint(_callable_context_manager): def __init__(self, db, sid=None): self.db = db _compiler = db.compiler() self.sid = sid or 's' + uuid.uuid4().hex self.quoted_sid = _compiler.quote(self.sid) def _execute(self, query): self.db.execute_sql(query, require_commit=False) def commit(self): self._execute('RELEASE SAVEPOINT %s;' % self.quoted_sid) def rollback(self): self._execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) def __enter__(self): self._orig_autocommit = self.db.get_autocommit() self.db.set_autocommit(False) self._execute('SAVEPOINT %s;' % self.quoted_sid) return self def __exit__(self, exc_type, exc_val, exc_tb): try: if exc_type: self.rollback() else: try: self.commit() except: self.rollback() raise finally: self.db.set_autocommit(self._orig_autocommit) class savepoint_sqlite(savepoint): def __enter__(self): conn = self.db.get_conn() # For sqlite, the connection's isolation_level *must* be set to None. # The act of setting it, though, will break any existing savepoints, # so only write to it if necessary. if conn.isolation_level is not None: self._orig_isolation_level = conn.isolation_level conn.isolation_level = None else: self._orig_isolation_level = None return super(savepoint_sqlite, self).__enter__() def __exit__(self, exc_type, exc_val, exc_tb): try: return super(savepoint_sqlite, self).__exit__( exc_type, exc_val, exc_tb) finally: if self._orig_isolation_level is not None: self.db.get_conn().isolation_level = self._orig_isolation_level class FieldProxy(Field): def __init__(self, alias, field_instance): self._model_alias = alias self.model = self._model_alias.model_class self.field_instance = field_instance def clone_base(self): return FieldProxy(self._model_alias, self.field_instance) def coerce(self, value): return self.field_instance.coerce(value) def python_value(self, value): return self.field_instance.python_value(value) def db_value(self, value): return self.field_instance.db_value(value) def __getattr__(self, attr): if attr == 'model_class': return self._model_alias return getattr(self.field_instance, attr) class ModelAlias(object): def __init__(self, model_class): self.__dict__['model_class'] = model_class def __getattr__(self, attr): model_attr = getattr(self.model_class, attr) if isinstance(model_attr, Field): return FieldProxy(self, model_attr) return model_attr def __setattr__(self, attr, value): raise AttributeError('Cannot set attributes on ModelAlias instances') def get_proxy_fields(self, declared_fields=False): mm = self.model_class._meta fields = mm.declared_fields if declared_fields else mm.sorted_fields return [FieldProxy(self, f) for f in fields] def select(self, *selection): if not selection: selection = self.get_proxy_fields() query = SelectQuery(self, *selection) if self._meta.order_by: query = query.order_by(*self._meta.order_by) return query def __call__(self, **kwargs): return self.model_class(**kwargs) if _SortedFieldList is None: class _SortedFieldList(object): __slots__ = ('_keys', '_items') def __init__(self): self._keys = [] self._items = [] def __getitem__(self, i): return self._items[i] def __iter__(self): return iter(self._items) def __contains__(self, item): k = item._sort_key i = bisect_left(self._keys, k) j = bisect_right(self._keys, k) return item in self._items[i:j] def index(self, field): return self._keys.index(field._sort_key) def insert(self, item): k = item._sort_key i = bisect_left(self._keys, k) self._keys.insert(i, k) self._items.insert(i, item) def remove(self, item): idx = self.index(item) del self._items[idx] del self._keys[idx] class DoesNotExist(Exception): pass if sqlite3: default_database = SqliteDatabase('peewee.db') else: default_database = None class ModelOptions(object): def __init__(self, cls, database=None, db_table=None, db_table_func=None, indexes=None, order_by=None, primary_key=None, table_alias=None, constraints=None, schema=None, validate_backrefs=True, only_save_dirty=False, **kwargs): self.model_class = cls self.name = cls.__name__.lower() self.fields = {} self.columns = {} self.defaults = {} self._default_by_name = {} self._default_dict = {} self._default_callables = {} self._default_callable_list = [] self._sorted_field_list = _SortedFieldList() self.sorted_fields = [] self.sorted_field_names = [] self.valid_fields = set() self.declared_fields = [] self.database = database if database is not None else default_database self.db_table = db_table self.db_table_func = db_table_func self.indexes = list(indexes or []) self.order_by = order_by self.primary_key = primary_key self.table_alias = table_alias self.constraints = constraints self.schema = schema self.validate_backrefs = validate_backrefs self.only_save_dirty = only_save_dirty self.auto_increment = None self.composite_key = False self.rel = {} self.reverse_rel = {} for key, value in kwargs.items(): setattr(self, key, value) self._additional_keys = set(kwargs.keys()) if self.db_table_func and not self.db_table: self.db_table = self.db_table_func(cls) def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.name) def prepared(self): if self.order_by: norm_order_by = [] for item in self.order_by: if isinstance(item, Field): prefix = '-' if item._ordering == 'DESC' else '' item = prefix + item.name field = self.fields[item.lstrip('-')] if item.startswith('-'): norm_order_by.append(field.desc()) else: norm_order_by.append(field.asc()) self.order_by = norm_order_by def _update_field_lists(self): self.sorted_fields = list(self._sorted_field_list) self.sorted_field_names = [f.name for f in self.sorted_fields] self.valid_fields = (set(self.fields.keys()) | set(self.fields.values()) | set((self.primary_key,))) self.declared_fields = [field for field in self.sorted_fields if not isinstance(field, _AutoPrimaryKeyField)] def add_field(self, field): self.remove_field(field.name) self.fields[field.name] = field self.columns[field.db_column] = field self._sorted_field_list.insert(field) self._update_field_lists() if field.default is not None: self.defaults[field] = field.default if callable(field.default): self._default_callables[field] = field.default self._default_callable_list.append((field.name, field.default)) else: self._default_dict[field] = field.default self._default_by_name[field.name] = field.default def remove_field(self, field_name): if field_name not in self.fields: return original = self.fields.pop(field_name) del self.columns[original.db_column] self._sorted_field_list.remove(original) self._update_field_lists() if original.default is not None: del self.defaults[original] if self._default_callables.pop(original, None): for i, (name, _) in enumerate(self._default_callable_list): if name == field_name: self._default_callable_list.pop(i) break else: self._default_dict.pop(original, None) self._default_by_name.pop(original.name, None) def get_default_dict(self): dd = self._default_by_name.copy() for field_name, default in self._default_callable_list: dd[field_name] = default() return dd def get_field_index(self, field): try: return self._sorted_field_list.index(field) except ValueError: return -1 def get_primary_key_fields(self): if self.composite_key: return [ self.fields[field_name] for field_name in self.primary_key.field_names] return [self.primary_key] def rel_for_model(self, model, field_obj=None, multi=False): is_field = isinstance(field_obj, Field) is_node = not is_field and isinstance(field_obj, Node) if multi: accum = [] for field in self.sorted_fields: if isinstance(field, ForeignKeyField) and field.rel_model == model: is_match = ( (field_obj is None) or (is_field and field_obj.name == field.name) or (is_node and field_obj._alias == field.name)) if is_match: if not multi: return field accum.append(field) if multi: return accum def reverse_rel_for_model(self, model, field_obj=None, multi=False): return model._meta.rel_for_model(self.model_class, field_obj, multi) def rel_exists(self, model): return self.rel_for_model(model) or self.reverse_rel_for_model(model) def related_models(self, backrefs=False): models = [] stack = [self.model_class] while stack: model = stack.pop() if model in models: continue models.append(model) for fk in model._meta.rel.values(): stack.append(fk.rel_model) if backrefs: for fk in model._meta.reverse_rel.values(): stack.append(fk.model_class) return models class BaseModel(type): inheritable = set([ 'constraints', 'database', 'db_table_func', 'indexes', 'order_by', 'primary_key', 'schema', 'validate_backrefs', 'only_save_dirty']) def __new__(cls, name, bases, attrs): if name == _METACLASS_ or bases[0].__name__ == _METACLASS_: return super(BaseModel, cls).__new__(cls, name, bases, attrs) meta_options = {} meta = attrs.pop('Meta', None) if meta: for k, v in meta.__dict__.items(): if not k.startswith('_'): meta_options[k] = v model_pk = getattr(meta, 'primary_key', None) parent_pk = None # inherit any field descriptors by deep copying the underlying field # into the attrs of the new model, additionally see if the bases define # inheritable model options and swipe them for b in bases: if not hasattr(b, '_meta'): continue base_meta = getattr(b, '_meta') if parent_pk is None: parent_pk = deepcopy(base_meta.primary_key) all_inheritable = cls.inheritable | base_meta._additional_keys for (k, v) in base_meta.__dict__.items(): if k in all_inheritable and k not in meta_options: meta_options[k] = v for (k, v) in b.__dict__.items(): if k in attrs: continue if isinstance(v, FieldDescriptor): if not v.field.primary_key: attrs[k] = deepcopy(v.field) # initialize the new class and set the magic attributes cls = super(BaseModel, cls).__new__(cls, name, bases, attrs) ModelOptionsBase = meta_options.get('model_options_base', ModelOptions) cls._meta = ModelOptionsBase(cls, **meta_options) cls._data = None cls._meta.indexes = list(cls._meta.indexes) if not cls._meta.db_table: cls._meta.db_table = re.sub('[^\w]+', '_', cls.__name__.lower()) # replace fields with field descriptors, calling the add_to_class hook fields = [] for name, attr in cls.__dict__.items(): if isinstance(attr, Field): if attr.primary_key and model_pk: raise ValueError('primary key is overdetermined.') elif attr.primary_key: model_pk, pk_name = attr, name else: fields.append((attr, name)) composite_key = False if model_pk is None: if parent_pk: model_pk, pk_name = parent_pk, parent_pk.name else: model_pk, pk_name = PrimaryKeyField(primary_key=True), 'id' elif isinstance(model_pk, CompositeKey): pk_name = '_composite_key' composite_key = True if model_pk is not False: model_pk.add_to_class(cls, pk_name) cls._meta.primary_key = model_pk cls._meta.auto_increment = ( isinstance(model_pk, PrimaryKeyField) or bool(model_pk.sequence)) cls._meta.composite_key = composite_key for field, name in fields: field.add_to_class(cls, name) # create a repr and error class before finalizing if hasattr(cls, '__unicode__'): setattr(cls, '__repr__', lambda self: '<%s: %r>' % ( cls.__name__, self.__unicode__())) exc_name = '%sDoesNotExist' % cls.__name__ exc_attrs = {'__module__': cls.__module__} exception_class = type(exc_name, (DoesNotExist,), exc_attrs) cls.DoesNotExist = exception_class cls._meta.prepared() if hasattr(cls, 'validate_model'): cls.validate_model() DeferredRelation.resolve(cls) return cls def __iter__(self): return iter(self.select()) class Model(with_metaclass(BaseModel)): def __init__(self, *args, **kwargs): self._data = self._meta.get_default_dict() self._dirty = set(self._data) self._obj_cache = {} for k, v in kwargs.items(): setattr(self, k, v) @classmethod def alias(cls): return ModelAlias(cls) @classmethod def select(cls, *selection): query = SelectQuery(cls, *selection) if cls._meta.order_by: query = query.order_by(*cls._meta.order_by) return query @classmethod def update(cls, __data=None, **update): fdict = __data or {} fdict.update([(cls._meta.fields[f], update[f]) for f in update]) return UpdateQuery(cls, fdict) @classmethod def insert(cls, __data=None, **insert): fdict = __data or {} fdict.update([(cls._meta.fields[f], insert[f]) for f in insert]) return InsertQuery(cls, fdict) @classmethod def insert_many(cls, rows, validate_fields=True): return InsertQuery(cls, rows=rows, validate_fields=validate_fields) @classmethod def insert_from(cls, fields, query): return InsertQuery(cls, fields=fields, query=query) @classmethod def delete(cls): return DeleteQuery(cls) @classmethod def raw(cls, sql, *params): return RawQuery(cls, sql, *params) @classmethod def create(cls, **query): inst = cls(**query) inst.save(force_insert=True) inst._prepare_instance() return inst @classmethod def get(cls, *query, **kwargs): sq = cls.select().naive() if query: sq = sq.where(*query) if kwargs: sq = sq.filter(**kwargs) return sq.get() @classmethod def get_or_create(cls, **kwargs): defaults = kwargs.pop('defaults', {}) query = cls.select() for field, value in kwargs.items(): if '__' in field: query = query.filter(**{field: value}) else: query = query.where(getattr(cls, field) == value) try: return query.get(), False except cls.DoesNotExist: try: params = dict((k, v) for k, v in kwargs.items() if '__' not in k) params.update(defaults) with cls._meta.database.atomic(): return cls.create(**params), True except IntegrityError as exc: try: return query.get(), False except cls.DoesNotExist: raise exc @classmethod def create_or_get(cls, **kwargs): try: with cls._meta.database.atomic(): return cls.create(**kwargs), True except IntegrityError: query = [] # TODO: multi-column unique constraints. for field_name, value in kwargs.items(): field = getattr(cls, field_name) if field.unique or field.primary_key: query.append(field == value) return cls.get(*query), False @classmethod def filter(cls, *dq, **query): return cls.select().filter(*dq, **query) @classmethod def table_exists(cls): kwargs = {} if cls._meta.schema: kwargs['schema'] = cls._meta.schema return cls._meta.db_table in cls._meta.database.get_tables(**kwargs) @classmethod def create_table(cls, fail_silently=False): if fail_silently and cls.table_exists(): return db = cls._meta.database pk = cls._meta.primary_key if db.sequences and pk is not False and pk.sequence: if not db.sequence_exists(pk.sequence): db.create_sequence(pk.sequence) db.create_table(cls) cls._create_indexes() @classmethod def _fields_to_index(cls): fields = [] for field in cls._meta.sorted_fields: if field.primary_key: continue requires_index = any(( field.index, field.unique, isinstance(field, ForeignKeyField))) if requires_index: fields.append(field) return fields @classmethod def _index_data(cls): return itertools.chain( [((field,), field.unique) for field in cls._fields_to_index()], cls._meta.indexes or ()) @classmethod def _create_indexes(cls): for field_list, is_unique in cls._index_data(): cls._meta.database.create_index(cls, field_list, is_unique) @classmethod def _drop_indexes(cls, safe=False): for field_list, is_unique in cls._index_data(): cls._meta.database.drop_index(cls, field_list, safe) @classmethod def sqlall(cls): queries = [] compiler = cls._meta.database.compiler() pk = cls._meta.primary_key if cls._meta.database.sequences and pk.sequence: queries.append(compiler.create_sequence(pk.sequence)) queries.append(compiler.create_table(cls)) for field in cls._fields_to_index(): queries.append(compiler.create_index(cls, [field], field.unique)) if cls._meta.indexes: for field_names, unique in cls._meta.indexes: fields = [cls._meta.fields[f] for f in field_names] queries.append(compiler.create_index(cls, fields, unique)) return [sql for sql, _ in queries] @classmethod def drop_table(cls, fail_silently=False, cascade=False): cls._meta.database.drop_table(cls, fail_silently, cascade) @classmethod def truncate_table(cls, restart_identity=False, cascade=False): cls._meta.database.truncate_table(cls, restart_identity, cascade) @classmethod def as_entity(cls): if cls._meta.schema: return Entity(cls._meta.schema, cls._meta.db_table) return Entity(cls._meta.db_table) @classmethod def noop(cls, *args, **kwargs): return NoopSelectQuery(cls, *args, **kwargs) def _get_pk_value(self): return getattr(self, self._meta.primary_key.name) get_id = _get_pk_value # Backwards-compatibility. def _set_pk_value(self, value): if not self._meta.composite_key: setattr(self, self._meta.primary_key.name, value) set_id = _set_pk_value # Backwards-compatibility. def _pk_expr(self): return self._meta.primary_key == self._get_pk_value() def _prepare_instance(self): self._dirty.clear() self.prepared() def prepared(self): pass def _prune_fields(self, field_dict, only): new_data = {} for field in only: if field.name in field_dict: new_data[field.name] = field_dict[field.name] return new_data def _populate_unsaved_relations(self, field_dict): for key in self._meta.rel: conditions = ( key in self._dirty and key in field_dict and field_dict[key] is None and self._obj_cache.get(key) is not None) if conditions: setattr(self, key, getattr(self, key)) field_dict[key] = self._data[key] def save(self, force_insert=False, only=None): field_dict = dict(self._data) if self._meta.primary_key is not False: pk_field = self._meta.primary_key pk_value = self._get_pk_value() else: pk_field = pk_value = None if only: field_dict = self._prune_fields(field_dict, only) elif self._meta.only_save_dirty and not force_insert: field_dict = self._prune_fields( field_dict, self.dirty_fields) if not field_dict: self._dirty.clear() return False self._populate_unsaved_relations(field_dict) if pk_value is not None and not force_insert: if self._meta.composite_key: for pk_part_name in pk_field.field_names: field_dict.pop(pk_part_name, None) else: field_dict.pop(pk_field.name, None) rows = self.update(**field_dict).where(self._pk_expr()).execute() elif pk_field is None: self.insert(**field_dict).execute() rows = 1 else: pk_from_cursor = self.insert(**field_dict).execute() if pk_from_cursor is not None: pk_value = pk_from_cursor self._set_pk_value(pk_value) rows = 1 self._dirty.clear() return rows def is_dirty(self): return bool(self._dirty) @property def dirty_fields(self): return [f for f in self._meta.sorted_fields if f.name in self._dirty] def dependencies(self, search_nullable=False): model_class = type(self) query = self.select().where(self._pk_expr()) stack = [(type(self), query)] seen = set() while stack: klass, query = stack.pop() if klass in seen: continue seen.add(klass) for rel_name, fk in klass._meta.reverse_rel.items(): rel_model = fk.model_class if fk.rel_model is model_class: node = (fk == self._data[fk.to_field.name]) subquery = rel_model.select().where(node) else: node = fk << query subquery = rel_model.select().where(node) if not fk.null or search_nullable: stack.append((rel_model, subquery)) yield (node, fk) def delete_instance(self, recursive=False, delete_nullable=False): if recursive: dependencies = self.dependencies(delete_nullable) for query, fk in reversed(list(dependencies)): model = fk.model_class if fk.null and not delete_nullable: model.update(**{fk.name: None}).where(query).execute() else: model.delete().where(query).execute() return self.delete().where(self._pk_expr()).execute() def __hash__(self): return hash((self.__class__, self._get_pk_value())) def __eq__(self, other): return ( other.__class__ == self.__class__ and self._get_pk_value() is not None and other._get_pk_value() == self._get_pk_value()) def __ne__(self, other): return not self == other def clean_prefetch_subquery(query): query = query.clone() query._group_by = query._having = None return query def prefetch_add_subquery(sq, subqueries): fixed_queries = [PrefetchResult(sq)] for i, subquery in enumerate(subqueries): if isinstance(subquery, tuple): subquery, target_model = subquery else: target_model = None if not isinstance(subquery, Query) and issubclass(subquery, Model): subquery = subquery.select() subquery_model = subquery.model_class fks = backrefs = None for j in reversed(range(i + 1)): prefetch_result = fixed_queries[j] last_query = prefetch_result.query last_model = prefetch_result.model rels = subquery_model._meta.rel_for_model(last_model, multi=True) if rels: fks = [getattr(subquery_model, fk.name) for fk in rels] pks = [getattr(last_model, fk.to_field.name) for fk in rels] else: backrefs = last_model._meta.rel_for_model( subquery_model, multi=True) if (fks or backrefs) and ((target_model is last_model) or (target_model is None)): break if not (fks or backrefs): tgt_err = ' using %s' % target_model if target_model else '' raise AttributeError('Error: unable to find foreign key for ' 'query: %s%s' % (subquery, tgt_err)) if fks: cleaned = clean_prefetch_subquery(last_query) expr = reduce(operator.or_, [ (fk << cleaned.select(pk)) for (fk, pk) in zip(fks, pks)]) subquery = subquery.where(expr) fixed_queries.append(PrefetchResult(subquery, fks, False)) elif backrefs: cleaned = clean_prefetch_subquery(last_query) expr = reduce(operator.or_, [ (backref.to_field << cleaned.select(backref)) for backref in backrefs]) subquery = subquery.where(expr) fixed_queries.append(PrefetchResult(subquery, backrefs, True)) return fixed_queries __prefetched = namedtuple('__prefetched', ( 'query', 'fields', 'backref', 'rel_models', 'field_to_name', 'model')) class PrefetchResult(__prefetched): def __new__(cls, query, fields=None, backref=None, rel_models=None, field_to_name=None, model=None): if fields: if backref: rel_models = [field.model_class for field in fields] foreign_key_attrs = [field.to_field.name for field in fields] else: rel_models = [field.rel_model for field in fields] foreign_key_attrs = [field.name for field in fields] field_to_name = list(zip(fields, foreign_key_attrs)) model = query.model_class return super(PrefetchResult, cls).__new__( cls, query, fields, backref, rel_models, field_to_name, model) def populate_instance(self, instance, id_map): if self.backref: for field in self.fields: identifier = instance._data[field.name] key = (field, identifier) if key in id_map: setattr(instance, field.name, id_map[key]) else: for field, attname in self.field_to_name: identifier = instance._data[field.to_field.name] key = (field, identifier) rel_instances = id_map.get(key, []) dest = '%s_prefetch' % field.related_name for inst in rel_instances: setattr(inst, attname, instance) setattr(instance, dest, rel_instances) def store_instance(self, instance, id_map): for field, attname in self.field_to_name: identity = field.to_field.python_value(instance._data[attname]) key = (field, identity) if self.backref: id_map[key] = instance else: id_map.setdefault(key, []) id_map[key].append(instance) def prefetch(sq, *subqueries): if not subqueries: return sq fixed_queries = prefetch_add_subquery(sq, subqueries) deps = {} rel_map = {} for prefetch_result in reversed(fixed_queries): query_model = prefetch_result.model if prefetch_result.fields: for rel_model in prefetch_result.rel_models: rel_map.setdefault(rel_model, []) rel_map[rel_model].append(prefetch_result) deps[query_model] = {} id_map = deps[query_model] has_relations = bool(rel_map.get(query_model)) for instance in prefetch_result.query: if prefetch_result.fields: prefetch_result.store_instance(instance, id_map) if has_relations: for rel in rel_map[query_model]: rel.populate_instance(instance, deps[rel.model]) return prefetch_result.query def create_model_tables(models, **create_table_kwargs): """Create tables for all given models (in the right order).""" for m in sort_models_topologically(models): m.create_table(**create_table_kwargs) def drop_model_tables(models, **drop_table_kwargs): """Drop tables for all given models (in the right order).""" for m in reversed(sort_models_topologically(models)): m.drop_table(**drop_table_kwargs)