diff --git a/.travis.yml b/.travis.yml index 4395107..b7c56a0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,14 +11,12 @@ env: - PYMONGO=dev DJANGO=1.4.2 - PYMONGO=2.5 DJANGO=1.5.1 - PYMONGO=2.5 DJANGO=1.4.2 - - PYMONGO=3.2 DJANGO=1.5.1 - - PYMONGO=3.3 DJANGO=1.5.1 install: - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi + - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install django==$DJANGO --use-mirrors ; true; fi - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi - - pip install https://pypi.python.org/packages/source/p/python-dateutil/python-dateutil-2.1.tar.gz#md5=1534bb15cf311f07afaa3aacba1c028b - python setup.py install script: - python setup.py test diff --git a/AUTHORS b/AUTHORS index f043207..0aa6056 100644 --- a/AUTHORS +++ b/AUTHORS @@ -9,6 +9,7 @@ Steve Challis Wilson Júnior Dan Crosta https://github.com/dcrosta Laine Herron https://github.com/LaineHerron +Thomas Steinacher http://thomasst.ch/ CONTRIBUTORS @@ -16,6 +17,8 @@ Dervived from the git logs, inevitably incomplete but all of whom and others have submitted patches, reported bugs and generally helped make MongoEngine that much better: + * Harry Marr + * Ross Lawley * blackbrrr * Florian Schlachter * Vincent Driessen @@ -112,7 +115,6 @@ that much better: * Alexander Koshelev * Jaime Irurzun * Alexandre González - * Thomas Steinacher * Tommi Komulainen * Peter Landry * biszkoptwielki @@ -169,13 +171,3 @@ that much better: * ygbourhis (https://github.com/ygbourhis) * Bob Dickinson (https://github.com/BobDickinson) * Michael Bartnett (https://github.com/michaelbartnett) - * Alon Horev (https://github.com/alonho) - * Kelvin Hammond (https://github.com/kelvinhammond) - * Jatin- (https://github.com/jatin-) - * Paul Uithol (https://github.com/PaulUithol) - * Thom Knowles (https://github.com/fleat) - * Paul (https://github.com/squamous) - * Olivier Cortès (https://github.com/Karmak23) - * crazyzubr (https://github.com/crazyzubr) - * FrankSomething (https://github.com/FrankSomething) - * Alexandr Morozov (https://github.com/LK4D4) diff --git a/DIFFERENCES.md b/DIFFERENCES.md new file mode 100644 index 0000000..01ae0f2 --- /dev/null +++ b/DIFFERENCES.md @@ -0,0 +1,37 @@ +Differences between Mongomallard and Mongoengine +----- + +* All document fields are lazily evaluated, resulting in much faster object initialization time. +* `_data` is removed due to lazy evaluation. `to_dict()` can be used to convert a document to a dictionary, and `_internal_data` contains previously evaluated data. +* Field methods `to_python`, `from_python`, `to_mongo`, `value_for_instance`: + * `to_python` is called when converting from a MongoDB type to a document Python type only. + * `from_python` is called when converting an assignment in Python to the document Python type. + * `to_mongo` is called when converting from a document Python type to a MongoDB type. + * `value_for_instance` is called just before returning a value in Python allowing for instance-specific transformations. +* `pre_init`, `post_init`, `pre_save_post_validation` signals are removed to ensure fast object initialization. +* `DecimalField` is removed since there is no corresponding MongoDB type +* `LongField` is removed since it is equivalent with `IntField` +* Adding `SafeReferenceField` which returns None if the reference does not exist. +* Adding `SafeReferenceListField` which doesn't return references that don't exist. +* Accessing a `ListField(ReferenceField)` doesn't automatically dereference all objects since they are lazily evaluated. A `SafeReferenceListField` may be used instead. +* Accessing a related object's id doesn't fetch the object from the database, e.g. `book.author.id` where author is a `ReferenceField` will not make a database lookup except when using a `SafeReferenceField`. When inheritance is allowed, a proxy object will be returned, otherwise a lazy object from the referenced document class will be returned. +* The primary key is only stored as `_id` in the database and is referenced in Python as `pk` or as the name of the primary key field. +* Saves are not cascaded by default. +* `Document.save()` supports `full=True` keyword argument to force saving all model fields. +* `_get_changed_fields()` / `_changed_fields` returns a set of field names (not db field names) +* Simplified `EmailField` email regex to be more compatible +* Assigning invalid types (e.g. an invalid string to `IntField`) raises immediately a `ValueError` +* `order_by()` without an argument resets the ordering (no ordering will be applied) + +Untested / not implemented yet: +----- + +* Dynamic documents / `DynamicField`, dynamic addition/deletion of fields +* Field display name methods +* `SequenceField` +* Pickling documents +* `FileField` +* All Geo fields +* `no_dereference()` +* using `SafeReferenceListField` with `GenericReferenceField` +* `max_depth` argument for `doc.reload()` diff --git a/README.md b/README.md new file mode 100644 index 0000000..ab8002c --- /dev/null +++ b/README.md @@ -0,0 +1,85 @@ +MongoMallard +============ + +MongoMallard is a fast ORM-like layer on top of PyMongo, based on MongoEngine. + +* Repository: https://github.com/elasticsales/mongomallard +* See [README_MONGOENGINE](https://github.com/elasticsales/mongomallard/blob/master/README_MONGOENGINE.rst) for MongoEngine's README. +* See [DIFFERENCES](https://github.com/elasticsales/mongomallard/blob/master/DIFFERENCES.md) for differences between MongoEngine and MongoMallard. + + +Benchmarks +---------- + +Sample run on a 2.7 GHz Intel Core i5 running OS X 10.8.3 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MongoEngine 0.8.2 (ede9fcf)MongoMallard (478062c)Speedup
Doc initialization52.494us25.195us2.08x
Doc getattr1.339us0.584us2.29x
Doc setattr3.064us2.550us1.20x
Doc to mongo49.415us26.497us1.86x
Load from SON61.475us4.510us13.63x
Save to database434.389us289.972us2.29x
Load from database558.178us480.690us1.16x
Save/delete big object to database98.838ms65.789ms1.50x
Serialize big object from database31.390ms20.265ms1.55x
Load big object from database41.159ms1.400ms29.40x
+ +See [tests/benchmark.py](https://github.com/elasticsales/mongomallard/blob/master/tests/benchmark.py) for source code. diff --git a/README.rst b/README_MONGOENGINE.rst similarity index 100% rename from README.rst rename to README_MONGOENGINE.rst diff --git a/docs/_themes/nature/static/nature.css_t b/docs/_themes/nature/static/nature.css_t index 337760b..03b0379 100644 --- a/docs/_themes/nature/static/nature.css_t +++ b/docs/_themes/nature/static/nature.css_t @@ -2,15 +2,11 @@ * Sphinx stylesheet -- default theme * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ - + @import url("basic.css"); - -#changelog p.first {margin-bottom: 0 !important;} -#changelog p {margin-top: 0 !important; - margin-bottom: 0 !important;} - + /* -- page layout ----------------------------------------------------------- */ - + body { font-family: Arial, sans-serif; font-size: 100%; @@ -32,18 +28,18 @@ div.bodywrapper { hr{ border: 1px solid #B1B4B6; } - + div.document { background-color: #eee; } - + div.body { background-color: #ffffff; color: #3E4349; padding: 0 30px 30px 30px; font-size: 0.8em; } - + div.footer { color: #555; width: 100%; @@ -51,12 +47,12 @@ div.footer { text-align: center; font-size: 75%; } - + div.footer a { color: #444; text-decoration: underline; } - + div.related { background-color: #6BA81E; line-height: 32px; @@ -64,11 +60,11 @@ div.related { text-shadow: 0px 1px 0 #444; font-size: 0.80em; } - + div.related a { color: #E2F3CC; } - + div.sphinxsidebar { font-size: 0.75em; line-height: 1.5em; @@ -77,7 +73,7 @@ div.sphinxsidebar { div.sphinxsidebarwrapper{ padding: 20px 0; } - + div.sphinxsidebar h3, div.sphinxsidebar h4 { font-family: Arial, sans-serif; @@ -93,30 +89,30 @@ div.sphinxsidebar h4 { div.sphinxsidebar h4{ font-size: 1.1em; } - + div.sphinxsidebar h3 a { color: #444; } - - + + div.sphinxsidebar p { color: #888; padding: 5px 20px; } - + div.sphinxsidebar p.topless { } - + div.sphinxsidebar ul { margin: 10px 20px; padding: 0; color: #000; } - + div.sphinxsidebar a { color: #444; } - + div.sphinxsidebar input { border: 1px solid #ccc; font-family: sans-serif; @@ -126,19 +122,19 @@ div.sphinxsidebar input { div.sphinxsidebar input[type=text]{ margin-left: 20px; } - + /* -- body styles ----------------------------------------------------------- */ - + a { color: #005B81; text-decoration: none; } - + a:hover { color: #E32E00; text-decoration: underline; } - + div.body h1, div.body h2, div.body h3, @@ -153,30 +149,30 @@ div.body h6 { padding: 5px 0 5px 10px; text-shadow: 0px 1px 0 white } - + div.body h1 { border-top: 20px solid white; margin-top: 0; font-size: 200%; } div.body h2 { font-size: 150%; background-color: #C8D5E3; } div.body h3 { font-size: 120%; background-color: #D8DEE3; } div.body h4 { font-size: 110%; background-color: #D8DEE3; } div.body h5 { font-size: 100%; background-color: #D8DEE3; } div.body h6 { font-size: 100%; background-color: #D8DEE3; } - + a.headerlink { color: #c60f0f; font-size: 0.8em; padding: 0 4px 0 4px; text-decoration: none; } - + a.headerlink:hover { background-color: #c60f0f; color: white; } - + div.body p, div.body dd, div.body li { line-height: 1.5em; } - + div.admonition p.admonition-title + p { display: inline; } @@ -189,29 +185,29 @@ div.note { background-color: #eee; border: 1px solid #ccc; } - + div.seealso { background-color: #ffc; border: 1px solid #ff6; } - + div.topic { background-color: #eee; } - + div.warning { background-color: #ffe4e4; border: 1px solid #f66; } - + p.admonition-title { display: inline; } - + p.admonition-title:after { content: ":"; } - + pre { padding: 10px; background-color: White; @@ -223,7 +219,7 @@ pre { -webkit-box-shadow: 1px 1px 1px #d8d8d8; -moz-box-shadow: 1px 1px 1px #d8d8d8; } - + tt { background-color: #ecf0f3; color: #222; diff --git a/docs/apireference.rst b/docs/apireference.rst index 9057de5..d062727 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -44,21 +44,12 @@ Context Managers Querying ======== -.. automodule:: mongoengine.queryset - :synopsis: Queryset level operations +.. autoclass:: mongoengine.queryset.QuerySet + :members: - .. autoclass:: mongoengine.queryset.QuerySet - :members: - :inherited-members: + .. automethod:: mongoengine.queryset.QuerySet.__call__ - .. automethod:: QuerySet.__call__ - - .. autoclass:: mongoengine.queryset.QuerySetNoCache - :members: - - .. automethod:: mongoengine.queryset.QuerySetNoCache.__call__ - - .. autofunction:: mongoengine.queryset.queryset_manager +.. autofunction:: mongoengine.queryset.queryset_manager Fields ====== diff --git a/docs/changelog.rst b/docs/changelog.rst index 926fb8a..1927bee 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,49 +2,13 @@ Changelog ========= -Changes in 0.8.4 -================ -- Remove database name necessity in uri connection schema (#452) -- Fixed "$pull" semantics for nested ListFields (#447) -- Allow fields to be named the same as query operators (#445) -- Updated field filter logic - can now exclude subclass fields (#443) -- Fixed dereference issue with embedded listfield referencefields (#439) -- Fixed slice when using inheritance causing fields to be excluded (#437) -- Fixed ._get_db() attribute after a Document.switch_db() (#441) -- Dynamic Fields store and recompose Embedded Documents / Documents correctly (#449) -- Handle dynamic fieldnames that look like digits (#434) -- Added get_user_document and improve mongo_auth module (#423) -- Added str representation of GridFSProxy (#424) -- Update transform to handle docs erroneously passed to unset (#416) -- Fixed indexing - turn off _cls (#414) -- Fixed dereference threading issue in ComplexField.__get__ (#412) -- Fixed QuerySetNoCache.count() caching (#410) -- Don't follow references in _get_changed_fields (#422, #417) -- Allow args and kwargs to be passed through to_json (#420) - Changes in 0.8.3 ================ -- Fixed EmbeddedDocuments with `id` also storing `_id` (#402) -- Added get_proxy_object helper to filefields (#391) -- Added QuerySetNoCache and QuerySet.no_cache() for lower memory consumption (#365) -- Fixed sum and average mapreduce dot notation support (#375, #376, #393) -- Fixed as_pymongo to return the id (#386) -- Document.select_related() now respects `db_alias` (#377) -- Reload uses shard_key if applicable (#384) -- Dynamic fields are ordered based on creation and stored in _fields_ordered (#396) - - **Potential breaking change:** http://docs.mongoengine.org/en/latest/upgrade.html#to-0-8-3 - -- Fixed pickling dynamic documents `_dynamic_fields` (#387) -- Fixed ListField setslice and delslice dirty tracking (#390) -- Added Django 1.5 PY3 support (#392) - Added match ($elemMatch) support for EmbeddedDocuments (#379) - Fixed weakref being valid after reload (#374) - Fixed queryset.get() respecting no_dereference (#373) - Added full_result kwarg to update (#380) - - Changes in 0.8.2 ================ - Added compare_indexes helper (#361) diff --git a/docs/django.rst b/docs/django.rst index 62d4dd4..da15188 100644 --- a/docs/django.rst +++ b/docs/django.rst @@ -45,7 +45,7 @@ The :mod:`~mongoengine.django.auth` module also contains a Custom User model ================= Django 1.5 introduced `Custom user Models -`_ +` which can be used as an alternative to the MongoEngine authentication backend. The main advantage of this option is that other components relying on @@ -74,7 +74,7 @@ An additional ``MONGOENGINE_USER_DOCUMENT`` setting enables you to replace the The custom :class:`User` must be a :class:`~mongoengine.Document` class, but otherwise has the same requirements as a standard custom user model, as specified in the `Django Documentation -`_. +`. In particular, the custom class must define :attr:`USERNAME_FIELD` and :attr:`REQUIRED_FIELDS` attributes. @@ -128,7 +128,7 @@ appended to the filename until the generated filename doesn't exist. The >>> fs.listdir() ([], [u'hello.txt']) -All files will be saved and retrieved in GridFS via the :class:`FileDocument` +All files will be saved and retrieved in GridFS via the :class::`FileDocument` document, allowing easy access to the files without the GridFSStorage backend.:: @@ -137,36 +137,3 @@ backend.:: [] .. versionadded:: 0.4 - -Shortcuts -========= -Inspired by the `Django shortcut get_object_or_404 -`_, -the :func:`~mongoengine.django.shortcuts.get_document_or_404` method returns -a document or raises an Http404 exception if the document does not exist:: - - from mongoengine.django.shortcuts import get_document_or_404 - - admin_user = get_document_or_404(User, username='root') - -The first argument may be a Document or QuerySet object. All other passed arguments -and keyword arguments are used in the query:: - - foo_email = get_document_or_404(User.objects.only('email'), username='foo', is_active=True).email - -.. note:: Like with :func:`get`, a MultipleObjectsReturned will be raised if more than one - object is found. - - -Also inspired by the `Django shortcut get_list_or_404 -`_, -the :func:`~mongoengine.django.shortcuts.get_list_or_404` method returns a list of -documents or raises an Http404 exception if the list is empty:: - - from mongoengine.django.shortcuts import get_list_or_404 - - active_users = get_list_or_404(User, is_active=True) - -The first argument may be a Document or QuerySet object. All other passed -arguments and keyword arguments are used to filter the query. - diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index f681aad..854e2c3 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -23,15 +23,12 @@ arguments should be provided:: connect('project1', username='webapp', password='pwd123') -Uri style connections are also supported - just supply the uri as -the :attr:`host` to +Uri style connections are also supported as long as you include the database +name - just supply the uri as the :attr:`host` to :func:`~mongoengine.connect`:: connect('project1', host='mongodb://localhost/database_name') -Note that database name from uri has priority over name -in ::func:`~mongoengine.connect` - ReplicaSets =========== diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index ba1af33..a61d8fe 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -54,7 +54,7 @@ be saved :: There is one caveat on Dynamic Documents: fields cannot start with `_` -Dynamic fields are stored in creation order *after* any declared fields. +Dynamic fields are stored in alphabetical order *after* any declared fields. Fields ====== @@ -442,8 +442,6 @@ The following example shows a :class:`Log` document that will be limited to ip_address = StringField() meta = {'max_documents': 1000, 'max_size': 2000000} -.. defining-indexes_ - Indexes ======= @@ -487,35 +485,6 @@ If a dictionary is passed then the following options are available: Inheritance adds extra fields indices see: :ref:`document-inheritance`. -Global index default options ----------------------------- - -There are a few top level defaults for all indexes that can be set:: - - class Page(Document): - title = StringField() - rating = StringField() - meta = { - 'index_options': {}, - 'index_background': True, - 'index_drop_dups': True, - 'index_cls': False - } - - -:attr:`index_options` (Optional) - Set any default index options - see the `full options list `_ - -:attr:`index_background` (Optional) - Set the default value for if an index should be indexed in the background - -:attr:`index_drop_dups` (Optional) - Set the default value for if an index should drop duplicates - -:attr:`index_cls` (Optional) - A way to turn off a specific index for _cls. - - Compound Indexes and Indexing sub documents ------------------------------------------- @@ -589,11 +558,6 @@ documentation for more information. A common usecase might be session data:: ] } -.. warning:: TTL indexes happen on the MongoDB server and not in the application - code, therefore no signals will be fired on document deletion. - If you need signals to be fired on deletion, then you must handle the - deletion of Documents in your application code. - Comparing Indexes ----------------- @@ -689,6 +653,7 @@ document.:: .. note:: From 0.8 onwards you must declare :attr:`allow_inheritance` defaults to False, meaning you must set it to True to use inheritance. + Working with existing data -------------------------- As MongoEngine no longer defaults to needing :attr:`_cls` you can quickly and @@ -708,25 +673,3 @@ defining all possible field types. If you use :class:`~mongoengine.Document` and the database contains data that isn't defined then that data will be stored in the `document._data` dictionary. - -Abstract classes -================ - -If you want to add some extra functionality to a group of Document classes but -you don't need or want the overhead of inheritance you can use the -:attr:`abstract` attribute of :attr:`-mongoengine.Document.meta`. -This won't turn on :ref:`document-inheritance` but will allow you to keep your -code DRY:: - - class BaseDocument(Document): - meta = { - 'abstract': True, - } - def check_permissions(self): - ... - - class User(BaseDocument): - ... - -Now the User class will have access to the inherited `check_permissions` method -and won't store any of the extra `_cls` information. diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index f50985b..1350130 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -16,9 +16,7 @@ fetch documents from the database:: .. note:: As of MongoEngine 0.8 the querysets utilise a local cache. So iterating - it multiple times will only cause a single query. If this is not the - desired behavour you can call :class:`~mongoengine.QuerySet.no_cache` - (version **0.8.3+**) to return a non-caching queryset. + it multiple times will only cause a single query. Filtering queries ================= @@ -497,6 +495,7 @@ that you may use with these methods: * ``unset`` -- delete a particular value (since MongoDB v1.3+) * ``inc`` -- increment a value by a given amount * ``dec`` -- decrement a value by a given amount +* ``pop`` -- remove the last item from a list * ``push`` -- append a value to a list * ``push_all`` -- append several values to a list * ``pop`` -- remove the first or last element of a list diff --git a/docs/upgrade.rst b/docs/upgrade.rst index a1fccea..c3d3182 100644 --- a/docs/upgrade.rst +++ b/docs/upgrade.rst @@ -2,22 +2,12 @@ Upgrading ######### - -0.8.2 to 0.8.3 -************** - -Minor change that may impact users: - -DynamicDocument fields are now stored in creation order after any declared -fields. Previously they were stored alphabetically. - - 0.7 to 0.8 ********** There have been numerous backwards breaking changes in 0.8. The reasons for -these are to ensure that MongoEngine has sane defaults going forward and that it -performs the best it can out of the box. Where possible there have been +these are ensure that MongoEngine has sane defaults going forward and +performs the best it can out the box. Where possible there have been FutureWarnings to help get you ready for the change, but that hasn't been possible for the whole of the release. @@ -71,7 +61,7 @@ inherited classes like so: :: Document Definition ------------------- -The default for inheritance has changed - it is now off by default and +The default for inheritance has changed - its now off by default and :attr:`_cls` will not be stored automatically with the class. So if you extend your :class:`~mongoengine.Document` or :class:`~mongoengine.EmbeddedDocuments` you will need to declare :attr:`allow_inheritance` in the meta data like so: :: @@ -81,7 +71,7 @@ you will need to declare :attr:`allow_inheritance` in the meta data like so: :: meta = {'allow_inheritance': True} -Previously, if you had data in the database that wasn't defined in the Document +Previously, if you had data the database that wasn't defined in the Document definition, it would set it as an attribute on the document. This is no longer the case and the data is set only in the ``document._data`` dictionary: :: @@ -102,8 +92,8 @@ the case and the data is set only in the ``document._data`` dictionary: :: AttributeError: 'Animal' object has no attribute 'size' The Document class has introduced a reserved function `clean()`, which will be -called before saving the document. If your document class happens to have a method -with the same name, please try to rename it. +called before saving the document. If your document class happen to have a method +with the same name, please try rename it. def clean(self): pass @@ -111,7 +101,7 @@ with the same name, please try to rename it. ReferenceField -------------- -ReferenceFields now store ObjectIds by default - this is more efficient than +ReferenceFields now store ObjectId's by default - this is more efficient than DBRefs as we already know what Document types they reference:: # Old code @@ -157,7 +147,7 @@ UUIDFields now default to storing binary values:: class Animal(Document): uuid = UUIDField(binary=False) -To migrate all the uuids you need to touch each object and mark it as dirty +To migrate all the uuid's you need to touch each object and mark it as dirty eg:: # Doc definition @@ -175,7 +165,7 @@ eg:: DecimalField ------------ -DecimalFields now store floats - previously it was storing strings and that +DecimalField now store floats - previous it was storing strings and that made it impossible to do comparisons when querying correctly.:: # Old code @@ -186,7 +176,7 @@ made it impossible to do comparisons when querying correctly.:: class Person(Document): balance = DecimalField(force_string=True) -To migrate all the DecimalFields you need to touch each object and mark it as dirty +To migrate all the uuid's you need to touch each object and mark it as dirty eg:: # Doc definition @@ -198,7 +188,7 @@ eg:: p._mark_as_changed('balance') p.save() -.. note:: DecimalFields have also been improved with the addition of precision +.. note:: DecimalField's have also been improved with the addition of precision and rounding. See :class:`~mongoengine.fields.DecimalField` for more information. `An example test migration for DecimalFields is available on github @@ -207,7 +197,7 @@ eg:: Cascading Saves --------------- To improve performance document saves will no longer automatically cascade. -Any changes to a Document's references will either have to be saved manually or +Any changes to a Documents references will either have to be saved manually or you will have to explicitly tell it to cascade on save:: # At the class level: @@ -249,7 +239,7 @@ update your code like so: :: # Update example a) assign queryset after a change: mammals = Animal.objects(type="mammal") - carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so filter can be applied + carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so fitler can be applied [m for m in carnivores] # This will return all carnivores # Update example b) chain the queryset: @@ -276,7 +266,7 @@ queryset you should upgrade to use count:: .only() now inline with .exclude() ---------------------------------- -The behaviour of `.only()` was highly ambiguous, now it works in mirror fashion +The behaviour of `.only()` was highly ambious, now it works in the mirror fashion to `.exclude()`. Chaining `.only()` calls will increase the fields required:: # Old code @@ -440,7 +430,7 @@ main areas of changed are: choices in fields, map_reduce and collection names. Choice options: =============== -Are now expected to be an iterable of tuples, with the first element in each +Are now expected to be an iterable of tuples, with the first element in each tuple being the actual value to be stored. The second element is the human-readable name for the option. @@ -462,8 +452,8 @@ such the following have been changed: Default collection naming ========================= -Previously it was just lowercase, it's now much more pythonic and readable as -it's lowercase and underscores, previously :: +Previously it was just lowercase, its now much more pythonic and readable as +its lowercase and underscores, previously :: class MyAceDocument(Document): pass @@ -530,5 +520,5 @@ Alternatively, you can rename your collections eg :: mongodb 1.8 > 2.0 + =================== -It's been reported that indexes may need to be recreated to the newer version of indexes. +Its been reported that indexes may need to be recreated to the newer version of indexes. To do this drop indexes and call ``ensure_indexes`` on each model. diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 2b68b3c..875c916 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -15,7 +15,8 @@ import django __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + list(queryset.__all__) + signals.__all__ + list(errors.__all__)) -VERSION = (0, 8, 4) +VERSION = (0, 8, 2) +MALLARD = True def get_version(): diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 4652fb5..adcd8d0 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -108,14 +108,6 @@ class BaseList(list): self._mark_as_changed() return super(BaseList, self).__delitem__(*args, **kwargs) - def __setslice__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__setslice__(*args, **kwargs) - - def __delslice__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__delslice__(*args, **kwargs) - def __getstate__(self): self.instance = None self._dereferenced = False diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index cea2f09..b5e357d 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -4,7 +4,7 @@ import numbers from functools import partial import pymongo -from bson import json_util, ObjectId +from bson import json_util from bson.dbref import DBRef from bson.son import SON @@ -15,6 +15,7 @@ from mongoengine.errors import (ValidationError, InvalidDocumentError, from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, to_str_keys_recursive) +from mongoengine.base.proxy import DocumentProxy from mongoengine.base.common import get_document, ALLOW_INHERITANCE from mongoengine.base.datastructures import BaseDict, BaseList from mongoengine.base.fields import ComplexBaseField @@ -23,155 +24,52 @@ __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') NON_FIELD_ERRORS = '__all__' +_set = object.__setattr__ + class BaseDocument(object): - _dynamic = False - _created = True - _dynamic_lock = True + #_dynamic = False + #_dynamic_lock = True _initialised = False - def __init__(self, *args, **values): + def __init__(self, _son=None, **values): """ Initialise a document or embedded document - :param __auto_convert: Try and will cast python objects to Object types :param values: A dictionary of values for the document """ - if args: - # Combine positional arguments with named arguments. - # We only want named arguments. - field = iter(self._fields_ordered) - # If its an automatic id field then skip to the first defined field - if self._auto_id_field: - next(field) - for value in args: - name = next(field) - if name in values: - raise TypeError("Multiple values for keyword argument '" + name + "'") - values[name] = value - __auto_convert = values.pop("__auto_convert", True) - signals.pre_init.send(self.__class__, document=self, values=values) + _set(self, '_db_data', _son) + _set(self, '_lazy', False) + _set(self, '_internal_data', {}) + _set(self, '_changed_fields', set()) + if values: + pk = values.pop('pk', None) + for field in set(self._fields.keys()).intersection(values.keys()): + setattr(self, field, values[field]) + if pk != None: + self.pk = pk - self._data = {} - self._dynamic_fields = SON() + def __delattr__(self, name): + default = self._fields[name].default + value = default() if callable(default) else default + setattr(self, name, value) - # Assign default values to instance - for key, field in self._fields.iteritems(): - if self._db_field_map.get(key, key) in values: - continue - value = getattr(self, key, None) - setattr(self, key, value) - - # Set passed values after initialisation - if self._dynamic: - dynamic_data = {} - for key, value in values.iteritems(): - if key in self._fields or key == '_id': - setattr(self, key, value) - elif self._dynamic: - dynamic_data[key] = value - else: - FileField = _import_class('FileField') - for key, value in values.iteritems(): - if key == '__auto_convert': - continue - key = self._reverse_db_field_map.get(key, key) - if key in self._fields or key in ('id', 'pk', '_cls'): - if __auto_convert and value is not None: - field = self._fields.get(key) - if field and not isinstance(field, FileField): - value = field.to_python(value) - setattr(self, key, value) - else: - self._data[key] = value - - # Set any get_fieldname_display methods - self.__set_field_display() - - if self._dynamic: - self._dynamic_lock = False - for key, value in dynamic_data.iteritems(): - setattr(self, key, value) - - # Flag initialised - self._initialised = True - signals.post_init.send(self.__class__, document=self) - - def __delattr__(self, *args, **kwargs): - """Handle deletions of fields""" - field_name = args[0] - if field_name in self._fields: - default = self._fields[field_name].default - if callable(default): - default = default() - setattr(self, field_name, default) - else: - super(BaseDocument, self).__delattr__(*args, **kwargs) - - def __setattr__(self, name, value): - # Handle dynamic data only if an initialised dynamic document - if self._dynamic and not self._dynamic_lock: - - field = None - if not hasattr(self, name) and not name.startswith('_'): - DynamicField = _import_class("DynamicField") - field = DynamicField(db_field=name) - field.name = name - self._dynamic_fields[name] = field - self._fields_ordered += (name,) - - if not name.startswith('_'): - value = self.__expand_dynamic_values(name, value) - - # Handle marking data as changed - if name in self._dynamic_fields: - self._data[name] = value - if hasattr(self, '_changed_fields'): - self._mark_as_changed(name) - - if (self._is_document and not self._created and - name in self._meta.get('shard_key', tuple()) and - self._data.get(name) != value): - OperationError = _import_class('OperationError') - msg = "Shard Keys are immutable. Tried to update %s" % name - raise OperationError(msg) - - # Check if the user has created a new instance of a class - if (self._is_document and self._initialised - and self._created and name == self._meta['id_field']): - super(BaseDocument, self).__setattr__('_created', False) - - super(BaseDocument, self).__setattr__(name, value) - - def __getstate__(self): - data = {} - for k in ('_changed_fields', '_initialised', '_created', - '_dynamic_fields', '_fields_ordered'): - if hasattr(self, k): - data[k] = getattr(self, k) - data['_data'] = self.to_mongo() - return data - - def __setstate__(self, data): - if isinstance(data["_data"], SON): - data["_data"] = self.__class__._from_son(data["_data"])._data - for k in ('_changed_fields', '_initialised', '_created', '_data', - '_fields_ordered', '_dynamic_fields'): - if k in data: - setattr(self, k, data[k]) - dynamic_fields = data.get('_dynamic_fields') or SON() - for k in dynamic_fields.keys(): - setattr(self, k, data["_data"].get(k)) + @property + def _created(self): + return self._db_data != None or self._lazy def __iter__(self): + if 'id' in self._fields and 'id' not in self._fields_ordered: + return iter(('id', ) + self._fields_ordered) + return iter(self._fields_ordered) def __getitem__(self, name): """Dictionary-style field access, return a field's value if present. """ try: - if name in self._fields_ordered: + if name in self._fields: return getattr(self, name) except AttributeError: pass @@ -192,8 +90,8 @@ class BaseDocument(object): except AttributeError: return False - def __len__(self): - return len(self._data) + def __unicode__(self): + return u'%s object' % self.__class__.__name__ def __repr__(self): try: @@ -212,9 +110,12 @@ class BaseDocument(object): return txt_type('%s object' % self.__class__.__name__) def __eq__(self, other): - if isinstance(other, self.__class__) and hasattr(other, 'id'): - if self.id == other.id: - return True + if isinstance(other, DocumentProxy) and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk') and self.pk == other.pk: + return True + + if isinstance(other, self.__class__) and hasattr(other, 'pk') and self.pk == other.pk: + return True + return False def __ne__(self, other): @@ -240,42 +141,16 @@ class BaseDocument(object): def to_mongo(self): """Return as SON data ready for use with MongoDB. """ - data = SON() - data["_id"] = None - data['_cls'] = self._class_name + sets, unsets = self._delta(full=True) + son = SON(data=sets) + allow_inheritance = self._meta.get('allow_inheritance', + ALLOW_INHERITANCE) + if allow_inheritance: + son['_cls'] = self._class_name + return son - for field_name in self: - value = self._data.get(field_name, None) - field = self._fields.get(field_name) - if field is None and self._dynamic: - field = self._dynamic_fields.get(field_name) - - if value is not None: - value = field.to_mongo(value) - - # Handle self generating fields - if value is None and field._auto_gen: - value = field.generate() - self._data[field_name] = value - - if value is not None: - data[field.db_field] = value - - # If "_id" has not been set, then try and set it - Document = _import_class("Document") - if isinstance(self, Document): - if data["_id"] is None: - data["_id"] = self._data.get("id", None) - - if data['_id'] is None: - data.pop('_id') - - # Only add _cls if allow_inheritance is True - if (not hasattr(self, '_meta') or - not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): - data.pop('_cls') - - return data + def to_dict(self): + return dict((field, getattr(self, field)) for field in self._fields) def validate(self, clean=True): """Ensure that all fields' values are valid and that required fields @@ -290,8 +165,11 @@ class BaseDocument(object): errors[NON_FIELD_ERRORS] = error # Get a list of tuples of field names and their current values - fields = [(self._fields.get(name, self._dynamic_fields.get(name)), - self._data.get(name)) for name in self._fields_ordered] + fields = [(field, getattr(self, name)) + for name, field in self._fields.items()] + #if self._dynamic: + # fields += [(field, self._data.get(name)) + # for name, field in self._dynamic_fields.items()] EmbeddedDocumentField = _import_class("EmbeddedDocumentField") GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") @@ -321,9 +199,9 @@ class BaseDocument(object): message = "ValidationError (%s:%s) " % (self._class_name, pk) raise ValidationError(message, errors=errors) - def to_json(self, *args, **kwargs): + def to_json(self): """Converts a document to JSON""" - return json_util.dumps(self.to_mongo(), *args, **kwargs) + return json_util.dumps(self.to_mongo()) @classmethod def from_json(cls, json_data): @@ -365,17 +243,40 @@ class BaseDocument(object): return value def _mark_as_changed(self, key): - """Marks a key as explicitly changed by the user + """Marks a key as explicitly changed by the user. """ - if not key: - return - key = self._db_field_map.get(key, key) - if (hasattr(self, '_changed_fields') and - key not in self._changed_fields): - self._changed_fields.append(key) + + if key: + self._changed_fields.add(key) + + def _get_changed_fields(self): + """Returns a list of all fields that have explicitly been changed. + """ + changed_fields = set(self._changed_fields) + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + for field_name, field in self._fields.iteritems(): + if field_name not in changed_fields: + if (isinstance(field, ComplexBaseField) and + isinstance(field.field, EmbeddedDocumentField)): + field_value = getattr(self, field_name, None) + if field_value: + for idx in (field_value if isinstance(field_value, dict) + else xrange(len(field_value))): + changed_subfields = field_value[idx]._get_changed_fields() + if changed_subfields: + changed_fields |= set(['.'.join([field_name, str(idx), subfield_name]) + for subfield_name in changed_subfields]) + elif isinstance(field, EmbeddedDocumentField): + field_value = getattr(self, field_name, None) + if field_value: + changed_subfields = field_value._get_changed_fields() + if changed_subfields: + changed_fields |= set(['.'.join([field_name, subfield_name]) + for subfield_name in changed_subfields]) + return changed_fields def _clear_changed_fields(self): - self._changed_fields = [] + _set(self, '_changed_fields', set()) EmbeddedDocumentField = _import_class("EmbeddedDocumentField") for field_name, field in self._fields.iteritems(): if (isinstance(field, ComplexBaseField) and @@ -390,136 +291,57 @@ class BaseDocument(object): if field_value: field_value._clear_changed_fields() - def _get_changed_fields(self, inspected=None): - """Returns a list of all fields that have explicitly been changed. - """ - EmbeddedDocument = _import_class("EmbeddedDocument") - DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") - ReferenceField = _import_class("ReferenceField") - _changed_fields = [] - _changed_fields += getattr(self, '_changed_fields', []) + def _delta(self, full=False): + sets = {} + unsets = {} - inspected = inspected or set() - if hasattr(self, 'id'): - if self.id in inspected: - return _changed_fields - inspected.add(self.id) - for field_name in self._fields_ordered: - db_field_name = self._db_field_map.get(field_name, field_name) - key = '%s.' % db_field_name - data = self._data.get(field_name, None) - field = self._fields.get(field_name) + def get_db_value(field, value): + if value is None: + value = field.default() if callable(field.default) else field.default + return field.to_mongo(value) - if hasattr(data, 'id'): - if data.id in inspected: - continue - inspected.add(data.id) - if isinstance(field, ReferenceField): - continue - elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) - and db_field_name not in _changed_fields): - # Find all embedded fields that have been changed - changed = data._get_changed_fields(inspected) - _changed_fields += ["%s%s" % (key, k) for k in changed if k] - elif (isinstance(data, (list, tuple, dict)) and - db_field_name not in _changed_fields): - # Loop list / dict fields as they contain documents - # Determine the iterator to use - if not hasattr(data, 'items'): - iterator = enumerate(data) - else: - iterator = data.iteritems() - for index, value in iterator: - if not hasattr(value, '_get_changed_fields'): - continue - if (hasattr(field, 'field') and - isinstance(field.field, ReferenceField)): - continue - list_key = "%s%s." % (key, index) - changed = value._get_changed_fields(inspected) - _changed_fields += ["%s%s" % (list_key, k) - for k in changed if k] - return _changed_fields - def _delta(self): - """Returns the delta (set, unset) of the changes for a document. - Gets any values that have been explicitly changed. - """ - # Handles cases where not loaded from_son but has _id - doc = self.to_mongo() + if full or not self._created: + fields = self._fields.iteritems() + db_data = ((self._db_field_map.get(field_name, field_name), + get_db_value(field, getattr(self, field_name))) + for field_name, field in fields) - set_fields = self._get_changed_fields() - unset_data = {} - parts = [] - if hasattr(self, '_changed_fields'): - set_data = {} - # Fetch each set item from its path - for path in set_fields: - parts = path.split('.') - d = doc - new_path = [] - for p in parts: - if isinstance(d, (ObjectId, DBRef)): - break - elif isinstance(d, list) and p.isdigit(): - d = d[int(p)] - elif hasattr(d, 'get'): - d = d.get(p) - new_path.append(p) - path = '.'.join(new_path) - set_data[path] = d else: - set_data = doc - if '_id' in set_data: - del(set_data['_id']) + # List of (db_field_name, db_value) tuples. + db_data = [] - # Determine if any changed items were actually unset. - for path, value in set_data.items(): - if value or isinstance(value, (numbers.Number, bool)): - continue + for field_name in self._get_changed_fields(): + parts = field_name.split('.') - # If we've set a value that ain't the default value dont unset it. - default = None - if (self._dynamic and len(parts) and parts[0] in - self._dynamic_fields): - del(set_data[path]) - unset_data[path] = 1 - continue - elif path in self._fields: - default = self._fields[path].default - else: # Perform a full lookup for lists / embedded lookups - d = self - parts = path.split('.') - db_field_name = parts.pop() - for p in parts: - if isinstance(d, list) and p.isdigit(): - d = d[int(p)] - elif (hasattr(d, '__getattribute__') and - not isinstance(d, dict)): - real_path = d._reverse_db_field_map.get(p, p) - d = getattr(d, real_path) - else: - d = d.get(p) + db_field_parts = [] - if hasattr(d, '_fields'): - field_name = d._reverse_db_field_map.get(db_field_name, - db_field_name) - if field_name in d._fields: - default = d._fields.get(field_name).default - else: - default = None + value = self + for part in parts: + if isinstance(value, list) and part.isdigit(): + db_field_parts.append(part) + field = field.field + value = value[int(part)] + elif isinstance(value, dict): + db_field_parts.append(part) + field = field.field + value = value[part] + else: # It's a document + obj = value + field = obj._fields[part] + db_field_parts.append(obj._db_field_map.get(part, part)) + value = getattr(obj, part) - if default is not None: - if callable(default): - default = default() + db_data.append(('.'.join(db_field_parts), get_db_value(field, value))) - if default != value: - continue + for db_field_name, db_value in db_data: + if db_value == None: + unsets[db_field_name] = 1 + else: + sets[db_field_name] = db_value - del(set_data[path]) - unset_data[path] = 1 - return set_data, unset_data + return sets, unsets @classmethod def _get_collection_name(cls): @@ -528,61 +350,16 @@ class BaseDocument(object): return cls._meta.get('collection', None) @classmethod - def _from_son(cls, son, _auto_dereference=True): - """Create an instance of a Document (subclass) from a PyMongo SON. - """ - + def _from_son(cls, son, _auto_dereference=False): # get the class name from the document, falling back to the given # class if unavailable class_name = son.get('_cls', cls._class_name) - data = dict(("%s" % key, value) for key, value in son.iteritems()) - if not UNICODE_KWARGS: - # python 2.6.4 and lower cannot handle unicode keys - # passed to class constructor example: cls(**data) - to_str_keys_recursive(data) # Return correct subclass for document type if class_name != cls._class_name: cls = get_document(class_name) - changed_fields = [] - errors_dict = {} - - fields = cls._fields - if not _auto_dereference: - fields = copy.copy(fields) - - for field_name, field in fields.iteritems(): - field._auto_dereference = _auto_dereference - if field.db_field in data: - value = data[field.db_field] - try: - data[field_name] = (value if value is None - else field.to_python(value)) - if field_name != field.db_field: - del data[field.db_field] - except (AttributeError, ValueError), e: - errors_dict[field_name] = e - elif field.default: - default = field.default - if callable(default): - default = default() - if isinstance(default, BaseDocument): - changed_fields.append(field_name) - - if errors_dict: - errors = "\n".join(["%s - %s" % (k, v) - for k, v in errors_dict.items()]) - msg = ("Invalid data to create a `%s` instance.\n%s" - % (cls._class_name, errors)) - raise InvalidDocumentError(msg) - - obj = cls(__auto_convert=False, **data) - obj._changed_fields = changed_fields - obj._created = False - if not _auto_dereference: - obj._fields = fields - return obj + return cls(_son=son) @classmethod def _build_index_specs(cls, meta_indexes): @@ -629,10 +406,8 @@ class BaseDocument(object): # Check to see if we need to include _cls allow_inheritance = cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) - include_cls = (allow_inheritance and not spec.get('sparse', False) and - spec.get('cls', True)) - if "cls" in spec: - spec.pop('cls') + include_cls = allow_inheritance and not spec.get('sparse', False) + for key in spec['fields']: # If inherited spec continue if isinstance(key, (list, tuple)): @@ -762,7 +537,7 @@ class BaseDocument(object): for field_name in parts: # Handle ListField indexing: - if field_name.isdigit() and hasattr(field, 'field'): + if field_name.isdigit(): new_field = field.field fields.append(field_name) continue @@ -774,9 +549,9 @@ class BaseDocument(object): field_name = cls._meta['id_field'] if field_name in cls._fields: field = cls._fields[field_name] - elif cls._dynamic: - DynamicField = _import_class('DynamicField') - field = DynamicField(db_field=field_name) + #elif cls._dynamic: + # DynamicField = _import_class('DynamicField') + # field = DynamicField(db_field=field_name) else: raise LookUpError('Cannot resolve field "%s"' % field_name) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index c6abd02..168c063 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -59,15 +59,17 @@ class BaseField(object): :param help_text: (optional) The help text for this field and is often used when generating model forms from the document model. """ - self.db_field = (db_field or name) if not primary_key else '_id' - if name: - msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" - warnings.warn(msg, DeprecationWarning) + self.name = None # filled in by document + self.db_field = db_field self.required = required or primary_key self.default = default self.unique = bool(unique or unique_with) self.unique_with = unique_with self.primary_key = primary_key + if self.primary_key: + if self.db_field: + raise ValueError("Can't use primary_key in conjunction with db_field.") + self.db_field = '_id' self.validation = validation self.choices = choices self.verbose_name = verbose_name @@ -82,41 +84,52 @@ class BaseField(object): BaseField.creation_counter += 1 def __get__(self, instance, owner): - """Descriptor for retrieving a value from a field in a document. - """ if instance is None: # Document class being used rather than a document object return self + else: + name = self.name + data = instance._internal_data + if not name in data: + if instance._lazy and name != instance._meta['id_field']: + # We need to fetch the doc from the database. + instance.reload() + db_field = instance._db_field_map.get(name, name) + try: + db_value = instance._db_data[db_field] + except (TypeError, KeyError): + value = self.default() if callable(self.default) else self.default + else: + value = self.to_python(db_value) - # Get value from document instance if available - value = instance._data.get(self.name) + if hasattr(self, 'value_for_instance'): + value = self.value_for_instance(value, instance) + data[name] = value - EmbeddedDocument = _import_class('EmbeddedDocument') - if isinstance(value, EmbeddedDocument) and value._instance is None: - value._instance = weakref.proxy(instance) - return value + return data[name] def __set__(self, instance, value): """Descriptor for assigning a value to a field in a document. """ - # If setting to None and theres a default - # Then set the value to the default value - if value is None and self.default is not None: - value = self.default - if callable(value): - value = value() + if instance._lazy: + # Fetch the from the database before we assign to a lazy object. + instance.reload() - if instance._initialised: - try: - if (self.name not in instance._data or - instance._data[self.name] != value): - instance._mark_as_changed(self.name) - except: - # Values cant be compared eg: naive and tz datetimes - # So mark it as changed - instance._mark_as_changed(self.name) - instance._data[self.name] = value + name = self.name + + value = self.from_python(value) + if hasattr(self, 'value_for_instance'): + value = self.value_for_instance(value, instance) + try: + has_changed = name not in instance._internal_data or instance._internal_data[name] != value + except: # Values can't be compared eg: naive and tz datetimes + has_changed = True + + if has_changed: + instance._mark_as_changed(name) + + instance._internal_data[name] = value def error(self, message="", errors=None, field_name=None): """Raises a ValidationError. @@ -132,7 +145,15 @@ class BaseField(object): def to_mongo(self, value): """Convert a Python type to a MongoDB-compatible type. """ - return self.to_python(value) + return value + + def from_python(self, value): + """Convert a raw Python value (in an assignment) to the internal + Python representation. + """ + if value == None: + return self.default() if callable(self.default) else self.default + return value def prepare_query_value(self, op, value): """Prepare a value that is being used in a query for PyMongo. @@ -187,50 +208,6 @@ class ComplexBaseField(BaseField): field = None - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - ReferenceField = _import_class('ReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') - dereference = (self._auto_dereference and - (self.field is None or isinstance(self.field, - (GenericReferenceField, ReferenceField)))) - - _dereference = _import_class("DeReference")() - - self._auto_dereference = instance._fields[self.name]._auto_dereference - if instance._initialised and dereference: - instance._data[self.name] = _dereference( - instance._data.get(self.name), max_depth=1, instance=instance, - name=self.name - ) - - value = super(ComplexBaseField, self).__get__(instance, owner) - - # Convert lists / values so we can watch for any changes on them - if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): - value = BaseList(value, instance, self.name) - instance._data[self.name] = value - elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, instance, self.name) - instance._data[self.name] = value - - if (self._auto_dereference and instance._initialised and - isinstance(value, (BaseList, BaseDict)) - and not value._dereferenced): - value = _dereference( - value, max_depth=1, instance=instance, name=self.name - ) - value._dereferenced = True - instance._data[self.name] = value - - return value - def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. """ @@ -389,12 +366,10 @@ class ObjectIdField(BaseField): """ def to_python(self, value): - if not isinstance(value, ObjectId): - value = ObjectId(value) return value def to_mongo(self, value): - if not isinstance(value, ObjectId): + if value and not isinstance(value, ObjectId): try: return ObjectId(unicode(value)) except Exception, e: diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index ff5afdd..34a8a51 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -91,12 +91,11 @@ class DocumentMetaclass(type): attrs['_fields'] = doc_fields attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) for k, v in doc_fields.iteritems()]) - attrs['_reverse_db_field_map'] = dict( - (v, k) for k, v in attrs['_db_field_map'].iteritems()) - attrs['_fields_ordered'] = tuple(i[1] for i in sorted( (v.creation_counter, v.name) for v in doc_fields.itervalues())) + attrs['_reverse_db_field_map'] = dict( + (v, k) for k, v in attrs['_db_field_map'].iteritems()) # # Set document hierarchy @@ -359,17 +358,15 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class.id = field # Set primary key if not defined by the document - new_class._auto_id_field = False if not new_class._meta.get('id_field'): - new_class._auto_id_field = True - new_class._meta['id_field'] = 'id' - new_class._fields['id'] = ObjectIdField(db_field='_id') - new_class._fields['id'].name = 'id' + id_field = ObjectIdField(primary_key=True) + id_field.name = 'id' + id_field._auto_gen = True + new_class._fields['id'] = id_field new_class.id = new_class._fields['id'] + new_class._meta['id_field'] = 'id' + new_class._db_field_map['id'] = id_field.db_field - # Prepend id field to _fields_ordered - if 'id' in new_class._fields and 'id' not in new_class._fields_ordered: - new_class._fields_ordered = ('id', ) + new_class._fields_ordered # Merge in exceptions with parent hierarchy exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) diff --git a/mongoengine/base/proxy.py b/mongoengine/base/proxy.py new file mode 100644 index 0000000..7d2879b --- /dev/null +++ b/mongoengine/base/proxy.py @@ -0,0 +1,190 @@ +from mongoengine.queryset import OperationError +from bson.dbref import DBRef + +class LocalProxy(object): + # From werkzeug/local.py + + """ Forwards all operations to + a proxied object. The only operations not supported for forwarding + are right handed operands and any kind of assignment. + """ + + __slots__ = ('__local', '__dict__', '__name__') + + def __init__(self, local, name=None): + object.__setattr__(self, '_LocalProxy__local', local) + object.__setattr__(self, '__name__', name) + + def _get_current_object(self): + """Return the current object. This is useful if you want the real + object behind the proxy at a time for performance reasons or because + you want to pass the object into a different context. + """ + if not hasattr(self.__local, '__release_local__'): + return self.__local() + try: + return getattr(self.__local, self.__name__) + except AttributeError: + raise RuntimeError('no object bound to %s' % self.__name__) + + @property + def __dict__(self): + try: + return self._get_current_object().__dict__ + except RuntimeError: + raise AttributeError('__dict__') + + def __repr__(self): + try: + obj = self._get_current_object() + except RuntimeError: + return '<%s unbound>' % self.__class__.__name__ + return repr(obj) + + def __nonzero__(self): + try: + return bool(self._get_current_object()) + except RuntimeError: + return False + + def __unicode__(self): + try: + return unicode(self._get_current_object()) + except RuntimeError: + return repr(self) + + def __dir__(self): + try: + return dir(self._get_current_object()) + except RuntimeError: + return [] + + def __getattr__(self, name): + if name == '__members__': + return dir(self._get_current_object()) + return getattr(self._get_current_object(), name) + + def __setitem__(self, key, value): + self._get_current_object()[key] = value + + def __delitem__(self, key): + del self._get_current_object()[key] + + def __setslice__(self, i, j, seq): + self._get_current_object()[i:j] = seq + + def __delslice__(self, i, j): + del self._get_current_object()[i:j] + + __setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v) + __delattr__ = lambda x, n: delattr(x._get_current_object(), n) + __str__ = lambda x: str(x._get_current_object()) + __lt__ = lambda x, o: x._get_current_object() < o + __le__ = lambda x, o: x._get_current_object() <= o + __eq__ = lambda x, o: x._get_current_object() == o + __ne__ = lambda x, o: x._get_current_object() != o + __gt__ = lambda x, o: x._get_current_object() > o + __ge__ = lambda x, o: x._get_current_object() >= o + __cmp__ = lambda x, o: cmp(x._get_current_object(), o) + __hash__ = lambda x: hash(x._get_current_object()) + __call__ = lambda x, *a, **kw: x._get_current_object()(*a, **kw) + __len__ = lambda x: len(x._get_current_object()) + __getitem__ = lambda x, i: x._get_current_object()[i] + __iter__ = lambda x: iter(x._get_current_object()) + __contains__ = lambda x, i: i in x._get_current_object() + __getslice__ = lambda x, i, j: x._get_current_object()[i:j] + __add__ = lambda x, o: x._get_current_object() + o + __sub__ = lambda x, o: x._get_current_object() - o + __mul__ = lambda x, o: x._get_current_object() * o + __floordiv__ = lambda x, o: x._get_current_object() // o + __mod__ = lambda x, o: x._get_current_object() % o + __divmod__ = lambda x, o: x._get_current_object().__divmod__(o) + __pow__ = lambda x, o: x._get_current_object() ** o + __lshift__ = lambda x, o: x._get_current_object() << o + __rshift__ = lambda x, o: x._get_current_object() >> o + __and__ = lambda x, o: x._get_current_object() & o + __xor__ = lambda x, o: x._get_current_object() ^ o + __or__ = lambda x, o: x._get_current_object() | o + __div__ = lambda x, o: x._get_current_object().__div__(o) + __truediv__ = lambda x, o: x._get_current_object().__truediv__(o) + __neg__ = lambda x: -(x._get_current_object()) + __pos__ = lambda x: +(x._get_current_object()) + __abs__ = lambda x: abs(x._get_current_object()) + __invert__ = lambda x: ~(x._get_current_object()) + __complex__ = lambda x: complex(x._get_current_object()) + __int__ = lambda x: int(x._get_current_object()) + __long__ = lambda x: long(x._get_current_object()) + __float__ = lambda x: float(x._get_current_object()) + __oct__ = lambda x: oct(x._get_current_object()) + __hex__ = lambda x: hex(x._get_current_object()) + __index__ = lambda x: x._get_current_object().__index__() + __coerce__ = lambda x, o: x.__coerce__(x, o) + __enter__ = lambda x: x.__enter__() + __exit__ = lambda x, *a, **kw: x.__exit__(*a, **kw) + + +class DocumentProxy(LocalProxy): + __slots__ = ('__document_type', '__document', '__pk') + + def __init__(self, document_type, pk): + object.__setattr__(self, '_DocumentProxy__document_type', document_type) + object.__setattr__(self, '_DocumentProxy__document', None) + object.__setattr__(self, '_DocumentProxy__pk', pk) + object.__setattr__(self, document_type._meta['id_field'], self.pk) + + @property + def __class__(self): + # We need to fetch the object to determine to which class it belongs. + return self._get_current_object().__class__ + + def _lazy(): + def fget(self): + return self.__document._lazy if self.__document else True + def fset(self, value): + self._get_current_object()._lazy = value + return property(fget, fset) + _lazy = _lazy() + + # copy normally updates __dict__ which would result in errors + def __setstate__(self, state): + for k, v in state[1].iteritems(): + object.__setattr__(self, k, v) + + def _get_collection_name(self): + return self.__document_type._meta.get('collection', None) + + def __eq__(self, other): + if other and hasattr(other, '_get_collection_name') and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk'): + if self.pk == other.pk: + return True + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def to_dbref(self): + """Returns an instance of :class:`~bson.dbref.DBRef` useful in + `__raw__` queries.""" + if not self.pk: + msg = "Only saved documents can have a valid dbref" + raise OperationError(msg) + return DBRef(self._get_collection_name(), self.pk) + + def pk(): + def fget(self): + return self.__document.pk if self.__document else self.__pk + def fset(self, value): + self._get_current_object().pk = value + return property(fget, fset) + pk = pk() + + def _get_current_object(self): + if self.__document == None: + collection = self.__document_type._get_collection() + son = collection.find_one({'_id': self.__pk}) + document = self.__document_type._from_son(son) + object.__setattr__(self, '_DocumentProxy__document', document) + return self.__document + + def __nonzero__(self): + return True diff --git a/mongoengine/common.py b/mongoengine/common.py index 6303231..20d5138 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -23,9 +23,8 @@ def _import_class(cls_name): field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', 'FileField', 'GenericReferenceField', 'GenericEmbeddedDocumentField', 'GeoPointField', - 'PointField', 'LineStringField', 'ListField', - 'PolygonField', 'ReferenceField', 'StringField', - 'ComplexBaseField') + 'PointField', 'LineStringField', 'PolygonField', + 'ReferenceField', 'StringField', 'ComplexBaseField') queryset_classes = ('OperationError',) deref_classes = ('DeReference',) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 4275da5..abab269 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -55,9 +55,12 @@ def register_connection(alias, name, host='localhost', port=27017, # Handle uri style connections if "://" in host: uri_dict = uri_parser.parse_uri(host) + if uri_dict.get('database') is None: + raise ConnectionError("If using URI style connection include "\ + "database name in string") conn_settings.update({ 'host': host, - 'name': uri_dict.get('database') or name, + 'name': uri_dict.get('database'), 'username': uri_dict.get('username'), 'password': uri_dict.get('password'), 'read_preference': read_preference, diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index ceda403..79f755f 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -4,7 +4,7 @@ from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document) from fields import (ReferenceField, ListField, DictField, MapField) from connection import get_db from queryset import QuerySet -from document import Document, EmbeddedDocument +from document import Document class DeReference(object): @@ -33,8 +33,7 @@ class DeReference(object): self.max_depth = max_depth doc_type = None - if instance and isinstance(instance, (Document, EmbeddedDocument, - TopLevelDocumentMetaclass)): + if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)): doc_type = instance._fields.get(name) if hasattr(doc_type, 'field'): doc_type = doc_type.field @@ -87,9 +86,11 @@ class DeReference(object): for k, item in iterator: if isinstance(item, Document): for field_name, field in item._fields.iteritems(): - v = item._data.get(field_name, None) + v = getattr(item, field_name) if isinstance(v, (DBRef)): reference_map.setdefault(field.document_type, []).append(v.id) + elif isinstance(v, Document) and getattr(v, '_lazy', False): + reference_map.setdefault(field.document_type, []).append(v.pk) elif isinstance(v, (dict, SON)) and '_ref' in v: reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: @@ -170,7 +171,7 @@ class DeReference(object): return self.object_map.get(items['_ref'].id, items) elif '_cls' in items: doc = get_document(items['_cls'])._from_son(items) - doc._data = self._attach_objects(doc._data, depth, doc, None) + doc._internal_data = self._attach_objects(doc._internal_data, depth, doc, None) return doc if not hasattr(items, 'items'): @@ -194,15 +195,17 @@ class DeReference(object): data[k] = self.object_map[k] elif isinstance(v, Document): for field_name, field in v._fields.iteritems(): - v = data[k]._data.get(field_name, None) + v = data[k]._internal_data.get(field_name, None) if isinstance(v, (DBRef)): - data[k]._data[field_name] = self.object_map.get(v.id, v) + data[k]._internal_data[field_name] = self.object_map.get(v.id, v) + elif isinstance(v, Document) and getattr(v, '_lazy', False): + data[k]._internal_data[field_name] = self.object_map.get(v.pk, v) elif isinstance(v, (dict, SON)) and '_ref' in v: - data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) + data[k]._internal_data[field_name] = self.object_map.get(v['_ref'].id, v) elif isinstance(v, dict) and depth <= self.max_depth: - data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) + data[k]._internal_data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) elif isinstance(v, (list, tuple)) and depth <= self.max_depth: - data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) + data[k]._internal_data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) elif hasattr(v, 'id'): diff --git a/mongoengine/django/mongo_auth/models.py b/mongoengine/django/mongo_auth/models.py index d4947a2..3529d8e 100644 --- a/mongoengine/django/mongo_auth/models.py +++ b/mongoengine/django/mongo_auth/models.py @@ -6,29 +6,10 @@ from django.utils.importlib import import_module from django.utils.translation import ugettext_lazy as _ -__all__ = ( - 'get_user_document', -) - - MONGOENGINE_USER_DOCUMENT = getattr( settings, 'MONGOENGINE_USER_DOCUMENT', 'mongoengine.django.auth.User') -def get_user_document(): - """Get the user document class used for authentication. - - This is the class defined in settings.MONGOENGINE_USER_DOCUMENT, which - defaults to `mongoengine.django.auth.User`. - - """ - - name = MONGOENGINE_USER_DOCUMENT - dot = name.rindex('.') - module = import_module(name[:dot]) - return getattr(module, name[dot + 1:]) - - class MongoUserManager(UserManager): """A User manager wich allows the use of MongoEngine documents in Django. @@ -63,7 +44,7 @@ class MongoUserManager(UserManager): def contribute_to_class(self, model, name): super(MongoUserManager, self).contribute_to_class(model, name) self.dj_model = self.model - self.model = get_user_document() + self.model = self._get_user_document() self.dj_model.USERNAME_FIELD = self.model.USERNAME_FIELD username = models.CharField(_('username'), max_length=30, unique=True) @@ -74,6 +55,16 @@ class MongoUserManager(UserManager): field = models.CharField(_(name), max_length=30) field.contribute_to_class(self.dj_model, name) + def _get_user_document(self): + try: + name = MONGOENGINE_USER_DOCUMENT + dot = name.rindex('.') + module = import_module(name[:dot]) + return getattr(module, name[dot + 1:]) + except ImportError: + raise ImproperlyConfigured("Error importing %s, please check " + "settings.MONGOENGINE_USER_DOCUMENT" + % name) def get(self, *args, **kwargs): try: @@ -94,14 +85,5 @@ class MongoUserManager(UserManager): class MongoUser(models.Model): - """"Dummy user model for Django. - - MongoUser is used to replace Django's UserManager with MongoUserManager. - The actual user document class is mongoengine.django.auth.User or any - other document class specified in MONGOENGINE_USER_DOCUMENT. - - To get the user document class, use `get_user_document()`. - - """ - objects = MongoUserManager() + diff --git a/mongoengine/django/sessions.py b/mongoengine/django/sessions.py index 7e4e182..c90807e 100644 --- a/mongoengine/django/sessions.py +++ b/mongoengine/django/sessions.py @@ -1,10 +1,7 @@ from django.conf import settings from django.contrib.sessions.backends.base import SessionBase, CreateError from django.core.exceptions import SuspiciousOperation -try: - from django.utils.encoding import force_unicode -except ImportError: - from django.utils.encoding import force_text as force_unicode +from django.utils.encoding import force_unicode from mongoengine.document import Document from mongoengine import fields diff --git a/mongoengine/django/storage.py b/mongoengine/django/storage.py index 9df6f9e..341455c 100644 --- a/mongoengine/django/storage.py +++ b/mongoengine/django/storage.py @@ -76,7 +76,7 @@ class GridFSStorage(Storage): """Find the documents in the store with the given name """ docs = self.document.objects - doc = [d for d in docs if hasattr(getattr(d, self.field), 'name') and getattr(d, self.field).name == name] + doc = [d for d in docs if getattr(d, self.field).name == name] if doc: return doc[0] else: diff --git a/mongoengine/document.py b/mongoengine/document.py index 1bbd7b7..b553097 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -12,7 +12,7 @@ from mongoengine.common import _import_class from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) -from mongoengine.queryset import OperationError, NotUniqueError, QuerySet +from mongoengine.queryset import OperationError, NotUniqueError, QuerySet, DoesNotExist from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME from mongoengine.context_managers import switch_db, switch_collection @@ -20,6 +20,7 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') +_set = object.__setattr__ def includes_cls(fields): """ Helper function used for ensuring and comparing indexes @@ -62,11 +63,11 @@ class EmbeddedDocument(BaseDocument): def __init__(self, *args, **kwargs): super(EmbeddedDocument, self).__init__(*args, **kwargs) - self._changed_fields = [] + self._changed_fields = set() def __eq__(self, other): if isinstance(other, self.__class__): - return self._data == other._data + return self.to_dict() == other.to_dict() return False def __ne__(self, other): @@ -177,15 +178,13 @@ class Document(BaseDocument): cls.ensure_indexes() return cls._collection - def save(self, force_insert=False, validate=True, clean=True, + def save(self, validate=True, clean=True, write_concern=None, cascade=None, cascade_kwargs=None, - _refs=None, **kwargs): + _refs=None, full=False, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. - :param force_insert: only try to create a new document, don't allow - updates of existing documents :param validate: validates the document; set to ``False`` to skip. :param clean: call the document clean method, requires `validate` to be True. @@ -202,6 +201,7 @@ class Document(BaseDocument): :param cascade_kwargs: (optional) kwargs dictionary to be passed throw to cascading saves. Implies ``cascade=True``. :param _refs: A list of processed references used in cascading saves + :param full: Save all model fields instead of just changed ones. .. versionchanged:: 0.5 In existing documents it only saves changed fields using @@ -217,61 +217,52 @@ class Document(BaseDocument): the cascade save using cascade_kwargs which overwrites the existing kwargs with custom values. """ + signals.pre_save.send(self.__class__, document=self) if validate: self.validate(clean=clean) - if write_concern is None: - write_concern = {"w": 1} - - doc = self.to_mongo() - - created = ('_id' not in doc or self._created or force_insert) - - signals.pre_save_post_validation.send(self.__class__, document=self, created=created) + if not write_concern: + write_concern = {'w': 1} + collection = self._get_collection() try: - collection = self._get_collection() - if created: - if force_insert: - object_id = collection.insert(doc, **write_concern) - else: - object_id = collection.save(doc, **write_concern) - else: - object_id = doc['_id'] - updates, removals = self._delta() - # Need to add shard key to query, or you get an error - select_dict = {'_id': object_id} - shard_key = self.__class__._meta.get('shard_key', tuple()) - for k in shard_key: - actual_key = self._db_field_map.get(k, k) - select_dict[actual_key] = doc[actual_key] - - def is_new_object(last_error): - if last_error is not None: - updated = last_error.get("updatedExisting") - if updated is not None: - return not updated - return created + if self._created: + # Update: Get delta. + sets, unsets = self._delta(full) + db_id_field = self._fields[self._meta['id_field']].db_field + sets.pop(db_id_field, None) update_query = {} + if sets: + update_query['$set'] = sets + if unsets: + update_query['$unset'] = unsets - if updates: - update_query["$set"] = updates - if removals: - update_query["$unset"] = removals - if updates or removals: - last_error = collection.update(select_dict, update_query, - upsert=True, **write_concern) - created = is_new_object(last_error) + if update_query: + collection.update(self._db_object_key, update_query, **write_concern) - if cascade is None: - cascade = self._meta.get('cascade', False) or cascade_kwargs is not None + created = False + else: + # Insert: Get full SON. + doc = self.to_mongo() + object_id = collection.insert(doc, **write_concern) + # Fix pymongo's "return return_one and ids[0] or ids": + # If the ID is 0, pymongo wraps it in a list. + if isinstance(object_id, list) and not object_id[0]: + object_id = object_id[0] + id_field = self._meta['id_field'] + del self._internal_data[id_field] + _set(self, '_db_data', doc) + doc['_id'] = object_id + + created = True + cascade = (self._meta.get('cascade', False) + if cascade is None else cascade) if cascade: kwargs = { - "force_insert": force_insert, "validate": validate, "write_concern": write_concern, "cascade": cascade @@ -289,12 +280,9 @@ class Document(BaseDocument): message = u'Tried to save duplicate unique keys (%s)' raise NotUniqueError(message % unicode(err)) raise OperationError(message % unicode(err)) - id_field = self._meta['id_field'] - if id_field not in self._meta.get('shard_key', []): - self[id_field] = self._fields[id_field].to_python(object_id) self._clear_changed_fields() - self._created = False + signals.post_save.send(self.__class__, document=self, created=created) return self @@ -311,14 +299,17 @@ class Document(BaseDocument): GenericReferenceField)): continue - ref = self._data.get(name) + ref = getattr(self, name) if not ref or isinstance(ref, DBRef): continue if not getattr(ref, '_changed_fields', True): continue - ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data)) + if getattr(ref, '_lazy', False): + continue + + ref_id = "%s,%s" % (ref.__class__.__name__, str(ref.to_dict())) if ref and ref_id not in _refs: _refs.append(ref_id) kwargs["_refs"] = _refs @@ -344,6 +335,16 @@ class Document(BaseDocument): select_dict[k] = getattr(self, k) return select_dict + @property + def _db_object_key(self): + field = self._fields[self._meta['id_field']] + select_dict = {field.db_field: field.to_mongo(self.pk)} + shard_key = self.__class__._meta.get('shard_key', tuple()) + for k in shard_key: + actual_key = self._db_field_map.get(k, k) + select_dict[actual_key] = self._fields[k].to_mongo(getattr(self, k)) + return select_dict + def update(self, **kwargs): """Performs an update on the :class:`~mongoengine.Document` A convenience wrapper to :meth:`~mongoengine.QuerySet.update`. @@ -376,6 +377,9 @@ class Document(BaseDocument): """ signals.pre_delete.send(self.__class__, document=self) + if not write_concern: + write_concern = {'w': 1} + try: self._qs.filter(**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) except pymongo.errors.OperationFailure, err: @@ -400,11 +404,11 @@ class Document(BaseDocument): """ with switch_db(self.__class__, db_alias) as cls: collection = cls._get_collection() - db = cls._get_db() + db = cls._get_db self._get_collection = lambda: collection self._get_db = lambda: db self._collection = collection - self._created = True + #self._created = True self.__objects = self._qs self.__objects._collection_obj = collection return self @@ -429,7 +433,7 @@ class Document(BaseDocument): collection = cls._get_collection() self._get_collection = lambda: collection self._collection = collection - self._created = True + #self._created = True self.__objects = self._qs self.__objects._collection_obj = collection return self @@ -440,44 +444,22 @@ class Document(BaseDocument): .. versionadded:: 0.5 """ - DeReference = _import_class('DeReference') - DeReference()([self], max_depth + 1) + import dereference + self._internal_data = dereference.DeReference()(self._internal_data, max_depth) return self - def reload(self, max_depth=1): + def reload(self): """Reloads all attributes from the database. - - .. versionadded:: 0.1.2 - .. versionchanged:: 0.6 Now chainable """ - obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( - **self._object_key).limit(1).select_related(max_depth=max_depth) - - if obj: - obj = obj[0] - else: - msg = "Reloaded document has been deleted" - raise OperationError(msg) - for field in self._fields_ordered: - setattr(self, field, self._reload(field, obj[field])) - self._changed_fields = obj._changed_fields - self._created = False - return obj - - def _reload(self, key, value): - """Used by :meth:`~mongoengine.Document.reload` to ensure the - correct instance is linked to self. - """ - if isinstance(value, BaseDict): - value = [(k, self._reload(k, v)) for k, v in value.items()] - value = BaseDict(value, self, key) - elif isinstance(value, BaseList): - value = [self._reload(key, v) for v in value] - value = BaseList(value, self, key) - elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): - value._instance = None - value._changed_fields = [] - return value + collection = self._get_collection() + son = collection.find_one(self._db_object_key, read_preference=ReadPreference.PRIMARY) + if son == None: + raise DoesNotExist('Document has been deleted.') + _set(self, '_db_data', son) + _set(self, '_internal_data', {}) + _set(self, '_lazy', False) + self._clear_changed_fields() + return self def to_dbref(self): """Returns an instance of :class:`~bson.dbref.DBRef` useful in @@ -536,8 +518,6 @@ class Document(BaseDocument): def ensure_indexes(cls): """Checks the document meta data and ensures all the indexes exist. - Global defaults can be set in the meta - see :doc:`guide/defining-documents` - .. note:: You can disable automatic index creation by setting `auto_create_index` to False in the documents meta data """ @@ -679,6 +659,8 @@ class DynamicDocument(Document): _dynamic = True + # TODO + def __delattr__(self, *args, **kwargs): """Deletes the attribute by setting to None and allowing _delta to unset it""" @@ -702,6 +684,8 @@ class DynamicEmbeddedDocument(EmbeddedDocument): _dynamic = True + # TODO + def __delattr__(self, *args, **kwargs): """Deletes the attribute by setting to None and allowing _delta to unset it""" diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 419f2ef..57c4154 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -3,6 +3,7 @@ import decimal import itertools import re import time +import types import urllib2 import uuid import warnings @@ -22,8 +23,11 @@ from bson import Binary, DBRef, SON, ObjectId from mongoengine.errors import ValidationError from mongoengine.python_support import (PY3, bin_type, txt_type, str_types, StringIO) -from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, +from mongoengine.base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, get_document, BaseDocument) +from mongoengine.base.datastructures import BaseList, BaseDict +from mongoengine.base.proxy import DocumentProxy +from mongoengine.queryset import DoesNotExist from queryset import DO_NOTHING, QuerySet from document import Document, EmbeddedDocument from connection import get_db, DEFAULT_CONNECTION_NAME @@ -34,11 +38,12 @@ except ImportError: Image = None ImageOps = None -__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField', - 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', +__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', + 'FloatField', 'BooleanField', 'DateTimeField', 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', 'SortedListField', 'DictField', 'MapField', 'ReferenceField', + 'SafeReferenceField', 'SafeReferenceListField', 'GenericReferenceField', 'BinaryField', 'GridFSError', 'GridFSProxy', 'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', @@ -58,15 +63,6 @@ class StringField(BaseField): self.min_length = min_length super(StringField, self).__init__(**kwargs) - def to_python(self, value): - if isinstance(value, unicode): - return value - try: - value = value.decode('utf-8') - except: - pass - return value - def validate(self, value): if not isinstance(value, basestring): self.error('StringField only accepts string values') @@ -121,8 +117,7 @@ class URLField(StringField): r'(?::\d+)?' # optional port r'(?:/?|[/?]\S+)$', re.IGNORECASE) - def __init__(self, verify_exists=False, url_regex=None, **kwargs): - self.verify_exists = verify_exists + def __init__(self, url_regex=None, **kwargs): self.url_regex = url_regex or self._URL_REGEX super(URLField, self).__init__(**kwargs) @@ -131,50 +126,31 @@ class URLField(StringField): self.error('Invalid URL: %s' % value) return - if self.verify_exists: - warnings.warn( - "The URLField verify_exists argument has intractable security " - "and performance issues. Accordingly, it has been deprecated.", - DeprecationWarning) - try: - request = urllib2.Request(value) - urllib2.urlopen(request) - except Exception, e: - self.error('This URL appears to be a broken link: %s' % e) - class EmailField(StringField): - """A field that validates input as an E-Mail-Address. + """A field that validates input as an email address. .. versionadded:: 0.4 """ - EMAIL_REGEX = re.compile( - r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain - ) + EMAIL_REGEX = re.compile(r'^.+@[^.].*\.[a-z]{2,10}$', re.IGNORECASE) def validate(self, value): if not EmailField.EMAIL_REGEX.match(value): - self.error('Invalid Mail-address: %s' % value) + self.error('Invalid email address: %s' % value) super(EmailField, self).validate(value) class IntField(BaseField): - """An 32-bit integer field. + """An integer field. """ def __init__(self, min_value=None, max_value=None, **kwargs): self.min_value, self.max_value = min_value, max_value super(IntField, self).__init__(**kwargs) - def to_python(self, value): - try: - value = int(value) - except ValueError: - pass - return value + def from_python(self, value): + return self.prepare_query_value(None, value) def validate(self, value): try: @@ -191,59 +167,18 @@ class IntField(BaseField): def prepare_query_value(self, op, value): if value is None: return value - - return int(value) - - -class LongField(BaseField): - """An 64-bit integer field. - """ - - def __init__(self, min_value=None, max_value=None, **kwargs): - self.min_value, self.max_value = min_value, max_value - super(LongField, self).__init__(**kwargs) - - def to_python(self, value): - try: - value = long(value) - except ValueError: - pass - return value - - def validate(self, value): - try: - value = long(value) - except: - self.error('%s could not be converted to long' % value) - - if self.min_value is not None and value < self.min_value: - self.error('Long value is too small') - - if self.max_value is not None and value > self.max_value: - self.error('Long value is too large') - - def prepare_query_value(self, op, value): - if value is None: - return value - - return long(value) + else: + return int(value) class FloatField(BaseField): - """An floating point number field. + """A floating point number field. """ def __init__(self, min_value=None, max_value=None, **kwargs): self.min_value, self.max_value = min_value, max_value super(FloatField, self).__init__(**kwargs) - def to_python(self, value): - try: - value = float(value) - except ValueError: - pass - return value - def validate(self, value): if isinstance(value, int): value = float(value) @@ -256,82 +191,6 @@ class FloatField(BaseField): if self.max_value is not None and value > self.max_value: self.error('Float value is too large') - def prepare_query_value(self, op, value): - if value is None: - return value - - return float(value) - - -class DecimalField(BaseField): - """A fixed-point decimal number field. - - .. versionchanged:: 0.8 - .. versionadded:: 0.3 - """ - - def __init__(self, min_value=None, max_value=None, force_string=False, - precision=2, rounding=decimal.ROUND_HALF_UP, **kwargs): - """ - :param min_value: Validation rule for the minimum acceptable value. - :param max_value: Validation rule for the maximum acceptable value. - :param force_string: Store as a string. - :param precision: Number of decimal places to store. - :param rounding: The rounding rule from the python decimal libary: - - - decimal.ROUND_CEILING (towards Infinity) - - decimal.ROUND_DOWN (towards zero) - - decimal.ROUND_FLOOR (towards -Infinity) - - decimal.ROUND_HALF_DOWN (to nearest with ties going towards zero) - - decimal.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer) - - decimal.ROUND_HALF_UP (to nearest with ties going away from zero) - - decimal.ROUND_UP (away from zero) - - decimal.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero) - - Defaults to: ``decimal.ROUND_HALF_UP`` - - """ - self.min_value = min_value - self.max_value = max_value - self.force_string = force_string - self.precision = decimal.Decimal(".%s" % ("0" * precision)) - self.rounding = rounding - - super(DecimalField, self).__init__(**kwargs) - - def to_python(self, value): - if value is None: - return value - - # Convert to string for python 2.6 before casting to Decimal - value = decimal.Decimal("%s" % value) - return value.quantize(self.precision, rounding=self.rounding) - - def to_mongo(self, value): - if value is None: - return value - if self.force_string: - return unicode(value) - return float(self.to_python(value)) - - def validate(self, value): - if not isinstance(value, decimal.Decimal): - if not isinstance(value, basestring): - value = unicode(value) - try: - value = decimal.Decimal(value) - except Exception, exc: - self.error('Could not convert value to decimal: %s' % exc) - - if self.min_value is not None and value < self.min_value: - self.error('Decimal value is too small') - - if self.max_value is not None and value > self.max_value: - self.error('Decimal value is too large') - - def prepare_query_value(self, op, value): - return self.to_mongo(value) - class BooleanField(BaseField): """A boolean field type. @@ -339,13 +198,6 @@ class BooleanField(BaseField): .. versionadded:: 0.1.2 """ - def to_python(self, value): - try: - value = bool(value) - except ValueError: - pass - return value - def validate(self, value): if not isinstance(value, bool): self.error('BooleanField only accepts boolean values') @@ -366,11 +218,13 @@ class DateTimeField(BaseField): """ def validate(self, value): - new_value = self.to_mongo(value) - if not isinstance(new_value, (datetime.datetime, datetime.date)): + if not isinstance(value, (datetime.datetime, datetime.date)): self.error(u'cannot parse date "%s"' % value) - def to_mongo(self, value): + def from_python(self, value): + return self.prepare_query_value(None, value) or value + + def prepare_query_value(self, op, value): if value is None: return value if isinstance(value, datetime.datetime): @@ -414,9 +268,6 @@ class DateTimeField(BaseField): except ValueError: return None - def prepare_query_value(self, op, value): - return self.to_mongo(value) - class ComplexDateTimeField(StringField): """ @@ -437,6 +288,8 @@ class ComplexDateTimeField(StringField): .. versionadded:: 0.5 """ + # TODO + def __init__(self, separator=',', **kwargs): self.names = ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond'] @@ -542,15 +395,11 @@ class EmbeddedDocumentField(BaseField): self.document_type_obj = get_document(self.document_type_obj) return self.document_type_obj - def to_python(self, value): - if not isinstance(value, self.document_type): - return self.document_type._from_son(value) - return value + def to_python(self, val): + return self.document_type._from_son(val) - def to_mongo(self, value): - if not isinstance(value, self.document_type): - return value - return self.document_type.to_mongo(value) + def to_mongo(self, val): + return val and val.to_mongo() def validate(self, value, clean=True): """Make sure that the document instance is an instance of the @@ -584,9 +433,8 @@ class GenericEmbeddedDocumentField(BaseField): return self.to_mongo(value) def to_python(self, value): - if isinstance(value, dict): - doc_cls = get_document(value['_cls']) - value = doc_cls._from_son(value) + doc_cls = get_document(value['_cls']) + value = doc_cls._from_son(value) return value @@ -624,9 +472,7 @@ class DynamicField(BaseField): cls = value.__class__ val = value.to_mongo() # If we its a document thats not inherited add _cls - if (isinstance(value, Document)): - val = {"_ref": value.to_dbref(), "_cls": cls.__name__} - if (isinstance(value, EmbeddedDocument)): + if (isinstance(value, (Document, EmbeddedDocument))): val['_cls'] = cls.__name__ return val @@ -647,15 +493,6 @@ class DynamicField(BaseField): value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))] return value - def to_python(self, value): - if isinstance(value, dict) and '_cls' in value: - doc_cls = get_document(value['_cls']) - if '_ref' in value: - value = doc_cls._get_db().dereference(value['_ref']) - return doc_cls._from_son(value) - - return super(DynamicField, self).to_python(value) - def lookup_member(self, member_name): return member_name @@ -685,6 +522,26 @@ class ListField(ComplexBaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) + def value_for_instance(self, value, instance, name=None): + name = name or self.name + if value and self.field: + value_for_instance = getattr(self.field, 'value_for_instance', None) + if value_for_instance: + value = [value_for_instance(v, instance, name) for v in value] + return BaseList(value or [], instance, name) + + def from_python(self, val): + from_python = getattr(self.field, 'from_python', None) + return [from_python(v) for v in val] if from_python else val + + def to_python(self, val): + to_python = getattr(self.field, 'to_python', None) + return [to_python(v) for v in val] if to_python and val else val or None + + def to_mongo(self, val): + to_mongo = getattr(self.field, 'to_mongo', None) + return [to_mongo(v) for v in val] if to_mongo and val else val or None + def validate(self, value): """Make sure that a list of valid fields is being used. """ @@ -700,6 +557,9 @@ class ListField(ComplexBaseField): and hasattr(value, '__iter__')): return [self.field.prepare_query_value(op, v) for v in value] return self.field.prepare_query_value(op, value) + else: + if op in ('set', 'unset'): + return value return super(ListField, self).prepare_query_value(op, value) @@ -730,10 +590,13 @@ class SortedListField(ListField): def to_mongo(self, value): value = super(SortedListField, self).to_mongo(value) - if self._ordering is not None: - return sorted(value, key=itemgetter(self._ordering), - reverse=self._order_reverse) - return sorted(value, reverse=self._order_reverse) + if value: + if self._ordering is not None: + return sorted(value, key=itemgetter(self._ordering), + reverse=self._order_reverse) + return sorted(value, reverse=self._order_reverse) + else: + return value class DictField(ComplexBaseField): @@ -755,6 +618,26 @@ class DictField(ComplexBaseField): kwargs.setdefault('default', lambda: {}) super(DictField, self).__init__(*args, **kwargs) + def from_python(self, val): + from_python = getattr(self.field, 'from_python', None) + return {k: from_python(v) for k, v in val.iteritems()} if from_python else val + + def to_python(self, val): + to_python = getattr(self.field, 'to_python', None) + return {k: to_python(v) for k, v in val.iteritems()} if to_python and val else val or None + + def value_for_instance(self, value, instance, name=None): + name = name or self.name + if value and self.field: + value_for_instance = getattr(self.field, 'value_for_instance', None) + if value_for_instance: + value = {k: value_for_instance(v, instance, name) for k, v in value.iteritems()} + return BaseDict(value or {}, instance, name) + + def to_mongo(self, val): + to_mongo = getattr(self.field, 'to_mongo', None) + return {k: to_mongo(v) for k, v in val.iteritems()} if to_mongo and val else val or None + def validate(self, value): """Make sure that a list of valid fields is being used. """ @@ -780,10 +663,6 @@ class DictField(ComplexBaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) - - if hasattr(self.field, 'field'): - return self.field.prepare_query_value(op, value) - return super(DictField, self).prepare_query_value(op, value) @@ -866,69 +745,82 @@ class ReferenceField(BaseField): self.document_type_obj = get_document(self.document_type_obj) return self.document_type_obj - def __get__(self, instance, owner): - """Descriptor to allow lazy dereferencing. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value = instance._data.get(self.name) - self._auto_dereference = instance._fields[self.name]._auto_dereference - # Dereference DBRefs - if self._auto_dereference and isinstance(value, DBRef): - value = self.document_type._get_db().dereference(value) - if value is not None: - instance._data[self.name] = self.document_type._from_son(value) - - return super(ReferenceField, self).__get__(instance, owner) - - def to_mongo(self, document): - if isinstance(document, DBRef): - if not self.dbref: - return document.id - return document - - id_field_name = self.document_type._meta['id_field'] - id_field = self.document_type._fields[id_field_name] - - if isinstance(document, Document): + def to_mongo(self, value): + if isinstance(value, DBRef): + if self.dbref: + return value + else: + return value.id + elif isinstance(value, (Document, DocumentProxy)): + document_type = self.document_type # We need the id from the saved object to create the DBRef - id_ = document.pk - if id_ is None: + pk = value.pk + if pk is None: self.error('You can only reference documents once they have' ' been saved to the database') - else: - id_ = document - - id_ = id_field.to_mongo(id_) - if self.dbref: - collection = self.document_type._get_collection_name() - return DBRef(collection, id_) - - return id_ + id_field_name = document_type._meta['id_field'] + id_field = document_type._fields[id_field_name] + pk = id_field.to_mongo(pk) + if self.dbref: + collection = document_type._get_collection_name() + return DBRef(collection, pk) + else: + return pk + elif value != None: # string ID + document_type = self.document_type + collection = document_type._get_collection_name() + return DBRef(collection, value) def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ - if (not self.dbref and - not isinstance(value, (DBRef, Document, EmbeddedDocument))): - collection = self.document_type._get_collection_name() - value = DBRef(collection, self.document_type.id.to_python(value)) - return value + if value != None: + document_type = self.document_type + if self.dbref: + pk = value.id + else: + if isinstance(value, DBRef): + pk = value.id + else: + pk = value + if document_type._meta['allow_inheritance']: + # We don't know of which type the object will be. + obj = DocumentProxy(document_type, pk) + else: + obj = document_type(pk=pk) + obj._lazy = True + return obj + + def from_python(self, value): + if isinstance(value, (BaseDocument, DocumentProxy)): + return value + elif value == None: + return super(ReferenceField, self).from_python(value) + else: + # Support for werkzeug.local.LocalProxy + if hasattr(value, '_get_current_object'): + return value._get_current_object() + else: + # DBRef or ID + document_type = self.document_type + if isinstance(value, DBRef): + pk = value.id + else: + pk = value + if document_type._meta['allow_inheritance']: + # We don't know of which type the object will be. + obj = DocumentProxy(document_type, pk) + else: + obj = document_type(pk=pk) + obj._lazy = True + return obj def prepare_query_value(self, op, value): - if value is None: - return None - return self.to_mongo(value) + return self.to_mongo(self.from_python(value)) def validate(self, value): - - if not isinstance(value, (self.document_type, DBRef)): + if not isinstance(value, (self.document_type, DBRef, DocumentProxy)): self.error("A ReferenceField only accepts DBRef or documents") - if isinstance(value, Document) and value.id is None: + if isinstance(value, Document) and value.pk is None: self.error('You can only reference documents once they have been ' 'saved to the database') @@ -936,6 +828,52 @@ class ReferenceField(BaseField): return self.document_type._fields.get(member_name) +class SafeReferenceField(ReferenceField): + """ + Like a ReferenceField, but doesn't return non-existing references when + dereferencing, i.e. no DBRefs are returned. This means that the next time + an object is saved, the non-existing references are removed and application + code can rely on having only valid dereferenced objects. + + When the field is referenced, the referenced object is loaded from the + database. + """ + + def to_python(self, value): + obj = super(SafeReferenceField, self).to_python(value) + if obj: + # Must dereference so we don't get an invalid ObjectId back. + try: + obj.reload() + except DoesNotExist: + return None + return obj + + +class SafeReferenceListField(ListField): + """ + Like a ListField, but doesn't return non-existing references when + dereferencing, i.e. no DBRefs are returned. This means that the next time + an object is saved, the non-existing references are removed and application + code can rely on having only valid dereferenced objects. + + When the field is referenced, all referenced objects are loaded from the + database. + + Must use ReferenceField as its field class. + """ + + def __init__(self, field, **kwargs): + if not isinstance(field, ReferenceField): + raise ValueError('Field argument must be a ReferenceField instance.') + return super(SafeReferenceListField, self).__init__(field, **kwargs) + + def to_python(self, value): + result = super(SafeReferenceListField, self).to_python(value) + if result: + objs = self.field.document_type.objects.in_bulk([obj.id for obj in result]) + return filter(None, [objs.get(obj.id) for obj in result]) + class GenericReferenceField(BaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -950,17 +888,6 @@ class GenericReferenceField(BaseField): .. versionadded:: 0.3 """ - def __get__(self, instance, owner): - if instance is None: - return self - - value = instance._data.get(self.name) - self._auto_dereference = instance._fields[self.name]._auto_dereference - if self._auto_dereference and isinstance(value, (dict, SON)): - instance._data[self.name] = self.dereference(value) - - return super(GenericReferenceField, self).__get__(instance, owner) - def validate(self, value): if not isinstance(value, (Document, DBRef, dict, SON)): self.error('GenericReferences can only contain documents') @@ -982,6 +909,14 @@ class GenericReferenceField(BaseField): doc = doc_cls._from_son(doc) return doc + def to_python(self, value): + if value != None: + doc_cls = get_document(value['_cls']) + reference = value['_ref'] + obj = doc_cls(pk=reference.id) + obj._lazy = True + return obj + def to_mongo(self, document): if document is None: return None @@ -1098,10 +1033,6 @@ class GridFSProxy(object): def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.grid_id) - def __str__(self): - name = getattr(self.get(), 'filename', self.grid_id) if self.get() else '(no file)' - return '<%s: %s>' % (self.__class__.__name__, name) - def __eq__(self, other): if isinstance(other, GridFSProxy): return ((self.grid_id == other.grid_id) and @@ -1209,7 +1140,9 @@ class FileField(BaseField): # Check if a file already exists for this model grid_file = instance._data.get(self.name) if not isinstance(grid_file, self.proxy_class): - grid_file = self.get_proxy_obj(key=self.name, instance=instance) + grid_file = self.proxy_class(key=self.name, instance=instance, + db_alias=self.db_alias, + collection_name=self.collection_name) instance._data[self.name] = grid_file if not grid_file.key: @@ -1231,23 +1164,15 @@ class FileField(BaseField): pass # Create a new proxy object as we don't already have one - instance._data[key] = self.get_proxy_obj(key=key, instance=instance) + instance._data[key] = self.proxy_class(key=key, instance=instance, + db_alias=self.db_alias, + collection_name=self.collection_name) instance._data[key].put(value) else: instance._data[key] = value instance._mark_as_changed(key) - def get_proxy_obj(self, key, instance, db_alias=None, collection_name=None): - if db_alias is None: - db_alias = self.db_alias - if collection_name is None: - collection_name = self.collection_name - - return self.proxy_class(key=key, instance=instance, - db_alias=db_alias, - collection_name=collection_name) - def to_mongo(self, value): # Store the GridFS file id in MongoDB if isinstance(value, self.proxy_class) and value.grid_id is not None: @@ -1280,9 +1205,6 @@ class ImageGridFsProxy(GridFSProxy): applying field properties (size, thumbnail_size) """ field = self.instance._fields[self.key] - # Handle nested fields - if hasattr(field, 'field') and isinstance(field.field, FileField): - field = field.field try: img = Image.open(file_obj) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py deleted file mode 100644 index b4dad0c..0000000 --- a/mongoengine/queryset/base.py +++ /dev/null @@ -1,1494 +0,0 @@ -from __future__ import absolute_import - -import copy -import itertools -import operator -import pprint -import re -import warnings - -from bson.code import Code -from bson import json_util -import pymongo -from pymongo.common import validate_read_preference - -from mongoengine import signals -from mongoengine.common import _import_class -from mongoengine.base.common import get_document -from mongoengine.errors import (OperationError, NotUniqueError, - InvalidQueryError, LookUpError) - -from mongoengine.queryset import transform -from mongoengine.queryset.field_list import QueryFieldList -from mongoengine.queryset.visitor import Q, QNode - - -__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') - -# Delete rules -DO_NOTHING = 0 -NULLIFY = 1 -CASCADE = 2 -DENY = 3 -PULL = 4 - -RE_TYPE = type(re.compile('')) - - -class BaseQuerySet(object): - """A set of results returned from a query. Wraps a MongoDB cursor, - providing :class:`~mongoengine.Document` objects as the results. - """ - __dereference = False - _auto_dereference = True - - def __init__(self, document, collection): - self._document = document - self._collection_obj = collection - self._mongo_query = None - self._query_obj = Q() - self._initial_query = {} - self._where_clause = None - self._loaded_fields = QueryFieldList() - self._ordering = [] - self._snapshot = False - self._timeout = True - self._class_check = True - self._slave_okay = False - self._read_preference = None - self._iter = False - self._scalar = [] - self._none = False - self._as_pymongo = False - self._as_pymongo_coerce = False - - # If inheritance is allowed, only return instances and instances of - # subclasses of the class being used - if document._meta.get('allow_inheritance') is True: - if len(self._document._subclasses) == 1: - self._initial_query = {"_cls": self._document._subclasses[0]} - else: - self._initial_query = {"_cls": {"$in": self._document._subclasses}} - self._loaded_fields = QueryFieldList(always_include=['_cls']) - self._cursor_obj = None - self._limit = None - self._skip = None - self._hint = -1 # Using -1 as None is a valid value for hint - - def __call__(self, q_obj=None, class_check=True, slave_okay=False, - read_preference=None, **query): - """Filter the selected documents by calling the - :class:`~mongoengine.queryset.QuerySet` with a query. - - :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in - the query; the :class:`~mongoengine.queryset.QuerySet` is filtered - multiple times with different :class:`~mongoengine.queryset.Q` - objects, only the last one will be used - :param class_check: If set to False bypass class name check when - querying collection - :param slave_okay: if True, allows this query to be run against a - replica secondary. - :params read_preference: if set, overrides connection-level - read_preference from `ReplicaSetConnection`. - :param query: Django-style query keyword arguments - """ - query = Q(**query) - if q_obj: - # make sure proper query object is passed - if not isinstance(q_obj, QNode): - msg = ("Not a query object: %s. " - "Did you intend to use key=value?" % q_obj) - raise InvalidQueryError(msg) - query &= q_obj - - if read_preference is None: - queryset = self.clone() - else: - # Use the clone provided when setting read_preference - queryset = self.read_preference(read_preference) - - queryset._query_obj &= query - queryset._mongo_query = None - queryset._cursor_obj = None - queryset._class_check = class_check - - return queryset - - def __getitem__(self, key): - """Support skip and limit using getitem and slicing syntax. - """ - queryset = self.clone() - - # Slice provided - if isinstance(key, slice): - try: - queryset._cursor_obj = queryset._cursor[key] - queryset._skip, queryset._limit = key.start, key.stop - if key.start and key.stop: - queryset._limit = key.stop - key.start - except IndexError, err: - # PyMongo raises an error if key.start == key.stop, catch it, - # bin it, kill it. - start = key.start or 0 - if start >= 0 and key.stop >= 0 and key.step is None: - if start == key.stop: - queryset.limit(0) - queryset._skip = key.start - queryset._limit = key.stop - start - return queryset - raise err - # Allow further QuerySet modifications to be performed - return queryset - # Integer index provided - elif isinstance(key, int): - if queryset._scalar: - return queryset._get_scalar( - queryset._document._from_son(queryset._cursor[key], - _auto_dereference=self._auto_dereference)) - if queryset._as_pymongo: - return queryset._get_as_pymongo(queryset._cursor.next()) - return queryset._document._from_son(queryset._cursor[key], - _auto_dereference=self._auto_dereference) - raise AttributeError - - def __iter__(self): - raise NotImplementedError - - # Core functions - - def all(self): - """Returns all documents.""" - return self.__call__() - - def filter(self, *q_objs, **query): - """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` - """ - return self.__call__(*q_objs, **query) - - def get(self, *q_objs, **query): - """Retrieve the the matching object raising - :class:`~mongoengine.queryset.MultipleObjectsReturned` or - `DocumentName.MultipleObjectsReturned` exception if multiple results - and :class:`~mongoengine.queryset.DoesNotExist` or - `DocumentName.DoesNotExist` if no results are found. - - .. versionadded:: 0.3 - """ - queryset = self.clone() - queryset = queryset.limit(2) - queryset = queryset.filter(*q_objs, **query) - - try: - result = queryset.next() - except StopIteration: - msg = ("%s matching query does not exist." - % queryset._document._class_name) - raise queryset._document.DoesNotExist(msg) - try: - queryset.next() - except StopIteration: - return result - - queryset.rewind() - message = u'%d items returned, instead of 1' % queryset.count() - raise queryset._document.MultipleObjectsReturned(message) - - def create(self, **kwargs): - """Create new object. Returns the saved object instance. - - .. versionadded:: 0.4 - """ - return self._document(**kwargs).save() - - def get_or_create(self, write_concern=None, auto_save=True, - *q_objs, **query): - """Retrieve unique object or create, if it doesn't exist. Returns a - tuple of ``(object, created)``, where ``object`` is the retrieved or - created object and ``created`` is a boolean specifying whether a new - object was created. Raises - :class:`~mongoengine.queryset.MultipleObjectsReturned` or - `DocumentName.MultipleObjectsReturned` if multiple results are found. - A new document will be created if the document doesn't exists; a - dictionary of default values for the new document may be provided as a - keyword argument called :attr:`defaults`. - - .. note:: This requires two separate operations and therefore a - race condition exists. Because there are no transactions in - mongoDB other approaches should be investigated, to ensure you - don't accidently duplicate data when using this method. This is - now scheduled to be removed before 1.0 - - :param write_concern: optional extra keyword arguments used if we - have to create a new document. - Passes any write_concern onto :meth:`~mongoengine.Document.save` - - :param auto_save: if the object is to be saved automatically if - not found. - - .. deprecated:: 0.8 - .. versionchanged:: 0.6 - added `auto_save` - .. versionadded:: 0.3 - """ - msg = ("get_or_create is scheduled to be deprecated. The approach is " - "flawed without transactions. Upserts should be preferred.") - warnings.warn(msg, DeprecationWarning) - - defaults = query.get('defaults', {}) - if 'defaults' in query: - del query['defaults'] - - try: - doc = self.get(*q_objs, **query) - return doc, False - except self._document.DoesNotExist: - query.update(defaults) - doc = self._document(**query) - - if auto_save: - doc.save(write_concern=write_concern) - return doc, True - - def first(self): - """Retrieve the first object matching the query. - """ - queryset = self.clone() - try: - result = queryset[0] - except IndexError: - result = None - return result - - def insert(self, doc_or_docs, load_bulk=True, write_concern=None): - """bulk insert documents - - :param docs_or_doc: a document or list of documents to be inserted - :param load_bulk (optional): If True returns the list of document - instances - :param write_concern: Extra keyword arguments are passed down to - :meth:`~pymongo.collection.Collection.insert` - which will be used as options for the resultant - ``getLastError`` command. For example, - ``insert(..., {w: 2, fsync: True})`` will wait until at least - two servers have recorded the write and will force an fsync on - each server being written to. - - By default returns document instances, set ``load_bulk`` to False to - return just ``ObjectIds`` - - .. versionadded:: 0.5 - """ - Document = _import_class('Document') - - if write_concern is None: - write_concern = {} - - docs = doc_or_docs - return_one = False - if isinstance(docs, Document) or issubclass(docs.__class__, Document): - return_one = True - docs = [docs] - - raw = [] - for doc in docs: - if not isinstance(doc, self._document): - msg = ("Some documents inserted aren't instances of %s" - % str(self._document)) - raise OperationError(msg) - if doc.pk and not doc._created: - msg = "Some documents have ObjectIds use doc.update() instead" - raise OperationError(msg) - raw.append(doc.to_mongo()) - - signals.pre_bulk_insert.send(self._document, documents=docs) - try: - ids = self._collection.insert(raw, **write_concern) - except pymongo.errors.OperationFailure, err: - message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', unicode(err)): - # E11000 - duplicate key error index - # E11001 - duplicate key on update - message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) - raise OperationError(message % unicode(err)) - - if not load_bulk: - signals.post_bulk_insert.send( - self._document, documents=docs, loaded=False) - return return_one and ids[0] or ids - - documents = self.in_bulk(ids) - results = [] - for obj_id in ids: - results.append(documents.get(obj_id)) - signals.post_bulk_insert.send( - self._document, documents=results, loaded=True) - return return_one and results[0] or results - - def count(self, with_limit_and_skip=True): - """Count the selected elements in the query. - - :param with_limit_and_skip (optional): take any :meth:`limit` or - :meth:`skip` that has been applied to this cursor into account when - getting the count - """ - if self._limit == 0 and with_limit_and_skip: - return 0 - return self._cursor.count(with_limit_and_skip=with_limit_and_skip) - - def delete(self, write_concern=None, _from_doc_delete=False): - """Delete the documents matched by the query. - - :param write_concern: Extra keyword arguments are passed down which - will be used as options for the resultant - ``getLastError`` command. For example, - ``save(..., write_concern={w: 2, fsync: True}, ...)`` will - wait until at least two servers have recorded the write and - will force an fsync on the primary server. - :param _from_doc_delete: True when called from document delete therefore - signals will have been triggered so don't loop. - """ - queryset = self.clone() - doc = queryset._document - - if write_concern is None: - write_concern = {} - - # Handle deletes where skips or limits have been applied or - # there is an untriggered delete signal - has_delete_signal = signals.signals_available and ( - signals.pre_delete.has_receivers_for(self._document) or - signals.post_delete.has_receivers_for(self._document)) - - call_document_delete = (queryset._skip or queryset._limit or - has_delete_signal) and not _from_doc_delete - - if call_document_delete: - for doc in queryset: - doc.delete(write_concern=write_concern) - return - - delete_rules = doc._meta.get('delete_rules') or {} - # Check for DENY rules before actually deleting/nullifying any other - # references - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == DENY and document_cls.objects( - **{field_name + '__in': self}).count() > 0: - msg = ("Could not delete document (%s.%s refers to it)" - % (document_cls.__name__, field_name)) - raise OperationError(msg) - - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == CASCADE: - ref_q = document_cls.objects(**{field_name + '__in': self}) - ref_q_count = ref_q.count() - if (doc != document_cls and ref_q_count > 0 - or (doc == document_cls and ref_q_count > 0)): - ref_q.delete(write_concern=write_concern) - elif rule == NULLIFY: - document_cls.objects(**{field_name + '__in': self}).update( - write_concern=write_concern, **{'unset__%s' % field_name: 1}) - elif rule == PULL: - document_cls.objects(**{field_name + '__in': self}).update( - write_concern=write_concern, - **{'pull_all__%s' % field_name: self}) - - queryset._collection.remove(queryset._query, write_concern=write_concern) - - def update(self, upsert=False, multi=True, write_concern=None, - full_result=False, **update): - """Perform an atomic update on the fields matched by the query. - - :param upsert: Any existing document with that "_id" is overwritten. - :param multi: Update multiple documents. - :param write_concern: Extra keyword arguments are passed down which - will be used as options for the resultant - ``getLastError`` command. For example, - ``save(..., write_concern={w: 2, fsync: True}, ...)`` will - wait until at least two servers have recorded the write and - will force an fsync on the primary server. - :param full_result: Return the full result rather than just the number - updated. - :param update: Django-style update keyword arguments - - .. versionadded:: 0.2 - """ - if not update and not upsert: - raise OperationError("No update parameters, would remove data") - - if write_concern is None: - write_concern = {} - - queryset = self.clone() - query = queryset._query - update = transform.update(queryset._document, **update) - - # If doing an atomic upsert on an inheritable class - # then ensure we add _cls to the update operation - if upsert and '_cls' in query: - if '$set' in update: - update["$set"]["_cls"] = queryset._document._class_name - else: - update["$set"] = {"_cls": queryset._document._class_name} - try: - result = queryset._collection.update(query, update, multi=multi, - upsert=upsert, **write_concern) - if full_result: - return result - elif result: - return result['n'] - except pymongo.errors.OperationFailure, err: - if unicode(err) == u'multi not coded yet': - message = u'update() method requires MongoDB 1.1.3+' - raise OperationError(message) - raise OperationError(u'Update failed (%s)' % unicode(err)) - - def update_one(self, upsert=False, write_concern=None, **update): - """Perform an atomic update on first field matched by the query. - - :param upsert: Any existing document with that "_id" is overwritten. - :param write_concern: Extra keyword arguments are passed down which - will be used as options for the resultant - ``getLastError`` command. For example, - ``save(..., write_concern={w: 2, fsync: True}, ...)`` will - wait until at least two servers have recorded the write and - will force an fsync on the primary server. - :param update: Django-style update keyword arguments - - .. versionadded:: 0.2 - """ - return self.update( - upsert=upsert, multi=False, write_concern=write_concern, **update) - - def with_id(self, object_id): - """Retrieve the object matching the id provided. Uses `object_id` only - and raises InvalidQueryError if a filter has been applied. Returns - `None` if no document exists with that id. - - :param object_id: the value for the id of the document to look up - - .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set - """ - queryset = self.clone() - if not queryset._query_obj.empty: - msg = "Cannot use a filter whilst using `with_id`" - raise InvalidQueryError(msg) - return queryset.filter(pk=object_id).first() - - def in_bulk(self, object_ids): - """Retrieve a set of documents by their ids. - - :param object_ids: a list or tuple of ``ObjectId``\ s - :rtype: dict of ObjectIds as keys and collection-specific - Document subclasses as values. - - .. versionadded:: 0.3 - """ - doc_map = {} - - docs = self._collection.find({'_id': {'$in': object_ids}}, - **self._cursor_args) - if self._scalar: - for doc in docs: - doc_map[doc['_id']] = self._get_scalar( - self._document._from_son(doc)) - elif self._as_pymongo: - for doc in docs: - doc_map[doc['_id']] = self._get_as_pymongo(doc) - else: - for doc in docs: - doc_map[doc['_id']] = self._document._from_son(doc) - - return doc_map - - def none(self): - """Helper that just returns a list""" - queryset = self.clone() - queryset._none = True - return queryset - - def no_sub_classes(self): - """ - Only return instances of this document and not any inherited documents - """ - if self._document._meta.get('allow_inheritance') is True: - self._initial_query = {"_cls": self._document._class_name} - - return self - - def clone(self): - """Creates a copy of the current - :class:`~mongoengine.queryset.QuerySet` - - .. versionadded:: 0.5 - """ - return self.clone_into(self.__class__(self._document, self._collection_obj)) - - def clone_into(self, cls): - """Creates a copy of the current - :class:`~mongoengine.queryset.base.BaseQuerySet` into another child class - """ - if not isinstance(cls, BaseQuerySet): - raise OperationError('%s is not a subclass of BaseQuerySet' % cls.__name__) - - copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', - '_where_clause', '_loaded_fields', '_ordering', '_snapshot', - '_timeout', '_class_check', '_slave_okay', '_read_preference', - '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', - '_limit', '_skip', '_hint', '_auto_dereference') - - for prop in copy_props: - val = getattr(self, prop) - setattr(cls, prop, copy.copy(val)) - - if self._cursor_obj: - cls._cursor_obj = self._cursor_obj.clone() - - return cls - - def select_related(self, max_depth=1): - """Handles dereferencing of :class:`~bson.dbref.DBRef` objects or - :class:`~bson.object_id.ObjectId` a maximum depth in order to cut down - the number queries to mongodb. - - .. versionadded:: 0.5 - """ - # Make select related work the same for querysets - max_depth += 1 - queryset = self.clone() - return queryset._dereference(queryset, max_depth=max_depth) - - def limit(self, n): - """Limit the number of returned documents to `n`. This may also be - achieved using array-slicing syntax (e.g. ``User.objects[:5]``). - - :param n: the maximum number of objects to return - """ - queryset = self.clone() - if n == 0: - queryset._cursor.limit(1) - else: - queryset._cursor.limit(n) - queryset._limit = n - # Return self to allow chaining - return queryset - - def skip(self, n): - """Skip `n` documents before returning the results. This may also be - achieved using array-slicing syntax (e.g. ``User.objects[5:]``). - - :param n: the number of objects to skip before returning results - """ - queryset = self.clone() - queryset._cursor.skip(n) - queryset._skip = n - return queryset - - def hint(self, index=None): - """Added 'hint' support, telling Mongo the proper index to use for the - query. - - Judicious use of hints can greatly improve query performance. When - doing a query on multiple fields (at least one of which is indexed) - pass the indexed field as a hint to the query. - - Hinting will not do anything if the corresponding index does not exist. - The last hint applied to this cursor takes precedence over all others. - - .. versionadded:: 0.5 - """ - queryset = self.clone() - queryset._cursor.hint(index) - queryset._hint = index - return queryset - - def distinct(self, field): - """Return a list of distinct values for a given field. - - :param field: the field to select distinct values from - - .. note:: This is a command and won't take ordering or limit into - account. - - .. versionadded:: 0.4 - .. versionchanged:: 0.5 - Fixed handling references - .. versionchanged:: 0.6 - Improved db_field refrence handling - """ - queryset = self.clone() - try: - field = self._fields_to_dbfields([field]).pop() - finally: - return self._dereference(queryset._cursor.distinct(field), 1, - name=field, instance=self._document) - - def only(self, *fields): - """Load only a subset of this document's fields. :: - - post = BlogPost.objects(...).only("title", "author.name") - - .. note :: `only()` is chainable and will perform a union :: - So with the following it will fetch both: `title` and `author.name`:: - - post = BlogPost.objects.only("title").only("author.name") - - :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any - field filters. - - :param fields: fields to include - - .. versionadded:: 0.3 - .. versionchanged:: 0.5 - Added subfield support - """ - fields = dict([(f, QueryFieldList.ONLY) for f in fields]) - return self.fields(True, **fields) - - def exclude(self, *fields): - """Opposite to .only(), exclude some document's fields. :: - - post = BlogPost.objects(...).exclude("comments") - - .. note :: `exclude()` is chainable and will perform a union :: - So with the following it will exclude both: `title` and `author.name`:: - - post = BlogPost.objects.exclude("title").exclude("author.name") - - :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any - field filters. - - :param fields: fields to exclude - - .. versionadded:: 0.5 - """ - fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) - return self.fields(**fields) - - def fields(self, _only_called=False, **kwargs): - """Manipulate how you load this document's fields. Used by `.only()` - and `.exclude()` to manipulate which fields to retrieve. Fields also - allows for a greater level of control for example: - - Retrieving a Subrange of Array Elements: - - You can use the $slice operator to retrieve a subrange of elements in - an array. For example to get the first 5 comments:: - - post = BlogPost.objects(...).fields(slice__comments=5) - - :param kwargs: A dictionary identifying what to include - - .. versionadded:: 0.5 - """ - - # Check for an operator and transform to mongo-style if there is - operators = ["slice"] - cleaned_fields = [] - for key, value in kwargs.items(): - parts = key.split('__') - op = None - if parts[0] in operators: - op = parts.pop(0) - value = {'$' + op: value} - key = '.'.join(parts) - cleaned_fields.append((key, value)) - - fields = sorted(cleaned_fields, key=operator.itemgetter(1)) - queryset = self.clone() - for value, group in itertools.groupby(fields, lambda x: x[1]): - fields = [field for field, value in group] - fields = queryset._fields_to_dbfields(fields) - queryset._loaded_fields += QueryFieldList(fields, value=value, _only_called=_only_called) - - return queryset - - def all_fields(self): - """Include all fields. Reset all previously calls of .only() or - .exclude(). :: - - post = BlogPost.objects.exclude("comments").all_fields() - - .. versionadded:: 0.5 - """ - queryset = self.clone() - queryset._loaded_fields = QueryFieldList( - always_include=queryset._loaded_fields.always_include) - return queryset - - def order_by(self, *keys): - """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The - order may be specified by prepending each of the keys by a + or a -. - Ascending order is assumed. - - :param keys: fields to order the query results by; keys may be - prefixed with **+** or **-** to determine the ordering direction - """ - queryset = self.clone() - queryset._ordering = queryset._get_order_by(keys) - return queryset - - def explain(self, format=False): - """Return an explain plan record for the - :class:`~mongoengine.queryset.QuerySet`\ 's cursor. - - :param format: format the plan before returning it - """ - plan = self._cursor.explain() - if format: - plan = pprint.pformat(plan) - return plan - - def snapshot(self, enabled): - """Enable or disable snapshot mode when querying. - - :param enabled: whether or not snapshot mode is enabled - - ..versionchanged:: 0.5 - made chainable - """ - queryset = self.clone() - queryset._snapshot = enabled - return queryset - - def timeout(self, enabled): - """Enable or disable the default mongod timeout when querying. - - :param enabled: whether or not the timeout is used - - ..versionchanged:: 0.5 - made chainable - """ - queryset = self.clone() - queryset._timeout = enabled - return queryset - - def slave_okay(self, enabled): - """Enable or disable the slave_okay when querying. - - :param enabled: whether or not the slave_okay is enabled - """ - queryset = self.clone() - queryset._slave_okay = enabled - return queryset - - def read_preference(self, read_preference): - """Change the read_preference when querying. - - :param read_preference: override ReplicaSetConnection-level - preference. - """ - validate_read_preference('read_preference', read_preference) - queryset = self.clone() - queryset._read_preference = read_preference - return queryset - - def scalar(self, *fields): - """Instead of returning Document instances, return either a specific - value or a tuple of values in order. - - Can be used along with - :func:`~mongoengine.queryset.QuerySet.no_dereference` to turn off - dereferencing. - - .. note:: This effects all results and can be unset by calling - ``scalar`` without arguments. Calls ``only`` automatically. - - :param fields: One or more fields to return instead of a Document. - """ - queryset = self.clone() - queryset._scalar = list(fields) - - if fields: - queryset = queryset.only(*fields) - else: - queryset = queryset.all_fields() - - return queryset - - def values_list(self, *fields): - """An alias for scalar""" - return self.scalar(*fields) - - def as_pymongo(self, coerce_types=False): - """Instead of returning Document instances, return raw values from - pymongo. - - :param coerce_type: Field types (if applicable) would be use to - coerce types. - """ - queryset = self.clone() - queryset._as_pymongo = True - queryset._as_pymongo_coerce = coerce_types - return queryset - - # JSON Helpers - - def to_json(self, *args, **kwargs): - """Converts a queryset to JSON""" - return json_util.dumps(self.as_pymongo(), *args, **kwargs) - - def from_json(self, json_data): - """Converts json data to unsaved objects""" - son_data = json_util.loads(json_data) - return [self._document._from_son(data) for data in son_data] - - # JS functionality - - def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, - scope=None): - """Perform a map/reduce query using the current query spec - and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, - it must be the last call made, as it does not return a maleable - ``QuerySet``. - - See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` - and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` - tests in ``tests.queryset.QuerySetTest`` for usage examples. - - :param map_f: map function, as :class:`~bson.code.Code` or string - :param reduce_f: reduce function, as - :class:`~bson.code.Code` or string - :param output: output collection name, if set to 'inline' will try to - use :class:`~pymongo.collection.Collection.inline_map_reduce` - This can also be a dictionary containing output options - see: http://docs.mongodb.org/manual/reference/commands/#mapReduce - :param finalize_f: finalize function, an optional function that - performs any post-reduction processing. - :param scope: values to insert into map/reduce global scope. Optional. - :param limit: number of objects from current query to provide - to map/reduce method - - Returns an iterator yielding - :class:`~mongoengine.document.MapReduceDocument`. - - .. note:: - - Map/Reduce changed in server version **>= 1.7.4**. The PyMongo - :meth:`~pymongo.collection.Collection.map_reduce` helper requires - PyMongo version **>= 1.11**. - - .. versionchanged:: 0.5 - - removed ``keep_temp`` keyword argument, which was only relevant - for MongoDB server versions older than 1.7.4 - - .. versionadded:: 0.3 - """ - queryset = self.clone() - - MapReduceDocument = _import_class('MapReduceDocument') - - if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.7.1") - - map_f_scope = {} - if isinstance(map_f, Code): - map_f_scope = map_f.scope - map_f = unicode(map_f) - map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) - - reduce_f_scope = {} - if isinstance(reduce_f, Code): - reduce_f_scope = reduce_f.scope - reduce_f = unicode(reduce_f) - reduce_f_code = queryset._sub_js_fields(reduce_f) - reduce_f = Code(reduce_f_code, reduce_f_scope) - - mr_args = {'query': queryset._query} - - if finalize_f: - finalize_f_scope = {} - if isinstance(finalize_f, Code): - finalize_f_scope = finalize_f.scope - finalize_f = unicode(finalize_f) - finalize_f_code = queryset._sub_js_fields(finalize_f) - finalize_f = Code(finalize_f_code, finalize_f_scope) - mr_args['finalize'] = finalize_f - - if scope: - mr_args['scope'] = scope - - if limit: - mr_args['limit'] = limit - - if output == 'inline' and not queryset._ordering: - map_reduce_function = 'inline_map_reduce' - else: - map_reduce_function = 'map_reduce' - mr_args['out'] = output - - results = getattr(queryset._collection, map_reduce_function)( - map_f, reduce_f, **mr_args) - - if map_reduce_function == 'map_reduce': - results = results.find() - - if queryset._ordering: - results = results.sort(queryset._ordering) - - for doc in results: - yield MapReduceDocument(queryset._document, queryset._collection, - doc['_id'], doc['value']) - - def exec_js(self, code, *fields, **options): - """Execute a Javascript function on the server. A list of fields may be - provided, which will be translated to their correct names and supplied - as the arguments to the function. A few extra variables are added to - the function's scope: ``collection``, which is the name of the - collection in use; ``query``, which is an object representing the - current query; and ``options``, which is an object containing any - options specified as keyword arguments. - - As fields in MongoEngine may use different names in the database (set - using the :attr:`db_field` keyword argument to a :class:`Field` - constructor), a mechanism exists for replacing MongoEngine field names - with the database field names in Javascript code. When accessing a - field, use square-bracket notation, and prefix the MongoEngine field - name with a tilde (~). - - :param code: a string of Javascript code to execute - :param fields: fields that you will be using in your function, which - will be passed in to your function as arguments - :param options: options that you want available to the function - (accessed in Javascript through the ``options`` object) - """ - queryset = self.clone() - - code = queryset._sub_js_fields(code) - - fields = [queryset._document._translate_field_name(f) for f in fields] - collection = queryset._document._get_collection_name() - - scope = { - 'collection': collection, - 'options': options or {}, - } - - query = queryset._query - if queryset._where_clause: - query['$where'] = queryset._where_clause - - scope['query'] = query - code = Code(code, scope=scope) - - db = queryset._document._get_db() - return db.eval(code, *fields) - - def where(self, where_clause): - """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript - expression). Performs automatic field name substitution like - :meth:`mongoengine.queryset.Queryset.exec_js`. - - .. note:: When using this mode of query, the database will call your - function, or evaluate your predicate clause, for each object - in the collection. - - .. versionadded:: 0.5 - """ - queryset = self.clone() - where_clause = queryset._sub_js_fields(where_clause) - queryset._where_clause = where_clause - return queryset - - def sum(self, field): - """Sum over the values of the specified field. - - :param field: the field to sum over; use dot-notation to refer to - embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. - """ - map_func = """ - function() { - var path = '{{~%(field)s}}'.split('.'), - field = this; - - for (p in path) { - if (typeof field != 'undefined') - field = field[path[p]]; - else - break; - } - - if (field && field.constructor == Array) { - field.forEach(function(item) { - emit(1, item||0); - }); - } else if (typeof field != 'undefined') { - emit(1, field||0); - } - } - """ % dict(field=field) - - reduce_func = Code(""" - function(key, values) { - var sum = 0; - for (var i in values) { - sum += values[i]; - } - return sum; - } - """) - - for result in self.map_reduce(map_func, reduce_func, output='inline'): - return result.value - else: - return 0 - - def average(self, field): - """Average over the values of the specified field. - - :param field: the field to average over; use dot-notation to refer to - embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. - """ - map_func = """ - function() { - var path = '{{~%(field)s}}'.split('.'), - field = this; - - for (p in path) { - if (typeof field != 'undefined') - field = field[path[p]]; - else - break; - } - - if (field && field.constructor == Array) { - field.forEach(function(item) { - emit(1, {t: item||0, c: 1}); - }); - } else if (typeof field != 'undefined') { - emit(1, {t: field||0, c: 1}); - } - } - """ % dict(field=field) - - reduce_func = Code(""" - function(key, values) { - var out = {t: 0, c: 0}; - for (var i in values) { - var value = values[i]; - out.t += value.t; - out.c += value.c; - } - return out; - } - """) - - finalize_func = Code(""" - function(key, value) { - return value.t / value.c; - } - """) - - for result in self.map_reduce(map_func, reduce_func, - finalize_f=finalize_func, output='inline'): - return result.value - else: - return 0 - - def item_frequencies(self, field, normalize=False, map_reduce=True): - """Returns a dictionary of all items present in a field across - the whole queried set of documents, and their corresponding frequency. - This is useful for generating tag clouds, or searching documents. - - .. note:: - - Can only do direct simple mappings and cannot map across - :class:`~mongoengine.fields.ReferenceField` or - :class:`~mongoengine.fields.GenericReferenceField` for more complex - counting a manual map reduce call would is required. - - If the field is a :class:`~mongoengine.fields.ListField`, the items within - each list will be counted individually. - - :param field: the field to use - :param normalize: normalize the results so they add to 1.0 - :param map_reduce: Use map_reduce over exec_js - - .. versionchanged:: 0.5 defaults to map_reduce and can handle embedded - document lookups - """ - if map_reduce: - return self._item_frequencies_map_reduce(field, - normalize=normalize) - return self._item_frequencies_exec_js(field, normalize=normalize) - - # Iterator helpers - - def next(self): - """Wrap the result in a :class:`~mongoengine.Document` object. - """ - if self._limit == 0 or self._none: - raise StopIteration - - raw_doc = self._cursor.next() - if self._as_pymongo: - return self._get_as_pymongo(raw_doc) - doc = self._document._from_son(raw_doc, - _auto_dereference=self._auto_dereference) - if self._scalar: - return self._get_scalar(doc) - - return doc - - def rewind(self): - """Rewind the cursor to its unevaluated state. - - .. versionadded:: 0.3 - """ - self._iter = False - self._cursor.rewind() - - # Properties - - @property - def _collection(self): - """Property that returns the collection object. This allows us to - perform operations only if the collection is accessed. - """ - return self._collection_obj - - @property - def _cursor_args(self): - cursor_args = { - 'snapshot': self._snapshot, - 'timeout': self._timeout - } - if self._read_preference is not None: - cursor_args['read_preference'] = self._read_preference - else: - cursor_args['slave_okay'] = self._slave_okay - if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields.as_dict() - return cursor_args - - @property - def _cursor(self): - if self._cursor_obj is None: - - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) - # Apply where clauses to cursor - if self._where_clause: - where_clause = self._sub_js_fields(self._where_clause) - self._cursor_obj.where(where_clause) - - if self._ordering: - # Apply query ordering - self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: - # Otherwise, apply the ordering from the document model - order = self._get_order_by(self._document._meta['ordering']) - self._cursor_obj.sort(order) - - if self._limit is not None: - self._cursor_obj.limit(self._limit) - - if self._skip is not None: - self._cursor_obj.skip(self._skip) - - if self._hint != -1: - self._cursor_obj.hint(self._hint) - - return self._cursor_obj - - def __deepcopy__(self, memo): - """Essential for chained queries with ReferenceFields involved""" - return self.clone() - - @property - def _query(self): - if self._mongo_query is None: - self._mongo_query = self._query_obj.to_query(self._document) - if self._class_check: - self._mongo_query.update(self._initial_query) - return self._mongo_query - - @property - def _dereference(self): - if not self.__dereference: - self.__dereference = _import_class('DeReference')() - return self.__dereference - - def no_dereference(self): - """Turn off any dereferencing for the results of this queryset. - """ - queryset = self.clone() - queryset._auto_dereference = False - return queryset - - # Helper Functions - - def _item_frequencies_map_reduce(self, field, normalize=False): - map_func = """ - function() { - var path = '{{~%(field)s}}'.split('.'); - var field = this; - - for (p in path) { - if (typeof field != 'undefined') - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - field.forEach(function(item) { - emit(item, 1); - }); - } else if (typeof field != 'undefined') { - emit(field, 1); - } else { - emit(null, 1); - } - } - """ % dict(field=field) - reduce_func = """ - function(key, values) { - var total = 0; - var valuesSize = values.length; - for (var i=0; i < valuesSize; i++) { - total += parseInt(values[i], 10); - } - return total; - } - """ - values = self.map_reduce(map_func, reduce_func, 'inline') - frequencies = {} - for f in values: - key = f.key - if isinstance(key, float): - if int(key) == key: - key = int(key) - frequencies[key] = int(f.value) - - if normalize: - count = sum(frequencies.values()) - frequencies = dict([(k, float(v) / count) - for k, v in frequencies.items()]) - - return frequencies - - def _item_frequencies_exec_js(self, field, normalize=False): - """Uses exec_js to execute""" - freq_func = """ - function(path) { - var path = path.split('.'); - - var total = 0.0; - db[collection].find(query).forEach(function(doc) { - var field = doc; - for (p in path) { - if (field) - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - total += field.length; - } else { - total++; - } - }); - - var frequencies = {}; - var types = {}; - var inc = 1.0; - - db[collection].find(query).forEach(function(doc) { - field = doc; - for (p in path) { - if (field) - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - field.forEach(function(item) { - frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); - }); - } else { - var item = field; - types[item] = item; - frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); - } - }); - return [total, frequencies, types]; - } - """ - total, data, types = self.exec_js(freq_func, field) - values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) - - if normalize: - values = dict([(k, float(v) / total) for k, v in values.items()]) - - frequencies = {} - for k, v in values.iteritems(): - if isinstance(k, float): - if int(k) == k: - k = int(k) - - frequencies[k] = v - - return frequencies - - def _fields_to_dbfields(self, fields, subdoc=False): - """Translate fields paths to its db equivalents""" - ret = [] - subclasses = [] - document = self._document - if document._meta['allow_inheritance']: - subclasses = [get_document(x) - for x in document._subclasses][1:] - for field in fields: - try: - field = ".".join(f.db_field for f in - document._lookup_field(field.split('.'))) - ret.append(field) - except LookUpError, err: - found = False - for subdoc in subclasses: - try: - subfield = ".".join(f.db_field for f in - subdoc._lookup_field(field.split('.'))) - ret.append(subfield) - found = True - break - except LookUpError, e: - pass - - if not found: - raise err - return ret - - def _get_order_by(self, keys): - """Creates a list of order by fields - """ - key_list = [] - for key in keys: - if not key: - continue - direction = pymongo.ASCENDING - if key[0] == '-': - direction = pymongo.DESCENDING - if key[0] in ('-', '+'): - key = key[1:] - key = key.replace('__', '.') - try: - key = self._document._translate_field_name(key) - except: - pass - key_list.append((key, direction)) - - if self._cursor_obj: - self._cursor_obj.sort(key_list) - return key_list - - def _get_scalar(self, doc): - - def lookup(obj, name): - chunks = name.split('__') - for chunk in chunks: - obj = getattr(obj, chunk) - return obj - - data = [lookup(doc, n) for n in self._scalar] - if len(data) == 1: - return data[0] - - return tuple(data) - - def _get_as_pymongo(self, row): - # Extract which fields paths we should follow if .fields(...) was - # used. If not, handle all fields. - if not getattr(self, '__as_pymongo_fields', None): - self.__as_pymongo_fields = [] - - for field in self._loaded_fields.fields - set(['_cls']): - self.__as_pymongo_fields.append(field) - while '.' in field: - field, _ = field.rsplit('.', 1) - self.__as_pymongo_fields.append(field) - - all_fields = not self.__as_pymongo_fields - - def clean(data, path=None): - path = path or '' - - if isinstance(data, dict): - new_data = {} - for key, value in data.iteritems(): - new_path = '%s.%s' % (path, key) if path else key - - if all_fields: - include_field = True - elif self._loaded_fields.value == QueryFieldList.ONLY: - include_field = new_path in self.__as_pymongo_fields - else: - include_field = new_path not in self.__as_pymongo_fields - - if include_field: - new_data[key] = clean(value, path=new_path) - data = new_data - elif isinstance(data, list): - data = [clean(d, path=path) for d in data] - else: - if self._as_pymongo_coerce: - # If we need to coerce types, we need to determine the - # type of this field and use the corresponding - # .to_python(...) - from mongoengine.fields import EmbeddedDocumentField - obj = self._document - for chunk in path.split('.'): - obj = getattr(obj, chunk, None) - if obj is None: - break - elif isinstance(obj, EmbeddedDocumentField): - obj = obj.document_type - if obj and data is not None: - data = obj.to_python(data) - return data - return clean(row) - - def _sub_js_fields(self, code): - """When fields are specified with [~fieldname] syntax, where - *fieldname* is the Python name of a field, *fieldname* will be - substituted for the MongoDB name of the field (specified using the - :attr:`name` keyword argument in a field's constructor). - """ - def field_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = self._document._lookup_field(field_name) - # Substitute the correct name for the field into the javascript - return u'["%s"]' % fields[-1].db_field - - def field_path_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = self._document._lookup_field(field_name) - # Substitute the correct name for the field into the javascript - return ".".join([f.db_field for f in fields]) - - code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) - code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, - code) - return code - - # Deprecated - def ensure_index(self, **kwargs): - """Deprecated use :func:`Document.ensure_index`""" - msg = ("Doc.objects()._ensure_index() is deprecated. " - "Use Doc.ensure_index() instead.") - warnings.warn(msg, DeprecationWarning) - self._document.__class__.ensure_index(**kwargs) - return self - - def _ensure_indexes(self): - """Deprecated use :func:`~Document.ensure_indexes`""" - msg = ("Doc.objects()._ensure_indexes() is deprecated. " - "Use Doc.ensure_indexes() instead.") - warnings.warn(msg, DeprecationWarning) - self._document.__class__.ensure_indexes() \ No newline at end of file diff --git a/mongoengine/queryset/field_list.py b/mongoengine/queryset/field_list.py index 140a71e..73d3cc2 100644 --- a/mongoengine/queryset/field_list.py +++ b/mongoengine/queryset/field_list.py @@ -55,8 +55,7 @@ class QueryFieldList(object): if self.always_include: if self.value is self.ONLY and self.fields: - if sorted(self.slice.keys()) != sorted(self.fields): - self.fields = self.fields.union(self.always_include) + self.fields = self.fields.union(self.always_include) else: self.fields -= self.always_include diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 1437e76..235d27b 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -1,26 +1,137 @@ -from mongoengine.errors import OperationError -from mongoengine.queryset.base import (BaseQuerySet, DO_NOTHING, NULLIFY, - CASCADE, DENY, PULL) +from __future__ import absolute_import -__all__ = ('QuerySet', 'QuerySetNoCache', 'DO_NOTHING', 'NULLIFY', 'CASCADE', - 'DENY', 'PULL') +import copy +import itertools +import operator +import pprint +import re +import warnings + +from bson.code import Code +from bson import json_util +import pymongo +from pymongo.common import validate_read_preference + +from mongoengine import signals +from mongoengine.common import _import_class +from mongoengine.errors import (OperationError, NotUniqueError, + InvalidQueryError) + +from mongoengine.queryset import transform +from mongoengine.queryset.field_list import QueryFieldList +from mongoengine.queryset.visitor import Q, QNode + + +__all__ = ('QuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') # The maximum number of items to display in a QuerySet.__repr__ REPR_OUTPUT_SIZE = 20 ITER_CHUNK_SIZE = 100 +# Delete rules +DO_NOTHING = 0 +NULLIFY = 1 +CASCADE = 2 +DENY = 3 +PULL = 4 -class QuerySet(BaseQuerySet): - """The default queryset, that builds queries and handles a set of results - returned from a query. +RE_TYPE = type(re.compile('')) - Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as - the results. + +class QuerySet(object): + """A set of results returned from a query. Wraps a MongoDB cursor, + providing :class:`~mongoengine.Document` objects as the results. """ + __dereference = False + _auto_dereference = True - _has_more = True - _len = None - _result_cache = None + def __init__(self, document, collection): + self._document = document + self._collection_obj = collection + self._mongo_query = None + self._query_obj = Q() + self._initial_query = {} + self._where_clause = None + self._loaded_fields = QueryFieldList() + self._ordering = None + self._snapshot = False + self._timeout = True + self._class_check = True + self._slave_okay = False + self._read_preference = None + self._iter = False + self._scalar = [] + self._none = False + self._as_pymongo = False + self._as_pymongo_coerce = False + self._result_cache = [] + self._has_more = True + self._len = None + + # If inheritance is allowed, only return instances and instances of + # subclasses of the class being used + if document._meta.get('allow_inheritance') is True: + if len(self._document._subclasses) == 1: + self._initial_query = {"_cls": self._document._subclasses[0]} + else: + self._initial_query = {"_cls": {"$in": self._document._subclasses}} + self._loaded_fields = QueryFieldList(always_include=['_cls']) + self._cursor_obj = None + self._limit = None + self._skip = None + self._hint = -1 # Using -1 as None is a valid value for hint + + def __call__(self, q_obj=None, class_check=True, slave_okay=False, + read_preference=None, **query): + """Filter the selected documents by calling the + :class:`~mongoengine.queryset.QuerySet` with a query. + + :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in + the query; the :class:`~mongoengine.queryset.QuerySet` is filtered + multiple times with different :class:`~mongoengine.queryset.Q` + objects, only the last one will be used + :param class_check: If set to False bypass class name check when + querying collection + :param slave_okay: if True, allows this query to be run against a + replica secondary. + :params read_preference: if set, overrides connection-level + read_preference from `ReplicaSetConnection`. + :param query: Django-style query keyword arguments + """ + query = Q(**query) + if q_obj: + # make sure proper query object is passed + if not isinstance(q_obj, QNode): + msg = ("Not a query object: %s. " + "Did you intend to use key=value?" % q_obj) + raise InvalidQueryError(msg) + query &= q_obj + + if read_preference is None: + queryset = self.clone() + else: + # Use the clone provided when setting read_preference + queryset = self.read_preference(read_preference) + + queryset._query_obj &= query + queryset._mongo_query = None + queryset._cursor_obj = None + queryset._class_check = class_check + + return queryset + + def __len__(self): + """Since __len__ is called quite frequently (for example, as part of + list(qs) we populate the result cache and cache the length. + """ + if self._len is not None: + return self._len + if self._has_more: + # populate the cache + list(self._iter_results()) + + self._len = len(self._result_cache) + return self._len def __iter__(self): """Iteration utilises a results cache which iterates the cursor @@ -36,39 +147,11 @@ class QuerySet(BaseQuerySet): # iterating over the cache. return iter(self._result_cache) - def __len__(self): - """Since __len__ is called quite frequently (for example, as part of - list(qs) we populate the result cache and cache the length. - """ - if self._len is not None: - return self._len - if self._has_more: - # populate the cache - list(self._iter_results()) - - self._len = len(self._result_cache) - return self._len - - def __repr__(self): - """Provides the string representation of the QuerySet - """ - if self._iter: - return '.. queryset mid-iteration ..' - - self._populate_cache() - data = self._result_cache[:REPR_OUTPUT_SIZE + 1] - if len(data) > REPR_OUTPUT_SIZE: - data[-1] = "...(remaining elements truncated)..." - return repr(data) - - def _iter_results(self): """A generator for iterating over the result cache. Also populates the cache if there are more possible results to yield. Raises StopIteration when there are no more results""" - if self._result_cache is None: - self._result_cache = [] pos = 0 while True: upper = len(self._result_cache) @@ -85,8 +168,6 @@ class QuerySet(BaseQuerySet): Populates the result cache with ``ITER_CHUNK_SIZE`` more entries (until the cursor is exhausted). """ - if self._result_cache is None: - self._result_cache = [] if self._has_more: try: for i in xrange(ITER_CHUNK_SIZE): @@ -94,6 +175,226 @@ class QuerySet(BaseQuerySet): except StopIteration: self._has_more = False + def __getitem__(self, key): + """Support skip and limit using getitem and slicing syntax. + """ + queryset = self.clone() + + # Slice provided + if isinstance(key, slice): + try: + queryset._cursor_obj = queryset._cursor[key] + queryset._skip, queryset._limit = key.start, key.stop + if key.start and key.stop: + queryset._limit = key.stop - key.start + except IndexError, err: + # PyMongo raises an error if key.start == key.stop, catch it, + # bin it, kill it. + start = key.start or 0 + if start >= 0 and key.stop >= 0 and key.step is None: + if start == key.stop: + queryset.limit(0) + queryset._skip = key.start + queryset._limit = key.stop - start + return queryset + raise err + # Allow further QuerySet modifications to be performed + return queryset + # Integer index provided + elif isinstance(key, int): + if queryset._scalar: + return queryset._get_scalar( + queryset._document._from_son(queryset._cursor[key], + _auto_dereference=self._auto_dereference)) + if queryset._as_pymongo: + return queryset._get_as_pymongo(queryset._cursor.next()) + return queryset._document._from_son(queryset._cursor[key], + _auto_dereference=self._auto_dereference) + raise AttributeError + + def __repr__(self): + """Provides the string representation of the QuerySet + """ + + if self._iter: + return '.. queryset mid-iteration ..' + + self._populate_cache() + data = self._result_cache[:REPR_OUTPUT_SIZE + 1] + if len(data) > REPR_OUTPUT_SIZE: + data[-1] = "...(remaining elements truncated)..." + return repr(data) + + # Core functions + + def all(self): + """Returns all documents.""" + return self.__call__() + + def filter(self, *q_objs, **query): + """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` + """ + return self.__call__(*q_objs, **query) + + def get(self, *q_objs, **query): + """Retrieve the the matching object raising + :class:`~mongoengine.queryset.MultipleObjectsReturned` or + `DocumentName.MultipleObjectsReturned` exception if multiple results + and :class:`~mongoengine.queryset.DoesNotExist` or + `DocumentName.DoesNotExist` if no results are found. + + .. versionadded:: 0.3 + """ + queryset = self.clone() + queryset = queryset.limit(2) + queryset = queryset.filter(*q_objs, **query) + + try: + result = queryset.next() + except StopIteration: + msg = ("%s matching query does not exist." + % queryset._document._class_name) + raise queryset._document.DoesNotExist(msg) + try: + queryset.next() + except StopIteration: + return result + + queryset.rewind() + message = u'%d items returned, instead of 1' % queryset.count() + raise queryset._document.MultipleObjectsReturned(message) + + def create(self, **kwargs): + """Create new object. Returns the saved object instance. + + .. versionadded:: 0.4 + """ + return self._document(**kwargs).save() + + def get_or_create(self, write_concern=None, auto_save=True, + *q_objs, **query): + """Retrieve unique object or create, if it doesn't exist. Returns a + tuple of ``(object, created)``, where ``object`` is the retrieved or + created object and ``created`` is a boolean specifying whether a new + object was created. Raises + :class:`~mongoengine.queryset.MultipleObjectsReturned` or + `DocumentName.MultipleObjectsReturned` if multiple results are found. + A new document will be created if the document doesn't exists; a + dictionary of default values for the new document may be provided as a + keyword argument called :attr:`defaults`. + + .. note:: This requires two separate operations and therefore a + race condition exists. Because there are no transactions in + mongoDB other approaches should be investigated, to ensure you + don't accidently duplicate data when using this method. This is + now scheduled to be removed before 1.0 + + :param write_concern: optional extra keyword arguments used if we + have to create a new document. + Passes any write_concern onto :meth:`~mongoengine.Document.save` + + :param auto_save: if the object is to be saved automatically if + not found. + + .. deprecated:: 0.8 + .. versionchanged:: 0.6 - added `auto_save` + .. versionadded:: 0.3 + """ + msg = ("get_or_create is scheduled to be deprecated. The approach is " + "flawed without transactions. Upserts should be preferred.") + warnings.warn(msg, DeprecationWarning) + + defaults = query.get('defaults', {}) + if 'defaults' in query: + del query['defaults'] + + try: + doc = self.get(*q_objs, **query) + return doc, False + except self._document.DoesNotExist: + query.update(defaults) + doc = self._document(**query) + + if auto_save: + doc.save(write_concern=write_concern) + return doc, True + + def first(self): + """Retrieve the first object matching the query. + """ + queryset = self.clone() + try: + result = queryset[0] + except IndexError: + result = None + return result + + def insert(self, doc_or_docs, load_bulk=True, write_concern=None): + """bulk insert documents + + :param docs_or_doc: a document or list of documents to be inserted + :param load_bulk (optional): If True returns the list of document + instances + :param write_concern: Extra keyword arguments are passed down to + :meth:`~pymongo.collection.Collection.insert` + which will be used as options for the resultant + ``getLastError`` command. For example, + ``insert(..., {w: 2, fsync: True})`` will wait until at least + two servers have recorded the write and will force an fsync on + each server being written to. + + By default returns document instances, set ``load_bulk`` to False to + return just ``ObjectIds`` + + .. versionadded:: 0.5 + """ + Document = _import_class('Document') + + if write_concern is None: + write_concern = {} + + docs = doc_or_docs + return_one = False + if isinstance(docs, Document) or issubclass(docs.__class__, Document): + return_one = True + docs = [docs] + + raw = [] + for doc in docs: + if not isinstance(doc, self._document): + msg = ("Some documents inserted aren't instances of %s" + % str(self._document)) + raise OperationError(msg) + if doc.pk and doc._created: + msg = "Some documents have ObjectIds use doc.update() instead" + raise OperationError(msg) + raw.append(doc.to_mongo()) + + signals.pre_bulk_insert.send(self._document, documents=docs) + try: + ids = self._collection.insert(raw, **write_concern) + except pymongo.errors.OperationFailure, err: + message = 'Could not save document (%s)' + if re.match('^E1100[01] duplicate key', unicode(err)): + # E11000 - duplicate key error index + # E11001 - duplicate key on update + message = u'Tried to save duplicate unique keys (%s)' + raise NotUniqueError(message % unicode(err)) + raise OperationError(message % unicode(err)) + + if not load_bulk: + signals.post_bulk_insert.send( + self._document, documents=docs, loaded=False) + return return_one and ids[0] or ids + + documents = self.in_bulk(ids) + results = [] + for obj_id in ids: + results.append(documents.get(obj_id)) + signals.post_bulk_insert.send( + self._document, documents=results, loaded=True) + return return_one and results[0] or results + def count(self, with_limit_and_skip=True): """Count the selected elements in the query. @@ -101,57 +402,1138 @@ class QuerySet(BaseQuerySet): :meth:`skip` that has been applied to this cursor into account when getting the count """ - if with_limit_and_skip is False: - return super(QuerySet, self).count(with_limit_and_skip) + if self._limit == 0: + return 0 + if with_limit_and_skip and self._len is not None: + return self._len + count = self._cursor.count(with_limit_and_skip=with_limit_and_skip) + if with_limit_and_skip: + self._len = count + return count - if self._len is None: - self._len = super(QuerySet, self).count(with_limit_and_skip) + def delete(self, write_concern=None, _from_doc_delete=False): + """Delete the documents matched by the query. - return self._len - - def no_cache(self): - """Convert to a non_caching queryset - - .. versionadded:: 0.8.3 Convert to non caching queryset + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param _from_doc_delete: True when called from document delete therefore + signals will have been triggered so don't loop. """ - if self._result_cache is not None: - raise OperationError("QuerySet already cached") - return self.clone_into(QuerySetNoCache(self._document, self._collection)) + queryset = self.clone() + doc = queryset._document + if write_concern is None: + write_concern = {} -class QuerySetNoCache(BaseQuerySet): - """A non caching QuerySet""" + # Handle deletes where skips or limits have been applied or + # there is an untriggered delete signal + has_delete_signal = signals.signals_available and ( + signals.pre_delete.has_receivers_for(self._document) or + signals.post_delete.has_receivers_for(self._document)) - def cache(self): - """Convert to a caching queryset + call_document_delete = (queryset._skip or queryset._limit or + has_delete_signal) and not _from_doc_delete - .. versionadded:: 0.8.3 Convert to caching queryset + if call_document_delete: + for doc in queryset: + doc.delete(write_concern=write_concern) + return + + delete_rules = doc._meta.get('delete_rules') or {} + # Check for DENY rules before actually deleting/nullifying any other + # references + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == DENY and document_cls.objects( + **{field_name + '__in': self}).count() > 0: + msg = ("Could not delete document (%s.%s refers to it)" + % (document_cls.__name__, field_name)) + raise OperationError(msg) + + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == CASCADE: + ref_q = document_cls.objects(**{field_name + '__in': self}) + ref_q_count = ref_q.count() + if (doc != document_cls and ref_q_count > 0 + or (doc == document_cls and ref_q_count > 0)): + ref_q.delete(write_concern=write_concern) + elif rule == NULLIFY: + document_cls.objects(**{field_name + '__in': self}).update( + write_concern=write_concern, **{'unset__%s' % field_name: 1}) + elif rule == PULL: + document_cls.objects(**{field_name + '__in': self}).update( + write_concern=write_concern, + **{'pull_all__%s' % field_name: self}) + + queryset._collection.remove(queryset._query, write_concern=write_concern) + + def update(self, upsert=False, multi=True, write_concern=None, + full_result=False, **update): + """Perform an atomic update on the fields matched by the query. + + :param upsert: Any existing document with that "_id" is overwritten. + :param multi: Update multiple documents. + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param full_result: Return the full result rather than just the number + updated. + :param update: Django-style update keyword arguments + + .. versionadded:: 0.2 """ - return self.clone_into(QuerySet(self._document, self._collection)) + if not update and not upsert: + raise OperationError("No update parameters, would remove data") - def __repr__(self): - """Provides the string representation of the QuerySet + if write_concern is None: + write_concern = {} - .. versionchanged:: 0.6.13 Now doesnt modify the cursor + queryset = self.clone() + query = queryset._query + update = transform.update(queryset._document, **update) + + # If doing an atomic upsert on an inheritable class + # then ensure we add _cls to the update operation + if upsert and '_cls' in query: + if '$set' in update: + update["$set"]["_cls"] = queryset._document._class_name + else: + update["$set"] = {"_cls": queryset._document._class_name} + try: + result = queryset._collection.update(query, update, multi=multi, + upsert=upsert, **write_concern) + if full_result: + return result + elif result: + return result['n'] + except pymongo.errors.OperationFailure, err: + if unicode(err) == u'multi not coded yet': + message = u'update() method requires MongoDB 1.1.3+' + raise OperationError(message) + raise OperationError(u'Update failed (%s)' % unicode(err)) + + def update_one(self, upsert=False, write_concern=None, **update): + """Perform an atomic update on first field matched by the query. + + :param upsert: Any existing document with that "_id" is overwritten. + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param update: Django-style update keyword arguments + + .. versionadded:: 0.2 """ - if self._iter: - return '.. queryset mid-iteration ..' + return self.update( + upsert=upsert, multi=False, write_concern=write_concern, **update) - data = [] - for i in xrange(REPR_OUTPUT_SIZE + 1): - try: - data.append(self.next()) - except StopIteration: - break - if len(data) > REPR_OUTPUT_SIZE: - data[-1] = "...(remaining elements truncated)..." + def with_id(self, object_id): + """Retrieve the object matching the id provided. Uses `object_id` only + and raises InvalidQueryError if a filter has been applied. Returns + `None` if no document exists with that id. - self.rewind() - return repr(data) + :param object_id: the value for the id of the document to look up - def __iter__(self): - queryset = self - if queryset._iter: - queryset = self.clone() - queryset.rewind() + .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set + """ + queryset = self.clone() + if not queryset._query_obj.empty: + msg = "Cannot use a filter whilst using `with_id`" + raise InvalidQueryError(msg) + return queryset.filter(pk=object_id).first() + + def in_bulk(self, object_ids): + """Retrieve a set of documents by their ids. + + :param object_ids: a list or tuple of ``ObjectId``\ s + :rtype: dict of ObjectIds as keys and collection-specific + Document subclasses as values. + + .. versionadded:: 0.3 + """ + doc_map = {} + + docs = self._collection.find({'_id': {'$in': object_ids}}, + **self._cursor_args) + if self._scalar: + for doc in docs: + doc_map[doc['_id']] = self._get_scalar( + self._document._from_son(doc)) + elif self._as_pymongo: + for doc in docs: + doc_map[doc['_id']] = self._get_as_pymongo(doc) + else: + for doc in docs: + doc_map[doc['_id']] = self._document._from_son(doc) + + return doc_map + + def none(self): + """Helper that just returns a list""" + queryset = self.clone() + queryset._none = True return queryset + + def no_sub_classes(self): + """ + Only return instances of this document and not any inherited documents + """ + if self._document._meta.get('allow_inheritance') is True: + self._initial_query = {"_cls": self._document._class_name} + + return self + + def only_classes(self, *classes): + doc = self._document + if doc._meta.get('allow_inheritance') is True: + queryset = self.clone() + class_names = [cls._class_name for cls in classes] + allowed_class_names = [name for name in self._document._subclasses if name in class_names] + if len(allowed_class_names) == 1: + queryset._initial_query = {"_cls": allowed_class_names[0]} + else: + queryset._initial_query = {"_cls": {"$in": allowed_class_names}} + return queryset + else: + return self + + def exclude_classes(self, *classes): + doc = self._document + if doc._meta.get('allow_inheritance') is True: + queryset = self.clone() + class_names = [cls._class_name for cls in classes] + allowed_class_names = [name for name in self._document._subclasses if name in class_names] + if len(allowed_class_names) == 1: + queryset._initial_query = {"_cls": {"$ne": allowed_class_names[0]}} + else: + queryset._initial_query = {"_cls": {"$nin": allowed_class_names}} + return queryset + else: + return self + + def clone(self): + """Creates a copy of the current + :class:`~mongoengine.queryset.QuerySet` + + .. versionadded:: 0.5 + """ + c = self.__class__(self._document, self._collection_obj) + + copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', + '_where_clause', '_loaded_fields', '_ordering', '_snapshot', + '_timeout', '_class_check', '_slave_okay', '_read_preference', + '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', + '_limit', '_skip', '_hint', '_auto_dereference') + + for prop in copy_props: + val = getattr(self, prop) + setattr(c, prop, copy.copy(val)) + + if self._cursor_obj: + c._cursor_obj = self._cursor_obj.clone() + + return c + + def select_related(self, max_depth=1): + """Handles dereferencing of :class:`~bson.dbref.DBRef` objects or + :class:`~bson.object_id.ObjectId` a maximum depth in order to cut down + the number queries to mongodb. + + .. versionadded:: 0.5 + """ + # Make select related work the same for querysets + max_depth += 1 + queryset = self.clone() + return queryset._dereference(queryset, max_depth=max_depth) + + def limit(self, n): + """Limit the number of returned documents to `n`. This may also be + achieved using array-slicing syntax (e.g. ``User.objects[:5]``). + + :param n: the maximum number of objects to return + """ + queryset = self.clone() + if n == 0: + queryset._cursor.limit(1) + else: + queryset._cursor.limit(n) + queryset._limit = n + # Return self to allow chaining + return queryset + + def skip(self, n): + """Skip `n` documents before returning the results. This may also be + achieved using array-slicing syntax (e.g. ``User.objects[5:]``). + + :param n: the number of objects to skip before returning results + """ + queryset = self.clone() + queryset._cursor.skip(n) + queryset._skip = n + return queryset + + def hint(self, index=None): + """Added 'hint' support, telling Mongo the proper index to use for the + query. + + Judicious use of hints can greatly improve query performance. When + doing a query on multiple fields (at least one of which is indexed) + pass the indexed field as a hint to the query. + + Hinting will not do anything if the corresponding index does not exist. + The last hint applied to this cursor takes precedence over all others. + + .. versionadded:: 0.5 + """ + queryset = self.clone() + queryset._cursor.hint(index) + queryset._hint = index + return queryset + + def distinct(self, field): + """Return a list of distinct values for a given field. + + :param field: the field to select distinct values from + + .. note:: This is a command and won't take ordering or limit into + account. + + .. versionadded:: 0.4 + .. versionchanged:: 0.5 - Fixed handling references + .. versionchanged:: 0.6 - Improved db_field refrence handling + """ + queryset = self.clone() + try: + field = self._fields_to_dbfields([field]).pop() + finally: + return self._dereference(queryset._cursor.distinct(field), 1, + name=field, instance=self._document) + + def only(self, *fields): + """Load only a subset of this document's fields. :: + + post = BlogPost.objects(...).only("title", "author.name") + + .. note :: `only()` is chainable and will perform a union :: + So with the following it will fetch both: `title` and `author.name`:: + + post = BlogPost.objects.only("title").only("author.name") + + :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any + field filters. + + :param fields: fields to include + + .. versionadded:: 0.3 + .. versionchanged:: 0.5 - Added subfield support + """ + fields = dict([(f, QueryFieldList.ONLY) for f in fields]) + return self.fields(True, **fields) + + def exclude(self, *fields): + """Opposite to .only(), exclude some document's fields. :: + + post = BlogPost.objects(...).exclude("comments") + + .. note :: `exclude()` is chainable and will perform a union :: + So with the following it will exclude both: `title` and `author.name`:: + + post = BlogPost.objects.exclude("title").exclude("author.name") + + :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any + field filters. + + :param fields: fields to exclude + + .. versionadded:: 0.5 + """ + fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) + return self.fields(**fields) + + def fields(self, _only_called=False, **kwargs): + """Manipulate how you load this document's fields. Used by `.only()` + and `.exclude()` to manipulate which fields to retrieve. Fields also + allows for a greater level of control for example: + + Retrieving a Subrange of Array Elements: + + You can use the $slice operator to retrieve a subrange of elements in + an array. For example to get the first 5 comments:: + + post = BlogPost.objects(...).fields(slice__comments=5) + + :param kwargs: A dictionary identifying what to include + + .. versionadded:: 0.5 + """ + + # Check for an operator and transform to mongo-style if there is + operators = ["slice"] + cleaned_fields = [] + for key, value in kwargs.items(): + parts = key.split('__') + op = None + if parts[0] in operators: + op = parts.pop(0) + value = {'$' + op: value} + key = '.'.join(parts) + cleaned_fields.append((key, value)) + + fields = sorted(cleaned_fields, key=operator.itemgetter(1)) + queryset = self.clone() + for value, group in itertools.groupby(fields, lambda x: x[1]): + fields = [field for field, value in group] + fields = queryset._fields_to_dbfields(fields) + queryset._loaded_fields += QueryFieldList(fields, value=value, _only_called=_only_called) + + return queryset + + def all_fields(self): + """Include all fields. Reset all previously calls of .only() or + .exclude(). :: + + post = BlogPost.objects.exclude("comments").all_fields() + + .. versionadded:: 0.5 + """ + queryset = self.clone() + queryset._loaded_fields = QueryFieldList( + always_include=queryset._loaded_fields.always_include) + return queryset + + def order_by(self, *keys): + """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The + order may be specified by prepending each of the keys by a + or a -. + Ascending order is assumed. + + :param keys: fields to order the query results by; keys may be + prefixed with **+** or **-** to determine the ordering direction + """ + queryset = self.clone() + queryset._ordering = queryset._get_order_by(keys) + return queryset + + def explain(self, format=False): + """Return an explain plan record for the + :class:`~mongoengine.queryset.QuerySet`\ 's cursor. + + :param format: format the plan before returning it + """ + plan = self._cursor.explain() + if format: + plan = pprint.pformat(plan) + return plan + + def snapshot(self, enabled): + """Enable or disable snapshot mode when querying. + + :param enabled: whether or not snapshot mode is enabled + + ..versionchanged:: 0.5 - made chainable + """ + queryset = self.clone() + queryset._snapshot = enabled + return queryset + + def timeout(self, enabled): + """Enable or disable the default mongod timeout when querying. + + :param enabled: whether or not the timeout is used + + ..versionchanged:: 0.5 - made chainable + """ + queryset = self.clone() + queryset._timeout = enabled + return queryset + + def slave_okay(self, enabled): + """Enable or disable the slave_okay when querying. + + :param enabled: whether or not the slave_okay is enabled + """ + queryset = self.clone() + queryset._slave_okay = enabled + return queryset + + def read_preference(self, read_preference): + """Change the read_preference when querying. + + :param read_preference: override ReplicaSetConnection-level + preference. + """ + validate_read_preference('read_preference', read_preference) + queryset = self.clone() + queryset._read_preference = read_preference + return queryset + + def scalar(self, *fields): + """Instead of returning Document instances, return either a specific + value or a tuple of values in order. + + Can be used along with + :func:`~mongoengine.queryset.QuerySet.no_dereference` to turn off + dereferencing. + + .. note:: This effects all results and can be unset by calling + ``scalar`` without arguments. Calls ``only`` automatically. + + :param fields: One or more fields to return instead of a Document. + """ + queryset = self.clone() + queryset._scalar = list(fields) + + if fields: + queryset = queryset.only(*fields) + else: + queryset = queryset.all_fields() + + return queryset + + def values_list(self, *fields): + """An alias for scalar""" + return self.scalar(*fields) + + def as_pymongo(self, coerce_types=False): + """Instead of returning Document instances, return raw values from + pymongo. + + :param coerce_type: Field types (if applicable) would be use to + coerce types. + """ + queryset = self.clone() + queryset._as_pymongo = True + queryset._as_pymongo_coerce = coerce_types + return queryset + + # JSON Helpers + + def to_json(self): + """Converts a queryset to JSON""" + return json_util.dumps(self.as_pymongo()) + + def from_json(self, json_data): + """Converts json data to unsaved objects""" + son_data = json_util.loads(json_data) + return [self._document._from_son(data) for data in son_data] + + # JS functionality + + def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, + scope=None): + """Perform a map/reduce query using the current query spec + and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, + it must be the last call made, as it does not return a maleable + ``QuerySet``. + + See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` + and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` + tests in ``tests.queryset.QuerySetTest`` for usage examples. + + :param map_f: map function, as :class:`~bson.code.Code` or string + :param reduce_f: reduce function, as + :class:`~bson.code.Code` or string + :param output: output collection name, if set to 'inline' will try to + use :class:`~pymongo.collection.Collection.inline_map_reduce` + This can also be a dictionary containing output options + see: http://docs.mongodb.org/manual/reference/commands/#mapReduce + :param finalize_f: finalize function, an optional function that + performs any post-reduction processing. + :param scope: values to insert into map/reduce global scope. Optional. + :param limit: number of objects from current query to provide + to map/reduce method + + Returns an iterator yielding + :class:`~mongoengine.document.MapReduceDocument`. + + .. note:: + + Map/Reduce changed in server version **>= 1.7.4**. The PyMongo + :meth:`~pymongo.collection.Collection.map_reduce` helper requires + PyMongo version **>= 1.11**. + + .. versionchanged:: 0.5 + - removed ``keep_temp`` keyword argument, which was only relevant + for MongoDB server versions older than 1.7.4 + + .. versionadded:: 0.3 + """ + queryset = self.clone() + + MapReduceDocument = _import_class('MapReduceDocument') + + if not hasattr(self._collection, "map_reduce"): + raise NotImplementedError("Requires MongoDB >= 1.7.1") + + map_f_scope = {} + if isinstance(map_f, Code): + map_f_scope = map_f.scope + map_f = unicode(map_f) + map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) + + reduce_f_scope = {} + if isinstance(reduce_f, Code): + reduce_f_scope = reduce_f.scope + reduce_f = unicode(reduce_f) + reduce_f_code = queryset._sub_js_fields(reduce_f) + reduce_f = Code(reduce_f_code, reduce_f_scope) + + mr_args = {'query': queryset._query} + + if finalize_f: + finalize_f_scope = {} + if isinstance(finalize_f, Code): + finalize_f_scope = finalize_f.scope + finalize_f = unicode(finalize_f) + finalize_f_code = queryset._sub_js_fields(finalize_f) + finalize_f = Code(finalize_f_code, finalize_f_scope) + mr_args['finalize'] = finalize_f + + if scope: + mr_args['scope'] = scope + + if limit: + mr_args['limit'] = limit + + if output == 'inline' and not queryset._ordering: + map_reduce_function = 'inline_map_reduce' + else: + map_reduce_function = 'map_reduce' + mr_args['out'] = output + + results = getattr(queryset._collection, map_reduce_function)( + map_f, reduce_f, **mr_args) + + if map_reduce_function == 'map_reduce': + results = results.find() + + if queryset._ordering: + results = results.sort(queryset._ordering) + + for doc in results: + yield MapReduceDocument(queryset._document, queryset._collection, + doc['_id'], doc['value']) + + def exec_js(self, code, *fields, **options): + """Execute a Javascript function on the server. A list of fields may be + provided, which will be translated to their correct names and supplied + as the arguments to the function. A few extra variables are added to + the function's scope: ``collection``, which is the name of the + collection in use; ``query``, which is an object representing the + current query; and ``options``, which is an object containing any + options specified as keyword arguments. + + As fields in MongoEngine may use different names in the database (set + using the :attr:`db_field` keyword argument to a :class:`Field` + constructor), a mechanism exists for replacing MongoEngine field names + with the database field names in Javascript code. When accessing a + field, use square-bracket notation, and prefix the MongoEngine field + name with a tilde (~). + + :param code: a string of Javascript code to execute + :param fields: fields that you will be using in your function, which + will be passed in to your function as arguments + :param options: options that you want available to the function + (accessed in Javascript through the ``options`` object) + """ + queryset = self.clone() + + code = queryset._sub_js_fields(code) + + fields = [queryset._document._translate_field_name(f) for f in fields] + collection = queryset._document._get_collection_name() + + scope = { + 'collection': collection, + 'options': options or {}, + } + + query = queryset._query + if queryset._where_clause: + query['$where'] = queryset._where_clause + + scope['query'] = query + code = Code(code, scope=scope) + + db = queryset._document._get_db() + return db.eval(code, *fields) + + def where(self, where_clause): + """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript + expression). Performs automatic field name substitution like + :meth:`mongoengine.queryset.Queryset.exec_js`. + + .. note:: When using this mode of query, the database will call your + function, or evaluate your predicate clause, for each object + in the collection. + + .. versionadded:: 0.5 + """ + queryset = self.clone() + where_clause = queryset._sub_js_fields(where_clause) + queryset._where_clause = where_clause + return queryset + + def sum(self, field): + """Sum over the values of the specified field. + + :param field: the field to sum over; use dot-notation to refer to + embedded document fields + + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. + """ + map_func = Code(""" + function() { + emit(1, this[field] || 0); + } + """, scope={'field': field}) + + reduce_func = Code(""" + function(key, values) { + var sum = 0; + for (var i in values) { + sum += values[i]; + } + return sum; + } + """) + + for result in self.map_reduce(map_func, reduce_func, output='inline'): + return result.value + else: + return 0 + + def average(self, field): + """Average over the values of the specified field. + + :param field: the field to average over; use dot-notation to refer to + embedded document fields + + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. + """ + map_func = Code(""" + function() { + if (this.hasOwnProperty(field)) + emit(1, {t: this[field] || 0, c: 1}); + } + """, scope={'field': field}) + + reduce_func = Code(""" + function(key, values) { + var out = {t: 0, c: 0}; + for (var i in values) { + var value = values[i]; + out.t += value.t; + out.c += value.c; + } + return out; + } + """) + + finalize_func = Code(""" + function(key, value) { + return value.t / value.c; + } + """) + + for result in self.map_reduce(map_func, reduce_func, + finalize_f=finalize_func, output='inline'): + return result.value + else: + return 0 + + def item_frequencies(self, field, normalize=False, map_reduce=True): + """Returns a dictionary of all items present in a field across + the whole queried set of documents, and their corresponding frequency. + This is useful for generating tag clouds, or searching documents. + + .. note:: + + Can only do direct simple mappings and cannot map across + :class:`~mongoengine.fields.ReferenceField` or + :class:`~mongoengine.fields.GenericReferenceField` for more complex + counting a manual map reduce call would is required. + + If the field is a :class:`~mongoengine.fields.ListField`, the items within + each list will be counted individually. + + :param field: the field to use + :param normalize: normalize the results so they add to 1.0 + :param map_reduce: Use map_reduce over exec_js + + .. versionchanged:: 0.5 defaults to map_reduce and can handle embedded + document lookups + """ + if map_reduce: + return self._item_frequencies_map_reduce(field, + normalize=normalize) + return self._item_frequencies_exec_js(field, normalize=normalize) + + # Iterator helpers + + def next(self): + """Wrap the result in a :class:`~mongoengine.Document` object. + """ + if self._limit == 0 or self._none: + raise StopIteration + + raw_doc = self._cursor.next() + if self._as_pymongo: + return self._get_as_pymongo(raw_doc) + doc = self._document._from_son(raw_doc, + _auto_dereference=self._auto_dereference) + if self._scalar: + return self._get_scalar(doc) + + return doc + + def rewind(self): + """Rewind the cursor to its unevaluated state. + + .. versionadded:: 0.3 + """ + self._iter = False + self._cursor.rewind() + + # Properties + + @property + def _collection(self): + """Property that returns the collection object. This allows us to + perform operations only if the collection is accessed. + """ + return self._collection_obj + + @property + def _cursor_args(self): + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout + } + if self._read_preference is not None: + cursor_args['read_preference'] = self._read_preference + else: + cursor_args['slave_okay'] = self._slave_okay + if self._loaded_fields: + cursor_args['fields'] = self._loaded_fields.as_dict() + return cursor_args + + @property + def _cursor(self): + if self._cursor_obj is None: + + self._cursor_obj = self._collection.find(self._query, + **self._cursor_args) + # Apply where clauses to cursor + if self._where_clause: + where_clause = self._sub_js_fields(self._where_clause) + self._cursor_obj.where(where_clause) + + if self._ordering: + # Apply query ordering + self._cursor_obj.sort(self._ordering) + elif self._ordering == None and self._document._meta['ordering']: + # Otherwise, apply the ordering from the document model + order = self._get_order_by(self._document._meta['ordering']) + self._cursor_obj.sort(order) + + if self._limit is not None: + self._cursor_obj.limit(self._limit) + + if self._skip is not None: + self._cursor_obj.skip(self._skip) + + if self._hint != -1: + self._cursor_obj.hint(self._hint) + + return self._cursor_obj + + def __deepcopy__(self, memo): + """Essential for chained queries with ReferenceFields involved""" + return self.clone() + + @property + def _query(self): + if self._mongo_query is None: + self._mongo_query = self._query_obj.to_query(self._document) + if self._class_check: + self._mongo_query.update(self._initial_query) + return self._mongo_query + + @property + def _dereference(self): + if not self.__dereference: + self.__dereference = _import_class('DeReference')() + return self.__dereference + + def no_dereference(self): + """Turn off any dereferencing for the results of this queryset. + """ + queryset = self.clone() + queryset._auto_dereference = False + return queryset + + # Helper Functions + + def _item_frequencies_map_reduce(self, field, normalize=False): + map_func = """ + function() { + var path = '{{~%(field)s}}'.split('.'); + var field = this; + + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(item, 1); + }); + } else if (typeof field != 'undefined') { + emit(field, 1); + } else { + emit(null, 1); + } + } + """ % dict(field=field) + reduce_func = """ + function(key, values) { + var total = 0; + var valuesSize = values.length; + for (var i=0; i < valuesSize; i++) { + total += parseInt(values[i], 10); + } + return total; + } + """ + values = self.map_reduce(map_func, reduce_func, 'inline') + frequencies = {} + for f in values: + key = f.key + if isinstance(key, float): + if int(key) == key: + key = int(key) + frequencies[key] = int(f.value) + + if normalize: + count = sum(frequencies.values()) + frequencies = dict([(k, float(v) / count) + for k, v in frequencies.items()]) + + return frequencies + + def _item_frequencies_exec_js(self, field, normalize=False): + """Uses exec_js to execute""" + freq_func = """ + function(path) { + var path = path.split('.'); + + var total = 0.0; + db[collection].find(query).forEach(function(doc) { + var field = doc; + for (p in path) { + if (field) + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + total += field.length; + } else { + total++; + } + }); + + var frequencies = {}; + var types = {}; + var inc = 1.0; + + db[collection].find(query).forEach(function(doc) { + field = doc; + for (p in path) { + if (field) + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + field.forEach(function(item) { + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); + }); + } else { + var item = field; + types[item] = item; + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); + } + }); + return [total, frequencies, types]; + } + """ + total, data, types = self.exec_js(freq_func, field) + values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) + + if normalize: + values = dict([(k, float(v) / total) for k, v in values.items()]) + + frequencies = {} + for k, v in values.iteritems(): + if isinstance(k, float): + if int(k) == k: + k = int(k) + + frequencies[k] = v + + return frequencies + + def _fields_to_dbfields(self, fields): + """Translate fields paths to its db equivalents""" + ret = [] + for field in fields: + field = ".".join(f.db_field for f in + self._document._lookup_field(field.split('.'))) + ret.append(field) + return ret + + def _get_order_by(self, keys): + """Creates a list of order by fields + """ + key_list = [] + for key in keys: + if not key: + continue + direction = pymongo.ASCENDING + if key[0] == '-': + direction = pymongo.DESCENDING + if key[0] in ('-', '+'): + key = key[1:] + key = key.replace('__', '.') + try: + key = self._document._translate_field_name(key) + except: + pass + key_list.append((key, direction)) + + if self._cursor_obj: + self._cursor_obj.sort(key_list) + return key_list + + def _get_scalar(self, doc): + + def lookup(obj, name): + chunks = name.split('__') + for chunk in chunks: + obj = getattr(obj, chunk) + return obj + + data = [lookup(doc, n) for n in self._scalar] + if len(data) == 1: + return data[0] + + return tuple(data) + + def _get_as_pymongo(self, row): + # Extract which fields paths we should follow if .fields(...) was + # used. If not, handle all fields. + if not getattr(self, '__as_pymongo_fields', None): + self.__as_pymongo_fields = [] + for field in self._loaded_fields.fields - set(['_cls', '_id']): + self.__as_pymongo_fields.append(field) + while '.' in field: + field, _ = field.rsplit('.', 1) + self.__as_pymongo_fields.append(field) + + all_fields = not self.__as_pymongo_fields + + def clean(data, path=None): + path = path or '' + + if isinstance(data, dict): + new_data = {} + for key, value in data.iteritems(): + new_path = '%s.%s' % (path, key) if path else key + + if all_fields: + include_field = True + elif self._loaded_fields.value == QueryFieldList.ONLY: + include_field = new_path in self.__as_pymongo_fields + else: + include_field = new_path not in self.__as_pymongo_fields + + if include_field: + new_data[key] = clean(value, path=new_path) + data = new_data + elif isinstance(data, list): + data = [clean(d, path=path) for d in data] + else: + if self._as_pymongo_coerce: + # If we need to coerce types, we need to determine the + # type of this field and use the corresponding + # .to_python(...) + from mongoengine.fields import EmbeddedDocumentField + obj = self._document + for chunk in path.split('.'): + obj = getattr(obj, chunk, None) + if obj is None: + break + elif isinstance(obj, EmbeddedDocumentField): + obj = obj.document_type + if obj and data is not None: + data = obj.to_python(data) + return data + return clean(row) + + def _sub_js_fields(self, code): + """When fields are specified with [~fieldname] syntax, where + *fieldname* is the Python name of a field, *fieldname* will be + substituted for the MongoDB name of the field (specified using the + :attr:`name` keyword argument in a field's constructor). + """ + def field_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return u'["%s"]' % fields[-1].db_field + + def field_path_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return ".".join([f.db_field for f in fields]) + + code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) + code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, + code) + return code + + # Deprecated + def ensure_index(self, **kwargs): + """Deprecated use :func:`Document.ensure_index`""" + msg = ("Doc.objects()._ensure_index() is deprecated. " + "Use Doc.ensure_index() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_index(**kwargs) + return self + + def _ensure_indexes(self): + """Deprecated use :func:`~Document.ensure_indexes`""" + msg = ("Doc.objects()._ensure_indexes() is deprecated. " + "Use Doc.ensure_indexes() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_indexes() diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 2ee7e38..352774f 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -43,11 +43,11 @@ def query(_doc_cls=None, _field_operation=False, **query): parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is op = None - if len(parts) > 1 and parts[-1] in MATCH_OPERATORS: + if parts[-1] in MATCH_OPERATORS: op = parts.pop() negate = False - if len(parts) > 1 and parts[-1] == 'not': + if parts[-1] == 'not': parts.pop() negate = True @@ -182,7 +182,6 @@ def update(_doc_cls=None, **update): parts = [] cleaned_fields = [] - appended_sub_field = False for field in fields: append_field = True if isinstance(field, basestring): @@ -194,30 +193,21 @@ def update(_doc_cls=None, **update): else: parts.append(field.db_field) if append_field: - appended_sub_field = False cleaned_fields.append(field) - if hasattr(field, 'field'): - cleaned_fields.append(field.field) - appended_sub_field = True # Convert value to proper value - if appended_sub_field: - field = cleaned_fields[-2] - else: - field = cleaned_fields[-1] + field = cleaned_fields[-1] if op in (None, 'set', 'push', 'pull'): if field.required or value is not None: value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] - elif op in ('addToSet', 'setOnInsert'): + elif op == 'addToSet': if isinstance(value, (list, tuple, set)): value = [field.prepare_query_value(op, v) for v in value] elif field.required or value is not None: value = field.prepare_query_value(op, value) - elif op == "unset": - value = 1 if match: match = '$' + match @@ -231,24 +221,11 @@ def update(_doc_cls=None, **update): if 'pull' in op and '.' in key: # Dot operators don't work on pull operations - # unless they point to a list field - # Otherwise it uses nested dict syntax + # it uses nested dict syntax if op == 'pullAll': raise InvalidQueryError("pullAll operations only support " "a single field depth") - # Look for the last list field and use dot notation until there - field_classes = [c.__class__ for c in cleaned_fields] - field_classes.reverse() - ListField = _import_class('ListField') - if ListField in field_classes: - # Join all fields via dot notation to the last ListField - # Then process as normal - last_listField = len(cleaned_fields) - field_classes.index(ListField) - key = ".".join(parts[:last_listField]) - parts = parts[last_listField:] - parts.insert(0, key) - parts.reverse() for key in parts: value = {key: value} diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 41f4ebf..3783e7a 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -28,7 +28,7 @@ class DuplicateQueryConditionsError(InvalidQueryError): class SimplificationVisitor(QNodeVisitor): - """Simplifies query trees by combinging unnecessary 'and' connection nodes + """Simplifies query trees by combining unnecessary 'and' connection nodes into a single Q-object. """ @@ -73,6 +73,16 @@ class QueryCompilerVisitor(QNodeVisitor): def visit_combination(self, combination): operator = "$and" if combination.operation == combination.OR: + keys = set([key for q in combination.children for key in q.keys()]) + if len(keys) == 1: + field = keys.pop() + if not field.startswith('$') and not any([isinstance(q[field], dict) for q in combination.children]): + return { + field: { + '$in': [q[field] for q in combination.children if field in q] + } + } + operator = "$or" return {operator: combination.children} diff --git a/mongoengine/signals.py b/mongoengine/signals.py index 06fb8b4..1c256da 100644 --- a/mongoengine/signals.py +++ b/mongoengine/signals.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- -__all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', - 'post_save', 'pre_delete', 'post_delete'] +__all__ = ['pre_save', 'post_save', 'pre_delete', 'post_delete'] signals_available = False try: @@ -36,10 +35,7 @@ except ImportError: # not put signals in here. Create your own namespace instead. _signals = Namespace() -pre_init = _signals.signal('pre_init') -post_init = _signals.signal('post_init') pre_save = _signals.signal('pre_save') -pre_save_post_validation = _signals.signal('pre_save_post_validation') post_save = _signals.signal('post_save') pre_delete = _signals.signal('pre_delete') post_delete = _signals.signal('post_delete') diff --git a/python-mongoengine.spec b/python-mongoengine.spec index b9c45ef..4eaba4d 100644 --- a/python-mongoengine.spec +++ b/python-mongoengine.spec @@ -5,7 +5,7 @@ %define srcname mongoengine Name: python-%{srcname} -Version: 0.8.4 +Version: 0.8.2 Release: 1%{?dist} Summary: A Python Document-Object Mapper for working with MongoDB diff --git a/setup.py b/setup.py index 85707d0..effb6f1 100644 --- a/setup.py +++ b/setup.py @@ -48,15 +48,17 @@ CLASSIFIERS = [ 'Topic :: Software Development :: Libraries :: Python Modules', ] -extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} +extra_opts = {} if sys.version_info[0] == 3: extra_opts['use_2to3'] = True - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'django>=1.5.1'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6'] + extra_opts['packages'] = find_packages(exclude=('tests',)) if "test" in sys.argv or "nosetests" in sys.argv: - extra_opts['packages'] = find_packages() + extra_opts['packages'].append("tests") extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} else: - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6', 'python-dateutil'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2==2.6', 'python-dateutil'] + extra_opts['packages'] = find_packages(exclude=('tests',)) setup(name='mongoengine', version=VERSION, diff --git a/tests/benchmark.py b/tests/benchmark.py new file mode 100644 index 0000000..89439f3 --- /dev/null +++ b/tests/benchmark.py @@ -0,0 +1,90 @@ +from mongoengine import * +from timeit import repeat +import unittest + +conn_settings = { + 'db': 'mongomallard-test', +} + +connect(**conn_settings) + +def timeit(f, n=10000): + return min(repeat(f, repeat=3, number=n))/float(n) + +class BenchmarkTestCase(unittest.TestCase): + def setUp(self): + pass + + def test_basic(self): + class Book(Document): + name = StringField() + pages = IntField() + tags = ListField(StringField()) + is_published = BooleanField() + + Book.drop_collection() + + create_book = lambda: Book(name='Always be closing', pages=100, tags=['self-help', 'sales'], is_published=True) + print 'Doc initialization: %.3fus' % (timeit(create_book, 1000) * 10**6) + + b = create_book() + + print 'Doc getattr: %.3fus' % (timeit(lambda: b.name, 10000) * 10**6) + + print 'Doc setattr: %.3fus' % (timeit(lambda: setattr(b, 'name', 'New name'), 10000) * 10**6) + + print 'Doc to mongo: %.3fus' % (timeit(b.to_mongo, 1000) * 10**6) + + def save_book(): + b._mark_as_changed('name') + b._mark_as_changed('tags') + b.save() + + save_book() + son = b.to_mongo() + + print 'Load from SON: %.3fus' % (timeit(lambda: Book._from_son(son), 1000) * 10**6) + + print 'Save to database: %.3fus' % (timeit(save_book, 100) * 10**6) + + print 'Load from database: %.3fus' % (timeit(lambda: Book.objects[0], 100) * 10**6) + + def test_embedded(self): + class Contact(EmbeddedDocument): + name = StringField() + title = StringField() + address = StringField() + + class Company(Document): + name = StringField() + contacts = ListField(EmbeddedDocumentField(Contact)) + + Company.drop_collection() + + def get_company(): + return Company( + name='Elastic', + contacts=[ + Contact( + name='Contact %d' % x, + title='CEO', + address='Address %d' % x, + ) + for x in range(1000)] + ) + + def create_company(): + c = get_company() + c.save() + c.delete() + + print 'Save/delete big object to database: %.3fms' % (timeit(create_company, 10) * 10**3) + + c = get_company().save() + + print 'Serialize big object from database: %.3fms' % (timeit(c.to_mongo, 100) * 10**3) + print 'Load big object from database: %.3fms' % (timeit(lambda: Company.objects[0], 100) * 10**3) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/delta.py b/tests/document/delta.py index b4749f3..355717f 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -3,7 +3,6 @@ import sys sys.path[0:0] = [""] import unittest -from bson import SON from mongoengine import * from mongoengine.connection import get_db @@ -49,41 +48,42 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._delta(), ({}, {})) doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['string_field']) + self.assertEqual(doc._get_changed_fields(), set(['string_field'])) self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) - doc._changed_fields = [] + doc._changed_fields = set() doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['int_field']) + self.assertEqual(doc._get_changed_fields(), set(['int_field'])) self.assertEqual(doc._delta(), ({'int_field': 1}, {})) - doc._changed_fields = [] + doc._changed_fields = set() dict_value = {'hello': 'world', 'ping': 'pong'} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) - doc._changed_fields = [] + doc._changed_fields = set() list_value = ['1', 2, {'hello': 'world'}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) # Test unsetting - doc._changed_fields = [] + doc._changed_fields = set() doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) - doc._changed_fields = [] + doc._changed_fields = set() doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + @unittest.skip("not fully implemented") def test_delta_recursive(self): self.delta_recursive(Document, EmbeddedDocument) self.delta_recursive(DynamicDocument, EmbeddedDocument) @@ -110,7 +110,7 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._delta(), ({}, {})) embedded_1 = Embedded() @@ -120,7 +120,7 @@ class DeltaTest(unittest.TestCase): embedded_1.list_field = ['1', 2, {'hello': 'world'}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ['embedded_field']) + self.assertEqual(doc._get_changed_fields(), set(['embedded_field'])) embedded_delta = { 'string_field': 'hello', @@ -137,7 +137,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.dict_field = {} self.assertEqual(doc._get_changed_fields(), - ['embedded_field.dict_field']) + set(['embedded_field.dict_field'])) self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) doc.save() @@ -146,7 +146,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = [] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + set(['embedded_field.list_field'])) self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) doc.save() @@ -161,7 +161,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = ['1', 2, embedded_2] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + set(['embedded_field.list_field'])) self.assertEqual(doc.embedded_field._delta(), ({ 'list_field': ['1', 2, { @@ -193,7 +193,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'world' self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field.2.string_field']) + set(['embedded_field.list_field.2.string_field'])) self.assertEqual(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) self.assertEqual(doc._delta(), @@ -207,7 +207,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + set(['embedded_field.list_field'])) self.assertEqual(doc.embedded_field._delta(), ({ 'list_field': ['1', 2, { '_cls': 'Embedded', @@ -270,7 +270,7 @@ class DeltaTest(unittest.TestCase): doc.dict_field['Embedded'].string_field = 'Hello World' self.assertEqual(doc._get_changed_fields(), - ['dict_field.Embedded.string_field']) + set(['dict_field.Embedded.string_field'])) self.assertEqual(doc._delta(), ({'dict_field.Embedded.string_field': 'Hello World'}, {})) @@ -313,24 +313,29 @@ class DeltaTest(unittest.TestCase): self.circular_reference_deltas_2(DynamicDocument, Document) self.circular_reference_deltas_2(DynamicDocument, DynamicDocument) - def circular_reference_deltas_2(self, DocClass1, DocClass2, dbref=True): + def circular_reference_deltas_2(self, DocClass1, DocClass2): class Person(DocClass1): name = StringField() - owns = ListField(ReferenceField('Organization', dbref=dbref)) - employer = ReferenceField('Organization', dbref=dbref) + owns = ListField(ReferenceField('Organization')) + employer = ReferenceField('Organization') class Organization(DocClass2): name = StringField() - owner = ReferenceField('Person', dbref=dbref) - employees = ListField(ReferenceField('Person', dbref=dbref)) + owner = ReferenceField('Person') + employees = ListField(ReferenceField('Person')) Person.drop_collection() Organization.drop_collection() - person = Person(name="owner").save() - employee = Person(name="employee").save() - organization = Organization(name="company").save() + person = Person(name="owner") + person.save() + + employee = Person(name="employee") + employee.save() + + organization = Organization(name="company") + organization.save() person.owns.append(organization) organization.owner = person @@ -350,8 +355,6 @@ class DeltaTest(unittest.TestCase): self.assertEqual(o.owner, p) self.assertEqual(e.employer, o) - return person, organization, employee - def test_delta_db_field(self): self.delta_db_field(Document) self.delta_db_field(DynamicDocument) @@ -369,39 +372,39 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._delta(), ({}, {})) doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['db_string_field']) + self.assertEqual(doc._get_changed_fields(), set(['string_field'])) self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {})) - doc._changed_fields = [] + doc._changed_fields = set() doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['db_int_field']) + self.assertEqual(doc._get_changed_fields(), set(['int_field'])) self.assertEqual(doc._delta(), ({'db_int_field': 1}, {})) - doc._changed_fields = [] + doc._changed_fields = set() dict_value = {'hello': 'world', 'ping': 'pong'} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) + self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {})) - doc._changed_fields = [] + doc._changed_fields = set() list_value = ['1', 2, {'hello': 'world'}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) + self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {})) # Test unsetting - doc._changed_fields = [] + doc._changed_fields = set() doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) + self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1})) - doc._changed_fields = [] + doc._changed_fields = set() doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) + self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._delta(), ({}, {'db_list_field': 1})) # Test it saves that data @@ -413,13 +416,15 @@ class DeltaTest(unittest.TestCase): doc.dict_field = {'hello': 'world'} doc.list_field = ['1', 2, {'hello': 'world'}] doc.save() - doc = doc.reload(10) + #doc = doc.reload(10) + doc = doc.reload() self.assertEqual(doc.string_field, 'hello') self.assertEqual(doc.int_field, 1) self.assertEqual(doc.dict_field, {'hello': 'world'}) self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}]) + @unittest.skip("not fully implemented") def test_delta_recursive_db_field(self): self.delta_recursive_db_field(Document, EmbeddedDocument) self.delta_recursive_db_field(Document, DynamicEmbeddedDocument) @@ -447,7 +452,7 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._delta(), ({}, {})) embedded_1 = Embedded() @@ -457,7 +462,7 @@ class DeltaTest(unittest.TestCase): embedded_1.list_field = ['1', 2, {'hello': 'world'}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field']) + self.assertEqual(doc._get_changed_fields(), set(['embedded_field'])) embedded_delta = { 'db_string_field': 'hello', @@ -485,7 +490,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = [] self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_list_field']) + set(['db_embedded_field.db_list_field'])) self.assertEqual(doc.embedded_field._delta(), ({}, {'db_list_field': 1})) self.assertEqual(doc._delta(), @@ -603,6 +608,7 @@ class DeltaTest(unittest.TestCase): self.assertEqual(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1})) + @unittest.skip("DynamicDocument not implemented") def test_delta_for_dynamic_documents(self): class Person(DynamicDocument): name = StringField() @@ -611,13 +617,13 @@ class DeltaTest(unittest.TestCase): Person.drop_collection() p = Person(name="James", age=34) - self.assertEqual(p._delta(), ( - SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {})) + self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', + '_cls': 'Person'}, {})) p.doc = 123 del(p.doc) - self.assertEqual(p._delta(), ( - SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {})) + self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', + '_cls': 'Person'}, {'doc': 1})) p = Person() p.name = "Dean" @@ -629,15 +635,16 @@ class DeltaTest(unittest.TestCase): self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._delta(), ({'age': 24}, {})) - p = Person.objects(age=22).get() + p = self.Person.objects(age=22).get() p.age = 24 self.assertEqual(p.age, 24) self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._delta(), ({'age': 24}, {})) p.save() - self.assertEqual(1, Person.objects(age=24).count()) + self.assertEqual(1, self.Person.objects(age=24).count()) + @unittest.skip("DynamicDocument not implemented") def test_dynamic_delta(self): class Doc(DynamicDocument): @@ -683,36 +690,6 @@ class DeltaTest(unittest.TestCase): self.assertEqual(doc._get_changed_fields(), ['list_field']) self.assertEqual(doc._delta(), ({}, {'list_field': 1})) - def test_delta_with_dbref_true(self): - person, organization, employee = self.circular_reference_deltas_2(Document, Document, True) - employee.name = 'test' - - self.assertEqual(organization._get_changed_fields(), []) - - updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertEqual({}, updates) - - organization.employees.append(person) - updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertTrue('employees' in updates) - - def test_delta_with_dbref_false(self): - person, organization, employee = self.circular_reference_deltas_2(Document, Document, False) - employee.name = 'test' - - self.assertEqual(organization._get_changed_fields(), []) - - updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertEqual({}, updates) - - organization.employees.append(person) - updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertTrue('employees' in updates) - if __name__ == '__main__': unittest.main() diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py index 6263e68..05870ee 100644 --- a/tests/document/dynamic.py +++ b/tests/document/dynamic.py @@ -8,6 +8,7 @@ from mongoengine.connection import get_db __all__ = ("DynamicTest", ) +@unittest.skip("DynamicDocument not implemented") class DynamicTest(unittest.TestCase): def setUp(self): diff --git a/tests/document/indexes.py b/tests/document/indexes.py index ccf8463..49fd7cb 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -156,25 +156,6 @@ class IndexesTest(unittest.TestCase): self.assertEqual([{'fields': [('_cls', 1), ('title', 1)]}], A._meta['index_specs']) - def test_index_no_cls(self): - """Ensure index specs are inhertited correctly""" - - class A(Document): - title = StringField() - meta = { - 'indexes': [ - {'fields': ('title',), 'cls': False}, - ], - 'allow_inheritance': True, - 'index_cls': False - } - - self.assertEqual([('title', 1)], A._meta['index_specs'][0]['fields']) - A._get_collection().drop_indexes() - A.ensure_indexes() - info = A._get_collection().index_information() - self.assertEqual(len(info.keys()), 2) - def test_build_index_spec_is_not_destructive(self): class MyDoc(Document): @@ -651,6 +632,7 @@ class IndexesTest(unittest.TestCase): pass Customer.drop_collection() + @unittest.skip("behavior differs") def test_unique_and_primary(self): """If you set a field as primary, then unexpected behaviour can occur. You won't create a duplicate but you will update an existing document. diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 5a48f75..d311538 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -182,10 +182,10 @@ class InheritanceTest(unittest.TestCase): self.assertEqual(['age', 'id', 'name', 'salary'], sorted(Employee._fields.keys())) - self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), - ['_cls', 'name', 'age']) - self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ['_cls', 'name', 'age', 'salary']) + self.assertEqual(set(Person(name="Bob", age=35).to_mongo().keys()), + set(['_cls', 'name', 'age'])) + self.assertEqual(set(Employee(name="Bob", age=35, salary=0).to_mongo().keys()), + set(['_cls', 'name', 'age', 'salary'])) self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) diff --git a/tests/document/instance.py b/tests/document/instance.py index a61c439..80a6130 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -10,8 +10,7 @@ import uuid from datetime import datetime from bson import DBRef -from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, - PickleDyanmicEmbedded, PickleDynamicTest) +from tests.fixtures import PickleEmbedded, PickleTest, PickleSignalsTest from mongoengine import * from mongoengine.errors import (NotRegistered, InvalidDocumentError, @@ -391,24 +390,25 @@ class InstanceTest(unittest.TestCase): doc.embedded_field = embedded_1 doc.save() - doc = doc.reload(10) + doc = doc.reload() doc.list_field.append(1) doc.dict_field['woot'] = "woot" doc.embedded_field.list_field.append(1) doc.embedded_field.dict_field['woot'] = "woot" - self.assertEqual(doc._get_changed_fields(), [ + self.assertEqual(doc._get_changed_fields(), set([ 'list_field', 'dict_field', 'embedded_field.list_field', - 'embedded_field.dict_field']) + 'embedded_field.dict_field'])) doc.save() - doc = doc.reload(10) - self.assertEqual(doc._get_changed_fields(), []) + doc = doc.reload() + self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(len(doc.list_field), 4) self.assertEqual(len(doc.dict_field), 2) self.assertEqual(len(doc.embedded_field.list_field), 4) self.assertEqual(len(doc.embedded_field.dict_field), 2) + @unittest.skip("not implemented") def test_dictionary_access(self): """Ensure that dictionary-style field access works properly. """ @@ -439,17 +439,10 @@ class InstanceTest(unittest.TestCase): class Employee(Person): salary = IntField() - self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), - ['_cls', 'name', 'age']) - self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ['_cls', 'name', 'age', 'salary']) - - def test_embedded_document_to_mongo_id(self): - class SubDoc(EmbeddedDocument): - id = StringField(required=True) - - sub_doc = SubDoc(id="abc") - self.assertEqual(sub_doc.to_mongo().keys(), ['id']) + self.assertEqual(set(Person(name="Bob", age=35).to_mongo().keys()), + set(['_cls', 'name', 'age'])) + self.assertEqual(set(Employee(name="Bob", age=35, salary=0).to_mongo().keys()), + set(['_cls', 'name', 'age', 'salary'])) def test_embedded_document(self): """Ensure that embedded documents are set up correctly. @@ -460,6 +453,7 @@ class InstanceTest(unittest.TestCase): self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) + @unittest.skip("not implemented") def test_embedded_document_instance(self): """Ensure that embedded documents can reference parent instance """ @@ -468,6 +462,7 @@ class InstanceTest(unittest.TestCase): class Doc(Document): embedded_field = EmbeddedDocumentField(Embedded) + meta = { 'cascade': True } Doc.drop_collection() Doc(embedded_field=Embedded(string="Hi")).save() @@ -475,6 +470,7 @@ class InstanceTest(unittest.TestCase): doc = Doc.objects.get() self.assertEqual(doc, doc.embedded_field._instance) + @unittest.skip("not implemented") def test_embedded_document_complex_instance(self): """Ensure that embedded documents in complex fields can reference parent instance""" @@ -631,6 +627,7 @@ class InstanceTest(unittest.TestCase): p0.name = 'wpjunior' p0.save() + @unittest.skip("FileField not implemented") def test_save_max_recursion_not_hit_with_file_field(self): class Foo(Document): @@ -779,6 +776,7 @@ class InstanceTest(unittest.TestCase): p1.reload() self.assertEqual(p1.name, p.parent.name) + @unittest.skip("not implemented") def test_update(self): """Ensure that an existing document is updated instead of be overwritten.""" @@ -893,7 +891,6 @@ class InstanceTest(unittest.TestCase): reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) - decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) @@ -1062,9 +1059,9 @@ class InstanceTest(unittest.TestCase): user = User.objects.first() # Even if stored as ObjectId's internally mongoengine uses DBRefs # As ObjectId's aren't automatically derefenced - self.assertTrue(isinstance(user._data['orgs'][0], DBRef)) + #self.assertTrue(isinstance(user._data['orgs'][0], DBRef)) self.assertTrue(isinstance(user.orgs[0], Organization)) - self.assertTrue(isinstance(user._data['orgs'][0], Organization)) + #self.assertTrue(isinstance(user._data['orgs'][0], Organization)) # Changing a value with query_counter() as q: @@ -1144,6 +1141,7 @@ class InstanceTest(unittest.TestCase): foo.save() self.assertEqual(1, q) + @unittest.skip("not implemented") def test_save_only_changed_fields_recursive(self): """Ensure save only sets / unsets changed fields """ @@ -1441,8 +1439,8 @@ class InstanceTest(unittest.TestCase): post_obj = BlogPost.objects.first() # Test laziness - self.assertTrue(isinstance(post_obj._data['author'], - bson.DBRef)) + #self.assertTrue(isinstance(post_obj._data['author'], + # bson.DBRef)) self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertEqual(post_obj.author.name, 'Test User') @@ -1466,6 +1464,7 @@ class InstanceTest(unittest.TestCase): self.assertRaises(InvalidDocumentError, throw_invalid_document_error) + @unittest.skip("not implemented") def test_invalid_son(self): """Raise an error if loading invalid data""" class Occurrence(EmbeddedDocument): @@ -1809,6 +1808,7 @@ class InstanceTest(unittest.TestCase): self.assertTrue(u1 in all_user_set) + @unittest.skip("not implemented") def test_picklable(self): pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) @@ -1835,29 +1835,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(pickle_doc.string, "Two") self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) - def test_dynamic_document_pickle(self): - - pickle_doc = PickleDynamicTest(name="test", number=1, string="One", lists=['1', '2']) - pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar") - pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved - - pickle_doc.save() - - pickled_doc = pickle.dumps(pickle_doc) - resurrected = pickle.loads(pickled_doc) - - self.assertEqual(resurrected, pickle_doc) - self.assertEqual(resurrected._fields_ordered, - pickle_doc._fields_ordered) - self.assertEqual(resurrected._dynamic_fields.keys(), - pickle_doc._dynamic_fields.keys()) - - self.assertEqual(resurrected.embedded, pickle_doc.embedded) - self.assertEqual(resurrected.embedded._fields_ordered, - pickle_doc.embedded._fields_ordered) - self.assertEqual(resurrected.embedded._dynamic_fields.keys(), - pickle_doc.embedded._dynamic_fields.keys()) - + @unittest.skip("not implemented") def test_picklable_on_signals(self): pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2']) pickle_doc.embedded = PickleEmbedded() @@ -1918,6 +1896,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(Doc.objects(archived=False).count(), 1) + @unittest.skip("DynamicDocument not implemented") def test_can_save_false_values_dynamic(self): """Ensures you can save False values on dynamic docs""" class Doc(DynamicDocument): @@ -2057,6 +2036,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual('testdb-1', B._meta.get('db_alias')) + @unittest.skip("not implemented") def test_db_ref_usage(self): """ DB Ref usage in dict_fields""" @@ -2135,6 +2115,7 @@ class InstanceTest(unittest.TestCase): })]), "1,2") + @unittest.skip("not implemented") def test_switch_db_instance(self): register_connection('testdb-1', 'mongoenginetest2') @@ -2206,9 +2187,10 @@ class InstanceTest(unittest.TestCase): user = User.objects.first() self.assertEqual("Ross", user.username) self.assertEqual(True, user.foo) - self.assertEqual("Bar", user._data["foo"]) - self.assertEqual([1, 2, 3], user._data["data"]) + self.assertEqual("Bar", user._db_data["foo"]) + self.assertEqual([1, 2, 3], user._db_data["data"]) + @unittest.skip("DynamicDocument not implemented") def test_spaces_in_keys(self): class Embedded(DynamicEmbeddedDocument): @@ -2225,6 +2207,7 @@ class InstanceTest(unittest.TestCase): one = Doc.objects.filter(**{'hello world': 1}).count() self.assertEqual(1, one) + @unittest.skip("not implemented") def test_shard_key(self): class LogEntry(Document): machine = StringField() @@ -2248,6 +2231,7 @@ class InstanceTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) + @unittest.skip("not implemented") def test_shard_key_primary(self): class LogEntry(Document): machine = StringField(primary_key=True) @@ -2271,6 +2255,7 @@ class InstanceTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) + @unittest.skip("not implemented") def test_kwargs_simple(self): class Embedded(EmbeddedDocument): @@ -2285,8 +2270,9 @@ class InstanceTest(unittest.TestCase): "doc": {"name": "embedded doc"}}) self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc._data, dict_doc._data) + self.assertEqual(classic_doc.to_dict(), dict_doc.to_dict()) + @unittest.skip("not implemented") def test_kwargs_complex(self): class Embedded(EmbeddedDocument): @@ -2304,8 +2290,9 @@ class InstanceTest(unittest.TestCase): {"name": "embedded doc2"}]}) self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc._data, dict_doc._data) + self.assertEqual(classic_doc.to_dict(), dict_doc.to_dict()) + @unittest.skip("not implemented") def test_positional_creation(self): """Ensure that document may be created using positional arguments. """ @@ -2313,6 +2300,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 42) + @unittest.skip("not implemented") def test_mixed_creation(self): """Ensure that document may be created using mixed arguments. """ @@ -2320,16 +2308,6 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 42) - def test_mixed_creation_dynamic(self): - """Ensure that document may be created using mixed arguments. - """ - class Person(DynamicDocument): - name = StringField() - - person = Person("Test User", age=42) - self.assertEqual(person.name, "Test User") - self.assertEqual(person.age, 42) - def test_bad_mixed_creation(self): """Ensure that document gives correct error when duplicating arguments """ @@ -2348,8 +2326,8 @@ class InstanceTest(unittest.TestCase): Person(name="Harry Potter").save() person = Person.objects.first() - self.assertTrue('id' in person._data.keys()) - self.assertEqual(person._data.get('id'), person.id) + self.assertTrue('id' in person.to_dict().keys()) + self.assertEqual(person.to_dict().get('id'), person.id) def test_complex_nesting_document_and_embedded_document(self): diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index 2b5d9a0..1f2d5c8 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -31,10 +31,6 @@ class TestJson(unittest.TestCase): doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) - doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) - expected_json = """{"embedded_field":{"string":"Hi"},"string":"Hi"}""" - self.assertEqual(doc_json, expected_json) - self.assertEqual(doc, Doc.from_json(doc.to_json())) def test_json_complex(self): @@ -62,7 +58,6 @@ class TestJson(unittest.TestCase): reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) - decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) diff --git a/tests/document/validation.py b/tests/document/validation.py index d3f3fd7..4637dee 100644 --- a/tests/document/validation.py +++ b/tests/document/validation.py @@ -53,11 +53,12 @@ class ValidatorErrorTest(unittest.TestCase): self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") def test_model_validation(self): - class User(Document): username = StringField(primary_key=True) name = StringField(required=True) + User.drop_collection() + try: User().validate() except ValidationError, e: @@ -128,18 +129,13 @@ class ValidatorErrorTest(unittest.TestCase): Doc(id="test", e=SubDoc(val=15)).save() doc = Doc.objects.first() - keys = doc._data.keys() + keys = doc.to_dict().keys() self.assertEqual(2, len(keys)) self.assertTrue('e' in keys) self.assertTrue('id' in keys) - doc.e.val = "OK" - try: - doc.save() - except ValidationError, e: - self.assertTrue("Doc:test" in e.message) - self.assertEqual(e.to_dict(), { - "e": {'val': 'OK could not be converted to int'}}) + with self.assertRaises(ValueError): + doc.e.val = "OK" if __name__ == '__main__': diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 8791781..1eea2ac 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -56,10 +56,11 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + data = person.to_dict() + self.assertEqual(data['name'], person.name) + self.assertEqual(data['age'], person.age) + self.assertEqual(data['userid'], person.userid) + self.assertEqual(data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -88,10 +89,11 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + data = person.to_dict() + self.assertEqual(data['name'], person.name) + self.assertEqual(data['age'], person.age) + self.assertEqual(data['userid'], person.userid) + self.assertEqual(data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -123,10 +125,12 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + data = person.to_dict() + + self.assertEqual(data['name'], person.name) + self.assertEqual(data['age'], person.age) + self.assertEqual(data['userid'], person.userid) + self.assertEqual(data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -157,10 +161,12 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + data = person.to_dict() + + self.assertEqual(data['name'], person.name) + self.assertEqual(data['age'], person.age) + self.assertEqual(data['userid'], person.userid) + self.assertEqual(data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -266,17 +272,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) - def test_long_ne_operator(self): - class TestDocument(Document): - long_fld = LongField() - - TestDocument.drop_collection() - - TestDocument(long_fld=None).save() - TestDocument(long_fld=1).save() - - self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) - def test_object_id_validation(self): """Ensure that invalid values cannot be assigned to string fields. """ @@ -347,25 +342,8 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, person.validate) person.age = 120 self.assertRaises(ValidationError, person.validate) - person.age = 'ten' - self.assertRaises(ValidationError, person.validate) - - def test_long_validation(self): - """Ensure that invalid values cannot be assigned to long fields. - """ - class TestDocument(Document): - value = LongField(min_value=0, max_value=110) - - doc = TestDocument() - doc.value = 50 - doc.validate() - - doc.value = -1 - self.assertRaises(ValidationError, doc.validate) - doc.age = 120 - self.assertRaises(ValidationError, doc.validate) - doc.age = 'ten' - self.assertRaises(ValidationError, doc.validate) + with self.assertRaises(ValueError): + person.age = 'ten' def test_float_validation(self): """Ensure that invalid values cannot be assigned to float fields. @@ -384,69 +362,6 @@ class FieldTest(unittest.TestCase): person.height = 4.0 self.assertRaises(ValidationError, person.validate) - def test_decimal_validation(self): - """Ensure that invalid values cannot be assigned to decimal fields. - """ - class Person(Document): - height = DecimalField(min_value=Decimal('0.1'), - max_value=Decimal('3.5')) - - Person.drop_collection() - - Person(height=Decimal('1.89')).save() - person = Person.objects.first() - self.assertEqual(person.height, Decimal('1.89')) - - person.height = '2.0' - person.save() - person.height = 0.01 - self.assertRaises(ValidationError, person.validate) - person.height = Decimal('0.01') - self.assertRaises(ValidationError, person.validate) - person.height = Decimal('4.0') - self.assertRaises(ValidationError, person.validate) - - Person.drop_collection() - - def test_decimal_comparison(self): - - class Person(Document): - money = DecimalField() - - Person.drop_collection() - - Person(money=6).save() - Person(money=8).save() - Person(money=10).save() - - self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) - self.assertEqual(2, Person.objects(money__gt=7).count()) - self.assertEqual(2, Person.objects(money__gt="7").count()) - - def test_decimal_storage(self): - class Person(Document): - btc = DecimalField(precision=4) - - Person.drop_collection() - Person(btc=10).save() - Person(btc=10.1).save() - Person(btc=10.11).save() - Person(btc="10.111").save() - Person(btc=Decimal("10.1111")).save() - Person(btc=Decimal("10.11111")).save() - - # How its stored - expected = [{'btc': 10.0}, {'btc': 10.1}, {'btc': 10.11}, - {'btc': 10.111}, {'btc': 10.1111}, {'btc': 10.1111}] - actual = list(Person.objects.exclude('id').as_pymongo()) - self.assertEqual(expected, actual) - - # How it comes out locally - expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), - Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] - actual = list(Person.objects().scalar('btc')) - self.assertEqual(expected, actual) - def test_boolean_validation(self): """Ensure that invalid values cannot be assigned to boolean fields. """ @@ -532,10 +447,10 @@ class FieldTest(unittest.TestCase): log.time = datetime.datetime.now().isoformat('T') log.validate() - log.time = -1 - self.assertRaises(ValidationError, log.validate) - log.time = 'ABC' - self.assertRaises(ValidationError, log.validate) + #log.time = -1 + #self.assertRaises(ValidationError, log.validate) + #log.time = 'ABC' + #self.assertRaises(ValidationError, log.validate) def test_datetime_tz_aware_mark_as_changed(self): from mongoengine import connection @@ -556,7 +471,7 @@ class FieldTest(unittest.TestCase): log = LogEntry.objects.first() log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) - self.assertEqual(['time'], log._changed_fields) + self.assertEqual(set(['time']), log._changed_fields) def test_datetime(self): """Tests showing pymongo datetime fields handling of microseconds. @@ -791,8 +706,8 @@ class FieldTest(unittest.TestCase): post = BlogPost(content='Went for a walk today...') post.validate() - post.tags = 'fun' - self.assertRaises(ValidationError, post.validate) + #post.tags = 'fun' + #self.assertRaises(ValidationError, post.validate) post.tags = [1, 2] self.assertRaises(ValidationError, post.validate) @@ -903,11 +818,11 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() post = BlogPost() - post.info = 'my post' - self.assertRaises(ValidationError, post.validate) + #post.info = 'my post' + #self.assertRaises(ValidationError, post.validate) - post.info = {'title': 'test'} - self.assertRaises(ValidationError, post.validate) + #post.info = {'title': 'test'} + #self.assertRaises(ValidationError, post.validate) post.info = ['test'] post.save() @@ -956,14 +871,12 @@ class FieldTest(unittest.TestCase): e.mapping = [1] e.save() - def create_invalid_mapping(): + with self.assertRaises(ValueError): e.mapping = ["abc"] - e.save() - - self.assertRaises(ValidationError, create_invalid_mapping) Simple.drop_collection() + @unittest.skip("different behavior") def test_list_field_rejects_strings(self): """Strings aren't valid list field data types""" @@ -1008,7 +921,7 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() e = Simple().save() e.mapping = [] - self.assertEqual([], e._changed_fields) + self.assertEqual(set([]), e._changed_fields) class Simple(Document): mapping = DictField() @@ -1016,34 +929,9 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() e = Simple().save() e.mapping = {} - self.assertEqual([], e._changed_fields) - - def test_slice_marks_field_as_changed(self): - - class Simple(Document): - widgets = ListField() - - simple = Simple(widgets=[1, 2, 3, 4]).save() - simple.widgets[:3] = [] - self.assertEqual(['widgets'], simple._changed_fields) - simple.save() - - simple = simple.reload() - self.assertEqual(simple.widgets, [4]) - - def test_del_slice_marks_field_as_changed(self): - - class Simple(Document): - widgets = ListField() - - simple = Simple(widgets=[1, 2, 3, 4]).save() - del simple.widgets[:3] - self.assertEqual(['widgets'], simple._changed_fields) - simple.save() - - simple = simple.reload() - self.assertEqual(simple.widgets, [4]) + self.assertEqual(set([]), e._changed_fields) + @unittest.skip("complex types not implemented") def test_list_field_complex(self): """Ensure that the list fields can handle the complex types.""" @@ -1100,11 +988,11 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() post = BlogPost() - post.info = 'my post' - self.assertRaises(ValidationError, post.validate) + #post.info = 'my post' + #self.assertRaises(ValidationError, post.validate) - post.info = ['test', 'test'] - self.assertRaises(ValidationError, post.validate) + #post.info = ['test', 'test'] + #self.assertRaises(ValidationError, post.validate) post.info = {'$title': 'test'} self.assertRaises(ValidationError, post.validate) @@ -1162,6 +1050,7 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() + @unittest.skip("complex types not implemented") def test_dictfield_complex(self): """Ensure that the dict field can handle the complex types.""" @@ -1979,6 +1868,7 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + @unittest.skip("not implemented") def test_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. @@ -2031,6 +1921,7 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + @unittest.skip("not implemented") def test_simple_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. @@ -2110,6 +2001,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(d2.data, {}) self.assertEqual(d2.data2, {}) + @unittest.skip("SequenceField not implemented") def test_sequence_field(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2135,6 +2027,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(c['next'], 1000) + @unittest.skip("SequenceField not implemented") def test_sequence_field_get_next_value(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2166,6 +2059,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(Person.id.get_next_value(), '1') + @unittest.skip("SequenceField not implemented") def test_sequence_field_sequence_name(self): class Person(Document): id = SequenceField(primary_key=True, sequence_name='jelly') @@ -2190,6 +2084,7 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) self.assertEqual(c['next'], 1000) + @unittest.skip("SequenceField not implemented") def test_multiple_sequence_fields(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2222,6 +2117,7 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) self.assertEqual(c['next'], 999) + @unittest.skip("SequenceField not implemented") def test_sequence_fields_reload(self): class Animal(Document): counter = SequenceField() @@ -2247,6 +2143,7 @@ class FieldTest(unittest.TestCase): a.reload() self.assertEqual(a.counter, 2) + @unittest.skip("SequenceField not implemented") def test_multiple_sequence_fields_on_docs(self): class Animal(Document): @@ -2281,6 +2178,7 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) self.assertEqual(c['next'], 10) + @unittest.skip("SequenceField not implemented") def test_sequence_field_value_decorator(self): class Person(Document): id = SequenceField(primary_key=True, value_decorator=str) @@ -2302,6 +2200,7 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) + @unittest.skip("SequenceField not implemented") def test_embedded_sequence_field(self): class Comment(EmbeddedDocument): id = SequenceField() @@ -2474,78 +2373,6 @@ class FieldTest(unittest.TestCase): user = User(email='me@example.com') self.assertTrue(user.validate() is None) - def test_tuples_as_tuples(self): - """ - Ensure that tuples remain tuples when they are - inside a ComplexBaseField - """ - from mongoengine.base import BaseField - - class EnumField(BaseField): - - def __init__(self, **kwargs): - super(EnumField, self).__init__(**kwargs) - - def to_mongo(self, value): - return value - - def to_python(self, value): - return tuple(value) - - class TestDoc(Document): - items = ListField(EnumField()) - - TestDoc.drop_collection() - tuples = [(100, 'Testing')] - doc = TestDoc() - doc.items = tuples - doc.save() - x = TestDoc.objects().get() - self.assertTrue(x is not None) - self.assertTrue(len(x.items) == 1) - self.assertTrue(tuple(x.items[0]) in tuples) - self.assertTrue(x.items[0] in tuples) - - def test_dynamic_fields_class(self): - - class Doc2(Document): - field_1 = StringField(db_field='f') - - class Doc(Document): - my_id = IntField(required=True, unique=True, primary_key=True) - embed_me = DynamicField(db_field='e') - field_x = StringField(db_field='x') - - Doc.drop_collection() - Doc2.drop_collection() - - doc2 = Doc2(field_1="hello") - doc = Doc(my_id=1, embed_me=doc2, field_x="x") - self.assertRaises(OperationError, doc.save) - - doc2.save() - doc.save() - - doc = Doc.objects.get() - self.assertEqual(doc.embed_me.field_1, "hello") - - def test_dynamic_fields_embedded_class(self): - - class Embed(EmbeddedDocument): - field_1 = StringField(db_field='f') - - class Doc(Document): - my_id = IntField(required=True, unique=True, primary_key=True) - embed_me = DynamicField(db_field='e') - field_x = StringField(db_field='x') - - Doc.drop_collection() - - Doc(my_id=1, embed_me=Embed(field_1="hello"), field_x="x").save() - - doc = Doc.objects.get() - self.assertEqual(doc.embed_me.field_1, "hello") - if __name__ == '__main__': unittest.main() diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index ba601de..9ad3fdd 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -24,6 +24,7 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') +@unittest.skip("FileField not implemented") class FileTest(unittest.TestCase): def setUp(self): @@ -53,12 +54,11 @@ class FileTest(unittest.TestCase): content_type = 'text/plain' putfile = PutFile() - putfile.the_file.put(text, content_type=content_type, filename="hello") + putfile.the_file.put(text, content_type=content_type) putfile.save() result = PutFile.objects.first() self.assertTrue(putfile == result) - self.assertEqual("%s" % result.the_file, "") self.assertEqual(result.the_file.read(), text) self.assertEqual(result.the_file.content_type, content_type) result.the_file.delete() # Remove file from GridFS @@ -456,31 +456,5 @@ class FileTest(unittest.TestCase): self.assertEqual(1, TestImage.objects(Q(image1=grid_id) or Q(image2=grid_id)).count()) - def test_complex_field_filefield(self): - """Ensure you can add meta data to file""" - - class Animal(Document): - genus = StringField() - family = StringField() - photos = ListField(FileField()) - - Animal.drop_collection() - marmot = Animal(genus='Marmota', family='Sciuridae') - - marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk - - photos_field = marmot._fields['photos'].field - new_proxy = photos_field.get_proxy_obj('photos', marmot) - new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') - marmot_photo.close() - - marmot.photos.append(new_proxy) - marmot.save() - - marmot = Animal.objects.get() - self.assertEqual(marmot.photos[0].content_type, 'image/jpeg') - self.assertEqual(marmot.photos[0].foo, 'bar') - self.assertEqual(marmot.photos[0].get().length, 8313) - if __name__ == '__main__': unittest.main() diff --git a/tests/fields/geo.py b/tests/fields/geo.py index 31ded26..81f8a69 100644 --- a/tests/fields/geo.py +++ b/tests/fields/geo.py @@ -10,6 +10,7 @@ from mongoengine.connection import get_db __all__ = ("GeoFieldTest", ) +@unittest.skip("geo fields not implemented") class GeoFieldTest(unittest.TestCase): def setUp(self): diff --git a/tests/fixtures.py b/tests/fixtures.py index f1344d7..e207044 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -17,14 +17,6 @@ class PickleTest(Document): photo = FileField() -class PickleDyanmicEmbedded(DynamicEmbeddedDocument): - date = DateTimeField(default=datetime.now) - - -class PickleDynamicTest(DynamicDocument): - number = IntField() - - class PickleSignalsTest(Document): number = IntField() string = StringField(choices=(('One', '1'), ('Two', '2'))) diff --git a/tests/migration/__init__.py b/tests/migration/__init__.py index 6fc83e0..bff50c3 100644 --- a/tests/migration/__init__.py +++ b/tests/migration/__init__.py @@ -1,5 +1,4 @@ from convert_to_new_inheritance_model import * -from decimalfield_as_float import * from refrencefield_dbref_to_object_id import * from turn_off_inheritance import * from uuidfield_to_binary import * diff --git a/tests/migration/decimalfield_as_float.py b/tests/migration/decimalfield_as_float.py deleted file mode 100644 index 3903c91..0000000 --- a/tests/migration/decimalfield_as_float.py +++ /dev/null @@ -1,50 +0,0 @@ - # -*- coding: utf-8 -*- -import unittest -import decimal -from decimal import Decimal - -from mongoengine import Document, connect -from mongoengine.connection import get_db -from mongoengine.fields import StringField, DecimalField, ListField - -__all__ = ('ConvertDecimalField', ) - - -class ConvertDecimalField(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - def test_how_to_convert_decimal_fields(self): - """Demonstrates migrating from 0.7 to 0.8 - """ - - # 1. Old definition - using dbrefs - class Person(Document): - name = StringField() - money = DecimalField(force_string=True) - monies = ListField(DecimalField(force_string=True)) - - Person.drop_collection() - Person(name="Wilson Jr", money=Decimal("2.50"), - monies=[Decimal("2.10"), Decimal("5.00")]).save() - - # 2. Start the migration by changing the schema - # Change DecimalField - add precision and rounding settings - class Person(Document): - name = StringField() - money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP) - monies = ListField(DecimalField(precision=2, - rounding=decimal.ROUND_HALF_UP)) - - # 3. Loop all the objects and mark parent as changed - for p in Person.objects: - p._mark_as_changed('money') - p._mark_as_changed('monies') - p.save() - - # 4. Confirmation of the fix! - wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] - self.assertTrue(isinstance(wilson['money'], float)) - self.assertTrue(all([isinstance(m, float) for m in wilson['monies']])) diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py index 7d66d26..2bdfce1 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/field_list.py @@ -162,10 +162,6 @@ class OnlyExcludeAllTest(unittest.TestCase): self.assertEqual(obj.name, person.name) self.assertEqual(obj.age, person.age) - obj = self.Person.objects.only(*('id', 'name',)).get() - self.assertEqual(obj.name, person.name) - self.assertEqual(obj.age, None) - # Check polymorphism still works class Employee(self.Person): salary = IntField(db_field='wage') @@ -399,28 +395,5 @@ class OnlyExcludeAllTest(unittest.TestCase): numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) - - def test_exclude_from_subclasses_docs(self): - - class Base(Document): - username = StringField() - - meta = {'allow_inheritance': True} - - class Anon(Base): - anon = BooleanField() - - class User(Base): - password = StringField() - wibble = StringField() - - Base.drop_collection() - User(username="mongodb", password="secret").save() - - user = Base.objects().exclude("password", "wibble").first() - self.assertEqual(user.password, None) - - self.assertRaises(LookUpError, Base.objects.exclude, "made_up") - if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index f564896..7e1c5df 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -8,6 +8,7 @@ from mongoengine import * __all__ = ("GeoQueriesTest",) +@unittest.skip("geo queries not implemented") class GeoQueriesTest(unittest.TestCase): def setUp(self): diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index b4bcf2a..1f1051b 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -30,17 +30,12 @@ class QuerySetTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') - class PersonMeta(EmbeddedDocument): - weight = IntField() - class Person(Document): name = StringField() age = IntField() - person_meta = EmbeddedDocumentField(PersonMeta) meta = {'allow_inheritance': True} Person.drop_collection() - self.PersonMeta = PersonMeta self.Person = Person def test_initialisation(self): @@ -782,10 +777,10 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 0) fresh_o1 = Organization.objects.get(id=o1.id) - fresh_o1.employees.append(p2) # Dereferences + fresh_o1.employees.append(p2) fresh_o1.save(cascade=False) # Saves - self.assertEqual(q, 3) + self.assertEqual(q, 2) def test_slave_okay(self): """Ensures that a query can take slave_okay syntax @@ -1497,6 +1492,9 @@ class QuerySetTest(unittest.TestCase): def test_pull_nested(self): + class User(Document): + name = StringField() + class Collaborator(EmbeddedDocument): user = StringField() @@ -1511,7 +1509,8 @@ class QuerySetTest(unittest.TestCase): Site.drop_collection() c = Collaborator(user='Esteban') - s = Site(name="test", collaborators=[c]).save() + s = Site(name="test", collaborators=[c]) + s.save() Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') self.assertEqual(Site.objects.first().collaborators, []) @@ -1521,71 +1520,6 @@ class QuerySetTest(unittest.TestCase): self.assertRaises(InvalidQueryError, pull_all) - def test_pull_from_nested_embedded(self): - - class User(EmbeddedDocument): - name = StringField() - - def __unicode__(self): - return '%s' % self.name - - class Collaborator(EmbeddedDocument): - helpful = ListField(EmbeddedDocumentField(User)) - unhelpful = ListField(EmbeddedDocumentField(User)) - - class Site(Document): - name = StringField(max_length=75, unique=True, required=True) - collaborators = EmbeddedDocumentField(Collaborator) - - - Site.drop_collection() - - c = User(name='Esteban') - f = User(name='Frank') - s = Site(name="test", collaborators=Collaborator(helpful=[c], unhelpful=[f])).save() - - Site.objects(id=s.id).update_one(pull__collaborators__helpful=c) - self.assertEqual(Site.objects.first().collaborators['helpful'], []) - - Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'name': 'Frank'}) - self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) - - def pull_all(): - Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__name=['Ross']) - - self.assertRaises(InvalidQueryError, pull_all) - - def test_pull_from_nested_mapfield(self): - - class Collaborator(EmbeddedDocument): - user = StringField() - - def __unicode__(self): - return '%s' % self.user - - class Site(Document): - name = StringField(max_length=75, unique=True, required=True) - collaborators = MapField(ListField(EmbeddedDocumentField(Collaborator))) - - - Site.drop_collection() - - c = Collaborator(user='Esteban') - f = Collaborator(user='Frank') - s = Site(name="test", collaborators={'helpful':[c],'unhelpful':[f]}) - s.save() - - Site.objects(id=s.id).update_one(pull__collaborators__helpful__user='Esteban') - self.assertEqual(Site.objects.first().collaborators['helpful'], []) - - Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'user':'Frank'}) - self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) - - def pull_all(): - Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__user=['Ross']) - - self.assertRaises(InvalidQueryError, pull_all) - def test_update_one_pop_generic_reference(self): class BlogTag(Document): @@ -2274,19 +2208,6 @@ class QuerySetTest(unittest.TestCase): self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.average('age')), avg) - # dot notation - self.Person(name='person meta', person_meta=self.PersonMeta(weight=0)).save() - self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), 0) - - for i, weight in enumerate(ages): - self.Person(name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save() - - self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), avg) - - self.Person(name='test meta none').save() - self.assertEqual(int(self.Person.objects.average('person_meta.weight')), avg) - - def test_sum(self): """Ensure that field can be summed over correctly. """ @@ -2299,153 +2220,6 @@ class QuerySetTest(unittest.TestCase): self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) - for i, age in enumerate(ages): - self.Person(name='test meta%s' % i, person_meta=self.PersonMeta(weight=age)).save() - - self.assertEqual(int(self.Person.objects.sum('person_meta.weight')), sum(ages)) - - self.Person(name='weightless person').save() - self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) - - def test_embedded_average(self): - class Pay(EmbeddedDocument): - value = DecimalField() - - class Doc(Document): - name = StringField() - pay = EmbeddedDocumentField( - Pay) - - Doc.drop_collection() - - Doc(name=u"Wilson Junior", - pay=Pay(value=150)).save() - - Doc(name=u"Isabella Luanna", - pay=Pay(value=530)).save() - - Doc(name=u"Tayza mariana", - pay=Pay(value=165)).save() - - Doc(name=u"Eliana Costa", - pay=Pay(value=115)).save() - - self.assertEqual( - Doc.objects.average('pay.value'), - 240) - - def test_embedded_array_average(self): - class Pay(EmbeddedDocument): - values = ListField(DecimalField()) - - class Doc(Document): - name = StringField() - pay = EmbeddedDocumentField( - Pay) - - Doc.drop_collection() - - Doc(name=u"Wilson Junior", - pay=Pay(values=[150, 100])).save() - - Doc(name=u"Isabella Luanna", - pay=Pay(values=[530, 100])).save() - - Doc(name=u"Tayza mariana", - pay=Pay(values=[165, 100])).save() - - Doc(name=u"Eliana Costa", - pay=Pay(values=[115, 100])).save() - - self.assertEqual( - Doc.objects.average('pay.values'), - 170) - - def test_array_average(self): - class Doc(Document): - values = ListField(DecimalField()) - - Doc.drop_collection() - - Doc(values=[150, 100]).save() - Doc(values=[530, 100]).save() - Doc(values=[165, 100]).save() - Doc(values=[115, 100]).save() - - self.assertEqual( - Doc.objects.average('values'), - 170) - - def test_embedded_sum(self): - class Pay(EmbeddedDocument): - value = DecimalField() - - class Doc(Document): - name = StringField() - pay = EmbeddedDocumentField( - Pay) - - Doc.drop_collection() - - Doc(name=u"Wilson Junior", - pay=Pay(value=150)).save() - - Doc(name=u"Isabella Luanna", - pay=Pay(value=530)).save() - - Doc(name=u"Tayza mariana", - pay=Pay(value=165)).save() - - Doc(name=u"Eliana Costa", - pay=Pay(value=115)).save() - - self.assertEqual( - Doc.objects.sum('pay.value'), - 960) - - - def test_embedded_array_sum(self): - class Pay(EmbeddedDocument): - values = ListField(DecimalField()) - - class Doc(Document): - name = StringField() - pay = EmbeddedDocumentField( - Pay) - - Doc.drop_collection() - - Doc(name=u"Wilson Junior", - pay=Pay(values=[150, 100])).save() - - Doc(name=u"Isabella Luanna", - pay=Pay(values=[530, 100])).save() - - Doc(name=u"Tayza mariana", - pay=Pay(values=[165, 100])).save() - - Doc(name=u"Eliana Costa", - pay=Pay(values=[115, 100])).save() - - self.assertEqual( - Doc.objects.sum('pay.values'), - 1360) - - def test_array_sum(self): - class Doc(Document): - values = ListField(DecimalField()) - - Doc.drop_collection() - - Doc(values=[150, 100]).save() - Doc(values=[530, 100]).save() - Doc(values=[165, 100]).save() - Doc(values=[115, 100]).save() - - self.assertEqual( - Doc.objects.sum('values'), - 1360) - def test_distinct(self): """Ensure that the QuerySet.distinct method works. """ @@ -3148,19 +2922,6 @@ class QuerySetTest(unittest.TestCase): (u'Wilson Jr', 19, u'Corumba-GO'), (u'Gabriel Falcao', 23, u'New York')]) - def test_scalar_decimal(self): - from decimal import Decimal - class Person(Document): - name = StringField() - rating = DecimalField() - - Person.drop_collection() - Person(name="Wilson Jr", rating=Decimal('1.0')).save() - - ulist = list(Person.objects.scalar('name', 'rating')) - self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))]) - - def test_scalar_reference_field(self): class State(Document): name = StringField() @@ -3360,13 +3121,6 @@ class QuerySetTest(unittest.TestCase): Test.objects(test='foo').update_one(upsert=True, set__test='foo') self.assertTrue('_cls' in Test._collection.find_one()) - def test_update_upsert_looks_like_a_digit(self): - class MyDoc(DynamicDocument): - pass - MyDoc.drop_collection() - self.assertEqual(1, MyDoc.objects.update_one(upsert=True, inc__47=1)) - self.assertEqual(MyDoc.objects.get()['47'], 1) - def test_read_preference(self): class Bar(Document): pass @@ -3395,7 +3149,7 @@ class QuerySetTest(unittest.TestCase): Doc(string="Bye", embedded_field=Embedded(string="Bye")).save() Doc().save() - json_data = Doc.objects.to_json(sort_keys=True, separators=(',', ':')) + json_data = Doc.objects.to_json() doc_objects = list(Doc.objects) self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) @@ -3423,7 +3177,6 @@ class QuerySetTest(unittest.TestCase): objectid_field = ObjectIdField(default=ObjectId) reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) - decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) @@ -3454,33 +3207,25 @@ class QuerySetTest(unittest.TestCase): id = ObjectIdField('_id') name = StringField() age = IntField() - price = DecimalField() User.drop_collection() - User(name="Bob Dole", age=89, price=Decimal('1.11')).save() - User(name="Barack Obama", age=51, price=Decimal('2.22')).save() + User(name="Bob Dole", age=89).save() + User(name="Barack Obama", age=51).save() - results = User.objects.only('id', 'name').as_pymongo() - self.assertEqual(sorted(results[0].keys()), sorted(['_id', 'name'])) - - users = User.objects.only('name', 'price').as_pymongo() + users = User.objects.only('name').as_pymongo() results = list(users) self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[1], dict)) self.assertEqual(results[0]['name'], 'Bob Dole') - self.assertEqual(results[0]['price'], 1.11) self.assertEqual(results[1]['name'], 'Barack Obama') - self.assertEqual(results[1]['price'], 2.22) # Test coerce_types - users = User.objects.only('name', 'price').as_pymongo(coerce_types=True) + users = User.objects.only('name').as_pymongo(coerce_types=True) results = list(users) self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[1], dict)) self.assertEqual(results[0]['name'], 'Bob Dole') - self.assertEqual(results[0]['price'], Decimal('1.11')) self.assertEqual(results[1]['name'], 'Barack Obama') - self.assertEqual(results[1]['price'], Decimal('2.22')) def test_as_pymongo_json_limit_fields(self): @@ -3504,6 +3249,7 @@ class QuerySetTest(unittest.TestCase): serialized_user = User.objects.exclude('password_salt').only('email').to_json() self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) + @unittest.skip("not implemented") def test_no_dereference(self): class Organization(Document): @@ -3551,27 +3297,6 @@ class QuerySetTest(unittest.TestCase): people.count() # count is cached self.assertEqual(q, 1) - def test_no_cached_queryset(self): - class Person(Document): - name = StringField() - - Person.drop_collection() - for i in xrange(100): - Person(name="No: %s" % i).save() - - with query_counter() as q: - self.assertEqual(q, 0) - people = Person.objects.no_cache() - - [x for x in people] - self.assertEqual(q, 1) - - list(people) - self.assertEqual(q, 2) - - people.count() - self.assertEqual(q, 3) - def test_cache_not_cloned(self): class User(Document): @@ -3593,34 +3318,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual("%s" % users, "[]") self.assertEqual(1, len(users._result_cache)) - def test_no_cache(self): - """Ensure you can add meta data to file""" - - class Noddy(Document): - fields = DictField() - - Noddy.drop_collection() - for i in xrange(100): - noddy = Noddy() - for j in range(20): - noddy.fields["key"+str(j)] = "value "+str(j) - noddy.save() - - docs = Noddy.objects.no_cache() - - counter = len([1 for i in docs]) - self.assertEquals(counter, 100) - - self.assertEquals(len(list(docs)), 100) - self.assertRaises(TypeError, lambda: len(docs)) - - with query_counter() as q: - self.assertEqual(q, 0) - list(docs) - self.assertEqual(q, 1) - list(docs) - self.assertEqual(q, 2) - def test_nested_queryset_iterator(self): # Try iterating the same queryset twice, nested. names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George'] @@ -3752,23 +3449,6 @@ class QuerySetTest(unittest.TestCase): '_cls': 'Animal.Cat' }) - def test_can_have_field_same_name_as_query_operator(self): - - class Size(Document): - name = StringField() - - class Example(Document): - size = ReferenceField(Size) - - Size.drop_collection() - Example.drop_collection() - - instance_size = Size(name="Large").save() - Example(size=instance_size).save() - - self.assertEqual(Example.objects(size=instance_size).count(), 1) - self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) - if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index d2e8b78..53c1660 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -31,31 +31,6 @@ class TransformTest(unittest.TestCase): self.assertEqual(transform.query(name__exists=True), {'name': {'$exists': True}}) - def test_transform_update(self): - class DicDoc(Document): - dictField = DictField() - - class Doc(Document): - pass - - DicDoc.drop_collection() - Doc.drop_collection() - - doc = Doc().save() - dic_doc = DicDoc().save() - - for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")): - update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) - self.assertTrue(isinstance(update[v]["dictField.test"], dict)) - - # Update special cases - update = transform.update(DicDoc, unset__dictField__test=doc) - self.assertEqual(update["$unset"]["dictField.test"], 1) - - update = transform.update(DicDoc, pull__dictField__test=doc) - self.assertTrue(isinstance(update["$pull"]["dictField"]["test"], dict)) - - def test_query_field_name(self): """Ensure that the correct field name is used when querying. """ @@ -88,6 +63,7 @@ class TransformTest(unittest.TestCase): BlogPost.drop_collection() + @unittest.skip("unsupported") def test_query_pk_field_name(self): """Ensure that the correct "primary key" field name is used when querying diff --git a/tests/test_connection.py b/tests/test_connection.py index 62d795c..d27a66d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -59,32 +59,6 @@ class ConnectionTest(unittest.TestCase): c.admin.system.users.remove({}) c.mongoenginetest.system.users.remove({}) - def test_connect_uri_without_db(self): - """Ensure that the connect() method works properly with uri's - without database_name - """ - c = connect(db='mongoenginetest', alias='admin') - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) - - c.admin.add_user("admin", "password") - c.admin.authenticate("admin", "password") - c.mongoenginetest.add_user("username", "password") - - self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') - - connect("mongoenginetest", host='mongodb://localhost/') - - conn = get_connection() - self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) - - db = get_db() - self.assertTrue(isinstance(db, pymongo.database.Database)) - self.assertEqual(db.name, 'mongoenginetest') - - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) - def test_register_connection(self): """Ensure that connections with different aliases may be registered. """ diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 6f2664a..f13f291 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -16,6 +16,7 @@ class FieldTest(unittest.TestCase): connect(db='mongoenginetest') self.db = get_db() + @unittest.skip("select_related currently doesn't dereference lists") def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -74,6 +75,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() + @unittest.skip("select_related currently doesn't dereference lists") def test_list_item_dereference_dref_false(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -146,6 +148,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(Group._get_collection().find_one()['members'], [1]) self.assertEqual(group.members, [user]) + @unittest.skip('currently not implemented') def test_handle_old_style_references(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -179,6 +182,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(group.members[0].name, 'user 1') self.assertEqual(group.members[-1].name, 'String!') + @unittest.skip('currently not implemented') def test_migrate_references(self): """Example of migrating ReferenceField storage """ @@ -225,6 +229,7 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(raw_data['author'], ObjectId)) self.assertTrue(isinstance(raw_data['members'][0], ObjectId)) + @unittest.skip("select_related currently doesn't dereference lists") def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. """ @@ -259,9 +264,15 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 1) peter.boss - self.assertEqual(q, 2) + self.assertEqual(q, 1) peter.friends + self.assertEqual(q, 1) + + peter.boss.name + self.assertEqual(q, 2) + + peter.friends[0].name self.assertEqual(q, 3) # Document select_related @@ -291,6 +302,32 @@ class FieldTest(unittest.TestCase): self.assertEqual(employee.friends, friends) self.assertEqual(q, 2) + def test_list_of_lists_of_references(self): + + class User(Document): + name = StringField() + + class Post(Document): + user_lists = ListField(ListField(ReferenceField(User))) + + class SimpleList(Document): + users = ListField(ReferenceField(User)) + + User.drop_collection() + Post.drop_collection() + SimpleList.drop_collection() + + u1 = User.objects.create(name='u1') + u2 = User.objects.create(name='u2') + u3 = User.objects.create(name='u3') + + SimpleList.objects.create(users=[u1, u2, u3]) + self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3]) + + Post.objects.create(user_lists=[[u1, u2], [u3]]) + self.assertEqual(Post.objects.all()[0].user_lists, [[u1, u2], [u3]]) + + def test_circular_reference(self): """Ensure you can handle circular references """ @@ -391,6 +428,7 @@ class FieldTest(unittest.TestCase): "%s" % Person.objects() ) + @unittest.skip("not implemented") def test_generic_reference(self): class UserA(Document): @@ -482,6 +520,7 @@ class FieldTest(unittest.TestCase): UserC.drop_collection() Group.drop_collection() + @unittest.skip("not implemented") def test_list_field_complex(self): class UserA(Document): @@ -573,6 +612,7 @@ class FieldTest(unittest.TestCase): UserC.drop_collection() Group.drop_collection() + @unittest.skip('MapField not fully implemented') def test_map_field_reference(self): class User(Document): @@ -638,6 +678,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() + @unittest.skip("not implemented") def test_dict_field(self): class UserA(Document): @@ -741,6 +782,7 @@ class FieldTest(unittest.TestCase): UserC.drop_collection() Group.drop_collection() + @unittest.skip("not implemented") def test_dict_field_no_field_inheritance(self): class UserA(Document): @@ -817,6 +859,7 @@ class FieldTest(unittest.TestCase): UserA.drop_collection() Group.drop_collection() + @unittest.skip("select_related currently doesn't dereference lists") def test_generic_reference_map_field(self): class UserA(Document): @@ -942,6 +985,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(root.children, [company]) self.assertEqual(company.parents, [root]) + @unittest.skip("not implemented") def test_dict_in_dbref_instance(self): class Person(Document): @@ -1121,32 +1165,37 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) - def test_objectid_reference_across_databases(self): - # mongoenginetest - Is default connection alias from setUp() - # Register Aliases - register_connection('testdb-1', 'mongoenginetest2') + def test_tuples_as_tuples(self): + """ + Ensure that tuples remain tuples when they are + inside a ComplexBaseField + """ + from mongoengine.base import BaseField - class User(Document): - name = StringField() - meta = {"db_alias": "testdb-1"} + class EnumField(BaseField): - class Book(Document): - name = StringField() - author = ReferenceField(User) + def __init__(self, **kwargs): + super(EnumField, self).__init__(**kwargs) - # Drops - User.drop_collection() - Book.drop_collection() + def to_mongo(self, value): + return value - user = User(name="Ross").save() - Book(name="MongoEngine for pros", author=user).save() + def to_python(self, value): + return tuple(value) - # Can't use query_counter across databases - so test the _data object - book = Book.objects.first() - self.assertFalse(isinstance(book._data['author'], User)) + class TestDoc(Document): + items = ListField(EnumField()) - book.select_related() - self.assertTrue(isinstance(book._data['author'], User)) + TestDoc.drop_collection() + tuples = [(100, 'Testing')] + doc = TestDoc() + doc.items = tuples + doc.save() + x = TestDoc.objects().get() + self.assertTrue(x is not None) + self.assertTrue(len(x.items) == 1) + self.assertTrue(tuple(x.items[0]) in tuples) + self.assertTrue(x.items[0] in tuples) def test_non_ascii_pk(self): """ @@ -1171,30 +1220,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(2, len([brand for bg in brand_groups for brand in bg.brands])) - def test_dereferencing_embedded_listfield_referencefield(self): - class Tag(Document): - meta = {'collection': 'tags'} - name = StringField() - - class Post(EmbeddedDocument): - body = StringField() - tags = ListField(ReferenceField("Tag", dbref=True)) - - class Page(Document): - meta = {'collection': 'pages'} - tags = ListField(ReferenceField("Tag", dbref=True)) - posts = ListField(EmbeddedDocumentField(Post)) - - Tag.drop_collection() - Page.drop_collection() - - tag = Tag(name='test').save() - post = Post(body='test body', tags=[tag]) - Page(tags=[tag], posts=[post]).save() - - page = Page.objects.first() - self.assertEqual(page.tags[0], page.posts[0].tags[0]) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_django.py b/tests/test_django.py index 46568ac..63e3245 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -2,44 +2,48 @@ import sys sys.path[0:0] = [""] import unittest from nose.plugins.skip import SkipTest +from mongoengine.python_support import PY3 from mongoengine import * - -from mongoengine.django.shortcuts import get_document_or_404 - -from django.http import Http404 -from django.template import Context, Template -from django.conf import settings -from django.core.paginator import Paginator - -settings.configure( - USE_TZ=True, - INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'), - AUTH_USER_MODEL=('mongo_auth.MongoUser'), -) - try: - from django.contrib.auth import authenticate, get_user_model - from mongoengine.django.auth import User - from mongoengine.django.mongo_auth.models import ( - MongoUser, - MongoUserManager, - get_user_document, + from mongoengine.django.shortcuts import get_document_or_404 + + from django.http import Http404 + from django.template import Context, Template + from django.conf import settings + from django.core.paginator import Paginator + + settings.configure( + USE_TZ=True, + INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'), + AUTH_USER_MODEL=('mongo_auth.MongoUser'), ) - DJ15 = True -except Exception: - DJ15 = False -from django.contrib.sessions.tests import SessionTestsMixin -from mongoengine.django.sessions import SessionStore, MongoSession + + try: + from django.contrib.auth import authenticate, get_user_model + from mongoengine.django.auth import User + from mongoengine.django.mongo_auth.models import MongoUser, MongoUserManager + DJ15 = True + except Exception: + DJ15 = False + from django.contrib.sessions.tests import SessionTestsMixin + from mongoengine.django.sessions import SessionStore, MongoSession +except Exception, err: + if PY3: + SessionTestsMixin = type # dummy value so no error + SessionStore = None # dummy value so no error + else: + raise err + + from datetime import tzinfo, timedelta ZERO = timedelta(0) - class FixedOffset(tzinfo): """Fixed offset in minutes east from UTC.""" def __init__(self, offset, name): - self.__offset = timedelta(minutes=offset) + self.__offset = timedelta(minutes = offset) self.__name = name def utcoffset(self, dt): @@ -66,6 +70,8 @@ def activate_timezone(tz): class QuerySetTest(unittest.TestCase): def setUp(self): + if PY3: + raise SkipTest('django does not have Python 3 support') connect(db='mongoenginetest') class Person(Document): @@ -167,8 +173,6 @@ class QuerySetTest(unittest.TestCase): class Note(Document): text = StringField() - Note.drop_collection() - for i in xrange(1, 101): Note(name="Note: %s" % i).save() @@ -219,6 +223,8 @@ class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase): backend = SessionStore def setUp(self): + if PY3: + raise SkipTest('django does not have Python 3 support') connect(db='mongoenginetest') MongoSession.drop_collection() super(MongoDBSessionTest, self).setUp() @@ -256,18 +262,17 @@ class MongoAuthTest(unittest.TestCase): } def setUp(self): + if PY3: + raise SkipTest('django does not have Python 3 support') if not DJ15: raise SkipTest('mongo_auth requires Django 1.5') connect(db='mongoenginetest') User.drop_collection() super(MongoAuthTest, self).setUp() - def test_get_user_model(self): + def test_user_model(self): self.assertEqual(get_user_model(), MongoUser) - def test_get_user_document(self): - self.assertEqual(get_user_document(), User) - def test_user_manager(self): manager = get_user_model()._default_manager self.assertTrue(isinstance(manager, MongoUserManager)) diff --git a/tests/test_signals.py b/tests/test_signals.py index 50e5e6b..da217c0 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -30,28 +30,10 @@ class SignalTests(unittest.TestCase): def __unicode__(self): return self.name - @classmethod - def pre_init(cls, sender, document, *args, **kwargs): - signal_output.append('pre_init signal, %s' % cls.__name__) - signal_output.append(str(kwargs['values'])) - - @classmethod - def post_init(cls, sender, document, **kwargs): - signal_output.append('post_init signal, %s' % document) - @classmethod def pre_save(cls, sender, document, **kwargs): signal_output.append('pre_save signal, %s' % document) - @classmethod - def pre_save_post_validation(cls, sender, document, **kwargs): - signal_output.append('pre_save_post_validation signal, %s' % document) - if 'created' in kwargs: - if kwargs['created']: - signal_output.append('Is created') - else: - signal_output.append('Is updated') - @classmethod def post_save(cls, sender, document, **kwargs): signal_output.append('post_save signal, %s' % document) @@ -118,10 +100,7 @@ class SignalTests(unittest.TestCase): # Save up the number of connected signals so that we can check at the # end that all the signals we register get properly unregistered self.pre_signals = ( - len(signals.pre_init.receivers), - len(signals.post_init.receivers), len(signals.pre_save.receivers), - len(signals.pre_save_post_validation.receivers), len(signals.post_save.receivers), len(signals.pre_delete.receivers), len(signals.post_delete.receivers), @@ -129,10 +108,7 @@ class SignalTests(unittest.TestCase): len(signals.post_bulk_insert.receivers), ) - signals.pre_init.connect(Author.pre_init, sender=Author) - signals.post_init.connect(Author.post_init, sender=Author) signals.pre_save.connect(Author.pre_save, sender=Author) - signals.pre_save_post_validation.connect(Author.pre_save_post_validation, sender=Author) signals.post_save.connect(Author.post_save, sender=Author) signals.pre_delete.connect(Author.pre_delete, sender=Author) signals.post_delete.connect(Author.post_delete, sender=Author) @@ -145,12 +121,9 @@ class SignalTests(unittest.TestCase): signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) def tearDown(self): - signals.pre_init.disconnect(self.Author.pre_init) - signals.post_init.disconnect(self.Author.post_init) signals.post_delete.disconnect(self.Author.post_delete) signals.pre_delete.disconnect(self.Author.pre_delete) signals.post_save.disconnect(self.Author.post_save) - signals.pre_save_post_validation.disconnect(self.Author.pre_save_post_validation) signals.pre_save.disconnect(self.Author.pre_save) signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) @@ -162,10 +135,7 @@ class SignalTests(unittest.TestCase): # Check that all our signals got disconnected properly. post_signals = ( - len(signals.pre_init.receivers), - len(signals.post_init.receivers), len(signals.pre_save.receivers), - len(signals.pre_save_post_validation.receivers), len(signals.post_save.receivers), len(signals.pre_delete.receivers), len(signals.post_delete.receivers), @@ -180,9 +150,6 @@ class SignalTests(unittest.TestCase): def test_model_signals(self): """ Model saves should throw some signals. """ - def create_author(): - self.Author(name='Bill Shakespeare') - def bulk_create_author_with_load(): a1 = self.Author(name='Bill Shakespeare') self.Author.objects.insert([a1], load_bulk=True) @@ -191,17 +158,9 @@ class SignalTests(unittest.TestCase): a1 = self.Author(name='Bill Shakespeare') self.Author.objects.insert([a1], load_bulk=False) - self.assertEqual(self.get_signal_output(create_author), [ - "pre_init signal, Author", - "{'name': 'Bill Shakespeare'}", - "post_init signal, Bill Shakespeare", - ]) - a1 = self.Author(name='Bill Shakespeare') self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, Bill Shakespeare", - "pre_save_post_validation signal, Bill Shakespeare", - "Is created", "post_save signal, Bill Shakespeare", "Is created" ]) @@ -210,8 +169,6 @@ class SignalTests(unittest.TestCase): a1.name = 'William Shakespeare' self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, William Shakespeare", - "pre_save_post_validation signal, William Shakespeare", - "Is updated", "post_save signal, William Shakespeare", "Is updated" ]) @@ -223,18 +180,13 @@ class SignalTests(unittest.TestCase): signal_output = self.get_signal_output(bulk_create_author_with_load) - # The output of this signal is not entirely deterministic. The reloaded - # object will have an object ID. Hence, we only check part of the output - self.assertEqual(signal_output[3], - "pre_bulk_insert signal, []") - self.assertEqual(signal_output[-2:], - ["post_bulk_insert signal, []", - "Is loaded",]) + self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [ + "pre_bulk_insert signal, []", + "post_bulk_insert signal, []", + "Is loaded", + ]) self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [ - "pre_init signal, Author", - "{'name': 'Bill Shakespeare'}", - "post_init signal, Bill Shakespeare", "pre_bulk_insert signal, []", "post_bulk_insert signal, []", "Not loaded",