diff --git a/README.md b/README.md index 9027af9b16d379115f04a56c0d213503f57d7500..1a773dd20bcd44a66abc051809ea64507aa13f23 100644 --- a/README.md +++ b/README.md @@ -28,3 +28,4 @@ This project uses - [mygpoclient](http://gpodder.org/mygpoclient/) by Thomas Perl and others - [html2text](https://github.com/Alir3z4/html2text) by Aaron Schwartz and others - [mutagen](https://github.com/quodlibet/mutagen) by Christoph Reiter and Joe Wreschnig +- [peewee](https://github.com/coleifer/peewee) diff --git a/python/peewee.py b/python/peewee.py index ad5eb7f28e9012702b43793f9c01d68960ffa1c6..774044c017b919098086d1e93d5254cb14f18545 100644 --- a/python/peewee.py +++ b/python/peewee.py @@ -70,8 +70,9 @@ except ImportError: mysql = None -__version__ = '3.14.4' +__version__ = '3.14.9' __all__ = [ + 'AnyField', 'AsIs', 'AutoField', 'BareField', @@ -1180,6 +1181,9 @@ class ColumnBase(Node): __mod__ = _e(OP.LIKE) __pow__ = _e(OP.ILIKE) + like = _e(OP.LIKE) + ilike = _e(OP.ILIKE) + bin_and = _e(OP.BIN_AND) bin_or = _e(OP.BIN_OR) in_ = _e(OP.IN) @@ -2112,6 +2116,7 @@ class Query(BaseQuery): def __compound_select__(operation, inverted=False): + @__bind_database__ def method(self, other): if inverted: self, other = other, self @@ -2320,6 +2325,9 @@ class Select(SelectBase): item = self._from_list.pop() self._from_list.append(Join(item, dest, join_type, on)) + def left_outer_join(self, dest, on=None): + return self.join(dest, JOIN.LEFT_OUTER, on) + @Node.copy def group_by(self, *columns): grouping = [] @@ -2520,8 +2528,15 @@ class Update(_WriteQuery): v = k.to_value(v) else: v = Value(v, unpack=False) + elif isinstance(v, Model) and isinstance(k, ForeignKeyField): + # NB: we want to ensure that when passed a model instance + # in the context of a foreign-key, we apply the fk-specific + # adaptation of the model. + v = k.to_value(v) + if not isinstance(v, Value): v = qualify_names(v) + expressions.append(NodeList((k, SQL('='), v))) (ctx @@ -2634,6 +2649,7 @@ class Insert(_WriteQuery): if col not in seen: columns.append(col) + fk_fields = set() nullable_columns = set() value_lookups = {} for column in columns: @@ -2643,6 +2659,8 @@ class Insert(_WriteQuery): lookups.append(column.column_name) if column.null: nullable_columns.add(column) + if isinstance(column, ForeignKeyField): + fk_fields.add(column) value_lookups[column] = lookups ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') @@ -2681,7 +2699,8 @@ class Insert(_WriteQuery): else: raise ValueError('Missing value for %s.' % column.name) - if not isinstance(val, Node): + if not isinstance(val, Node) or (isinstance(val, Model) and + column in fk_fields): val = Value(val, converter=converter, unpack=False) values.append(val) @@ -3307,6 +3326,10 @@ class Database(_callable_context_manager): yield obj def table_exists(self, table_name, schema=None): + if is_model(table_name): + model = table_name + table_name = model._meta.table_name + schema = model._meta.schema return table_name in self.get_tables(schema=schema) def get_tables(self, schema=None): @@ -3773,7 +3796,16 @@ class PostgresqlDatabase(Database): def _connect(self): if psycopg2 is None: raise ImproperlyConfigured('Postgres driver not installed!') - conn = psycopg2.connect(database=self.database, **self.connect_params) + + # Handle connection-strings nicely, since psycopg2 will accept them, + # and they may be easier when lots of parameters are specified. + params = self.connect_params.copy() + if self.database.startswith('postgresql://'): + params.setdefault('dsn', self.database) + else: + params.setdefault('dbname', self.database) + + conn = psycopg2.connect(**params) if self._register_unicode: pg_extensions.register_type(pg_extensions.UNICODE, conn) pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) @@ -3902,7 +3934,13 @@ class PostgresqlDatabase(Database): def conflict_update(self, oc, query): action = oc._action.lower() if oc._action else '' if action in ('ignore', 'nothing'): - return SQL('ON CONFLICT DO NOTHING') + parts = [SQL('ON CONFLICT')] + if oc._conflict_target: + parts.append(EnclosedNodeList([ + Entity(col) if isinstance(col, basestring) else col + for col in oc._conflict_target])) + parts.append(SQL('DO NOTHING')) + return NodeList(parts) elif action and action != 'update': raise ValueError('The only supported actions for conflict ' 'resolution with Postgresql are "ignore" or ' @@ -4001,6 +4039,18 @@ class MySQLDatabase(Database): warnings.warn('Unable to determine MySQL version: "%s"' % version) return (0, 0, 0) # Unable to determine version! + def is_connection_usable(self): + if self._state.closed: + return False + + conn = self._state.conn + if hasattr(conn, 'ping'): + try: + conn.ping(False) + except Exception: + return False + return True + def default_values_insert(self, ctx): return ctx.literal('() VALUES ()') @@ -4327,7 +4377,7 @@ class CursorWrapper(object): class DictCursorWrapper(CursorWrapper): def _initialize_columns(self): description = self.cursor.description - self.columns = [t[0][t[0].find('.') + 1:].strip('")') + self.columns = [t[0][t[0].rfind('.') + 1:].strip('()"`') for t in description] self.ncols = len(description) @@ -4345,9 +4395,8 @@ class DictCursorWrapper(CursorWrapper): class NamedTupleCursorWrapper(CursorWrapper): def initialize(self): description = self.cursor.description - self.tuple_class = collections.namedtuple( - 'Row', - [col[0][col[0].find('.') + 1:].strip('"') for col in description]) + self.tuple_class = collections.namedtuple('Row', [ + t[0][t[0].rfind('.') + 1:].strip('()"`') for t in description]) def process_row(self, row): return self.tuple_class(*row) @@ -4414,7 +4463,7 @@ class ForeignKeyAccessor(FieldAccessor): obj = self.rel_model.get(self.field.rel_field == value) instance.__rel__[self.name] = obj return instance.__rel__.get(self.name, value) - elif not self.field.null: + elif not self.field.null and self.field.lazy_load: raise self.rel_model.DoesNotExist return value @@ -4430,7 +4479,8 @@ class ForeignKeyAccessor(FieldAccessor): else: fk_value = instance.__data__.get(self.name) instance.__data__[self.name] = obj - if obj != fk_value and self.name in instance.__rel__: + if (obj != fk_value or obj is None) and \ + self.name in instance.__rel__: del instance.__rel__[self.name] instance._dirty.add(self.name) @@ -4586,6 +4636,10 @@ class Field(ColumnBase): return NodeList(accum) +class AnyField(Field): + field_type = 'ANY' + + class IntegerField(Field): field_type = 'INT' @@ -5148,6 +5202,7 @@ class BareField(Field): class ForeignKeyField(Field): accessor_class = ForeignKeyAccessor + backref_accessor_class = BackrefAccessor def __init__(self, model, field=None, backref=None, on_delete=None, on_update=None, deferrable=None, _deferred=None, @@ -5243,7 +5298,8 @@ class ForeignKeyField(Field): if set_attribute: setattr(model, self.object_id_name, ObjectIdAccessor(self)) if self.backref not in '!+': - setattr(self.rel_model, self.backref, BackrefAccessor(self)) + setattr(self.rel_model, self.backref, + self.backref_accessor_class(self)) def foreign_key_constraint(self): parts = [] @@ -5282,7 +5338,8 @@ class DeferredForeignKey(Field): DeferredForeignKey._unresolved.add(self) super(DeferredForeignKey, self).__init__( column_name=kwargs.get('column_name'), - null=kwargs.get('null')) + null=kwargs.get('null'), + primary_key=kwargs.get('primary_key')) __hash__ = object.__hash__ @@ -5291,7 +5348,11 @@ class DeferredForeignKey(Field): def set_model(self, rel_model): field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) - self.model._meta.add_field(self.name, field) + if field.primary_key: + # NOTE: this calls add_field() under-the-hood. + self.model._meta.set_primary_key(self.name, field) + else: + self.model._meta.add_field(self.name, field) @staticmethod def resolve(model_cls): @@ -5615,8 +5676,11 @@ class SchemaManager(object): raise ValueError('table_settings must be strings') ctx.literal(' ').literal(setting) - if meta.without_rowid: - ctx.literal(' WITHOUT ROWID') + extra_opts = [] + if meta.strict_tables: extra_opts.append('STRICT') + if meta.without_rowid: extra_opts.append('WITHOUT ROWID') + if extra_opts: + ctx.literal(' %s' % ', '.join(extra_opts)) return ctx def _create_table_option_sql(self, options): @@ -5644,7 +5708,7 @@ class SchemaManager(object): if safe: ctx.literal('IF NOT EXISTS ') return (ctx - .sql(Entity(table_name)) + .sql(Entity(*ensure_tuple(table_name))) .literal(' AS ') .sql(query)) @@ -5800,8 +5864,8 @@ class Metadata(object): primary_key=None, constraints=None, schema=None, only_save_dirty=False, depends_on=None, options=None, db_table=None, table_function=None, table_settings=None, - without_rowid=False, temporary=False, legacy_table_names=True, - **kwargs): + without_rowid=False, temporary=False, strict_tables=None, + legacy_table_names=True, **kwargs): if db_table is not None: __deprecated__('"db_table" has been deprecated in favor of ' '"table_name" for Models.') @@ -5842,6 +5906,7 @@ class Metadata(object): self.depends_on = depends_on self.table_settings = table_settings self.without_rowid = without_rowid + self.strict_tables = strict_tables self.temporary = temporary self.refs = {} @@ -6075,7 +6140,11 @@ class Metadata(object): self.model._schema._database = database del self.table - # Apply any hooks that have been registered. + # Apply any hooks that have been registered. If we have an + # uninitialized proxy object, we will treat that as `None`. + if isinstance(database, Proxy) and database.obj is None: + database = None + for hook in self._db_hooks: hook(database) @@ -6103,7 +6172,7 @@ class ModelBase(type): inheritable = set(['constraints', 'database', 'indexes', 'primary_key', 'options', 'schema', 'table_function', 'temporary', 'only_save_dirty', 'legacy_table_names', - 'table_settings']) + 'table_settings', 'strict_tables']) def __new__(cls, name, bases, attrs): if name == MODEL_BASE or bases[0].__name__ == MODEL_BASE: @@ -6133,6 +6202,7 @@ class ModelBase(type): for k in base_meta.__dict__: if k in all_inheritable and k not in meta_options: meta_options[k] = base_meta.__dict__[k] + meta_options.setdefault('database', base_meta.database) meta_options.setdefault('schema', base_meta.schema) for (k, v) in b.__dict__.items(): @@ -6557,10 +6627,12 @@ class Model(with_metaclass(ModelBase, Node)): if pk is not None and (self._meta.auto_increment or pk_value is None): self._pk = pk + # Although we set the primary-key, do not mark it as dirty. + self._dirty.discard(pk_field.name) else: self.insert(**field_dict).execute() - self._dirty.clear() + self._dirty -= set(field_dict) # Remove any fields we saved. return rows def is_dirty(self): @@ -6888,6 +6960,12 @@ class BaseModelSelect(_ModelQueryHelper): 'not exist:\nSQL: %s\nParams: %s' % (clone.model, sql, params)) + def get_or_none(self, database=None): + try: + return self.get(database=database) + except self.model.DoesNotExist: + pass + @Node.copy def group_by(self, *columns): grouping = [] @@ -7117,6 +7195,9 @@ class ModelSelect(BaseModelSelect, Select): item = self._from_list.pop() self._from_list.append(Join(item, dest, join_type, on)) + def left_outer_join(self, dest, on=None, src=None, attr=None): + return self.join(dest, JOIN.LEFT_OUTER, on, src, attr) + def join_from(self, src, dest, join_type=JOIN.INNER, on=None, attr=None): return self.join(dest, join_type, on, src, attr) @@ -7382,13 +7463,20 @@ class BaseModelCursorWrapper(DictCursorWrapper): self.fields = fields = [None] * self.ncols for idx, description_item in enumerate(description): - column = description_item[0] - dot_index = column.find('.') + column = orig_column = description_item[0] + + # Try to clean-up messy column descriptions when people do not + # provide an alias. The idea is that we take something like: + # SUM("t1"."price") -> "price") -> price + dot_index = column.rfind('.') if dot_index != -1: column = column[dot_index + 1:] - - column = column.strip('")') + column = column.strip('()"`') self.columns.append(column) + + # Now we'll see what they selected and see if we can improve the + # column-name being returned - e.g. by mapping it to the selected + # field's name. try: raw_node = self.select[idx] except IndexError: @@ -7399,6 +7487,12 @@ class BaseModelCursorWrapper(DictCursorWrapper): else: node = raw_node.unwrap() + # If this column was given an alias, then we will use whatever + # alias was returned by the cursor. + is_alias = raw_node.is_alias() + if is_alias: + self.columns[idx] = orig_column + # Heuristics used to attempt to get the field associated with a # given SELECT column, so that we can accurately convert the value # returned by the database-cursor into a Python object. @@ -7406,7 +7500,7 @@ class BaseModelCursorWrapper(DictCursorWrapper): if raw_node._coerce: converters[idx] = node.python_value fields[idx] = node - if not raw_node.is_alias(): + if not is_alias: self.columns[idx] = node.name elif isinstance(node, ColumnBase) and raw_node._converter: converters[idx] = raw_node._converter