import operator from functools import partial import pymongo from bson.dbref import DBRef from mongoengine import signals from mongoengine.common import _import_class from mongoengine.errors import (ValidationError, InvalidDocumentError, LookUpError) from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, to_str_keys_recursive) from .common import get_document, ALLOW_INHERITANCE from .datastructures import BaseDict, BaseList from .fields import ComplexBaseField __all__ = ('BaseDocument', ) class BaseDocument(object): _dynamic = False _created = True _dynamic_lock = True _initialised = False def __init__(self, **values): signals.pre_init.send(self.__class__, document=self, values=values) self._data = {} # Assign default values to instance for key, field in self._fields.iteritems(): if self._db_field_map.get(key, key) in values: continue value = getattr(self, key, None) setattr(self, key, value) # Set passed values after initialisation if self._dynamic: self._dynamic_fields = {} dynamic_data = {} for key, value in values.iteritems(): if key in self._fields or key == '_id': setattr(self, key, value) elif self._dynamic: dynamic_data[key] = value else: for key, value in values.iteritems(): key = self._reverse_db_field_map.get(key, key) setattr(self, key, value) # Set any get_fieldname_display methods self.__set_field_display() if self._dynamic: self._dynamic_lock = False for key, value in dynamic_data.iteritems(): setattr(self, key, value) # Flag initialised self._initialised = True signals.post_init.send(self.__class__, document=self) def __setattr__(self, name, value): # Handle dynamic data only if an initialised dynamic document if self._dynamic and not self._dynamic_lock: field = None if not hasattr(self, name) and not name.startswith('_'): DynamicField = _import_class("DynamicField") field = DynamicField(db_field=name) field.name = name self._dynamic_fields[name] = field if not name.startswith('_'): value = self.__expand_dynamic_values(name, value) # Handle marking data as changed if name in self._dynamic_fields: self._data[name] = value if hasattr(self, '_changed_fields'): self._mark_as_changed(name) if (self._is_document and not self._created and name in self._meta.get('shard_key', tuple()) and self._data.get(name) != value): OperationError = _import_class('OperationError') msg = "Shard Keys are immutable. Tried to update %s" % name raise OperationError(msg) super(BaseDocument, self).__setattr__(name, value) def __getstate__(self): removals = ("get_%s_display" % k for k, v in self._fields.items() if v.choices) for k in removals: if hasattr(self, k): delattr(self, k) return self.__dict__ def __setstate__(self, __dict__): self.__dict__ = __dict__ self.__set_field_display() def __iter__(self): return iter(self._fields) def __getitem__(self, name): """Dictionary-style field access, return a field's value if present. """ try: if name in self._fields: return getattr(self, name) except AttributeError: pass raise KeyError(name) def __setitem__(self, name, value): """Dictionary-style field access, set a field's value. """ # Ensure that the field exists before settings its value if name not in self._fields: raise KeyError(name) return setattr(self, name, value) def __contains__(self, name): try: val = getattr(self, name) return val is not None except AttributeError: return False def __len__(self): return len(self._data) def __repr__(self): try: u = self.__str__() except (UnicodeEncodeError, UnicodeDecodeError): u = '[Bad Unicode data]' repr_type = type(u) return repr_type('<%s: %s>' % (self.__class__.__name__, u)) def __str__(self): if hasattr(self, '__unicode__'): if PY3: return self.__unicode__() else: return unicode(self).encode('utf-8') return txt_type('%s object' % self.__class__.__name__) def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id'): if self.id == other.id: return True return False def __ne__(self, other): return not self.__eq__(other) def __hash__(self): if self.pk is None: # For new object return super(BaseDocument, self).__hash__() else: return hash(self.pk) def to_mongo(self): """Return data dictionary ready for use with MongoDB. """ data = {} for field_name, field in self._fields.items(): value = getattr(self, field_name, None) if value is not None: data[field.db_field] = field.to_mongo(value) # Only add _cls if allow_inheritance is not False if not (hasattr(self, '_meta') and self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == False): data['_cls'] = self._class_name if '_id' in data and data['_id'] is None: del data['_id'] if not self._dynamic: return data for name, field in self._dynamic_fields.items(): data[name] = field.to_mongo(self._data.get(name, None)) return data def validate(self): """Ensure that all fields' values are valid and that required fields are present. """ # Get a list of tuples of field names and their current values fields = [(field, getattr(self, name)) for name, field in self._fields.items()] # Ensure that each field is matched to a valid value errors = {} for field, value in fields: if value is not None: try: field._validate(value) except ValidationError, error: errors[field.name] = error.errors or error except (ValueError, AttributeError, AssertionError), error: errors[field.name] = error elif field.required: errors[field.name] = ValidationError('Field is required', field_name=field.name) if errors: raise ValidationError('ValidationError', errors=errors) def __expand_dynamic_values(self, name, value): """expand any dynamic values to their correct types / values""" if not isinstance(value, (dict, list, tuple)): return value is_list = False if not hasattr(value, 'items'): is_list = True value = dict([(k, v) for k, v in enumerate(value)]) if not is_list and '_cls' in value: cls = get_document(value['_cls']) return cls(**value) data = {} for k, v in value.items(): key = name if is_list else k data[k] = self.__expand_dynamic_values(key, v) if is_list: # Convert back to a list data_items = sorted(data.items(), key=operator.itemgetter(0)) value = [v for k, v in data_items] else: value = data # Convert lists / values so we can watch for any changes on them if (isinstance(value, (list, tuple)) and not isinstance(value, BaseList)): value = BaseList(value, self, name) elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, self, name) return value def _mark_as_changed(self, key): """Marks a key as explicitly changed by the user """ if not key: return key = self._db_field_map.get(key, key) if (hasattr(self, '_changed_fields') and key not in self._changed_fields): self._changed_fields.append(key) def _get_changed_fields(self, key='', inspected=None): """Returns a list of all fields that have explicitly been changed. """ EmbeddedDocument = _import_class("EmbeddedDocument") DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") _changed_fields = [] _changed_fields += getattr(self, '_changed_fields', []) inspected = inspected or set() if hasattr(self, 'id'): if self.id in inspected: return _changed_fields inspected.add(self.id) field_list = self._fields.copy() if self._dynamic: field_list.update(self._dynamic_fields) for field_name in field_list: db_field_name = self._db_field_map.get(field_name, field_name) key = '%s.' % db_field_name field = self._data.get(field_name, None) if hasattr(field, 'id'): if field.id in inspected: continue inspected.add(field.id) if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields): # Find all embedded fields that have been changed changed = field._get_changed_fields(key, inspected) _changed_fields += ["%s%s" % (key, k) for k in changed if k] elif (isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields): # Loop list / dict fields as they contain documents # Determine the iterator to use if not hasattr(field, 'items'): iterator = enumerate(field) else: iterator = field.iteritems() for index, value in iterator: if not hasattr(value, '_get_changed_fields'): continue list_key = "%s%s." % (key, index) changed = value._get_changed_fields(list_key, inspected) _changed_fields += ["%s%s" % (list_key, k) for k in changed if k] return _changed_fields def _delta(self): """Returns the delta (set, unset) of the changes for a document. Gets any values that have been explicitly changed. """ # Handles cases where not loaded from_son but has _id doc = self.to_mongo() set_fields = self._get_changed_fields() set_data = {} unset_data = {} parts = [] if hasattr(self, '_changed_fields'): set_data = {} # Fetch each set item from its path for path in set_fields: parts = path.split('.') d = doc new_path = [] for p in parts: if isinstance(d, DBRef): break elif p.isdigit(): d = d[int(p)] elif hasattr(d, 'get'): d = d.get(p) new_path.append(p) path = '.'.join(new_path) set_data[path] = d else: set_data = doc if '_id' in set_data: del(set_data['_id']) # Determine if any changed items were actually unset. for path, value in set_data.items(): if value or isinstance(value, bool): continue # If we've set a value that ain't the default value dont unset it. default = None if (self._dynamic and len(parts) and parts[0] in self._dynamic_fields): del(set_data[path]) unset_data[path] = 1 continue elif path in self._fields: default = self._fields[path].default else: # Perform a full lookup for lists / embedded lookups d = self parts = path.split('.') db_field_name = parts.pop() for p in parts: if p.isdigit(): d = d[int(p)] elif (hasattr(d, '__getattribute__') and not isinstance(d, dict)): real_path = d._reverse_db_field_map.get(p, p) d = getattr(d, real_path) else: d = d.get(p) if hasattr(d, '_fields'): field_name = d._reverse_db_field_map.get(db_field_name, db_field_name) if field_name in d._fields: default = d._fields.get(field_name).default else: default = None if default is not None: if callable(default): default = default() if default != value: continue del(set_data[path]) unset_data[path] = 1 return set_data, unset_data @classmethod def _get_collection_name(cls): """Returns the collection name for this class. """ return cls._meta.get('collection', None) @classmethod def _from_son(cls, son): """Create an instance of a Document (subclass) from a PyMongo SON. """ # get the class name from the document, falling back to the given # class if unavailable class_name = son.get('_cls', cls._class_name) data = dict(("%s" % key, value) for key, value in son.items()) if not UNICODE_KWARGS: # python 2.6.4 and lower cannot handle unicode keys # passed to class constructor example: cls(**data) to_str_keys_recursive(data) if '_cls' in data: del data['_cls'] # Return correct subclass for document type if class_name != cls._class_name: cls = get_document(class_name) changed_fields = [] errors_dict = {} for field_name, field in cls._fields.items(): if field.db_field in data: value = data[field.db_field] try: data[field_name] = (value if value is None else field.to_python(value)) if field_name != field.db_field: del data[field.db_field] except (AttributeError, ValueError), e: errors_dict[field_name] = e elif field.default: default = field.default if callable(default): default = default() if isinstance(default, BaseDocument): changed_fields.append(field_name) if errors_dict: errors = "\n".join(["%s - %s" % (k, v) for k, v in errors_dict.items()]) msg = ("Invalid data to create a `%s` instance.\n%s" % (cls._class_name, errors)) raise InvalidDocumentError(msg) obj = cls(**data) obj._changed_fields = changed_fields obj._created = False return obj @classmethod def _build_index_spec(cls, spec): """Build a PyMongo index spec from a MongoEngine index spec. """ if isinstance(spec, basestring): spec = {'fields': [spec]} elif isinstance(spec, (list, tuple)): spec = {'fields': list(spec)} elif isinstance(spec, dict): spec = dict(spec) index_list = [] direction = None # Check to see if we need to include _cls allow_inheritance = cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) != False include_cls = allow_inheritance and not spec.get('sparse', False) for key in spec['fields']: # If inherited spec continue if isinstance(key, (list, tuple)): continue # ASCENDING from +, # DESCENDING from - # GEO2D from * direction = pymongo.ASCENDING if key.startswith("-"): direction = pymongo.DESCENDING elif key.startswith("*"): direction = pymongo.GEO2D if key.startswith(("+", "-", "*")): key = key[1:] # Use real field name, do it manually because we need field # objects for the next part (list field checking) parts = key.split('.') if parts in (['pk'], ['id'], ['_id']): key = '_id' fields = [] else: fields = cls._lookup_field(parts) parts = [field if field == '_id' else field.db_field for field in fields] key = '.'.join(parts) index_list.append((key, direction)) # Don't add cls to a geo index if include_cls and direction is not pymongo.GEO2D: index_list.insert(0, ('_cls', 1)) spec['fields'] = index_list if spec.get('sparse', False) and len(spec['fields']) > 1: raise ValueError( 'Sparse indexes can only have one field in them. ' 'See https://jira.mongodb.org/browse/SERVER-2193') return spec @classmethod def _unique_with_indexes(cls, namespace=""): """ Find and set unique indexes """ unique_indexes = [] for field_name, field in cls._fields.items(): # Generate a list of indexes needed by uniqueness constraints if field.unique: field.required = True unique_fields = [field.db_field] # Add any unique_with fields to the back of the index spec if field.unique_with: if isinstance(field.unique_with, basestring): field.unique_with = [field.unique_with] # Convert unique_with field names to real field names unique_with = [] for other_name in field.unique_with: parts = other_name.split('.') # Lookup real name parts = cls._lookup_field(parts) name_parts = [part.db_field for part in parts] unique_with.append('.'.join(name_parts)) # Unique field should be required parts[-1].required = True unique_fields += unique_with # Add the new index to the list index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields] unique_indexes.append(index) # Grab any embedded document field unique indexes if (field.__class__.__name__ == "EmbeddedDocumentField" and field.document_type != cls): field_namespace = "%s." % field_name doc_cls = field.document_type unique_indexes += doc_cls._unique_with_indexes(field_namespace) return unique_indexes @classmethod def _lookup_field(cls, parts): """Lookup a field based on its attribute and return a list containing the field's parents and the field. """ if not isinstance(parts, (list, tuple)): parts = [parts] fields = [] field = None for field_name in parts: # Handle ListField indexing: if field_name.isdigit(): new_field = field.field fields.append(field_name) continue if field is None: # Look up first field from the document if field_name == 'pk': # Deal with "primary key" alias field_name = cls._meta['id_field'] if field_name in cls._fields: field = cls._fields[field_name] elif cls._dynamic: DynamicField = _import_class('DynamicField') field = DynamicField(db_field=field_name) else: raise LookUpError('Cannot resolve field "%s"' % field_name) else: ReferenceField = _import_class('ReferenceField') GenericReferenceField = _import_class('GenericReferenceField') if isinstance(field, (ReferenceField, GenericReferenceField)): raise LookUpError('Cannot perform join in mongoDB: %s' % '__'.join(parts)) if hasattr(getattr(field, 'field', None), 'lookup_member'): new_field = field.field.lookup_member(field_name) else: # Look up subfield on the previous field new_field = field.lookup_member(field_name) if not new_field and isinstance(field, ComplexBaseField): fields.append(field_name) continue elif not new_field: raise LookUpError('Cannot resolve field "%s"' % field_name) field = new_field # update field to the new field type fields.append(field) return fields @classmethod def _translate_field_name(cls, field, sep='.'): """Translate a field attribute name to a database field name. """ parts = field.split(sep) parts = [f.db_field for f in cls._lookup_field(parts)] return '.'.join(parts) @classmethod def _geo_indices(cls, inspected=None): inspected = inspected or [] geo_indices = [] inspected.append(cls) EmbeddedDocumentField = _import_class("EmbeddedDocumentField") GeoPointField = _import_class("GeoPointField") for field in cls._fields.values(): if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): continue if hasattr(field, 'document_type'): field_cls = field.document_type if field_cls in inspected: continue if hasattr(field_cls, '_geo_indices'): geo_indices += field_cls._geo_indices(inspected) elif field._geo_index: geo_indices.append(field) return geo_indices def __set_field_display(self): """Dynamically set the display value for a field with choices""" for attr_name, field in self._fields.items(): if field.choices: setattr(self, 'get_%s_display' % attr_name, partial(self.__get_field_display, field=field)) def __get_field_display(self, field): """Returns the display value for a choice field""" value = getattr(self, field.name) if field.choices and isinstance(field.choices[0], (list, tuple)): return dict(field.choices).get(value, value) return value