Compare commits

..

106 commits

Author SHA1 Message Date
Joey Payne
654cca82a9 Fixes AttributeError when using storage.exists() on a non-existing file. 2013-09-18 11:38:38 -06:00
Ross Lawley
bcbe740598 Updated setup.py 2013-08-23 13:41:15 +00:00
Ross Lawley
86c8929d77 0.8.4 is a go 2013-08-23 10:03:10 +00:00
Ross Lawley
6738a9433b Updated travis 2013-08-23 09:36:33 +00:00
Ross Lawley
23843ec86e Updated travis config 2013-08-23 09:06:57 +00:00
Ross Lawley
f4db0da585 Update changelog add LK4D4 to authors (#452) 2013-08-23 09:03:51 +00:00
Ross Lawley
9ee3b796cd Merge pull request #452 from LK4D4/master
Remove database name necessity in uri connection schema
2013-08-23 02:02:12 -07:00
Alexandr Morozov
f57569f553 Remove database name necessity in uri connection schema 2013-08-21 13:52:24 +04:00
Ross Lawley
fffd0e8990 Fixed error raise 2013-08-20 18:54:14 +00:00
Ross Lawley
200e52bab5 Added documentation about abstract meta
Refs #438
2013-08-20 18:44:12 +00:00
Ross Lawley
a0ef649dd8 Update travis.yml 2013-08-20 18:31:33 +00:00
Ross Lawley
0dd01bda01 Fixed "$pull" semantics for nested ListFields (#447) 2013-08-20 15:54:42 +00:00
Ross Lawley
a707598042 Allow fields to be named the same as query operators (#445) 2013-08-20 13:13:17 +00:00
Ross Lawley
8a3171308a Merge remote-tracking branch 'origin/pr/445' 2013-08-20 13:04:20 +00:00
Ross Lawley
29c887f30b Updated field filter logic - can now exclude subclass fields (#443) 2013-08-20 12:21:20 +00:00
Ross Lawley
661398d891 Fixed dereference issue with embedded listfield referencefields (#439) 2013-08-20 10:22:06 +00:00
Ross Lawley
2cd722d751 Updated setup.py 2013-08-20 10:20:05 +00:00
Ross Lawley
49f5b4fa5c Fix Queryset docs (#448) 2013-08-20 09:45:00 +00:00
Ross Lawley
67baf465f4 Fixed slice when using inheritance causing fields to be excluded (#437) 2013-08-20 09:14:58 +00:00
Ross Lawley
ee7666ddea Update AUTHORS and Changelog (#441) 2013-08-20 08:31:56 +00:00
Ross Lawley
02fc41ff1c Merge branch 'master' of github.com:MongoEngine/mongoengine 2013-08-20 08:30:33 +00:00
Ross Lawley
d07a9d2ef8 Dynamic Fields store and recompose Embedded Documents / Documents correctly (#449) 2013-08-20 08:30:20 +00:00
Ross Lawley
3622ebfabd Merge pull request #441 from Karmak23/patch-1
Fix the ._get_db() attribute after a Document.switch_db()
2013-08-20 01:19:26 -07:00
crazyzubr
70b320633f permit the establishment of a field with the name of size or other
Example:

# model
class Example(Document):
    size = ReferenceField(Size, verbose_name='Size')

# query

examples = Example.objects(size=instance_size)

# caused an error

"""
File ".../mongoengine/queryset/transform.py", line 50, in query
if parts[-1] == 'not':
IndexError: list index out of range
"""
2013-08-15 19:32:13 +08:00
Olivier Cortès
f30208f345 Fix the ._get_db() attribute after a Document.switch_db()
Without this patch, I've got:

```
myobj._get_db()

<bound method TopLevelDocumentMetaclass._get_db of <class 'oneflow.core.models.nonrel.Article'>>

```

I need to `myobj._get_db()()` to get the database.

I felt this like a bug.

regards,
2013-08-12 19:12:53 +02:00
Ross Lawley
5bcc454678 Handle dynamic fieldnames that look like digits (#434) 2013-08-07 09:07:57 +00:00
Ross Lawley
473110568f Merge branch 'master' of github.com:MongoEngine/mongoengine 2013-08-06 11:07:25 +00:00
Ross Lawley
88ca0f8196 Merge remote-tracking branch 'origin/pr/432'
Conflicts:
	tests/test_django.py
2013-08-06 11:05:52 +00:00
Ross Lawley
a171005010 Merge pull request #428 from devoto13/patch-1
Removed duplicated line for 'pop' method in documentation
2013-08-06 03:20:50 -07:00
Ross Lawley
f56ad2fa58 Merge pull request #426 from laurentpayot/Django_shortcuts_docs
updated docs for django shortcuts get_object_or_404 and get_list_or_404
2013-08-06 03:17:56 -07:00
Nicolas Cortot
a0d255369a Add a test case for get_user_document 2013-08-04 11:29:16 +02:00
Nicolas Cortot
40b0a15b35 Fixing typos 2013-08-04 11:03:34 +02:00
Nicolas Cortot
b98b06ff79 Fix an error in get_user_document 2013-08-04 11:01:09 +02:00
devoto13
a448c9aebf removed duplicated method 2013-08-01 17:54:41 +03:00
Laurent Payot
b3f462a39d updated docs for django shortcuts get_object_or_404 and get_list_or_404 2013-08-01 03:51:10 +02:00
Ross Lawley
7ce34ca019 Merge branch 'master' of github.com:MongoEngine/mongoengine 2013-07-31 09:44:50 +00:00
Ross Lawley
719bb53c3a Updated changelog (#423) 2013-07-31 09:44:15 +00:00
Ross Lawley
214415969f Merge pull request #423 from ncortot/get_user_document
Add get_user_document and improve mongo_auth module
2013-07-31 02:43:19 -07:00
Ross Lawley
7431b1f123 Updated AUTHORS (#424) 2013-07-31 09:31:04 +00:00
Ross Lawley
d8ffa843a9 Added str representation of GridFSProxy (#424) 2013-07-31 09:29:41 +00:00
Paul
a69db231cc Pretty-print GridFSProxy objects 2013-07-31 11:26:23 +10:00
Nicolas Cortot
c17f94422f Add get_user_document and improve mongo_auth module
* Added a get_user_document() methot to access the actual Document class
    used for authentication.
  * Clarified the docstring on MongoUser to prevent its use when the user
    Document class should be used.
  * Removed the masking of exceptions when loading the user document class.
2013-07-30 20:48:25 +02:00
Ross Lawley
b4777f7f4f Fix test 2013-07-30 15:04:52 +00:00
Ross Lawley
a57d9a9303 Added regression test (#418) 2013-07-30 13:28:05 +00:00
Ross Lawley
5e70e1bcb2 Update transform to handle docs erroneously passed to unset (#416) 2013-07-30 13:17:38 +00:00
Ross Lawley
0c43787996 Fixed indexing - turn off _cls (#414) 2013-07-30 11:43:52 +00:00
Ross Lawley
dc310b99f9 Updated docs about TTL indexes and signals (#413) 2013-07-30 10:54:04 +00:00
Ross Lawley
e98c5e10bc Fixed dereference threading issue in ComplexField.__get__ (#412) 2013-07-30 10:49:08 +00:00
Ross Lawley
f1b1090263 Merge remote-tracking branch 'origin/pr/412' into 412
Conflicts:
	AUTHORS
2013-07-30 10:32:07 +00:00
Ross Lawley
6efd6faa3f Fixed QuerySetNoCache.count() caching (#410) 2013-07-30 10:30:16 +00:00
Ross Lawley
1e4d48d371 Don't follow references in _get_changed_fields (#422, #417)
A better fix so we dont follow down a references rabbit hole.
2013-07-29 17:22:24 +00:00
Ross Lawley
93a2adb3e6 Updating changelog and authors #417 2013-07-29 15:43:54 +00:00
Ross Lawley
a66d516777 Merge pull request #417 from ProgressiveCompany/delta-dbref-false-bug
BaseDocument._delta doesn't properly end it's path at Documents when using `dbref=False`
2013-07-29 08:41:09 -07:00
Ross Lawley
7a97d42338 to_json test updates #420 2013-07-29 15:38:08 +00:00
Ross Lawley
b66cdc8fa0 Merge branch 'master' of github.com:MongoEngine/mongoengine 2013-07-29 15:30:21 +00:00
Ross Lawley
67f43b2aad Allow args and kwargs to be passed through to_json (#420) 2013-07-29 15:29:48 +00:00
Paul Uithol
d143e50238 Replace assertIn with an assertTrue; apparently missing in Python 2.6 2013-07-25 15:34:58 +02:00
Paul Uithol
e27439be6a Fix BaseDocument._delta when working with plain ObjectIds instead of DBRefs 2013-07-25 14:52:03 +02:00
Paul Uithol
2ad5ffbda2 Add asserts to test_delta_with_dbref_*, instead of relying on exceptions 2013-07-25 14:51:09 +02:00
Paul Uithol
dae9e662a5 Create test case for failing saves (wrong delta) with dbref=False 2013-07-25 14:30:20 +02:00
Ross Lawley
f22737d6a4 Merge pull request #409 from bool-dev/master
Corrected mistakes in upgrade.rst documentation
2013-07-23 01:19:36 -07:00
Ross Lawley
a458d5a176 Docs update #406 2013-07-23 08:16:06 +00:00
Ross Lawley
d92ed04538 Docs update #406 2013-07-23 08:13:52 +00:00
Thom Knowles
80b3df8953 dereference instance not thread-safe 2013-07-22 20:07:57 -04:00
bool-dev
bcf83ec761 Corrected spelling mistakes, some grammar, and UUID/DecimalField error in upgrade.rst 2013-07-18 09:17:28 +05:30
bool.dev
e44e72bce3 Merge remote-tracking branch 'upstream/master' 2013-07-18 08:49:02 +05:30
Ross Lawley
35f2781518 Update changelog 2013-07-12 09:11:27 +00:00
Ross Lawley
dc5512e403 Upgrade warning for 0.8.3 2013-07-12 09:01:11 +00:00
Ross Lawley
48ef176e28 0.8.3 is a go 2013-07-12 08:41:56 +00:00
Ross Lawley
1aa2b86df3 travis install python-dateutil direct 2013-07-11 09:38:59 +00:00
Ross Lawley
73026047e9 Trying to fix dateutil 2013-07-11 09:29:06 +00:00
Ross Lawley
6c2c33cac8 Add Jatin- to Authors, changelog update 2013-07-11 08:12:27 +00:00
Ross Lawley
d593f7e04b Fixed EmbeddedDocuments with id also storing _id (#402) 2013-07-11 08:11:00 +00:00
Ross Lawley
6c599ef506 Fix edge case where _dynamic_keys stored as None (#387, #401) 2013-07-11 07:15:34 +00:00
Ross Lawley
f48a0b7b7d Trying to fix travis 2013-07-10 21:30:29 +00:00
Ross Lawley
d9f538170b Added get_proxy_object helper to filefields (#391) 2013-07-10 21:19:11 +00:00
Ross Lawley
1785ced655 Merge branch 'master' into 391 2013-07-10 20:35:21 +00:00
Ross Lawley
e155e1fa86 Add a default for previously pickled versions 2013-07-10 20:10:01 +00:00
Ross Lawley
e28fab0550 Merge remote-tracking branch 'origin/pr/400' 2013-07-10 19:56:15 +00:00
Ross Lawley
fb0dd2c1ca Updated changelog 2013-07-10 19:54:30 +00:00
Ross Lawley
6e89e736b7 Merge remote-tracking branch 'origin/pr/393' into 393
Conflicts:
	mongoengine/queryset/queryset.py
	tests/queryset/queryset.py
2013-07-10 19:53:13 +00:00
Ross Lawley
634b874c46 Added QuerySetNoCache and QuerySet.no_cache() for lower memory consumption (#365) 2013-07-10 19:40:57 +00:00
Ross Lawley
9d16364394 Merge remote-tracking branch 'origin/pr/391' into 391 2013-07-10 14:10:09 +00:00
Wilson Júnior
daeecef59e Update fields.py
Typo in documentation for DecimalField
2013-07-10 10:59:41 -03:00
Ross Lawley
8131f0a752 Fixed sum and average mapreduce dot notation support (#375, #376) 2013-07-10 13:53:18 +00:00
Ross Lawley
f4ea1ad517 Merge remote-tracking branch 'origin/pr/376'
Conflicts:
	AUTHORS
2013-07-10 13:50:52 +00:00
Ross Lawley
f34e8a0ff6 Fixed as_pymongo to return the id (#386) 2013-07-10 13:38:53 +00:00
Ross Lawley
4209d61b13 Document.select_related() now respects db_alias (#377) 2013-07-10 12:49:19 +00:00
Ross Lawley
fa83fba637 Reload uses shard_key if applicable (#384) 2013-07-10 11:18:49 +00:00
Ross Lawley
af86aee970 _dynamic field updates - fixed pickling and creation order
Dynamic fields are ordered based on creation and stored in _fields_ordered (#396)
Fixed pickling dynamic documents `_dynamic_fields` (#387)
2013-07-10 10:57:24 +00:00
Ross Lawley
f26f1a526c Merge branch 'master' of github.com:MongoEngine/mongoengine 2013-07-10 09:12:36 +00:00
Ross Lawley
7cb46d0761 Fixed ListField setslice and delslice dirty tracking (#390) 2013-07-10 09:11:50 +00:00
Ross Lawley
0cb4070364 Added Django 1.5 PY3 support (#392) 2013-07-10 08:53:56 +00:00
Ross Lawley
bc008c2597 Merge remote-tracking branch 'origin/pr/392' into 392 2013-07-10 08:44:10 +00:00
Ross Lawley
a1d142d3a4 Prep for django and py3 support 2013-07-10 08:38:13 +00:00
Ross Lawley
aa00dc1031 Merge pull request #392 from lig/patch-1
Fix crash on Python 3.x and Django >= 1.5
2013-07-10 01:37:40 -07:00
Wilson Júnior
592c654916 extending support for queryset.sum and queryset.average methods 2013-07-05 10:36:11 -03:00
Serge Matveenko
5021b10535 Fix crash on Python 3.x and Django >= 1.5 2013-07-03 01:17:40 +04:00
Jan Schrewe
43d6e64cfa Added a get_proxy_obj method to FileField and handle FileFields in container fields properly in ImageGridFsProxy. 2013-07-02 17:04:15 +02:00
kelvinhammond
caff44c663 Fixed sum and average queryset function
* Fixed sum and average map reduce functions for sum and average so that
        it works with mongo dot notation.

* Added unittest cases / updated them for the new changes
2013-06-21 09:39:11 -04:00
kelvinhammond
e0d2fab3c3 Merge branch 'master' of https://github.com/MongoEngine/mongoengine
Conflicts:
	AUTHORS
2013-06-21 07:26:40 -04:00
kelvinhammond
c31d6a6898 Fixed sum and average mapreduce function for issue #375 2013-06-19 10:34:33 -04:00
bool.dev
5cfd8909a8 Merge remote-tracking branch 'upstream/master' 2013-04-28 13:40:58 +05:30
bool.dev
d92f992c01 Removed merge trackers in code, merged correctly now. 2013-04-14 13:48:11 +05:30
bool.dev
20a5d9051d Merge conflicts resolved. 2013-04-14 13:39:54 +05:30
bool.dev
782d48594a Fixes resolving to db_field from class field name, in distinct() query. 2013-04-04 09:02:30 +05:30
57 changed files with 3766 additions and 2765 deletions

View file

@ -11,12 +11,14 @@ env:
- PYMONGO=dev DJANGO=1.4.2 - PYMONGO=dev DJANGO=1.4.2
- PYMONGO=2.5 DJANGO=1.5.1 - PYMONGO=2.5 DJANGO=1.5.1
- PYMONGO=2.5 DJANGO=1.4.2 - PYMONGO=2.5 DJANGO=1.4.2
- PYMONGO=3.2 DJANGO=1.5.1
- PYMONGO=3.3 DJANGO=1.5.1
install: install:
- if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi
- if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi
- if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install django==$DJANGO --use-mirrors ; true; fi
- if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi
- if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi
- pip install https://pypi.python.org/packages/source/p/python-dateutil/python-dateutil-2.1.tar.gz#md5=1534bb15cf311f07afaa3aacba1c028b
- python setup.py install - python setup.py install
script: script:
- python setup.py test - python setup.py test

14
AUTHORS
View file

@ -9,7 +9,6 @@ Steve Challis <steve@stevechallis.com>
Wilson Júnior <wilsonpjunior@gmail.com> Wilson Júnior <wilsonpjunior@gmail.com>
Dan Crosta https://github.com/dcrosta Dan Crosta https://github.com/dcrosta
Laine Herron https://github.com/LaineHerron Laine Herron https://github.com/LaineHerron
Thomas Steinacher http://thomasst.ch/
CONTRIBUTORS CONTRIBUTORS
@ -17,8 +16,6 @@ Dervived from the git logs, inevitably incomplete but all of whom and others
have submitted patches, reported bugs and generally helped make MongoEngine have submitted patches, reported bugs and generally helped make MongoEngine
that much better: that much better:
* Harry Marr
* Ross Lawley
* blackbrrr * blackbrrr
* Florian Schlachter * Florian Schlachter
* Vincent Driessen * Vincent Driessen
@ -115,6 +112,7 @@ that much better:
* Alexander Koshelev * Alexander Koshelev
* Jaime Irurzun * Jaime Irurzun
* Alexandre González * Alexandre González
* Thomas Steinacher
* Tommi Komulainen * Tommi Komulainen
* Peter Landry * Peter Landry
* biszkoptwielki * biszkoptwielki
@ -171,3 +169,13 @@ that much better:
* ygbourhis (https://github.com/ygbourhis) * ygbourhis (https://github.com/ygbourhis)
* Bob Dickinson (https://github.com/BobDickinson) * Bob Dickinson (https://github.com/BobDickinson)
* Michael Bartnett (https://github.com/michaelbartnett) * Michael Bartnett (https://github.com/michaelbartnett)
* Alon Horev (https://github.com/alonho)
* Kelvin Hammond (https://github.com/kelvinhammond)
* Jatin- (https://github.com/jatin-)
* Paul Uithol (https://github.com/PaulUithol)
* Thom Knowles (https://github.com/fleat)
* Paul (https://github.com/squamous)
* Olivier Cortès (https://github.com/Karmak23)
* crazyzubr (https://github.com/crazyzubr)
* FrankSomething (https://github.com/FrankSomething)
* Alexandr Morozov (https://github.com/LK4D4)

View file

@ -1,37 +0,0 @@
Differences between Mongomallard and Mongoengine
-----
* All document fields are lazily evaluated, resulting in much faster object initialization time.
* `_data` is removed due to lazy evaluation. `to_dict()` can be used to convert a document to a dictionary, and `_internal_data` contains previously evaluated data.
* Field methods `to_python`, `from_python`, `to_mongo`, `value_for_instance`:
* `to_python` is called when converting from a MongoDB type to a document Python type only.
* `from_python` is called when converting an assignment in Python to the document Python type.
* `to_mongo` is called when converting from a document Python type to a MongoDB type.
* `value_for_instance` is called just before returning a value in Python allowing for instance-specific transformations.
* `pre_init`, `post_init`, `pre_save_post_validation` signals are removed to ensure fast object initialization.
* `DecimalField` is removed since there is no corresponding MongoDB type
* `LongField` is removed since it is equivalent with `IntField`
* Adding `SafeReferenceField` which returns None if the reference does not exist.
* Adding `SafeReferenceListField` which doesn't return references that don't exist.
* Accessing a `ListField(ReferenceField)` doesn't automatically dereference all objects since they are lazily evaluated. A `SafeReferenceListField` may be used instead.
* Accessing a related object's id doesn't fetch the object from the database, e.g. `book.author.id` where author is a `ReferenceField` will not make a database lookup except when using a `SafeReferenceField`. When inheritance is allowed, a proxy object will be returned, otherwise a lazy object from the referenced document class will be returned.
* The primary key is only stored as `_id` in the database and is referenced in Python as `pk` or as the name of the primary key field.
* Saves are not cascaded by default.
* `Document.save()` supports `full=True` keyword argument to force saving all model fields.
* `_get_changed_fields()` / `_changed_fields` returns a set of field names (not db field names)
* Simplified `EmailField` email regex to be more compatible
* Assigning invalid types (e.g. an invalid string to `IntField`) raises immediately a `ValueError`
* `order_by()` without an argument resets the ordering (no ordering will be applied)
Untested / not implemented yet:
-----
* Dynamic documents / `DynamicField`, dynamic addition/deletion of fields
* Field display name methods
* `SequenceField`
* Pickling documents
* `FileField`
* All Geo fields
* `no_dereference()`
* using `SafeReferenceListField` with `GenericReferenceField`
* `max_depth` argument for `doc.reload()`

View file

@ -1,85 +0,0 @@
MongoMallard
============
MongoMallard is a fast ORM-like layer on top of PyMongo, based on MongoEngine.
* Repository: https://github.com/elasticsales/mongomallard
* See [README_MONGOENGINE](https://github.com/elasticsales/mongomallard/blob/master/README_MONGOENGINE.rst) for MongoEngine's README.
* See [DIFFERENCES](https://github.com/elasticsales/mongomallard/blob/master/DIFFERENCES.md) for differences between MongoEngine and MongoMallard.
Benchmarks
----------
Sample run on a 2.7 GHz Intel Core i5 running OS X 10.8.3
<table>
<tr>
<th></th>
<th>MongoEngine 0.8.2 (ede9fcf)</th>
<th>MongoMallard (478062c)</th>
<th>Speedup</th>
</tr>
<tr>
<td>Doc initialization</td>
<td>52.494us</td>
<td>25.195us</td>
<td>2.08x</td>
</tr>
<tr>
<td>Doc getattr</td>
<td>1.339us</td>
<td>0.584us</td>
<td>2.29x</td>
</tr>
<tr>
<td>Doc setattr</td>
<td>3.064us</td>
<td>2.550us</td>
<td>1.20x</td>
</tr>
<tr>
<td>Doc to mongo</td>
<td>49.415us</td>
<td>26.497us</td>
<td>1.86x</td>
</tr>
<tr>
<td>Load from SON</td>
<td>61.475us</td>
<td>4.510us</td>
<td>13.63x</td>
</tr>
<tr>
<td>Save to database</td>
<td>434.389us</td>
<td>289.972us</td>
<td>2.29x</td>
</tr>
<tr>
<td>Load from database</td>
<td>558.178us</td>
<td>480.690us</td>
<td>1.16x</td>
</tr>
<tr>
<td>Save/delete big object to database</td>
<td>98.838ms</td>
<td>65.789ms</td>
<td>1.50x</td>
</tr>
<tr>
<td>Serialize big object from database</td>
<td>31.390ms</td>
<td>20.265ms</td>
<td>1.55x</td>
</tr>
<tr>
<td>Load big object from database</td>
<td>41.159ms</td>
<td>1.400ms</td>
<td>29.40x</td>
</tr>
</table>
See [tests/benchmark.py](https://github.com/elasticsales/mongomallard/blob/master/tests/benchmark.py) for source code.

View file

@ -5,6 +5,10 @@
@import url("basic.css"); @import url("basic.css");
#changelog p.first {margin-bottom: 0 !important;}
#changelog p {margin-top: 0 !important;
margin-bottom: 0 !important;}
/* -- page layout ----------------------------------------------------------- */ /* -- page layout ----------------------------------------------------------- */
body { body {

View file

@ -44,12 +44,21 @@ Context Managers
Querying Querying
======== ========
.. autoclass:: mongoengine.queryset.QuerySet .. automodule:: mongoengine.queryset
:members: :synopsis: Queryset level operations
.. automethod:: mongoengine.queryset.QuerySet.__call__ .. autoclass:: mongoengine.queryset.QuerySet
:members:
:inherited-members:
.. autofunction:: mongoengine.queryset.queryset_manager .. automethod:: QuerySet.__call__
.. autoclass:: mongoengine.queryset.QuerySetNoCache
:members:
.. automethod:: mongoengine.queryset.QuerySetNoCache.__call__
.. autofunction:: mongoengine.queryset.queryset_manager
Fields Fields
====== ======

View file

@ -2,13 +2,49 @@
Changelog Changelog
========= =========
Changes in 0.8.4
================
- Remove database name necessity in uri connection schema (#452)
- Fixed "$pull" semantics for nested ListFields (#447)
- Allow fields to be named the same as query operators (#445)
- Updated field filter logic - can now exclude subclass fields (#443)
- Fixed dereference issue with embedded listfield referencefields (#439)
- Fixed slice when using inheritance causing fields to be excluded (#437)
- Fixed ._get_db() attribute after a Document.switch_db() (#441)
- Dynamic Fields store and recompose Embedded Documents / Documents correctly (#449)
- Handle dynamic fieldnames that look like digits (#434)
- Added get_user_document and improve mongo_auth module (#423)
- Added str representation of GridFSProxy (#424)
- Update transform to handle docs erroneously passed to unset (#416)
- Fixed indexing - turn off _cls (#414)
- Fixed dereference threading issue in ComplexField.__get__ (#412)
- Fixed QuerySetNoCache.count() caching (#410)
- Don't follow references in _get_changed_fields (#422, #417)
- Allow args and kwargs to be passed through to_json (#420)
Changes in 0.8.3 Changes in 0.8.3
================ ================
- Fixed EmbeddedDocuments with `id` also storing `_id` (#402)
- Added get_proxy_object helper to filefields (#391)
- Added QuerySetNoCache and QuerySet.no_cache() for lower memory consumption (#365)
- Fixed sum and average mapreduce dot notation support (#375, #376, #393)
- Fixed as_pymongo to return the id (#386)
- Document.select_related() now respects `db_alias` (#377)
- Reload uses shard_key if applicable (#384)
- Dynamic fields are ordered based on creation and stored in _fields_ordered (#396)
**Potential breaking change:** http://docs.mongoengine.org/en/latest/upgrade.html#to-0-8-3
- Fixed pickling dynamic documents `_dynamic_fields` (#387)
- Fixed ListField setslice and delslice dirty tracking (#390)
- Added Django 1.5 PY3 support (#392)
- Added match ($elemMatch) support for EmbeddedDocuments (#379) - Added match ($elemMatch) support for EmbeddedDocuments (#379)
- Fixed weakref being valid after reload (#374) - Fixed weakref being valid after reload (#374)
- Fixed queryset.get() respecting no_dereference (#373) - Fixed queryset.get() respecting no_dereference (#373)
- Added full_result kwarg to update (#380) - Added full_result kwarg to update (#380)
Changes in 0.8.2 Changes in 0.8.2
================ ================
- Added compare_indexes helper (#361) - Added compare_indexes helper (#361)

View file

@ -45,7 +45,7 @@ The :mod:`~mongoengine.django.auth` module also contains a
Custom User model Custom User model
================= =================
Django 1.5 introduced `Custom user Models Django 1.5 introduced `Custom user Models
<https://docs.djangoproject.com/en/dev/topics/auth/customizing/#auth-custom-user>` <https://docs.djangoproject.com/en/dev/topics/auth/customizing/#auth-custom-user>`_
which can be used as an alternative to the MongoEngine authentication backend. which can be used as an alternative to the MongoEngine authentication backend.
The main advantage of this option is that other components relying on The main advantage of this option is that other components relying on
@ -74,7 +74,7 @@ An additional ``MONGOENGINE_USER_DOCUMENT`` setting enables you to replace the
The custom :class:`User` must be a :class:`~mongoengine.Document` class, but The custom :class:`User` must be a :class:`~mongoengine.Document` class, but
otherwise has the same requirements as a standard custom user model, otherwise has the same requirements as a standard custom user model,
as specified in the `Django Documentation as specified in the `Django Documentation
<https://docs.djangoproject.com/en/dev/topics/auth/customizing/>`. <https://docs.djangoproject.com/en/dev/topics/auth/customizing/>`_.
In particular, the custom class must define :attr:`USERNAME_FIELD` and In particular, the custom class must define :attr:`USERNAME_FIELD` and
:attr:`REQUIRED_FIELDS` attributes. :attr:`REQUIRED_FIELDS` attributes.
@ -128,7 +128,7 @@ appended to the filename until the generated filename doesn't exist. The
>>> fs.listdir() >>> fs.listdir()
([], [u'hello.txt']) ([], [u'hello.txt'])
All files will be saved and retrieved in GridFS via the :class::`FileDocument` All files will be saved and retrieved in GridFS via the :class:`FileDocument`
document, allowing easy access to the files without the GridFSStorage document, allowing easy access to the files without the GridFSStorage
backend.:: backend.::
@ -137,3 +137,36 @@ backend.::
[<FileDocument: FileDocument object>] [<FileDocument: FileDocument object>]
.. versionadded:: 0.4 .. versionadded:: 0.4
Shortcuts
=========
Inspired by the `Django shortcut get_object_or_404
<https://docs.djangoproject.com/en/dev/topics/http/shortcuts/#get-object-or-404>`_,
the :func:`~mongoengine.django.shortcuts.get_document_or_404` method returns
a document or raises an Http404 exception if the document does not exist::
from mongoengine.django.shortcuts import get_document_or_404
admin_user = get_document_or_404(User, username='root')
The first argument may be a Document or QuerySet object. All other passed arguments
and keyword arguments are used in the query::
foo_email = get_document_or_404(User.objects.only('email'), username='foo', is_active=True).email
.. note:: Like with :func:`get`, a MultipleObjectsReturned will be raised if more than one
object is found.
Also inspired by the `Django shortcut get_list_or_404
<https://docs.djangoproject.com/en/dev/topics/http/shortcuts/#get-list-or-404>`_,
the :func:`~mongoengine.django.shortcuts.get_list_or_404` method returns a list of
documents or raises an Http404 exception if the list is empty::
from mongoengine.django.shortcuts import get_list_or_404
active_users = get_list_or_404(User, is_active=True)
The first argument may be a Document or QuerySet object. All other passed
arguments and keyword arguments are used to filter the query.

View file

@ -23,12 +23,15 @@ arguments should be provided::
connect('project1', username='webapp', password='pwd123') connect('project1', username='webapp', password='pwd123')
Uri style connections are also supported as long as you include the database Uri style connections are also supported - just supply the uri as
name - just supply the uri as the :attr:`host` to the :attr:`host` to
:func:`~mongoengine.connect`:: :func:`~mongoengine.connect`::
connect('project1', host='mongodb://localhost/database_name') connect('project1', host='mongodb://localhost/database_name')
Note that database name from uri has priority over name
in ::func:`~mongoengine.connect`
ReplicaSets ReplicaSets
=========== ===========

View file

@ -54,7 +54,7 @@ be saved ::
There is one caveat on Dynamic Documents: fields cannot start with `_` There is one caveat on Dynamic Documents: fields cannot start with `_`
Dynamic fields are stored in alphabetical order *after* any declared fields. Dynamic fields are stored in creation order *after* any declared fields.
Fields Fields
====== ======
@ -442,6 +442,8 @@ The following example shows a :class:`Log` document that will be limited to
ip_address = StringField() ip_address = StringField()
meta = {'max_documents': 1000, 'max_size': 2000000} meta = {'max_documents': 1000, 'max_size': 2000000}
.. defining-indexes_
Indexes Indexes
======= =======
@ -485,6 +487,35 @@ If a dictionary is passed then the following options are available:
Inheritance adds extra fields indices see: :ref:`document-inheritance`. Inheritance adds extra fields indices see: :ref:`document-inheritance`.
Global index default options
----------------------------
There are a few top level defaults for all indexes that can be set::
class Page(Document):
title = StringField()
rating = StringField()
meta = {
'index_options': {},
'index_background': True,
'index_drop_dups': True,
'index_cls': False
}
:attr:`index_options` (Optional)
Set any default index options - see the `full options list <http://docs.mongodb.org/manual/reference/method/db.collection.ensureIndex/#db.collection.ensureIndex>`_
:attr:`index_background` (Optional)
Set the default value for if an index should be indexed in the background
:attr:`index_drop_dups` (Optional)
Set the default value for if an index should drop duplicates
:attr:`index_cls` (Optional)
A way to turn off a specific index for _cls.
Compound Indexes and Indexing sub documents Compound Indexes and Indexing sub documents
------------------------------------------- -------------------------------------------
@ -558,6 +589,11 @@ documentation for more information. A common usecase might be session data::
] ]
} }
.. warning:: TTL indexes happen on the MongoDB server and not in the application
code, therefore no signals will be fired on document deletion.
If you need signals to be fired on deletion, then you must handle the
deletion of Documents in your application code.
Comparing Indexes Comparing Indexes
----------------- -----------------
@ -653,7 +689,6 @@ document.::
.. note:: From 0.8 onwards you must declare :attr:`allow_inheritance` defaults .. note:: From 0.8 onwards you must declare :attr:`allow_inheritance` defaults
to False, meaning you must set it to True to use inheritance. to False, meaning you must set it to True to use inheritance.
Working with existing data Working with existing data
-------------------------- --------------------------
As MongoEngine no longer defaults to needing :attr:`_cls` you can quickly and As MongoEngine no longer defaults to needing :attr:`_cls` you can quickly and
@ -673,3 +708,25 @@ defining all possible field types.
If you use :class:`~mongoengine.Document` and the database contains data that If you use :class:`~mongoengine.Document` and the database contains data that
isn't defined then that data will be stored in the `document._data` dictionary. isn't defined then that data will be stored in the `document._data` dictionary.
Abstract classes
================
If you want to add some extra functionality to a group of Document classes but
you don't need or want the overhead of inheritance you can use the
:attr:`abstract` attribute of :attr:`-mongoengine.Document.meta`.
This won't turn on :ref:`document-inheritance` but will allow you to keep your
code DRY::
class BaseDocument(Document):
meta = {
'abstract': True,
}
def check_permissions(self):
...
class User(BaseDocument):
...
Now the User class will have access to the inherited `check_permissions` method
and won't store any of the extra `_cls` information.

View file

@ -16,7 +16,9 @@ fetch documents from the database::
.. note:: .. note::
As of MongoEngine 0.8 the querysets utilise a local cache. So iterating As of MongoEngine 0.8 the querysets utilise a local cache. So iterating
it multiple times will only cause a single query. it multiple times will only cause a single query. If this is not the
desired behavour you can call :class:`~mongoengine.QuerySet.no_cache`
(version **0.8.3+**) to return a non-caching queryset.
Filtering queries Filtering queries
================= =================
@ -495,7 +497,6 @@ that you may use with these methods:
* ``unset`` -- delete a particular value (since MongoDB v1.3+) * ``unset`` -- delete a particular value (since MongoDB v1.3+)
* ``inc`` -- increment a value by a given amount * ``inc`` -- increment a value by a given amount
* ``dec`` -- decrement a value by a given amount * ``dec`` -- decrement a value by a given amount
* ``pop`` -- remove the last item from a list
* ``push`` -- append a value to a list * ``push`` -- append a value to a list
* ``push_all`` -- append several values to a list * ``push_all`` -- append several values to a list
* ``pop`` -- remove the first or last element of a list * ``pop`` -- remove the first or last element of a list

View file

@ -2,12 +2,22 @@
Upgrading Upgrading
######### #########
0.8.2 to 0.8.3
**************
Minor change that may impact users:
DynamicDocument fields are now stored in creation order after any declared
fields. Previously they were stored alphabetically.
0.7 to 0.8 0.7 to 0.8
********** **********
There have been numerous backwards breaking changes in 0.8. The reasons for There have been numerous backwards breaking changes in 0.8. The reasons for
these are ensure that MongoEngine has sane defaults going forward and these are to ensure that MongoEngine has sane defaults going forward and that it
performs the best it can out the box. Where possible there have been performs the best it can out of the box. Where possible there have been
FutureWarnings to help get you ready for the change, but that hasn't been FutureWarnings to help get you ready for the change, but that hasn't been
possible for the whole of the release. possible for the whole of the release.
@ -61,7 +71,7 @@ inherited classes like so: ::
Document Definition Document Definition
------------------- -------------------
The default for inheritance has changed - its now off by default and The default for inheritance has changed - it is now off by default and
:attr:`_cls` will not be stored automatically with the class. So if you extend :attr:`_cls` will not be stored automatically with the class. So if you extend
your :class:`~mongoengine.Document` or :class:`~mongoengine.EmbeddedDocuments` your :class:`~mongoengine.Document` or :class:`~mongoengine.EmbeddedDocuments`
you will need to declare :attr:`allow_inheritance` in the meta data like so: :: you will need to declare :attr:`allow_inheritance` in the meta data like so: ::
@ -71,7 +81,7 @@ you will need to declare :attr:`allow_inheritance` in the meta data like so: ::
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
Previously, if you had data the database that wasn't defined in the Document Previously, if you had data in the database that wasn't defined in the Document
definition, it would set it as an attribute on the document. This is no longer definition, it would set it as an attribute on the document. This is no longer
the case and the data is set only in the ``document._data`` dictionary: :: the case and the data is set only in the ``document._data`` dictionary: ::
@ -92,8 +102,8 @@ the case and the data is set only in the ``document._data`` dictionary: ::
AttributeError: 'Animal' object has no attribute 'size' AttributeError: 'Animal' object has no attribute 'size'
The Document class has introduced a reserved function `clean()`, which will be The Document class has introduced a reserved function `clean()`, which will be
called before saving the document. If your document class happen to have a method called before saving the document. If your document class happens to have a method
with the same name, please try rename it. with the same name, please try to rename it.
def clean(self): def clean(self):
pass pass
@ -101,7 +111,7 @@ with the same name, please try rename it.
ReferenceField ReferenceField
-------------- --------------
ReferenceFields now store ObjectId's by default - this is more efficient than ReferenceFields now store ObjectIds by default - this is more efficient than
DBRefs as we already know what Document types they reference:: DBRefs as we already know what Document types they reference::
# Old code # Old code
@ -147,7 +157,7 @@ UUIDFields now default to storing binary values::
class Animal(Document): class Animal(Document):
uuid = UUIDField(binary=False) uuid = UUIDField(binary=False)
To migrate all the uuid's you need to touch each object and mark it as dirty To migrate all the uuids you need to touch each object and mark it as dirty
eg:: eg::
# Doc definition # Doc definition
@ -165,7 +175,7 @@ eg::
DecimalField DecimalField
------------ ------------
DecimalField now store floats - previous it was storing strings and that DecimalFields now store floats - previously it was storing strings and that
made it impossible to do comparisons when querying correctly.:: made it impossible to do comparisons when querying correctly.::
# Old code # Old code
@ -176,7 +186,7 @@ made it impossible to do comparisons when querying correctly.::
class Person(Document): class Person(Document):
balance = DecimalField(force_string=True) balance = DecimalField(force_string=True)
To migrate all the uuid's you need to touch each object and mark it as dirty To migrate all the DecimalFields you need to touch each object and mark it as dirty
eg:: eg::
# Doc definition # Doc definition
@ -188,7 +198,7 @@ eg::
p._mark_as_changed('balance') p._mark_as_changed('balance')
p.save() p.save()
.. note:: DecimalField's have also been improved with the addition of precision .. note:: DecimalFields have also been improved with the addition of precision
and rounding. See :class:`~mongoengine.fields.DecimalField` for more information. and rounding. See :class:`~mongoengine.fields.DecimalField` for more information.
`An example test migration for DecimalFields is available on github `An example test migration for DecimalFields is available on github
@ -197,7 +207,7 @@ eg::
Cascading Saves Cascading Saves
--------------- ---------------
To improve performance document saves will no longer automatically cascade. To improve performance document saves will no longer automatically cascade.
Any changes to a Documents references will either have to be saved manually or Any changes to a Document's references will either have to be saved manually or
you will have to explicitly tell it to cascade on save:: you will have to explicitly tell it to cascade on save::
# At the class level: # At the class level:
@ -239,7 +249,7 @@ update your code like so: ::
# Update example a) assign queryset after a change: # Update example a) assign queryset after a change:
mammals = Animal.objects(type="mammal") mammals = Animal.objects(type="mammal")
carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so fitler can be applied carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so filter can be applied
[m for m in carnivores] # This will return all carnivores [m for m in carnivores] # This will return all carnivores
# Update example b) chain the queryset: # Update example b) chain the queryset:
@ -266,7 +276,7 @@ queryset you should upgrade to use count::
.only() now inline with .exclude() .only() now inline with .exclude()
---------------------------------- ----------------------------------
The behaviour of `.only()` was highly ambious, now it works in the mirror fashion The behaviour of `.only()` was highly ambiguous, now it works in mirror fashion
to `.exclude()`. Chaining `.only()` calls will increase the fields required:: to `.exclude()`. Chaining `.only()` calls will increase the fields required::
# Old code # Old code
@ -430,7 +440,7 @@ main areas of changed are: choices in fields, map_reduce and collection names.
Choice options: Choice options:
=============== ===============
Are now expected to be an iterable of tuples, with the first element in each Are now expected to be an iterable of tuples, with the first element in each
tuple being the actual value to be stored. The second element is the tuple being the actual value to be stored. The second element is the
human-readable name for the option. human-readable name for the option.
@ -452,8 +462,8 @@ such the following have been changed:
Default collection naming Default collection naming
========================= =========================
Previously it was just lowercase, its now much more pythonic and readable as Previously it was just lowercase, it's now much more pythonic and readable as
its lowercase and underscores, previously :: it's lowercase and underscores, previously ::
class MyAceDocument(Document): class MyAceDocument(Document):
pass pass
@ -520,5 +530,5 @@ Alternatively, you can rename your collections eg ::
mongodb 1.8 > 2.0 + mongodb 1.8 > 2.0 +
=================== ===================
Its been reported that indexes may need to be recreated to the newer version of indexes. It's been reported that indexes may need to be recreated to the newer version of indexes.
To do this drop indexes and call ``ensure_indexes`` on each model. To do this drop indexes and call ``ensure_indexes`` on each model.

View file

@ -15,8 +15,7 @@ import django
__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ +
list(queryset.__all__) + signals.__all__ + list(errors.__all__)) list(queryset.__all__) + signals.__all__ + list(errors.__all__))
VERSION = (0, 8, 2) VERSION = (0, 8, 4)
MALLARD = True
def get_version(): def get_version():

View file

@ -108,6 +108,14 @@ class BaseList(list):
self._mark_as_changed() self._mark_as_changed()
return super(BaseList, self).__delitem__(*args, **kwargs) return super(BaseList, self).__delitem__(*args, **kwargs)
def __setslice__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).__setslice__(*args, **kwargs)
def __delslice__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).__delslice__(*args, **kwargs)
def __getstate__(self): def __getstate__(self):
self.instance = None self.instance = None
self._dereferenced = False self._dereferenced = False

View file

@ -4,7 +4,7 @@ import numbers
from functools import partial from functools import partial
import pymongo import pymongo
from bson import json_util from bson import json_util, ObjectId
from bson.dbref import DBRef from bson.dbref import DBRef
from bson.son import SON from bson.son import SON
@ -15,7 +15,6 @@ from mongoengine.errors import (ValidationError, InvalidDocumentError,
from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type,
to_str_keys_recursive) to_str_keys_recursive)
from mongoengine.base.proxy import DocumentProxy
from mongoengine.base.common import get_document, ALLOW_INHERITANCE from mongoengine.base.common import get_document, ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList from mongoengine.base.datastructures import BaseDict, BaseList
from mongoengine.base.fields import ComplexBaseField from mongoengine.base.fields import ComplexBaseField
@ -24,52 +23,155 @@ __all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
NON_FIELD_ERRORS = '__all__' NON_FIELD_ERRORS = '__all__'
_set = object.__setattr__
class BaseDocument(object): class BaseDocument(object):
#_dynamic = False _dynamic = False
#_dynamic_lock = True _created = True
_dynamic_lock = True
_initialised = False _initialised = False
def __init__(self, _son=None, **values): def __init__(self, *args, **values):
""" """
Initialise a document or embedded document Initialise a document or embedded document
:param __auto_convert: Try and will cast python objects to Object types
:param values: A dictionary of values for the document :param values: A dictionary of values for the document
""" """
_set(self, '_db_data', _son) if args:
_set(self, '_lazy', False) # Combine positional arguments with named arguments.
_set(self, '_internal_data', {}) # We only want named arguments.
_set(self, '_changed_fields', set()) field = iter(self._fields_ordered)
if values: # If its an automatic id field then skip to the first defined field
pk = values.pop('pk', None) if self._auto_id_field:
for field in set(self._fields.keys()).intersection(values.keys()): next(field)
setattr(self, field, values[field]) for value in args:
if pk != None: name = next(field)
self.pk = pk if name in values:
raise TypeError("Multiple values for keyword argument '" + name + "'")
values[name] = value
__auto_convert = values.pop("__auto_convert", True)
signals.pre_init.send(self.__class__, document=self, values=values)
def __delattr__(self, name): self._data = {}
default = self._fields[name].default self._dynamic_fields = SON()
value = default() if callable(default) else default
setattr(self, name, value)
@property # Assign default values to instance
def _created(self): for key, field in self._fields.iteritems():
return self._db_data != None or self._lazy if self._db_field_map.get(key, key) in values:
continue
value = getattr(self, key, None)
setattr(self, key, value)
# Set passed values after initialisation
if self._dynamic:
dynamic_data = {}
for key, value in values.iteritems():
if key in self._fields or key == '_id':
setattr(self, key, value)
elif self._dynamic:
dynamic_data[key] = value
else:
FileField = _import_class('FileField')
for key, value in values.iteritems():
if key == '__auto_convert':
continue
key = self._reverse_db_field_map.get(key, key)
if key in self._fields or key in ('id', 'pk', '_cls'):
if __auto_convert and value is not None:
field = self._fields.get(key)
if field and not isinstance(field, FileField):
value = field.to_python(value)
setattr(self, key, value)
else:
self._data[key] = value
# Set any get_fieldname_display methods
self.__set_field_display()
if self._dynamic:
self._dynamic_lock = False
for key, value in dynamic_data.iteritems():
setattr(self, key, value)
# Flag initialised
self._initialised = True
signals.post_init.send(self.__class__, document=self)
def __delattr__(self, *args, **kwargs):
"""Handle deletions of fields"""
field_name = args[0]
if field_name in self._fields:
default = self._fields[field_name].default
if callable(default):
default = default()
setattr(self, field_name, default)
else:
super(BaseDocument, self).__delattr__(*args, **kwargs)
def __setattr__(self, name, value):
# Handle dynamic data only if an initialised dynamic document
if self._dynamic and not self._dynamic_lock:
field = None
if not hasattr(self, name) and not name.startswith('_'):
DynamicField = _import_class("DynamicField")
field = DynamicField(db_field=name)
field.name = name
self._dynamic_fields[name] = field
self._fields_ordered += (name,)
if not name.startswith('_'):
value = self.__expand_dynamic_values(name, value)
# Handle marking data as changed
if name in self._dynamic_fields:
self._data[name] = value
if hasattr(self, '_changed_fields'):
self._mark_as_changed(name)
if (self._is_document and not self._created and
name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value):
OperationError = _import_class('OperationError')
msg = "Shard Keys are immutable. Tried to update %s" % name
raise OperationError(msg)
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
super(BaseDocument, self).__setattr__(name, value)
def __getstate__(self):
data = {}
for k in ('_changed_fields', '_initialised', '_created',
'_dynamic_fields', '_fields_ordered'):
if hasattr(self, k):
data[k] = getattr(self, k)
data['_data'] = self.to_mongo()
return data
def __setstate__(self, data):
if isinstance(data["_data"], SON):
data["_data"] = self.__class__._from_son(data["_data"])._data
for k in ('_changed_fields', '_initialised', '_created', '_data',
'_fields_ordered', '_dynamic_fields'):
if k in data:
setattr(self, k, data[k])
dynamic_fields = data.get('_dynamic_fields') or SON()
for k in dynamic_fields.keys():
setattr(self, k, data["_data"].get(k))
def __iter__(self): def __iter__(self):
if 'id' in self._fields and 'id' not in self._fields_ordered:
return iter(('id', ) + self._fields_ordered)
return iter(self._fields_ordered) return iter(self._fields_ordered)
def __getitem__(self, name): def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present. """Dictionary-style field access, return a field's value if present.
""" """
try: try:
if name in self._fields: if name in self._fields_ordered:
return getattr(self, name) return getattr(self, name)
except AttributeError: except AttributeError:
pass pass
@ -90,8 +192,8 @@ class BaseDocument(object):
except AttributeError: except AttributeError:
return False return False
def __unicode__(self): def __len__(self):
return u'%s object' % self.__class__.__name__ return len(self._data)
def __repr__(self): def __repr__(self):
try: try:
@ -110,12 +212,9 @@ class BaseDocument(object):
return txt_type('%s object' % self.__class__.__name__) return txt_type('%s object' % self.__class__.__name__)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, DocumentProxy) and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk') and self.pk == other.pk: if isinstance(other, self.__class__) and hasattr(other, 'id'):
return True if self.id == other.id:
return True
if isinstance(other, self.__class__) and hasattr(other, 'pk') and self.pk == other.pk:
return True
return False return False
def __ne__(self, other): def __ne__(self, other):
@ -141,16 +240,42 @@ class BaseDocument(object):
def to_mongo(self): def to_mongo(self):
"""Return as SON data ready for use with MongoDB. """Return as SON data ready for use with MongoDB.
""" """
sets, unsets = self._delta(full=True) data = SON()
son = SON(data=sets) data["_id"] = None
allow_inheritance = self._meta.get('allow_inheritance', data['_cls'] = self._class_name
ALLOW_INHERITANCE)
if allow_inheritance:
son['_cls'] = self._class_name
return son
def to_dict(self): for field_name in self:
return dict((field, getattr(self, field)) for field in self._fields) value = self._data.get(field_name, None)
field = self._fields.get(field_name)
if field is None and self._dynamic:
field = self._dynamic_fields.get(field_name)
if value is not None:
value = field.to_mongo(value)
# Handle self generating fields
if value is None and field._auto_gen:
value = field.generate()
self._data[field_name] = value
if value is not None:
data[field.db_field] = value
# If "_id" has not been set, then try and set it
Document = _import_class("Document")
if isinstance(self, Document):
if data["_id"] is None:
data["_id"] = self._data.get("id", None)
if data['_id'] is None:
data.pop('_id')
# Only add _cls if allow_inheritance is True
if (not hasattr(self, '_meta') or
not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)):
data.pop('_cls')
return data
def validate(self, clean=True): def validate(self, clean=True):
"""Ensure that all fields' values are valid and that required fields """Ensure that all fields' values are valid and that required fields
@ -165,11 +290,8 @@ class BaseDocument(object):
errors[NON_FIELD_ERRORS] = error errors[NON_FIELD_ERRORS] = error
# Get a list of tuples of field names and their current values # Get a list of tuples of field names and their current values
fields = [(field, getattr(self, name)) fields = [(self._fields.get(name, self._dynamic_fields.get(name)),
for name, field in self._fields.items()] self._data.get(name)) for name in self._fields_ordered]
#if self._dynamic:
# fields += [(field, self._data.get(name))
# for name, field in self._dynamic_fields.items()]
EmbeddedDocumentField = _import_class("EmbeddedDocumentField") EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
@ -199,9 +321,9 @@ class BaseDocument(object):
message = "ValidationError (%s:%s) " % (self._class_name, pk) message = "ValidationError (%s:%s) " % (self._class_name, pk)
raise ValidationError(message, errors=errors) raise ValidationError(message, errors=errors)
def to_json(self): def to_json(self, *args, **kwargs):
"""Converts a document to JSON""" """Converts a document to JSON"""
return json_util.dumps(self.to_mongo()) return json_util.dumps(self.to_mongo(), *args, **kwargs)
@classmethod @classmethod
def from_json(cls, json_data): def from_json(cls, json_data):
@ -243,40 +365,17 @@ class BaseDocument(object):
return value return value
def _mark_as_changed(self, key): def _mark_as_changed(self, key):
"""Marks a key as explicitly changed by the user. """Marks a key as explicitly changed by the user
""" """
if not key:
if key: return
self._changed_fields.add(key) key = self._db_field_map.get(key, key)
if (hasattr(self, '_changed_fields') and
def _get_changed_fields(self): key not in self._changed_fields):
"""Returns a list of all fields that have explicitly been changed. self._changed_fields.append(key)
"""
changed_fields = set(self._changed_fields)
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
for field_name, field in self._fields.iteritems():
if field_name not in changed_fields:
if (isinstance(field, ComplexBaseField) and
isinstance(field.field, EmbeddedDocumentField)):
field_value = getattr(self, field_name, None)
if field_value:
for idx in (field_value if isinstance(field_value, dict)
else xrange(len(field_value))):
changed_subfields = field_value[idx]._get_changed_fields()
if changed_subfields:
changed_fields |= set(['.'.join([field_name, str(idx), subfield_name])
for subfield_name in changed_subfields])
elif isinstance(field, EmbeddedDocumentField):
field_value = getattr(self, field_name, None)
if field_value:
changed_subfields = field_value._get_changed_fields()
if changed_subfields:
changed_fields |= set(['.'.join([field_name, subfield_name])
for subfield_name in changed_subfields])
return changed_fields
def _clear_changed_fields(self): def _clear_changed_fields(self):
_set(self, '_changed_fields', set()) self._changed_fields = []
EmbeddedDocumentField = _import_class("EmbeddedDocumentField") EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
for field_name, field in self._fields.iteritems(): for field_name, field in self._fields.iteritems():
if (isinstance(field, ComplexBaseField) and if (isinstance(field, ComplexBaseField) and
@ -291,57 +390,136 @@ class BaseDocument(object):
if field_value: if field_value:
field_value._clear_changed_fields() field_value._clear_changed_fields()
def _delta(self, full=False): def _get_changed_fields(self, inspected=None):
sets = {} """Returns a list of all fields that have explicitly been changed.
unsets = {} """
EmbeddedDocument = _import_class("EmbeddedDocument")
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
ReferenceField = _import_class("ReferenceField")
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
inspected = inspected or set()
if hasattr(self, 'id'):
if self.id in inspected:
return _changed_fields
inspected.add(self.id)
def get_db_value(field, value): for field_name in self._fields_ordered:
if value is None: db_field_name = self._db_field_map.get(field_name, field_name)
value = field.default() if callable(field.default) else field.default key = '%s.' % db_field_name
return field.to_mongo(value) data = self._data.get(field_name, None)
field = self._fields.get(field_name)
if hasattr(data, 'id'):
if data.id in inspected:
continue
inspected.add(data.id)
if isinstance(field, ReferenceField):
continue
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in _changed_fields):
# Find all embedded fields that have been changed
changed = data._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in _changed_fields):
# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(data, 'items'):
iterator = enumerate(data)
else:
iterator = data.iteritems()
for index, value in iterator:
if not hasattr(value, '_get_changed_fields'):
continue
if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)):
continue
list_key = "%s%s." % (key, index)
changed = value._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (list_key, k)
for k in changed if k]
return _changed_fields
if full or not self._created: def _delta(self):
fields = self._fields.iteritems() """Returns the delta (set, unset) of the changes for a document.
db_data = ((self._db_field_map.get(field_name, field_name), Gets any values that have been explicitly changed.
get_db_value(field, getattr(self, field_name))) """
for field_name, field in fields) # Handles cases where not loaded from_son but has _id
doc = self.to_mongo()
set_fields = self._get_changed_fields()
unset_data = {}
parts = []
if hasattr(self, '_changed_fields'):
set_data = {}
# Fetch each set item from its path
for path in set_fields:
parts = path.split('.')
d = doc
new_path = []
for p in parts:
if isinstance(d, (ObjectId, DBRef)):
break
elif isinstance(d, list) and p.isdigit():
d = d[int(p)]
elif hasattr(d, 'get'):
d = d.get(p)
new_path.append(p)
path = '.'.join(new_path)
set_data[path] = d
else: else:
# List of (db_field_name, db_value) tuples. set_data = doc
db_data = [] if '_id' in set_data:
del(set_data['_id'])
for field_name in self._get_changed_fields(): # Determine if any changed items were actually unset.
parts = field_name.split('.') for path, value in set_data.items():
if value or isinstance(value, (numbers.Number, bool)):
continue
db_field_parts = [] # If we've set a value that ain't the default value dont unset it.
default = None
if (self._dynamic and len(parts) and parts[0] in
self._dynamic_fields):
del(set_data[path])
unset_data[path] = 1
continue
elif path in self._fields:
default = self._fields[path].default
else: # Perform a full lookup for lists / embedded lookups
d = self
parts = path.split('.')
db_field_name = parts.pop()
for p in parts:
if isinstance(d, list) and p.isdigit():
d = d[int(p)]
elif (hasattr(d, '__getattribute__') and
not isinstance(d, dict)):
real_path = d._reverse_db_field_map.get(p, p)
d = getattr(d, real_path)
else:
d = d.get(p)
value = self if hasattr(d, '_fields'):
for part in parts: field_name = d._reverse_db_field_map.get(db_field_name,
if isinstance(value, list) and part.isdigit(): db_field_name)
db_field_parts.append(part) if field_name in d._fields:
field = field.field default = d._fields.get(field_name).default
value = value[int(part)] else:
elif isinstance(value, dict): default = None
db_field_parts.append(part)
field = field.field
value = value[part]
else: # It's a document
obj = value
field = obj._fields[part]
db_field_parts.append(obj._db_field_map.get(part, part))
value = getattr(obj, part)
db_data.append(('.'.join(db_field_parts), get_db_value(field, value))) if default is not None:
if callable(default):
default = default()
for db_field_name, db_value in db_data: if default != value:
if db_value == None: continue
unsets[db_field_name] = 1
else:
sets[db_field_name] = db_value
return sets, unsets del(set_data[path])
unset_data[path] = 1
return set_data, unset_data
@classmethod @classmethod
def _get_collection_name(cls): def _get_collection_name(cls):
@ -350,16 +528,61 @@ class BaseDocument(object):
return cls._meta.get('collection', None) return cls._meta.get('collection', None)
@classmethod @classmethod
def _from_son(cls, son, _auto_dereference=False): def _from_son(cls, son, _auto_dereference=True):
"""Create an instance of a Document (subclass) from a PyMongo SON.
"""
# get the class name from the document, falling back to the given # get the class name from the document, falling back to the given
# class if unavailable # class if unavailable
class_name = son.get('_cls', cls._class_name) class_name = son.get('_cls', cls._class_name)
data = dict(("%s" % key, value) for key, value in son.iteritems())
if not UNICODE_KWARGS:
# python 2.6.4 and lower cannot handle unicode keys
# passed to class constructor example: cls(**data)
to_str_keys_recursive(data)
# Return correct subclass for document type # Return correct subclass for document type
if class_name != cls._class_name: if class_name != cls._class_name:
cls = get_document(class_name) cls = get_document(class_name)
return cls(_son=son) changed_fields = []
errors_dict = {}
fields = cls._fields
if not _auto_dereference:
fields = copy.copy(fields)
for field_name, field in fields.iteritems():
field._auto_dereference = _auto_dereference
if field.db_field in data:
value = data[field.db_field]
try:
data[field_name] = (value if value is None
else field.to_python(value))
if field_name != field.db_field:
del data[field.db_field]
except (AttributeError, ValueError), e:
errors_dict[field_name] = e
elif field.default:
default = field.default
if callable(default):
default = default()
if isinstance(default, BaseDocument):
changed_fields.append(field_name)
if errors_dict:
errors = "\n".join(["%s - %s" % (k, v)
for k, v in errors_dict.items()])
msg = ("Invalid data to create a `%s` instance.\n%s"
% (cls._class_name, errors))
raise InvalidDocumentError(msg)
obj = cls(__auto_convert=False, **data)
obj._changed_fields = changed_fields
obj._created = False
if not _auto_dereference:
obj._fields = fields
return obj
@classmethod @classmethod
def _build_index_specs(cls, meta_indexes): def _build_index_specs(cls, meta_indexes):
@ -406,8 +629,10 @@ class BaseDocument(object):
# Check to see if we need to include _cls # Check to see if we need to include _cls
allow_inheritance = cls._meta.get('allow_inheritance', allow_inheritance = cls._meta.get('allow_inheritance',
ALLOW_INHERITANCE) ALLOW_INHERITANCE)
include_cls = allow_inheritance and not spec.get('sparse', False) include_cls = (allow_inheritance and not spec.get('sparse', False) and
spec.get('cls', True))
if "cls" in spec:
spec.pop('cls')
for key in spec['fields']: for key in spec['fields']:
# If inherited spec continue # If inherited spec continue
if isinstance(key, (list, tuple)): if isinstance(key, (list, tuple)):
@ -537,7 +762,7 @@ class BaseDocument(object):
for field_name in parts: for field_name in parts:
# Handle ListField indexing: # Handle ListField indexing:
if field_name.isdigit(): if field_name.isdigit() and hasattr(field, 'field'):
new_field = field.field new_field = field.field
fields.append(field_name) fields.append(field_name)
continue continue
@ -549,9 +774,9 @@ class BaseDocument(object):
field_name = cls._meta['id_field'] field_name = cls._meta['id_field']
if field_name in cls._fields: if field_name in cls._fields:
field = cls._fields[field_name] field = cls._fields[field_name]
#elif cls._dynamic: elif cls._dynamic:
# DynamicField = _import_class('DynamicField') DynamicField = _import_class('DynamicField')
# field = DynamicField(db_field=field_name) field = DynamicField(db_field=field_name)
else: else:
raise LookUpError('Cannot resolve field "%s"' raise LookUpError('Cannot resolve field "%s"'
% field_name) % field_name)

View file

@ -59,17 +59,15 @@ class BaseField(object):
:param help_text: (optional) The help text for this field and is often :param help_text: (optional) The help text for this field and is often
used when generating model forms from the document model. used when generating model forms from the document model.
""" """
self.name = None # filled in by document self.db_field = (db_field or name) if not primary_key else '_id'
self.db_field = db_field if name:
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
warnings.warn(msg, DeprecationWarning)
self.required = required or primary_key self.required = required or primary_key
self.default = default self.default = default
self.unique = bool(unique or unique_with) self.unique = bool(unique or unique_with)
self.unique_with = unique_with self.unique_with = unique_with
self.primary_key = primary_key self.primary_key = primary_key
if self.primary_key:
if self.db_field:
raise ValueError("Can't use primary_key in conjunction with db_field.")
self.db_field = '_id'
self.validation = validation self.validation = validation
self.choices = choices self.choices = choices
self.verbose_name = verbose_name self.verbose_name = verbose_name
@ -84,52 +82,41 @@ class BaseField(object):
BaseField.creation_counter += 1 BaseField.creation_counter += 1
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document.
"""
if instance is None: if instance is None:
# Document class being used rather than a document object # Document class being used rather than a document object
return self return self
else:
name = self.name
data = instance._internal_data
if not name in data:
if instance._lazy and name != instance._meta['id_field']:
# We need to fetch the doc from the database.
instance.reload()
db_field = instance._db_field_map.get(name, name)
try:
db_value = instance._db_data[db_field]
except (TypeError, KeyError):
value = self.default() if callable(self.default) else self.default
else:
value = self.to_python(db_value)
if hasattr(self, 'value_for_instance'): # Get value from document instance if available
value = self.value_for_instance(value, instance) value = instance._data.get(self.name)
data[name] = value
return data[name] EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = weakref.proxy(instance)
return value
def __set__(self, instance, value): def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document. """Descriptor for assigning a value to a field in a document.
""" """
if instance._lazy: # If setting to None and theres a default
# Fetch the from the database before we assign to a lazy object. # Then set the value to the default value
instance.reload() if value is None and self.default is not None:
value = self.default
if callable(value):
value = value()
name = self.name if instance._initialised:
try:
value = self.from_python(value) if (self.name not in instance._data or
if hasattr(self, 'value_for_instance'): instance._data[self.name] != value):
value = self.value_for_instance(value, instance) instance._mark_as_changed(self.name)
try: except:
has_changed = name not in instance._internal_data or instance._internal_data[name] != value # Values cant be compared eg: naive and tz datetimes
except: # Values can't be compared eg: naive and tz datetimes # So mark it as changed
has_changed = True instance._mark_as_changed(self.name)
instance._data[self.name] = value
if has_changed:
instance._mark_as_changed(name)
instance._internal_data[name] = value
def error(self, message="", errors=None, field_name=None): def error(self, message="", errors=None, field_name=None):
"""Raises a ValidationError. """Raises a ValidationError.
@ -145,15 +132,7 @@ class BaseField(object):
def to_mongo(self, value): def to_mongo(self, value):
"""Convert a Python type to a MongoDB-compatible type. """Convert a Python type to a MongoDB-compatible type.
""" """
return value return self.to_python(value)
def from_python(self, value):
"""Convert a raw Python value (in an assignment) to the internal
Python representation.
"""
if value == None:
return self.default() if callable(self.default) else self.default
return value
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
"""Prepare a value that is being used in a query for PyMongo. """Prepare a value that is being used in a query for PyMongo.
@ -208,6 +187,50 @@ class ComplexBaseField(BaseField):
field = None field = None
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
dereference = (self._auto_dereference and
(self.field is None or isinstance(self.field,
(GenericReferenceField, ReferenceField))))
_dereference = _import_class("DeReference")()
self._auto_dereference = instance._fields[self.name]._auto_dereference
if instance._initialised and dereference:
instance._data[self.name] = _dereference(
instance._data.get(self.name), max_depth=1, instance=instance,
name=self.name
)
value = super(ComplexBaseField, self).__get__(instance, owner)
# Convert lists / values so we can watch for any changes on them
if (isinstance(value, (list, tuple)) and
not isinstance(value, BaseList)):
value = BaseList(value, instance, self.name)
instance._data[self.name] = value
elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, instance, self.name)
instance._data[self.name] = value
if (self._auto_dereference and instance._initialised and
isinstance(value, (BaseList, BaseDict))
and not value._dereferenced):
value = _dereference(
value, max_depth=1, instance=instance, name=self.name
)
value._dereferenced = True
instance._data[self.name] = value
return value
def to_python(self, value): def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type.
""" """
@ -366,10 +389,12 @@ class ObjectIdField(BaseField):
""" """
def to_python(self, value): def to_python(self, value):
if not isinstance(value, ObjectId):
value = ObjectId(value)
return value return value
def to_mongo(self, value): def to_mongo(self, value):
if value and not isinstance(value, ObjectId): if not isinstance(value, ObjectId):
try: try:
return ObjectId(unicode(value)) return ObjectId(unicode(value))
except Exception, e: except Exception, e:

View file

@ -91,11 +91,12 @@ class DocumentMetaclass(type):
attrs['_fields'] = doc_fields attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k))
for k, v in doc_fields.iteritems()]) for k, v in doc_fields.iteritems()])
attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems())
attrs['_fields_ordered'] = tuple(i[1] for i in sorted( attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
(v.creation_counter, v.name) (v.creation_counter, v.name)
for v in doc_fields.itervalues())) for v in doc_fields.itervalues()))
attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems())
# #
# Set document hierarchy # Set document hierarchy
@ -358,15 +359,17 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class.id = field new_class.id = field
# Set primary key if not defined by the document # Set primary key if not defined by the document
new_class._auto_id_field = False
if not new_class._meta.get('id_field'): if not new_class._meta.get('id_field'):
id_field = ObjectIdField(primary_key=True) new_class._auto_id_field = True
id_field.name = 'id'
id_field._auto_gen = True
new_class._fields['id'] = id_field
new_class.id = new_class._fields['id']
new_class._meta['id_field'] = 'id' new_class._meta['id_field'] = 'id'
new_class._db_field_map['id'] = id_field.db_field new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class._fields['id'].name = 'id'
new_class.id = new_class._fields['id']
# Prepend id field to _fields_ordered
if 'id' in new_class._fields and 'id' not in new_class._fields_ordered:
new_class._fields_ordered = ('id', ) + new_class._fields_ordered
# Merge in exceptions with parent hierarchy # Merge in exceptions with parent hierarchy
exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned)

View file

@ -1,190 +0,0 @@
from mongoengine.queryset import OperationError
from bson.dbref import DBRef
class LocalProxy(object):
# From werkzeug/local.py
""" Forwards all operations to
a proxied object. The only operations not supported for forwarding
are right handed operands and any kind of assignment.
"""
__slots__ = ('__local', '__dict__', '__name__')
def __init__(self, local, name=None):
object.__setattr__(self, '_LocalProxy__local', local)
object.__setattr__(self, '__name__', name)
def _get_current_object(self):
"""Return the current object. This is useful if you want the real
object behind the proxy at a time for performance reasons or because
you want to pass the object into a different context.
"""
if not hasattr(self.__local, '__release_local__'):
return self.__local()
try:
return getattr(self.__local, self.__name__)
except AttributeError:
raise RuntimeError('no object bound to %s' % self.__name__)
@property
def __dict__(self):
try:
return self._get_current_object().__dict__
except RuntimeError:
raise AttributeError('__dict__')
def __repr__(self):
try:
obj = self._get_current_object()
except RuntimeError:
return '<%s unbound>' % self.__class__.__name__
return repr(obj)
def __nonzero__(self):
try:
return bool(self._get_current_object())
except RuntimeError:
return False
def __unicode__(self):
try:
return unicode(self._get_current_object())
except RuntimeError:
return repr(self)
def __dir__(self):
try:
return dir(self._get_current_object())
except RuntimeError:
return []
def __getattr__(self, name):
if name == '__members__':
return dir(self._get_current_object())
return getattr(self._get_current_object(), name)
def __setitem__(self, key, value):
self._get_current_object()[key] = value
def __delitem__(self, key):
del self._get_current_object()[key]
def __setslice__(self, i, j, seq):
self._get_current_object()[i:j] = seq
def __delslice__(self, i, j):
del self._get_current_object()[i:j]
__setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v)
__delattr__ = lambda x, n: delattr(x._get_current_object(), n)
__str__ = lambda x: str(x._get_current_object())
__lt__ = lambda x, o: x._get_current_object() < o
__le__ = lambda x, o: x._get_current_object() <= o
__eq__ = lambda x, o: x._get_current_object() == o
__ne__ = lambda x, o: x._get_current_object() != o
__gt__ = lambda x, o: x._get_current_object() > o
__ge__ = lambda x, o: x._get_current_object() >= o
__cmp__ = lambda x, o: cmp(x._get_current_object(), o)
__hash__ = lambda x: hash(x._get_current_object())
__call__ = lambda x, *a, **kw: x._get_current_object()(*a, **kw)
__len__ = lambda x: len(x._get_current_object())
__getitem__ = lambda x, i: x._get_current_object()[i]
__iter__ = lambda x: iter(x._get_current_object())
__contains__ = lambda x, i: i in x._get_current_object()
__getslice__ = lambda x, i, j: x._get_current_object()[i:j]
__add__ = lambda x, o: x._get_current_object() + o
__sub__ = lambda x, o: x._get_current_object() - o
__mul__ = lambda x, o: x._get_current_object() * o
__floordiv__ = lambda x, o: x._get_current_object() // o
__mod__ = lambda x, o: x._get_current_object() % o
__divmod__ = lambda x, o: x._get_current_object().__divmod__(o)
__pow__ = lambda x, o: x._get_current_object() ** o
__lshift__ = lambda x, o: x._get_current_object() << o
__rshift__ = lambda x, o: x._get_current_object() >> o
__and__ = lambda x, o: x._get_current_object() & o
__xor__ = lambda x, o: x._get_current_object() ^ o
__or__ = lambda x, o: x._get_current_object() | o
__div__ = lambda x, o: x._get_current_object().__div__(o)
__truediv__ = lambda x, o: x._get_current_object().__truediv__(o)
__neg__ = lambda x: -(x._get_current_object())
__pos__ = lambda x: +(x._get_current_object())
__abs__ = lambda x: abs(x._get_current_object())
__invert__ = lambda x: ~(x._get_current_object())
__complex__ = lambda x: complex(x._get_current_object())
__int__ = lambda x: int(x._get_current_object())
__long__ = lambda x: long(x._get_current_object())
__float__ = lambda x: float(x._get_current_object())
__oct__ = lambda x: oct(x._get_current_object())
__hex__ = lambda x: hex(x._get_current_object())
__index__ = lambda x: x._get_current_object().__index__()
__coerce__ = lambda x, o: x.__coerce__(x, o)
__enter__ = lambda x: x.__enter__()
__exit__ = lambda x, *a, **kw: x.__exit__(*a, **kw)
class DocumentProxy(LocalProxy):
__slots__ = ('__document_type', '__document', '__pk')
def __init__(self, document_type, pk):
object.__setattr__(self, '_DocumentProxy__document_type', document_type)
object.__setattr__(self, '_DocumentProxy__document', None)
object.__setattr__(self, '_DocumentProxy__pk', pk)
object.__setattr__(self, document_type._meta['id_field'], self.pk)
@property
def __class__(self):
# We need to fetch the object to determine to which class it belongs.
return self._get_current_object().__class__
def _lazy():
def fget(self):
return self.__document._lazy if self.__document else True
def fset(self, value):
self._get_current_object()._lazy = value
return property(fget, fset)
_lazy = _lazy()
# copy normally updates __dict__ which would result in errors
def __setstate__(self, state):
for k, v in state[1].iteritems():
object.__setattr__(self, k, v)
def _get_collection_name(self):
return self.__document_type._meta.get('collection', None)
def __eq__(self, other):
if other and hasattr(other, '_get_collection_name') and other._get_collection_name() == self._get_collection_name() and hasattr(other, 'pk'):
if self.pk == other.pk:
return True
return False
def __ne__(self, other):
return not self.__eq__(other)
def to_dbref(self):
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in
`__raw__` queries."""
if not self.pk:
msg = "Only saved documents can have a valid dbref"
raise OperationError(msg)
return DBRef(self._get_collection_name(), self.pk)
def pk():
def fget(self):
return self.__document.pk if self.__document else self.__pk
def fset(self, value):
self._get_current_object().pk = value
return property(fget, fset)
pk = pk()
def _get_current_object(self):
if self.__document == None:
collection = self.__document_type._get_collection()
son = collection.find_one({'_id': self.__pk})
document = self.__document_type._from_son(son)
object.__setattr__(self, '_DocumentProxy__document', document)
return self.__document
def __nonzero__(self):
return True

View file

@ -23,8 +23,9 @@ def _import_class(cls_name):
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'FileField', 'GenericReferenceField', 'FileField', 'GenericReferenceField',
'GenericEmbeddedDocumentField', 'GeoPointField', 'GenericEmbeddedDocumentField', 'GeoPointField',
'PointField', 'LineStringField', 'PolygonField', 'PointField', 'LineStringField', 'ListField',
'ReferenceField', 'StringField', 'ComplexBaseField') 'PolygonField', 'ReferenceField', 'StringField',
'ComplexBaseField')
queryset_classes = ('OperationError',) queryset_classes = ('OperationError',)
deref_classes = ('DeReference',) deref_classes = ('DeReference',)

View file

@ -55,12 +55,9 @@ def register_connection(alias, name, host='localhost', port=27017,
# Handle uri style connections # Handle uri style connections
if "://" in host: if "://" in host:
uri_dict = uri_parser.parse_uri(host) uri_dict = uri_parser.parse_uri(host)
if uri_dict.get('database') is None:
raise ConnectionError("If using URI style connection include "\
"database name in string")
conn_settings.update({ conn_settings.update({
'host': host, 'host': host,
'name': uri_dict.get('database'), 'name': uri_dict.get('database') or name,
'username': uri_dict.get('username'), 'username': uri_dict.get('username'),
'password': uri_dict.get('password'), 'password': uri_dict.get('password'),
'read_preference': read_preference, 'read_preference': read_preference,

View file

@ -4,7 +4,7 @@ from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document)
from fields import (ReferenceField, ListField, DictField, MapField) from fields import (ReferenceField, ListField, DictField, MapField)
from connection import get_db from connection import get_db
from queryset import QuerySet from queryset import QuerySet
from document import Document from document import Document, EmbeddedDocument
class DeReference(object): class DeReference(object):
@ -33,7 +33,8 @@ class DeReference(object):
self.max_depth = max_depth self.max_depth = max_depth
doc_type = None doc_type = None
if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)): if instance and isinstance(instance, (Document, EmbeddedDocument,
TopLevelDocumentMetaclass)):
doc_type = instance._fields.get(name) doc_type = instance._fields.get(name)
if hasattr(doc_type, 'field'): if hasattr(doc_type, 'field'):
doc_type = doc_type.field doc_type = doc_type.field
@ -86,11 +87,9 @@ class DeReference(object):
for k, item in iterator: for k, item in iterator:
if isinstance(item, Document): if isinstance(item, Document):
for field_name, field in item._fields.iteritems(): for field_name, field in item._fields.iteritems():
v = getattr(item, field_name) v = item._data.get(field_name, None)
if isinstance(v, (DBRef)): if isinstance(v, (DBRef)):
reference_map.setdefault(field.document_type, []).append(v.id) reference_map.setdefault(field.document_type, []).append(v.id)
elif isinstance(v, Document) and getattr(v, '_lazy', False):
reference_map.setdefault(field.document_type, []).append(v.pk)
elif isinstance(v, (dict, SON)) and '_ref' in v: elif isinstance(v, (dict, SON)) and '_ref' in v:
reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id) reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
@ -171,7 +170,7 @@ class DeReference(object):
return self.object_map.get(items['_ref'].id, items) return self.object_map.get(items['_ref'].id, items)
elif '_cls' in items: elif '_cls' in items:
doc = get_document(items['_cls'])._from_son(items) doc = get_document(items['_cls'])._from_son(items)
doc._internal_data = self._attach_objects(doc._internal_data, depth, doc, None) doc._data = self._attach_objects(doc._data, depth, doc, None)
return doc return doc
if not hasattr(items, 'items'): if not hasattr(items, 'items'):
@ -195,17 +194,15 @@ class DeReference(object):
data[k] = self.object_map[k] data[k] = self.object_map[k]
elif isinstance(v, Document): elif isinstance(v, Document):
for field_name, field in v._fields.iteritems(): for field_name, field in v._fields.iteritems():
v = data[k]._internal_data.get(field_name, None) v = data[k]._data.get(field_name, None)
if isinstance(v, (DBRef)): if isinstance(v, (DBRef)):
data[k]._internal_data[field_name] = self.object_map.get(v.id, v) data[k]._data[field_name] = self.object_map.get(v.id, v)
elif isinstance(v, Document) and getattr(v, '_lazy', False):
data[k]._internal_data[field_name] = self.object_map.get(v.pk, v)
elif isinstance(v, (dict, SON)) and '_ref' in v: elif isinstance(v, (dict, SON)) and '_ref' in v:
data[k]._internal_data[field_name] = self.object_map.get(v['_ref'].id, v) data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v)
elif isinstance(v, dict) and depth <= self.max_depth: elif isinstance(v, dict) and depth <= self.max_depth:
data[k]._internal_data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name)
elif isinstance(v, (list, tuple)) and depth <= self.max_depth: elif isinstance(v, (list, tuple)) and depth <= self.max_depth:
data[k]._internal_data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name)
elif hasattr(v, 'id'): elif hasattr(v, 'id'):

View file

@ -6,10 +6,29 @@ from django.utils.importlib import import_module
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
__all__ = (
'get_user_document',
)
MONGOENGINE_USER_DOCUMENT = getattr( MONGOENGINE_USER_DOCUMENT = getattr(
settings, 'MONGOENGINE_USER_DOCUMENT', 'mongoengine.django.auth.User') settings, 'MONGOENGINE_USER_DOCUMENT', 'mongoengine.django.auth.User')
def get_user_document():
"""Get the user document class used for authentication.
This is the class defined in settings.MONGOENGINE_USER_DOCUMENT, which
defaults to `mongoengine.django.auth.User`.
"""
name = MONGOENGINE_USER_DOCUMENT
dot = name.rindex('.')
module = import_module(name[:dot])
return getattr(module, name[dot + 1:])
class MongoUserManager(UserManager): class MongoUserManager(UserManager):
"""A User manager wich allows the use of MongoEngine documents in Django. """A User manager wich allows the use of MongoEngine documents in Django.
@ -44,7 +63,7 @@ class MongoUserManager(UserManager):
def contribute_to_class(self, model, name): def contribute_to_class(self, model, name):
super(MongoUserManager, self).contribute_to_class(model, name) super(MongoUserManager, self).contribute_to_class(model, name)
self.dj_model = self.model self.dj_model = self.model
self.model = self._get_user_document() self.model = get_user_document()
self.dj_model.USERNAME_FIELD = self.model.USERNAME_FIELD self.dj_model.USERNAME_FIELD = self.model.USERNAME_FIELD
username = models.CharField(_('username'), max_length=30, unique=True) username = models.CharField(_('username'), max_length=30, unique=True)
@ -55,16 +74,6 @@ class MongoUserManager(UserManager):
field = models.CharField(_(name), max_length=30) field = models.CharField(_(name), max_length=30)
field.contribute_to_class(self.dj_model, name) field.contribute_to_class(self.dj_model, name)
def _get_user_document(self):
try:
name = MONGOENGINE_USER_DOCUMENT
dot = name.rindex('.')
module = import_module(name[:dot])
return getattr(module, name[dot + 1:])
except ImportError:
raise ImproperlyConfigured("Error importing %s, please check "
"settings.MONGOENGINE_USER_DOCUMENT"
% name)
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
try: try:
@ -85,5 +94,14 @@ class MongoUserManager(UserManager):
class MongoUser(models.Model): class MongoUser(models.Model):
objects = MongoUserManager() """"Dummy user model for Django.
MongoUser is used to replace Django's UserManager with MongoUserManager.
The actual user document class is mongoengine.django.auth.User or any
other document class specified in MONGOENGINE_USER_DOCUMENT.
To get the user document class, use `get_user_document()`.
"""
objects = MongoUserManager()

View file

@ -1,7 +1,10 @@
from django.conf import settings from django.conf import settings
from django.contrib.sessions.backends.base import SessionBase, CreateError from django.contrib.sessions.backends.base import SessionBase, CreateError
from django.core.exceptions import SuspiciousOperation from django.core.exceptions import SuspiciousOperation
from django.utils.encoding import force_unicode try:
from django.utils.encoding import force_unicode
except ImportError:
from django.utils.encoding import force_text as force_unicode
from mongoengine.document import Document from mongoengine.document import Document
from mongoengine import fields from mongoengine import fields

View file

@ -76,7 +76,7 @@ class GridFSStorage(Storage):
"""Find the documents in the store with the given name """Find the documents in the store with the given name
""" """
docs = self.document.objects docs = self.document.objects
doc = [d for d in docs if getattr(d, self.field).name == name] doc = [d for d in docs if hasattr(getattr(d, self.field), 'name') and getattr(d, self.field).name == name]
if doc: if doc:
return doc[0] return doc[0]
else: else:

View file

@ -12,7 +12,7 @@ from mongoengine.common import _import_class
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
BaseDocument, BaseDict, BaseList, BaseDocument, BaseDict, BaseList,
ALLOW_INHERITANCE, get_document) ALLOW_INHERITANCE, get_document)
from mongoengine.queryset import OperationError, NotUniqueError, QuerySet, DoesNotExist from mongoengine.queryset import OperationError, NotUniqueError, QuerySet
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
from mongoengine.context_managers import switch_db, switch_collection from mongoengine.context_managers import switch_db, switch_collection
@ -20,7 +20,6 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
'DynamicEmbeddedDocument', 'OperationError', 'DynamicEmbeddedDocument', 'OperationError',
'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument')
_set = object.__setattr__
def includes_cls(fields): def includes_cls(fields):
""" Helper function used for ensuring and comparing indexes """ Helper function used for ensuring and comparing indexes
@ -63,11 +62,11 @@ class EmbeddedDocument(BaseDocument):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(EmbeddedDocument, self).__init__(*args, **kwargs) super(EmbeddedDocument, self).__init__(*args, **kwargs)
self._changed_fields = set() self._changed_fields = []
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return self.to_dict() == other.to_dict() return self._data == other._data
return False return False
def __ne__(self, other): def __ne__(self, other):
@ -178,13 +177,15 @@ class Document(BaseDocument):
cls.ensure_indexes() cls.ensure_indexes()
return cls._collection return cls._collection
def save(self, validate=True, clean=True, def save(self, force_insert=False, validate=True, clean=True,
write_concern=None, cascade=None, cascade_kwargs=None, write_concern=None, cascade=None, cascade_kwargs=None,
_refs=None, full=False, **kwargs): _refs=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the """Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
:param force_insert: only try to create a new document, don't allow
updates of existing documents
:param validate: validates the document; set to ``False`` to skip. :param validate: validates the document; set to ``False`` to skip.
:param clean: call the document clean method, requires `validate` to be :param clean: call the document clean method, requires `validate` to be
True. True.
@ -201,7 +202,6 @@ class Document(BaseDocument):
:param cascade_kwargs: (optional) kwargs dictionary to be passed throw :param cascade_kwargs: (optional) kwargs dictionary to be passed throw
to cascading saves. Implies ``cascade=True``. to cascading saves. Implies ``cascade=True``.
:param _refs: A list of processed references used in cascading saves :param _refs: A list of processed references used in cascading saves
:param full: Save all model fields instead of just changed ones.
.. versionchanged:: 0.5 .. versionchanged:: 0.5
In existing documents it only saves changed fields using In existing documents it only saves changed fields using
@ -217,52 +217,61 @@ class Document(BaseDocument):
the cascade save using cascade_kwargs which overwrites the the cascade save using cascade_kwargs which overwrites the
existing kwargs with custom values. existing kwargs with custom values.
""" """
signals.pre_save.send(self.__class__, document=self) signals.pre_save.send(self.__class__, document=self)
if validate: if validate:
self.validate(clean=clean) self.validate(clean=clean)
if not write_concern: if write_concern is None:
write_concern = {'w': 1} write_concern = {"w": 1}
doc = self.to_mongo()
created = ('_id' not in doc or self._created or force_insert)
signals.pre_save_post_validation.send(self.__class__, document=self, created=created)
collection = self._get_collection()
try: try:
if self._created: collection = self._get_collection()
# Update: Get delta. if created:
sets, unsets = self._delta(full) if force_insert:
db_id_field = self._fields[self._meta['id_field']].db_field object_id = collection.insert(doc, **write_concern)
sets.pop(db_id_field, None) else:
object_id = collection.save(doc, **write_concern)
else:
object_id = doc['_id']
updates, removals = self._delta()
# Need to add shard key to query, or you get an error
select_dict = {'_id': object_id}
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
actual_key = self._db_field_map.get(k, k)
select_dict[actual_key] = doc[actual_key]
def is_new_object(last_error):
if last_error is not None:
updated = last_error.get("updatedExisting")
if updated is not None:
return not updated
return created
update_query = {} update_query = {}
if sets:
update_query['$set'] = sets
if unsets:
update_query['$unset'] = unsets
if update_query: if updates:
collection.update(self._db_object_key, update_query, **write_concern) update_query["$set"] = updates
if removals:
update_query["$unset"] = removals
if updates or removals:
last_error = collection.update(select_dict, update_query,
upsert=True, **write_concern)
created = is_new_object(last_error)
created = False if cascade is None:
else: cascade = self._meta.get('cascade', False) or cascade_kwargs is not None
# Insert: Get full SON.
doc = self.to_mongo()
object_id = collection.insert(doc, **write_concern)
# Fix pymongo's "return return_one and ids[0] or ids":
# If the ID is 0, pymongo wraps it in a list.
if isinstance(object_id, list) and not object_id[0]:
object_id = object_id[0]
id_field = self._meta['id_field']
del self._internal_data[id_field]
_set(self, '_db_data', doc)
doc['_id'] = object_id
created = True
cascade = (self._meta.get('cascade', False)
if cascade is None else cascade)
if cascade: if cascade:
kwargs = { kwargs = {
"force_insert": force_insert,
"validate": validate, "validate": validate,
"write_concern": write_concern, "write_concern": write_concern,
"cascade": cascade "cascade": cascade
@ -280,9 +289,12 @@ class Document(BaseDocument):
message = u'Tried to save duplicate unique keys (%s)' message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % unicode(err)) raise NotUniqueError(message % unicode(err))
raise OperationError(message % unicode(err)) raise OperationError(message % unicode(err))
id_field = self._meta['id_field']
if id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id)
self._clear_changed_fields() self._clear_changed_fields()
self._created = False
signals.post_save.send(self.__class__, document=self, created=created) signals.post_save.send(self.__class__, document=self, created=created)
return self return self
@ -299,17 +311,14 @@ class Document(BaseDocument):
GenericReferenceField)): GenericReferenceField)):
continue continue
ref = getattr(self, name) ref = self._data.get(name)
if not ref or isinstance(ref, DBRef): if not ref or isinstance(ref, DBRef):
continue continue
if not getattr(ref, '_changed_fields', True): if not getattr(ref, '_changed_fields', True):
continue continue
if getattr(ref, '_lazy', False): ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data))
continue
ref_id = "%s,%s" % (ref.__class__.__name__, str(ref.to_dict()))
if ref and ref_id not in _refs: if ref and ref_id not in _refs:
_refs.append(ref_id) _refs.append(ref_id)
kwargs["_refs"] = _refs kwargs["_refs"] = _refs
@ -335,16 +344,6 @@ class Document(BaseDocument):
select_dict[k] = getattr(self, k) select_dict[k] = getattr(self, k)
return select_dict return select_dict
@property
def _db_object_key(self):
field = self._fields[self._meta['id_field']]
select_dict = {field.db_field: field.to_mongo(self.pk)}
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
actual_key = self._db_field_map.get(k, k)
select_dict[actual_key] = self._fields[k].to_mongo(getattr(self, k))
return select_dict
def update(self, **kwargs): def update(self, **kwargs):
"""Performs an update on the :class:`~mongoengine.Document` """Performs an update on the :class:`~mongoengine.Document`
A convenience wrapper to :meth:`~mongoengine.QuerySet.update`. A convenience wrapper to :meth:`~mongoengine.QuerySet.update`.
@ -377,9 +376,6 @@ class Document(BaseDocument):
""" """
signals.pre_delete.send(self.__class__, document=self) signals.pre_delete.send(self.__class__, document=self)
if not write_concern:
write_concern = {'w': 1}
try: try:
self._qs.filter(**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) self._qs.filter(**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True)
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
@ -404,11 +400,11 @@ class Document(BaseDocument):
""" """
with switch_db(self.__class__, db_alias) as cls: with switch_db(self.__class__, db_alias) as cls:
collection = cls._get_collection() collection = cls._get_collection()
db = cls._get_db db = cls._get_db()
self._get_collection = lambda: collection self._get_collection = lambda: collection
self._get_db = lambda: db self._get_db = lambda: db
self._collection = collection self._collection = collection
#self._created = True self._created = True
self.__objects = self._qs self.__objects = self._qs
self.__objects._collection_obj = collection self.__objects._collection_obj = collection
return self return self
@ -433,7 +429,7 @@ class Document(BaseDocument):
collection = cls._get_collection() collection = cls._get_collection()
self._get_collection = lambda: collection self._get_collection = lambda: collection
self._collection = collection self._collection = collection
#self._created = True self._created = True
self.__objects = self._qs self.__objects = self._qs
self.__objects._collection_obj = collection self.__objects._collection_obj = collection
return self return self
@ -444,22 +440,44 @@ class Document(BaseDocument):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
import dereference DeReference = _import_class('DeReference')
self._internal_data = dereference.DeReference()(self._internal_data, max_depth) DeReference()([self], max_depth + 1)
return self return self
def reload(self): def reload(self, max_depth=1):
"""Reloads all attributes from the database. """Reloads all attributes from the database.
.. versionadded:: 0.1.2
.. versionchanged:: 0.6 Now chainable
""" """
collection = self._get_collection() obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
son = collection.find_one(self._db_object_key, read_preference=ReadPreference.PRIMARY) **self._object_key).limit(1).select_related(max_depth=max_depth)
if son == None:
raise DoesNotExist('Document has been deleted.') if obj:
_set(self, '_db_data', son) obj = obj[0]
_set(self, '_internal_data', {}) else:
_set(self, '_lazy', False) msg = "Reloaded document has been deleted"
self._clear_changed_fields() raise OperationError(msg)
return self for field in self._fields_ordered:
setattr(self, field, self._reload(field, obj[field]))
self._changed_fields = obj._changed_fields
self._created = False
return obj
def _reload(self, key, value):
"""Used by :meth:`~mongoengine.Document.reload` to ensure the
correct instance is linked to self.
"""
if isinstance(value, BaseDict):
value = [(k, self._reload(k, v)) for k, v in value.items()]
value = BaseDict(value, self, key)
elif isinstance(value, BaseList):
value = [self._reload(key, v) for v in value]
value = BaseList(value, self, key)
elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)):
value._instance = None
value._changed_fields = []
return value
def to_dbref(self): def to_dbref(self):
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in """Returns an instance of :class:`~bson.dbref.DBRef` useful in
@ -518,6 +536,8 @@ class Document(BaseDocument):
def ensure_indexes(cls): def ensure_indexes(cls):
"""Checks the document meta data and ensures all the indexes exist. """Checks the document meta data and ensures all the indexes exist.
Global defaults can be set in the meta - see :doc:`guide/defining-documents`
.. note:: You can disable automatic index creation by setting .. note:: You can disable automatic index creation by setting
`auto_create_index` to False in the documents meta data `auto_create_index` to False in the documents meta data
""" """
@ -659,8 +679,6 @@ class DynamicDocument(Document):
_dynamic = True _dynamic = True
# TODO
def __delattr__(self, *args, **kwargs): def __delattr__(self, *args, **kwargs):
"""Deletes the attribute by setting to None and allowing _delta to unset """Deletes the attribute by setting to None and allowing _delta to unset
it""" it"""
@ -684,8 +702,6 @@ class DynamicEmbeddedDocument(EmbeddedDocument):
_dynamic = True _dynamic = True
# TODO
def __delattr__(self, *args, **kwargs): def __delattr__(self, *args, **kwargs):
"""Deletes the attribute by setting to None and allowing _delta to unset """Deletes the attribute by setting to None and allowing _delta to unset
it""" it"""

View file

@ -3,7 +3,6 @@ import decimal
import itertools import itertools
import re import re
import time import time
import types
import urllib2 import urllib2
import uuid import uuid
import warnings import warnings
@ -23,11 +22,8 @@ from bson import Binary, DBRef, SON, ObjectId
from mongoengine.errors import ValidationError from mongoengine.errors import ValidationError
from mongoengine.python_support import (PY3, bin_type, txt_type, from mongoengine.python_support import (PY3, bin_type, txt_type,
str_types, StringIO) str_types, StringIO)
from mongoengine.base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField,
get_document, BaseDocument) get_document, BaseDocument)
from mongoengine.base.datastructures import BaseList, BaseDict
from mongoengine.base.proxy import DocumentProxy
from mongoengine.queryset import DoesNotExist
from queryset import DO_NOTHING, QuerySet from queryset import DO_NOTHING, QuerySet
from document import Document, EmbeddedDocument from document import Document, EmbeddedDocument
from connection import get_db, DEFAULT_CONNECTION_NAME from connection import get_db, DEFAULT_CONNECTION_NAME
@ -38,12 +34,11 @@ except ImportError:
Image = None Image = None
ImageOps = None ImageOps = None
__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', __all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
'FloatField', 'BooleanField', 'DateTimeField', 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField',
'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField',
'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField',
'SortedListField', 'DictField', 'MapField', 'ReferenceField', 'SortedListField', 'DictField', 'MapField', 'ReferenceField',
'SafeReferenceField', 'SafeReferenceListField',
'GenericReferenceField', 'BinaryField', 'GridFSError', 'GenericReferenceField', 'BinaryField', 'GridFSError',
'GridFSProxy', 'FileField', 'ImageGridFsProxy', 'GridFSProxy', 'FileField', 'ImageGridFsProxy',
'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField',
@ -63,6 +58,15 @@ class StringField(BaseField):
self.min_length = min_length self.min_length = min_length
super(StringField, self).__init__(**kwargs) super(StringField, self).__init__(**kwargs)
def to_python(self, value):
if isinstance(value, unicode):
return value
try:
value = value.decode('utf-8')
except:
pass
return value
def validate(self, value): def validate(self, value):
if not isinstance(value, basestring): if not isinstance(value, basestring):
self.error('StringField only accepts string values') self.error('StringField only accepts string values')
@ -117,7 +121,8 @@ class URLField(StringField):
r'(?::\d+)?' # optional port r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE) r'(?:/?|[/?]\S+)$', re.IGNORECASE)
def __init__(self, url_regex=None, **kwargs): def __init__(self, verify_exists=False, url_regex=None, **kwargs):
self.verify_exists = verify_exists
self.url_regex = url_regex or self._URL_REGEX self.url_regex = url_regex or self._URL_REGEX
super(URLField, self).__init__(**kwargs) super(URLField, self).__init__(**kwargs)
@ -126,31 +131,50 @@ class URLField(StringField):
self.error('Invalid URL: %s' % value) self.error('Invalid URL: %s' % value)
return return
if self.verify_exists:
warnings.warn(
"The URLField verify_exists argument has intractable security "
"and performance issues. Accordingly, it has been deprecated.",
DeprecationWarning)
try:
request = urllib2.Request(value)
urllib2.urlopen(request)
except Exception, e:
self.error('This URL appears to be a broken link: %s' % e)
class EmailField(StringField): class EmailField(StringField):
"""A field that validates input as an email address. """A field that validates input as an E-Mail-Address.
.. versionadded:: 0.4 .. versionadded:: 0.4
""" """
EMAIL_REGEX = re.compile(r'^.+@[^.].*\.[a-z]{2,10}$', re.IGNORECASE) EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
)
def validate(self, value): def validate(self, value):
if not EmailField.EMAIL_REGEX.match(value): if not EmailField.EMAIL_REGEX.match(value):
self.error('Invalid email address: %s' % value) self.error('Invalid Mail-address: %s' % value)
super(EmailField, self).validate(value) super(EmailField, self).validate(value)
class IntField(BaseField): class IntField(BaseField):
"""An integer field. """An 32-bit integer field.
""" """
def __init__(self, min_value=None, max_value=None, **kwargs): def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value self.min_value, self.max_value = min_value, max_value
super(IntField, self).__init__(**kwargs) super(IntField, self).__init__(**kwargs)
def from_python(self, value): def to_python(self, value):
return self.prepare_query_value(None, value) try:
value = int(value)
except ValueError:
pass
return value
def validate(self, value): def validate(self, value):
try: try:
@ -167,18 +191,59 @@ class IntField(BaseField):
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if value is None: if value is None:
return value return value
else:
return int(value) return int(value)
class LongField(BaseField):
"""An 64-bit integer field.
"""
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
super(LongField, self).__init__(**kwargs)
def to_python(self, value):
try:
value = long(value)
except ValueError:
pass
return value
def validate(self, value):
try:
value = long(value)
except:
self.error('%s could not be converted to long' % value)
if self.min_value is not None and value < self.min_value:
self.error('Long value is too small')
if self.max_value is not None and value > self.max_value:
self.error('Long value is too large')
def prepare_query_value(self, op, value):
if value is None:
return value
return long(value)
class FloatField(BaseField): class FloatField(BaseField):
"""A floating point number field. """An floating point number field.
""" """
def __init__(self, min_value=None, max_value=None, **kwargs): def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value self.min_value, self.max_value = min_value, max_value
super(FloatField, self).__init__(**kwargs) super(FloatField, self).__init__(**kwargs)
def to_python(self, value):
try:
value = float(value)
except ValueError:
pass
return value
def validate(self, value): def validate(self, value):
if isinstance(value, int): if isinstance(value, int):
value = float(value) value = float(value)
@ -191,6 +256,82 @@ class FloatField(BaseField):
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
self.error('Float value is too large') self.error('Float value is too large')
def prepare_query_value(self, op, value):
if value is None:
return value
return float(value)
class DecimalField(BaseField):
"""A fixed-point decimal number field.
.. versionchanged:: 0.8
.. versionadded:: 0.3
"""
def __init__(self, min_value=None, max_value=None, force_string=False,
precision=2, rounding=decimal.ROUND_HALF_UP, **kwargs):
"""
:param min_value: Validation rule for the minimum acceptable value.
:param max_value: Validation rule for the maximum acceptable value.
:param force_string: Store as a string.
:param precision: Number of decimal places to store.
:param rounding: The rounding rule from the python decimal libary:
- decimal.ROUND_CEILING (towards Infinity)
- decimal.ROUND_DOWN (towards zero)
- decimal.ROUND_FLOOR (towards -Infinity)
- decimal.ROUND_HALF_DOWN (to nearest with ties going towards zero)
- decimal.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer)
- decimal.ROUND_HALF_UP (to nearest with ties going away from zero)
- decimal.ROUND_UP (away from zero)
- decimal.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero)
Defaults to: ``decimal.ROUND_HALF_UP``
"""
self.min_value = min_value
self.max_value = max_value
self.force_string = force_string
self.precision = decimal.Decimal(".%s" % ("0" * precision))
self.rounding = rounding
super(DecimalField, self).__init__(**kwargs)
def to_python(self, value):
if value is None:
return value
# Convert to string for python 2.6 before casting to Decimal
value = decimal.Decimal("%s" % value)
return value.quantize(self.precision, rounding=self.rounding)
def to_mongo(self, value):
if value is None:
return value
if self.force_string:
return unicode(value)
return float(self.to_python(value))
def validate(self, value):
if not isinstance(value, decimal.Decimal):
if not isinstance(value, basestring):
value = unicode(value)
try:
value = decimal.Decimal(value)
except Exception, exc:
self.error('Could not convert value to decimal: %s' % exc)
if self.min_value is not None and value < self.min_value:
self.error('Decimal value is too small')
if self.max_value is not None and value > self.max_value:
self.error('Decimal value is too large')
def prepare_query_value(self, op, value):
return self.to_mongo(value)
class BooleanField(BaseField): class BooleanField(BaseField):
"""A boolean field type. """A boolean field type.
@ -198,6 +339,13 @@ class BooleanField(BaseField):
.. versionadded:: 0.1.2 .. versionadded:: 0.1.2
""" """
def to_python(self, value):
try:
value = bool(value)
except ValueError:
pass
return value
def validate(self, value): def validate(self, value):
if not isinstance(value, bool): if not isinstance(value, bool):
self.error('BooleanField only accepts boolean values') self.error('BooleanField only accepts boolean values')
@ -218,13 +366,11 @@ class DateTimeField(BaseField):
""" """
def validate(self, value): def validate(self, value):
if not isinstance(value, (datetime.datetime, datetime.date)): new_value = self.to_mongo(value)
if not isinstance(new_value, (datetime.datetime, datetime.date)):
self.error(u'cannot parse date "%s"' % value) self.error(u'cannot parse date "%s"' % value)
def from_python(self, value): def to_mongo(self, value):
return self.prepare_query_value(None, value) or value
def prepare_query_value(self, op, value):
if value is None: if value is None:
return value return value
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
@ -268,6 +414,9 @@ class DateTimeField(BaseField):
except ValueError: except ValueError:
return None return None
def prepare_query_value(self, op, value):
return self.to_mongo(value)
class ComplexDateTimeField(StringField): class ComplexDateTimeField(StringField):
""" """
@ -288,8 +437,6 @@ class ComplexDateTimeField(StringField):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
# TODO
def __init__(self, separator=',', **kwargs): def __init__(self, separator=',', **kwargs):
self.names = ['year', 'month', 'day', 'hour', 'minute', 'second', self.names = ['year', 'month', 'day', 'hour', 'minute', 'second',
'microsecond'] 'microsecond']
@ -395,11 +542,15 @@ class EmbeddedDocumentField(BaseField):
self.document_type_obj = get_document(self.document_type_obj) self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj return self.document_type_obj
def to_python(self, val): def to_python(self, value):
return self.document_type._from_son(val) if not isinstance(value, self.document_type):
return self.document_type._from_son(value)
return value
def to_mongo(self, val): def to_mongo(self, value):
return val and val.to_mongo() if not isinstance(value, self.document_type):
return value
return self.document_type.to_mongo(value)
def validate(self, value, clean=True): def validate(self, value, clean=True):
"""Make sure that the document instance is an instance of the """Make sure that the document instance is an instance of the
@ -433,8 +584,9 @@ class GenericEmbeddedDocumentField(BaseField):
return self.to_mongo(value) return self.to_mongo(value)
def to_python(self, value): def to_python(self, value):
doc_cls = get_document(value['_cls']) if isinstance(value, dict):
value = doc_cls._from_son(value) doc_cls = get_document(value['_cls'])
value = doc_cls._from_son(value)
return value return value
@ -472,7 +624,9 @@ class DynamicField(BaseField):
cls = value.__class__ cls = value.__class__
val = value.to_mongo() val = value.to_mongo()
# If we its a document thats not inherited add _cls # If we its a document thats not inherited add _cls
if (isinstance(value, (Document, EmbeddedDocument))): if (isinstance(value, Document)):
val = {"_ref": value.to_dbref(), "_cls": cls.__name__}
if (isinstance(value, EmbeddedDocument)):
val['_cls'] = cls.__name__ val['_cls'] = cls.__name__
return val return val
@ -493,6 +647,15 @@ class DynamicField(BaseField):
value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))] value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))]
return value return value
def to_python(self, value):
if isinstance(value, dict) and '_cls' in value:
doc_cls = get_document(value['_cls'])
if '_ref' in value:
value = doc_cls._get_db().dereference(value['_ref'])
return doc_cls._from_son(value)
return super(DynamicField, self).to_python(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return member_name return member_name
@ -522,26 +685,6 @@ class ListField(ComplexBaseField):
kwargs.setdefault('default', lambda: []) kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs) super(ListField, self).__init__(**kwargs)
def value_for_instance(self, value, instance, name=None):
name = name or self.name
if value and self.field:
value_for_instance = getattr(self.field, 'value_for_instance', None)
if value_for_instance:
value = [value_for_instance(v, instance, name) for v in value]
return BaseList(value or [], instance, name)
def from_python(self, val):
from_python = getattr(self.field, 'from_python', None)
return [from_python(v) for v in val] if from_python else val
def to_python(self, val):
to_python = getattr(self.field, 'to_python', None)
return [to_python(v) for v in val] if to_python and val else val or None
def to_mongo(self, val):
to_mongo = getattr(self.field, 'to_mongo', None)
return [to_mongo(v) for v in val] if to_mongo and val else val or None
def validate(self, value): def validate(self, value):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used.
""" """
@ -557,9 +700,6 @@ class ListField(ComplexBaseField):
and hasattr(value, '__iter__')): and hasattr(value, '__iter__')):
return [self.field.prepare_query_value(op, v) for v in value] return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value) return self.field.prepare_query_value(op, value)
else:
if op in ('set', 'unset'):
return value
return super(ListField, self).prepare_query_value(op, value) return super(ListField, self).prepare_query_value(op, value)
@ -590,13 +730,10 @@ class SortedListField(ListField):
def to_mongo(self, value): def to_mongo(self, value):
value = super(SortedListField, self).to_mongo(value) value = super(SortedListField, self).to_mongo(value)
if value: if self._ordering is not None:
if self._ordering is not None: return sorted(value, key=itemgetter(self._ordering),
return sorted(value, key=itemgetter(self._ordering), reverse=self._order_reverse)
reverse=self._order_reverse) return sorted(value, reverse=self._order_reverse)
return sorted(value, reverse=self._order_reverse)
else:
return value
class DictField(ComplexBaseField): class DictField(ComplexBaseField):
@ -618,26 +755,6 @@ class DictField(ComplexBaseField):
kwargs.setdefault('default', lambda: {}) kwargs.setdefault('default', lambda: {})
super(DictField, self).__init__(*args, **kwargs) super(DictField, self).__init__(*args, **kwargs)
def from_python(self, val):
from_python = getattr(self.field, 'from_python', None)
return {k: from_python(v) for k, v in val.iteritems()} if from_python else val
def to_python(self, val):
to_python = getattr(self.field, 'to_python', None)
return {k: to_python(v) for k, v in val.iteritems()} if to_python and val else val or None
def value_for_instance(self, value, instance, name=None):
name = name or self.name
if value and self.field:
value_for_instance = getattr(self.field, 'value_for_instance', None)
if value_for_instance:
value = {k: value_for_instance(v, instance, name) for k, v in value.iteritems()}
return BaseDict(value or {}, instance, name)
def to_mongo(self, val):
to_mongo = getattr(self.field, 'to_mongo', None)
return {k: to_mongo(v) for k, v in val.iteritems()} if to_mongo and val else val or None
def validate(self, value): def validate(self, value):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used.
""" """
@ -663,6 +780,10 @@ class DictField(ComplexBaseField):
if op in match_operators and isinstance(value, basestring): if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value) return StringField().prepare_query_value(op, value)
if hasattr(self.field, 'field'):
return self.field.prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value) return super(DictField, self).prepare_query_value(op, value)
@ -745,82 +866,69 @@ class ReferenceField(BaseField):
self.document_type_obj = get_document(self.document_type_obj) self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj return self.document_type_obj
def to_mongo(self, value): def __get__(self, instance, owner):
if isinstance(value, DBRef): """Descriptor to allow lazy dereferencing.
if self.dbref: """
return value if instance is None:
else: # Document class being used rather than a document object
return value.id return self
elif isinstance(value, (Document, DocumentProxy)):
document_type = self.document_type # Get value from document instance if available
value = instance._data.get(self.name)
self._auto_dereference = instance._fields[self.name]._auto_dereference
# Dereference DBRefs
if self._auto_dereference and isinstance(value, DBRef):
value = self.document_type._get_db().dereference(value)
if value is not None:
instance._data[self.name] = self.document_type._from_son(value)
return super(ReferenceField, self).__get__(instance, owner)
def to_mongo(self, document):
if isinstance(document, DBRef):
if not self.dbref:
return document.id
return document
id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name]
if isinstance(document, Document):
# We need the id from the saved object to create the DBRef # We need the id from the saved object to create the DBRef
pk = value.pk id_ = document.pk
if pk is None: if id_ is None:
self.error('You can only reference documents once they have' self.error('You can only reference documents once they have'
' been saved to the database') ' been saved to the database')
id_field_name = document_type._meta['id_field'] else:
id_field = document_type._fields[id_field_name] id_ = document
pk = id_field.to_mongo(pk)
if self.dbref: id_ = id_field.to_mongo(id_)
collection = document_type._get_collection_name() if self.dbref:
return DBRef(collection, pk) collection = self.document_type._get_collection_name()
else: return DBRef(collection, id_)
return pk
elif value != None: # string ID return id_
document_type = self.document_type
collection = document_type._get_collection_name()
return DBRef(collection, value)
def to_python(self, value): def to_python(self, value):
if value != None: """Convert a MongoDB-compatible type to a Python type.
document_type = self.document_type """
if self.dbref: if (not self.dbref and
pk = value.id not isinstance(value, (DBRef, Document, EmbeddedDocument))):
else: collection = self.document_type._get_collection_name()
if isinstance(value, DBRef): value = DBRef(collection, self.document_type.id.to_python(value))
pk = value.id return value
else:
pk = value
if document_type._meta['allow_inheritance']:
# We don't know of which type the object will be.
obj = DocumentProxy(document_type, pk)
else:
obj = document_type(pk=pk)
obj._lazy = True
return obj
def from_python(self, value):
if isinstance(value, (BaseDocument, DocumentProxy)):
return value
elif value == None:
return super(ReferenceField, self).from_python(value)
else:
# Support for werkzeug.local.LocalProxy
if hasattr(value, '_get_current_object'):
return value._get_current_object()
else:
# DBRef or ID
document_type = self.document_type
if isinstance(value, DBRef):
pk = value.id
else:
pk = value
if document_type._meta['allow_inheritance']:
# We don't know of which type the object will be.
obj = DocumentProxy(document_type, pk)
else:
obj = document_type(pk=pk)
obj._lazy = True
return obj
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return self.to_mongo(self.from_python(value)) if value is None:
return None
return self.to_mongo(value)
def validate(self, value): def validate(self, value):
if not isinstance(value, (self.document_type, DBRef, DocumentProxy)):
if not isinstance(value, (self.document_type, DBRef)):
self.error("A ReferenceField only accepts DBRef or documents") self.error("A ReferenceField only accepts DBRef or documents")
if isinstance(value, Document) and value.pk is None: if isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been ' self.error('You can only reference documents once they have been '
'saved to the database') 'saved to the database')
@ -828,52 +936,6 @@ class ReferenceField(BaseField):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
class SafeReferenceField(ReferenceField):
"""
Like a ReferenceField, but doesn't return non-existing references when
dereferencing, i.e. no DBRefs are returned. This means that the next time
an object is saved, the non-existing references are removed and application
code can rely on having only valid dereferenced objects.
When the field is referenced, the referenced object is loaded from the
database.
"""
def to_python(self, value):
obj = super(SafeReferenceField, self).to_python(value)
if obj:
# Must dereference so we don't get an invalid ObjectId back.
try:
obj.reload()
except DoesNotExist:
return None
return obj
class SafeReferenceListField(ListField):
"""
Like a ListField, but doesn't return non-existing references when
dereferencing, i.e. no DBRefs are returned. This means that the next time
an object is saved, the non-existing references are removed and application
code can rely on having only valid dereferenced objects.
When the field is referenced, all referenced objects are loaded from the
database.
Must use ReferenceField as its field class.
"""
def __init__(self, field, **kwargs):
if not isinstance(field, ReferenceField):
raise ValueError('Field argument must be a ReferenceField instance.')
return super(SafeReferenceListField, self).__init__(field, **kwargs)
def to_python(self, value):
result = super(SafeReferenceListField, self).to_python(value)
if result:
objs = self.field.document_type.objects.in_bulk([obj.id for obj in result])
return filter(None, [objs.get(obj.id) for obj in result])
class GenericReferenceField(BaseField): class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass """A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily). that will be automatically dereferenced on access (lazily).
@ -888,6 +950,17 @@ class GenericReferenceField(BaseField):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
def __get__(self, instance, owner):
if instance is None:
return self
value = instance._data.get(self.name)
self._auto_dereference = instance._fields[self.name]._auto_dereference
if self._auto_dereference and isinstance(value, (dict, SON)):
instance._data[self.name] = self.dereference(value)
return super(GenericReferenceField, self).__get__(instance, owner)
def validate(self, value): def validate(self, value):
if not isinstance(value, (Document, DBRef, dict, SON)): if not isinstance(value, (Document, DBRef, dict, SON)):
self.error('GenericReferences can only contain documents') self.error('GenericReferences can only contain documents')
@ -909,14 +982,6 @@ class GenericReferenceField(BaseField):
doc = doc_cls._from_son(doc) doc = doc_cls._from_son(doc)
return doc return doc
def to_python(self, value):
if value != None:
doc_cls = get_document(value['_cls'])
reference = value['_ref']
obj = doc_cls(pk=reference.id)
obj._lazy = True
return obj
def to_mongo(self, document): def to_mongo(self, document):
if document is None: if document is None:
return None return None
@ -1033,6 +1098,10 @@ class GridFSProxy(object):
def __repr__(self): def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.grid_id) return '<%s: %s>' % (self.__class__.__name__, self.grid_id)
def __str__(self):
name = getattr(self.get(), 'filename', self.grid_id) if self.get() else '(no file)'
return '<%s: %s>' % (self.__class__.__name__, name)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, GridFSProxy): if isinstance(other, GridFSProxy):
return ((self.grid_id == other.grid_id) and return ((self.grid_id == other.grid_id) and
@ -1140,9 +1209,7 @@ class FileField(BaseField):
# Check if a file already exists for this model # Check if a file already exists for this model
grid_file = instance._data.get(self.name) grid_file = instance._data.get(self.name)
if not isinstance(grid_file, self.proxy_class): if not isinstance(grid_file, self.proxy_class):
grid_file = self.proxy_class(key=self.name, instance=instance, grid_file = self.get_proxy_obj(key=self.name, instance=instance)
db_alias=self.db_alias,
collection_name=self.collection_name)
instance._data[self.name] = grid_file instance._data[self.name] = grid_file
if not grid_file.key: if not grid_file.key:
@ -1164,15 +1231,23 @@ class FileField(BaseField):
pass pass
# Create a new proxy object as we don't already have one # Create a new proxy object as we don't already have one
instance._data[key] = self.proxy_class(key=key, instance=instance, instance._data[key] = self.get_proxy_obj(key=key, instance=instance)
db_alias=self.db_alias,
collection_name=self.collection_name)
instance._data[key].put(value) instance._data[key].put(value)
else: else:
instance._data[key] = value instance._data[key] = value
instance._mark_as_changed(key) instance._mark_as_changed(key)
def get_proxy_obj(self, key, instance, db_alias=None, collection_name=None):
if db_alias is None:
db_alias = self.db_alias
if collection_name is None:
collection_name = self.collection_name
return self.proxy_class(key=key, instance=instance,
db_alias=db_alias,
collection_name=collection_name)
def to_mongo(self, value): def to_mongo(self, value):
# Store the GridFS file id in MongoDB # Store the GridFS file id in MongoDB
if isinstance(value, self.proxy_class) and value.grid_id is not None: if isinstance(value, self.proxy_class) and value.grid_id is not None:
@ -1205,6 +1280,9 @@ class ImageGridFsProxy(GridFSProxy):
applying field properties (size, thumbnail_size) applying field properties (size, thumbnail_size)
""" """
field = self.instance._fields[self.key] field = self.instance._fields[self.key]
# Handle nested fields
if hasattr(field, 'field') and isinstance(field.field, FileField):
field = field.field
try: try:
img = Image.open(file_obj) img = Image.open(file_obj)

1494
mongoengine/queryset/base.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -55,7 +55,8 @@ class QueryFieldList(object):
if self.always_include: if self.always_include:
if self.value is self.ONLY and self.fields: if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include) if sorted(self.slice.keys()) != sorted(self.fields):
self.fields = self.fields.union(self.always_include)
else: else:
self.fields -= self.always_include self.fields -= self.always_include

File diff suppressed because it is too large Load diff

View file

@ -43,11 +43,11 @@ def query(_doc_cls=None, _field_operation=False, **query):
parts = [part for part in parts if not part.isdigit()] parts = [part for part in parts if not part.isdigit()]
# Check for an operator and transform to mongo-style if there is # Check for an operator and transform to mongo-style if there is
op = None op = None
if parts[-1] in MATCH_OPERATORS: if len(parts) > 1 and parts[-1] in MATCH_OPERATORS:
op = parts.pop() op = parts.pop()
negate = False negate = False
if parts[-1] == 'not': if len(parts) > 1 and parts[-1] == 'not':
parts.pop() parts.pop()
negate = True negate = True
@ -182,6 +182,7 @@ def update(_doc_cls=None, **update):
parts = [] parts = []
cleaned_fields = [] cleaned_fields = []
appended_sub_field = False
for field in fields: for field in fields:
append_field = True append_field = True
if isinstance(field, basestring): if isinstance(field, basestring):
@ -193,21 +194,30 @@ def update(_doc_cls=None, **update):
else: else:
parts.append(field.db_field) parts.append(field.db_field)
if append_field: if append_field:
appended_sub_field = False
cleaned_fields.append(field) cleaned_fields.append(field)
if hasattr(field, 'field'):
cleaned_fields.append(field.field)
appended_sub_field = True
# Convert value to proper value # Convert value to proper value
field = cleaned_fields[-1] if appended_sub_field:
field = cleaned_fields[-2]
else:
field = cleaned_fields[-1]
if op in (None, 'set', 'push', 'pull'): if op in (None, 'set', 'push', 'pull'):
if field.required or value is not None: if field.required or value is not None:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'): elif op in ('pushAll', 'pullAll'):
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
elif op == 'addToSet': elif op in ('addToSet', 'setOnInsert'):
if isinstance(value, (list, tuple, set)): if isinstance(value, (list, tuple, set)):
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
elif field.required or value is not None: elif field.required or value is not None:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op == "unset":
value = 1
if match: if match:
match = '$' + match match = '$' + match
@ -221,11 +231,24 @@ def update(_doc_cls=None, **update):
if 'pull' in op and '.' in key: if 'pull' in op and '.' in key:
# Dot operators don't work on pull operations # Dot operators don't work on pull operations
# it uses nested dict syntax # unless they point to a list field
# Otherwise it uses nested dict syntax
if op == 'pullAll': if op == 'pullAll':
raise InvalidQueryError("pullAll operations only support " raise InvalidQueryError("pullAll operations only support "
"a single field depth") "a single field depth")
# Look for the last list field and use dot notation until there
field_classes = [c.__class__ for c in cleaned_fields]
field_classes.reverse()
ListField = _import_class('ListField')
if ListField in field_classes:
# Join all fields via dot notation to the last ListField
# Then process as normal
last_listField = len(cleaned_fields) - field_classes.index(ListField)
key = ".".join(parts[:last_listField])
parts = parts[last_listField:]
parts.insert(0, key)
parts.reverse() parts.reverse()
for key in parts: for key in parts:
value = {key: value} value = {key: value}

View file

@ -28,7 +28,7 @@ class DuplicateQueryConditionsError(InvalidQueryError):
class SimplificationVisitor(QNodeVisitor): class SimplificationVisitor(QNodeVisitor):
"""Simplifies query trees by combining unnecessary 'and' connection nodes """Simplifies query trees by combinging unnecessary 'and' connection nodes
into a single Q-object. into a single Q-object.
""" """
@ -73,16 +73,6 @@ class QueryCompilerVisitor(QNodeVisitor):
def visit_combination(self, combination): def visit_combination(self, combination):
operator = "$and" operator = "$and"
if combination.operation == combination.OR: if combination.operation == combination.OR:
keys = set([key for q in combination.children for key in q.keys()])
if len(keys) == 1:
field = keys.pop()
if not field.startswith('$') and not any([isinstance(q[field], dict) for q in combination.children]):
return {
field: {
'$in': [q[field] for q in combination.children if field in q]
}
}
operator = "$or" operator = "$or"
return {operator: combination.children} return {operator: combination.children}

View file

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
__all__ = ['pre_save', 'post_save', 'pre_delete', 'post_delete'] __all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation',
'post_save', 'pre_delete', 'post_delete']
signals_available = False signals_available = False
try: try:
@ -35,7 +36,10 @@ except ImportError:
# not put signals in here. Create your own namespace instead. # not put signals in here. Create your own namespace instead.
_signals = Namespace() _signals = Namespace()
pre_init = _signals.signal('pre_init')
post_init = _signals.signal('post_init')
pre_save = _signals.signal('pre_save') pre_save = _signals.signal('pre_save')
pre_save_post_validation = _signals.signal('pre_save_post_validation')
post_save = _signals.signal('post_save') post_save = _signals.signal('post_save')
pre_delete = _signals.signal('pre_delete') pre_delete = _signals.signal('pre_delete')
post_delete = _signals.signal('post_delete') post_delete = _signals.signal('post_delete')

View file

@ -5,7 +5,7 @@
%define srcname mongoengine %define srcname mongoengine
Name: python-%{srcname} Name: python-%{srcname}
Version: 0.8.2 Version: 0.8.4
Release: 1%{?dist} Release: 1%{?dist}
Summary: A Python Document-Object Mapper for working with MongoDB Summary: A Python Document-Object Mapper for working with MongoDB

View file

@ -48,17 +48,15 @@ CLASSIFIERS = [
'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: Software Development :: Libraries :: Python Modules',
] ]
extra_opts = {} extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])}
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
extra_opts['use_2to3'] = True extra_opts['use_2to3'] = True
extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6'] extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'django>=1.5.1']
extra_opts['packages'] = find_packages(exclude=('tests',))
if "test" in sys.argv or "nosetests" in sys.argv: if "test" in sys.argv or "nosetests" in sys.argv:
extra_opts['packages'].append("tests") extra_opts['packages'] = find_packages()
extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]}
else: else:
extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2==2.6', 'python-dateutil'] extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6', 'python-dateutil']
extra_opts['packages'] = find_packages(exclude=('tests',))
setup(name='mongoengine', setup(name='mongoengine',
version=VERSION, version=VERSION,

View file

@ -1,90 +0,0 @@
from mongoengine import *
from timeit import repeat
import unittest
conn_settings = {
'db': 'mongomallard-test',
}
connect(**conn_settings)
def timeit(f, n=10000):
return min(repeat(f, repeat=3, number=n))/float(n)
class BenchmarkTestCase(unittest.TestCase):
def setUp(self):
pass
def test_basic(self):
class Book(Document):
name = StringField()
pages = IntField()
tags = ListField(StringField())
is_published = BooleanField()
Book.drop_collection()
create_book = lambda: Book(name='Always be closing', pages=100, tags=['self-help', 'sales'], is_published=True)
print 'Doc initialization: %.3fus' % (timeit(create_book, 1000) * 10**6)
b = create_book()
print 'Doc getattr: %.3fus' % (timeit(lambda: b.name, 10000) * 10**6)
print 'Doc setattr: %.3fus' % (timeit(lambda: setattr(b, 'name', 'New name'), 10000) * 10**6)
print 'Doc to mongo: %.3fus' % (timeit(b.to_mongo, 1000) * 10**6)
def save_book():
b._mark_as_changed('name')
b._mark_as_changed('tags')
b.save()
save_book()
son = b.to_mongo()
print 'Load from SON: %.3fus' % (timeit(lambda: Book._from_son(son), 1000) * 10**6)
print 'Save to database: %.3fus' % (timeit(save_book, 100) * 10**6)
print 'Load from database: %.3fus' % (timeit(lambda: Book.objects[0], 100) * 10**6)
def test_embedded(self):
class Contact(EmbeddedDocument):
name = StringField()
title = StringField()
address = StringField()
class Company(Document):
name = StringField()
contacts = ListField(EmbeddedDocumentField(Contact))
Company.drop_collection()
def get_company():
return Company(
name='Elastic',
contacts=[
Contact(
name='Contact %d' % x,
title='CEO',
address='Address %d' % x,
)
for x in range(1000)]
)
def create_company():
c = get_company()
c.save()
c.delete()
print 'Save/delete big object to database: %.3fms' % (timeit(create_company, 10) * 10**3)
c = get_company().save()
print 'Serialize big object from database: %.3fms' % (timeit(c.to_mongo, 100) * 10**3)
print 'Load big object from database: %.3fms' % (timeit(lambda: Company.objects[0], 100) * 10**3)
if __name__ == '__main__':
unittest.main()

View file

@ -3,6 +3,7 @@ import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import unittest import unittest
from bson import SON
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
@ -48,42 +49,41 @@ class DeltaTest(unittest.TestCase):
doc.save() doc.save()
doc = Doc.objects.first() doc = Doc.objects.first()
self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(doc._delta(), ({}, {})) self.assertEqual(doc._delta(), ({}, {}))
doc.string_field = 'hello' doc.string_field = 'hello'
self.assertEqual(doc._get_changed_fields(), set(['string_field'])) self.assertEqual(doc._get_changed_fields(), ['string_field'])
self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {}))
doc._changed_fields = set() doc._changed_fields = []
doc.int_field = 1 doc.int_field = 1
self.assertEqual(doc._get_changed_fields(), set(['int_field'])) self.assertEqual(doc._get_changed_fields(), ['int_field'])
self.assertEqual(doc._delta(), ({'int_field': 1}, {})) self.assertEqual(doc._delta(), ({'int_field': 1}, {}))
doc._changed_fields = set() doc._changed_fields = []
dict_value = {'hello': 'world', 'ping': 'pong'} dict_value = {'hello': 'world', 'ping': 'pong'}
doc.dict_field = dict_value doc.dict_field = dict_value
self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._get_changed_fields(), ['dict_field'])
self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {}))
doc._changed_fields = set() doc._changed_fields = []
list_value = ['1', 2, {'hello': 'world'}] list_value = ['1', 2, {'hello': 'world'}]
doc.list_field = list_value doc.list_field = list_value
self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._get_changed_fields(), ['list_field'])
self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) self.assertEqual(doc._delta(), ({'list_field': list_value}, {}))
# Test unsetting # Test unsetting
doc._changed_fields = set() doc._changed_fields = []
doc.dict_field = {} doc.dict_field = {}
self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._get_changed_fields(), ['dict_field'])
self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) self.assertEqual(doc._delta(), ({}, {'dict_field': 1}))
doc._changed_fields = set() doc._changed_fields = []
doc.list_field = [] doc.list_field = []
self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._get_changed_fields(), ['list_field'])
self.assertEqual(doc._delta(), ({}, {'list_field': 1})) self.assertEqual(doc._delta(), ({}, {'list_field': 1}))
@unittest.skip("not fully implemented")
def test_delta_recursive(self): def test_delta_recursive(self):
self.delta_recursive(Document, EmbeddedDocument) self.delta_recursive(Document, EmbeddedDocument)
self.delta_recursive(DynamicDocument, EmbeddedDocument) self.delta_recursive(DynamicDocument, EmbeddedDocument)
@ -110,7 +110,7 @@ class DeltaTest(unittest.TestCase):
doc.save() doc.save()
doc = Doc.objects.first() doc = Doc.objects.first()
self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(doc._delta(), ({}, {})) self.assertEqual(doc._delta(), ({}, {}))
embedded_1 = Embedded() embedded_1 = Embedded()
@ -120,7 +120,7 @@ class DeltaTest(unittest.TestCase):
embedded_1.list_field = ['1', 2, {'hello': 'world'}] embedded_1.list_field = ['1', 2, {'hello': 'world'}]
doc.embedded_field = embedded_1 doc.embedded_field = embedded_1
self.assertEqual(doc._get_changed_fields(), set(['embedded_field'])) self.assertEqual(doc._get_changed_fields(), ['embedded_field'])
embedded_delta = { embedded_delta = {
'string_field': 'hello', 'string_field': 'hello',
@ -137,7 +137,7 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.dict_field = {} doc.embedded_field.dict_field = {}
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
set(['embedded_field.dict_field'])) ['embedded_field.dict_field'])
self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1}))
self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1}))
doc.save() doc.save()
@ -146,7 +146,7 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field = [] doc.embedded_field.list_field = []
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
set(['embedded_field.list_field'])) ['embedded_field.list_field'])
self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1}))
self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1}))
doc.save() doc.save()
@ -161,7 +161,7 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field = ['1', 2, embedded_2] doc.embedded_field.list_field = ['1', 2, embedded_2]
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
set(['embedded_field.list_field'])) ['embedded_field.list_field'])
self.assertEqual(doc.embedded_field._delta(), ({ self.assertEqual(doc.embedded_field._delta(), ({
'list_field': ['1', 2, { 'list_field': ['1', 2, {
@ -193,7 +193,7 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field[2].string_field = 'world' doc.embedded_field.list_field[2].string_field = 'world'
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
set(['embedded_field.list_field.2.string_field'])) ['embedded_field.list_field.2.string_field'])
self.assertEqual(doc.embedded_field._delta(), self.assertEqual(doc.embedded_field._delta(),
({'list_field.2.string_field': 'world'}, {})) ({'list_field.2.string_field': 'world'}, {}))
self.assertEqual(doc._delta(), self.assertEqual(doc._delta(),
@ -207,7 +207,7 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2].string_field = 'hello world'
doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2]
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
set(['embedded_field.list_field'])) ['embedded_field.list_field'])
self.assertEqual(doc.embedded_field._delta(), ({ self.assertEqual(doc.embedded_field._delta(), ({
'list_field': ['1', 2, { 'list_field': ['1', 2, {
'_cls': 'Embedded', '_cls': 'Embedded',
@ -270,7 +270,7 @@ class DeltaTest(unittest.TestCase):
doc.dict_field['Embedded'].string_field = 'Hello World' doc.dict_field['Embedded'].string_field = 'Hello World'
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
set(['dict_field.Embedded.string_field'])) ['dict_field.Embedded.string_field'])
self.assertEqual(doc._delta(), self.assertEqual(doc._delta(),
({'dict_field.Embedded.string_field': 'Hello World'}, {})) ({'dict_field.Embedded.string_field': 'Hello World'}, {}))
@ -313,29 +313,24 @@ class DeltaTest(unittest.TestCase):
self.circular_reference_deltas_2(DynamicDocument, Document) self.circular_reference_deltas_2(DynamicDocument, Document)
self.circular_reference_deltas_2(DynamicDocument, DynamicDocument) self.circular_reference_deltas_2(DynamicDocument, DynamicDocument)
def circular_reference_deltas_2(self, DocClass1, DocClass2): def circular_reference_deltas_2(self, DocClass1, DocClass2, dbref=True):
class Person(DocClass1): class Person(DocClass1):
name = StringField() name = StringField()
owns = ListField(ReferenceField('Organization')) owns = ListField(ReferenceField('Organization', dbref=dbref))
employer = ReferenceField('Organization') employer = ReferenceField('Organization', dbref=dbref)
class Organization(DocClass2): class Organization(DocClass2):
name = StringField() name = StringField()
owner = ReferenceField('Person') owner = ReferenceField('Person', dbref=dbref)
employees = ListField(ReferenceField('Person')) employees = ListField(ReferenceField('Person', dbref=dbref))
Person.drop_collection() Person.drop_collection()
Organization.drop_collection() Organization.drop_collection()
person = Person(name="owner") person = Person(name="owner").save()
person.save() employee = Person(name="employee").save()
organization = Organization(name="company").save()
employee = Person(name="employee")
employee.save()
organization = Organization(name="company")
organization.save()
person.owns.append(organization) person.owns.append(organization)
organization.owner = person organization.owner = person
@ -355,6 +350,8 @@ class DeltaTest(unittest.TestCase):
self.assertEqual(o.owner, p) self.assertEqual(o.owner, p)
self.assertEqual(e.employer, o) self.assertEqual(e.employer, o)
return person, organization, employee
def test_delta_db_field(self): def test_delta_db_field(self):
self.delta_db_field(Document) self.delta_db_field(Document)
self.delta_db_field(DynamicDocument) self.delta_db_field(DynamicDocument)
@ -372,39 +369,39 @@ class DeltaTest(unittest.TestCase):
doc.save() doc.save()
doc = Doc.objects.first() doc = Doc.objects.first()
self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(doc._delta(), ({}, {})) self.assertEqual(doc._delta(), ({}, {}))
doc.string_field = 'hello' doc.string_field = 'hello'
self.assertEqual(doc._get_changed_fields(), set(['string_field'])) self.assertEqual(doc._get_changed_fields(), ['db_string_field'])
self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {})) self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {}))
doc._changed_fields = set() doc._changed_fields = []
doc.int_field = 1 doc.int_field = 1
self.assertEqual(doc._get_changed_fields(), set(['int_field'])) self.assertEqual(doc._get_changed_fields(), ['db_int_field'])
self.assertEqual(doc._delta(), ({'db_int_field': 1}, {})) self.assertEqual(doc._delta(), ({'db_int_field': 1}, {}))
doc._changed_fields = set() doc._changed_fields = []
dict_value = {'hello': 'world', 'ping': 'pong'} dict_value = {'hello': 'world', 'ping': 'pong'}
doc.dict_field = dict_value doc.dict_field = dict_value
self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._get_changed_fields(), ['db_dict_field'])
self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {})) self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {}))
doc._changed_fields = set() doc._changed_fields = []
list_value = ['1', 2, {'hello': 'world'}] list_value = ['1', 2, {'hello': 'world'}]
doc.list_field = list_value doc.list_field = list_value
self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._get_changed_fields(), ['db_list_field'])
self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {})) self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {}))
# Test unsetting # Test unsetting
doc._changed_fields = set() doc._changed_fields = []
doc.dict_field = {} doc.dict_field = {}
self.assertEqual(doc._get_changed_fields(), set(['dict_field'])) self.assertEqual(doc._get_changed_fields(), ['db_dict_field'])
self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1})) self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1}))
doc._changed_fields = set() doc._changed_fields = []
doc.list_field = [] doc.list_field = []
self.assertEqual(doc._get_changed_fields(), set(['list_field'])) self.assertEqual(doc._get_changed_fields(), ['db_list_field'])
self.assertEqual(doc._delta(), ({}, {'db_list_field': 1})) self.assertEqual(doc._delta(), ({}, {'db_list_field': 1}))
# Test it saves that data # Test it saves that data
@ -416,15 +413,13 @@ class DeltaTest(unittest.TestCase):
doc.dict_field = {'hello': 'world'} doc.dict_field = {'hello': 'world'}
doc.list_field = ['1', 2, {'hello': 'world'}] doc.list_field = ['1', 2, {'hello': 'world'}]
doc.save() doc.save()
#doc = doc.reload(10) doc = doc.reload(10)
doc = doc.reload()
self.assertEqual(doc.string_field, 'hello') self.assertEqual(doc.string_field, 'hello')
self.assertEqual(doc.int_field, 1) self.assertEqual(doc.int_field, 1)
self.assertEqual(doc.dict_field, {'hello': 'world'}) self.assertEqual(doc.dict_field, {'hello': 'world'})
self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}]) self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}])
@unittest.skip("not fully implemented")
def test_delta_recursive_db_field(self): def test_delta_recursive_db_field(self):
self.delta_recursive_db_field(Document, EmbeddedDocument) self.delta_recursive_db_field(Document, EmbeddedDocument)
self.delta_recursive_db_field(Document, DynamicEmbeddedDocument) self.delta_recursive_db_field(Document, DynamicEmbeddedDocument)
@ -452,7 +447,7 @@ class DeltaTest(unittest.TestCase):
doc.save() doc.save()
doc = Doc.objects.first() doc = Doc.objects.first()
self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(doc._delta(), ({}, {})) self.assertEqual(doc._delta(), ({}, {}))
embedded_1 = Embedded() embedded_1 = Embedded()
@ -462,7 +457,7 @@ class DeltaTest(unittest.TestCase):
embedded_1.list_field = ['1', 2, {'hello': 'world'}] embedded_1.list_field = ['1', 2, {'hello': 'world'}]
doc.embedded_field = embedded_1 doc.embedded_field = embedded_1
self.assertEqual(doc._get_changed_fields(), set(['embedded_field'])) self.assertEqual(doc._get_changed_fields(), ['db_embedded_field'])
embedded_delta = { embedded_delta = {
'db_string_field': 'hello', 'db_string_field': 'hello',
@ -490,7 +485,7 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field = [] doc.embedded_field.list_field = []
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
set(['db_embedded_field.db_list_field'])) ['db_embedded_field.db_list_field'])
self.assertEqual(doc.embedded_field._delta(), self.assertEqual(doc.embedded_field._delta(),
({}, {'db_list_field': 1})) ({}, {'db_list_field': 1}))
self.assertEqual(doc._delta(), self.assertEqual(doc._delta(),
@ -608,7 +603,6 @@ class DeltaTest(unittest.TestCase):
self.assertEqual(doc._delta(), ({}, self.assertEqual(doc._delta(), ({},
{'db_embedded_field.db_list_field.2.db_list_field': 1})) {'db_embedded_field.db_list_field.2.db_list_field': 1}))
@unittest.skip("DynamicDocument not implemented")
def test_delta_for_dynamic_documents(self): def test_delta_for_dynamic_documents(self):
class Person(DynamicDocument): class Person(DynamicDocument):
name = StringField() name = StringField()
@ -617,13 +611,13 @@ class DeltaTest(unittest.TestCase):
Person.drop_collection() Person.drop_collection()
p = Person(name="James", age=34) p = Person(name="James", age=34)
self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', self.assertEqual(p._delta(), (
'_cls': 'Person'}, {})) SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {}))
p.doc = 123 p.doc = 123
del(p.doc) del(p.doc)
self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', self.assertEqual(p._delta(), (
'_cls': 'Person'}, {'doc': 1})) SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {}))
p = Person() p = Person()
p.name = "Dean" p.name = "Dean"
@ -635,16 +629,15 @@ class DeltaTest(unittest.TestCase):
self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._get_changed_fields(), ['age'])
self.assertEqual(p._delta(), ({'age': 24}, {})) self.assertEqual(p._delta(), ({'age': 24}, {}))
p = self.Person.objects(age=22).get() p = Person.objects(age=22).get()
p.age = 24 p.age = 24
self.assertEqual(p.age, 24) self.assertEqual(p.age, 24)
self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._get_changed_fields(), ['age'])
self.assertEqual(p._delta(), ({'age': 24}, {})) self.assertEqual(p._delta(), ({'age': 24}, {}))
p.save() p.save()
self.assertEqual(1, self.Person.objects(age=24).count()) self.assertEqual(1, Person.objects(age=24).count())
@unittest.skip("DynamicDocument not implemented")
def test_dynamic_delta(self): def test_dynamic_delta(self):
class Doc(DynamicDocument): class Doc(DynamicDocument):
@ -690,6 +683,36 @@ class DeltaTest(unittest.TestCase):
self.assertEqual(doc._get_changed_fields(), ['list_field']) self.assertEqual(doc._get_changed_fields(), ['list_field'])
self.assertEqual(doc._delta(), ({}, {'list_field': 1})) self.assertEqual(doc._delta(), ({}, {'list_field': 1}))
def test_delta_with_dbref_true(self):
person, organization, employee = self.circular_reference_deltas_2(Document, Document, True)
employee.name = 'test'
self.assertEqual(organization._get_changed_fields(), [])
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertEqual({}, updates)
organization.employees.append(person)
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertTrue('employees' in updates)
def test_delta_with_dbref_false(self):
person, organization, employee = self.circular_reference_deltas_2(Document, Document, False)
employee.name = 'test'
self.assertEqual(organization._get_changed_fields(), [])
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertEqual({}, updates)
organization.employees.append(person)
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertTrue('employees' in updates)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -8,7 +8,6 @@ from mongoengine.connection import get_db
__all__ = ("DynamicTest", ) __all__ = ("DynamicTest", )
@unittest.skip("DynamicDocument not implemented")
class DynamicTest(unittest.TestCase): class DynamicTest(unittest.TestCase):
def setUp(self): def setUp(self):

View file

@ -156,6 +156,25 @@ class IndexesTest(unittest.TestCase):
self.assertEqual([{'fields': [('_cls', 1), ('title', 1)]}], self.assertEqual([{'fields': [('_cls', 1), ('title', 1)]}],
A._meta['index_specs']) A._meta['index_specs'])
def test_index_no_cls(self):
"""Ensure index specs are inhertited correctly"""
class A(Document):
title = StringField()
meta = {
'indexes': [
{'fields': ('title',), 'cls': False},
],
'allow_inheritance': True,
'index_cls': False
}
self.assertEqual([('title', 1)], A._meta['index_specs'][0]['fields'])
A._get_collection().drop_indexes()
A.ensure_indexes()
info = A._get_collection().index_information()
self.assertEqual(len(info.keys()), 2)
def test_build_index_spec_is_not_destructive(self): def test_build_index_spec_is_not_destructive(self):
class MyDoc(Document): class MyDoc(Document):
@ -632,7 +651,6 @@ class IndexesTest(unittest.TestCase):
pass pass
Customer.drop_collection() Customer.drop_collection()
@unittest.skip("behavior differs")
def test_unique_and_primary(self): def test_unique_and_primary(self):
"""If you set a field as primary, then unexpected behaviour can occur. """If you set a field as primary, then unexpected behaviour can occur.
You won't create a duplicate but you will update an existing document. You won't create a duplicate but you will update an existing document.

View file

@ -182,10 +182,10 @@ class InheritanceTest(unittest.TestCase):
self.assertEqual(['age', 'id', 'name', 'salary'], self.assertEqual(['age', 'id', 'name', 'salary'],
sorted(Employee._fields.keys())) sorted(Employee._fields.keys()))
self.assertEqual(set(Person(name="Bob", age=35).to_mongo().keys()), self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(),
set(['_cls', 'name', 'age'])) ['_cls', 'name', 'age'])
self.assertEqual(set(Employee(name="Bob", age=35, salary=0).to_mongo().keys()), self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(),
set(['_cls', 'name', 'age', 'salary'])) ['_cls', 'name', 'age', 'salary'])
self.assertEqual(Employee._get_collection_name(), self.assertEqual(Employee._get_collection_name(),
Person._get_collection_name()) Person._get_collection_name())

View file

@ -10,7 +10,8 @@ import uuid
from datetime import datetime from datetime import datetime
from bson import DBRef from bson import DBRef
from tests.fixtures import PickleEmbedded, PickleTest, PickleSignalsTest from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
PickleDyanmicEmbedded, PickleDynamicTest)
from mongoengine import * from mongoengine import *
from mongoengine.errors import (NotRegistered, InvalidDocumentError, from mongoengine.errors import (NotRegistered, InvalidDocumentError,
@ -390,25 +391,24 @@ class InstanceTest(unittest.TestCase):
doc.embedded_field = embedded_1 doc.embedded_field = embedded_1
doc.save() doc.save()
doc = doc.reload() doc = doc.reload(10)
doc.list_field.append(1) doc.list_field.append(1)
doc.dict_field['woot'] = "woot" doc.dict_field['woot'] = "woot"
doc.embedded_field.list_field.append(1) doc.embedded_field.list_field.append(1)
doc.embedded_field.dict_field['woot'] = "woot" doc.embedded_field.dict_field['woot'] = "woot"
self.assertEqual(doc._get_changed_fields(), set([ self.assertEqual(doc._get_changed_fields(), [
'list_field', 'dict_field', 'embedded_field.list_field', 'list_field', 'dict_field', 'embedded_field.list_field',
'embedded_field.dict_field'])) 'embedded_field.dict_field'])
doc.save() doc.save()
doc = doc.reload() doc = doc.reload(10)
self.assertEqual(doc._get_changed_fields(), set()) self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(len(doc.list_field), 4) self.assertEqual(len(doc.list_field), 4)
self.assertEqual(len(doc.dict_field), 2) self.assertEqual(len(doc.dict_field), 2)
self.assertEqual(len(doc.embedded_field.list_field), 4) self.assertEqual(len(doc.embedded_field.list_field), 4)
self.assertEqual(len(doc.embedded_field.dict_field), 2) self.assertEqual(len(doc.embedded_field.dict_field), 2)
@unittest.skip("not implemented")
def test_dictionary_access(self): def test_dictionary_access(self):
"""Ensure that dictionary-style field access works properly. """Ensure that dictionary-style field access works properly.
""" """
@ -439,10 +439,17 @@ class InstanceTest(unittest.TestCase):
class Employee(Person): class Employee(Person):
salary = IntField() salary = IntField()
self.assertEqual(set(Person(name="Bob", age=35).to_mongo().keys()), self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(),
set(['_cls', 'name', 'age'])) ['_cls', 'name', 'age'])
self.assertEqual(set(Employee(name="Bob", age=35, salary=0).to_mongo().keys()), self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(),
set(['_cls', 'name', 'age', 'salary'])) ['_cls', 'name', 'age', 'salary'])
def test_embedded_document_to_mongo_id(self):
class SubDoc(EmbeddedDocument):
id = StringField(required=True)
sub_doc = SubDoc(id="abc")
self.assertEqual(sub_doc.to_mongo().keys(), ['id'])
def test_embedded_document(self): def test_embedded_document(self):
"""Ensure that embedded documents are set up correctly. """Ensure that embedded documents are set up correctly.
@ -453,7 +460,6 @@ class InstanceTest(unittest.TestCase):
self.assertTrue('content' in Comment._fields) self.assertTrue('content' in Comment._fields)
self.assertFalse('id' in Comment._fields) self.assertFalse('id' in Comment._fields)
@unittest.skip("not implemented")
def test_embedded_document_instance(self): def test_embedded_document_instance(self):
"""Ensure that embedded documents can reference parent instance """Ensure that embedded documents can reference parent instance
""" """
@ -462,7 +468,6 @@ class InstanceTest(unittest.TestCase):
class Doc(Document): class Doc(Document):
embedded_field = EmbeddedDocumentField(Embedded) embedded_field = EmbeddedDocumentField(Embedded)
meta = { 'cascade': True }
Doc.drop_collection() Doc.drop_collection()
Doc(embedded_field=Embedded(string="Hi")).save() Doc(embedded_field=Embedded(string="Hi")).save()
@ -470,7 +475,6 @@ class InstanceTest(unittest.TestCase):
doc = Doc.objects.get() doc = Doc.objects.get()
self.assertEqual(doc, doc.embedded_field._instance) self.assertEqual(doc, doc.embedded_field._instance)
@unittest.skip("not implemented")
def test_embedded_document_complex_instance(self): def test_embedded_document_complex_instance(self):
"""Ensure that embedded documents in complex fields can reference """Ensure that embedded documents in complex fields can reference
parent instance""" parent instance"""
@ -627,7 +631,6 @@ class InstanceTest(unittest.TestCase):
p0.name = 'wpjunior' p0.name = 'wpjunior'
p0.save() p0.save()
@unittest.skip("FileField not implemented")
def test_save_max_recursion_not_hit_with_file_field(self): def test_save_max_recursion_not_hit_with_file_field(self):
class Foo(Document): class Foo(Document):
@ -776,7 +779,6 @@ class InstanceTest(unittest.TestCase):
p1.reload() p1.reload()
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
@unittest.skip("not implemented")
def test_update(self): def test_update(self):
"""Ensure that an existing document is updated instead of be """Ensure that an existing document is updated instead of be
overwritten.""" overwritten."""
@ -891,6 +893,7 @@ class InstanceTest(unittest.TestCase):
reference_field = ReferenceField(Simple, default=lambda: reference_field = ReferenceField(Simple, default=lambda:
Simple().save()) Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1}) map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.now) complex_datetime_field = ComplexDateTimeField(default=datetime.now)
url_field = URLField(default="http://mongoengine.org") url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1) dynamic_field = DynamicField(default=1)
@ -1059,9 +1062,9 @@ class InstanceTest(unittest.TestCase):
user = User.objects.first() user = User.objects.first()
# Even if stored as ObjectId's internally mongoengine uses DBRefs # Even if stored as ObjectId's internally mongoengine uses DBRefs
# As ObjectId's aren't automatically derefenced # As ObjectId's aren't automatically derefenced
#self.assertTrue(isinstance(user._data['orgs'][0], DBRef)) self.assertTrue(isinstance(user._data['orgs'][0], DBRef))
self.assertTrue(isinstance(user.orgs[0], Organization)) self.assertTrue(isinstance(user.orgs[0], Organization))
#self.assertTrue(isinstance(user._data['orgs'][0], Organization)) self.assertTrue(isinstance(user._data['orgs'][0], Organization))
# Changing a value # Changing a value
with query_counter() as q: with query_counter() as q:
@ -1141,7 +1144,6 @@ class InstanceTest(unittest.TestCase):
foo.save() foo.save()
self.assertEqual(1, q) self.assertEqual(1, q)
@unittest.skip("not implemented")
def test_save_only_changed_fields_recursive(self): def test_save_only_changed_fields_recursive(self):
"""Ensure save only sets / unsets changed fields """Ensure save only sets / unsets changed fields
""" """
@ -1439,8 +1441,8 @@ class InstanceTest(unittest.TestCase):
post_obj = BlogPost.objects.first() post_obj = BlogPost.objects.first()
# Test laziness # Test laziness
#self.assertTrue(isinstance(post_obj._data['author'], self.assertTrue(isinstance(post_obj._data['author'],
# bson.DBRef)) bson.DBRef))
self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertTrue(isinstance(post_obj.author, self.Person))
self.assertEqual(post_obj.author.name, 'Test User') self.assertEqual(post_obj.author.name, 'Test User')
@ -1464,7 +1466,6 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(InvalidDocumentError, throw_invalid_document_error) self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
@unittest.skip("not implemented")
def test_invalid_son(self): def test_invalid_son(self):
"""Raise an error if loading invalid data""" """Raise an error if loading invalid data"""
class Occurrence(EmbeddedDocument): class Occurrence(EmbeddedDocument):
@ -1808,7 +1809,6 @@ class InstanceTest(unittest.TestCase):
self.assertTrue(u1 in all_user_set) self.assertTrue(u1 in all_user_set)
@unittest.skip("not implemented")
def test_picklable(self): def test_picklable(self):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
@ -1835,7 +1835,29 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(pickle_doc.string, "Two") self.assertEqual(pickle_doc.string, "Two")
self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) self.assertEqual(pickle_doc.lists, ["1", "2", "3"])
@unittest.skip("not implemented") def test_dynamic_document_pickle(self):
pickle_doc = PickleDynamicTest(name="test", number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar")
pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
pickle_doc.save()
pickled_doc = pickle.dumps(pickle_doc)
resurrected = pickle.loads(pickled_doc)
self.assertEqual(resurrected, pickle_doc)
self.assertEqual(resurrected._fields_ordered,
pickle_doc._fields_ordered)
self.assertEqual(resurrected._dynamic_fields.keys(),
pickle_doc._dynamic_fields.keys())
self.assertEqual(resurrected.embedded, pickle_doc.embedded)
self.assertEqual(resurrected.embedded._fields_ordered,
pickle_doc.embedded._fields_ordered)
self.assertEqual(resurrected.embedded._dynamic_fields.keys(),
pickle_doc.embedded._dynamic_fields.keys())
def test_picklable_on_signals(self): def test_picklable_on_signals(self):
pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded() pickle_doc.embedded = PickleEmbedded()
@ -1896,7 +1918,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Doc.objects(archived=False).count(), 1) self.assertEqual(Doc.objects(archived=False).count(), 1)
@unittest.skip("DynamicDocument not implemented")
def test_can_save_false_values_dynamic(self): def test_can_save_false_values_dynamic(self):
"""Ensures you can save False values on dynamic docs""" """Ensures you can save False values on dynamic docs"""
class Doc(DynamicDocument): class Doc(DynamicDocument):
@ -2036,7 +2057,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual('testdb-1', B._meta.get('db_alias')) self.assertEqual('testdb-1', B._meta.get('db_alias'))
@unittest.skip("not implemented")
def test_db_ref_usage(self): def test_db_ref_usage(self):
""" DB Ref usage in dict_fields""" """ DB Ref usage in dict_fields"""
@ -2115,7 +2135,6 @@ class InstanceTest(unittest.TestCase):
})]), })]),
"1,2") "1,2")
@unittest.skip("not implemented")
def test_switch_db_instance(self): def test_switch_db_instance(self):
register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-1', 'mongoenginetest2')
@ -2187,10 +2206,9 @@ class InstanceTest(unittest.TestCase):
user = User.objects.first() user = User.objects.first()
self.assertEqual("Ross", user.username) self.assertEqual("Ross", user.username)
self.assertEqual(True, user.foo) self.assertEqual(True, user.foo)
self.assertEqual("Bar", user._db_data["foo"]) self.assertEqual("Bar", user._data["foo"])
self.assertEqual([1, 2, 3], user._db_data["data"]) self.assertEqual([1, 2, 3], user._data["data"])
@unittest.skip("DynamicDocument not implemented")
def test_spaces_in_keys(self): def test_spaces_in_keys(self):
class Embedded(DynamicEmbeddedDocument): class Embedded(DynamicEmbeddedDocument):
@ -2207,7 +2225,6 @@ class InstanceTest(unittest.TestCase):
one = Doc.objects.filter(**{'hello world': 1}).count() one = Doc.objects.filter(**{'hello world': 1}).count()
self.assertEqual(1, one) self.assertEqual(1, one)
@unittest.skip("not implemented")
def test_shard_key(self): def test_shard_key(self):
class LogEntry(Document): class LogEntry(Document):
machine = StringField() machine = StringField()
@ -2231,7 +2248,6 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(OperationError, change_shard_key) self.assertRaises(OperationError, change_shard_key)
@unittest.skip("not implemented")
def test_shard_key_primary(self): def test_shard_key_primary(self):
class LogEntry(Document): class LogEntry(Document):
machine = StringField(primary_key=True) machine = StringField(primary_key=True)
@ -2255,7 +2271,6 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(OperationError, change_shard_key) self.assertRaises(OperationError, change_shard_key)
@unittest.skip("not implemented")
def test_kwargs_simple(self): def test_kwargs_simple(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
@ -2270,9 +2285,8 @@ class InstanceTest(unittest.TestCase):
"doc": {"name": "embedded doc"}}) "doc": {"name": "embedded doc"}})
self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc, dict_doc)
self.assertEqual(classic_doc.to_dict(), dict_doc.to_dict()) self.assertEqual(classic_doc._data, dict_doc._data)
@unittest.skip("not implemented")
def test_kwargs_complex(self): def test_kwargs_complex(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
@ -2290,9 +2304,8 @@ class InstanceTest(unittest.TestCase):
{"name": "embedded doc2"}]}) {"name": "embedded doc2"}]})
self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc, dict_doc)
self.assertEqual(classic_doc.to_dict(), dict_doc.to_dict()) self.assertEqual(classic_doc._data, dict_doc._data)
@unittest.skip("not implemented")
def test_positional_creation(self): def test_positional_creation(self):
"""Ensure that document may be created using positional arguments. """Ensure that document may be created using positional arguments.
""" """
@ -2300,7 +2313,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
@unittest.skip("not implemented")
def test_mixed_creation(self): def test_mixed_creation(self):
"""Ensure that document may be created using mixed arguments. """Ensure that document may be created using mixed arguments.
""" """
@ -2308,6 +2320,16 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_mixed_creation_dynamic(self):
"""Ensure that document may be created using mixed arguments.
"""
class Person(DynamicDocument):
name = StringField()
person = Person("Test User", age=42)
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42)
def test_bad_mixed_creation(self): def test_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments """Ensure that document gives correct error when duplicating arguments
""" """
@ -2326,8 +2348,8 @@ class InstanceTest(unittest.TestCase):
Person(name="Harry Potter").save() Person(name="Harry Potter").save()
person = Person.objects.first() person = Person.objects.first()
self.assertTrue('id' in person.to_dict().keys()) self.assertTrue('id' in person._data.keys())
self.assertEqual(person.to_dict().get('id'), person.id) self.assertEqual(person._data.get('id'), person.id)
def test_complex_nesting_document_and_embedded_document(self): def test_complex_nesting_document_and_embedded_document(self):

View file

@ -31,6 +31,10 @@ class TestJson(unittest.TestCase):
doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) doc = Doc(string="Hi", embedded_field=Embedded(string="Hi"))
doc_json = doc.to_json(sort_keys=True, separators=(',', ':'))
expected_json = """{"embedded_field":{"string":"Hi"},"string":"Hi"}"""
self.assertEqual(doc_json, expected_json)
self.assertEqual(doc, Doc.from_json(doc.to_json())) self.assertEqual(doc, Doc.from_json(doc.to_json()))
def test_json_complex(self): def test_json_complex(self):
@ -58,6 +62,7 @@ class TestJson(unittest.TestCase):
reference_field = ReferenceField(Simple, default=lambda: reference_field = ReferenceField(Simple, default=lambda:
Simple().save()) Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1}) map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.now) complex_datetime_field = ComplexDateTimeField(default=datetime.now)
url_field = URLField(default="http://mongoengine.org") url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1) dynamic_field = DynamicField(default=1)

View file

@ -53,12 +53,11 @@ class ValidatorErrorTest(unittest.TestCase):
self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])")
def test_model_validation(self): def test_model_validation(self):
class User(Document): class User(Document):
username = StringField(primary_key=True) username = StringField(primary_key=True)
name = StringField(required=True) name = StringField(required=True)
User.drop_collection()
try: try:
User().validate() User().validate()
except ValidationError, e: except ValidationError, e:
@ -129,13 +128,18 @@ class ValidatorErrorTest(unittest.TestCase):
Doc(id="test", e=SubDoc(val=15)).save() Doc(id="test", e=SubDoc(val=15)).save()
doc = Doc.objects.first() doc = Doc.objects.first()
keys = doc.to_dict().keys() keys = doc._data.keys()
self.assertEqual(2, len(keys)) self.assertEqual(2, len(keys))
self.assertTrue('e' in keys) self.assertTrue('e' in keys)
self.assertTrue('id' in keys) self.assertTrue('id' in keys)
with self.assertRaises(ValueError): doc.e.val = "OK"
doc.e.val = "OK" try:
doc.save()
except ValidationError, e:
self.assertTrue("Doc:test" in e.message)
self.assertEqual(e.to_dict(), {
"e": {'val': 'OK could not be converted to int'}})
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -56,11 +56,10 @@ class FieldTest(unittest.TestCase):
self.assertEqual(person.userid, person.userid) self.assertEqual(person.userid, person.userid)
self.assertEqual(person.created, person.created) self.assertEqual(person.created, person.created)
data = person.to_dict() self.assertEqual(person._data['name'], person.name)
self.assertEqual(data['name'], person.name) self.assertEqual(person._data['age'], person.age)
self.assertEqual(data['age'], person.age) self.assertEqual(person._data['userid'], person.userid)
self.assertEqual(data['userid'], person.userid) self.assertEqual(person._data['created'], person.created)
self.assertEqual(data['created'], person.created)
# Confirm introspection changes nothing # Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys()) data_to_be_saved = sorted(person.to_mongo().keys())
@ -89,11 +88,10 @@ class FieldTest(unittest.TestCase):
self.assertEqual(person.userid, person.userid) self.assertEqual(person.userid, person.userid)
self.assertEqual(person.created, person.created) self.assertEqual(person.created, person.created)
data = person.to_dict() self.assertEqual(person._data['name'], person.name)
self.assertEqual(data['name'], person.name) self.assertEqual(person._data['age'], person.age)
self.assertEqual(data['age'], person.age) self.assertEqual(person._data['userid'], person.userid)
self.assertEqual(data['userid'], person.userid) self.assertEqual(person._data['created'], person.created)
self.assertEqual(data['created'], person.created)
# Confirm introspection changes nothing # Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys()) data_to_be_saved = sorted(person.to_mongo().keys())
@ -125,12 +123,10 @@ class FieldTest(unittest.TestCase):
self.assertEqual(person.userid, person.userid) self.assertEqual(person.userid, person.userid)
self.assertEqual(person.created, person.created) self.assertEqual(person.created, person.created)
data = person.to_dict() self.assertEqual(person._data['name'], person.name)
self.assertEqual(person._data['age'], person.age)
self.assertEqual(data['name'], person.name) self.assertEqual(person._data['userid'], person.userid)
self.assertEqual(data['age'], person.age) self.assertEqual(person._data['created'], person.created)
self.assertEqual(data['userid'], person.userid)
self.assertEqual(data['created'], person.created)
# Confirm introspection changes nothing # Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys()) data_to_be_saved = sorted(person.to_mongo().keys())
@ -161,12 +157,10 @@ class FieldTest(unittest.TestCase):
self.assertEqual(person.userid, person.userid) self.assertEqual(person.userid, person.userid)
self.assertEqual(person.created, person.created) self.assertEqual(person.created, person.created)
data = person.to_dict() self.assertEqual(person._data['name'], person.name)
self.assertEqual(person._data['age'], person.age)
self.assertEqual(data['name'], person.name) self.assertEqual(person._data['userid'], person.userid)
self.assertEqual(data['age'], person.age) self.assertEqual(person._data['created'], person.created)
self.assertEqual(data['userid'], person.userid)
self.assertEqual(data['created'], person.created)
# Confirm introspection changes nothing # Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys()) data_to_be_saved = sorted(person.to_mongo().keys())
@ -272,6 +266,17 @@ class FieldTest(unittest.TestCase):
self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count())
self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count())
def test_long_ne_operator(self):
class TestDocument(Document):
long_fld = LongField()
TestDocument.drop_collection()
TestDocument(long_fld=None).save()
TestDocument(long_fld=1).save()
self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count())
def test_object_id_validation(self): def test_object_id_validation(self):
"""Ensure that invalid values cannot be assigned to string fields. """Ensure that invalid values cannot be assigned to string fields.
""" """
@ -342,8 +347,25 @@ class FieldTest(unittest.TestCase):
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
person.age = 120 person.age = 120
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
with self.assertRaises(ValueError): person.age = 'ten'
person.age = 'ten' self.assertRaises(ValidationError, person.validate)
def test_long_validation(self):
"""Ensure that invalid values cannot be assigned to long fields.
"""
class TestDocument(Document):
value = LongField(min_value=0, max_value=110)
doc = TestDocument()
doc.value = 50
doc.validate()
doc.value = -1
self.assertRaises(ValidationError, doc.validate)
doc.age = 120
self.assertRaises(ValidationError, doc.validate)
doc.age = 'ten'
self.assertRaises(ValidationError, doc.validate)
def test_float_validation(self): def test_float_validation(self):
"""Ensure that invalid values cannot be assigned to float fields. """Ensure that invalid values cannot be assigned to float fields.
@ -362,6 +384,69 @@ class FieldTest(unittest.TestCase):
person.height = 4.0 person.height = 4.0
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
def test_decimal_validation(self):
"""Ensure that invalid values cannot be assigned to decimal fields.
"""
class Person(Document):
height = DecimalField(min_value=Decimal('0.1'),
max_value=Decimal('3.5'))
Person.drop_collection()
Person(height=Decimal('1.89')).save()
person = Person.objects.first()
self.assertEqual(person.height, Decimal('1.89'))
person.height = '2.0'
person.save()
person.height = 0.01
self.assertRaises(ValidationError, person.validate)
person.height = Decimal('0.01')
self.assertRaises(ValidationError, person.validate)
person.height = Decimal('4.0')
self.assertRaises(ValidationError, person.validate)
Person.drop_collection()
def test_decimal_comparison(self):
class Person(Document):
money = DecimalField()
Person.drop_collection()
Person(money=6).save()
Person(money=8).save()
Person(money=10).save()
self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count())
self.assertEqual(2, Person.objects(money__gt=7).count())
self.assertEqual(2, Person.objects(money__gt="7").count())
def test_decimal_storage(self):
class Person(Document):
btc = DecimalField(precision=4)
Person.drop_collection()
Person(btc=10).save()
Person(btc=10.1).save()
Person(btc=10.11).save()
Person(btc="10.111").save()
Person(btc=Decimal("10.1111")).save()
Person(btc=Decimal("10.11111")).save()
# How its stored
expected = [{'btc': 10.0}, {'btc': 10.1}, {'btc': 10.11},
{'btc': 10.111}, {'btc': 10.1111}, {'btc': 10.1111}]
actual = list(Person.objects.exclude('id').as_pymongo())
self.assertEqual(expected, actual)
# How it comes out locally
expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'),
Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')]
actual = list(Person.objects().scalar('btc'))
self.assertEqual(expected, actual)
def test_boolean_validation(self): def test_boolean_validation(self):
"""Ensure that invalid values cannot be assigned to boolean fields. """Ensure that invalid values cannot be assigned to boolean fields.
""" """
@ -447,10 +532,10 @@ class FieldTest(unittest.TestCase):
log.time = datetime.datetime.now().isoformat('T') log.time = datetime.datetime.now().isoformat('T')
log.validate() log.validate()
#log.time = -1 log.time = -1
#self.assertRaises(ValidationError, log.validate) self.assertRaises(ValidationError, log.validate)
#log.time = 'ABC' log.time = 'ABC'
#self.assertRaises(ValidationError, log.validate) self.assertRaises(ValidationError, log.validate)
def test_datetime_tz_aware_mark_as_changed(self): def test_datetime_tz_aware_mark_as_changed(self):
from mongoengine import connection from mongoengine import connection
@ -471,7 +556,7 @@ class FieldTest(unittest.TestCase):
log = LogEntry.objects.first() log = LogEntry.objects.first()
log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) log.time = datetime.datetime(2013, 1, 1, 0, 0, 0)
self.assertEqual(set(['time']), log._changed_fields) self.assertEqual(['time'], log._changed_fields)
def test_datetime(self): def test_datetime(self):
"""Tests showing pymongo datetime fields handling of microseconds. """Tests showing pymongo datetime fields handling of microseconds.
@ -706,8 +791,8 @@ class FieldTest(unittest.TestCase):
post = BlogPost(content='Went for a walk today...') post = BlogPost(content='Went for a walk today...')
post.validate() post.validate()
#post.tags = 'fun' post.tags = 'fun'
#self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
post.tags = [1, 2] post.tags = [1, 2]
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
@ -818,11 +903,11 @@ class FieldTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
post = BlogPost() post = BlogPost()
#post.info = 'my post' post.info = 'my post'
#self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
#post.info = {'title': 'test'} post.info = {'title': 'test'}
#self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
post.info = ['test'] post.info = ['test']
post.save() post.save()
@ -871,12 +956,14 @@ class FieldTest(unittest.TestCase):
e.mapping = [1] e.mapping = [1]
e.save() e.save()
with self.assertRaises(ValueError): def create_invalid_mapping():
e.mapping = ["abc"] e.mapping = ["abc"]
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection() Simple.drop_collection()
@unittest.skip("different behavior")
def test_list_field_rejects_strings(self): def test_list_field_rejects_strings(self):
"""Strings aren't valid list field data types""" """Strings aren't valid list field data types"""
@ -921,7 +1008,7 @@ class FieldTest(unittest.TestCase):
Simple.drop_collection() Simple.drop_collection()
e = Simple().save() e = Simple().save()
e.mapping = [] e.mapping = []
self.assertEqual(set([]), e._changed_fields) self.assertEqual([], e._changed_fields)
class Simple(Document): class Simple(Document):
mapping = DictField() mapping = DictField()
@ -929,9 +1016,34 @@ class FieldTest(unittest.TestCase):
Simple.drop_collection() Simple.drop_collection()
e = Simple().save() e = Simple().save()
e.mapping = {} e.mapping = {}
self.assertEqual(set([]), e._changed_fields) self.assertEqual([], e._changed_fields)
def test_slice_marks_field_as_changed(self):
class Simple(Document):
widgets = ListField()
simple = Simple(widgets=[1, 2, 3, 4]).save()
simple.widgets[:3] = []
self.assertEqual(['widgets'], simple._changed_fields)
simple.save()
simple = simple.reload()
self.assertEqual(simple.widgets, [4])
def test_del_slice_marks_field_as_changed(self):
class Simple(Document):
widgets = ListField()
simple = Simple(widgets=[1, 2, 3, 4]).save()
del simple.widgets[:3]
self.assertEqual(['widgets'], simple._changed_fields)
simple.save()
simple = simple.reload()
self.assertEqual(simple.widgets, [4])
@unittest.skip("complex types not implemented")
def test_list_field_complex(self): def test_list_field_complex(self):
"""Ensure that the list fields can handle the complex types.""" """Ensure that the list fields can handle the complex types."""
@ -988,11 +1100,11 @@ class FieldTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
post = BlogPost() post = BlogPost()
#post.info = 'my post' post.info = 'my post'
#self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
#post.info = ['test', 'test'] post.info = ['test', 'test']
#self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
post.info = {'$title': 'test'} post.info = {'$title': 'test'}
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
@ -1050,7 +1162,6 @@ class FieldTest(unittest.TestCase):
Simple.drop_collection() Simple.drop_collection()
@unittest.skip("complex types not implemented")
def test_dictfield_complex(self): def test_dictfield_complex(self):
"""Ensure that the dict field can handle the complex types.""" """Ensure that the dict field can handle the complex types."""
@ -1868,7 +1979,6 @@ class FieldTest(unittest.TestCase):
Shirt.drop_collection() Shirt.drop_collection()
@unittest.skip("not implemented")
def test_choices_get_field_display(self): def test_choices_get_field_display(self):
"""Test dynamic helper for returning the display value of a choices """Test dynamic helper for returning the display value of a choices
field. field.
@ -1921,7 +2031,6 @@ class FieldTest(unittest.TestCase):
Shirt.drop_collection() Shirt.drop_collection()
@unittest.skip("not implemented")
def test_simple_choices_get_field_display(self): def test_simple_choices_get_field_display(self):
"""Test dynamic helper for returning the display value of a choices """Test dynamic helper for returning the display value of a choices
field. field.
@ -2001,7 +2110,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(d2.data, {}) self.assertEqual(d2.data, {})
self.assertEqual(d2.data2, {}) self.assertEqual(d2.data2, {})
@unittest.skip("SequenceField not implemented")
def test_sequence_field(self): def test_sequence_field(self):
class Person(Document): class Person(Document):
id = SequenceField(primary_key=True) id = SequenceField(primary_key=True)
@ -2027,7 +2135,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(c['next'], 1000) self.assertEqual(c['next'], 1000)
@unittest.skip("SequenceField not implemented")
def test_sequence_field_get_next_value(self): def test_sequence_field_get_next_value(self):
class Person(Document): class Person(Document):
id = SequenceField(primary_key=True) id = SequenceField(primary_key=True)
@ -2059,7 +2166,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(Person.id.get_next_value(), '1') self.assertEqual(Person.id.get_next_value(), '1')
@unittest.skip("SequenceField not implemented")
def test_sequence_field_sequence_name(self): def test_sequence_field_sequence_name(self):
class Person(Document): class Person(Document):
id = SequenceField(primary_key=True, sequence_name='jelly') id = SequenceField(primary_key=True, sequence_name='jelly')
@ -2084,7 +2190,6 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
self.assertEqual(c['next'], 1000) self.assertEqual(c['next'], 1000)
@unittest.skip("SequenceField not implemented")
def test_multiple_sequence_fields(self): def test_multiple_sequence_fields(self):
class Person(Document): class Person(Document):
id = SequenceField(primary_key=True) id = SequenceField(primary_key=True)
@ -2117,7 +2222,6 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'})
self.assertEqual(c['next'], 999) self.assertEqual(c['next'], 999)
@unittest.skip("SequenceField not implemented")
def test_sequence_fields_reload(self): def test_sequence_fields_reload(self):
class Animal(Document): class Animal(Document):
counter = SequenceField() counter = SequenceField()
@ -2143,7 +2247,6 @@ class FieldTest(unittest.TestCase):
a.reload() a.reload()
self.assertEqual(a.counter, 2) self.assertEqual(a.counter, 2)
@unittest.skip("SequenceField not implemented")
def test_multiple_sequence_fields_on_docs(self): def test_multiple_sequence_fields_on_docs(self):
class Animal(Document): class Animal(Document):
@ -2178,7 +2281,6 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'})
self.assertEqual(c['next'], 10) self.assertEqual(c['next'], 10)
@unittest.skip("SequenceField not implemented")
def test_sequence_field_value_decorator(self): def test_sequence_field_value_decorator(self):
class Person(Document): class Person(Document):
id = SequenceField(primary_key=True, value_decorator=str) id = SequenceField(primary_key=True, value_decorator=str)
@ -2200,7 +2302,6 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10) self.assertEqual(c['next'], 10)
@unittest.skip("SequenceField not implemented")
def test_embedded_sequence_field(self): def test_embedded_sequence_field(self):
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
id = SequenceField() id = SequenceField()
@ -2373,6 +2474,78 @@ class FieldTest(unittest.TestCase):
user = User(email='me@example.com') user = User(email='me@example.com')
self.assertTrue(user.validate() is None) self.assertTrue(user.validate() is None)
def test_tuples_as_tuples(self):
"""
Ensure that tuples remain tuples when they are
inside a ComplexBaseField
"""
from mongoengine.base import BaseField
class EnumField(BaseField):
def __init__(self, **kwargs):
super(EnumField, self).__init__(**kwargs)
def to_mongo(self, value):
return value
def to_python(self, value):
return tuple(value)
class TestDoc(Document):
items = ListField(EnumField())
TestDoc.drop_collection()
tuples = [(100, 'Testing')]
doc = TestDoc()
doc.items = tuples
doc.save()
x = TestDoc.objects().get()
self.assertTrue(x is not None)
self.assertTrue(len(x.items) == 1)
self.assertTrue(tuple(x.items[0]) in tuples)
self.assertTrue(x.items[0] in tuples)
def test_dynamic_fields_class(self):
class Doc2(Document):
field_1 = StringField(db_field='f')
class Doc(Document):
my_id = IntField(required=True, unique=True, primary_key=True)
embed_me = DynamicField(db_field='e')
field_x = StringField(db_field='x')
Doc.drop_collection()
Doc2.drop_collection()
doc2 = Doc2(field_1="hello")
doc = Doc(my_id=1, embed_me=doc2, field_x="x")
self.assertRaises(OperationError, doc.save)
doc2.save()
doc.save()
doc = Doc.objects.get()
self.assertEqual(doc.embed_me.field_1, "hello")
def test_dynamic_fields_embedded_class(self):
class Embed(EmbeddedDocument):
field_1 = StringField(db_field='f')
class Doc(Document):
my_id = IntField(required=True, unique=True, primary_key=True)
embed_me = DynamicField(db_field='e')
field_x = StringField(db_field='x')
Doc.drop_collection()
Doc(my_id=1, embed_me=Embed(field_1="hello"), field_x="x").save()
doc = Doc.objects.get()
self.assertEqual(doc.embed_me.field_1, "hello")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -24,7 +24,6 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png')
@unittest.skip("FileField not implemented")
class FileTest(unittest.TestCase): class FileTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -54,11 +53,12 @@ class FileTest(unittest.TestCase):
content_type = 'text/plain' content_type = 'text/plain'
putfile = PutFile() putfile = PutFile()
putfile.the_file.put(text, content_type=content_type) putfile.the_file.put(text, content_type=content_type, filename="hello")
putfile.save() putfile.save()
result = PutFile.objects.first() result = PutFile.objects.first()
self.assertTrue(putfile == result) self.assertTrue(putfile == result)
self.assertEqual("%s" % result.the_file, "<GridFSProxy: hello>")
self.assertEqual(result.the_file.read(), text) self.assertEqual(result.the_file.read(), text)
self.assertEqual(result.the_file.content_type, content_type) self.assertEqual(result.the_file.content_type, content_type)
result.the_file.delete() # Remove file from GridFS result.the_file.delete() # Remove file from GridFS
@ -456,5 +456,31 @@ class FileTest(unittest.TestCase):
self.assertEqual(1, TestImage.objects(Q(image1=grid_id) self.assertEqual(1, TestImage.objects(Q(image1=grid_id)
or Q(image2=grid_id)).count()) or Q(image2=grid_id)).count())
def test_complex_field_filefield(self):
"""Ensure you can add meta data to file"""
class Animal(Document):
genus = StringField()
family = StringField()
photos = ListField(FileField())
Animal.drop_collection()
marmot = Animal(genus='Marmota', family='Sciuridae')
marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk
photos_field = marmot._fields['photos'].field
new_proxy = photos_field.get_proxy_obj('photos', marmot)
new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar')
marmot_photo.close()
marmot.photos.append(new_proxy)
marmot.save()
marmot = Animal.objects.get()
self.assertEqual(marmot.photos[0].content_type, 'image/jpeg')
self.assertEqual(marmot.photos[0].foo, 'bar')
self.assertEqual(marmot.photos[0].get().length, 8313)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -10,7 +10,6 @@ from mongoengine.connection import get_db
__all__ = ("GeoFieldTest", ) __all__ = ("GeoFieldTest", )
@unittest.skip("geo fields not implemented")
class GeoFieldTest(unittest.TestCase): class GeoFieldTest(unittest.TestCase):
def setUp(self): def setUp(self):

View file

@ -17,6 +17,14 @@ class PickleTest(Document):
photo = FileField() photo = FileField()
class PickleDyanmicEmbedded(DynamicEmbeddedDocument):
date = DateTimeField(default=datetime.now)
class PickleDynamicTest(DynamicDocument):
number = IntField()
class PickleSignalsTest(Document): class PickleSignalsTest(Document):
number = IntField() number = IntField()
string = StringField(choices=(('One', '1'), ('Two', '2'))) string = StringField(choices=(('One', '1'), ('Two', '2')))

View file

@ -1,4 +1,5 @@
from convert_to_new_inheritance_model import * from convert_to_new_inheritance_model import *
from decimalfield_as_float import *
from refrencefield_dbref_to_object_id import * from refrencefield_dbref_to_object_id import *
from turn_off_inheritance import * from turn_off_inheritance import *
from uuidfield_to_binary import * from uuidfield_to_binary import *

View file

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
import unittest
import decimal
from decimal import Decimal
from mongoengine import Document, connect
from mongoengine.connection import get_db
from mongoengine.fields import StringField, DecimalField, ListField
__all__ = ('ConvertDecimalField', )
class ConvertDecimalField(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def test_how_to_convert_decimal_fields(self):
"""Demonstrates migrating from 0.7 to 0.8
"""
# 1. Old definition - using dbrefs
class Person(Document):
name = StringField()
money = DecimalField(force_string=True)
monies = ListField(DecimalField(force_string=True))
Person.drop_collection()
Person(name="Wilson Jr", money=Decimal("2.50"),
monies=[Decimal("2.10"), Decimal("5.00")]).save()
# 2. Start the migration by changing the schema
# Change DecimalField - add precision and rounding settings
class Person(Document):
name = StringField()
money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP)
monies = ListField(DecimalField(precision=2,
rounding=decimal.ROUND_HALF_UP))
# 3. Loop all the objects and mark parent as changed
for p in Person.objects:
p._mark_as_changed('money')
p._mark_as_changed('monies')
p.save()
# 4. Confirmation of the fix!
wilson = Person.objects(name="Wilson Jr").as_pymongo()[0]
self.assertTrue(isinstance(wilson['money'], float))
self.assertTrue(all([isinstance(m, float) for m in wilson['monies']]))

View file

@ -162,6 +162,10 @@ class OnlyExcludeAllTest(unittest.TestCase):
self.assertEqual(obj.name, person.name) self.assertEqual(obj.name, person.name)
self.assertEqual(obj.age, person.age) self.assertEqual(obj.age, person.age)
obj = self.Person.objects.only(*('id', 'name',)).get()
self.assertEqual(obj.name, person.name)
self.assertEqual(obj.age, None)
# Check polymorphism still works # Check polymorphism still works
class Employee(self.Person): class Employee(self.Person):
salary = IntField(db_field='wage') salary = IntField(db_field='wage')
@ -395,5 +399,28 @@ class OnlyExcludeAllTest(unittest.TestCase):
numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get()
self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1])
def test_exclude_from_subclasses_docs(self):
class Base(Document):
username = StringField()
meta = {'allow_inheritance': True}
class Anon(Base):
anon = BooleanField()
class User(Base):
password = StringField()
wibble = StringField()
Base.drop_collection()
User(username="mongodb", password="secret").save()
user = Base.objects().exclude("password", "wibble").first()
self.assertEqual(user.password, None)
self.assertRaises(LookUpError, Base.objects.exclude, "made_up")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -8,7 +8,6 @@ from mongoengine import *
__all__ = ("GeoQueriesTest",) __all__ = ("GeoQueriesTest",)
@unittest.skip("geo queries not implemented")
class GeoQueriesTest(unittest.TestCase): class GeoQueriesTest(unittest.TestCase):
def setUp(self): def setUp(self):

View file

@ -30,12 +30,17 @@ class QuerySetTest(unittest.TestCase):
def setUp(self): def setUp(self):
connect(db='mongoenginetest') connect(db='mongoenginetest')
class PersonMeta(EmbeddedDocument):
weight = IntField()
class Person(Document): class Person(Document):
name = StringField() name = StringField()
age = IntField() age = IntField()
person_meta = EmbeddedDocumentField(PersonMeta)
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
Person.drop_collection() Person.drop_collection()
self.PersonMeta = PersonMeta
self.Person = Person self.Person = Person
def test_initialisation(self): def test_initialisation(self):
@ -777,10 +782,10 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(q, 0) self.assertEqual(q, 0)
fresh_o1 = Organization.objects.get(id=o1.id) fresh_o1 = Organization.objects.get(id=o1.id)
fresh_o1.employees.append(p2) fresh_o1.employees.append(p2) # Dereferences
fresh_o1.save(cascade=False) # Saves fresh_o1.save(cascade=False) # Saves
self.assertEqual(q, 2) self.assertEqual(q, 3)
def test_slave_okay(self): def test_slave_okay(self):
"""Ensures that a query can take slave_okay syntax """Ensures that a query can take slave_okay syntax
@ -1492,9 +1497,6 @@ class QuerySetTest(unittest.TestCase):
def test_pull_nested(self): def test_pull_nested(self):
class User(Document):
name = StringField()
class Collaborator(EmbeddedDocument): class Collaborator(EmbeddedDocument):
user = StringField() user = StringField()
@ -1509,8 +1511,7 @@ class QuerySetTest(unittest.TestCase):
Site.drop_collection() Site.drop_collection()
c = Collaborator(user='Esteban') c = Collaborator(user='Esteban')
s = Site(name="test", collaborators=[c]) s = Site(name="test", collaborators=[c]).save()
s.save()
Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban')
self.assertEqual(Site.objects.first().collaborators, []) self.assertEqual(Site.objects.first().collaborators, [])
@ -1520,6 +1521,71 @@ class QuerySetTest(unittest.TestCase):
self.assertRaises(InvalidQueryError, pull_all) self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_embedded(self):
class User(EmbeddedDocument):
name = StringField()
def __unicode__(self):
return '%s' % self.name
class Collaborator(EmbeddedDocument):
helpful = ListField(EmbeddedDocumentField(User))
unhelpful = ListField(EmbeddedDocumentField(User))
class Site(Document):
name = StringField(max_length=75, unique=True, required=True)
collaborators = EmbeddedDocumentField(Collaborator)
Site.drop_collection()
c = User(name='Esteban')
f = User(name='Frank')
s = Site(name="test", collaborators=Collaborator(helpful=[c], unhelpful=[f])).save()
Site.objects(id=s.id).update_one(pull__collaborators__helpful=c)
self.assertEqual(Site.objects.first().collaborators['helpful'], [])
Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'name': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all():
Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__name=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_mapfield(self):
class Collaborator(EmbeddedDocument):
user = StringField()
def __unicode__(self):
return '%s' % self.user
class Site(Document):
name = StringField(max_length=75, unique=True, required=True)
collaborators = MapField(ListField(EmbeddedDocumentField(Collaborator)))
Site.drop_collection()
c = Collaborator(user='Esteban')
f = Collaborator(user='Frank')
s = Site(name="test", collaborators={'helpful':[c],'unhelpful':[f]})
s.save()
Site.objects(id=s.id).update_one(pull__collaborators__helpful__user='Esteban')
self.assertEqual(Site.objects.first().collaborators['helpful'], [])
Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'user':'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all():
Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_update_one_pop_generic_reference(self): def test_update_one_pop_generic_reference(self):
class BlogTag(Document): class BlogTag(Document):
@ -2208,6 +2274,19 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='ageless person').save() self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.average('age')), avg) self.assertEqual(int(self.Person.objects.average('age')), avg)
# dot notation
self.Person(name='person meta', person_meta=self.PersonMeta(weight=0)).save()
self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), 0)
for i, weight in enumerate(ages):
self.Person(name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save()
self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), avg)
self.Person(name='test meta none').save()
self.assertEqual(int(self.Person.objects.average('person_meta.weight')), avg)
def test_sum(self): def test_sum(self):
"""Ensure that field can be summed over correctly. """Ensure that field can be summed over correctly.
""" """
@ -2220,6 +2299,153 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='ageless person').save() self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
for i, age in enumerate(ages):
self.Person(name='test meta%s' % i, person_meta=self.PersonMeta(weight=age)).save()
self.assertEqual(int(self.Person.objects.sum('person_meta.weight')), sum(ages))
self.Person(name='weightless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
def test_embedded_average(self):
class Pay(EmbeddedDocument):
value = DecimalField()
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(value=150)).save()
Doc(name=u"Isabella Luanna",
pay=Pay(value=530)).save()
Doc(name=u"Tayza mariana",
pay=Pay(value=165)).save()
Doc(name=u"Eliana Costa",
pay=Pay(value=115)).save()
self.assertEqual(
Doc.objects.average('pay.value'),
240)
def test_embedded_array_average(self):
class Pay(EmbeddedDocument):
values = ListField(DecimalField())
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(values=[150, 100])).save()
Doc(name=u"Isabella Luanna",
pay=Pay(values=[530, 100])).save()
Doc(name=u"Tayza mariana",
pay=Pay(values=[165, 100])).save()
Doc(name=u"Eliana Costa",
pay=Pay(values=[115, 100])).save()
self.assertEqual(
Doc.objects.average('pay.values'),
170)
def test_array_average(self):
class Doc(Document):
values = ListField(DecimalField())
Doc.drop_collection()
Doc(values=[150, 100]).save()
Doc(values=[530, 100]).save()
Doc(values=[165, 100]).save()
Doc(values=[115, 100]).save()
self.assertEqual(
Doc.objects.average('values'),
170)
def test_embedded_sum(self):
class Pay(EmbeddedDocument):
value = DecimalField()
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(value=150)).save()
Doc(name=u"Isabella Luanna",
pay=Pay(value=530)).save()
Doc(name=u"Tayza mariana",
pay=Pay(value=165)).save()
Doc(name=u"Eliana Costa",
pay=Pay(value=115)).save()
self.assertEqual(
Doc.objects.sum('pay.value'),
960)
def test_embedded_array_sum(self):
class Pay(EmbeddedDocument):
values = ListField(DecimalField())
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(values=[150, 100])).save()
Doc(name=u"Isabella Luanna",
pay=Pay(values=[530, 100])).save()
Doc(name=u"Tayza mariana",
pay=Pay(values=[165, 100])).save()
Doc(name=u"Eliana Costa",
pay=Pay(values=[115, 100])).save()
self.assertEqual(
Doc.objects.sum('pay.values'),
1360)
def test_array_sum(self):
class Doc(Document):
values = ListField(DecimalField())
Doc.drop_collection()
Doc(values=[150, 100]).save()
Doc(values=[530, 100]).save()
Doc(values=[165, 100]).save()
Doc(values=[115, 100]).save()
self.assertEqual(
Doc.objects.sum('values'),
1360)
def test_distinct(self): def test_distinct(self):
"""Ensure that the QuerySet.distinct method works. """Ensure that the QuerySet.distinct method works.
""" """
@ -2922,6 +3148,19 @@ class QuerySetTest(unittest.TestCase):
(u'Wilson Jr', 19, u'Corumba-GO'), (u'Wilson Jr', 19, u'Corumba-GO'),
(u'Gabriel Falcao', 23, u'New York')]) (u'Gabriel Falcao', 23, u'New York')])
def test_scalar_decimal(self):
from decimal import Decimal
class Person(Document):
name = StringField()
rating = DecimalField()
Person.drop_collection()
Person(name="Wilson Jr", rating=Decimal('1.0')).save()
ulist = list(Person.objects.scalar('name', 'rating'))
self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))])
def test_scalar_reference_field(self): def test_scalar_reference_field(self):
class State(Document): class State(Document):
name = StringField() name = StringField()
@ -3121,6 +3360,13 @@ class QuerySetTest(unittest.TestCase):
Test.objects(test='foo').update_one(upsert=True, set__test='foo') Test.objects(test='foo').update_one(upsert=True, set__test='foo')
self.assertTrue('_cls' in Test._collection.find_one()) self.assertTrue('_cls' in Test._collection.find_one())
def test_update_upsert_looks_like_a_digit(self):
class MyDoc(DynamicDocument):
pass
MyDoc.drop_collection()
self.assertEqual(1, MyDoc.objects.update_one(upsert=True, inc__47=1))
self.assertEqual(MyDoc.objects.get()['47'], 1)
def test_read_preference(self): def test_read_preference(self):
class Bar(Document): class Bar(Document):
pass pass
@ -3149,7 +3395,7 @@ class QuerySetTest(unittest.TestCase):
Doc(string="Bye", embedded_field=Embedded(string="Bye")).save() Doc(string="Bye", embedded_field=Embedded(string="Bye")).save()
Doc().save() Doc().save()
json_data = Doc.objects.to_json() json_data = Doc.objects.to_json(sort_keys=True, separators=(',', ':'))
doc_objects = list(Doc.objects) doc_objects = list(Doc.objects)
self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) self.assertEqual(doc_objects, Doc.objects.from_json(json_data))
@ -3177,6 +3423,7 @@ class QuerySetTest(unittest.TestCase):
objectid_field = ObjectIdField(default=ObjectId) objectid_field = ObjectIdField(default=ObjectId)
reference_field = ReferenceField(Simple, default=lambda: Simple().save()) reference_field = ReferenceField(Simple, default=lambda: Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1}) map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.now) complex_datetime_field = ComplexDateTimeField(default=datetime.now)
url_field = URLField(default="http://mongoengine.org") url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1) dynamic_field = DynamicField(default=1)
@ -3207,25 +3454,33 @@ class QuerySetTest(unittest.TestCase):
id = ObjectIdField('_id') id = ObjectIdField('_id')
name = StringField() name = StringField()
age = IntField() age = IntField()
price = DecimalField()
User.drop_collection() User.drop_collection()
User(name="Bob Dole", age=89).save() User(name="Bob Dole", age=89, price=Decimal('1.11')).save()
User(name="Barack Obama", age=51).save() User(name="Barack Obama", age=51, price=Decimal('2.22')).save()
users = User.objects.only('name').as_pymongo() results = User.objects.only('id', 'name').as_pymongo()
self.assertEqual(sorted(results[0].keys()), sorted(['_id', 'name']))
users = User.objects.only('name', 'price').as_pymongo()
results = list(users) results = list(users)
self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[0], dict))
self.assertTrue(isinstance(results[1], dict)) self.assertTrue(isinstance(results[1], dict))
self.assertEqual(results[0]['name'], 'Bob Dole') self.assertEqual(results[0]['name'], 'Bob Dole')
self.assertEqual(results[0]['price'], 1.11)
self.assertEqual(results[1]['name'], 'Barack Obama') self.assertEqual(results[1]['name'], 'Barack Obama')
self.assertEqual(results[1]['price'], 2.22)
# Test coerce_types # Test coerce_types
users = User.objects.only('name').as_pymongo(coerce_types=True) users = User.objects.only('name', 'price').as_pymongo(coerce_types=True)
results = list(users) results = list(users)
self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[0], dict))
self.assertTrue(isinstance(results[1], dict)) self.assertTrue(isinstance(results[1], dict))
self.assertEqual(results[0]['name'], 'Bob Dole') self.assertEqual(results[0]['name'], 'Bob Dole')
self.assertEqual(results[0]['price'], Decimal('1.11'))
self.assertEqual(results[1]['name'], 'Barack Obama') self.assertEqual(results[1]['name'], 'Barack Obama')
self.assertEqual(results[1]['price'], Decimal('2.22'))
def test_as_pymongo_json_limit_fields(self): def test_as_pymongo_json_limit_fields(self):
@ -3249,7 +3504,6 @@ class QuerySetTest(unittest.TestCase):
serialized_user = User.objects.exclude('password_salt').only('email').to_json() serialized_user = User.objects.exclude('password_salt').only('email').to_json()
self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) self.assertEqual('[{"email": "ross@example.com"}]', serialized_user)
@unittest.skip("not implemented")
def test_no_dereference(self): def test_no_dereference(self):
class Organization(Document): class Organization(Document):
@ -3297,6 +3551,27 @@ class QuerySetTest(unittest.TestCase):
people.count() # count is cached people.count() # count is cached
self.assertEqual(q, 1) self.assertEqual(q, 1)
def test_no_cached_queryset(self):
class Person(Document):
name = StringField()
Person.drop_collection()
for i in xrange(100):
Person(name="No: %s" % i).save()
with query_counter() as q:
self.assertEqual(q, 0)
people = Person.objects.no_cache()
[x for x in people]
self.assertEqual(q, 1)
list(people)
self.assertEqual(q, 2)
people.count()
self.assertEqual(q, 3)
def test_cache_not_cloned(self): def test_cache_not_cloned(self):
class User(Document): class User(Document):
@ -3318,6 +3593,34 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual("%s" % users, "[<User: Bob>]") self.assertEqual("%s" % users, "[<User: Bob>]")
self.assertEqual(1, len(users._result_cache)) self.assertEqual(1, len(users._result_cache))
def test_no_cache(self):
"""Ensure you can add meta data to file"""
class Noddy(Document):
fields = DictField()
Noddy.drop_collection()
for i in xrange(100):
noddy = Noddy()
for j in range(20):
noddy.fields["key"+str(j)] = "value "+str(j)
noddy.save()
docs = Noddy.objects.no_cache()
counter = len([1 for i in docs])
self.assertEquals(counter, 100)
self.assertEquals(len(list(docs)), 100)
self.assertRaises(TypeError, lambda: len(docs))
with query_counter() as q:
self.assertEqual(q, 0)
list(docs)
self.assertEqual(q, 1)
list(docs)
self.assertEqual(q, 2)
def test_nested_queryset_iterator(self): def test_nested_queryset_iterator(self):
# Try iterating the same queryset twice, nested. # Try iterating the same queryset twice, nested.
names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George'] names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George']
@ -3449,6 +3752,23 @@ class QuerySetTest(unittest.TestCase):
'_cls': 'Animal.Cat' '_cls': 'Animal.Cat'
}) })
def test_can_have_field_same_name_as_query_operator(self):
class Size(Document):
name = StringField()
class Example(Document):
size = ReferenceField(Size)
Size.drop_collection()
Example.drop_collection()
instance_size = Size(name="Large").save()
Example(size=instance_size).save()
self.assertEqual(Example.objects(size=instance_size).count(), 1)
self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -31,6 +31,31 @@ class TransformTest(unittest.TestCase):
self.assertEqual(transform.query(name__exists=True), self.assertEqual(transform.query(name__exists=True),
{'name': {'$exists': True}}) {'name': {'$exists': True}})
def test_transform_update(self):
class DicDoc(Document):
dictField = DictField()
class Doc(Document):
pass
DicDoc.drop_collection()
Doc.drop_collection()
doc = Doc().save()
dic_doc = DicDoc().save()
for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")):
update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc})
self.assertTrue(isinstance(update[v]["dictField.test"], dict))
# Update special cases
update = transform.update(DicDoc, unset__dictField__test=doc)
self.assertEqual(update["$unset"]["dictField.test"], 1)
update = transform.update(DicDoc, pull__dictField__test=doc)
self.assertTrue(isinstance(update["$pull"]["dictField"]["test"], dict))
def test_query_field_name(self): def test_query_field_name(self):
"""Ensure that the correct field name is used when querying. """Ensure that the correct field name is used when querying.
""" """
@ -63,7 +88,6 @@ class TransformTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
@unittest.skip("unsupported")
def test_query_pk_field_name(self): def test_query_pk_field_name(self):
"""Ensure that the correct "primary key" field name is used when """Ensure that the correct "primary key" field name is used when
querying querying

View file

@ -59,6 +59,32 @@ class ConnectionTest(unittest.TestCase):
c.admin.system.users.remove({}) c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({}) c.mongoenginetest.system.users.remove({})
def test_connect_uri_without_db(self):
"""Ensure that the connect() method works properly with uri's
without database_name
"""
c = connect(db='mongoenginetest', alias='admin')
c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({})
c.admin.add_user("admin", "password")
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("mongoenginetest", host='mongodb://localhost/')
conn = get_connection()
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
db = get_db()
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'mongoenginetest')
c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({})
def test_register_connection(self): def test_register_connection(self):
"""Ensure that connections with different aliases may be registered. """Ensure that connections with different aliases may be registered.
""" """

View file

@ -16,7 +16,6 @@ class FieldTest(unittest.TestCase):
connect(db='mongoenginetest') connect(db='mongoenginetest')
self.db = get_db() self.db = get_db()
@unittest.skip("select_related currently doesn't dereference lists")
def test_list_item_dereference(self): def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced. """Ensure that DBRef items in ListFields are dereferenced.
""" """
@ -75,7 +74,6 @@ class FieldTest(unittest.TestCase):
User.drop_collection() User.drop_collection()
Group.drop_collection() Group.drop_collection()
@unittest.skip("select_related currently doesn't dereference lists")
def test_list_item_dereference_dref_false(self): def test_list_item_dereference_dref_false(self):
"""Ensure that DBRef items in ListFields are dereferenced. """Ensure that DBRef items in ListFields are dereferenced.
""" """
@ -148,7 +146,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(Group._get_collection().find_one()['members'], [1]) self.assertEqual(Group._get_collection().find_one()['members'], [1])
self.assertEqual(group.members, [user]) self.assertEqual(group.members, [user])
@unittest.skip('currently not implemented')
def test_handle_old_style_references(self): def test_handle_old_style_references(self):
"""Ensure that DBRef items in ListFields are dereferenced. """Ensure that DBRef items in ListFields are dereferenced.
""" """
@ -182,7 +179,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(group.members[0].name, 'user 1') self.assertEqual(group.members[0].name, 'user 1')
self.assertEqual(group.members[-1].name, 'String!') self.assertEqual(group.members[-1].name, 'String!')
@unittest.skip('currently not implemented')
def test_migrate_references(self): def test_migrate_references(self):
"""Example of migrating ReferenceField storage """Example of migrating ReferenceField storage
""" """
@ -229,7 +225,6 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(raw_data['author'], ObjectId)) self.assertTrue(isinstance(raw_data['author'], ObjectId))
self.assertTrue(isinstance(raw_data['members'][0], ObjectId)) self.assertTrue(isinstance(raw_data['members'][0], ObjectId))
@unittest.skip("select_related currently doesn't dereference lists")
def test_recursive_reference(self): def test_recursive_reference(self):
"""Ensure that ReferenceFields can reference their own documents. """Ensure that ReferenceFields can reference their own documents.
""" """
@ -264,15 +259,9 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 1) self.assertEqual(q, 1)
peter.boss peter.boss
self.assertEqual(q, 1)
peter.friends
self.assertEqual(q, 1)
peter.boss.name
self.assertEqual(q, 2) self.assertEqual(q, 2)
peter.friends[0].name peter.friends
self.assertEqual(q, 3) self.assertEqual(q, 3)
# Document select_related # Document select_related
@ -302,32 +291,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(employee.friends, friends) self.assertEqual(employee.friends, friends)
self.assertEqual(q, 2) self.assertEqual(q, 2)
def test_list_of_lists_of_references(self):
class User(Document):
name = StringField()
class Post(Document):
user_lists = ListField(ListField(ReferenceField(User)))
class SimpleList(Document):
users = ListField(ReferenceField(User))
User.drop_collection()
Post.drop_collection()
SimpleList.drop_collection()
u1 = User.objects.create(name='u1')
u2 = User.objects.create(name='u2')
u3 = User.objects.create(name='u3')
SimpleList.objects.create(users=[u1, u2, u3])
self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3])
Post.objects.create(user_lists=[[u1, u2], [u3]])
self.assertEqual(Post.objects.all()[0].user_lists, [[u1, u2], [u3]])
def test_circular_reference(self): def test_circular_reference(self):
"""Ensure you can handle circular references """Ensure you can handle circular references
""" """
@ -428,7 +391,6 @@ class FieldTest(unittest.TestCase):
"%s" % Person.objects() "%s" % Person.objects()
) )
@unittest.skip("not implemented")
def test_generic_reference(self): def test_generic_reference(self):
class UserA(Document): class UserA(Document):
@ -520,7 +482,6 @@ class FieldTest(unittest.TestCase):
UserC.drop_collection() UserC.drop_collection()
Group.drop_collection() Group.drop_collection()
@unittest.skip("not implemented")
def test_list_field_complex(self): def test_list_field_complex(self):
class UserA(Document): class UserA(Document):
@ -612,7 +573,6 @@ class FieldTest(unittest.TestCase):
UserC.drop_collection() UserC.drop_collection()
Group.drop_collection() Group.drop_collection()
@unittest.skip('MapField not fully implemented')
def test_map_field_reference(self): def test_map_field_reference(self):
class User(Document): class User(Document):
@ -678,7 +638,6 @@ class FieldTest(unittest.TestCase):
User.drop_collection() User.drop_collection()
Group.drop_collection() Group.drop_collection()
@unittest.skip("not implemented")
def test_dict_field(self): def test_dict_field(self):
class UserA(Document): class UserA(Document):
@ -782,7 +741,6 @@ class FieldTest(unittest.TestCase):
UserC.drop_collection() UserC.drop_collection()
Group.drop_collection() Group.drop_collection()
@unittest.skip("not implemented")
def test_dict_field_no_field_inheritance(self): def test_dict_field_no_field_inheritance(self):
class UserA(Document): class UserA(Document):
@ -859,7 +817,6 @@ class FieldTest(unittest.TestCase):
UserA.drop_collection() UserA.drop_collection()
Group.drop_collection() Group.drop_collection()
@unittest.skip("select_related currently doesn't dereference lists")
def test_generic_reference_map_field(self): def test_generic_reference_map_field(self):
class UserA(Document): class UserA(Document):
@ -985,7 +942,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(root.children, [company]) self.assertEqual(root.children, [company])
self.assertEqual(company.parents, [root]) self.assertEqual(company.parents, [root])
@unittest.skip("not implemented")
def test_dict_in_dbref_instance(self): def test_dict_in_dbref_instance(self):
class Person(Document): class Person(Document):
@ -1165,37 +1121,32 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2) self.assertEqual(q, 2)
def test_tuples_as_tuples(self): def test_objectid_reference_across_databases(self):
""" # mongoenginetest - Is default connection alias from setUp()
Ensure that tuples remain tuples when they are # Register Aliases
inside a ComplexBaseField register_connection('testdb-1', 'mongoenginetest2')
"""
from mongoengine.base import BaseField
class EnumField(BaseField): class User(Document):
name = StringField()
meta = {"db_alias": "testdb-1"}
def __init__(self, **kwargs): class Book(Document):
super(EnumField, self).__init__(**kwargs) name = StringField()
author = ReferenceField(User)
def to_mongo(self, value): # Drops
return value User.drop_collection()
Book.drop_collection()
def to_python(self, value): user = User(name="Ross").save()
return tuple(value) Book(name="MongoEngine for pros", author=user).save()
class TestDoc(Document): # Can't use query_counter across databases - so test the _data object
items = ListField(EnumField()) book = Book.objects.first()
self.assertFalse(isinstance(book._data['author'], User))
TestDoc.drop_collection() book.select_related()
tuples = [(100, 'Testing')] self.assertTrue(isinstance(book._data['author'], User))
doc = TestDoc()
doc.items = tuples
doc.save()
x = TestDoc.objects().get()
self.assertTrue(x is not None)
self.assertTrue(len(x.items) == 1)
self.assertTrue(tuple(x.items[0]) in tuples)
self.assertTrue(x.items[0] in tuples)
def test_non_ascii_pk(self): def test_non_ascii_pk(self):
""" """
@ -1220,6 +1171,30 @@ class FieldTest(unittest.TestCase):
self.assertEqual(2, len([brand for bg in brand_groups for brand in bg.brands])) self.assertEqual(2, len([brand for bg in brand_groups for brand in bg.brands]))
def test_dereferencing_embedded_listfield_referencefield(self):
class Tag(Document):
meta = {'collection': 'tags'}
name = StringField()
class Post(EmbeddedDocument):
body = StringField()
tags = ListField(ReferenceField("Tag", dbref=True))
class Page(Document):
meta = {'collection': 'pages'}
tags = ListField(ReferenceField("Tag", dbref=True))
posts = ListField(EmbeddedDocumentField(Post))
Tag.drop_collection()
Page.drop_collection()
tag = Tag(name='test').save()
post = Post(body='test body', tags=[tag])
Page(tags=[tag], posts=[post]).save()
page = Page.objects.first()
self.assertEqual(page.tags[0], page.posts[0].tags[0])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -2,48 +2,44 @@ import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import unittest import unittest
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from mongoengine.python_support import PY3
from mongoengine import * from mongoengine import *
from mongoengine.django.shortcuts import get_document_or_404
from django.http import Http404
from django.template import Context, Template
from django.conf import settings
from django.core.paginator import Paginator
settings.configure(
USE_TZ=True,
INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'),
AUTH_USER_MODEL=('mongo_auth.MongoUser'),
)
try: try:
from mongoengine.django.shortcuts import get_document_or_404 from django.contrib.auth import authenticate, get_user_model
from mongoengine.django.auth import User
from django.http import Http404 from mongoengine.django.mongo_auth.models import (
from django.template import Context, Template MongoUser,
from django.conf import settings MongoUserManager,
from django.core.paginator import Paginator get_user_document,
settings.configure(
USE_TZ=True,
INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'),
AUTH_USER_MODEL=('mongo_auth.MongoUser'),
) )
DJ15 = True
try: except Exception:
from django.contrib.auth import authenticate, get_user_model DJ15 = False
from mongoengine.django.auth import User from django.contrib.sessions.tests import SessionTestsMixin
from mongoengine.django.mongo_auth.models import MongoUser, MongoUserManager from mongoengine.django.sessions import SessionStore, MongoSession
DJ15 = True
except Exception:
DJ15 = False
from django.contrib.sessions.tests import SessionTestsMixin
from mongoengine.django.sessions import SessionStore, MongoSession
except Exception, err:
if PY3:
SessionTestsMixin = type # dummy value so no error
SessionStore = None # dummy value so no error
else:
raise err
from datetime import tzinfo, timedelta from datetime import tzinfo, timedelta
ZERO = timedelta(0) ZERO = timedelta(0)
class FixedOffset(tzinfo): class FixedOffset(tzinfo):
"""Fixed offset in minutes east from UTC.""" """Fixed offset in minutes east from UTC."""
def __init__(self, offset, name): def __init__(self, offset, name):
self.__offset = timedelta(minutes = offset) self.__offset = timedelta(minutes=offset)
self.__name = name self.__name = name
def utcoffset(self, dt): def utcoffset(self, dt):
@ -70,8 +66,6 @@ def activate_timezone(tz):
class QuerySetTest(unittest.TestCase): class QuerySetTest(unittest.TestCase):
def setUp(self): def setUp(self):
if PY3:
raise SkipTest('django does not have Python 3 support')
connect(db='mongoenginetest') connect(db='mongoenginetest')
class Person(Document): class Person(Document):
@ -173,6 +167,8 @@ class QuerySetTest(unittest.TestCase):
class Note(Document): class Note(Document):
text = StringField() text = StringField()
Note.drop_collection()
for i in xrange(1, 101): for i in xrange(1, 101):
Note(name="Note: %s" % i).save() Note(name="Note: %s" % i).save()
@ -223,8 +219,6 @@ class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase):
backend = SessionStore backend = SessionStore
def setUp(self): def setUp(self):
if PY3:
raise SkipTest('django does not have Python 3 support')
connect(db='mongoenginetest') connect(db='mongoenginetest')
MongoSession.drop_collection() MongoSession.drop_collection()
super(MongoDBSessionTest, self).setUp() super(MongoDBSessionTest, self).setUp()
@ -262,17 +256,18 @@ class MongoAuthTest(unittest.TestCase):
} }
def setUp(self): def setUp(self):
if PY3:
raise SkipTest('django does not have Python 3 support')
if not DJ15: if not DJ15:
raise SkipTest('mongo_auth requires Django 1.5') raise SkipTest('mongo_auth requires Django 1.5')
connect(db='mongoenginetest') connect(db='mongoenginetest')
User.drop_collection() User.drop_collection()
super(MongoAuthTest, self).setUp() super(MongoAuthTest, self).setUp()
def test_user_model(self): def test_get_user_model(self):
self.assertEqual(get_user_model(), MongoUser) self.assertEqual(get_user_model(), MongoUser)
def test_get_user_document(self):
self.assertEqual(get_user_document(), User)
def test_user_manager(self): def test_user_manager(self):
manager = get_user_model()._default_manager manager = get_user_model()._default_manager
self.assertTrue(isinstance(manager, MongoUserManager)) self.assertTrue(isinstance(manager, MongoUserManager))

View file

@ -30,10 +30,28 @@ class SignalTests(unittest.TestCase):
def __unicode__(self): def __unicode__(self):
return self.name return self.name
@classmethod
def pre_init(cls, sender, document, *args, **kwargs):
signal_output.append('pre_init signal, %s' % cls.__name__)
signal_output.append(str(kwargs['values']))
@classmethod
def post_init(cls, sender, document, **kwargs):
signal_output.append('post_init signal, %s' % document)
@classmethod @classmethod
def pre_save(cls, sender, document, **kwargs): def pre_save(cls, sender, document, **kwargs):
signal_output.append('pre_save signal, %s' % document) signal_output.append('pre_save signal, %s' % document)
@classmethod
def pre_save_post_validation(cls, sender, document, **kwargs):
signal_output.append('pre_save_post_validation signal, %s' % document)
if 'created' in kwargs:
if kwargs['created']:
signal_output.append('Is created')
else:
signal_output.append('Is updated')
@classmethod @classmethod
def post_save(cls, sender, document, **kwargs): def post_save(cls, sender, document, **kwargs):
signal_output.append('post_save signal, %s' % document) signal_output.append('post_save signal, %s' % document)
@ -100,7 +118,10 @@ class SignalTests(unittest.TestCase):
# Save up the number of connected signals so that we can check at the # Save up the number of connected signals so that we can check at the
# end that all the signals we register get properly unregistered # end that all the signals we register get properly unregistered
self.pre_signals = ( self.pre_signals = (
len(signals.pre_init.receivers),
len(signals.post_init.receivers),
len(signals.pre_save.receivers), len(signals.pre_save.receivers),
len(signals.pre_save_post_validation.receivers),
len(signals.post_save.receivers), len(signals.post_save.receivers),
len(signals.pre_delete.receivers), len(signals.pre_delete.receivers),
len(signals.post_delete.receivers), len(signals.post_delete.receivers),
@ -108,7 +129,10 @@ class SignalTests(unittest.TestCase):
len(signals.post_bulk_insert.receivers), len(signals.post_bulk_insert.receivers),
) )
signals.pre_init.connect(Author.pre_init, sender=Author)
signals.post_init.connect(Author.post_init, sender=Author)
signals.pre_save.connect(Author.pre_save, sender=Author) signals.pre_save.connect(Author.pre_save, sender=Author)
signals.pre_save_post_validation.connect(Author.pre_save_post_validation, sender=Author)
signals.post_save.connect(Author.post_save, sender=Author) signals.post_save.connect(Author.post_save, sender=Author)
signals.pre_delete.connect(Author.pre_delete, sender=Author) signals.pre_delete.connect(Author.pre_delete, sender=Author)
signals.post_delete.connect(Author.post_delete, sender=Author) signals.post_delete.connect(Author.post_delete, sender=Author)
@ -121,9 +145,12 @@ class SignalTests(unittest.TestCase):
signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId)
def tearDown(self): def tearDown(self):
signals.pre_init.disconnect(self.Author.pre_init)
signals.post_init.disconnect(self.Author.post_init)
signals.post_delete.disconnect(self.Author.post_delete) signals.post_delete.disconnect(self.Author.post_delete)
signals.pre_delete.disconnect(self.Author.pre_delete) signals.pre_delete.disconnect(self.Author.pre_delete)
signals.post_save.disconnect(self.Author.post_save) signals.post_save.disconnect(self.Author.post_save)
signals.pre_save_post_validation.disconnect(self.Author.pre_save_post_validation)
signals.pre_save.disconnect(self.Author.pre_save) signals.pre_save.disconnect(self.Author.pre_save)
signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert)
signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert)
@ -135,7 +162,10 @@ class SignalTests(unittest.TestCase):
# Check that all our signals got disconnected properly. # Check that all our signals got disconnected properly.
post_signals = ( post_signals = (
len(signals.pre_init.receivers),
len(signals.post_init.receivers),
len(signals.pre_save.receivers), len(signals.pre_save.receivers),
len(signals.pre_save_post_validation.receivers),
len(signals.post_save.receivers), len(signals.post_save.receivers),
len(signals.pre_delete.receivers), len(signals.pre_delete.receivers),
len(signals.post_delete.receivers), len(signals.post_delete.receivers),
@ -150,6 +180,9 @@ class SignalTests(unittest.TestCase):
def test_model_signals(self): def test_model_signals(self):
""" Model saves should throw some signals. """ """ Model saves should throw some signals. """
def create_author():
self.Author(name='Bill Shakespeare')
def bulk_create_author_with_load(): def bulk_create_author_with_load():
a1 = self.Author(name='Bill Shakespeare') a1 = self.Author(name='Bill Shakespeare')
self.Author.objects.insert([a1], load_bulk=True) self.Author.objects.insert([a1], load_bulk=True)
@ -158,9 +191,17 @@ class SignalTests(unittest.TestCase):
a1 = self.Author(name='Bill Shakespeare') a1 = self.Author(name='Bill Shakespeare')
self.Author.objects.insert([a1], load_bulk=False) self.Author.objects.insert([a1], load_bulk=False)
self.assertEqual(self.get_signal_output(create_author), [
"pre_init signal, Author",
"{'name': 'Bill Shakespeare'}",
"post_init signal, Bill Shakespeare",
])
a1 = self.Author(name='Bill Shakespeare') a1 = self.Author(name='Bill Shakespeare')
self.assertEqual(self.get_signal_output(a1.save), [ self.assertEqual(self.get_signal_output(a1.save), [
"pre_save signal, Bill Shakespeare", "pre_save signal, Bill Shakespeare",
"pre_save_post_validation signal, Bill Shakespeare",
"Is created",
"post_save signal, Bill Shakespeare", "post_save signal, Bill Shakespeare",
"Is created" "Is created"
]) ])
@ -169,6 +210,8 @@ class SignalTests(unittest.TestCase):
a1.name = 'William Shakespeare' a1.name = 'William Shakespeare'
self.assertEqual(self.get_signal_output(a1.save), [ self.assertEqual(self.get_signal_output(a1.save), [
"pre_save signal, William Shakespeare", "pre_save signal, William Shakespeare",
"pre_save_post_validation signal, William Shakespeare",
"Is updated",
"post_save signal, William Shakespeare", "post_save signal, William Shakespeare",
"Is updated" "Is updated"
]) ])
@ -180,13 +223,18 @@ class SignalTests(unittest.TestCase):
signal_output = self.get_signal_output(bulk_create_author_with_load) signal_output = self.get_signal_output(bulk_create_author_with_load)
self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [ # The output of this signal is not entirely deterministic. The reloaded
"pre_bulk_insert signal, [<Author: Bill Shakespeare>]", # object will have an object ID. Hence, we only check part of the output
"post_bulk_insert signal, [<Author: Bill Shakespeare>]", self.assertEqual(signal_output[3],
"Is loaded", "pre_bulk_insert signal, [<Author: Bill Shakespeare>]")
]) self.assertEqual(signal_output[-2:],
["post_bulk_insert signal, [<Author: Bill Shakespeare>]",
"Is loaded",])
self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [ self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [
"pre_init signal, Author",
"{'name': 'Bill Shakespeare'}",
"post_init signal, Bill Shakespeare",
"pre_bulk_insert signal, [<Author: Bill Shakespeare>]", "pre_bulk_insert signal, [<Author: Bill Shakespeare>]",
"post_bulk_insert signal, [<Author: Bill Shakespeare>]", "post_bulk_insert signal, [<Author: Bill Shakespeare>]",
"Not loaded", "Not loaded",