From dcc5d3c85829ee308292624993916b2fe743e302 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Thu, 13 Jun 2013 16:41:04 -0700 Subject: [PATCH 01/18] Mongomallard --- mongoengine/base/document.py | 436 +++++------------------ mongoengine/base/fields.py | 133 +++---- mongoengine/base/metaclasses.py | 10 +- mongoengine/dereference.py | 14 +- mongoengine/document.py | 165 ++++----- mongoengine/fields.py | 429 +++++++++------------- mongoengine/queryset/queryset.py | 2 +- tests/document/delta.py | 77 ++-- tests/document/dynamic.py | 1 + tests/document/indexes.py | 1 + tests/document/inheritance.py | 8 +- tests/document/instance.py | 63 ++-- tests/document/json_serialisation.py | 1 - tests/document/validation.py | 5 +- tests/fields/fields.py | 176 +++------ tests/fields/file_tests.py | 1 + tests/fields/geo.py | 1 + tests/migration/__init__.py | 1 - tests/migration/decimalfield_as_float.py | 50 --- tests/queryset/geo.py | 1 + tests/queryset/queryset.py | 32 +- tests/queryset/transform.py | 1 + tests/test_dereference.py | 20 +- tests/test_signals.py | 58 +-- 24 files changed, 584 insertions(+), 1102 deletions(-) delete mode 100644 tests/migration/decimalfield_as_float.py diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index ca154a2..1195bc4 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -23,137 +23,40 @@ __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') NON_FIELD_ERRORS = '__all__' +_set = object.__setattr__ + class BaseDocument(object): - _dynamic = False - _created = True - _dynamic_lock = True + #_dynamic = False + #_dynamic_lock = True _initialised = False - def __init__(self, *args, **values): + def __init__(self, _son=None, **values): """ Initialise a document or embedded document - :param __auto_convert: Try and will cast python objects to Object types :param values: A dictionary of values for the document """ - if args: - # Combine positional arguments with named arguments. - # We only want named arguments. - field = iter(self._fields_ordered) - for value in args: - name = next(field) - if name in values: - raise TypeError("Multiple values for keyword argument '" + name + "'") - values[name] = value - __auto_convert = values.pop("__auto_convert", True) - signals.pre_init.send(self.__class__, document=self, values=values) + _set(self, '_db_data', _son) + _set(self, '_lazy', False) + _set(self, '_internal_data', {}) + _set(self, '_changed_fields', set()) + if values: + pk = values.pop('pk', None) + for field in set(self._fields.keys()).intersection(values.keys()): + setattr(self, field, values[field]) + if pk != None: + self.pk = pk - self._data = {} + def __delattr__(self, name): + default = self._fields[name].default + value = default() if callable(default) else default + setattr(self, name, value) - # Assign default values to instance - for key, field in self._fields.iteritems(): - if self._db_field_map.get(key, key) in values: - continue - value = getattr(self, key, None) - setattr(self, key, value) - - # Set passed values after initialisation - if self._dynamic: - self._dynamic_fields = {} - dynamic_data = {} - for key, value in values.iteritems(): - if key in self._fields or key == '_id': - setattr(self, key, value) - elif self._dynamic: - dynamic_data[key] = value - else: - FileField = _import_class('FileField') - for key, value in values.iteritems(): - if key == '__auto_convert': - continue - key = self._reverse_db_field_map.get(key, key) - if key in self._fields or key in ('id', 'pk', '_cls'): - if __auto_convert and value is not None: - field = self._fields.get(key) - if field and not isinstance(field, FileField): - value = field.to_python(value) - setattr(self, key, value) - else: - self._data[key] = value - - # Set any get_fieldname_display methods - self.__set_field_display() - - if self._dynamic: - self._dynamic_lock = False - for key, value in dynamic_data.iteritems(): - setattr(self, key, value) - - # Flag initialised - self._initialised = True - signals.post_init.send(self.__class__, document=self) - - def __delattr__(self, *args, **kwargs): - """Handle deletions of fields""" - field_name = args[0] - if field_name in self._fields: - default = self._fields[field_name].default - if callable(default): - default = default() - setattr(self, field_name, default) - else: - super(BaseDocument, self).__delattr__(*args, **kwargs) - - def __setattr__(self, name, value): - # Handle dynamic data only if an initialised dynamic document - if self._dynamic and not self._dynamic_lock: - - field = None - if not hasattr(self, name) and not name.startswith('_'): - DynamicField = _import_class("DynamicField") - field = DynamicField(db_field=name) - field.name = name - self._dynamic_fields[name] = field - - if not name.startswith('_'): - value = self.__expand_dynamic_values(name, value) - - # Handle marking data as changed - if name in self._dynamic_fields: - self._data[name] = value - if hasattr(self, '_changed_fields'): - self._mark_as_changed(name) - - if (self._is_document and not self._created and - name in self._meta.get('shard_key', tuple()) and - self._data.get(name) != value): - OperationError = _import_class('OperationError') - msg = "Shard Keys are immutable. Tried to update %s" % name - raise OperationError(msg) - - # Check if the user has created a new instance of a class - if (self._is_document and self._initialised - and self._created and name == self._meta['id_field']): - super(BaseDocument, self).__setattr__('_created', False) - - super(BaseDocument, self).__setattr__(name, value) - - def __getstate__(self): - data = {} - for k in ('_changed_fields', '_initialised', '_created'): - if hasattr(self, k): - data[k] = getattr(self, k) - data['_data'] = self.to_mongo() - return data - - def __setstate__(self, data): - if isinstance(data["_data"], SON): - data["_data"] = self.__class__._from_son(data["_data"])._data - for k in ('_changed_fields', '_initialised', '_created', '_data'): - if k in data: - setattr(self, k, data[k]) + @property + def _created(self): + return self._db_data != None or self._lazy def __iter__(self): if 'id' in self._fields and 'id' not in self._fields_ordered: @@ -186,8 +89,8 @@ class BaseDocument(object): except AttributeError: return False - def __len__(self): - return len(self._data) + def __unicode__(self): + return u'%s object' % self.__class__.__name__ def __repr__(self): try: @@ -206,8 +109,8 @@ class BaseDocument(object): return txt_type('%s object' % self.__class__.__name__) def __eq__(self, other): - if isinstance(other, self.__class__) and hasattr(other, 'id'): - if self.id == other.id: + if isinstance(other, self.__class__) and hasattr(other, 'pk'): + if self.pk == other.pk: return True return False @@ -234,47 +137,16 @@ class BaseDocument(object): def to_mongo(self): """Return as SON data ready for use with MongoDB. """ - data = SON() - data["_id"] = None - data['_cls'] = self._class_name + sets, unsets = self._delta(full=True) + son = SON(data=sets) + allow_inheritance = self._meta.get('allow_inheritance', + ALLOW_INHERITANCE) + if allow_inheritance: + son['_cls'] = self._class_name + return son - for field_name in self: - value = self._data.get(field_name, None) - field = self._fields.get(field_name) - - if value is not None: - value = field.to_mongo(value) - - # Handle self generating fields - if value is None and field._auto_gen: - value = field.generate() - self._data[field_name] = value - - if value is not None: - data[field.db_field] = value - - # If "_id" has not been set, then try and set it - if data["_id"] is None: - data["_id"] = self._data.get("id", None) - - if data['_id'] is None: - data.pop('_id') - - # Only add _cls if allow_inheritance is True - if (not hasattr(self, '_meta') or - not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): - data.pop('_cls') - - if not self._dynamic: - return data - - # Sort dynamic fields by key - dynamic_fields = sorted(self._dynamic_fields.iteritems(), - key=operator.itemgetter(0)) - for name, field in dynamic_fields: - data[name] = field.to_mongo(self._data.get(name, None)) - - return data + def to_dict(self): + return dict((field, getattr(self, field)) for field in self._fields) def validate(self, clean=True): """Ensure that all fields' values are valid and that required fields @@ -289,11 +161,11 @@ class BaseDocument(object): errors[NON_FIELD_ERRORS] = error # Get a list of tuples of field names and their current values - fields = [(field, self._data.get(name)) + fields = [(field, getattr(self, name)) for name, field in self._fields.items()] - if self._dynamic: - fields += [(field, self._data.get(name)) - for name, field in self._dynamic_fields.items()] + #if self._dynamic: + # fields += [(field, self._data.get(name)) + # for name, field in self._dynamic_fields.items()] EmbeddedDocumentField = _import_class("EmbeddedDocumentField") GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") @@ -369,15 +241,32 @@ class BaseDocument(object): def _mark_as_changed(self, key): """Marks a key as explicitly changed by the user """ - if not key: - return - key = self._db_field_map.get(key, key) - if (hasattr(self, '_changed_fields') and - key not in self._changed_fields): - self._changed_fields.append(key) + + if key: + self._changed_fields.add(key) + + def _get_changed_fields(self): + changed_fields = set(self._changed_fields) + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + for field_name, field in self._fields.iteritems(): + if (isinstance(field, ComplexBaseField) and + isinstance(field.field, EmbeddedDocumentField)): + field_value = getattr(self, field_name, None) + if field_value: + for idx in (field_value if isinstance(field_value, dict) + else xrange(len(field_value))): + if field_value[idx]._get_changed_fields(): + changed_fields.add(field_name) + continue + elif isinstance(field, EmbeddedDocumentField): + field_value = getattr(self, field_name, None) + if field_value: + if field_value._get_changed_fields(): + changed_fields.add(field_name) + return changed_fields def _clear_changed_fields(self): - self._changed_fields = [] + _set(self, '_changed_fields', set()) EmbeddedDocumentField = _import_class("EmbeddedDocumentField") for field_name, field in self._fields.iteritems(): if (isinstance(field, ComplexBaseField) and @@ -392,135 +281,33 @@ class BaseDocument(object): if field_value: field_value._clear_changed_fields() - def _get_changed_fields(self, inspected=None): - """Returns a list of all fields that have explicitly been changed. - """ - EmbeddedDocument = _import_class("EmbeddedDocument") - DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") - _changed_fields = [] - _changed_fields += getattr(self, '_changed_fields', []) + def _delta(self, full=False): + sets = {} + unsets = {} - inspected = inspected or set() - if hasattr(self, 'id'): - if self.id in inspected: - return _changed_fields - inspected.add(self.id) - - field_list = self._fields.copy() - if self._dynamic: - field_list.update(self._dynamic_fields) - - for field_name in field_list: - - db_field_name = self._db_field_map.get(field_name, field_name) - key = '%s.' % db_field_name - field = self._data.get(field_name, None) - if hasattr(field, 'id'): - if field.id in inspected: - continue - inspected.add(field.id) - - if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) - and db_field_name not in _changed_fields): - # Find all embedded fields that have been changed - changed = field._get_changed_fields(inspected) - _changed_fields += ["%s%s" % (key, k) for k in changed if k] - elif (isinstance(field, (list, tuple, dict)) and - db_field_name not in _changed_fields): - # Loop list / dict fields as they contain documents - # Determine the iterator to use - if not hasattr(field, 'items'): - iterator = enumerate(field) - else: - iterator = field.iteritems() - for index, value in iterator: - if not hasattr(value, '_get_changed_fields'): - continue - list_key = "%s%s." % (key, index) - changed = value._get_changed_fields(inspected) - _changed_fields += ["%s%s" % (list_key, k) - for k in changed if k] - return _changed_fields - - def _delta(self): - """Returns the delta (set, unset) of the changes for a document. - Gets any values that have been explicitly changed. - """ - # Handles cases where not loaded from_son but has _id - doc = self.to_mongo() - - set_fields = self._get_changed_fields() - set_data = {} - unset_data = {} - parts = [] - if hasattr(self, '_changed_fields'): - set_data = {} - # Fetch each set item from its path - for path in set_fields: - parts = path.split('.') - d = doc - new_path = [] - for p in parts: - if isinstance(d, DBRef): - break - elif isinstance(d, list) and p.isdigit(): - d = d[int(p)] - elif hasattr(d, 'get'): - d = d.get(p) - new_path.append(p) - path = '.'.join(new_path) - set_data[path] = d + if full or not self._created: + fields = self._fields.iteritems() else: - set_data = doc - if '_id' in set_data: - del(set_data['_id']) + fields = ((field_name, self._fields[field_name]) for field_name in self._get_changed_fields()) - # Determine if any changed items were actually unset. - for path, value in set_data.items(): - if value or isinstance(value, (numbers.Number, bool)): - continue + def get(field_name, field): + value = getattr(self, field_name) + if value is None: + value = field.default() if callable(field.default) else field.default + return value - # If we've set a value that ain't the default value dont unset it. - default = None - if (self._dynamic and len(parts) and parts[0] in - self._dynamic_fields): - del(set_data[path]) - unset_data[path] = 1 - continue - elif path in self._fields: - default = self._fields[path].default - else: # Perform a full lookup for lists / embedded lookups - d = self - parts = path.split('.') - db_field_name = parts.pop() - for p in parts: - if isinstance(d, list) and p.isdigit(): - d = d[int(p)] - elif (hasattr(d, '__getattribute__') and - not isinstance(d, dict)): - real_path = d._reverse_db_field_map.get(p, p) - d = getattr(d, real_path) - else: - d = d.get(p) + data = (( + self._db_field_map.get(field_name, field_name), + field.to_mongo(get(field_name, field))) + for field_name, field in fields) - if hasattr(d, '_fields'): - field_name = d._reverse_db_field_map.get(db_field_name, - db_field_name) - if field_name in d._fields: - default = d._fields.get(field_name).default - else: - default = None + for db_field_name, db_value in data: + if db_value == None: + unsets[db_field_name] = 1 + else: + sets[db_field_name] = db_value - if default is not None: - if callable(default): - default = default() - - if default != value: - continue - - del(set_data[path]) - unset_data[path] = 1 - return set_data, unset_data + return sets, unsets @classmethod def _get_collection_name(cls): @@ -529,61 +316,16 @@ class BaseDocument(object): return cls._meta.get('collection', None) @classmethod - def _from_son(cls, son, _auto_dereference=True): - """Create an instance of a Document (subclass) from a PyMongo SON. - """ - + def _from_son(cls, son, _auto_dereference=False): # get the class name from the document, falling back to the given # class if unavailable class_name = son.get('_cls', cls._class_name) - data = dict(("%s" % key, value) for key, value in son.iteritems()) - if not UNICODE_KWARGS: - # python 2.6.4 and lower cannot handle unicode keys - # passed to class constructor example: cls(**data) - to_str_keys_recursive(data) # Return correct subclass for document type if class_name != cls._class_name: cls = get_document(class_name) - changed_fields = [] - errors_dict = {} - - fields = cls._fields - if not _auto_dereference: - fields = copy.copy(fields) - - for field_name, field in fields.iteritems(): - field._auto_dereference = _auto_dereference - if field.db_field in data: - value = data[field.db_field] - try: - data[field_name] = (value if value is None - else field.to_python(value)) - if field_name != field.db_field: - del data[field.db_field] - except (AttributeError, ValueError), e: - errors_dict[field_name] = e - elif field.default: - default = field.default - if callable(default): - default = default() - if isinstance(default, BaseDocument): - changed_fields.append(field_name) - - if errors_dict: - errors = "\n".join(["%s - %s" % (k, v) - for k, v in errors_dict.items()]) - msg = ("Invalid data to create a `%s` instance.\n%s" - % (cls._class_name, errors)) - raise InvalidDocumentError(msg) - - obj = cls(__auto_convert=False, **data) - obj._changed_fields = changed_fields - obj._created = False - if not _auto_dereference: - obj._fields = fields - return obj + return cls(_son=son) @classmethod def _build_index_specs(cls, meta_indexes): @@ -773,9 +515,9 @@ class BaseDocument(object): field_name = cls._meta['id_field'] if field_name in cls._fields: field = cls._fields[field_name] - elif cls._dynamic: - DynamicField = _import_class('DynamicField') - field = DynamicField(db_field=field_name) + #elif cls._dynamic: + # DynamicField = _import_class('DynamicField') + # field = DynamicField(db_field=field_name) else: raise LookUpError('Cannot resolve field "%s"' % field_name) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index eda9b3c..168c063 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -59,15 +59,17 @@ class BaseField(object): :param help_text: (optional) The help text for this field and is often used when generating model forms from the document model. """ - self.db_field = (db_field or name) if not primary_key else '_id' - if name: - msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" - warnings.warn(msg, DeprecationWarning) + self.name = None # filled in by document + self.db_field = db_field self.required = required or primary_key self.default = default self.unique = bool(unique or unique_with) self.unique_with = unique_with self.primary_key = primary_key + if self.primary_key: + if self.db_field: + raise ValueError("Can't use primary_key in conjunction with db_field.") + self.db_field = '_id' self.validation = validation self.choices = choices self.verbose_name = verbose_name @@ -82,41 +84,52 @@ class BaseField(object): BaseField.creation_counter += 1 def __get__(self, instance, owner): - """Descriptor for retrieving a value from a field in a document. - """ if instance is None: # Document class being used rather than a document object return self + else: + name = self.name + data = instance._internal_data + if not name in data: + if instance._lazy and name != instance._meta['id_field']: + # We need to fetch the doc from the database. + instance.reload() + db_field = instance._db_field_map.get(name, name) + try: + db_value = instance._db_data[db_field] + except (TypeError, KeyError): + value = self.default() if callable(self.default) else self.default + else: + value = self.to_python(db_value) - # Get value from document instance if available - value = instance._data.get(self.name) + if hasattr(self, 'value_for_instance'): + value = self.value_for_instance(value, instance) + data[name] = value - EmbeddedDocument = _import_class('EmbeddedDocument') - if isinstance(value, EmbeddedDocument) and value._instance is None: - value._instance = weakref.proxy(instance) - return value + return data[name] def __set__(self, instance, value): """Descriptor for assigning a value to a field in a document. """ - # If setting to None and theres a default - # Then set the value to the default value - if value is None and self.default is not None: - value = self.default - if callable(value): - value = value() + if instance._lazy: + # Fetch the from the database before we assign to a lazy object. + instance.reload() - if instance._initialised: - try: - if (self.name not in instance._data or - instance._data[self.name] != value): - instance._mark_as_changed(self.name) - except: - # Values cant be compared eg: naive and tz datetimes - # So mark it as changed - instance._mark_as_changed(self.name) - instance._data[self.name] = value + name = self.name + + value = self.from_python(value) + if hasattr(self, 'value_for_instance'): + value = self.value_for_instance(value, instance) + try: + has_changed = name not in instance._internal_data or instance._internal_data[name] != value + except: # Values can't be compared eg: naive and tz datetimes + has_changed = True + + if has_changed: + instance._mark_as_changed(name) + + instance._internal_data[name] = value def error(self, message="", errors=None, field_name=None): """Raises a ValidationError. @@ -132,7 +145,15 @@ class BaseField(object): def to_mongo(self, value): """Convert a Python type to a MongoDB-compatible type. """ - return self.to_python(value) + return value + + def from_python(self, value): + """Convert a raw Python value (in an assignment) to the internal + Python representation. + """ + if value == None: + return self.default() if callable(self.default) else self.default + return value def prepare_query_value(self, op, value): """Prepare a value that is being used in a query for PyMongo. @@ -186,49 +207,6 @@ class ComplexBaseField(BaseField): """ field = None - __dereference = False - - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - ReferenceField = _import_class('ReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') - dereference = (self._auto_dereference and - (self.field is None or isinstance(self.field, - (GenericReferenceField, ReferenceField)))) - - self._auto_dereference = instance._fields[self.name]._auto_dereference - if not self.__dereference and instance._initialised and dereference: - instance._data[self.name] = self._dereference( - instance._data.get(self.name), max_depth=1, instance=instance, - name=self.name - ) - - value = super(ComplexBaseField, self).__get__(instance, owner) - - # Convert lists / values so we can watch for any changes on them - if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): - value = BaseList(value, instance, self.name) - instance._data[self.name] = value - elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, instance, self.name) - instance._data[self.name] = value - - if (self._auto_dereference and instance._initialised and - isinstance(value, (BaseList, BaseDict)) - and not value._dereferenced): - value = self._dereference( - value, max_depth=1, instance=instance, name=self.name - ) - value._dereferenced = True - instance._data[self.name] = value - - return value def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. @@ -382,25 +360,16 @@ class ComplexBaseField(BaseField): owner_document = property(_get_owner_document, _set_owner_document) - @property - def _dereference(self,): - if not self.__dereference: - DeReference = _import_class("DeReference") - self.__dereference = DeReference() # Cached - return self.__dereference - class ObjectIdField(BaseField): """A field wrapper around MongoDB's ObjectIds. """ def to_python(self, value): - if not isinstance(value, ObjectId): - value = ObjectId(value) return value def to_mongo(self, value): - if not isinstance(value, ObjectId): + if value and not isinstance(value, ObjectId): try: return ObjectId(unicode(value)) except Exception, e: diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 444d9a2..34a8a51 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -359,10 +359,14 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Set primary key if not defined by the document if not new_class._meta.get('id_field'): - new_class._meta['id_field'] = 'id' - new_class._fields['id'] = ObjectIdField(db_field='_id') - new_class._fields['id'].name = 'id' + id_field = ObjectIdField(primary_key=True) + id_field.name = 'id' + id_field._auto_gen = True + new_class._fields['id'] = id_field new_class.id = new_class._fields['id'] + new_class._meta['id_field'] = 'id' + new_class._db_field_map['id'] = id_field.db_field + # Merge in exceptions with parent hierarchy exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index e5e8886..b9d79e6 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -86,7 +86,7 @@ class DeReference(object): for k, item in iterator: if isinstance(item, Document): for field_name, field in item._fields.iteritems(): - v = item._data.get(field_name, None) + v = getattr(item, field_name) if isinstance(v, (DBRef)): reference_map.setdefault(field.document_type, []).append(v.id) elif isinstance(v, (dict, SON)) and '_ref' in v: @@ -169,7 +169,7 @@ class DeReference(object): return self.object_map.get(items['_ref'].id, items) elif '_cls' in items: doc = get_document(items['_cls'])._from_son(items) - doc._data = self._attach_objects(doc._data, depth, doc, None) + doc._internal_data = self._attach_objects(doc._internal_data, depth, doc, None) return doc if not hasattr(items, 'items'): @@ -193,15 +193,15 @@ class DeReference(object): data[k] = self.object_map[k] elif isinstance(v, Document): for field_name, field in v._fields.iteritems(): - v = data[k]._data.get(field_name, None) + v = data[k]._internal_data.get(field_name, None) if isinstance(v, (DBRef)): - data[k]._data[field_name] = self.object_map.get(v.id, v) + data[k]._internal_data[field_name] = self.object_map.get(v.id, v) elif isinstance(v, (dict, SON)) and '_ref' in v: - data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) + data[k]._internal_data[field_name] = self.object_map.get(v['_ref'].id, v) elif isinstance(v, dict) and depth <= self.max_depth: - data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) + data[k]._internal_data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) elif isinstance(v, (list, tuple)) and depth <= self.max_depth: - data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) + data[k]._internal_data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) elif hasattr(v, 'id'): diff --git a/mongoengine/document.py b/mongoengine/document.py index a61ed07..dca9ccb 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -12,7 +12,7 @@ from mongoengine.common import _import_class from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) -from mongoengine.queryset import OperationError, NotUniqueError, QuerySet +from mongoengine.queryset import OperationError, NotUniqueError, QuerySet, DoesNotExist from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME from mongoengine.context_managers import switch_db, switch_collection @@ -20,6 +20,7 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') +_set = object.__setattr__ def includes_cls(fields): """ Helper function used for ensuring and comparing indexes @@ -62,11 +63,11 @@ class EmbeddedDocument(BaseDocument): def __init__(self, *args, **kwargs): super(EmbeddedDocument, self).__init__(*args, **kwargs) - self._changed_fields = [] + self._changed_fields = set() def __eq__(self, other): if isinstance(other, self.__class__): - return self._data == other._data + return self.to_dict() == other.to_dict() return False def __ne__(self, other): @@ -177,15 +178,13 @@ class Document(BaseDocument): cls.ensure_indexes() return cls._collection - def save(self, force_insert=False, validate=True, clean=True, + def save(self, validate=True, clean=True, write_concern=None, cascade=None, cascade_kwargs=None, - _refs=None, **kwargs): + _refs=None, full=False, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. - :param force_insert: only try to create a new document, don't allow - updates of existing documents :param validate: validates the document; set to ``False`` to skip. :param clean: call the document clean method, requires `validate` to be True. @@ -202,6 +201,7 @@ class Document(BaseDocument): :param cascade_kwargs: (optional) kwargs dictionary to be passed throw to cascading saves. Implies ``cascade=True``. :param _refs: A list of processed references used in cascading saves + :param full: Save all model fields instead of just changed ones. .. versionchanged:: 0.5 In existing documents it only saves changed fields using @@ -217,62 +217,52 @@ class Document(BaseDocument): the cascade save using cascade_kwargs which overwrites the existing kwargs with custom values. """ + signals.pre_save.send(self.__class__, document=self) if validate: self.validate(clean=clean) - if write_concern is None: - write_concern = {"w": 1} - - doc = self.to_mongo() - - created = ('_id' not in doc or self._created or force_insert) - - signals.pre_save_post_validation.send(self.__class__, document=self, created=created) + if not write_concern: + write_concern = {'w': 1} + collection = self._get_collection() try: - collection = self._get_collection() - if created: - if force_insert: - object_id = collection.insert(doc, **write_concern) - else: - object_id = collection.save(doc, **write_concern) - else: - object_id = doc['_id'] - updates, removals = self._delta() - # Need to add shard key to query, or you get an error - select_dict = {'_id': object_id} - shard_key = self.__class__._meta.get('shard_key', tuple()) - for k in shard_key: - actual_key = self._db_field_map.get(k, k) - select_dict[actual_key] = doc[actual_key] - - def is_new_object(last_error): - if last_error is not None: - updated = last_error.get("updatedExisting") - if updated is not None: - return not updated - return created + if self._created: + # Update: Get delta. + sets, unsets = self._delta(full) + db_id_field = self._fields[self._meta['id_field']].db_field + sets.pop(db_id_field, None) update_query = {} + if sets: + update_query['$set'] = sets + if unsets: + update_query['$unset'] = unsets - if updates: - update_query["$set"] = updates - if removals: - update_query["$unset"] = removals - if updates or removals: - last_error = collection.update(select_dict, update_query, - upsert=True, **write_concern) - created = is_new_object(last_error) + if update_query: + collection.update(self._db_object_key, update_query, **write_concern) + created = False + else: + # Insert: Get full SON. + doc = self.to_mongo() + object_id = collection.insert(doc, **write_concern) + # Fix pymongo's "return return_one and ids[0] or ids": + # If the ID is 0, pymongo wraps it in a list. + if isinstance(object_id, list) and not object_id[0]: + object_id = object_id[0] - if cascade is None: - cascade = self._meta.get('cascade', False) or cascade_kwargs is not None + id_field = self._meta['id_field'] + del self._internal_data[id_field] + _set(self, '_db_data', doc) + doc['_id'] = object_id + created = True + cascade = (self._meta.get('cascade', False) + if cascade is None else cascade) if cascade: kwargs = { - "force_insert": force_insert, "validate": validate, "write_concern": write_concern, "cascade": cascade @@ -290,12 +280,9 @@ class Document(BaseDocument): message = u'Tried to save duplicate unique keys (%s)' raise NotUniqueError(message % unicode(err)) raise OperationError(message % unicode(err)) - id_field = self._meta['id_field'] - if id_field not in self._meta.get('shard_key', []): - self[id_field] = self._fields[id_field].to_python(object_id) self._clear_changed_fields() - self._created = False + signals.post_save.send(self.__class__, document=self, created=created) return self @@ -312,14 +299,17 @@ class Document(BaseDocument): GenericReferenceField)): continue - ref = self._data.get(name) + ref = getattr(self, name) if not ref or isinstance(ref, DBRef): continue if not getattr(ref, '_changed_fields', True): continue - ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data)) + if getattr(ref, '_lazy', False): + continue + + ref_id = "%s,%s" % (ref.__class__.__name__, str(ref.to_dict())) if ref and ref_id not in _refs: _refs.append(ref_id) kwargs["_refs"] = _refs @@ -345,6 +335,16 @@ class Document(BaseDocument): select_dict[k] = getattr(self, k) return select_dict + @property + def _db_object_key(self): + field = self._fields[self._meta['id_field']] + select_dict = {field.db_field: field.to_mongo(self.pk)} + shard_key = self.__class__._meta.get('shard_key', tuple()) + for k in shard_key: + actual_key = self._db_field_map.get(k, k) + select_dict[actual_key] = self._fields[k].to_mongo(getattr(self, k)) + return select_dict + def update(self, **kwargs): """Performs an update on the :class:`~mongoengine.Document` A convenience wrapper to :meth:`~mongoengine.QuerySet.update`. @@ -371,6 +371,9 @@ class Document(BaseDocument): """ signals.pre_delete.send(self.__class__, document=self) + if not write_concern: + write_concern = {'w': 1} + try: self._qs.filter(**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) except pymongo.errors.OperationFailure, err: @@ -399,7 +402,7 @@ class Document(BaseDocument): self._get_collection = lambda: collection self._get_db = lambda: db self._collection = collection - self._created = True + #self._created = True self.__objects = self._qs self.__objects._collection_obj = collection return self @@ -424,7 +427,7 @@ class Document(BaseDocument): collection = cls._get_collection() self._get_collection = lambda: collection self._collection = collection - self._created = True + #self._created = True self.__objects = self._qs self.__objects._collection_obj = collection return self @@ -436,46 +439,22 @@ class Document(BaseDocument): .. versionadded:: 0.5 """ import dereference - self._data = dereference.DeReference()(self._data, max_depth) + self._internal_data = dereference.DeReference()(self._internal_data, max_depth) return self - def reload(self, max_depth=1): + def reload(self): """Reloads all attributes from the database. - - .. versionadded:: 0.1.2 - .. versionchanged:: 0.6 Now chainable """ id_field = self._meta['id_field'] - obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( - **{id_field: self[id_field]}).limit(1).select_related(max_depth=max_depth) - - if obj: - obj = obj[0] - else: - msg = "Reloaded document has been deleted" - raise OperationError(msg) - for field in self._fields: - setattr(self, field, self._reload(field, obj[field])) - if self._dynamic: - for name in self._dynamic_fields.keys(): - setattr(self, name, self._reload(name, obj._data[name])) - self._changed_fields = obj._changed_fields - self._created = False - return obj - - def _reload(self, key, value): - """Used by :meth:`~mongoengine.Document.reload` to ensure the - correct instance is linked to self. - """ - if isinstance(value, BaseDict): - value = [(k, self._reload(k, v)) for k, v in value.items()] - value = BaseDict(value, self, key) - elif isinstance(value, BaseList): - value = [self._reload(key, v) for v in value] - value = BaseList(value, self, key) - elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): - value._changed_fields = [] - return value + collection = self._get_collection() + son = collection.find_one(self._db_object_key) + if son == None: + raise DoesNotExist('Document has been deleted.') + _set(self, '_db_data', son) + _set(self, '_internal_data', {}) + _set(self, '_lazy', False) + self._clear_changed_fields() + return self def to_dbref(self): """Returns an instance of :class:`~bson.dbref.DBRef` useful in @@ -675,6 +654,8 @@ class DynamicDocument(Document): _dynamic = True + # TODO + def __delattr__(self, *args, **kwargs): """Deletes the attribute by setting to None and allowing _delta to unset it""" @@ -698,6 +679,8 @@ class DynamicEmbeddedDocument(EmbeddedDocument): _dynamic = True + # TODO + def __delattr__(self, *args, **kwargs): """Deletes the attribute by setting to None and allowing _delta to unset it""" diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 451f7ac..8811f7f 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -3,6 +3,7 @@ import decimal import itertools import re import time +import types import urllib2 import uuid import warnings @@ -22,8 +23,10 @@ from bson import Binary, DBRef, SON, ObjectId from mongoengine.errors import ValidationError from mongoengine.python_support import (PY3, bin_type, txt_type, str_types, StringIO) -from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, +from mongoengine.base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, get_document, BaseDocument) +from mongoengine.base.datastructures import BaseList, BaseDict +from mongoengine.queryset import DoesNotExist from queryset import DO_NOTHING, QuerySet from document import Document, EmbeddedDocument from connection import get_db, DEFAULT_CONNECTION_NAME @@ -34,11 +37,12 @@ except ImportError: Image = None ImageOps = None -__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField', - 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', +__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', + 'FloatField', 'BooleanField', 'DateTimeField', 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', 'SortedListField', 'DictField', 'MapField', 'ReferenceField', + 'SafeReferenceField', 'SafeReferenceListField', 'GenericReferenceField', 'BinaryField', 'GridFSError', 'GridFSProxy', 'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', @@ -58,14 +62,8 @@ class StringField(BaseField): self.min_length = min_length super(StringField, self).__init__(**kwargs) - def to_python(self, value): - if isinstance(value, unicode): - return value - try: - value = value.decode('utf-8') - except: - pass - return value + def to_mongo(self, value): + return value or None def validate(self, value): if not isinstance(value, basestring): @@ -121,8 +119,7 @@ class URLField(StringField): r'(?::\d+)?' # optional port r'(?:/?|[/?]\S+)$', re.IGNORECASE) - def __init__(self, verify_exists=False, url_regex=None, **kwargs): - self.verify_exists = verify_exists + def __init__(self, url_regex=None, **kwargs): self.url_regex = url_regex or self._URL_REGEX super(URLField, self).__init__(**kwargs) @@ -131,51 +128,29 @@ class URLField(StringField): self.error('Invalid URL: %s' % value) return - if self.verify_exists: - warnings.warn( - "The URLField verify_exists argument has intractable security " - "and performance issues. Accordingly, it has been deprecated.", - DeprecationWarning) - try: - request = urllib2.Request(value) - urllib2.urlopen(request) - except Exception, e: - self.error('This URL appears to be a broken link: %s' % e) - class EmailField(StringField): - """A field that validates input as an E-Mail-Address. + """A field that validates input as an email address. .. versionadded:: 0.4 """ - EMAIL_REGEX = re.compile( - r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain - ) + EMAIL_REGEX = re.compile(r'^.+@[^.].*\.[a-z]{2,10}$', re.IGNORECASE) def validate(self, value): if not EmailField.EMAIL_REGEX.match(value): - self.error('Invalid Mail-address: %s' % value) + self.error('Invalid email address: %s' % value) super(EmailField, self).validate(value) class IntField(BaseField): - """An 32-bit integer field. + """An integer field. """ def __init__(self, min_value=None, max_value=None, **kwargs): self.min_value, self.max_value = min_value, max_value super(IntField, self).__init__(**kwargs) - def to_python(self, value): - try: - value = int(value) - except ValueError: - pass - return value - def validate(self, value): try: value = int(value) @@ -188,62 +163,15 @@ class IntField(BaseField): if self.max_value is not None and value > self.max_value: self.error('Integer value is too large') - def prepare_query_value(self, op, value): - if value is None: - return value - - return int(value) - - -class LongField(BaseField): - """An 64-bit integer field. - """ - - def __init__(self, min_value=None, max_value=None, **kwargs): - self.min_value, self.max_value = min_value, max_value - super(LongField, self).__init__(**kwargs) - - def to_python(self, value): - try: - value = long(value) - except ValueError: - pass - return value - - def validate(self, value): - try: - value = long(value) - except: - self.error('%s could not be converted to long' % value) - - if self.min_value is not None and value < self.min_value: - self.error('Long value is too small') - - if self.max_value is not None and value > self.max_value: - self.error('Long value is too large') - - def prepare_query_value(self, op, value): - if value is None: - return value - - return long(value) - class FloatField(BaseField): - """An floating point number field. + """A floating point number field. """ def __init__(self, min_value=None, max_value=None, **kwargs): self.min_value, self.max_value = min_value, max_value super(FloatField, self).__init__(**kwargs) - def to_python(self, value): - try: - value = float(value) - except ValueError: - pass - return value - def validate(self, value): if isinstance(value, int): value = float(value) @@ -256,82 +184,6 @@ class FloatField(BaseField): if self.max_value is not None and value > self.max_value: self.error('Float value is too large') - def prepare_query_value(self, op, value): - if value is None: - return value - - return float(value) - - -class DecimalField(BaseField): - """A fixed-point decimal number field. - - .. versionchanged:: 0.8 - .. versionadded:: 0.3 - """ - - def __init__(self, min_value=None, max_value=None, force_string=False, - precision=2, rounding=decimal.ROUND_HALF_UP, **kwargs): - """ - :param min_value: Validation rule for the minimum acceptable value. - :param max_value: Validation rule for the maximum acceptable value. - :param force_string: Store as a string. - :param precision: Number of decimal places to store. - :param rounding: The rounding rule from the python decimal libary: - - - decimial.ROUND_CEILING (towards Infinity) - - decimial.ROUND_DOWN (towards zero) - - decimial.ROUND_FLOOR (towards -Infinity) - - decimial.ROUND_HALF_DOWN (to nearest with ties going towards zero) - - decimial.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer) - - decimial.ROUND_HALF_UP (to nearest with ties going away from zero) - - decimial.ROUND_UP (away from zero) - - decimial.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero) - - Defaults to: ``decimal.ROUND_HALF_UP`` - - """ - self.min_value = min_value - self.max_value = max_value - self.force_string = force_string - self.precision = decimal.Decimal(".%s" % ("0" * precision)) - self.rounding = rounding - - super(DecimalField, self).__init__(**kwargs) - - def to_python(self, value): - if value is None: - return value - - # Convert to string for python 2.6 before casting to Decimal - value = decimal.Decimal("%s" % value) - return value.quantize(self.precision, rounding=self.rounding) - - def to_mongo(self, value): - if value is None: - return value - if self.force_string: - return unicode(value) - return float(self.to_python(value)) - - def validate(self, value): - if not isinstance(value, decimal.Decimal): - if not isinstance(value, basestring): - value = unicode(value) - try: - value = decimal.Decimal(value) - except Exception, exc: - self.error('Could not convert value to decimal: %s' % exc) - - if self.min_value is not None and value < self.min_value: - self.error('Decimal value is too small') - - if self.max_value is not None and value > self.max_value: - self.error('Decimal value is too large') - - def prepare_query_value(self, op, value): - return self.to_mongo(value) - class BooleanField(BaseField): """A boolean field type. @@ -339,13 +191,6 @@ class BooleanField(BaseField): .. versionadded:: 0.1.2 """ - def to_python(self, value): - try: - value = bool(value) - except ValueError: - pass - return value - def validate(self, value): if not isinstance(value, bool): self.error('BooleanField only accepts boolean values') @@ -366,11 +211,13 @@ class DateTimeField(BaseField): """ def validate(self, value): - new_value = self.to_mongo(value) - if not isinstance(new_value, (datetime.datetime, datetime.date)): + if not isinstance(value, (datetime.datetime, datetime.date)): self.error(u'cannot parse date "%s"' % value) - def to_mongo(self, value): + def from_python(self, value): + return self.prepare_query_value(None, value) or value + + def prepare_query_value(self, op, value): if value is None: return value if isinstance(value, datetime.datetime): @@ -414,9 +261,6 @@ class DateTimeField(BaseField): except ValueError: return None - def prepare_query_value(self, op, value): - return self.to_mongo(value) - class ComplexDateTimeField(StringField): """ @@ -437,6 +281,8 @@ class ComplexDateTimeField(StringField): .. versionadded:: 0.5 """ + # TODO + def __init__(self, separator=',', **kwargs): self.names = ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond'] @@ -542,15 +388,11 @@ class EmbeddedDocumentField(BaseField): self.document_type_obj = get_document(self.document_type_obj) return self.document_type_obj - def to_python(self, value): - if not isinstance(value, self.document_type): - return self.document_type._from_son(value) - return value + def to_python(self, val): + return self.document_type._from_son(val) - def to_mongo(self, value): - if not isinstance(value, self.document_type): - return value - return self.document_type.to_mongo(value) + def to_mongo(self, val): + return val and val.to_mongo() def validate(self, value, clean=True): """Make sure that the document instance is an instance of the @@ -584,9 +426,8 @@ class GenericEmbeddedDocumentField(BaseField): return self.to_mongo(value) def to_python(self, value): - if isinstance(value, dict): - doc_cls = get_document(value['_cls']) - value = doc_cls._from_son(value) + doc_cls = get_document(value['_cls']) + value = doc_cls._from_son(value) return value @@ -674,6 +515,21 @@ class ListField(ComplexBaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) + def value_for_instance(self, value, instance): + return BaseList(value, instance, self.name) + + def from_python(self, val): + from_python = getattr(self.field, 'from_python', None) + return [from_python(v) for v in val] if from_python else val + + def to_python(self, val): + to_python = getattr(self.field, 'to_python', None) + return [to_python(v) for v in val] if to_python else val + + def to_mongo(self, val): + to_mongo = getattr(self.field, 'to_mongo', None) + return [to_mongo(v) for v in val] if to_mongo and val else val or None + def validate(self, value): """Make sure that a list of valid fields is being used. """ @@ -689,6 +545,9 @@ class ListField(ComplexBaseField): and hasattr(value, '__iter__')): return [self.field.prepare_query_value(op, v) for v in value] return self.field.prepare_query_value(op, value) + else: + if op in ('set', 'unset'): + return value return super(ListField, self).prepare_query_value(op, value) @@ -719,10 +578,13 @@ class SortedListField(ListField): def to_mongo(self, value): value = super(SortedListField, self).to_mongo(value) - if self._ordering is not None: - return sorted(value, key=itemgetter(self._ordering), - reverse=self._order_reverse) - return sorted(value, reverse=self._order_reverse) + if value: + if self._ordering is not None: + return sorted(value, key=itemgetter(self._ordering), + reverse=self._order_reverse) + return sorted(value, reverse=self._order_reverse) + else: + return value class DictField(ComplexBaseField): @@ -744,6 +606,21 @@ class DictField(ComplexBaseField): kwargs.setdefault('default', lambda: {}) super(DictField, self).__init__(*args, **kwargs) + def from_python(self, val): + from_python = getattr(self.field, 'from_python', None) + return {k: from_python(v) for k, v in val.iteritems()} if from_python else val + + def to_python(self, val): + to_python = getattr(self.field, 'to_python', None) + return {k: to_python(v) for k, v in val.iteritems()} if to_python else val + + def value_for_instance(self, value, instance): + return BaseDict(value, instance, self.name) + + def to_mongo(self, val): + to_mongo = getattr(self.field, 'to_mongo', None) + return {k: to_mongo(v) for k, v in val.iteritems()} if to_mongo and val else val or None + def validate(self, value): """Make sure that a list of valid fields is being used. """ @@ -851,69 +728,72 @@ class ReferenceField(BaseField): self.document_type_obj = get_document(self.document_type_obj) return self.document_type_obj - def __get__(self, instance, owner): - """Descriptor to allow lazy dereferencing. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value = instance._data.get(self.name) - self._auto_dereference = instance._fields[self.name]._auto_dereference - # Dereference DBRefs - if self._auto_dereference and isinstance(value, DBRef): - value = self.document_type._get_db().dereference(value) - if value is not None: - instance._data[self.name] = self.document_type._from_son(value) - - return super(ReferenceField, self).__get__(instance, owner) - - def to_mongo(self, document): - if isinstance(document, DBRef): - if not self.dbref: - return document.id - return document - - id_field_name = self.document_type._meta['id_field'] - id_field = self.document_type._fields[id_field_name] - - if isinstance(document, Document): + def to_mongo(self, value): + if isinstance(value, DBRef): + if self.dbref: + return value + else: + return value.id + elif isinstance(value, Document): + document_type = self.document_type # We need the id from the saved object to create the DBRef - id_ = document.pk - if id_ is None: + pk = value.pk + if pk is None: self.error('You can only reference documents once they have' ' been saved to the database') - else: - id_ = document - - id_ = id_field.to_mongo(id_) - if self.dbref: - collection = self.document_type._get_collection_name() - return DBRef(collection, id_) - - return id_ + id_field_name = document_type._meta['id_field'] + id_field = document_type._fields[id_field_name] + pk = id_field.to_mongo(pk) + if self.dbref: + collection = document_type._get_collection_name() + return DBRef(collection, pk) + else: + return pk + elif value != None: # string ID + document_type = self.document_type + collection = document_type._get_collection_name() + return DBRef(collection, value) def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ - if (not self.dbref and - not isinstance(value, (DBRef, Document, EmbeddedDocument))): - collection = self.document_type._get_collection_name() - value = DBRef(collection, self.document_type.id.to_python(value)) - return value + if value != None: + document_type = self.document_type + if self.dbref: + obj = document_type(pk=value.id) + else: + if isinstance(value, DBRef): + obj = document_type(pk=value.id) + else: + obj = document_type(pk=value) + obj._lazy = True + return obj + + def from_python(self, value): + if isinstance(value, BaseDocument): + return value + elif value == None: + return super(ReferenceField, self).from_python(value) + else: + # Support for werkzeug.local.LocalProxy + if hasattr(value, '_get_current_object'): + return value._get_current_object() + else: + # DBRef or ID + document_type = self.document_type + if isinstance(value, DBRef): + obj = document_type(pk=value.id) + else: + obj = document_type(pk=value) + obj._lazy = True + return obj def prepare_query_value(self, op, value): - if value is None: - return None - return self.to_mongo(value) + return self.to_mongo(self.from_python(value)) def validate(self, value): - if not isinstance(value, (self.document_type, DBRef)): self.error("A ReferenceField only accepts DBRef or documents") - if isinstance(value, Document) and value.id is None: + if isinstance(value, Document) and value.pk is None: self.error('You can only reference documents once they have been ' 'saved to the database') @@ -921,6 +801,52 @@ class ReferenceField(BaseField): return self.document_type._fields.get(member_name) +class SafeReferenceField(ReferenceField): + """ + Like a ReferenceField, but doesn't return non-existing references when + dereferencing, i.e. no DBRefs are returned. This means that the next time + an object is saved, the non-existing references are removed and application + code can rely on having only valid dereferenced objects. + + When the field is referenced, the referenced object is loaded from the + database. + """ + + def to_python(self, value): + obj = super(SafeReferenceField, self).to_python(value) + if obj: + # Must dereference so we don't get an invalid ObjectId back. + try: + obj.reload() + except DoesNotExist: + return None + return obj + + +class SafeReferenceListField(ListField): + """ + Like a ListField, but doesn't return non-existing references when + dereferencing, i.e. no DBRefs are returned. This means that the next time + an object is saved, the non-existing references are removed and application + code can rely on having only valid dereferenced objects. + + When the field is referenced, all referenced objects are loaded from the + database. + + Must use ReferenceField as its field class. + """ + + def __init__(self, field, **kwargs): + if not isinstance(field, ReferenceField): + raise ValueError('Field argument must be a ReferenceField instance.') + return super(SafeReferenceListField, self).__init__(field, **kwargs) + + def to_python(self, value): + result = super(SafeReferenceListField, self).to_python(value) + if result: + objs = self.field.document_type.objects.in_bulk([obj.id for obj in result]) + return filter(None, [objs.get(obj.id) for obj in result]) + class GenericReferenceField(BaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -935,17 +861,6 @@ class GenericReferenceField(BaseField): .. versionadded:: 0.3 """ - def __get__(self, instance, owner): - if instance is None: - return self - - value = instance._data.get(self.name) - self._auto_dereference = instance._fields[self.name]._auto_dereference - if self._auto_dereference and isinstance(value, (dict, SON)): - instance._data[self.name] = self.dereference(value) - - return super(GenericReferenceField, self).__get__(instance, owner) - def validate(self, value): if not isinstance(value, (Document, DBRef, dict, SON)): self.error('GenericReferences can only contain documents') @@ -967,6 +882,14 @@ class GenericReferenceField(BaseField): doc = doc_cls._from_son(doc) return doc + def to_python(self, value): + if value != None: + doc_cls = get_document(value['_cls']) + reference = value['_ref'] + obj = doc_cls(pk=reference.id) + obj._lazy = True + return obj + def to_mongo(self, document): if document is None: return None diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index d58a13b..21b23c1 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -363,7 +363,7 @@ class QuerySet(object): msg = ("Some documents inserted aren't instances of %s" % str(self._document)) raise OperationError(msg) - if doc.pk and not doc._created: + if doc.pk and doc._created: msg = "Some documents have ObjectIds use doc.update() instead" raise OperationError(msg) raw.append(doc.to_mongo()) diff --git a/tests/document/delta.py b/tests/document/delta.py index 16ab609..355717f 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -48,41 +48,42 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._delta(), ({}, {})) doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['string_field']) + self.assertEqual(doc._get_changed_fields(), set(['string_field'])) self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) - doc._changed_fields = [] + doc._changed_fields = set() doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['int_field']) + self.assertEqual(doc._get_changed_fields(), set(['int_field'])) self.assertEqual(doc._delta(), ({'int_field': 1}, {})) - doc._changed_fields = [] + doc._changed_fields = set() dict_value = {'hello': 'world', 'ping': 'pong'} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) - doc._changed_fields = [] + doc._changed_fields = set() list_value = ['1', 2, {'hello': 'world'}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) # Test unsetting - doc._changed_fields = [] + doc._changed_fields = set() doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) - doc._changed_fields = [] + doc._changed_fields = set() doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + @unittest.skip("not fully implemented") def test_delta_recursive(self): self.delta_recursive(Document, EmbeddedDocument) self.delta_recursive(DynamicDocument, EmbeddedDocument) @@ -109,7 +110,7 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._delta(), ({}, {})) embedded_1 = Embedded() @@ -119,7 +120,7 @@ class DeltaTest(unittest.TestCase): embedded_1.list_field = ['1', 2, {'hello': 'world'}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ['embedded_field']) + self.assertEqual(doc._get_changed_fields(), set(['embedded_field'])) embedded_delta = { 'string_field': 'hello', @@ -136,7 +137,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.dict_field = {} self.assertEqual(doc._get_changed_fields(), - ['embedded_field.dict_field']) + set(['embedded_field.dict_field'])) self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) doc.save() @@ -145,7 +146,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = [] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + set(['embedded_field.list_field'])) self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) doc.save() @@ -160,7 +161,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = ['1', 2, embedded_2] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + set(['embedded_field.list_field'])) self.assertEqual(doc.embedded_field._delta(), ({ 'list_field': ['1', 2, { @@ -192,7 +193,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'world' self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field.2.string_field']) + set(['embedded_field.list_field.2.string_field'])) self.assertEqual(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) self.assertEqual(doc._delta(), @@ -206,7 +207,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + set(['embedded_field.list_field'])) self.assertEqual(doc.embedded_field._delta(), ({ 'list_field': ['1', 2, { '_cls': 'Embedded', @@ -269,7 +270,7 @@ class DeltaTest(unittest.TestCase): doc.dict_field['Embedded'].string_field = 'Hello World' self.assertEqual(doc._get_changed_fields(), - ['dict_field.Embedded.string_field']) + set(['dict_field.Embedded.string_field'])) self.assertEqual(doc._delta(), ({'dict_field.Embedded.string_field': 'Hello World'}, {})) @@ -371,39 +372,39 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._delta(), ({}, {})) doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['db_string_field']) + self.assertEqual(doc._get_changed_fields(), set(['string_field'])) self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {})) - doc._changed_fields = [] + doc._changed_fields = set() doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['db_int_field']) + self.assertEqual(doc._get_changed_fields(), set(['int_field'])) self.assertEqual(doc._delta(), ({'db_int_field': 1}, {})) - doc._changed_fields = [] + doc._changed_fields = set() dict_value = {'hello': 'world', 'ping': 'pong'} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) + self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {})) - doc._changed_fields = [] + doc._changed_fields = set() list_value = ['1', 2, {'hello': 'world'}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) + self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {})) # Test unsetting - doc._changed_fields = [] + doc._changed_fields = set() doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) + self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1})) - doc._changed_fields = [] + doc._changed_fields = set() doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) + self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._delta(), ({}, {'db_list_field': 1})) # Test it saves that data @@ -415,13 +416,15 @@ class DeltaTest(unittest.TestCase): doc.dict_field = {'hello': 'world'} doc.list_field = ['1', 2, {'hello': 'world'}] doc.save() - doc = doc.reload(10) + #doc = doc.reload(10) + doc = doc.reload() self.assertEqual(doc.string_field, 'hello') self.assertEqual(doc.int_field, 1) self.assertEqual(doc.dict_field, {'hello': 'world'}) self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}]) + @unittest.skip("not fully implemented") def test_delta_recursive_db_field(self): self.delta_recursive_db_field(Document, EmbeddedDocument) self.delta_recursive_db_field(Document, DynamicEmbeddedDocument) @@ -449,7 +452,7 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._delta(), ({}, {})) embedded_1 = Embedded() @@ -459,7 +462,7 @@ class DeltaTest(unittest.TestCase): embedded_1.list_field = ['1', 2, {'hello': 'world'}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field']) + self.assertEqual(doc._get_changed_fields(), set(['embedded_field'])) embedded_delta = { 'db_string_field': 'hello', @@ -487,7 +490,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = [] self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_list_field']) + set(['db_embedded_field.db_list_field'])) self.assertEqual(doc.embedded_field._delta(), ({}, {'db_list_field': 1})) self.assertEqual(doc._delta(), @@ -605,6 +608,7 @@ class DeltaTest(unittest.TestCase): self.assertEqual(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1})) + @unittest.skip("DynamicDocument not implemented") def test_delta_for_dynamic_documents(self): class Person(DynamicDocument): name = StringField() @@ -640,6 +644,7 @@ class DeltaTest(unittest.TestCase): p.save() self.assertEqual(1, self.Person.objects(age=24).count()) + @unittest.skip("DynamicDocument not implemented") def test_dynamic_delta(self): class Doc(DynamicDocument): diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py index 6263e68..05870ee 100644 --- a/tests/document/dynamic.py +++ b/tests/document/dynamic.py @@ -8,6 +8,7 @@ from mongoengine.connection import get_db __all__ = ("DynamicTest", ) +@unittest.skip("DynamicDocument not implemented") class DynamicTest(unittest.TestCase): def setUp(self): diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 04d5632..49fd7cb 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -632,6 +632,7 @@ class IndexesTest(unittest.TestCase): pass Customer.drop_collection() + @unittest.skip("behavior differs") def test_unique_and_primary(self): """If you set a field as primary, then unexpected behaviour can occur. You won't create a duplicate but you will update an existing document. diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 5a48f75..d311538 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -182,10 +182,10 @@ class InheritanceTest(unittest.TestCase): self.assertEqual(['age', 'id', 'name', 'salary'], sorted(Employee._fields.keys())) - self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), - ['_cls', 'name', 'age']) - self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ['_cls', 'name', 'age', 'salary']) + self.assertEqual(set(Person(name="Bob", age=35).to_mongo().keys()), + set(['_cls', 'name', 'age'])) + self.assertEqual(set(Employee(name="Bob", age=35, salary=0).to_mongo().keys()), + set(['_cls', 'name', 'age', 'salary'])) self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) diff --git a/tests/document/instance.py b/tests/document/instance.py index 81734aa..1df5d90 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -390,24 +390,27 @@ class InstanceTest(unittest.TestCase): doc.embedded_field = embedded_1 doc.save() - doc = doc.reload(10) + doc = doc.reload() doc.list_field.append(1) doc.dict_field['woot'] = "woot" doc.embedded_field.list_field.append(1) doc.embedded_field.dict_field['woot'] = "woot" - self.assertEqual(doc._get_changed_fields(), [ - 'list_field', 'dict_field', 'embedded_field.list_field', - 'embedded_field.dict_field']) + self.assertEqual(doc._get_changed_fields(), set([ + 'list_field', 'dict_field', 'embedded_field'])) + #self.assertEqual(doc._get_changed_fields(), [ + # 'list_field', 'dict_field', 'embedded_field.list_field', + # 'embedded_field.dict_field']) doc.save() - doc = doc.reload(10) - self.assertEqual(doc._get_changed_fields(), []) + doc = doc.reload() + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(len(doc.list_field), 4) self.assertEqual(len(doc.dict_field), 2) self.assertEqual(len(doc.embedded_field.list_field), 4) self.assertEqual(len(doc.embedded_field.dict_field), 2) + @unittest.skip("not implemented") def test_dictionary_access(self): """Ensure that dictionary-style field access works properly. """ @@ -438,10 +441,10 @@ class InstanceTest(unittest.TestCase): class Employee(Person): salary = IntField() - self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), - ['_cls', 'name', 'age']) - self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ['_cls', 'name', 'age', 'salary']) + self.assertEqual(set(Person(name="Bob", age=35).to_mongo().keys()), + set(['_cls', 'name', 'age'])) + self.assertEqual(set(Employee(name="Bob", age=35, salary=0).to_mongo().keys()), + set(['_cls', 'name', 'age', 'salary'])) def test_embedded_document(self): """Ensure that embedded documents are set up correctly. @@ -452,6 +455,7 @@ class InstanceTest(unittest.TestCase): self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) + @unittest.skip("not implemented") def test_embedded_document_instance(self): """Ensure that embedded documents can reference parent instance """ @@ -460,6 +464,7 @@ class InstanceTest(unittest.TestCase): class Doc(Document): embedded_field = EmbeddedDocumentField(Embedded) + meta = { 'cascade': True } Doc.drop_collection() Doc(embedded_field=Embedded(string="Hi")).save() @@ -467,6 +472,7 @@ class InstanceTest(unittest.TestCase): doc = Doc.objects.get() self.assertEqual(doc, doc.embedded_field._instance) + @unittest.skip("not implemented") def test_embedded_document_complex_instance(self): """Ensure that embedded documents in complex fields can reference parent instance""" @@ -623,6 +629,7 @@ class InstanceTest(unittest.TestCase): p0.name = 'wpjunior' p0.save() + @unittest.skip("FileField not implemented") def test_save_max_recursion_not_hit_with_file_field(self): class Foo(Document): @@ -771,6 +778,7 @@ class InstanceTest(unittest.TestCase): p1.reload() self.assertEqual(p1.name, p.parent.name) + @unittest.skip("not implemented") def test_update(self): """Ensure that an existing document is updated instead of be overwritten.""" @@ -885,7 +893,6 @@ class InstanceTest(unittest.TestCase): reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) - decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) @@ -1054,9 +1061,9 @@ class InstanceTest(unittest.TestCase): user = User.objects.first() # Even if stored as ObjectId's internally mongoengine uses DBRefs # As ObjectId's aren't automatically derefenced - self.assertTrue(isinstance(user._data['orgs'][0], DBRef)) + #self.assertTrue(isinstance(user._data['orgs'][0], DBRef)) self.assertTrue(isinstance(user.orgs[0], Organization)) - self.assertTrue(isinstance(user._data['orgs'][0], Organization)) + #self.assertTrue(isinstance(user._data['orgs'][0], Organization)) # Changing a value with query_counter() as q: @@ -1136,6 +1143,7 @@ class InstanceTest(unittest.TestCase): foo.save() self.assertEqual(1, q) + @unittest.skip("not implemented") def test_save_only_changed_fields_recursive(self): """Ensure save only sets / unsets changed fields """ @@ -1433,8 +1441,8 @@ class InstanceTest(unittest.TestCase): post_obj = BlogPost.objects.first() # Test laziness - self.assertTrue(isinstance(post_obj._data['author'], - bson.DBRef)) + #self.assertTrue(isinstance(post_obj._data['author'], + # bson.DBRef)) self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertEqual(post_obj.author.name, 'Test User') @@ -1458,6 +1466,7 @@ class InstanceTest(unittest.TestCase): self.assertRaises(InvalidDocumentError, throw_invalid_document_error) + @unittest.skip("not implemented") def test_invalid_son(self): """Raise an error if loading invalid data""" class Occurrence(EmbeddedDocument): @@ -1801,6 +1810,7 @@ class InstanceTest(unittest.TestCase): self.assertTrue(u1 in all_user_set) + @unittest.skip("not implemented") def test_picklable(self): pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) @@ -1827,6 +1837,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(pickle_doc.string, "Two") self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) + @unittest.skip("not implemented") def test_picklable_on_signals(self): pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2']) pickle_doc.embedded = PickleEmbedded() @@ -1887,6 +1898,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(Doc.objects(archived=False).count(), 1) + @unittest.skip("DynamicDocument not implemented") def test_can_save_false_values_dynamic(self): """Ensures you can save False values on dynamic docs""" class Doc(DynamicDocument): @@ -2026,6 +2038,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual('testdb-1', B._meta.get('db_alias')) + @unittest.skip("not implemented") def test_db_ref_usage(self): """ DB Ref usage in dict_fields""" @@ -2104,6 +2117,7 @@ class InstanceTest(unittest.TestCase): })]), "1,2") + @unittest.skip("not implemented") def test_switch_db_instance(self): register_connection('testdb-1', 'mongoenginetest2') @@ -2175,9 +2189,10 @@ class InstanceTest(unittest.TestCase): user = User.objects.first() self.assertEqual("Ross", user.username) self.assertEqual(True, user.foo) - self.assertEqual("Bar", user._data["foo"]) - self.assertEqual([1, 2, 3], user._data["data"]) + self.assertEqual("Bar", user._db_data["foo"]) + self.assertEqual([1, 2, 3], user._db_data["data"]) + @unittest.skip("DynamicDocument not implemented") def test_spaces_in_keys(self): class Embedded(DynamicEmbeddedDocument): @@ -2194,6 +2209,7 @@ class InstanceTest(unittest.TestCase): one = Doc.objects.filter(**{'hello world': 1}).count() self.assertEqual(1, one) + @unittest.skip("not implemented") def test_shard_key(self): class LogEntry(Document): machine = StringField() @@ -2217,6 +2233,7 @@ class InstanceTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) + @unittest.skip("not implemented") def test_shard_key_primary(self): class LogEntry(Document): machine = StringField(primary_key=True) @@ -2240,6 +2257,7 @@ class InstanceTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) + @unittest.skip("not implemented") def test_kwargs_simple(self): class Embedded(EmbeddedDocument): @@ -2254,8 +2272,9 @@ class InstanceTest(unittest.TestCase): "doc": {"name": "embedded doc"}}) self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc._data, dict_doc._data) + self.assertEqual(classic_doc.to_dict(), dict_doc.to_dict()) + @unittest.skip("not implemented") def test_kwargs_complex(self): class Embedded(EmbeddedDocument): @@ -2273,8 +2292,9 @@ class InstanceTest(unittest.TestCase): {"name": "embedded doc2"}]}) self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc._data, dict_doc._data) + self.assertEqual(classic_doc.to_dict(), dict_doc.to_dict()) + @unittest.skip("not implemented") def test_positional_creation(self): """Ensure that document may be created using positional arguments. """ @@ -2282,6 +2302,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 42) + @unittest.skip("not implemented") def test_mixed_creation(self): """Ensure that document may be created using mixed arguments. """ @@ -2307,8 +2328,8 @@ class InstanceTest(unittest.TestCase): Person(name="Harry Potter").save() person = Person.objects.first() - self.assertTrue('id' in person._data.keys()) - self.assertEqual(person._data.get('id'), person.id) + self.assertTrue('id' in person.to_dict().keys()) + self.assertEqual(person.to_dict().get('id'), person.id) def test_complex_nesting_document_and_embedded_document(self): diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index dbc09d8..1f2d5c8 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -58,7 +58,6 @@ class TestJson(unittest.TestCase): reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) - decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) diff --git a/tests/document/validation.py b/tests/document/validation.py index d3f3fd7..b8480a9 100644 --- a/tests/document/validation.py +++ b/tests/document/validation.py @@ -53,11 +53,12 @@ class ValidatorErrorTest(unittest.TestCase): self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") def test_model_validation(self): - class User(Document): username = StringField(primary_key=True) name = StringField(required=True) + User.drop_collection() + try: User().validate() except ValidationError, e: @@ -128,7 +129,7 @@ class ValidatorErrorTest(unittest.TestCase): Doc(id="test", e=SubDoc(val=15)).save() doc = Doc.objects.first() - keys = doc._data.keys() + keys = doc.to_dict().keys() self.assertEqual(2, len(keys)) self.assertTrue('e' in keys) self.assertTrue('id' in keys) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 3e48a21..e6a1a37 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -56,10 +56,11 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + data = person.to_dict() + self.assertEqual(data['name'], person.name) + self.assertEqual(data['age'], person.age) + self.assertEqual(data['userid'], person.userid) + self.assertEqual(data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -88,10 +89,11 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + data = person.to_dict() + self.assertEqual(data['name'], person.name) + self.assertEqual(data['age'], person.age) + self.assertEqual(data['userid'], person.userid) + self.assertEqual(data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -123,10 +125,12 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + data = person.to_dict() + + self.assertEqual(data['name'], person.name) + self.assertEqual(data['age'], person.age) + self.assertEqual(data['userid'], person.userid) + self.assertEqual(data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -157,10 +161,12 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + data = person.to_dict() + + self.assertEqual(data['name'], person.name) + self.assertEqual(data['age'], person.age) + self.assertEqual(data['userid'], person.userid) + self.assertEqual(data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -266,17 +272,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) - def test_long_ne_operator(self): - class TestDocument(Document): - long_fld = LongField() - - TestDocument.drop_collection() - - TestDocument(long_fld=None).save() - TestDocument(long_fld=1).save() - - self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) - def test_object_id_validation(self): """Ensure that invalid values cannot be assigned to string fields. """ @@ -350,23 +345,6 @@ class FieldTest(unittest.TestCase): person.age = 'ten' self.assertRaises(ValidationError, person.validate) - def test_long_validation(self): - """Ensure that invalid values cannot be assigned to long fields. - """ - class TestDocument(Document): - value = LongField(min_value=0, max_value=110) - - doc = TestDocument() - doc.value = 50 - doc.validate() - - doc.value = -1 - self.assertRaises(ValidationError, doc.validate) - doc.age = 120 - self.assertRaises(ValidationError, doc.validate) - doc.age = 'ten' - self.assertRaises(ValidationError, doc.validate) - def test_float_validation(self): """Ensure that invalid values cannot be assigned to float fields. """ @@ -384,69 +362,6 @@ class FieldTest(unittest.TestCase): person.height = 4.0 self.assertRaises(ValidationError, person.validate) - def test_decimal_validation(self): - """Ensure that invalid values cannot be assigned to decimal fields. - """ - class Person(Document): - height = DecimalField(min_value=Decimal('0.1'), - max_value=Decimal('3.5')) - - Person.drop_collection() - - Person(height=Decimal('1.89')).save() - person = Person.objects.first() - self.assertEqual(person.height, Decimal('1.89')) - - person.height = '2.0' - person.save() - person.height = 0.01 - self.assertRaises(ValidationError, person.validate) - person.height = Decimal('0.01') - self.assertRaises(ValidationError, person.validate) - person.height = Decimal('4.0') - self.assertRaises(ValidationError, person.validate) - - Person.drop_collection() - - def test_decimal_comparison(self): - - class Person(Document): - money = DecimalField() - - Person.drop_collection() - - Person(money=6).save() - Person(money=8).save() - Person(money=10).save() - - self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) - self.assertEqual(2, Person.objects(money__gt=7).count()) - self.assertEqual(2, Person.objects(money__gt="7").count()) - - def test_decimal_storage(self): - class Person(Document): - btc = DecimalField(precision=4) - - Person.drop_collection() - Person(btc=10).save() - Person(btc=10.1).save() - Person(btc=10.11).save() - Person(btc="10.111").save() - Person(btc=Decimal("10.1111")).save() - Person(btc=Decimal("10.11111")).save() - - # How its stored - expected = [{'btc': 10.0}, {'btc': 10.1}, {'btc': 10.11}, - {'btc': 10.111}, {'btc': 10.1111}, {'btc': 10.1111}] - actual = list(Person.objects.exclude('id').as_pymongo()) - self.assertEqual(expected, actual) - - # How it comes out locally - expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), - Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] - actual = list(Person.objects().scalar('btc')) - self.assertEqual(expected, actual) - def test_boolean_validation(self): """Ensure that invalid values cannot be assigned to boolean fields. """ @@ -532,10 +447,10 @@ class FieldTest(unittest.TestCase): log.time = datetime.datetime.now().isoformat('T') log.validate() - log.time = -1 - self.assertRaises(ValidationError, log.validate) - log.time = 'ABC' - self.assertRaises(ValidationError, log.validate) + #log.time = -1 + #self.assertRaises(ValidationError, log.validate) + #log.time = 'ABC' + #self.assertRaises(ValidationError, log.validate) def test_datetime_tz_aware_mark_as_changed(self): from mongoengine import connection @@ -556,7 +471,7 @@ class FieldTest(unittest.TestCase): log = LogEntry.objects.first() log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) - self.assertEqual(['time'], log._changed_fields) + self.assertEqual(set(['time']), log._changed_fields) def test_datetime(self): """Tests showing pymongo datetime fields handling of microseconds. @@ -791,8 +706,8 @@ class FieldTest(unittest.TestCase): post = BlogPost(content='Went for a walk today...') post.validate() - post.tags = 'fun' - self.assertRaises(ValidationError, post.validate) + #post.tags = 'fun' + #self.assertRaises(ValidationError, post.validate) post.tags = [1, 2] self.assertRaises(ValidationError, post.validate) @@ -903,11 +818,11 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() post = BlogPost() - post.info = 'my post' - self.assertRaises(ValidationError, post.validate) + #post.info = 'my post' + #self.assertRaises(ValidationError, post.validate) - post.info = {'title': 'test'} - self.assertRaises(ValidationError, post.validate) + #post.info = {'title': 'test'} + #self.assertRaises(ValidationError, post.validate) post.info = ['test'] post.save() @@ -964,6 +879,7 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() + @unittest.skip("different behavior") def test_list_field_rejects_strings(self): """Strings aren't valid list field data types""" @@ -1008,7 +924,7 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() e = Simple().save() e.mapping = [] - self.assertEqual([], e._changed_fields) + self.assertEqual(set([]), e._changed_fields) class Simple(Document): mapping = DictField() @@ -1016,8 +932,9 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() e = Simple().save() e.mapping = {} - self.assertEqual([], e._changed_fields) + self.assertEqual(set([]), e._changed_fields) + @unittest.skip("complex types not implemented") def test_list_field_complex(self): """Ensure that the list fields can handle the complex types.""" @@ -1074,11 +991,11 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() post = BlogPost() - post.info = 'my post' - self.assertRaises(ValidationError, post.validate) + #post.info = 'my post' + #self.assertRaises(ValidationError, post.validate) - post.info = ['test', 'test'] - self.assertRaises(ValidationError, post.validate) + #post.info = ['test', 'test'] + #self.assertRaises(ValidationError, post.validate) post.info = {'$title': 'test'} self.assertRaises(ValidationError, post.validate) @@ -1136,6 +1053,7 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() + @unittest.skip("complex types not implemented") def test_dictfield_complex(self): """Ensure that the dict field can handle the complex types.""" @@ -1953,6 +1871,7 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + @unittest.skip("not implemented") def test_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. @@ -2005,6 +1924,7 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + @unittest.skip("not implemented") def test_simple_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. @@ -2084,6 +2004,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(d2.data, {}) self.assertEqual(d2.data2, {}) + @unittest.skip("SequenceField not implemented") def test_sequence_field(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2109,6 +2030,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(c['next'], 1000) + @unittest.skip("SequenceField not implemented") def test_sequence_field_get_next_value(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2140,6 +2062,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(Person.id.get_next_value(), '1') + @unittest.skip("SequenceField not implemented") def test_sequence_field_sequence_name(self): class Person(Document): id = SequenceField(primary_key=True, sequence_name='jelly') @@ -2164,6 +2087,7 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) self.assertEqual(c['next'], 1000) + @unittest.skip("SequenceField not implemented") def test_multiple_sequence_fields(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2196,6 +2120,7 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) self.assertEqual(c['next'], 999) + @unittest.skip("SequenceField not implemented") def test_sequence_fields_reload(self): class Animal(Document): counter = SequenceField() @@ -2221,6 +2146,7 @@ class FieldTest(unittest.TestCase): a.reload() self.assertEqual(a.counter, 2) + @unittest.skip("SequenceField not implemented") def test_multiple_sequence_fields_on_docs(self): class Animal(Document): @@ -2255,6 +2181,7 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) self.assertEqual(c['next'], 10) + @unittest.skip("SequenceField not implemented") def test_sequence_field_value_decorator(self): class Person(Document): id = SequenceField(primary_key=True, value_decorator=str) @@ -2276,6 +2203,7 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) + @unittest.skip("SequenceField not implemented") def test_embedded_sequence_field(self): class Comment(EmbeddedDocument): id = SequenceField() diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index dfef9ee..9ad3fdd 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -24,6 +24,7 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') +@unittest.skip("FileField not implemented") class FileTest(unittest.TestCase): def setUp(self): diff --git a/tests/fields/geo.py b/tests/fields/geo.py index 31ded26..81f8a69 100644 --- a/tests/fields/geo.py +++ b/tests/fields/geo.py @@ -10,6 +10,7 @@ from mongoengine.connection import get_db __all__ = ("GeoFieldTest", ) +@unittest.skip("geo fields not implemented") class GeoFieldTest(unittest.TestCase): def setUp(self): diff --git a/tests/migration/__init__.py b/tests/migration/__init__.py index 6fc83e0..bff50c3 100644 --- a/tests/migration/__init__.py +++ b/tests/migration/__init__.py @@ -1,5 +1,4 @@ from convert_to_new_inheritance_model import * -from decimalfield_as_float import * from refrencefield_dbref_to_object_id import * from turn_off_inheritance import * from uuidfield_to_binary import * diff --git a/tests/migration/decimalfield_as_float.py b/tests/migration/decimalfield_as_float.py deleted file mode 100644 index 3903c91..0000000 --- a/tests/migration/decimalfield_as_float.py +++ /dev/null @@ -1,50 +0,0 @@ - # -*- coding: utf-8 -*- -import unittest -import decimal -from decimal import Decimal - -from mongoengine import Document, connect -from mongoengine.connection import get_db -from mongoengine.fields import StringField, DecimalField, ListField - -__all__ = ('ConvertDecimalField', ) - - -class ConvertDecimalField(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - def test_how_to_convert_decimal_fields(self): - """Demonstrates migrating from 0.7 to 0.8 - """ - - # 1. Old definition - using dbrefs - class Person(Document): - name = StringField() - money = DecimalField(force_string=True) - monies = ListField(DecimalField(force_string=True)) - - Person.drop_collection() - Person(name="Wilson Jr", money=Decimal("2.50"), - monies=[Decimal("2.10"), Decimal("5.00")]).save() - - # 2. Start the migration by changing the schema - # Change DecimalField - add precision and rounding settings - class Person(Document): - name = StringField() - money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP) - monies = ListField(DecimalField(precision=2, - rounding=decimal.ROUND_HALF_UP)) - - # 3. Loop all the objects and mark parent as changed - for p in Person.objects: - p._mark_as_changed('money') - p._mark_as_changed('monies') - p.save() - - # 4. Confirmation of the fix! - wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] - self.assertTrue(isinstance(wilson['money'], float)) - self.assertTrue(all([isinstance(m, float) for m in wilson['monies']])) diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index f564896..7e1c5df 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -8,6 +8,7 @@ from mongoengine import * __all__ = ("GeoQueriesTest",) +@unittest.skip("geo queries not implemented") class GeoQueriesTest(unittest.TestCase): def setUp(self): diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 21df22c..1ddfc27 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -760,10 +760,10 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 0) fresh_o1 = Organization.objects.get(id=o1.id) - fresh_o1.employees.append(p2) # Dereferences + fresh_o1.employees.append(p2) fresh_o1.save(cascade=False) # Saves - self.assertEqual(q, 3) + self.assertEqual(q, 2) def test_slave_okay(self): """Ensures that a query can take slave_okay syntax @@ -2879,19 +2879,6 @@ class QuerySetTest(unittest.TestCase): (u'Wilson Jr', 19, u'Corumba-GO'), (u'Gabriel Falcao', 23, u'New York')]) - def test_scalar_decimal(self): - from decimal import Decimal - class Person(Document): - name = StringField() - rating = DecimalField() - - Person.drop_collection() - Person(name="Wilson Jr", rating=Decimal('1.0')).save() - - ulist = list(Person.objects.scalar('name', 'rating')) - self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))]) - - def test_scalar_reference_field(self): class State(Document): name = StringField() @@ -3144,7 +3131,6 @@ class QuerySetTest(unittest.TestCase): objectid_field = ObjectIdField(default=ObjectId) reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) - decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) @@ -3175,30 +3161,25 @@ class QuerySetTest(unittest.TestCase): id = ObjectIdField('_id') name = StringField() age = IntField() - price = DecimalField() User.drop_collection() - User(name="Bob Dole", age=89, price=Decimal('1.11')).save() - User(name="Barack Obama", age=51, price=Decimal('2.22')).save() + User(name="Bob Dole", age=89).save() + User(name="Barack Obama", age=51).save() - users = User.objects.only('name', 'price').as_pymongo() + users = User.objects.only('name').as_pymongo() results = list(users) self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[1], dict)) self.assertEqual(results[0]['name'], 'Bob Dole') - self.assertEqual(results[0]['price'], 1.11) self.assertEqual(results[1]['name'], 'Barack Obama') - self.assertEqual(results[1]['price'], 2.22) # Test coerce_types - users = User.objects.only('name', 'price').as_pymongo(coerce_types=True) + users = User.objects.only('name').as_pymongo(coerce_types=True) results = list(users) self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[1], dict)) self.assertEqual(results[0]['name'], 'Bob Dole') - self.assertEqual(results[0]['price'], Decimal('1.11')) self.assertEqual(results[1]['name'], 'Barack Obama') - self.assertEqual(results[1]['price'], Decimal('2.22')) def test_as_pymongo_json_limit_fields(self): @@ -3222,6 +3203,7 @@ class QuerySetTest(unittest.TestCase): serialized_user = User.objects.exclude('password_salt').only('email').to_json() self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) + @unittest.skip("not implemented") def test_no_dereference(self): class Organization(Document): diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 7886965..53c1660 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -63,6 +63,7 @@ class TransformTest(unittest.TestCase): BlogPost.drop_collection() + @unittest.skip("unsupported") def test_query_pk_field_name(self): """Ensure that the correct "primary key" field name is used when querying diff --git a/tests/test_dereference.py b/tests/test_dereference.py index e146963..95bfbe9 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -16,6 +16,7 @@ class FieldTest(unittest.TestCase): connect(db='mongoenginetest') self.db = get_db() + @unittest.skip("select_related currently doesn't dereference lists") def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -74,6 +75,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() + @unittest.skip("select_related currently doesn't dereference lists") def test_list_item_dereference_dref_false(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -146,6 +148,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(Group._get_collection().find_one()['members'], [1]) self.assertEqual(group.members, [user]) + @unittest.skip('currently not implemented') def test_handle_old_style_references(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -179,6 +182,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(group.members[0].name, 'user 1') self.assertEqual(group.members[-1].name, 'String!') + @unittest.skip('currently not implemented') def test_migrate_references(self): """Example of migrating ReferenceField storage """ @@ -225,6 +229,7 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(raw_data['author'], ObjectId)) self.assertTrue(isinstance(raw_data['members'][0], ObjectId)) + @unittest.skip("select_related currently doesn't dereference lists") def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. """ @@ -259,9 +264,15 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 1) peter.boss - self.assertEqual(q, 2) + self.assertEqual(q, 1) peter.friends + self.assertEqual(q, 1) + + peter.boss.name + self.assertEqual(q, 2) + + peter.friends[0].name self.assertEqual(q, 3) # Document select_related @@ -391,6 +402,7 @@ class FieldTest(unittest.TestCase): "%s" % Person.objects() ) + @unittest.skip("not implemented") def test_generic_reference(self): class UserA(Document): @@ -482,6 +494,7 @@ class FieldTest(unittest.TestCase): UserC.drop_collection() Group.drop_collection() + @unittest.skip("not implemented") def test_list_field_complex(self): class UserA(Document): @@ -573,6 +586,7 @@ class FieldTest(unittest.TestCase): UserC.drop_collection() Group.drop_collection() + @unittest.skip('MapField not fully implemented') def test_map_field_reference(self): class User(Document): @@ -638,6 +652,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() + @unittest.skip("not implemented") def test_dict_field(self): class UserA(Document): @@ -741,6 +756,7 @@ class FieldTest(unittest.TestCase): UserC.drop_collection() Group.drop_collection() + @unittest.skip("not implemented") def test_dict_field_no_field_inheritance(self): class UserA(Document): @@ -817,6 +833,7 @@ class FieldTest(unittest.TestCase): UserA.drop_collection() Group.drop_collection() + @unittest.skip("select_related currently doesn't dereference lists") def test_generic_reference_map_field(self): class UserA(Document): @@ -942,6 +959,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(root.children, [company]) self.assertEqual(company.parents, [root]) + @unittest.skip("not implemented") def test_dict_in_dbref_instance(self): class Person(Document): diff --git a/tests/test_signals.py b/tests/test_signals.py index 50e5e6b..da217c0 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -30,28 +30,10 @@ class SignalTests(unittest.TestCase): def __unicode__(self): return self.name - @classmethod - def pre_init(cls, sender, document, *args, **kwargs): - signal_output.append('pre_init signal, %s' % cls.__name__) - signal_output.append(str(kwargs['values'])) - - @classmethod - def post_init(cls, sender, document, **kwargs): - signal_output.append('post_init signal, %s' % document) - @classmethod def pre_save(cls, sender, document, **kwargs): signal_output.append('pre_save signal, %s' % document) - @classmethod - def pre_save_post_validation(cls, sender, document, **kwargs): - signal_output.append('pre_save_post_validation signal, %s' % document) - if 'created' in kwargs: - if kwargs['created']: - signal_output.append('Is created') - else: - signal_output.append('Is updated') - @classmethod def post_save(cls, sender, document, **kwargs): signal_output.append('post_save signal, %s' % document) @@ -118,10 +100,7 @@ class SignalTests(unittest.TestCase): # Save up the number of connected signals so that we can check at the # end that all the signals we register get properly unregistered self.pre_signals = ( - len(signals.pre_init.receivers), - len(signals.post_init.receivers), len(signals.pre_save.receivers), - len(signals.pre_save_post_validation.receivers), len(signals.post_save.receivers), len(signals.pre_delete.receivers), len(signals.post_delete.receivers), @@ -129,10 +108,7 @@ class SignalTests(unittest.TestCase): len(signals.post_bulk_insert.receivers), ) - signals.pre_init.connect(Author.pre_init, sender=Author) - signals.post_init.connect(Author.post_init, sender=Author) signals.pre_save.connect(Author.pre_save, sender=Author) - signals.pre_save_post_validation.connect(Author.pre_save_post_validation, sender=Author) signals.post_save.connect(Author.post_save, sender=Author) signals.pre_delete.connect(Author.pre_delete, sender=Author) signals.post_delete.connect(Author.post_delete, sender=Author) @@ -145,12 +121,9 @@ class SignalTests(unittest.TestCase): signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) def tearDown(self): - signals.pre_init.disconnect(self.Author.pre_init) - signals.post_init.disconnect(self.Author.post_init) signals.post_delete.disconnect(self.Author.post_delete) signals.pre_delete.disconnect(self.Author.pre_delete) signals.post_save.disconnect(self.Author.post_save) - signals.pre_save_post_validation.disconnect(self.Author.pre_save_post_validation) signals.pre_save.disconnect(self.Author.pre_save) signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) @@ -162,10 +135,7 @@ class SignalTests(unittest.TestCase): # Check that all our signals got disconnected properly. post_signals = ( - len(signals.pre_init.receivers), - len(signals.post_init.receivers), len(signals.pre_save.receivers), - len(signals.pre_save_post_validation.receivers), len(signals.post_save.receivers), len(signals.pre_delete.receivers), len(signals.post_delete.receivers), @@ -180,9 +150,6 @@ class SignalTests(unittest.TestCase): def test_model_signals(self): """ Model saves should throw some signals. """ - def create_author(): - self.Author(name='Bill Shakespeare') - def bulk_create_author_with_load(): a1 = self.Author(name='Bill Shakespeare') self.Author.objects.insert([a1], load_bulk=True) @@ -191,17 +158,9 @@ class SignalTests(unittest.TestCase): a1 = self.Author(name='Bill Shakespeare') self.Author.objects.insert([a1], load_bulk=False) - self.assertEqual(self.get_signal_output(create_author), [ - "pre_init signal, Author", - "{'name': 'Bill Shakespeare'}", - "post_init signal, Bill Shakespeare", - ]) - a1 = self.Author(name='Bill Shakespeare') self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, Bill Shakespeare", - "pre_save_post_validation signal, Bill Shakespeare", - "Is created", "post_save signal, Bill Shakespeare", "Is created" ]) @@ -210,8 +169,6 @@ class SignalTests(unittest.TestCase): a1.name = 'William Shakespeare' self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, William Shakespeare", - "pre_save_post_validation signal, William Shakespeare", - "Is updated", "post_save signal, William Shakespeare", "Is updated" ]) @@ -223,18 +180,13 @@ class SignalTests(unittest.TestCase): signal_output = self.get_signal_output(bulk_create_author_with_load) - # The output of this signal is not entirely deterministic. The reloaded - # object will have an object ID. Hence, we only check part of the output - self.assertEqual(signal_output[3], - "pre_bulk_insert signal, []") - self.assertEqual(signal_output[-2:], - ["post_bulk_insert signal, []", - "Is loaded",]) + self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [ + "pre_bulk_insert signal, []", + "post_bulk_insert signal, []", + "Is loaded", + ]) self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [ - "pre_init signal, Author", - "{'name': 'Bill Shakespeare'}", - "post_init signal, Bill Shakespeare", "pre_bulk_insert signal, []", "post_bulk_insert signal, []", "Not loaded", From 929a1b222c871890a4f606d795146de73aa4cb24 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Fri, 14 Jun 2013 20:06:12 -0700 Subject: [PATCH 02/18] Add support for DocumentProxy so we can access obj.ref.pk without fetching ref and still get the correct type even if ref can be inherited. --- mongoengine/base/document.py | 10 +- mongoengine/base/proxy.py | 193 +++++++++++++++++++++++++++++++++++ mongoengine/fields.py | 35 ++++--- setup.py | 5 +- 4 files changed, 226 insertions(+), 17 deletions(-) create mode 100644 mongoengine/base/proxy.py diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 1195bc4..bdc97ec 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -15,6 +15,7 @@ from mongoengine.errors import (ValidationError, InvalidDocumentError, from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, to_str_keys_recursive) +from mongoengine.base.proxy import DocumentProxy from mongoengine.base.common import get_document, ALLOW_INHERITANCE from mongoengine.base.datastructures import BaseDict, BaseList from mongoengine.base.fields import ComplexBaseField @@ -109,9 +110,12 @@ class BaseDocument(object): return txt_type('%s object' % self.__class__.__name__) def __eq__(self, other): - if isinstance(other, self.__class__) and hasattr(other, 'pk'): - if self.pk == other.pk: - return True + if isinstance(other, DocumentProxy) and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk') and self.pk == other.pk: + return True + + if isinstance(other, self.__class__) and hasattr(other, 'pk') and self.pk == other.pk: + return True + return False def __ne__(self, other): diff --git a/mongoengine/base/proxy.py b/mongoengine/base/proxy.py new file mode 100644 index 0000000..4d92462 --- /dev/null +++ b/mongoengine/base/proxy.py @@ -0,0 +1,193 @@ +from mongoengine.queryset import OperationError +from bson.dbref import DBRef + +class LocalProxy(object): + # From werkzeug/local.py + + """ Forwards all operations to + a proxied object. The only operations not supported for forwarding + are right handed operands and any kind of assignment. + """ + + __slots__ = ('__local', '__dict__', '__name__') + + def __init__(self, local, name=None): + object.__setattr__(self, '_LocalProxy__local', local) + object.__setattr__(self, '__name__', name) + + def _get_current_object(self): + """Return the current object. This is useful if you want the real + object behind the proxy at a time for performance reasons or because + you want to pass the object into a different context. + """ + if not hasattr(self.__local, '__release_local__'): + return self.__local() + try: + return getattr(self.__local, self.__name__) + except AttributeError: + raise RuntimeError('no object bound to %s' % self.__name__) + + @property + def __dict__(self): + try: + return self._get_current_object().__dict__ + except RuntimeError: + raise AttributeError('__dict__') + + def __repr__(self): + try: + obj = self._get_current_object() + except RuntimeError: + return '<%s unbound>' % self.__class__.__name__ + return repr(obj) + + def __nonzero__(self): + try: + return bool(self._get_current_object()) + except RuntimeError: + return False + + def __unicode__(self): + try: + return unicode(self._get_current_object()) + except RuntimeError: + return repr(self) + + def __dir__(self): + try: + return dir(self._get_current_object()) + except RuntimeError: + return [] + + def __getattr__(self, name): + if name == '__members__': + return dir(self._get_current_object()) + return getattr(self._get_current_object(), name) + + def __setitem__(self, key, value): + self._get_current_object()[key] = value + + def __delitem__(self, key): + del self._get_current_object()[key] + + def __setslice__(self, i, j, seq): + self._get_current_object()[i:j] = seq + + def __delslice__(self, i, j): + del self._get_current_object()[i:j] + + __setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v) + __delattr__ = lambda x, n: delattr(x._get_current_object(), n) + __str__ = lambda x: str(x._get_current_object()) + __lt__ = lambda x, o: x._get_current_object() < o + __le__ = lambda x, o: x._get_current_object() <= o + __eq__ = lambda x, o: x._get_current_object() == o + __ne__ = lambda x, o: x._get_current_object() != o + __gt__ = lambda x, o: x._get_current_object() > o + __ge__ = lambda x, o: x._get_current_object() >= o + __cmp__ = lambda x, o: cmp(x._get_current_object(), o) + __hash__ = lambda x: hash(x._get_current_object()) + __call__ = lambda x, *a, **kw: x._get_current_object()(*a, **kw) + __len__ = lambda x: len(x._get_current_object()) + __getitem__ = lambda x, i: x._get_current_object()[i] + __iter__ = lambda x: iter(x._get_current_object()) + __contains__ = lambda x, i: i in x._get_current_object() + __getslice__ = lambda x, i, j: x._get_current_object()[i:j] + __add__ = lambda x, o: x._get_current_object() + o + __sub__ = lambda x, o: x._get_current_object() - o + __mul__ = lambda x, o: x._get_current_object() * o + __floordiv__ = lambda x, o: x._get_current_object() // o + __mod__ = lambda x, o: x._get_current_object() % o + __divmod__ = lambda x, o: x._get_current_object().__divmod__(o) + __pow__ = lambda x, o: x._get_current_object() ** o + __lshift__ = lambda x, o: x._get_current_object() << o + __rshift__ = lambda x, o: x._get_current_object() >> o + __and__ = lambda x, o: x._get_current_object() & o + __xor__ = lambda x, o: x._get_current_object() ^ o + __or__ = lambda x, o: x._get_current_object() | o + __div__ = lambda x, o: x._get_current_object().__div__(o) + __truediv__ = lambda x, o: x._get_current_object().__truediv__(o) + __neg__ = lambda x: -(x._get_current_object()) + __pos__ = lambda x: +(x._get_current_object()) + __abs__ = lambda x: abs(x._get_current_object()) + __invert__ = lambda x: ~(x._get_current_object()) + __complex__ = lambda x: complex(x._get_current_object()) + __int__ = lambda x: int(x._get_current_object()) + __long__ = lambda x: long(x._get_current_object()) + __float__ = lambda x: float(x._get_current_object()) + __oct__ = lambda x: oct(x._get_current_object()) + __hex__ = lambda x: hex(x._get_current_object()) + __index__ = lambda x: x._get_current_object().__index__() + __coerce__ = lambda x, o: x.__coerce__(x, o) + __enter__ = lambda x: x.__enter__() + __exit__ = lambda x, *a, **kw: x.__exit__(*a, **kw) + + +class DocumentProxy(LocalProxy): + __slots__ = ('__document_type', '__document', '__pk') + + def __init__(self, document_type, pk): + object.__setattr__(self, '_DocumentProxy__document_type', document_type) + object.__setattr__(self, '_DocumentProxy__document', None) + object.__setattr__(self, '_DocumentProxy__pk', pk) + object.__setattr__(self, document_type._meta['id_field'], self.pk) + + @property + def __class__(self): + # We need to fetch the object to determine to which class it belongs. + return self._get_current_object().__class__ + + def _lazy(): + def fget(self): + return self.__document._lazy if self.__document else True + def fset(self, value): + self._get_current_object()._lazy = value + return property(fget, fset) + _lazy = _lazy() + + # copy normally updates __dict__ which would result in errors + def __setstate__(self, state): + for k, v in state[1].iteritems(): + object.__setattr__(self, k, v) + + def _get_collection_name(self): + return self.__document_type._meta.get('collection', None) + + def __eq__(self, other): + if other and hasattr(other, '_get_collection_name') and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk'): + if self.pk == other.pk: + return True + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def to_dbref(self): + """Returns an instance of :class:`~bson.dbref.DBRef` useful in + `__raw__` queries.""" + if not self.pk: + msg = "Only saved documents can have a valid dbref" + raise OperationError(msg) + return DBRef(self._get_collection_name(), self.pk) + + def pk(): + def fget(self): + return self.__document.pk if self.__document else self.__pk + def fset(self, value): + self._get_current_object().pk = value + return property(fget, fset) + pk = pk() + + def _get_current_object(self): + if self.__document == None: + #print 'fetching', self.__document_type, self.__pk + #import traceback + #traceback.print_stack() + collection = self.__document_type._get_collection() + son = collection.find_one({'_id': self.__pk}) + document = self.__document_type._from_son(son) + object.__setattr__(self, '_DocumentProxy__document', document) + return self.__document + + def __nonzero__(self): + return True diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 8811f7f..0c2056a 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -26,6 +26,7 @@ from mongoengine.python_support import (PY3, bin_type, txt_type, from mongoengine.base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, get_document, BaseDocument) from mongoengine.base.datastructures import BaseList, BaseDict +from mongoengine.base.proxy import DocumentProxy from mongoengine.queryset import DoesNotExist from queryset import DO_NOTHING, QuerySet from document import Document, EmbeddedDocument @@ -516,7 +517,7 @@ class ListField(ComplexBaseField): super(ListField, self).__init__(**kwargs) def value_for_instance(self, value, instance): - return BaseList(value, instance, self.name) + return BaseList(value or [], instance, self.name) def from_python(self, val): from_python = getattr(self.field, 'from_python', None) @@ -615,7 +616,7 @@ class DictField(ComplexBaseField): return {k: to_python(v) for k, v in val.iteritems()} if to_python else val def value_for_instance(self, value, instance): - return BaseDict(value, instance, self.name) + return BaseDict(value or {}, instance, self.name) def to_mongo(self, val): to_mongo = getattr(self.field, 'to_mongo', None) @@ -734,7 +735,7 @@ class ReferenceField(BaseField): return value else: return value.id - elif isinstance(value, Document): + elif isinstance(value, (Document, DocumentProxy)): document_type = self.document_type # We need the id from the saved object to create the DBRef pk = value.pk @@ -758,17 +759,22 @@ class ReferenceField(BaseField): if value != None: document_type = self.document_type if self.dbref: - obj = document_type(pk=value.id) + pk = value.id else: if isinstance(value, DBRef): - obj = document_type(pk=value.id) + pk = value.id else: - obj = document_type(pk=value) - obj._lazy = True + pk = value + if document_type._meta['allow_inheritance']: + # We don't know of which type the object will be. + obj = DocumentProxy(document_type, pk) + else: + obj = document_type(pk=pk) + obj._lazy = True return obj def from_python(self, value): - if isinstance(value, BaseDocument): + if isinstance(value, (BaseDocument, DocumentProxy)): return value elif value == None: return super(ReferenceField, self).from_python(value) @@ -780,17 +786,22 @@ class ReferenceField(BaseField): # DBRef or ID document_type = self.document_type if isinstance(value, DBRef): - obj = document_type(pk=value.id) + pk = value.id else: - obj = document_type(pk=value) - obj._lazy = True + pk = value + if document_type._meta['allow_inheritance']: + # We don't know of which type the object will be. + obj = DocumentProxy(document_type, pk) + else: + obj = document_type(pk=pk) + obj._lazy = True return obj def prepare_query_value(self, op, value): return self.to_mongo(self.from_python(value)) def validate(self, value): - if not isinstance(value, (self.document_type, DBRef)): + if not isinstance(value, (self.document_type, DBRef, DocumentProxy)): self.error("A ReferenceField only accepts DBRef or documents") if isinstance(value, Document) and value.pk is None: diff --git a/setup.py b/setup.py index effb6f1..7a9f360 100644 --- a/setup.py +++ b/setup.py @@ -51,13 +51,14 @@ CLASSIFIERS = [ extra_opts = {} if sys.version_info[0] == 3: extra_opts['use_2to3'] = True - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2>=2.6'] extra_opts['packages'] = find_packages(exclude=('tests',)) if "test" in sys.argv or "nosetests" in sys.argv: extra_opts['packages'].append("tests") extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} else: - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2==2.6', 'python-dateutil'] + #extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2>=2.6'] extra_opts['packages'] = find_packages(exclude=('tests',)) setup(name='mongoengine', From 6017b426b494ee69c772c464b84e32e6f54ccb6c Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Sat, 15 Jun 2013 11:58:20 -0700 Subject: [PATCH 03/18] Remove StringField.to_mongo --- mongoengine/fields.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0c2056a..42971b8 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -63,9 +63,6 @@ class StringField(BaseField): self.min_length = min_length super(StringField, self).__init__(**kwargs) - def to_mongo(self, value): - return value or None - def validate(self, value): if not isinstance(value, basestring): self.error('StringField only accepts string values') From 0217ca184c7d265df0ca6c2cda8b54060f48848d Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Sun, 16 Jun 2013 00:03:12 -0700 Subject: [PATCH 04/18] Recursive value_for_instance --- mongoengine/fields.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 42971b8..b270984 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -513,8 +513,13 @@ class ListField(ComplexBaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def value_for_instance(self, value, instance): - return BaseList(value or [], instance, self.name) + def value_for_instance(self, value, instance, name=None): + name = name or self.name + if value and self.field: + value_for_instance = getattr(self.field, 'value_for_instance', None) + if value_for_instance: + value = [value_for_instance(v, instance, name) for v in value] + return BaseList(value or [], instance, name) def from_python(self, val): from_python = getattr(self.field, 'from_python', None) @@ -612,8 +617,13 @@ class DictField(ComplexBaseField): to_python = getattr(self.field, 'to_python', None) return {k: to_python(v) for k, v in val.iteritems()} if to_python else val - def value_for_instance(self, value, instance): - return BaseDict(value or {}, instance, self.name) + def value_for_instance(self, value, instance, name=None): + name = name or self.name + if value and self.field: + value_for_instance = getattr(self.field, 'value_for_instance', None) + if value_for_instance: + value = {k: value_for_instance(v, instance, name) for k, v in value.iteritems()} + return BaseDict(value or {}, instance, name) def to_mongo(self, val): to_mongo = getattr(self.field, 'to_mongo', None) From 6fbd931ffafbd40efb11aee522553f340a612750 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Sun, 16 Jun 2013 14:13:49 -0700 Subject: [PATCH 05/18] Adding (passing) unit test from https://github.com/MongoEngine/mongoengine/pull/368 --- tests/test_dereference.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 95bfbe9..2758ce7 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -302,6 +302,31 @@ class FieldTest(unittest.TestCase): self.assertEqual(employee.friends, friends) self.assertEqual(q, 2) + def test_list_of_lists_of_references(self): + + class User(Document): + name = StringField() + + class Post(Document): + user_lists = ListField(ListField(ReferenceField(User))) + + class SimpleList(Document): + users = ListField(ReferenceField(User)) + + User.drop_collection() + Post.drop_collection() + + u1 = User.objects.create(name='u1') + u2 = User.objects.create(name='u2') + u3 = User.objects.create(name='u3') + + SimpleList.objects.create(users=[u1, u2, u3]) + self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3]) + + Post.objects.create(user_lists=[[u1, u2], [u3]]) + self.assertEqual(Post.objects.all()[0].user_lists, [[u1, u2], [u3]]) + + def test_circular_reference(self): """Ensure you can handle circular references """ From 04c30028892a179ffe370ed868dde7832db33d5b Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Sun, 16 Jun 2013 14:15:26 -0700 Subject: [PATCH 06/18] Add SimpleList.drop_collection to previous unit test --- tests/test_dereference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 2758ce7..f13f291 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -315,6 +315,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() Post.drop_collection() + SimpleList.drop_collection() u1 = User.objects.create(name='u1') u2 = User.objects.create(name='u2') From a3ba2f53b192f1f8697b89ede0b93e80a800c700 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Sun, 16 Jun 2013 14:34:49 -0700 Subject: [PATCH 07/18] Removing unneeded signals --- mongoengine/signals.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mongoengine/signals.py b/mongoengine/signals.py index 06fb8b4..1c256da 100644 --- a/mongoengine/signals.py +++ b/mongoengine/signals.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- -__all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', - 'post_save', 'pre_delete', 'post_delete'] +__all__ = ['pre_save', 'post_save', 'pre_delete', 'post_delete'] signals_available = False try: @@ -36,10 +35,7 @@ except ImportError: # not put signals in here. Create your own namespace instead. _signals = Namespace() -pre_init = _signals.signal('pre_init') -post_init = _signals.signal('post_init') pre_save = _signals.signal('pre_save') -pre_save_post_validation = _signals.signal('pre_save_post_validation') post_save = _signals.signal('post_save') pre_delete = _signals.signal('pre_delete') post_delete = _signals.signal('post_delete') From d189b6f0703456c3a0c7ce4ef010d23ce30bc72a Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Sun, 16 Jun 2013 14:44:25 -0700 Subject: [PATCH 08/18] Fix setup.py and AUTHORS --- AUTHORS | 4 ++-- setup.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/AUTHORS b/AUTHORS index 7788139..e16f23e 100644 --- a/AUTHORS +++ b/AUTHORS @@ -9,6 +9,7 @@ Steve Challis Wilson Júnior Dan Crosta https://github.com/dcrosta Laine Herron https://github.com/LaineHerron +Thomas Steinacher http://thomasst.ch/ CONTRIBUTORS @@ -114,7 +115,6 @@ that much better: * Alexander Koshelev * Jaime Irurzun * Alexandre González - * Thomas Steinacher * Tommi Komulainen * Peter Landry * biszkoptwielki @@ -169,4 +169,4 @@ that much better: * Massimo Santini (https://github.com/mapio) * Nigel McNie (https://github.com/nigelmcnie) * ygbourhis (https://github.com/ygbourhis) - * Bob Dickinson (https://github.com/BobDickinson) \ No newline at end of file + * Bob Dickinson (https://github.com/BobDickinson) diff --git a/setup.py b/setup.py index 7a9f360..effb6f1 100644 --- a/setup.py +++ b/setup.py @@ -51,14 +51,13 @@ CLASSIFIERS = [ extra_opts = {} if sys.version_info[0] == 3: extra_opts['use_2to3'] = True - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2>=2.6'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6'] extra_opts['packages'] = find_packages(exclude=('tests',)) if "test" in sys.argv or "nosetests" in sys.argv: extra_opts['packages'].append("tests") extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} else: - #extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6'] - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2>=2.6'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2==2.6', 'python-dateutil'] extra_opts['packages'] = find_packages(exclude=('tests',)) setup(name='mongoengine', From 8616c17700b5923d8c78381868e7691097e21b9d Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Sun, 16 Jun 2013 14:46:06 -0700 Subject: [PATCH 09/18] Adding DIFFERENCES file to compare Mongomallard and Mongoengine --- DIFFERENCES.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 DIFFERENCES.md diff --git a/DIFFERENCES.md b/DIFFERENCES.md new file mode 100644 index 0000000..d25f6ca --- /dev/null +++ b/DIFFERENCES.md @@ -0,0 +1,36 @@ +Differences between Mongomallard and Mongoengine +----- + +* All document fields are lazily evaluated, resulting in much faster object initialization time. +* `_data` is removed due to lazy evaluation. `to_dict()` can be used to convert a document to a dictionary, and `_internal_data` contains previously evaluated data. +* Field methods `to_python`, `from_python`, `to_mongo`, `value_for_instance`: + * `to_python` is called when converting from a MongoDB type to a document Python type only. + * `from_python` is called when converting an assignment in Python to the document Python type. + * `to_mongo` is called when converting from a document Python type to a MongoDB type. + * `value_for_instance` is called just before returning a value in Python allowing for instance-specific transformations. +* `pre_init`, `post_init`, `pre_save_post_validation` signals are removed to ensure fast object initialization. +* `DecimalField` is removed since there is no corresponding MongoDB type +* `LongField` is removed since it is equivalent with `IntField` +* Adding `SafeReferenceField` which returns None if the reference does not exist. +* Adding `SafeReferenceListField` which doesn't return references that don't exist. +* Accessing a `ListField(ReferenceField)` doesn't automatically dereference all objects since they are lazily evaluated. A `SafeReferenceListField` may be used instead. +* Accessing a related object's id doesn't fetch the object from the database, e.g. `book.author.id` where author is a `ReferenceField` will not make a database lookup except when using a `SafeReferenceField`. When inheritance is allowed, a proxy object will be returned, otherwise a lazy object from the referenced document class will be returned. +* The primary key is only stored as `_id` in the database and is referenced in Python as `pk` or as the name of the primary key field. +* Saves are not cascaded by default. +* `Document.save()` supports `full=True` keyword argument to force saving all model fields. +* `_get_changed_fields()` / `_changed_fields` returns a set +* Simplified `EmailField` email regex to be more compatible + +Untested / not implemented yet: +----- + +* Delta updates for lists / embedded documents (fields that have changed on the document are fully updated) +* Dynamic documents / `DynamicField`, dynamic addition/deletion of fields +* Field display name methods +* `SequenceField` +* Pickling documents +* `FileField` +* All Geo fields +* `no_dereference()` +* using `SafeReferenceListField` with `GenericReferenceField` +* `max_depth` argument for `doc.reload()` From 628a8dacfaeea98e94882620b334f097a0911fb1 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Sun, 16 Jun 2013 23:28:31 -0700 Subject: [PATCH 10/18] Fix select_related for documents --- mongoengine/dereference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index b9d79e6..79f755f 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -89,6 +89,8 @@ class DeReference(object): v = getattr(item, field_name) if isinstance(v, (DBRef)): reference_map.setdefault(field.document_type, []).append(v.id) + elif isinstance(v, Document) and getattr(v, '_lazy', False): + reference_map.setdefault(field.document_type, []).append(v.pk) elif isinstance(v, (dict, SON)) and '_ref' in v: reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: @@ -196,6 +198,8 @@ class DeReference(object): v = data[k]._internal_data.get(field_name, None) if isinstance(v, (DBRef)): data[k]._internal_data[field_name] = self.object_map.get(v.id, v) + elif isinstance(v, Document) and getattr(v, '_lazy', False): + data[k]._internal_data[field_name] = self.object_map.get(v.pk, v) elif isinstance(v, (dict, SON)) and '_ref' in v: data[k]._internal_data[field_name] = self.object_map.get(v['_ref'].id, v) elif isinstance(v, dict) and depth <= self.max_depth: From 17a1844b6ff7a26aa3522a38a21a6b586dca5b26 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Mon, 17 Jun 2013 06:35:52 -0700 Subject: [PATCH 11/18] Don't raise exception if ListField/DictField are stored as None --- mongoengine/fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index b270984..ec206d0 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -527,7 +527,7 @@ class ListField(ComplexBaseField): def to_python(self, val): to_python = getattr(self.field, 'to_python', None) - return [to_python(v) for v in val] if to_python else val + return [to_python(v) for v in val] if to_python and val else val or None def to_mongo(self, val): to_mongo = getattr(self.field, 'to_mongo', None) @@ -615,7 +615,7 @@ class DictField(ComplexBaseField): def to_python(self, val): to_python = getattr(self.field, 'to_python', None) - return {k: to_python(v) for k, v in val.iteritems()} if to_python else val + return {k: to_python(v) for k, v in val.iteritems()} if to_python and val else val or None def value_for_instance(self, value, instance, name=None): name = name or self.name From 069aedabc1a7fd6cdb7aed6196ee0aad43457560 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Mon, 17 Jun 2013 11:15:45 -0700 Subject: [PATCH 12/18] IntField: Automatically convert strings to integers --- mongoengine/fields.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index ec206d0..57c4154 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -149,6 +149,9 @@ class IntField(BaseField): self.min_value, self.max_value = min_value, max_value super(IntField, self).__init__(**kwargs) + def from_python(self, value): + return self.prepare_query_value(None, value) + def validate(self, value): try: value = int(value) @@ -161,6 +164,12 @@ class IntField(BaseField): if self.max_value is not None and value > self.max_value: self.error('Integer value is too large') + def prepare_query_value(self, op, value): + if value is None: + return value + else: + return int(value) + class FloatField(BaseField): """A floating point number field. From 2a93a69085fc6bc080c31371b152331195bef679 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Tue, 18 Jun 2013 18:31:11 -0700 Subject: [PATCH 13/18] Fix IntField assignment unit tests --- DIFFERENCES.md | 1 + tests/document/validation.py | 9 ++------- tests/fields/fields.py | 9 +++------ 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/DIFFERENCES.md b/DIFFERENCES.md index d25f6ca..4b51bd7 100644 --- a/DIFFERENCES.md +++ b/DIFFERENCES.md @@ -20,6 +20,7 @@ Differences between Mongomallard and Mongoengine * `Document.save()` supports `full=True` keyword argument to force saving all model fields. * `_get_changed_fields()` / `_changed_fields` returns a set * Simplified `EmailField` email regex to be more compatible +* Assigning invalid types (e.g. an invalid string to `IntField`) raises immediately a `ValueError` Untested / not implemented yet: ----- diff --git a/tests/document/validation.py b/tests/document/validation.py index b8480a9..4637dee 100644 --- a/tests/document/validation.py +++ b/tests/document/validation.py @@ -134,13 +134,8 @@ class ValidatorErrorTest(unittest.TestCase): self.assertTrue('e' in keys) self.assertTrue('id' in keys) - doc.e.val = "OK" - try: - doc.save() - except ValidationError, e: - self.assertTrue("Doc:test" in e.message) - self.assertEqual(e.to_dict(), { - "e": {'val': 'OK could not be converted to int'}}) + with self.assertRaises(ValueError): + doc.e.val = "OK" if __name__ == '__main__': diff --git a/tests/fields/fields.py b/tests/fields/fields.py index e6a1a37..1eea2ac 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -342,8 +342,8 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, person.validate) person.age = 120 self.assertRaises(ValidationError, person.validate) - person.age = 'ten' - self.assertRaises(ValidationError, person.validate) + with self.assertRaises(ValueError): + person.age = 'ten' def test_float_validation(self): """Ensure that invalid values cannot be assigned to float fields. @@ -871,11 +871,8 @@ class FieldTest(unittest.TestCase): e.mapping = [1] e.save() - def create_invalid_mapping(): + with self.assertRaises(ValueError): e.mapping = ["abc"] - e.save() - - self.assertRaises(ValidationError, create_invalid_mapping) Simple.drop_collection() From 5e787404f570efa30edd40666b473b372eeb6071 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Tue, 18 Jun 2013 18:31:24 -0700 Subject: [PATCH 14/18] Allow to reset ordering by calling order_by() with no arguments --- DIFFERENCES.md | 1 + mongoengine/queryset/queryset.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/DIFFERENCES.md b/DIFFERENCES.md index 4b51bd7..3c524ef 100644 --- a/DIFFERENCES.md +++ b/DIFFERENCES.md @@ -21,6 +21,7 @@ Differences between Mongomallard and Mongoengine * `_get_changed_fields()` / `_changed_fields` returns a set * Simplified `EmailField` email regex to be more compatible * Assigning invalid types (e.g. an invalid string to `IntField`) raises immediately a `ValueError` +* `order_by()` without an argument resets the ordering (no ordering will be applied) Untested / not implemented yet: ----- diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 21b23c1..3b376db 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -53,7 +53,7 @@ class QuerySet(object): self._initial_query = {} self._where_clause = None self._loaded_fields = QueryFieldList() - self._ordering = [] + self._ordering = None self._snapshot = False self._timeout = True self._class_check = True @@ -1211,7 +1211,7 @@ class QuerySet(object): if self._ordering: # Apply query ordering self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: + elif self._ordering == None and self._document._meta['ordering']: # Otherwise, apply the ordering from the document model order = self._get_order_by(self._document._meta['ordering']) self._cursor_obj.sort(order) From 2689bcf953eabbdea13c95644094a165a075a79a Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Fri, 21 Jun 2013 17:33:53 -0700 Subject: [PATCH 15/18] Transform certain $or queries into $in queries to boost performance. --- mongoengine/queryset/visitor.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 41f4ebf..3783e7a 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -28,7 +28,7 @@ class DuplicateQueryConditionsError(InvalidQueryError): class SimplificationVisitor(QNodeVisitor): - """Simplifies query trees by combinging unnecessary 'and' connection nodes + """Simplifies query trees by combining unnecessary 'and' connection nodes into a single Q-object. """ @@ -73,6 +73,16 @@ class QueryCompilerVisitor(QNodeVisitor): def visit_combination(self, combination): operator = "$and" if combination.operation == combination.OR: + keys = set([key for q in combination.children for key in q.keys()]) + if len(keys) == 1: + field = keys.pop() + if not field.startswith('$') and not any([isinstance(q[field], dict) for q in combination.children]): + return { + field: { + '$in': [q[field] for q in combination.children if field in q] + } + } + operator = "$or" return {operator: combination.children} From 478062cb0f77e610966d58ffa162d1681e30ffdc Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Sat, 22 Jun 2013 16:10:23 -0700 Subject: [PATCH 16/18] Support for partial updates. --- DIFFERENCES.md | 3 +- mongoengine/base/document.py | 84 ++++++++++++++++++++++++------------ tests/document/instance.py | 6 +-- 3 files changed, 60 insertions(+), 33 deletions(-) diff --git a/DIFFERENCES.md b/DIFFERENCES.md index 3c524ef..01ae0f2 100644 --- a/DIFFERENCES.md +++ b/DIFFERENCES.md @@ -18,7 +18,7 @@ Differences between Mongomallard and Mongoengine * The primary key is only stored as `_id` in the database and is referenced in Python as `pk` or as the name of the primary key field. * Saves are not cascaded by default. * `Document.save()` supports `full=True` keyword argument to force saving all model fields. -* `_get_changed_fields()` / `_changed_fields` returns a set +* `_get_changed_fields()` / `_changed_fields` returns a set of field names (not db field names) * Simplified `EmailField` email regex to be more compatible * Assigning invalid types (e.g. an invalid string to `IntField`) raises immediately a `ValueError` * `order_by()` without an argument resets the ordering (no ordering will be applied) @@ -26,7 +26,6 @@ Differences between Mongomallard and Mongoengine Untested / not implemented yet: ----- -* Delta updates for lists / embedded documents (fields that have changed on the document are fully updated) * Dynamic documents / `DynamicField`, dynamic addition/deletion of fields * Field display name methods * `SequenceField` diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index bdc97ec..b5e357d 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -243,30 +243,36 @@ class BaseDocument(object): return value def _mark_as_changed(self, key): - """Marks a key as explicitly changed by the user + """Marks a key as explicitly changed by the user. """ if key: self._changed_fields.add(key) def _get_changed_fields(self): + """Returns a list of all fields that have explicitly been changed. + """ changed_fields = set(self._changed_fields) EmbeddedDocumentField = _import_class("EmbeddedDocumentField") for field_name, field in self._fields.iteritems(): - if (isinstance(field, ComplexBaseField) and - isinstance(field.field, EmbeddedDocumentField)): - field_value = getattr(self, field_name, None) - if field_value: - for idx in (field_value if isinstance(field_value, dict) - else xrange(len(field_value))): - if field_value[idx]._get_changed_fields(): - changed_fields.add(field_name) - continue - elif isinstance(field, EmbeddedDocumentField): - field_value = getattr(self, field_name, None) - if field_value: - if field_value._get_changed_fields(): - changed_fields.add(field_name) + if field_name not in changed_fields: + if (isinstance(field, ComplexBaseField) and + isinstance(field.field, EmbeddedDocumentField)): + field_value = getattr(self, field_name, None) + if field_value: + for idx in (field_value if isinstance(field_value, dict) + else xrange(len(field_value))): + changed_subfields = field_value[idx]._get_changed_fields() + if changed_subfields: + changed_fields |= set(['.'.join([field_name, str(idx), subfield_name]) + for subfield_name in changed_subfields]) + elif isinstance(field, EmbeddedDocumentField): + field_value = getattr(self, field_name, None) + if field_value: + changed_subfields = field_value._get_changed_fields() + if changed_subfields: + changed_fields |= set(['.'.join([field_name, subfield_name]) + for subfield_name in changed_subfields]) return changed_fields def _clear_changed_fields(self): @@ -289,23 +295,47 @@ class BaseDocument(object): sets = {} unsets = {} - if full or not self._created: - fields = self._fields.iteritems() - else: - fields = ((field_name, self._fields[field_name]) for field_name in self._get_changed_fields()) - def get(field_name, field): - value = getattr(self, field_name) + def get_db_value(field, value): if value is None: value = field.default() if callable(field.default) else field.default - return value + return field.to_mongo(value) - data = (( - self._db_field_map.get(field_name, field_name), - field.to_mongo(get(field_name, field))) - for field_name, field in fields) - for db_field_name, db_value in data: + if full or not self._created: + fields = self._fields.iteritems() + db_data = ((self._db_field_map.get(field_name, field_name), + get_db_value(field, getattr(self, field_name))) + for field_name, field in fields) + + else: + # List of (db_field_name, db_value) tuples. + db_data = [] + + for field_name in self._get_changed_fields(): + parts = field_name.split('.') + + db_field_parts = [] + + value = self + for part in parts: + if isinstance(value, list) and part.isdigit(): + db_field_parts.append(part) + field = field.field + value = value[int(part)] + elif isinstance(value, dict): + db_field_parts.append(part) + field = field.field + value = value[part] + else: # It's a document + obj = value + field = obj._fields[part] + db_field_parts.append(obj._db_field_map.get(part, part)) + value = getattr(obj, part) + + db_data.append(('.'.join(db_field_parts), get_db_value(field, value))) + + for db_field_name, db_value in db_data: if db_value == None: unsets[db_field_name] = 1 else: diff --git a/tests/document/instance.py b/tests/document/instance.py index 1df5d90..80a6130 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -397,10 +397,8 @@ class InstanceTest(unittest.TestCase): doc.embedded_field.dict_field['woot'] = "woot" self.assertEqual(doc._get_changed_fields(), set([ - 'list_field', 'dict_field', 'embedded_field'])) - #self.assertEqual(doc._get_changed_fields(), [ - # 'list_field', 'dict_field', 'embedded_field.list_field', - # 'embedded_field.dict_field']) + 'list_field', 'dict_field', 'embedded_field.list_field', + 'embedded_field.dict_field'])) doc.save() doc = doc.reload() From d7b4ad08cbe8cf4b91bd3aa1cfaf362c059690b2 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Fri, 28 Jun 2013 14:30:20 -0700 Subject: [PATCH 17/18] MongoMallard README + benchmarks --- README.md | 85 ++++++++++++++++++++++++++ README.rst => README_MONGOENGINE.rst | 0 mongoengine/__init__.py | 1 + mongoengine/base/proxy.py | 3 - tests/benchmark.py | 90 ++++++++++++++++++++++++++++ 5 files changed, 176 insertions(+), 3 deletions(-) create mode 100644 README.md rename README.rst => README_MONGOENGINE.rst (100%) create mode 100644 tests/benchmark.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..ab8002c --- /dev/null +++ b/README.md @@ -0,0 +1,85 @@ +MongoMallard +============ + +MongoMallard is a fast ORM-like layer on top of PyMongo, based on MongoEngine. + +* Repository: https://github.com/elasticsales/mongomallard +* See [README_MONGOENGINE](https://github.com/elasticsales/mongomallard/blob/master/README_MONGOENGINE.rst) for MongoEngine's README. +* See [DIFFERENCES](https://github.com/elasticsales/mongomallard/blob/master/DIFFERENCES.md) for differences between MongoEngine and MongoMallard. + + +Benchmarks +---------- + +Sample run on a 2.7 GHz Intel Core i5 running OS X 10.8.3 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MongoEngine 0.8.2 (ede9fcf)MongoMallard (478062c)Speedup
Doc initialization52.494us25.195us2.08x
Doc getattr1.339us0.584us2.29x
Doc setattr3.064us2.550us1.20x
Doc to mongo49.415us26.497us1.86x
Load from SON61.475us4.510us13.63x
Save to database434.389us289.972us2.29x
Load from database558.178us480.690us1.16x
Save/delete big object to database98.838ms65.789ms1.50x
Serialize big object from database31.390ms20.265ms1.55x
Load big object from database41.159ms1.400ms29.40x
+ +See [tests/benchmark.py](https://github.com/elasticsales/mongomallard/blob/master/tests/benchmark.py) for source code. diff --git a/README.rst b/README_MONGOENGINE.rst similarity index 100% rename from README.rst rename to README_MONGOENGINE.rst diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 5bd1201..875c916 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -16,6 +16,7 @@ __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + list(queryset.__all__) + signals.__all__ + list(errors.__all__)) VERSION = (0, 8, 2) +MALLARD = True def get_version(): diff --git a/mongoengine/base/proxy.py b/mongoengine/base/proxy.py index 4d92462..7d2879b 100644 --- a/mongoengine/base/proxy.py +++ b/mongoengine/base/proxy.py @@ -180,9 +180,6 @@ class DocumentProxy(LocalProxy): def _get_current_object(self): if self.__document == None: - #print 'fetching', self.__document_type, self.__pk - #import traceback - #traceback.print_stack() collection = self.__document_type._get_collection() son = collection.find_one({'_id': self.__pk}) document = self.__document_type._from_son(son) diff --git a/tests/benchmark.py b/tests/benchmark.py new file mode 100644 index 0000000..89439f3 --- /dev/null +++ b/tests/benchmark.py @@ -0,0 +1,90 @@ +from mongoengine import * +from timeit import repeat +import unittest + +conn_settings = { + 'db': 'mongomallard-test', +} + +connect(**conn_settings) + +def timeit(f, n=10000): + return min(repeat(f, repeat=3, number=n))/float(n) + +class BenchmarkTestCase(unittest.TestCase): + def setUp(self): + pass + + def test_basic(self): + class Book(Document): + name = StringField() + pages = IntField() + tags = ListField(StringField()) + is_published = BooleanField() + + Book.drop_collection() + + create_book = lambda: Book(name='Always be closing', pages=100, tags=['self-help', 'sales'], is_published=True) + print 'Doc initialization: %.3fus' % (timeit(create_book, 1000) * 10**6) + + b = create_book() + + print 'Doc getattr: %.3fus' % (timeit(lambda: b.name, 10000) * 10**6) + + print 'Doc setattr: %.3fus' % (timeit(lambda: setattr(b, 'name', 'New name'), 10000) * 10**6) + + print 'Doc to mongo: %.3fus' % (timeit(b.to_mongo, 1000) * 10**6) + + def save_book(): + b._mark_as_changed('name') + b._mark_as_changed('tags') + b.save() + + save_book() + son = b.to_mongo() + + print 'Load from SON: %.3fus' % (timeit(lambda: Book._from_son(son), 1000) * 10**6) + + print 'Save to database: %.3fus' % (timeit(save_book, 100) * 10**6) + + print 'Load from database: %.3fus' % (timeit(lambda: Book.objects[0], 100) * 10**6) + + def test_embedded(self): + class Contact(EmbeddedDocument): + name = StringField() + title = StringField() + address = StringField() + + class Company(Document): + name = StringField() + contacts = ListField(EmbeddedDocumentField(Contact)) + + Company.drop_collection() + + def get_company(): + return Company( + name='Elastic', + contacts=[ + Contact( + name='Contact %d' % x, + title='CEO', + address='Address %d' % x, + ) + for x in range(1000)] + ) + + def create_company(): + c = get_company() + c.save() + c.delete() + + print 'Save/delete big object to database: %.3fms' % (timeit(create_company, 10) * 10**3) + + c = get_company().save() + + print 'Serialize big object from database: %.3fms' % (timeit(c.to_mongo, 100) * 10**3) + print 'Load big object from database: %.3fms' % (timeit(lambda: Company.objects[0], 100) * 10**3) + + +if __name__ == '__main__': + unittest.main() From ab5dec31f8fc074bf64fcbe297e570123d07fbee Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Thu, 4 Jul 2013 19:15:08 -0700 Subject: [PATCH 18/18] QuerySet.only_classes / QuerySet.exclude_classes: Allow to limit/exclude classes for documents that can be inherited. --- mongoengine/queryset/queryset.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 3b376db..c27adef 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -589,6 +589,34 @@ class QuerySet(object): return self + def only_classes(self, *classes): + doc = self._document + if doc._meta.get('allow_inheritance') is True: + queryset = self.clone() + class_names = [cls._class_name for cls in classes] + allowed_class_names = [name for name in self._document._subclasses if name in class_names] + if len(allowed_class_names) == 1: + queryset._initial_query = {"_cls": allowed_class_names[0]} + else: + queryset._initial_query = {"_cls": {"$in": allowed_class_names}} + return queryset + else: + return self + + def exclude_classes(self, *classes): + doc = self._document + if doc._meta.get('allow_inheritance') is True: + queryset = self.clone() + class_names = [cls._class_name for cls in classes] + allowed_class_names = [name for name in self._document._subclasses if name in class_names] + if len(allowed_class_names) == 1: + queryset._initial_query = {"_cls": {"$ne": allowed_class_names[0]}} + else: + queryset._initial_query = {"_cls": {"$nin": allowed_class_names}} + return queryset + else: + return self + def clone(self): """Creates a copy of the current :class:`~mongoengine.queryset.QuerySet`