Support for partial updates.

This commit is contained in:
Thomas Steinacher 2013-06-22 16:10:23 -07:00
commit 478062cb0f
3 changed files with 60 additions and 33 deletions

View file

@ -18,7 +18,7 @@ Differences between Mongomallard and Mongoengine
* The primary key is only stored as `_id` in the database and is referenced in Python as `pk` or as the name of the primary key field.
* Saves are not cascaded by default.
* `Document.save()` supports `full=True` keyword argument to force saving all model fields.
* `_get_changed_fields()` / `_changed_fields` returns a set
* `_get_changed_fields()` / `_changed_fields` returns a set of field names (not db field names)
* Simplified `EmailField` email regex to be more compatible
* Assigning invalid types (e.g. an invalid string to `IntField`) raises immediately a `ValueError`
* `order_by()` without an argument resets the ordering (no ordering will be applied)
@ -26,7 +26,6 @@ Differences between Mongomallard and Mongoengine
Untested / not implemented yet:
-----
* Delta updates for lists / embedded documents (fields that have changed on the document are fully updated)
* Dynamic documents / `DynamicField`, dynamic addition/deletion of fields
* Field display name methods
* `SequenceField`

View file

@ -243,30 +243,36 @@ class BaseDocument(object):
return value
def _mark_as_changed(self, key):
"""Marks a key as explicitly changed by the user
"""Marks a key as explicitly changed by the user.
"""
if key:
self._changed_fields.add(key)
def _get_changed_fields(self):
"""Returns a list of all fields that have explicitly been changed.
"""
changed_fields = set(self._changed_fields)
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
for field_name, field in self._fields.iteritems():
if (isinstance(field, ComplexBaseField) and
isinstance(field.field, EmbeddedDocumentField)):
field_value = getattr(self, field_name, None)
if field_value:
for idx in (field_value if isinstance(field_value, dict)
else xrange(len(field_value))):
if field_value[idx]._get_changed_fields():
changed_fields.add(field_name)
continue
elif isinstance(field, EmbeddedDocumentField):
field_value = getattr(self, field_name, None)
if field_value:
if field_value._get_changed_fields():
changed_fields.add(field_name)
if field_name not in changed_fields:
if (isinstance(field, ComplexBaseField) and
isinstance(field.field, EmbeddedDocumentField)):
field_value = getattr(self, field_name, None)
if field_value:
for idx in (field_value if isinstance(field_value, dict)
else xrange(len(field_value))):
changed_subfields = field_value[idx]._get_changed_fields()
if changed_subfields:
changed_fields |= set(['.'.join([field_name, str(idx), subfield_name])
for subfield_name in changed_subfields])
elif isinstance(field, EmbeddedDocumentField):
field_value = getattr(self, field_name, None)
if field_value:
changed_subfields = field_value._get_changed_fields()
if changed_subfields:
changed_fields |= set(['.'.join([field_name, subfield_name])
for subfield_name in changed_subfields])
return changed_fields
def _clear_changed_fields(self):
@ -289,23 +295,47 @@ class BaseDocument(object):
sets = {}
unsets = {}
if full or not self._created:
fields = self._fields.iteritems()
else:
fields = ((field_name, self._fields[field_name]) for field_name in self._get_changed_fields())
def get(field_name, field):
value = getattr(self, field_name)
def get_db_value(field, value):
if value is None:
value = field.default() if callable(field.default) else field.default
return value
return field.to_mongo(value)
data = ((
self._db_field_map.get(field_name, field_name),
field.to_mongo(get(field_name, field)))
for field_name, field in fields)
for db_field_name, db_value in data:
if full or not self._created:
fields = self._fields.iteritems()
db_data = ((self._db_field_map.get(field_name, field_name),
get_db_value(field, getattr(self, field_name)))
for field_name, field in fields)
else:
# List of (db_field_name, db_value) tuples.
db_data = []
for field_name in self._get_changed_fields():
parts = field_name.split('.')
db_field_parts = []
value = self
for part in parts:
if isinstance(value, list) and part.isdigit():
db_field_parts.append(part)
field = field.field
value = value[int(part)]
elif isinstance(value, dict):
db_field_parts.append(part)
field = field.field
value = value[part]
else: # It's a document
obj = value
field = obj._fields[part]
db_field_parts.append(obj._db_field_map.get(part, part))
value = getattr(obj, part)
db_data.append(('.'.join(db_field_parts), get_db_value(field, value)))
for db_field_name, db_value in db_data:
if db_value == None:
unsets[db_field_name] = 1
else:

View file

@ -397,10 +397,8 @@ class InstanceTest(unittest.TestCase):
doc.embedded_field.dict_field['woot'] = "woot"
self.assertEqual(doc._get_changed_fields(), set([
'list_field', 'dict_field', 'embedded_field']))
#self.assertEqual(doc._get_changed_fields(), [
# 'list_field', 'dict_field', 'embedded_field.list_field',
# 'embedded_field.dict_field'])
'list_field', 'dict_field', 'embedded_field.list_field',
'embedded_field.dict_field']))
doc.save()
doc = doc.reload()