diff --git a/DIFFERENCES.md b/DIFFERENCES.md index 3c524ef..01ae0f2 100644 --- a/DIFFERENCES.md +++ b/DIFFERENCES.md @@ -18,7 +18,7 @@ Differences between Mongomallard and Mongoengine * The primary key is only stored as `_id` in the database and is referenced in Python as `pk` or as the name of the primary key field. * Saves are not cascaded by default. * `Document.save()` supports `full=True` keyword argument to force saving all model fields. -* `_get_changed_fields()` / `_changed_fields` returns a set +* `_get_changed_fields()` / `_changed_fields` returns a set of field names (not db field names) * Simplified `EmailField` email regex to be more compatible * Assigning invalid types (e.g. an invalid string to `IntField`) raises immediately a `ValueError` * `order_by()` without an argument resets the ordering (no ordering will be applied) @@ -26,7 +26,6 @@ Differences between Mongomallard and Mongoengine Untested / not implemented yet: ----- -* Delta updates for lists / embedded documents (fields that have changed on the document are fully updated) * Dynamic documents / `DynamicField`, dynamic addition/deletion of fields * Field display name methods * `SequenceField` diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index bdc97ec..b5e357d 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -243,30 +243,36 @@ class BaseDocument(object): return value def _mark_as_changed(self, key): - """Marks a key as explicitly changed by the user + """Marks a key as explicitly changed by the user. """ if key: self._changed_fields.add(key) def _get_changed_fields(self): + """Returns a list of all fields that have explicitly been changed. + """ changed_fields = set(self._changed_fields) EmbeddedDocumentField = _import_class("EmbeddedDocumentField") for field_name, field in self._fields.iteritems(): - if (isinstance(field, ComplexBaseField) and - isinstance(field.field, EmbeddedDocumentField)): - field_value = getattr(self, field_name, None) - if field_value: - for idx in (field_value if isinstance(field_value, dict) - else xrange(len(field_value))): - if field_value[idx]._get_changed_fields(): - changed_fields.add(field_name) - continue - elif isinstance(field, EmbeddedDocumentField): - field_value = getattr(self, field_name, None) - if field_value: - if field_value._get_changed_fields(): - changed_fields.add(field_name) + if field_name not in changed_fields: + if (isinstance(field, ComplexBaseField) and + isinstance(field.field, EmbeddedDocumentField)): + field_value = getattr(self, field_name, None) + if field_value: + for idx in (field_value if isinstance(field_value, dict) + else xrange(len(field_value))): + changed_subfields = field_value[idx]._get_changed_fields() + if changed_subfields: + changed_fields |= set(['.'.join([field_name, str(idx), subfield_name]) + for subfield_name in changed_subfields]) + elif isinstance(field, EmbeddedDocumentField): + field_value = getattr(self, field_name, None) + if field_value: + changed_subfields = field_value._get_changed_fields() + if changed_subfields: + changed_fields |= set(['.'.join([field_name, subfield_name]) + for subfield_name in changed_subfields]) return changed_fields def _clear_changed_fields(self): @@ -289,23 +295,47 @@ class BaseDocument(object): sets = {} unsets = {} - if full or not self._created: - fields = self._fields.iteritems() - else: - fields = ((field_name, self._fields[field_name]) for field_name in self._get_changed_fields()) - def get(field_name, field): - value = getattr(self, field_name) + def get_db_value(field, value): if value is None: value = field.default() if callable(field.default) else field.default - return value + return field.to_mongo(value) - data = (( - self._db_field_map.get(field_name, field_name), - field.to_mongo(get(field_name, field))) - for field_name, field in fields) - for db_field_name, db_value in data: + if full or not self._created: + fields = self._fields.iteritems() + db_data = ((self._db_field_map.get(field_name, field_name), + get_db_value(field, getattr(self, field_name))) + for field_name, field in fields) + + else: + # List of (db_field_name, db_value) tuples. + db_data = [] + + for field_name in self._get_changed_fields(): + parts = field_name.split('.') + + db_field_parts = [] + + value = self + for part in parts: + if isinstance(value, list) and part.isdigit(): + db_field_parts.append(part) + field = field.field + value = value[int(part)] + elif isinstance(value, dict): + db_field_parts.append(part) + field = field.field + value = value[part] + else: # It's a document + obj = value + field = obj._fields[part] + db_field_parts.append(obj._db_field_map.get(part, part)) + value = getattr(obj, part) + + db_data.append(('.'.join(db_field_parts), get_db_value(field, value))) + + for db_field_name, db_value in db_data: if db_value == None: unsets[db_field_name] = 1 else: diff --git a/tests/document/instance.py b/tests/document/instance.py index 1df5d90..80a6130 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -397,10 +397,8 @@ class InstanceTest(unittest.TestCase): doc.embedded_field.dict_field['woot'] = "woot" self.assertEqual(doc._get_changed_fields(), set([ - 'list_field', 'dict_field', 'embedded_field'])) - #self.assertEqual(doc._get_changed_fields(), [ - # 'list_field', 'dict_field', 'embedded_field.list_field', - # 'embedded_field.dict_field']) + 'list_field', 'dict_field', 'embedded_field.list_field', + 'embedded_field.dict_field'])) doc.save() doc = doc.reload()