diff --git a/.travis.yml b/.travis.yml index 1b9f5b7..4395107 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,14 +7,18 @@ python: - "3.2" - "3.3" env: - - PYMONGO=dev - - PYMONGO=2.5 - - PYMONGO=2.4.2 + - PYMONGO=dev DJANGO=1.5.1 + - PYMONGO=dev DJANGO=1.4.2 + - PYMONGO=2.5 DJANGO=1.5.1 + - PYMONGO=2.5 DJANGO=1.4.2 + - PYMONGO=3.2 DJANGO=1.5.1 + - PYMONGO=3.3 DJANGO=1.5.1 install: - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi + - pip install https://pypi.python.org/packages/source/p/python-dateutil/python-dateutil-2.1.tar.gz#md5=1534bb15cf311f07afaa3aacba1c028b - python setup.py install script: - python setup.py test @@ -23,4 +27,3 @@ notifications: branches: only: - master - - "0.8" diff --git a/AUTHORS b/AUTHORS index e388a04..f043207 100644 --- a/AUTHORS +++ b/AUTHORS @@ -16,8 +16,6 @@ Dervived from the git logs, inevitably incomplete but all of whom and others have submitted patches, reported bugs and generally helped make MongoEngine that much better: - * Harry Marr - * Ross Lawley * blackbrrr * Florian Schlachter * Vincent Driessen @@ -25,7 +23,7 @@ that much better: * flosch * Deepak Thukral * Colin Howe - * Wilson Júnior + * Wilson Júnior (https://github.com/wpjunior) * Alistair Roche * Dan Crosta * Viktor Kerkez @@ -77,7 +75,7 @@ that much better: * Adam Parrish * jpfarias * jonrscott - * Alice Zoë Bevan-McGregor + * Alice Zoë Bevan-McGregor (https://github.com/amcgregor/) * Stephen Young * tkloc * aid @@ -157,4 +155,27 @@ that much better: * Kenneth Falck * Lukasz Balcerzak * Nicolas Cortot - + * Alex (https://github.com/kelsta) + * Jin Zhang + * Daniel Axtens + * Leo-Naeka + * Ryan Witt (https://github.com/ryanwitt) + * Jiequan (https://github.com/Jiequan) + * hensom (https://github.com/hensom) + * zhy0216 (https://github.com/zhy0216) + * istinspring (https://github.com/istinspring) + * Massimo Santini (https://github.com/mapio) + * Nigel McNie (https://github.com/nigelmcnie) + * ygbourhis (https://github.com/ygbourhis) + * Bob Dickinson (https://github.com/BobDickinson) + * Michael Bartnett (https://github.com/michaelbartnett) + * Alon Horev (https://github.com/alonho) + * Kelvin Hammond (https://github.com/kelvinhammond) + * Jatin- (https://github.com/jatin-) + * Paul Uithol (https://github.com/PaulUithol) + * Thom Knowles (https://github.com/fleat) + * Paul (https://github.com/squamous) + * Olivier Cortès (https://github.com/Karmak23) + * crazyzubr (https://github.com/crazyzubr) + * FrankSomething (https://github.com/FrankSomething) + * Alexandr Morozov (https://github.com/LK4D4) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 9688339..8754896 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -20,7 +20,7 @@ post to the `user group ` Supported Interpreters ---------------------- -PyMongo supports CPython 2.5 and newer. Language +MongoEngine supports CPython 2.6 and newer. Language features not supported by all interpreters can not be used. Please also ensure that your code is properly converted by `2to3 `_ for Python 3 support. @@ -46,7 +46,7 @@ General Guidelines - Write tests and make sure they pass (make sure you have a mongod running on the default port, then execute ``python setup.py test`` from the cmd line to run the test suite). -- Add yourself to AUTHORS.rst :) +- Add yourself to AUTHORS :) Documentation ------------- diff --git a/README.rst b/README.rst index 5eab502..ea4b505 100644 --- a/README.rst +++ b/README.rst @@ -26,7 +26,7 @@ setup.py install``. Dependencies ============ -- pymongo 2.1.1+ +- pymongo 2.5+ - sphinx (optional - for documentation generation) Examples diff --git a/benchmark.py b/benchmark.py index 0197e1d..16b2fd4 100644 --- a/benchmark.py +++ b/benchmark.py @@ -86,17 +86,43 @@ def main(): ---------------------------------------------------------------------------------------------------- Creating 10000 dictionaries - MongoEngine, force=True 8.36906409264 + 0.8.X + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - Pymongo + 3.69964408875 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - Pymongo write_concern={"w": 0} + 3.5526599884 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine + 7.00959801674 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries without continual assign - MongoEngine + 5.60943293571 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade=True + 6.715102911 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True + 5.50644683838 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False + 4.69851183891 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False + 4.68946313858 + ---------------------------------------------------------------------------------------------------- """ setup = """ -from pymongo import Connection -connection = Connection() +from pymongo import MongoClient +connection = MongoClient() connection.drop_database('timeit_test') """ stmt = """ -from pymongo import Connection -connection = Connection() +from pymongo import MongoClient +connection = MongoClient() db = connection.timeit_test noddy = db.noddy @@ -106,7 +132,7 @@ for i in xrange(10000): for j in range(20): example['fields']["key"+str(j)] = "value "+str(j) - noddy.insert(example) + noddy.save(example) myNoddys = noddy.find() [n for n in myNoddys] # iterate @@ -117,9 +143,32 @@ myNoddys = noddy.find() t = timeit.Timer(stmt=stmt, setup=setup) print t.timeit(1) + stmt = """ +from pymongo import MongoClient +connection = MongoClient() + +db = connection.timeit_test +noddy = db.noddy + +for i in xrange(10000): + example = {'fields': {}} + for j in range(20): + example['fields']["key"+str(j)] = "value "+str(j) + + noddy.save(example, write_concern={"w": 0}) + +myNoddys = noddy.find() +[n for n in myNoddys] # iterate +""" + + print "-" * 100 + print """Creating 10000 dictionaries - Pymongo write_concern={"w": 0}""" + t = timeit.Timer(stmt=stmt, setup=setup) + print t.timeit(1) + setup = """ -from pymongo import Connection -connection = Connection() +from pymongo import MongoClient +connection = MongoClient() connection.drop_database('timeit_test') connection.disconnect() @@ -149,33 +198,18 @@ myNoddys = Noddy.objects() stmt = """ for i in xrange(10000): noddy = Noddy() + fields = {} for j in range(20): - noddy.fields["key"+str(j)] = "value "+str(j) - noddy.save(safe=False, validate=False) + fields["key"+str(j)] = "value "+str(j) + noddy.fields = fields + noddy.save() myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, safe=False, validate=False""" - t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) - - - stmt = """ -for i in xrange(10000): - noddy = Noddy() - for j in range(20): - noddy.fields["key"+str(j)] = "value "+str(j) - noddy.save(safe=False, validate=False, cascade=False) - -myNoddys = Noddy.objects() -[n for n in myNoddys] # iterate -""" - - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False""" + print """Creating 10000 dictionaries without continual assign - MongoEngine""" t = timeit.Timer(stmt=stmt, setup=setup) print t.timeit(1) @@ -184,16 +218,65 @@ for i in xrange(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) - noddy.save(force_insert=True, safe=False, validate=False, cascade=False) + noddy.save(write_concern={"w": 0}, cascade=True) myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, force=True""" + print """Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True""" t = timeit.Timer(stmt=stmt, setup=setup) print t.timeit(1) + stmt = """ +for i in xrange(10000): + noddy = Noddy() + for j in range(20): + noddy.fields["key"+str(j)] = "value "+str(j) + noddy.save(write_concern={"w": 0}, validate=False, cascade=True) + +myNoddys = Noddy.objects() +[n for n in myNoddys] # iterate +""" + + print "-" * 100 + print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True""" + t = timeit.Timer(stmt=stmt, setup=setup) + print t.timeit(1) + + stmt = """ +for i in xrange(10000): + noddy = Noddy() + for j in range(20): + noddy.fields["key"+str(j)] = "value "+str(j) + noddy.save(validate=False, write_concern={"w": 0}) + +myNoddys = Noddy.objects() +[n for n in myNoddys] # iterate +""" + + print "-" * 100 + print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False""" + t = timeit.Timer(stmt=stmt, setup=setup) + print t.timeit(1) + + stmt = """ +for i in xrange(10000): + noddy = Noddy() + for j in range(20): + noddy.fields["key"+str(j)] = "value "+str(j) + noddy.save(force_insert=True, write_concern={"w": 0}, validate=False) + +myNoddys = Noddy.objects() +[n for n in myNoddys] # iterate +""" + + print "-" * 100 + print """Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False""" + t = timeit.Timer(stmt=stmt, setup=setup) + print t.timeit(1) + + if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/docs/_themes/nature/static/nature.css_t b/docs/_themes/nature/static/nature.css_t index 03b0379..337760b 100644 --- a/docs/_themes/nature/static/nature.css_t +++ b/docs/_themes/nature/static/nature.css_t @@ -2,11 +2,15 @@ * Sphinx stylesheet -- default theme * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ - + @import url("basic.css"); - + +#changelog p.first {margin-bottom: 0 !important;} +#changelog p {margin-top: 0 !important; + margin-bottom: 0 !important;} + /* -- page layout ----------------------------------------------------------- */ - + body { font-family: Arial, sans-serif; font-size: 100%; @@ -28,18 +32,18 @@ div.bodywrapper { hr{ border: 1px solid #B1B4B6; } - + div.document { background-color: #eee; } - + div.body { background-color: #ffffff; color: #3E4349; padding: 0 30px 30px 30px; font-size: 0.8em; } - + div.footer { color: #555; width: 100%; @@ -47,12 +51,12 @@ div.footer { text-align: center; font-size: 75%; } - + div.footer a { color: #444; text-decoration: underline; } - + div.related { background-color: #6BA81E; line-height: 32px; @@ -60,11 +64,11 @@ div.related { text-shadow: 0px 1px 0 #444; font-size: 0.80em; } - + div.related a { color: #E2F3CC; } - + div.sphinxsidebar { font-size: 0.75em; line-height: 1.5em; @@ -73,7 +77,7 @@ div.sphinxsidebar { div.sphinxsidebarwrapper{ padding: 20px 0; } - + div.sphinxsidebar h3, div.sphinxsidebar h4 { font-family: Arial, sans-serif; @@ -89,30 +93,30 @@ div.sphinxsidebar h4 { div.sphinxsidebar h4{ font-size: 1.1em; } - + div.sphinxsidebar h3 a { color: #444; } - - + + div.sphinxsidebar p { color: #888; padding: 5px 20px; } - + div.sphinxsidebar p.topless { } - + div.sphinxsidebar ul { margin: 10px 20px; padding: 0; color: #000; } - + div.sphinxsidebar a { color: #444; } - + div.sphinxsidebar input { border: 1px solid #ccc; font-family: sans-serif; @@ -122,19 +126,19 @@ div.sphinxsidebar input { div.sphinxsidebar input[type=text]{ margin-left: 20px; } - + /* -- body styles ----------------------------------------------------------- */ - + a { color: #005B81; text-decoration: none; } - + a:hover { color: #E32E00; text-decoration: underline; } - + div.body h1, div.body h2, div.body h3, @@ -149,30 +153,30 @@ div.body h6 { padding: 5px 0 5px 10px; text-shadow: 0px 1px 0 white } - + div.body h1 { border-top: 20px solid white; margin-top: 0; font-size: 200%; } div.body h2 { font-size: 150%; background-color: #C8D5E3; } div.body h3 { font-size: 120%; background-color: #D8DEE3; } div.body h4 { font-size: 110%; background-color: #D8DEE3; } div.body h5 { font-size: 100%; background-color: #D8DEE3; } div.body h6 { font-size: 100%; background-color: #D8DEE3; } - + a.headerlink { color: #c60f0f; font-size: 0.8em; padding: 0 4px 0 4px; text-decoration: none; } - + a.headerlink:hover { background-color: #c60f0f; color: white; } - + div.body p, div.body dd, div.body li { line-height: 1.5em; } - + div.admonition p.admonition-title + p { display: inline; } @@ -185,29 +189,29 @@ div.note { background-color: #eee; border: 1px solid #ccc; } - + div.seealso { background-color: #ffc; border: 1px solid #ff6; } - + div.topic { background-color: #eee; } - + div.warning { background-color: #ffe4e4; border: 1px solid #f66; } - + p.admonition-title { display: inline; } - + p.admonition-title:after { content: ":"; } - + pre { padding: 10px; background-color: White; @@ -219,7 +223,7 @@ pre { -webkit-box-shadow: 1px 1px 1px #d8d8d8; -moz-box-shadow: 1px 1px 1px #d8d8d8; } - + tt { background-color: #ecf0f3; color: #222; diff --git a/docs/apireference.rst b/docs/apireference.rst index 049cc30..9057de5 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -44,38 +44,61 @@ Context Managers Querying ======== -.. autoclass:: mongoengine.queryset.QuerySet - :members: +.. automodule:: mongoengine.queryset + :synopsis: Queryset level operations - .. automethod:: mongoengine.queryset.QuerySet.__call__ + .. autoclass:: mongoengine.queryset.QuerySet + :members: + :inherited-members: -.. autofunction:: mongoengine.queryset.queryset_manager + .. automethod:: QuerySet.__call__ + + .. autoclass:: mongoengine.queryset.QuerySetNoCache + :members: + + .. automethod:: mongoengine.queryset.QuerySetNoCache.__call__ + + .. autofunction:: mongoengine.queryset.queryset_manager Fields ====== -.. autoclass:: mongoengine.BinaryField -.. autoclass:: mongoengine.BooleanField -.. autoclass:: mongoengine.ComplexDateTimeField -.. autoclass:: mongoengine.DateTimeField -.. autoclass:: mongoengine.DecimalField -.. autoclass:: mongoengine.DictField -.. autoclass:: mongoengine.DynamicField -.. autoclass:: mongoengine.EmailField -.. autoclass:: mongoengine.EmbeddedDocumentField -.. autoclass:: mongoengine.FileField -.. autoclass:: mongoengine.FloatField -.. autoclass:: mongoengine.GenericEmbeddedDocumentField -.. autoclass:: mongoengine.GenericReferenceField -.. autoclass:: mongoengine.GeoPointField -.. autoclass:: mongoengine.ImageField -.. autoclass:: mongoengine.IntField -.. autoclass:: mongoengine.ListField -.. autoclass:: mongoengine.MapField -.. autoclass:: mongoengine.ObjectIdField -.. autoclass:: mongoengine.ReferenceField -.. autoclass:: mongoengine.SequenceField -.. autoclass:: mongoengine.SortedListField -.. autoclass:: mongoengine.StringField -.. autoclass:: mongoengine.URLField -.. autoclass:: mongoengine.UUIDField +.. autoclass:: mongoengine.base.fields.BaseField +.. autoclass:: mongoengine.fields.StringField +.. autoclass:: mongoengine.fields.URLField +.. autoclass:: mongoengine.fields.EmailField +.. autoclass:: mongoengine.fields.IntField +.. autoclass:: mongoengine.fields.LongField +.. autoclass:: mongoengine.fields.FloatField +.. autoclass:: mongoengine.fields.DecimalField +.. autoclass:: mongoengine.fields.BooleanField +.. autoclass:: mongoengine.fields.DateTimeField +.. autoclass:: mongoengine.fields.ComplexDateTimeField +.. autoclass:: mongoengine.fields.EmbeddedDocumentField +.. autoclass:: mongoengine.fields.GenericEmbeddedDocumentField +.. autoclass:: mongoengine.fields.DynamicField +.. autoclass:: mongoengine.fields.ListField +.. autoclass:: mongoengine.fields.SortedListField +.. autoclass:: mongoengine.fields.DictField +.. autoclass:: mongoengine.fields.MapField +.. autoclass:: mongoengine.fields.ReferenceField +.. autoclass:: mongoengine.fields.GenericReferenceField +.. autoclass:: mongoengine.fields.BinaryField +.. autoclass:: mongoengine.fields.FileField +.. autoclass:: mongoengine.fields.ImageField +.. autoclass:: mongoengine.fields.SequenceField +.. autoclass:: mongoengine.fields.ObjectIdField +.. autoclass:: mongoengine.fields.UUIDField +.. autoclass:: mongoengine.fields.GeoPointField +.. autoclass:: mongoengine.fields.PointField +.. autoclass:: mongoengine.fields.LineStringField +.. autoclass:: mongoengine.fields.PolygonField +.. autoclass:: mongoengine.fields.GridFSError +.. autoclass:: mongoengine.fields.GridFSProxy +.. autoclass:: mongoengine.fields.ImageGridFsProxy +.. autoclass:: mongoengine.fields.ImproperlyConfigured + +Misc +==== + +.. autofunction:: mongoengine.common._import_class diff --git a/docs/changelog.rst b/docs/changelog.rst index 4547000..926fb8a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,8 +2,107 @@ Changelog ========= -Changes in 0.8.X +Changes in 0.8.4 ================ +- Remove database name necessity in uri connection schema (#452) +- Fixed "$pull" semantics for nested ListFields (#447) +- Allow fields to be named the same as query operators (#445) +- Updated field filter logic - can now exclude subclass fields (#443) +- Fixed dereference issue with embedded listfield referencefields (#439) +- Fixed slice when using inheritance causing fields to be excluded (#437) +- Fixed ._get_db() attribute after a Document.switch_db() (#441) +- Dynamic Fields store and recompose Embedded Documents / Documents correctly (#449) +- Handle dynamic fieldnames that look like digits (#434) +- Added get_user_document and improve mongo_auth module (#423) +- Added str representation of GridFSProxy (#424) +- Update transform to handle docs erroneously passed to unset (#416) +- Fixed indexing - turn off _cls (#414) +- Fixed dereference threading issue in ComplexField.__get__ (#412) +- Fixed QuerySetNoCache.count() caching (#410) +- Don't follow references in _get_changed_fields (#422, #417) +- Allow args and kwargs to be passed through to_json (#420) + +Changes in 0.8.3 +================ +- Fixed EmbeddedDocuments with `id` also storing `_id` (#402) +- Added get_proxy_object helper to filefields (#391) +- Added QuerySetNoCache and QuerySet.no_cache() for lower memory consumption (#365) +- Fixed sum and average mapreduce dot notation support (#375, #376, #393) +- Fixed as_pymongo to return the id (#386) +- Document.select_related() now respects `db_alias` (#377) +- Reload uses shard_key if applicable (#384) +- Dynamic fields are ordered based on creation and stored in _fields_ordered (#396) + + **Potential breaking change:** http://docs.mongoengine.org/en/latest/upgrade.html#to-0-8-3 + +- Fixed pickling dynamic documents `_dynamic_fields` (#387) +- Fixed ListField setslice and delslice dirty tracking (#390) +- Added Django 1.5 PY3 support (#392) +- Added match ($elemMatch) support for EmbeddedDocuments (#379) +- Fixed weakref being valid after reload (#374) +- Fixed queryset.get() respecting no_dereference (#373) +- Added full_result kwarg to update (#380) + + + +Changes in 0.8.2 +================ +- Added compare_indexes helper (#361) +- Fixed cascading saves which weren't turned off as planned (#291) +- Fixed Datastructures so instances are a Document or EmbeddedDocument (#363) +- Improved cascading saves write performance (#361) +- Fixed ambiguity and differing behaviour regarding field defaults (#349) +- ImageFields now include PIL error messages if invalid error (#353) +- Added lock when calling doc.Delete() for when signals have no sender (#350) +- Reload forces read preference to be PRIMARY (#355) +- Querysets are now lest restrictive when querying duplicate fields (#332, #333) +- FileField now honouring db_alias (#341) +- Removed customised __set__ change tracking in ComplexBaseField (#344) +- Removed unused var in _get_changed_fields (#347) +- Added pre_save_post_validation signal (#345) +- DateTimeField now auto converts valid datetime isostrings into dates (#343) +- DateTimeField now uses dateutil for parsing if available (#343) +- Fixed Doc.objects(read_preference=X) not setting read preference (#352) +- Django session ttl index expiry fixed (#329) +- Fixed pickle.loads (#342) +- Documentation fixes + +Changes in 0.8.1 +================ +- Fixed Python 2.6 django auth importlib issue (#326) +- Fixed pickle unsaved document regression (#327) + +Changes in 0.8.0 +================ +- Fixed querying ReferenceField custom_id (#317) +- Fixed pickle issues with collections (#316) +- Added `get_next_value` preview for SequenceFields (#319) +- Added no_sub_classes context manager and queryset helper (#312) +- Querysets now utilises a local cache +- Changed __len__ behavour in the queryset (#247, #311) +- Fixed querying string versions of ObjectIds issue with ReferenceField (#307) +- Added $setOnInsert support for upserts (#308) +- Upserts now possible with just query parameters (#309) +- Upserting is the only way to ensure docs are saved correctly (#306) +- Fixed register_delete_rule inheritance issue +- Fix cloning of sliced querysets (#303) +- Fixed update_one write concern (#302) +- Updated minimum requirement for pymongo to 2.5 +- Add support for new geojson fields, indexes and queries (#299) +- If values cant be compared mark as changed (#287) +- Ensure as_pymongo() and to_json honour only() and exclude() (#293) +- Document serialization uses field order to ensure a strict order is set (#296) +- DecimalField now stores as float not string (#289) +- UUIDField now stores as a binary by default (#292) +- Added Custom User Model for Django 1.5 (#285) +- Cascading saves now default to off (#291) +- ReferenceField now store ObjectId's by default rather than DBRef (#290) +- Added ImageField support for inline replacements (#86) +- Added SequenceField.set_next_value(value) helper (#159) +- Updated .only() behaviour - now like exclude it is chainable (#202) +- Added with_limit_and_skip support to count() (#235) +- Objects queryset manager now inherited (#256) +- Updated connection to use MongoClient (#262, #274) - Fixed db_alias and inherited Documents (#143) - Documentation update for document errors (#124) - Deprecated `get_or_create` (#35) diff --git a/docs/code/tumblelog.py b/docs/code/tumblelog.py index 6ba1eee..0e40e89 100644 --- a/docs/code/tumblelog.py +++ b/docs/code/tumblelog.py @@ -45,7 +45,7 @@ print 'ALL POSTS' print for post in Post.objects: print post.title - print '=' * len(post.title) + print '=' * post.title.count() if isinstance(post, TextPost): print post.content diff --git a/docs/conf.py b/docs/conf.py index 3cfcef5..40c1f43 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -132,7 +132,11 @@ html_theme_path = ['_themes'] html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +html_sidebars = { + 'index': ['globaltoc.html', 'searchbox.html'], + '**': ['localtoc.html', 'relations.html', 'searchbox.html'] +} + # Additional templates that should be rendered to pages, maps page names to # template names. @@ -173,8 +177,8 @@ latex_paper_size = 'a4' # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'MongoEngine.tex', u'MongoEngine Documentation', - u'Harry Marr', 'manual'), + ('index', 'MongoEngine.tex', 'MongoEngine Documentation', + 'Ross Lawley', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -193,3 +197,6 @@ latex_documents = [ # If false, no module index is generated. #latex_use_modindex = True + +autoclass_content = 'both' + diff --git a/docs/django.rst b/docs/django.rst index 6f27b90..62d4dd4 100644 --- a/docs/django.rst +++ b/docs/django.rst @@ -1,8 +1,8 @@ -============================= -Using MongoEngine with Django -============================= +============== +Django Support +============== -.. note:: Updated to support Django 1.4 +.. note:: Updated to support Django 1.5 Connecting ========== @@ -10,7 +10,7 @@ In your **settings.py** file, ignore the standard database settings (unless you also plan to use the ORM in your project), and instead call :func:`~mongoengine.connect` somewhere in the settings module. -.. note :: +.. note:: If you are not using another Database backend you may need to add a dummy database backend to ``settings.py`` eg:: @@ -27,9 +27,9 @@ MongoEngine includes a Django authentication backend, which uses MongoDB. The :class:`~mongoengine.Document`, but implements most of the methods and attributes that the standard Django :class:`User` model does - so the two are moderately compatible. Using this backend will allow you to store users in -MongoDB but still use many of the Django authentication infrastucture (such as +MongoDB but still use many of the Django authentication infrastructure (such as the :func:`login_required` decorator and the :func:`authenticate` function). To -enable the MongoEngine auth backend, add the following to you **settings.py** +enable the MongoEngine auth backend, add the following to your **settings.py** file:: AUTHENTICATION_BACKENDS = ( @@ -42,27 +42,63 @@ The :mod:`~mongoengine.django.auth` module also contains a .. versionadded:: 0.1.3 +Custom User model +================= +Django 1.5 introduced `Custom user Models +`_ +which can be used as an alternative to the MongoEngine authentication backend. + +The main advantage of this option is that other components relying on +:mod:`django.contrib.auth` and supporting the new swappable user model are more +likely to work. For example, you can use the ``createsuperuser`` management +command as usual. + +To enable the custom User model in Django, add ``'mongoengine.django.mongo_auth'`` +in your ``INSTALLED_APPS`` and set ``'mongo_auth.MongoUser'`` as the custom user +user model to use. In your **settings.py** file you will have:: + + INSTALLED_APPS = ( + ... + 'django.contrib.auth', + 'mongoengine.django.mongo_auth', + ... + ) + + AUTH_USER_MODEL = 'mongo_auth.MongoUser' + +An additional ``MONGOENGINE_USER_DOCUMENT`` setting enables you to replace the +:class:`~mongoengine.django.auth.User` class with another class of your choice:: + + MONGOENGINE_USER_DOCUMENT = 'mongoengine.django.auth.User' + +The custom :class:`User` must be a :class:`~mongoengine.Document` class, but +otherwise has the same requirements as a standard custom user model, +as specified in the `Django Documentation +`_. +In particular, the custom class must define :attr:`USERNAME_FIELD` and +:attr:`REQUIRED_FIELDS` attributes. + Sessions ======== Django allows the use of different backend stores for its sessions. MongoEngine provides a MongoDB-based session backend for Django, which allows you to use -sessions in you Django application with just MongoDB. To enable the MongoEngine +sessions in your Django application with just MongoDB. To enable the MongoEngine session backend, ensure that your settings module has ``'django.contrib.sessions.middleware.SessionMiddleware'`` in the ``MIDDLEWARE_CLASSES`` field and ``'django.contrib.sessions'`` in your ``INSTALLED_APPS``. From there, all you need to do is add the following line -into you settings module:: +into your settings module:: SESSION_ENGINE = 'mongoengine.django.sessions' -Django provides session cookie, which expires after ```SESSION_COOKIE_AGE``` seconds, but doesnt delete cookie at sessions backend, so ``'mongoengine.django.sessions'`` supports `mongodb TTL +Django provides session cookie, which expires after ```SESSION_COOKIE_AGE``` seconds, but doesn't delete cookie at sessions backend, so ``'mongoengine.django.sessions'`` supports `mongodb TTL `_. .. versionadded:: 0.2.1 Storage ======= -With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`, +With MongoEngine's support for GridFS via the :class:`~mongoengine.fields.FileField`, it is useful to have a Django file storage backend that wraps this. The new storage module is called :class:`~mongoengine.django.storage.GridFSStorage`. Using it is very similar to using the default FileSystemStorage.:: @@ -92,7 +128,7 @@ appended to the filename until the generated filename doesn't exist. The >>> fs.listdir() ([], [u'hello.txt']) -All files will be saved and retrieved in GridFS via the :class::`FileDocument` +All files will be saved and retrieved in GridFS via the :class:`FileDocument` document, allowing easy access to the files without the GridFSStorage backend.:: @@ -101,3 +137,36 @@ backend.:: [] .. versionadded:: 0.4 + +Shortcuts +========= +Inspired by the `Django shortcut get_object_or_404 +`_, +the :func:`~mongoengine.django.shortcuts.get_document_or_404` method returns +a document or raises an Http404 exception if the document does not exist:: + + from mongoengine.django.shortcuts import get_document_or_404 + + admin_user = get_document_or_404(User, username='root') + +The first argument may be a Document or QuerySet object. All other passed arguments +and keyword arguments are used in the query:: + + foo_email = get_document_or_404(User.objects.only('email'), username='foo', is_active=True).email + +.. note:: Like with :func:`get`, a MultipleObjectsReturned will be raised if more than one + object is found. + + +Also inspired by the `Django shortcut get_list_or_404 +`_, +the :func:`~mongoengine.django.shortcuts.get_list_or_404` method returns a list of +documents or raises an Http404 exception if the list is empty:: + + from mongoengine.django.shortcuts import get_list_or_404 + + active_users = get_list_or_404(User, is_active=True) + +The first argument may be a Document or QuerySet object. All other passed +arguments and keyword arguments are used to filter the query. + diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index ebd61a9..f681aad 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -6,34 +6,40 @@ Connecting to MongoDB To connect to a running instance of :program:`mongod`, use the :func:`~mongoengine.connect` function. The first argument is the name of the -database to connect to. If the database does not exist, it will be created. If -the database requires authentication, :attr:`username` and :attr:`password` -arguments may be provided:: +database to connect to:: from mongoengine import connect - connect('project1', username='webapp', password='pwd123') + connect('project1') By default, MongoEngine assumes that the :program:`mongod` instance is running -on **localhost** on port **27017**. If MongoDB is running elsewhere, you may -provide :attr:`host` and :attr:`port` arguments to +on **localhost** on port **27017**. If MongoDB is running elsewhere, you should +provide the :attr:`host` and :attr:`port` arguments to :func:`~mongoengine.connect`:: connect('project1', host='192.168.1.35', port=12345) -Uri style connections are also supported as long as you include the database -name - just supply the uri as the :attr:`host` to +If the database requires authentication, :attr:`username` and :attr:`password` +arguments should be provided:: + + connect('project1', username='webapp', password='pwd123') + +Uri style connections are also supported - just supply the uri as +the :attr:`host` to :func:`~mongoengine.connect`:: connect('project1', host='mongodb://localhost/database_name') +Note that database name from uri has priority over name +in ::func:`~mongoengine.connect` + ReplicaSets =========== -MongoEngine now supports :func:`~pymongo.replica_set_connection.ReplicaSetConnection` +MongoEngine supports :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` to use them please use a URI style connection and provide the `replicaSet` name in the connection kwargs. -Read preferences are supported throught the connection or via individual +Read preferences are supported through the connection or via individual queries by passing the read_preference :: Bar.objects().read_preference(ReadPreference.PRIMARY) @@ -74,9 +80,13 @@ to point across databases and collections. Below is an example schema, using Switch Database Context Manager =============================== -Sometimes you might want to switch the database to query against for a class. +Sometimes you may want to switch the database to query against for a class +for example, archiving older data into a separate database for performance +reasons. + The :class:`~mongoengine.context_managers.switch_db` context manager allows -you to change the database alias for a class eg :: +you to change the database alias for a given class allowing quick and easy +access to the same User document across databases:: from mongoengine.context_managers import switch_db @@ -87,3 +97,6 @@ you to change the database alias for a class eg :: with switch_db(User, 'archive-user-db') as User: User(name="Ross").save() # Saves the 'archive-user-db' + +.. note:: Make sure any aliases have been registered with + :func:`~mongoengine.register_connection` before using the context manager. diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 350ba67..ba1af33 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -24,6 +24,9 @@ objects** as class attributes to the document class:: title = StringField(max_length=200, required=True) date_modified = DateTimeField(default=datetime.datetime.now) +As BSON (the binary format for storing data in mongodb) is order dependent, +documents are serialized based on their field order. + Dynamic document schemas ======================== One of the benefits of MongoDb is dynamic schemas for a collection, whilst data @@ -51,6 +54,7 @@ be saved :: There is one caveat on Dynamic Documents: fields cannot start with `_` +Dynamic fields are stored in creation order *after* any declared fields. Fields ====== @@ -62,31 +66,31 @@ not provided. Default values may optionally be a callable, which will be called to retrieve the value (such as in the above example). The field types available are as follows: -* :class:`~mongoengine.BinaryField` -* :class:`~mongoengine.BooleanField` -* :class:`~mongoengine.ComplexDateTimeField` -* :class:`~mongoengine.DateTimeField` -* :class:`~mongoengine.DecimalField` -* :class:`~mongoengine.DictField` -* :class:`~mongoengine.DynamicField` -* :class:`~mongoengine.EmailField` -* :class:`~mongoengine.EmbeddedDocumentField` -* :class:`~mongoengine.FileField` -* :class:`~mongoengine.FloatField` -* :class:`~mongoengine.GenericEmbeddedDocumentField` -* :class:`~mongoengine.GenericReferenceField` -* :class:`~mongoengine.GeoPointField` -* :class:`~mongoengine.ImageField` -* :class:`~mongoengine.IntField` -* :class:`~mongoengine.ListField` -* :class:`~mongoengine.MapField` -* :class:`~mongoengine.ObjectIdField` -* :class:`~mongoengine.ReferenceField` -* :class:`~mongoengine.SequenceField` -* :class:`~mongoengine.SortedListField` -* :class:`~mongoengine.StringField` -* :class:`~mongoengine.URLField` -* :class:`~mongoengine.UUIDField` +* :class:`~mongoengine.fields.BinaryField` +* :class:`~mongoengine.fields.BooleanField` +* :class:`~mongoengine.fields.ComplexDateTimeField` +* :class:`~mongoengine.fields.DateTimeField` +* :class:`~mongoengine.fields.DecimalField` +* :class:`~mongoengine.fields.DictField` +* :class:`~mongoengine.fields.DynamicField` +* :class:`~mongoengine.fields.EmailField` +* :class:`~mongoengine.fields.EmbeddedDocumentField` +* :class:`~mongoengine.fields.FileField` +* :class:`~mongoengine.fields.FloatField` +* :class:`~mongoengine.fields.GenericEmbeddedDocumentField` +* :class:`~mongoengine.fields.GenericReferenceField` +* :class:`~mongoengine.fields.GeoPointField` +* :class:`~mongoengine.fields.ImageField` +* :class:`~mongoengine.fields.IntField` +* :class:`~mongoengine.fields.ListField` +* :class:`~mongoengine.fields.MapField` +* :class:`~mongoengine.fields.ObjectIdField` +* :class:`~mongoengine.fields.ReferenceField` +* :class:`~mongoengine.fields.SequenceField` +* :class:`~mongoengine.fields.SortedListField` +* :class:`~mongoengine.fields.StringField` +* :class:`~mongoengine.fields.URLField` +* :class:`~mongoengine.fields.UUIDField` Field arguments --------------- @@ -96,9 +100,6 @@ arguments can be set on all fields: :attr:`db_field` (Default: None) The MongoDB field name. -:attr:`name` (Default: None) - The mongoengine field name. - :attr:`required` (Default: False) If set to True and the field is not set on the document instance, a :class:`~mongoengine.ValidationError` will be raised when the document is @@ -110,7 +111,7 @@ arguments can be set on all fields: The definion of default parameters follow `the general rules on Python `__, which means that some care should be taken when dealing with default mutable objects - (like in :class:`~mongoengine.ListField` or :class:`~mongoengine.DictField`):: + (like in :class:`~mongoengine.fields.ListField` or :class:`~mongoengine.fields.DictField`):: class ExampleFirst(Document): # Default an empty list @@ -125,6 +126,7 @@ arguments can be set on all fields: # instead to just an object values = ListField(IntField(), default=[1,2,3]) + .. note:: Unsetting a field with a default value will revert back to the default. :attr:`unique` (Default: False) When True, no documents in the collection will have the same value for this @@ -172,8 +174,8 @@ arguments can be set on all fields: List fields ----------- MongoDB allows the storage of lists of items. To add a list of items to a -:class:`~mongoengine.Document`, use the :class:`~mongoengine.ListField` field -type. :class:`~mongoengine.ListField` takes another field object as its first +:class:`~mongoengine.Document`, use the :class:`~mongoengine.fields.ListField` field +type. :class:`~mongoengine.fields.ListField` takes another field object as its first argument, which specifies which type elements may be stored within the list:: class Page(Document): @@ -191,7 +193,7 @@ inherit from :class:`~mongoengine.EmbeddedDocument` rather than content = StringField() To embed the document within another document, use the -:class:`~mongoengine.EmbeddedDocumentField` field type, providing the embedded +:class:`~mongoengine.fields.EmbeddedDocumentField` field type, providing the embedded document class as the first argument:: class Page(Document): @@ -206,7 +208,7 @@ Dictionary Fields Often, an embedded document may be used instead of a dictionary -- generally this is recommended as dictionaries don't support validation or custom field types. However, sometimes you will not know the structure of what you want to -store; in this situation a :class:`~mongoengine.DictField` is appropriate:: +store; in this situation a :class:`~mongoengine.fields.DictField` is appropriate:: class SurveyResponse(Document): date = DateTimeField() @@ -224,7 +226,7 @@ other objects, so are the most flexible field type available. Reference fields ---------------- References may be stored to other documents in the database using the -:class:`~mongoengine.ReferenceField`. Pass in another document class as the +:class:`~mongoengine.fields.ReferenceField`. Pass in another document class as the first argument to the constructor, then simply assign document objects to the field:: @@ -245,9 +247,9 @@ field:: The :class:`User` object is automatically turned into a reference behind the scenes, and dereferenced when the :class:`Page` object is retrieved. -To add a :class:`~mongoengine.ReferenceField` that references the document +To add a :class:`~mongoengine.fields.ReferenceField` that references the document being defined, use the string ``'self'`` in place of the document class as the -argument to :class:`~mongoengine.ReferenceField`'s constructor. To reference a +argument to :class:`~mongoengine.fields.ReferenceField`'s constructor. To reference a document that has not yet been defined, use the name of the undefined document as the constructor's argument:: @@ -325,7 +327,7 @@ Its value can take any of the following constants: :const:`mongoengine.PULL` Removes the reference to the object (using MongoDB's "pull" operation) from any object's fields of - :class:`~mongoengine.ListField` (:class:`~mongoengine.ReferenceField`). + :class:`~mongoengine.fields.ListField` (:class:`~mongoengine.fields.ReferenceField`). .. warning:: @@ -352,7 +354,7 @@ Its value can take any of the following constants: Generic reference fields '''''''''''''''''''''''' A second kind of reference field also exists, -:class:`~mongoengine.GenericReferenceField`. This allows you to reference any +:class:`~mongoengine.fields.GenericReferenceField`. This allows you to reference any kind of :class:`~mongoengine.Document`, and hence doesn't take a :class:`~mongoengine.Document` subclass as a constructor argument:: @@ -376,15 +378,15 @@ kind of :class:`~mongoengine.Document`, and hence doesn't take a .. note:: - Using :class:`~mongoengine.GenericReferenceField`\ s is slightly less - efficient than the standard :class:`~mongoengine.ReferenceField`\ s, so if + Using :class:`~mongoengine.fields.GenericReferenceField`\ s is slightly less + efficient than the standard :class:`~mongoengine.fields.ReferenceField`\ s, so if you will only be referencing one document type, prefer the standard - :class:`~mongoengine.ReferenceField`. + :class:`~mongoengine.fields.ReferenceField`. Uniqueness constraints ---------------------- MongoEngine allows you to specify that a field should be unique across a -collection by providing ``unique=True`` to a :class:`~mongoengine.Field`\ 's +collection by providing ``unique=True`` to a :class:`~mongoengine.fields.Field`\ 's constructor. If you try to save a document that has the same value for a unique field as a document that is already in the database, a :class:`~mongoengine.OperationError` will be raised. You may also specify @@ -399,7 +401,7 @@ either a single field name, or a list or tuple of field names:: Skipping Document validation on save ------------------------------------ You can also skip the whole document validation process by setting -``validate=False`` when caling the :meth:`~mongoengine.document.Document.save` +``validate=False`` when calling the :meth:`~mongoengine.document.Document.save` method:: class Recipient(Document): @@ -440,6 +442,8 @@ The following example shows a :class:`Log` document that will be limited to ip_address = StringField() meta = {'max_documents': 1000, 'max_size': 2000000} +.. defining-indexes_ + Indexes ======= @@ -448,8 +452,8 @@ by creating a list of index specifications called :attr:`indexes` in the :attr:`~mongoengine.Document.meta` dictionary, where an index specification may either be a single field name, a tuple containing multiple field names, or a dictionary containing a full index definition. A direction may be specified on -fields by prefixing the field name with a **+** or a **-** sign. Note that -direction only matters on multi-field indexes. :: +fields by prefixing the field name with a **+** (for ascending) or a **-** sign +(for descending). Note that direction only matters on multi-field indexes. :: class Page(Document): title = StringField() @@ -475,28 +479,89 @@ If a dictionary is passed then the following options are available: :attr:`unique` (Default: False) Whether the index should be unique. +:attr:`expireAfterSeconds` (Optional) + Allows you to automatically expire data from a collection by setting the + time in seconds to expire the a field. + .. note:: Inheritance adds extra fields indices see: :ref:`document-inheritance`. +Global index default options +---------------------------- + +There are a few top level defaults for all indexes that can be set:: + + class Page(Document): + title = StringField() + rating = StringField() + meta = { + 'index_options': {}, + 'index_background': True, + 'index_drop_dups': True, + 'index_cls': False + } + + +:attr:`index_options` (Optional) + Set any default index options - see the `full options list `_ + +:attr:`index_background` (Optional) + Set the default value for if an index should be indexed in the background + +:attr:`index_drop_dups` (Optional) + Set the default value for if an index should drop duplicates + +:attr:`index_cls` (Optional) + A way to turn off a specific index for _cls. + + Compound Indexes and Indexing sub documents ------------------------------------------- Compound indexes can be created by adding the Embedded field or dictionary field name to the index definition. -Sometimes its more efficient to index parts of Embeedded / dictionary fields, +Sometimes its more efficient to index parts of Embedded / dictionary fields, in this case use 'dot' notation to identify the value to index eg: `rank.title` Geospatial indexes ------------------ +The best geo index for mongodb is the new "2dsphere", which has an improved +spherical model and provides better performance and more options when querying. +The following fields will explicitly add a "2dsphere" index: + + - :class:`~mongoengine.fields.PointField` + - :class:`~mongoengine.fields.LineStringField` + - :class:`~mongoengine.fields.PolygonField` + +As "2dsphere" indexes can be part of a compound index, you may not want the +automatic index but would prefer a compound index. In this example we turn off +auto indexing and explicitly declare a compound index on ``location`` and ``datetime``:: + + class Log(Document): + location = PointField(auto_index=False) + datetime = DateTimeField() + + meta = { + 'indexes': [[("location", "2dsphere"), ("datetime", 1)]] + } + + +Pre MongoDB 2.4 Geo +''''''''''''''''''' + +.. note:: For MongoDB < 2.4 this is still current, however the new 2dsphere + index is a big improvement over the previous 2D model - so upgrading is + advised. + Geospatial indexes will be automatically created for all -:class:`~mongoengine.GeoPointField`\ s +:class:`~mongoengine.fields.GeoPointField`\ s It is also possible to explicitly define geospatial indexes. This is useful if you need to define a geospatial index on a subfield of a -:class:`~mongoengine.DictField` or a custom field that contains a +:class:`~mongoengine.fields.DictField` or a custom field that contains a point. To create a geospatial index you must prefix the field with the ***** sign. :: @@ -508,6 +573,35 @@ point. To create a geospatial index you must prefix the field with the ], } +Time To Live indexes +-------------------- + +A special index type that allows you to automatically expire data from a +collection after a given period. See the official +`ttl `_ +documentation for more information. A common usecase might be session data:: + + class Session(Document): + created = DateTimeField(default=datetime.now) + meta = { + 'indexes': [ + {'fields': ['created'], 'expireAfterSeconds': 3600} + ] + } + +.. warning:: TTL indexes happen on the MongoDB server and not in the application + code, therefore no signals will be fired on document deletion. + If you need signals to be fired on deletion, then you must handle the + deletion of Documents in your application code. + +Comparing Indexes +----------------- + +Use :func:`mongoengine.Document.compare_indexes` to compare actual indexes in +the database to those that your document definitions define. This is useful +for maintenance purposes and ensuring you have the correct indexes for your +schema. + Ordering ======== A default ordering can be specified for your @@ -595,7 +689,6 @@ document.:: .. note:: From 0.8 onwards you must declare :attr:`allow_inheritance` defaults to False, meaning you must set it to True to use inheritance. - Working with existing data -------------------------- As MongoEngine no longer defaults to needing :attr:`_cls` you can quickly and @@ -615,3 +708,25 @@ defining all possible field types. If you use :class:`~mongoengine.Document` and the database contains data that isn't defined then that data will be stored in the `document._data` dictionary. + +Abstract classes +================ + +If you want to add some extra functionality to a group of Document classes but +you don't need or want the overhead of inheritance you can use the +:attr:`abstract` attribute of :attr:`-mongoengine.Document.meta`. +This won't turn on :ref:`document-inheritance` but will allow you to keep your +code DRY:: + + class BaseDocument(Document): + meta = { + 'abstract': True, + } + def check_permissions(self): + ... + + class User(BaseDocument): + ... + +Now the User class will have access to the inherited `check_permissions` method +and won't store any of the extra `_cls` information. diff --git a/docs/guide/document-instances.rst b/docs/guide/document-instances.rst index e8e7d63..f9a6610 100644 --- a/docs/guide/document-instances.rst +++ b/docs/guide/document-instances.rst @@ -30,11 +30,14 @@ already exist, then any changes will be updated atomically. For example:: .. note:: - Changes to documents are tracked and on the whole perform `set` operations. + Changes to documents are tracked and on the whole perform ``set`` operations. - * ``list_field.pop(0)`` - *sets* the resulting list + * ``list_field.push(0)`` - *sets* the resulting list * ``del(list_field)`` - *unsets* whole list + With lists its preferable to use ``Doc.update(push__list_field=0)`` as + this stops the whole list being updated - stopping any race conditions. + .. seealso:: :ref:`guide-atomic-updates` @@ -68,11 +71,12 @@ document values for example:: Cascading Saves --------------- -If your document contains :class:`~mongoengine.ReferenceField` or -:class:`~mongoengine.GenericReferenceField` objects, then by default the -:meth:`~mongoengine.Document.save` method will automatically save any changes to -those objects as well. If this is not desired passing :attr:`cascade` as False -to the save method turns this feature off. +If your document contains :class:`~mongoengine.fields.ReferenceField` or +:class:`~mongoengine.fields.GenericReferenceField` objects, then by default the +:meth:`~mongoengine.Document.save` method will not save any changes to +those objects. If you want all references to also be saved also, noting each +save is a separate query, then passing :attr:`cascade` as True +to the save method will cascade any saves. Deleting documents ------------------ diff --git a/docs/guide/gridfs.rst b/docs/guide/gridfs.rst index 1125947..d81bb92 100644 --- a/docs/guide/gridfs.rst +++ b/docs/guide/gridfs.rst @@ -7,7 +7,7 @@ GridFS Writing ------- -GridFS support comes in the form of the :class:`~mongoengine.FileField` field +GridFS support comes in the form of the :class:`~mongoengine.fields.FileField` field object. This field acts as a file-like object and provides a couple of different ways of inserting and retrieving data. Arbitrary metadata such as content type can also be stored alongside the files. In the following example, @@ -27,7 +27,7 @@ a document is created to store details about animals, including a photo:: Retrieval --------- -So using the :class:`~mongoengine.FileField` is just like using any other +So using the :class:`~mongoengine.fields.FileField` is just like using any other field. The file can also be retrieved just as easily:: marmot = Animal.objects(genus='Marmota').first() @@ -37,7 +37,7 @@ field. The file can also be retrieved just as easily:: Streaming --------- -Streaming data into a :class:`~mongoengine.FileField` is achieved in a +Streaming data into a :class:`~mongoengine.fields.FileField` is achieved in a slightly different manner. First, a new file must be created by calling the :func:`new_file` method. Data can then be written using :func:`write`:: diff --git a/docs/guide/installing.rst b/docs/guide/installing.rst index f15d3db..e93f048 100644 --- a/docs/guide/installing.rst +++ b/docs/guide/installing.rst @@ -22,10 +22,10 @@ Alternatively, if you don't have setuptools installed, `download it from PyPi $ python setup.py install To use the bleeding-edge version of MongoEngine, you can get the source from -`GitHub `_ and install it as above: +`GitHub `_ and install it as above: .. code-block:: console - $ git clone git://github.com/hmarr/mongoengine + $ git clone git://github.com/mongoengine/mongoengine $ cd mongoengine $ python setup.py install diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 3279853..f50985b 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -15,11 +15,10 @@ fetch documents from the database:: .. note:: - Once the iteration finishes (when :class:`StopIteration` is raised), - :meth:`~mongoengine.queryset.QuerySet.rewind` will be called so that the - :class:`~mongoengine.queryset.QuerySet` may be iterated over again. The - results of the first iteration are *not* cached, so the database will be hit - each time the :class:`~mongoengine.queryset.QuerySet` is iterated over. + As of MongoEngine 0.8 the querysets utilise a local cache. So iterating + it multiple times will only cause a single query. If this is not the + desired behavour you can call :class:`~mongoengine.QuerySet.no_cache` + (version **0.8.3+**) to return a non-caching queryset. Filtering queries ================= @@ -65,6 +64,9 @@ Available operators are as follows: * ``size`` -- the size of the array is * ``exists`` -- value for field exists +String queries +-------------- + The following operators are available as shortcuts to querying with regular expressions: @@ -78,8 +80,71 @@ expressions: * ``iendswith`` -- string field ends with value (case insensitive) * ``match`` -- performs an $elemMatch so you can match an entire document within an array -There are a few special operators for performing geographical queries, that -may used with :class:`~mongoengine.GeoPointField`\ s: + +Geo queries +----------- + +There are a few special operators for performing geographical queries. The following +were added in 0.8 for: :class:`~mongoengine.fields.PointField`, +:class:`~mongoengine.fields.LineStringField` and +:class:`~mongoengine.fields.PolygonField`: + +* ``geo_within`` -- Check if a geometry is within a polygon. For ease of use + it accepts either a geojson geometry or just the polygon coordinates eg:: + + loc.objects(point__geo_with=[[[40, 5], [40, 6], [41, 6], [40, 5]]]) + loc.objects(point__geo_with={"type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}) + +* ``geo_within_box`` - simplified geo_within searching with a box eg:: + + loc.objects(point__geo_within_box=[(-125.0, 35.0), (-100.0, 40.0)]) + loc.objects(point__geo_within_box=[, ]) + +* ``geo_within_polygon`` -- simplified geo_within searching within a simple polygon eg:: + + loc.objects(point__geo_within_polygon=[[40, 5], [40, 6], [41, 6], [40, 5]]) + loc.objects(point__geo_within_polygon=[ [ , ] , + [ , ] , + [ , ] ]) + +* ``geo_within_center`` -- simplified geo_within the flat circle radius of a point eg:: + + loc.objects(point__geo_within_center=[(-125.0, 35.0), 1]) + loc.objects(point__geo_within_center=[ [ , ] , ]) + +* ``geo_within_sphere`` -- simplified geo_within the spherical circle radius of a point eg:: + + loc.objects(point__geo_within_sphere=[(-125.0, 35.0), 1]) + loc.objects(point__geo_within_sphere=[ [ , ] , ]) + +* ``geo_intersects`` -- selects all locations that intersect with a geometry eg:: + + # Inferred from provided points lists: + loc.objects(poly__geo_intersects=[40, 6]) + loc.objects(poly__geo_intersects=[[40, 5], [40, 6]]) + loc.objects(poly__geo_intersects=[[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]) + + # With geoJson style objects + loc.objects(poly__geo_intersects={"type": "Point", "coordinates": [40, 6]}) + loc.objects(poly__geo_intersects={"type": "LineString", + "coordinates": [[40, 5], [40, 6]]}) + loc.objects(poly__geo_intersects={"type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}) + +* ``near`` -- Find all the locations near a given point:: + + loc.objects(point__near=[40, 5]) + loc.objects(point__near={"type": "Point", "coordinates": [40, 5]}) + + + You can also set the maximum distance in meters as well:: + + loc.objects(point__near=[40, 5], point__max_distance=1000) + + +The older 2D indexes are still supported with the +:class:`~mongoengine.fields.GeoPointField`: * ``within_distance`` -- provide a list containing a point and a maximum distance (e.g. [(41.342, -87.653), 5]) @@ -91,7 +156,9 @@ may used with :class:`~mongoengine.GeoPointField`\ s: [(35.0, -125.0), (40.0, -100.0)]) * ``within_polygon`` -- filter documents to those within a given polygon (e.g. [(41.91,-87.69), (41.92,-87.68), (41.91,-87.65), (41.89,-87.65)]). + .. note:: Requires Mongo Server 2.0 + * ``max_distance`` -- can be added to your location queries to set a maximum distance. @@ -100,7 +167,7 @@ Querying lists -------------- On most fields, this syntax will look up documents where the field specified matches the given value exactly, but when the field refers to a -:class:`~mongoengine.ListField`, a single item may be provided, in which case +:class:`~mongoengine.fields.ListField`, a single item may be provided, in which case lists that contain that item will be matched:: class Page(Document): @@ -319,7 +386,7 @@ Retrieving a subset of fields Sometimes a subset of fields on a :class:`~mongoengine.Document` is required, and for efficiency only these should be retrieved from the database. This issue is especially important for MongoDB, as fields may often be extremely large -(e.g. a :class:`~mongoengine.ListField` of +(e.g. a :class:`~mongoengine.fields.ListField` of :class:`~mongoengine.EmbeddedDocument`\ s, which represent the comments on a blog post. To select only a subset of fields, use :meth:`~mongoengine.queryset.QuerySet.only`, specifying the fields you want to @@ -351,14 +418,14 @@ If you later need the missing fields, just call Getting related data -------------------- -When iterating the results of :class:`~mongoengine.ListField` or -:class:`~mongoengine.DictField` we automatically dereference any +When iterating the results of :class:`~mongoengine.fields.ListField` or +:class:`~mongoengine.fields.DictField` we automatically dereference any :class:`~pymongo.dbref.DBRef` objects as efficiently as possible, reducing the number the queries to mongo. There are times when that efficiency is not enough, documents that have -:class:`~mongoengine.ReferenceField` objects or -:class:`~mongoengine.GenericReferenceField` objects at the top level are +:class:`~mongoengine.fields.ReferenceField` objects or +:class:`~mongoengine.fields.GenericReferenceField` objects at the top level are expensive as the number of queries to MongoDB can quickly rise. To limit the number of queries use @@ -392,6 +459,7 @@ You can also turn off all dereferencing for a fixed period by using the Advanced queries ================ + Sometimes calling a :class:`~mongoengine.queryset.QuerySet` object with keyword arguments can't fully express the query you want to use -- for example if you need to combine a number of constraints using *and* and *or*. This is made @@ -410,6 +478,11 @@ calling it with keyword arguments:: # Get top posts Post.objects((Q(featured=True) & Q(hits__gte=1000)) | Q(hits__gte=5000)) +.. warning:: You have to use bitwise operators. You cannot use ``or``, ``and`` + to combine queries as ``Q(a=a) or Q(b=b)`` is not the same as + ``Q(a=a) | Q(b=b)``. As ``Q(a=a)`` equates to true ``Q(a=a) or Q(b=b)`` is + the same as ``Q(a=a)``. + .. _guide-atomic-updates: Atomic updates @@ -424,7 +497,6 @@ that you may use with these methods: * ``unset`` -- delete a particular value (since MongoDB v1.3+) * ``inc`` -- increment a value by a given amount * ``dec`` -- decrement a value by a given amount -* ``pop`` -- remove the last item from a list * ``push`` -- append a value to a list * ``push_all`` -- append several values to a list * ``pop`` -- remove the first or last element of a list @@ -450,7 +522,7 @@ modifier comes before the field, not after it:: >>> post.tags ['database', 'nosql'] -.. note :: +.. note:: In version 0.5 the :meth:`~mongoengine.Document.save` runs atomic updates on changed documents by tracking changes to that document. @@ -466,7 +538,7 @@ cannot use the `$` syntax in keyword arguments it has been mapped to `S`:: >>> post.tags ['database', 'mongodb'] -.. note :: +.. note:: Currently only top level lists are handled, future versions of mongodb / pymongo plan to support nested positional operators. See `The $ positional operator `_. @@ -535,7 +607,7 @@ Javascript code. When accessing a field on a collection object, use square-bracket notation, and prefix the MongoEngine field name with a tilde. The field name that follows the tilde will be translated to the name used in the database. Note that when referring to fields on embedded documents, -the name of the :class:`~mongoengine.EmbeddedDocumentField`, followed by a dot, +the name of the :class:`~mongoengine.fields.EmbeddedDocumentField`, followed by a dot, should be used before the name of the field on the embedded document. The following example shows how the substitutions are made:: diff --git a/docs/guide/signals.rst b/docs/guide/signals.rst index 75f81e2..dc295d4 100644 --- a/docs/guide/signals.rst +++ b/docs/guide/signals.rst @@ -1,5 +1,6 @@ .. _signals: +======= Signals ======= @@ -7,32 +8,95 @@ Signals .. note:: - Signal support is provided by the excellent `blinker`_ library and - will gracefully fall back if it is not available. + Signal support is provided by the excellent `blinker`_ library. If you wish + to enable signal support this library must be installed, though it is not + required for MongoEngine to function. +Overview +-------- -The following document signals exist in MongoEngine and are pretty self-explanatory: +Signals are found within the `mongoengine.signals` module. Unless +specified signals receive no additional arguments beyond the `sender` class and +`document` instance. Post-signals are only called if there were no exceptions +raised during the processing of their related function. - * `mongoengine.signals.pre_init` - * `mongoengine.signals.post_init` - * `mongoengine.signals.pre_save` - * `mongoengine.signals.post_save` - * `mongoengine.signals.pre_delete` - * `mongoengine.signals.post_delete` - * `mongoengine.signals.pre_bulk_insert` - * `mongoengine.signals.post_bulk_insert` +Available signals include: -Example usage:: +`pre_init` + Called during the creation of a new :class:`~mongoengine.Document` or + :class:`~mongoengine.EmbeddedDocument` instance, after the constructor + arguments have been collected but before any additional processing has been + done to them. (I.e. assignment of default values.) Handlers for this signal + are passed the dictionary of arguments using the `values` keyword argument + and may modify this dictionary prior to returning. + +`post_init` + Called after all processing of a new :class:`~mongoengine.Document` or + :class:`~mongoengine.EmbeddedDocument` instance has been completed. + +`pre_save` + Called within :meth:`~mongoengine.document.Document.save` prior to performing + any actions. + +`pre_save_post_validation` + Called within :meth:`~mongoengine.document.Document.save` after validation + has taken place but before saving. + +`post_save` + Called within :meth:`~mongoengine.document.Document.save` after all actions + (validation, insert/update, cascades, clearing dirty flags) have completed + successfully. Passed the additional boolean keyword argument `created` to + indicate if the save was an insert or an update. + +`pre_delete` + Called within :meth:`~mongoengine.document.Document.delete` prior to + attempting the delete operation. + +`post_delete` + Called within :meth:`~mongoengine.document.Document.delete` upon successful + deletion of the record. + +`pre_bulk_insert` + Called after validation of the documents to insert, but prior to any data + being written. In this case, the `document` argument is replaced by a + `documents` argument representing the list of documents being inserted. + +`post_bulk_insert` + Called after a successful bulk insert operation. As per `pre_bulk_insert`, + the `document` argument is omitted and replaced with a `documents` argument. + An additional boolean argument, `loaded`, identifies the contents of + `documents` as either :class:`~mongoengine.Document` instances when `True` or + simply a list of primary key values for the inserted records if `False`. + +Attaching Events +---------------- + +After writing a handler function like the following:: + + import logging + from datetime import datetime from mongoengine import * from mongoengine import signals + def update_modified(sender, document): + document.modified = datetime.utcnow() + +You attach the event handler to your :class:`~mongoengine.Document` or +:class:`~mongoengine.EmbeddedDocument` subclass:: + + class Record(Document): + modified = DateTimeField() + + signals.pre_save.connect(update_modified) + +While this is not the most elaborate document model, it does demonstrate the +concepts involved. As a more complete demonstration you can also define your +handlers within your subclass:: + class Author(Document): name = StringField() - def __unicode__(self): - return self.name - @classmethod def pre_save(cls, sender, document, **kwargs): logging.debug("Pre Save: %s" % document.name) @@ -49,12 +113,40 @@ Example usage:: signals.pre_save.connect(Author.pre_save, sender=Author) signals.post_save.connect(Author.post_save, sender=Author) +Finally, you can also use this small decorator to quickly create a number of +signals and attach them to your :class:`~mongoengine.Document` or +:class:`~mongoengine.EmbeddedDocument` subclasses as class decorators:: -ReferenceFields and signals + def handler(event): + """Signal decorator to allow use of callback functions as class decorators.""" + + def decorator(fn): + def apply(cls): + event.connect(fn, sender=cls) + return cls + + fn.apply = apply + return fn + + return decorator + +Using the first example of updating a modification time the code is now much +cleaner looking while still allowing manual execution of the callback:: + + @handler(signals.pre_save) + def update_modified(sender, document): + document.modified = datetime.utcnow() + + @update_modified.apply + class Record(Document): + modified = DateTimeField() + + +ReferenceFields and Signals --------------------------- Currently `reverse_delete_rules` do not trigger signals on the other part of -the relationship. If this is required you must manually handled the +the relationship. If this is required you must manually handle the reverse deletion. .. _blinker: http://pypi.python.org/pypi/blinker diff --git a/docs/index.rst b/docs/index.rst index f6d44b5..77f965c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,16 +7,18 @@ MongoDB. To install it, simply run .. code-block:: console - # pip install -U mongoengine + $ pip install -U mongoengine :doc:`tutorial` - Start here for a quick overview. + A quick tutorial building a tumblelog to get you up and running with + MongoEngine. :doc:`guide/index` - The Full guide to MongoEngine + The Full guide to MongoEngine - from modeling documents to storing files, + from querying for data to firing signals and *everything* between. :doc:`apireference` - The complete API documentation. + The complete API documentation --- the innards of documents, querysets and fields. :doc:`upgrade` How to upgrade MongoEngine. @@ -28,35 +30,50 @@ Community --------- To get help with using MongoEngine, use the `MongoEngine Users mailing list -`_ or come chat on the -`#mongoengine IRC channel `_. +`_ or the ever popular +`stackoverflow `_. Contributing ------------ -The source is available on `GitHub `_ and -contributions are always encouraged. Contributions can be as simple as -minor tweaks to this documentation. To contribute, fork the project on +**Yes please!** We are always looking for contributions, additions and improvements. + +The source is available on `GitHub `_ +and contributions are always encouraged. Contributions can be as simple as +minor tweaks to this documentation, the website or the core. + +To contribute, fork the project on `GitHub `_ and send a pull request. -Also, you can join the developers' `mailing list -`_. - Changes ------- + See the :doc:`changelog` for a full list of changes to MongoEngine and :doc:`upgrade` for upgrade information. -.. toctree:: - :hidden: +.. note:: Always read and test the `upgrade `_ documentation before + putting updates live in production **;)** - tutorial - guide/index - apireference - django - changelog - upgrade +Offline Reading +--------------- + +Download the docs in `pdf `_ +or `epub `_ +formats for offline reading. + + +.. toctree:: + :maxdepth: 1 + :numbered: + :hidden: + + tutorial + guide/index + apireference + changelog + upgrade + django Indices and tables ------------------ diff --git a/docs/tutorial.rst b/docs/tutorial.rst index c4b69c4..0c592a0 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -1,6 +1,7 @@ ======== Tutorial ======== + This tutorial introduces **MongoEngine** by means of example --- we will walk through how to create a simple **Tumblelog** application. A Tumblelog is a type of blog where posts are not constrained to being conventional text-based posts. @@ -12,23 +13,29 @@ interface. Getting started =============== + Before we start, make sure that a copy of MongoDB is running in an accessible location --- running it locally will be easier, but if that is not an option -then it may be run on a remote server. +then it may be run on a remote server. If you haven't installed mongoengine, +simply use pip to install it like so:: + + $ pip install mongoengine Before we can start using MongoEngine, we need to tell it how to connect to our instance of :program:`mongod`. For this we use the :func:`~mongoengine.connect` -function. The only argument we need to provide is the name of the MongoDB -database to use:: +function. If running locally the only argument we need to provide is the name +of the MongoDB database to use:: from mongoengine import * connect('tumblelog') -For more information about connecting to MongoDB see :ref:`guide-connecting`. +There are lots of options for connecting to MongoDB, for more information about +them see the :ref:`guide-connecting` guide. Defining our documents ====================== + MongoDB is *schemaless*, which means that no schema is enforced by the database --- we may add and remove fields however we want and MongoDB won't complain. This makes life a lot easier in many regards, especially when there is a change @@ -39,17 +46,19 @@ define utility methods on our documents in the same way that traditional In our Tumblelog application we need to store several different types of information. We will need to have a collection of **users**, so that we may -link posts to an individual. We also need to store our different types -**posts** (text, image and link) in the database. To aid navigation of our +link posts to an individual. We also need to store our different types of +**posts** (eg: text, image and link) in the database. To aid navigation of our Tumblelog, posts may have **tags** associated with them, so that the list of posts shown to the user may be limited to posts that have been assigned a -specified tag. Finally, it would be nice if **comments** could be added to -posts. We'll start with **users**, as the others are slightly more involved. +specific tag. Finally, it would be nice if **comments** could be added to +posts. We'll start with **users**, as the other document models are slightly +more involved. Users ----- + Just as if we were using a relational database with an ORM, we need to define -which fields a :class:`User` may have, and what their types will be:: +which fields a :class:`User` may have, and what types of data they might store:: class User(Document): email = StringField(required=True) @@ -58,11 +67,13 @@ which fields a :class:`User` may have, and what their types will be:: This looks similar to how a the structure of a table would be defined in a regular ORM. The key difference is that this schema will never be passed on to -MongoDB --- this will only be enforced at the application level. Also, the User -documents will be stored in a MongoDB *collection* rather than a table. +MongoDB --- this will only be enforced at the application level, making future +changes easy to manage. Also, the User documents will be stored in a +MongoDB *collection* rather than a table. Posts, Comments and Tags ------------------------ + Now we'll think about how to store the rest of the information. If we were using a relational database, we would most likely have a table of **posts**, a table of **comments** and a table of **tags**. To associate the comments with @@ -75,16 +86,17 @@ of them stand out as particularly intuitive solutions. Posts ^^^^^ -But MongoDB *isn't* a relational database, so we're not going to do it that + +Happily mongoDB *isn't* a relational database, so we're not going to do it that way. As it turns out, we can use MongoDB's schemaless nature to provide us with -a much nicer solution. We will store all of the posts in *one collection* --- -each post type will just have the fields it needs. If we later want to add +a much nicer solution. We will store all of the posts in *one collection* and +each post type will only store the fields it needs. If we later want to add video posts, we don't have to modify the collection at all, we just *start using* the new fields we need to support video posts. This fits with the Object-Oriented principle of *inheritance* nicely. We can think of :class:`Post` as a base class, and :class:`TextPost`, :class:`ImagePost` and :class:`LinkPost` as subclasses of :class:`Post`. In fact, MongoEngine supports -this kind of modelling out of the box - all you need do is turn on inheritance +this kind of modelling out of the box --- all you need do is turn on inheritance by setting :attr:`allow_inheritance` to True in the :attr:`meta`:: class Post(Document): @@ -103,12 +115,13 @@ by setting :attr:`allow_inheritance` to True in the :attr:`meta`:: link_url = StringField() We are storing a reference to the author of the posts using a -:class:`~mongoengine.ReferenceField` object. These are similar to foreign key +:class:`~mongoengine.fields.ReferenceField` object. These are similar to foreign key fields in traditional ORMs, and are automatically translated into references when they are saved, and dereferenced when they are loaded. Tags ^^^^ + Now that we have our Post models figured out, how will we attach tags to them? MongoDB allows us to store lists of items natively, so rather than having a link table, we can just store a list of tags in each post. So, for both @@ -124,13 +137,16 @@ size of our database. So let's take a look that the code our modified author = ReferenceField(User) tags = ListField(StringField(max_length=30)) -The :class:`~mongoengine.ListField` object that is used to define a Post's tags +The :class:`~mongoengine.fields.ListField` object that is used to define a Post's tags takes a field object as its first argument --- this means that you can have -lists of any type of field (including lists). Note that we don't need to -modify the specialised post types as they all inherit from :class:`Post`. +lists of any type of field (including lists). + +.. note:: We don't need to modify the specialised post types as they all + inherit from :class:`Post`. Comments ^^^^^^^^ + A comment is typically associated with *one* post. In a relational database, to display a post with its comments, we would have to retrieve the post from the database, then query the database again for the comments associated with the @@ -158,7 +174,7 @@ We can then store a list of comment documents in our post document:: Handling deletions of references ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The :class:`~mongoengine.ReferenceField` object takes a keyword +The :class:`~mongoengine.fields.ReferenceField` object takes a keyword `reverse_delete_rule` for handling deletion rules if the reference is deleted. To delete all the posts if a user is deleted set the rule:: @@ -168,7 +184,7 @@ To delete all the posts if a user is deleted set the rule:: tags = ListField(StringField(max_length=30)) comments = ListField(EmbeddedDocumentField(Comment)) -See :class:`~mongoengine.ReferenceField` for more information. +See :class:`~mongoengine.fields.ReferenceField` for more information. .. note:: MapFields and DictFields currently don't support automatic handling of @@ -181,15 +197,15 @@ Now that we've defined how our documents will be structured, let's start adding some documents to the database. Firstly, we'll need to create a :class:`User` object:: - john = User(email='jdoe@example.com', first_name='John', last_name='Doe') - john.save() + ross = User(email='ross@example.com', first_name='Ross', last_name='Lawley').save() -Note that we could have also defined our user using attribute syntax:: +.. note:: + We could have also defined our user using attribute syntax:: - john = User(email='jdoe@example.com') - john.first_name = 'John' - john.last_name = 'Doe' - john.save() + ross = User(email='ross@example.com') + ross.first_name = 'Ross' + ross.last_name = 'Lawley' + ross.save() Now that we've got our user in the database, let's add a couple of posts:: @@ -198,16 +214,17 @@ Now that we've got our user in the database, let's add a couple of posts:: post1.tags = ['mongodb', 'mongoengine'] post1.save() - post2 = LinkPost(title='MongoEngine Documentation', author=john) - post2.link_url = 'http://tractiondigital.com/labs/mongoengine/docs' + post2 = LinkPost(title='MongoEngine Documentation', author=ross) + post2.link_url = 'http://docs.mongoengine.com/' post2.tags = ['mongoengine'] post2.save() -Note that if you change a field on a object that has already been saved, then -call :meth:`save` again, the document will be updated. +.. note:: If you change a field on a object that has already been saved, then + call :meth:`save` again, the document will be updated. Accessing our data ================== + So now we've got a couple of posts in our database, how do we display them? Each document class (i.e. any class that inherits either directly or indirectly from :class:`~mongoengine.Document`) has an :attr:`objects` attribute, which is @@ -219,6 +236,7 @@ class. So let's see how we can get our posts' titles:: Retrieving type-specific information ------------------------------------ + This will print the titles of our posts, one on each line. But What if we want to access the type-specific data (link_url, content, etc.)? One way is simply to use the :attr:`objects` attribute of a subclass of :class:`Post`:: @@ -257,6 +275,7 @@ text post, and "Link: " if it was a link post. Searching our posts by tag -------------------------- + The :attr:`objects` attribute of a :class:`~mongoengine.Document` is actually a :class:`~mongoengine.queryset.QuerySet` object. This lazily queries the database only when you need the data. It may also be filtered to narrow down @@ -275,3 +294,9 @@ used on :class:`~mongoengine.queryset.QuerySet` objects:: num_posts = Post.objects(tags='mongodb').count() print 'Found %d posts with tag "mongodb"' % num_posts +Learning more about mongoengine +------------------------------- + +If you got this far you've made a great start, so well done! The next step on +your mongoengine journey is the `full user guide `_, where you +can learn indepth about how to use mongoengine and mongodb. diff --git a/docs/upgrade.rst b/docs/upgrade.rst index 8724503..a1fccea 100644 --- a/docs/upgrade.rst +++ b/docs/upgrade.rst @@ -1,16 +1,41 @@ -========= +######### Upgrading -========= +######### + + +0.8.2 to 0.8.3 +************** + +Minor change that may impact users: + +DynamicDocument fields are now stored in creation order after any declared +fields. Previously they were stored alphabetically. + 0.7 to 0.8 +********** + +There have been numerous backwards breaking changes in 0.8. The reasons for +these are to ensure that MongoEngine has sane defaults going forward and that it +performs the best it can out of the box. Where possible there have been +FutureWarnings to help get you ready for the change, but that hasn't been +possible for the whole of the release. + +.. warning:: Breaking changes - test upgrading on a test system before putting + live. There maybe multiple manual steps in migrating and these are best honed + on a staging / test system. + +Python and PyMongo +================== + +MongoEngine requires python 2.6 (or above) and pymongo 2.5 (or above) + +Data Model ========== Inheritance ----------- -Data Model -~~~~~~~~~~ - The inheritance model has changed, we no longer need to store an array of :attr:`types` with the model we can just use the classname in :attr:`_cls`. This means that you will have to update your indexes for each of your @@ -44,9 +69,9 @@ inherited classes like so: :: Document Definition -~~~~~~~~~~~~~~~~~~~ +------------------- -The default for inheritance has changed - its now off by default and +The default for inheritance has changed - it is now off by default and :attr:`_cls` will not be stored automatically with the class. So if you extend your :class:`~mongoengine.Document` or :class:`~mongoengine.EmbeddedDocuments` you will need to declare :attr:`allow_inheritance` in the meta data like so: :: @@ -56,7 +81,7 @@ you will need to declare :attr:`allow_inheritance` in the meta data like so: :: meta = {'allow_inheritance': True} -Previously, if you had data the database that wasn't defined in the Document +Previously, if you had data in the database that wasn't defined in the Document definition, it would set it as an attribute on the document. This is no longer the case and the data is set only in the ``document._data`` dictionary: :: @@ -76,8 +101,141 @@ the case and the data is set only in the ``document._data`` dictionary: :: File "", line 1, in AttributeError: 'Animal' object has no attribute 'size' +The Document class has introduced a reserved function `clean()`, which will be +called before saving the document. If your document class happens to have a method +with the same name, please try to rename it. + + def clean(self): + pass + +ReferenceField +-------------- + +ReferenceFields now store ObjectIds by default - this is more efficient than +DBRefs as we already know what Document types they reference:: + + # Old code + class Animal(Document): + name = ReferenceField('self') + + # New code to keep dbrefs + class Animal(Document): + name = ReferenceField('self', dbref=True) + +To migrate all the references you need to touch each object and mark it as dirty +eg:: + + # Doc definition + class Person(Document): + name = StringField() + parent = ReferenceField('self') + friends = ListField(ReferenceField('self')) + + # Mark all ReferenceFields as dirty and save + for p in Person.objects: + p._mark_as_changed('parent') + p._mark_as_changed('friends') + p.save() + +`An example test migration for ReferenceFields is available on github +`_. + +.. Note:: Internally mongoengine handles ReferenceFields the same, so they are + converted to DBRef on loading and ObjectIds or DBRefs depending on settings + on storage. + +UUIDField +--------- + +UUIDFields now default to storing binary values:: + + # Old code + class Animal(Document): + uuid = UUIDField() + + # New code + class Animal(Document): + uuid = UUIDField(binary=False) + +To migrate all the uuids you need to touch each object and mark it as dirty +eg:: + + # Doc definition + class Animal(Document): + uuid = UUIDField() + + # Mark all UUIDFields as dirty and save + for a in Animal.objects: + a._mark_as_changed('uuid') + a.save() + +`An example test migration for UUIDFields is available on github +`_. + +DecimalField +------------ + +DecimalFields now store floats - previously it was storing strings and that +made it impossible to do comparisons when querying correctly.:: + + # Old code + class Person(Document): + balance = DecimalField() + + # New code + class Person(Document): + balance = DecimalField(force_string=True) + +To migrate all the DecimalFields you need to touch each object and mark it as dirty +eg:: + + # Doc definition + class Person(Document): + balance = DecimalField() + + # Mark all DecimalField's as dirty and save + for p in Person.objects: + p._mark_as_changed('balance') + p.save() + +.. note:: DecimalFields have also been improved with the addition of precision + and rounding. See :class:`~mongoengine.fields.DecimalField` for more information. + +`An example test migration for DecimalFields is available on github +`_. + +Cascading Saves +--------------- +To improve performance document saves will no longer automatically cascade. +Any changes to a Document's references will either have to be saved manually or +you will have to explicitly tell it to cascade on save:: + + # At the class level: + class Person(Document): + meta = {'cascade': True} + + # Or on save: + my_document.save(cascade=True) + +Storage +------- + +Document and Embedded Documents are now serialized based on declared field order. +Previously, the data was passed to mongodb as a dictionary and which meant that +order wasn't guaranteed - so things like ``$addToSet`` operations on +:class:`~mongoengine.EmbeddedDocument` could potentially fail in unexpected +ways. + +If this impacts you, you may want to rewrite the objects using the +``doc.mark_as_dirty('field')`` pattern described above. If you are using a +compound primary key then you will need to ensure the order is fixed and match +your EmbeddedDocument to that order. + Querysets -~~~~~~~~~ +========= + +Attack of the clones +-------------------- Querysets now return clones and should no longer be considered editable in place. This brings us in line with how Django's querysets work and removes a @@ -91,15 +249,87 @@ update your code like so: :: # Update example a) assign queryset after a change: mammals = Animal.objects(type="mammal") - carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so fitler can be applied + carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so filter can be applied [m for m in carnivores] # This will return all carnivores # Update example b) chain the queryset: mammals = Animal.objects(type="mammal").filter(order="Carnivora") # The final queryset is assgined to mammals [m for m in mammals] # This will return all carnivores +Len iterates the queryset +-------------------------- + +If you ever did `len(queryset)` it previously did a `count()` under the covers, +this caused some unusual issues. As `len(queryset)` is most often used by +`list(queryset)` we now cache the queryset results and use that for the length. + +This isn't as performant as a `count()` and if you aren't iterating the +queryset you should upgrade to use count:: + + # Old code + len(Animal.objects(type="mammal")) + + # New code + Animal.objects(type="mammal").count()) + + +.only() now inline with .exclude() +---------------------------------- + +The behaviour of `.only()` was highly ambiguous, now it works in mirror fashion +to `.exclude()`. Chaining `.only()` calls will increase the fields required:: + + # Old code + Animal.objects().only(['type', 'name']).only('name', 'order') # Would have returned just `name` + + # New code + Animal.objects().only('name') + + # Note: + Animal.objects().only(['name']).only('order') # Now returns `name` *and* `order` + + +Client +====== +PyMongo 2.4 came with a new connection client; MongoClient_ and started the +depreciation of the old :class:`~pymongo.connection.Connection`. MongoEngine +now uses the latest `MongoClient` for connections. By default operations were +`safe` but if you turned them off or used the connection directly this will +impact your queries. + +Querysets +--------- + +Safe +^^^^ + +`safe` has been depreciated in the new MongoClient connection. Please use +`write_concern` instead. As `safe` always defaulted as `True` normally no code +change is required. To disable confirmation of the write just pass `{"w": 0}` +eg: :: + + # Old + Animal(name="Dinasour").save(safe=False) + + # new code: + Animal(name="Dinasour").save(write_concern={"w": 0}) + +Write Concern +^^^^^^^^^^^^^ + +`write_options` has been replaced with `write_concern` to bring it inline with +pymongo. To upgrade simply rename any instances where you used the `write_option` +keyword to `write_concern` like so:: + + # Old code: + Animal(name="Dinasour").save(write_options={"w": 2}) + + # new code: + Animal(name="Dinasour").save(write_concern={"w": 2}) + + Indexes -------- +======= Index methods are no longer tied to querysets but rather to the document class. Although `QuerySet._ensure_indexes` and `QuerySet.ensure_index` still exist. @@ -107,17 +337,19 @@ They should be replaced with :func:`~mongoengine.Document.ensure_indexes` / :func:`~mongoengine.Document.ensure_index`. SequenceFields --------------- +============== :class:`~mongoengine.fields.SequenceField` now inherits from `BaseField` to allow flexible storage of the calculated value. As such MIN and MAX settings are no longer handled. +.. _MongoClient: http://blog.mongodb.org/post/36666163412/introducing-mongoclient + 0.6 to 0.7 -========== +********** Cascade saves -------------- +============= Saves will raise a `FutureWarning` if they cascade and cascade hasn't been set to True. This is because in 0.8 it will default to False. If you require @@ -131,11 +363,11 @@ via `save` eg :: # Or in code: my_document.save(cascade=True) -.. note :: +.. note:: Remember: cascading saves **do not** cascade through lists. ReferenceFields ---------------- +=============== ReferenceFields now can store references as ObjectId strings instead of DBRefs. This will become the default in 0.8 and if `dbref` is not set a `FutureWarning` @@ -164,7 +396,7 @@ migrate :: item_frequencies ----------------- +================ In the 0.6 series we added support for null / zero / false values in item_frequencies. A side effect was to return keys in the value they are @@ -173,14 +405,14 @@ updated to handle native types rather than strings keys for the results of item frequency queries. BinaryFields ------------- +============ Binary fields have been updated so that they are native binary types. If you previously were doing `str` comparisons with binary field values you will have to update and wrap the value in a `str`. 0.5 to 0.6 -========== +********** Embedded Documents - if you had a `pk` field you will have to rename it from `_id` to `pk` as pk is no longer a property of Embedded Documents. @@ -200,21 +432,21 @@ don't define :attr:`allow_inheritance` in their meta. You may need to update pyMongo to 2.0 for use with Sharding. 0.4 to 0.5 -=========== +********** There have been the following backwards incompatibilities from 0.4 to 0.5. The main areas of changed are: choices in fields, map_reduce and collection names. Choice options: ---------------- +=============== -Are now expected to be an iterable of tuples, with the first element in each +Are now expected to be an iterable of tuples, with the first element in each tuple being the actual value to be stored. The second element is the human-readable name for the option. PyMongo / MongoDB ------------------ +================= map reduce now requires pymongo 1.11+- The pymongo `merge_output` and `reduce_output` parameters, have been depreciated. @@ -228,10 +460,10 @@ such the following have been changed: Default collection naming -------------------------- +========================= -Previously it was just lowercase, its now much more pythonic and readable as -its lowercase and underscores, previously :: +Previously it was just lowercase, it's now much more pythonic and readable as +it's lowercase and underscores, previously :: class MyAceDocument(Document): pass @@ -298,5 +530,5 @@ Alternatively, you can rename your collections eg :: mongodb 1.8 > 2.0 + =================== -Its been reported that indexes may need to be recreated to the newer version of indexes. +It's been reported that indexes may need to be recreated to the newer version of indexes. To do this drop indexes and call ``ensure_indexes`` on each model. diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index da72e53..2b68b3c 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -8,12 +8,14 @@ import queryset from queryset import * import signals from signals import * +from errors import * +import errors import django __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + - list(queryset.__all__) + signals.__all__) + list(queryset.__all__) + signals.__all__ + list(errors.__all__)) -VERSION = (0, 8, 0, '+') +VERSION = (0, 8, 4) def get_version(): diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py index ce119b3..e8d4b6a 100644 --- a/mongoengine/base/__init__.py +++ b/mongoengine/base/__init__.py @@ -3,3 +3,6 @@ from mongoengine.base.datastructures import * from mongoengine.base.document import * from mongoengine.base.fields import * from mongoengine.base.metaclasses import * + +# Help with backwards compatibility +from mongoengine.errors import * diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index c750b5b..4652fb5 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -13,7 +13,11 @@ class BaseDict(dict): _name = None def __init__(self, dict_items, instance, name): - self._instance = weakref.proxy(instance) + Document = _import_class('Document') + EmbeddedDocument = _import_class('EmbeddedDocument') + + if isinstance(instance, (Document, EmbeddedDocument)): + self._instance = weakref.proxy(instance) self._name = name return super(BaseDict, self).__init__(dict_items) @@ -80,7 +84,11 @@ class BaseList(list): _name = None def __init__(self, list_items, instance, name): - self._instance = weakref.proxy(instance) + Document = _import_class('Document') + EmbeddedDocument = _import_class('EmbeddedDocument') + + if isinstance(instance, (Document, EmbeddedDocument)): + self._instance = weakref.proxy(instance) self._name = name return super(BaseList, self).__init__(list_items) @@ -100,6 +108,14 @@ class BaseList(list): self._mark_as_changed() return super(BaseList, self).__delitem__(*args, **kwargs) + def __setslice__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).__setslice__(*args, **kwargs) + + def __delslice__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).__delslice__(*args, **kwargs) + def __getstate__(self): self.instance = None self._dereferenced = False diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 7ec672f..cea2f09 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -4,8 +4,9 @@ import numbers from functools import partial import pymongo -from bson import json_util +from bson import json_util, ObjectId from bson.dbref import DBRef +from bson.son import SON from mongoengine import signals from mongoengine.common import _import_class @@ -41,6 +42,9 @@ class BaseDocument(object): # Combine positional arguments with named arguments. # We only want named arguments. field = iter(self._fields_ordered) + # If its an automatic id field then skip to the first defined field + if self._auto_id_field: + next(field) for value in args: name = next(field) if name in values: @@ -50,6 +54,7 @@ class BaseDocument(object): signals.pre_init.send(self.__class__, document=self, values=values) self._data = {} + self._dynamic_fields = SON() # Assign default values to instance for key, field in self._fields.iteritems(): @@ -60,7 +65,6 @@ class BaseDocument(object): # Set passed values after initialisation if self._dynamic: - self._dynamic_fields = {} dynamic_data = {} for key, value in values.iteritems(): if key in self._fields or key == '_id': @@ -115,6 +119,7 @@ class BaseDocument(object): field = DynamicField(db_field=name) field.name = name self._dynamic_fields[name] = field + self._fields_ordered += (name,) if not name.startswith('_'): value = self.__expand_dynamic_values(name, value) @@ -140,28 +145,33 @@ class BaseDocument(object): super(BaseDocument, self).__setattr__(name, value) def __getstate__(self): - removals = ("get_%s_display" % k - for k, v in self._fields.items() if v.choices) - for k in removals: + data = {} + for k in ('_changed_fields', '_initialised', '_created', + '_dynamic_fields', '_fields_ordered'): if hasattr(self, k): - delattr(self, k) - return self.__dict__ + data[k] = getattr(self, k) + data['_data'] = self.to_mongo() + return data - def __setstate__(self, __dict__): - self.__dict__ = __dict__ - self.__set_field_display() + def __setstate__(self, data): + if isinstance(data["_data"], SON): + data["_data"] = self.__class__._from_son(data["_data"])._data + for k in ('_changed_fields', '_initialised', '_created', '_data', + '_fields_ordered', '_dynamic_fields'): + if k in data: + setattr(self, k, data[k]) + dynamic_fields = data.get('_dynamic_fields') or SON() + for k in dynamic_fields.keys(): + setattr(self, k, data["_data"].get(k)) def __iter__(self): - if 'id' in self._fields and 'id' not in self._fields_ordered: - return iter(('id', ) + self._fields_ordered) - return iter(self._fields_ordered) def __getitem__(self, name): """Dictionary-style field access, return a field's value if present. """ try: - if name in self._fields: + if name in self._fields_ordered: return getattr(self, name) except AttributeError: pass @@ -211,7 +221,7 @@ class BaseDocument(object): return not self.__eq__(other) def __hash__(self): - if self.pk is None: + if getattr(self, 'pk', None) is None: # For new object return super(BaseDocument, self).__hash__() else: @@ -228,11 +238,18 @@ class BaseDocument(object): pass def to_mongo(self): - """Return data dictionary ready for use with MongoDB. + """Return as SON data ready for use with MongoDB. """ - data = {} - for field_name, field in self._fields.iteritems(): + data = SON() + data["_id"] = None + data['_cls'] = self._class_name + + for field_name in self: value = self._data.get(field_name, None) + field = self._fields.get(field_name) + if field is None and self._dynamic: + field = self._dynamic_fields.get(field_name) + if value is not None: value = field.to_mongo(value) @@ -244,19 +261,20 @@ class BaseDocument(object): if value is not None: data[field.db_field] = value + # If "_id" has not been set, then try and set it + Document = _import_class("Document") + if isinstance(self, Document): + if data["_id"] is None: + data["_id"] = self._data.get("id", None) + + if data['_id'] is None: + data.pop('_id') + # Only add _cls if allow_inheritance is True - if (hasattr(self, '_meta') and - self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == True): - data['_cls'] = self._class_name + if (not hasattr(self, '_meta') or + not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): + data.pop('_cls') - if '_id' in data and data['_id'] is None: - del data['_id'] - - if not self._dynamic: - return data - - for name, field in self._dynamic_fields.items(): - data[name] = field.to_mongo(self._data.get(name, None)) return data def validate(self, clean=True): @@ -272,11 +290,8 @@ class BaseDocument(object): errors[NON_FIELD_ERRORS] = error # Get a list of tuples of field names and their current values - fields = [(field, self._data.get(name)) - for name, field in self._fields.items()] - if self._dynamic: - fields += [(field, self._data.get(name)) - for name, field in self._dynamic_fields.items()] + fields = [(self._fields.get(name, self._dynamic_fields.get(name)), + self._data.get(name)) for name in self._fields_ordered] EmbeddedDocumentField = _import_class("EmbeddedDocumentField") GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") @@ -306,9 +321,9 @@ class BaseDocument(object): message = "ValidationError (%s:%s) " % (self._class_name, pk) raise ValidationError(message, errors=errors) - def to_json(self): + def to_json(self, *args, **kwargs): """Converts a document to JSON""" - return json_util.dumps(self.to_mongo()) + return json_util.dumps(self.to_mongo(), *args, **kwargs) @classmethod def from_json(cls, json_data): @@ -375,11 +390,12 @@ class BaseDocument(object): if field_value: field_value._clear_changed_fields() - def _get_changed_fields(self, key='', inspected=None): + def _get_changed_fields(self, inspected=None): """Returns a list of all fields that have explicitly been changed. """ EmbeddedDocument = _import_class("EmbeddedDocument") DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") + ReferenceField = _import_class("ReferenceField") _changed_fields = [] _changed_fields += getattr(self, '_changed_fields', []) @@ -389,38 +405,39 @@ class BaseDocument(object): return _changed_fields inspected.add(self.id) - field_list = self._fields.copy() - if self._dynamic: - field_list.update(self._dynamic_fields) - - for field_name in field_list: - + for field_name in self._fields_ordered: db_field_name = self._db_field_map.get(field_name, field_name) key = '%s.' % db_field_name - field = self._data.get(field_name, None) - if hasattr(field, 'id'): - if field.id in inspected: - continue - inspected.add(field.id) + data = self._data.get(field_name, None) + field = self._fields.get(field_name) - if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) + if hasattr(data, 'id'): + if data.id in inspected: + continue + inspected.add(data.id) + if isinstance(field, ReferenceField): + continue + elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields): # Find all embedded fields that have been changed - changed = field._get_changed_fields(key, inspected) + changed = data._get_changed_fields(inspected) _changed_fields += ["%s%s" % (key, k) for k in changed if k] - elif (isinstance(field, (list, tuple, dict)) and + elif (isinstance(data, (list, tuple, dict)) and db_field_name not in _changed_fields): # Loop list / dict fields as they contain documents # Determine the iterator to use - if not hasattr(field, 'items'): - iterator = enumerate(field) + if not hasattr(data, 'items'): + iterator = enumerate(data) else: - iterator = field.iteritems() + iterator = data.iteritems() for index, value in iterator: if not hasattr(value, '_get_changed_fields'): continue + if (hasattr(field, 'field') and + isinstance(field.field, ReferenceField)): + continue list_key = "%s%s." % (key, index) - changed = value._get_changed_fields(list_key, inspected) + changed = value._get_changed_fields(inspected) _changed_fields += ["%s%s" % (list_key, k) for k in changed if k] return _changed_fields @@ -433,7 +450,6 @@ class BaseDocument(object): doc = self.to_mongo() set_fields = self._get_changed_fields() - set_data = {} unset_data = {} parts = [] if hasattr(self, '_changed_fields'): @@ -444,7 +460,7 @@ class BaseDocument(object): d = doc new_path = [] for p in parts: - if isinstance(d, DBRef): + if isinstance(d, (ObjectId, DBRef)): break elif isinstance(d, list) and p.isdigit(): d = d[int(p)] @@ -613,8 +629,10 @@ class BaseDocument(object): # Check to see if we need to include _cls allow_inheritance = cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) - include_cls = allow_inheritance and not spec.get('sparse', False) - + include_cls = (allow_inheritance and not spec.get('sparse', False) and + spec.get('cls', True)) + if "cls" in spec: + spec.pop('cls') for key in spec['fields']: # If inherited spec continue if isinstance(key, (list, tuple)): @@ -648,7 +666,8 @@ class BaseDocument(object): if include_cls and direction is not pymongo.GEO2D: index_list.insert(0, ('_cls', 1)) - spec['fields'] = index_list + if index_list: + spec['fields'] = index_list if spec.get('sparse', False) and len(spec['fields']) > 1: raise ValueError( 'Sparse indexes can only have one field in them. ' @@ -690,13 +709,13 @@ class BaseDocument(object): # Add the new index to the list fields = [("%s%s" % (namespace, f), pymongo.ASCENDING) - for f in unique_fields] + for f in unique_fields] index = {'fields': fields, 'unique': True, 'sparse': sparse} unique_indexes.append(index) # Grab any embedded document field unique indexes if (field.__class__.__name__ == "EmbeddedDocumentField" and - field.document_type != cls): + field.document_type != cls): field_namespace = "%s." % field_name doc_cls = field.document_type unique_indexes += doc_cls._unique_with_indexes(field_namespace) @@ -704,26 +723,31 @@ class BaseDocument(object): return unique_indexes @classmethod - def _geo_indices(cls, inspected=None): + def _geo_indices(cls, inspected=None, parent_field=None): inspected = inspected or [] geo_indices = [] inspected.append(cls) - EmbeddedDocumentField = _import_class("EmbeddedDocumentField") - GeoPointField = _import_class("GeoPointField") + geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField", + "PointField", "LineStringField", "PolygonField"] + + geo_field_types = tuple([_import_class(field) for field in geo_field_type_names]) for field in cls._fields.values(): - if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): + if not isinstance(field, geo_field_types): continue if hasattr(field, 'document_type'): field_cls = field.document_type if field_cls in inspected: continue if hasattr(field_cls, '_geo_indices'): - geo_indices += field_cls._geo_indices(inspected) + geo_indices += field_cls._geo_indices(inspected, parent_field=field.db_field) elif field._geo_index: + field_name = field.db_field + if parent_field: + field_name = "%s.%s" % (parent_field, field_name) geo_indices.append({'fields': - [(field.db_field, pymongo.GEO2D)]}) + [(field_name, field._geo_index)]}) return geo_indices @classmethod @@ -738,7 +762,7 @@ class BaseDocument(object): for field_name in parts: # Handle ListField indexing: - if field_name.isdigit(): + if field_name.isdigit() and hasattr(field, 'field'): new_field = field.field fields.append(field_name) continue diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 6ebba36..c6abd02 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -2,7 +2,8 @@ import operator import warnings import weakref -from bson import DBRef, ObjectId +from bson import DBRef, ObjectId, SON +import pymongo from mongoengine.common import _import_class from mongoengine.errors import ValidationError @@ -10,7 +11,7 @@ from mongoengine.errors import ValidationError from mongoengine.base.common import ALLOW_INHERITANCE from mongoengine.base.datastructures import BaseDict, BaseList -__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField") +__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") class BaseField(object): @@ -35,6 +36,29 @@ class BaseField(object): unique=False, unique_with=None, primary_key=False, validation=None, choices=None, verbose_name=None, help_text=None): + """ + :param db_field: The database field to store this field in + (defaults to the name of the field) + :param name: Depreciated - use db_field + :param required: If the field is required. Whether it has to have a + value or not. Defaults to False. + :param default: (optional) The default value for this field if no value + has been set (or if the value has been unset). It Can be a + callable. + :param unique: Is the field value unique or not. Defaults to False. + :param unique_with: (optional) The other field this field should be + unique with. + :param primary_key: Mark this field as the primary key. Defaults to False. + :param validation: (optional) A callable to validate the value of the + field. Generally this is deprecated in favour of the + `FIELD.validate` method + :param choices: (optional) The valid choices + :param verbose_name: (optional) The verbose name for the field. + Designed to be human readable and is often used when generating + model forms from the document model. + :param help_text: (optional) The help text for this field and is often + used when generating model forms from the document model. + """ self.db_field = (db_field or name) if not primary_key else '_id' if name: msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" @@ -58,20 +82,14 @@ class BaseField(object): BaseField.creation_counter += 1 def __get__(self, instance, owner): - """Descriptor for retrieving a value from a field in a document. Do - any necessary conversion between Python and MongoDB types. + """Descriptor for retrieving a value from a field in a document. """ if instance is None: # Document class being used rather than a document object return self - # Get value from document instance if available, if not use default - value = instance._data.get(self.name) - if value is None: - value = self.default - # Allow callable default values - if callable(value): - value = value() + # Get value from document instance if available + value = instance._data.get(self.name) EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: @@ -81,13 +99,24 @@ class BaseField(object): def __set__(self, instance, value): """Descriptor for assigning a value to a field in a document. """ - changed = False - if (self.name not in instance._data or - instance._data[self.name] != value): - changed = True - instance._data[self.name] = value - if changed and instance._initialised: - instance._mark_as_changed(self.name) + + # If setting to None and theres a default + # Then set the value to the default value + if value is None and self.default is not None: + value = self.default + if callable(value): + value = value() + + if instance._initialised: + try: + if (self.name not in instance._data or + instance._data[self.name] != value): + instance._mark_as_changed(self.name) + except: + # Values cant be compared eg: naive and tz datetimes + # So mark it as changed + instance._mark_as_changed(self.name) + instance._data[self.name] = value def error(self, message="", errors=None, field_name=None): """Raises a ValidationError. @@ -157,7 +186,6 @@ class ComplexBaseField(BaseField): """ field = None - __dereference = False def __get__(self, instance, owner): """Descriptor to automatically dereference references. @@ -172,9 +200,11 @@ class ComplexBaseField(BaseField): (self.field is None or isinstance(self.field, (GenericReferenceField, ReferenceField)))) + _dereference = _import_class("DeReference")() + self._auto_dereference = instance._fields[self.name]._auto_dereference - if not self.__dereference and instance._initialised and dereference: - instance._data[self.name] = self._dereference( + if instance._initialised and dereference: + instance._data[self.name] = _dereference( instance._data.get(self.name), max_depth=1, instance=instance, name=self.name ) @@ -183,7 +213,7 @@ class ComplexBaseField(BaseField): # Convert lists / values so we can watch for any changes on them if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): + not isinstance(value, BaseList)): value = BaseList(value, instance, self.name) instance._data[self.name] = value elif isinstance(value, dict) and not isinstance(value, BaseDict): @@ -191,9 +221,9 @@ class ComplexBaseField(BaseField): instance._data[self.name] = value if (self._auto_dereference and instance._initialised and - isinstance(value, (BaseList, BaseDict)) - and not value._dereferenced): - value = self._dereference( + isinstance(value, (BaseList, BaseDict)) + and not value._dereferenced): + value = _dereference( value, max_depth=1, instance=instance, name=self.name ) value._dereferenced = True @@ -201,12 +231,6 @@ class ComplexBaseField(BaseField): return value - def __set__(self, instance, value): - """Descriptor for assigning a value to a field in a document. - """ - instance._data[self.name] = value - instance._mark_as_changed(self.name) - def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. """ @@ -228,7 +252,7 @@ class ComplexBaseField(BaseField): if self.field: value_dict = dict([(key, self.field.to_python(item)) - for key, item in value.items()]) + for key, item in value.items()]) else: value_dict = {} for k, v in value.items(): @@ -279,7 +303,7 @@ class ComplexBaseField(BaseField): if self.field: value_dict = dict([(key, self.field.to_mongo(item)) - for key, item in value.iteritems()]) + for key, item in value.iteritems()]) else: value_dict = {} for k, v in value.iteritems(): @@ -295,7 +319,7 @@ class ComplexBaseField(BaseField): meta = getattr(v, '_meta', {}) allow_inheritance = ( meta.get('allow_inheritance', ALLOW_INHERITANCE) - == True) + is True) if not allow_inheritance and not self.field: value_dict[k] = GenericReferenceField().to_mongo(v) else: @@ -359,13 +383,6 @@ class ComplexBaseField(BaseField): owner_document = property(_get_owner_document, _set_owner_document) - @property - def _dereference(self,): - if not self.__dereference: - DeReference = _import_class("DeReference") - self.__dereference = DeReference() # Cached - return self.__dereference - class ObjectIdField(BaseField): """A field wrapper around MongoDB's ObjectIds. @@ -393,3 +410,100 @@ class ObjectIdField(BaseField): ObjectId(unicode(value)) except: self.error('Invalid Object ID') + + +class GeoJsonBaseField(BaseField): + """A geo json field storing a geojson style object. + .. versionadded:: 0.8 + """ + + _geo_index = pymongo.GEOSPHERE + _type = "GeoBase" + + def __init__(self, auto_index=True, *args, **kwargs): + """ + :param auto_index: Automatically create a "2dsphere" index. Defaults + to `True`. + """ + self._name = "%sField" % self._type + if not auto_index: + self._geo_index = False + super(GeoJsonBaseField, self).__init__(*args, **kwargs) + + def validate(self, value): + """Validate the GeoJson object based on its type + """ + if isinstance(value, dict): + if set(value.keys()) == set(['type', 'coordinates']): + if value['type'] != self._type: + self.error('%s type must be "%s"' % (self._name, self._type)) + return self.validate(value['coordinates']) + else: + self.error('%s can only accept a valid GeoJson dictionary' + ' or lists of (x, y)' % self._name) + return + elif not isinstance(value, (list, tuple)): + self.error('%s can only accept lists of [x, y]' % self._name) + return + + validate = getattr(self, "_validate_%s" % self._type.lower()) + error = validate(value) + if error: + self.error(error) + + def _validate_polygon(self, value): + if not isinstance(value, (list, tuple)): + return 'Polygons must contain list of linestrings' + + # Quick and dirty validator + try: + value[0][0][0] + except: + return "Invalid Polygon must contain at least one valid linestring" + + errors = [] + for val in value: + error = self._validate_linestring(val, False) + if not error and val[0] != val[-1]: + error = 'LineStrings must start and end at the same point' + if error and error not in errors: + errors.append(error) + if errors: + return "Invalid Polygon:\n%s" % ", ".join(errors) + + def _validate_linestring(self, value, top_level=True): + """Validates a linestring""" + if not isinstance(value, (list, tuple)): + return 'LineStrings must contain list of coordinate pairs' + + # Quick and dirty validator + try: + value[0][0] + except: + return "Invalid LineString must contain at least one valid point" + + errors = [] + for val in value: + error = self._validate_point(val) + if error and error not in errors: + errors.append(error) + if errors: + if top_level: + return "Invalid LineString:\n%s" % ", ".join(errors) + else: + return "%s" % ", ".join(errors) + + def _validate_point(self, value): + """Validate each set of coords""" + if not isinstance(value, (list, tuple)): + return 'Points must be a list of coordinate pairs' + elif not len(value) == 2: + return "Value (%s) must be a two-dimensional point" % repr(value) + elif (not isinstance(value[0], (float, int)) or + not isinstance(value[1], (float, int))): + return "Both values (%s) in point must be float or int" % repr(value) + + def to_mongo(self, value): + if isinstance(value, dict): + return value + return SON([("type", self._type), ("coordinates", value)]) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index a53744d..ff5afdd 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -91,11 +91,12 @@ class DocumentMetaclass(type): attrs['_fields'] = doc_fields attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) for k, v in doc_fields.iteritems()]) + attrs['_reverse_db_field_map'] = dict( + (v, k) for k, v in attrs['_db_field_map'].iteritems()) + attrs['_fields_ordered'] = tuple(i[1] for i in sorted( (v.creation_counter, v.name) for v in doc_fields.itervalues())) - attrs['_reverse_db_field_map'] = dict( - (v, k) for k, v in attrs['_db_field_map'].iteritems()) # # Set document hierarchy @@ -140,8 +141,31 @@ class DocumentMetaclass(type): base._subclasses += (_cls,) base._types = base._subclasses # TODO depreciate _types - # Handle delete rules Document, EmbeddedDocument, DictField = cls._import_classes() + + if issubclass(new_class, Document): + new_class._collection = None + + # Add class to the _document_registry + _document_registry[new_class._class_name] = new_class + + # In Python 2, User-defined methods objects have special read-only + # attributes 'im_func' and 'im_self' which contain the function obj + # and class instance object respectively. With Python 3 these special + # attributes have been replaced by __func__ and __self__. The Blinker + # module continues to use im_func and im_self, so the code below + # copies __func__ into im_func and __self__ into im_self for + # classmethod objects in Document derived classes. + if PY3: + for key, val in new_class.__dict__.items(): + if isinstance(val, classmethod): + f = val.__get__(new_class) + if hasattr(f, '__func__') and not hasattr(f, 'im_func'): + f.__dict__.update({'im_func': getattr(f, '__func__')}) + if hasattr(f, '__self__') and not hasattr(f, 'im_self'): + f.__dict__.update({'im_self': getattr(f, '__self__')}) + + # Handle delete rules for field in new_class._fields.itervalues(): f = field f.owner_document = new_class @@ -167,33 +191,11 @@ class DocumentMetaclass(type): field.name, delete_rule) if (field.name and hasattr(Document, field.name) and - EmbeddedDocument not in new_class.mro()): + EmbeddedDocument not in new_class.mro()): msg = ("%s is a document method and not a valid " "field name" % field.name) raise InvalidDocumentError(msg) - if issubclass(new_class, Document): - new_class._collection = None - - # Add class to the _document_registry - _document_registry[new_class._class_name] = new_class - - # In Python 2, User-defined methods objects have special read-only - # attributes 'im_func' and 'im_self' which contain the function obj - # and class instance object respectively. With Python 3 these special - # attributes have been replaced by __func__ and __self__. The Blinker - # module continues to use im_func and im_self, so the code below - # copies __func__ into im_func and __self__ into im_self for - # classmethod objects in Document derived classes. - if PY3: - for key, val in new_class.__dict__.items(): - if isinstance(val, classmethod): - f = val.__get__(new_class) - if hasattr(f, '__func__') and not hasattr(f, 'im_func'): - f.__dict__.update({'im_func': getattr(f, '__func__')}) - if hasattr(f, '__self__') and not hasattr(f, 'im_self'): - f.__dict__.update({'im_self': getattr(f, '__self__')}) - return new_class def add_to_class(self, name, value): @@ -315,8 +317,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # may set allow_inheritance to False simple_class = all([b._meta.get('abstract') for b in flattened_bases if hasattr(b, '_meta')]) - if (not simple_class and meta['allow_inheritance'] == False and - not meta['abstract']): + if (not simple_class and meta['allow_inheritance'] is False and + not meta['abstract']): raise ValueError('Only direct subclasses of Document may set ' '"allow_inheritance" to False') @@ -339,9 +341,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): if callable(collection): new_class._meta['collection'] = collection(new_class) - # Provide a default queryset unless one has been set - manager = attrs.get('objects', QuerySetManager()) - new_class.objects = manager + # Provide a default queryset unless exists or one has been set + if 'objects' not in dir(new_class): + new_class.objects = QuerySetManager() # Validate the fields and set primary key if needed for field_name, field in new_class._fields.iteritems(): @@ -357,12 +359,18 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class.id = field # Set primary key if not defined by the document + new_class._auto_id_field = False if not new_class._meta.get('id_field'): + new_class._auto_id_field = True new_class._meta['id_field'] = 'id' new_class._fields['id'] = ObjectIdField(db_field='_id') new_class._fields['id'].name = 'id' new_class.id = new_class._fields['id'] + # Prepend id field to _fields_ordered + if 'id' in new_class._fields and 'id' not in new_class._fields_ordered: + new_class._fields_ordered = ('id', ) + new_class._fields_ordered + # Merge in exceptions with parent hierarchy exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) module = attrs.get('__module__') diff --git a/mongoengine/common.py b/mongoengine/common.py index 718ac0b..6303231 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -2,7 +2,19 @@ _class_registry_cache = {} def _import_class(cls_name): - """Cached mechanism for imports""" + """Cache mechanism for imports. + + Due to complications of circular imports mongoengine needs to do lots of + inline imports in functions. This is inefficient as classes are + imported repeated throughout the mongoengine code. This is + compounded by some recursive functions requiring inline imports. + + :mod:`mongoengine.common` provides a single point to import all these + classes. Circular imports aren't an issue as it dynamically imports the + class when first needed. Subsequent calls to the + :func:`~mongoengine.common._import_class` can then directly retrieve the + class from the :data:`mongoengine.common._class_registry_cache`. + """ if cls_name in _class_registry_cache: return _class_registry_cache.get(cls_name) @@ -11,7 +23,9 @@ def _import_class(cls_name): field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', 'FileField', 'GenericReferenceField', 'GenericEmbeddedDocumentField', 'GeoPointField', - 'ReferenceField', 'StringField', 'ComplexBaseField') + 'PointField', 'LineStringField', 'ListField', + 'PolygonField', 'ReferenceField', 'StringField', + 'ComplexBaseField') queryset_classes = ('OperationError',) deref_classes = ('DeReference',) @@ -33,4 +47,4 @@ def _import_class(cls_name): for cls in import_classes: _class_registry_cache[cls] = getattr(module, cls) - return _class_registry_cache.get(cls_name) \ No newline at end of file + return _class_registry_cache.get(cls_name) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index a47be44..4275da5 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,5 +1,5 @@ import pymongo -from pymongo import Connection, ReplicaSetConnection, uri_parser +from pymongo import MongoClient, MongoReplicaSetClient, uri_parser __all__ = ['ConnectionError', 'connect', 'register_connection', @@ -55,12 +55,9 @@ def register_connection(alias, name, host='localhost', port=27017, # Handle uri style connections if "://" in host: uri_dict = uri_parser.parse_uri(host) - if uri_dict.get('database') is None: - raise ConnectionError("If using URI style connection include "\ - "database name in string") conn_settings.update({ 'host': host, - 'name': uri_dict.get('database'), + 'name': uri_dict.get('database') or name, 'username': uri_dict.get('username'), 'password': uri_dict.get('password'), 'read_preference': read_preference, @@ -112,15 +109,15 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings['slaves'] = slaves conn_settings.pop('read_preference', None) - connection_class = Connection + connection_class = MongoClient if 'replicaSet' in conn_settings: conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) - # Discard port since it can't be used on ReplicaSetConnection + # Discard port since it can't be used on MongoReplicaSetClient conn_settings.pop('port', None) # Discard replicaSet if not base string if not isinstance(conn_settings['replicaSet'], basestring): conn_settings.pop('replicaSet', None) - connection_class = ReplicaSetConnection + connection_class = MongoReplicaSetClient try: _connections[alias] = connection_class(**conn_settings) @@ -137,11 +134,12 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): if alias not in _dbs: conn = get_connection(alias) conn_settings = _connection_settings[alias] - _dbs[alias] = conn[conn_settings['name']] + db = conn[conn_settings['name']] # Authenticate if necessary if conn_settings['username'] and conn_settings['password']: - _dbs[alias].authenticate(conn_settings['username'], - conn_settings['password']) + db.authenticate(conn_settings['username'], + conn_settings['password']) + _dbs[alias] = db return _dbs[alias] diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 76d5fbf..13ed100 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -1,8 +1,10 @@ from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db -from mongoengine.queryset import OperationError, QuerySet +from mongoengine.queryset import QuerySet -__all__ = ("switch_db", "switch_collection", "no_dereference", "query_counter") + +__all__ = ("switch_db", "switch_collection", "no_dereference", + "no_sub_classes", "query_counter") class switch_db(object): @@ -130,6 +132,36 @@ class no_dereference(object): return self.cls +class no_sub_classes(object): + """ no_sub_classes context manager. + + Only returns instances of this class and no sub (inherited) classes:: + + with no_sub_classes(Group) as Group: + Group.objects.find() + + """ + + def __init__(self, cls): + """ Construct the no_sub_classes context manager. + + :param cls: the class to turn querying sub classes on + """ + self.cls = cls + + def __enter__(self): + """ change the objects default and _auto_dereference values""" + self.cls._all_subclasses = self.cls._subclasses + self.cls._subclasses = (self.cls,) + return self.cls + + def __exit__(self, t, value, traceback): + """ Reset the default and _auto_dereference values""" + self.cls._subclasses = self.cls._all_subclasses + delattr(self.cls, '_all_subclasses') + return self.cls + + class QuerySetNoDeRef(QuerySet): """Special no_dereference QuerySet""" def __dereference(items, max_depth=1, instance=None, name=None): @@ -157,7 +189,8 @@ class query_counter(object): def __eq__(self, value): """ == Compare querycounter. """ - return value == self._get_count() + counter = self._get_count() + return value == counter def __ne__(self, value): """ != Compare querycounter. """ @@ -189,6 +222,7 @@ class query_counter(object): def _get_count(self): """ Get the number of queries. """ - count = self.db.system.profile.find().count() - self.counter + ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} + count = self.db.system.profile.find(ignore_query).count() - self.counter self.counter += 1 return count diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index e5e8886..ceda403 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -4,7 +4,7 @@ from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document) from fields import (ReferenceField, ListField, DictField, MapField) from connection import get_db from queryset import QuerySet -from document import Document +from document import Document, EmbeddedDocument class DeReference(object): @@ -33,7 +33,8 @@ class DeReference(object): self.max_depth = max_depth doc_type = None - if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)): + if instance and isinstance(instance, (Document, EmbeddedDocument, + TopLevelDocumentMetaclass)): doc_type = instance._fields.get(name) if hasattr(doc_type, 'field'): doc_type = doc_type.field diff --git a/mongoengine/django/auth.py b/mongoengine/django/auth.py index d22f086..cff4b74 100644 --- a/mongoengine/django/auth.py +++ b/mongoengine/django/auth.py @@ -1,8 +1,7 @@ from mongoengine import * from django.utils.encoding import smart_str -from django.contrib.auth.models import _user_get_all_permissions -from django.contrib.auth.models import _user_has_perm +from django.contrib.auth.models import _user_has_perm, _user_get_all_permissions, _user_has_module_perms from django.db import models from django.contrib.contenttypes.models import ContentTypeManager from django.contrib import auth @@ -38,11 +37,12 @@ from .utils import datetime_now REDIRECT_FIELD_NAME = 'next' + class ContentType(Document): name = StringField(max_length=100) app_label = StringField(max_length=100) model = StringField(max_length=100, verbose_name=_('python model class name'), - unique_with='app_label') + unique_with='app_label') objects = ContentTypeManager() class Meta: @@ -72,9 +72,11 @@ class ContentType(Document): def natural_key(self): return (self.app_label, self.model) + class SiteProfileNotAvailable(Exception): pass + class PermissionManager(models.Manager): def get_by_natural_key(self, codename, app_label, model): return self.get( @@ -82,18 +84,28 @@ class PermissionManager(models.Manager): content_type=ContentType.objects.get_by_natural_key(app_label, model) ) + class Permission(Document): - """The permissions system provides a way to assign permissions to specific users and groups of users. + """The permissions system provides a way to assign permissions to specific + users and groups of users. - The permission system is used by the Django admin site, but may also be useful in your own code. The Django admin site uses permissions as follows: + The permission system is used by the Django admin site, but may also be + useful in your own code. The Django admin site uses permissions as follows: - - The "add" permission limits the user's ability to view the "add" form and add an object. - - The "change" permission limits a user's ability to view the change list, view the "change" form and change an object. + - The "add" permission limits the user's ability to view the "add" + form and add an object. + - The "change" permission limits a user's ability to view the change + list, view the "change" form and change an object. - The "delete" permission limits the ability to delete an object. - Permissions are set globally per type of object, not per specific object instance. It is possible to say "Mary may change news stories," but it's not currently possible to say "Mary may change news stories, but only the ones she created herself" or "Mary may only change news stories that have a certain status or publication date." + Permissions are set globally per type of object, not per specific object + instance. It is possible to say "Mary may change news stories," but it's + not currently possible to say "Mary may change news stories, but only the + ones she created herself" or "Mary may only change news stories that have + a certain status or publication date." - Three basic permissions -- add, change and delete -- are automatically created for each Django model. + Three basic permissions -- add, change and delete -- are automatically + created for each Django model. """ name = StringField(max_length=50, verbose_name=_('username')) content_type = ReferenceField(ContentType) @@ -119,15 +131,24 @@ class Permission(Document): return (self.codename,) + self.content_type.natural_key() natural_key.dependencies = ['contenttypes.contenttype'] + class Group(Document): - """Groups are a generic way of categorizing users to apply permissions, or some other label, to those users. A user can belong to any number of groups. + """Groups are a generic way of categorizing users to apply permissions, + or some other label, to those users. A user can belong to any number of + groups. - A user in a group automatically has all the permissions granted to that group. For example, if the group Site editors has the permission can_edit_home_page, any user in that group will have that permission. + A user in a group automatically has all the permissions granted to that + group. For example, if the group Site editors has the permission + can_edit_home_page, any user in that group will have that permission. - Beyond permissions, groups are a convenient way to categorize users to apply some label, or extended functionality, to them. For example, you could create a group 'Special users', and you could write code that would do special things to those users -- such as giving them access to a members-only portion of your site, or sending them members-only e-mail messages. + Beyond permissions, groups are a convenient way to categorize users to + apply some label, or extended functionality, to them. For example, you + could create a group 'Special users', and you could write code that would + do special things to those users -- such as giving them access to a + members-only portion of your site, or sending them members-only + e-mail messages. """ name = StringField(max_length=80, unique=True, verbose_name=_('name')) - # permissions = models.ManyToManyField(Permission, verbose_name=_('permissions'), blank=True) permissions = ListField(ReferenceField(Permission, verbose_name=_('permissions'), required=False)) class Meta: @@ -137,6 +158,7 @@ class Group(Document): def __unicode__(self): return self.name + class UserManager(models.Manager): def create_user(self, username, email, password=None): """ @@ -154,8 +176,8 @@ class UserManager(models.Manager): email = '@'.join([email_name, domain_part.lower()]) user = self.model(username=username, email=email, is_staff=False, - is_active=True, is_superuser=False, last_login=now, - date_joined=now) + is_active=True, is_superuser=False, last_login=now, + date_joined=now) user.set_password(password) user.save(using=self._db) @@ -177,7 +199,6 @@ class UserManager(models.Manager): return ''.join([choice(allowed_chars) for i in range(length)]) - class User(Document): """A User document that aims to mirror most of the API specified by Django at http://docs.djangoproject.com/en/dev/topics/auth/#users @@ -209,6 +230,9 @@ class User(Document): date_joined = DateTimeField(default=datetime_now, verbose_name=_('date joined')) + USERNAME_FIELD = 'username' + REQUIRED_FIELDS = ['email'] + meta = { 'allow_inheritance': True, 'indexes': [ @@ -248,25 +272,6 @@ class User(Document): """ return check_password(raw_password, self.password) - def get_all_permissions(self, obj=None): - return _user_get_all_permissions(self, obj) - - def has_perm(self, perm, obj=None): - """ - Returns True if the user has the specified permission. This method - queries all available auth backends, but returns immediately if any - backend returns True. Thus, a user who has permission from a single - auth backend is assumed to have permission in general. If an object is - provided, permissions for this specific object are checked. - """ - - # Active superusers have all permissions. - if self.is_active and self.is_superuser: - return True - - # Otherwise we need to check the backends. - return _user_has_perm(self, perm, obj) - @classmethod def create_user(cls, username, password, email=None): """Create (and save) a new user with the given username, password and @@ -289,68 +294,47 @@ class User(Document): user.save() return user - def get_all_permissions(self, obj=None): + def get_group_permissions(self, obj=None): + """ + Returns a list of permission strings that this user has through his/her + groups. This method queries all available auth backends. If an object + is passed in, only permissions matching this object are returned. + """ permissions = set() - anon = self.is_anonymous() for backend in auth.get_backends(): - if not anon or backend.supports_anonymous_user: - if hasattr(backend, "get_all_permissions"): - if obj is not None: - if backend.supports_object_permissions: - permissions.update( - backend.get_all_permissions(user, obj) - ) - else: - permissions.update(backend.get_all_permissions(self)) + if hasattr(backend, "get_group_permissions"): + permissions.update(backend.get_group_permissions(self, obj)) return permissions - def get_and_delete_messages(self): - return [] + def get_all_permissions(self, obj=None): + return _user_get_all_permissions(self, obj) def has_perm(self, perm, obj=None): - anon = self.is_anonymous() - active = self.is_active - for backend in auth.get_backends(): - if (not active and not anon and backend.supports_inactive_user) or \ - (not anon or backend.supports_anonymous_user): - if hasattr(backend, "has_perm"): - if obj is not None: - if (backend.supports_object_permissions and - backend.has_perm(self, perm, obj)): - return True - else: - if backend.has_perm(self, perm): - return True - return False + """ + Returns True if the user has the specified permission. This method + queries all available auth backends, but returns immediately if any + backend returns True. Thus, a user who has permission from a single + auth backend is assumed to have permission in general. If an object is + provided, permissions for this specific object are checked. + """ - def has_perms(self, perm_list, obj=None): - """ - Returns True if the user has each of the specified permissions. - If object is passed, it checks if the user has all required perms - for this object. - """ - for perm in perm_list: - if not self.has_perm(perm, obj): - return False - return True + # Active superusers have all permissions. + if self.is_active and self.is_superuser: + return True + + # Otherwise we need to check the backends. + return _user_has_perm(self, perm, obj) def has_module_perms(self, app_label): - anon = self.is_anonymous() - active = self.is_active - for backend in auth.get_backends(): - if (not active and not anon and backend.supports_inactive_user) or \ - (not anon or backend.supports_anonymous_user): - if hasattr(backend, "has_module_perms"): - if backend.has_module_perms(self, app_label): - return True - return False + """ + Returns True if the user has any permissions in the given app label. + Uses pretty much the same logic as has_perm, above. + """ + # Active superusers have all permissions. + if self.is_active and self.is_superuser: + return True - def get_and_delete_messages(self): - messages = [] - for m in self.message_set.all(): - messages.append(m.message) - m.delete() - return messages + return _user_has_module_perms(self, app_label) def email_user(self, subject, message, from_email=None): "Sends an e-mail to this User." @@ -386,14 +370,6 @@ class User(Document): raise SiteProfileNotAvailable return self._profile_cache - def _get_message_set(self): - import warnings - warnings.warn('The user messaging API is deprecated. Please update' - ' your code to use the new messages framework.', - category=DeprecationWarning) - return self._message_set - message_set = property(_get_message_set) - class MongoEngineBackend(object): """Authenticate using MongoEngine and mongoengine.django.auth.User. diff --git a/mongoengine/django/mongo_auth/__init__.py b/mongoengine/django/mongo_auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mongoengine/django/mongo_auth/models.py b/mongoengine/django/mongo_auth/models.py new file mode 100644 index 0000000..d4947a2 --- /dev/null +++ b/mongoengine/django/mongo_auth/models.py @@ -0,0 +1,107 @@ +from django.conf import settings +from django.contrib.auth.models import UserManager +from django.core.exceptions import ImproperlyConfigured +from django.db import models +from django.utils.importlib import import_module +from django.utils.translation import ugettext_lazy as _ + + +__all__ = ( + 'get_user_document', +) + + +MONGOENGINE_USER_DOCUMENT = getattr( + settings, 'MONGOENGINE_USER_DOCUMENT', 'mongoengine.django.auth.User') + + +def get_user_document(): + """Get the user document class used for authentication. + + This is the class defined in settings.MONGOENGINE_USER_DOCUMENT, which + defaults to `mongoengine.django.auth.User`. + + """ + + name = MONGOENGINE_USER_DOCUMENT + dot = name.rindex('.') + module = import_module(name[:dot]) + return getattr(module, name[dot + 1:]) + + +class MongoUserManager(UserManager): + """A User manager wich allows the use of MongoEngine documents in Django. + + To use the manager, you must tell django.contrib.auth to use MongoUser as + the user model. In you settings.py, you need: + + INSTALLED_APPS = ( + ... + 'django.contrib.auth', + 'mongoengine.django.mongo_auth', + ... + ) + AUTH_USER_MODEL = 'mongo_auth.MongoUser' + + Django will use the model object to access the custom Manager, which will + replace the original queryset with MongoEngine querysets. + + By default, mongoengine.django.auth.User will be used to store users. You + can specify another document class in MONGOENGINE_USER_DOCUMENT in your + settings.py. + + The User Document class has the same requirements as a standard custom user + model: https://docs.djangoproject.com/en/dev/topics/auth/customizing/ + + In particular, the User Document class must define USERNAME_FIELD and + REQUIRED_FIELDS. + + `AUTH_USER_MODEL` has been added in Django 1.5. + + """ + + def contribute_to_class(self, model, name): + super(MongoUserManager, self).contribute_to_class(model, name) + self.dj_model = self.model + self.model = get_user_document() + + self.dj_model.USERNAME_FIELD = self.model.USERNAME_FIELD + username = models.CharField(_('username'), max_length=30, unique=True) + username.contribute_to_class(self.dj_model, self.dj_model.USERNAME_FIELD) + + self.dj_model.REQUIRED_FIELDS = self.model.REQUIRED_FIELDS + for name in self.dj_model.REQUIRED_FIELDS: + field = models.CharField(_(name), max_length=30) + field.contribute_to_class(self.dj_model, name) + + + def get(self, *args, **kwargs): + try: + return self.get_query_set().get(*args, **kwargs) + except self.model.DoesNotExist: + # ModelBackend expects this exception + raise self.dj_model.DoesNotExist + + @property + def db(self): + raise NotImplementedError + + def get_empty_query_set(self): + return self.model.objects.none() + + def get_query_set(self): + return self.model.objects + + +class MongoUser(models.Model): + """"Dummy user model for Django. + + MongoUser is used to replace Django's UserManager with MongoUserManager. + The actual user document class is mongoengine.django.auth.User or any + other document class specified in MONGOENGINE_USER_DOCUMENT. + + To get the user document class, use `get_user_document()`. + + """ + + objects = MongoUserManager() diff --git a/mongoengine/django/sessions.py b/mongoengine/django/sessions.py index 0d199a6..7e4e182 100644 --- a/mongoengine/django/sessions.py +++ b/mongoengine/django/sessions.py @@ -1,7 +1,10 @@ from django.conf import settings from django.contrib.sessions.backends.base import SessionBase, CreateError from django.core.exceptions import SuspiciousOperation -from django.utils.encoding import force_unicode +try: + from django.utils.encoding import force_unicode +except ImportError: + from django.utils.encoding import force_text as force_unicode from mongoengine.document import Document from mongoengine import fields @@ -39,7 +42,7 @@ class MongoSession(Document): 'indexes': [ { 'fields': ['expire_date'], - 'expireAfterSeconds': settings.SESSION_COOKIE_AGE + 'expireAfterSeconds': 0 } ] } @@ -88,7 +91,7 @@ class SessionStore(SessionBase): s.session_data = self._get_session(no_load=must_create) s.expire_date = self.get_expiry_date() try: - s.save(force_insert=must_create, safe=True) + s.save(force_insert=must_create) except OperationError: if must_create: raise CreateError diff --git a/mongoengine/django/storage.py b/mongoengine/django/storage.py index 341455c..9df6f9e 100644 --- a/mongoengine/django/storage.py +++ b/mongoengine/django/storage.py @@ -76,7 +76,7 @@ class GridFSStorage(Storage): """Find the documents in the store with the given name """ docs = self.document.objects - doc = [d for d in docs if getattr(d, self.field).name == name] + doc = [d for d in docs if hasattr(getattr(d, self.field), 'name') and getattr(d, self.field).name == name] if doc: return doc[0] else: diff --git a/mongoengine/document.py b/mongoengine/document.py index 9057075..1bbd7b7 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,11 +1,14 @@ -from __future__ import with_statement import warnings +import hashlib import pymongo import re +from pymongo.read_preferences import ReadPreference +from bson import ObjectId from bson.dbref import DBRef from mongoengine import signals +from mongoengine.common import _import_class from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) @@ -18,6 +21,19 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') +def includes_cls(fields): + """ Helper function used for ensuring and comparing indexes + """ + + first_field = None + if len(fields): + if isinstance(fields[0], basestring): + first_field = fields[0] + elif isinstance(fields[0], (list, tuple)) and len(fields[0]): + first_field = fields[0][0] + return first_field == '_cls' + + class InvalidCollectionError(Exception): pass @@ -53,6 +69,9 @@ class EmbeddedDocument(BaseDocument): return self._data == other._data return False + def __ne__(self, other): + return not self.__eq__(other) + class Document(BaseDocument): """The base class used for defining the structure and properties of @@ -142,7 +161,7 @@ class Document(BaseDocument): options.get('size') != max_size: msg = (('Cannot create collection "%s" as a capped ' 'collection as it already exists') - % cls._collection) + % cls._collection) raise InvalidCollectionError(msg) else: # Create the collection as a capped collection @@ -158,34 +177,30 @@ class Document(BaseDocument): cls.ensure_indexes() return cls._collection - def save(self, safe=True, force_insert=False, validate=True, clean=True, - write_options=None, cascade=None, cascade_kwargs=None, + def save(self, force_insert=False, validate=True, clean=True, + write_concern=None, cascade=None, cascade_kwargs=None, _refs=None, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. - If ``safe=True`` and the operation is unsuccessful, an - :class:`~mongoengine.OperationError` will be raised. - - :param safe: check if the operation succeeded before returning :param force_insert: only try to create a new document, don't allow updates of existing documents :param validate: validates the document; set to ``False`` to skip. :param clean: call the document clean method, requires `validate` to be True. - :param write_options: Extra keyword arguments are passed down to + :param write_concern: Extra keyword arguments are passed down to :meth:`~pymongo.collection.Collection.save` OR :meth:`~pymongo.collection.Collection.insert` which will be used as options for the resultant ``getLastError`` command. For example, - ``save(..., write_options={w: 2, fsync: True}, ...)`` will + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will wait until at least two servers have recorded the write and will force an fsync on the primary server. :param cascade: Sets the flag for cascading saves. You can set a default by setting "cascade" in the document __meta__ - :param cascade_kwargs: optional kwargs dictionary to be passed throw - to cascading saves + :param cascade_kwargs: (optional) kwargs dictionary to be passed throw + to cascading saves. Implies ``cascade=True``. :param _refs: A list of processed references used in cascading saves .. versionchanged:: 0.5 @@ -194,33 +209,35 @@ class Document(BaseDocument): :class:`~bson.dbref.DBRef` objects that have changes are saved as well. .. versionchanged:: 0.6 - Cascade saves are optional = defaults to True, if you want + Added cascading saves + .. versionchanged:: 0.8 + Cascade saves are optional and default to False. If you want fine grain control then you can turn off using document - meta['cascade'] = False Also you can pass different kwargs to + meta['cascade'] = True. Also you can pass different kwargs to the cascade save using cascade_kwargs which overwrites the - existing kwargs with custom values + existing kwargs with custom values. """ signals.pre_save.send(self.__class__, document=self) if validate: self.validate(clean=clean) - if not write_options: - write_options = {} + if write_concern is None: + write_concern = {"w": 1} doc = self.to_mongo() created = ('_id' not in doc or self._created or force_insert) + signals.pre_save_post_validation.send(self.__class__, document=self, created=created) + try: collection = self._get_collection() if created: if force_insert: - object_id = collection.insert(doc, safe=safe, - **write_options) + object_id = collection.insert(doc, **write_concern) else: - object_id = collection.save(doc, safe=safe, - **write_options) + object_id = collection.save(doc, **write_concern) else: object_id = doc['_id'] updates, removals = self._delta() @@ -238,7 +255,6 @@ class Document(BaseDocument): return not updated return created - upsert = self._created update_query = {} if updates: @@ -247,24 +263,23 @@ class Document(BaseDocument): update_query["$unset"] = removals if updates or removals: last_error = collection.update(select_dict, update_query, - upsert=upsert, safe=safe, **write_options) + upsert=True, **write_concern) created = is_new_object(last_error) - warn_cascade = not cascade and 'cascade' not in self._meta - cascade = (self._meta.get('cascade', True) - if cascade is None else cascade) + if cascade is None: + cascade = self._meta.get('cascade', False) or cascade_kwargs is not None + if cascade: kwargs = { - "safe": safe, "force_insert": force_insert, "validate": validate, - "write_options": write_options, + "write_concern": write_concern, "cascade": cascade } if cascade_kwargs: # Allow granular control over cascades kwargs.update(cascade_kwargs) kwargs['_refs'] = _refs - self.cascade_save(warn_cascade=warn_cascade, **kwargs) + self.cascade_save(**kwargs) except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' @@ -283,18 +298,20 @@ class Document(BaseDocument): signals.post_save.send(self.__class__, document=self, created=created) return self - def cascade_save(self, warn_cascade=None, *args, **kwargs): + def cascade_save(self, *args, **kwargs): """Recursively saves any references / generic references on an objects""" - import fields _refs = kwargs.get('_refs', []) or [] + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + for name, cls in self._fields.items(): - if not isinstance(cls, (fields.ReferenceField, - fields.GenericReferenceField)): + if not isinstance(cls, (ReferenceField, + GenericReferenceField)): continue - ref = getattr(self, name) + ref = self._data.get(name) if not ref or isinstance(ref, DBRef): continue @@ -303,10 +320,6 @@ class Document(BaseDocument): ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data)) if ref and ref_id not in _refs: - if warn_cascade: - msg = ("Cascading saves will default to off in 0.8, " - "please explicitly set `.save(cascade=True)`") - warnings.warn(msg, FutureWarning) _refs.append(ref_id) kwargs["_refs"] = _refs ref.save(**kwargs) @@ -339,25 +352,35 @@ class Document(BaseDocument): been saved. """ if not self.pk: - raise OperationError('attempt to update a document not yet saved') + if kwargs.get('upsert', False): + query = self.to_mongo() + if "_cls" in query: + del(query["_cls"]) + return self._qs.filter(**query).update_one(**kwargs) + else: + raise OperationError('attempt to update a document not yet saved') # Need to add shard key to query, or you get an error return self._qs.filter(**self._object_key).update_one(**kwargs) - def delete(self, safe=False): + def delete(self, **write_concern): """Delete the :class:`~mongoengine.Document` from the database. This will only take effect if the document has been previously saved. - :param safe: check if the operation succeeded before returning + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. """ signals.pre_delete.send(self.__class__, document=self) try: - self._qs.filter(**self._object_key).delete(safe=safe) + self._qs.filter(**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) except pymongo.errors.OperationFailure, err: message = u'Could not delete document (%s)' % err.message raise OperationError(message) - signals.post_delete.send(self.__class__, document=self) def switch_db(self, db_alias): @@ -377,7 +400,7 @@ class Document(BaseDocument): """ with switch_db(self.__class__, db_alias) as cls: collection = cls._get_collection() - db = cls._get_db + db = cls._get_db() self._get_collection = lambda: collection self._get_db = lambda: db self._collection = collection @@ -397,7 +420,7 @@ class Document(BaseDocument): user.save() If you need to read from another database see - :class:`~mongoengine.context_managers.switch_collection` + :class:`~mongoengine.context_managers.switch_db` :param collection_name: The database alias to use for saving the document @@ -417,8 +440,8 @@ class Document(BaseDocument): .. versionadded:: 0.5 """ - import dereference - self._data = dereference.DeReference()(self._data, max_depth) + DeReference = _import_class('DeReference') + DeReference()([self], max_depth + 1) return self def reload(self, max_depth=1): @@ -427,20 +450,16 @@ class Document(BaseDocument): .. versionadded:: 0.1.2 .. versionchanged:: 0.6 Now chainable """ - id_field = self._meta['id_field'] - obj = self._qs.filter( - **{id_field: self[id_field]} - ).limit(1).select_related(max_depth=max_depth) + obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( + **self._object_key).limit(1).select_related(max_depth=max_depth) + if obj: obj = obj[0] else: msg = "Reloaded document has been deleted" raise OperationError(msg) - for field in self._fields: + for field in self._fields_ordered: setattr(self, field, self._reload(field, obj[field])) - if self._dynamic: - for name in self._dynamic_fields.keys(): - setattr(self, name, self._reload(name, obj._data[name])) self._changed_fields = obj._changed_fields self._created = False return obj @@ -456,6 +475,7 @@ class Document(BaseDocument): value = [self._reload(key, v) for v in value] value = BaseList(value, self, key) elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): + value._instance = None value._changed_fields = [] return value @@ -516,6 +536,8 @@ class Document(BaseDocument): def ensure_indexes(cls): """Checks the document meta data and ensures all the indexes exist. + Global defaults can be set in the meta - see :doc:`guide/defining-documents` + .. note:: You can disable automatic index creation by setting `auto_create_index` to False in the documents meta data """ @@ -532,15 +554,6 @@ class Document(BaseDocument): # index to service queries against _cls cls_indexed = False - def includes_cls(fields): - first_field = None - if len(fields): - if isinstance(fields[0], basestring): - first_field = fields[0] - elif isinstance(fields[0], (list, tuple)) and len(fields[0]): - first_field = fields[0][0] - return first_field == '_cls' - # Ensure document-defined indexes are created if cls._meta['index_specs']: index_spec = cls._meta['index_specs'] @@ -556,10 +569,94 @@ class Document(BaseDocument): # If _cls is being used (for polymorphism), it needs an index, # only if another index doesn't begin with _cls if (index_cls and not cls_indexed and - cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) == True): + cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): collection.ensure_index('_cls', background=background, **index_opts) + @classmethod + def list_indexes(cls, go_up=True, go_down=True): + """ Lists all of the indexes that should be created for given + collection. It includes all the indexes from super- and sub-classes. + """ + + if cls._meta.get('abstract'): + return [] + + # get all the base classes, subclasses and sieblings + classes = [] + def get_classes(cls): + + if (cls not in classes and + isinstance(cls, TopLevelDocumentMetaclass)): + classes.append(cls) + + for base_cls in cls.__bases__: + if (isinstance(base_cls, TopLevelDocumentMetaclass) and + base_cls != Document and + not base_cls._meta.get('abstract') and + base_cls._get_collection().full_name == cls._get_collection().full_name and + base_cls not in classes): + classes.append(base_cls) + get_classes(base_cls) + for subclass in cls.__subclasses__(): + if (isinstance(base_cls, TopLevelDocumentMetaclass) and + subclass._get_collection().full_name == cls._get_collection().full_name and + subclass not in classes): + classes.append(subclass) + get_classes(subclass) + + get_classes(cls) + + # get the indexes spec for all of the gathered classes + def get_indexes_spec(cls): + indexes = [] + + if cls._meta['index_specs']: + index_spec = cls._meta['index_specs'] + for spec in index_spec: + spec = spec.copy() + fields = spec.pop('fields') + indexes.append(fields) + return indexes + + indexes = [] + for cls in classes: + for index in get_indexes_spec(cls): + if index not in indexes: + indexes.append(index) + + # finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed + if [(u'_id', 1)] not in indexes: + indexes.append([(u'_id', 1)]) + if (cls._meta.get('index_cls', True) and + cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): + indexes.append([(u'_cls', 1)]) + + return indexes + + @classmethod + def compare_indexes(cls): + """ Compares the indexes defined in MongoEngine with the ones existing + in the database. Returns any missing/extra indexes. + """ + + required = cls.list_indexes() + existing = [info['key'] for info in cls._get_collection().index_information().values()] + missing = [index for index in required if index not in existing] + extra = [index for index in existing if index not in required] + + # if { _cls: 1 } is missing, make sure it's *really* necessary + if [(u'_cls', 1)] in missing: + cls_obsolete = False + for index in existing: + if includes_cls(index) and index not in extra: + cls_obsolete = True + break + if cls_obsolete: + missing.remove([(u'_cls', 1)]) + + return {'missing': missing, 'extra': extra} + class DynamicDocument(Document): """A Dynamic Document class allowing flexible, expandable and uncontrolled @@ -567,7 +664,7 @@ class DynamicDocument(Document): way as an ordinary document but has expando style properties. Any data passed or set against the :class:`~mongoengine.DynamicDocument` that is not a field is automatically converted into a - :class:`~mongoengine.DynamicField` and data can be attributed to that + :class:`~mongoengine.fields.DynamicField` and data can be attributed to that field. .. note:: diff --git a/mongoengine/errors.py b/mongoengine/errors.py index 9cfcd1d..4b6b562 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -3,7 +3,9 @@ from collections import defaultdict from mongoengine.python_support import txt_type -__all__ = ('NotRegistered', 'InvalidDocumentError', 'ValidationError') +__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', + 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', + 'OperationError', 'NotUniqueError', 'ValidationError') class NotRegistered(Exception): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 690e7ac..419f2ef 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -4,19 +4,26 @@ import itertools import re import time import urllib2 -import urlparse import uuid import warnings from operator import itemgetter +try: + import dateutil +except ImportError: + dateutil = None +else: + import dateutil.parser + +import pymongo import gridfs from bson import Binary, DBRef, SON, ObjectId from mongoengine.errors import ValidationError from mongoengine.python_support import (PY3, bin_type, txt_type, str_types, StringIO) -from base import (BaseField, ComplexBaseField, ObjectIdField, - get_document, BaseDocument, ALLOW_INHERITANCE) +from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, + get_document, BaseDocument) from queryset import DO_NOTHING, QuerySet from document import Document, EmbeddedDocument from connection import get_db, DEFAULT_CONNECTION_NAME @@ -27,13 +34,16 @@ except ImportError: Image = None ImageOps = None -__all__ = ['StringField', 'IntField', 'LongField', 'FloatField', 'BooleanField', - 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', - 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', - 'DecimalField', 'ComplexDateTimeField', 'URLField', 'DynamicField', - 'GenericReferenceField', 'FileField', 'BinaryField', - 'SortedListField', 'EmailField', 'GeoPointField', 'ImageField', - 'SequenceField', 'UUIDField', 'GenericEmbeddedDocumentField'] +__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField', + 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', + 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', + 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', + 'SortedListField', 'DictField', 'MapField', 'ReferenceField', + 'GenericReferenceField', 'BinaryField', 'GridFSError', + 'GridFSProxy', 'FileField', 'ImageGridFsProxy', + 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', + 'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField'] + RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -104,11 +114,11 @@ class URLField(StringField): """ _URL_REGEX = re.compile( - r'^(?:http|ftp)s?://' # http:// or https:// - r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain... - r'localhost|' #localhost... - r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip - r'(?::\d+)?' # optional port + r'^(?:http|ftp)s?://' # http:// or https:// + r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain... + r'localhost|' # localhost... + r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip + r'(?::\d+)?' # optional port r'(?:/?|[/?]\S+)$', re.IGNORECASE) def __init__(self, verify_exists=False, url_regex=None, **kwargs): @@ -125,8 +135,7 @@ class URLField(StringField): warnings.warn( "The URLField verify_exists argument has intractable security " "and performance issues. Accordingly, it has been deprecated.", - DeprecationWarning - ) + DeprecationWarning) try: request = urllib2.Request(value) urllib2.urlopen(request) @@ -257,30 +266,58 @@ class FloatField(BaseField): class DecimalField(BaseField): """A fixed-point decimal number field. + .. versionchanged:: 0.8 .. versionadded:: 0.3 """ - def __init__(self, min_value=None, max_value=None, **kwargs): - self.min_value, self.max_value = min_value, max_value + def __init__(self, min_value=None, max_value=None, force_string=False, + precision=2, rounding=decimal.ROUND_HALF_UP, **kwargs): + """ + :param min_value: Validation rule for the minimum acceptable value. + :param max_value: Validation rule for the maximum acceptable value. + :param force_string: Store as a string. + :param precision: Number of decimal places to store. + :param rounding: The rounding rule from the python decimal libary: + + - decimal.ROUND_CEILING (towards Infinity) + - decimal.ROUND_DOWN (towards zero) + - decimal.ROUND_FLOOR (towards -Infinity) + - decimal.ROUND_HALF_DOWN (to nearest with ties going towards zero) + - decimal.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer) + - decimal.ROUND_HALF_UP (to nearest with ties going away from zero) + - decimal.ROUND_UP (away from zero) + - decimal.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero) + + Defaults to: ``decimal.ROUND_HALF_UP`` + + """ + self.min_value = min_value + self.max_value = max_value + self.force_string = force_string + self.precision = decimal.Decimal(".%s" % ("0" * precision)) + self.rounding = rounding + super(DecimalField, self).__init__(**kwargs) def to_python(self, value): - original_value = value - if not isinstance(value, basestring): - value = unicode(value) - try: - value = decimal.Decimal(value) - except ValueError: - return original_value - return value + if value is None: + return value + + # Convert to string for python 2.6 before casting to Decimal + value = decimal.Decimal("%s" % value) + return value.quantize(self.precision, rounding=self.rounding) def to_mongo(self, value): - return unicode(value) + if value is None: + return value + if self.force_string: + return unicode(value) + return float(self.to_python(value)) def validate(self, value): if not isinstance(value, decimal.Decimal): if not isinstance(value, basestring): - value = str(value) + value = unicode(value) try: value = decimal.Decimal(value) except Exception, exc: @@ -292,6 +329,9 @@ class DecimalField(BaseField): if self.max_value is not None and value > self.max_value: self.error('Decimal value is too large') + def prepare_query_value(self, op, value): + return self.to_mongo(value) + class BooleanField(BaseField): """A boolean field type. @@ -314,6 +354,11 @@ class BooleanField(BaseField): class DateTimeField(BaseField): """A datetime field. + Uses the python-dateutil library if available alternatively use time.strptime + to parse the dates. Note: python-dateutil's parser is fully featured and when + installed you can utilise it to convert varing types of date formats into valid + python datetime objects. + Note: Microseconds are rounded to the nearest millisecond. Pre UTC microsecond support is effecively broken. Use :class:`~mongoengine.fields.ComplexDateTimeField` if you @@ -321,13 +366,11 @@ class DateTimeField(BaseField): """ def validate(self, value): - if not isinstance(value, (datetime.datetime, datetime.date)): + new_value = self.to_mongo(value) + if not isinstance(new_value, (datetime.datetime, datetime.date)): self.error(u'cannot parse date "%s"' % value) def to_mongo(self, value): - return self.prepare_query_value(None, value) - - def prepare_query_value(self, op, value): if value is None: return value if isinstance(value, datetime.datetime): @@ -337,8 +380,16 @@ class DateTimeField(BaseField): if callable(value): return value() + if not isinstance(value, basestring): + return None + # Attempt to parse a datetime: - # value = smart_str(value) + if dateutil: + try: + return dateutil.parser.parse(value) + except ValueError: + return None + # split usecs, because they are not recognized by strptime. if '.' in value: try: @@ -351,7 +402,7 @@ class DateTimeField(BaseField): kwargs = {'microsecond': usecs} try: # Seconds are optional, so try converting seconds first. return datetime.datetime(*time.strptime(value, - '%Y-%m-%d %H:%M:%S')[:6], **kwargs) + '%Y-%m-%d %H:%M:%S')[:6], **kwargs) except ValueError: try: # Try without seconds. return datetime.datetime(*time.strptime(value, @@ -363,6 +414,9 @@ class DateTimeField(BaseField): except ValueError: return None + def prepare_query_value(self, op, value): + return self.to_mongo(value) + class ComplexDateTimeField(StringField): """ @@ -435,7 +489,7 @@ class ComplexDateTimeField(StringField): def __get__(self, instance, owner): data = super(ComplexDateTimeField, self).__get__(instance, owner) - if data == None: + if data is None: return datetime.datetime.now() if isinstance(data, datetime.datetime): return data @@ -570,7 +624,9 @@ class DynamicField(BaseField): cls = value.__class__ val = value.to_mongo() # If we its a document thats not inherited add _cls - if (isinstance(value, (Document, EmbeddedDocument))): + if (isinstance(value, Document)): + val = {"_ref": value.to_dbref(), "_cls": cls.__name__} + if (isinstance(value, EmbeddedDocument)): val['_cls'] = cls.__name__ return val @@ -591,6 +647,15 @@ class DynamicField(BaseField): value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))] return value + def to_python(self, value): + if isinstance(value, dict) and '_cls' in value: + doc_cls = get_document(value['_cls']) + if '_ref' in value: + value = doc_cls._get_db().dereference(value['_ref']) + return doc_cls._from_son(value) + + return super(DynamicField, self).to_python(value) + def lookup_member(self, member_name): return member_name @@ -624,15 +689,15 @@ class ListField(ComplexBaseField): """Make sure that a list of valid fields is being used. """ if (not isinstance(value, (list, tuple, QuerySet)) or - isinstance(value, basestring)): + isinstance(value, basestring)): self.error('Only lists and tuples may be used in a list field') super(ListField, self).validate(value) def prepare_query_value(self, op, value): if self.field: if op in ('set', 'unset') and (not isinstance(value, basestring) - and not isinstance(value, BaseDocument) - and hasattr(value, '__iter__')): + and not isinstance(value, BaseDocument) + and hasattr(value, '__iter__')): return [self.field.prepare_query_value(op, v) for v in value] return self.field.prepare_query_value(op, value) return super(ListField, self).prepare_query_value(op, value) @@ -667,7 +732,7 @@ class SortedListField(ListField): value = super(SortedListField, self).to_mongo(value) if self._ordering is not None: return sorted(value, key=itemgetter(self._ordering), - reverse=self._order_reverse) + reverse=self._order_reverse) return sorted(value, reverse=self._order_reverse) @@ -715,6 +780,10 @@ class DictField(ComplexBaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) + + if hasattr(self.field, 'field'): + return self.field.prepare_query_value(op, value) + return super(DictField, self).prepare_query_value(op, value) @@ -748,7 +817,7 @@ class ReferenceField(BaseField): * NULLIFY - Updates the reference to null. * CASCADE - Deletes the documents associated with the reference. * DENY - Prevent the deletion of the reference object. - * PULL - Pull the reference from a :class:`~mongoengine.ListField` + * PULL - Pull the reference from a :class:`~mongoengine.fields.ListField` of references Alternative syntax for registering delete rules (useful when implementing @@ -769,7 +838,7 @@ class ReferenceField(BaseField): .. versionchanged:: 0.5 added `reverse_delete_rule` """ - def __init__(self, document_type, dbref=None, + def __init__(self, document_type, dbref=False, reverse_delete_rule=DO_NOTHING, **kwargs): """Initialises the Reference Field. @@ -783,12 +852,7 @@ class ReferenceField(BaseField): self.error('Argument to ReferenceField constructor must be a ' 'document class or a string') - if dbref is None: - msg = ("ReferenceFields will default to using ObjectId " - "in 0.8, set DBRef=True if this isn't desired") - warnings.warn(msg, FutureWarning) - - self.dbref = dbref if dbref is not None else True # To change in 0.8 + self.dbref = dbref self.document_type_obj = document_type self.reverse_delete_rule = reverse_delete_rule super(ReferenceField, self).__init__(**kwargs) @@ -825,8 +889,6 @@ class ReferenceField(BaseField): if not self.dbref: return document.id return document - elif not self.dbref and isinstance(document, basestring): - return document id_field_name = self.document_type._meta['id_field'] id_field = self.document_type._fields[id_field_name] @@ -851,7 +913,7 @@ class ReferenceField(BaseField): """Convert a MongoDB-compatible type to a Python type. """ if (not self.dbref and - not isinstance(value, (DBRef, Document, EmbeddedDocument))): + not isinstance(value, (DBRef, Document, EmbeddedDocument))): collection = self.document_type._get_collection_name() value = DBRef(collection, self.document_type.id.to_python(value)) return value @@ -972,7 +1034,7 @@ class BinaryField(BaseField): if not isinstance(value, (bin_type, txt_type, Binary)): self.error("BinaryField only accepts instances of " "(%s, %s, Binary)" % ( - bin_type.__name__, txt_type.__name__)) + bin_type.__name__, txt_type.__name__)) if self.max_bytes is not None and len(value) > self.max_bytes: self.error('Binary value is too long') @@ -1036,6 +1098,10 @@ class GridFSProxy(object): def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.grid_id) + def __str__(self): + name = getattr(self.get(), 'filename', self.grid_id) if self.get() else '(no file)' + return '<%s: %s>' % (self.__class__.__name__, name) + def __eq__(self, other): if isinstance(other, GridFSProxy): return ((self.grid_id == other.grid_id) and @@ -1143,9 +1209,7 @@ class FileField(BaseField): # Check if a file already exists for this model grid_file = instance._data.get(self.name) if not isinstance(grid_file, self.proxy_class): - grid_file = self.proxy_class(key=self.name, instance=instance, - db_alias=self.db_alias, - collection_name=self.collection_name) + grid_file = self.get_proxy_obj(key=self.name, instance=instance) instance._data[self.name] = grid_file if not grid_file.key: @@ -1165,18 +1229,25 @@ class FileField(BaseField): grid_file.delete() except: pass - # Create a new file with the new data - grid_file.put(value) - else: - # Create a new proxy object as we don't already have one - instance._data[key] = self.proxy_class(key=key, instance=instance, - collection_name=self.collection_name) - instance._data[key].put(value) + + # Create a new proxy object as we don't already have one + instance._data[key] = self.get_proxy_obj(key=key, instance=instance) + instance._data[key].put(value) else: instance._data[key] = value instance._mark_as_changed(key) + def get_proxy_obj(self, key, instance, db_alias=None, collection_name=None): + if db_alias is None: + db_alias = self.db_alias + if collection_name is None: + collection_name = self.collection_name + + return self.proxy_class(key=key, instance=instance, + db_alias=db_alias, + collection_name=collection_name) + def to_mongo(self, value): # Store the GridFS file id in MongoDB if isinstance(value, self.proxy_class) and value.grid_id is not None: @@ -1209,12 +1280,15 @@ class ImageGridFsProxy(GridFSProxy): applying field properties (size, thumbnail_size) """ field = self.instance._fields[self.key] + # Handle nested fields + if hasattr(field, 'field') and isinstance(field.field, FileField): + field = field.field try: img = Image.open(file_obj) img_format = img.format - except: - raise ValidationError('Invalid image') + except Exception, e: + raise ValidationError('Invalid image: %s' % e) if (field.size and (img.size[0] > field.size['width'] or img.size[1] > field.size['height'])): @@ -1235,10 +1309,7 @@ class ImageGridFsProxy(GridFSProxy): size = field.thumbnail_size if size['force']: - thumbnail = ImageOps.fit(img, - (size['width'], - size['height']), - Image.ANTIALIAS) + thumbnail = ImageOps.fit(img, (size['width'], size['height']), Image.ANTIALIAS) else: thumbnail = img.copy() thumbnail.thumbnail((size['width'], @@ -1246,8 +1317,7 @@ class ImageGridFsProxy(GridFSProxy): Image.ANTIALIAS) if thumbnail: - thumb_id = self._put_thumbnail(thumbnail, - img_format) + thumb_id = self._put_thumbnail(thumbnail, img_format) else: thumb_id = None @@ -1283,6 +1353,7 @@ class ImageGridFsProxy(GridFSProxy): height=h, format=format, **kwargs) + @property def size(self): """ @@ -1350,7 +1421,7 @@ class ImageField(FileField): if isinstance(att, (tuple, list)): if PY3: value = dict(itertools.zip_longest(params_size, att, - fillvalue=None)) + fillvalue=None)) else: value = dict(map(None, params_size, att)) @@ -1361,28 +1432,6 @@ class ImageField(FileField): **kwargs) -class GeoPointField(BaseField): - """A list storing a latitude and longitude. - - .. versionadded:: 0.4 - """ - - _geo_index = True - - def validate(self, value): - """Make sure that a geo-value is of type (x, y) - """ - if not isinstance(value, (list, tuple)): - self.error('GeoPointField can only accept tuples or lists ' - 'of (x, y)') - - if not len(value) == 2: - self.error('Value must be a two-dimensional point') - if (not isinstance(value[0], (float, int)) and - not isinstance(value[1], (float, int))): - self.error('Both values in point must be float or int') - - class SequenceField(BaseField): """Provides a sequental counter see: http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers @@ -1408,13 +1457,13 @@ class SequenceField(BaseField): COLLECTION_NAME = 'mongoengine.counters' VALUE_DECORATOR = int - def __init__(self, collection_name=None, db_alias=None, - sequence_name=None, value_decorator=None, *args, **kwargs): + def __init__(self, collection_name=None, db_alias=None, sequence_name=None, + value_decorator=None, *args, **kwargs): self.collection_name = collection_name or self.COLLECTION_NAME self.db_alias = db_alias or DEFAULT_CONNECTION_NAME self.sequence_name = sequence_name self.value_decorator = (callable(value_decorator) and - value_decorator or self.VALUE_DECORATOR) + value_decorator or self.VALUE_DECORATOR) return super(SequenceField, self).__init__(*args, **kwargs) def generate(self): @@ -1430,6 +1479,33 @@ class SequenceField(BaseField): upsert=True) return self.value_decorator(counter['next']) + def set_next_value(self, value): + """Helper method to set the next sequence value""" + sequence_name = self.get_sequence_name() + sequence_id = "%s.%s" % (sequence_name, self.name) + collection = get_db(alias=self.db_alias)[self.collection_name] + counter = collection.find_and_modify(query={"_id": sequence_id}, + update={"$set": {"next": value}}, + new=True, + upsert=True) + return self.value_decorator(counter['next']) + + def get_next_value(self): + """Helper method to get the next value for previewing. + + .. warning:: There is no guarantee this will be the next value + as it is only fixed on set. + """ + sequence_name = self.get_sequence_name() + sequence_id = "%s.%s" % (sequence_name, self.name) + collection = get_db(alias=self.db_alias)[self.collection_name] + data = collection.find_one({"_id": sequence_id}) + + if data: + return self.value_decorator(data['next']+1) + + return self.value_decorator(1) + def get_sequence_name(self): if self.sequence_name: return self.sequence_name @@ -1438,7 +1514,7 @@ class SequenceField(BaseField): return owner._get_collection_name() else: return ''.join('_%s' % c if c.isupper() else c - for c in owner._class_name).strip('_').lower() + for c in owner._class_name).strip('_').lower() def __get__(self, instance, owner): value = super(SequenceField, self).__get__(instance, owner) @@ -1469,19 +1545,15 @@ class UUIDField(BaseField): """ _binary = None - def __init__(self, binary=None, **kwargs): + def __init__(self, binary=True, **kwargs): """ Store UUID data in the database - :param binary: (optional) boolean store as binary. + :param binary: if False store as a string. + .. versionchanged:: 0.8.0 .. versionchanged:: 0.6.19 """ - if binary is None: - binary = False - msg = ("UUIDFields will soon default to store as binary, please " - "configure binary=False if you wish to store as a string") - warnings.warn(msg, FutureWarning) self._binary = binary super(UUIDField, self).__init__(**kwargs) @@ -1499,6 +1571,8 @@ class UUIDField(BaseField): def to_mongo(self, value): if not self._binary: return unicode(value) + elif isinstance(value, basestring): + return uuid.UUID(value) return value def prepare_query_value(self, op, value): @@ -1514,3 +1588,83 @@ class UUIDField(BaseField): value = uuid.UUID(value) except Exception, exc: self.error('Could not convert to UUID: %s' % exc) + + +class GeoPointField(BaseField): + """A list storing a latitude and longitude. + + .. versionadded:: 0.4 + """ + + _geo_index = pymongo.GEO2D + + def validate(self, value): + """Make sure that a geo-value is of type (x, y) + """ + if not isinstance(value, (list, tuple)): + self.error('GeoPointField can only accept tuples or lists ' + 'of (x, y)') + + if not len(value) == 2: + self.error("Value (%s) must be a two-dimensional point" % repr(value)) + elif (not isinstance(value[0], (float, int)) or + not isinstance(value[1], (float, int))): + self.error("Both values (%s) in point must be float or int" % repr(value)) + + +class PointField(GeoJsonBaseField): + """A geo json field storing a latitude and longitude. + + The data is represented as: + + .. code-block:: js + + { "type" : "Point" , + "coordinates" : [x, y]} + + You can either pass a dict with the full information or a list + to set the value. + + Requires mongodb >= 2.4 + .. versionadded:: 0.8 + """ + _type = "Point" + + +class LineStringField(GeoJsonBaseField): + """A geo json field storing a line of latitude and longitude coordinates. + + The data is represented as: + + .. code-block:: js + + { "type" : "LineString" , + "coordinates" : [[x1, y1], [x1, y1] ... [xn, yn]]} + + You can either pass a dict with the full information or a list of points. + + Requires mongodb >= 2.4 + .. versionadded:: 0.8 + """ + _type = "LineString" + + +class PolygonField(GeoJsonBaseField): + """A geo json field storing a polygon of latitude and longitude coordinates. + + The data is represented as: + + .. code-block:: js + + { "type" : "Polygon" , + "coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]], + [[x1, y1], [x1, y1] ... [xn, yn]]} + + You can either pass a dict with the full information or a list + of LineStrings. The first LineString being the outside and the rest being + holes. + + Requires mongodb >= 2.4 + .. versionadded:: 0.8 + """ + _type = "Polygon" diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py new file mode 100644 index 0000000..b4dad0c --- /dev/null +++ b/mongoengine/queryset/base.py @@ -0,0 +1,1494 @@ +from __future__ import absolute_import + +import copy +import itertools +import operator +import pprint +import re +import warnings + +from bson.code import Code +from bson import json_util +import pymongo +from pymongo.common import validate_read_preference + +from mongoengine import signals +from mongoengine.common import _import_class +from mongoengine.base.common import get_document +from mongoengine.errors import (OperationError, NotUniqueError, + InvalidQueryError, LookUpError) + +from mongoengine.queryset import transform +from mongoengine.queryset.field_list import QueryFieldList +from mongoengine.queryset.visitor import Q, QNode + + +__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') + +# Delete rules +DO_NOTHING = 0 +NULLIFY = 1 +CASCADE = 2 +DENY = 3 +PULL = 4 + +RE_TYPE = type(re.compile('')) + + +class BaseQuerySet(object): + """A set of results returned from a query. Wraps a MongoDB cursor, + providing :class:`~mongoengine.Document` objects as the results. + """ + __dereference = False + _auto_dereference = True + + def __init__(self, document, collection): + self._document = document + self._collection_obj = collection + self._mongo_query = None + self._query_obj = Q() + self._initial_query = {} + self._where_clause = None + self._loaded_fields = QueryFieldList() + self._ordering = [] + self._snapshot = False + self._timeout = True + self._class_check = True + self._slave_okay = False + self._read_preference = None + self._iter = False + self._scalar = [] + self._none = False + self._as_pymongo = False + self._as_pymongo_coerce = False + + # If inheritance is allowed, only return instances and instances of + # subclasses of the class being used + if document._meta.get('allow_inheritance') is True: + if len(self._document._subclasses) == 1: + self._initial_query = {"_cls": self._document._subclasses[0]} + else: + self._initial_query = {"_cls": {"$in": self._document._subclasses}} + self._loaded_fields = QueryFieldList(always_include=['_cls']) + self._cursor_obj = None + self._limit = None + self._skip = None + self._hint = -1 # Using -1 as None is a valid value for hint + + def __call__(self, q_obj=None, class_check=True, slave_okay=False, + read_preference=None, **query): + """Filter the selected documents by calling the + :class:`~mongoengine.queryset.QuerySet` with a query. + + :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in + the query; the :class:`~mongoengine.queryset.QuerySet` is filtered + multiple times with different :class:`~mongoengine.queryset.Q` + objects, only the last one will be used + :param class_check: If set to False bypass class name check when + querying collection + :param slave_okay: if True, allows this query to be run against a + replica secondary. + :params read_preference: if set, overrides connection-level + read_preference from `ReplicaSetConnection`. + :param query: Django-style query keyword arguments + """ + query = Q(**query) + if q_obj: + # make sure proper query object is passed + if not isinstance(q_obj, QNode): + msg = ("Not a query object: %s. " + "Did you intend to use key=value?" % q_obj) + raise InvalidQueryError(msg) + query &= q_obj + + if read_preference is None: + queryset = self.clone() + else: + # Use the clone provided when setting read_preference + queryset = self.read_preference(read_preference) + + queryset._query_obj &= query + queryset._mongo_query = None + queryset._cursor_obj = None + queryset._class_check = class_check + + return queryset + + def __getitem__(self, key): + """Support skip and limit using getitem and slicing syntax. + """ + queryset = self.clone() + + # Slice provided + if isinstance(key, slice): + try: + queryset._cursor_obj = queryset._cursor[key] + queryset._skip, queryset._limit = key.start, key.stop + if key.start and key.stop: + queryset._limit = key.stop - key.start + except IndexError, err: + # PyMongo raises an error if key.start == key.stop, catch it, + # bin it, kill it. + start = key.start or 0 + if start >= 0 and key.stop >= 0 and key.step is None: + if start == key.stop: + queryset.limit(0) + queryset._skip = key.start + queryset._limit = key.stop - start + return queryset + raise err + # Allow further QuerySet modifications to be performed + return queryset + # Integer index provided + elif isinstance(key, int): + if queryset._scalar: + return queryset._get_scalar( + queryset._document._from_son(queryset._cursor[key], + _auto_dereference=self._auto_dereference)) + if queryset._as_pymongo: + return queryset._get_as_pymongo(queryset._cursor.next()) + return queryset._document._from_son(queryset._cursor[key], + _auto_dereference=self._auto_dereference) + raise AttributeError + + def __iter__(self): + raise NotImplementedError + + # Core functions + + def all(self): + """Returns all documents.""" + return self.__call__() + + def filter(self, *q_objs, **query): + """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` + """ + return self.__call__(*q_objs, **query) + + def get(self, *q_objs, **query): + """Retrieve the the matching object raising + :class:`~mongoengine.queryset.MultipleObjectsReturned` or + `DocumentName.MultipleObjectsReturned` exception if multiple results + and :class:`~mongoengine.queryset.DoesNotExist` or + `DocumentName.DoesNotExist` if no results are found. + + .. versionadded:: 0.3 + """ + queryset = self.clone() + queryset = queryset.limit(2) + queryset = queryset.filter(*q_objs, **query) + + try: + result = queryset.next() + except StopIteration: + msg = ("%s matching query does not exist." + % queryset._document._class_name) + raise queryset._document.DoesNotExist(msg) + try: + queryset.next() + except StopIteration: + return result + + queryset.rewind() + message = u'%d items returned, instead of 1' % queryset.count() + raise queryset._document.MultipleObjectsReturned(message) + + def create(self, **kwargs): + """Create new object. Returns the saved object instance. + + .. versionadded:: 0.4 + """ + return self._document(**kwargs).save() + + def get_or_create(self, write_concern=None, auto_save=True, + *q_objs, **query): + """Retrieve unique object or create, if it doesn't exist. Returns a + tuple of ``(object, created)``, where ``object`` is the retrieved or + created object and ``created`` is a boolean specifying whether a new + object was created. Raises + :class:`~mongoengine.queryset.MultipleObjectsReturned` or + `DocumentName.MultipleObjectsReturned` if multiple results are found. + A new document will be created if the document doesn't exists; a + dictionary of default values for the new document may be provided as a + keyword argument called :attr:`defaults`. + + .. note:: This requires two separate operations and therefore a + race condition exists. Because there are no transactions in + mongoDB other approaches should be investigated, to ensure you + don't accidently duplicate data when using this method. This is + now scheduled to be removed before 1.0 + + :param write_concern: optional extra keyword arguments used if we + have to create a new document. + Passes any write_concern onto :meth:`~mongoengine.Document.save` + + :param auto_save: if the object is to be saved automatically if + not found. + + .. deprecated:: 0.8 + .. versionchanged:: 0.6 - added `auto_save` + .. versionadded:: 0.3 + """ + msg = ("get_or_create is scheduled to be deprecated. The approach is " + "flawed without transactions. Upserts should be preferred.") + warnings.warn(msg, DeprecationWarning) + + defaults = query.get('defaults', {}) + if 'defaults' in query: + del query['defaults'] + + try: + doc = self.get(*q_objs, **query) + return doc, False + except self._document.DoesNotExist: + query.update(defaults) + doc = self._document(**query) + + if auto_save: + doc.save(write_concern=write_concern) + return doc, True + + def first(self): + """Retrieve the first object matching the query. + """ + queryset = self.clone() + try: + result = queryset[0] + except IndexError: + result = None + return result + + def insert(self, doc_or_docs, load_bulk=True, write_concern=None): + """bulk insert documents + + :param docs_or_doc: a document or list of documents to be inserted + :param load_bulk (optional): If True returns the list of document + instances + :param write_concern: Extra keyword arguments are passed down to + :meth:`~pymongo.collection.Collection.insert` + which will be used as options for the resultant + ``getLastError`` command. For example, + ``insert(..., {w: 2, fsync: True})`` will wait until at least + two servers have recorded the write and will force an fsync on + each server being written to. + + By default returns document instances, set ``load_bulk`` to False to + return just ``ObjectIds`` + + .. versionadded:: 0.5 + """ + Document = _import_class('Document') + + if write_concern is None: + write_concern = {} + + docs = doc_or_docs + return_one = False + if isinstance(docs, Document) or issubclass(docs.__class__, Document): + return_one = True + docs = [docs] + + raw = [] + for doc in docs: + if not isinstance(doc, self._document): + msg = ("Some documents inserted aren't instances of %s" + % str(self._document)) + raise OperationError(msg) + if doc.pk and not doc._created: + msg = "Some documents have ObjectIds use doc.update() instead" + raise OperationError(msg) + raw.append(doc.to_mongo()) + + signals.pre_bulk_insert.send(self._document, documents=docs) + try: + ids = self._collection.insert(raw, **write_concern) + except pymongo.errors.OperationFailure, err: + message = 'Could not save document (%s)' + if re.match('^E1100[01] duplicate key', unicode(err)): + # E11000 - duplicate key error index + # E11001 - duplicate key on update + message = u'Tried to save duplicate unique keys (%s)' + raise NotUniqueError(message % unicode(err)) + raise OperationError(message % unicode(err)) + + if not load_bulk: + signals.post_bulk_insert.send( + self._document, documents=docs, loaded=False) + return return_one and ids[0] or ids + + documents = self.in_bulk(ids) + results = [] + for obj_id in ids: + results.append(documents.get(obj_id)) + signals.post_bulk_insert.send( + self._document, documents=results, loaded=True) + return return_one and results[0] or results + + def count(self, with_limit_and_skip=True): + """Count the selected elements in the query. + + :param with_limit_and_skip (optional): take any :meth:`limit` or + :meth:`skip` that has been applied to this cursor into account when + getting the count + """ + if self._limit == 0 and with_limit_and_skip: + return 0 + return self._cursor.count(with_limit_and_skip=with_limit_and_skip) + + def delete(self, write_concern=None, _from_doc_delete=False): + """Delete the documents matched by the query. + + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param _from_doc_delete: True when called from document delete therefore + signals will have been triggered so don't loop. + """ + queryset = self.clone() + doc = queryset._document + + if write_concern is None: + write_concern = {} + + # Handle deletes where skips or limits have been applied or + # there is an untriggered delete signal + has_delete_signal = signals.signals_available and ( + signals.pre_delete.has_receivers_for(self._document) or + signals.post_delete.has_receivers_for(self._document)) + + call_document_delete = (queryset._skip or queryset._limit or + has_delete_signal) and not _from_doc_delete + + if call_document_delete: + for doc in queryset: + doc.delete(write_concern=write_concern) + return + + delete_rules = doc._meta.get('delete_rules') or {} + # Check for DENY rules before actually deleting/nullifying any other + # references + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == DENY and document_cls.objects( + **{field_name + '__in': self}).count() > 0: + msg = ("Could not delete document (%s.%s refers to it)" + % (document_cls.__name__, field_name)) + raise OperationError(msg) + + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == CASCADE: + ref_q = document_cls.objects(**{field_name + '__in': self}) + ref_q_count = ref_q.count() + if (doc != document_cls and ref_q_count > 0 + or (doc == document_cls and ref_q_count > 0)): + ref_q.delete(write_concern=write_concern) + elif rule == NULLIFY: + document_cls.objects(**{field_name + '__in': self}).update( + write_concern=write_concern, **{'unset__%s' % field_name: 1}) + elif rule == PULL: + document_cls.objects(**{field_name + '__in': self}).update( + write_concern=write_concern, + **{'pull_all__%s' % field_name: self}) + + queryset._collection.remove(queryset._query, write_concern=write_concern) + + def update(self, upsert=False, multi=True, write_concern=None, + full_result=False, **update): + """Perform an atomic update on the fields matched by the query. + + :param upsert: Any existing document with that "_id" is overwritten. + :param multi: Update multiple documents. + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param full_result: Return the full result rather than just the number + updated. + :param update: Django-style update keyword arguments + + .. versionadded:: 0.2 + """ + if not update and not upsert: + raise OperationError("No update parameters, would remove data") + + if write_concern is None: + write_concern = {} + + queryset = self.clone() + query = queryset._query + update = transform.update(queryset._document, **update) + + # If doing an atomic upsert on an inheritable class + # then ensure we add _cls to the update operation + if upsert and '_cls' in query: + if '$set' in update: + update["$set"]["_cls"] = queryset._document._class_name + else: + update["$set"] = {"_cls": queryset._document._class_name} + try: + result = queryset._collection.update(query, update, multi=multi, + upsert=upsert, **write_concern) + if full_result: + return result + elif result: + return result['n'] + except pymongo.errors.OperationFailure, err: + if unicode(err) == u'multi not coded yet': + message = u'update() method requires MongoDB 1.1.3+' + raise OperationError(message) + raise OperationError(u'Update failed (%s)' % unicode(err)) + + def update_one(self, upsert=False, write_concern=None, **update): + """Perform an atomic update on first field matched by the query. + + :param upsert: Any existing document with that "_id" is overwritten. + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param update: Django-style update keyword arguments + + .. versionadded:: 0.2 + """ + return self.update( + upsert=upsert, multi=False, write_concern=write_concern, **update) + + def with_id(self, object_id): + """Retrieve the object matching the id provided. Uses `object_id` only + and raises InvalidQueryError if a filter has been applied. Returns + `None` if no document exists with that id. + + :param object_id: the value for the id of the document to look up + + .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set + """ + queryset = self.clone() + if not queryset._query_obj.empty: + msg = "Cannot use a filter whilst using `with_id`" + raise InvalidQueryError(msg) + return queryset.filter(pk=object_id).first() + + def in_bulk(self, object_ids): + """Retrieve a set of documents by their ids. + + :param object_ids: a list or tuple of ``ObjectId``\ s + :rtype: dict of ObjectIds as keys and collection-specific + Document subclasses as values. + + .. versionadded:: 0.3 + """ + doc_map = {} + + docs = self._collection.find({'_id': {'$in': object_ids}}, + **self._cursor_args) + if self._scalar: + for doc in docs: + doc_map[doc['_id']] = self._get_scalar( + self._document._from_son(doc)) + elif self._as_pymongo: + for doc in docs: + doc_map[doc['_id']] = self._get_as_pymongo(doc) + else: + for doc in docs: + doc_map[doc['_id']] = self._document._from_son(doc) + + return doc_map + + def none(self): + """Helper that just returns a list""" + queryset = self.clone() + queryset._none = True + return queryset + + def no_sub_classes(self): + """ + Only return instances of this document and not any inherited documents + """ + if self._document._meta.get('allow_inheritance') is True: + self._initial_query = {"_cls": self._document._class_name} + + return self + + def clone(self): + """Creates a copy of the current + :class:`~mongoengine.queryset.QuerySet` + + .. versionadded:: 0.5 + """ + return self.clone_into(self.__class__(self._document, self._collection_obj)) + + def clone_into(self, cls): + """Creates a copy of the current + :class:`~mongoengine.queryset.base.BaseQuerySet` into another child class + """ + if not isinstance(cls, BaseQuerySet): + raise OperationError('%s is not a subclass of BaseQuerySet' % cls.__name__) + + copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', + '_where_clause', '_loaded_fields', '_ordering', '_snapshot', + '_timeout', '_class_check', '_slave_okay', '_read_preference', + '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', + '_limit', '_skip', '_hint', '_auto_dereference') + + for prop in copy_props: + val = getattr(self, prop) + setattr(cls, prop, copy.copy(val)) + + if self._cursor_obj: + cls._cursor_obj = self._cursor_obj.clone() + + return cls + + def select_related(self, max_depth=1): + """Handles dereferencing of :class:`~bson.dbref.DBRef` objects or + :class:`~bson.object_id.ObjectId` a maximum depth in order to cut down + the number queries to mongodb. + + .. versionadded:: 0.5 + """ + # Make select related work the same for querysets + max_depth += 1 + queryset = self.clone() + return queryset._dereference(queryset, max_depth=max_depth) + + def limit(self, n): + """Limit the number of returned documents to `n`. This may also be + achieved using array-slicing syntax (e.g. ``User.objects[:5]``). + + :param n: the maximum number of objects to return + """ + queryset = self.clone() + if n == 0: + queryset._cursor.limit(1) + else: + queryset._cursor.limit(n) + queryset._limit = n + # Return self to allow chaining + return queryset + + def skip(self, n): + """Skip `n` documents before returning the results. This may also be + achieved using array-slicing syntax (e.g. ``User.objects[5:]``). + + :param n: the number of objects to skip before returning results + """ + queryset = self.clone() + queryset._cursor.skip(n) + queryset._skip = n + return queryset + + def hint(self, index=None): + """Added 'hint' support, telling Mongo the proper index to use for the + query. + + Judicious use of hints can greatly improve query performance. When + doing a query on multiple fields (at least one of which is indexed) + pass the indexed field as a hint to the query. + + Hinting will not do anything if the corresponding index does not exist. + The last hint applied to this cursor takes precedence over all others. + + .. versionadded:: 0.5 + """ + queryset = self.clone() + queryset._cursor.hint(index) + queryset._hint = index + return queryset + + def distinct(self, field): + """Return a list of distinct values for a given field. + + :param field: the field to select distinct values from + + .. note:: This is a command and won't take ordering or limit into + account. + + .. versionadded:: 0.4 + .. versionchanged:: 0.5 - Fixed handling references + .. versionchanged:: 0.6 - Improved db_field refrence handling + """ + queryset = self.clone() + try: + field = self._fields_to_dbfields([field]).pop() + finally: + return self._dereference(queryset._cursor.distinct(field), 1, + name=field, instance=self._document) + + def only(self, *fields): + """Load only a subset of this document's fields. :: + + post = BlogPost.objects(...).only("title", "author.name") + + .. note :: `only()` is chainable and will perform a union :: + So with the following it will fetch both: `title` and `author.name`:: + + post = BlogPost.objects.only("title").only("author.name") + + :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any + field filters. + + :param fields: fields to include + + .. versionadded:: 0.3 + .. versionchanged:: 0.5 - Added subfield support + """ + fields = dict([(f, QueryFieldList.ONLY) for f in fields]) + return self.fields(True, **fields) + + def exclude(self, *fields): + """Opposite to .only(), exclude some document's fields. :: + + post = BlogPost.objects(...).exclude("comments") + + .. note :: `exclude()` is chainable and will perform a union :: + So with the following it will exclude both: `title` and `author.name`:: + + post = BlogPost.objects.exclude("title").exclude("author.name") + + :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any + field filters. + + :param fields: fields to exclude + + .. versionadded:: 0.5 + """ + fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) + return self.fields(**fields) + + def fields(self, _only_called=False, **kwargs): + """Manipulate how you load this document's fields. Used by `.only()` + and `.exclude()` to manipulate which fields to retrieve. Fields also + allows for a greater level of control for example: + + Retrieving a Subrange of Array Elements: + + You can use the $slice operator to retrieve a subrange of elements in + an array. For example to get the first 5 comments:: + + post = BlogPost.objects(...).fields(slice__comments=5) + + :param kwargs: A dictionary identifying what to include + + .. versionadded:: 0.5 + """ + + # Check for an operator and transform to mongo-style if there is + operators = ["slice"] + cleaned_fields = [] + for key, value in kwargs.items(): + parts = key.split('__') + op = None + if parts[0] in operators: + op = parts.pop(0) + value = {'$' + op: value} + key = '.'.join(parts) + cleaned_fields.append((key, value)) + + fields = sorted(cleaned_fields, key=operator.itemgetter(1)) + queryset = self.clone() + for value, group in itertools.groupby(fields, lambda x: x[1]): + fields = [field for field, value in group] + fields = queryset._fields_to_dbfields(fields) + queryset._loaded_fields += QueryFieldList(fields, value=value, _only_called=_only_called) + + return queryset + + def all_fields(self): + """Include all fields. Reset all previously calls of .only() or + .exclude(). :: + + post = BlogPost.objects.exclude("comments").all_fields() + + .. versionadded:: 0.5 + """ + queryset = self.clone() + queryset._loaded_fields = QueryFieldList( + always_include=queryset._loaded_fields.always_include) + return queryset + + def order_by(self, *keys): + """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The + order may be specified by prepending each of the keys by a + or a -. + Ascending order is assumed. + + :param keys: fields to order the query results by; keys may be + prefixed with **+** or **-** to determine the ordering direction + """ + queryset = self.clone() + queryset._ordering = queryset._get_order_by(keys) + return queryset + + def explain(self, format=False): + """Return an explain plan record for the + :class:`~mongoengine.queryset.QuerySet`\ 's cursor. + + :param format: format the plan before returning it + """ + plan = self._cursor.explain() + if format: + plan = pprint.pformat(plan) + return plan + + def snapshot(self, enabled): + """Enable or disable snapshot mode when querying. + + :param enabled: whether or not snapshot mode is enabled + + ..versionchanged:: 0.5 - made chainable + """ + queryset = self.clone() + queryset._snapshot = enabled + return queryset + + def timeout(self, enabled): + """Enable or disable the default mongod timeout when querying. + + :param enabled: whether or not the timeout is used + + ..versionchanged:: 0.5 - made chainable + """ + queryset = self.clone() + queryset._timeout = enabled + return queryset + + def slave_okay(self, enabled): + """Enable or disable the slave_okay when querying. + + :param enabled: whether or not the slave_okay is enabled + """ + queryset = self.clone() + queryset._slave_okay = enabled + return queryset + + def read_preference(self, read_preference): + """Change the read_preference when querying. + + :param read_preference: override ReplicaSetConnection-level + preference. + """ + validate_read_preference('read_preference', read_preference) + queryset = self.clone() + queryset._read_preference = read_preference + return queryset + + def scalar(self, *fields): + """Instead of returning Document instances, return either a specific + value or a tuple of values in order. + + Can be used along with + :func:`~mongoengine.queryset.QuerySet.no_dereference` to turn off + dereferencing. + + .. note:: This effects all results and can be unset by calling + ``scalar`` without arguments. Calls ``only`` automatically. + + :param fields: One or more fields to return instead of a Document. + """ + queryset = self.clone() + queryset._scalar = list(fields) + + if fields: + queryset = queryset.only(*fields) + else: + queryset = queryset.all_fields() + + return queryset + + def values_list(self, *fields): + """An alias for scalar""" + return self.scalar(*fields) + + def as_pymongo(self, coerce_types=False): + """Instead of returning Document instances, return raw values from + pymongo. + + :param coerce_type: Field types (if applicable) would be use to + coerce types. + """ + queryset = self.clone() + queryset._as_pymongo = True + queryset._as_pymongo_coerce = coerce_types + return queryset + + # JSON Helpers + + def to_json(self, *args, **kwargs): + """Converts a queryset to JSON""" + return json_util.dumps(self.as_pymongo(), *args, **kwargs) + + def from_json(self, json_data): + """Converts json data to unsaved objects""" + son_data = json_util.loads(json_data) + return [self._document._from_son(data) for data in son_data] + + # JS functionality + + def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, + scope=None): + """Perform a map/reduce query using the current query spec + and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, + it must be the last call made, as it does not return a maleable + ``QuerySet``. + + See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` + and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` + tests in ``tests.queryset.QuerySetTest`` for usage examples. + + :param map_f: map function, as :class:`~bson.code.Code` or string + :param reduce_f: reduce function, as + :class:`~bson.code.Code` or string + :param output: output collection name, if set to 'inline' will try to + use :class:`~pymongo.collection.Collection.inline_map_reduce` + This can also be a dictionary containing output options + see: http://docs.mongodb.org/manual/reference/commands/#mapReduce + :param finalize_f: finalize function, an optional function that + performs any post-reduction processing. + :param scope: values to insert into map/reduce global scope. Optional. + :param limit: number of objects from current query to provide + to map/reduce method + + Returns an iterator yielding + :class:`~mongoengine.document.MapReduceDocument`. + + .. note:: + + Map/Reduce changed in server version **>= 1.7.4**. The PyMongo + :meth:`~pymongo.collection.Collection.map_reduce` helper requires + PyMongo version **>= 1.11**. + + .. versionchanged:: 0.5 + - removed ``keep_temp`` keyword argument, which was only relevant + for MongoDB server versions older than 1.7.4 + + .. versionadded:: 0.3 + """ + queryset = self.clone() + + MapReduceDocument = _import_class('MapReduceDocument') + + if not hasattr(self._collection, "map_reduce"): + raise NotImplementedError("Requires MongoDB >= 1.7.1") + + map_f_scope = {} + if isinstance(map_f, Code): + map_f_scope = map_f.scope + map_f = unicode(map_f) + map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) + + reduce_f_scope = {} + if isinstance(reduce_f, Code): + reduce_f_scope = reduce_f.scope + reduce_f = unicode(reduce_f) + reduce_f_code = queryset._sub_js_fields(reduce_f) + reduce_f = Code(reduce_f_code, reduce_f_scope) + + mr_args = {'query': queryset._query} + + if finalize_f: + finalize_f_scope = {} + if isinstance(finalize_f, Code): + finalize_f_scope = finalize_f.scope + finalize_f = unicode(finalize_f) + finalize_f_code = queryset._sub_js_fields(finalize_f) + finalize_f = Code(finalize_f_code, finalize_f_scope) + mr_args['finalize'] = finalize_f + + if scope: + mr_args['scope'] = scope + + if limit: + mr_args['limit'] = limit + + if output == 'inline' and not queryset._ordering: + map_reduce_function = 'inline_map_reduce' + else: + map_reduce_function = 'map_reduce' + mr_args['out'] = output + + results = getattr(queryset._collection, map_reduce_function)( + map_f, reduce_f, **mr_args) + + if map_reduce_function == 'map_reduce': + results = results.find() + + if queryset._ordering: + results = results.sort(queryset._ordering) + + for doc in results: + yield MapReduceDocument(queryset._document, queryset._collection, + doc['_id'], doc['value']) + + def exec_js(self, code, *fields, **options): + """Execute a Javascript function on the server. A list of fields may be + provided, which will be translated to their correct names and supplied + as the arguments to the function. A few extra variables are added to + the function's scope: ``collection``, which is the name of the + collection in use; ``query``, which is an object representing the + current query; and ``options``, which is an object containing any + options specified as keyword arguments. + + As fields in MongoEngine may use different names in the database (set + using the :attr:`db_field` keyword argument to a :class:`Field` + constructor), a mechanism exists for replacing MongoEngine field names + with the database field names in Javascript code. When accessing a + field, use square-bracket notation, and prefix the MongoEngine field + name with a tilde (~). + + :param code: a string of Javascript code to execute + :param fields: fields that you will be using in your function, which + will be passed in to your function as arguments + :param options: options that you want available to the function + (accessed in Javascript through the ``options`` object) + """ + queryset = self.clone() + + code = queryset._sub_js_fields(code) + + fields = [queryset._document._translate_field_name(f) for f in fields] + collection = queryset._document._get_collection_name() + + scope = { + 'collection': collection, + 'options': options or {}, + } + + query = queryset._query + if queryset._where_clause: + query['$where'] = queryset._where_clause + + scope['query'] = query + code = Code(code, scope=scope) + + db = queryset._document._get_db() + return db.eval(code, *fields) + + def where(self, where_clause): + """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript + expression). Performs automatic field name substitution like + :meth:`mongoengine.queryset.Queryset.exec_js`. + + .. note:: When using this mode of query, the database will call your + function, or evaluate your predicate clause, for each object + in the collection. + + .. versionadded:: 0.5 + """ + queryset = self.clone() + where_clause = queryset._sub_js_fields(where_clause) + queryset._where_clause = where_clause + return queryset + + def sum(self, field): + """Sum over the values of the specified field. + + :param field: the field to sum over; use dot-notation to refer to + embedded document fields + + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. + """ + map_func = """ + function() { + var path = '{{~%(field)s}}'.split('.'), + field = this; + + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; + } + + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(1, item||0); + }); + } else if (typeof field != 'undefined') { + emit(1, field||0); + } + } + """ % dict(field=field) + + reduce_func = Code(""" + function(key, values) { + var sum = 0; + for (var i in values) { + sum += values[i]; + } + return sum; + } + """) + + for result in self.map_reduce(map_func, reduce_func, output='inline'): + return result.value + else: + return 0 + + def average(self, field): + """Average over the values of the specified field. + + :param field: the field to average over; use dot-notation to refer to + embedded document fields + + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. + """ + map_func = """ + function() { + var path = '{{~%(field)s}}'.split('.'), + field = this; + + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; + } + + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(1, {t: item||0, c: 1}); + }); + } else if (typeof field != 'undefined') { + emit(1, {t: field||0, c: 1}); + } + } + """ % dict(field=field) + + reduce_func = Code(""" + function(key, values) { + var out = {t: 0, c: 0}; + for (var i in values) { + var value = values[i]; + out.t += value.t; + out.c += value.c; + } + return out; + } + """) + + finalize_func = Code(""" + function(key, value) { + return value.t / value.c; + } + """) + + for result in self.map_reduce(map_func, reduce_func, + finalize_f=finalize_func, output='inline'): + return result.value + else: + return 0 + + def item_frequencies(self, field, normalize=False, map_reduce=True): + """Returns a dictionary of all items present in a field across + the whole queried set of documents, and their corresponding frequency. + This is useful for generating tag clouds, or searching documents. + + .. note:: + + Can only do direct simple mappings and cannot map across + :class:`~mongoengine.fields.ReferenceField` or + :class:`~mongoengine.fields.GenericReferenceField` for more complex + counting a manual map reduce call would is required. + + If the field is a :class:`~mongoengine.fields.ListField`, the items within + each list will be counted individually. + + :param field: the field to use + :param normalize: normalize the results so they add to 1.0 + :param map_reduce: Use map_reduce over exec_js + + .. versionchanged:: 0.5 defaults to map_reduce and can handle embedded + document lookups + """ + if map_reduce: + return self._item_frequencies_map_reduce(field, + normalize=normalize) + return self._item_frequencies_exec_js(field, normalize=normalize) + + # Iterator helpers + + def next(self): + """Wrap the result in a :class:`~mongoengine.Document` object. + """ + if self._limit == 0 or self._none: + raise StopIteration + + raw_doc = self._cursor.next() + if self._as_pymongo: + return self._get_as_pymongo(raw_doc) + doc = self._document._from_son(raw_doc, + _auto_dereference=self._auto_dereference) + if self._scalar: + return self._get_scalar(doc) + + return doc + + def rewind(self): + """Rewind the cursor to its unevaluated state. + + .. versionadded:: 0.3 + """ + self._iter = False + self._cursor.rewind() + + # Properties + + @property + def _collection(self): + """Property that returns the collection object. This allows us to + perform operations only if the collection is accessed. + """ + return self._collection_obj + + @property + def _cursor_args(self): + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout + } + if self._read_preference is not None: + cursor_args['read_preference'] = self._read_preference + else: + cursor_args['slave_okay'] = self._slave_okay + if self._loaded_fields: + cursor_args['fields'] = self._loaded_fields.as_dict() + return cursor_args + + @property + def _cursor(self): + if self._cursor_obj is None: + + self._cursor_obj = self._collection.find(self._query, + **self._cursor_args) + # Apply where clauses to cursor + if self._where_clause: + where_clause = self._sub_js_fields(self._where_clause) + self._cursor_obj.where(where_clause) + + if self._ordering: + # Apply query ordering + self._cursor_obj.sort(self._ordering) + elif self._document._meta['ordering']: + # Otherwise, apply the ordering from the document model + order = self._get_order_by(self._document._meta['ordering']) + self._cursor_obj.sort(order) + + if self._limit is not None: + self._cursor_obj.limit(self._limit) + + if self._skip is not None: + self._cursor_obj.skip(self._skip) + + if self._hint != -1: + self._cursor_obj.hint(self._hint) + + return self._cursor_obj + + def __deepcopy__(self, memo): + """Essential for chained queries with ReferenceFields involved""" + return self.clone() + + @property + def _query(self): + if self._mongo_query is None: + self._mongo_query = self._query_obj.to_query(self._document) + if self._class_check: + self._mongo_query.update(self._initial_query) + return self._mongo_query + + @property + def _dereference(self): + if not self.__dereference: + self.__dereference = _import_class('DeReference')() + return self.__dereference + + def no_dereference(self): + """Turn off any dereferencing for the results of this queryset. + """ + queryset = self.clone() + queryset._auto_dereference = False + return queryset + + # Helper Functions + + def _item_frequencies_map_reduce(self, field, normalize=False): + map_func = """ + function() { + var path = '{{~%(field)s}}'.split('.'); + var field = this; + + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(item, 1); + }); + } else if (typeof field != 'undefined') { + emit(field, 1); + } else { + emit(null, 1); + } + } + """ % dict(field=field) + reduce_func = """ + function(key, values) { + var total = 0; + var valuesSize = values.length; + for (var i=0; i < valuesSize; i++) { + total += parseInt(values[i], 10); + } + return total; + } + """ + values = self.map_reduce(map_func, reduce_func, 'inline') + frequencies = {} + for f in values: + key = f.key + if isinstance(key, float): + if int(key) == key: + key = int(key) + frequencies[key] = int(f.value) + + if normalize: + count = sum(frequencies.values()) + frequencies = dict([(k, float(v) / count) + for k, v in frequencies.items()]) + + return frequencies + + def _item_frequencies_exec_js(self, field, normalize=False): + """Uses exec_js to execute""" + freq_func = """ + function(path) { + var path = path.split('.'); + + var total = 0.0; + db[collection].find(query).forEach(function(doc) { + var field = doc; + for (p in path) { + if (field) + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + total += field.length; + } else { + total++; + } + }); + + var frequencies = {}; + var types = {}; + var inc = 1.0; + + db[collection].find(query).forEach(function(doc) { + field = doc; + for (p in path) { + if (field) + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + field.forEach(function(item) { + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); + }); + } else { + var item = field; + types[item] = item; + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); + } + }); + return [total, frequencies, types]; + } + """ + total, data, types = self.exec_js(freq_func, field) + values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) + + if normalize: + values = dict([(k, float(v) / total) for k, v in values.items()]) + + frequencies = {} + for k, v in values.iteritems(): + if isinstance(k, float): + if int(k) == k: + k = int(k) + + frequencies[k] = v + + return frequencies + + def _fields_to_dbfields(self, fields, subdoc=False): + """Translate fields paths to its db equivalents""" + ret = [] + subclasses = [] + document = self._document + if document._meta['allow_inheritance']: + subclasses = [get_document(x) + for x in document._subclasses][1:] + for field in fields: + try: + field = ".".join(f.db_field for f in + document._lookup_field(field.split('.'))) + ret.append(field) + except LookUpError, err: + found = False + for subdoc in subclasses: + try: + subfield = ".".join(f.db_field for f in + subdoc._lookup_field(field.split('.'))) + ret.append(subfield) + found = True + break + except LookUpError, e: + pass + + if not found: + raise err + return ret + + def _get_order_by(self, keys): + """Creates a list of order by fields + """ + key_list = [] + for key in keys: + if not key: + continue + direction = pymongo.ASCENDING + if key[0] == '-': + direction = pymongo.DESCENDING + if key[0] in ('-', '+'): + key = key[1:] + key = key.replace('__', '.') + try: + key = self._document._translate_field_name(key) + except: + pass + key_list.append((key, direction)) + + if self._cursor_obj: + self._cursor_obj.sort(key_list) + return key_list + + def _get_scalar(self, doc): + + def lookup(obj, name): + chunks = name.split('__') + for chunk in chunks: + obj = getattr(obj, chunk) + return obj + + data = [lookup(doc, n) for n in self._scalar] + if len(data) == 1: + return data[0] + + return tuple(data) + + def _get_as_pymongo(self, row): + # Extract which fields paths we should follow if .fields(...) was + # used. If not, handle all fields. + if not getattr(self, '__as_pymongo_fields', None): + self.__as_pymongo_fields = [] + + for field in self._loaded_fields.fields - set(['_cls']): + self.__as_pymongo_fields.append(field) + while '.' in field: + field, _ = field.rsplit('.', 1) + self.__as_pymongo_fields.append(field) + + all_fields = not self.__as_pymongo_fields + + def clean(data, path=None): + path = path or '' + + if isinstance(data, dict): + new_data = {} + for key, value in data.iteritems(): + new_path = '%s.%s' % (path, key) if path else key + + if all_fields: + include_field = True + elif self._loaded_fields.value == QueryFieldList.ONLY: + include_field = new_path in self.__as_pymongo_fields + else: + include_field = new_path not in self.__as_pymongo_fields + + if include_field: + new_data[key] = clean(value, path=new_path) + data = new_data + elif isinstance(data, list): + data = [clean(d, path=path) for d in data] + else: + if self._as_pymongo_coerce: + # If we need to coerce types, we need to determine the + # type of this field and use the corresponding + # .to_python(...) + from mongoengine.fields import EmbeddedDocumentField + obj = self._document + for chunk in path.split('.'): + obj = getattr(obj, chunk, None) + if obj is None: + break + elif isinstance(obj, EmbeddedDocumentField): + obj = obj.document_type + if obj and data is not None: + data = obj.to_python(data) + return data + return clean(row) + + def _sub_js_fields(self, code): + """When fields are specified with [~fieldname] syntax, where + *fieldname* is the Python name of a field, *fieldname* will be + substituted for the MongoDB name of the field (specified using the + :attr:`name` keyword argument in a field's constructor). + """ + def field_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return u'["%s"]' % fields[-1].db_field + + def field_path_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return ".".join([f.db_field for f in fields]) + + code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) + code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, + code) + return code + + # Deprecated + def ensure_index(self, **kwargs): + """Deprecated use :func:`Document.ensure_index`""" + msg = ("Doc.objects()._ensure_index() is deprecated. " + "Use Doc.ensure_index() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_index(**kwargs) + return self + + def _ensure_indexes(self): + """Deprecated use :func:`~Document.ensure_indexes`""" + msg = ("Doc.objects()._ensure_indexes() is deprecated. " + "Use Doc.ensure_indexes() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_indexes() \ No newline at end of file diff --git a/mongoengine/queryset/field_list.py b/mongoengine/queryset/field_list.py index 7b2b0cb..140a71e 100644 --- a/mongoengine/queryset/field_list.py +++ b/mongoengine/queryset/field_list.py @@ -7,11 +7,20 @@ class QueryFieldList(object): ONLY = 1 EXCLUDE = 0 - def __init__(self, fields=[], value=ONLY, always_include=[]): + def __init__(self, fields=None, value=ONLY, always_include=None, _only_called=False): + """The QueryFieldList builder + + :param fields: A list of fields used in `.only()` or `.exclude()` + :param value: How to handle the fields; either `ONLY` or `EXCLUDE` + :param always_include: Any fields to always_include eg `_cls` + :param _only_called: Has `.only()` been called? If so its a set of fields + otherwise it performs a union. + """ self.value = value - self.fields = set(fields) - self.always_include = set(always_include) + self.fields = set(fields or []) + self.always_include = set(always_include or []) self._id = None + self._only_called = _only_called self.slice = {} def __add__(self, f): @@ -26,7 +35,10 @@ class QueryFieldList(object): self.slice = {} elif self.value is self.ONLY and f.value is self.ONLY: self._clean_slice() - self.fields = self.fields.intersection(f.fields) + if self._only_called: + self.fields = self.fields.union(f.fields) + else: + self.fields = f.fields elif self.value is self.EXCLUDE and f.value is self.EXCLUDE: self.fields = self.fields.union(f.fields) self._clean_slice() @@ -43,9 +55,13 @@ class QueryFieldList(object): if self.always_include: if self.value is self.ONLY and self.fields: - self.fields = self.fields.union(self.always_include) + if sorted(self.slice.keys()) != sorted(self.fields): + self.fields = self.fields.union(self.always_include) else: self.fields -= self.always_include + + if getattr(f, '_only_called', False): + self._only_called = True return self def __nonzero__(self): diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 15c8e63..1437e76 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -1,161 +1,133 @@ -from __future__ import absolute_import +from mongoengine.errors import OperationError +from mongoengine.queryset.base import (BaseQuerySet, DO_NOTHING, NULLIFY, + CASCADE, DENY, PULL) -import copy -import itertools -import operator -import pprint -import re -import warnings - -from bson.code import Code -from bson import json_util -import pymongo -from pymongo.common import validate_read_preference - -from mongoengine import signals -from mongoengine.common import _import_class -from mongoengine.errors import (OperationError, NotUniqueError, - InvalidQueryError) - -from mongoengine.queryset import transform -from mongoengine.queryset.field_list import QueryFieldList -from mongoengine.queryset.visitor import Q, QNode - - -__all__ = ('QuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') +__all__ = ('QuerySet', 'QuerySetNoCache', 'DO_NOTHING', 'NULLIFY', 'CASCADE', + 'DENY', 'PULL') # The maximum number of items to display in a QuerySet.__repr__ REPR_OUTPUT_SIZE = 20 - -# Delete rules -DO_NOTHING = 0 -NULLIFY = 1 -CASCADE = 2 -DENY = 3 -PULL = 4 - -RE_TYPE = type(re.compile('')) +ITER_CHUNK_SIZE = 100 -class QuerySet(object): - """A set of results returned from a query. Wraps a MongoDB cursor, - providing :class:`~mongoengine.Document` objects as the results. +class QuerySet(BaseQuerySet): + """The default queryset, that builds queries and handles a set of results + returned from a query. + + Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as + the results. """ - __dereference = False - _auto_dereference = True - def __init__(self, document, collection): - self._document = document - self._collection_obj = collection - self._mongo_query = None - self._query_obj = Q() - self._initial_query = {} - self._where_clause = None - self._loaded_fields = QueryFieldList() - self._ordering = [] - self._snapshot = False - self._timeout = True - self._class_check = True - self._slave_okay = False - self._read_preference = None - self._iter = False - self._scalar = [] - self._none = False - self._as_pymongo = False - self._as_pymongo_coerce = False - - # If inheritance is allowed, only return instances and instances of - # subclasses of the class being used - if document._meta.get('allow_inheritance') == True: - self._initial_query = {"_cls": {"$in": self._document._subclasses}} - self._loaded_fields = QueryFieldList(always_include=['_cls']) - self._cursor_obj = None - self._limit = None - self._skip = None - self._slice = None - self._hint = -1 # Using -1 as None is a valid value for hint - - def __call__(self, q_obj=None, class_check=True, slave_okay=False, - read_preference=None, **query): - """Filter the selected documents by calling the - :class:`~mongoengine.queryset.QuerySet` with a query. - - :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in - the query; the :class:`~mongoengine.queryset.QuerySet` is filtered - multiple times with different :class:`~mongoengine.queryset.Q` - objects, only the last one will be used - :param class_check: If set to False bypass class name check when - querying collection - :param slave_okay: if True, allows this query to be run against a - replica secondary. - :params read_preference: if set, overrides connection-level - read_preference from `ReplicaSetConnection`. - :param query: Django-style query keyword arguments - """ - query = Q(**query) - if q_obj: - # make sure proper query object is passed - if not isinstance(q_obj, QNode): - msg = ("Not a query object: %s. " - "Did you intend to use key=value?" % q_obj) - raise InvalidQueryError(msg) - query &= q_obj - - queryset = self.clone() - queryset._query_obj &= query - queryset._mongo_query = None - queryset._cursor_obj = None - if read_preference is not None: - queryset.read_preference(read_preference) - queryset._class_check = class_check - return queryset + _has_more = True + _len = None + _result_cache = None def __iter__(self): - """Support iterator protocol""" - queryset = self - if queryset._iter: - queryset = self.clone() - queryset.rewind() - return queryset + """Iteration utilises a results cache which iterates the cursor + in batches of ``ITER_CHUNK_SIZE``. + + If ``self._has_more`` the cursor hasn't been exhausted so cache then + batch. Otherwise iterate the result_cache. + """ + self._iter = True + if self._has_more: + return self._iter_results() + + # iterating over the cache. + return iter(self._result_cache) def __len__(self): - return self.count() - - def __getitem__(self, key): - """Support skip and limit using getitem and slicing syntax. + """Since __len__ is called quite frequently (for example, as part of + list(qs) we populate the result cache and cache the length. """ - queryset = self.clone() + if self._len is not None: + return self._len + if self._has_more: + # populate the cache + list(self._iter_results()) - # Slice provided - if isinstance(key, slice): + self._len = len(self._result_cache) + return self._len + + def __repr__(self): + """Provides the string representation of the QuerySet + """ + if self._iter: + return '.. queryset mid-iteration ..' + + self._populate_cache() + data = self._result_cache[:REPR_OUTPUT_SIZE + 1] + if len(data) > REPR_OUTPUT_SIZE: + data[-1] = "...(remaining elements truncated)..." + return repr(data) + + + def _iter_results(self): + """A generator for iterating over the result cache. + + Also populates the cache if there are more possible results to yield. + Raises StopIteration when there are no more results""" + if self._result_cache is None: + self._result_cache = [] + pos = 0 + while True: + upper = len(self._result_cache) + while pos < upper: + yield self._result_cache[pos] + pos = pos + 1 + if not self._has_more: + raise StopIteration + if len(self._result_cache) <= pos: + self._populate_cache() + + def _populate_cache(self): + """ + Populates the result cache with ``ITER_CHUNK_SIZE`` more entries + (until the cursor is exhausted). + """ + if self._result_cache is None: + self._result_cache = [] + if self._has_more: try: - queryset._cursor_obj = queryset._cursor[key] - queryset._slice = key - queryset._skip, queryset._limit = key.start, key.stop - except IndexError, err: - # PyMongo raises an error if key.start == key.stop, catch it, - # bin it, kill it. - start = key.start or 0 - if start >= 0 and key.stop >= 0 and key.step is None: - if start == key.stop: - queryset.limit(0) - queryset._skip = key.start - queryset._limit = key.stop - start - return queryset - raise err - # Allow further QuerySet modifications to be performed - return queryset - # Integer index provided - elif isinstance(key, int): - if queryset._scalar: - return queryset._get_scalar( - queryset._document._from_son(queryset._cursor[key], - _auto_dereference=self._auto_dereference)) - if queryset._as_pymongo: - return queryset._get_as_pymongo(queryset._cursor.next()) - return queryset._document._from_son(queryset._cursor[key], - _auto_dereference=self._auto_dereference) - raise AttributeError + for i in xrange(ITER_CHUNK_SIZE): + self._result_cache.append(self.next()) + except StopIteration: + self._has_more = False + + def count(self, with_limit_and_skip=True): + """Count the selected elements in the query. + + :param with_limit_and_skip (optional): take any :meth:`limit` or + :meth:`skip` that has been applied to this cursor into account when + getting the count + """ + if with_limit_and_skip is False: + return super(QuerySet, self).count(with_limit_and_skip) + + if self._len is None: + self._len = super(QuerySet, self).count(with_limit_and_skip) + + return self._len + + def no_cache(self): + """Convert to a non_caching queryset + + .. versionadded:: 0.8.3 Convert to non caching queryset + """ + if self._result_cache is not None: + raise OperationError("QuerySet already cached") + return self.clone_into(QuerySetNoCache(self._document, self._collection)) + + +class QuerySetNoCache(BaseQuerySet): + """A non caching QuerySet""" + + def cache(self): + """Convert to a caching queryset + + .. versionadded:: 0.8.3 Convert to caching queryset + """ + return self.clone_into(QuerySet(self._document, self._collection)) def __repr__(self): """Provides the string representation of the QuerySet @@ -177,1236 +149,9 @@ class QuerySet(object): self.rewind() return repr(data) - # Core functions - - def all(self): - """Returns all documents.""" - return self.__call__() - - def filter(self, *q_objs, **query): - """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` - """ - return self.__call__(*q_objs, **query) - - def get(self, *q_objs, **query): - """Retrieve the the matching object raising - :class:`~mongoengine.queryset.MultipleObjectsReturned` or - `DocumentName.MultipleObjectsReturned` exception if multiple results - and :class:`~mongoengine.queryset.DoesNotExist` or - `DocumentName.DoesNotExist` if no results are found. - - .. versionadded:: 0.3 - """ - queryset = self.__call__(*q_objs, **query) - queryset = queryset.limit(2) - try: - result = queryset.next() - except StopIteration: - msg = ("%s matching query does not exist." - % queryset._document._class_name) - raise queryset._document.DoesNotExist(msg) - try: - queryset.next() - except StopIteration: - return result - + def __iter__(self): + queryset = self + if queryset._iter: + queryset = self.clone() queryset.rewind() - message = u'%d items returned, instead of 1' % queryset.count() - raise queryset._document.MultipleObjectsReturned(message) - - def create(self, **kwargs): - """Create new object. Returns the saved object instance. - - .. versionadded:: 0.4 - """ - return self._document(**kwargs).save() - - def get_or_create(self, write_options=None, auto_save=True, - *q_objs, **query): - """Retrieve unique object or create, if it doesn't exist. Returns a - tuple of ``(object, created)``, where ``object`` is the retrieved or - created object and ``created`` is a boolean specifying whether a new - object was created. Raises - :class:`~mongoengine.queryset.MultipleObjectsReturned` or - `DocumentName.MultipleObjectsReturned` if multiple results are found. - A new document will be created if the document doesn't exists; a - dictionary of default values for the new document may be provided as a - keyword argument called :attr:`defaults`. - - .. note:: This requires two separate operations and therefore a - race condition exists. Because there are no transactions in - mongoDB other approaches should be investigated, to ensure you - don't accidently duplicate data when using this method. This is - now scheduled to be removed before 1.0 - - :param write_options: optional extra keyword arguments used if we - have to create a new document. - Passes any write_options onto :meth:`~mongoengine.Document.save` - - :param auto_save: if the object is to be saved automatically if - not found. - - .. deprecated:: 0.8 - .. versionchanged:: 0.6 - added `auto_save` - .. versionadded:: 0.3 - """ - msg = ("get_or_create is scheduled to be deprecated. The approach is " - "flawed without transactions. Upserts should be preferred.") - warnings.warn(msg, DeprecationWarning) - - defaults = query.get('defaults', {}) - if 'defaults' in query: - del query['defaults'] - - try: - doc = self.get(*q_objs, **query) - return doc, False - except self._document.DoesNotExist: - query.update(defaults) - doc = self._document(**query) - - if auto_save: - doc.save(write_options=write_options) - return doc, True - - def first(self): - """Retrieve the first object matching the query. - """ - queryset = self.clone() - try: - result = queryset[0] - except IndexError: - result = None - return result - - def insert(self, doc_or_docs, load_bulk=True, safe=False, - write_options=None): - """bulk insert documents - - If ``safe=True`` and the operation is unsuccessful, an - :class:`~mongoengine.OperationError` will be raised. - - :param docs_or_doc: a document or list of documents to be inserted - :param load_bulk (optional): If True returns the list of document - instances - :param safe: check if the operation succeeded before returning - :param write_options: Extra keyword arguments are passed down to - :meth:`~pymongo.collection.Collection.insert` - which will be used as options for the resultant - ``getLastError`` command. For example, - ``insert(..., {w: 2, fsync: True})`` will wait until at least - two servers have recorded the write and will force an fsync on - each server being written to. - - By default returns document instances, set ``load_bulk`` to False to - return just ``ObjectIds`` - - .. versionadded:: 0.5 - """ - Document = _import_class('Document') - - if not write_options: - write_options = {} - write_options.update({'safe': safe}) - - docs = doc_or_docs - return_one = False - if isinstance(docs, Document) or issubclass(docs.__class__, Document): - return_one = True - docs = [docs] - - raw = [] - for doc in docs: - if not isinstance(doc, self._document): - msg = ("Some documents inserted aren't instances of %s" - % str(self._document)) - raise OperationError(msg) - if doc.pk and not doc._created: - msg = "Some documents have ObjectIds use doc.update() instead" - raise OperationError(msg) - raw.append(doc.to_mongo()) - - signals.pre_bulk_insert.send(self._document, documents=docs) - try: - ids = self._collection.insert(raw, **write_options) - except pymongo.errors.OperationFailure, err: - message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', unicode(err)): - # E11000 - duplicate key error index - # E11001 - duplicate key on update - message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) - raise OperationError(message % unicode(err)) - - if not load_bulk: - signals.post_bulk_insert.send( - self._document, documents=docs, loaded=False) - return return_one and ids[0] or ids - - documents = self.in_bulk(ids) - results = [] - for obj_id in ids: - results.append(documents.get(obj_id)) - signals.post_bulk_insert.send( - self._document, documents=results, loaded=True) - return return_one and results[0] or results - - def count(self): - """Count the selected elements in the query. - """ - if self._limit == 0: - return 0 - return self._cursor.count(with_limit_and_skip=True) - - def delete(self, safe=False): - """Delete the documents matched by the query. - - :param safe: check if the operation succeeded before returning - """ - queryset = self.clone() - doc = queryset._document - - has_delete_signal = signals.signals_available and ( - signals.pre_delete.has_receivers_for(self._document) or - signals.post_delete.has_receivers_for(self._document)) - - # Handle deletes where skips or limits have been applied or has a - # delete signal - if queryset._skip or queryset._limit or has_delete_signal: - for doc in queryset: - doc.delete(safe=safe) - return - - delete_rules = doc._meta.get('delete_rules') or {} - # Check for DENY rules before actually deleting/nullifying any other - # references - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == DENY and document_cls.objects( - **{field_name + '__in': self}).count() > 0: - msg = ("Could not delete document (%s.%s refers to it)" - % (document_cls.__name__, field_name)) - raise OperationError(msg) - - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == CASCADE: - ref_q = document_cls.objects(**{field_name + '__in': self}) - ref_q_count = ref_q.count() - if (doc != document_cls and ref_q_count > 0 - or (doc == document_cls and ref_q_count > 0)): - ref_q.delete(safe=safe) - elif rule == NULLIFY: - document_cls.objects(**{field_name + '__in': self}).update( - safe_update=safe, - **{'unset__%s' % field_name: 1}) - elif rule == PULL: - document_cls.objects(**{field_name + '__in': self}).update( - safe_update=safe, - **{'pull_all__%s' % field_name: self}) - - queryset._collection.remove(queryset._query, safe=safe) - - def update(self, safe_update=True, upsert=False, multi=True, - write_options=None, **update): - """Perform an atomic update on the fields matched by the query. When - ``safe_update`` is used, the number of affected documents is returned. - - :param safe_update: check if the operation succeeded before returning - :param upsert: Any existing document with that "_id" is overwritten. - :param write_options: extra keyword arguments for - :meth:`~pymongo.collection.Collection.update` - - .. versionadded:: 0.2 - """ - if not update: - raise OperationError("No update parameters, would remove data") - - if not write_options: - write_options = {} - - queryset = self.clone() - query = queryset._query - update = transform.update(queryset._document, **update) - - # If doing an atomic upsert on an inheritable class - # then ensure we add _cls to the update operation - if upsert and '_cls' in query: - if '$set' in update: - update["$set"]["_cls"] = queryset._document._class_name - else: - update["$set"] = {"_cls": queryset._document._class_name} - - try: - ret = queryset._collection.update(query, update, multi=multi, - upsert=upsert, safe=safe_update, - **write_options) - if ret is not None and 'n' in ret: - return ret['n'] - except pymongo.errors.OperationFailure, err: - if unicode(err) == u'multi not coded yet': - message = u'update() method requires MongoDB 1.1.3+' - raise OperationError(message) - raise OperationError(u'Update failed (%s)' % unicode(err)) - - def update_one(self, safe_update=True, upsert=False, write_options=None, - **update): - """Perform an atomic update on first field matched by the query. When - ``safe_update`` is used, the number of affected documents is returned. - - :param safe_update: check if the operation succeeded before returning - :param upsert: Any existing document with that "_id" is overwritten. - :param write_options: extra keyword arguments for - :meth:`~pymongo.collection.Collection.update` - :param update: Django-style update keyword arguments - - .. versionadded:: 0.2 - """ - return self.update(safe_update=True, upsert=upsert, multi=False, - write_options=None, **update) - - def with_id(self, object_id): - """Retrieve the object matching the id provided. Uses `object_id` only - and raises InvalidQueryError if a filter has been applied. Returns - `None` if no document exists with that id. - - :param object_id: the value for the id of the document to look up - - .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set - """ - queryset = self.clone() - if not queryset._query_obj.empty: - msg = "Cannot use a filter whilst using `with_id`" - raise InvalidQueryError(msg) - return queryset.filter(pk=object_id).first() - - def in_bulk(self, object_ids): - """Retrieve a set of documents by their ids. - - :param object_ids: a list or tuple of ``ObjectId``\ s - :rtype: dict of ObjectIds as keys and collection-specific - Document subclasses as values. - - .. versionadded:: 0.3 - """ - doc_map = {} - - docs = self._collection.find({'_id': {'$in': object_ids}}, - **self._cursor_args) - if self._scalar: - for doc in docs: - doc_map[doc['_id']] = self._get_scalar( - self._document._from_son(doc)) - elif self._as_pymongo: - for doc in docs: - doc_map[doc['_id']] = self._get_as_pymongo(doc) - else: - for doc in docs: - doc_map[doc['_id']] = self._document._from_son(doc) - - return doc_map - - def none(self): - """Helper that just returns a list""" - queryset = self.clone() - queryset._none = True return queryset - - def clone(self): - """Creates a copy of the current - :class:`~mongoengine.queryset.QuerySet` - - .. versionadded:: 0.5 - """ - c = self.__class__(self._document, self._collection_obj) - - copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', - '_where_clause', '_loaded_fields', '_ordering', '_snapshot', - '_timeout', '_class_check', '_slave_okay', '_read_preference', - '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', - '_limit', '_skip', '_hint', '_auto_dereference') - - for prop in copy_props: - val = getattr(self, prop) - setattr(c, prop, copy.copy(val)) - - if self._slice: - c._slice = self._slice - - if self._cursor_obj: - c._cursor_obj = self._cursor_obj.clone() - - if self._slice: - c._cursor[self._slice] - - return c - - def select_related(self, max_depth=1): - """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to - a maximum depth in order to cut down the number queries to mongodb. - - .. versionadded:: 0.5 - """ - # Make select related work the same for querysets - max_depth += 1 - queryset = self.clone() - return queryset._dereference(queryset, max_depth=max_depth) - - def limit(self, n): - """Limit the number of returned documents to `n`. This may also be - achieved using array-slicing syntax (e.g. ``User.objects[:5]``). - - :param n: the maximum number of objects to return - """ - queryset = self.clone() - if n == 0: - queryset._cursor.limit(1) - else: - queryset._cursor.limit(n) - queryset._limit = n - - # Return self to allow chaining - return queryset - - def skip(self, n): - """Skip `n` documents before returning the results. This may also be - achieved using array-slicing syntax (e.g. ``User.objects[5:]``). - - :param n: the number of objects to skip before returning results - """ - queryset = self.clone() - queryset._cursor.skip(n) - queryset._skip = n - return queryset - - def hint(self, index=None): - """Added 'hint' support, telling Mongo the proper index to use for the - query. - - Judicious use of hints can greatly improve query performance. When - doing a query on multiple fields (at least one of which is indexed) - pass the indexed field as a hint to the query. - - Hinting will not do anything if the corresponding index does not exist. - The last hint applied to this cursor takes precedence over all others. - - .. versionadded:: 0.5 - """ - queryset = self.clone() - queryset._cursor.hint(index) - queryset._hint = index - return queryset - - def distinct(self, field): - """Return a list of distinct values for a given field. - - :param field: the field to select distinct values from - - .. versionadded:: 0.4 - .. versionchanged:: 0.5 - Fixed handling references - .. versionchanged:: 0.6 - Improved db_field refrence handling - """ - queryset = self.clone() - try: - field = self._fields_to_dbfields([field]).pop() - finally: - return self._dereference(queryset._cursor.distinct(field), 1, - name=field, instance=self._document) - - def only(self, *fields): - """Load only a subset of this document's fields. :: - - post = BlogPost.objects(...).only("title", "author.name") - - :param fields: fields to include - - .. versionadded:: 0.3 - .. versionchanged:: 0.5 - Added subfield support - """ - fields = dict([(f, QueryFieldList.ONLY) for f in fields]) - return self.fields(**fields) - - def exclude(self, *fields): - """Opposite to .only(), exclude some document's fields. :: - - post = BlogPost.objects(...).exclude("comments") - - :param fields: fields to exclude - - .. versionadded:: 0.5 - """ - fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) - return self.fields(**fields) - - def fields(self, **kwargs): - """Manipulate how you load this document's fields. Used by `.only()` - and `.exclude()` to manipulate which fields to retrieve. Fields also - allows for a greater level of control for example: - - Retrieving a Subrange of Array Elements: - - You can use the $slice operator to retrieve a subrange of elements in - an array. For example to get the first 5 comments:: - - post = BlogPost.objects(...).fields(slice__comments=5) - - :param kwargs: A dictionary identifying what to include - - .. versionadded:: 0.5 - """ - - # Check for an operator and transform to mongo-style if there is - operators = ["slice"] - cleaned_fields = [] - for key, value in kwargs.items(): - parts = key.split('__') - op = None - if parts[0] in operators: - op = parts.pop(0) - value = {'$' + op: value} - key = '.'.join(parts) - cleaned_fields.append((key, value)) - - fields = sorted(cleaned_fields, key=operator.itemgetter(1)) - queryset = self.clone() - for value, group in itertools.groupby(fields, lambda x: x[1]): - fields = [field for field, value in group] - fields = queryset._fields_to_dbfields(fields) - queryset._loaded_fields += QueryFieldList(fields, value=value) - return queryset - - def all_fields(self): - """Include all fields. Reset all previously calls of .only() or - .exclude(). :: - - post = BlogPost.objects.exclude("comments").all_fields() - - .. versionadded:: 0.5 - """ - queryset = self.clone() - queryset._loaded_fields = QueryFieldList( - always_include=queryset._loaded_fields.always_include) - return queryset - - def order_by(self, *keys): - """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The - order may be specified by prepending each of the keys by a + or a -. - Ascending order is assumed. - - :param keys: fields to order the query results by; keys may be - prefixed with **+** or **-** to determine the ordering direction - """ - queryset = self.clone() - queryset._ordering = queryset._get_order_by(keys) - return queryset - - def explain(self, format=False): - """Return an explain plan record for the - :class:`~mongoengine.queryset.QuerySet`\ 's cursor. - - :param format: format the plan before returning it - """ - plan = self._cursor.explain() - if format: - plan = pprint.pformat(plan) - return plan - - def snapshot(self, enabled): - """Enable or disable snapshot mode when querying. - - :param enabled: whether or not snapshot mode is enabled - - ..versionchanged:: 0.5 - made chainable - """ - queryset = self.clone() - queryset._snapshot = enabled - return queryset - - def timeout(self, enabled): - """Enable or disable the default mongod timeout when querying. - - :param enabled: whether or not the timeout is used - - ..versionchanged:: 0.5 - made chainable - """ - queryset = self.clone() - queryset._timeout = enabled - return queryset - - def slave_okay(self, enabled): - """Enable or disable the slave_okay when querying. - - :param enabled: whether or not the slave_okay is enabled - """ - queryset = self.clone() - queryset._slave_okay = enabled - return queryset - - def read_preference(self, read_preference): - """Change the read_preference when querying. - - :param read_preference: override ReplicaSetConnection-level - preference. - """ - validate_read_preference('read_preference', read_preference) - queryset = self.clone() - queryset._read_preference = read_preference - return queryset - - def scalar(self, *fields): - """Instead of returning Document instances, return either a specific - value or a tuple of values in order. - - Can be used along with - :func:`~mongoengine.queryset.QuerySet.no_dereference` to turn off - dereferencing. - - .. note:: This effects all results and can be unset by calling - ``scalar`` without arguments. Calls ``only`` automatically. - - :param fields: One or more fields to return instead of a Document. - """ - queryset = self.clone() - queryset._scalar = list(fields) - - if fields: - queryset = queryset.only(*fields) - else: - queryset = queryset.all_fields() - - return queryset - - def values_list(self, *fields): - """An alias for scalar""" - return self.scalar(*fields) - - def as_pymongo(self, coerce_types=False): - """Instead of returning Document instances, return raw values from - pymongo. - - :param coerce_type: Field types (if applicable) would be use to - coerce types. - """ - queryset = self.clone() - queryset._as_pymongo = True - queryset._as_pymongo_coerce = coerce_types - return queryset - - # JSON Helpers - - def to_json(self): - """Converts a queryset to JSON""" - queryset = self.clone() - return json_util.dumps(queryset._collection_obj.find(queryset._query)) - - def from_json(self, json_data): - """Converts json data to unsaved objects""" - son_data = json_util.loads(json_data) - return [self._document._from_son(data) for data in son_data] - - # JS functionality - - def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, - scope=None): - """Perform a map/reduce query using the current query spec - and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, - it must be the last call made, as it does not return a maleable - ``QuerySet``. - - See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` - and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` - tests in ``tests.queryset.QuerySetTest`` for usage examples. - - :param map_f: map function, as :class:`~bson.code.Code` or string - :param reduce_f: reduce function, as - :class:`~bson.code.Code` or string - :param output: output collection name, if set to 'inline' will try to - use :class:`~pymongo.collection.Collection.inline_map_reduce` - This can also be a dictionary containing output options - see: http://docs.mongodb.org/manual/reference/commands/#mapReduce - :param finalize_f: finalize function, an optional function that - performs any post-reduction processing. - :param scope: values to insert into map/reduce global scope. Optional. - :param limit: number of objects from current query to provide - to map/reduce method - - Returns an iterator yielding - :class:`~mongoengine.document.MapReduceDocument`. - - .. note:: - - Map/Reduce changed in server version **>= 1.7.4**. The PyMongo - :meth:`~pymongo.collection.Collection.map_reduce` helper requires - PyMongo version **>= 1.11**. - - .. versionchanged:: 0.5 - - removed ``keep_temp`` keyword argument, which was only relevant - for MongoDB server versions older than 1.7.4 - - .. versionadded:: 0.3 - """ - queryset = self.clone() - - MapReduceDocument = _import_class('MapReduceDocument') - - if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.7.1") - - map_f_scope = {} - if isinstance(map_f, Code): - map_f_scope = map_f.scope - map_f = unicode(map_f) - map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) - - reduce_f_scope = {} - if isinstance(reduce_f, Code): - reduce_f_scope = reduce_f.scope - reduce_f = unicode(reduce_f) - reduce_f_code = queryset._sub_js_fields(reduce_f) - reduce_f = Code(reduce_f_code, reduce_f_scope) - - mr_args = {'query': queryset._query} - - if finalize_f: - finalize_f_scope = {} - if isinstance(finalize_f, Code): - finalize_f_scope = finalize_f.scope - finalize_f = unicode(finalize_f) - finalize_f_code = queryset._sub_js_fields(finalize_f) - finalize_f = Code(finalize_f_code, finalize_f_scope) - mr_args['finalize'] = finalize_f - - if scope: - mr_args['scope'] = scope - - if limit: - mr_args['limit'] = limit - - if output == 'inline' and not queryset._ordering: - map_reduce_function = 'inline_map_reduce' - else: - map_reduce_function = 'map_reduce' - mr_args['out'] = output - - results = getattr(queryset._collection, map_reduce_function)( - map_f, reduce_f, **mr_args) - - if map_reduce_function == 'map_reduce': - results = results.find() - - if queryset._ordering: - results = results.sort(queryset._ordering) - - for doc in results: - yield MapReduceDocument(queryset._document, queryset._collection, - doc['_id'], doc['value']) - - def exec_js(self, code, *fields, **options): - """Execute a Javascript function on the server. A list of fields may be - provided, which will be translated to their correct names and supplied - as the arguments to the function. A few extra variables are added to - the function's scope: ``collection``, which is the name of the - collection in use; ``query``, which is an object representing the - current query; and ``options``, which is an object containing any - options specified as keyword arguments. - - As fields in MongoEngine may use different names in the database (set - using the :attr:`db_field` keyword argument to a :class:`Field` - constructor), a mechanism exists for replacing MongoEngine field names - with the database field names in Javascript code. When accessing a - field, use square-bracket notation, and prefix the MongoEngine field - name with a tilde (~). - - :param code: a string of Javascript code to execute - :param fields: fields that you will be using in your function, which - will be passed in to your function as arguments - :param options: options that you want available to the function - (accessed in Javascript through the ``options`` object) - """ - queryset = self.clone() - - code = queryset._sub_js_fields(code) - - fields = [queryset._document._translate_field_name(f) for f in fields] - collection = queryset._document._get_collection_name() - - scope = { - 'collection': collection, - 'options': options or {}, - } - - query = queryset._query - if queryset._where_clause: - query['$where'] = queryset._where_clause - - scope['query'] = query - code = Code(code, scope=scope) - - db = queryset._document._get_db() - return db.eval(code, *fields) - - def where(self, where_clause): - """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript - expression). Performs automatic field name substitution like - :meth:`mongoengine.queryset.Queryset.exec_js`. - - .. note:: When using this mode of query, the database will call your - function, or evaluate your predicate clause, for each object - in the collection. - - .. versionadded:: 0.5 - """ - queryset = self.clone() - where_clause = queryset._sub_js_fields(where_clause) - queryset._where_clause = where_clause - return queryset - - def sum(self, field): - """Sum over the values of the specified field. - - :param field: the field to sum over; use dot-notation to refer to - embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. - """ - map_func = Code(""" - function() { - emit(1, this[field] || 0); - } - """, scope={'field': field}) - - reduce_func = Code(""" - function(key, values) { - var sum = 0; - for (var i in values) { - sum += values[i]; - } - return sum; - } - """) - - for result in self.map_reduce(map_func, reduce_func, output='inline'): - return result.value - else: - return 0 - - def average(self, field): - """Average over the values of the specified field. - - :param field: the field to average over; use dot-notation to refer to - embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. - """ - map_func = Code(""" - function() { - if (this.hasOwnProperty(field)) - emit(1, {t: this[field] || 0, c: 1}); - } - """, scope={'field': field}) - - reduce_func = Code(""" - function(key, values) { - var out = {t: 0, c: 0}; - for (var i in values) { - var value = values[i]; - out.t += value.t; - out.c += value.c; - } - return out; - } - """) - - finalize_func = Code(""" - function(key, value) { - return value.t / value.c; - } - """) - - for result in self.map_reduce(map_func, reduce_func, - finalize_f=finalize_func, output='inline'): - return result.value - else: - return 0 - - def item_frequencies(self, field, normalize=False, map_reduce=True): - """Returns a dictionary of all items present in a field across - the whole queried set of documents, and their corresponding frequency. - This is useful for generating tag clouds, or searching documents. - - .. note:: - - Can only do direct simple mappings and cannot map across - :class:`~mongoengine.ReferenceField` or - :class:`~mongoengine.GenericReferenceField` for more complex - counting a manual map reduce call would is required. - - If the field is a :class:`~mongoengine.ListField`, the items within - each list will be counted individually. - - :param field: the field to use - :param normalize: normalize the results so they add to 1.0 - :param map_reduce: Use map_reduce over exec_js - - .. versionchanged:: 0.5 defaults to map_reduce and can handle embedded - document lookups - """ - if map_reduce: - return self._item_frequencies_map_reduce(field, - normalize=normalize) - return self._item_frequencies_exec_js(field, normalize=normalize) - - # Iterator helpers - - def next(self): - """Wrap the result in a :class:`~mongoengine.Document` object. - """ - self._iter = True - try: - if self._limit == 0 or self._none: - raise StopIteration - if self._scalar: - return self._get_scalar(self._document._from_son( - self._cursor.next())) - if self._as_pymongo: - return self._get_as_pymongo(self._cursor.next()) - - return self._document._from_son(self._cursor.next()) - except StopIteration, e: - self.rewind() - raise e - - def rewind(self): - """Rewind the cursor to its unevaluated state. - - .. versionadded:: 0.3 - """ - self._iter = False - self._cursor.rewind() - - # Properties - - @property - def _collection(self): - """Property that returns the collection object. This allows us to - perform operations only if the collection is accessed. - """ - return self._collection_obj - - @property - def _cursor_args(self): - cursor_args = { - 'snapshot': self._snapshot, - 'timeout': self._timeout - } - if self._read_preference is not None: - cursor_args['read_preference'] = self._read_preference - else: - cursor_args['slave_okay'] = self._slave_okay - if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields.as_dict() - return cursor_args - - @property - def _cursor(self): - if self._cursor_obj is None: - - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) - # Apply where clauses to cursor - if self._where_clause: - where_clause = self._sub_js_fields(self._where_clause) - self._cursor_obj.where(where_clause) - - if self._ordering: - # Apply query ordering - self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: - # Otherwise, apply the ordering from the document model - order = self._get_order_by(self._document._meta['ordering']) - self._cursor_obj.sort(order) - - if self._limit is not None: - self._cursor_obj.limit(self._limit - (self._skip or 0)) - - if self._skip is not None: - self._cursor_obj.skip(self._skip) - - if self._hint != -1: - self._cursor_obj.hint(self._hint) - - return self._cursor_obj - - def __deepcopy__(self, memo): - """Essential for chained queries with ReferenceFields involved""" - return self.clone() - - @property - def _query(self): - if self._mongo_query is None: - self._mongo_query = self._query_obj.to_query(self._document) - if self._class_check: - self._mongo_query.update(self._initial_query) - return self._mongo_query - - @property - def _dereference(self): - if not self.__dereference: - self.__dereference = _import_class('DeReference')() - return self.__dereference - - def no_dereference(self): - """Turn off any dereferencing for the results of this queryset. - """ - queryset = self.clone() - queryset._auto_dereference = False - return queryset - - # Helper Functions - - def _item_frequencies_map_reduce(self, field, normalize=False): - map_func = """ - function() { - var path = '{{~%(field)s}}'.split('.'); - var field = this; - - for (p in path) { - if (typeof field != 'undefined') - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - field.forEach(function(item) { - emit(item, 1); - }); - } else if (typeof field != 'undefined') { - emit(field, 1); - } else { - emit(null, 1); - } - } - """ % dict(field=field) - reduce_func = """ - function(key, values) { - var total = 0; - var valuesSize = values.length; - for (var i=0; i < valuesSize; i++) { - total += parseInt(values[i], 10); - } - return total; - } - """ - values = self.map_reduce(map_func, reduce_func, 'inline') - frequencies = {} - for f in values: - key = f.key - if isinstance(key, float): - if int(key) == key: - key = int(key) - frequencies[key] = int(f.value) - - if normalize: - count = sum(frequencies.values()) - frequencies = dict([(k, float(v) / count) - for k, v in frequencies.items()]) - - return frequencies - - def _item_frequencies_exec_js(self, field, normalize=False): - """Uses exec_js to execute""" - freq_func = """ - function(path) { - var path = path.split('.'); - - var total = 0.0; - db[collection].find(query).forEach(function(doc) { - var field = doc; - for (p in path) { - if (field) - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - total += field.length; - } else { - total++; - } - }); - - var frequencies = {}; - var types = {}; - var inc = 1.0; - - db[collection].find(query).forEach(function(doc) { - field = doc; - for (p in path) { - if (field) - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - field.forEach(function(item) { - frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); - }); - } else { - var item = field; - types[item] = item; - frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); - } - }); - return [total, frequencies, types]; - } - """ - total, data, types = self.exec_js(freq_func, field) - values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) - - if normalize: - values = dict([(k, float(v) / total) for k, v in values.items()]) - - frequencies = {} - for k, v in values.iteritems(): - if isinstance(k, float): - if int(k) == k: - k = int(k) - - frequencies[k] = v - - return frequencies - - def _fields_to_dbfields(self, fields): - """Translate fields paths to its db equivalents""" - ret = [] - for field in fields: - field = ".".join(f.db_field for f in - self._document._lookup_field(field.split('.'))) - ret.append(field) - return ret - - def _get_order_by(self, keys): - """Creates a list of order by fields - """ - key_list = [] - for key in keys: - if not key: - continue - direction = pymongo.ASCENDING - if key[0] == '-': - direction = pymongo.DESCENDING - if key[0] in ('-', '+'): - key = key[1:] - key = key.replace('__', '.') - try: - key = self._document._translate_field_name(key) - except: - pass - key_list.append((key, direction)) - - if self._cursor_obj: - self._cursor_obj.sort(key_list) - return key_list - - def _get_scalar(self, doc): - - def lookup(obj, name): - chunks = name.split('__') - for chunk in chunks: - obj = getattr(obj, chunk) - return obj - - data = [lookup(doc, n) for n in self._scalar] - if len(data) == 1: - return data[0] - - return tuple(data) - - def _get_as_pymongo(self, row): - # Extract which fields paths we should follow if .fields(...) was - # used. If not, handle all fields. - if not getattr(self, '__as_pymongo_fields', None): - self.__as_pymongo_fields = [] - for field in self._loaded_fields.fields - set(['_cls', '_id']): - self.__as_pymongo_fields.append(field) - while '.' in field: - field, _ = field.rsplit('.', 1) - self.__as_pymongo_fields.append(field) - - all_fields = not self.__as_pymongo_fields - - def clean(data, path=None): - path = path or '' - - if isinstance(data, dict): - new_data = {} - for key, value in data.iteritems(): - new_path = '%s.%s' % (path, key) if path else key - if all_fields or new_path in self.__as_pymongo_fields: - new_data[key] = clean(value, path=new_path) - data = new_data - elif isinstance(data, list): - data = [clean(d, path=path) for d in data] - else: - if self._as_pymongo_coerce: - # If we need to coerce types, we need to determine the - # type of this field and use the corresponding - # .to_python(...) - from mongoengine.fields import EmbeddedDocumentField - obj = self._document - for chunk in path.split('.'): - obj = getattr(obj, chunk, None) - if obj is None: - break - elif isinstance(obj, EmbeddedDocumentField): - obj = obj.document_type - if obj and data is not None: - data = obj.to_python(data) - return data - return clean(row) - - def _sub_js_fields(self, code): - """When fields are specified with [~fieldname] syntax, where - *fieldname* is the Python name of a field, *fieldname* will be - substituted for the MongoDB name of the field (specified using the - :attr:`name` keyword argument in a field's constructor). - """ - def field_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = self._document._lookup_field(field_name) - # Substitute the correct name for the field into the javascript - return u'["%s"]' % fields[-1].db_field - - def field_path_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = self._document._lookup_field(field_name) - # Substitute the correct name for the field into the javascript - return ".".join([f.db_field for f in fields]) - - code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) - code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, - code) - return code - - # Deprecated - - def ensure_index(self, **kwargs): - """Deprecated use :func:`~Document.ensure_index`""" - msg = ("Doc.objects()._ensure_index() is deprecated. " - "Use Doc.ensure_index() instead.") - warnings.warn(msg, DeprecationWarning) - self._document.__class__.ensure_index(**kwargs) - return self - - def _ensure_indexes(self): - """Deprecated use :func:`~Document.ensure_indexes`""" - msg = ("Doc.objects()._ensure_indexes() is deprecated. " - "Use Doc.ensure_indexes() instead.") - warnings.warn(msg, DeprecationWarning) - self._document.__class__.ensure_indexes() diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 71f12e3..2ee7e38 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -1,5 +1,6 @@ from collections import defaultdict +import pymongo from bson import SON from mongoengine.common import _import_class @@ -9,10 +10,12 @@ __all__ = ('query', 'update') COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists', 'not') + 'all', 'size', 'exists', 'not') GEO_OPERATORS = ('within_distance', 'within_spherical_distance', 'within_box', 'within_polygon', 'near', 'near_sphere', - 'max_distance') + 'max_distance', 'geo_within', 'geo_within_box', + 'geo_within_polygon', 'geo_within_center', + 'geo_within_sphere', 'geo_intersects') STRING_OPERATORS = ('contains', 'icontains', 'startswith', 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact') @@ -21,7 +24,8 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + STRING_OPERATORS + CUSTOM_OPERATORS) UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', - 'push_all', 'pull', 'pull_all', 'add_to_set') + 'push_all', 'pull', 'pull_all', 'add_to_set', + 'set_on_insert') def query(_doc_cls=None, _field_operation=False, **query): @@ -39,11 +43,11 @@ def query(_doc_cls=None, _field_operation=False, **query): parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is op = None - if parts[-1] in MATCH_OPERATORS: + if len(parts) > 1 and parts[-1] in MATCH_OPERATORS: op = parts.pop() negate = False - if parts[-1] == 'not': + if len(parts) > 1 and parts[-1] == 'not': parts.pop() negate = True @@ -74,39 +78,24 @@ def query(_doc_cls=None, _field_operation=False, **query): if op in singular_ops: if isinstance(field, basestring): if (op in STRING_OPERATORS and - isinstance(value, basestring)): + isinstance(value, basestring)): StringField = _import_class('StringField') value = StringField.prepare_query_value(op, value) else: value = field else: value = field.prepare_query_value(op, value) - elif op in ('in', 'nin', 'all', 'near'): + elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] # if op and op not in COMPARISON_OPERATORS: if op: if op in GEO_OPERATORS: - if op == "within_distance": - value = {'$within': {'$center': value}} - elif op == "within_spherical_distance": - value = {'$within': {'$centerSphere': value}} - elif op == "within_polygon": - value = {'$within': {'$polygon': value}} - elif op == "near": - value = {'$near': value} - elif op == "near_sphere": - value = {'$nearSphere': value} - elif op == 'within_box': - value = {'$within': {'$box': value}} - elif op == "max_distance": - value = {'$maxDistance': value} - else: - raise NotImplementedError("Geo method '%s' has not " - "been implemented" % op) + value = _geo_operator(field, op, value) elif op in CUSTOM_OPERATORS: if op == 'match': + value = field.prepare_query_value(op, value) value = {"$elemMatch": value} else: NotImplementedError("Custom method '%s' has not " @@ -144,7 +133,7 @@ def query(_doc_cls=None, _field_operation=False, **query): merge_query[k].append(mongo_query[k]) del mongo_query[k] if isinstance(v, list): - value = [{k:val} for val in v] + value = [{k: val} for val in v] if '$and' in mongo_query.keys(): mongo_query['$and'].append(value) else: @@ -176,7 +165,9 @@ def update(_doc_cls=None, **update): if value > 0: value = -value elif op == 'add_to_set': - op = op.replace('_to_set', 'ToSet') + op = 'addToSet' + elif op == 'set_on_insert': + op = "setOnInsert" match = None if parts[-1] in COMPARISON_OPERATORS: @@ -191,6 +182,7 @@ def update(_doc_cls=None, **update): parts = [] cleaned_fields = [] + appended_sub_field = False for field in fields: append_field = True if isinstance(field, basestring): @@ -202,21 +194,30 @@ def update(_doc_cls=None, **update): else: parts.append(field.db_field) if append_field: + appended_sub_field = False cleaned_fields.append(field) + if hasattr(field, 'field'): + cleaned_fields.append(field.field) + appended_sub_field = True # Convert value to proper value - field = cleaned_fields[-1] + if appended_sub_field: + field = cleaned_fields[-2] + else: + field = cleaned_fields[-1] if op in (None, 'set', 'push', 'pull'): if field.required or value is not None: value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] - elif op == 'addToSet': + elif op in ('addToSet', 'setOnInsert'): if isinstance(value, (list, tuple, set)): value = [field.prepare_query_value(op, v) for v in value] elif field.required or value is not None: value = field.prepare_query_value(op, value) + elif op == "unset": + value = 1 if match: match = '$' + match @@ -230,11 +231,24 @@ def update(_doc_cls=None, **update): if 'pull' in op and '.' in key: # Dot operators don't work on pull operations - # it uses nested dict syntax + # unless they point to a list field + # Otherwise it uses nested dict syntax if op == 'pullAll': raise InvalidQueryError("pullAll operations only support " "a single field depth") + # Look for the last list field and use dot notation until there + field_classes = [c.__class__ for c in cleaned_fields] + field_classes.reverse() + ListField = _import_class('ListField') + if ListField in field_classes: + # Join all fields via dot notation to the last ListField + # Then process as normal + last_listField = len(cleaned_fields) - field_classes.index(ListField) + key = ".".join(parts[:last_listField]) + parts = parts[last_listField:] + parts.insert(0, key) + parts.reverse() for key in parts: value = {key: value} @@ -250,3 +264,76 @@ def update(_doc_cls=None, **update): mongo_update[key].update(value) return mongo_update + + +def _geo_operator(field, op, value): + """Helper to return the query for a given geo query""" + if field._geo_index == pymongo.GEO2D: + if op == "within_distance": + value = {'$within': {'$center': value}} + elif op == "within_spherical_distance": + value = {'$within': {'$centerSphere': value}} + elif op == "within_polygon": + value = {'$within': {'$polygon': value}} + elif op == "near": + value = {'$near': value} + elif op == "near_sphere": + value = {'$nearSphere': value} + elif op == 'within_box': + value = {'$within': {'$box': value}} + elif op == "max_distance": + value = {'$maxDistance': value} + else: + raise NotImplementedError("Geo method '%s' has not " + "been implemented for a GeoPointField" % op) + else: + if op == "geo_within": + value = {"$geoWithin": _infer_geometry(value)} + elif op == "geo_within_box": + value = {"$geoWithin": {"$box": value}} + elif op == "geo_within_polygon": + value = {"$geoWithin": {"$polygon": value}} + elif op == "geo_within_center": + value = {"$geoWithin": {"$center": value}} + elif op == "geo_within_sphere": + value = {"$geoWithin": {"$centerSphere": value}} + elif op == "geo_intersects": + value = {"$geoIntersects": _infer_geometry(value)} + elif op == "near": + value = {'$near': _infer_geometry(value)} + elif op == "max_distance": + value = {'$maxDistance': value} + else: + raise NotImplementedError("Geo method '%s' has not " + "been implemented for a %s " % (op, field._name)) + return value + + +def _infer_geometry(value): + """Helper method that tries to infer the $geometry shape for a given value""" + if isinstance(value, dict): + if "$geometry" in value: + return value + elif 'coordinates' in value and 'type' in value: + return {"$geometry": value} + raise InvalidQueryError("Invalid $geometry dictionary should have " + "type and coordinates keys") + elif isinstance(value, (list, set)): + try: + value[0][0][0] + return {"$geometry": {"type": "Polygon", "coordinates": value}} + except: + pass + try: + value[0][0] + return {"$geometry": {"type": "LineString", "coordinates": value}} + except: + pass + try: + value[0] + return {"$geometry": {"type": "Point", "coordinates": value}} + except: + pass + + raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary " + "or (nested) lists of coordinate(s)") diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 95d11e8..41f4ebf 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -23,6 +23,10 @@ class QNodeVisitor(object): return query +class DuplicateQueryConditionsError(InvalidQueryError): + pass + + class SimplificationVisitor(QNodeVisitor): """Simplifies query trees by combinging unnecessary 'and' connection nodes into a single Q-object. @@ -33,7 +37,11 @@ class SimplificationVisitor(QNodeVisitor): # The simplification only applies to 'simple' queries if all(isinstance(node, Q) for node in combination.children): queries = [n.query for n in combination.children] - return Q(**self._query_conjunction(queries)) + try: + return Q(**self._query_conjunction(queries)) + except DuplicateQueryConditionsError: + # Cannot be simplified + pass return combination def _query_conjunction(self, queries): @@ -47,8 +55,7 @@ class SimplificationVisitor(QNodeVisitor): # to a single field intersection = ops.intersection(query_ops) if intersection: - msg = 'Duplicate query conditions: ' - raise InvalidQueryError(msg + ', '.join(intersection)) + raise DuplicateQueryConditionsError() query_ops.update(ops) combined_query.update(copy.deepcopy(query)) @@ -122,8 +129,7 @@ class QCombination(QNode): # If the child is a combination of the same type, we can merge its # children directly into this combinations children if isinstance(node, QCombination) and node.operation == operation: - # self.children += node.children - self.children.append(node) + self.children += node.children else: self.children.append(node) diff --git a/mongoengine/signals.py b/mongoengine/signals.py index 52ef312..06fb8b4 100644 --- a/mongoengine/signals.py +++ b/mongoengine/signals.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save', - 'pre_delete', 'post_delete'] +__all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', + 'post_save', 'pre_delete', 'post_delete'] signals_available = False try: @@ -39,6 +39,7 @@ _signals = Namespace() pre_init = _signals.signal('pre_init') post_init = _signals.signal('post_init') pre_save = _signals.signal('pre_save') +pre_save_post_validation = _signals.signal('pre_save_post_validation') post_save = _signals.signal('post_save') pre_delete = _signals.signal('pre_delete') post_delete = _signals.signal('post_delete') diff --git a/python-mongoengine.spec b/python-mongoengine.spec index b1ec336..b9c45ef 100644 --- a/python-mongoengine.spec +++ b/python-mongoengine.spec @@ -5,7 +5,7 @@ %define srcname mongoengine Name: python-%{srcname} -Version: 0.7.9 +Version: 0.8.4 Release: 1%{?dist} Summary: A Python Document-Object Mapper for working with MongoDB @@ -51,4 +51,4 @@ rm -rf $RPM_BUILD_ROOT # %{python_sitearch}/* %changelog -* See: http://readthedocs.org/docs/mongoengine-odm/en/latest/changelog.html \ No newline at end of file +* See: http://docs.mongoengine.org/en/latest/changelog.html \ No newline at end of file diff --git a/setup.py b/setup.py index ba538fa..85707d0 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,6 @@ CLASSIFIERS = [ 'Operating System :: OS Independent', 'Programming Language :: Python', "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.5", "Programming Language :: Python :: 2.6", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", @@ -49,17 +48,15 @@ CLASSIFIERS = [ 'Topic :: Software Development :: Libraries :: Python Modules', ] -extra_opts = {} +extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} if sys.version_info[0] == 3: extra_opts['use_2to3'] = True - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker'] - extra_opts['packages'] = find_packages(exclude=('tests',)) + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'django>=1.5.1'] if "test" in sys.argv or "nosetests" in sys.argv: - extra_opts['packages'].append("tests") - extra_opts['package_data'] = {"tests": ["fields/mongoengine.png"]} + extra_opts['packages'] = find_packages() + extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} else: - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django==1.4.2', 'PIL'] - extra_opts['packages'] = find_packages(exclude=('tests',)) + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6', 'python-dateutil'] setup(name='mongoengine', version=VERSION, @@ -75,7 +72,7 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo'], + install_requires=['pymongo>=2.5'], test_suite='nose.collector', **extra_opts ) diff --git a/tests/__init__.py b/tests/__init__.py index 152a8ce..b24df5d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,5 @@ from all_warnings import AllWarnings from document import * -from queryset import * \ No newline at end of file +from queryset import * +from fields import * +from migration import * diff --git a/tests/all_warnings/__init__.py b/tests/all_warnings/__init__.py index 74533de..53ce638 100644 --- a/tests/all_warnings/__init__.py +++ b/tests/all_warnings/__init__.py @@ -17,7 +17,7 @@ __all__ = ('AllWarnings', ) class AllWarnings(unittest.TestCase): def setUp(self): - conn = connect(db='mongoenginetest') + connect(db='mongoenginetest') self.warning_list = [] self.showwarning_default = warnings.showwarning warnings.showwarning = self.append_to_warning_list @@ -30,53 +30,6 @@ class AllWarnings(unittest.TestCase): # restore default handling of warnings warnings.showwarning = self.showwarning_default - def test_dbref_reference_field_future_warning(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self') - - Person.drop_collection() - - p1 = Person() - p1.parent = None - p1.save() - - p2 = Person(name="Wilson Jr") - p2.parent = p1 - p2.save(cascade=False) - - self.assertTrue(len(self.warning_list) > 0) - warning = self.warning_list[0] - self.assertEqual(FutureWarning, warning["category"]) - self.assertTrue("ReferenceFields will default to using ObjectId" - in str(warning["message"])) - - def test_document_save_cascade_future_warning(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self') - - Person.drop_collection() - - p1 = Person(name="Wilson Snr") - p1.parent = None - p1.save() - - p2 = Person(name="Wilson Jr") - p2.parent = p1 - p2.parent.name = "Poppa Wilson" - p2.save() - - self.assertTrue(len(self.warning_list) > 0) - if len(self.warning_list) > 1: - print self.warning_list - warning = self.warning_list[0] - self.assertEqual(FutureWarning, warning["category"]) - self.assertTrue("Cascading saves will default to off in 0.8" - in str(warning["message"])) - def test_document_collection_syntax_warning(self): class NonAbstractBase(Document): @@ -89,6 +42,3 @@ class AllWarnings(unittest.TestCase): self.assertEqual(SyntaxWarning, warning["category"]) self.assertEqual('non_abstract_base', InheritedDocumentFailTest._get_collection_name()) - -import sys -sys.path[0:0] = [""] diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 83e68ff..52e3794 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- -from __future__ import with_statement import sys sys.path[0:0] = [""] import unittest from mongoengine import * -from mongoengine.queryset import NULLIFY +from mongoengine.queryset import NULLIFY, PULL from mongoengine.connection import get_db __all__ = ("ClassMethodsTest", ) @@ -86,6 +85,172 @@ class ClassMethodsTest(unittest.TestCase): self.assertEqual(self.Person._meta['delete_rules'], {(Job, 'employee'): NULLIFY}) + def test_compare_indexes(self): + """ Ensure that the indexes are properly created and that + compare_indexes identifies the missing/extra indexes + """ + + class BlogPost(Document): + author = StringField() + title = StringField() + description = StringField() + tags = StringField() + + meta = { + 'indexes': [('author', 'title')] + } + + BlogPost.drop_collection() + + BlogPost.ensure_indexes() + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + + BlogPost.ensure_index(['author', 'description']) + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [[('author', 1), ('description', 1)]] }) + + BlogPost._get_collection().drop_index('author_1_description_1') + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + + BlogPost._get_collection().drop_index('author_1_title_1') + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [[('author', 1), ('title', 1)]], 'extra': [] }) + + def test_compare_indexes_inheritance(self): + """ Ensure that the indexes are properly created and that + compare_indexes identifies the missing/extra indexes for subclassed + documents (_cls included) + """ + + class BlogPost(Document): + author = StringField() + title = StringField() + description = StringField() + + meta = { + 'allow_inheritance': True + } + + class BlogPostWithTags(BlogPost): + tags = StringField() + tag_list = ListField(StringField()) + + meta = { + 'indexes': [('author', 'tags')] + } + + BlogPost.drop_collection() + + BlogPost.ensure_indexes() + BlogPostWithTags.ensure_indexes() + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + + BlogPostWithTags.ensure_index(['author', 'tag_list']) + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [[('_cls', 1), ('author', 1), ('tag_list', 1)]] }) + + BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tag_list_1') + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + + BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tags_1') + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [[('_cls', 1), ('author', 1), ('tags', 1)]], 'extra': [] }) + + def test_compare_indexes_multiple_subclasses(self): + """ Ensure that compare_indexes behaves correctly if called from a + class, which base class has multiple subclasses + """ + + class BlogPost(Document): + author = StringField() + title = StringField() + description = StringField() + + meta = { + 'allow_inheritance': True + } + + class BlogPostWithTags(BlogPost): + tags = StringField() + tag_list = ListField(StringField()) + + meta = { + 'indexes': [('author', 'tags')] + } + + class BlogPostWithCustomField(BlogPost): + custom = DictField() + + meta = { + 'indexes': [('author', 'custom')] + } + + BlogPost.ensure_indexes() + BlogPostWithTags.ensure_indexes() + BlogPostWithCustomField.ensure_indexes() + + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPostWithTags.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPostWithCustomField.compare_indexes(), { 'missing': [], 'extra': [] }) + + def test_list_indexes_inheritance(self): + """ ensure that all of the indexes are listed regardless of the super- + or sub-class that we call it from + """ + + class BlogPost(Document): + author = StringField() + title = StringField() + description = StringField() + + meta = { + 'allow_inheritance': True + } + + class BlogPostWithTags(BlogPost): + tags = StringField() + + meta = { + 'indexes': [('author', 'tags')] + } + + class BlogPostWithTagsAndExtraText(BlogPostWithTags): + extra_text = StringField() + + meta = { + 'indexes': [('author', 'tags', 'extra_text')] + } + + BlogPost.drop_collection() + + BlogPost.ensure_indexes() + BlogPostWithTags.ensure_indexes() + BlogPostWithTagsAndExtraText.ensure_indexes() + + self.assertEqual(BlogPost.list_indexes(), + BlogPostWithTags.list_indexes()) + self.assertEqual(BlogPost.list_indexes(), + BlogPostWithTagsAndExtraText.list_indexes()) + self.assertEqual(BlogPost.list_indexes(), + [[('_cls', 1), ('author', 1), ('tags', 1)], + [('_cls', 1), ('author', 1), ('tags', 1), ('extra_text', 1)], + [(u'_id', 1)], [('_cls', 1)]]) + + def test_register_delete_rule_inherited(self): + + class Vaccine(Document): + name = StringField(required=True) + + meta = {"indexes": ["name"]} + + class Animal(Document): + family = StringField(required=True) + vaccine_made = ListField(ReferenceField("Vaccine", reverse_delete_rule=PULL)) + + meta = {"allow_inheritance": True, "indexes": ["family"]} + + class Cat(Animal): + name = StringField(required=True) + + self.assertEqual(Vaccine._meta['delete_rules'][(Animal, 'vaccine_made')], PULL) + self.assertEqual(Vaccine._meta['delete_rules'][(Cat, 'vaccine_made')], PULL) + def test_collection_naming(self): """Ensure that a collection with a specified name may be used. """ diff --git a/tests/document/delta.py b/tests/document/delta.py index c6191d9..b4749f3 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -3,6 +3,7 @@ import sys sys.path[0:0] = [""] import unittest +from bson import SON from mongoengine import * from mongoengine.connection import get_db @@ -129,14 +130,14 @@ class DeltaTest(unittest.TestCase): } self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) self.assertEqual(doc._delta(), - ({'embedded_field': embedded_delta}, {})) + ({'embedded_field': embedded_delta}, {})) doc.save() doc = doc.reload(10) doc.embedded_field.dict_field = {} self.assertEqual(doc._get_changed_fields(), - ['embedded_field.dict_field']) + ['embedded_field.dict_field']) self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) doc.save() @@ -145,7 +146,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = [] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + ['embedded_field.list_field']) self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) doc.save() @@ -160,7 +161,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = ['1', 2, embedded_2] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + ['embedded_field.list_field']) self.assertEqual(doc.embedded_field._delta(), ({ 'list_field': ['1', 2, { @@ -192,11 +193,11 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'world' self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field.2.string_field']) + ['embedded_field.list_field.2.string_field']) self.assertEqual(doc.embedded_field._delta(), - ({'list_field.2.string_field': 'world'}, {})) + ({'list_field.2.string_field': 'world'}, {})) self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.string_field': 'world'}, {})) + ({'embedded_field.list_field.2.string_field': 'world'}, {})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].string_field, @@ -206,7 +207,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + ['embedded_field.list_field']) self.assertEqual(doc.embedded_field._delta(), ({ 'list_field': ['1', 2, { '_cls': 'Embedded', @@ -225,40 +226,40 @@ class DeltaTest(unittest.TestCase): doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].string_field, - 'hello world') + 'hello world') # Test list native methods doc.embedded_field.list_field[2].list_field.pop(0) self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.list_field': - [2, {'hello': 'world'}]}, {})) + ({'embedded_field.list_field.2.list_field': + [2, {'hello': 'world'}]}, {})) doc.save() doc = doc.reload(10) doc.embedded_field.list_field[2].list_field.append(1) self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.list_field': - [2, {'hello': 'world'}, 1]}, {})) + ({'embedded_field.list_field.2.list_field': + [2, {'hello': 'world'}, 1]}, {})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].list_field, - [2, {'hello': 'world'}, 1]) + [2, {'hello': 'world'}, 1]) doc.embedded_field.list_field[2].list_field.sort(key=str) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].list_field, - [1, 2, {'hello': 'world'}]) + [1, 2, {'hello': 'world'}]) del(doc.embedded_field.list_field[2].list_field[2]['hello']) self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) + ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) doc.save() doc = doc.reload(10) del(doc.embedded_field.list_field[2].list_field) self.assertEqual(doc._delta(), - ({}, {'embedded_field.list_field.2.list_field': 1})) + ({}, {'embedded_field.list_field.2.list_field': 1})) doc.save() doc = doc.reload(10) @@ -269,9 +270,9 @@ class DeltaTest(unittest.TestCase): doc.dict_field['Embedded'].string_field = 'Hello World' self.assertEqual(doc._get_changed_fields(), - ['dict_field.Embedded.string_field']) + ['dict_field.Embedded.string_field']) self.assertEqual(doc._delta(), - ({'dict_field.Embedded.string_field': 'Hello World'}, {})) + ({'dict_field.Embedded.string_field': 'Hello World'}, {})) def test_circular_reference_deltas(self): self.circular_reference_deltas(Document, Document) @@ -289,10 +290,11 @@ class DeltaTest(unittest.TestCase): name = StringField() owner = ReferenceField('Person') - person = Person(name="owner") - person.save() - organization = Organization(name="company") - organization.save() + Person.drop_collection() + Organization.drop_collection() + + person = Person(name="owner").save() + organization = Organization(name="company").save() person.owns.append(organization) organization.owner = person @@ -311,29 +313,24 @@ class DeltaTest(unittest.TestCase): self.circular_reference_deltas_2(DynamicDocument, Document) self.circular_reference_deltas_2(DynamicDocument, DynamicDocument) - def circular_reference_deltas_2(self, DocClass1, DocClass2): + def circular_reference_deltas_2(self, DocClass1, DocClass2, dbref=True): class Person(DocClass1): name = StringField() - owns = ListField(ReferenceField('Organization')) - employer = ReferenceField('Organization') + owns = ListField(ReferenceField('Organization', dbref=dbref)) + employer = ReferenceField('Organization', dbref=dbref) class Organization(DocClass2): name = StringField() - owner = ReferenceField('Person') - employees = ListField(ReferenceField('Person')) + owner = ReferenceField('Person', dbref=dbref) + employees = ListField(ReferenceField('Person', dbref=dbref)) Person.drop_collection() Organization.drop_collection() - person = Person(name="owner") - person.save() - - employee = Person(name="employee") - employee.save() - - organization = Organization(name="company") - organization.save() + person = Person(name="owner").save() + employee = Person(name="employee").save() + organization = Organization(name="company").save() person.owns.append(organization) organization.owner = person @@ -353,6 +350,8 @@ class DeltaTest(unittest.TestCase): self.assertEqual(o.owner, p) self.assertEqual(e.employer, o) + return person, organization, employee + def test_delta_db_field(self): self.delta_db_field(Document) self.delta_db_field(DynamicDocument) @@ -612,13 +611,13 @@ class DeltaTest(unittest.TestCase): Person.drop_collection() p = Person(name="James", age=34) - self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', - '_cls': 'Person'}, {})) + self.assertEqual(p._delta(), ( + SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {})) p.doc = 123 del(p.doc) - self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', - '_cls': 'Person'}, {'doc': 1})) + self.assertEqual(p._delta(), ( + SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {})) p = Person() p.name = "Dean" @@ -630,14 +629,14 @@ class DeltaTest(unittest.TestCase): self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._delta(), ({'age': 24}, {})) - p = self.Person.objects(age=22).get() + p = Person.objects(age=22).get() p.age = 24 self.assertEqual(p.age, 24) self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._delta(), ({'age': 24}, {})) p.save() - self.assertEqual(1, self.Person.objects(age=24).count()) + self.assertEqual(1, Person.objects(age=24).count()) def test_dynamic_delta(self): @@ -684,6 +683,36 @@ class DeltaTest(unittest.TestCase): self.assertEqual(doc._get_changed_fields(), ['list_field']) self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + def test_delta_with_dbref_true(self): + person, organization, employee = self.circular_reference_deltas_2(Document, Document, True) + employee.name = 'test' + + self.assertEqual(organization._get_changed_fields(), []) + + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertEqual({}, updates) + + organization.employees.append(person) + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertTrue('employees' in updates) + + def test_delta_with_dbref_false(self): + person, organization, employee = self.circular_reference_deltas_2(Document, Document, False) + employee.name = 'test' + + self.assertEqual(organization._get_changed_fields(), []) + + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertEqual({}, updates) + + organization.employees.append(person) + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertTrue('employees' in updates) + if __name__ == '__main__': unittest.main() diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py index 5881cd0..6263e68 100644 --- a/tests/document/dynamic.py +++ b/tests/document/dynamic.py @@ -31,8 +31,9 @@ class DynamicTest(unittest.TestCase): self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James", "age": 34}) - + self.assertEqual(p.to_mongo().keys(), ["_cls", "name", "age"]) p.save() + self.assertEqual(p.to_mongo().keys(), ["_id", "_cls", "name", "age"]) self.assertEqual(self.Person.objects.first().age, 34) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index ff08ef1..ccf8463 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -from __future__ import with_statement import unittest import sys sys.path[0:0] = [""] @@ -157,6 +156,25 @@ class IndexesTest(unittest.TestCase): self.assertEqual([{'fields': [('_cls', 1), ('title', 1)]}], A._meta['index_specs']) + def test_index_no_cls(self): + """Ensure index specs are inhertited correctly""" + + class A(Document): + title = StringField() + meta = { + 'indexes': [ + {'fields': ('title',), 'cls': False}, + ], + 'allow_inheritance': True, + 'index_cls': False + } + + self.assertEqual([('title', 1)], A._meta['index_specs'][0]['fields']) + A._get_collection().drop_indexes() + A.ensure_indexes() + info = A._get_collection().index_information() + self.assertEqual(len(info.keys()), 2) + def test_build_index_spec_is_not_destructive(self): class MyDoc(Document): @@ -217,7 +235,7 @@ class IndexesTest(unittest.TestCase): } self.assertEqual([{'fields': [('location.point', '2d')]}], - Place._meta['index_specs']) + Place._meta['index_specs']) Place.ensure_indexes() info = Place._get_collection().index_information() @@ -231,8 +249,7 @@ class IndexesTest(unittest.TestCase): location = DictField() class Place(Document): - current = DictField( - field=EmbeddedDocumentField('EmbeddedLocation')) + current = DictField(field=EmbeddedDocumentField('EmbeddedLocation')) meta = { 'allow_inheritance': True, 'indexes': [ @@ -241,7 +258,7 @@ class IndexesTest(unittest.TestCase): } self.assertEqual([{'fields': [('current.location.point', '2d')]}], - Place._meta['index_specs']) + Place._meta['index_specs']) Place.ensure_indexes() info = Place._get_collection().index_information() @@ -264,7 +281,7 @@ class IndexesTest(unittest.TestCase): self.assertEqual([{'fields': [('addDate', -1)], 'unique': True, 'sparse': True}], - BlogPost._meta['index_specs']) + BlogPost._meta['index_specs']) BlogPost.drop_collection() @@ -314,19 +331,27 @@ class IndexesTest(unittest.TestCase): """ class User(Document): meta = { + 'allow_inheritance': True, 'indexes': ['user_guid'], 'auto_create_index': False } user_guid = StringField(required=True) + class MongoUser(User): + pass + User.drop_collection() - u = User(user_guid='123') - u.save() + User(user_guid='123').save() + MongoUser(user_guid='123').save() - self.assertEqual(1, User.objects.count()) + self.assertEqual(2, User.objects.count()) info = User.objects._collection.index_information() self.assertEqual(info.keys(), ['_id_']) + + User.ensure_indexes() + info = User.objects._collection.index_information() + self.assertEqual(sorted(info.keys()), ['_cls_1_user_guid_1', '_id_']) User.drop_collection() def test_embedded_document_index(self): @@ -374,8 +399,7 @@ class IndexesTest(unittest.TestCase): self.assertEqual(sorted(info.keys()), ['_id_', 'tags.tag_1']) post1 = BlogPost(title="Embedded Indexes tests in place", - tags=[Tag(name="about"), Tag(name="time")] - ) + tags=[Tag(name="about"), Tag(name="time")]) post1.save() BlogPost.drop_collection() @@ -392,29 +416,6 @@ class IndexesTest(unittest.TestCase): info = RecursiveDocument._get_collection().index_information() self.assertEqual(sorted(info.keys()), ['_cls_1', '_id_']) - def test_geo_indexes_recursion(self): - - class Location(Document): - name = StringField() - location = GeoPointField() - - class Parent(Document): - name = StringField() - location = ReferenceField(Location, dbref=False) - - Location.drop_collection() - Parent.drop_collection() - - list(Parent.objects) - - collection = Parent._get_collection() - info = collection.index_information() - - self.assertFalse('location_2d' in info) - - self.assertEqual(len(Parent._geo_indices()), 0) - self.assertEqual(len(Location._geo_indices()), 1) - def test_covered_index(self): """Ensure that covered indexes can be used """ @@ -425,7 +426,7 @@ class IndexesTest(unittest.TestCase): meta = { 'indexes': ['a'], 'allow_inheritance': False - } + } Test.drop_collection() @@ -625,7 +626,7 @@ class IndexesTest(unittest.TestCase): list(Log.objects) info = Log.objects._collection.index_information() self.assertEqual(3600, - info['created_1']['expireAfterSeconds']) + info['created_1']['expireAfterSeconds']) def test_unique_and_indexes(self): """Ensure that 'unique' constraints aren't overridden by diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 3b550f1..5a48f75 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -143,7 +143,7 @@ class InheritanceTest(unittest.TestCase): self.assertEqual(Animal._superclasses, ()) self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish', - 'Animal.Fish.Pike')) + 'Animal.Fish.Pike')) self.assertEqual(Fish._superclasses, ('Animal', )) self.assertEqual(Fish._subclasses, ('Animal.Fish', 'Animal.Fish.Pike')) @@ -168,6 +168,61 @@ class InheritanceTest(unittest.TestCase): self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) + def test_inheritance_to_mongo_keys(self): + """Ensure that document may inherit fields from a superclass document. + """ + class Person(Document): + name = StringField() + age = IntField() + + meta = {'allow_inheritance': True} + + class Employee(Person): + salary = IntField() + + self.assertEqual(['age', 'id', 'name', 'salary'], + sorted(Employee._fields.keys())) + self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), + ['_cls', 'name', 'age']) + self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), + ['_cls', 'name', 'age', 'salary']) + self.assertEqual(Employee._get_collection_name(), + Person._get_collection_name()) + + def test_indexes_and_multiple_inheritance(self): + """ Ensure that all of the indexes are created for a document with + multiple inheritance. + """ + + class A(Document): + a = StringField() + + meta = { + 'allow_inheritance': True, + 'indexes': ['a'] + } + + class B(Document): + b = StringField() + + meta = { + 'allow_inheritance': True, + 'indexes': ['b'] + } + + class C(A, B): + pass + + A.drop_collection() + B.drop_collection() + C.drop_collection() + + C.ensure_indexes() + + self.assertEqual( + sorted([idx['key'] for idx in C._get_collection().index_information().values()]), + sorted([[(u'_cls', 1), (u'b', 1)], [(u'_id', 1)], [(u'_cls', 1), (u'a', 1)]]) + ) def test_polymorphic_queries(self): """Ensure that the correct subclasses are returned from a query @@ -197,7 +252,6 @@ class InheritanceTest(unittest.TestCase): classes = [obj.__class__ for obj in Human.objects] self.assertEqual(classes, [Human]) - def test_allow_inheritance(self): """Ensure that inheritance may be disabled on simple classes and that _cls and _subclasses will not be used. @@ -213,8 +267,8 @@ class InheritanceTest(unittest.TestCase): self.assertRaises(ValueError, create_dog_class) # Check that _cls etc aren't present on simple documents - dog = Animal(name='dog') - dog.save() + dog = Animal(name='dog').save() + self.assertEqual(dog.to_mongo().keys(), ['_id', 'name']) collection = self.db[Animal._get_collection_name()] obj = collection.find_one() diff --git a/tests/document/instance.py b/tests/document/instance.py index 07991c1..a61c439 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -from __future__ import with_statement import sys sys.path[0:0] = [""] @@ -10,7 +9,9 @@ import unittest import uuid from datetime import datetime -from tests.fixtures import PickleEmbedded, PickleTest +from bson import DBRef +from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, + PickleDyanmicEmbedded, PickleDynamicTest) from mongoengine import * from mongoengine.errors import (NotRegistered, InvalidDocumentError, @@ -65,11 +66,11 @@ class InstanceTest(unittest.TestCase): for _ in range(10): Log().save() - self.assertEqual(len(Log.objects), 10) + self.assertEqual(Log.objects.count(), 10) # Check that extra documents don't increase the size Log().save() - self.assertEqual(len(Log.objects), 10) + self.assertEqual(Log.objects.count(), 10) options = Log.objects._collection.options() self.assertEqual(options['capped'], True) @@ -320,8 +321,8 @@ class InstanceTest(unittest.TestCase): Location.drop_collection() - self.assertEquals(Area, get_document("Area")) - self.assertEquals(Area, get_document("Location.Area")) + self.assertEqual(Area, get_document("Area")) + self.assertEqual(Area, get_document("Location.Area")) def test_creation(self): """Ensure that document may be created using keyword arguments. @@ -428,6 +429,28 @@ class InstanceTest(unittest.TestCase): self.assertFalse('age' in person) self.assertFalse('nationality' in person) + def test_embedded_document_to_mongo(self): + class Person(EmbeddedDocument): + name = StringField() + age = IntField() + + meta = {"allow_inheritance": True} + + class Employee(Person): + salary = IntField() + + self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), + ['_cls', 'name', 'age']) + self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), + ['_cls', 'name', 'age', 'salary']) + + def test_embedded_document_to_mongo_id(self): + class SubDoc(EmbeddedDocument): + id = StringField(required=True) + + sub_doc = SubDoc(id="abc") + self.assertEqual(sub_doc.to_mongo().keys(), ['id']) + def test_embedded_document(self): """Ensure that embedded documents are set up correctly. """ @@ -494,12 +517,12 @@ class InstanceTest(unittest.TestCase): t = TestDocument(status="published") t.save(clean=False) - self.assertEquals(t.pub_date, None) + self.assertEqual(t.pub_date, None) t = TestDocument(status="published") t.save(clean=True) - self.assertEquals(type(t.pub_date), datetime) + self.assertEqual(type(t.pub_date), datetime) def test_document_embedded_clean(self): class TestEmbeddedDocument(EmbeddedDocument): @@ -531,7 +554,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}}) t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save() - self.assertEquals(t.doc.z, 35) + self.assertEqual(t.doc.z, 35) # Asserts not raises t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5)) @@ -650,7 +673,7 @@ class InstanceTest(unittest.TestCase): p = Person.objects(name="Wilson Jr").get() p.parent.name = "Daddy Wilson" - p.save() + p.save(cascade=True) p1.reload() self.assertEqual(p1.name, p.parent.name) @@ -669,16 +692,14 @@ class InstanceTest(unittest.TestCase): p2 = Person(name="Wilson Jr") p2.parent = p1 + p1.name = "Daddy Wilson" p2.save(force_insert=True, cascade_kwargs={"force_insert": False}) - p = Person.objects(name="Wilson Jr").get() - p.parent.name = "Daddy Wilson" - p.save() - p1.reload() - self.assertEqual(p1.name, p.parent.name) + p2.reload() + self.assertEqual(p1.name, p2.parent.name) - def test_save_cascade_meta(self): + def test_save_cascade_meta_false(self): class Person(Document): name = StringField() @@ -707,6 +728,31 @@ class InstanceTest(unittest.TestCase): p1.reload() self.assertEqual(p1.name, p.parent.name) + def test_save_cascade_meta_true(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + meta = {'cascade': False} + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save(cascade=True) + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertNotEqual(p1.name, p.parent.name) + def test_save_cascades_generically(self): class Person(Document): @@ -726,6 +772,10 @@ class InstanceTest(unittest.TestCase): p.parent.name = "Daddy Wilson" p.save() + p1.reload() + self.assertNotEqual(p1.name, p.parent.name) + + p.save(cascade=True) p1.reload() self.assertEqual(p1.name, p.parent.name) @@ -813,6 +863,14 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.name, None) self.assertEqual(person.age, None) + def test_inserts_if_you_set_the_pk(self): + p1 = self.Person(name='p1', id=bson.ObjectId()).save() + p2 = self.Person(name='p2') + p2.id = bson.ObjectId() + p2.save() + + self.assertEqual(2, self.Person.objects.count()) + def test_can_save_if_not_included(self): class EmbeddedDoc(EmbeddedDocument): @@ -971,6 +1029,99 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.age, 21) self.assertEqual(person.active, False) + def test_query_count_when_saving(self): + """Ensure references don't cause extra fetches when saving""" + class Organization(Document): + name = StringField() + + class User(Document): + name = StringField() + orgs = ListField(ReferenceField('Organization')) + + class Feed(Document): + name = StringField() + + class UserSubscription(Document): + name = StringField() + user = ReferenceField(User) + feed = ReferenceField(Feed) + + Organization.drop_collection() + User.drop_collection() + Feed.drop_collection() + UserSubscription.drop_collection() + + o1 = Organization(name="o1").save() + o2 = Organization(name="o2").save() + + u1 = User(name="Ross", orgs=[o1, o2]).save() + f1 = Feed(name="MongoEngine").save() + + sub = UserSubscription(user=u1, feed=f1).save() + + user = User.objects.first() + # Even if stored as ObjectId's internally mongoengine uses DBRefs + # As ObjectId's aren't automatically derefenced + self.assertTrue(isinstance(user._data['orgs'][0], DBRef)) + self.assertTrue(isinstance(user.orgs[0], Organization)) + self.assertTrue(isinstance(user._data['orgs'][0], Organization)) + + # Changing a value + with query_counter() as q: + self.assertEqual(q, 0) + sub = UserSubscription.objects.first() + self.assertEqual(q, 1) + sub.name = "Test Sub" + sub.save() + self.assertEqual(q, 2) + + # Changing a value that will cascade + with query_counter() as q: + self.assertEqual(q, 0) + sub = UserSubscription.objects.first() + self.assertEqual(q, 1) + sub.user.name = "Test" + self.assertEqual(q, 2) + sub.save(cascade=True) + self.assertEqual(q, 3) + + # Changing a value and one that will cascade + with query_counter() as q: + self.assertEqual(q, 0) + sub = UserSubscription.objects.first() + sub.name = "Test Sub 2" + self.assertEqual(q, 1) + sub.user.name = "Test 2" + self.assertEqual(q, 2) + sub.save(cascade=True) + self.assertEqual(q, 4) # One for the UserSub and one for the User + + # Saving with just the refs + with query_counter() as q: + self.assertEqual(q, 0) + sub = UserSubscription(user=u1.pk, feed=f1.pk) + self.assertEqual(q, 0) + sub.save() + self.assertEqual(q, 1) + + # Saving with just the refs on a ListField + with query_counter() as q: + self.assertEqual(q, 0) + User(name="Bob", orgs=[o1.pk, o2.pk]).save() + self.assertEqual(q, 1) + + # Saving new objects + with query_counter() as q: + self.assertEqual(q, 0) + user = User.objects.first() + self.assertEqual(q, 1) + feed = Feed.objects.first() + self.assertEqual(q, 2) + sub = UserSubscription(user=user, feed=feed) + self.assertEqual(q, 2) # Check no change + sub.save() + self.assertEqual(q, 3) + def test_set_unset_one_operation(self): """Ensure that $set and $unset actions are performed in the same operation. @@ -1040,9 +1191,9 @@ class InstanceTest(unittest.TestCase): """ person = self.Person(name="Test User", age=30) person.save() - self.assertEqual(len(self.Person.objects), 1) + self.assertEqual(self.Person.objects.count(), 1) person.delete() - self.assertEqual(len(self.Person.objects), 0) + self.assertEqual(self.Person.objects.count(), 0) def test_save_custom_id(self): """Ensure that a document may be saved with a custom _id. @@ -1356,12 +1507,12 @@ class InstanceTest(unittest.TestCase): post.save() reviewer.delete() - self.assertEqual(len(BlogPost.objects), 1) # No effect on the BlogPost + self.assertEqual(BlogPost.objects.count(), 1) # No effect on the BlogPost self.assertEqual(BlogPost.objects.get().reviewer, None) # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() - self.assertEqual(len(BlogPost.objects), 0) + self.assertEqual(BlogPost.objects.count(), 0) def test_reverse_delete_rule_with_document_inheritance(self): """Ensure that a referenced document is also deleted upon deletion @@ -1391,12 +1542,12 @@ class InstanceTest(unittest.TestCase): post.save() reviewer.delete() - self.assertEqual(len(BlogPost.objects), 1) + self.assertEqual(BlogPost.objects.count(), 1) self.assertEqual(BlogPost.objects.get().reviewer, None) # Delete the Writer should lead to deletion of the BlogPost author.delete() - self.assertEqual(len(BlogPost.objects), 0) + self.assertEqual(BlogPost.objects.count(), 0) def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): """Ensure that a referenced document is also deleted upon deletion for @@ -1425,12 +1576,12 @@ class InstanceTest(unittest.TestCase): # Deleting the reviewer should have no effect on the BlogPost reviewer.delete() - self.assertEqual(len(BlogPost.objects), 1) + self.assertEqual(BlogPost.objects.count(), 1) self.assertEqual(BlogPost.objects.get().reviewers, []) # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() - self.assertEqual(len(BlogPost.objects), 0) + self.assertEqual(BlogPost.objects.count(), 0) def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): ''' ensure the pre_delete signal is triggered upon a cascading deletion @@ -1498,7 +1649,7 @@ class InstanceTest(unittest.TestCase): f.delete() - self.assertEqual(len(Bar.objects), 1) # No effect on the BlogPost + self.assertEqual(Bar.objects.count(), 1) # No effect on the BlogPost self.assertEqual(Bar.objects.get().foo, None) def test_invalid_reverse_delete_rules_raise_errors(self): @@ -1549,7 +1700,7 @@ class InstanceTest(unittest.TestCase): # Delete the Person, which should lead to deletion of the BlogPost, and, # recursively to the Comment, too author.delete() - self.assertEqual(len(Comment.objects), 0) + self.assertEqual(Comment.objects.count(), 0) self.Person.drop_collection() BlogPost.drop_collection() @@ -1576,16 +1727,16 @@ class InstanceTest(unittest.TestCase): # Delete the Person should be denied self.assertRaises(OperationError, author.delete) # Should raise denied error - self.assertEqual(len(BlogPost.objects), 1) # No objects may have been deleted - self.assertEqual(len(self.Person.objects), 1) + self.assertEqual(BlogPost.objects.count(), 1) # No objects may have been deleted + self.assertEqual(self.Person.objects.count(), 1) # Other users, that don't have BlogPosts must be removable, like normal author = self.Person(name='Another User') author.save() - self.assertEqual(len(self.Person.objects), 2) + self.assertEqual(self.Person.objects.count(), 2) author.delete() - self.assertEqual(len(self.Person.objects), 1) + self.assertEqual(self.Person.objects.count(), 1) self.Person.drop_collection() BlogPost.drop_collection() @@ -1662,6 +1813,7 @@ class InstanceTest(unittest.TestCase): pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc.embedded = PickleEmbedded() + pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved pickle_doc.save() pickled_doc = pickle.dumps(pickle_doc) @@ -1669,11 +1821,48 @@ class InstanceTest(unittest.TestCase): self.assertEqual(resurrected, pickle_doc) + # Test pickling changed data + pickle_doc.lists.append("3") + pickled_doc = pickle.dumps(pickle_doc) + resurrected = pickle.loads(pickled_doc) + + self.assertEqual(resurrected, pickle_doc) resurrected.string = "Two" resurrected.save() - pickle_doc = pickle_doc.reload() + pickle_doc = PickleTest.objects.first() self.assertEqual(resurrected, pickle_doc) + self.assertEqual(pickle_doc.string, "Two") + self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) + + def test_dynamic_document_pickle(self): + + pickle_doc = PickleDynamicTest(name="test", number=1, string="One", lists=['1', '2']) + pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar") + pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved + + pickle_doc.save() + + pickled_doc = pickle.dumps(pickle_doc) + resurrected = pickle.loads(pickled_doc) + + self.assertEqual(resurrected, pickle_doc) + self.assertEqual(resurrected._fields_ordered, + pickle_doc._fields_ordered) + self.assertEqual(resurrected._dynamic_fields.keys(), + pickle_doc._dynamic_fields.keys()) + + self.assertEqual(resurrected.embedded, pickle_doc.embedded) + self.assertEqual(resurrected.embedded._fields_ordered, + pickle_doc.embedded._fields_ordered) + self.assertEqual(resurrected.embedded._dynamic_fields.keys(), + pickle_doc.embedded._dynamic_fields.keys()) + + def test_picklable_on_signals(self): + pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2']) + pickle_doc.embedded = PickleEmbedded() + pickle_doc.save() + pickle_doc.delete() def test_throw_invalid_document_error(self): @@ -1848,15 +2037,17 @@ class InstanceTest(unittest.TestCase): A.objects.all() - self.assertEquals('testdb-2', B._meta.get('db_alias')) - self.assertEquals('mongoenginetest', - A._get_collection().database.name) - self.assertEquals('mongoenginetest2', - B._get_collection().database.name) + self.assertEqual('testdb-2', B._meta.get('db_alias')) + self.assertEqual('mongoenginetest', + A._get_collection().database.name) + self.assertEqual('mongoenginetest2', + B._get_collection().database.name) def test_db_alias_propagates(self): """db_alias propagates? """ + register_connection('testdb-1', 'mongoenginetest2') + class A(Document): name = StringField() meta = {"db_alias": "testdb-1", "allow_inheritance": True} @@ -2129,6 +2320,16 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 42) + def test_mixed_creation_dynamic(self): + """Ensure that document may be created using mixed arguments. + """ + class Person(DynamicDocument): + name = StringField() + + person = Person("Test User", age=42) + self.assertEqual(person.name, "Test User") + self.assertEqual(person.age, 42) + def test_bad_mixed_creation(self): """Ensure that document gives correct error when duplicating arguments """ diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index dbc09d8..2b5d9a0 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -31,6 +31,10 @@ class TestJson(unittest.TestCase): doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) + doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) + expected_json = """{"embedded_field":{"string":"Hi"},"string":"Hi"}""" + self.assertEqual(doc_json, expected_json) + self.assertEqual(doc, Doc.from_json(doc.to_json())) def test_json_complex(self): diff --git a/tests/fields/__init__.py b/tests/fields/__init__.py index 0731838..be70aaa 100644 --- a/tests/fields/__init__.py +++ b/tests/fields/__init__.py @@ -1,2 +1,3 @@ from fields import * -from file_tests import * \ No newline at end of file +from file_tests import * +from geo import * \ No newline at end of file diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 9a7b82f..8791781 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -from __future__ import with_statement import sys sys.path[0:0] = [""] @@ -7,6 +6,11 @@ import datetime import unittest import uuid +try: + import dateutil +except ImportError: + dateutil = None + from decimal import Decimal from bson import Binary, DBRef, ObjectId @@ -30,20 +34,137 @@ class FieldTest(unittest.TestCase): self.db.drop_collection('fs.files') self.db.drop_collection('fs.chunks') - def test_default_values(self): + def test_default_values_nothing_set(self): """Ensure that default field values are used when creating a document. """ class Person(Document): name = StringField() - age = IntField(default=30, help_text="Your real age") - userid = StringField(default=lambda: 'test', verbose_name="User Identity") + age = IntField(default=30, required=False) + userid = StringField(default=lambda: 'test', required=True) + created = DateTimeField(default=datetime.datetime.utcnow) - person = Person(name='Test Person') - self.assertEqual(person._data['age'], 30) - self.assertEqual(person._data['userid'], 'test') - self.assertEqual(person._fields['name'].help_text, None) - self.assertEqual(person._fields['age'].help_text, "Your real age") - self.assertEqual(person._fields['userid'].verbose_name, "User Identity") + person = Person(name="Ross") + + # Confirm saving now would store values + data_to_be_saved = sorted(person.to_mongo().keys()) + self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid']) + + self.assertTrue(person.validate() is None) + + self.assertEqual(person.name, person.name) + self.assertEqual(person.age, person.age) + self.assertEqual(person.userid, person.userid) + self.assertEqual(person.created, person.created) + + self.assertEqual(person._data['name'], person.name) + self.assertEqual(person._data['age'], person.age) + self.assertEqual(person._data['userid'], person.userid) + self.assertEqual(person._data['created'], person.created) + + # Confirm introspection changes nothing + data_to_be_saved = sorted(person.to_mongo().keys()) + self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid']) + + def test_default_values_set_to_None(self): + """Ensure that default field values are used when creating a document. + """ + class Person(Document): + name = StringField() + age = IntField(default=30, required=False) + userid = StringField(default=lambda: 'test', required=True) + created = DateTimeField(default=datetime.datetime.utcnow) + + # Trying setting values to None + person = Person(name=None, age=None, userid=None, created=None) + + # Confirm saving now would store values + data_to_be_saved = sorted(person.to_mongo().keys()) + self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + + self.assertTrue(person.validate() is None) + + self.assertEqual(person.name, person.name) + self.assertEqual(person.age, person.age) + self.assertEqual(person.userid, person.userid) + self.assertEqual(person.created, person.created) + + self.assertEqual(person._data['name'], person.name) + self.assertEqual(person._data['age'], person.age) + self.assertEqual(person._data['userid'], person.userid) + self.assertEqual(person._data['created'], person.created) + + # Confirm introspection changes nothing + data_to_be_saved = sorted(person.to_mongo().keys()) + self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + + def test_default_values_when_setting_to_None(self): + """Ensure that default field values are used when creating a document. + """ + class Person(Document): + name = StringField() + age = IntField(default=30, required=False) + userid = StringField(default=lambda: 'test', required=True) + created = DateTimeField(default=datetime.datetime.utcnow) + + person = Person() + person.name = None + person.age = None + person.userid = None + person.created = None + + # Confirm saving now would store values + data_to_be_saved = sorted(person.to_mongo().keys()) + self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + + self.assertTrue(person.validate() is None) + + self.assertEqual(person.name, person.name) + self.assertEqual(person.age, person.age) + self.assertEqual(person.userid, person.userid) + self.assertEqual(person.created, person.created) + + self.assertEqual(person._data['name'], person.name) + self.assertEqual(person._data['age'], person.age) + self.assertEqual(person._data['userid'], person.userid) + self.assertEqual(person._data['created'], person.created) + + # Confirm introspection changes nothing + data_to_be_saved = sorted(person.to_mongo().keys()) + self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + + def test_default_values_when_deleting_value(self): + """Ensure that default field values are used when creating a document. + """ + class Person(Document): + name = StringField() + age = IntField(default=30, required=False) + userid = StringField(default=lambda: 'test', required=True) + created = DateTimeField(default=datetime.datetime.utcnow) + + person = Person(name="Ross") + del person.name + del person.age + del person.userid + del person.created + + data_to_be_saved = sorted(person.to_mongo().keys()) + self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + + self.assertTrue(person.validate() is None) + + self.assertEqual(person.name, person.name) + self.assertEqual(person.age, person.age) + self.assertEqual(person.userid, person.userid) + self.assertEqual(person.created, person.created) + + self.assertEqual(person._data['name'], person.name) + self.assertEqual(person._data['age'], person.age) + self.assertEqual(person._data['userid'], person.userid) + self.assertEqual(person._data['created'], person.created) + + # Confirm introspection changes nothing + data_to_be_saved = sorted(person.to_mongo().keys()) + self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) def test_required_values(self): """Ensure that required field constraints are enforced. @@ -272,10 +393,8 @@ class FieldTest(unittest.TestCase): Person.drop_collection() - person = Person() - person.height = Decimal('1.89') - person.save() - person.reload() + Person(height=Decimal('1.89')).save() + person = Person.objects.first() self.assertEqual(person.height, Decimal('1.89')) person.height = '2.0' @@ -289,6 +408,45 @@ class FieldTest(unittest.TestCase): Person.drop_collection() + def test_decimal_comparison(self): + + class Person(Document): + money = DecimalField() + + Person.drop_collection() + + Person(money=6).save() + Person(money=8).save() + Person(money=10).save() + + self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) + self.assertEqual(2, Person.objects(money__gt=7).count()) + self.assertEqual(2, Person.objects(money__gt="7").count()) + + def test_decimal_storage(self): + class Person(Document): + btc = DecimalField(precision=4) + + Person.drop_collection() + Person(btc=10).save() + Person(btc=10.1).save() + Person(btc=10.11).save() + Person(btc="10.111").save() + Person(btc=Decimal("10.1111")).save() + Person(btc=Decimal("10.11111")).save() + + # How its stored + expected = [{'btc': 10.0}, {'btc': 10.1}, {'btc': 10.11}, + {'btc': 10.111}, {'btc': 10.1111}, {'btc': 10.1111}] + actual = list(Person.objects.exclude('id').as_pymongo()) + self.assertEqual(expected, actual) + + # How it comes out locally + expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), + Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] + actual = list(Person.objects().scalar('btc')) + self.assertEqual(expected, actual) + def test_boolean_validation(self): """Ensure that invalid values cannot be assigned to boolean fields. """ @@ -354,7 +512,6 @@ class FieldTest(unittest.TestCase): person.api_key = api_key self.assertRaises(ValidationError, person.validate) - def test_datetime_validation(self): """Ensure that invalid values cannot be assigned to datetime fields. """ @@ -368,11 +525,39 @@ class FieldTest(unittest.TestCase): log.time = datetime.date.today() log.validate() + log.time = datetime.datetime.now().isoformat(' ') + log.validate() + + if dateutil: + log.time = datetime.datetime.now().isoformat('T') + log.validate() + log.time = -1 self.assertRaises(ValidationError, log.validate) - log.time = '1pm' + log.time = 'ABC' self.assertRaises(ValidationError, log.validate) + def test_datetime_tz_aware_mark_as_changed(self): + from mongoengine import connection + + # Reset the connections + connection._connection_settings = {} + connection._connections = {} + connection._dbs = {} + + connect(db='mongoenginetest', tz_aware=True) + + class LogEntry(Document): + time = DateTimeField() + + LogEntry.drop_collection() + + LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save() + + log = LogEntry.objects.first() + log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) + self.assertEqual(['time'], log._changed_fields) + def test_datetime(self): """Tests showing pymongo datetime fields handling of microseconds. Microseconds are rounded to the nearest millisecond and pre UTC @@ -426,6 +611,66 @@ class FieldTest(unittest.TestCase): LogEntry.drop_collection() + def test_datetime_usage(self): + """Tests for regular datetime fields""" + class LogEntry(Document): + date = DateTimeField() + + LogEntry.drop_collection() + + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01) + log = LogEntry() + log.date = d1 + log.validate() + log.save() + + for query in (d1, d1.isoformat(' ')): + log1 = LogEntry.objects.get(date=query) + self.assertEqual(log, log1) + + if dateutil: + log1 = LogEntry.objects.get(date=d1.isoformat('T')) + self.assertEqual(log, log1) + + LogEntry.drop_collection() + + # create 60 log entries + for i in xrange(1950, 2010): + d = datetime.datetime(i, 01, 01, 00, 00, 01) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 60) + + # Test ordering + logs = LogEntry.objects.order_by("date") + count = logs.count() + i = 0 + while i == count - 1: + self.assertTrue(logs[i].date <= logs[i + 1].date) + i += 1 + + logs = LogEntry.objects.order_by("-date") + count = logs.count() + i = 0 + while i == count - 1: + self.assertTrue(logs[i].date >= logs[i + 1].date) + i += 1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(2011, 1, 1), + date__gte=datetime.datetime(2000, 1, 1), + ) + self.assertEqual(logs.count(), 10) + + LogEntry.drop_collection() + def test_complexdatetime_storage(self): """Tests for complex datetime fields - which can handle microseconds without rounding. @@ -752,6 +997,53 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, e.save) + def test_complex_field_same_value_not_changed(self): + """ + If a complex field is set to the same value, it should not be marked as + changed. + """ + class Simple(Document): + mapping = ListField() + + Simple.drop_collection() + e = Simple().save() + e.mapping = [] + self.assertEqual([], e._changed_fields) + + class Simple(Document): + mapping = DictField() + + Simple.drop_collection() + e = Simple().save() + e.mapping = {} + self.assertEqual([], e._changed_fields) + + def test_slice_marks_field_as_changed(self): + + class Simple(Document): + widgets = ListField() + + simple = Simple(widgets=[1, 2, 3, 4]).save() + simple.widgets[:3] = [] + self.assertEqual(['widgets'], simple._changed_fields) + simple.save() + + simple = simple.reload() + self.assertEqual(simple.widgets, [4]) + + def test_del_slice_marks_field_as_changed(self): + + class Simple(Document): + widgets = ListField() + + simple = Simple(widgets=[1, 2, 3, 4]).save() + del simple.widgets[:3] + self.assertEqual(['widgets'], simple._changed_fields) + simple.save() + + simple = simple.reload() + self.assertEqual(simple.widgets, [4]) + def test_list_field_complex(self): """Ensure that the list fields can handle the complex types.""" @@ -1805,343 +2097,6 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() - def test_file_fields(self): - """Ensure that file fields can be written to and their data retrieved - """ - class PutFile(Document): - the_file = FileField() - - class StreamFile(Document): - the_file = FileField() - - class SetFile(Document): - the_file = FileField() - - text = b('Hello, World!') - more_text = b('Foo Bar') - content_type = 'text/plain' - - PutFile.drop_collection() - StreamFile.drop_collection() - SetFile.drop_collection() - - putfile = PutFile() - putfile.the_file.put(text, content_type=content_type) - putfile.save() - putfile.validate() - result = PutFile.objects.first() - self.assertTrue(putfile == result) - self.assertEqual(result.the_file.read(), text) - self.assertEqual(result.the_file.content_type, content_type) - result.the_file.delete() # Remove file from GridFS - PutFile.objects.delete() - - # Ensure file-like objects are stored - putfile = PutFile() - putstring = StringIO() - putstring.write(text) - putstring.seek(0) - putfile.the_file.put(putstring, content_type=content_type) - putfile.save() - putfile.validate() - result = PutFile.objects.first() - self.assertTrue(putfile == result) - self.assertEqual(result.the_file.read(), text) - self.assertEqual(result.the_file.content_type, content_type) - result.the_file.delete() - - streamfile = StreamFile() - streamfile.the_file.new_file(content_type=content_type) - streamfile.the_file.write(text) - streamfile.the_file.write(more_text) - streamfile.the_file.close() - streamfile.save() - streamfile.validate() - result = StreamFile.objects.first() - self.assertTrue(streamfile == result) - self.assertEqual(result.the_file.read(), text + more_text) - self.assertEqual(result.the_file.content_type, content_type) - result.the_file.seek(0) - self.assertEqual(result.the_file.tell(), 0) - self.assertEqual(result.the_file.read(len(text)), text) - self.assertEqual(result.the_file.tell(), len(text)) - self.assertEqual(result.the_file.read(len(more_text)), more_text) - self.assertEqual(result.the_file.tell(), len(text + more_text)) - result.the_file.delete() - - # Ensure deleted file returns None - self.assertTrue(result.the_file.read() == None) - - setfile = SetFile() - setfile.the_file = text - setfile.save() - setfile.validate() - result = SetFile.objects.first() - self.assertTrue(setfile == result) - self.assertEqual(result.the_file.read(), text) - - # Try replacing file with new one - result.the_file.replace(more_text) - result.save() - result.validate() - result = SetFile.objects.first() - self.assertTrue(setfile == result) - self.assertEqual(result.the_file.read(), more_text) - result.the_file.delete() - - PutFile.drop_collection() - StreamFile.drop_collection() - SetFile.drop_collection() - - # Make sure FileField is optional and not required - class DemoFile(Document): - the_file = FileField() - DemoFile.objects.create() - - - def test_file_field_no_default(self): - - class GridDocument(Document): - the_file = FileField() - - GridDocument.drop_collection() - - with tempfile.TemporaryFile() as f: - f.write(b("Hello World!")) - f.flush() - - # Test without default - doc_a = GridDocument() - doc_a.save() - - - doc_b = GridDocument.objects.with_id(doc_a.id) - doc_b.the_file.replace(f, filename='doc_b') - doc_b.save() - self.assertNotEqual(doc_b.the_file.grid_id, None) - - # Test it matches - doc_c = GridDocument.objects.with_id(doc_b.id) - self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) - - # Test with default - doc_d = GridDocument(the_file=b('')) - doc_d.save() - - doc_e = GridDocument.objects.with_id(doc_d.id) - self.assertEqual(doc_d.the_file.grid_id, doc_e.the_file.grid_id) - - doc_e.the_file.replace(f, filename='doc_e') - doc_e.save() - - doc_f = GridDocument.objects.with_id(doc_e.id) - self.assertEqual(doc_e.the_file.grid_id, doc_f.the_file.grid_id) - - db = GridDocument._get_db() - grid_fs = gridfs.GridFS(db) - self.assertEqual(['doc_b', 'doc_e'], grid_fs.list()) - - def test_file_uniqueness(self): - """Ensure that each instance of a FileField is unique - """ - class TestFile(Document): - name = StringField() - the_file = FileField() - - # First instance - test_file = TestFile() - test_file.name = "Hello, World!" - test_file.the_file.put(b('Hello, World!')) - test_file.save() - - # Second instance - test_file_dupe = TestFile() - data = test_file_dupe.the_file.read() # Should be None - - self.assertTrue(test_file.name != test_file_dupe.name) - self.assertTrue(test_file.the_file.read() != data) - - TestFile.drop_collection() - - def test_file_boolean(self): - """Ensure that a boolean test of a FileField indicates its presence - """ - class TestFile(Document): - the_file = FileField() - - test_file = TestFile() - self.assertFalse(bool(test_file.the_file)) - test_file.the_file = b('Hello, World!') - test_file.the_file.content_type = 'text/plain' - test_file.save() - self.assertTrue(bool(test_file.the_file)) - - TestFile.drop_collection() - - def test_file_cmp(self): - """Test comparing against other types""" - class TestFile(Document): - the_file = FileField() - - test_file = TestFile() - self.assertFalse(test_file.the_file in [{"test": 1}]) - - def test_image_field(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') - - class TestImage(Document): - image = ImageField() - - TestImage.drop_collection() - - t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) - t.save() - - t = TestImage.objects.first() - - self.assertEqual(t.image.format, 'PNG') - - w, h = t.image.size - self.assertEqual(w, 371) - self.assertEqual(h, 76) - - t.image.delete() - - def test_image_field_resize(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') - - class TestImage(Document): - image = ImageField(size=(185, 37)) - - TestImage.drop_collection() - - t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) - t.save() - - t = TestImage.objects.first() - - self.assertEqual(t.image.format, 'PNG') - w, h = t.image.size - - self.assertEqual(w, 185) - self.assertEqual(h, 37) - - t.image.delete() - - def test_image_field_resize_force(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') - - class TestImage(Document): - image = ImageField(size=(185, 37, True)) - - TestImage.drop_collection() - - t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) - t.save() - - t = TestImage.objects.first() - - self.assertEqual(t.image.format, 'PNG') - w, h = t.image.size - - self.assertEqual(w, 185) - self.assertEqual(h, 37) - - t.image.delete() - - def test_image_field_thumbnail(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') - - class TestImage(Document): - image = ImageField(thumbnail_size=(92, 18)) - - TestImage.drop_collection() - - t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) - t.save() - - t = TestImage.objects.first() - - self.assertEqual(t.image.thumbnail.format, 'PNG') - self.assertEqual(t.image.thumbnail.width, 92) - self.assertEqual(t.image.thumbnail.height, 18) - - t.image.delete() - - def test_file_multidb(self): - register_connection('test_files', 'test_files') - class TestFile(Document): - name = StringField() - the_file = FileField(db_alias="test_files", - collection_name="macumba") - - TestFile.drop_collection() - - # delete old filesystem - get_db("test_files").macumba.files.drop() - get_db("test_files").macumba.chunks.drop() - - # First instance - test_file = TestFile() - test_file.name = "Hello, World!" - test_file.the_file.put(b('Hello, World!'), - name="hello.txt") - test_file.save() - - data = get_db("test_files").macumba.files.find_one() - self.assertEqual(data.get('name'), 'hello.txt') - - test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), - b('Hello, World!')) - - def test_geo_indexes(self): - """Ensure that indexes are created automatically for GeoPointFields. - """ - class Event(Document): - title = StringField() - location = GeoPointField() - - Event.drop_collection() - event = Event(title="Coltrane Motion @ Double Door", - location=[41.909889, -87.677137]) - event.save() - - info = Event.objects._collection.index_information() - self.assertTrue(u'location_2d' in info) - self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')]) - - Event.drop_collection() - - def test_geo_embedded_indexes(self): - """Ensure that indexes are created automatically for GeoPointFields on - embedded documents. - """ - class Venue(EmbeddedDocument): - location = GeoPointField() - name = StringField() - - class Event(Document): - title = StringField() - venue = EmbeddedDocumentField(Venue) - - Event.drop_collection() - venue = Venue(name="Double Door", location=[41.909889, -87.677137]) - event = Event(title="Coltrane Motion", venue=venue) - event.save() - - info = Event.objects._collection.index_information() - self.assertTrue(u'location_2d' in info) - self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')]) - def test_ensure_unique_default_instances(self): """Ensure that every field has it's own unique default instance.""" class D(Document): @@ -2164,8 +2119,7 @@ class FieldTest(unittest.TestCase): Person.drop_collection() for x in xrange(10): - p = Person(name="Person %s" % x) - p.save() + Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) @@ -2176,6 +2130,42 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 1000) + + + def test_sequence_field_get_next_value(self): + class Person(Document): + id = SequenceField(primary_key=True) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in xrange(10): + Person(name="Person %s" % x).save() + + self.assertEqual(Person.id.get_next_value(), 11) + self.db['mongoengine.counters'].drop() + + self.assertEqual(Person.id.get_next_value(), 1) + + class Person(Document): + id = SequenceField(primary_key=True, value_decorator=str) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in xrange(10): + Person(name="Person %s" % x).save() + + self.assertEqual(Person.id.get_next_value(), '11') + self.db['mongoengine.counters'].drop() + + self.assertEqual(Person.id.get_next_value(), '1') + def test_sequence_field_sequence_name(self): class Person(Document): id = SequenceField(primary_key=True, sequence_name='jelly') @@ -2185,8 +2175,7 @@ class FieldTest(unittest.TestCase): Person.drop_collection() for x in xrange(10): - p = Person(name="Person %s" % x) - p.save() + Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) self.assertEqual(c['next'], 10) @@ -2197,6 +2186,10 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) self.assertEqual(c['next'], 10) + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) + self.assertEqual(c['next'], 1000) + def test_multiple_sequence_fields(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2207,8 +2200,7 @@ class FieldTest(unittest.TestCase): Person.drop_collection() for x in xrange(10): - p = Person(name="Person %s" % x) - p.save() + Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) @@ -2222,6 +2214,14 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 1000) + + Person.counter.set_next_value(999) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) + self.assertEqual(c['next'], 999) + def test_sequence_fields_reload(self): class Animal(Document): counter = SequenceField() @@ -2230,8 +2230,7 @@ class FieldTest(unittest.TestCase): self.db['mongoengine.counters'].drop() Animal.drop_collection() - a = Animal(name="Boi") - a.save() + a = Animal(name="Boi").save() self.assertEqual(a.counter, 1) a.reload() @@ -2261,10 +2260,8 @@ class FieldTest(unittest.TestCase): Person.drop_collection() for x in xrange(10): - a = Animal(name="Animal %s" % x) - a.save() - p = Person(name="Person %s" % x) - p.save() + Animal(name="Animal %s" % x).save() + Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) @@ -2477,6 +2474,78 @@ class FieldTest(unittest.TestCase): user = User(email='me@example.com') self.assertTrue(user.validate() is None) + def test_tuples_as_tuples(self): + """ + Ensure that tuples remain tuples when they are + inside a ComplexBaseField + """ + from mongoengine.base import BaseField + + class EnumField(BaseField): + + def __init__(self, **kwargs): + super(EnumField, self).__init__(**kwargs) + + def to_mongo(self, value): + return value + + def to_python(self, value): + return tuple(value) + + class TestDoc(Document): + items = ListField(EnumField()) + + TestDoc.drop_collection() + tuples = [(100, 'Testing')] + doc = TestDoc() + doc.items = tuples + doc.save() + x = TestDoc.objects().get() + self.assertTrue(x is not None) + self.assertTrue(len(x.items) == 1) + self.assertTrue(tuple(x.items[0]) in tuples) + self.assertTrue(x.items[0] in tuples) + + def test_dynamic_fields_class(self): + + class Doc2(Document): + field_1 = StringField(db_field='f') + + class Doc(Document): + my_id = IntField(required=True, unique=True, primary_key=True) + embed_me = DynamicField(db_field='e') + field_x = StringField(db_field='x') + + Doc.drop_collection() + Doc2.drop_collection() + + doc2 = Doc2(field_1="hello") + doc = Doc(my_id=1, embed_me=doc2, field_x="x") + self.assertRaises(OperationError, doc.save) + + doc2.save() + doc.save() + + doc = Doc.objects.get() + self.assertEqual(doc.embed_me.field_1, "hello") + + def test_dynamic_fields_embedded_class(self): + + class Embed(EmbeddedDocument): + field_1 = StringField(db_field='f') + + class Doc(Document): + my_id = IntField(required=True, unique=True, primary_key=True) + embed_me = DynamicField(db_field='e') + field_x = StringField(db_field='x') + + Doc.drop_collection() + + Doc(my_id=1, embed_me=Embed(field_1="hello"), field_x="x").save() + + doc = Doc.objects.get() + self.assertEqual(doc.embed_me.field_1, "hello") + if __name__ == '__main__': unittest.main() diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 44d2862..ba601de 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -from __future__ import with_statement import sys sys.path[0:0] = [""] @@ -15,7 +14,14 @@ from mongoengine import * from mongoengine.connection import get_db from mongoengine.python_support import PY3, b, StringIO +try: + from PIL import Image + HAS_PIL = True +except ImportError: + HAS_PIL = False + TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') +TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') class FileTest(unittest.TestCase): @@ -47,11 +53,12 @@ class FileTest(unittest.TestCase): content_type = 'text/plain' putfile = PutFile() - putfile.the_file.put(text, content_type=content_type) + putfile.the_file.put(text, content_type=content_type, filename="hello") putfile.save() result = PutFile.objects.first() self.assertTrue(putfile == result) + self.assertEqual("%s" % result.the_file, "") self.assertEqual(result.the_file.read(), text) self.assertEqual(result.the_file.content_type, content_type) result.the_file.delete() # Remove file from GridFS @@ -217,6 +224,19 @@ class FileTest(unittest.TestCase): self.assertEqual(marmot.photo.content_type, 'image/jpeg') self.assertEqual(marmot.photo.foo, 'bar') + def test_file_reassigning(self): + class TestFile(Document): + the_file = FileField() + TestFile.drop_collection() + + test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + self.assertEqual(test_file.the_file.get().length, 8313) + + test_file = TestFile.objects.first() + test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.save() + self.assertEqual(test_file.the_file.get().length, 4971) + def test_file_boolean(self): """Ensure that a boolean test of a FileField indicates its presence """ @@ -242,14 +262,25 @@ class FileTest(unittest.TestCase): self.assertFalse(test_file.the_file in [{"test": 1}]) def test_image_field(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') + if not HAS_PIL: + raise SkipTest('PIL not installed') class TestImage(Document): image = ImageField() TestImage.drop_collection() + with tempfile.TemporaryFile() as f: + f.write(b("Hello World!")) + f.flush() + + t = TestImage() + try: + t.image.put(f) + self.fail("Should have raised an invalidation error") + except ValidationError, e: + self.assertEquals("%s" % e, "Invalid image: cannot identify image file") + t = TestImage() t.image.put(open(TEST_IMAGE_PATH, 'rb')) t.save() @@ -264,9 +295,25 @@ class FileTest(unittest.TestCase): t.image.delete() + def test_image_field_reassigning(self): + if not HAS_PIL: + raise SkipTest('PIL not installed') + + class TestFile(Document): + the_file = ImageField() + TestFile.drop_collection() + + test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + self.assertEqual(test_file.the_file.size, (371, 76)) + + test_file = TestFile.objects.first() + test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.save() + self.assertEqual(test_file.the_file.size, (45, 101)) + def test_image_field_resize(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') + if not HAS_PIL: + raise SkipTest('PIL not installed') class TestImage(Document): image = ImageField(size=(185, 37)) @@ -288,8 +335,8 @@ class FileTest(unittest.TestCase): t.image.delete() def test_image_field_resize_force(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') + if not HAS_PIL: + raise SkipTest('PIL not installed') class TestImage(Document): image = ImageField(size=(185, 37, True)) @@ -311,8 +358,8 @@ class FileTest(unittest.TestCase): t.image.delete() def test_image_field_thumbnail(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') + if not HAS_PIL: + raise SkipTest('PIL not installed') class TestImage(Document): image = ImageField(thumbnail_size=(92, 18)) @@ -359,6 +406,14 @@ class FileTest(unittest.TestCase): self.assertEqual(test_file.the_file.read(), b('Hello, World!')) + test_file = TestFile.objects.first() + test_file.the_file = b('HELLO, WORLD!') + test_file.save() + + test_file = TestFile.objects.first() + self.assertEqual(test_file.the_file.read(), + b('HELLO, WORLD!')) + def test_copyable(self): class PutFile(Document): the_file = FileField() @@ -378,6 +433,54 @@ class FileTest(unittest.TestCase): self.assertEqual(putfile, copy.copy(putfile)) self.assertEqual(putfile, copy.deepcopy(putfile)) + def test_get_image_by_grid_id(self): + + if not HAS_PIL: + raise SkipTest('PIL not installed') + + class TestImage(Document): + + image1 = ImageField() + image2 = ImageField() + + TestImage.drop_collection() + + t = TestImage() + t.image1.put(open(TEST_IMAGE_PATH, 'rb')) + t.image2.put(open(TEST_IMAGE2_PATH, 'rb')) + t.save() + + test = TestImage.objects.first() + grid_id = test.image1.grid_id + + self.assertEqual(1, TestImage.objects(Q(image1=grid_id) + or Q(image2=grid_id)).count()) + + def test_complex_field_filefield(self): + """Ensure you can add meta data to file""" + + class Animal(Document): + genus = StringField() + family = StringField() + photos = ListField(FileField()) + + Animal.drop_collection() + marmot = Animal(genus='Marmota', family='Sciuridae') + + marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk + + photos_field = marmot._fields['photos'].field + new_proxy = photos_field.get_proxy_obj('photos', marmot) + new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') + marmot_photo.close() + + marmot.photos.append(new_proxy) + marmot.save() + + marmot = Animal.objects.get() + self.assertEqual(marmot.photos[0].content_type, 'image/jpeg') + self.assertEqual(marmot.photos[0].foo, 'bar') + self.assertEqual(marmot.photos[0].get().length, 8313) if __name__ == '__main__': unittest.main() diff --git a/tests/fields/geo.py b/tests/fields/geo.py new file mode 100644 index 0000000..31ded26 --- /dev/null +++ b/tests/fields/geo.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +import sys +sys.path[0:0] = [""] + +import unittest + +from mongoengine import * +from mongoengine.connection import get_db + +__all__ = ("GeoFieldTest", ) + + +class GeoFieldTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def _test_for_expected_error(self, Cls, loc, expected): + try: + Cls(loc=loc).validate() + self.fail() + except ValidationError, e: + self.assertEqual(expected, e.to_dict()['loc']) + + def test_geopoint_validation(self): + class Location(Document): + loc = GeoPointField() + + invalid_coords = [{"x": 1, "y": 2}, 5, "a"] + expected = 'GeoPointField can only accept tuples or lists of (x, y)' + + for coord in invalid_coords: + self._test_for_expected_error(Location, coord, expected) + + invalid_coords = [[], [1], [1, 2, 3]] + for coord in invalid_coords: + expected = "Value (%s) must be a two-dimensional point" % repr(coord) + self._test_for_expected_error(Location, coord, expected) + + invalid_coords = [[{}, {}], ("a", "b")] + for coord in invalid_coords: + expected = "Both values (%s) in point must be float or int" % repr(coord) + self._test_for_expected_error(Location, coord, expected) + + def test_point_validation(self): + class Location(Document): + loc = PointField() + + invalid_coords = {"x": 1, "y": 2} + expected = 'PointField can only accept a valid GeoJson dictionary or lists of (x, y)' + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = {"type": "MadeUp", "coordinates": []} + expected = 'PointField type must be "Point"' + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = {"type": "Point", "coordinates": [1, 2, 3]} + expected = "Value ([1, 2, 3]) must be a two-dimensional point" + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [5, "a"] + expected = "PointField can only accept lists of [x, y]" + for coord in invalid_coords: + self._test_for_expected_error(Location, coord, expected) + + invalid_coords = [[], [1], [1, 2, 3]] + for coord in invalid_coords: + expected = "Value (%s) must be a two-dimensional point" % repr(coord) + self._test_for_expected_error(Location, coord, expected) + + invalid_coords = [[{}, {}], ("a", "b")] + for coord in invalid_coords: + expected = "Both values (%s) in point must be float or int" % repr(coord) + self._test_for_expected_error(Location, coord, expected) + + Location(loc=[1, 2]).validate() + + def test_linestring_validation(self): + class Location(Document): + loc = LineStringField() + + invalid_coords = {"x": 1, "y": 2} + expected = 'LineStringField can only accept a valid GeoJson dictionary or lists of (x, y)' + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = {"type": "MadeUp", "coordinates": [[]]} + expected = 'LineStringField type must be "LineString"' + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = {"type": "LineString", "coordinates": [[1, 2, 3]]} + expected = "Invalid LineString:\nValue ([1, 2, 3]) must be a two-dimensional point" + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [5, "a"] + expected = "Invalid LineString must contain at least one valid point" + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [[1]] + expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0]) + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [[1, 2, 3]] + expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0]) + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [[[{}, {}]], [("a", "b")]] + for coord in invalid_coords: + expected = "Invalid LineString:\nBoth values (%s) in point must be float or int" % repr(coord[0]) + self._test_for_expected_error(Location, coord, expected) + + Location(loc=[[1, 2], [3, 4], [5, 6], [1,2]]).validate() + + def test_polygon_validation(self): + class Location(Document): + loc = PolygonField() + + invalid_coords = {"x": 1, "y": 2} + expected = 'PolygonField can only accept a valid GeoJson dictionary or lists of (x, y)' + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = {"type": "MadeUp", "coordinates": [[]]} + expected = 'PolygonField type must be "Polygon"' + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = {"type": "Polygon", "coordinates": [[[1, 2, 3]]]} + expected = "Invalid Polygon:\nValue ([1, 2, 3]) must be a two-dimensional point" + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [[[5, "a"]]] + expected = "Invalid Polygon:\nBoth values ([5, 'a']) in point must be float or int" + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [[[]]] + expected = "Invalid Polygon must contain at least one valid linestring" + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [[[1, 2, 3]]] + expected = "Invalid Polygon:\nValue ([1, 2, 3]) must be a two-dimensional point" + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [[[{}, {}]], [("a", "b")]] + expected = "Invalid Polygon:\nBoth values ([{}, {}]) in point must be float or int, Both values (('a', 'b')) in point must be float or int" + self._test_for_expected_error(Location, invalid_coords, expected) + + invalid_coords = [[[1, 2], [3, 4]]] + expected = "Invalid Polygon:\nLineStrings must start and end at the same point" + self._test_for_expected_error(Location, invalid_coords, expected) + + Location(loc=[[[1, 2], [3, 4], [5, 6], [1, 2]]]).validate() + + def test_indexes_geopoint(self): + """Ensure that indexes are created automatically for GeoPointFields. + """ + class Event(Document): + title = StringField() + location = GeoPointField() + + geo_indicies = Event._geo_indices() + self.assertEqual(geo_indicies, [{'fields': [('location', '2d')]}]) + + def test_geopoint_embedded_indexes(self): + """Ensure that indexes are created automatically for GeoPointFields on + embedded documents. + """ + class Venue(EmbeddedDocument): + location = GeoPointField() + name = StringField() + + class Event(Document): + title = StringField() + venue = EmbeddedDocumentField(Venue) + + geo_indicies = Event._geo_indices() + self.assertEqual(geo_indicies, [{'fields': [('venue.location', '2d')]}]) + + def test_indexes_2dsphere(self): + """Ensure that indexes are created automatically for GeoPointFields. + """ + class Event(Document): + title = StringField() + point = PointField() + line = LineStringField() + polygon = PolygonField() + + geo_indicies = Event._geo_indices() + self.assertTrue({'fields': [('line', '2dsphere')]} in geo_indicies) + self.assertTrue({'fields': [('polygon', '2dsphere')]} in geo_indicies) + self.assertTrue({'fields': [('point', '2dsphere')]} in geo_indicies) + + def test_indexes_2dsphere_embedded(self): + """Ensure that indexes are created automatically for GeoPointFields. + """ + class Venue(EmbeddedDocument): + name = StringField() + point = PointField() + line = LineStringField() + polygon = PolygonField() + + class Event(Document): + title = StringField() + venue = EmbeddedDocumentField(Venue) + + geo_indicies = Event._geo_indices() + self.assertTrue({'fields': [('venue.line', '2dsphere')]} in geo_indicies) + self.assertTrue({'fields': [('venue.polygon', '2dsphere')]} in geo_indicies) + self.assertTrue({'fields': [('venue.point', '2dsphere')]} in geo_indicies) + + def test_geo_indexes_recursion(self): + + class Location(Document): + name = StringField() + location = GeoPointField() + + class Parent(Document): + name = StringField() + location = ReferenceField(Location) + + Location.drop_collection() + Parent.drop_collection() + + list(Parent.objects) + + collection = Parent._get_collection() + info = collection.index_information() + + self.assertFalse('location_2d' in info) + + self.assertEqual(len(Parent._geo_indices()), 0) + self.assertEqual(len(Location._geo_indices()), 1) + + def test_geo_indexes_auto_index(self): + + # Test just listing the fields + class Log(Document): + location = PointField(auto_index=False) + datetime = DateTimeField() + + meta = { + 'indexes': [[("location", "2dsphere"), ("datetime", 1)]] + } + + self.assertEqual([], Log._geo_indices()) + + Log.drop_collection() + Log.ensure_indexes() + + info = Log._get_collection().index_information() + self.assertEqual(info["location_2dsphere_datetime_1"]["key"], + [('location', '2dsphere'), ('datetime', 1)]) + + # Test listing explicitly + class Log(Document): + location = PointField(auto_index=False) + datetime = DateTimeField() + + meta = { + 'indexes': [ + {'fields': [("location", "2dsphere"), ("datetime", 1)]} + ] + } + + self.assertEqual([], Log._geo_indices()) + + Log.drop_collection() + Log.ensure_indexes() + + info = Log._get_collection().index_information() + self.assertEqual(info["location_2dsphere_datetime_1"]["key"], + [('location', '2dsphere'), ('datetime', 1)]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/fields/mongodb_leaf.png b/tests/fields/mongodb_leaf.png new file mode 100644 index 0000000..36661ce Binary files /dev/null and b/tests/fields/mongodb_leaf.png differ diff --git a/tests/fixtures.py b/tests/fixtures.py index fd9062e..f1344d7 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,6 +1,8 @@ +import pickle from datetime import datetime from mongoengine import * +from mongoengine import signals class PickleEmbedded(EmbeddedDocument): @@ -15,6 +17,32 @@ class PickleTest(Document): photo = FileField() +class PickleDyanmicEmbedded(DynamicEmbeddedDocument): + date = DateTimeField(default=datetime.now) + + +class PickleDynamicTest(DynamicDocument): + number = IntField() + + +class PickleSignalsTest(Document): + number = IntField() + string = StringField(choices=(('One', '1'), ('Two', '2'))) + embedded = EmbeddedDocumentField(PickleEmbedded) + lists = ListField(StringField()) + + @classmethod + def post_save(self, sender, document, created, **kwargs): + pickled = pickle.dumps(document) + + @classmethod + def post_delete(self, sender, document, **kwargs): + pickled = pickle.dumps(document) + +signals.post_save.connect(PickleSignalsTest.post_save, sender=PickleSignalsTest) +signals.post_delete.connect(PickleSignalsTest.post_delete, sender=PickleSignalsTest) + + class Mixin(object): name = StringField() diff --git a/tests/migration/__init__.py b/tests/migration/__init__.py index 882e737..6fc83e0 100644 --- a/tests/migration/__init__.py +++ b/tests/migration/__init__.py @@ -1,4 +1,8 @@ +from convert_to_new_inheritance_model import * +from decimalfield_as_float import * +from refrencefield_dbref_to_object_id import * from turn_off_inheritance import * +from uuidfield_to_binary import * if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/migration/test_convert_to_new_inheritance_model.py b/tests/migration/convert_to_new_inheritance_model.py similarity index 97% rename from tests/migration/test_convert_to_new_inheritance_model.py rename to tests/migration/convert_to_new_inheritance_model.py index d4337bf..89ee9e9 100644 --- a/tests/migration/test_convert_to_new_inheritance_model.py +++ b/tests/migration/convert_to_new_inheritance_model.py @@ -38,7 +38,7 @@ class ConvertToNewInheritanceModel(unittest.TestCase): # 3. Confirm extra data is removed count = collection.find({'_types': {"$exists": True}}).count() - assert count == 0 + self.assertEqual(0, count) # 4. Remove indexes info = collection.index_information() diff --git a/tests/migration/decimalfield_as_float.py b/tests/migration/decimalfield_as_float.py new file mode 100644 index 0000000..3903c91 --- /dev/null +++ b/tests/migration/decimalfield_as_float.py @@ -0,0 +1,50 @@ + # -*- coding: utf-8 -*- +import unittest +import decimal +from decimal import Decimal + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField, DecimalField, ListField + +__all__ = ('ConvertDecimalField', ) + + +class ConvertDecimalField(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def test_how_to_convert_decimal_fields(self): + """Demonstrates migrating from 0.7 to 0.8 + """ + + # 1. Old definition - using dbrefs + class Person(Document): + name = StringField() + money = DecimalField(force_string=True) + monies = ListField(DecimalField(force_string=True)) + + Person.drop_collection() + Person(name="Wilson Jr", money=Decimal("2.50"), + monies=[Decimal("2.10"), Decimal("5.00")]).save() + + # 2. Start the migration by changing the schema + # Change DecimalField - add precision and rounding settings + class Person(Document): + name = StringField() + money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP) + monies = ListField(DecimalField(precision=2, + rounding=decimal.ROUND_HALF_UP)) + + # 3. Loop all the objects and mark parent as changed + for p in Person.objects: + p._mark_as_changed('money') + p._mark_as_changed('monies') + p.save() + + # 4. Confirmation of the fix! + wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] + self.assertTrue(isinstance(wilson['money'], float)) + self.assertTrue(all([isinstance(m, float) for m in wilson['monies']])) diff --git a/tests/migration/refrencefield_dbref_to_object_id.py b/tests/migration/refrencefield_dbref_to_object_id.py new file mode 100644 index 0000000..d3acbe9 --- /dev/null +++ b/tests/migration/refrencefield_dbref_to_object_id.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField, ReferenceField, ListField + +__all__ = ('ConvertToObjectIdsModel', ) + + +class ConvertToObjectIdsModel(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def test_how_to_convert_to_object_id_reference_fields(self): + """Demonstrates migrating from 0.7 to 0.8 + """ + + # 1. Old definition - using dbrefs + class Person(Document): + name = StringField() + parent = ReferenceField('self', dbref=True) + friends = ListField(ReferenceField('self', dbref=True)) + + Person.drop_collection() + + p1 = Person(name="Wilson", parent=None).save() + f1 = Person(name="John", parent=None).save() + f2 = Person(name="Paul", parent=None).save() + f3 = Person(name="George", parent=None).save() + f4 = Person(name="Ringo", parent=None).save() + Person(name="Wilson Jr", parent=p1, friends=[f1, f2, f3, f4]).save() + + # 2. Start the migration by changing the schema + # Change ReferenceField as now dbref defaults to False + class Person(Document): + name = StringField() + parent = ReferenceField('self') + friends = ListField(ReferenceField('self')) + + # 3. Loop all the objects and mark parent as changed + for p in Person.objects: + p._mark_as_changed('parent') + p._mark_as_changed('friends') + p.save() + + # 4. Confirmation of the fix! + wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] + self.assertEqual(p1.id, wilson['parent']) + self.assertEqual([f1.id, f2.id, f3.id, f4.id], wilson['friends']) diff --git a/tests/migration/uuidfield_to_binary.py b/tests/migration/uuidfield_to_binary.py new file mode 100644 index 0000000..a535e91 --- /dev/null +++ b/tests/migration/uuidfield_to_binary.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +import unittest +import uuid + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField, UUIDField, ListField + +__all__ = ('ConvertToBinaryUUID', ) + + +class ConvertToBinaryUUID(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def test_how_to_convert_to_binary_uuid_fields(self): + """Demonstrates migrating from 0.7 to 0.8 + """ + + # 1. Old definition - using dbrefs + class Person(Document): + name = StringField() + uuid = UUIDField(binary=False) + uuids = ListField(UUIDField(binary=False)) + + Person.drop_collection() + Person(name="Wilson Jr", uuid=uuid.uuid4(), + uuids=[uuid.uuid4(), uuid.uuid4()]).save() + + # 2. Start the migration by changing the schema + # Change UUIDFIeld as now binary defaults to True + class Person(Document): + name = StringField() + uuid = UUIDField() + uuids = ListField(UUIDField()) + + # 3. Loop all the objects and mark parent as changed + for p in Person.objects: + p._mark_as_changed('uuid') + p._mark_as_changed('uuids') + p.save() + + # 4. Confirmation of the fix! + wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] + self.assertTrue(isinstance(wilson['uuid'], uuid.UUID)) + self.assertTrue(all([isinstance(u, uuid.UUID) for u in wilson['uuids']])) diff --git a/tests/queryset/__init__.py b/tests/queryset/__init__.py index 93cb8c2..8a93c19 100644 --- a/tests/queryset/__init__.py +++ b/tests/queryset/__init__.py @@ -1,5 +1,5 @@ - from transform import * from field_list import * from queryset import * -from visitor import * \ No newline at end of file +from visitor import * +from geo import * diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py index 4a8a72b..7d66d26 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/field_list.py @@ -20,47 +20,47 @@ class QueryFieldListTest(unittest.TestCase): def test_include_include(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'a': True, 'b': True}) + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY, _only_called=True) + self.assertEqual(q.as_dict(), {'a': 1, 'b': 1}) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'b': True}) + self.assertEqual(q.as_dict(), {'a': 1, 'b': 1, 'c': 1}) def test_include_exclude(self): q = QueryFieldList() q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'a': True, 'b': True}) + self.assertEqual(q.as_dict(), {'a': 1, 'b': 1}) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': True}) + self.assertEqual(q.as_dict(), {'a': 1}) def test_exclude_exclude(self): q = QueryFieldList() q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': False, 'b': False}) + self.assertEqual(q.as_dict(), {'a': 0, 'b': 0}) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) + self.assertEqual(q.as_dict(), {'a': 0, 'b': 0, 'c': 0}) def test_exclude_include(self): q = QueryFieldList() q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': False, 'b': False}) + self.assertEqual(q.as_dict(), {'a': 0, 'b': 0}) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'c': True}) + self.assertEqual(q.as_dict(), {'c': 1}) def test_always_include(self): q = QueryFieldList(always_include=['x', 'y']) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) + self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'c': 1}) def test_reset(self): q = QueryFieldList(always_include=['x', 'y']) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) + self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'c': 1}) q.reset() self.assertFalse(q) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) + self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'b': 1, 'c': 1}) def test_using_a_slice(self): q = QueryFieldList() @@ -97,7 +97,7 @@ class OnlyExcludeAllTest(unittest.TestCase): qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) self.assertEqual(qs._loaded_fields.as_dict(), - {'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1}) + {'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1}) qs = qs.only(*only) self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) qs = qs.exclude(*exclude) @@ -134,15 +134,15 @@ class OnlyExcludeAllTest(unittest.TestCase): qs = qs.only(*only) qs = qs.fields(slice__b=5) self.assertEqual(qs._loaded_fields.as_dict(), - {'b': {'$slice': 5}, 'c': 1}) + {'b': {'$slice': 5}, 'c': 1}) qs = qs.fields(slice__c=[5, 1]) self.assertEqual(qs._loaded_fields.as_dict(), - {'b': {'$slice': 5}, 'c': {'$slice': [5, 1]}}) + {'b': {'$slice': 5}, 'c': {'$slice': [5, 1]}}) qs = qs.exclude('c') self.assertEqual(qs._loaded_fields.as_dict(), - {'b': {'$slice': 5}}) + {'b': {'$slice': 5}}) def test_only(self): """Ensure that QuerySet.only only returns the requested fields. @@ -162,6 +162,10 @@ class OnlyExcludeAllTest(unittest.TestCase): self.assertEqual(obj.name, person.name) self.assertEqual(obj.age, person.age) + obj = self.Person.objects.only(*('id', 'name',)).get() + self.assertEqual(obj.name, person.name) + self.assertEqual(obj.age, None) + # Check polymorphism still works class Employee(self.Person): salary = IntField(db_field='wage') @@ -328,7 +332,7 @@ class OnlyExcludeAllTest(unittest.TestCase): Numbers.drop_collection() - numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) + numbers = Numbers(n=[0, 1, 2, 3, 4, 5, -5, -4, -3, -2, -1]) numbers.save() # first three @@ -368,7 +372,7 @@ class OnlyExcludeAllTest(unittest.TestCase): Numbers.drop_collection() numbers = Numbers() - numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) + numbers.embedded = EmbeddedNumber(n=[0, 1, 2, 3, 4, 5, -5, -4, -3, -2, -1]) numbers.save() # first three @@ -395,5 +399,28 @@ class OnlyExcludeAllTest(unittest.TestCase): numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) + + def test_exclude_from_subclasses_docs(self): + + class Base(Document): + username = StringField() + + meta = {'allow_inheritance': True} + + class Anon(Base): + anon = BooleanField() + + class User(Base): + password = StringField() + wibble = StringField() + + Base.drop_collection() + User(username="mongodb", password="secret").save() + + user = Base.objects().exclude("password", "wibble").first() + self.assertEqual(user.password, None) + + self.assertRaises(LookUpError, Base.objects.exclude, "made_up") + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py new file mode 100644 index 0000000..f564896 --- /dev/null +++ b/tests/queryset/geo.py @@ -0,0 +1,418 @@ +import sys +sys.path[0:0] = [""] + +import unittest +from datetime import datetime, timedelta +from mongoengine import * + +__all__ = ("GeoQueriesTest",) + + +class GeoQueriesTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + + def test_geospatial_operators(self): + """Ensure that geospatial queries are working. + """ + class Event(Document): + title = StringField() + date = DateTimeField() + location = GeoPointField() + + def __unicode__(self): + return self.title + + Event.drop_collection() + + event1 = Event(title="Coltrane Motion @ Double Door", + date=datetime.now() - timedelta(days=1), + location=[-87.677137, 41.909889]).save() + event2 = Event(title="Coltrane Motion @ Bottom of the Hill", + date=datetime.now() - timedelta(days=10), + location=[-122.4194155, 37.7749295]).save() + event3 = Event(title="Coltrane Motion @ Empty Bottle", + date=datetime.now(), + location=[-87.686638, 41.900474]).save() + + # find all events "near" pitchfork office, chicago. + # note that "near" will show the san francisco event, too, + # although it sorts to last. + events = Event.objects(location__near=[-87.67892, 41.9120459]) + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event1, event3, event2]) + + # find events within 5 degrees of pitchfork office, chicago + point_and_distance = [[-87.67892, 41.9120459], 5] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 2) + events = list(events) + self.assertTrue(event2 not in events) + self.assertTrue(event1 in events) + self.assertTrue(event3 in events) + + # ensure ordering is respected by "near" + events = Event.objects(location__near=[-87.67892, 41.9120459]) + events = events.order_by("-date") + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event3, event1, event2]) + + # find events within 10 degrees of san francisco + point = [-122.415579, 37.7566023] + events = Event.objects(location__near=point, location__max_distance=10) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0], event2) + + # find events within 10 degrees of san francisco + point_and_distance = [[-122.415579, 37.7566023], 10] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0], event2) + + # find events within 1 degree of greenpoint, broolyn, nyc, ny + point_and_distance = [[-73.9509714, 40.7237134], 1] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 0) + + # ensure ordering is respected by "within_distance" + point_and_distance = [[-87.67892, 41.9120459], 10] + events = Event.objects(location__within_distance=point_and_distance) + events = events.order_by("-date") + self.assertEqual(events.count(), 2) + self.assertEqual(events[0], event3) + + # check that within_box works + box = [(-125.0, 35.0), (-100.0, 40.0)] + events = Event.objects(location__within_box=box) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0].id, event2.id) + + polygon = [ + (-87.694445, 41.912114), + (-87.69084, 41.919395), + (-87.681742, 41.927186), + (-87.654276, 41.911731), + (-87.656164, 41.898061), + ] + events = Event.objects(location__within_polygon=polygon) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0].id, event1.id) + + polygon2 = [ + (-1.742249, 54.033586), + (-1.225891, 52.792797), + (-4.40094, 53.389881) + ] + events = Event.objects(location__within_polygon=polygon2) + self.assertEqual(events.count(), 0) + + def test_geo_spatial_embedded(self): + + class Venue(EmbeddedDocument): + location = GeoPointField() + name = StringField() + + class Event(Document): + title = StringField() + venue = EmbeddedDocumentField(Venue) + + Event.drop_collection() + + venue1 = Venue(name="The Rock", location=[-87.677137, 41.909889]) + venue2 = Venue(name="The Bridge", location=[-122.4194155, 37.7749295]) + + event1 = Event(title="Coltrane Motion @ Double Door", + venue=venue1).save() + event2 = Event(title="Coltrane Motion @ Bottom of the Hill", + venue=venue2).save() + event3 = Event(title="Coltrane Motion @ Empty Bottle", + venue=venue1).save() + + # find all events "near" pitchfork office, chicago. + # note that "near" will show the san francisco event, too, + # although it sorts to last. + events = Event.objects(venue__location__near=[-87.67892, 41.9120459]) + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event1, event3, event2]) + + def test_spherical_geospatial_operators(self): + """Ensure that spherical geospatial queries are working + """ + class Point(Document): + location = GeoPointField() + + Point.drop_collection() + + # These points are one degree apart, which (according to Google Maps) + # is about 110 km apart at this place on the Earth. + north_point = Point(location=[-122, 38]).save() # Near Concord, CA + south_point = Point(location=[-122, 37]).save() # Near Santa Cruz, CA + + earth_radius = 6378.009 # in km (needs to be a float for dividing by) + + # Finds both points because they are within 60 km of the reference + # point equidistant between them. + points = Point.objects(location__near_sphere=[-122, 37.5]) + self.assertEqual(points.count(), 2) + + # Same behavior for _within_spherical_distance + points = Point.objects( + location__within_spherical_distance=[[-122, 37.5], 60/earth_radius] + ) + self.assertEqual(points.count(), 2) + + points = Point.objects(location__near_sphere=[-122, 37.5], + location__max_distance=60 / earth_radius) + self.assertEqual(points.count(), 2) + + # Finds both points, but orders the north point first because it's + # closer to the reference point to the north. + points = Point.objects(location__near_sphere=[-122, 38.5]) + self.assertEqual(points.count(), 2) + self.assertEqual(points[0].id, north_point.id) + self.assertEqual(points[1].id, south_point.id) + + # Finds both points, but orders the south point first because it's + # closer to the reference point to the south. + points = Point.objects(location__near_sphere=[-122, 36.5]) + self.assertEqual(points.count(), 2) + self.assertEqual(points[0].id, south_point.id) + self.assertEqual(points[1].id, north_point.id) + + # Finds only one point because only the first point is within 60km of + # the reference point to the south. + points = Point.objects( + location__within_spherical_distance=[[-122, 36.5], 60/earth_radius]) + self.assertEqual(points.count(), 1) + self.assertEqual(points[0].id, south_point.id) + + def test_2dsphere_point(self): + + class Event(Document): + title = StringField() + date = DateTimeField() + location = PointField() + + def __unicode__(self): + return self.title + + Event.drop_collection() + + event1 = Event(title="Coltrane Motion @ Double Door", + date=datetime.now() - timedelta(days=1), + location=[-87.677137, 41.909889]) + event1.save() + event2 = Event(title="Coltrane Motion @ Bottom of the Hill", + date=datetime.now() - timedelta(days=10), + location=[-122.4194155, 37.7749295]).save() + event3 = Event(title="Coltrane Motion @ Empty Bottle", + date=datetime.now(), + location=[-87.686638, 41.900474]).save() + + # find all events "near" pitchfork office, chicago. + # note that "near" will show the san francisco event, too, + # although it sorts to last. + events = Event.objects(location__near=[-87.67892, 41.9120459]) + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event1, event3, event2]) + + # find events within 5 degrees of pitchfork office, chicago + point_and_distance = [[-87.67892, 41.9120459], 2] + events = Event.objects(location__geo_within_center=point_and_distance) + self.assertEqual(events.count(), 2) + events = list(events) + self.assertTrue(event2 not in events) + self.assertTrue(event1 in events) + self.assertTrue(event3 in events) + + # ensure ordering is respected by "near" + events = Event.objects(location__near=[-87.67892, 41.9120459]) + events = events.order_by("-date") + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event3, event1, event2]) + + # find events within 10km of san francisco + point = [-122.415579, 37.7566023] + events = Event.objects(location__near=point, location__max_distance=10000) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0], event2) + + # find events within 1km of greenpoint, broolyn, nyc, ny + events = Event.objects(location__near=[-73.9509714, 40.7237134], location__max_distance=1000) + self.assertEqual(events.count(), 0) + + # ensure ordering is respected by "near" + events = Event.objects(location__near=[-87.67892, 41.9120459], + location__max_distance=10000).order_by("-date") + self.assertEqual(events.count(), 2) + self.assertEqual(events[0], event3) + + # check that within_box works + box = [(-125.0, 35.0), (-100.0, 40.0)] + events = Event.objects(location__geo_within_box=box) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0].id, event2.id) + + polygon = [ + (-87.694445, 41.912114), + (-87.69084, 41.919395), + (-87.681742, 41.927186), + (-87.654276, 41.911731), + (-87.656164, 41.898061), + ] + events = Event.objects(location__geo_within_polygon=polygon) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0].id, event1.id) + + polygon2 = [ + (-1.742249, 54.033586), + (-1.225891, 52.792797), + (-4.40094, 53.389881) + ] + events = Event.objects(location__geo_within_polygon=polygon2) + self.assertEqual(events.count(), 0) + + def test_2dsphere_point_embedded(self): + + class Venue(EmbeddedDocument): + location = GeoPointField() + name = StringField() + + class Event(Document): + title = StringField() + venue = EmbeddedDocumentField(Venue) + + Event.drop_collection() + + venue1 = Venue(name="The Rock", location=[-87.677137, 41.909889]) + venue2 = Venue(name="The Bridge", location=[-122.4194155, 37.7749295]) + + event1 = Event(title="Coltrane Motion @ Double Door", + venue=venue1).save() + event2 = Event(title="Coltrane Motion @ Bottom of the Hill", + venue=venue2).save() + event3 = Event(title="Coltrane Motion @ Empty Bottle", + venue=venue1).save() + + # find all events "near" pitchfork office, chicago. + # note that "near" will show the san francisco event, too, + # although it sorts to last. + events = Event.objects(venue__location__near=[-87.67892, 41.9120459]) + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event1, event3, event2]) + + def test_linestring(self): + + class Road(Document): + name = StringField() + line = LineStringField() + + Road.drop_collection() + + Road(name="66", line=[[40, 5], [41, 6]]).save() + + # near + point = {"type": "Point", "coordinates": [40, 5]} + roads = Road.objects.filter(line__near=point["coordinates"]).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(line__near=point).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(line__near={"$geometry": point}).count() + self.assertEqual(1, roads) + + # Within + polygon = {"type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} + roads = Road.objects.filter(line__geo_within=polygon["coordinates"]).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(line__geo_within=polygon).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(line__geo_within={"$geometry": polygon}).count() + self.assertEqual(1, roads) + + # Intersects + line = {"type": "LineString", + "coordinates": [[40, 5], [40, 6]]} + roads = Road.objects.filter(line__geo_intersects=line["coordinates"]).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(line__geo_intersects=line).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(line__geo_intersects={"$geometry": line}).count() + self.assertEqual(1, roads) + + polygon = {"type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} + roads = Road.objects.filter(line__geo_intersects=polygon["coordinates"]).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(line__geo_intersects=polygon).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(line__geo_intersects={"$geometry": polygon}).count() + self.assertEqual(1, roads) + + def test_polygon(self): + + class Road(Document): + name = StringField() + poly = PolygonField() + + Road.drop_collection() + + Road(name="66", poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]]).save() + + # near + point = {"type": "Point", "coordinates": [40, 5]} + roads = Road.objects.filter(poly__near=point["coordinates"]).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(poly__near=point).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(poly__near={"$geometry": point}).count() + self.assertEqual(1, roads) + + # Within + polygon = {"type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} + roads = Road.objects.filter(poly__geo_within=polygon["coordinates"]).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(poly__geo_within=polygon).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(poly__geo_within={"$geometry": polygon}).count() + self.assertEqual(1, roads) + + # Intersects + line = {"type": "LineString", + "coordinates": [[40, 5], [41, 6]]} + roads = Road.objects.filter(poly__geo_intersects=line["coordinates"]).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(poly__geo_intersects=line).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(poly__geo_intersects={"$geometry": line}).count() + self.assertEqual(1, roads) + + polygon = {"type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} + roads = Road.objects.filter(poly__geo_intersects=polygon["coordinates"]).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(poly__geo_intersects=polygon).count() + self.assertEqual(1, roads) + + roads = Road.objects.filter(poly__geo_intersects={"$geometry": polygon}).count() + self.assertEqual(1, roads) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 37670b0..b4bcf2a 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -1,4 +1,3 @@ -from __future__ import with_statement import sys sys.path[0:0] = [""] @@ -20,7 +19,7 @@ from mongoengine.python_support import PY3 from mongoengine.context_managers import query_counter from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, - QueryFieldList, queryset_manager) + queryset_manager) from mongoengine.errors import InvalidQueryError __all__ = ("QuerySetTest",) @@ -31,12 +30,17 @@ class QuerySetTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') + class PersonMeta(EmbeddedDocument): + weight = IntField() + class Person(Document): name = StringField() age = IntField() + person_meta = EmbeddedDocumentField(PersonMeta) meta = {'allow_inheritance': True} Person.drop_collection() + self.PersonMeta = PersonMeta self.Person = Person def test_initialisation(self): @@ -65,14 +69,12 @@ class QuerySetTest(unittest.TestCase): def test_find(self): """Ensure that a query returns a valid set of results. """ - person1 = self.Person(name="User A", age=20) - person1.save() - person2 = self.Person(name="User B", age=30) - person2.save() + self.Person(name="User A", age=20).save() + self.Person(name="User B", age=30).save() # Find all people in the collection people = self.Person.objects - self.assertEqual(len(people), 2) + self.assertEqual(people.count(), 2) results = list(people) self.assertTrue(isinstance(results[0], self.Person)) self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode))) @@ -83,7 +85,7 @@ class QuerySetTest(unittest.TestCase): # Use a query to filter the people found to just person1 people = self.Person.objects(age=20) - self.assertEqual(len(people), 1) + self.assertEqual(people.count(), 1) person = people.next() self.assertEqual(person.name, "User A") self.assertEqual(person.age, 20) @@ -118,6 +120,15 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(len(people), 1) self.assertEqual(people[0].name, 'User B') + # Test slice limit and skip cursor reset + qs = self.Person.objects[1:2] + # fetch then delete the cursor + qs._cursor + qs._cursor_obj = None + people = list(qs) + self.assertEqual(len(people), 1) + self.assertEqual(people[0].name, 'User B') + people = list(self.Person.objects[1:1]) self.assertEqual(len(people), 0) @@ -130,7 +141,7 @@ class QuerySetTest(unittest.TestCase): for i in xrange(55): self.Person(name='A%s' % i, age=i).save() - self.assertEqual(len(self.Person.objects), 55) + self.assertEqual(self.Person.objects.count(), 55) self.assertEqual("Person object", "%s" % self.Person.objects[0]) self.assertEqual("[, ]", "%s" % self.Person.objects[1:3]) self.assertEqual("[, ]", "%s" % self.Person.objects[51:53]) @@ -211,10 +222,10 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() Blog.objects.create(tags=['a', 'b']) - self.assertEqual(len(Blog.objects(tags__0='a')), 1) - self.assertEqual(len(Blog.objects(tags__0='b')), 0) - self.assertEqual(len(Blog.objects(tags__1='a')), 0) - self.assertEqual(len(Blog.objects(tags__1='b')), 1) + self.assertEqual(Blog.objects(tags__0='a').count(), 1) + self.assertEqual(Blog.objects(tags__0='b').count(), 0) + self.assertEqual(Blog.objects(tags__1='a').count(), 0) + self.assertEqual(Blog.objects(tags__1='b').count(), 1) Blog.drop_collection() @@ -229,13 +240,13 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(blog, blog1) query = Blog.objects(posts__1__comments__1__name='testb') - self.assertEqual(len(query), 2) + self.assertEqual(query.count(), 2) query = Blog.objects(posts__1__comments__1__name='testa') - self.assertEqual(len(query), 0) + self.assertEqual(query.count(), 0) query = Blog.objects(posts__0__comments__1__name='testa') - self.assertEqual(len(query), 0) + self.assertEqual(query.count(), 0) Blog.drop_collection() @@ -276,28 +287,32 @@ class QuerySetTest(unittest.TestCase): a_objects = A.objects(s='test1') query = B.objects(ref__in=a_objects) query = query.filter(boolfield=True) - self.assertEquals(query.count(), 1) + self.assertEqual(query.count(), 1) - def test_update_write_options(self): - """Test that passing write_options works""" + def test_update_write_concern(self): + """Test that passing write_concern works""" self.Person.drop_collection() - write_options = {"fsync": True} + write_concern = {"fsync": True} author, created = self.Person.objects.get_or_create( - name='Test User', write_options=write_options) - author.save(write_options=write_options) + name='Test User', write_concern=write_concern) + author.save(write_concern=write_concern) - self.Person.objects.update(set__name='Ross', - write_options=write_options) + result = self.Person.objects.update( + set__name='Ross', write_concern={"w": 1}) + self.assertEqual(result, 1) + result = self.Person.objects.update( + set__name='Ross', write_concern={"w": 0}) + self.assertEqual(result, None) - author = self.Person.objects.first() - self.assertEqual(author.name, 'Ross') - - self.Person.objects.update_one(set__name='Test User', write_options=write_options) - author = self.Person.objects.first() - self.assertEqual(author.name, 'Test User') + result = self.Person.objects.update_one( + set__name='Test User', write_concern={"w": 1}) + self.assertEqual(result, 1) + result = self.Person.objects.update_one( + set__name='Test User', write_concern={"w": 0}) + self.assertEqual(result, None) def test_update_update_has_a_value(self): """Test to ensure that update is passed a value to update to""" @@ -338,24 +353,23 @@ class QuerySetTest(unittest.TestCase): comment2 = Comment(name='testb') post1 = Post(comments=[comment1, comment2]) post2 = Post(comments=[comment2, comment2]) - blog1 = Blog.objects.create(posts=[post1, post2]) - blog2 = Blog.objects.create(posts=[post2, post1]) + Blog.objects.create(posts=[post1, post2]) + Blog.objects.create(posts=[post2, post1]) # Update all of the first comments of second posts of all blogs - blog = Blog.objects().update(set__posts__1__comments__0__name="testc") + Blog.objects().update(set__posts__1__comments__0__name="testc") testc_blogs = Blog.objects(posts__1__comments__0__name="testc") - self.assertEqual(len(testc_blogs), 2) + self.assertEqual(testc_blogs.count(), 2) Blog.drop_collection() - - blog1 = Blog.objects.create(posts=[post1, post2]) - blog2 = Blog.objects.create(posts=[post2, post1]) + Blog.objects.create(posts=[post1, post2]) + Blog.objects.create(posts=[post2, post1]) # Update only the first blog returned by the query - blog = Blog.objects().update_one( + Blog.objects().update_one( set__posts__1__comments__1__name="testc") testc_blogs = Blog.objects(posts__1__comments__1__name="testc") - self.assertEqual(len(testc_blogs), 1) + self.assertEqual(testc_blogs.count(), 1) # Check that using this indexing syntax on a non-list fails def non_list_indexing(): @@ -527,6 +541,50 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(club.members['John']['gender'], "F") self.assertEqual(club.members['John']['age'], 14) + def test_update_results(self): + self.Person.drop_collection() + + result = self.Person(name="Bob", age=25).update(upsert=True, full_result=True) + self.assertTrue(isinstance(result, dict)) + self.assertTrue("upserted" in result) + self.assertFalse(result["updatedExisting"]) + + bob = self.Person.objects.first() + result = bob.update(set__age=30, full_result=True) + self.assertTrue(isinstance(result, dict)) + self.assertTrue(result["updatedExisting"]) + + self.Person(name="Bob", age=20).save() + result = self.Person.objects(name="Bob").update(set__name="bobby", multi=True) + self.assertEqual(result, 2) + + def test_upsert(self): + self.Person.drop_collection() + + self.Person.objects(pk=ObjectId(), name="Bob", age=30).update(upsert=True) + + bob = self.Person.objects.first() + self.assertEqual("Bob", bob.name) + self.assertEqual(30, bob.age) + + def test_upsert_one(self): + self.Person.drop_collection() + + self.Person.objects(name="Bob", age=30).update_one(upsert=True) + + bob = self.Person.objects.first() + self.assertEqual("Bob", bob.name) + self.assertEqual(30, bob.age) + + def test_set_on_insert(self): + self.Person.drop_collection() + + self.Person.objects(pk=ObjectId()).update(set__name='Bob', set_on_insert__age=30, upsert=True) + + bob = self.Person.objects.first() + self.assertEqual("Bob", bob.name) + self.assertEqual(30, bob.age) + def test_get_or_create(self): """Ensure that ``get_or_create`` returns one result or creates a new document. @@ -592,10 +650,16 @@ class QuerySetTest(unittest.TestCase): blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) Blog.objects.insert(blogs, load_bulk=False) - self.assertEqual(q, 1) # 1 for the insert + self.assertEqual(q, 1) # 1 for the insert + + Blog.drop_collection() + Blog.ensure_indexes() + + with query_counter() as q: + self.assertEqual(q, 0) Blog.objects.insert(blogs) - self.assertEqual(q, 3) # 1 for insert, and 1 for in bulk fetch (3 in total) + self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch Blog.drop_collection() @@ -619,7 +683,7 @@ class QuerySetTest(unittest.TestCase): self.assertRaises(OperationError, throw_operation_error) # Test can insert new doc - new_post = Blog(title="code", id=ObjectId()) + new_post = Blog(title="code123", id=ObjectId()) Blog.objects.insert(new_post) # test handles other classes being inserted @@ -655,13 +719,13 @@ class QuerySetTest(unittest.TestCase): Blog.objects.insert([blog1, blog2]) def throw_operation_error_not_unique(): - Blog.objects.insert([blog2, blog3], safe=True) + Blog.objects.insert([blog2, blog3]) self.assertRaises(NotUniqueError, throw_operation_error_not_unique) self.assertEqual(Blog.objects.count(), 2) - Blog.objects.insert([blog2, blog3], write_options={ - 'continue_on_error': True}) + Blog.objects.insert([blog2, blog3], write_concern={"w": 0, + 'continue_on_error': True}) self.assertEqual(Blog.objects.count(), 3) def test_get_changed_fields_query_count(self): @@ -759,7 +823,7 @@ class QuerySetTest(unittest.TestCase): p = p.snapshot(True).slave_okay(True).timeout(True) self.assertEqual(p._cursor_args, - {'snapshot': True, 'slave_okay': True, 'timeout': True}) + {'snapshot': True, 'slave_okay': True, 'timeout': True}) def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. @@ -786,7 +850,7 @@ class QuerySetTest(unittest.TestCase): number = IntField() def __repr__(self): - return "" % self.number + return "" % self.number Doc.drop_collection() @@ -796,20 +860,18 @@ class QuerySetTest(unittest.TestCase): docs = Doc.objects.order_by('number') self.assertEqual(docs.count(), 1000) - self.assertEqual(len(docs), 1000) docs_string = "%s" % docs self.assertTrue("Doc: 0" in docs_string) self.assertEqual(docs.count(), 1000) - self.assertEqual(len(docs), 1000) + self.assertTrue('(remaining elements truncated)' in "%s" % docs) # Limit and skip docs = docs[1:4] self.assertEqual('[, , ]', "%s" % docs) self.assertEqual(docs.count(), 3) - self.assertEqual(len(docs), 3) for doc in docs: self.assertEqual('.. queryset mid-iteration ..', repr(docs)) @@ -938,8 +1000,10 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() def assertSequence(self, qs, expected): + qs = list(qs) + expected = list(expected) self.assertEqual(len(qs), len(expected)) - for i in range(len(qs)): + for i in xrange(len(qs)): self.assertEqual(qs[i], expected[i]) def test_ordering(self): @@ -1117,13 +1181,13 @@ class QuerySetTest(unittest.TestCase): self.Person(name="User B", age=30).save() self.Person(name="User C", age=40).save() - self.assertEqual(len(self.Person.objects), 3) + self.assertEqual(self.Person.objects.count(), 3) self.Person.objects(age__lt=30).delete() - self.assertEqual(len(self.Person.objects), 2) + self.assertEqual(self.Person.objects.count(), 2) self.Person.objects.delete() - self.assertEqual(len(self.Person.objects), 0) + self.assertEqual(self.Person.objects.count(), 0) def test_reverse_delete_rule_cascade(self): """Ensure cascading deletion of referring documents from the database. @@ -1230,7 +1294,7 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): content = StringField() authors = ListField(ReferenceField(self.Person, - reverse_delete_rule=PULL)) + reverse_delete_rule=PULL)) BlogPost.drop_collection() self.Person.drop_collection() @@ -1288,6 +1352,49 @@ class QuerySetTest(unittest.TestCase): self.Person.objects()[:1].delete() self.assertEqual(1, BlogPost.objects.count()) + + def test_reference_field_find(self): + """Ensure cascading deletion of referring documents from the database. + """ + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person) + + BlogPost.drop_collection() + self.Person.drop_collection() + + me = self.Person(name='Test User').save() + BlogPost(content="test 123", author=me).save() + + self.assertEqual(1, BlogPost.objects(author=me).count()) + self.assertEqual(1, BlogPost.objects(author=me.pk).count()) + self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count()) + + self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) + self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) + self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + + def test_reference_field_find_dbref(self): + """Ensure cascading deletion of referring documents from the database. + """ + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, dbref=True) + + BlogPost.drop_collection() + self.Person.drop_collection() + + me = self.Person(name='Test User').save() + BlogPost(content="test 123", author=me).save() + + self.assertEqual(1, BlogPost.objects(author=me).count()) + self.assertEqual(1, BlogPost.objects(author=me.pk).count()) + self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count()) + + self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) + self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) + self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + def test_update(self): """Ensure that atomic updates work properly. """ @@ -1390,9 +1497,6 @@ class QuerySetTest(unittest.TestCase): def test_pull_nested(self): - class User(Document): - name = StringField() - class Collaborator(EmbeddedDocument): user = StringField() @@ -1407,8 +1511,7 @@ class QuerySetTest(unittest.TestCase): Site.drop_collection() c = Collaborator(user='Esteban') - s = Site(name="test", collaborators=[c]) - s.save() + s = Site(name="test", collaborators=[c]).save() Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') self.assertEqual(Site.objects.first().collaborators, []) @@ -1418,6 +1521,71 @@ class QuerySetTest(unittest.TestCase): self.assertRaises(InvalidQueryError, pull_all) + def test_pull_from_nested_embedded(self): + + class User(EmbeddedDocument): + name = StringField() + + def __unicode__(self): + return '%s' % self.name + + class Collaborator(EmbeddedDocument): + helpful = ListField(EmbeddedDocumentField(User)) + unhelpful = ListField(EmbeddedDocumentField(User)) + + class Site(Document): + name = StringField(max_length=75, unique=True, required=True) + collaborators = EmbeddedDocumentField(Collaborator) + + + Site.drop_collection() + + c = User(name='Esteban') + f = User(name='Frank') + s = Site(name="test", collaborators=Collaborator(helpful=[c], unhelpful=[f])).save() + + Site.objects(id=s.id).update_one(pull__collaborators__helpful=c) + self.assertEqual(Site.objects.first().collaborators['helpful'], []) + + Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'name': 'Frank'}) + self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) + + def pull_all(): + Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__name=['Ross']) + + self.assertRaises(InvalidQueryError, pull_all) + + def test_pull_from_nested_mapfield(self): + + class Collaborator(EmbeddedDocument): + user = StringField() + + def __unicode__(self): + return '%s' % self.user + + class Site(Document): + name = StringField(max_length=75, unique=True, required=True) + collaborators = MapField(ListField(EmbeddedDocumentField(Collaborator))) + + + Site.drop_collection() + + c = Collaborator(user='Esteban') + f = Collaborator(user='Frank') + s = Site(name="test", collaborators={'helpful':[c],'unhelpful':[f]}) + s.save() + + Site.objects(id=s.id).update_one(pull__collaborators__helpful__user='Esteban') + self.assertEqual(Site.objects.first().collaborators['helpful'], []) + + Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'user':'Frank'}) + self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) + + def pull_all(): + Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__user=['Ross']) + + self.assertRaises(InvalidQueryError, pull_all) + def test_update_one_pop_generic_reference(self): class BlogTag(Document): @@ -1511,6 +1679,32 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(message.authors[1].name, "Ross") self.assertEqual(message.authors[2].name, "Adam") + def test_reload_embedded_docs_instance(self): + + class SubDoc(EmbeddedDocument): + val = IntField() + + class Doc(Document): + embedded = EmbeddedDocumentField(SubDoc) + + doc = Doc(embedded=SubDoc(val=0)).save() + doc.reload() + + self.assertEqual(doc.pk, doc.embedded._instance.pk) + + def test_reload_list_embedded_docs_instance(self): + + class SubDoc(EmbeddedDocument): + val = IntField() + + class Doc(Document): + embedded = ListField(EmbeddedDocumentField(SubDoc)) + + doc = Doc(embedded=[SubDoc(val=0)]).save() + doc.reload() + + self.assertEqual(doc.pk, doc.embedded[0]._instance.pk) + def test_order_by(self): """Ensure that QuerySets may be ordered. """ @@ -2080,6 +2274,19 @@ class QuerySetTest(unittest.TestCase): self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.average('age')), avg) + # dot notation + self.Person(name='person meta', person_meta=self.PersonMeta(weight=0)).save() + self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), 0) + + for i, weight in enumerate(ages): + self.Person(name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save() + + self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), avg) + + self.Person(name='test meta none').save() + self.assertEqual(int(self.Person.objects.average('person_meta.weight')), avg) + + def test_sum(self): """Ensure that field can be summed over correctly. """ @@ -2092,6 +2299,153 @@ class QuerySetTest(unittest.TestCase): self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + for i, age in enumerate(ages): + self.Person(name='test meta%s' % i, person_meta=self.PersonMeta(weight=age)).save() + + self.assertEqual(int(self.Person.objects.sum('person_meta.weight')), sum(ages)) + + self.Person(name='weightless person').save() + self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + + def test_embedded_average(self): + class Pay(EmbeddedDocument): + value = DecimalField() + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(value=150)).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(value=530)).save() + + Doc(name=u"Tayza mariana", + pay=Pay(value=165)).save() + + Doc(name=u"Eliana Costa", + pay=Pay(value=115)).save() + + self.assertEqual( + Doc.objects.average('pay.value'), + 240) + + def test_embedded_array_average(self): + class Pay(EmbeddedDocument): + values = ListField(DecimalField()) + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(values=[150, 100])).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(values=[530, 100])).save() + + Doc(name=u"Tayza mariana", + pay=Pay(values=[165, 100])).save() + + Doc(name=u"Eliana Costa", + pay=Pay(values=[115, 100])).save() + + self.assertEqual( + Doc.objects.average('pay.values'), + 170) + + def test_array_average(self): + class Doc(Document): + values = ListField(DecimalField()) + + Doc.drop_collection() + + Doc(values=[150, 100]).save() + Doc(values=[530, 100]).save() + Doc(values=[165, 100]).save() + Doc(values=[115, 100]).save() + + self.assertEqual( + Doc.objects.average('values'), + 170) + + def test_embedded_sum(self): + class Pay(EmbeddedDocument): + value = DecimalField() + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(value=150)).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(value=530)).save() + + Doc(name=u"Tayza mariana", + pay=Pay(value=165)).save() + + Doc(name=u"Eliana Costa", + pay=Pay(value=115)).save() + + self.assertEqual( + Doc.objects.sum('pay.value'), + 960) + + + def test_embedded_array_sum(self): + class Pay(EmbeddedDocument): + values = ListField(DecimalField()) + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(values=[150, 100])).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(values=[530, 100])).save() + + Doc(name=u"Tayza mariana", + pay=Pay(values=[165, 100])).save() + + Doc(name=u"Eliana Costa", + pay=Pay(values=[115, 100])).save() + + self.assertEqual( + Doc.objects.sum('pay.values'), + 1360) + + def test_array_sum(self): + class Doc(Document): + values = ListField(DecimalField()) + + Doc.drop_collection() + + Doc(values=[150, 100]).save() + Doc(values=[530, 100]).save() + Doc(values=[165, 100]).save() + Doc(values=[115, 100]).save() + + self.assertEqual( + Doc.objects.sum('values'), + 1360) + def test_distinct(self): """Ensure that the QuerySet.distinct method works. """ @@ -2226,6 +2580,42 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(0, Foo.with_inactive.count()) self.assertEqual(1, Foo.objects.count()) + def test_inherit_objects(self): + + class Foo(Document): + meta = {'allow_inheritance': True} + active = BooleanField(default=True) + + @queryset_manager + def objects(klass, queryset): + return queryset(active=True) + + class Bar(Foo): + pass + + Bar.drop_collection() + Bar.objects.create(active=False) + self.assertEqual(0, Bar.objects.count()) + + def test_inherit_objects_override(self): + + class Foo(Document): + meta = {'allow_inheritance': True} + active = BooleanField(default=True) + + @queryset_manager + def objects(klass, queryset): + return queryset(active=True) + + class Bar(Foo): + @queryset_manager + def objects(klass, queryset): + return queryset(active=False) + + Bar.drop_collection() + Bar.objects.create(active=False) + self.assertEqual(0, Foo.objects.count()) + self.assertEqual(1, Bar.objects.count()) def test_query_value_conversion(self): """Ensure that query values are properly converted when necessary. @@ -2289,8 +2679,8 @@ class QuerySetTest(unittest.TestCase): t = Test(testdict={'f': 'Value'}) t.save() - self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1) - self.assertEqual(len(Test.objects(testdict__f='Value')), 1) + self.assertEqual(Test.objects(testdict__f__startswith='Val').count(), 1) + self.assertEqual(Test.objects(testdict__f='Value').count(), 1) Test.drop_collection() class Test(Document): @@ -2299,8 +2689,8 @@ class QuerySetTest(unittest.TestCase): t = Test(testdict={'f': 'Value'}) t.save() - self.assertEqual(len(Test.objects(testdict__f='Value')), 1) - self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1) + self.assertEqual(Test.objects(testdict__f='Value').count(), 1) + self.assertEqual(Test.objects(testdict__f__startswith='Val').count(), 1) Test.drop_collection() def test_bulk(self): @@ -2341,174 +2731,12 @@ class QuerySetTest(unittest.TestCase): def tearDown(self): self.Person.drop_collection() - def test_geospatial_operators(self): - """Ensure that geospatial queries are working. - """ - class Event(Document): - title = StringField() - date = DateTimeField() - location = GeoPointField() - - def __unicode__(self): - return self.title - - Event.drop_collection() - - event1 = Event(title="Coltrane Motion @ Double Door", - date=datetime.now() - timedelta(days=1), - location=[41.909889, -87.677137]) - event2 = Event(title="Coltrane Motion @ Bottom of the Hill", - date=datetime.now() - timedelta(days=10), - location=[37.7749295, -122.4194155]) - event3 = Event(title="Coltrane Motion @ Empty Bottle", - date=datetime.now(), - location=[41.900474, -87.686638]) - - event1.save() - event2.save() - event3.save() - - # find all events "near" pitchfork office, chicago. - # note that "near" will show the san francisco event, too, - # although it sorts to last. - events = Event.objects(location__near=[41.9120459, -87.67892]) - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event1, event3, event2]) - - # find events within 5 degrees of pitchfork office, chicago - point_and_distance = [[41.9120459, -87.67892], 5] - events = Event.objects(location__within_distance=point_and_distance) - self.assertEqual(events.count(), 2) - events = list(events) - self.assertTrue(event2 not in events) - self.assertTrue(event1 in events) - self.assertTrue(event3 in events) - - # ensure ordering is respected by "near" - events = Event.objects(location__near=[41.9120459, -87.67892]) - events = events.order_by("-date") - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event3, event1, event2]) - - # find events within 10 degrees of san francisco - point = [37.7566023, -122.415579] - events = Event.objects(location__near=point, location__max_distance=10) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) - - # find events within 10 degrees of san francisco - point_and_distance = [[37.7566023, -122.415579], 10] - events = Event.objects(location__within_distance=point_and_distance) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) - - # find events within 1 degree of greenpoint, broolyn, nyc, ny - point_and_distance = [[40.7237134, -73.9509714], 1] - events = Event.objects(location__within_distance=point_and_distance) - self.assertEqual(events.count(), 0) - - # ensure ordering is respected by "within_distance" - point_and_distance = [[41.9120459, -87.67892], 10] - events = Event.objects(location__within_distance=point_and_distance) - events = events.order_by("-date") - self.assertEqual(events.count(), 2) - self.assertEqual(events[0], event3) - - # check that within_box works - box = [(35.0, -125.0), (40.0, -100.0)] - events = Event.objects(location__within_box=box) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event2.id) - - # check that polygon works for users who have a server >= 1.9 - server_version = tuple( - get_connection().server_info()['version'].split('.') - ) - required_version = tuple("1.9.0".split(".")) - if server_version >= required_version: - polygon = [ - (41.912114,-87.694445), - (41.919395,-87.69084), - (41.927186,-87.681742), - (41.911731,-87.654276), - (41.898061,-87.656164), - ] - events = Event.objects(location__within_polygon=polygon) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event1.id) - - polygon2 = [ - (54.033586,-1.742249), - (52.792797,-1.225891), - (53.389881,-4.40094) - ] - events = Event.objects(location__within_polygon=polygon2) - self.assertEqual(events.count(), 0) - - Event.drop_collection() - - def test_spherical_geospatial_operators(self): - """Ensure that spherical geospatial queries are working - """ - class Point(Document): - location = GeoPointField() - - Point.drop_collection() - - # These points are one degree apart, which (according to Google Maps) - # is about 110 km apart at this place on the Earth. - north_point = Point(location=[-122, 38]) # Near Concord, CA - south_point = Point(location=[-122, 37]) # Near Santa Cruz, CA - north_point.save() - south_point.save() - - earth_radius = 6378.009; # in km (needs to be a float for dividing by) - - # Finds both points because they are within 60 km of the reference - # point equidistant between them. - points = Point.objects(location__near_sphere=[-122, 37.5]) - self.assertEqual(points.count(), 2) - - # Same behavior for _within_spherical_distance - points = Point.objects( - location__within_spherical_distance=[[-122, 37.5], 60/earth_radius] - ); - self.assertEqual(points.count(), 2) - - points = Point.objects(location__near_sphere=[-122, 37.5], - location__max_distance=60 / earth_radius); - self.assertEqual(points.count(), 2) - - # Finds both points, but orders the north point first because it's - # closer to the reference point to the north. - points = Point.objects(location__near_sphere=[-122, 38.5]) - self.assertEqual(points.count(), 2) - self.assertEqual(points[0].id, north_point.id) - self.assertEqual(points[1].id, south_point.id) - - # Finds both points, but orders the south point first because it's - # closer to the reference point to the south. - points = Point.objects(location__near_sphere=[-122, 36.5]) - self.assertEqual(points.count(), 2) - self.assertEqual(points[0].id, south_point.id) - self.assertEqual(points[1].id, north_point.id) - - # Finds only one point because only the first point is within 60km of - # the reference point to the south. - points = Point.objects( - location__within_spherical_distance=[[-122, 36.5], 60/earth_radius] - ); - self.assertEqual(points.count(), 1) - self.assertEqual(points[0].id, south_point.id) - - Point.drop_collection() - def test_custom_querysets(self): """Ensure that custom QuerySet classes may be used. """ class CustomQuerySet(QuerySet): def not_empty(self): - return len(self) > 0 + return self.count() > 0 class Post(Document): meta = {'queryset_class': CustomQuerySet} @@ -2529,7 +2757,7 @@ class QuerySetTest(unittest.TestCase): class CustomQuerySet(QuerySet): def not_empty(self): - return len(self) > 0 + return self.count() > 0 class CustomQuerySetManager(QuerySetManager): queryset_class = CustomQuerySet @@ -2576,7 +2804,7 @@ class QuerySetTest(unittest.TestCase): class CustomQuerySet(QuerySet): def not_empty(self): - return len(self) > 0 + return self.count() > 0 class Base(Document): meta = {'abstract': True, 'queryset_class': CustomQuerySet} @@ -2599,7 +2827,7 @@ class QuerySetTest(unittest.TestCase): class CustomQuerySet(QuerySet): def not_empty(self): - return len(self) > 0 + return self.count() > 0 class CustomQuerySetManager(QuerySetManager): queryset_class = CustomQuerySet @@ -2620,6 +2848,19 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() + def test_count_limit_and_skip(self): + class Post(Document): + title = StringField() + + Post.drop_collection() + + for i in xrange(10): + Post(title="Post %s" % i).save() + + self.assertEqual(5, Post.objects.limit(5).skip(5).count()) + + self.assertEqual(10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False)) + def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works """ @@ -2628,10 +2869,8 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() - post1 = Post(title="Post 1") - post1.save() - post2 = Post(title="Post 2") - post2.save() + Post(title="Post 1").save() + Post(title="Post 2").save() posts = Post.objects.all()[0:1] self.assertEqual(len(list(posts())), 1) @@ -3001,14 +3240,14 @@ class QuerySetTest(unittest.TestCase): # Find all people in the collection people = self.Person.objects.scalar('name') - self.assertEqual(len(people), 2) + self.assertEqual(people.count(), 2) results = list(people) self.assertEqual(results[0], "User A") self.assertEqual(results[1], "User B") # Use a query to filter the people found to just person1 people = self.Person.objects(age=20).scalar('name') - self.assertEqual(len(people), 1) + self.assertEqual(people.count(), 1) person = people.next() self.assertEqual(person, "User A") @@ -3054,7 +3293,7 @@ class QuerySetTest(unittest.TestCase): for i in xrange(55): self.Person(name='A%s' % i, age=i).save() - self.assertEqual(len(self.Person.objects.scalar('name')), 55) + self.assertEqual(self.Person.objects.scalar('name').count(), 55) self.assertEqual("A0", "%s" % self.Person.objects.order_by('name').scalar('name').first()) self.assertEqual("A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) if PY3: @@ -3078,7 +3317,7 @@ class QuerySetTest(unittest.TestCase): class Foo(EmbeddedDocument): shape = StringField() color = StringField() - trick = BooleanField() + thick = BooleanField() meta = {'allow_inheritance': False} class Bar(Document): @@ -3087,17 +3326,20 @@ class QuerySetTest(unittest.TestCase): Bar.drop_collection() - b1 = Bar(foo=[Foo(shape= "square", color ="purple", thick = False), - Foo(shape= "circle", color ="red", thick = True)]) + b1 = Bar(foo=[Foo(shape="square", color="purple", thick=False), + Foo(shape="circle", color="red", thick=True)]) b1.save() - b2 = Bar(foo=[Foo(shape= "square", color ="red", thick = True), - Foo(shape= "circle", color ="purple", thick = False)]) + b2 = Bar(foo=[Foo(shape="square", color="red", thick=True), + Foo(shape="circle", color="purple", thick=False)]) b2.save() ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) self.assertEqual([b1], ak) + ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple"))) + self.assertEqual([b1], ak) + def test_upsert_includes_cls(self): """Upserts should include _cls information for inheritable classes """ @@ -3118,6 +3360,13 @@ class QuerySetTest(unittest.TestCase): Test.objects(test='foo').update_one(upsert=True, set__test='foo') self.assertTrue('_cls' in Test._collection.find_one()) + def test_update_upsert_looks_like_a_digit(self): + class MyDoc(DynamicDocument): + pass + MyDoc.drop_collection() + self.assertEqual(1, MyDoc.objects.update_one(upsert=True, inc__47=1)) + self.assertEqual(MyDoc.objects.get()['47'], 1) + def test_read_preference(self): class Bar(Document): pass @@ -3127,7 +3376,10 @@ class QuerySetTest(unittest.TestCase): self.assertEqual([], bars) self.assertRaises(ConfigurationError, Bar.objects, - read_preference='Primary') + read_preference='Primary') + + bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) def test_json_simple(self): @@ -3143,7 +3395,7 @@ class QuerySetTest(unittest.TestCase): Doc(string="Bye", embedded_field=Embedded(string="Bye")).save() Doc().save() - json_data = Doc.objects.to_json() + json_data = Doc.objects.to_json(sort_keys=True, separators=(',', ':')) doc_objects = list(Doc.objects) self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) @@ -3164,20 +3416,18 @@ class QuerySetTest(unittest.TestCase): float_field = FloatField(default=1.1) boolean_field = BooleanField(default=True) datetime_field = DateTimeField(default=datetime.now) - embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, - default=lambda: EmbeddedDoc()) + embedded_document_field = EmbeddedDocumentField( + EmbeddedDoc, default=lambda: EmbeddedDoc()) list_field = ListField(default=lambda: [1, 2, 3]) dict_field = DictField(default=lambda: {"hello": "world"}) objectid_field = ObjectIdField(default=ObjectId) - reference_field = ReferenceField(Simple, default=lambda: - Simple().save()) + reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) - generic_reference_field = GenericReferenceField( - default=lambda: Simple().save()) + generic_reference_field = GenericReferenceField(default=lambda: Simple().save()) sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3]) email_field = EmailField(default="ross@example.com") @@ -3185,7 +3435,7 @@ class QuerySetTest(unittest.TestCase): sequence_field = SequenceField() uuid_field = UUIDField(default=uuid.uuid4) generic_embedded_document_field = GenericEmbeddedDocumentField( - default=lambda: EmbeddedDoc()) + default=lambda: EmbeddedDoc()) Simple.drop_collection() Doc.drop_collection() @@ -3210,14 +3460,17 @@ class QuerySetTest(unittest.TestCase): User(name="Bob Dole", age=89, price=Decimal('1.11')).save() User(name="Barack Obama", age=51, price=Decimal('2.22')).save() + results = User.objects.only('id', 'name').as_pymongo() + self.assertEqual(sorted(results[0].keys()), sorted(['_id', 'name'])) + users = User.objects.only('name', 'price').as_pymongo() results = list(users) self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[1], dict)) self.assertEqual(results[0]['name'], 'Bob Dole') - self.assertEqual(results[0]['price'], '1.11') + self.assertEqual(results[0]['price'], 1.11) self.assertEqual(results[1]['name'], 'Barack Obama') - self.assertEqual(results[1]['price'], '2.22') + self.assertEqual(results[1]['price'], 2.22) # Test coerce_types users = User.objects.only('name', 'price').as_pymongo(coerce_types=True) @@ -3229,6 +3482,28 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(results[1]['name'], 'Barack Obama') self.assertEqual(results[1]['price'], Decimal('2.22')) + def test_as_pymongo_json_limit_fields(self): + + class User(Document): + email = EmailField(unique=True, required=True) + password_hash = StringField(db_field='password_hash', required=True) + password_salt = StringField(db_field='password_salt', required=True) + + User.drop_collection() + User(email="ross@example.com", password_salt="SomeSalt", password_hash="SomeHash").save() + + serialized_user = User.objects.exclude('password_salt', 'password_hash').as_pymongo()[0] + self.assertEqual(set(['_id', 'email']), set(serialized_user.keys())) + + serialized_user = User.objects.exclude('id', 'password_salt', 'password_hash').to_json() + self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) + + serialized_user = User.objects.exclude('password_salt').only('email').as_pymongo()[0] + self.assertEqual(set(['email']), set(serialized_user.keys())) + + serialized_user = User.objects.exclude('password_salt').only('email').to_json() + self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) + def test_no_dereference(self): class Organization(Document): @@ -3248,8 +3523,104 @@ class QuerySetTest(unittest.TestCase): self.assertTrue(isinstance(qs.first().organization, Organization)) self.assertFalse(isinstance(qs.no_dereference().first().organization, Organization)) + self.assertFalse(isinstance(qs.no_dereference().get().organization, + Organization)) self.assertTrue(isinstance(qs.first().organization, Organization)) + def test_cached_queryset(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + for i in xrange(100): + Person(name="No: %s" % i).save() + + with query_counter() as q: + self.assertEqual(q, 0) + people = Person.objects + + [x for x in people] + self.assertEqual(100, len(people._result_cache)) + self.assertEqual(None, people._len) + self.assertEqual(q, 1) + + list(people) + self.assertEqual(100, people._len) # Caused by list calling len + self.assertEqual(q, 1) + + people.count() # count is cached + self.assertEqual(q, 1) + + def test_no_cached_queryset(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + for i in xrange(100): + Person(name="No: %s" % i).save() + + with query_counter() as q: + self.assertEqual(q, 0) + people = Person.objects.no_cache() + + [x for x in people] + self.assertEqual(q, 1) + + list(people) + self.assertEqual(q, 2) + + people.count() + self.assertEqual(q, 3) + + def test_cache_not_cloned(self): + + class User(Document): + name = StringField() + + def __unicode__(self): + return self.name + + User.drop_collection() + + User(name="Alice").save() + User(name="Bob").save() + + users = User.objects.all().order_by('name') + self.assertEqual("%s" % users, "[, ]") + self.assertEqual(2, len(users._result_cache)) + + users = users.filter(name="Bob") + self.assertEqual("%s" % users, "[]") + self.assertEqual(1, len(users._result_cache)) + + def test_no_cache(self): + """Ensure you can add meta data to file""" + + class Noddy(Document): + fields = DictField() + + Noddy.drop_collection() + for i in xrange(100): + noddy = Noddy() + for j in range(20): + noddy.fields["key"+str(j)] = "value "+str(j) + noddy.save() + + docs = Noddy.objects.no_cache() + + counter = len([1 for i in docs]) + self.assertEquals(counter, 100) + + self.assertEquals(len(list(docs)), 100) + self.assertRaises(TypeError, lambda: len(docs)) + + with query_counter() as q: + self.assertEqual(q, 0) + list(docs) + self.assertEqual(q, 1) + list(docs) + self.assertEqual(q, 2) + def test_nested_queryset_iterator(self): # Try iterating the same queryset twice, nested. names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George'] @@ -3266,30 +3637,138 @@ class QuerySetTest(unittest.TestCase): User(name=name).save() users = User.objects.all().order_by('name') - outer_count = 0 inner_count = 0 inner_total_count = 0 - self.assertEqual(len(users), 7) + with query_counter() as q: + self.assertEqual(q, 0) - for i, outer_user in enumerate(users): - self.assertEqual(outer_user.name, names[i]) - outer_count += 1 - inner_count = 0 + self.assertEqual(users.count(), 7) - # Calling len might disrupt the inner loop if there are bugs - self.assertEqual(len(users), 7) + for i, outer_user in enumerate(users): + self.assertEqual(outer_user.name, names[i]) + outer_count += 1 + inner_count = 0 - for j, inner_user in enumerate(users): - self.assertEqual(inner_user.name, names[j]) - inner_count += 1 - inner_total_count += 1 + # Calling len might disrupt the inner loop if there are bugs + self.assertEqual(users.count(), 7) - self.assertEqual(inner_count, 7) # inner loop should always be executed seven times + for j, inner_user in enumerate(users): + self.assertEqual(inner_user.name, names[j]) + inner_count += 1 + inner_total_count += 1 + + self.assertEqual(inner_count, 7) # inner loop should always be executed seven times + + self.assertEqual(outer_count, 7) # outer loop should be executed seven times total + self.assertEqual(inner_total_count, 7 * 7) # inner loop should be executed fourtynine times total + + self.assertEqual(q, 2) + + def test_no_sub_classes(self): + class A(Document): + x = IntField() + y = IntField() + + meta = {'allow_inheritance': True} + + class B(A): + z = IntField() + + class C(B): + zz = IntField() + + A.drop_collection() + + A(x=10, y=20).save() + A(x=15, y=30).save() + B(x=20, y=40).save() + B(x=30, y=50).save() + C(x=40, y=60).save() + + self.assertEqual(A.objects.no_sub_classes().count(), 2) + self.assertEqual(A.objects.count(), 5) + + self.assertEqual(B.objects.no_sub_classes().count(), 2) + self.assertEqual(B.objects.count(), 3) + + self.assertEqual(C.objects.no_sub_classes().count(), 1) + self.assertEqual(C.objects.count(), 1) + + for obj in A.objects.no_sub_classes(): + self.assertEqual(obj.__class__, A) + + for obj in B.objects.no_sub_classes(): + self.assertEqual(obj.__class__, B) + + for obj in C.objects.no_sub_classes(): + self.assertEqual(obj.__class__, C) + + def test_query_reference_to_custom_pk_doc(self): + + class A(Document): + id = StringField(unique=True, primary_key=True) + + class B(Document): + a = ReferenceField(A) + + A.drop_collection() + B.drop_collection() + + a = A.objects.create(id='custom_id') + + b = B.objects.create(a=a) + + self.assertEqual(B.objects.count(), 1) + self.assertEqual(B.objects.get(a=a).a, a) + self.assertEqual(B.objects.get(a=a.id).a, a) + + def test_cls_query_in_subclassed_docs(self): + + class Animal(Document): + name = StringField() + + meta = { + 'allow_inheritance': True + } + + class Dog(Animal): + pass + + class Cat(Animal): + pass + + self.assertEqual(Animal.objects(name='Charlie')._query, { + 'name': 'Charlie', + '_cls': { '$in': ('Animal', 'Animal.Dog', 'Animal.Cat') } + }) + self.assertEqual(Dog.objects(name='Charlie')._query, { + 'name': 'Charlie', + '_cls': 'Animal.Dog' + }) + self.assertEqual(Cat.objects(name='Charlie')._query, { + 'name': 'Charlie', + '_cls': 'Animal.Cat' + }) + + def test_can_have_field_same_name_as_query_operator(self): + + class Size(Document): + name = StringField() + + class Example(Document): + size = ReferenceField(Size) + + Size.drop_collection() + Example.drop_collection() + + instance_size = Size(name="Large").save() + Example(size=instance_size).save() + + self.assertEqual(Example.objects(size=instance_size).count(), 1) + self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) - self.assertEqual(outer_count, 7) # outer loop should be executed seven times total - self.assertEqual(inner_total_count, 7 * 7) # inner loop should be executed fourtynine times total if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index d38cbfd..d2e8b78 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -1,4 +1,3 @@ -from __future__ import with_statement import sys sys.path[0:0] = [""] @@ -32,6 +31,31 @@ class TransformTest(unittest.TestCase): self.assertEqual(transform.query(name__exists=True), {'name': {'$exists': True}}) + def test_transform_update(self): + class DicDoc(Document): + dictField = DictField() + + class Doc(Document): + pass + + DicDoc.drop_collection() + Doc.drop_collection() + + doc = Doc().save() + dic_doc = DicDoc().save() + + for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")): + update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) + self.assertTrue(isinstance(update[v]["dictField.test"], dict)) + + # Update special cases + update = transform.update(DicDoc, unset__dictField__test=doc) + self.assertEqual(update["$unset"]["dictField.test"], 1) + + update = transform.update(DicDoc, pull__dictField__test=doc) + self.assertTrue(isinstance(update["$pull"]["dictField"]["test"], dict)) + + def test_query_field_name(self): """Ensure that the correct field name is used when querying. """ @@ -53,14 +77,14 @@ class TransformTest(unittest.TestCase): BlogPost.objects(title=data['title'])._query) self.assertFalse('title' in BlogPost.objects(title=data['title'])._query) - self.assertEqual(len(BlogPost.objects(title=data['title'])), 1) + self.assertEqual(BlogPost.objects(title=data['title']).count(), 1) self.assertTrue('_id' in BlogPost.objects(pk=post.id)._query) - self.assertEqual(len(BlogPost.objects(pk=post.id)), 1) + self.assertEqual(BlogPost.objects(pk=post.id).count(), 1) self.assertTrue('postComments.commentContent' in BlogPost.objects(comments__content='test')._query) - self.assertEqual(len(BlogPost.objects(comments__content='test')), 1) + self.assertEqual(BlogPost.objects(comments__content='test').count(), 1) BlogPost.drop_collection() @@ -79,7 +103,7 @@ class TransformTest(unittest.TestCase): self.assertTrue('_id' in BlogPost.objects(pk=data['title'])._query) self.assertTrue('_id' in BlogPost.objects(title=data['title'])._query) - self.assertEqual(len(BlogPost.objects(pk=data['title'])), 1) + self.assertEqual(BlogPost.objects(pk=data['title']).count(), 1) BlogPost.drop_collection() diff --git a/tests/queryset/visitor.py b/tests/queryset/visitor.py index 98815db..0bb6f69 100644 --- a/tests/queryset/visitor.py +++ b/tests/queryset/visitor.py @@ -1,4 +1,3 @@ -from __future__ import with_statement import sys sys.path[0:0] = [""] @@ -69,11 +68,11 @@ class QTest(unittest.TestCase): x = IntField() y = StringField() - # Check than an error is raised when conflicting queries are anded - def invalid_combination(): - query = Q(x__lt=7) & Q(x__lt=3) - query.to_query(TestDoc) - self.assertRaises(InvalidQueryError, invalid_combination) + query = (Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc) + self.assertEqual(query, {'$and': [{'x': {'$lt': 7}}, {'x': {'$lt': 3}}]}) + + query = (Q(y="a") & Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc) + self.assertEqual(query, {'$and': [{'y': "a"}, {'x': {'$lt': 7}}, {'x': {'$lt': 3}}]}) # Check normal cases work without an error query = Q(x__lt=7) & Q(x__gt=3) @@ -268,8 +267,8 @@ class QTest(unittest.TestCase): self.Person(name='user3', age=30).save() self.Person(name='user4', age=40).save() - self.assertEqual(len(self.Person.objects(Q(age__in=[20]))), 2) - self.assertEqual(len(self.Person.objects(Q(age__in=[20, 30]))), 3) + self.assertEqual(self.Person.objects(Q(age__in=[20])).count(), 2) + self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) # Test invalid query objs def wrong_query_objs(): @@ -311,8 +310,8 @@ class QTest(unittest.TestCase): BlogPost(tags=['python', 'mongo']).save() BlogPost(tags=['python']).save() - self.assertEqual(len(BlogPost.objects(Q(tags='mongo'))), 1) - self.assertEqual(len(BlogPost.objects(Q(tags='python'))), 2) + self.assertEqual(BlogPost.objects(Q(tags='mongo')).count(), 1) + self.assertEqual(BlogPost.objects(Q(tags='python')).count(), 2) BlogPost.drop_collection() @@ -326,10 +325,26 @@ class QTest(unittest.TestCase): pk = ObjectId() User(email='example@example.com', pk=pk).save() - self.assertEqual(1, User.objects.filter( - Q(email='example@example.com') | - Q(name='John Doe') - ).limit(2).filter(pk=pk).count()) + self.assertEqual(1, User.objects.filter(Q(email='example@example.com') | + Q(name='John Doe')).limit(2).filter(pk=pk).count()) + + def test_chained_q_or_filtering(self): + + class Post(EmbeddedDocument): + name = StringField(required=True) + + class Item(Document): + postables = ListField(EmbeddedDocumentField(Post)) + + Item.drop_collection() + + Item(postables=[Post(name="a"), Post(name="b")]).save() + Item(postables=[Post(name="a"), Post(name="c")]).save() + Item(postables=[Post(name="a"), Post(name="b"), Post(name="c")]).save() + + self.assertEqual(Item.objects(Q(postables__name="a") & Q(postables__name="b")).count(), 2) + self.assertEqual(Item.objects.filter(postables__name="a").filter(postables__name="b").count(), 2) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_connection.py b/tests/test_connection.py index 5b9743d..62d795c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,3 @@ -from __future__ import with_statement import sys sys.path[0:0] = [""] import unittest @@ -10,7 +9,6 @@ from bson.tz_util import utc from mongoengine import * import mongoengine.connection from mongoengine.connection import get_db, get_connection, ConnectionError -from mongoengine.context_managers import switch_db class ConnectionTest(unittest.TestCase): @@ -26,7 +24,7 @@ class ConnectionTest(unittest.TestCase): connect('mongoenginetest') conn = get_connection() - self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) db = get_db() self.assertTrue(isinstance(db, pymongo.database.Database)) @@ -34,7 +32,7 @@ class ConnectionTest(unittest.TestCase): connect('mongoenginetest2', alias='testdb') conn = get_connection('testdb') - self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) def test_connect_uri(self): """Ensure that the connect() method works properly with uri's @@ -52,12 +50,41 @@ class ConnectionTest(unittest.TestCase): connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') conn = get_connection() - self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) db = get_db() self.assertTrue(isinstance(db, pymongo.database.Database)) self.assertEqual(db.name, 'mongoenginetest') + c.admin.system.users.remove({}) + c.mongoenginetest.system.users.remove({}) + + def test_connect_uri_without_db(self): + """Ensure that the connect() method works properly with uri's + without database_name + """ + c = connect(db='mongoenginetest', alias='admin') + c.admin.system.users.remove({}) + c.mongoenginetest.system.users.remove({}) + + c.admin.add_user("admin", "password") + c.admin.authenticate("admin", "password") + c.mongoenginetest.add_user("username", "password") + + self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') + + connect("mongoenginetest", host='mongodb://localhost/') + + conn = get_connection() + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) + + db = get_db() + self.assertTrue(isinstance(db, pymongo.database.Database)) + self.assertEqual(db.name, 'mongoenginetest') + + c.admin.system.users.remove({}) + c.mongoenginetest.system.users.remove({}) + def test_register_connection(self): """Ensure that connections with different aliases may be registered. """ @@ -65,7 +92,7 @@ class ConnectionTest(unittest.TestCase): self.assertRaises(ConnectionError, get_connection) conn = get_connection('testdb') - self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) db = get_db('testdb') self.assertTrue(isinstance(db, pymongo.database.Database)) diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index eef63be..c201a5f 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -1,4 +1,3 @@ -from __future__ import with_statement import sys sys.path[0:0] = [""] import unittest @@ -6,7 +5,8 @@ import unittest from mongoengine import * from mongoengine.connection import get_db from mongoengine.context_managers import (switch_db, switch_collection, - no_dereference, query_counter) + no_sub_classes, no_dereference, + query_counter) class ContextManagersTest(unittest.TestCase): @@ -139,6 +139,54 @@ class ContextManagersTest(unittest.TestCase): self.assertTrue(isinstance(group.ref, User)) self.assertTrue(isinstance(group.generic, User)) + def test_no_sub_classes(self): + class A(Document): + x = IntField() + y = IntField() + + meta = {'allow_inheritance': True} + + class B(A): + z = IntField() + + class C(B): + zz = IntField() + + A.drop_collection() + + A(x=10, y=20).save() + A(x=15, y=30).save() + B(x=20, y=40).save() + B(x=30, y=50).save() + C(x=40, y=60).save() + + self.assertEqual(A.objects.count(), 5) + self.assertEqual(B.objects.count(), 3) + self.assertEqual(C.objects.count(), 1) + + with no_sub_classes(A) as A: + self.assertEqual(A.objects.count(), 2) + + for obj in A.objects: + self.assertEqual(obj.__class__, A) + + with no_sub_classes(B) as B: + self.assertEqual(B.objects.count(), 2) + + for obj in B.objects: + self.assertEqual(obj.__class__, B) + + with no_sub_classes(C) as C: + self.assertEqual(C.objects.count(), 1) + + for obj in C.objects: + self.assertEqual(obj.__class__, C) + + # Confirm context manager exit correctly + self.assertEqual(A.objects.count(), 5) + self.assertEqual(B.objects.count(), 3) + self.assertEqual(C.objects.count(), 1) + def test_query_counter(self): connect('mongoenginetest') db = get_db() diff --git a/tests/test_dereference.py b/tests/test_dereference.py index ef5a10d..6f2664a 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -from __future__ import with_statement import sys sys.path[0:0] = [""] import unittest @@ -1122,37 +1121,32 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) - def test_tuples_as_tuples(self): - """ - Ensure that tuples remain tuples when they are - inside a ComplexBaseField - """ - from mongoengine.base import BaseField + def test_objectid_reference_across_databases(self): + # mongoenginetest - Is default connection alias from setUp() + # Register Aliases + register_connection('testdb-1', 'mongoenginetest2') - class EnumField(BaseField): + class User(Document): + name = StringField() + meta = {"db_alias": "testdb-1"} - def __init__(self, **kwargs): - super(EnumField, self).__init__(**kwargs) + class Book(Document): + name = StringField() + author = ReferenceField(User) - def to_mongo(self, value): - return value + # Drops + User.drop_collection() + Book.drop_collection() - def to_python(self, value): - return tuple(value) + user = User(name="Ross").save() + Book(name="MongoEngine for pros", author=user).save() - class TestDoc(Document): - items = ListField(EnumField()) + # Can't use query_counter across databases - so test the _data object + book = Book.objects.first() + self.assertFalse(isinstance(book._data['author'], User)) - TestDoc.drop_collection() - tuples = [(100, 'Testing')] - doc = TestDoc() - doc.items = tuples - doc.save() - x = TestDoc.objects().get() - self.assertTrue(x is not None) - self.assertTrue(len(x.items) == 1) - self.assertTrue(tuple(x.items[0]) in tuples) - self.assertTrue(x.items[0] in tuples) + book.select_related() + self.assertTrue(isinstance(book._data['author'], User)) def test_non_ascii_pk(self): """ @@ -1177,6 +1171,30 @@ class FieldTest(unittest.TestCase): self.assertEqual(2, len([brand for bg in brand_groups for brand in bg.brands])) + def test_dereferencing_embedded_listfield_referencefield(self): + class Tag(Document): + meta = {'collection': 'tags'} + name = StringField() + + class Post(EmbeddedDocument): + body = StringField() + tags = ListField(ReferenceField("Tag", dbref=True)) + + class Page(Document): + meta = {'collection': 'pages'} + tags = ListField(ReferenceField("Tag", dbref=True)) + posts = ListField(EmbeddedDocumentField(Post)) + + Tag.drop_collection() + Page.drop_collection() + + tag = Tag(name='test').save() + post = Post(body='test body', tags=[tag]) + Page(tags=[tag], posts=[post]).save() + + page = Page.objects.first() + self.assertEqual(page.tags[0], page.posts[0].tags[0]) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_django.py b/tests/test_django.py index dceeba2..46568ac 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -1,39 +1,45 @@ -from __future__ import with_statement import sys sys.path[0:0] = [""] import unittest from nose.plugins.skip import SkipTest -from mongoengine.python_support import PY3 from mongoengine import * + +from mongoengine.django.shortcuts import get_document_or_404 + +from django.http import Http404 +from django.template import Context, Template +from django.conf import settings +from django.core.paginator import Paginator + +settings.configure( + USE_TZ=True, + INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'), + AUTH_USER_MODEL=('mongo_auth.MongoUser'), +) + try: - from mongoengine.django.shortcuts import get_document_or_404 - - from django.http import Http404 - from django.template import Context, Template - from django.conf import settings - from django.core.paginator import Paginator - - settings.configure(USE_TZ=True) - - from django.contrib.sessions.tests import SessionTestsMixin - from mongoengine.django.sessions import SessionStore, MongoSession -except Exception, err: - if PY3: - SessionTestsMixin = type # dummy value so no error - SessionStore = None # dummy value so no error - else: - raise err - - + from django.contrib.auth import authenticate, get_user_model + from mongoengine.django.auth import User + from mongoengine.django.mongo_auth.models import ( + MongoUser, + MongoUserManager, + get_user_document, + ) + DJ15 = True +except Exception: + DJ15 = False +from django.contrib.sessions.tests import SessionTestsMixin +from mongoengine.django.sessions import SessionStore, MongoSession from datetime import tzinfo, timedelta ZERO = timedelta(0) + class FixedOffset(tzinfo): """Fixed offset in minutes east from UTC.""" def __init__(self, offset, name): - self.__offset = timedelta(minutes = offset) + self.__offset = timedelta(minutes=offset) self.__name = name def utcoffset(self, dt): @@ -60,8 +66,6 @@ def activate_timezone(tz): class QuerySetTest(unittest.TestCase): def setUp(self): - if PY3: - raise SkipTest('django does not have Python 3 support') connect(db='mongoenginetest') class Person(Document): @@ -140,32 +144,91 @@ class QuerySetTest(unittest.TestCase): # Try iterating the same queryset twice, nested, in a Django template. names = ['A', 'B', 'C', 'D'] - class User(Document): + class CustomUser(Document): name = StringField() def __unicode__(self): return self.name - User.drop_collection() + CustomUser.drop_collection() for name in names: - User(name=name).save() + CustomUser(name=name).save() - users = User.objects.all().order_by('name') + users = CustomUser.objects.all().order_by('name') template = Template("{% for user in users %}{{ user.name }}{% ifequal forloop.counter 2 %} {% for inner_user in users %}{{ inner_user.name }}{% endfor %} {% endifequal %}{% endfor %}") rendered = template.render(Context({'users': users})) self.assertEqual(rendered, 'AB ABCD CD') + def test_filter(self): + """Ensure that a queryset and filters work as expected + """ + + class Note(Document): + text = StringField() + + Note.drop_collection() + + for i in xrange(1, 101): + Note(name="Note: %s" % i).save() + + # Check the count + self.assertEqual(Note.objects.count(), 100) + + # Get the first 10 and confirm + notes = Note.objects[:10] + self.assertEqual(notes.count(), 10) + + # Test djangos template filters + # self.assertEqual(length(notes), 10) + t = Template("{{ notes.count }}") + c = Context({"notes": notes}) + self.assertEqual(t.render(c), "10") + + # Test with skip + notes = Note.objects.skip(90) + self.assertEqual(notes.count(), 10) + + # Test djangos template filters + self.assertEqual(notes.count(), 10) + t = Template("{{ notes.count }}") + c = Context({"notes": notes}) + self.assertEqual(t.render(c), "10") + + # Test with limit + notes = Note.objects.skip(90) + self.assertEqual(notes.count(), 10) + + # Test djangos template filters + self.assertEqual(notes.count(), 10) + t = Template("{{ notes.count }}") + c = Context({"notes": notes}) + self.assertEqual(t.render(c), "10") + + # Test with skip and limit + notes = Note.objects.skip(10).limit(10) + + # Test djangos template filters + self.assertEqual(notes.count(), 10) + t = Template("{{ notes.count }}") + c = Context({"notes": notes}) + self.assertEqual(t.render(c), "10") + + class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase): backend = SessionStore def setUp(self): - if PY3: - raise SkipTest('django does not have Python 3 support') connect(db='mongoenginetest') MongoSession.drop_collection() super(MongoDBSessionTest, self).setUp() + def assertIn(self, first, second, msg=None): + self.assertTrue(first in second, msg) + + def assertNotIn(self, first, second, msg=None): + self.assertFalse(first in second, msg) + def test_first_save(self): session = SessionStore() session['test'] = True @@ -176,7 +239,7 @@ class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase): activate_timezone(FixedOffset(60, 'UTC+1')) # create and save new session session = SessionStore() - session.set_expiry(600) # expire in 600 seconds + session.set_expiry(600) # expire in 600 seconds session['test_expire'] = True session.save() # reload session with key @@ -184,5 +247,50 @@ class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase): session = SessionStore(key) self.assertTrue('test_expire' in session, 'Session has expired before it is expected') + +class MongoAuthTest(unittest.TestCase): + user_data = { + 'username': 'user', + 'email': 'user@example.com', + 'password': 'test', + } + + def setUp(self): + if not DJ15: + raise SkipTest('mongo_auth requires Django 1.5') + connect(db='mongoenginetest') + User.drop_collection() + super(MongoAuthTest, self).setUp() + + def test_get_user_model(self): + self.assertEqual(get_user_model(), MongoUser) + + def test_get_user_document(self): + self.assertEqual(get_user_document(), User) + + def test_user_manager(self): + manager = get_user_model()._default_manager + self.assertTrue(isinstance(manager, MongoUserManager)) + + def test_user_manager_exception(self): + manager = get_user_model()._default_manager + self.assertRaises(MongoUser.DoesNotExist, manager.get, + username='not found') + + def test_create_user(self): + manager = get_user_model()._default_manager + user = manager.create_user(**self.user_data) + self.assertTrue(isinstance(user, User)) + db_user = User.objects.get(username='user') + self.assertEqual(user.id, db_user.id) + + def test_authenticate(self): + get_user_model()._default_manager.create_user(**self.user_data) + user = authenticate(username='user', password='fail') + self.assertEqual(None, user) + user = authenticate(username='user', password='test') + db_user = User.objects.get(username='user') + self.assertEqual(user.id, db_user.id) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_jinja.py b/tests/test_jinja.py new file mode 100644 index 0000000..0449f86 --- /dev/null +++ b/tests/test_jinja.py @@ -0,0 +1,47 @@ +import sys +sys.path[0:0] = [""] + +import unittest + +from mongoengine import * + +import jinja2 + + +class TemplateFilterTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + + def test_jinja2(self): + env = jinja2.Environment() + + class TestData(Document): + title = StringField() + description = StringField() + + TestData.drop_collection() + + examples = [('A', '1'), + ('B', '2'), + ('C', '3')] + + for title, description in examples: + TestData(title=title, description=description).save() + + tmpl = """ +{%- for record in content -%} + {%- if loop.first -%}{ {%- endif -%} + "{{ record.title }}": "{{ record.description }}" + {%- if loop.last -%} }{%- else -%},{% endif -%} +{%- endfor -%} +""" + ctx = {'content': TestData.objects} + template = env.from_string(tmpl) + rendered = template.render(**ctx) + + self.assertEqual('{"A": "1","B": "2","C": "3"}', rendered) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_signals.py b/tests/test_signals.py index fc638cf..50e5e6b 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -43,6 +43,15 @@ class SignalTests(unittest.TestCase): def pre_save(cls, sender, document, **kwargs): signal_output.append('pre_save signal, %s' % document) + @classmethod + def pre_save_post_validation(cls, sender, document, **kwargs): + signal_output.append('pre_save_post_validation signal, %s' % document) + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + @classmethod def post_save(cls, sender, document, **kwargs): signal_output.append('post_save signal, %s' % document) @@ -72,44 +81,25 @@ class SignalTests(unittest.TestCase): else: signal_output.append('Not loaded') self.Author = Author + Author.drop_collection() class Another(Document): + name = StringField() def __unicode__(self): return self.name - @classmethod - def pre_init(cls, sender, document, **kwargs): - signal_output.append('pre_init Another signal, %s' % cls.__name__) - signal_output.append(str(kwargs['values'])) - - @classmethod - def post_init(cls, sender, document, **kwargs): - signal_output.append('post_init Another signal, %s' % document) - - @classmethod - def pre_save(cls, sender, document, **kwargs): - signal_output.append('pre_save Another signal, %s' % document) - - @classmethod - def post_save(cls, sender, document, **kwargs): - signal_output.append('post_save Another signal, %s' % document) - if 'created' in kwargs: - if kwargs['created']: - signal_output.append('Is created') - else: - signal_output.append('Is updated') - @classmethod def pre_delete(cls, sender, document, **kwargs): - signal_output.append('pre_delete Another signal, %s' % document) + signal_output.append('pre_delete signal, %s' % document) @classmethod def post_delete(cls, sender, document, **kwargs): - signal_output.append('post_delete Another signal, %s' % document) + signal_output.append('post_delete signal, %s' % document) self.Another = Another + Another.drop_collection() class ExplicitId(Document): id = IntField(primary_key=True) @@ -123,13 +113,15 @@ class SignalTests(unittest.TestCase): signal_output.append('Is updated') self.ExplicitId = ExplicitId - self.ExplicitId.objects.delete() + ExplicitId.drop_collection() + # Save up the number of connected signals so that we can check at the # end that all the signals we register get properly unregistered self.pre_signals = ( len(signals.pre_init.receivers), len(signals.post_init.receivers), len(signals.pre_save.receivers), + len(signals.pre_save_post_validation.receivers), len(signals.post_save.receivers), len(signals.pre_delete.receivers), len(signals.post_delete.receivers), @@ -140,16 +132,13 @@ class SignalTests(unittest.TestCase): signals.pre_init.connect(Author.pre_init, sender=Author) signals.post_init.connect(Author.post_init, sender=Author) signals.pre_save.connect(Author.pre_save, sender=Author) + signals.pre_save_post_validation.connect(Author.pre_save_post_validation, sender=Author) signals.post_save.connect(Author.post_save, sender=Author) signals.pre_delete.connect(Author.pre_delete, sender=Author) signals.post_delete.connect(Author.post_delete, sender=Author) signals.pre_bulk_insert.connect(Author.pre_bulk_insert, sender=Author) signals.post_bulk_insert.connect(Author.post_bulk_insert, sender=Author) - signals.pre_init.connect(Another.pre_init, sender=Another) - signals.post_init.connect(Another.post_init, sender=Another) - signals.pre_save.connect(Another.pre_save, sender=Another) - signals.post_save.connect(Another.post_save, sender=Another) signals.pre_delete.connect(Another.pre_delete, sender=Another) signals.post_delete.connect(Another.post_delete, sender=Another) @@ -161,16 +150,13 @@ class SignalTests(unittest.TestCase): signals.post_delete.disconnect(self.Author.post_delete) signals.pre_delete.disconnect(self.Author.pre_delete) signals.post_save.disconnect(self.Author.post_save) + signals.pre_save_post_validation.disconnect(self.Author.pre_save_post_validation) signals.pre_save.disconnect(self.Author.pre_save) signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) - signals.pre_init.disconnect(self.Another.pre_init) - signals.post_init.disconnect(self.Another.post_init) signals.post_delete.disconnect(self.Another.post_delete) signals.pre_delete.disconnect(self.Another.pre_delete) - signals.post_save.disconnect(self.Another.post_save) - signals.pre_save.disconnect(self.Another.pre_save) signals.post_save.disconnect(self.ExplicitId.post_save) @@ -179,6 +165,7 @@ class SignalTests(unittest.TestCase): len(signals.pre_init.receivers), len(signals.post_init.receivers), len(signals.pre_save.receivers), + len(signals.pre_save_post_validation.receivers), len(signals.post_save.receivers), len(signals.pre_delete.receivers), len(signals.post_delete.receivers), @@ -213,6 +200,8 @@ class SignalTests(unittest.TestCase): a1 = self.Author(name='Bill Shakespeare') self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, Bill Shakespeare", + "pre_save_post_validation signal, Bill Shakespeare", + "Is created", "post_save signal, Bill Shakespeare", "Is created" ]) @@ -221,6 +210,8 @@ class SignalTests(unittest.TestCase): a1.name = 'William Shakespeare' self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, William Shakespeare", + "pre_save_post_validation signal, William Shakespeare", + "Is updated", "post_save signal, William Shakespeare", "Is updated" ]) @@ -249,7 +240,14 @@ class SignalTests(unittest.TestCase): "Not loaded", ]) - self.Author.objects.delete() + def test_queryset_delete_signals(self): + """ Queryset delete should throw some signals. """ + + self.Another(name='Bill Shakespeare').save() + self.assertEqual(self.get_signal_output(self.Another.objects.delete), [ + 'pre_delete signal, Bill Shakespeare', + 'post_delete signal, Bill Shakespeare', + ]) def test_signals_with_explicit_doc_ids(self): """ Model saves must have a created flag the first time."""