From 929a1b222c871890a4f606d795146de73aa4cb24 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Fri, 14 Jun 2013 20:06:12 -0700 Subject: [PATCH] Add support for DocumentProxy so we can access obj.ref.pk without fetching ref and still get the correct type even if ref can be inherited. --- mongoengine/base/document.py | 10 +- mongoengine/base/proxy.py | 193 +++++++++++++++++++++++++++++++++++ mongoengine/fields.py | 35 ++++--- setup.py | 5 +- 4 files changed, 226 insertions(+), 17 deletions(-) create mode 100644 mongoengine/base/proxy.py diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 1195bc4..bdc97ec 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -15,6 +15,7 @@ from mongoengine.errors import (ValidationError, InvalidDocumentError, from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, to_str_keys_recursive) +from mongoengine.base.proxy import DocumentProxy from mongoengine.base.common import get_document, ALLOW_INHERITANCE from mongoengine.base.datastructures import BaseDict, BaseList from mongoengine.base.fields import ComplexBaseField @@ -109,9 +110,12 @@ class BaseDocument(object): return txt_type('%s object' % self.__class__.__name__) def __eq__(self, other): - if isinstance(other, self.__class__) and hasattr(other, 'pk'): - if self.pk == other.pk: - return True + if isinstance(other, DocumentProxy) and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk') and self.pk == other.pk: + return True + + if isinstance(other, self.__class__) and hasattr(other, 'pk') and self.pk == other.pk: + return True + return False def __ne__(self, other): diff --git a/mongoengine/base/proxy.py b/mongoengine/base/proxy.py new file mode 100644 index 0000000..4d92462 --- /dev/null +++ b/mongoengine/base/proxy.py @@ -0,0 +1,193 @@ +from mongoengine.queryset import OperationError +from bson.dbref import DBRef + +class LocalProxy(object): + # From werkzeug/local.py + + """ Forwards all operations to + a proxied object. The only operations not supported for forwarding + are right handed operands and any kind of assignment. + """ + + __slots__ = ('__local', '__dict__', '__name__') + + def __init__(self, local, name=None): + object.__setattr__(self, '_LocalProxy__local', local) + object.__setattr__(self, '__name__', name) + + def _get_current_object(self): + """Return the current object. This is useful if you want the real + object behind the proxy at a time for performance reasons or because + you want to pass the object into a different context. + """ + if not hasattr(self.__local, '__release_local__'): + return self.__local() + try: + return getattr(self.__local, self.__name__) + except AttributeError: + raise RuntimeError('no object bound to %s' % self.__name__) + + @property + def __dict__(self): + try: + return self._get_current_object().__dict__ + except RuntimeError: + raise AttributeError('__dict__') + + def __repr__(self): + try: + obj = self._get_current_object() + except RuntimeError: + return '<%s unbound>' % self.__class__.__name__ + return repr(obj) + + def __nonzero__(self): + try: + return bool(self._get_current_object()) + except RuntimeError: + return False + + def __unicode__(self): + try: + return unicode(self._get_current_object()) + except RuntimeError: + return repr(self) + + def __dir__(self): + try: + return dir(self._get_current_object()) + except RuntimeError: + return [] + + def __getattr__(self, name): + if name == '__members__': + return dir(self._get_current_object()) + return getattr(self._get_current_object(), name) + + def __setitem__(self, key, value): + self._get_current_object()[key] = value + + def __delitem__(self, key): + del self._get_current_object()[key] + + def __setslice__(self, i, j, seq): + self._get_current_object()[i:j] = seq + + def __delslice__(self, i, j): + del self._get_current_object()[i:j] + + __setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v) + __delattr__ = lambda x, n: delattr(x._get_current_object(), n) + __str__ = lambda x: str(x._get_current_object()) + __lt__ = lambda x, o: x._get_current_object() < o + __le__ = lambda x, o: x._get_current_object() <= o + __eq__ = lambda x, o: x._get_current_object() == o + __ne__ = lambda x, o: x._get_current_object() != o + __gt__ = lambda x, o: x._get_current_object() > o + __ge__ = lambda x, o: x._get_current_object() >= o + __cmp__ = lambda x, o: cmp(x._get_current_object(), o) + __hash__ = lambda x: hash(x._get_current_object()) + __call__ = lambda x, *a, **kw: x._get_current_object()(*a, **kw) + __len__ = lambda x: len(x._get_current_object()) + __getitem__ = lambda x, i: x._get_current_object()[i] + __iter__ = lambda x: iter(x._get_current_object()) + __contains__ = lambda x, i: i in x._get_current_object() + __getslice__ = lambda x, i, j: x._get_current_object()[i:j] + __add__ = lambda x, o: x._get_current_object() + o + __sub__ = lambda x, o: x._get_current_object() - o + __mul__ = lambda x, o: x._get_current_object() * o + __floordiv__ = lambda x, o: x._get_current_object() // o + __mod__ = lambda x, o: x._get_current_object() % o + __divmod__ = lambda x, o: x._get_current_object().__divmod__(o) + __pow__ = lambda x, o: x._get_current_object() ** o + __lshift__ = lambda x, o: x._get_current_object() << o + __rshift__ = lambda x, o: x._get_current_object() >> o + __and__ = lambda x, o: x._get_current_object() & o + __xor__ = lambda x, o: x._get_current_object() ^ o + __or__ = lambda x, o: x._get_current_object() | o + __div__ = lambda x, o: x._get_current_object().__div__(o) + __truediv__ = lambda x, o: x._get_current_object().__truediv__(o) + __neg__ = lambda x: -(x._get_current_object()) + __pos__ = lambda x: +(x._get_current_object()) + __abs__ = lambda x: abs(x._get_current_object()) + __invert__ = lambda x: ~(x._get_current_object()) + __complex__ = lambda x: complex(x._get_current_object()) + __int__ = lambda x: int(x._get_current_object()) + __long__ = lambda x: long(x._get_current_object()) + __float__ = lambda x: float(x._get_current_object()) + __oct__ = lambda x: oct(x._get_current_object()) + __hex__ = lambda x: hex(x._get_current_object()) + __index__ = lambda x: x._get_current_object().__index__() + __coerce__ = lambda x, o: x.__coerce__(x, o) + __enter__ = lambda x: x.__enter__() + __exit__ = lambda x, *a, **kw: x.__exit__(*a, **kw) + + +class DocumentProxy(LocalProxy): + __slots__ = ('__document_type', '__document', '__pk') + + def __init__(self, document_type, pk): + object.__setattr__(self, '_DocumentProxy__document_type', document_type) + object.__setattr__(self, '_DocumentProxy__document', None) + object.__setattr__(self, '_DocumentProxy__pk', pk) + object.__setattr__(self, document_type._meta['id_field'], self.pk) + + @property + def __class__(self): + # We need to fetch the object to determine to which class it belongs. + return self._get_current_object().__class__ + + def _lazy(): + def fget(self): + return self.__document._lazy if self.__document else True + def fset(self, value): + self._get_current_object()._lazy = value + return property(fget, fset) + _lazy = _lazy() + + # copy normally updates __dict__ which would result in errors + def __setstate__(self, state): + for k, v in state[1].iteritems(): + object.__setattr__(self, k, v) + + def _get_collection_name(self): + return self.__document_type._meta.get('collection', None) + + def __eq__(self, other): + if other and hasattr(other, '_get_collection_name') and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk'): + if self.pk == other.pk: + return True + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def to_dbref(self): + """Returns an instance of :class:`~bson.dbref.DBRef` useful in + `__raw__` queries.""" + if not self.pk: + msg = "Only saved documents can have a valid dbref" + raise OperationError(msg) + return DBRef(self._get_collection_name(), self.pk) + + def pk(): + def fget(self): + return self.__document.pk if self.__document else self.__pk + def fset(self, value): + self._get_current_object().pk = value + return property(fget, fset) + pk = pk() + + def _get_current_object(self): + if self.__document == None: + #print 'fetching', self.__document_type, self.__pk + #import traceback + #traceback.print_stack() + collection = self.__document_type._get_collection() + son = collection.find_one({'_id': self.__pk}) + document = self.__document_type._from_son(son) + object.__setattr__(self, '_DocumentProxy__document', document) + return self.__document + + def __nonzero__(self): + return True diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 8811f7f..0c2056a 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -26,6 +26,7 @@ from mongoengine.python_support import (PY3, bin_type, txt_type, from mongoengine.base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, get_document, BaseDocument) from mongoengine.base.datastructures import BaseList, BaseDict +from mongoengine.base.proxy import DocumentProxy from mongoengine.queryset import DoesNotExist from queryset import DO_NOTHING, QuerySet from document import Document, EmbeddedDocument @@ -516,7 +517,7 @@ class ListField(ComplexBaseField): super(ListField, self).__init__(**kwargs) def value_for_instance(self, value, instance): - return BaseList(value, instance, self.name) + return BaseList(value or [], instance, self.name) def from_python(self, val): from_python = getattr(self.field, 'from_python', None) @@ -615,7 +616,7 @@ class DictField(ComplexBaseField): return {k: to_python(v) for k, v in val.iteritems()} if to_python else val def value_for_instance(self, value, instance): - return BaseDict(value, instance, self.name) + return BaseDict(value or {}, instance, self.name) def to_mongo(self, val): to_mongo = getattr(self.field, 'to_mongo', None) @@ -734,7 +735,7 @@ class ReferenceField(BaseField): return value else: return value.id - elif isinstance(value, Document): + elif isinstance(value, (Document, DocumentProxy)): document_type = self.document_type # We need the id from the saved object to create the DBRef pk = value.pk @@ -758,17 +759,22 @@ class ReferenceField(BaseField): if value != None: document_type = self.document_type if self.dbref: - obj = document_type(pk=value.id) + pk = value.id else: if isinstance(value, DBRef): - obj = document_type(pk=value.id) + pk = value.id else: - obj = document_type(pk=value) - obj._lazy = True + pk = value + if document_type._meta['allow_inheritance']: + # We don't know of which type the object will be. + obj = DocumentProxy(document_type, pk) + else: + obj = document_type(pk=pk) + obj._lazy = True return obj def from_python(self, value): - if isinstance(value, BaseDocument): + if isinstance(value, (BaseDocument, DocumentProxy)): return value elif value == None: return super(ReferenceField, self).from_python(value) @@ -780,17 +786,22 @@ class ReferenceField(BaseField): # DBRef or ID document_type = self.document_type if isinstance(value, DBRef): - obj = document_type(pk=value.id) + pk = value.id else: - obj = document_type(pk=value) - obj._lazy = True + pk = value + if document_type._meta['allow_inheritance']: + # We don't know of which type the object will be. + obj = DocumentProxy(document_type, pk) + else: + obj = document_type(pk=pk) + obj._lazy = True return obj def prepare_query_value(self, op, value): return self.to_mongo(self.from_python(value)) def validate(self, value): - if not isinstance(value, (self.document_type, DBRef)): + if not isinstance(value, (self.document_type, DBRef, DocumentProxy)): self.error("A ReferenceField only accepts DBRef or documents") if isinstance(value, Document) and value.pk is None: diff --git a/setup.py b/setup.py index effb6f1..7a9f360 100644 --- a/setup.py +++ b/setup.py @@ -51,13 +51,14 @@ CLASSIFIERS = [ extra_opts = {} if sys.version_info[0] == 3: extra_opts['use_2to3'] = True - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2>=2.6'] extra_opts['packages'] = find_packages(exclude=('tests',)) if "test" in sys.argv or "nosetests" in sys.argv: extra_opts['packages'].append("tests") extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} else: - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2==2.6', 'python-dateutil'] + #extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2>=2.6'] extra_opts['packages'] = find_packages(exclude=('tests',)) setup(name='mongoengine',