diff --git a/.travis.yml b/.travis.yml index b7c56a0..4395107 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,12 +11,14 @@ env: - PYMONGO=dev DJANGO=1.4.2 - PYMONGO=2.5 DJANGO=1.5.1 - PYMONGO=2.5 DJANGO=1.4.2 + - PYMONGO=3.2 DJANGO=1.5.1 + - PYMONGO=3.3 DJANGO=1.5.1 install: - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi - - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install django==$DJANGO --use-mirrors ; true; fi - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi + - pip install https://pypi.python.org/packages/source/p/python-dateutil/python-dateutil-2.1.tar.gz#md5=1534bb15cf311f07afaa3aacba1c028b - python setup.py install script: - python setup.py test diff --git a/AUTHORS b/AUTHORS index 0aa6056..f043207 100644 --- a/AUTHORS +++ b/AUTHORS @@ -9,7 +9,6 @@ Steve Challis Wilson Júnior Dan Crosta https://github.com/dcrosta Laine Herron https://github.com/LaineHerron -Thomas Steinacher http://thomasst.ch/ CONTRIBUTORS @@ -17,8 +16,6 @@ Dervived from the git logs, inevitably incomplete but all of whom and others have submitted patches, reported bugs and generally helped make MongoEngine that much better: - * Harry Marr - * Ross Lawley * blackbrrr * Florian Schlachter * Vincent Driessen @@ -115,6 +112,7 @@ that much better: * Alexander Koshelev * Jaime Irurzun * Alexandre González + * Thomas Steinacher * Tommi Komulainen * Peter Landry * biszkoptwielki @@ -171,3 +169,13 @@ that much better: * ygbourhis (https://github.com/ygbourhis) * Bob Dickinson (https://github.com/BobDickinson) * Michael Bartnett (https://github.com/michaelbartnett) + * Alon Horev (https://github.com/alonho) + * Kelvin Hammond (https://github.com/kelvinhammond) + * Jatin- (https://github.com/jatin-) + * Paul Uithol (https://github.com/PaulUithol) + * Thom Knowles (https://github.com/fleat) + * Paul (https://github.com/squamous) + * Olivier Cortès (https://github.com/Karmak23) + * crazyzubr (https://github.com/crazyzubr) + * FrankSomething (https://github.com/FrankSomething) + * Alexandr Morozov (https://github.com/LK4D4) diff --git a/DIFFERENCES.md b/DIFFERENCES.md deleted file mode 100644 index 01ae0f2..0000000 --- a/DIFFERENCES.md +++ /dev/null @@ -1,37 +0,0 @@ -Differences between Mongomallard and Mongoengine ------ - -* All document fields are lazily evaluated, resulting in much faster object initialization time. -* `_data` is removed due to lazy evaluation. `to_dict()` can be used to convert a document to a dictionary, and `_internal_data` contains previously evaluated data. -* Field methods `to_python`, `from_python`, `to_mongo`, `value_for_instance`: - * `to_python` is called when converting from a MongoDB type to a document Python type only. - * `from_python` is called when converting an assignment in Python to the document Python type. - * `to_mongo` is called when converting from a document Python type to a MongoDB type. - * `value_for_instance` is called just before returning a value in Python allowing for instance-specific transformations. -* `pre_init`, `post_init`, `pre_save_post_validation` signals are removed to ensure fast object initialization. -* `DecimalField` is removed since there is no corresponding MongoDB type -* `LongField` is removed since it is equivalent with `IntField` -* Adding `SafeReferenceField` which returns None if the reference does not exist. -* Adding `SafeReferenceListField` which doesn't return references that don't exist. -* Accessing a `ListField(ReferenceField)` doesn't automatically dereference all objects since they are lazily evaluated. A `SafeReferenceListField` may be used instead. -* Accessing a related object's id doesn't fetch the object from the database, e.g. `book.author.id` where author is a `ReferenceField` will not make a database lookup except when using a `SafeReferenceField`. When inheritance is allowed, a proxy object will be returned, otherwise a lazy object from the referenced document class will be returned. -* The primary key is only stored as `_id` in the database and is referenced in Python as `pk` or as the name of the primary key field. -* Saves are not cascaded by default. -* `Document.save()` supports `full=True` keyword argument to force saving all model fields. -* `_get_changed_fields()` / `_changed_fields` returns a set of field names (not db field names) -* Simplified `EmailField` email regex to be more compatible -* Assigning invalid types (e.g. an invalid string to `IntField`) raises immediately a `ValueError` -* `order_by()` without an argument resets the ordering (no ordering will be applied) - -Untested / not implemented yet: ------ - -* Dynamic documents / `DynamicField`, dynamic addition/deletion of fields -* Field display name methods -* `SequenceField` -* Pickling documents -* `FileField` -* All Geo fields -* `no_dereference()` -* using `SafeReferenceListField` with `GenericReferenceField` -* `max_depth` argument for `doc.reload()` diff --git a/README.md b/README.md deleted file mode 100644 index ab8002c..0000000 --- a/README.md +++ /dev/null @@ -1,85 +0,0 @@ -MongoMallard -============ - -MongoMallard is a fast ORM-like layer on top of PyMongo, based on MongoEngine. - -* Repository: https://github.com/elasticsales/mongomallard -* See [README_MONGOENGINE](https://github.com/elasticsales/mongomallard/blob/master/README_MONGOENGINE.rst) for MongoEngine's README. -* See [DIFFERENCES](https://github.com/elasticsales/mongomallard/blob/master/DIFFERENCES.md) for differences between MongoEngine and MongoMallard. - - -Benchmarks ----------- - -Sample run on a 2.7 GHz Intel Core i5 running OS X 10.8.3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
MongoEngine 0.8.2 (ede9fcf)MongoMallard (478062c)Speedup
Doc initialization52.494us25.195us2.08x
Doc getattr1.339us0.584us2.29x
Doc setattr3.064us2.550us1.20x
Doc to mongo49.415us26.497us1.86x
Load from SON61.475us4.510us13.63x
Save to database434.389us289.972us2.29x
Load from database558.178us480.690us1.16x
Save/delete big object to database98.838ms65.789ms1.50x
Serialize big object from database31.390ms20.265ms1.55x
Load big object from database41.159ms1.400ms29.40x
- -See [tests/benchmark.py](https://github.com/elasticsales/mongomallard/blob/master/tests/benchmark.py) for source code. diff --git a/README_MONGOENGINE.rst b/README.rst similarity index 100% rename from README_MONGOENGINE.rst rename to README.rst diff --git a/docs/_themes/nature/static/nature.css_t b/docs/_themes/nature/static/nature.css_t index 03b0379..337760b 100644 --- a/docs/_themes/nature/static/nature.css_t +++ b/docs/_themes/nature/static/nature.css_t @@ -2,11 +2,15 @@ * Sphinx stylesheet -- default theme * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ - + @import url("basic.css"); - + +#changelog p.first {margin-bottom: 0 !important;} +#changelog p {margin-top: 0 !important; + margin-bottom: 0 !important;} + /* -- page layout ----------------------------------------------------------- */ - + body { font-family: Arial, sans-serif; font-size: 100%; @@ -28,18 +32,18 @@ div.bodywrapper { hr{ border: 1px solid #B1B4B6; } - + div.document { background-color: #eee; } - + div.body { background-color: #ffffff; color: #3E4349; padding: 0 30px 30px 30px; font-size: 0.8em; } - + div.footer { color: #555; width: 100%; @@ -47,12 +51,12 @@ div.footer { text-align: center; font-size: 75%; } - + div.footer a { color: #444; text-decoration: underline; } - + div.related { background-color: #6BA81E; line-height: 32px; @@ -60,11 +64,11 @@ div.related { text-shadow: 0px 1px 0 #444; font-size: 0.80em; } - + div.related a { color: #E2F3CC; } - + div.sphinxsidebar { font-size: 0.75em; line-height: 1.5em; @@ -73,7 +77,7 @@ div.sphinxsidebar { div.sphinxsidebarwrapper{ padding: 20px 0; } - + div.sphinxsidebar h3, div.sphinxsidebar h4 { font-family: Arial, sans-serif; @@ -89,30 +93,30 @@ div.sphinxsidebar h4 { div.sphinxsidebar h4{ font-size: 1.1em; } - + div.sphinxsidebar h3 a { color: #444; } - - + + div.sphinxsidebar p { color: #888; padding: 5px 20px; } - + div.sphinxsidebar p.topless { } - + div.sphinxsidebar ul { margin: 10px 20px; padding: 0; color: #000; } - + div.sphinxsidebar a { color: #444; } - + div.sphinxsidebar input { border: 1px solid #ccc; font-family: sans-serif; @@ -122,19 +126,19 @@ div.sphinxsidebar input { div.sphinxsidebar input[type=text]{ margin-left: 20px; } - + /* -- body styles ----------------------------------------------------------- */ - + a { color: #005B81; text-decoration: none; } - + a:hover { color: #E32E00; text-decoration: underline; } - + div.body h1, div.body h2, div.body h3, @@ -149,30 +153,30 @@ div.body h6 { padding: 5px 0 5px 10px; text-shadow: 0px 1px 0 white } - + div.body h1 { border-top: 20px solid white; margin-top: 0; font-size: 200%; } div.body h2 { font-size: 150%; background-color: #C8D5E3; } div.body h3 { font-size: 120%; background-color: #D8DEE3; } div.body h4 { font-size: 110%; background-color: #D8DEE3; } div.body h5 { font-size: 100%; background-color: #D8DEE3; } div.body h6 { font-size: 100%; background-color: #D8DEE3; } - + a.headerlink { color: #c60f0f; font-size: 0.8em; padding: 0 4px 0 4px; text-decoration: none; } - + a.headerlink:hover { background-color: #c60f0f; color: white; } - + div.body p, div.body dd, div.body li { line-height: 1.5em; } - + div.admonition p.admonition-title + p { display: inline; } @@ -185,29 +189,29 @@ div.note { background-color: #eee; border: 1px solid #ccc; } - + div.seealso { background-color: #ffc; border: 1px solid #ff6; } - + div.topic { background-color: #eee; } - + div.warning { background-color: #ffe4e4; border: 1px solid #f66; } - + p.admonition-title { display: inline; } - + p.admonition-title:after { content: ":"; } - + pre { padding: 10px; background-color: White; @@ -219,7 +223,7 @@ pre { -webkit-box-shadow: 1px 1px 1px #d8d8d8; -moz-box-shadow: 1px 1px 1px #d8d8d8; } - + tt { background-color: #ecf0f3; color: #222; diff --git a/docs/apireference.rst b/docs/apireference.rst index d062727..9057de5 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -44,12 +44,21 @@ Context Managers Querying ======== -.. autoclass:: mongoengine.queryset.QuerySet - :members: +.. automodule:: mongoengine.queryset + :synopsis: Queryset level operations - .. automethod:: mongoengine.queryset.QuerySet.__call__ + .. autoclass:: mongoengine.queryset.QuerySet + :members: + :inherited-members: -.. autofunction:: mongoengine.queryset.queryset_manager + .. automethod:: QuerySet.__call__ + + .. autoclass:: mongoengine.queryset.QuerySetNoCache + :members: + + .. automethod:: mongoengine.queryset.QuerySetNoCache.__call__ + + .. autofunction:: mongoengine.queryset.queryset_manager Fields ====== diff --git a/docs/changelog.rst b/docs/changelog.rst index 1927bee..926fb8a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,13 +2,49 @@ Changelog ========= +Changes in 0.8.4 +================ +- Remove database name necessity in uri connection schema (#452) +- Fixed "$pull" semantics for nested ListFields (#447) +- Allow fields to be named the same as query operators (#445) +- Updated field filter logic - can now exclude subclass fields (#443) +- Fixed dereference issue with embedded listfield referencefields (#439) +- Fixed slice when using inheritance causing fields to be excluded (#437) +- Fixed ._get_db() attribute after a Document.switch_db() (#441) +- Dynamic Fields store and recompose Embedded Documents / Documents correctly (#449) +- Handle dynamic fieldnames that look like digits (#434) +- Added get_user_document and improve mongo_auth module (#423) +- Added str representation of GridFSProxy (#424) +- Update transform to handle docs erroneously passed to unset (#416) +- Fixed indexing - turn off _cls (#414) +- Fixed dereference threading issue in ComplexField.__get__ (#412) +- Fixed QuerySetNoCache.count() caching (#410) +- Don't follow references in _get_changed_fields (#422, #417) +- Allow args and kwargs to be passed through to_json (#420) + Changes in 0.8.3 ================ +- Fixed EmbeddedDocuments with `id` also storing `_id` (#402) +- Added get_proxy_object helper to filefields (#391) +- Added QuerySetNoCache and QuerySet.no_cache() for lower memory consumption (#365) +- Fixed sum and average mapreduce dot notation support (#375, #376, #393) +- Fixed as_pymongo to return the id (#386) +- Document.select_related() now respects `db_alias` (#377) +- Reload uses shard_key if applicable (#384) +- Dynamic fields are ordered based on creation and stored in _fields_ordered (#396) + + **Potential breaking change:** http://docs.mongoengine.org/en/latest/upgrade.html#to-0-8-3 + +- Fixed pickling dynamic documents `_dynamic_fields` (#387) +- Fixed ListField setslice and delslice dirty tracking (#390) +- Added Django 1.5 PY3 support (#392) - Added match ($elemMatch) support for EmbeddedDocuments (#379) - Fixed weakref being valid after reload (#374) - Fixed queryset.get() respecting no_dereference (#373) - Added full_result kwarg to update (#380) + + Changes in 0.8.2 ================ - Added compare_indexes helper (#361) diff --git a/docs/django.rst b/docs/django.rst index da15188..62d4dd4 100644 --- a/docs/django.rst +++ b/docs/django.rst @@ -45,7 +45,7 @@ The :mod:`~mongoengine.django.auth` module also contains a Custom User model ================= Django 1.5 introduced `Custom user Models -` +`_ which can be used as an alternative to the MongoEngine authentication backend. The main advantage of this option is that other components relying on @@ -74,7 +74,7 @@ An additional ``MONGOENGINE_USER_DOCUMENT`` setting enables you to replace the The custom :class:`User` must be a :class:`~mongoengine.Document` class, but otherwise has the same requirements as a standard custom user model, as specified in the `Django Documentation -`. +`_. In particular, the custom class must define :attr:`USERNAME_FIELD` and :attr:`REQUIRED_FIELDS` attributes. @@ -128,7 +128,7 @@ appended to the filename until the generated filename doesn't exist. The >>> fs.listdir() ([], [u'hello.txt']) -All files will be saved and retrieved in GridFS via the :class::`FileDocument` +All files will be saved and retrieved in GridFS via the :class:`FileDocument` document, allowing easy access to the files without the GridFSStorage backend.:: @@ -137,3 +137,36 @@ backend.:: [] .. versionadded:: 0.4 + +Shortcuts +========= +Inspired by the `Django shortcut get_object_or_404 +`_, +the :func:`~mongoengine.django.shortcuts.get_document_or_404` method returns +a document or raises an Http404 exception if the document does not exist:: + + from mongoengine.django.shortcuts import get_document_or_404 + + admin_user = get_document_or_404(User, username='root') + +The first argument may be a Document or QuerySet object. All other passed arguments +and keyword arguments are used in the query:: + + foo_email = get_document_or_404(User.objects.only('email'), username='foo', is_active=True).email + +.. note:: Like with :func:`get`, a MultipleObjectsReturned will be raised if more than one + object is found. + + +Also inspired by the `Django shortcut get_list_or_404 +`_, +the :func:`~mongoengine.django.shortcuts.get_list_or_404` method returns a list of +documents or raises an Http404 exception if the list is empty:: + + from mongoengine.django.shortcuts import get_list_or_404 + + active_users = get_list_or_404(User, is_active=True) + +The first argument may be a Document or QuerySet object. All other passed +arguments and keyword arguments are used to filter the query. + diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index 854e2c3..f681aad 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -23,12 +23,15 @@ arguments should be provided:: connect('project1', username='webapp', password='pwd123') -Uri style connections are also supported as long as you include the database -name - just supply the uri as the :attr:`host` to +Uri style connections are also supported - just supply the uri as +the :attr:`host` to :func:`~mongoengine.connect`:: connect('project1', host='mongodb://localhost/database_name') +Note that database name from uri has priority over name +in ::func:`~mongoengine.connect` + ReplicaSets =========== diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index a61d8fe..ba1af33 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -54,7 +54,7 @@ be saved :: There is one caveat on Dynamic Documents: fields cannot start with `_` -Dynamic fields are stored in alphabetical order *after* any declared fields. +Dynamic fields are stored in creation order *after* any declared fields. Fields ====== @@ -442,6 +442,8 @@ The following example shows a :class:`Log` document that will be limited to ip_address = StringField() meta = {'max_documents': 1000, 'max_size': 2000000} +.. defining-indexes_ + Indexes ======= @@ -485,6 +487,35 @@ If a dictionary is passed then the following options are available: Inheritance adds extra fields indices see: :ref:`document-inheritance`. +Global index default options +---------------------------- + +There are a few top level defaults for all indexes that can be set:: + + class Page(Document): + title = StringField() + rating = StringField() + meta = { + 'index_options': {}, + 'index_background': True, + 'index_drop_dups': True, + 'index_cls': False + } + + +:attr:`index_options` (Optional) + Set any default index options - see the `full options list `_ + +:attr:`index_background` (Optional) + Set the default value for if an index should be indexed in the background + +:attr:`index_drop_dups` (Optional) + Set the default value for if an index should drop duplicates + +:attr:`index_cls` (Optional) + A way to turn off a specific index for _cls. + + Compound Indexes and Indexing sub documents ------------------------------------------- @@ -558,6 +589,11 @@ documentation for more information. A common usecase might be session data:: ] } +.. warning:: TTL indexes happen on the MongoDB server and not in the application + code, therefore no signals will be fired on document deletion. + If you need signals to be fired on deletion, then you must handle the + deletion of Documents in your application code. + Comparing Indexes ----------------- @@ -653,7 +689,6 @@ document.:: .. note:: From 0.8 onwards you must declare :attr:`allow_inheritance` defaults to False, meaning you must set it to True to use inheritance. - Working with existing data -------------------------- As MongoEngine no longer defaults to needing :attr:`_cls` you can quickly and @@ -673,3 +708,25 @@ defining all possible field types. If you use :class:`~mongoengine.Document` and the database contains data that isn't defined then that data will be stored in the `document._data` dictionary. + +Abstract classes +================ + +If you want to add some extra functionality to a group of Document classes but +you don't need or want the overhead of inheritance you can use the +:attr:`abstract` attribute of :attr:`-mongoengine.Document.meta`. +This won't turn on :ref:`document-inheritance` but will allow you to keep your +code DRY:: + + class BaseDocument(Document): + meta = { + 'abstract': True, + } + def check_permissions(self): + ... + + class User(BaseDocument): + ... + +Now the User class will have access to the inherited `check_permissions` method +and won't store any of the extra `_cls` information. diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 1350130..f50985b 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -16,7 +16,9 @@ fetch documents from the database:: .. note:: As of MongoEngine 0.8 the querysets utilise a local cache. So iterating - it multiple times will only cause a single query. + it multiple times will only cause a single query. If this is not the + desired behavour you can call :class:`~mongoengine.QuerySet.no_cache` + (version **0.8.3+**) to return a non-caching queryset. Filtering queries ================= @@ -495,7 +497,6 @@ that you may use with these methods: * ``unset`` -- delete a particular value (since MongoDB v1.3+) * ``inc`` -- increment a value by a given amount * ``dec`` -- decrement a value by a given amount -* ``pop`` -- remove the last item from a list * ``push`` -- append a value to a list * ``push_all`` -- append several values to a list * ``pop`` -- remove the first or last element of a list diff --git a/docs/upgrade.rst b/docs/upgrade.rst index c3d3182..a1fccea 100644 --- a/docs/upgrade.rst +++ b/docs/upgrade.rst @@ -2,12 +2,22 @@ Upgrading ######### + +0.8.2 to 0.8.3 +************** + +Minor change that may impact users: + +DynamicDocument fields are now stored in creation order after any declared +fields. Previously they were stored alphabetically. + + 0.7 to 0.8 ********** There have been numerous backwards breaking changes in 0.8. The reasons for -these are ensure that MongoEngine has sane defaults going forward and -performs the best it can out the box. Where possible there have been +these are to ensure that MongoEngine has sane defaults going forward and that it +performs the best it can out of the box. Where possible there have been FutureWarnings to help get you ready for the change, but that hasn't been possible for the whole of the release. @@ -61,7 +71,7 @@ inherited classes like so: :: Document Definition ------------------- -The default for inheritance has changed - its now off by default and +The default for inheritance has changed - it is now off by default and :attr:`_cls` will not be stored automatically with the class. So if you extend your :class:`~mongoengine.Document` or :class:`~mongoengine.EmbeddedDocuments` you will need to declare :attr:`allow_inheritance` in the meta data like so: :: @@ -71,7 +81,7 @@ you will need to declare :attr:`allow_inheritance` in the meta data like so: :: meta = {'allow_inheritance': True} -Previously, if you had data the database that wasn't defined in the Document +Previously, if you had data in the database that wasn't defined in the Document definition, it would set it as an attribute on the document. This is no longer the case and the data is set only in the ``document._data`` dictionary: :: @@ -92,8 +102,8 @@ the case and the data is set only in the ``document._data`` dictionary: :: AttributeError: 'Animal' object has no attribute 'size' The Document class has introduced a reserved function `clean()`, which will be -called before saving the document. If your document class happen to have a method -with the same name, please try rename it. +called before saving the document. If your document class happens to have a method +with the same name, please try to rename it. def clean(self): pass @@ -101,7 +111,7 @@ with the same name, please try rename it. ReferenceField -------------- -ReferenceFields now store ObjectId's by default - this is more efficient than +ReferenceFields now store ObjectIds by default - this is more efficient than DBRefs as we already know what Document types they reference:: # Old code @@ -147,7 +157,7 @@ UUIDFields now default to storing binary values:: class Animal(Document): uuid = UUIDField(binary=False) -To migrate all the uuid's you need to touch each object and mark it as dirty +To migrate all the uuids you need to touch each object and mark it as dirty eg:: # Doc definition @@ -165,7 +175,7 @@ eg:: DecimalField ------------ -DecimalField now store floats - previous it was storing strings and that +DecimalFields now store floats - previously it was storing strings and that made it impossible to do comparisons when querying correctly.:: # Old code @@ -176,7 +186,7 @@ made it impossible to do comparisons when querying correctly.:: class Person(Document): balance = DecimalField(force_string=True) -To migrate all the uuid's you need to touch each object and mark it as dirty +To migrate all the DecimalFields you need to touch each object and mark it as dirty eg:: # Doc definition @@ -188,7 +198,7 @@ eg:: p._mark_as_changed('balance') p.save() -.. note:: DecimalField's have also been improved with the addition of precision +.. note:: DecimalFields have also been improved with the addition of precision and rounding. See :class:`~mongoengine.fields.DecimalField` for more information. `An example test migration for DecimalFields is available on github @@ -197,7 +207,7 @@ eg:: Cascading Saves --------------- To improve performance document saves will no longer automatically cascade. -Any changes to a Documents references will either have to be saved manually or +Any changes to a Document's references will either have to be saved manually or you will have to explicitly tell it to cascade on save:: # At the class level: @@ -239,7 +249,7 @@ update your code like so: :: # Update example a) assign queryset after a change: mammals = Animal.objects(type="mammal") - carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so fitler can be applied + carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so filter can be applied [m for m in carnivores] # This will return all carnivores # Update example b) chain the queryset: @@ -266,7 +276,7 @@ queryset you should upgrade to use count:: .only() now inline with .exclude() ---------------------------------- -The behaviour of `.only()` was highly ambious, now it works in the mirror fashion +The behaviour of `.only()` was highly ambiguous, now it works in mirror fashion to `.exclude()`. Chaining `.only()` calls will increase the fields required:: # Old code @@ -430,7 +440,7 @@ main areas of changed are: choices in fields, map_reduce and collection names. Choice options: =============== -Are now expected to be an iterable of tuples, with the first element in each +Are now expected to be an iterable of tuples, with the first element in each tuple being the actual value to be stored. The second element is the human-readable name for the option. @@ -452,8 +462,8 @@ such the following have been changed: Default collection naming ========================= -Previously it was just lowercase, its now much more pythonic and readable as -its lowercase and underscores, previously :: +Previously it was just lowercase, it's now much more pythonic and readable as +it's lowercase and underscores, previously :: class MyAceDocument(Document): pass @@ -520,5 +530,5 @@ Alternatively, you can rename your collections eg :: mongodb 1.8 > 2.0 + =================== -Its been reported that indexes may need to be recreated to the newer version of indexes. +It's been reported that indexes may need to be recreated to the newer version of indexes. To do this drop indexes and call ``ensure_indexes`` on each model. diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 875c916..2b68b3c 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -15,8 +15,7 @@ import django __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + list(queryset.__all__) + signals.__all__ + list(errors.__all__)) -VERSION = (0, 8, 2) -MALLARD = True +VERSION = (0, 8, 4) def get_version(): diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index adcd8d0..4652fb5 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -108,6 +108,14 @@ class BaseList(list): self._mark_as_changed() return super(BaseList, self).__delitem__(*args, **kwargs) + def __setslice__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).__setslice__(*args, **kwargs) + + def __delslice__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).__delslice__(*args, **kwargs) + def __getstate__(self): self.instance = None self._dereferenced = False diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index b5e357d..cea2f09 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -4,7 +4,7 @@ import numbers from functools import partial import pymongo -from bson import json_util +from bson import json_util, ObjectId from bson.dbref import DBRef from bson.son import SON @@ -15,7 +15,6 @@ from mongoengine.errors import (ValidationError, InvalidDocumentError, from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, to_str_keys_recursive) -from mongoengine.base.proxy import DocumentProxy from mongoengine.base.common import get_document, ALLOW_INHERITANCE from mongoengine.base.datastructures import BaseDict, BaseList from mongoengine.base.fields import ComplexBaseField @@ -24,52 +23,155 @@ __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') NON_FIELD_ERRORS = '__all__' -_set = object.__setattr__ - class BaseDocument(object): - #_dynamic = False - #_dynamic_lock = True + _dynamic = False + _created = True + _dynamic_lock = True _initialised = False - def __init__(self, _son=None, **values): + def __init__(self, *args, **values): """ Initialise a document or embedded document + :param __auto_convert: Try and will cast python objects to Object types :param values: A dictionary of values for the document """ - _set(self, '_db_data', _son) - _set(self, '_lazy', False) - _set(self, '_internal_data', {}) - _set(self, '_changed_fields', set()) - if values: - pk = values.pop('pk', None) - for field in set(self._fields.keys()).intersection(values.keys()): - setattr(self, field, values[field]) - if pk != None: - self.pk = pk + if args: + # Combine positional arguments with named arguments. + # We only want named arguments. + field = iter(self._fields_ordered) + # If its an automatic id field then skip to the first defined field + if self._auto_id_field: + next(field) + for value in args: + name = next(field) + if name in values: + raise TypeError("Multiple values for keyword argument '" + name + "'") + values[name] = value + __auto_convert = values.pop("__auto_convert", True) + signals.pre_init.send(self.__class__, document=self, values=values) - def __delattr__(self, name): - default = self._fields[name].default - value = default() if callable(default) else default - setattr(self, name, value) + self._data = {} + self._dynamic_fields = SON() - @property - def _created(self): - return self._db_data != None or self._lazy + # Assign default values to instance + for key, field in self._fields.iteritems(): + if self._db_field_map.get(key, key) in values: + continue + value = getattr(self, key, None) + setattr(self, key, value) + + # Set passed values after initialisation + if self._dynamic: + dynamic_data = {} + for key, value in values.iteritems(): + if key in self._fields or key == '_id': + setattr(self, key, value) + elif self._dynamic: + dynamic_data[key] = value + else: + FileField = _import_class('FileField') + for key, value in values.iteritems(): + if key == '__auto_convert': + continue + key = self._reverse_db_field_map.get(key, key) + if key in self._fields or key in ('id', 'pk', '_cls'): + if __auto_convert and value is not None: + field = self._fields.get(key) + if field and not isinstance(field, FileField): + value = field.to_python(value) + setattr(self, key, value) + else: + self._data[key] = value + + # Set any get_fieldname_display methods + self.__set_field_display() + + if self._dynamic: + self._dynamic_lock = False + for key, value in dynamic_data.iteritems(): + setattr(self, key, value) + + # Flag initialised + self._initialised = True + signals.post_init.send(self.__class__, document=self) + + def __delattr__(self, *args, **kwargs): + """Handle deletions of fields""" + field_name = args[0] + if field_name in self._fields: + default = self._fields[field_name].default + if callable(default): + default = default() + setattr(self, field_name, default) + else: + super(BaseDocument, self).__delattr__(*args, **kwargs) + + def __setattr__(self, name, value): + # Handle dynamic data only if an initialised dynamic document + if self._dynamic and not self._dynamic_lock: + + field = None + if not hasattr(self, name) and not name.startswith('_'): + DynamicField = _import_class("DynamicField") + field = DynamicField(db_field=name) + field.name = name + self._dynamic_fields[name] = field + self._fields_ordered += (name,) + + if not name.startswith('_'): + value = self.__expand_dynamic_values(name, value) + + # Handle marking data as changed + if name in self._dynamic_fields: + self._data[name] = value + if hasattr(self, '_changed_fields'): + self._mark_as_changed(name) + + if (self._is_document and not self._created and + name in self._meta.get('shard_key', tuple()) and + self._data.get(name) != value): + OperationError = _import_class('OperationError') + msg = "Shard Keys are immutable. Tried to update %s" % name + raise OperationError(msg) + + # Check if the user has created a new instance of a class + if (self._is_document and self._initialised + and self._created and name == self._meta['id_field']): + super(BaseDocument, self).__setattr__('_created', False) + + super(BaseDocument, self).__setattr__(name, value) + + def __getstate__(self): + data = {} + for k in ('_changed_fields', '_initialised', '_created', + '_dynamic_fields', '_fields_ordered'): + if hasattr(self, k): + data[k] = getattr(self, k) + data['_data'] = self.to_mongo() + return data + + def __setstate__(self, data): + if isinstance(data["_data"], SON): + data["_data"] = self.__class__._from_son(data["_data"])._data + for k in ('_changed_fields', '_initialised', '_created', '_data', + '_fields_ordered', '_dynamic_fields'): + if k in data: + setattr(self, k, data[k]) + dynamic_fields = data.get('_dynamic_fields') or SON() + for k in dynamic_fields.keys(): + setattr(self, k, data["_data"].get(k)) def __iter__(self): - if 'id' in self._fields and 'id' not in self._fields_ordered: - return iter(('id', ) + self._fields_ordered) - return iter(self._fields_ordered) def __getitem__(self, name): """Dictionary-style field access, return a field's value if present. """ try: - if name in self._fields: + if name in self._fields_ordered: return getattr(self, name) except AttributeError: pass @@ -90,8 +192,8 @@ class BaseDocument(object): except AttributeError: return False - def __unicode__(self): - return u'%s object' % self.__class__.__name__ + def __len__(self): + return len(self._data) def __repr__(self): try: @@ -110,12 +212,9 @@ class BaseDocument(object): return txt_type('%s object' % self.__class__.__name__) def __eq__(self, other): - if isinstance(other, DocumentProxy) and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk') and self.pk == other.pk: - return True - - if isinstance(other, self.__class__) and hasattr(other, 'pk') and self.pk == other.pk: - return True - + if isinstance(other, self.__class__) and hasattr(other, 'id'): + if self.id == other.id: + return True return False def __ne__(self, other): @@ -141,16 +240,42 @@ class BaseDocument(object): def to_mongo(self): """Return as SON data ready for use with MongoDB. """ - sets, unsets = self._delta(full=True) - son = SON(data=sets) - allow_inheritance = self._meta.get('allow_inheritance', - ALLOW_INHERITANCE) - if allow_inheritance: - son['_cls'] = self._class_name - return son + data = SON() + data["_id"] = None + data['_cls'] = self._class_name - def to_dict(self): - return dict((field, getattr(self, field)) for field in self._fields) + for field_name in self: + value = self._data.get(field_name, None) + field = self._fields.get(field_name) + if field is None and self._dynamic: + field = self._dynamic_fields.get(field_name) + + if value is not None: + value = field.to_mongo(value) + + # Handle self generating fields + if value is None and field._auto_gen: + value = field.generate() + self._data[field_name] = value + + if value is not None: + data[field.db_field] = value + + # If "_id" has not been set, then try and set it + Document = _import_class("Document") + if isinstance(self, Document): + if data["_id"] is None: + data["_id"] = self._data.get("id", None) + + if data['_id'] is None: + data.pop('_id') + + # Only add _cls if allow_inheritance is True + if (not hasattr(self, '_meta') or + not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): + data.pop('_cls') + + return data def validate(self, clean=True): """Ensure that all fields' values are valid and that required fields @@ -165,11 +290,8 @@ class BaseDocument(object): errors[NON_FIELD_ERRORS] = error # Get a list of tuples of field names and their current values - fields = [(field, getattr(self, name)) - for name, field in self._fields.items()] - #if self._dynamic: - # fields += [(field, self._data.get(name)) - # for name, field in self._dynamic_fields.items()] + fields = [(self._fields.get(name, self._dynamic_fields.get(name)), + self._data.get(name)) for name in self._fields_ordered] EmbeddedDocumentField = _import_class("EmbeddedDocumentField") GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") @@ -199,9 +321,9 @@ class BaseDocument(object): message = "ValidationError (%s:%s) " % (self._class_name, pk) raise ValidationError(message, errors=errors) - def to_json(self): + def to_json(self, *args, **kwargs): """Converts a document to JSON""" - return json_util.dumps(self.to_mongo()) + return json_util.dumps(self.to_mongo(), *args, **kwargs) @classmethod def from_json(cls, json_data): @@ -243,40 +365,17 @@ class BaseDocument(object): return value def _mark_as_changed(self, key): - """Marks a key as explicitly changed by the user. + """Marks a key as explicitly changed by the user """ - - if key: - self._changed_fields.add(key) - - def _get_changed_fields(self): - """Returns a list of all fields that have explicitly been changed. - """ - changed_fields = set(self._changed_fields) - EmbeddedDocumentField = _import_class("EmbeddedDocumentField") - for field_name, field in self._fields.iteritems(): - if field_name not in changed_fields: - if (isinstance(field, ComplexBaseField) and - isinstance(field.field, EmbeddedDocumentField)): - field_value = getattr(self, field_name, None) - if field_value: - for idx in (field_value if isinstance(field_value, dict) - else xrange(len(field_value))): - changed_subfields = field_value[idx]._get_changed_fields() - if changed_subfields: - changed_fields |= set(['.'.join([field_name, str(idx), subfield_name]) - for subfield_name in changed_subfields]) - elif isinstance(field, EmbeddedDocumentField): - field_value = getattr(self, field_name, None) - if field_value: - changed_subfields = field_value._get_changed_fields() - if changed_subfields: - changed_fields |= set(['.'.join([field_name, subfield_name]) - for subfield_name in changed_subfields]) - return changed_fields + if not key: + return + key = self._db_field_map.get(key, key) + if (hasattr(self, '_changed_fields') and + key not in self._changed_fields): + self._changed_fields.append(key) def _clear_changed_fields(self): - _set(self, '_changed_fields', set()) + self._changed_fields = [] EmbeddedDocumentField = _import_class("EmbeddedDocumentField") for field_name, field in self._fields.iteritems(): if (isinstance(field, ComplexBaseField) and @@ -291,57 +390,136 @@ class BaseDocument(object): if field_value: field_value._clear_changed_fields() - def _delta(self, full=False): - sets = {} - unsets = {} + def _get_changed_fields(self, inspected=None): + """Returns a list of all fields that have explicitly been changed. + """ + EmbeddedDocument = _import_class("EmbeddedDocument") + DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") + ReferenceField = _import_class("ReferenceField") + _changed_fields = [] + _changed_fields += getattr(self, '_changed_fields', []) + inspected = inspected or set() + if hasattr(self, 'id'): + if self.id in inspected: + return _changed_fields + inspected.add(self.id) - def get_db_value(field, value): - if value is None: - value = field.default() if callable(field.default) else field.default - return field.to_mongo(value) + for field_name in self._fields_ordered: + db_field_name = self._db_field_map.get(field_name, field_name) + key = '%s.' % db_field_name + data = self._data.get(field_name, None) + field = self._fields.get(field_name) + if hasattr(data, 'id'): + if data.id in inspected: + continue + inspected.add(data.id) + if isinstance(field, ReferenceField): + continue + elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) + and db_field_name not in _changed_fields): + # Find all embedded fields that have been changed + changed = data._get_changed_fields(inspected) + _changed_fields += ["%s%s" % (key, k) for k in changed if k] + elif (isinstance(data, (list, tuple, dict)) and + db_field_name not in _changed_fields): + # Loop list / dict fields as they contain documents + # Determine the iterator to use + if not hasattr(data, 'items'): + iterator = enumerate(data) + else: + iterator = data.iteritems() + for index, value in iterator: + if not hasattr(value, '_get_changed_fields'): + continue + if (hasattr(field, 'field') and + isinstance(field.field, ReferenceField)): + continue + list_key = "%s%s." % (key, index) + changed = value._get_changed_fields(inspected) + _changed_fields += ["%s%s" % (list_key, k) + for k in changed if k] + return _changed_fields - if full or not self._created: - fields = self._fields.iteritems() - db_data = ((self._db_field_map.get(field_name, field_name), - get_db_value(field, getattr(self, field_name))) - for field_name, field in fields) + def _delta(self): + """Returns the delta (set, unset) of the changes for a document. + Gets any values that have been explicitly changed. + """ + # Handles cases where not loaded from_son but has _id + doc = self.to_mongo() + set_fields = self._get_changed_fields() + unset_data = {} + parts = [] + if hasattr(self, '_changed_fields'): + set_data = {} + # Fetch each set item from its path + for path in set_fields: + parts = path.split('.') + d = doc + new_path = [] + for p in parts: + if isinstance(d, (ObjectId, DBRef)): + break + elif isinstance(d, list) and p.isdigit(): + d = d[int(p)] + elif hasattr(d, 'get'): + d = d.get(p) + new_path.append(p) + path = '.'.join(new_path) + set_data[path] = d else: - # List of (db_field_name, db_value) tuples. - db_data = [] + set_data = doc + if '_id' in set_data: + del(set_data['_id']) - for field_name in self._get_changed_fields(): - parts = field_name.split('.') + # Determine if any changed items were actually unset. + for path, value in set_data.items(): + if value or isinstance(value, (numbers.Number, bool)): + continue - db_field_parts = [] + # If we've set a value that ain't the default value dont unset it. + default = None + if (self._dynamic and len(parts) and parts[0] in + self._dynamic_fields): + del(set_data[path]) + unset_data[path] = 1 + continue + elif path in self._fields: + default = self._fields[path].default + else: # Perform a full lookup for lists / embedded lookups + d = self + parts = path.split('.') + db_field_name = parts.pop() + for p in parts: + if isinstance(d, list) and p.isdigit(): + d = d[int(p)] + elif (hasattr(d, '__getattribute__') and + not isinstance(d, dict)): + real_path = d._reverse_db_field_map.get(p, p) + d = getattr(d, real_path) + else: + d = d.get(p) - value = self - for part in parts: - if isinstance(value, list) and part.isdigit(): - db_field_parts.append(part) - field = field.field - value = value[int(part)] - elif isinstance(value, dict): - db_field_parts.append(part) - field = field.field - value = value[part] - else: # It's a document - obj = value - field = obj._fields[part] - db_field_parts.append(obj._db_field_map.get(part, part)) - value = getattr(obj, part) + if hasattr(d, '_fields'): + field_name = d._reverse_db_field_map.get(db_field_name, + db_field_name) + if field_name in d._fields: + default = d._fields.get(field_name).default + else: + default = None - db_data.append(('.'.join(db_field_parts), get_db_value(field, value))) + if default is not None: + if callable(default): + default = default() - for db_field_name, db_value in db_data: - if db_value == None: - unsets[db_field_name] = 1 - else: - sets[db_field_name] = db_value + if default != value: + continue - return sets, unsets + del(set_data[path]) + unset_data[path] = 1 + return set_data, unset_data @classmethod def _get_collection_name(cls): @@ -350,16 +528,61 @@ class BaseDocument(object): return cls._meta.get('collection', None) @classmethod - def _from_son(cls, son, _auto_dereference=False): + def _from_son(cls, son, _auto_dereference=True): + """Create an instance of a Document (subclass) from a PyMongo SON. + """ + # get the class name from the document, falling back to the given # class if unavailable class_name = son.get('_cls', cls._class_name) + data = dict(("%s" % key, value) for key, value in son.iteritems()) + if not UNICODE_KWARGS: + # python 2.6.4 and lower cannot handle unicode keys + # passed to class constructor example: cls(**data) + to_str_keys_recursive(data) # Return correct subclass for document type if class_name != cls._class_name: cls = get_document(class_name) - return cls(_son=son) + changed_fields = [] + errors_dict = {} + + fields = cls._fields + if not _auto_dereference: + fields = copy.copy(fields) + + for field_name, field in fields.iteritems(): + field._auto_dereference = _auto_dereference + if field.db_field in data: + value = data[field.db_field] + try: + data[field_name] = (value if value is None + else field.to_python(value)) + if field_name != field.db_field: + del data[field.db_field] + except (AttributeError, ValueError), e: + errors_dict[field_name] = e + elif field.default: + default = field.default + if callable(default): + default = default() + if isinstance(default, BaseDocument): + changed_fields.append(field_name) + + if errors_dict: + errors = "\n".join(["%s - %s" % (k, v) + for k, v in errors_dict.items()]) + msg = ("Invalid data to create a `%s` instance.\n%s" + % (cls._class_name, errors)) + raise InvalidDocumentError(msg) + + obj = cls(__auto_convert=False, **data) + obj._changed_fields = changed_fields + obj._created = False + if not _auto_dereference: + obj._fields = fields + return obj @classmethod def _build_index_specs(cls, meta_indexes): @@ -406,8 +629,10 @@ class BaseDocument(object): # Check to see if we need to include _cls allow_inheritance = cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) - include_cls = allow_inheritance and not spec.get('sparse', False) - + include_cls = (allow_inheritance and not spec.get('sparse', False) and + spec.get('cls', True)) + if "cls" in spec: + spec.pop('cls') for key in spec['fields']: # If inherited spec continue if isinstance(key, (list, tuple)): @@ -537,7 +762,7 @@ class BaseDocument(object): for field_name in parts: # Handle ListField indexing: - if field_name.isdigit(): + if field_name.isdigit() and hasattr(field, 'field'): new_field = field.field fields.append(field_name) continue @@ -549,9 +774,9 @@ class BaseDocument(object): field_name = cls._meta['id_field'] if field_name in cls._fields: field = cls._fields[field_name] - #elif cls._dynamic: - # DynamicField = _import_class('DynamicField') - # field = DynamicField(db_field=field_name) + elif cls._dynamic: + DynamicField = _import_class('DynamicField') + field = DynamicField(db_field=field_name) else: raise LookUpError('Cannot resolve field "%s"' % field_name) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 168c063..c6abd02 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -59,17 +59,15 @@ class BaseField(object): :param help_text: (optional) The help text for this field and is often used when generating model forms from the document model. """ - self.name = None # filled in by document - self.db_field = db_field + self.db_field = (db_field or name) if not primary_key else '_id' + if name: + msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" + warnings.warn(msg, DeprecationWarning) self.required = required or primary_key self.default = default self.unique = bool(unique or unique_with) self.unique_with = unique_with self.primary_key = primary_key - if self.primary_key: - if self.db_field: - raise ValueError("Can't use primary_key in conjunction with db_field.") - self.db_field = '_id' self.validation = validation self.choices = choices self.verbose_name = verbose_name @@ -84,52 +82,41 @@ class BaseField(object): BaseField.creation_counter += 1 def __get__(self, instance, owner): + """Descriptor for retrieving a value from a field in a document. + """ if instance is None: # Document class being used rather than a document object return self - else: - name = self.name - data = instance._internal_data - if not name in data: - if instance._lazy and name != instance._meta['id_field']: - # We need to fetch the doc from the database. - instance.reload() - db_field = instance._db_field_map.get(name, name) - try: - db_value = instance._db_data[db_field] - except (TypeError, KeyError): - value = self.default() if callable(self.default) else self.default - else: - value = self.to_python(db_value) - if hasattr(self, 'value_for_instance'): - value = self.value_for_instance(value, instance) - data[name] = value + # Get value from document instance if available + value = instance._data.get(self.name) - return data[name] + EmbeddedDocument = _import_class('EmbeddedDocument') + if isinstance(value, EmbeddedDocument) and value._instance is None: + value._instance = weakref.proxy(instance) + return value def __set__(self, instance, value): """Descriptor for assigning a value to a field in a document. """ - if instance._lazy: - # Fetch the from the database before we assign to a lazy object. - instance.reload() + # If setting to None and theres a default + # Then set the value to the default value + if value is None and self.default is not None: + value = self.default + if callable(value): + value = value() - name = self.name - - value = self.from_python(value) - if hasattr(self, 'value_for_instance'): - value = self.value_for_instance(value, instance) - try: - has_changed = name not in instance._internal_data or instance._internal_data[name] != value - except: # Values can't be compared eg: naive and tz datetimes - has_changed = True - - if has_changed: - instance._mark_as_changed(name) - - instance._internal_data[name] = value + if instance._initialised: + try: + if (self.name not in instance._data or + instance._data[self.name] != value): + instance._mark_as_changed(self.name) + except: + # Values cant be compared eg: naive and tz datetimes + # So mark it as changed + instance._mark_as_changed(self.name) + instance._data[self.name] = value def error(self, message="", errors=None, field_name=None): """Raises a ValidationError. @@ -145,15 +132,7 @@ class BaseField(object): def to_mongo(self, value): """Convert a Python type to a MongoDB-compatible type. """ - return value - - def from_python(self, value): - """Convert a raw Python value (in an assignment) to the internal - Python representation. - """ - if value == None: - return self.default() if callable(self.default) else self.default - return value + return self.to_python(value) def prepare_query_value(self, op, value): """Prepare a value that is being used in a query for PyMongo. @@ -208,6 +187,50 @@ class ComplexBaseField(BaseField): field = None + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + dereference = (self._auto_dereference and + (self.field is None or isinstance(self.field, + (GenericReferenceField, ReferenceField)))) + + _dereference = _import_class("DeReference")() + + self._auto_dereference = instance._fields[self.name]._auto_dereference + if instance._initialised and dereference: + instance._data[self.name] = _dereference( + instance._data.get(self.name), max_depth=1, instance=instance, + name=self.name + ) + + value = super(ComplexBaseField, self).__get__(instance, owner) + + # Convert lists / values so we can watch for any changes on them + if (isinstance(value, (list, tuple)) and + not isinstance(value, BaseList)): + value = BaseList(value, instance, self.name) + instance._data[self.name] = value + elif isinstance(value, dict) and not isinstance(value, BaseDict): + value = BaseDict(value, instance, self.name) + instance._data[self.name] = value + + if (self._auto_dereference and instance._initialised and + isinstance(value, (BaseList, BaseDict)) + and not value._dereferenced): + value = _dereference( + value, max_depth=1, instance=instance, name=self.name + ) + value._dereferenced = True + instance._data[self.name] = value + + return value + def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. """ @@ -366,10 +389,12 @@ class ObjectIdField(BaseField): """ def to_python(self, value): + if not isinstance(value, ObjectId): + value = ObjectId(value) return value def to_mongo(self, value): - if value and not isinstance(value, ObjectId): + if not isinstance(value, ObjectId): try: return ObjectId(unicode(value)) except Exception, e: diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 34a8a51..ff5afdd 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -91,11 +91,12 @@ class DocumentMetaclass(type): attrs['_fields'] = doc_fields attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) for k, v in doc_fields.iteritems()]) + attrs['_reverse_db_field_map'] = dict( + (v, k) for k, v in attrs['_db_field_map'].iteritems()) + attrs['_fields_ordered'] = tuple(i[1] for i in sorted( (v.creation_counter, v.name) for v in doc_fields.itervalues())) - attrs['_reverse_db_field_map'] = dict( - (v, k) for k, v in attrs['_db_field_map'].iteritems()) # # Set document hierarchy @@ -358,15 +359,17 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class.id = field # Set primary key if not defined by the document + new_class._auto_id_field = False if not new_class._meta.get('id_field'): - id_field = ObjectIdField(primary_key=True) - id_field.name = 'id' - id_field._auto_gen = True - new_class._fields['id'] = id_field - new_class.id = new_class._fields['id'] + new_class._auto_id_field = True new_class._meta['id_field'] = 'id' - new_class._db_field_map['id'] = id_field.db_field + new_class._fields['id'] = ObjectIdField(db_field='_id') + new_class._fields['id'].name = 'id' + new_class.id = new_class._fields['id'] + # Prepend id field to _fields_ordered + if 'id' in new_class._fields and 'id' not in new_class._fields_ordered: + new_class._fields_ordered = ('id', ) + new_class._fields_ordered # Merge in exceptions with parent hierarchy exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) diff --git a/mongoengine/base/proxy.py b/mongoengine/base/proxy.py deleted file mode 100644 index 7d2879b..0000000 --- a/mongoengine/base/proxy.py +++ /dev/null @@ -1,190 +0,0 @@ -from mongoengine.queryset import OperationError -from bson.dbref import DBRef - -class LocalProxy(object): - # From werkzeug/local.py - - """ Forwards all operations to - a proxied object. The only operations not supported for forwarding - are right handed operands and any kind of assignment. - """ - - __slots__ = ('__local', '__dict__', '__name__') - - def __init__(self, local, name=None): - object.__setattr__(self, '_LocalProxy__local', local) - object.__setattr__(self, '__name__', name) - - def _get_current_object(self): - """Return the current object. This is useful if you want the real - object behind the proxy at a time for performance reasons or because - you want to pass the object into a different context. - """ - if not hasattr(self.__local, '__release_local__'): - return self.__local() - try: - return getattr(self.__local, self.__name__) - except AttributeError: - raise RuntimeError('no object bound to %s' % self.__name__) - - @property - def __dict__(self): - try: - return self._get_current_object().__dict__ - except RuntimeError: - raise AttributeError('__dict__') - - def __repr__(self): - try: - obj = self._get_current_object() - except RuntimeError: - return '<%s unbound>' % self.__class__.__name__ - return repr(obj) - - def __nonzero__(self): - try: - return bool(self._get_current_object()) - except RuntimeError: - return False - - def __unicode__(self): - try: - return unicode(self._get_current_object()) - except RuntimeError: - return repr(self) - - def __dir__(self): - try: - return dir(self._get_current_object()) - except RuntimeError: - return [] - - def __getattr__(self, name): - if name == '__members__': - return dir(self._get_current_object()) - return getattr(self._get_current_object(), name) - - def __setitem__(self, key, value): - self._get_current_object()[key] = value - - def __delitem__(self, key): - del self._get_current_object()[key] - - def __setslice__(self, i, j, seq): - self._get_current_object()[i:j] = seq - - def __delslice__(self, i, j): - del self._get_current_object()[i:j] - - __setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v) - __delattr__ = lambda x, n: delattr(x._get_current_object(), n) - __str__ = lambda x: str(x._get_current_object()) - __lt__ = lambda x, o: x._get_current_object() < o - __le__ = lambda x, o: x._get_current_object() <= o - __eq__ = lambda x, o: x._get_current_object() == o - __ne__ = lambda x, o: x._get_current_object() != o - __gt__ = lambda x, o: x._get_current_object() > o - __ge__ = lambda x, o: x._get_current_object() >= o - __cmp__ = lambda x, o: cmp(x._get_current_object(), o) - __hash__ = lambda x: hash(x._get_current_object()) - __call__ = lambda x, *a, **kw: x._get_current_object()(*a, **kw) - __len__ = lambda x: len(x._get_current_object()) - __getitem__ = lambda x, i: x._get_current_object()[i] - __iter__ = lambda x: iter(x._get_current_object()) - __contains__ = lambda x, i: i in x._get_current_object() - __getslice__ = lambda x, i, j: x._get_current_object()[i:j] - __add__ = lambda x, o: x._get_current_object() + o - __sub__ = lambda x, o: x._get_current_object() - o - __mul__ = lambda x, o: x._get_current_object() * o - __floordiv__ = lambda x, o: x._get_current_object() // o - __mod__ = lambda x, o: x._get_current_object() % o - __divmod__ = lambda x, o: x._get_current_object().__divmod__(o) - __pow__ = lambda x, o: x._get_current_object() ** o - __lshift__ = lambda x, o: x._get_current_object() << o - __rshift__ = lambda x, o: x._get_current_object() >> o - __and__ = lambda x, o: x._get_current_object() & o - __xor__ = lambda x, o: x._get_current_object() ^ o - __or__ = lambda x, o: x._get_current_object() | o - __div__ = lambda x, o: x._get_current_object().__div__(o) - __truediv__ = lambda x, o: x._get_current_object().__truediv__(o) - __neg__ = lambda x: -(x._get_current_object()) - __pos__ = lambda x: +(x._get_current_object()) - __abs__ = lambda x: abs(x._get_current_object()) - __invert__ = lambda x: ~(x._get_current_object()) - __complex__ = lambda x: complex(x._get_current_object()) - __int__ = lambda x: int(x._get_current_object()) - __long__ = lambda x: long(x._get_current_object()) - __float__ = lambda x: float(x._get_current_object()) - __oct__ = lambda x: oct(x._get_current_object()) - __hex__ = lambda x: hex(x._get_current_object()) - __index__ = lambda x: x._get_current_object().__index__() - __coerce__ = lambda x, o: x.__coerce__(x, o) - __enter__ = lambda x: x.__enter__() - __exit__ = lambda x, *a, **kw: x.__exit__(*a, **kw) - - -class DocumentProxy(LocalProxy): - __slots__ = ('__document_type', '__document', '__pk') - - def __init__(self, document_type, pk): - object.__setattr__(self, '_DocumentProxy__document_type', document_type) - object.__setattr__(self, '_DocumentProxy__document', None) - object.__setattr__(self, '_DocumentProxy__pk', pk) - object.__setattr__(self, document_type._meta['id_field'], self.pk) - - @property - def __class__(self): - # We need to fetch the object to determine to which class it belongs. - return self._get_current_object().__class__ - - def _lazy(): - def fget(self): - return self.__document._lazy if self.__document else True - def fset(self, value): - self._get_current_object()._lazy = value - return property(fget, fset) - _lazy = _lazy() - - # copy normally updates __dict__ which would result in errors - def __setstate__(self, state): - for k, v in state[1].iteritems(): - object.__setattr__(self, k, v) - - def _get_collection_name(self): - return self.__document_type._meta.get('collection', None) - - def __eq__(self, other): - if other and hasattr(other, '_get_collection_name') and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk'): - if self.pk == other.pk: - return True - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def to_dbref(self): - """Returns an instance of :class:`~bson.dbref.DBRef` useful in - `__raw__` queries.""" - if not self.pk: - msg = "Only saved documents can have a valid dbref" - raise OperationError(msg) - return DBRef(self._get_collection_name(), self.pk) - - def pk(): - def fget(self): - return self.__document.pk if self.__document else self.__pk - def fset(self, value): - self._get_current_object().pk = value - return property(fget, fset) - pk = pk() - - def _get_current_object(self): - if self.__document == None: - collection = self.__document_type._get_collection() - son = collection.find_one({'_id': self.__pk}) - document = self.__document_type._from_son(son) - object.__setattr__(self, '_DocumentProxy__document', document) - return self.__document - - def __nonzero__(self): - return True diff --git a/mongoengine/common.py b/mongoengine/common.py index 20d5138..6303231 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -23,8 +23,9 @@ def _import_class(cls_name): field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', 'FileField', 'GenericReferenceField', 'GenericEmbeddedDocumentField', 'GeoPointField', - 'PointField', 'LineStringField', 'PolygonField', - 'ReferenceField', 'StringField', 'ComplexBaseField') + 'PointField', 'LineStringField', 'ListField', + 'PolygonField', 'ReferenceField', 'StringField', + 'ComplexBaseField') queryset_classes = ('OperationError',) deref_classes = ('DeReference',) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index abab269..4275da5 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -55,12 +55,9 @@ def register_connection(alias, name, host='localhost', port=27017, # Handle uri style connections if "://" in host: uri_dict = uri_parser.parse_uri(host) - if uri_dict.get('database') is None: - raise ConnectionError("If using URI style connection include "\ - "database name in string") conn_settings.update({ 'host': host, - 'name': uri_dict.get('database'), + 'name': uri_dict.get('database') or name, 'username': uri_dict.get('username'), 'password': uri_dict.get('password'), 'read_preference': read_preference, diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 79f755f..ceda403 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -4,7 +4,7 @@ from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document) from fields import (ReferenceField, ListField, DictField, MapField) from connection import get_db from queryset import QuerySet -from document import Document +from document import Document, EmbeddedDocument class DeReference(object): @@ -33,7 +33,8 @@ class DeReference(object): self.max_depth = max_depth doc_type = None - if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)): + if instance and isinstance(instance, (Document, EmbeddedDocument, + TopLevelDocumentMetaclass)): doc_type = instance._fields.get(name) if hasattr(doc_type, 'field'): doc_type = doc_type.field @@ -86,11 +87,9 @@ class DeReference(object): for k, item in iterator: if isinstance(item, Document): for field_name, field in item._fields.iteritems(): - v = getattr(item, field_name) + v = item._data.get(field_name, None) if isinstance(v, (DBRef)): reference_map.setdefault(field.document_type, []).append(v.id) - elif isinstance(v, Document) and getattr(v, '_lazy', False): - reference_map.setdefault(field.document_type, []).append(v.pk) elif isinstance(v, (dict, SON)) and '_ref' in v: reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: @@ -171,7 +170,7 @@ class DeReference(object): return self.object_map.get(items['_ref'].id, items) elif '_cls' in items: doc = get_document(items['_cls'])._from_son(items) - doc._internal_data = self._attach_objects(doc._internal_data, depth, doc, None) + doc._data = self._attach_objects(doc._data, depth, doc, None) return doc if not hasattr(items, 'items'): @@ -195,17 +194,15 @@ class DeReference(object): data[k] = self.object_map[k] elif isinstance(v, Document): for field_name, field in v._fields.iteritems(): - v = data[k]._internal_data.get(field_name, None) + v = data[k]._data.get(field_name, None) if isinstance(v, (DBRef)): - data[k]._internal_data[field_name] = self.object_map.get(v.id, v) - elif isinstance(v, Document) and getattr(v, '_lazy', False): - data[k]._internal_data[field_name] = self.object_map.get(v.pk, v) + data[k]._data[field_name] = self.object_map.get(v.id, v) elif isinstance(v, (dict, SON)) and '_ref' in v: - data[k]._internal_data[field_name] = self.object_map.get(v['_ref'].id, v) + data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) elif isinstance(v, dict) and depth <= self.max_depth: - data[k]._internal_data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) + data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) elif isinstance(v, (list, tuple)) and depth <= self.max_depth: - data[k]._internal_data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) + data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) elif hasattr(v, 'id'): diff --git a/mongoengine/django/mongo_auth/models.py b/mongoengine/django/mongo_auth/models.py index 3529d8e..d4947a2 100644 --- a/mongoengine/django/mongo_auth/models.py +++ b/mongoengine/django/mongo_auth/models.py @@ -6,10 +6,29 @@ from django.utils.importlib import import_module from django.utils.translation import ugettext_lazy as _ +__all__ = ( + 'get_user_document', +) + + MONGOENGINE_USER_DOCUMENT = getattr( settings, 'MONGOENGINE_USER_DOCUMENT', 'mongoengine.django.auth.User') +def get_user_document(): + """Get the user document class used for authentication. + + This is the class defined in settings.MONGOENGINE_USER_DOCUMENT, which + defaults to `mongoengine.django.auth.User`. + + """ + + name = MONGOENGINE_USER_DOCUMENT + dot = name.rindex('.') + module = import_module(name[:dot]) + return getattr(module, name[dot + 1:]) + + class MongoUserManager(UserManager): """A User manager wich allows the use of MongoEngine documents in Django. @@ -44,7 +63,7 @@ class MongoUserManager(UserManager): def contribute_to_class(self, model, name): super(MongoUserManager, self).contribute_to_class(model, name) self.dj_model = self.model - self.model = self._get_user_document() + self.model = get_user_document() self.dj_model.USERNAME_FIELD = self.model.USERNAME_FIELD username = models.CharField(_('username'), max_length=30, unique=True) @@ -55,16 +74,6 @@ class MongoUserManager(UserManager): field = models.CharField(_(name), max_length=30) field.contribute_to_class(self.dj_model, name) - def _get_user_document(self): - try: - name = MONGOENGINE_USER_DOCUMENT - dot = name.rindex('.') - module = import_module(name[:dot]) - return getattr(module, name[dot + 1:]) - except ImportError: - raise ImproperlyConfigured("Error importing %s, please check " - "settings.MONGOENGINE_USER_DOCUMENT" - % name) def get(self, *args, **kwargs): try: @@ -85,5 +94,14 @@ class MongoUserManager(UserManager): class MongoUser(models.Model): - objects = MongoUserManager() + """"Dummy user model for Django. + MongoUser is used to replace Django's UserManager with MongoUserManager. + The actual user document class is mongoengine.django.auth.User or any + other document class specified in MONGOENGINE_USER_DOCUMENT. + + To get the user document class, use `get_user_document()`. + + """ + + objects = MongoUserManager() diff --git a/mongoengine/django/sessions.py b/mongoengine/django/sessions.py index c90807e..7e4e182 100644 --- a/mongoengine/django/sessions.py +++ b/mongoengine/django/sessions.py @@ -1,7 +1,10 @@ from django.conf import settings from django.contrib.sessions.backends.base import SessionBase, CreateError from django.core.exceptions import SuspiciousOperation -from django.utils.encoding import force_unicode +try: + from django.utils.encoding import force_unicode +except ImportError: + from django.utils.encoding import force_text as force_unicode from mongoengine.document import Document from mongoengine import fields diff --git a/mongoengine/django/storage.py b/mongoengine/django/storage.py index 341455c..9df6f9e 100644 --- a/mongoengine/django/storage.py +++ b/mongoengine/django/storage.py @@ -76,7 +76,7 @@ class GridFSStorage(Storage): """Find the documents in the store with the given name """ docs = self.document.objects - doc = [d for d in docs if getattr(d, self.field).name == name] + doc = [d for d in docs if hasattr(getattr(d, self.field), 'name') and getattr(d, self.field).name == name] if doc: return doc[0] else: diff --git a/mongoengine/document.py b/mongoengine/document.py index b553097..1bbd7b7 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -12,7 +12,7 @@ from mongoengine.common import _import_class from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) -from mongoengine.queryset import OperationError, NotUniqueError, QuerySet, DoesNotExist +from mongoengine.queryset import OperationError, NotUniqueError, QuerySet from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME from mongoengine.context_managers import switch_db, switch_collection @@ -20,7 +20,6 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') -_set = object.__setattr__ def includes_cls(fields): """ Helper function used for ensuring and comparing indexes @@ -63,11 +62,11 @@ class EmbeddedDocument(BaseDocument): def __init__(self, *args, **kwargs): super(EmbeddedDocument, self).__init__(*args, **kwargs) - self._changed_fields = set() + self._changed_fields = [] def __eq__(self, other): if isinstance(other, self.__class__): - return self.to_dict() == other.to_dict() + return self._data == other._data return False def __ne__(self, other): @@ -178,13 +177,15 @@ class Document(BaseDocument): cls.ensure_indexes() return cls._collection - def save(self, validate=True, clean=True, + def save(self, force_insert=False, validate=True, clean=True, write_concern=None, cascade=None, cascade_kwargs=None, - _refs=None, full=False, **kwargs): + _refs=None, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. + :param force_insert: only try to create a new document, don't allow + updates of existing documents :param validate: validates the document; set to ``False`` to skip. :param clean: call the document clean method, requires `validate` to be True. @@ -201,7 +202,6 @@ class Document(BaseDocument): :param cascade_kwargs: (optional) kwargs dictionary to be passed throw to cascading saves. Implies ``cascade=True``. :param _refs: A list of processed references used in cascading saves - :param full: Save all model fields instead of just changed ones. .. versionchanged:: 0.5 In existing documents it only saves changed fields using @@ -217,52 +217,61 @@ class Document(BaseDocument): the cascade save using cascade_kwargs which overwrites the existing kwargs with custom values. """ - signals.pre_save.send(self.__class__, document=self) if validate: self.validate(clean=clean) - if not write_concern: - write_concern = {'w': 1} + if write_concern is None: + write_concern = {"w": 1} + + doc = self.to_mongo() + + created = ('_id' not in doc or self._created or force_insert) + + signals.pre_save_post_validation.send(self.__class__, document=self, created=created) - collection = self._get_collection() try: - if self._created: - # Update: Get delta. - sets, unsets = self._delta(full) - db_id_field = self._fields[self._meta['id_field']].db_field - sets.pop(db_id_field, None) + collection = self._get_collection() + if created: + if force_insert: + object_id = collection.insert(doc, **write_concern) + else: + object_id = collection.save(doc, **write_concern) + else: + object_id = doc['_id'] + updates, removals = self._delta() + # Need to add shard key to query, or you get an error + select_dict = {'_id': object_id} + shard_key = self.__class__._meta.get('shard_key', tuple()) + for k in shard_key: + actual_key = self._db_field_map.get(k, k) + select_dict[actual_key] = doc[actual_key] + + def is_new_object(last_error): + if last_error is not None: + updated = last_error.get("updatedExisting") + if updated is not None: + return not updated + return created update_query = {} - if sets: - update_query['$set'] = sets - if unsets: - update_query['$unset'] = unsets - if update_query: - collection.update(self._db_object_key, update_query, **write_concern) + if updates: + update_query["$set"] = updates + if removals: + update_query["$unset"] = removals + if updates or removals: + last_error = collection.update(select_dict, update_query, + upsert=True, **write_concern) + created = is_new_object(last_error) - created = False - else: - # Insert: Get full SON. - doc = self.to_mongo() - object_id = collection.insert(doc, **write_concern) - # Fix pymongo's "return return_one and ids[0] or ids": - # If the ID is 0, pymongo wraps it in a list. - if isinstance(object_id, list) and not object_id[0]: - object_id = object_id[0] + if cascade is None: + cascade = self._meta.get('cascade', False) or cascade_kwargs is not None - id_field = self._meta['id_field'] - del self._internal_data[id_field] - _set(self, '_db_data', doc) - doc['_id'] = object_id - - created = True - cascade = (self._meta.get('cascade', False) - if cascade is None else cascade) if cascade: kwargs = { + "force_insert": force_insert, "validate": validate, "write_concern": write_concern, "cascade": cascade @@ -280,9 +289,12 @@ class Document(BaseDocument): message = u'Tried to save duplicate unique keys (%s)' raise NotUniqueError(message % unicode(err)) raise OperationError(message % unicode(err)) + id_field = self._meta['id_field'] + if id_field not in self._meta.get('shard_key', []): + self[id_field] = self._fields[id_field].to_python(object_id) self._clear_changed_fields() - + self._created = False signals.post_save.send(self.__class__, document=self, created=created) return self @@ -299,17 +311,14 @@ class Document(BaseDocument): GenericReferenceField)): continue - ref = getattr(self, name) + ref = self._data.get(name) if not ref or isinstance(ref, DBRef): continue if not getattr(ref, '_changed_fields', True): continue - if getattr(ref, '_lazy', False): - continue - - ref_id = "%s,%s" % (ref.__class__.__name__, str(ref.to_dict())) + ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data)) if ref and ref_id not in _refs: _refs.append(ref_id) kwargs["_refs"] = _refs @@ -335,16 +344,6 @@ class Document(BaseDocument): select_dict[k] = getattr(self, k) return select_dict - @property - def _db_object_key(self): - field = self._fields[self._meta['id_field']] - select_dict = {field.db_field: field.to_mongo(self.pk)} - shard_key = self.__class__._meta.get('shard_key', tuple()) - for k in shard_key: - actual_key = self._db_field_map.get(k, k) - select_dict[actual_key] = self._fields[k].to_mongo(getattr(self, k)) - return select_dict - def update(self, **kwargs): """Performs an update on the :class:`~mongoengine.Document` A convenience wrapper to :meth:`~mongoengine.QuerySet.update`. @@ -377,9 +376,6 @@ class Document(BaseDocument): """ signals.pre_delete.send(self.__class__, document=self) - if not write_concern: - write_concern = {'w': 1} - try: self._qs.filter(**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) except pymongo.errors.OperationFailure, err: @@ -404,11 +400,11 @@ class Document(BaseDocument): """ with switch_db(self.__class__, db_alias) as cls: collection = cls._get_collection() - db = cls._get_db + db = cls._get_db() self._get_collection = lambda: collection self._get_db = lambda: db self._collection = collection - #self._created = True + self._created = True self.__objects = self._qs self.__objects._collection_obj = collection return self @@ -433,7 +429,7 @@ class Document(BaseDocument): collection = cls._get_collection() self._get_collection = lambda: collection self._collection = collection - #self._created = True + self._created = True self.__objects = self._qs self.__objects._collection_obj = collection return self @@ -444,22 +440,44 @@ class Document(BaseDocument): .. versionadded:: 0.5 """ - import dereference - self._internal_data = dereference.DeReference()(self._internal_data, max_depth) + DeReference = _import_class('DeReference') + DeReference()([self], max_depth + 1) return self - def reload(self): + def reload(self, max_depth=1): """Reloads all attributes from the database. + + .. versionadded:: 0.1.2 + .. versionchanged:: 0.6 Now chainable """ - collection = self._get_collection() - son = collection.find_one(self._db_object_key, read_preference=ReadPreference.PRIMARY) - if son == None: - raise DoesNotExist('Document has been deleted.') - _set(self, '_db_data', son) - _set(self, '_internal_data', {}) - _set(self, '_lazy', False) - self._clear_changed_fields() - return self + obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( + **self._object_key).limit(1).select_related(max_depth=max_depth) + + if obj: + obj = obj[0] + else: + msg = "Reloaded document has been deleted" + raise OperationError(msg) + for field in self._fields_ordered: + setattr(self, field, self._reload(field, obj[field])) + self._changed_fields = obj._changed_fields + self._created = False + return obj + + def _reload(self, key, value): + """Used by :meth:`~mongoengine.Document.reload` to ensure the + correct instance is linked to self. + """ + if isinstance(value, BaseDict): + value = [(k, self._reload(k, v)) for k, v in value.items()] + value = BaseDict(value, self, key) + elif isinstance(value, BaseList): + value = [self._reload(key, v) for v in value] + value = BaseList(value, self, key) + elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): + value._instance = None + value._changed_fields = [] + return value def to_dbref(self): """Returns an instance of :class:`~bson.dbref.DBRef` useful in @@ -518,6 +536,8 @@ class Document(BaseDocument): def ensure_indexes(cls): """Checks the document meta data and ensures all the indexes exist. + Global defaults can be set in the meta - see :doc:`guide/defining-documents` + .. note:: You can disable automatic index creation by setting `auto_create_index` to False in the documents meta data """ @@ -659,8 +679,6 @@ class DynamicDocument(Document): _dynamic = True - # TODO - def __delattr__(self, *args, **kwargs): """Deletes the attribute by setting to None and allowing _delta to unset it""" @@ -684,8 +702,6 @@ class DynamicEmbeddedDocument(EmbeddedDocument): _dynamic = True - # TODO - def __delattr__(self, *args, **kwargs): """Deletes the attribute by setting to None and allowing _delta to unset it""" diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 57c4154..419f2ef 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -3,7 +3,6 @@ import decimal import itertools import re import time -import types import urllib2 import uuid import warnings @@ -23,11 +22,8 @@ from bson import Binary, DBRef, SON, ObjectId from mongoengine.errors import ValidationError from mongoengine.python_support import (PY3, bin_type, txt_type, str_types, StringIO) -from mongoengine.base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, +from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, get_document, BaseDocument) -from mongoengine.base.datastructures import BaseList, BaseDict -from mongoengine.base.proxy import DocumentProxy -from mongoengine.queryset import DoesNotExist from queryset import DO_NOTHING, QuerySet from document import Document, EmbeddedDocument from connection import get_db, DEFAULT_CONNECTION_NAME @@ -38,12 +34,11 @@ except ImportError: Image = None ImageOps = None -__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', - 'FloatField', 'BooleanField', 'DateTimeField', +__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField', + 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', 'SortedListField', 'DictField', 'MapField', 'ReferenceField', - 'SafeReferenceField', 'SafeReferenceListField', 'GenericReferenceField', 'BinaryField', 'GridFSError', 'GridFSProxy', 'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', @@ -63,6 +58,15 @@ class StringField(BaseField): self.min_length = min_length super(StringField, self).__init__(**kwargs) + def to_python(self, value): + if isinstance(value, unicode): + return value + try: + value = value.decode('utf-8') + except: + pass + return value + def validate(self, value): if not isinstance(value, basestring): self.error('StringField only accepts string values') @@ -117,7 +121,8 @@ class URLField(StringField): r'(?::\d+)?' # optional port r'(?:/?|[/?]\S+)$', re.IGNORECASE) - def __init__(self, url_regex=None, **kwargs): + def __init__(self, verify_exists=False, url_regex=None, **kwargs): + self.verify_exists = verify_exists self.url_regex = url_regex or self._URL_REGEX super(URLField, self).__init__(**kwargs) @@ -126,31 +131,50 @@ class URLField(StringField): self.error('Invalid URL: %s' % value) return + if self.verify_exists: + warnings.warn( + "The URLField verify_exists argument has intractable security " + "and performance issues. Accordingly, it has been deprecated.", + DeprecationWarning) + try: + request = urllib2.Request(value) + urllib2.urlopen(request) + except Exception, e: + self.error('This URL appears to be a broken link: %s' % e) + class EmailField(StringField): - """A field that validates input as an email address. + """A field that validates input as an E-Mail-Address. .. versionadded:: 0.4 """ - EMAIL_REGEX = re.compile(r'^.+@[^.].*\.[a-z]{2,10}$', re.IGNORECASE) + EMAIL_REGEX = re.compile( + r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain + ) def validate(self, value): if not EmailField.EMAIL_REGEX.match(value): - self.error('Invalid email address: %s' % value) + self.error('Invalid Mail-address: %s' % value) super(EmailField, self).validate(value) class IntField(BaseField): - """An integer field. + """An 32-bit integer field. """ def __init__(self, min_value=None, max_value=None, **kwargs): self.min_value, self.max_value = min_value, max_value super(IntField, self).__init__(**kwargs) - def from_python(self, value): - return self.prepare_query_value(None, value) + def to_python(self, value): + try: + value = int(value) + except ValueError: + pass + return value def validate(self, value): try: @@ -167,18 +191,59 @@ class IntField(BaseField): def prepare_query_value(self, op, value): if value is None: return value - else: - return int(value) + + return int(value) + + +class LongField(BaseField): + """An 64-bit integer field. + """ + + def __init__(self, min_value=None, max_value=None, **kwargs): + self.min_value, self.max_value = min_value, max_value + super(LongField, self).__init__(**kwargs) + + def to_python(self, value): + try: + value = long(value) + except ValueError: + pass + return value + + def validate(self, value): + try: + value = long(value) + except: + self.error('%s could not be converted to long' % value) + + if self.min_value is not None and value < self.min_value: + self.error('Long value is too small') + + if self.max_value is not None and value > self.max_value: + self.error('Long value is too large') + + def prepare_query_value(self, op, value): + if value is None: + return value + + return long(value) class FloatField(BaseField): - """A floating point number field. + """An floating point number field. """ def __init__(self, min_value=None, max_value=None, **kwargs): self.min_value, self.max_value = min_value, max_value super(FloatField, self).__init__(**kwargs) + def to_python(self, value): + try: + value = float(value) + except ValueError: + pass + return value + def validate(self, value): if isinstance(value, int): value = float(value) @@ -191,6 +256,82 @@ class FloatField(BaseField): if self.max_value is not None and value > self.max_value: self.error('Float value is too large') + def prepare_query_value(self, op, value): + if value is None: + return value + + return float(value) + + +class DecimalField(BaseField): + """A fixed-point decimal number field. + + .. versionchanged:: 0.8 + .. versionadded:: 0.3 + """ + + def __init__(self, min_value=None, max_value=None, force_string=False, + precision=2, rounding=decimal.ROUND_HALF_UP, **kwargs): + """ + :param min_value: Validation rule for the minimum acceptable value. + :param max_value: Validation rule for the maximum acceptable value. + :param force_string: Store as a string. + :param precision: Number of decimal places to store. + :param rounding: The rounding rule from the python decimal libary: + + - decimal.ROUND_CEILING (towards Infinity) + - decimal.ROUND_DOWN (towards zero) + - decimal.ROUND_FLOOR (towards -Infinity) + - decimal.ROUND_HALF_DOWN (to nearest with ties going towards zero) + - decimal.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer) + - decimal.ROUND_HALF_UP (to nearest with ties going away from zero) + - decimal.ROUND_UP (away from zero) + - decimal.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero) + + Defaults to: ``decimal.ROUND_HALF_UP`` + + """ + self.min_value = min_value + self.max_value = max_value + self.force_string = force_string + self.precision = decimal.Decimal(".%s" % ("0" * precision)) + self.rounding = rounding + + super(DecimalField, self).__init__(**kwargs) + + def to_python(self, value): + if value is None: + return value + + # Convert to string for python 2.6 before casting to Decimal + value = decimal.Decimal("%s" % value) + return value.quantize(self.precision, rounding=self.rounding) + + def to_mongo(self, value): + if value is None: + return value + if self.force_string: + return unicode(value) + return float(self.to_python(value)) + + def validate(self, value): + if not isinstance(value, decimal.Decimal): + if not isinstance(value, basestring): + value = unicode(value) + try: + value = decimal.Decimal(value) + except Exception, exc: + self.error('Could not convert value to decimal: %s' % exc) + + if self.min_value is not None and value < self.min_value: + self.error('Decimal value is too small') + + if self.max_value is not None and value > self.max_value: + self.error('Decimal value is too large') + + def prepare_query_value(self, op, value): + return self.to_mongo(value) + class BooleanField(BaseField): """A boolean field type. @@ -198,6 +339,13 @@ class BooleanField(BaseField): .. versionadded:: 0.1.2 """ + def to_python(self, value): + try: + value = bool(value) + except ValueError: + pass + return value + def validate(self, value): if not isinstance(value, bool): self.error('BooleanField only accepts boolean values') @@ -218,13 +366,11 @@ class DateTimeField(BaseField): """ def validate(self, value): - if not isinstance(value, (datetime.datetime, datetime.date)): + new_value = self.to_mongo(value) + if not isinstance(new_value, (datetime.datetime, datetime.date)): self.error(u'cannot parse date "%s"' % value) - def from_python(self, value): - return self.prepare_query_value(None, value) or value - - def prepare_query_value(self, op, value): + def to_mongo(self, value): if value is None: return value if isinstance(value, datetime.datetime): @@ -268,6 +414,9 @@ class DateTimeField(BaseField): except ValueError: return None + def prepare_query_value(self, op, value): + return self.to_mongo(value) + class ComplexDateTimeField(StringField): """ @@ -288,8 +437,6 @@ class ComplexDateTimeField(StringField): .. versionadded:: 0.5 """ - # TODO - def __init__(self, separator=',', **kwargs): self.names = ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond'] @@ -395,11 +542,15 @@ class EmbeddedDocumentField(BaseField): self.document_type_obj = get_document(self.document_type_obj) return self.document_type_obj - def to_python(self, val): - return self.document_type._from_son(val) + def to_python(self, value): + if not isinstance(value, self.document_type): + return self.document_type._from_son(value) + return value - def to_mongo(self, val): - return val and val.to_mongo() + def to_mongo(self, value): + if not isinstance(value, self.document_type): + return value + return self.document_type.to_mongo(value) def validate(self, value, clean=True): """Make sure that the document instance is an instance of the @@ -433,8 +584,9 @@ class GenericEmbeddedDocumentField(BaseField): return self.to_mongo(value) def to_python(self, value): - doc_cls = get_document(value['_cls']) - value = doc_cls._from_son(value) + if isinstance(value, dict): + doc_cls = get_document(value['_cls']) + value = doc_cls._from_son(value) return value @@ -472,7 +624,9 @@ class DynamicField(BaseField): cls = value.__class__ val = value.to_mongo() # If we its a document thats not inherited add _cls - if (isinstance(value, (Document, EmbeddedDocument))): + if (isinstance(value, Document)): + val = {"_ref": value.to_dbref(), "_cls": cls.__name__} + if (isinstance(value, EmbeddedDocument)): val['_cls'] = cls.__name__ return val @@ -493,6 +647,15 @@ class DynamicField(BaseField): value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))] return value + def to_python(self, value): + if isinstance(value, dict) and '_cls' in value: + doc_cls = get_document(value['_cls']) + if '_ref' in value: + value = doc_cls._get_db().dereference(value['_ref']) + return doc_cls._from_son(value) + + return super(DynamicField, self).to_python(value) + def lookup_member(self, member_name): return member_name @@ -522,26 +685,6 @@ class ListField(ComplexBaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def value_for_instance(self, value, instance, name=None): - name = name or self.name - if value and self.field: - value_for_instance = getattr(self.field, 'value_for_instance', None) - if value_for_instance: - value = [value_for_instance(v, instance, name) for v in value] - return BaseList(value or [], instance, name) - - def from_python(self, val): - from_python = getattr(self.field, 'from_python', None) - return [from_python(v) for v in val] if from_python else val - - def to_python(self, val): - to_python = getattr(self.field, 'to_python', None) - return [to_python(v) for v in val] if to_python and val else val or None - - def to_mongo(self, val): - to_mongo = getattr(self.field, 'to_mongo', None) - return [to_mongo(v) for v in val] if to_mongo and val else val or None - def validate(self, value): """Make sure that a list of valid fields is being used. """ @@ -557,9 +700,6 @@ class ListField(ComplexBaseField): and hasattr(value, '__iter__')): return [self.field.prepare_query_value(op, v) for v in value] return self.field.prepare_query_value(op, value) - else: - if op in ('set', 'unset'): - return value return super(ListField, self).prepare_query_value(op, value) @@ -590,13 +730,10 @@ class SortedListField(ListField): def to_mongo(self, value): value = super(SortedListField, self).to_mongo(value) - if value: - if self._ordering is not None: - return sorted(value, key=itemgetter(self._ordering), - reverse=self._order_reverse) - return sorted(value, reverse=self._order_reverse) - else: - return value + if self._ordering is not None: + return sorted(value, key=itemgetter(self._ordering), + reverse=self._order_reverse) + return sorted(value, reverse=self._order_reverse) class DictField(ComplexBaseField): @@ -618,26 +755,6 @@ class DictField(ComplexBaseField): kwargs.setdefault('default', lambda: {}) super(DictField, self).__init__(*args, **kwargs) - def from_python(self, val): - from_python = getattr(self.field, 'from_python', None) - return {k: from_python(v) for k, v in val.iteritems()} if from_python else val - - def to_python(self, val): - to_python = getattr(self.field, 'to_python', None) - return {k: to_python(v) for k, v in val.iteritems()} if to_python and val else val or None - - def value_for_instance(self, value, instance, name=None): - name = name or self.name - if value and self.field: - value_for_instance = getattr(self.field, 'value_for_instance', None) - if value_for_instance: - value = {k: value_for_instance(v, instance, name) for k, v in value.iteritems()} - return BaseDict(value or {}, instance, name) - - def to_mongo(self, val): - to_mongo = getattr(self.field, 'to_mongo', None) - return {k: to_mongo(v) for k, v in val.iteritems()} if to_mongo and val else val or None - def validate(self, value): """Make sure that a list of valid fields is being used. """ @@ -663,6 +780,10 @@ class DictField(ComplexBaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) + + if hasattr(self.field, 'field'): + return self.field.prepare_query_value(op, value) + return super(DictField, self).prepare_query_value(op, value) @@ -745,82 +866,69 @@ class ReferenceField(BaseField): self.document_type_obj = get_document(self.document_type_obj) return self.document_type_obj - def to_mongo(self, value): - if isinstance(value, DBRef): - if self.dbref: - return value - else: - return value.id - elif isinstance(value, (Document, DocumentProxy)): - document_type = self.document_type + def __get__(self, instance, owner): + """Descriptor to allow lazy dereferencing. + """ + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value = instance._data.get(self.name) + self._auto_dereference = instance._fields[self.name]._auto_dereference + # Dereference DBRefs + if self._auto_dereference and isinstance(value, DBRef): + value = self.document_type._get_db().dereference(value) + if value is not None: + instance._data[self.name] = self.document_type._from_son(value) + + return super(ReferenceField, self).__get__(instance, owner) + + def to_mongo(self, document): + if isinstance(document, DBRef): + if not self.dbref: + return document.id + return document + + id_field_name = self.document_type._meta['id_field'] + id_field = self.document_type._fields[id_field_name] + + if isinstance(document, Document): # We need the id from the saved object to create the DBRef - pk = value.pk - if pk is None: + id_ = document.pk + if id_ is None: self.error('You can only reference documents once they have' ' been saved to the database') - id_field_name = document_type._meta['id_field'] - id_field = document_type._fields[id_field_name] - pk = id_field.to_mongo(pk) - if self.dbref: - collection = document_type._get_collection_name() - return DBRef(collection, pk) - else: - return pk - elif value != None: # string ID - document_type = self.document_type - collection = document_type._get_collection_name() - return DBRef(collection, value) + else: + id_ = document + + id_ = id_field.to_mongo(id_) + if self.dbref: + collection = self.document_type._get_collection_name() + return DBRef(collection, id_) + + return id_ def to_python(self, value): - if value != None: - document_type = self.document_type - if self.dbref: - pk = value.id - else: - if isinstance(value, DBRef): - pk = value.id - else: - pk = value - if document_type._meta['allow_inheritance']: - # We don't know of which type the object will be. - obj = DocumentProxy(document_type, pk) - else: - obj = document_type(pk=pk) - obj._lazy = True - return obj - - def from_python(self, value): - if isinstance(value, (BaseDocument, DocumentProxy)): - return value - elif value == None: - return super(ReferenceField, self).from_python(value) - else: - # Support for werkzeug.local.LocalProxy - if hasattr(value, '_get_current_object'): - return value._get_current_object() - else: - # DBRef or ID - document_type = self.document_type - if isinstance(value, DBRef): - pk = value.id - else: - pk = value - if document_type._meta['allow_inheritance']: - # We don't know of which type the object will be. - obj = DocumentProxy(document_type, pk) - else: - obj = document_type(pk=pk) - obj._lazy = True - return obj + """Convert a MongoDB-compatible type to a Python type. + """ + if (not self.dbref and + not isinstance(value, (DBRef, Document, EmbeddedDocument))): + collection = self.document_type._get_collection_name() + value = DBRef(collection, self.document_type.id.to_python(value)) + return value def prepare_query_value(self, op, value): - return self.to_mongo(self.from_python(value)) + if value is None: + return None + return self.to_mongo(value) def validate(self, value): - if not isinstance(value, (self.document_type, DBRef, DocumentProxy)): + + if not isinstance(value, (self.document_type, DBRef)): self.error("A ReferenceField only accepts DBRef or documents") - if isinstance(value, Document) and value.pk is None: + if isinstance(value, Document) and value.id is None: self.error('You can only reference documents once they have been ' 'saved to the database') @@ -828,52 +936,6 @@ class ReferenceField(BaseField): return self.document_type._fields.get(member_name) -class SafeReferenceField(ReferenceField): - """ - Like a ReferenceField, but doesn't return non-existing references when - dereferencing, i.e. no DBRefs are returned. This means that the next time - an object is saved, the non-existing references are removed and application - code can rely on having only valid dereferenced objects. - - When the field is referenced, the referenced object is loaded from the - database. - """ - - def to_python(self, value): - obj = super(SafeReferenceField, self).to_python(value) - if obj: - # Must dereference so we don't get an invalid ObjectId back. - try: - obj.reload() - except DoesNotExist: - return None - return obj - - -class SafeReferenceListField(ListField): - """ - Like a ListField, but doesn't return non-existing references when - dereferencing, i.e. no DBRefs are returned. This means that the next time - an object is saved, the non-existing references are removed and application - code can rely on having only valid dereferenced objects. - - When the field is referenced, all referenced objects are loaded from the - database. - - Must use ReferenceField as its field class. - """ - - def __init__(self, field, **kwargs): - if not isinstance(field, ReferenceField): - raise ValueError('Field argument must be a ReferenceField instance.') - return super(SafeReferenceListField, self).__init__(field, **kwargs) - - def to_python(self, value): - result = super(SafeReferenceListField, self).to_python(value) - if result: - objs = self.field.document_type.objects.in_bulk([obj.id for obj in result]) - return filter(None, [objs.get(obj.id) for obj in result]) - class GenericReferenceField(BaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -888,6 +950,17 @@ class GenericReferenceField(BaseField): .. versionadded:: 0.3 """ + def __get__(self, instance, owner): + if instance is None: + return self + + value = instance._data.get(self.name) + self._auto_dereference = instance._fields[self.name]._auto_dereference + if self._auto_dereference and isinstance(value, (dict, SON)): + instance._data[self.name] = self.dereference(value) + + return super(GenericReferenceField, self).__get__(instance, owner) + def validate(self, value): if not isinstance(value, (Document, DBRef, dict, SON)): self.error('GenericReferences can only contain documents') @@ -909,14 +982,6 @@ class GenericReferenceField(BaseField): doc = doc_cls._from_son(doc) return doc - def to_python(self, value): - if value != None: - doc_cls = get_document(value['_cls']) - reference = value['_ref'] - obj = doc_cls(pk=reference.id) - obj._lazy = True - return obj - def to_mongo(self, document): if document is None: return None @@ -1033,6 +1098,10 @@ class GridFSProxy(object): def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.grid_id) + def __str__(self): + name = getattr(self.get(), 'filename', self.grid_id) if self.get() else '(no file)' + return '<%s: %s>' % (self.__class__.__name__, name) + def __eq__(self, other): if isinstance(other, GridFSProxy): return ((self.grid_id == other.grid_id) and @@ -1140,9 +1209,7 @@ class FileField(BaseField): # Check if a file already exists for this model grid_file = instance._data.get(self.name) if not isinstance(grid_file, self.proxy_class): - grid_file = self.proxy_class(key=self.name, instance=instance, - db_alias=self.db_alias, - collection_name=self.collection_name) + grid_file = self.get_proxy_obj(key=self.name, instance=instance) instance._data[self.name] = grid_file if not grid_file.key: @@ -1164,15 +1231,23 @@ class FileField(BaseField): pass # Create a new proxy object as we don't already have one - instance._data[key] = self.proxy_class(key=key, instance=instance, - db_alias=self.db_alias, - collection_name=self.collection_name) + instance._data[key] = self.get_proxy_obj(key=key, instance=instance) instance._data[key].put(value) else: instance._data[key] = value instance._mark_as_changed(key) + def get_proxy_obj(self, key, instance, db_alias=None, collection_name=None): + if db_alias is None: + db_alias = self.db_alias + if collection_name is None: + collection_name = self.collection_name + + return self.proxy_class(key=key, instance=instance, + db_alias=db_alias, + collection_name=collection_name) + def to_mongo(self, value): # Store the GridFS file id in MongoDB if isinstance(value, self.proxy_class) and value.grid_id is not None: @@ -1205,6 +1280,9 @@ class ImageGridFsProxy(GridFSProxy): applying field properties (size, thumbnail_size) """ field = self.instance._fields[self.key] + # Handle nested fields + if hasattr(field, 'field') and isinstance(field.field, FileField): + field = field.field try: img = Image.open(file_obj) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py new file mode 100644 index 0000000..b4dad0c --- /dev/null +++ b/mongoengine/queryset/base.py @@ -0,0 +1,1494 @@ +from __future__ import absolute_import + +import copy +import itertools +import operator +import pprint +import re +import warnings + +from bson.code import Code +from bson import json_util +import pymongo +from pymongo.common import validate_read_preference + +from mongoengine import signals +from mongoengine.common import _import_class +from mongoengine.base.common import get_document +from mongoengine.errors import (OperationError, NotUniqueError, + InvalidQueryError, LookUpError) + +from mongoengine.queryset import transform +from mongoengine.queryset.field_list import QueryFieldList +from mongoengine.queryset.visitor import Q, QNode + + +__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') + +# Delete rules +DO_NOTHING = 0 +NULLIFY = 1 +CASCADE = 2 +DENY = 3 +PULL = 4 + +RE_TYPE = type(re.compile('')) + + +class BaseQuerySet(object): + """A set of results returned from a query. Wraps a MongoDB cursor, + providing :class:`~mongoengine.Document` objects as the results. + """ + __dereference = False + _auto_dereference = True + + def __init__(self, document, collection): + self._document = document + self._collection_obj = collection + self._mongo_query = None + self._query_obj = Q() + self._initial_query = {} + self._where_clause = None + self._loaded_fields = QueryFieldList() + self._ordering = [] + self._snapshot = False + self._timeout = True + self._class_check = True + self._slave_okay = False + self._read_preference = None + self._iter = False + self._scalar = [] + self._none = False + self._as_pymongo = False + self._as_pymongo_coerce = False + + # If inheritance is allowed, only return instances and instances of + # subclasses of the class being used + if document._meta.get('allow_inheritance') is True: + if len(self._document._subclasses) == 1: + self._initial_query = {"_cls": self._document._subclasses[0]} + else: + self._initial_query = {"_cls": {"$in": self._document._subclasses}} + self._loaded_fields = QueryFieldList(always_include=['_cls']) + self._cursor_obj = None + self._limit = None + self._skip = None + self._hint = -1 # Using -1 as None is a valid value for hint + + def __call__(self, q_obj=None, class_check=True, slave_okay=False, + read_preference=None, **query): + """Filter the selected documents by calling the + :class:`~mongoengine.queryset.QuerySet` with a query. + + :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in + the query; the :class:`~mongoengine.queryset.QuerySet` is filtered + multiple times with different :class:`~mongoengine.queryset.Q` + objects, only the last one will be used + :param class_check: If set to False bypass class name check when + querying collection + :param slave_okay: if True, allows this query to be run against a + replica secondary. + :params read_preference: if set, overrides connection-level + read_preference from `ReplicaSetConnection`. + :param query: Django-style query keyword arguments + """ + query = Q(**query) + if q_obj: + # make sure proper query object is passed + if not isinstance(q_obj, QNode): + msg = ("Not a query object: %s. " + "Did you intend to use key=value?" % q_obj) + raise InvalidQueryError(msg) + query &= q_obj + + if read_preference is None: + queryset = self.clone() + else: + # Use the clone provided when setting read_preference + queryset = self.read_preference(read_preference) + + queryset._query_obj &= query + queryset._mongo_query = None + queryset._cursor_obj = None + queryset._class_check = class_check + + return queryset + + def __getitem__(self, key): + """Support skip and limit using getitem and slicing syntax. + """ + queryset = self.clone() + + # Slice provided + if isinstance(key, slice): + try: + queryset._cursor_obj = queryset._cursor[key] + queryset._skip, queryset._limit = key.start, key.stop + if key.start and key.stop: + queryset._limit = key.stop - key.start + except IndexError, err: + # PyMongo raises an error if key.start == key.stop, catch it, + # bin it, kill it. + start = key.start or 0 + if start >= 0 and key.stop >= 0 and key.step is None: + if start == key.stop: + queryset.limit(0) + queryset._skip = key.start + queryset._limit = key.stop - start + return queryset + raise err + # Allow further QuerySet modifications to be performed + return queryset + # Integer index provided + elif isinstance(key, int): + if queryset._scalar: + return queryset._get_scalar( + queryset._document._from_son(queryset._cursor[key], + _auto_dereference=self._auto_dereference)) + if queryset._as_pymongo: + return queryset._get_as_pymongo(queryset._cursor.next()) + return queryset._document._from_son(queryset._cursor[key], + _auto_dereference=self._auto_dereference) + raise AttributeError + + def __iter__(self): + raise NotImplementedError + + # Core functions + + def all(self): + """Returns all documents.""" + return self.__call__() + + def filter(self, *q_objs, **query): + """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` + """ + return self.__call__(*q_objs, **query) + + def get(self, *q_objs, **query): + """Retrieve the the matching object raising + :class:`~mongoengine.queryset.MultipleObjectsReturned` or + `DocumentName.MultipleObjectsReturned` exception if multiple results + and :class:`~mongoengine.queryset.DoesNotExist` or + `DocumentName.DoesNotExist` if no results are found. + + .. versionadded:: 0.3 + """ + queryset = self.clone() + queryset = queryset.limit(2) + queryset = queryset.filter(*q_objs, **query) + + try: + result = queryset.next() + except StopIteration: + msg = ("%s matching query does not exist." + % queryset._document._class_name) + raise queryset._document.DoesNotExist(msg) + try: + queryset.next() + except StopIteration: + return result + + queryset.rewind() + message = u'%d items returned, instead of 1' % queryset.count() + raise queryset._document.MultipleObjectsReturned(message) + + def create(self, **kwargs): + """Create new object. Returns the saved object instance. + + .. versionadded:: 0.4 + """ + return self._document(**kwargs).save() + + def get_or_create(self, write_concern=None, auto_save=True, + *q_objs, **query): + """Retrieve unique object or create, if it doesn't exist. Returns a + tuple of ``(object, created)``, where ``object`` is the retrieved or + created object and ``created`` is a boolean specifying whether a new + object was created. Raises + :class:`~mongoengine.queryset.MultipleObjectsReturned` or + `DocumentName.MultipleObjectsReturned` if multiple results are found. + A new document will be created if the document doesn't exists; a + dictionary of default values for the new document may be provided as a + keyword argument called :attr:`defaults`. + + .. note:: This requires two separate operations and therefore a + race condition exists. Because there are no transactions in + mongoDB other approaches should be investigated, to ensure you + don't accidently duplicate data when using this method. This is + now scheduled to be removed before 1.0 + + :param write_concern: optional extra keyword arguments used if we + have to create a new document. + Passes any write_concern onto :meth:`~mongoengine.Document.save` + + :param auto_save: if the object is to be saved automatically if + not found. + + .. deprecated:: 0.8 + .. versionchanged:: 0.6 - added `auto_save` + .. versionadded:: 0.3 + """ + msg = ("get_or_create is scheduled to be deprecated. The approach is " + "flawed without transactions. Upserts should be preferred.") + warnings.warn(msg, DeprecationWarning) + + defaults = query.get('defaults', {}) + if 'defaults' in query: + del query['defaults'] + + try: + doc = self.get(*q_objs, **query) + return doc, False + except self._document.DoesNotExist: + query.update(defaults) + doc = self._document(**query) + + if auto_save: + doc.save(write_concern=write_concern) + return doc, True + + def first(self): + """Retrieve the first object matching the query. + """ + queryset = self.clone() + try: + result = queryset[0] + except IndexError: + result = None + return result + + def insert(self, doc_or_docs, load_bulk=True, write_concern=None): + """bulk insert documents + + :param docs_or_doc: a document or list of documents to be inserted + :param load_bulk (optional): If True returns the list of document + instances + :param write_concern: Extra keyword arguments are passed down to + :meth:`~pymongo.collection.Collection.insert` + which will be used as options for the resultant + ``getLastError`` command. For example, + ``insert(..., {w: 2, fsync: True})`` will wait until at least + two servers have recorded the write and will force an fsync on + each server being written to. + + By default returns document instances, set ``load_bulk`` to False to + return just ``ObjectIds`` + + .. versionadded:: 0.5 + """ + Document = _import_class('Document') + + if write_concern is None: + write_concern = {} + + docs = doc_or_docs + return_one = False + if isinstance(docs, Document) or issubclass(docs.__class__, Document): + return_one = True + docs = [docs] + + raw = [] + for doc in docs: + if not isinstance(doc, self._document): + msg = ("Some documents inserted aren't instances of %s" + % str(self._document)) + raise OperationError(msg) + if doc.pk and not doc._created: + msg = "Some documents have ObjectIds use doc.update() instead" + raise OperationError(msg) + raw.append(doc.to_mongo()) + + signals.pre_bulk_insert.send(self._document, documents=docs) + try: + ids = self._collection.insert(raw, **write_concern) + except pymongo.errors.OperationFailure, err: + message = 'Could not save document (%s)' + if re.match('^E1100[01] duplicate key', unicode(err)): + # E11000 - duplicate key error index + # E11001 - duplicate key on update + message = u'Tried to save duplicate unique keys (%s)' + raise NotUniqueError(message % unicode(err)) + raise OperationError(message % unicode(err)) + + if not load_bulk: + signals.post_bulk_insert.send( + self._document, documents=docs, loaded=False) + return return_one and ids[0] or ids + + documents = self.in_bulk(ids) + results = [] + for obj_id in ids: + results.append(documents.get(obj_id)) + signals.post_bulk_insert.send( + self._document, documents=results, loaded=True) + return return_one and results[0] or results + + def count(self, with_limit_and_skip=True): + """Count the selected elements in the query. + + :param with_limit_and_skip (optional): take any :meth:`limit` or + :meth:`skip` that has been applied to this cursor into account when + getting the count + """ + if self._limit == 0 and with_limit_and_skip: + return 0 + return self._cursor.count(with_limit_and_skip=with_limit_and_skip) + + def delete(self, write_concern=None, _from_doc_delete=False): + """Delete the documents matched by the query. + + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param _from_doc_delete: True when called from document delete therefore + signals will have been triggered so don't loop. + """ + queryset = self.clone() + doc = queryset._document + + if write_concern is None: + write_concern = {} + + # Handle deletes where skips or limits have been applied or + # there is an untriggered delete signal + has_delete_signal = signals.signals_available and ( + signals.pre_delete.has_receivers_for(self._document) or + signals.post_delete.has_receivers_for(self._document)) + + call_document_delete = (queryset._skip or queryset._limit or + has_delete_signal) and not _from_doc_delete + + if call_document_delete: + for doc in queryset: + doc.delete(write_concern=write_concern) + return + + delete_rules = doc._meta.get('delete_rules') or {} + # Check for DENY rules before actually deleting/nullifying any other + # references + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == DENY and document_cls.objects( + **{field_name + '__in': self}).count() > 0: + msg = ("Could not delete document (%s.%s refers to it)" + % (document_cls.__name__, field_name)) + raise OperationError(msg) + + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == CASCADE: + ref_q = document_cls.objects(**{field_name + '__in': self}) + ref_q_count = ref_q.count() + if (doc != document_cls and ref_q_count > 0 + or (doc == document_cls and ref_q_count > 0)): + ref_q.delete(write_concern=write_concern) + elif rule == NULLIFY: + document_cls.objects(**{field_name + '__in': self}).update( + write_concern=write_concern, **{'unset__%s' % field_name: 1}) + elif rule == PULL: + document_cls.objects(**{field_name + '__in': self}).update( + write_concern=write_concern, + **{'pull_all__%s' % field_name: self}) + + queryset._collection.remove(queryset._query, write_concern=write_concern) + + def update(self, upsert=False, multi=True, write_concern=None, + full_result=False, **update): + """Perform an atomic update on the fields matched by the query. + + :param upsert: Any existing document with that "_id" is overwritten. + :param multi: Update multiple documents. + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param full_result: Return the full result rather than just the number + updated. + :param update: Django-style update keyword arguments + + .. versionadded:: 0.2 + """ + if not update and not upsert: + raise OperationError("No update parameters, would remove data") + + if write_concern is None: + write_concern = {} + + queryset = self.clone() + query = queryset._query + update = transform.update(queryset._document, **update) + + # If doing an atomic upsert on an inheritable class + # then ensure we add _cls to the update operation + if upsert and '_cls' in query: + if '$set' in update: + update["$set"]["_cls"] = queryset._document._class_name + else: + update["$set"] = {"_cls": queryset._document._class_name} + try: + result = queryset._collection.update(query, update, multi=multi, + upsert=upsert, **write_concern) + if full_result: + return result + elif result: + return result['n'] + except pymongo.errors.OperationFailure, err: + if unicode(err) == u'multi not coded yet': + message = u'update() method requires MongoDB 1.1.3+' + raise OperationError(message) + raise OperationError(u'Update failed (%s)' % unicode(err)) + + def update_one(self, upsert=False, write_concern=None, **update): + """Perform an atomic update on first field matched by the query. + + :param upsert: Any existing document with that "_id" is overwritten. + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param update: Django-style update keyword arguments + + .. versionadded:: 0.2 + """ + return self.update( + upsert=upsert, multi=False, write_concern=write_concern, **update) + + def with_id(self, object_id): + """Retrieve the object matching the id provided. Uses `object_id` only + and raises InvalidQueryError if a filter has been applied. Returns + `None` if no document exists with that id. + + :param object_id: the value for the id of the document to look up + + .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set + """ + queryset = self.clone() + if not queryset._query_obj.empty: + msg = "Cannot use a filter whilst using `with_id`" + raise InvalidQueryError(msg) + return queryset.filter(pk=object_id).first() + + def in_bulk(self, object_ids): + """Retrieve a set of documents by their ids. + + :param object_ids: a list or tuple of ``ObjectId``\ s + :rtype: dict of ObjectIds as keys and collection-specific + Document subclasses as values. + + .. versionadded:: 0.3 + """ + doc_map = {} + + docs = self._collection.find({'_id': {'$in': object_ids}}, + **self._cursor_args) + if self._scalar: + for doc in docs: + doc_map[doc['_id']] = self._get_scalar( + self._document._from_son(doc)) + elif self._as_pymongo: + for doc in docs: + doc_map[doc['_id']] = self._get_as_pymongo(doc) + else: + for doc in docs: + doc_map[doc['_id']] = self._document._from_son(doc) + + return doc_map + + def none(self): + """Helper that just returns a list""" + queryset = self.clone() + queryset._none = True + return queryset + + def no_sub_classes(self): + """ + Only return instances of this document and not any inherited documents + """ + if self._document._meta.get('allow_inheritance') is True: + self._initial_query = {"_cls": self._document._class_name} + + return self + + def clone(self): + """Creates a copy of the current + :class:`~mongoengine.queryset.QuerySet` + + .. versionadded:: 0.5 + """ + return self.clone_into(self.__class__(self._document, self._collection_obj)) + + def clone_into(self, cls): + """Creates a copy of the current + :class:`~mongoengine.queryset.base.BaseQuerySet` into another child class + """ + if not isinstance(cls, BaseQuerySet): + raise OperationError('%s is not a subclass of BaseQuerySet' % cls.__name__) + + copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', + '_where_clause', '_loaded_fields', '_ordering', '_snapshot', + '_timeout', '_class_check', '_slave_okay', '_read_preference', + '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', + '_limit', '_skip', '_hint', '_auto_dereference') + + for prop in copy_props: + val = getattr(self, prop) + setattr(cls, prop, copy.copy(val)) + + if self._cursor_obj: + cls._cursor_obj = self._cursor_obj.clone() + + return cls + + def select_related(self, max_depth=1): + """Handles dereferencing of :class:`~bson.dbref.DBRef` objects or + :class:`~bson.object_id.ObjectId` a maximum depth in order to cut down + the number queries to mongodb. + + .. versionadded:: 0.5 + """ + # Make select related work the same for querysets + max_depth += 1 + queryset = self.clone() + return queryset._dereference(queryset, max_depth=max_depth) + + def limit(self, n): + """Limit the number of returned documents to `n`. This may also be + achieved using array-slicing syntax (e.g. ``User.objects[:5]``). + + :param n: the maximum number of objects to return + """ + queryset = self.clone() + if n == 0: + queryset._cursor.limit(1) + else: + queryset._cursor.limit(n) + queryset._limit = n + # Return self to allow chaining + return queryset + + def skip(self, n): + """Skip `n` documents before returning the results. This may also be + achieved using array-slicing syntax (e.g. ``User.objects[5:]``). + + :param n: the number of objects to skip before returning results + """ + queryset = self.clone() + queryset._cursor.skip(n) + queryset._skip = n + return queryset + + def hint(self, index=None): + """Added 'hint' support, telling Mongo the proper index to use for the + query. + + Judicious use of hints can greatly improve query performance. When + doing a query on multiple fields (at least one of which is indexed) + pass the indexed field as a hint to the query. + + Hinting will not do anything if the corresponding index does not exist. + The last hint applied to this cursor takes precedence over all others. + + .. versionadded:: 0.5 + """ + queryset = self.clone() + queryset._cursor.hint(index) + queryset._hint = index + return queryset + + def distinct(self, field): + """Return a list of distinct values for a given field. + + :param field: the field to select distinct values from + + .. note:: This is a command and won't take ordering or limit into + account. + + .. versionadded:: 0.4 + .. versionchanged:: 0.5 - Fixed handling references + .. versionchanged:: 0.6 - Improved db_field refrence handling + """ + queryset = self.clone() + try: + field = self._fields_to_dbfields([field]).pop() + finally: + return self._dereference(queryset._cursor.distinct(field), 1, + name=field, instance=self._document) + + def only(self, *fields): + """Load only a subset of this document's fields. :: + + post = BlogPost.objects(...).only("title", "author.name") + + .. note :: `only()` is chainable and will perform a union :: + So with the following it will fetch both: `title` and `author.name`:: + + post = BlogPost.objects.only("title").only("author.name") + + :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any + field filters. + + :param fields: fields to include + + .. versionadded:: 0.3 + .. versionchanged:: 0.5 - Added subfield support + """ + fields = dict([(f, QueryFieldList.ONLY) for f in fields]) + return self.fields(True, **fields) + + def exclude(self, *fields): + """Opposite to .only(), exclude some document's fields. :: + + post = BlogPost.objects(...).exclude("comments") + + .. note :: `exclude()` is chainable and will perform a union :: + So with the following it will exclude both: `title` and `author.name`:: + + post = BlogPost.objects.exclude("title").exclude("author.name") + + :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any + field filters. + + :param fields: fields to exclude + + .. versionadded:: 0.5 + """ + fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) + return self.fields(**fields) + + def fields(self, _only_called=False, **kwargs): + """Manipulate how you load this document's fields. Used by `.only()` + and `.exclude()` to manipulate which fields to retrieve. Fields also + allows for a greater level of control for example: + + Retrieving a Subrange of Array Elements: + + You can use the $slice operator to retrieve a subrange of elements in + an array. For example to get the first 5 comments:: + + post = BlogPost.objects(...).fields(slice__comments=5) + + :param kwargs: A dictionary identifying what to include + + .. versionadded:: 0.5 + """ + + # Check for an operator and transform to mongo-style if there is + operators = ["slice"] + cleaned_fields = [] + for key, value in kwargs.items(): + parts = key.split('__') + op = None + if parts[0] in operators: + op = parts.pop(0) + value = {'$' + op: value} + key = '.'.join(parts) + cleaned_fields.append((key, value)) + + fields = sorted(cleaned_fields, key=operator.itemgetter(1)) + queryset = self.clone() + for value, group in itertools.groupby(fields, lambda x: x[1]): + fields = [field for field, value in group] + fields = queryset._fields_to_dbfields(fields) + queryset._loaded_fields += QueryFieldList(fields, value=value, _only_called=_only_called) + + return queryset + + def all_fields(self): + """Include all fields. Reset all previously calls of .only() or + .exclude(). :: + + post = BlogPost.objects.exclude("comments").all_fields() + + .. versionadded:: 0.5 + """ + queryset = self.clone() + queryset._loaded_fields = QueryFieldList( + always_include=queryset._loaded_fields.always_include) + return queryset + + def order_by(self, *keys): + """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The + order may be specified by prepending each of the keys by a + or a -. + Ascending order is assumed. + + :param keys: fields to order the query results by; keys may be + prefixed with **+** or **-** to determine the ordering direction + """ + queryset = self.clone() + queryset._ordering = queryset._get_order_by(keys) + return queryset + + def explain(self, format=False): + """Return an explain plan record for the + :class:`~mongoengine.queryset.QuerySet`\ 's cursor. + + :param format: format the plan before returning it + """ + plan = self._cursor.explain() + if format: + plan = pprint.pformat(plan) + return plan + + def snapshot(self, enabled): + """Enable or disable snapshot mode when querying. + + :param enabled: whether or not snapshot mode is enabled + + ..versionchanged:: 0.5 - made chainable + """ + queryset = self.clone() + queryset._snapshot = enabled + return queryset + + def timeout(self, enabled): + """Enable or disable the default mongod timeout when querying. + + :param enabled: whether or not the timeout is used + + ..versionchanged:: 0.5 - made chainable + """ + queryset = self.clone() + queryset._timeout = enabled + return queryset + + def slave_okay(self, enabled): + """Enable or disable the slave_okay when querying. + + :param enabled: whether or not the slave_okay is enabled + """ + queryset = self.clone() + queryset._slave_okay = enabled + return queryset + + def read_preference(self, read_preference): + """Change the read_preference when querying. + + :param read_preference: override ReplicaSetConnection-level + preference. + """ + validate_read_preference('read_preference', read_preference) + queryset = self.clone() + queryset._read_preference = read_preference + return queryset + + def scalar(self, *fields): + """Instead of returning Document instances, return either a specific + value or a tuple of values in order. + + Can be used along with + :func:`~mongoengine.queryset.QuerySet.no_dereference` to turn off + dereferencing. + + .. note:: This effects all results and can be unset by calling + ``scalar`` without arguments. Calls ``only`` automatically. + + :param fields: One or more fields to return instead of a Document. + """ + queryset = self.clone() + queryset._scalar = list(fields) + + if fields: + queryset = queryset.only(*fields) + else: + queryset = queryset.all_fields() + + return queryset + + def values_list(self, *fields): + """An alias for scalar""" + return self.scalar(*fields) + + def as_pymongo(self, coerce_types=False): + """Instead of returning Document instances, return raw values from + pymongo. + + :param coerce_type: Field types (if applicable) would be use to + coerce types. + """ + queryset = self.clone() + queryset._as_pymongo = True + queryset._as_pymongo_coerce = coerce_types + return queryset + + # JSON Helpers + + def to_json(self, *args, **kwargs): + """Converts a queryset to JSON""" + return json_util.dumps(self.as_pymongo(), *args, **kwargs) + + def from_json(self, json_data): + """Converts json data to unsaved objects""" + son_data = json_util.loads(json_data) + return [self._document._from_son(data) for data in son_data] + + # JS functionality + + def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, + scope=None): + """Perform a map/reduce query using the current query spec + and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, + it must be the last call made, as it does not return a maleable + ``QuerySet``. + + See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` + and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` + tests in ``tests.queryset.QuerySetTest`` for usage examples. + + :param map_f: map function, as :class:`~bson.code.Code` or string + :param reduce_f: reduce function, as + :class:`~bson.code.Code` or string + :param output: output collection name, if set to 'inline' will try to + use :class:`~pymongo.collection.Collection.inline_map_reduce` + This can also be a dictionary containing output options + see: http://docs.mongodb.org/manual/reference/commands/#mapReduce + :param finalize_f: finalize function, an optional function that + performs any post-reduction processing. + :param scope: values to insert into map/reduce global scope. Optional. + :param limit: number of objects from current query to provide + to map/reduce method + + Returns an iterator yielding + :class:`~mongoengine.document.MapReduceDocument`. + + .. note:: + + Map/Reduce changed in server version **>= 1.7.4**. The PyMongo + :meth:`~pymongo.collection.Collection.map_reduce` helper requires + PyMongo version **>= 1.11**. + + .. versionchanged:: 0.5 + - removed ``keep_temp`` keyword argument, which was only relevant + for MongoDB server versions older than 1.7.4 + + .. versionadded:: 0.3 + """ + queryset = self.clone() + + MapReduceDocument = _import_class('MapReduceDocument') + + if not hasattr(self._collection, "map_reduce"): + raise NotImplementedError("Requires MongoDB >= 1.7.1") + + map_f_scope = {} + if isinstance(map_f, Code): + map_f_scope = map_f.scope + map_f = unicode(map_f) + map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) + + reduce_f_scope = {} + if isinstance(reduce_f, Code): + reduce_f_scope = reduce_f.scope + reduce_f = unicode(reduce_f) + reduce_f_code = queryset._sub_js_fields(reduce_f) + reduce_f = Code(reduce_f_code, reduce_f_scope) + + mr_args = {'query': queryset._query} + + if finalize_f: + finalize_f_scope = {} + if isinstance(finalize_f, Code): + finalize_f_scope = finalize_f.scope + finalize_f = unicode(finalize_f) + finalize_f_code = queryset._sub_js_fields(finalize_f) + finalize_f = Code(finalize_f_code, finalize_f_scope) + mr_args['finalize'] = finalize_f + + if scope: + mr_args['scope'] = scope + + if limit: + mr_args['limit'] = limit + + if output == 'inline' and not queryset._ordering: + map_reduce_function = 'inline_map_reduce' + else: + map_reduce_function = 'map_reduce' + mr_args['out'] = output + + results = getattr(queryset._collection, map_reduce_function)( + map_f, reduce_f, **mr_args) + + if map_reduce_function == 'map_reduce': + results = results.find() + + if queryset._ordering: + results = results.sort(queryset._ordering) + + for doc in results: + yield MapReduceDocument(queryset._document, queryset._collection, + doc['_id'], doc['value']) + + def exec_js(self, code, *fields, **options): + """Execute a Javascript function on the server. A list of fields may be + provided, which will be translated to their correct names and supplied + as the arguments to the function. A few extra variables are added to + the function's scope: ``collection``, which is the name of the + collection in use; ``query``, which is an object representing the + current query; and ``options``, which is an object containing any + options specified as keyword arguments. + + As fields in MongoEngine may use different names in the database (set + using the :attr:`db_field` keyword argument to a :class:`Field` + constructor), a mechanism exists for replacing MongoEngine field names + with the database field names in Javascript code. When accessing a + field, use square-bracket notation, and prefix the MongoEngine field + name with a tilde (~). + + :param code: a string of Javascript code to execute + :param fields: fields that you will be using in your function, which + will be passed in to your function as arguments + :param options: options that you want available to the function + (accessed in Javascript through the ``options`` object) + """ + queryset = self.clone() + + code = queryset._sub_js_fields(code) + + fields = [queryset._document._translate_field_name(f) for f in fields] + collection = queryset._document._get_collection_name() + + scope = { + 'collection': collection, + 'options': options or {}, + } + + query = queryset._query + if queryset._where_clause: + query['$where'] = queryset._where_clause + + scope['query'] = query + code = Code(code, scope=scope) + + db = queryset._document._get_db() + return db.eval(code, *fields) + + def where(self, where_clause): + """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript + expression). Performs automatic field name substitution like + :meth:`mongoengine.queryset.Queryset.exec_js`. + + .. note:: When using this mode of query, the database will call your + function, or evaluate your predicate clause, for each object + in the collection. + + .. versionadded:: 0.5 + """ + queryset = self.clone() + where_clause = queryset._sub_js_fields(where_clause) + queryset._where_clause = where_clause + return queryset + + def sum(self, field): + """Sum over the values of the specified field. + + :param field: the field to sum over; use dot-notation to refer to + embedded document fields + + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. + """ + map_func = """ + function() { + var path = '{{~%(field)s}}'.split('.'), + field = this; + + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; + } + + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(1, item||0); + }); + } else if (typeof field != 'undefined') { + emit(1, field||0); + } + } + """ % dict(field=field) + + reduce_func = Code(""" + function(key, values) { + var sum = 0; + for (var i in values) { + sum += values[i]; + } + return sum; + } + """) + + for result in self.map_reduce(map_func, reduce_func, output='inline'): + return result.value + else: + return 0 + + def average(self, field): + """Average over the values of the specified field. + + :param field: the field to average over; use dot-notation to refer to + embedded document fields + + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. + """ + map_func = """ + function() { + var path = '{{~%(field)s}}'.split('.'), + field = this; + + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; + } + + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(1, {t: item||0, c: 1}); + }); + } else if (typeof field != 'undefined') { + emit(1, {t: field||0, c: 1}); + } + } + """ % dict(field=field) + + reduce_func = Code(""" + function(key, values) { + var out = {t: 0, c: 0}; + for (var i in values) { + var value = values[i]; + out.t += value.t; + out.c += value.c; + } + return out; + } + """) + + finalize_func = Code(""" + function(key, value) { + return value.t / value.c; + } + """) + + for result in self.map_reduce(map_func, reduce_func, + finalize_f=finalize_func, output='inline'): + return result.value + else: + return 0 + + def item_frequencies(self, field, normalize=False, map_reduce=True): + """Returns a dictionary of all items present in a field across + the whole queried set of documents, and their corresponding frequency. + This is useful for generating tag clouds, or searching documents. + + .. note:: + + Can only do direct simple mappings and cannot map across + :class:`~mongoengine.fields.ReferenceField` or + :class:`~mongoengine.fields.GenericReferenceField` for more complex + counting a manual map reduce call would is required. + + If the field is a :class:`~mongoengine.fields.ListField`, the items within + each list will be counted individually. + + :param field: the field to use + :param normalize: normalize the results so they add to 1.0 + :param map_reduce: Use map_reduce over exec_js + + .. versionchanged:: 0.5 defaults to map_reduce and can handle embedded + document lookups + """ + if map_reduce: + return self._item_frequencies_map_reduce(field, + normalize=normalize) + return self._item_frequencies_exec_js(field, normalize=normalize) + + # Iterator helpers + + def next(self): + """Wrap the result in a :class:`~mongoengine.Document` object. + """ + if self._limit == 0 or self._none: + raise StopIteration + + raw_doc = self._cursor.next() + if self._as_pymongo: + return self._get_as_pymongo(raw_doc) + doc = self._document._from_son(raw_doc, + _auto_dereference=self._auto_dereference) + if self._scalar: + return self._get_scalar(doc) + + return doc + + def rewind(self): + """Rewind the cursor to its unevaluated state. + + .. versionadded:: 0.3 + """ + self._iter = False + self._cursor.rewind() + + # Properties + + @property + def _collection(self): + """Property that returns the collection object. This allows us to + perform operations only if the collection is accessed. + """ + return self._collection_obj + + @property + def _cursor_args(self): + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout + } + if self._read_preference is not None: + cursor_args['read_preference'] = self._read_preference + else: + cursor_args['slave_okay'] = self._slave_okay + if self._loaded_fields: + cursor_args['fields'] = self._loaded_fields.as_dict() + return cursor_args + + @property + def _cursor(self): + if self._cursor_obj is None: + + self._cursor_obj = self._collection.find(self._query, + **self._cursor_args) + # Apply where clauses to cursor + if self._where_clause: + where_clause = self._sub_js_fields(self._where_clause) + self._cursor_obj.where(where_clause) + + if self._ordering: + # Apply query ordering + self._cursor_obj.sort(self._ordering) + elif self._document._meta['ordering']: + # Otherwise, apply the ordering from the document model + order = self._get_order_by(self._document._meta['ordering']) + self._cursor_obj.sort(order) + + if self._limit is not None: + self._cursor_obj.limit(self._limit) + + if self._skip is not None: + self._cursor_obj.skip(self._skip) + + if self._hint != -1: + self._cursor_obj.hint(self._hint) + + return self._cursor_obj + + def __deepcopy__(self, memo): + """Essential for chained queries with ReferenceFields involved""" + return self.clone() + + @property + def _query(self): + if self._mongo_query is None: + self._mongo_query = self._query_obj.to_query(self._document) + if self._class_check: + self._mongo_query.update(self._initial_query) + return self._mongo_query + + @property + def _dereference(self): + if not self.__dereference: + self.__dereference = _import_class('DeReference')() + return self.__dereference + + def no_dereference(self): + """Turn off any dereferencing for the results of this queryset. + """ + queryset = self.clone() + queryset._auto_dereference = False + return queryset + + # Helper Functions + + def _item_frequencies_map_reduce(self, field, normalize=False): + map_func = """ + function() { + var path = '{{~%(field)s}}'.split('.'); + var field = this; + + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(item, 1); + }); + } else if (typeof field != 'undefined') { + emit(field, 1); + } else { + emit(null, 1); + } + } + """ % dict(field=field) + reduce_func = """ + function(key, values) { + var total = 0; + var valuesSize = values.length; + for (var i=0; i < valuesSize; i++) { + total += parseInt(values[i], 10); + } + return total; + } + """ + values = self.map_reduce(map_func, reduce_func, 'inline') + frequencies = {} + for f in values: + key = f.key + if isinstance(key, float): + if int(key) == key: + key = int(key) + frequencies[key] = int(f.value) + + if normalize: + count = sum(frequencies.values()) + frequencies = dict([(k, float(v) / count) + for k, v in frequencies.items()]) + + return frequencies + + def _item_frequencies_exec_js(self, field, normalize=False): + """Uses exec_js to execute""" + freq_func = """ + function(path) { + var path = path.split('.'); + + var total = 0.0; + db[collection].find(query).forEach(function(doc) { + var field = doc; + for (p in path) { + if (field) + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + total += field.length; + } else { + total++; + } + }); + + var frequencies = {}; + var types = {}; + var inc = 1.0; + + db[collection].find(query).forEach(function(doc) { + field = doc; + for (p in path) { + if (field) + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + field.forEach(function(item) { + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); + }); + } else { + var item = field; + types[item] = item; + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); + } + }); + return [total, frequencies, types]; + } + """ + total, data, types = self.exec_js(freq_func, field) + values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) + + if normalize: + values = dict([(k, float(v) / total) for k, v in values.items()]) + + frequencies = {} + for k, v in values.iteritems(): + if isinstance(k, float): + if int(k) == k: + k = int(k) + + frequencies[k] = v + + return frequencies + + def _fields_to_dbfields(self, fields, subdoc=False): + """Translate fields paths to its db equivalents""" + ret = [] + subclasses = [] + document = self._document + if document._meta['allow_inheritance']: + subclasses = [get_document(x) + for x in document._subclasses][1:] + for field in fields: + try: + field = ".".join(f.db_field for f in + document._lookup_field(field.split('.'))) + ret.append(field) + except LookUpError, err: + found = False + for subdoc in subclasses: + try: + subfield = ".".join(f.db_field for f in + subdoc._lookup_field(field.split('.'))) + ret.append(subfield) + found = True + break + except LookUpError, e: + pass + + if not found: + raise err + return ret + + def _get_order_by(self, keys): + """Creates a list of order by fields + """ + key_list = [] + for key in keys: + if not key: + continue + direction = pymongo.ASCENDING + if key[0] == '-': + direction = pymongo.DESCENDING + if key[0] in ('-', '+'): + key = key[1:] + key = key.replace('__', '.') + try: + key = self._document._translate_field_name(key) + except: + pass + key_list.append((key, direction)) + + if self._cursor_obj: + self._cursor_obj.sort(key_list) + return key_list + + def _get_scalar(self, doc): + + def lookup(obj, name): + chunks = name.split('__') + for chunk in chunks: + obj = getattr(obj, chunk) + return obj + + data = [lookup(doc, n) for n in self._scalar] + if len(data) == 1: + return data[0] + + return tuple(data) + + def _get_as_pymongo(self, row): + # Extract which fields paths we should follow if .fields(...) was + # used. If not, handle all fields. + if not getattr(self, '__as_pymongo_fields', None): + self.__as_pymongo_fields = [] + + for field in self._loaded_fields.fields - set(['_cls']): + self.__as_pymongo_fields.append(field) + while '.' in field: + field, _ = field.rsplit('.', 1) + self.__as_pymongo_fields.append(field) + + all_fields = not self.__as_pymongo_fields + + def clean(data, path=None): + path = path or '' + + if isinstance(data, dict): + new_data = {} + for key, value in data.iteritems(): + new_path = '%s.%s' % (path, key) if path else key + + if all_fields: + include_field = True + elif self._loaded_fields.value == QueryFieldList.ONLY: + include_field = new_path in self.__as_pymongo_fields + else: + include_field = new_path not in self.__as_pymongo_fields + + if include_field: + new_data[key] = clean(value, path=new_path) + data = new_data + elif isinstance(data, list): + data = [clean(d, path=path) for d in data] + else: + if self._as_pymongo_coerce: + # If we need to coerce types, we need to determine the + # type of this field and use the corresponding + # .to_python(...) + from mongoengine.fields import EmbeddedDocumentField + obj = self._document + for chunk in path.split('.'): + obj = getattr(obj, chunk, None) + if obj is None: + break + elif isinstance(obj, EmbeddedDocumentField): + obj = obj.document_type + if obj and data is not None: + data = obj.to_python(data) + return data + return clean(row) + + def _sub_js_fields(self, code): + """When fields are specified with [~fieldname] syntax, where + *fieldname* is the Python name of a field, *fieldname* will be + substituted for the MongoDB name of the field (specified using the + :attr:`name` keyword argument in a field's constructor). + """ + def field_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return u'["%s"]' % fields[-1].db_field + + def field_path_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return ".".join([f.db_field for f in fields]) + + code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) + code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, + code) + return code + + # Deprecated + def ensure_index(self, **kwargs): + """Deprecated use :func:`Document.ensure_index`""" + msg = ("Doc.objects()._ensure_index() is deprecated. " + "Use Doc.ensure_index() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_index(**kwargs) + return self + + def _ensure_indexes(self): + """Deprecated use :func:`~Document.ensure_indexes`""" + msg = ("Doc.objects()._ensure_indexes() is deprecated. " + "Use Doc.ensure_indexes() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_indexes() \ No newline at end of file diff --git a/mongoengine/queryset/field_list.py b/mongoengine/queryset/field_list.py index 73d3cc2..140a71e 100644 --- a/mongoengine/queryset/field_list.py +++ b/mongoengine/queryset/field_list.py @@ -55,7 +55,8 @@ class QueryFieldList(object): if self.always_include: if self.value is self.ONLY and self.fields: - self.fields = self.fields.union(self.always_include) + if sorted(self.slice.keys()) != sorted(self.fields): + self.fields = self.fields.union(self.always_include) else: self.fields -= self.always_include diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 235d27b..1437e76 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -1,137 +1,26 @@ -from __future__ import absolute_import +from mongoengine.errors import OperationError +from mongoengine.queryset.base import (BaseQuerySet, DO_NOTHING, NULLIFY, + CASCADE, DENY, PULL) -import copy -import itertools -import operator -import pprint -import re -import warnings - -from bson.code import Code -from bson import json_util -import pymongo -from pymongo.common import validate_read_preference - -from mongoengine import signals -from mongoengine.common import _import_class -from mongoengine.errors import (OperationError, NotUniqueError, - InvalidQueryError) - -from mongoengine.queryset import transform -from mongoengine.queryset.field_list import QueryFieldList -from mongoengine.queryset.visitor import Q, QNode - - -__all__ = ('QuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') +__all__ = ('QuerySet', 'QuerySetNoCache', 'DO_NOTHING', 'NULLIFY', 'CASCADE', + 'DENY', 'PULL') # The maximum number of items to display in a QuerySet.__repr__ REPR_OUTPUT_SIZE = 20 ITER_CHUNK_SIZE = 100 -# Delete rules -DO_NOTHING = 0 -NULLIFY = 1 -CASCADE = 2 -DENY = 3 -PULL = 4 -RE_TYPE = type(re.compile('')) +class QuerySet(BaseQuerySet): + """The default queryset, that builds queries and handles a set of results + returned from a query. - -class QuerySet(object): - """A set of results returned from a query. Wraps a MongoDB cursor, - providing :class:`~mongoengine.Document` objects as the results. + Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as + the results. """ - __dereference = False - _auto_dereference = True - def __init__(self, document, collection): - self._document = document - self._collection_obj = collection - self._mongo_query = None - self._query_obj = Q() - self._initial_query = {} - self._where_clause = None - self._loaded_fields = QueryFieldList() - self._ordering = None - self._snapshot = False - self._timeout = True - self._class_check = True - self._slave_okay = False - self._read_preference = None - self._iter = False - self._scalar = [] - self._none = False - self._as_pymongo = False - self._as_pymongo_coerce = False - self._result_cache = [] - self._has_more = True - self._len = None - - # If inheritance is allowed, only return instances and instances of - # subclasses of the class being used - if document._meta.get('allow_inheritance') is True: - if len(self._document._subclasses) == 1: - self._initial_query = {"_cls": self._document._subclasses[0]} - else: - self._initial_query = {"_cls": {"$in": self._document._subclasses}} - self._loaded_fields = QueryFieldList(always_include=['_cls']) - self._cursor_obj = None - self._limit = None - self._skip = None - self._hint = -1 # Using -1 as None is a valid value for hint - - def __call__(self, q_obj=None, class_check=True, slave_okay=False, - read_preference=None, **query): - """Filter the selected documents by calling the - :class:`~mongoengine.queryset.QuerySet` with a query. - - :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in - the query; the :class:`~mongoengine.queryset.QuerySet` is filtered - multiple times with different :class:`~mongoengine.queryset.Q` - objects, only the last one will be used - :param class_check: If set to False bypass class name check when - querying collection - :param slave_okay: if True, allows this query to be run against a - replica secondary. - :params read_preference: if set, overrides connection-level - read_preference from `ReplicaSetConnection`. - :param query: Django-style query keyword arguments - """ - query = Q(**query) - if q_obj: - # make sure proper query object is passed - if not isinstance(q_obj, QNode): - msg = ("Not a query object: %s. " - "Did you intend to use key=value?" % q_obj) - raise InvalidQueryError(msg) - query &= q_obj - - if read_preference is None: - queryset = self.clone() - else: - # Use the clone provided when setting read_preference - queryset = self.read_preference(read_preference) - - queryset._query_obj &= query - queryset._mongo_query = None - queryset._cursor_obj = None - queryset._class_check = class_check - - return queryset - - def __len__(self): - """Since __len__ is called quite frequently (for example, as part of - list(qs) we populate the result cache and cache the length. - """ - if self._len is not None: - return self._len - if self._has_more: - # populate the cache - list(self._iter_results()) - - self._len = len(self._result_cache) - return self._len + _has_more = True + _len = None + _result_cache = None def __iter__(self): """Iteration utilises a results cache which iterates the cursor @@ -147,11 +36,39 @@ class QuerySet(object): # iterating over the cache. return iter(self._result_cache) + def __len__(self): + """Since __len__ is called quite frequently (for example, as part of + list(qs) we populate the result cache and cache the length. + """ + if self._len is not None: + return self._len + if self._has_more: + # populate the cache + list(self._iter_results()) + + self._len = len(self._result_cache) + return self._len + + def __repr__(self): + """Provides the string representation of the QuerySet + """ + if self._iter: + return '.. queryset mid-iteration ..' + + self._populate_cache() + data = self._result_cache[:REPR_OUTPUT_SIZE + 1] + if len(data) > REPR_OUTPUT_SIZE: + data[-1] = "...(remaining elements truncated)..." + return repr(data) + + def _iter_results(self): """A generator for iterating over the result cache. Also populates the cache if there are more possible results to yield. Raises StopIteration when there are no more results""" + if self._result_cache is None: + self._result_cache = [] pos = 0 while True: upper = len(self._result_cache) @@ -168,6 +85,8 @@ class QuerySet(object): Populates the result cache with ``ITER_CHUNK_SIZE`` more entries (until the cursor is exhausted). """ + if self._result_cache is None: + self._result_cache = [] if self._has_more: try: for i in xrange(ITER_CHUNK_SIZE): @@ -175,226 +94,6 @@ class QuerySet(object): except StopIteration: self._has_more = False - def __getitem__(self, key): - """Support skip and limit using getitem and slicing syntax. - """ - queryset = self.clone() - - # Slice provided - if isinstance(key, slice): - try: - queryset._cursor_obj = queryset._cursor[key] - queryset._skip, queryset._limit = key.start, key.stop - if key.start and key.stop: - queryset._limit = key.stop - key.start - except IndexError, err: - # PyMongo raises an error if key.start == key.stop, catch it, - # bin it, kill it. - start = key.start or 0 - if start >= 0 and key.stop >= 0 and key.step is None: - if start == key.stop: - queryset.limit(0) - queryset._skip = key.start - queryset._limit = key.stop - start - return queryset - raise err - # Allow further QuerySet modifications to be performed - return queryset - # Integer index provided - elif isinstance(key, int): - if queryset._scalar: - return queryset._get_scalar( - queryset._document._from_son(queryset._cursor[key], - _auto_dereference=self._auto_dereference)) - if queryset._as_pymongo: - return queryset._get_as_pymongo(queryset._cursor.next()) - return queryset._document._from_son(queryset._cursor[key], - _auto_dereference=self._auto_dereference) - raise AttributeError - - def __repr__(self): - """Provides the string representation of the QuerySet - """ - - if self._iter: - return '.. queryset mid-iteration ..' - - self._populate_cache() - data = self._result_cache[:REPR_OUTPUT_SIZE + 1] - if len(data) > REPR_OUTPUT_SIZE: - data[-1] = "...(remaining elements truncated)..." - return repr(data) - - # Core functions - - def all(self): - """Returns all documents.""" - return self.__call__() - - def filter(self, *q_objs, **query): - """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` - """ - return self.__call__(*q_objs, **query) - - def get(self, *q_objs, **query): - """Retrieve the the matching object raising - :class:`~mongoengine.queryset.MultipleObjectsReturned` or - `DocumentName.MultipleObjectsReturned` exception if multiple results - and :class:`~mongoengine.queryset.DoesNotExist` or - `DocumentName.DoesNotExist` if no results are found. - - .. versionadded:: 0.3 - """ - queryset = self.clone() - queryset = queryset.limit(2) - queryset = queryset.filter(*q_objs, **query) - - try: - result = queryset.next() - except StopIteration: - msg = ("%s matching query does not exist." - % queryset._document._class_name) - raise queryset._document.DoesNotExist(msg) - try: - queryset.next() - except StopIteration: - return result - - queryset.rewind() - message = u'%d items returned, instead of 1' % queryset.count() - raise queryset._document.MultipleObjectsReturned(message) - - def create(self, **kwargs): - """Create new object. Returns the saved object instance. - - .. versionadded:: 0.4 - """ - return self._document(**kwargs).save() - - def get_or_create(self, write_concern=None, auto_save=True, - *q_objs, **query): - """Retrieve unique object or create, if it doesn't exist. Returns a - tuple of ``(object, created)``, where ``object`` is the retrieved or - created object and ``created`` is a boolean specifying whether a new - object was created. Raises - :class:`~mongoengine.queryset.MultipleObjectsReturned` or - `DocumentName.MultipleObjectsReturned` if multiple results are found. - A new document will be created if the document doesn't exists; a - dictionary of default values for the new document may be provided as a - keyword argument called :attr:`defaults`. - - .. note:: This requires two separate operations and therefore a - race condition exists. Because there are no transactions in - mongoDB other approaches should be investigated, to ensure you - don't accidently duplicate data when using this method. This is - now scheduled to be removed before 1.0 - - :param write_concern: optional extra keyword arguments used if we - have to create a new document. - Passes any write_concern onto :meth:`~mongoengine.Document.save` - - :param auto_save: if the object is to be saved automatically if - not found. - - .. deprecated:: 0.8 - .. versionchanged:: 0.6 - added `auto_save` - .. versionadded:: 0.3 - """ - msg = ("get_or_create is scheduled to be deprecated. The approach is " - "flawed without transactions. Upserts should be preferred.") - warnings.warn(msg, DeprecationWarning) - - defaults = query.get('defaults', {}) - if 'defaults' in query: - del query['defaults'] - - try: - doc = self.get(*q_objs, **query) - return doc, False - except self._document.DoesNotExist: - query.update(defaults) - doc = self._document(**query) - - if auto_save: - doc.save(write_concern=write_concern) - return doc, True - - def first(self): - """Retrieve the first object matching the query. - """ - queryset = self.clone() - try: - result = queryset[0] - except IndexError: - result = None - return result - - def insert(self, doc_or_docs, load_bulk=True, write_concern=None): - """bulk insert documents - - :param docs_or_doc: a document or list of documents to be inserted - :param load_bulk (optional): If True returns the list of document - instances - :param write_concern: Extra keyword arguments are passed down to - :meth:`~pymongo.collection.Collection.insert` - which will be used as options for the resultant - ``getLastError`` command. For example, - ``insert(..., {w: 2, fsync: True})`` will wait until at least - two servers have recorded the write and will force an fsync on - each server being written to. - - By default returns document instances, set ``load_bulk`` to False to - return just ``ObjectIds`` - - .. versionadded:: 0.5 - """ - Document = _import_class('Document') - - if write_concern is None: - write_concern = {} - - docs = doc_or_docs - return_one = False - if isinstance(docs, Document) or issubclass(docs.__class__, Document): - return_one = True - docs = [docs] - - raw = [] - for doc in docs: - if not isinstance(doc, self._document): - msg = ("Some documents inserted aren't instances of %s" - % str(self._document)) - raise OperationError(msg) - if doc.pk and doc._created: - msg = "Some documents have ObjectIds use doc.update() instead" - raise OperationError(msg) - raw.append(doc.to_mongo()) - - signals.pre_bulk_insert.send(self._document, documents=docs) - try: - ids = self._collection.insert(raw, **write_concern) - except pymongo.errors.OperationFailure, err: - message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', unicode(err)): - # E11000 - duplicate key error index - # E11001 - duplicate key on update - message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) - raise OperationError(message % unicode(err)) - - if not load_bulk: - signals.post_bulk_insert.send( - self._document, documents=docs, loaded=False) - return return_one and ids[0] or ids - - documents = self.in_bulk(ids) - results = [] - for obj_id in ids: - results.append(documents.get(obj_id)) - signals.post_bulk_insert.send( - self._document, documents=results, loaded=True) - return return_one and results[0] or results - def count(self, with_limit_and_skip=True): """Count the selected elements in the query. @@ -402,1138 +101,57 @@ class QuerySet(object): :meth:`skip` that has been applied to this cursor into account when getting the count """ - if self._limit == 0: - return 0 - if with_limit_and_skip and self._len is not None: - return self._len - count = self._cursor.count(with_limit_and_skip=with_limit_and_skip) - if with_limit_and_skip: - self._len = count - return count + if with_limit_and_skip is False: + return super(QuerySet, self).count(with_limit_and_skip) - def delete(self, write_concern=None, _from_doc_delete=False): - """Delete the documents matched by the query. + if self._len is None: + self._len = super(QuerySet, self).count(with_limit_and_skip) - :param write_concern: Extra keyword arguments are passed down which - will be used as options for the resultant - ``getLastError`` command. For example, - ``save(..., write_concern={w: 2, fsync: True}, ...)`` will - wait until at least two servers have recorded the write and - will force an fsync on the primary server. - :param _from_doc_delete: True when called from document delete therefore - signals will have been triggered so don't loop. + return self._len + + def no_cache(self): + """Convert to a non_caching queryset + + .. versionadded:: 0.8.3 Convert to non caching queryset """ - queryset = self.clone() - doc = queryset._document + if self._result_cache is not None: + raise OperationError("QuerySet already cached") + return self.clone_into(QuerySetNoCache(self._document, self._collection)) - if write_concern is None: - write_concern = {} - # Handle deletes where skips or limits have been applied or - # there is an untriggered delete signal - has_delete_signal = signals.signals_available and ( - signals.pre_delete.has_receivers_for(self._document) or - signals.post_delete.has_receivers_for(self._document)) +class QuerySetNoCache(BaseQuerySet): + """A non caching QuerySet""" - call_document_delete = (queryset._skip or queryset._limit or - has_delete_signal) and not _from_doc_delete + def cache(self): + """Convert to a caching queryset - if call_document_delete: - for doc in queryset: - doc.delete(write_concern=write_concern) - return - - delete_rules = doc._meta.get('delete_rules') or {} - # Check for DENY rules before actually deleting/nullifying any other - # references - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == DENY and document_cls.objects( - **{field_name + '__in': self}).count() > 0: - msg = ("Could not delete document (%s.%s refers to it)" - % (document_cls.__name__, field_name)) - raise OperationError(msg) - - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == CASCADE: - ref_q = document_cls.objects(**{field_name + '__in': self}) - ref_q_count = ref_q.count() - if (doc != document_cls and ref_q_count > 0 - or (doc == document_cls and ref_q_count > 0)): - ref_q.delete(write_concern=write_concern) - elif rule == NULLIFY: - document_cls.objects(**{field_name + '__in': self}).update( - write_concern=write_concern, **{'unset__%s' % field_name: 1}) - elif rule == PULL: - document_cls.objects(**{field_name + '__in': self}).update( - write_concern=write_concern, - **{'pull_all__%s' % field_name: self}) - - queryset._collection.remove(queryset._query, write_concern=write_concern) - - def update(self, upsert=False, multi=True, write_concern=None, - full_result=False, **update): - """Perform an atomic update on the fields matched by the query. - - :param upsert: Any existing document with that "_id" is overwritten. - :param multi: Update multiple documents. - :param write_concern: Extra keyword arguments are passed down which - will be used as options for the resultant - ``getLastError`` command. For example, - ``save(..., write_concern={w: 2, fsync: True}, ...)`` will - wait until at least two servers have recorded the write and - will force an fsync on the primary server. - :param full_result: Return the full result rather than just the number - updated. - :param update: Django-style update keyword arguments - - .. versionadded:: 0.2 + .. versionadded:: 0.8.3 Convert to caching queryset """ - if not update and not upsert: - raise OperationError("No update parameters, would remove data") + return self.clone_into(QuerySet(self._document, self._collection)) - if write_concern is None: - write_concern = {} + def __repr__(self): + """Provides the string representation of the QuerySet - queryset = self.clone() - query = queryset._query - update = transform.update(queryset._document, **update) - - # If doing an atomic upsert on an inheritable class - # then ensure we add _cls to the update operation - if upsert and '_cls' in query: - if '$set' in update: - update["$set"]["_cls"] = queryset._document._class_name - else: - update["$set"] = {"_cls": queryset._document._class_name} - try: - result = queryset._collection.update(query, update, multi=multi, - upsert=upsert, **write_concern) - if full_result: - return result - elif result: - return result['n'] - except pymongo.errors.OperationFailure, err: - if unicode(err) == u'multi not coded yet': - message = u'update() method requires MongoDB 1.1.3+' - raise OperationError(message) - raise OperationError(u'Update failed (%s)' % unicode(err)) - - def update_one(self, upsert=False, write_concern=None, **update): - """Perform an atomic update on first field matched by the query. - - :param upsert: Any existing document with that "_id" is overwritten. - :param write_concern: Extra keyword arguments are passed down which - will be used as options for the resultant - ``getLastError`` command. For example, - ``save(..., write_concern={w: 2, fsync: True}, ...)`` will - wait until at least two servers have recorded the write and - will force an fsync on the primary server. - :param update: Django-style update keyword arguments - - .. versionadded:: 0.2 + .. versionchanged:: 0.6.13 Now doesnt modify the cursor """ - return self.update( - upsert=upsert, multi=False, write_concern=write_concern, **update) + if self._iter: + return '.. queryset mid-iteration ..' - def with_id(self, object_id): - """Retrieve the object matching the id provided. Uses `object_id` only - and raises InvalidQueryError if a filter has been applied. Returns - `None` if no document exists with that id. - - :param object_id: the value for the id of the document to look up - - .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set - """ - queryset = self.clone() - if not queryset._query_obj.empty: - msg = "Cannot use a filter whilst using `with_id`" - raise InvalidQueryError(msg) - return queryset.filter(pk=object_id).first() - - def in_bulk(self, object_ids): - """Retrieve a set of documents by their ids. - - :param object_ids: a list or tuple of ``ObjectId``\ s - :rtype: dict of ObjectIds as keys and collection-specific - Document subclasses as values. - - .. versionadded:: 0.3 - """ - doc_map = {} - - docs = self._collection.find({'_id': {'$in': object_ids}}, - **self._cursor_args) - if self._scalar: - for doc in docs: - doc_map[doc['_id']] = self._get_scalar( - self._document._from_son(doc)) - elif self._as_pymongo: - for doc in docs: - doc_map[doc['_id']] = self._get_as_pymongo(doc) - else: - for doc in docs: - doc_map[doc['_id']] = self._document._from_son(doc) - - return doc_map - - def none(self): - """Helper that just returns a list""" - queryset = self.clone() - queryset._none = True - return queryset - - def no_sub_classes(self): - """ - Only return instances of this document and not any inherited documents - """ - if self._document._meta.get('allow_inheritance') is True: - self._initial_query = {"_cls": self._document._class_name} - - return self - - def only_classes(self, *classes): - doc = self._document - if doc._meta.get('allow_inheritance') is True: - queryset = self.clone() - class_names = [cls._class_name for cls in classes] - allowed_class_names = [name for name in self._document._subclasses if name in class_names] - if len(allowed_class_names) == 1: - queryset._initial_query = {"_cls": allowed_class_names[0]} - else: - queryset._initial_query = {"_cls": {"$in": allowed_class_names}} - return queryset - else: - return self - - def exclude_classes(self, *classes): - doc = self._document - if doc._meta.get('allow_inheritance') is True: - queryset = self.clone() - class_names = [cls._class_name for cls in classes] - allowed_class_names = [name for name in self._document._subclasses if name in class_names] - if len(allowed_class_names) == 1: - queryset._initial_query = {"_cls": {"$ne": allowed_class_names[0]}} - else: - queryset._initial_query = {"_cls": {"$nin": allowed_class_names}} - return queryset - else: - return self - - def clone(self): - """Creates a copy of the current - :class:`~mongoengine.queryset.QuerySet` - - .. versionadded:: 0.5 - """ - c = self.__class__(self._document, self._collection_obj) - - copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', - '_where_clause', '_loaded_fields', '_ordering', '_snapshot', - '_timeout', '_class_check', '_slave_okay', '_read_preference', - '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', - '_limit', '_skip', '_hint', '_auto_dereference') - - for prop in copy_props: - val = getattr(self, prop) - setattr(c, prop, copy.copy(val)) - - if self._cursor_obj: - c._cursor_obj = self._cursor_obj.clone() - - return c - - def select_related(self, max_depth=1): - """Handles dereferencing of :class:`~bson.dbref.DBRef` objects or - :class:`~bson.object_id.ObjectId` a maximum depth in order to cut down - the number queries to mongodb. - - .. versionadded:: 0.5 - """ - # Make select related work the same for querysets - max_depth += 1 - queryset = self.clone() - return queryset._dereference(queryset, max_depth=max_depth) - - def limit(self, n): - """Limit the number of returned documents to `n`. This may also be - achieved using array-slicing syntax (e.g. ``User.objects[:5]``). - - :param n: the maximum number of objects to return - """ - queryset = self.clone() - if n == 0: - queryset._cursor.limit(1) - else: - queryset._cursor.limit(n) - queryset._limit = n - # Return self to allow chaining - return queryset - - def skip(self, n): - """Skip `n` documents before returning the results. This may also be - achieved using array-slicing syntax (e.g. ``User.objects[5:]``). - - :param n: the number of objects to skip before returning results - """ - queryset = self.clone() - queryset._cursor.skip(n) - queryset._skip = n - return queryset - - def hint(self, index=None): - """Added 'hint' support, telling Mongo the proper index to use for the - query. - - Judicious use of hints can greatly improve query performance. When - doing a query on multiple fields (at least one of which is indexed) - pass the indexed field as a hint to the query. - - Hinting will not do anything if the corresponding index does not exist. - The last hint applied to this cursor takes precedence over all others. - - .. versionadded:: 0.5 - """ - queryset = self.clone() - queryset._cursor.hint(index) - queryset._hint = index - return queryset - - def distinct(self, field): - """Return a list of distinct values for a given field. - - :param field: the field to select distinct values from - - .. note:: This is a command and won't take ordering or limit into - account. - - .. versionadded:: 0.4 - .. versionchanged:: 0.5 - Fixed handling references - .. versionchanged:: 0.6 - Improved db_field refrence handling - """ - queryset = self.clone() - try: - field = self._fields_to_dbfields([field]).pop() - finally: - return self._dereference(queryset._cursor.distinct(field), 1, - name=field, instance=self._document) - - def only(self, *fields): - """Load only a subset of this document's fields. :: - - post = BlogPost.objects(...).only("title", "author.name") - - .. note :: `only()` is chainable and will perform a union :: - So with the following it will fetch both: `title` and `author.name`:: - - post = BlogPost.objects.only("title").only("author.name") - - :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any - field filters. - - :param fields: fields to include - - .. versionadded:: 0.3 - .. versionchanged:: 0.5 - Added subfield support - """ - fields = dict([(f, QueryFieldList.ONLY) for f in fields]) - return self.fields(True, **fields) - - def exclude(self, *fields): - """Opposite to .only(), exclude some document's fields. :: - - post = BlogPost.objects(...).exclude("comments") - - .. note :: `exclude()` is chainable and will perform a union :: - So with the following it will exclude both: `title` and `author.name`:: - - post = BlogPost.objects.exclude("title").exclude("author.name") - - :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any - field filters. - - :param fields: fields to exclude - - .. versionadded:: 0.5 - """ - fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) - return self.fields(**fields) - - def fields(self, _only_called=False, **kwargs): - """Manipulate how you load this document's fields. Used by `.only()` - and `.exclude()` to manipulate which fields to retrieve. Fields also - allows for a greater level of control for example: - - Retrieving a Subrange of Array Elements: - - You can use the $slice operator to retrieve a subrange of elements in - an array. For example to get the first 5 comments:: - - post = BlogPost.objects(...).fields(slice__comments=5) - - :param kwargs: A dictionary identifying what to include - - .. versionadded:: 0.5 - """ - - # Check for an operator and transform to mongo-style if there is - operators = ["slice"] - cleaned_fields = [] - for key, value in kwargs.items(): - parts = key.split('__') - op = None - if parts[0] in operators: - op = parts.pop(0) - value = {'$' + op: value} - key = '.'.join(parts) - cleaned_fields.append((key, value)) - - fields = sorted(cleaned_fields, key=operator.itemgetter(1)) - queryset = self.clone() - for value, group in itertools.groupby(fields, lambda x: x[1]): - fields = [field for field, value in group] - fields = queryset._fields_to_dbfields(fields) - queryset._loaded_fields += QueryFieldList(fields, value=value, _only_called=_only_called) - - return queryset - - def all_fields(self): - """Include all fields. Reset all previously calls of .only() or - .exclude(). :: - - post = BlogPost.objects.exclude("comments").all_fields() - - .. versionadded:: 0.5 - """ - queryset = self.clone() - queryset._loaded_fields = QueryFieldList( - always_include=queryset._loaded_fields.always_include) - return queryset - - def order_by(self, *keys): - """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The - order may be specified by prepending each of the keys by a + or a -. - Ascending order is assumed. - - :param keys: fields to order the query results by; keys may be - prefixed with **+** or **-** to determine the ordering direction - """ - queryset = self.clone() - queryset._ordering = queryset._get_order_by(keys) - return queryset - - def explain(self, format=False): - """Return an explain plan record for the - :class:`~mongoengine.queryset.QuerySet`\ 's cursor. - - :param format: format the plan before returning it - """ - plan = self._cursor.explain() - if format: - plan = pprint.pformat(plan) - return plan - - def snapshot(self, enabled): - """Enable or disable snapshot mode when querying. - - :param enabled: whether or not snapshot mode is enabled - - ..versionchanged:: 0.5 - made chainable - """ - queryset = self.clone() - queryset._snapshot = enabled - return queryset - - def timeout(self, enabled): - """Enable or disable the default mongod timeout when querying. - - :param enabled: whether or not the timeout is used - - ..versionchanged:: 0.5 - made chainable - """ - queryset = self.clone() - queryset._timeout = enabled - return queryset - - def slave_okay(self, enabled): - """Enable or disable the slave_okay when querying. - - :param enabled: whether or not the slave_okay is enabled - """ - queryset = self.clone() - queryset._slave_okay = enabled - return queryset - - def read_preference(self, read_preference): - """Change the read_preference when querying. - - :param read_preference: override ReplicaSetConnection-level - preference. - """ - validate_read_preference('read_preference', read_preference) - queryset = self.clone() - queryset._read_preference = read_preference - return queryset - - def scalar(self, *fields): - """Instead of returning Document instances, return either a specific - value or a tuple of values in order. - - Can be used along with - :func:`~mongoengine.queryset.QuerySet.no_dereference` to turn off - dereferencing. - - .. note:: This effects all results and can be unset by calling - ``scalar`` without arguments. Calls ``only`` automatically. - - :param fields: One or more fields to return instead of a Document. - """ - queryset = self.clone() - queryset._scalar = list(fields) - - if fields: - queryset = queryset.only(*fields) - else: - queryset = queryset.all_fields() - - return queryset - - def values_list(self, *fields): - """An alias for scalar""" - return self.scalar(*fields) - - def as_pymongo(self, coerce_types=False): - """Instead of returning Document instances, return raw values from - pymongo. - - :param coerce_type: Field types (if applicable) would be use to - coerce types. - """ - queryset = self.clone() - queryset._as_pymongo = True - queryset._as_pymongo_coerce = coerce_types - return queryset - - # JSON Helpers - - def to_json(self): - """Converts a queryset to JSON""" - return json_util.dumps(self.as_pymongo()) - - def from_json(self, json_data): - """Converts json data to unsaved objects""" - son_data = json_util.loads(json_data) - return [self._document._from_son(data) for data in son_data] - - # JS functionality - - def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, - scope=None): - """Perform a map/reduce query using the current query spec - and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, - it must be the last call made, as it does not return a maleable - ``QuerySet``. - - See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` - and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` - tests in ``tests.queryset.QuerySetTest`` for usage examples. - - :param map_f: map function, as :class:`~bson.code.Code` or string - :param reduce_f: reduce function, as - :class:`~bson.code.Code` or string - :param output: output collection name, if set to 'inline' will try to - use :class:`~pymongo.collection.Collection.inline_map_reduce` - This can also be a dictionary containing output options - see: http://docs.mongodb.org/manual/reference/commands/#mapReduce - :param finalize_f: finalize function, an optional function that - performs any post-reduction processing. - :param scope: values to insert into map/reduce global scope. Optional. - :param limit: number of objects from current query to provide - to map/reduce method - - Returns an iterator yielding - :class:`~mongoengine.document.MapReduceDocument`. - - .. note:: - - Map/Reduce changed in server version **>= 1.7.4**. The PyMongo - :meth:`~pymongo.collection.Collection.map_reduce` helper requires - PyMongo version **>= 1.11**. - - .. versionchanged:: 0.5 - - removed ``keep_temp`` keyword argument, which was only relevant - for MongoDB server versions older than 1.7.4 - - .. versionadded:: 0.3 - """ - queryset = self.clone() - - MapReduceDocument = _import_class('MapReduceDocument') - - if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.7.1") - - map_f_scope = {} - if isinstance(map_f, Code): - map_f_scope = map_f.scope - map_f = unicode(map_f) - map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) - - reduce_f_scope = {} - if isinstance(reduce_f, Code): - reduce_f_scope = reduce_f.scope - reduce_f = unicode(reduce_f) - reduce_f_code = queryset._sub_js_fields(reduce_f) - reduce_f = Code(reduce_f_code, reduce_f_scope) - - mr_args = {'query': queryset._query} - - if finalize_f: - finalize_f_scope = {} - if isinstance(finalize_f, Code): - finalize_f_scope = finalize_f.scope - finalize_f = unicode(finalize_f) - finalize_f_code = queryset._sub_js_fields(finalize_f) - finalize_f = Code(finalize_f_code, finalize_f_scope) - mr_args['finalize'] = finalize_f - - if scope: - mr_args['scope'] = scope - - if limit: - mr_args['limit'] = limit - - if output == 'inline' and not queryset._ordering: - map_reduce_function = 'inline_map_reduce' - else: - map_reduce_function = 'map_reduce' - mr_args['out'] = output - - results = getattr(queryset._collection, map_reduce_function)( - map_f, reduce_f, **mr_args) - - if map_reduce_function == 'map_reduce': - results = results.find() - - if queryset._ordering: - results = results.sort(queryset._ordering) - - for doc in results: - yield MapReduceDocument(queryset._document, queryset._collection, - doc['_id'], doc['value']) - - def exec_js(self, code, *fields, **options): - """Execute a Javascript function on the server. A list of fields may be - provided, which will be translated to their correct names and supplied - as the arguments to the function. A few extra variables are added to - the function's scope: ``collection``, which is the name of the - collection in use; ``query``, which is an object representing the - current query; and ``options``, which is an object containing any - options specified as keyword arguments. - - As fields in MongoEngine may use different names in the database (set - using the :attr:`db_field` keyword argument to a :class:`Field` - constructor), a mechanism exists for replacing MongoEngine field names - with the database field names in Javascript code. When accessing a - field, use square-bracket notation, and prefix the MongoEngine field - name with a tilde (~). - - :param code: a string of Javascript code to execute - :param fields: fields that you will be using in your function, which - will be passed in to your function as arguments - :param options: options that you want available to the function - (accessed in Javascript through the ``options`` object) - """ - queryset = self.clone() - - code = queryset._sub_js_fields(code) - - fields = [queryset._document._translate_field_name(f) for f in fields] - collection = queryset._document._get_collection_name() - - scope = { - 'collection': collection, - 'options': options or {}, - } - - query = queryset._query - if queryset._where_clause: - query['$where'] = queryset._where_clause - - scope['query'] = query - code = Code(code, scope=scope) - - db = queryset._document._get_db() - return db.eval(code, *fields) - - def where(self, where_clause): - """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript - expression). Performs automatic field name substitution like - :meth:`mongoengine.queryset.Queryset.exec_js`. - - .. note:: When using this mode of query, the database will call your - function, or evaluate your predicate clause, for each object - in the collection. - - .. versionadded:: 0.5 - """ - queryset = self.clone() - where_clause = queryset._sub_js_fields(where_clause) - queryset._where_clause = where_clause - return queryset - - def sum(self, field): - """Sum over the values of the specified field. - - :param field: the field to sum over; use dot-notation to refer to - embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. - """ - map_func = Code(""" - function() { - emit(1, this[field] || 0); - } - """, scope={'field': field}) - - reduce_func = Code(""" - function(key, values) { - var sum = 0; - for (var i in values) { - sum += values[i]; - } - return sum; - } - """) - - for result in self.map_reduce(map_func, reduce_func, output='inline'): - return result.value - else: - return 0 - - def average(self, field): - """Average over the values of the specified field. - - :param field: the field to average over; use dot-notation to refer to - embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. - """ - map_func = Code(""" - function() { - if (this.hasOwnProperty(field)) - emit(1, {t: this[field] || 0, c: 1}); - } - """, scope={'field': field}) - - reduce_func = Code(""" - function(key, values) { - var out = {t: 0, c: 0}; - for (var i in values) { - var value = values[i]; - out.t += value.t; - out.c += value.c; - } - return out; - } - """) - - finalize_func = Code(""" - function(key, value) { - return value.t / value.c; - } - """) - - for result in self.map_reduce(map_func, reduce_func, - finalize_f=finalize_func, output='inline'): - return result.value - else: - return 0 - - def item_frequencies(self, field, normalize=False, map_reduce=True): - """Returns a dictionary of all items present in a field across - the whole queried set of documents, and their corresponding frequency. - This is useful for generating tag clouds, or searching documents. - - .. note:: - - Can only do direct simple mappings and cannot map across - :class:`~mongoengine.fields.ReferenceField` or - :class:`~mongoengine.fields.GenericReferenceField` for more complex - counting a manual map reduce call would is required. - - If the field is a :class:`~mongoengine.fields.ListField`, the items within - each list will be counted individually. - - :param field: the field to use - :param normalize: normalize the results so they add to 1.0 - :param map_reduce: Use map_reduce over exec_js - - .. versionchanged:: 0.5 defaults to map_reduce and can handle embedded - document lookups - """ - if map_reduce: - return self._item_frequencies_map_reduce(field, - normalize=normalize) - return self._item_frequencies_exec_js(field, normalize=normalize) - - # Iterator helpers - - def next(self): - """Wrap the result in a :class:`~mongoengine.Document` object. - """ - if self._limit == 0 or self._none: - raise StopIteration - - raw_doc = self._cursor.next() - if self._as_pymongo: - return self._get_as_pymongo(raw_doc) - doc = self._document._from_son(raw_doc, - _auto_dereference=self._auto_dereference) - if self._scalar: - return self._get_scalar(doc) - - return doc - - def rewind(self): - """Rewind the cursor to its unevaluated state. - - .. versionadded:: 0.3 - """ - self._iter = False - self._cursor.rewind() - - # Properties - - @property - def _collection(self): - """Property that returns the collection object. This allows us to - perform operations only if the collection is accessed. - """ - return self._collection_obj - - @property - def _cursor_args(self): - cursor_args = { - 'snapshot': self._snapshot, - 'timeout': self._timeout - } - if self._read_preference is not None: - cursor_args['read_preference'] = self._read_preference - else: - cursor_args['slave_okay'] = self._slave_okay - if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields.as_dict() - return cursor_args - - @property - def _cursor(self): - if self._cursor_obj is None: - - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) - # Apply where clauses to cursor - if self._where_clause: - where_clause = self._sub_js_fields(self._where_clause) - self._cursor_obj.where(where_clause) - - if self._ordering: - # Apply query ordering - self._cursor_obj.sort(self._ordering) - elif self._ordering == None and self._document._meta['ordering']: - # Otherwise, apply the ordering from the document model - order = self._get_order_by(self._document._meta['ordering']) - self._cursor_obj.sort(order) - - if self._limit is not None: - self._cursor_obj.limit(self._limit) - - if self._skip is not None: - self._cursor_obj.skip(self._skip) - - if self._hint != -1: - self._cursor_obj.hint(self._hint) - - return self._cursor_obj - - def __deepcopy__(self, memo): - """Essential for chained queries with ReferenceFields involved""" - return self.clone() - - @property - def _query(self): - if self._mongo_query is None: - self._mongo_query = self._query_obj.to_query(self._document) - if self._class_check: - self._mongo_query.update(self._initial_query) - return self._mongo_query - - @property - def _dereference(self): - if not self.__dereference: - self.__dereference = _import_class('DeReference')() - return self.__dereference - - def no_dereference(self): - """Turn off any dereferencing for the results of this queryset. - """ - queryset = self.clone() - queryset._auto_dereference = False - return queryset - - # Helper Functions - - def _item_frequencies_map_reduce(self, field, normalize=False): - map_func = """ - function() { - var path = '{{~%(field)s}}'.split('.'); - var field = this; - - for (p in path) { - if (typeof field != 'undefined') - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - field.forEach(function(item) { - emit(item, 1); - }); - } else if (typeof field != 'undefined') { - emit(field, 1); - } else { - emit(null, 1); - } - } - """ % dict(field=field) - reduce_func = """ - function(key, values) { - var total = 0; - var valuesSize = values.length; - for (var i=0; i < valuesSize; i++) { - total += parseInt(values[i], 10); - } - return total; - } - """ - values = self.map_reduce(map_func, reduce_func, 'inline') - frequencies = {} - for f in values: - key = f.key - if isinstance(key, float): - if int(key) == key: - key = int(key) - frequencies[key] = int(f.value) - - if normalize: - count = sum(frequencies.values()) - frequencies = dict([(k, float(v) / count) - for k, v in frequencies.items()]) - - return frequencies - - def _item_frequencies_exec_js(self, field, normalize=False): - """Uses exec_js to execute""" - freq_func = """ - function(path) { - var path = path.split('.'); - - var total = 0.0; - db[collection].find(query).forEach(function(doc) { - var field = doc; - for (p in path) { - if (field) - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - total += field.length; - } else { - total++; - } - }); - - var frequencies = {}; - var types = {}; - var inc = 1.0; - - db[collection].find(query).forEach(function(doc) { - field = doc; - for (p in path) { - if (field) - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - field.forEach(function(item) { - frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); - }); - } else { - var item = field; - types[item] = item; - frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); - } - }); - return [total, frequencies, types]; - } - """ - total, data, types = self.exec_js(freq_func, field) - values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) - - if normalize: - values = dict([(k, float(v) / total) for k, v in values.items()]) - - frequencies = {} - for k, v in values.iteritems(): - if isinstance(k, float): - if int(k) == k: - k = int(k) - - frequencies[k] = v - - return frequencies - - def _fields_to_dbfields(self, fields): - """Translate fields paths to its db equivalents""" - ret = [] - for field in fields: - field = ".".join(f.db_field for f in - self._document._lookup_field(field.split('.'))) - ret.append(field) - return ret - - def _get_order_by(self, keys): - """Creates a list of order by fields - """ - key_list = [] - for key in keys: - if not key: - continue - direction = pymongo.ASCENDING - if key[0] == '-': - direction = pymongo.DESCENDING - if key[0] in ('-', '+'): - key = key[1:] - key = key.replace('__', '.') + data = [] + for i in xrange(REPR_OUTPUT_SIZE + 1): try: - key = self._document._translate_field_name(key) - except: - pass - key_list.append((key, direction)) + data.append(self.next()) + except StopIteration: + break + if len(data) > REPR_OUTPUT_SIZE: + data[-1] = "...(remaining elements truncated)..." - if self._cursor_obj: - self._cursor_obj.sort(key_list) - return key_list + self.rewind() + return repr(data) - def _get_scalar(self, doc): - - def lookup(obj, name): - chunks = name.split('__') - for chunk in chunks: - obj = getattr(obj, chunk) - return obj - - data = [lookup(doc, n) for n in self._scalar] - if len(data) == 1: - return data[0] - - return tuple(data) - - def _get_as_pymongo(self, row): - # Extract which fields paths we should follow if .fields(...) was - # used. If not, handle all fields. - if not getattr(self, '__as_pymongo_fields', None): - self.__as_pymongo_fields = [] - for field in self._loaded_fields.fields - set(['_cls', '_id']): - self.__as_pymongo_fields.append(field) - while '.' in field: - field, _ = field.rsplit('.', 1) - self.__as_pymongo_fields.append(field) - - all_fields = not self.__as_pymongo_fields - - def clean(data, path=None): - path = path or '' - - if isinstance(data, dict): - new_data = {} - for key, value in data.iteritems(): - new_path = '%s.%s' % (path, key) if path else key - - if all_fields: - include_field = True - elif self._loaded_fields.value == QueryFieldList.ONLY: - include_field = new_path in self.__as_pymongo_fields - else: - include_field = new_path not in self.__as_pymongo_fields - - if include_field: - new_data[key] = clean(value, path=new_path) - data = new_data - elif isinstance(data, list): - data = [clean(d, path=path) for d in data] - else: - if self._as_pymongo_coerce: - # If we need to coerce types, we need to determine the - # type of this field and use the corresponding - # .to_python(...) - from mongoengine.fields import EmbeddedDocumentField - obj = self._document - for chunk in path.split('.'): - obj = getattr(obj, chunk, None) - if obj is None: - break - elif isinstance(obj, EmbeddedDocumentField): - obj = obj.document_type - if obj and data is not None: - data = obj.to_python(data) - return data - return clean(row) - - def _sub_js_fields(self, code): - """When fields are specified with [~fieldname] syntax, where - *fieldname* is the Python name of a field, *fieldname* will be - substituted for the MongoDB name of the field (specified using the - :attr:`name` keyword argument in a field's constructor). - """ - def field_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = self._document._lookup_field(field_name) - # Substitute the correct name for the field into the javascript - return u'["%s"]' % fields[-1].db_field - - def field_path_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = self._document._lookup_field(field_name) - # Substitute the correct name for the field into the javascript - return ".".join([f.db_field for f in fields]) - - code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) - code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, - code) - return code - - # Deprecated - def ensure_index(self, **kwargs): - """Deprecated use :func:`Document.ensure_index`""" - msg = ("Doc.objects()._ensure_index() is deprecated. " - "Use Doc.ensure_index() instead.") - warnings.warn(msg, DeprecationWarning) - self._document.__class__.ensure_index(**kwargs) - return self - - def _ensure_indexes(self): - """Deprecated use :func:`~Document.ensure_indexes`""" - msg = ("Doc.objects()._ensure_indexes() is deprecated. " - "Use Doc.ensure_indexes() instead.") - warnings.warn(msg, DeprecationWarning) - self._document.__class__.ensure_indexes() + def __iter__(self): + queryset = self + if queryset._iter: + queryset = self.clone() + queryset.rewind() + return queryset diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 352774f..2ee7e38 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -43,11 +43,11 @@ def query(_doc_cls=None, _field_operation=False, **query): parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is op = None - if parts[-1] in MATCH_OPERATORS: + if len(parts) > 1 and parts[-1] in MATCH_OPERATORS: op = parts.pop() negate = False - if parts[-1] == 'not': + if len(parts) > 1 and parts[-1] == 'not': parts.pop() negate = True @@ -182,6 +182,7 @@ def update(_doc_cls=None, **update): parts = [] cleaned_fields = [] + appended_sub_field = False for field in fields: append_field = True if isinstance(field, basestring): @@ -193,21 +194,30 @@ def update(_doc_cls=None, **update): else: parts.append(field.db_field) if append_field: + appended_sub_field = False cleaned_fields.append(field) + if hasattr(field, 'field'): + cleaned_fields.append(field.field) + appended_sub_field = True # Convert value to proper value - field = cleaned_fields[-1] + if appended_sub_field: + field = cleaned_fields[-2] + else: + field = cleaned_fields[-1] if op in (None, 'set', 'push', 'pull'): if field.required or value is not None: value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] - elif op == 'addToSet': + elif op in ('addToSet', 'setOnInsert'): if isinstance(value, (list, tuple, set)): value = [field.prepare_query_value(op, v) for v in value] elif field.required or value is not None: value = field.prepare_query_value(op, value) + elif op == "unset": + value = 1 if match: match = '$' + match @@ -221,11 +231,24 @@ def update(_doc_cls=None, **update): if 'pull' in op and '.' in key: # Dot operators don't work on pull operations - # it uses nested dict syntax + # unless they point to a list field + # Otherwise it uses nested dict syntax if op == 'pullAll': raise InvalidQueryError("pullAll operations only support " "a single field depth") + # Look for the last list field and use dot notation until there + field_classes = [c.__class__ for c in cleaned_fields] + field_classes.reverse() + ListField = _import_class('ListField') + if ListField in field_classes: + # Join all fields via dot notation to the last ListField + # Then process as normal + last_listField = len(cleaned_fields) - field_classes.index(ListField) + key = ".".join(parts[:last_listField]) + parts = parts[last_listField:] + parts.insert(0, key) + parts.reverse() for key in parts: value = {key: value} diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 3783e7a..41f4ebf 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -28,7 +28,7 @@ class DuplicateQueryConditionsError(InvalidQueryError): class SimplificationVisitor(QNodeVisitor): - """Simplifies query trees by combining unnecessary 'and' connection nodes + """Simplifies query trees by combinging unnecessary 'and' connection nodes into a single Q-object. """ @@ -73,16 +73,6 @@ class QueryCompilerVisitor(QNodeVisitor): def visit_combination(self, combination): operator = "$and" if combination.operation == combination.OR: - keys = set([key for q in combination.children for key in q.keys()]) - if len(keys) == 1: - field = keys.pop() - if not field.startswith('$') and not any([isinstance(q[field], dict) for q in combination.children]): - return { - field: { - '$in': [q[field] for q in combination.children if field in q] - } - } - operator = "$or" return {operator: combination.children} diff --git a/mongoengine/signals.py b/mongoengine/signals.py index 1c256da..06fb8b4 100644 --- a/mongoengine/signals.py +++ b/mongoengine/signals.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- -__all__ = ['pre_save', 'post_save', 'pre_delete', 'post_delete'] +__all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', + 'post_save', 'pre_delete', 'post_delete'] signals_available = False try: @@ -35,7 +36,10 @@ except ImportError: # not put signals in here. Create your own namespace instead. _signals = Namespace() +pre_init = _signals.signal('pre_init') +post_init = _signals.signal('post_init') pre_save = _signals.signal('pre_save') +pre_save_post_validation = _signals.signal('pre_save_post_validation') post_save = _signals.signal('post_save') pre_delete = _signals.signal('pre_delete') post_delete = _signals.signal('post_delete') diff --git a/python-mongoengine.spec b/python-mongoengine.spec index 4eaba4d..b9c45ef 100644 --- a/python-mongoengine.spec +++ b/python-mongoengine.spec @@ -5,7 +5,7 @@ %define srcname mongoengine Name: python-%{srcname} -Version: 0.8.2 +Version: 0.8.4 Release: 1%{?dist} Summary: A Python Document-Object Mapper for working with MongoDB diff --git a/setup.py b/setup.py index effb6f1..85707d0 100644 --- a/setup.py +++ b/setup.py @@ -48,17 +48,15 @@ CLASSIFIERS = [ 'Topic :: Software Development :: Libraries :: Python Modules', ] -extra_opts = {} +extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} if sys.version_info[0] == 3: extra_opts['use_2to3'] = True - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6'] - extra_opts['packages'] = find_packages(exclude=('tests',)) + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'django>=1.5.1'] if "test" in sys.argv or "nosetests" in sys.argv: - extra_opts['packages'].append("tests") + extra_opts['packages'] = find_packages() extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} else: - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2==2.6', 'python-dateutil'] - extra_opts['packages'] = find_packages(exclude=('tests',)) + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6', 'python-dateutil'] setup(name='mongoengine', version=VERSION, diff --git a/tests/benchmark.py b/tests/benchmark.py deleted file mode 100644 index 89439f3..0000000 --- a/tests/benchmark.py +++ /dev/null @@ -1,90 +0,0 @@ -from mongoengine import * -from timeit import repeat -import unittest - -conn_settings = { - 'db': 'mongomallard-test', -} - -connect(**conn_settings) - -def timeit(f, n=10000): - return min(repeat(f, repeat=3, number=n))/float(n) - -class BenchmarkTestCase(unittest.TestCase): - def setUp(self): - pass - - def test_basic(self): - class Book(Document): - name = StringField() - pages = IntField() - tags = ListField(StringField()) - is_published = BooleanField() - - Book.drop_collection() - - create_book = lambda: Book(name='Always be closing', pages=100, tags=['self-help', 'sales'], is_published=True) - print 'Doc initialization: %.3fus' % (timeit(create_book, 1000) * 10**6) - - b = create_book() - - print 'Doc getattr: %.3fus' % (timeit(lambda: b.name, 10000) * 10**6) - - print 'Doc setattr: %.3fus' % (timeit(lambda: setattr(b, 'name', 'New name'), 10000) * 10**6) - - print 'Doc to mongo: %.3fus' % (timeit(b.to_mongo, 1000) * 10**6) - - def save_book(): - b._mark_as_changed('name') - b._mark_as_changed('tags') - b.save() - - save_book() - son = b.to_mongo() - - print 'Load from SON: %.3fus' % (timeit(lambda: Book._from_son(son), 1000) * 10**6) - - print 'Save to database: %.3fus' % (timeit(save_book, 100) * 10**6) - - print 'Load from database: %.3fus' % (timeit(lambda: Book.objects[0], 100) * 10**6) - - def test_embedded(self): - class Contact(EmbeddedDocument): - name = StringField() - title = StringField() - address = StringField() - - class Company(Document): - name = StringField() - contacts = ListField(EmbeddedDocumentField(Contact)) - - Company.drop_collection() - - def get_company(): - return Company( - name='Elastic', - contacts=[ - Contact( - name='Contact %d' % x, - title='CEO', - address='Address %d' % x, - ) - for x in range(1000)] - ) - - def create_company(): - c = get_company() - c.save() - c.delete() - - print 'Save/delete big object to database: %.3fms' % (timeit(create_company, 10) * 10**3) - - c = get_company().save() - - print 'Serialize big object from database: %.3fms' % (timeit(c.to_mongo, 100) * 10**3) - print 'Load big object from database: %.3fms' % (timeit(lambda: Company.objects[0], 100) * 10**3) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/document/delta.py b/tests/document/delta.py index 355717f..b4749f3 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -3,6 +3,7 @@ import sys sys.path[0:0] = [""] import unittest +from bson import SON from mongoengine import * from mongoengine.connection import get_db @@ -48,42 +49,41 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), set()) + self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(doc._delta(), ({}, {})) doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), set(['string_field'])) + self.assertEqual(doc._get_changed_fields(), ['string_field']) self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) - doc._changed_fields = set() + doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), set(['int_field'])) + self.assertEqual(doc._get_changed_fields(), ['int_field']) self.assertEqual(doc._delta(), ({'int_field': 1}, {})) - doc._changed_fields = set() + doc._changed_fields = [] dict_value = {'hello': 'world', 'ping': 'pong'} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) + self.assertEqual(doc._get_changed_fields(), ['dict_field']) self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) - doc._changed_fields = set() + doc._changed_fields = [] list_value = ['1', 2, {'hello': 'world'}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), set(['list_field'])) + self.assertEqual(doc._get_changed_fields(), ['list_field']) self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) # Test unsetting - doc._changed_fields = set() + doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) + self.assertEqual(doc._get_changed_fields(), ['dict_field']) self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) - doc._changed_fields = set() + doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), set(['list_field'])) + self.assertEqual(doc._get_changed_fields(), ['list_field']) self.assertEqual(doc._delta(), ({}, {'list_field': 1})) - @unittest.skip("not fully implemented") def test_delta_recursive(self): self.delta_recursive(Document, EmbeddedDocument) self.delta_recursive(DynamicDocument, EmbeddedDocument) @@ -110,7 +110,7 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), set()) + self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(doc._delta(), ({}, {})) embedded_1 = Embedded() @@ -120,7 +120,7 @@ class DeltaTest(unittest.TestCase): embedded_1.list_field = ['1', 2, {'hello': 'world'}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), set(['embedded_field'])) + self.assertEqual(doc._get_changed_fields(), ['embedded_field']) embedded_delta = { 'string_field': 'hello', @@ -137,7 +137,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.dict_field = {} self.assertEqual(doc._get_changed_fields(), - set(['embedded_field.dict_field'])) + ['embedded_field.dict_field']) self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) doc.save() @@ -146,7 +146,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = [] self.assertEqual(doc._get_changed_fields(), - set(['embedded_field.list_field'])) + ['embedded_field.list_field']) self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) doc.save() @@ -161,7 +161,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = ['1', 2, embedded_2] self.assertEqual(doc._get_changed_fields(), - set(['embedded_field.list_field'])) + ['embedded_field.list_field']) self.assertEqual(doc.embedded_field._delta(), ({ 'list_field': ['1', 2, { @@ -193,7 +193,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'world' self.assertEqual(doc._get_changed_fields(), - set(['embedded_field.list_field.2.string_field'])) + ['embedded_field.list_field.2.string_field']) self.assertEqual(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) self.assertEqual(doc._delta(), @@ -207,7 +207,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - set(['embedded_field.list_field'])) + ['embedded_field.list_field']) self.assertEqual(doc.embedded_field._delta(), ({ 'list_field': ['1', 2, { '_cls': 'Embedded', @@ -270,7 +270,7 @@ class DeltaTest(unittest.TestCase): doc.dict_field['Embedded'].string_field = 'Hello World' self.assertEqual(doc._get_changed_fields(), - set(['dict_field.Embedded.string_field'])) + ['dict_field.Embedded.string_field']) self.assertEqual(doc._delta(), ({'dict_field.Embedded.string_field': 'Hello World'}, {})) @@ -313,29 +313,24 @@ class DeltaTest(unittest.TestCase): self.circular_reference_deltas_2(DynamicDocument, Document) self.circular_reference_deltas_2(DynamicDocument, DynamicDocument) - def circular_reference_deltas_2(self, DocClass1, DocClass2): + def circular_reference_deltas_2(self, DocClass1, DocClass2, dbref=True): class Person(DocClass1): name = StringField() - owns = ListField(ReferenceField('Organization')) - employer = ReferenceField('Organization') + owns = ListField(ReferenceField('Organization', dbref=dbref)) + employer = ReferenceField('Organization', dbref=dbref) class Organization(DocClass2): name = StringField() - owner = ReferenceField('Person') - employees = ListField(ReferenceField('Person')) + owner = ReferenceField('Person', dbref=dbref) + employees = ListField(ReferenceField('Person', dbref=dbref)) Person.drop_collection() Organization.drop_collection() - person = Person(name="owner") - person.save() - - employee = Person(name="employee") - employee.save() - - organization = Organization(name="company") - organization.save() + person = Person(name="owner").save() + employee = Person(name="employee").save() + organization = Organization(name="company").save() person.owns.append(organization) organization.owner = person @@ -355,6 +350,8 @@ class DeltaTest(unittest.TestCase): self.assertEqual(o.owner, p) self.assertEqual(e.employer, o) + return person, organization, employee + def test_delta_db_field(self): self.delta_db_field(Document) self.delta_db_field(DynamicDocument) @@ -372,39 +369,39 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), set()) + self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(doc._delta(), ({}, {})) doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), set(['string_field'])) + self.assertEqual(doc._get_changed_fields(), ['db_string_field']) self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {})) - doc._changed_fields = set() + doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), set(['int_field'])) + self.assertEqual(doc._get_changed_fields(), ['db_int_field']) self.assertEqual(doc._delta(), ({'db_int_field': 1}, {})) - doc._changed_fields = set() + doc._changed_fields = [] dict_value = {'hello': 'world', 'ping': 'pong'} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) + self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {})) - doc._changed_fields = set() + doc._changed_fields = [] list_value = ['1', 2, {'hello': 'world'}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), set(['list_field'])) + self.assertEqual(doc._get_changed_fields(), ['db_list_field']) self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {})) # Test unsetting - doc._changed_fields = set() + doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) + self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1})) - doc._changed_fields = set() + doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), set(['list_field'])) + self.assertEqual(doc._get_changed_fields(), ['db_list_field']) self.assertEqual(doc._delta(), ({}, {'db_list_field': 1})) # Test it saves that data @@ -416,15 +413,13 @@ class DeltaTest(unittest.TestCase): doc.dict_field = {'hello': 'world'} doc.list_field = ['1', 2, {'hello': 'world'}] doc.save() - #doc = doc.reload(10) - doc = doc.reload() + doc = doc.reload(10) self.assertEqual(doc.string_field, 'hello') self.assertEqual(doc.int_field, 1) self.assertEqual(doc.dict_field, {'hello': 'world'}) self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}]) - @unittest.skip("not fully implemented") def test_delta_recursive_db_field(self): self.delta_recursive_db_field(Document, EmbeddedDocument) self.delta_recursive_db_field(Document, DynamicEmbeddedDocument) @@ -452,7 +447,7 @@ class DeltaTest(unittest.TestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), set()) + self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(doc._delta(), ({}, {})) embedded_1 = Embedded() @@ -462,7 +457,7 @@ class DeltaTest(unittest.TestCase): embedded_1.list_field = ['1', 2, {'hello': 'world'}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), set(['embedded_field'])) + self.assertEqual(doc._get_changed_fields(), ['db_embedded_field']) embedded_delta = { 'db_string_field': 'hello', @@ -490,7 +485,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = [] self.assertEqual(doc._get_changed_fields(), - set(['db_embedded_field.db_list_field'])) + ['db_embedded_field.db_list_field']) self.assertEqual(doc.embedded_field._delta(), ({}, {'db_list_field': 1})) self.assertEqual(doc._delta(), @@ -608,7 +603,6 @@ class DeltaTest(unittest.TestCase): self.assertEqual(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1})) - @unittest.skip("DynamicDocument not implemented") def test_delta_for_dynamic_documents(self): class Person(DynamicDocument): name = StringField() @@ -617,13 +611,13 @@ class DeltaTest(unittest.TestCase): Person.drop_collection() p = Person(name="James", age=34) - self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', - '_cls': 'Person'}, {})) + self.assertEqual(p._delta(), ( + SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {})) p.doc = 123 del(p.doc) - self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', - '_cls': 'Person'}, {'doc': 1})) + self.assertEqual(p._delta(), ( + SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {})) p = Person() p.name = "Dean" @@ -635,16 +629,15 @@ class DeltaTest(unittest.TestCase): self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._delta(), ({'age': 24}, {})) - p = self.Person.objects(age=22).get() + p = Person.objects(age=22).get() p.age = 24 self.assertEqual(p.age, 24) self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._delta(), ({'age': 24}, {})) p.save() - self.assertEqual(1, self.Person.objects(age=24).count()) + self.assertEqual(1, Person.objects(age=24).count()) - @unittest.skip("DynamicDocument not implemented") def test_dynamic_delta(self): class Doc(DynamicDocument): @@ -690,6 +683,36 @@ class DeltaTest(unittest.TestCase): self.assertEqual(doc._get_changed_fields(), ['list_field']) self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + def test_delta_with_dbref_true(self): + person, organization, employee = self.circular_reference_deltas_2(Document, Document, True) + employee.name = 'test' + + self.assertEqual(organization._get_changed_fields(), []) + + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertEqual({}, updates) + + organization.employees.append(person) + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertTrue('employees' in updates) + + def test_delta_with_dbref_false(self): + person, organization, employee = self.circular_reference_deltas_2(Document, Document, False) + employee.name = 'test' + + self.assertEqual(organization._get_changed_fields(), []) + + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertEqual({}, updates) + + organization.employees.append(person) + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertTrue('employees' in updates) + if __name__ == '__main__': unittest.main() diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py index 05870ee..6263e68 100644 --- a/tests/document/dynamic.py +++ b/tests/document/dynamic.py @@ -8,7 +8,6 @@ from mongoengine.connection import get_db __all__ = ("DynamicTest", ) -@unittest.skip("DynamicDocument not implemented") class DynamicTest(unittest.TestCase): def setUp(self): diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 49fd7cb..ccf8463 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -156,6 +156,25 @@ class IndexesTest(unittest.TestCase): self.assertEqual([{'fields': [('_cls', 1), ('title', 1)]}], A._meta['index_specs']) + def test_index_no_cls(self): + """Ensure index specs are inhertited correctly""" + + class A(Document): + title = StringField() + meta = { + 'indexes': [ + {'fields': ('title',), 'cls': False}, + ], + 'allow_inheritance': True, + 'index_cls': False + } + + self.assertEqual([('title', 1)], A._meta['index_specs'][0]['fields']) + A._get_collection().drop_indexes() + A.ensure_indexes() + info = A._get_collection().index_information() + self.assertEqual(len(info.keys()), 2) + def test_build_index_spec_is_not_destructive(self): class MyDoc(Document): @@ -632,7 +651,6 @@ class IndexesTest(unittest.TestCase): pass Customer.drop_collection() - @unittest.skip("behavior differs") def test_unique_and_primary(self): """If you set a field as primary, then unexpected behaviour can occur. You won't create a duplicate but you will update an existing document. diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index d311538..5a48f75 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -182,10 +182,10 @@ class InheritanceTest(unittest.TestCase): self.assertEqual(['age', 'id', 'name', 'salary'], sorted(Employee._fields.keys())) - self.assertEqual(set(Person(name="Bob", age=35).to_mongo().keys()), - set(['_cls', 'name', 'age'])) - self.assertEqual(set(Employee(name="Bob", age=35, salary=0).to_mongo().keys()), - set(['_cls', 'name', 'age', 'salary'])) + self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), + ['_cls', 'name', 'age']) + self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), + ['_cls', 'name', 'age', 'salary']) self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) diff --git a/tests/document/instance.py b/tests/document/instance.py index 80a6130..a61c439 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -10,7 +10,8 @@ import uuid from datetime import datetime from bson import DBRef -from tests.fixtures import PickleEmbedded, PickleTest, PickleSignalsTest +from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, + PickleDyanmicEmbedded, PickleDynamicTest) from mongoengine import * from mongoengine.errors import (NotRegistered, InvalidDocumentError, @@ -390,25 +391,24 @@ class InstanceTest(unittest.TestCase): doc.embedded_field = embedded_1 doc.save() - doc = doc.reload() + doc = doc.reload(10) doc.list_field.append(1) doc.dict_field['woot'] = "woot" doc.embedded_field.list_field.append(1) doc.embedded_field.dict_field['woot'] = "woot" - self.assertEqual(doc._get_changed_fields(), set([ + self.assertEqual(doc._get_changed_fields(), [ 'list_field', 'dict_field', 'embedded_field.list_field', - 'embedded_field.dict_field'])) + 'embedded_field.dict_field']) doc.save() - doc = doc.reload() - self.assertEqual(doc._get_changed_fields(), set()) + doc = doc.reload(10) + self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(len(doc.list_field), 4) self.assertEqual(len(doc.dict_field), 2) self.assertEqual(len(doc.embedded_field.list_field), 4) self.assertEqual(len(doc.embedded_field.dict_field), 2) - @unittest.skip("not implemented") def test_dictionary_access(self): """Ensure that dictionary-style field access works properly. """ @@ -439,10 +439,17 @@ class InstanceTest(unittest.TestCase): class Employee(Person): salary = IntField() - self.assertEqual(set(Person(name="Bob", age=35).to_mongo().keys()), - set(['_cls', 'name', 'age'])) - self.assertEqual(set(Employee(name="Bob", age=35, salary=0).to_mongo().keys()), - set(['_cls', 'name', 'age', 'salary'])) + self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), + ['_cls', 'name', 'age']) + self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), + ['_cls', 'name', 'age', 'salary']) + + def test_embedded_document_to_mongo_id(self): + class SubDoc(EmbeddedDocument): + id = StringField(required=True) + + sub_doc = SubDoc(id="abc") + self.assertEqual(sub_doc.to_mongo().keys(), ['id']) def test_embedded_document(self): """Ensure that embedded documents are set up correctly. @@ -453,7 +460,6 @@ class InstanceTest(unittest.TestCase): self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) - @unittest.skip("not implemented") def test_embedded_document_instance(self): """Ensure that embedded documents can reference parent instance """ @@ -462,7 +468,6 @@ class InstanceTest(unittest.TestCase): class Doc(Document): embedded_field = EmbeddedDocumentField(Embedded) - meta = { 'cascade': True } Doc.drop_collection() Doc(embedded_field=Embedded(string="Hi")).save() @@ -470,7 +475,6 @@ class InstanceTest(unittest.TestCase): doc = Doc.objects.get() self.assertEqual(doc, doc.embedded_field._instance) - @unittest.skip("not implemented") def test_embedded_document_complex_instance(self): """Ensure that embedded documents in complex fields can reference parent instance""" @@ -627,7 +631,6 @@ class InstanceTest(unittest.TestCase): p0.name = 'wpjunior' p0.save() - @unittest.skip("FileField not implemented") def test_save_max_recursion_not_hit_with_file_field(self): class Foo(Document): @@ -776,7 +779,6 @@ class InstanceTest(unittest.TestCase): p1.reload() self.assertEqual(p1.name, p.parent.name) - @unittest.skip("not implemented") def test_update(self): """Ensure that an existing document is updated instead of be overwritten.""" @@ -891,6 +893,7 @@ class InstanceTest(unittest.TestCase): reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) + decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) @@ -1059,9 +1062,9 @@ class InstanceTest(unittest.TestCase): user = User.objects.first() # Even if stored as ObjectId's internally mongoengine uses DBRefs # As ObjectId's aren't automatically derefenced - #self.assertTrue(isinstance(user._data['orgs'][0], DBRef)) + self.assertTrue(isinstance(user._data['orgs'][0], DBRef)) self.assertTrue(isinstance(user.orgs[0], Organization)) - #self.assertTrue(isinstance(user._data['orgs'][0], Organization)) + self.assertTrue(isinstance(user._data['orgs'][0], Organization)) # Changing a value with query_counter() as q: @@ -1141,7 +1144,6 @@ class InstanceTest(unittest.TestCase): foo.save() self.assertEqual(1, q) - @unittest.skip("not implemented") def test_save_only_changed_fields_recursive(self): """Ensure save only sets / unsets changed fields """ @@ -1439,8 +1441,8 @@ class InstanceTest(unittest.TestCase): post_obj = BlogPost.objects.first() # Test laziness - #self.assertTrue(isinstance(post_obj._data['author'], - # bson.DBRef)) + self.assertTrue(isinstance(post_obj._data['author'], + bson.DBRef)) self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertEqual(post_obj.author.name, 'Test User') @@ -1464,7 +1466,6 @@ class InstanceTest(unittest.TestCase): self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - @unittest.skip("not implemented") def test_invalid_son(self): """Raise an error if loading invalid data""" class Occurrence(EmbeddedDocument): @@ -1808,7 +1809,6 @@ class InstanceTest(unittest.TestCase): self.assertTrue(u1 in all_user_set) - @unittest.skip("not implemented") def test_picklable(self): pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) @@ -1835,7 +1835,29 @@ class InstanceTest(unittest.TestCase): self.assertEqual(pickle_doc.string, "Two") self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) - @unittest.skip("not implemented") + def test_dynamic_document_pickle(self): + + pickle_doc = PickleDynamicTest(name="test", number=1, string="One", lists=['1', '2']) + pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar") + pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved + + pickle_doc.save() + + pickled_doc = pickle.dumps(pickle_doc) + resurrected = pickle.loads(pickled_doc) + + self.assertEqual(resurrected, pickle_doc) + self.assertEqual(resurrected._fields_ordered, + pickle_doc._fields_ordered) + self.assertEqual(resurrected._dynamic_fields.keys(), + pickle_doc._dynamic_fields.keys()) + + self.assertEqual(resurrected.embedded, pickle_doc.embedded) + self.assertEqual(resurrected.embedded._fields_ordered, + pickle_doc.embedded._fields_ordered) + self.assertEqual(resurrected.embedded._dynamic_fields.keys(), + pickle_doc.embedded._dynamic_fields.keys()) + def test_picklable_on_signals(self): pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2']) pickle_doc.embedded = PickleEmbedded() @@ -1896,7 +1918,6 @@ class InstanceTest(unittest.TestCase): self.assertEqual(Doc.objects(archived=False).count(), 1) - @unittest.skip("DynamicDocument not implemented") def test_can_save_false_values_dynamic(self): """Ensures you can save False values on dynamic docs""" class Doc(DynamicDocument): @@ -2036,7 +2057,6 @@ class InstanceTest(unittest.TestCase): self.assertEqual('testdb-1', B._meta.get('db_alias')) - @unittest.skip("not implemented") def test_db_ref_usage(self): """ DB Ref usage in dict_fields""" @@ -2115,7 +2135,6 @@ class InstanceTest(unittest.TestCase): })]), "1,2") - @unittest.skip("not implemented") def test_switch_db_instance(self): register_connection('testdb-1', 'mongoenginetest2') @@ -2187,10 +2206,9 @@ class InstanceTest(unittest.TestCase): user = User.objects.first() self.assertEqual("Ross", user.username) self.assertEqual(True, user.foo) - self.assertEqual("Bar", user._db_data["foo"]) - self.assertEqual([1, 2, 3], user._db_data["data"]) + self.assertEqual("Bar", user._data["foo"]) + self.assertEqual([1, 2, 3], user._data["data"]) - @unittest.skip("DynamicDocument not implemented") def test_spaces_in_keys(self): class Embedded(DynamicEmbeddedDocument): @@ -2207,7 +2225,6 @@ class InstanceTest(unittest.TestCase): one = Doc.objects.filter(**{'hello world': 1}).count() self.assertEqual(1, one) - @unittest.skip("not implemented") def test_shard_key(self): class LogEntry(Document): machine = StringField() @@ -2231,7 +2248,6 @@ class InstanceTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) - @unittest.skip("not implemented") def test_shard_key_primary(self): class LogEntry(Document): machine = StringField(primary_key=True) @@ -2255,7 +2271,6 @@ class InstanceTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) - @unittest.skip("not implemented") def test_kwargs_simple(self): class Embedded(EmbeddedDocument): @@ -2270,9 +2285,8 @@ class InstanceTest(unittest.TestCase): "doc": {"name": "embedded doc"}}) self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc.to_dict(), dict_doc.to_dict()) + self.assertEqual(classic_doc._data, dict_doc._data) - @unittest.skip("not implemented") def test_kwargs_complex(self): class Embedded(EmbeddedDocument): @@ -2290,9 +2304,8 @@ class InstanceTest(unittest.TestCase): {"name": "embedded doc2"}]}) self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc.to_dict(), dict_doc.to_dict()) + self.assertEqual(classic_doc._data, dict_doc._data) - @unittest.skip("not implemented") def test_positional_creation(self): """Ensure that document may be created using positional arguments. """ @@ -2300,7 +2313,6 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 42) - @unittest.skip("not implemented") def test_mixed_creation(self): """Ensure that document may be created using mixed arguments. """ @@ -2308,6 +2320,16 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 42) + def test_mixed_creation_dynamic(self): + """Ensure that document may be created using mixed arguments. + """ + class Person(DynamicDocument): + name = StringField() + + person = Person("Test User", age=42) + self.assertEqual(person.name, "Test User") + self.assertEqual(person.age, 42) + def test_bad_mixed_creation(self): """Ensure that document gives correct error when duplicating arguments """ @@ -2326,8 +2348,8 @@ class InstanceTest(unittest.TestCase): Person(name="Harry Potter").save() person = Person.objects.first() - self.assertTrue('id' in person.to_dict().keys()) - self.assertEqual(person.to_dict().get('id'), person.id) + self.assertTrue('id' in person._data.keys()) + self.assertEqual(person._data.get('id'), person.id) def test_complex_nesting_document_and_embedded_document(self): diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index 1f2d5c8..2b5d9a0 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -31,6 +31,10 @@ class TestJson(unittest.TestCase): doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) + doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) + expected_json = """{"embedded_field":{"string":"Hi"},"string":"Hi"}""" + self.assertEqual(doc_json, expected_json) + self.assertEqual(doc, Doc.from_json(doc.to_json())) def test_json_complex(self): @@ -58,6 +62,7 @@ class TestJson(unittest.TestCase): reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) + decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) diff --git a/tests/document/validation.py b/tests/document/validation.py index 4637dee..d3f3fd7 100644 --- a/tests/document/validation.py +++ b/tests/document/validation.py @@ -53,12 +53,11 @@ class ValidatorErrorTest(unittest.TestCase): self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") def test_model_validation(self): + class User(Document): username = StringField(primary_key=True) name = StringField(required=True) - User.drop_collection() - try: User().validate() except ValidationError, e: @@ -129,13 +128,18 @@ class ValidatorErrorTest(unittest.TestCase): Doc(id="test", e=SubDoc(val=15)).save() doc = Doc.objects.first() - keys = doc.to_dict().keys() + keys = doc._data.keys() self.assertEqual(2, len(keys)) self.assertTrue('e' in keys) self.assertTrue('id' in keys) - with self.assertRaises(ValueError): - doc.e.val = "OK" + doc.e.val = "OK" + try: + doc.save() + except ValidationError, e: + self.assertTrue("Doc:test" in e.message) + self.assertEqual(e.to_dict(), { + "e": {'val': 'OK could not be converted to int'}}) if __name__ == '__main__': diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 1eea2ac..8791781 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -56,11 +56,10 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - data = person.to_dict() - self.assertEqual(data['name'], person.name) - self.assertEqual(data['age'], person.age) - self.assertEqual(data['userid'], person.userid) - self.assertEqual(data['created'], person.created) + self.assertEqual(person._data['name'], person.name) + self.assertEqual(person._data['age'], person.age) + self.assertEqual(person._data['userid'], person.userid) + self.assertEqual(person._data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -89,11 +88,10 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - data = person.to_dict() - self.assertEqual(data['name'], person.name) - self.assertEqual(data['age'], person.age) - self.assertEqual(data['userid'], person.userid) - self.assertEqual(data['created'], person.created) + self.assertEqual(person._data['name'], person.name) + self.assertEqual(person._data['age'], person.age) + self.assertEqual(person._data['userid'], person.userid) + self.assertEqual(person._data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -125,12 +123,10 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - data = person.to_dict() - - self.assertEqual(data['name'], person.name) - self.assertEqual(data['age'], person.age) - self.assertEqual(data['userid'], person.userid) - self.assertEqual(data['created'], person.created) + self.assertEqual(person._data['name'], person.name) + self.assertEqual(person._data['age'], person.age) + self.assertEqual(person._data['userid'], person.userid) + self.assertEqual(person._data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -161,12 +157,10 @@ class FieldTest(unittest.TestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - data = person.to_dict() - - self.assertEqual(data['name'], person.name) - self.assertEqual(data['age'], person.age) - self.assertEqual(data['userid'], person.userid) - self.assertEqual(data['created'], person.created) + self.assertEqual(person._data['name'], person.name) + self.assertEqual(person._data['age'], person.age) + self.assertEqual(person._data['userid'], person.userid) + self.assertEqual(person._data['created'], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) @@ -272,6 +266,17 @@ class FieldTest(unittest.TestCase): self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) + def test_long_ne_operator(self): + class TestDocument(Document): + long_fld = LongField() + + TestDocument.drop_collection() + + TestDocument(long_fld=None).save() + TestDocument(long_fld=1).save() + + self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) + def test_object_id_validation(self): """Ensure that invalid values cannot be assigned to string fields. """ @@ -342,8 +347,25 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, person.validate) person.age = 120 self.assertRaises(ValidationError, person.validate) - with self.assertRaises(ValueError): - person.age = 'ten' + person.age = 'ten' + self.assertRaises(ValidationError, person.validate) + + def test_long_validation(self): + """Ensure that invalid values cannot be assigned to long fields. + """ + class TestDocument(Document): + value = LongField(min_value=0, max_value=110) + + doc = TestDocument() + doc.value = 50 + doc.validate() + + doc.value = -1 + self.assertRaises(ValidationError, doc.validate) + doc.age = 120 + self.assertRaises(ValidationError, doc.validate) + doc.age = 'ten' + self.assertRaises(ValidationError, doc.validate) def test_float_validation(self): """Ensure that invalid values cannot be assigned to float fields. @@ -362,6 +384,69 @@ class FieldTest(unittest.TestCase): person.height = 4.0 self.assertRaises(ValidationError, person.validate) + def test_decimal_validation(self): + """Ensure that invalid values cannot be assigned to decimal fields. + """ + class Person(Document): + height = DecimalField(min_value=Decimal('0.1'), + max_value=Decimal('3.5')) + + Person.drop_collection() + + Person(height=Decimal('1.89')).save() + person = Person.objects.first() + self.assertEqual(person.height, Decimal('1.89')) + + person.height = '2.0' + person.save() + person.height = 0.01 + self.assertRaises(ValidationError, person.validate) + person.height = Decimal('0.01') + self.assertRaises(ValidationError, person.validate) + person.height = Decimal('4.0') + self.assertRaises(ValidationError, person.validate) + + Person.drop_collection() + + def test_decimal_comparison(self): + + class Person(Document): + money = DecimalField() + + Person.drop_collection() + + Person(money=6).save() + Person(money=8).save() + Person(money=10).save() + + self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) + self.assertEqual(2, Person.objects(money__gt=7).count()) + self.assertEqual(2, Person.objects(money__gt="7").count()) + + def test_decimal_storage(self): + class Person(Document): + btc = DecimalField(precision=4) + + Person.drop_collection() + Person(btc=10).save() + Person(btc=10.1).save() + Person(btc=10.11).save() + Person(btc="10.111").save() + Person(btc=Decimal("10.1111")).save() + Person(btc=Decimal("10.11111")).save() + + # How its stored + expected = [{'btc': 10.0}, {'btc': 10.1}, {'btc': 10.11}, + {'btc': 10.111}, {'btc': 10.1111}, {'btc': 10.1111}] + actual = list(Person.objects.exclude('id').as_pymongo()) + self.assertEqual(expected, actual) + + # How it comes out locally + expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), + Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] + actual = list(Person.objects().scalar('btc')) + self.assertEqual(expected, actual) + def test_boolean_validation(self): """Ensure that invalid values cannot be assigned to boolean fields. """ @@ -447,10 +532,10 @@ class FieldTest(unittest.TestCase): log.time = datetime.datetime.now().isoformat('T') log.validate() - #log.time = -1 - #self.assertRaises(ValidationError, log.validate) - #log.time = 'ABC' - #self.assertRaises(ValidationError, log.validate) + log.time = -1 + self.assertRaises(ValidationError, log.validate) + log.time = 'ABC' + self.assertRaises(ValidationError, log.validate) def test_datetime_tz_aware_mark_as_changed(self): from mongoengine import connection @@ -471,7 +556,7 @@ class FieldTest(unittest.TestCase): log = LogEntry.objects.first() log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) - self.assertEqual(set(['time']), log._changed_fields) + self.assertEqual(['time'], log._changed_fields) def test_datetime(self): """Tests showing pymongo datetime fields handling of microseconds. @@ -706,8 +791,8 @@ class FieldTest(unittest.TestCase): post = BlogPost(content='Went for a walk today...') post.validate() - #post.tags = 'fun' - #self.assertRaises(ValidationError, post.validate) + post.tags = 'fun' + self.assertRaises(ValidationError, post.validate) post.tags = [1, 2] self.assertRaises(ValidationError, post.validate) @@ -818,11 +903,11 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() post = BlogPost() - #post.info = 'my post' - #self.assertRaises(ValidationError, post.validate) + post.info = 'my post' + self.assertRaises(ValidationError, post.validate) - #post.info = {'title': 'test'} - #self.assertRaises(ValidationError, post.validate) + post.info = {'title': 'test'} + self.assertRaises(ValidationError, post.validate) post.info = ['test'] post.save() @@ -871,12 +956,14 @@ class FieldTest(unittest.TestCase): e.mapping = [1] e.save() - with self.assertRaises(ValueError): + def create_invalid_mapping(): e.mapping = ["abc"] + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) Simple.drop_collection() - @unittest.skip("different behavior") def test_list_field_rejects_strings(self): """Strings aren't valid list field data types""" @@ -921,7 +1008,7 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() e = Simple().save() e.mapping = [] - self.assertEqual(set([]), e._changed_fields) + self.assertEqual([], e._changed_fields) class Simple(Document): mapping = DictField() @@ -929,9 +1016,34 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() e = Simple().save() e.mapping = {} - self.assertEqual(set([]), e._changed_fields) + self.assertEqual([], e._changed_fields) + + def test_slice_marks_field_as_changed(self): + + class Simple(Document): + widgets = ListField() + + simple = Simple(widgets=[1, 2, 3, 4]).save() + simple.widgets[:3] = [] + self.assertEqual(['widgets'], simple._changed_fields) + simple.save() + + simple = simple.reload() + self.assertEqual(simple.widgets, [4]) + + def test_del_slice_marks_field_as_changed(self): + + class Simple(Document): + widgets = ListField() + + simple = Simple(widgets=[1, 2, 3, 4]).save() + del simple.widgets[:3] + self.assertEqual(['widgets'], simple._changed_fields) + simple.save() + + simple = simple.reload() + self.assertEqual(simple.widgets, [4]) - @unittest.skip("complex types not implemented") def test_list_field_complex(self): """Ensure that the list fields can handle the complex types.""" @@ -988,11 +1100,11 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() post = BlogPost() - #post.info = 'my post' - #self.assertRaises(ValidationError, post.validate) + post.info = 'my post' + self.assertRaises(ValidationError, post.validate) - #post.info = ['test', 'test'] - #self.assertRaises(ValidationError, post.validate) + post.info = ['test', 'test'] + self.assertRaises(ValidationError, post.validate) post.info = {'$title': 'test'} self.assertRaises(ValidationError, post.validate) @@ -1050,7 +1162,6 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() - @unittest.skip("complex types not implemented") def test_dictfield_complex(self): """Ensure that the dict field can handle the complex types.""" @@ -1868,7 +1979,6 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() - @unittest.skip("not implemented") def test_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. @@ -1921,7 +2031,6 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() - @unittest.skip("not implemented") def test_simple_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. @@ -2001,7 +2110,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(d2.data, {}) self.assertEqual(d2.data2, {}) - @unittest.skip("SequenceField not implemented") def test_sequence_field(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2027,7 +2135,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(c['next'], 1000) - @unittest.skip("SequenceField not implemented") def test_sequence_field_get_next_value(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2059,7 +2166,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(Person.id.get_next_value(), '1') - @unittest.skip("SequenceField not implemented") def test_sequence_field_sequence_name(self): class Person(Document): id = SequenceField(primary_key=True, sequence_name='jelly') @@ -2084,7 +2190,6 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) self.assertEqual(c['next'], 1000) - @unittest.skip("SequenceField not implemented") def test_multiple_sequence_fields(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2117,7 +2222,6 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) self.assertEqual(c['next'], 999) - @unittest.skip("SequenceField not implemented") def test_sequence_fields_reload(self): class Animal(Document): counter = SequenceField() @@ -2143,7 +2247,6 @@ class FieldTest(unittest.TestCase): a.reload() self.assertEqual(a.counter, 2) - @unittest.skip("SequenceField not implemented") def test_multiple_sequence_fields_on_docs(self): class Animal(Document): @@ -2178,7 +2281,6 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) self.assertEqual(c['next'], 10) - @unittest.skip("SequenceField not implemented") def test_sequence_field_value_decorator(self): class Person(Document): id = SequenceField(primary_key=True, value_decorator=str) @@ -2200,7 +2302,6 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) - @unittest.skip("SequenceField not implemented") def test_embedded_sequence_field(self): class Comment(EmbeddedDocument): id = SequenceField() @@ -2373,6 +2474,78 @@ class FieldTest(unittest.TestCase): user = User(email='me@example.com') self.assertTrue(user.validate() is None) + def test_tuples_as_tuples(self): + """ + Ensure that tuples remain tuples when they are + inside a ComplexBaseField + """ + from mongoengine.base import BaseField + + class EnumField(BaseField): + + def __init__(self, **kwargs): + super(EnumField, self).__init__(**kwargs) + + def to_mongo(self, value): + return value + + def to_python(self, value): + return tuple(value) + + class TestDoc(Document): + items = ListField(EnumField()) + + TestDoc.drop_collection() + tuples = [(100, 'Testing')] + doc = TestDoc() + doc.items = tuples + doc.save() + x = TestDoc.objects().get() + self.assertTrue(x is not None) + self.assertTrue(len(x.items) == 1) + self.assertTrue(tuple(x.items[0]) in tuples) + self.assertTrue(x.items[0] in tuples) + + def test_dynamic_fields_class(self): + + class Doc2(Document): + field_1 = StringField(db_field='f') + + class Doc(Document): + my_id = IntField(required=True, unique=True, primary_key=True) + embed_me = DynamicField(db_field='e') + field_x = StringField(db_field='x') + + Doc.drop_collection() + Doc2.drop_collection() + + doc2 = Doc2(field_1="hello") + doc = Doc(my_id=1, embed_me=doc2, field_x="x") + self.assertRaises(OperationError, doc.save) + + doc2.save() + doc.save() + + doc = Doc.objects.get() + self.assertEqual(doc.embed_me.field_1, "hello") + + def test_dynamic_fields_embedded_class(self): + + class Embed(EmbeddedDocument): + field_1 = StringField(db_field='f') + + class Doc(Document): + my_id = IntField(required=True, unique=True, primary_key=True) + embed_me = DynamicField(db_field='e') + field_x = StringField(db_field='x') + + Doc.drop_collection() + + Doc(my_id=1, embed_me=Embed(field_1="hello"), field_x="x").save() + + doc = Doc.objects.get() + self.assertEqual(doc.embed_me.field_1, "hello") + if __name__ == '__main__': unittest.main() diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 9ad3fdd..ba601de 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -24,7 +24,6 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') -@unittest.skip("FileField not implemented") class FileTest(unittest.TestCase): def setUp(self): @@ -54,11 +53,12 @@ class FileTest(unittest.TestCase): content_type = 'text/plain' putfile = PutFile() - putfile.the_file.put(text, content_type=content_type) + putfile.the_file.put(text, content_type=content_type, filename="hello") putfile.save() result = PutFile.objects.first() self.assertTrue(putfile == result) + self.assertEqual("%s" % result.the_file, "") self.assertEqual(result.the_file.read(), text) self.assertEqual(result.the_file.content_type, content_type) result.the_file.delete() # Remove file from GridFS @@ -456,5 +456,31 @@ class FileTest(unittest.TestCase): self.assertEqual(1, TestImage.objects(Q(image1=grid_id) or Q(image2=grid_id)).count()) + def test_complex_field_filefield(self): + """Ensure you can add meta data to file""" + + class Animal(Document): + genus = StringField() + family = StringField() + photos = ListField(FileField()) + + Animal.drop_collection() + marmot = Animal(genus='Marmota', family='Sciuridae') + + marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk + + photos_field = marmot._fields['photos'].field + new_proxy = photos_field.get_proxy_obj('photos', marmot) + new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') + marmot_photo.close() + + marmot.photos.append(new_proxy) + marmot.save() + + marmot = Animal.objects.get() + self.assertEqual(marmot.photos[0].content_type, 'image/jpeg') + self.assertEqual(marmot.photos[0].foo, 'bar') + self.assertEqual(marmot.photos[0].get().length, 8313) + if __name__ == '__main__': unittest.main() diff --git a/tests/fields/geo.py b/tests/fields/geo.py index 81f8a69..31ded26 100644 --- a/tests/fields/geo.py +++ b/tests/fields/geo.py @@ -10,7 +10,6 @@ from mongoengine.connection import get_db __all__ = ("GeoFieldTest", ) -@unittest.skip("geo fields not implemented") class GeoFieldTest(unittest.TestCase): def setUp(self): diff --git a/tests/fixtures.py b/tests/fixtures.py index e207044..f1344d7 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -17,6 +17,14 @@ class PickleTest(Document): photo = FileField() +class PickleDyanmicEmbedded(DynamicEmbeddedDocument): + date = DateTimeField(default=datetime.now) + + +class PickleDynamicTest(DynamicDocument): + number = IntField() + + class PickleSignalsTest(Document): number = IntField() string = StringField(choices=(('One', '1'), ('Two', '2'))) diff --git a/tests/migration/__init__.py b/tests/migration/__init__.py index bff50c3..6fc83e0 100644 --- a/tests/migration/__init__.py +++ b/tests/migration/__init__.py @@ -1,4 +1,5 @@ from convert_to_new_inheritance_model import * +from decimalfield_as_float import * from refrencefield_dbref_to_object_id import * from turn_off_inheritance import * from uuidfield_to_binary import * diff --git a/tests/migration/decimalfield_as_float.py b/tests/migration/decimalfield_as_float.py new file mode 100644 index 0000000..3903c91 --- /dev/null +++ b/tests/migration/decimalfield_as_float.py @@ -0,0 +1,50 @@ + # -*- coding: utf-8 -*- +import unittest +import decimal +from decimal import Decimal + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField, DecimalField, ListField + +__all__ = ('ConvertDecimalField', ) + + +class ConvertDecimalField(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def test_how_to_convert_decimal_fields(self): + """Demonstrates migrating from 0.7 to 0.8 + """ + + # 1. Old definition - using dbrefs + class Person(Document): + name = StringField() + money = DecimalField(force_string=True) + monies = ListField(DecimalField(force_string=True)) + + Person.drop_collection() + Person(name="Wilson Jr", money=Decimal("2.50"), + monies=[Decimal("2.10"), Decimal("5.00")]).save() + + # 2. Start the migration by changing the schema + # Change DecimalField - add precision and rounding settings + class Person(Document): + name = StringField() + money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP) + monies = ListField(DecimalField(precision=2, + rounding=decimal.ROUND_HALF_UP)) + + # 3. Loop all the objects and mark parent as changed + for p in Person.objects: + p._mark_as_changed('money') + p._mark_as_changed('monies') + p.save() + + # 4. Confirmation of the fix! + wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] + self.assertTrue(isinstance(wilson['money'], float)) + self.assertTrue(all([isinstance(m, float) for m in wilson['monies']])) diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py index 2bdfce1..7d66d26 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/field_list.py @@ -162,6 +162,10 @@ class OnlyExcludeAllTest(unittest.TestCase): self.assertEqual(obj.name, person.name) self.assertEqual(obj.age, person.age) + obj = self.Person.objects.only(*('id', 'name',)).get() + self.assertEqual(obj.name, person.name) + self.assertEqual(obj.age, None) + # Check polymorphism still works class Employee(self.Person): salary = IntField(db_field='wage') @@ -395,5 +399,28 @@ class OnlyExcludeAllTest(unittest.TestCase): numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) + + def test_exclude_from_subclasses_docs(self): + + class Base(Document): + username = StringField() + + meta = {'allow_inheritance': True} + + class Anon(Base): + anon = BooleanField() + + class User(Base): + password = StringField() + wibble = StringField() + + Base.drop_collection() + User(username="mongodb", password="secret").save() + + user = Base.objects().exclude("password", "wibble").first() + self.assertEqual(user.password, None) + + self.assertRaises(LookUpError, Base.objects.exclude, "made_up") + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index 7e1c5df..f564896 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -8,7 +8,6 @@ from mongoengine import * __all__ = ("GeoQueriesTest",) -@unittest.skip("geo queries not implemented") class GeoQueriesTest(unittest.TestCase): def setUp(self): diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 1f1051b..b4bcf2a 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -30,12 +30,17 @@ class QuerySetTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') + class PersonMeta(EmbeddedDocument): + weight = IntField() + class Person(Document): name = StringField() age = IntField() + person_meta = EmbeddedDocumentField(PersonMeta) meta = {'allow_inheritance': True} Person.drop_collection() + self.PersonMeta = PersonMeta self.Person = Person def test_initialisation(self): @@ -777,10 +782,10 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 0) fresh_o1 = Organization.objects.get(id=o1.id) - fresh_o1.employees.append(p2) + fresh_o1.employees.append(p2) # Dereferences fresh_o1.save(cascade=False) # Saves - self.assertEqual(q, 2) + self.assertEqual(q, 3) def test_slave_okay(self): """Ensures that a query can take slave_okay syntax @@ -1492,9 +1497,6 @@ class QuerySetTest(unittest.TestCase): def test_pull_nested(self): - class User(Document): - name = StringField() - class Collaborator(EmbeddedDocument): user = StringField() @@ -1509,8 +1511,7 @@ class QuerySetTest(unittest.TestCase): Site.drop_collection() c = Collaborator(user='Esteban') - s = Site(name="test", collaborators=[c]) - s.save() + s = Site(name="test", collaborators=[c]).save() Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') self.assertEqual(Site.objects.first().collaborators, []) @@ -1520,6 +1521,71 @@ class QuerySetTest(unittest.TestCase): self.assertRaises(InvalidQueryError, pull_all) + def test_pull_from_nested_embedded(self): + + class User(EmbeddedDocument): + name = StringField() + + def __unicode__(self): + return '%s' % self.name + + class Collaborator(EmbeddedDocument): + helpful = ListField(EmbeddedDocumentField(User)) + unhelpful = ListField(EmbeddedDocumentField(User)) + + class Site(Document): + name = StringField(max_length=75, unique=True, required=True) + collaborators = EmbeddedDocumentField(Collaborator) + + + Site.drop_collection() + + c = User(name='Esteban') + f = User(name='Frank') + s = Site(name="test", collaborators=Collaborator(helpful=[c], unhelpful=[f])).save() + + Site.objects(id=s.id).update_one(pull__collaborators__helpful=c) + self.assertEqual(Site.objects.first().collaborators['helpful'], []) + + Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'name': 'Frank'}) + self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) + + def pull_all(): + Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__name=['Ross']) + + self.assertRaises(InvalidQueryError, pull_all) + + def test_pull_from_nested_mapfield(self): + + class Collaborator(EmbeddedDocument): + user = StringField() + + def __unicode__(self): + return '%s' % self.user + + class Site(Document): + name = StringField(max_length=75, unique=True, required=True) + collaborators = MapField(ListField(EmbeddedDocumentField(Collaborator))) + + + Site.drop_collection() + + c = Collaborator(user='Esteban') + f = Collaborator(user='Frank') + s = Site(name="test", collaborators={'helpful':[c],'unhelpful':[f]}) + s.save() + + Site.objects(id=s.id).update_one(pull__collaborators__helpful__user='Esteban') + self.assertEqual(Site.objects.first().collaborators['helpful'], []) + + Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'user':'Frank'}) + self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) + + def pull_all(): + Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__user=['Ross']) + + self.assertRaises(InvalidQueryError, pull_all) + def test_update_one_pop_generic_reference(self): class BlogTag(Document): @@ -2208,6 +2274,19 @@ class QuerySetTest(unittest.TestCase): self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.average('age')), avg) + # dot notation + self.Person(name='person meta', person_meta=self.PersonMeta(weight=0)).save() + self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), 0) + + for i, weight in enumerate(ages): + self.Person(name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save() + + self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), avg) + + self.Person(name='test meta none').save() + self.assertEqual(int(self.Person.objects.average('person_meta.weight')), avg) + + def test_sum(self): """Ensure that field can be summed over correctly. """ @@ -2220,6 +2299,153 @@ class QuerySetTest(unittest.TestCase): self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + for i, age in enumerate(ages): + self.Person(name='test meta%s' % i, person_meta=self.PersonMeta(weight=age)).save() + + self.assertEqual(int(self.Person.objects.sum('person_meta.weight')), sum(ages)) + + self.Person(name='weightless person').save() + self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + + def test_embedded_average(self): + class Pay(EmbeddedDocument): + value = DecimalField() + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(value=150)).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(value=530)).save() + + Doc(name=u"Tayza mariana", + pay=Pay(value=165)).save() + + Doc(name=u"Eliana Costa", + pay=Pay(value=115)).save() + + self.assertEqual( + Doc.objects.average('pay.value'), + 240) + + def test_embedded_array_average(self): + class Pay(EmbeddedDocument): + values = ListField(DecimalField()) + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(values=[150, 100])).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(values=[530, 100])).save() + + Doc(name=u"Tayza mariana", + pay=Pay(values=[165, 100])).save() + + Doc(name=u"Eliana Costa", + pay=Pay(values=[115, 100])).save() + + self.assertEqual( + Doc.objects.average('pay.values'), + 170) + + def test_array_average(self): + class Doc(Document): + values = ListField(DecimalField()) + + Doc.drop_collection() + + Doc(values=[150, 100]).save() + Doc(values=[530, 100]).save() + Doc(values=[165, 100]).save() + Doc(values=[115, 100]).save() + + self.assertEqual( + Doc.objects.average('values'), + 170) + + def test_embedded_sum(self): + class Pay(EmbeddedDocument): + value = DecimalField() + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(value=150)).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(value=530)).save() + + Doc(name=u"Tayza mariana", + pay=Pay(value=165)).save() + + Doc(name=u"Eliana Costa", + pay=Pay(value=115)).save() + + self.assertEqual( + Doc.objects.sum('pay.value'), + 960) + + + def test_embedded_array_sum(self): + class Pay(EmbeddedDocument): + values = ListField(DecimalField()) + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(values=[150, 100])).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(values=[530, 100])).save() + + Doc(name=u"Tayza mariana", + pay=Pay(values=[165, 100])).save() + + Doc(name=u"Eliana Costa", + pay=Pay(values=[115, 100])).save() + + self.assertEqual( + Doc.objects.sum('pay.values'), + 1360) + + def test_array_sum(self): + class Doc(Document): + values = ListField(DecimalField()) + + Doc.drop_collection() + + Doc(values=[150, 100]).save() + Doc(values=[530, 100]).save() + Doc(values=[165, 100]).save() + Doc(values=[115, 100]).save() + + self.assertEqual( + Doc.objects.sum('values'), + 1360) + def test_distinct(self): """Ensure that the QuerySet.distinct method works. """ @@ -2922,6 +3148,19 @@ class QuerySetTest(unittest.TestCase): (u'Wilson Jr', 19, u'Corumba-GO'), (u'Gabriel Falcao', 23, u'New York')]) + def test_scalar_decimal(self): + from decimal import Decimal + class Person(Document): + name = StringField() + rating = DecimalField() + + Person.drop_collection() + Person(name="Wilson Jr", rating=Decimal('1.0')).save() + + ulist = list(Person.objects.scalar('name', 'rating')) + self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))]) + + def test_scalar_reference_field(self): class State(Document): name = StringField() @@ -3121,6 +3360,13 @@ class QuerySetTest(unittest.TestCase): Test.objects(test='foo').update_one(upsert=True, set__test='foo') self.assertTrue('_cls' in Test._collection.find_one()) + def test_update_upsert_looks_like_a_digit(self): + class MyDoc(DynamicDocument): + pass + MyDoc.drop_collection() + self.assertEqual(1, MyDoc.objects.update_one(upsert=True, inc__47=1)) + self.assertEqual(MyDoc.objects.get()['47'], 1) + def test_read_preference(self): class Bar(Document): pass @@ -3149,7 +3395,7 @@ class QuerySetTest(unittest.TestCase): Doc(string="Bye", embedded_field=Embedded(string="Bye")).save() Doc().save() - json_data = Doc.objects.to_json() + json_data = Doc.objects.to_json(sort_keys=True, separators=(',', ':')) doc_objects = list(Doc.objects) self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) @@ -3177,6 +3423,7 @@ class QuerySetTest(unittest.TestCase): objectid_field = ObjectIdField(default=ObjectId) reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) + decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) @@ -3207,25 +3454,33 @@ class QuerySetTest(unittest.TestCase): id = ObjectIdField('_id') name = StringField() age = IntField() + price = DecimalField() User.drop_collection() - User(name="Bob Dole", age=89).save() - User(name="Barack Obama", age=51).save() + User(name="Bob Dole", age=89, price=Decimal('1.11')).save() + User(name="Barack Obama", age=51, price=Decimal('2.22')).save() - users = User.objects.only('name').as_pymongo() + results = User.objects.only('id', 'name').as_pymongo() + self.assertEqual(sorted(results[0].keys()), sorted(['_id', 'name'])) + + users = User.objects.only('name', 'price').as_pymongo() results = list(users) self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[1], dict)) self.assertEqual(results[0]['name'], 'Bob Dole') + self.assertEqual(results[0]['price'], 1.11) self.assertEqual(results[1]['name'], 'Barack Obama') + self.assertEqual(results[1]['price'], 2.22) # Test coerce_types - users = User.objects.only('name').as_pymongo(coerce_types=True) + users = User.objects.only('name', 'price').as_pymongo(coerce_types=True) results = list(users) self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[1], dict)) self.assertEqual(results[0]['name'], 'Bob Dole') + self.assertEqual(results[0]['price'], Decimal('1.11')) self.assertEqual(results[1]['name'], 'Barack Obama') + self.assertEqual(results[1]['price'], Decimal('2.22')) def test_as_pymongo_json_limit_fields(self): @@ -3249,7 +3504,6 @@ class QuerySetTest(unittest.TestCase): serialized_user = User.objects.exclude('password_salt').only('email').to_json() self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) - @unittest.skip("not implemented") def test_no_dereference(self): class Organization(Document): @@ -3297,6 +3551,27 @@ class QuerySetTest(unittest.TestCase): people.count() # count is cached self.assertEqual(q, 1) + def test_no_cached_queryset(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + for i in xrange(100): + Person(name="No: %s" % i).save() + + with query_counter() as q: + self.assertEqual(q, 0) + people = Person.objects.no_cache() + + [x for x in people] + self.assertEqual(q, 1) + + list(people) + self.assertEqual(q, 2) + + people.count() + self.assertEqual(q, 3) + def test_cache_not_cloned(self): class User(Document): @@ -3318,6 +3593,34 @@ class QuerySetTest(unittest.TestCase): self.assertEqual("%s" % users, "[]") self.assertEqual(1, len(users._result_cache)) + def test_no_cache(self): + """Ensure you can add meta data to file""" + + class Noddy(Document): + fields = DictField() + + Noddy.drop_collection() + for i in xrange(100): + noddy = Noddy() + for j in range(20): + noddy.fields["key"+str(j)] = "value "+str(j) + noddy.save() + + docs = Noddy.objects.no_cache() + + counter = len([1 for i in docs]) + self.assertEquals(counter, 100) + + self.assertEquals(len(list(docs)), 100) + self.assertRaises(TypeError, lambda: len(docs)) + + with query_counter() as q: + self.assertEqual(q, 0) + list(docs) + self.assertEqual(q, 1) + list(docs) + self.assertEqual(q, 2) + def test_nested_queryset_iterator(self): # Try iterating the same queryset twice, nested. names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George'] @@ -3449,6 +3752,23 @@ class QuerySetTest(unittest.TestCase): '_cls': 'Animal.Cat' }) + def test_can_have_field_same_name_as_query_operator(self): + + class Size(Document): + name = StringField() + + class Example(Document): + size = ReferenceField(Size) + + Size.drop_collection() + Example.drop_collection() + + instance_size = Size(name="Large").save() + Example(size=instance_size).save() + + self.assertEqual(Example.objects(size=instance_size).count(), 1) + self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 53c1660..d2e8b78 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -31,6 +31,31 @@ class TransformTest(unittest.TestCase): self.assertEqual(transform.query(name__exists=True), {'name': {'$exists': True}}) + def test_transform_update(self): + class DicDoc(Document): + dictField = DictField() + + class Doc(Document): + pass + + DicDoc.drop_collection() + Doc.drop_collection() + + doc = Doc().save() + dic_doc = DicDoc().save() + + for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")): + update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) + self.assertTrue(isinstance(update[v]["dictField.test"], dict)) + + # Update special cases + update = transform.update(DicDoc, unset__dictField__test=doc) + self.assertEqual(update["$unset"]["dictField.test"], 1) + + update = transform.update(DicDoc, pull__dictField__test=doc) + self.assertTrue(isinstance(update["$pull"]["dictField"]["test"], dict)) + + def test_query_field_name(self): """Ensure that the correct field name is used when querying. """ @@ -63,7 +88,6 @@ class TransformTest(unittest.TestCase): BlogPost.drop_collection() - @unittest.skip("unsupported") def test_query_pk_field_name(self): """Ensure that the correct "primary key" field name is used when querying diff --git a/tests/test_connection.py b/tests/test_connection.py index d27a66d..62d795c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -59,6 +59,32 @@ class ConnectionTest(unittest.TestCase): c.admin.system.users.remove({}) c.mongoenginetest.system.users.remove({}) + def test_connect_uri_without_db(self): + """Ensure that the connect() method works properly with uri's + without database_name + """ + c = connect(db='mongoenginetest', alias='admin') + c.admin.system.users.remove({}) + c.mongoenginetest.system.users.remove({}) + + c.admin.add_user("admin", "password") + c.admin.authenticate("admin", "password") + c.mongoenginetest.add_user("username", "password") + + self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') + + connect("mongoenginetest", host='mongodb://localhost/') + + conn = get_connection() + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) + + db = get_db() + self.assertTrue(isinstance(db, pymongo.database.Database)) + self.assertEqual(db.name, 'mongoenginetest') + + c.admin.system.users.remove({}) + c.mongoenginetest.system.users.remove({}) + def test_register_connection(self): """Ensure that connections with different aliases may be registered. """ diff --git a/tests/test_dereference.py b/tests/test_dereference.py index f13f291..6f2664a 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -16,7 +16,6 @@ class FieldTest(unittest.TestCase): connect(db='mongoenginetest') self.db = get_db() - @unittest.skip("select_related currently doesn't dereference lists") def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -75,7 +74,6 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() - @unittest.skip("select_related currently doesn't dereference lists") def test_list_item_dereference_dref_false(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -148,7 +146,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(Group._get_collection().find_one()['members'], [1]) self.assertEqual(group.members, [user]) - @unittest.skip('currently not implemented') def test_handle_old_style_references(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -182,7 +179,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(group.members[0].name, 'user 1') self.assertEqual(group.members[-1].name, 'String!') - @unittest.skip('currently not implemented') def test_migrate_references(self): """Example of migrating ReferenceField storage """ @@ -229,7 +225,6 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(raw_data['author'], ObjectId)) self.assertTrue(isinstance(raw_data['members'][0], ObjectId)) - @unittest.skip("select_related currently doesn't dereference lists") def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. """ @@ -264,15 +259,9 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 1) peter.boss - self.assertEqual(q, 1) - - peter.friends - self.assertEqual(q, 1) - - peter.boss.name self.assertEqual(q, 2) - peter.friends[0].name + peter.friends self.assertEqual(q, 3) # Document select_related @@ -302,32 +291,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(employee.friends, friends) self.assertEqual(q, 2) - def test_list_of_lists_of_references(self): - - class User(Document): - name = StringField() - - class Post(Document): - user_lists = ListField(ListField(ReferenceField(User))) - - class SimpleList(Document): - users = ListField(ReferenceField(User)) - - User.drop_collection() - Post.drop_collection() - SimpleList.drop_collection() - - u1 = User.objects.create(name='u1') - u2 = User.objects.create(name='u2') - u3 = User.objects.create(name='u3') - - SimpleList.objects.create(users=[u1, u2, u3]) - self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3]) - - Post.objects.create(user_lists=[[u1, u2], [u3]]) - self.assertEqual(Post.objects.all()[0].user_lists, [[u1, u2], [u3]]) - - def test_circular_reference(self): """Ensure you can handle circular references """ @@ -428,7 +391,6 @@ class FieldTest(unittest.TestCase): "%s" % Person.objects() ) - @unittest.skip("not implemented") def test_generic_reference(self): class UserA(Document): @@ -520,7 +482,6 @@ class FieldTest(unittest.TestCase): UserC.drop_collection() Group.drop_collection() - @unittest.skip("not implemented") def test_list_field_complex(self): class UserA(Document): @@ -612,7 +573,6 @@ class FieldTest(unittest.TestCase): UserC.drop_collection() Group.drop_collection() - @unittest.skip('MapField not fully implemented') def test_map_field_reference(self): class User(Document): @@ -678,7 +638,6 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() - @unittest.skip("not implemented") def test_dict_field(self): class UserA(Document): @@ -782,7 +741,6 @@ class FieldTest(unittest.TestCase): UserC.drop_collection() Group.drop_collection() - @unittest.skip("not implemented") def test_dict_field_no_field_inheritance(self): class UserA(Document): @@ -859,7 +817,6 @@ class FieldTest(unittest.TestCase): UserA.drop_collection() Group.drop_collection() - @unittest.skip("select_related currently doesn't dereference lists") def test_generic_reference_map_field(self): class UserA(Document): @@ -985,7 +942,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(root.children, [company]) self.assertEqual(company.parents, [root]) - @unittest.skip("not implemented") def test_dict_in_dbref_instance(self): class Person(Document): @@ -1165,37 +1121,32 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) - def test_tuples_as_tuples(self): - """ - Ensure that tuples remain tuples when they are - inside a ComplexBaseField - """ - from mongoengine.base import BaseField + def test_objectid_reference_across_databases(self): + # mongoenginetest - Is default connection alias from setUp() + # Register Aliases + register_connection('testdb-1', 'mongoenginetest2') - class EnumField(BaseField): + class User(Document): + name = StringField() + meta = {"db_alias": "testdb-1"} - def __init__(self, **kwargs): - super(EnumField, self).__init__(**kwargs) + class Book(Document): + name = StringField() + author = ReferenceField(User) - def to_mongo(self, value): - return value + # Drops + User.drop_collection() + Book.drop_collection() - def to_python(self, value): - return tuple(value) + user = User(name="Ross").save() + Book(name="MongoEngine for pros", author=user).save() - class TestDoc(Document): - items = ListField(EnumField()) + # Can't use query_counter across databases - so test the _data object + book = Book.objects.first() + self.assertFalse(isinstance(book._data['author'], User)) - TestDoc.drop_collection() - tuples = [(100, 'Testing')] - doc = TestDoc() - doc.items = tuples - doc.save() - x = TestDoc.objects().get() - self.assertTrue(x is not None) - self.assertTrue(len(x.items) == 1) - self.assertTrue(tuple(x.items[0]) in tuples) - self.assertTrue(x.items[0] in tuples) + book.select_related() + self.assertTrue(isinstance(book._data['author'], User)) def test_non_ascii_pk(self): """ @@ -1220,6 +1171,30 @@ class FieldTest(unittest.TestCase): self.assertEqual(2, len([brand for bg in brand_groups for brand in bg.brands])) + def test_dereferencing_embedded_listfield_referencefield(self): + class Tag(Document): + meta = {'collection': 'tags'} + name = StringField() + + class Post(EmbeddedDocument): + body = StringField() + tags = ListField(ReferenceField("Tag", dbref=True)) + + class Page(Document): + meta = {'collection': 'pages'} + tags = ListField(ReferenceField("Tag", dbref=True)) + posts = ListField(EmbeddedDocumentField(Post)) + + Tag.drop_collection() + Page.drop_collection() + + tag = Tag(name='test').save() + post = Post(body='test body', tags=[tag]) + Page(tags=[tag], posts=[post]).save() + + page = Page.objects.first() + self.assertEqual(page.tags[0], page.posts[0].tags[0]) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_django.py b/tests/test_django.py index 63e3245..46568ac 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -2,48 +2,44 @@ import sys sys.path[0:0] = [""] import unittest from nose.plugins.skip import SkipTest -from mongoengine.python_support import PY3 from mongoengine import * + +from mongoengine.django.shortcuts import get_document_or_404 + +from django.http import Http404 +from django.template import Context, Template +from django.conf import settings +from django.core.paginator import Paginator + +settings.configure( + USE_TZ=True, + INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'), + AUTH_USER_MODEL=('mongo_auth.MongoUser'), +) + try: - from mongoengine.django.shortcuts import get_document_or_404 - - from django.http import Http404 - from django.template import Context, Template - from django.conf import settings - from django.core.paginator import Paginator - - settings.configure( - USE_TZ=True, - INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'), - AUTH_USER_MODEL=('mongo_auth.MongoUser'), + from django.contrib.auth import authenticate, get_user_model + from mongoengine.django.auth import User + from mongoengine.django.mongo_auth.models import ( + MongoUser, + MongoUserManager, + get_user_document, ) - - try: - from django.contrib.auth import authenticate, get_user_model - from mongoengine.django.auth import User - from mongoengine.django.mongo_auth.models import MongoUser, MongoUserManager - DJ15 = True - except Exception: - DJ15 = False - from django.contrib.sessions.tests import SessionTestsMixin - from mongoengine.django.sessions import SessionStore, MongoSession -except Exception, err: - if PY3: - SessionTestsMixin = type # dummy value so no error - SessionStore = None # dummy value so no error - else: - raise err - - + DJ15 = True +except Exception: + DJ15 = False +from django.contrib.sessions.tests import SessionTestsMixin +from mongoengine.django.sessions import SessionStore, MongoSession from datetime import tzinfo, timedelta ZERO = timedelta(0) + class FixedOffset(tzinfo): """Fixed offset in minutes east from UTC.""" def __init__(self, offset, name): - self.__offset = timedelta(minutes = offset) + self.__offset = timedelta(minutes=offset) self.__name = name def utcoffset(self, dt): @@ -70,8 +66,6 @@ def activate_timezone(tz): class QuerySetTest(unittest.TestCase): def setUp(self): - if PY3: - raise SkipTest('django does not have Python 3 support') connect(db='mongoenginetest') class Person(Document): @@ -173,6 +167,8 @@ class QuerySetTest(unittest.TestCase): class Note(Document): text = StringField() + Note.drop_collection() + for i in xrange(1, 101): Note(name="Note: %s" % i).save() @@ -223,8 +219,6 @@ class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase): backend = SessionStore def setUp(self): - if PY3: - raise SkipTest('django does not have Python 3 support') connect(db='mongoenginetest') MongoSession.drop_collection() super(MongoDBSessionTest, self).setUp() @@ -262,17 +256,18 @@ class MongoAuthTest(unittest.TestCase): } def setUp(self): - if PY3: - raise SkipTest('django does not have Python 3 support') if not DJ15: raise SkipTest('mongo_auth requires Django 1.5') connect(db='mongoenginetest') User.drop_collection() super(MongoAuthTest, self).setUp() - def test_user_model(self): + def test_get_user_model(self): self.assertEqual(get_user_model(), MongoUser) + def test_get_user_document(self): + self.assertEqual(get_user_document(), User) + def test_user_manager(self): manager = get_user_model()._default_manager self.assertTrue(isinstance(manager, MongoUserManager)) diff --git a/tests/test_signals.py b/tests/test_signals.py index da217c0..50e5e6b 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -30,10 +30,28 @@ class SignalTests(unittest.TestCase): def __unicode__(self): return self.name + @classmethod + def pre_init(cls, sender, document, *args, **kwargs): + signal_output.append('pre_init signal, %s' % cls.__name__) + signal_output.append(str(kwargs['values'])) + + @classmethod + def post_init(cls, sender, document, **kwargs): + signal_output.append('post_init signal, %s' % document) + @classmethod def pre_save(cls, sender, document, **kwargs): signal_output.append('pre_save signal, %s' % document) + @classmethod + def pre_save_post_validation(cls, sender, document, **kwargs): + signal_output.append('pre_save_post_validation signal, %s' % document) + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + @classmethod def post_save(cls, sender, document, **kwargs): signal_output.append('post_save signal, %s' % document) @@ -100,7 +118,10 @@ class SignalTests(unittest.TestCase): # Save up the number of connected signals so that we can check at the # end that all the signals we register get properly unregistered self.pre_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), len(signals.pre_save.receivers), + len(signals.pre_save_post_validation.receivers), len(signals.post_save.receivers), len(signals.pre_delete.receivers), len(signals.post_delete.receivers), @@ -108,7 +129,10 @@ class SignalTests(unittest.TestCase): len(signals.post_bulk_insert.receivers), ) + signals.pre_init.connect(Author.pre_init, sender=Author) + signals.post_init.connect(Author.post_init, sender=Author) signals.pre_save.connect(Author.pre_save, sender=Author) + signals.pre_save_post_validation.connect(Author.pre_save_post_validation, sender=Author) signals.post_save.connect(Author.post_save, sender=Author) signals.pre_delete.connect(Author.pre_delete, sender=Author) signals.post_delete.connect(Author.post_delete, sender=Author) @@ -121,9 +145,12 @@ class SignalTests(unittest.TestCase): signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) def tearDown(self): + signals.pre_init.disconnect(self.Author.pre_init) + signals.post_init.disconnect(self.Author.post_init) signals.post_delete.disconnect(self.Author.post_delete) signals.pre_delete.disconnect(self.Author.pre_delete) signals.post_save.disconnect(self.Author.post_save) + signals.pre_save_post_validation.disconnect(self.Author.pre_save_post_validation) signals.pre_save.disconnect(self.Author.pre_save) signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) @@ -135,7 +162,10 @@ class SignalTests(unittest.TestCase): # Check that all our signals got disconnected properly. post_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), len(signals.pre_save.receivers), + len(signals.pre_save_post_validation.receivers), len(signals.post_save.receivers), len(signals.pre_delete.receivers), len(signals.post_delete.receivers), @@ -150,6 +180,9 @@ class SignalTests(unittest.TestCase): def test_model_signals(self): """ Model saves should throw some signals. """ + def create_author(): + self.Author(name='Bill Shakespeare') + def bulk_create_author_with_load(): a1 = self.Author(name='Bill Shakespeare') self.Author.objects.insert([a1], load_bulk=True) @@ -158,9 +191,17 @@ class SignalTests(unittest.TestCase): a1 = self.Author(name='Bill Shakespeare') self.Author.objects.insert([a1], load_bulk=False) + self.assertEqual(self.get_signal_output(create_author), [ + "pre_init signal, Author", + "{'name': 'Bill Shakespeare'}", + "post_init signal, Bill Shakespeare", + ]) + a1 = self.Author(name='Bill Shakespeare') self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, Bill Shakespeare", + "pre_save_post_validation signal, Bill Shakespeare", + "Is created", "post_save signal, Bill Shakespeare", "Is created" ]) @@ -169,6 +210,8 @@ class SignalTests(unittest.TestCase): a1.name = 'William Shakespeare' self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, William Shakespeare", + "pre_save_post_validation signal, William Shakespeare", + "Is updated", "post_save signal, William Shakespeare", "Is updated" ]) @@ -180,13 +223,18 @@ class SignalTests(unittest.TestCase): signal_output = self.get_signal_output(bulk_create_author_with_load) - self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [ - "pre_bulk_insert signal, []", - "post_bulk_insert signal, []", - "Is loaded", - ]) + # The output of this signal is not entirely deterministic. The reloaded + # object will have an object ID. Hence, we only check part of the output + self.assertEqual(signal_output[3], + "pre_bulk_insert signal, []") + self.assertEqual(signal_output[-2:], + ["post_bulk_insert signal, []", + "Is loaded",]) self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [ + "pre_init signal, Author", + "{'name': 'Bill Shakespeare'}", + "post_init signal, Bill Shakespeare", "pre_bulk_insert signal, []", "post_bulk_insert signal, []", "Not loaded",