Add support for DocumentProxy so we can access obj.ref.pk without fetching ref and still get the correct type even if ref can be inherited.

This commit is contained in:
Thomas Steinacher 2013-06-14 20:06:12 -07:00
commit 929a1b222c
4 changed files with 226 additions and 17 deletions

View file

@ -15,6 +15,7 @@ from mongoengine.errors import (ValidationError, InvalidDocumentError,
from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type,
to_str_keys_recursive)
from mongoengine.base.proxy import DocumentProxy
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList
from mongoengine.base.fields import ComplexBaseField
@ -109,9 +110,12 @@ class BaseDocument(object):
return txt_type('%s object' % self.__class__.__name__)
def __eq__(self, other):
if isinstance(other, self.__class__) and hasattr(other, 'pk'):
if self.pk == other.pk:
return True
if isinstance(other, DocumentProxy) and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk') and self.pk == other.pk:
return True
if isinstance(other, self.__class__) and hasattr(other, 'pk') and self.pk == other.pk:
return True
return False
def __ne__(self, other):

193
mongoengine/base/proxy.py Normal file
View file

@ -0,0 +1,193 @@
from mongoengine.queryset import OperationError
from bson.dbref import DBRef
class LocalProxy(object):
# From werkzeug/local.py
""" Forwards all operations to
a proxied object. The only operations not supported for forwarding
are right handed operands and any kind of assignment.
"""
__slots__ = ('__local', '__dict__', '__name__')
def __init__(self, local, name=None):
object.__setattr__(self, '_LocalProxy__local', local)
object.__setattr__(self, '__name__', name)
def _get_current_object(self):
"""Return the current object. This is useful if you want the real
object behind the proxy at a time for performance reasons or because
you want to pass the object into a different context.
"""
if not hasattr(self.__local, '__release_local__'):
return self.__local()
try:
return getattr(self.__local, self.__name__)
except AttributeError:
raise RuntimeError('no object bound to %s' % self.__name__)
@property
def __dict__(self):
try:
return self._get_current_object().__dict__
except RuntimeError:
raise AttributeError('__dict__')
def __repr__(self):
try:
obj = self._get_current_object()
except RuntimeError:
return '<%s unbound>' % self.__class__.__name__
return repr(obj)
def __nonzero__(self):
try:
return bool(self._get_current_object())
except RuntimeError:
return False
def __unicode__(self):
try:
return unicode(self._get_current_object())
except RuntimeError:
return repr(self)
def __dir__(self):
try:
return dir(self._get_current_object())
except RuntimeError:
return []
def __getattr__(self, name):
if name == '__members__':
return dir(self._get_current_object())
return getattr(self._get_current_object(), name)
def __setitem__(self, key, value):
self._get_current_object()[key] = value
def __delitem__(self, key):
del self._get_current_object()[key]
def __setslice__(self, i, j, seq):
self._get_current_object()[i:j] = seq
def __delslice__(self, i, j):
del self._get_current_object()[i:j]
__setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v)
__delattr__ = lambda x, n: delattr(x._get_current_object(), n)
__str__ = lambda x: str(x._get_current_object())
__lt__ = lambda x, o: x._get_current_object() < o
__le__ = lambda x, o: x._get_current_object() <= o
__eq__ = lambda x, o: x._get_current_object() == o
__ne__ = lambda x, o: x._get_current_object() != o
__gt__ = lambda x, o: x._get_current_object() > o
__ge__ = lambda x, o: x._get_current_object() >= o
__cmp__ = lambda x, o: cmp(x._get_current_object(), o)
__hash__ = lambda x: hash(x._get_current_object())
__call__ = lambda x, *a, **kw: x._get_current_object()(*a, **kw)
__len__ = lambda x: len(x._get_current_object())
__getitem__ = lambda x, i: x._get_current_object()[i]
__iter__ = lambda x: iter(x._get_current_object())
__contains__ = lambda x, i: i in x._get_current_object()
__getslice__ = lambda x, i, j: x._get_current_object()[i:j]
__add__ = lambda x, o: x._get_current_object() + o
__sub__ = lambda x, o: x._get_current_object() - o
__mul__ = lambda x, o: x._get_current_object() * o
__floordiv__ = lambda x, o: x._get_current_object() // o
__mod__ = lambda x, o: x._get_current_object() % o
__divmod__ = lambda x, o: x._get_current_object().__divmod__(o)
__pow__ = lambda x, o: x._get_current_object() ** o
__lshift__ = lambda x, o: x._get_current_object() << o
__rshift__ = lambda x, o: x._get_current_object() >> o
__and__ = lambda x, o: x._get_current_object() & o
__xor__ = lambda x, o: x._get_current_object() ^ o
__or__ = lambda x, o: x._get_current_object() | o
__div__ = lambda x, o: x._get_current_object().__div__(o)
__truediv__ = lambda x, o: x._get_current_object().__truediv__(o)
__neg__ = lambda x: -(x._get_current_object())
__pos__ = lambda x: +(x._get_current_object())
__abs__ = lambda x: abs(x._get_current_object())
__invert__ = lambda x: ~(x._get_current_object())
__complex__ = lambda x: complex(x._get_current_object())
__int__ = lambda x: int(x._get_current_object())
__long__ = lambda x: long(x._get_current_object())
__float__ = lambda x: float(x._get_current_object())
__oct__ = lambda x: oct(x._get_current_object())
__hex__ = lambda x: hex(x._get_current_object())
__index__ = lambda x: x._get_current_object().__index__()
__coerce__ = lambda x, o: x.__coerce__(x, o)
__enter__ = lambda x: x.__enter__()
__exit__ = lambda x, *a, **kw: x.__exit__(*a, **kw)
class DocumentProxy(LocalProxy):
__slots__ = ('__document_type', '__document', '__pk')
def __init__(self, document_type, pk):
object.__setattr__(self, '_DocumentProxy__document_type', document_type)
object.__setattr__(self, '_DocumentProxy__document', None)
object.__setattr__(self, '_DocumentProxy__pk', pk)
object.__setattr__(self, document_type._meta['id_field'], self.pk)
@property
def __class__(self):
# We need to fetch the object to determine to which class it belongs.
return self._get_current_object().__class__
def _lazy():
def fget(self):
return self.__document._lazy if self.__document else True
def fset(self, value):
self._get_current_object()._lazy = value
return property(fget, fset)
_lazy = _lazy()
# copy normally updates __dict__ which would result in errors
def __setstate__(self, state):
for k, v in state[1].iteritems():
object.__setattr__(self, k, v)
def _get_collection_name(self):
return self.__document_type._meta.get('collection', None)
def __eq__(self, other):
if other and hasattr(other, '_get_collection_name') and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk'):
if self.pk == other.pk:
return True
return False
def __ne__(self, other):
return not self.__eq__(other)
def to_dbref(self):
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in
`__raw__` queries."""
if not self.pk:
msg = "Only saved documents can have a valid dbref"
raise OperationError(msg)
return DBRef(self._get_collection_name(), self.pk)
def pk():
def fget(self):
return self.__document.pk if self.__document else self.__pk
def fset(self, value):
self._get_current_object().pk = value
return property(fget, fset)
pk = pk()
def _get_current_object(self):
if self.__document == None:
#print 'fetching', self.__document_type, self.__pk
#import traceback
#traceback.print_stack()
collection = self.__document_type._get_collection()
son = collection.find_one({'_id': self.__pk})
document = self.__document_type._from_son(son)
object.__setattr__(self, '_DocumentProxy__document', document)
return self.__document
def __nonzero__(self):
return True

View file

@ -26,6 +26,7 @@ from mongoengine.python_support import (PY3, bin_type, txt_type,
from mongoengine.base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField,
get_document, BaseDocument)
from mongoengine.base.datastructures import BaseList, BaseDict
from mongoengine.base.proxy import DocumentProxy
from mongoengine.queryset import DoesNotExist
from queryset import DO_NOTHING, QuerySet
from document import Document, EmbeddedDocument
@ -516,7 +517,7 @@ class ListField(ComplexBaseField):
super(ListField, self).__init__(**kwargs)
def value_for_instance(self, value, instance):
return BaseList(value, instance, self.name)
return BaseList(value or [], instance, self.name)
def from_python(self, val):
from_python = getattr(self.field, 'from_python', None)
@ -615,7 +616,7 @@ class DictField(ComplexBaseField):
return {k: to_python(v) for k, v in val.iteritems()} if to_python else val
def value_for_instance(self, value, instance):
return BaseDict(value, instance, self.name)
return BaseDict(value or {}, instance, self.name)
def to_mongo(self, val):
to_mongo = getattr(self.field, 'to_mongo', None)
@ -734,7 +735,7 @@ class ReferenceField(BaseField):
return value
else:
return value.id
elif isinstance(value, Document):
elif isinstance(value, (Document, DocumentProxy)):
document_type = self.document_type
# We need the id from the saved object to create the DBRef
pk = value.pk
@ -758,17 +759,22 @@ class ReferenceField(BaseField):
if value != None:
document_type = self.document_type
if self.dbref:
obj = document_type(pk=value.id)
pk = value.id
else:
if isinstance(value, DBRef):
obj = document_type(pk=value.id)
pk = value.id
else:
obj = document_type(pk=value)
obj._lazy = True
pk = value
if document_type._meta['allow_inheritance']:
# We don't know of which type the object will be.
obj = DocumentProxy(document_type, pk)
else:
obj = document_type(pk=pk)
obj._lazy = True
return obj
def from_python(self, value):
if isinstance(value, BaseDocument):
if isinstance(value, (BaseDocument, DocumentProxy)):
return value
elif value == None:
return super(ReferenceField, self).from_python(value)
@ -780,17 +786,22 @@ class ReferenceField(BaseField):
# DBRef or ID
document_type = self.document_type
if isinstance(value, DBRef):
obj = document_type(pk=value.id)
pk = value.id
else:
obj = document_type(pk=value)
obj._lazy = True
pk = value
if document_type._meta['allow_inheritance']:
# We don't know of which type the object will be.
obj = DocumentProxy(document_type, pk)
else:
obj = document_type(pk=pk)
obj._lazy = True
return obj
def prepare_query_value(self, op, value):
return self.to_mongo(self.from_python(value))
def validate(self, value):
if not isinstance(value, (self.document_type, DBRef)):
if not isinstance(value, (self.document_type, DBRef, DocumentProxy)):
self.error("A ReferenceField only accepts DBRef or documents")
if isinstance(value, Document) and value.pk is None:

View file

@ -51,13 +51,14 @@ CLASSIFIERS = [
extra_opts = {}
if sys.version_info[0] == 3:
extra_opts['use_2to3'] = True
extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6']
extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2>=2.6']
extra_opts['packages'] = find_packages(exclude=('tests',))
if "test" in sys.argv or "nosetests" in sys.argv:
extra_opts['packages'].append("tests")
extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]}
else:
extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2==2.6', 'python-dateutil']
#extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6']
extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2>=2.6']
extra_opts['packages'] = find_packages(exclude=('tests',))
setup(name='mongoengine',