Merge remote-tracking branch 'upstream/master'

This commit is contained in:
bool.dev 2013-07-18 08:49:02 +05:30
commit e44e72bce3
57 changed files with 5066 additions and 2197 deletions

View File

@ -11,13 +11,15 @@ env:
- PYMONGO=dev DJANGO=1.4.2 - PYMONGO=dev DJANGO=1.4.2
- PYMONGO=2.5 DJANGO=1.5.1 - PYMONGO=2.5 DJANGO=1.5.1
- PYMONGO=2.5 DJANGO=1.4.2 - PYMONGO=2.5 DJANGO=1.4.2
- PYMONGO=2.4.2 DJANGO=1.4.2 - PYMONGO=3.2 DJANGO=1.5.1
- PYMONGO=3.3 DJANGO=1.5.1
install: install:
- if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi
- if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi
- if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install django==$DJANGO --use-mirrors ; true; fi - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install django==$DJANGO --use-mirrors ; true; fi
- if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi
- if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi
- pip install https://pypi.python.org/packages/source/p/python-dateutil/python-dateutil-2.1.tar.gz#md5=1534bb15cf311f07afaa3aacba1c028b
- python setup.py install - python setup.py install
script: script:
- python setup.py test - python setup.py test

23
AUTHORS
View File

@ -16,8 +16,6 @@ Dervived from the git logs, inevitably incomplete but all of whom and others
have submitted patches, reported bugs and generally helped make MongoEngine have submitted patches, reported bugs and generally helped make MongoEngine
that much better: that much better:
* Harry Marr
* Ross Lawley
* blackbrrr * blackbrrr
* Florian Schlachter * Florian Schlachter
* Vincent Driessen * Vincent Driessen
@ -25,7 +23,7 @@ that much better:
* flosch * flosch
* Deepak Thukral * Deepak Thukral
* Colin Howe * Colin Howe
* Wilson Júnior * Wilson Júnior (https://github.com/wpjunior)
* Alistair Roche * Alistair Roche
* Dan Crosta * Dan Crosta
* Viktor Kerkez * Viktor Kerkez
@ -77,7 +75,7 @@ that much better:
* Adam Parrish * Adam Parrish
* jpfarias * jpfarias
* jonrscott * jonrscott
* Alice Zoë Bevan-McGregor * Alice Zoë Bevan-McGregor (https://github.com/amcgregor/)
* Stephen Young * Stephen Young
* tkloc * tkloc
* aid * aid
@ -157,3 +155,20 @@ that much better:
* Kenneth Falck * Kenneth Falck
* Lukasz Balcerzak * Lukasz Balcerzak
* Nicolas Cortot * Nicolas Cortot
* Alex (https://github.com/kelsta)
* Jin Zhang
* Daniel Axtens
* Leo-Naeka
* Ryan Witt (https://github.com/ryanwitt)
* Jiequan (https://github.com/Jiequan)
* hensom (https://github.com/hensom)
* zhy0216 (https://github.com/zhy0216)
* istinspring (https://github.com/istinspring)
* Massimo Santini (https://github.com/mapio)
* Nigel McNie (https://github.com/nigelmcnie)
* ygbourhis (https://github.com/ygbourhis)
* Bob Dickinson (https://github.com/BobDickinson)
* Michael Bartnett (https://github.com/michaelbartnett)
* Alon Horev (https://github.com/alonho)
* Kelvin Hammond (https://github.com/kelvinhammond)
* Jatin- (https://github.com/jatin-)

View File

@ -20,7 +20,7 @@ post to the `user group <http://groups.google.com/group/mongoengine-users>`
Supported Interpreters Supported Interpreters
---------------------- ----------------------
PyMongo supports CPython 2.5 and newer. Language MongoEngine supports CPython 2.6 and newer. Language
features not supported by all interpreters can not be used. features not supported by all interpreters can not be used.
Please also ensure that your code is properly converted by Please also ensure that your code is properly converted by
`2to3 <http://docs.python.org/library/2to3.html>`_ for Python 3 support. `2to3 <http://docs.python.org/library/2to3.html>`_ for Python 3 support.
@ -46,7 +46,7 @@ General Guidelines
- Write tests and make sure they pass (make sure you have a mongod - Write tests and make sure they pass (make sure you have a mongod
running on the default port, then execute ``python setup.py test`` running on the default port, then execute ``python setup.py test``
from the cmd line to run the test suite). from the cmd line to run the test suite).
- Add yourself to AUTHORS.rst :) - Add yourself to AUTHORS :)
Documentation Documentation
------------- -------------

View File

@ -26,7 +26,7 @@ setup.py install``.
Dependencies Dependencies
============ ============
- pymongo 2.1.1+ - pymongo 2.5+
- sphinx (optional - for documentation generation) - sphinx (optional - for documentation generation)
Examples Examples

View File

@ -2,11 +2,15 @@
* Sphinx stylesheet -- default theme * Sphinx stylesheet -- default theme
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*/ */
@import url("basic.css"); @import url("basic.css");
#changelog p.first {margin-bottom: 0 !important;}
#changelog p {margin-top: 0 !important;
margin-bottom: 0 !important;}
/* -- page layout ----------------------------------------------------------- */ /* -- page layout ----------------------------------------------------------- */
body { body {
font-family: Arial, sans-serif; font-family: Arial, sans-serif;
font-size: 100%; font-size: 100%;
@ -28,18 +32,18 @@ div.bodywrapper {
hr{ hr{
border: 1px solid #B1B4B6; border: 1px solid #B1B4B6;
} }
div.document { div.document {
background-color: #eee; background-color: #eee;
} }
div.body { div.body {
background-color: #ffffff; background-color: #ffffff;
color: #3E4349; color: #3E4349;
padding: 0 30px 30px 30px; padding: 0 30px 30px 30px;
font-size: 0.8em; font-size: 0.8em;
} }
div.footer { div.footer {
color: #555; color: #555;
width: 100%; width: 100%;
@ -47,12 +51,12 @@ div.footer {
text-align: center; text-align: center;
font-size: 75%; font-size: 75%;
} }
div.footer a { div.footer a {
color: #444; color: #444;
text-decoration: underline; text-decoration: underline;
} }
div.related { div.related {
background-color: #6BA81E; background-color: #6BA81E;
line-height: 32px; line-height: 32px;
@ -60,11 +64,11 @@ div.related {
text-shadow: 0px 1px 0 #444; text-shadow: 0px 1px 0 #444;
font-size: 0.80em; font-size: 0.80em;
} }
div.related a { div.related a {
color: #E2F3CC; color: #E2F3CC;
} }
div.sphinxsidebar { div.sphinxsidebar {
font-size: 0.75em; font-size: 0.75em;
line-height: 1.5em; line-height: 1.5em;
@ -73,7 +77,7 @@ div.sphinxsidebar {
div.sphinxsidebarwrapper{ div.sphinxsidebarwrapper{
padding: 20px 0; padding: 20px 0;
} }
div.sphinxsidebar h3, div.sphinxsidebar h3,
div.sphinxsidebar h4 { div.sphinxsidebar h4 {
font-family: Arial, sans-serif; font-family: Arial, sans-serif;
@ -89,30 +93,30 @@ div.sphinxsidebar h4 {
div.sphinxsidebar h4{ div.sphinxsidebar h4{
font-size: 1.1em; font-size: 1.1em;
} }
div.sphinxsidebar h3 a { div.sphinxsidebar h3 a {
color: #444; color: #444;
} }
div.sphinxsidebar p { div.sphinxsidebar p {
color: #888; color: #888;
padding: 5px 20px; padding: 5px 20px;
} }
div.sphinxsidebar p.topless { div.sphinxsidebar p.topless {
} }
div.sphinxsidebar ul { div.sphinxsidebar ul {
margin: 10px 20px; margin: 10px 20px;
padding: 0; padding: 0;
color: #000; color: #000;
} }
div.sphinxsidebar a { div.sphinxsidebar a {
color: #444; color: #444;
} }
div.sphinxsidebar input { div.sphinxsidebar input {
border: 1px solid #ccc; border: 1px solid #ccc;
font-family: sans-serif; font-family: sans-serif;
@ -122,19 +126,19 @@ div.sphinxsidebar input {
div.sphinxsidebar input[type=text]{ div.sphinxsidebar input[type=text]{
margin-left: 20px; margin-left: 20px;
} }
/* -- body styles ----------------------------------------------------------- */ /* -- body styles ----------------------------------------------------------- */
a { a {
color: #005B81; color: #005B81;
text-decoration: none; text-decoration: none;
} }
a:hover { a:hover {
color: #E32E00; color: #E32E00;
text-decoration: underline; text-decoration: underline;
} }
div.body h1, div.body h1,
div.body h2, div.body h2,
div.body h3, div.body h3,
@ -149,30 +153,30 @@ div.body h6 {
padding: 5px 0 5px 10px; padding: 5px 0 5px 10px;
text-shadow: 0px 1px 0 white text-shadow: 0px 1px 0 white
} }
div.body h1 { border-top: 20px solid white; margin-top: 0; font-size: 200%; } div.body h1 { border-top: 20px solid white; margin-top: 0; font-size: 200%; }
div.body h2 { font-size: 150%; background-color: #C8D5E3; } div.body h2 { font-size: 150%; background-color: #C8D5E3; }
div.body h3 { font-size: 120%; background-color: #D8DEE3; } div.body h3 { font-size: 120%; background-color: #D8DEE3; }
div.body h4 { font-size: 110%; background-color: #D8DEE3; } div.body h4 { font-size: 110%; background-color: #D8DEE3; }
div.body h5 { font-size: 100%; background-color: #D8DEE3; } div.body h5 { font-size: 100%; background-color: #D8DEE3; }
div.body h6 { font-size: 100%; background-color: #D8DEE3; } div.body h6 { font-size: 100%; background-color: #D8DEE3; }
a.headerlink { a.headerlink {
color: #c60f0f; color: #c60f0f;
font-size: 0.8em; font-size: 0.8em;
padding: 0 4px 0 4px; padding: 0 4px 0 4px;
text-decoration: none; text-decoration: none;
} }
a.headerlink:hover { a.headerlink:hover {
background-color: #c60f0f; background-color: #c60f0f;
color: white; color: white;
} }
div.body p, div.body dd, div.body li { div.body p, div.body dd, div.body li {
line-height: 1.5em; line-height: 1.5em;
} }
div.admonition p.admonition-title + p { div.admonition p.admonition-title + p {
display: inline; display: inline;
} }
@ -185,29 +189,29 @@ div.note {
background-color: #eee; background-color: #eee;
border: 1px solid #ccc; border: 1px solid #ccc;
} }
div.seealso { div.seealso {
background-color: #ffc; background-color: #ffc;
border: 1px solid #ff6; border: 1px solid #ff6;
} }
div.topic { div.topic {
background-color: #eee; background-color: #eee;
} }
div.warning { div.warning {
background-color: #ffe4e4; background-color: #ffe4e4;
border: 1px solid #f66; border: 1px solid #f66;
} }
p.admonition-title { p.admonition-title {
display: inline; display: inline;
} }
p.admonition-title:after { p.admonition-title:after {
content: ":"; content: ":";
} }
pre { pre {
padding: 10px; padding: 10px;
background-color: White; background-color: White;
@ -219,7 +223,7 @@ pre {
-webkit-box-shadow: 1px 1px 1px #d8d8d8; -webkit-box-shadow: 1px 1px 1px #d8d8d8;
-moz-box-shadow: 1px 1px 1px #d8d8d8; -moz-box-shadow: 1px 1px 1px #d8d8d8;
} }
tt { tt {
background-color: #ecf0f3; background-color: #ecf0f3;
color: #222; color: #222;

View File

@ -49,11 +49,17 @@ Querying
.. automethod:: mongoengine.queryset.QuerySet.__call__ .. automethod:: mongoengine.queryset.QuerySet.__call__
.. autoclass:: mongoengine.queryset.QuerySetNoCache
:members:
.. automethod:: mongoengine.queryset.QuerySetNoCache.__call__
.. autofunction:: mongoengine.queryset.queryset_manager .. autofunction:: mongoengine.queryset.queryset_manager
Fields Fields
====== ======
.. autoclass:: mongoengine.base.fields.BaseField
.. autoclass:: mongoengine.fields.StringField .. autoclass:: mongoengine.fields.StringField
.. autoclass:: mongoengine.fields.URLField .. autoclass:: mongoengine.fields.URLField
.. autoclass:: mongoengine.fields.EmailField .. autoclass:: mongoengine.fields.EmailField
@ -76,11 +82,19 @@ Fields
.. autoclass:: mongoengine.fields.BinaryField .. autoclass:: mongoengine.fields.BinaryField
.. autoclass:: mongoengine.fields.FileField .. autoclass:: mongoengine.fields.FileField
.. autoclass:: mongoengine.fields.ImageField .. autoclass:: mongoengine.fields.ImageField
.. autoclass:: mongoengine.fields.GeoPointField
.. autoclass:: mongoengine.fields.SequenceField .. autoclass:: mongoengine.fields.SequenceField
.. autoclass:: mongoengine.fields.ObjectIdField .. autoclass:: mongoengine.fields.ObjectIdField
.. autoclass:: mongoengine.fields.UUIDField .. autoclass:: mongoengine.fields.UUIDField
.. autoclass:: mongoengine.fields.GeoPointField
.. autoclass:: mongoengine.fields.PointField
.. autoclass:: mongoengine.fields.LineStringField
.. autoclass:: mongoengine.fields.PolygonField
.. autoclass:: mongoengine.fields.GridFSError .. autoclass:: mongoengine.fields.GridFSError
.. autoclass:: mongoengine.fields.GridFSProxy .. autoclass:: mongoengine.fields.GridFSProxy
.. autoclass:: mongoengine.fields.ImageGridFsProxy .. autoclass:: mongoengine.fields.ImageGridFsProxy
.. autoclass:: mongoengine.fields.ImproperlyConfigured .. autoclass:: mongoengine.fields.ImproperlyConfigured
Misc
====
.. autofunction:: mongoengine.common._import_class

View File

@ -2,8 +2,75 @@
Changelog Changelog
========= =========
Changes in 0.8.X Changes in 0.8.3
================ ================
- Fixed EmbeddedDocuments with `id` also storing `_id` (#402)
- Added get_proxy_object helper to filefields (#391)
- Added QuerySetNoCache and QuerySet.no_cache() for lower memory consumption (#365)
- Fixed sum and average mapreduce dot notation support (#375, #376, #393)
- Fixed as_pymongo to return the id (#386)
- Document.select_related() now respects `db_alias` (#377)
- Reload uses shard_key if applicable (#384)
- Dynamic fields are ordered based on creation and stored in _fields_ordered (#396)
**Potential breaking change:** http://docs.mongoengine.org/en/latest/upgrade.html#to-0-8-3
- Fixed pickling dynamic documents `_dynamic_fields` (#387)
- Fixed ListField setslice and delslice dirty tracking (#390)
- Added Django 1.5 PY3 support (#392)
- Added match ($elemMatch) support for EmbeddedDocuments (#379)
- Fixed weakref being valid after reload (#374)
- Fixed queryset.get() respecting no_dereference (#373)
- Added full_result kwarg to update (#380)
Changes in 0.8.2
================
- Added compare_indexes helper (#361)
- Fixed cascading saves which weren't turned off as planned (#291)
- Fixed Datastructures so instances are a Document or EmbeddedDocument (#363)
- Improved cascading saves write performance (#361)
- Fixed ambiguity and differing behaviour regarding field defaults (#349)
- ImageFields now include PIL error messages if invalid error (#353)
- Added lock when calling doc.Delete() for when signals have no sender (#350)
- Reload forces read preference to be PRIMARY (#355)
- Querysets are now lest restrictive when querying duplicate fields (#332, #333)
- FileField now honouring db_alias (#341)
- Removed customised __set__ change tracking in ComplexBaseField (#344)
- Removed unused var in _get_changed_fields (#347)
- Added pre_save_post_validation signal (#345)
- DateTimeField now auto converts valid datetime isostrings into dates (#343)
- DateTimeField now uses dateutil for parsing if available (#343)
- Fixed Doc.objects(read_preference=X) not setting read preference (#352)
- Django session ttl index expiry fixed (#329)
- Fixed pickle.loads (#342)
- Documentation fixes
Changes in 0.8.1
================
- Fixed Python 2.6 django auth importlib issue (#326)
- Fixed pickle unsaved document regression (#327)
Changes in 0.8.0
================
- Fixed querying ReferenceField custom_id (#317)
- Fixed pickle issues with collections (#316)
- Added `get_next_value` preview for SequenceFields (#319)
- Added no_sub_classes context manager and queryset helper (#312)
- Querysets now utilises a local cache
- Changed __len__ behavour in the queryset (#247, #311)
- Fixed querying string versions of ObjectIds issue with ReferenceField (#307)
- Added $setOnInsert support for upserts (#308)
- Upserts now possible with just query parameters (#309)
- Upserting is the only way to ensure docs are saved correctly (#306)
- Fixed register_delete_rule inheritance issue
- Fix cloning of sliced querysets (#303)
- Fixed update_one write concern (#302)
- Updated minimum requirement for pymongo to 2.5
- Add support for new geojson fields, indexes and queries (#299)
- If values cant be compared mark as changed (#287)
- Ensure as_pymongo() and to_json honour only() and exclude() (#293)
- Document serialization uses field order to ensure a strict order is set (#296) - Document serialization uses field order to ensure a strict order is set (#296)
- DecimalField now stores as float not string (#289) - DecimalField now stores as float not string (#289)
- UUIDField now stores as a binary by default (#292) - UUIDField now stores as a binary by default (#292)
@ -14,7 +81,6 @@ Changes in 0.8.X
- Added SequenceField.set_next_value(value) helper (#159) - Added SequenceField.set_next_value(value) helper (#159)
- Updated .only() behaviour - now like exclude it is chainable (#202) - Updated .only() behaviour - now like exclude it is chainable (#202)
- Added with_limit_and_skip support to count() (#235) - Added with_limit_and_skip support to count() (#235)
- Removed __len__ from queryset (#247)
- Objects queryset manager now inherited (#256) - Objects queryset manager now inherited (#256)
- Updated connection to use MongoClient (#262, #274) - Updated connection to use MongoClient (#262, #274)
- Fixed db_alias and inherited Documents (#143) - Fixed db_alias and inherited Documents (#143)

View File

@ -132,7 +132,11 @@ html_theme_path = ['_themes']
html_use_smartypants = True html_use_smartypants = True
# Custom sidebar templates, maps document names to template names. # Custom sidebar templates, maps document names to template names.
#html_sidebars = {} html_sidebars = {
'index': ['globaltoc.html', 'searchbox.html'],
'**': ['localtoc.html', 'relations.html', 'searchbox.html']
}
# Additional templates that should be rendered to pages, maps page names to # Additional templates that should be rendered to pages, maps page names to
# template names. # template names.

View File

@ -1,8 +1,8 @@
============================= ==============
Using MongoEngine with Django Django Support
============================= ==============
.. note:: Updated to support Django 1.4 .. note:: Updated to support Django 1.5
Connecting Connecting
========== ==========
@ -27,9 +27,9 @@ MongoEngine includes a Django authentication backend, which uses MongoDB. The
:class:`~mongoengine.Document`, but implements most of the methods and :class:`~mongoengine.Document`, but implements most of the methods and
attributes that the standard Django :class:`User` model does - so the two are attributes that the standard Django :class:`User` model does - so the two are
moderately compatible. Using this backend will allow you to store users in moderately compatible. Using this backend will allow you to store users in
MongoDB but still use many of the Django authentication infrastucture (such as MongoDB but still use many of the Django authentication infrastructure (such as
the :func:`login_required` decorator and the :func:`authenticate` function). To the :func:`login_required` decorator and the :func:`authenticate` function). To
enable the MongoEngine auth backend, add the following to you **settings.py** enable the MongoEngine auth backend, add the following to your **settings.py**
file:: file::
AUTHENTICATION_BACKENDS = ( AUTHENTICATION_BACKENDS = (
@ -46,7 +46,7 @@ Custom User model
================= =================
Django 1.5 introduced `Custom user Models Django 1.5 introduced `Custom user Models
<https://docs.djangoproject.com/en/dev/topics/auth/customizing/#auth-custom-user>` <https://docs.djangoproject.com/en/dev/topics/auth/customizing/#auth-custom-user>`
which can be used as an alternative the Mongoengine authentication backend. which can be used as an alternative to the MongoEngine authentication backend.
The main advantage of this option is that other components relying on The main advantage of this option is that other components relying on
:mod:`django.contrib.auth` and supporting the new swappable user model are more :mod:`django.contrib.auth` and supporting the new swappable user model are more
@ -82,16 +82,16 @@ Sessions
======== ========
Django allows the use of different backend stores for its sessions. MongoEngine Django allows the use of different backend stores for its sessions. MongoEngine
provides a MongoDB-based session backend for Django, which allows you to use provides a MongoDB-based session backend for Django, which allows you to use
sessions in you Django application with just MongoDB. To enable the MongoEngine sessions in your Django application with just MongoDB. To enable the MongoEngine
session backend, ensure that your settings module has session backend, ensure that your settings module has
``'django.contrib.sessions.middleware.SessionMiddleware'`` in the ``'django.contrib.sessions.middleware.SessionMiddleware'`` in the
``MIDDLEWARE_CLASSES`` field and ``'django.contrib.sessions'`` in your ``MIDDLEWARE_CLASSES`` field and ``'django.contrib.sessions'`` in your
``INSTALLED_APPS``. From there, all you need to do is add the following line ``INSTALLED_APPS``. From there, all you need to do is add the following line
into you settings module:: into your settings module::
SESSION_ENGINE = 'mongoengine.django.sessions' SESSION_ENGINE = 'mongoengine.django.sessions'
Django provides session cookie, which expires after ```SESSION_COOKIE_AGE``` seconds, but doesnt delete cookie at sessions backend, so ``'mongoengine.django.sessions'`` supports `mongodb TTL Django provides session cookie, which expires after ```SESSION_COOKIE_AGE``` seconds, but doesn't delete cookie at sessions backend, so ``'mongoengine.django.sessions'`` supports `mongodb TTL
<http://docs.mongodb.org/manual/tutorial/expire-data/>`_. <http://docs.mongodb.org/manual/tutorial/expire-data/>`_.
.. versionadded:: 0.2.1 .. versionadded:: 0.2.1

View File

@ -36,7 +36,7 @@ MongoEngine supports :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetCl
to use them please use a URI style connection and provide the `replicaSet` name in the to use them please use a URI style connection and provide the `replicaSet` name in the
connection kwargs. connection kwargs.
Read preferences are supported throught the connection or via individual Read preferences are supported through the connection or via individual
queries by passing the read_preference :: queries by passing the read_preference ::
Bar.objects().read_preference(ReadPreference.PRIMARY) Bar.objects().read_preference(ReadPreference.PRIMARY)
@ -83,7 +83,7 @@ reasons.
The :class:`~mongoengine.context_managers.switch_db` context manager allows The :class:`~mongoengine.context_managers.switch_db` context manager allows
you to change the database alias for a given class allowing quick and easy you to change the database alias for a given class allowing quick and easy
access to the same User document across databases.eg :: access to the same User document across databases::
from mongoengine.context_managers import switch_db from mongoengine.context_managers import switch_db

View File

@ -54,7 +54,7 @@ be saved ::
There is one caveat on Dynamic Documents: fields cannot start with `_` There is one caveat on Dynamic Documents: fields cannot start with `_`
Dynamic fields are stored in alphabetical order *after* any declared fields. Dynamic fields are stored in creation order *after* any declared fields.
Fields Fields
====== ======
@ -100,9 +100,6 @@ arguments can be set on all fields:
:attr:`db_field` (Default: None) :attr:`db_field` (Default: None)
The MongoDB field name. The MongoDB field name.
:attr:`name` (Default: None)
The mongoengine field name.
:attr:`required` (Default: False) :attr:`required` (Default: False)
If set to True and the field is not set on the document instance, a If set to True and the field is not set on the document instance, a
:class:`~mongoengine.ValidationError` will be raised when the document is :class:`~mongoengine.ValidationError` will be raised when the document is
@ -129,6 +126,7 @@ arguments can be set on all fields:
# instead to just an object # instead to just an object
values = ListField(IntField(), default=[1,2,3]) values = ListField(IntField(), default=[1,2,3])
.. note:: Unsetting a field with a default value will revert back to the default.
:attr:`unique` (Default: False) :attr:`unique` (Default: False)
When True, no documents in the collection will have the same value for this When True, no documents in the collection will have the same value for this
@ -403,7 +401,7 @@ either a single field name, or a list or tuple of field names::
Skipping Document validation on save Skipping Document validation on save
------------------------------------ ------------------------------------
You can also skip the whole document validation process by setting You can also skip the whole document validation process by setting
``validate=False`` when caling the :meth:`~mongoengine.document.Document.save` ``validate=False`` when calling the :meth:`~mongoengine.document.Document.save`
method:: method::
class Recipient(Document): class Recipient(Document):
@ -452,8 +450,8 @@ by creating a list of index specifications called :attr:`indexes` in the
:attr:`~mongoengine.Document.meta` dictionary, where an index specification may :attr:`~mongoengine.Document.meta` dictionary, where an index specification may
either be a single field name, a tuple containing multiple field names, or a either be a single field name, a tuple containing multiple field names, or a
dictionary containing a full index definition. A direction may be specified on dictionary containing a full index definition. A direction may be specified on
fields by prefixing the field name with a **+** or a **-** sign. Note that fields by prefixing the field name with a **+** (for ascending) or a **-** sign
direction only matters on multi-field indexes. :: (for descending). Note that direction only matters on multi-field indexes. ::
class Page(Document): class Page(Document):
title = StringField() title = StringField()
@ -479,6 +477,10 @@ If a dictionary is passed then the following options are available:
:attr:`unique` (Default: False) :attr:`unique` (Default: False)
Whether the index should be unique. Whether the index should be unique.
:attr:`expireAfterSeconds` (Optional)
Allows you to automatically expire data from a collection by setting the
time in seconds to expire the a field.
.. note:: .. note::
Inheritance adds extra fields indices see: :ref:`document-inheritance`. Inheritance adds extra fields indices see: :ref:`document-inheritance`.
@ -489,12 +491,40 @@ Compound Indexes and Indexing sub documents
Compound indexes can be created by adding the Embedded field or dictionary Compound indexes can be created by adding the Embedded field or dictionary
field name to the index definition. field name to the index definition.
Sometimes its more efficient to index parts of Embeedded / dictionary fields, Sometimes its more efficient to index parts of Embedded / dictionary fields,
in this case use 'dot' notation to identify the value to index eg: `rank.title` in this case use 'dot' notation to identify the value to index eg: `rank.title`
Geospatial indexes Geospatial indexes
------------------ ------------------
The best geo index for mongodb is the new "2dsphere", which has an improved
spherical model and provides better performance and more options when querying.
The following fields will explicitly add a "2dsphere" index:
- :class:`~mongoengine.fields.PointField`
- :class:`~mongoengine.fields.LineStringField`
- :class:`~mongoengine.fields.PolygonField`
As "2dsphere" indexes can be part of a compound index, you may not want the
automatic index but would prefer a compound index. In this example we turn off
auto indexing and explicitly declare a compound index on ``location`` and ``datetime``::
class Log(Document):
location = PointField(auto_index=False)
datetime = DateTimeField()
meta = {
'indexes': [[("location", "2dsphere"), ("datetime", 1)]]
}
Pre MongoDB 2.4 Geo
'''''''''''''''''''
.. note:: For MongoDB < 2.4 this is still current, however the new 2dsphere
index is a big improvement over the previous 2D model - so upgrading is
advised.
Geospatial indexes will be automatically created for all Geospatial indexes will be automatically created for all
:class:`~mongoengine.fields.GeoPointField`\ s :class:`~mongoengine.fields.GeoPointField`\ s
@ -512,6 +542,30 @@ point. To create a geospatial index you must prefix the field with the
], ],
} }
Time To Live indexes
--------------------
A special index type that allows you to automatically expire data from a
collection after a given period. See the official
`ttl <http://docs.mongodb.org/manual/tutorial/expire-data/#expire-data-from-collections-by-setting-ttl>`_
documentation for more information. A common usecase might be session data::
class Session(Document):
created = DateTimeField(default=datetime.now)
meta = {
'indexes': [
{'fields': ['created'], 'expireAfterSeconds': 3600}
]
}
Comparing Indexes
-----------------
Use :func:`mongoengine.Document.compare_indexes` to compare actual indexes in
the database to those that your document definitions define. This is useful
for maintenance purposes and ensuring you have the correct indexes for your
schema.
Ordering Ordering
======== ========
A default ordering can be specified for your A default ordering can be specified for your

View File

@ -15,11 +15,10 @@ fetch documents from the database::
.. note:: .. note::
Once the iteration finishes (when :class:`StopIteration` is raised), As of MongoEngine 0.8 the querysets utilise a local cache. So iterating
:meth:`~mongoengine.queryset.QuerySet.rewind` will be called so that the it multiple times will only cause a single query. If this is not the
:class:`~mongoengine.queryset.QuerySet` may be iterated over again. The desired behavour you can call :class:`~mongoengine.QuerySet.no_cache` to
results of the first iteration are *not* cached, so the database will be hit return a non-caching queryset.
each time the :class:`~mongoengine.queryset.QuerySet` is iterated over.
Filtering queries Filtering queries
================= =================
@ -65,6 +64,9 @@ Available operators are as follows:
* ``size`` -- the size of the array is * ``size`` -- the size of the array is
* ``exists`` -- value for field exists * ``exists`` -- value for field exists
String queries
--------------
The following operators are available as shortcuts to querying with regular The following operators are available as shortcuts to querying with regular
expressions: expressions:
@ -78,8 +80,71 @@ expressions:
* ``iendswith`` -- string field ends with value (case insensitive) * ``iendswith`` -- string field ends with value (case insensitive)
* ``match`` -- performs an $elemMatch so you can match an entire document within an array * ``match`` -- performs an $elemMatch so you can match an entire document within an array
There are a few special operators for performing geographical queries, that
may used with :class:`~mongoengine.fields.GeoPointField`\ s: Geo queries
-----------
There are a few special operators for performing geographical queries. The following
were added in 0.8 for: :class:`~mongoengine.fields.PointField`,
:class:`~mongoengine.fields.LineStringField` and
:class:`~mongoengine.fields.PolygonField`:
* ``geo_within`` -- Check if a geometry is within a polygon. For ease of use
it accepts either a geojson geometry or just the polygon coordinates eg::
loc.objects(point__geo_with=[[[40, 5], [40, 6], [41, 6], [40, 5]]])
loc.objects(point__geo_with={"type": "Polygon",
"coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]})
* ``geo_within_box`` - simplified geo_within searching with a box eg::
loc.objects(point__geo_within_box=[(-125.0, 35.0), (-100.0, 40.0)])
loc.objects(point__geo_within_box=[<bottom left coordinates>, <upper right coordinates>])
* ``geo_within_polygon`` -- simplified geo_within searching within a simple polygon eg::
loc.objects(point__geo_within_polygon=[[40, 5], [40, 6], [41, 6], [40, 5]])
loc.objects(point__geo_within_polygon=[ [ <x1> , <y1> ] ,
[ <x2> , <y2> ] ,
[ <x3> , <y3> ] ])
* ``geo_within_center`` -- simplified geo_within the flat circle radius of a point eg::
loc.objects(point__geo_within_center=[(-125.0, 35.0), 1])
loc.objects(point__geo_within_center=[ [ <x>, <y> ] , <radius> ])
* ``geo_within_sphere`` -- simplified geo_within the spherical circle radius of a point eg::
loc.objects(point__geo_within_sphere=[(-125.0, 35.0), 1])
loc.objects(point__geo_within_sphere=[ [ <x>, <y> ] , <radius> ])
* ``geo_intersects`` -- selects all locations that intersect with a geometry eg::
# Inferred from provided points lists:
loc.objects(poly__geo_intersects=[40, 6])
loc.objects(poly__geo_intersects=[[40, 5], [40, 6]])
loc.objects(poly__geo_intersects=[[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]])
# With geoJson style objects
loc.objects(poly__geo_intersects={"type": "Point", "coordinates": [40, 6]})
loc.objects(poly__geo_intersects={"type": "LineString",
"coordinates": [[40, 5], [40, 6]]})
loc.objects(poly__geo_intersects={"type": "Polygon",
"coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]})
* ``near`` -- Find all the locations near a given point::
loc.objects(point__near=[40, 5])
loc.objects(point__near={"type": "Point", "coordinates": [40, 5]})
You can also set the maximum distance in meters as well::
loc.objects(point__near=[40, 5], point__max_distance=1000)
The older 2D indexes are still supported with the
:class:`~mongoengine.fields.GeoPointField`:
* ``within_distance`` -- provide a list containing a point and a maximum * ``within_distance`` -- provide a list containing a point and a maximum
distance (e.g. [(41.342, -87.653), 5]) distance (e.g. [(41.342, -87.653), 5])
@ -91,7 +156,9 @@ may used with :class:`~mongoengine.fields.GeoPointField`\ s:
[(35.0, -125.0), (40.0, -100.0)]) [(35.0, -125.0), (40.0, -100.0)])
* ``within_polygon`` -- filter documents to those within a given polygon (e.g. * ``within_polygon`` -- filter documents to those within a given polygon (e.g.
[(41.91,-87.69), (41.92,-87.68), (41.91,-87.65), (41.89,-87.65)]). [(41.91,-87.69), (41.92,-87.68), (41.91,-87.65), (41.89,-87.65)]).
.. note:: Requires Mongo Server 2.0 .. note:: Requires Mongo Server 2.0
* ``max_distance`` -- can be added to your location queries to set a maximum * ``max_distance`` -- can be added to your location queries to set a maximum
distance. distance.

View File

@ -1,5 +1,6 @@
.. _signals: .. _signals:
=======
Signals Signals
======= =======
@ -7,32 +8,95 @@ Signals
.. note:: .. note::
Signal support is provided by the excellent `blinker`_ library and Signal support is provided by the excellent `blinker`_ library. If you wish
will gracefully fall back if it is not available. to enable signal support this library must be installed, though it is not
required for MongoEngine to function.
Overview
--------
The following document signals exist in MongoEngine and are pretty self-explanatory: Signals are found within the `mongoengine.signals` module. Unless
specified signals receive no additional arguments beyond the `sender` class and
`document` instance. Post-signals are only called if there were no exceptions
raised during the processing of their related function.
* `mongoengine.signals.pre_init` Available signals include:
* `mongoengine.signals.post_init`
* `mongoengine.signals.pre_save`
* `mongoengine.signals.post_save`
* `mongoengine.signals.pre_delete`
* `mongoengine.signals.post_delete`
* `mongoengine.signals.pre_bulk_insert`
* `mongoengine.signals.post_bulk_insert`
Example usage:: `pre_init`
Called during the creation of a new :class:`~mongoengine.Document` or
:class:`~mongoengine.EmbeddedDocument` instance, after the constructor
arguments have been collected but before any additional processing has been
done to them. (I.e. assignment of default values.) Handlers for this signal
are passed the dictionary of arguments using the `values` keyword argument
and may modify this dictionary prior to returning.
`post_init`
Called after all processing of a new :class:`~mongoengine.Document` or
:class:`~mongoengine.EmbeddedDocument` instance has been completed.
`pre_save`
Called within :meth:`~mongoengine.document.Document.save` prior to performing
any actions.
`pre_save_post_validation`
Called within :meth:`~mongoengine.document.Document.save` after validation
has taken place but before saving.
`post_save`
Called within :meth:`~mongoengine.document.Document.save` after all actions
(validation, insert/update, cascades, clearing dirty flags) have completed
successfully. Passed the additional boolean keyword argument `created` to
indicate if the save was an insert or an update.
`pre_delete`
Called within :meth:`~mongoengine.document.Document.delete` prior to
attempting the delete operation.
`post_delete`
Called within :meth:`~mongoengine.document.Document.delete` upon successful
deletion of the record.
`pre_bulk_insert`
Called after validation of the documents to insert, but prior to any data
being written. In this case, the `document` argument is replaced by a
`documents` argument representing the list of documents being inserted.
`post_bulk_insert`
Called after a successful bulk insert operation. As per `pre_bulk_insert`,
the `document` argument is omitted and replaced with a `documents` argument.
An additional boolean argument, `loaded`, identifies the contents of
`documents` as either :class:`~mongoengine.Document` instances when `True` or
simply a list of primary key values for the inserted records if `False`.
Attaching Events
----------------
After writing a handler function like the following::
import logging
from datetime import datetime
from mongoengine import * from mongoengine import *
from mongoengine import signals from mongoengine import signals
def update_modified(sender, document):
document.modified = datetime.utcnow()
You attach the event handler to your :class:`~mongoengine.Document` or
:class:`~mongoengine.EmbeddedDocument` subclass::
class Record(Document):
modified = DateTimeField()
signals.pre_save.connect(update_modified)
While this is not the most elaborate document model, it does demonstrate the
concepts involved. As a more complete demonstration you can also define your
handlers within your subclass::
class Author(Document): class Author(Document):
name = StringField() name = StringField()
def __unicode__(self):
return self.name
@classmethod @classmethod
def pre_save(cls, sender, document, **kwargs): def pre_save(cls, sender, document, **kwargs):
logging.debug("Pre Save: %s" % document.name) logging.debug("Pre Save: %s" % document.name)
@ -49,12 +113,40 @@ Example usage::
signals.pre_save.connect(Author.pre_save, sender=Author) signals.pre_save.connect(Author.pre_save, sender=Author)
signals.post_save.connect(Author.post_save, sender=Author) signals.post_save.connect(Author.post_save, sender=Author)
Finally, you can also use this small decorator to quickly create a number of
signals and attach them to your :class:`~mongoengine.Document` or
:class:`~mongoengine.EmbeddedDocument` subclasses as class decorators::
ReferenceFields and signals def handler(event):
"""Signal decorator to allow use of callback functions as class decorators."""
def decorator(fn):
def apply(cls):
event.connect(fn, sender=cls)
return cls
fn.apply = apply
return fn
return decorator
Using the first example of updating a modification time the code is now much
cleaner looking while still allowing manual execution of the callback::
@handler(signals.pre_save)
def update_modified(sender, document):
document.modified = datetime.utcnow()
@update_modified.apply
class Record(Document):
modified = DateTimeField()
ReferenceFields and Signals
--------------------------- ---------------------------
Currently `reverse_delete_rules` do not trigger signals on the other part of Currently `reverse_delete_rules` do not trigger signals on the other part of
the relationship. If this is required you must manually handled the the relationship. If this is required you must manually handle the
reverse deletion. reverse deletion.
.. _blinker: http://pypi.python.org/pypi/blinker .. _blinker: http://pypi.python.org/pypi/blinker

View File

@ -55,15 +55,25 @@ See the :doc:`changelog` for a full list of changes to MongoEngine and
.. note:: Always read and test the `upgrade <upgrade>`_ documentation before .. note:: Always read and test the `upgrade <upgrade>`_ documentation before
putting updates live in production **;)** putting updates live in production **;)**
Offline Reading
---------------
Download the docs in `pdf <https://media.readthedocs.org/pdf/mongoengine-odm/latest/mongoengine-odm.pdf>`_
or `epub <https://media.readthedocs.org/epub/mongoengine-odm/latest/mongoengine-odm.epub>`_
formats for offline reading.
.. toctree:: .. toctree::
:maxdepth: 1
:numbered:
:hidden: :hidden:
tutorial tutorial
guide/index guide/index
apireference apireference
django
changelog changelog
upgrade upgrade
django
Indices and tables Indices and tables
------------------ ------------------

View File

@ -298,5 +298,5 @@ Learning more about mongoengine
------------------------------- -------------------------------
If you got this far you've made a great start, so well done! The next step on If you got this far you've made a great start, so well done! The next step on
your mongoengine journey is the `full user guide <guide/index>`_, where you your mongoengine journey is the `full user guide <guide/index.html>`_, where you
can learn indepth about how to use mongoengine and mongodb. can learn indepth about how to use mongoengine and mongodb.

View File

@ -2,6 +2,16 @@
Upgrading Upgrading
######### #########
0.8.2 to 0.8.3
**************
Minor change that may impact users:
DynamicDocument fields are now stored in creation order after any declared
fields. Previously they were stored alphabetically.
0.7 to 0.8 0.7 to 0.8
********** **********
@ -15,10 +25,10 @@ possible for the whole of the release.
live. There maybe multiple manual steps in migrating and these are best honed live. There maybe multiple manual steps in migrating and these are best honed
on a staging / test system. on a staging / test system.
Python Python and PyMongo
======= ==================
Support for python 2.5 has been dropped. MongoEngine requires python 2.6 (or above) and pymongo 2.5 (or above)
Data Model Data Model
========== ==========
@ -91,6 +101,13 @@ the case and the data is set only in the ``document._data`` dictionary: ::
File "<stdin>", line 1, in <module> File "<stdin>", line 1, in <module>
AttributeError: 'Animal' object has no attribute 'size' AttributeError: 'Animal' object has no attribute 'size'
The Document class has introduced a reserved function `clean()`, which will be
called before saving the document. If your document class happen to have a method
with the same name, please try rename it.
def clean(self):
pass
ReferenceField ReferenceField
-------------- --------------
@ -116,13 +133,17 @@ eg::
# Mark all ReferenceFields as dirty and save # Mark all ReferenceFields as dirty and save
for p in Person.objects: for p in Person.objects:
p._mark_as_dirty('parent') p._mark_as_changed('parent')
p._mark_as_dirty('friends') p._mark_as_changed('friends')
p.save() p.save()
`An example test migration for ReferenceFields is available on github `An example test migration for ReferenceFields is available on github
<https://github.com/MongoEngine/mongoengine/blob/master/tests/migration/refrencefield_dbref_to_object_id.py>`_. <https://github.com/MongoEngine/mongoengine/blob/master/tests/migration/refrencefield_dbref_to_object_id.py>`_.
.. Note:: Internally mongoengine handles ReferenceFields the same, so they are
converted to DBRef on loading and ObjectIds or DBRefs depending on settings
on storage.
UUIDField UUIDField
--------- ---------
@ -143,9 +164,9 @@ eg::
class Animal(Document): class Animal(Document):
uuid = UUIDField() uuid = UUIDField()
# Mark all ReferenceFields as dirty and save # Mark all UUIDFields as dirty and save
for a in Animal.objects: for a in Animal.objects:
a._mark_as_dirty('uuid') a._mark_as_changed('uuid')
a.save() a.save()
`An example test migration for UUIDFields is available on github `An example test migration for UUIDFields is available on github
@ -172,9 +193,9 @@ eg::
class Person(Document): class Person(Document):
balance = DecimalField() balance = DecimalField()
# Mark all ReferenceFields as dirty and save # Mark all DecimalField's as dirty and save
for p in Person.objects: for p in Person.objects:
p._mark_as_dirty('balance') p._mark_as_changed('balance')
p.save() p.save()
.. note:: DecimalField's have also been improved with the addition of precision .. note:: DecimalField's have also been improved with the addition of precision
@ -235,12 +256,15 @@ update your code like so: ::
mammals = Animal.objects(type="mammal").filter(order="Carnivora") # The final queryset is assgined to mammals mammals = Animal.objects(type="mammal").filter(order="Carnivora") # The final queryset is assgined to mammals
[m for m in mammals] # This will return all carnivores [m for m in mammals] # This will return all carnivores
No more len Len iterates the queryset
----------- --------------------------
If you ever did len(queryset) it previously did a count() under the covers, this If you ever did `len(queryset)` it previously did a `count()` under the covers,
caused some unusual issues - so now it has been removed in favour of the this caused some unusual issues. As `len(queryset)` is most often used by
explicit `queryset.count()` to update:: `list(queryset)` we now cache the queryset results and use that for the length.
This isn't as performant as a `count()` and if you aren't iterating the
queryset you should upgrade to use count::
# Old code # Old code
len(Animal.objects(type="mammal")) len(Animal.objects(type="mammal"))

View File

@ -15,7 +15,7 @@ import django
__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ +
list(queryset.__all__) + signals.__all__ + list(errors.__all__)) list(queryset.__all__) + signals.__all__ + list(errors.__all__))
VERSION = (0, 8, 0, '+') VERSION = (0, 8, 3)
def get_version(): def get_version():

View File

@ -3,3 +3,6 @@ from mongoengine.base.datastructures import *
from mongoengine.base.document import * from mongoengine.base.document import *
from mongoengine.base.fields import * from mongoengine.base.fields import *
from mongoengine.base.metaclasses import * from mongoengine.base.metaclasses import *
# Help with backwards compatibility
from mongoengine.errors import *

View File

@ -13,7 +13,11 @@ class BaseDict(dict):
_name = None _name = None
def __init__(self, dict_items, instance, name): def __init__(self, dict_items, instance, name):
self._instance = weakref.proxy(instance) Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(instance, (Document, EmbeddedDocument)):
self._instance = weakref.proxy(instance)
self._name = name self._name = name
return super(BaseDict, self).__init__(dict_items) return super(BaseDict, self).__init__(dict_items)
@ -80,7 +84,11 @@ class BaseList(list):
_name = None _name = None
def __init__(self, list_items, instance, name): def __init__(self, list_items, instance, name):
self._instance = weakref.proxy(instance) Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(instance, (Document, EmbeddedDocument)):
self._instance = weakref.proxy(instance)
self._name = name self._name = name
return super(BaseList, self).__init__(list_items) return super(BaseList, self).__init__(list_items)
@ -100,6 +108,14 @@ class BaseList(list):
self._mark_as_changed() self._mark_as_changed()
return super(BaseList, self).__delitem__(*args, **kwargs) return super(BaseList, self).__delitem__(*args, **kwargs)
def __setslice__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).__setslice__(*args, **kwargs)
def __delslice__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).__delslice__(*args, **kwargs)
def __getstate__(self): def __getstate__(self):
self.instance = None self.instance = None
self._dereferenced = False self._dereferenced = False

View File

@ -42,6 +42,9 @@ class BaseDocument(object):
# Combine positional arguments with named arguments. # Combine positional arguments with named arguments.
# We only want named arguments. # We only want named arguments.
field = iter(self._fields_ordered) field = iter(self._fields_ordered)
# If its an automatic id field then skip to the first defined field
if self._auto_id_field:
next(field)
for value in args: for value in args:
name = next(field) name = next(field)
if name in values: if name in values:
@ -51,6 +54,7 @@ class BaseDocument(object):
signals.pre_init.send(self.__class__, document=self, values=values) signals.pre_init.send(self.__class__, document=self, values=values)
self._data = {} self._data = {}
self._dynamic_fields = SON()
# Assign default values to instance # Assign default values to instance
for key, field in self._fields.iteritems(): for key, field in self._fields.iteritems():
@ -61,7 +65,6 @@ class BaseDocument(object):
# Set passed values after initialisation # Set passed values after initialisation
if self._dynamic: if self._dynamic:
self._dynamic_fields = {}
dynamic_data = {} dynamic_data = {}
for key, value in values.iteritems(): for key, value in values.iteritems():
if key in self._fields or key == '_id': if key in self._fields or key == '_id':
@ -116,6 +119,7 @@ class BaseDocument(object):
field = DynamicField(db_field=name) field = DynamicField(db_field=name)
field.name = name field.name = name
self._dynamic_fields[name] = field self._dynamic_fields[name] = field
self._fields_ordered += (name,)
if not name.startswith('_'): if not name.startswith('_'):
value = self.__expand_dynamic_values(name, value) value = self.__expand_dynamic_values(name, value)
@ -141,28 +145,33 @@ class BaseDocument(object):
super(BaseDocument, self).__setattr__(name, value) super(BaseDocument, self).__setattr__(name, value)
def __getstate__(self): def __getstate__(self):
removals = ("get_%s_display" % k data = {}
for k, v in self._fields.items() if v.choices) for k in ('_changed_fields', '_initialised', '_created',
for k in removals: '_dynamic_fields', '_fields_ordered'):
if hasattr(self, k): if hasattr(self, k):
delattr(self, k) data[k] = getattr(self, k)
return self.__dict__ data['_data'] = self.to_mongo()
return data
def __setstate__(self, __dict__): def __setstate__(self, data):
self.__dict__ = __dict__ if isinstance(data["_data"], SON):
self.__set_field_display() data["_data"] = self.__class__._from_son(data["_data"])._data
for k in ('_changed_fields', '_initialised', '_created', '_data',
'_fields_ordered', '_dynamic_fields'):
if k in data:
setattr(self, k, data[k])
dynamic_fields = data.get('_dynamic_fields') or SON()
for k in dynamic_fields.keys():
setattr(self, k, data["_data"].get(k))
def __iter__(self): def __iter__(self):
if 'id' in self._fields and 'id' not in self._fields_ordered:
return iter(('id', ) + self._fields_ordered)
return iter(self._fields_ordered) return iter(self._fields_ordered)
def __getitem__(self, name): def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present. """Dictionary-style field access, return a field's value if present.
""" """
try: try:
if name in self._fields: if name in self._fields_ordered:
return getattr(self, name) return getattr(self, name)
except AttributeError: except AttributeError:
pass pass
@ -212,7 +221,7 @@ class BaseDocument(object):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self): def __hash__(self):
if self.pk is None: if getattr(self, 'pk', None) is None:
# For new object # For new object
return super(BaseDocument, self).__hash__() return super(BaseDocument, self).__hash__()
else: else:
@ -238,6 +247,8 @@ class BaseDocument(object):
for field_name in self: for field_name in self:
value = self._data.get(field_name, None) value = self._data.get(field_name, None)
field = self._fields.get(field_name) field = self._fields.get(field_name)
if field is None and self._dynamic:
field = self._dynamic_fields.get(field_name)
if value is not None: if value is not None:
value = field.to_mongo(value) value = field.to_mongo(value)
@ -251,8 +262,10 @@ class BaseDocument(object):
data[field.db_field] = value data[field.db_field] = value
# If "_id" has not been set, then try and set it # If "_id" has not been set, then try and set it
if data["_id"] is None: Document = _import_class("Document")
data["_id"] = self._data.get("id", None) if isinstance(self, Document):
if data["_id"] is None:
data["_id"] = self._data.get("id", None)
if data['_id'] is None: if data['_id'] is None:
data.pop('_id') data.pop('_id')
@ -262,15 +275,6 @@ class BaseDocument(object):
not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)):
data.pop('_cls') data.pop('_cls')
if not self._dynamic:
return data
# Sort dynamic fields by key
dynamic_fields = sorted(self._dynamic_fields.iteritems(),
key=operator.itemgetter(0))
for name, field in dynamic_fields:
data[name] = field.to_mongo(self._data.get(name, None))
return data return data
def validate(self, clean=True): def validate(self, clean=True):
@ -286,11 +290,8 @@ class BaseDocument(object):
errors[NON_FIELD_ERRORS] = error errors[NON_FIELD_ERRORS] = error
# Get a list of tuples of field names and their current values # Get a list of tuples of field names and their current values
fields = [(field, self._data.get(name)) fields = [(self._fields.get(name, self._dynamic_fields.get(name)),
for name, field in self._fields.items()] self._data.get(name)) for name in self._fields_ordered]
if self._dynamic:
fields += [(field, self._data.get(name))
for name, field in self._dynamic_fields.items()]
EmbeddedDocumentField = _import_class("EmbeddedDocumentField") EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
@ -389,7 +390,7 @@ class BaseDocument(object):
if field_value: if field_value:
field_value._clear_changed_fields() field_value._clear_changed_fields()
def _get_changed_fields(self, key='', inspected=None): def _get_changed_fields(self, inspected=None):
"""Returns a list of all fields that have explicitly been changed. """Returns a list of all fields that have explicitly been changed.
""" """
EmbeddedDocument = _import_class("EmbeddedDocument") EmbeddedDocument = _import_class("EmbeddedDocument")
@ -403,11 +404,7 @@ class BaseDocument(object):
return _changed_fields return _changed_fields
inspected.add(self.id) inspected.add(self.id)
field_list = self._fields.copy() for field_name in self._fields_ordered:
if self._dynamic:
field_list.update(self._dynamic_fields)
for field_name in field_list:
db_field_name = self._db_field_map.get(field_name, field_name) db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name key = '%s.' % db_field_name
@ -420,7 +417,7 @@ class BaseDocument(object):
if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in _changed_fields): and db_field_name not in _changed_fields):
# Find all embedded fields that have been changed # Find all embedded fields that have been changed
changed = field._get_changed_fields(key, inspected) changed = field._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (key, k) for k in changed if k] _changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(field, (list, tuple, dict)) and elif (isinstance(field, (list, tuple, dict)) and
db_field_name not in _changed_fields): db_field_name not in _changed_fields):
@ -434,7 +431,7 @@ class BaseDocument(object):
if not hasattr(value, '_get_changed_fields'): if not hasattr(value, '_get_changed_fields'):
continue continue
list_key = "%s%s." % (key, index) list_key = "%s%s." % (key, index)
changed = value._get_changed_fields(list_key, inspected) changed = value._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (list_key, k) _changed_fields += ["%s%s" % (list_key, k)
for k in changed if k] for k in changed if k]
return _changed_fields return _changed_fields
@ -447,7 +444,6 @@ class BaseDocument(object):
doc = self.to_mongo() doc = self.to_mongo()
set_fields = self._get_changed_fields() set_fields = self._get_changed_fields()
set_data = {}
unset_data = {} unset_data = {}
parts = [] parts = []
if hasattr(self, '_changed_fields'): if hasattr(self, '_changed_fields'):
@ -662,7 +658,8 @@ class BaseDocument(object):
if include_cls and direction is not pymongo.GEO2D: if include_cls and direction is not pymongo.GEO2D:
index_list.insert(0, ('_cls', 1)) index_list.insert(0, ('_cls', 1))
spec['fields'] = index_list if index_list:
spec['fields'] = index_list
if spec.get('sparse', False) and len(spec['fields']) > 1: if spec.get('sparse', False) and len(spec['fields']) > 1:
raise ValueError( raise ValueError(
'Sparse indexes can only have one field in them. ' 'Sparse indexes can only have one field in them. '
@ -704,13 +701,13 @@ class BaseDocument(object):
# Add the new index to the list # Add the new index to the list
fields = [("%s%s" % (namespace, f), pymongo.ASCENDING) fields = [("%s%s" % (namespace, f), pymongo.ASCENDING)
for f in unique_fields] for f in unique_fields]
index = {'fields': fields, 'unique': True, 'sparse': sparse} index = {'fields': fields, 'unique': True, 'sparse': sparse}
unique_indexes.append(index) unique_indexes.append(index)
# Grab any embedded document field unique indexes # Grab any embedded document field unique indexes
if (field.__class__.__name__ == "EmbeddedDocumentField" and if (field.__class__.__name__ == "EmbeddedDocumentField" and
field.document_type != cls): field.document_type != cls):
field_namespace = "%s." % field_name field_namespace = "%s." % field_name
doc_cls = field.document_type doc_cls = field.document_type
unique_indexes += doc_cls._unique_with_indexes(field_namespace) unique_indexes += doc_cls._unique_with_indexes(field_namespace)
@ -718,26 +715,31 @@ class BaseDocument(object):
return unique_indexes return unique_indexes
@classmethod @classmethod
def _geo_indices(cls, inspected=None): def _geo_indices(cls, inspected=None, parent_field=None):
inspected = inspected or [] inspected = inspected or []
geo_indices = [] geo_indices = []
inspected.append(cls) inspected.append(cls)
EmbeddedDocumentField = _import_class("EmbeddedDocumentField") geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField",
GeoPointField = _import_class("GeoPointField") "PointField", "LineStringField", "PolygonField"]
geo_field_types = tuple([_import_class(field) for field in geo_field_type_names])
for field in cls._fields.values(): for field in cls._fields.values():
if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): if not isinstance(field, geo_field_types):
continue continue
if hasattr(field, 'document_type'): if hasattr(field, 'document_type'):
field_cls = field.document_type field_cls = field.document_type
if field_cls in inspected: if field_cls in inspected:
continue continue
if hasattr(field_cls, '_geo_indices'): if hasattr(field_cls, '_geo_indices'):
geo_indices += field_cls._geo_indices(inspected) geo_indices += field_cls._geo_indices(inspected, parent_field=field.db_field)
elif field._geo_index: elif field._geo_index:
field_name = field.db_field
if parent_field:
field_name = "%s.%s" % (parent_field, field_name)
geo_indices.append({'fields': geo_indices.append({'fields':
[(field.db_field, pymongo.GEO2D)]}) [(field_name, field._geo_index)]})
return geo_indices return geo_indices
@classmethod @classmethod

View File

@ -2,7 +2,8 @@ import operator
import warnings import warnings
import weakref import weakref
from bson import DBRef, ObjectId from bson import DBRef, ObjectId, SON
import pymongo
from mongoengine.common import _import_class from mongoengine.common import _import_class
from mongoengine.errors import ValidationError from mongoengine.errors import ValidationError
@ -10,7 +11,7 @@ from mongoengine.errors import ValidationError
from mongoengine.base.common import ALLOW_INHERITANCE from mongoengine.base.common import ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList from mongoengine.base.datastructures import BaseDict, BaseList
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField") __all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
class BaseField(object): class BaseField(object):
@ -35,6 +36,29 @@ class BaseField(object):
unique=False, unique_with=None, primary_key=False, unique=False, unique_with=None, primary_key=False,
validation=None, choices=None, verbose_name=None, validation=None, choices=None, verbose_name=None,
help_text=None): help_text=None):
"""
:param db_field: The database field to store this field in
(defaults to the name of the field)
:param name: Depreciated - use db_field
:param required: If the field is required. Whether it has to have a
value or not. Defaults to False.
:param default: (optional) The default value for this field if no value
has been set (or if the value has been unset). It Can be a
callable.
:param unique: Is the field value unique or not. Defaults to False.
:param unique_with: (optional) The other field this field should be
unique with.
:param primary_key: Mark this field as the primary key. Defaults to False.
:param validation: (optional) A callable to validate the value of the
field. Generally this is deprecated in favour of the
`FIELD.validate` method
:param choices: (optional) The valid choices
:param verbose_name: (optional) The verbose name for the field.
Designed to be human readable and is often used when generating
model forms from the document model.
:param help_text: (optional) The help text for this field and is often
used when generating model forms from the document model.
"""
self.db_field = (db_field or name) if not primary_key else '_id' self.db_field = (db_field or name) if not primary_key else '_id'
if name: if name:
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
@ -58,20 +82,14 @@ class BaseField(object):
BaseField.creation_counter += 1 BaseField.creation_counter += 1
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document. Do """Descriptor for retrieving a value from a field in a document.
any necessary conversion between Python and MongoDB types.
""" """
if instance is None: if instance is None:
# Document class being used rather than a document object # Document class being used rather than a document object
return self return self
# Get value from document instance if available, if not use default
value = instance._data.get(self.name)
if value is None: # Get value from document instance if available
value = self.default value = instance._data.get(self.name)
# Allow callable default values
if callable(value):
value = value()
EmbeddedDocument = _import_class('EmbeddedDocument') EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None: if isinstance(value, EmbeddedDocument) and value._instance is None:
@ -81,13 +99,24 @@ class BaseField(object):
def __set__(self, instance, value): def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document. """Descriptor for assigning a value to a field in a document.
""" """
changed = False
if (self.name not in instance._data or # If setting to None and theres a default
instance._data[self.name] != value): # Then set the value to the default value
changed = True if value is None and self.default is not None:
instance._data[self.name] = value value = self.default
if changed and instance._initialised: if callable(value):
instance._mark_as_changed(self.name) value = value()
if instance._initialised:
try:
if (self.name not in instance._data or
instance._data[self.name] != value):
instance._mark_as_changed(self.name)
except:
# Values cant be compared eg: naive and tz datetimes
# So mark it as changed
instance._mark_as_changed(self.name)
instance._data[self.name] = value
def error(self, message="", errors=None, field_name=None): def error(self, message="", errors=None, field_name=None):
"""Raises a ValidationError. """Raises a ValidationError.
@ -183,7 +212,7 @@ class ComplexBaseField(BaseField):
# Convert lists / values so we can watch for any changes on them # Convert lists / values so we can watch for any changes on them
if (isinstance(value, (list, tuple)) and if (isinstance(value, (list, tuple)) and
not isinstance(value, BaseList)): not isinstance(value, BaseList)):
value = BaseList(value, instance, self.name) value = BaseList(value, instance, self.name)
instance._data[self.name] = value instance._data[self.name] = value
elif isinstance(value, dict) and not isinstance(value, BaseDict): elif isinstance(value, dict) and not isinstance(value, BaseDict):
@ -191,8 +220,8 @@ class ComplexBaseField(BaseField):
instance._data[self.name] = value instance._data[self.name] = value
if (self._auto_dereference and instance._initialised and if (self._auto_dereference and instance._initialised and
isinstance(value, (BaseList, BaseDict)) isinstance(value, (BaseList, BaseDict))
and not value._dereferenced): and not value._dereferenced):
value = self._dereference( value = self._dereference(
value, max_depth=1, instance=instance, name=self.name value, max_depth=1, instance=instance, name=self.name
) )
@ -201,12 +230,6 @@ class ComplexBaseField(BaseField):
return value return value
def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document.
"""
instance._data[self.name] = value
instance._mark_as_changed(self.name)
def to_python(self, value): def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type.
""" """
@ -228,7 +251,7 @@ class ComplexBaseField(BaseField):
if self.field: if self.field:
value_dict = dict([(key, self.field.to_python(item)) value_dict = dict([(key, self.field.to_python(item))
for key, item in value.items()]) for key, item in value.items()])
else: else:
value_dict = {} value_dict = {}
for k, v in value.items(): for k, v in value.items():
@ -279,7 +302,7 @@ class ComplexBaseField(BaseField):
if self.field: if self.field:
value_dict = dict([(key, self.field.to_mongo(item)) value_dict = dict([(key, self.field.to_mongo(item))
for key, item in value.iteritems()]) for key, item in value.iteritems()])
else: else:
value_dict = {} value_dict = {}
for k, v in value.iteritems(): for k, v in value.iteritems():
@ -393,3 +416,100 @@ class ObjectIdField(BaseField):
ObjectId(unicode(value)) ObjectId(unicode(value))
except: except:
self.error('Invalid Object ID') self.error('Invalid Object ID')
class GeoJsonBaseField(BaseField):
"""A geo json field storing a geojson style object.
.. versionadded:: 0.8
"""
_geo_index = pymongo.GEOSPHERE
_type = "GeoBase"
def __init__(self, auto_index=True, *args, **kwargs):
"""
:param auto_index: Automatically create a "2dsphere" index. Defaults
to `True`.
"""
self._name = "%sField" % self._type
if not auto_index:
self._geo_index = False
super(GeoJsonBaseField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Validate the GeoJson object based on its type
"""
if isinstance(value, dict):
if set(value.keys()) == set(['type', 'coordinates']):
if value['type'] != self._type:
self.error('%s type must be "%s"' % (self._name, self._type))
return self.validate(value['coordinates'])
else:
self.error('%s can only accept a valid GeoJson dictionary'
' or lists of (x, y)' % self._name)
return
elif not isinstance(value, (list, tuple)):
self.error('%s can only accept lists of [x, y]' % self._name)
return
validate = getattr(self, "_validate_%s" % self._type.lower())
error = validate(value)
if error:
self.error(error)
def _validate_polygon(self, value):
if not isinstance(value, (list, tuple)):
return 'Polygons must contain list of linestrings'
# Quick and dirty validator
try:
value[0][0][0]
except:
return "Invalid Polygon must contain at least one valid linestring"
errors = []
for val in value:
error = self._validate_linestring(val, False)
if not error and val[0] != val[-1]:
error = 'LineStrings must start and end at the same point'
if error and error not in errors:
errors.append(error)
if errors:
return "Invalid Polygon:\n%s" % ", ".join(errors)
def _validate_linestring(self, value, top_level=True):
"""Validates a linestring"""
if not isinstance(value, (list, tuple)):
return 'LineStrings must contain list of coordinate pairs'
# Quick and dirty validator
try:
value[0][0]
except:
return "Invalid LineString must contain at least one valid point"
errors = []
for val in value:
error = self._validate_point(val)
if error and error not in errors:
errors.append(error)
if errors:
if top_level:
return "Invalid LineString:\n%s" % ", ".join(errors)
else:
return "%s" % ", ".join(errors)
def _validate_point(self, value):
"""Validate each set of coords"""
if not isinstance(value, (list, tuple)):
return 'Points must be a list of coordinate pairs'
elif not len(value) == 2:
return "Value (%s) must be a two-dimensional point" % repr(value)
elif (not isinstance(value[0], (float, int)) or
not isinstance(value[1], (float, int))):
return "Both values (%s) in point must be float or int" % repr(value)
def to_mongo(self, value):
if isinstance(value, dict):
return value
return SON([("type", self._type), ("coordinates", value)])

View File

@ -91,11 +91,12 @@ class DocumentMetaclass(type):
attrs['_fields'] = doc_fields attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k))
for k, v in doc_fields.iteritems()]) for k, v in doc_fields.iteritems()])
attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems())
attrs['_fields_ordered'] = tuple(i[1] for i in sorted( attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
(v.creation_counter, v.name) (v.creation_counter, v.name)
for v in doc_fields.itervalues())) for v in doc_fields.itervalues()))
attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems())
# #
# Set document hierarchy # Set document hierarchy
@ -140,8 +141,31 @@ class DocumentMetaclass(type):
base._subclasses += (_cls,) base._subclasses += (_cls,)
base._types = base._subclasses # TODO depreciate _types base._types = base._subclasses # TODO depreciate _types
# Handle delete rules
Document, EmbeddedDocument, DictField = cls._import_classes() Document, EmbeddedDocument, DictField = cls._import_classes()
if issubclass(new_class, Document):
new_class._collection = None
# Add class to the _document_registry
_document_registry[new_class._class_name] = new_class
# In Python 2, User-defined methods objects have special read-only
# attributes 'im_func' and 'im_self' which contain the function obj
# and class instance object respectively. With Python 3 these special
# attributes have been replaced by __func__ and __self__. The Blinker
# module continues to use im_func and im_self, so the code below
# copies __func__ into im_func and __self__ into im_self for
# classmethod objects in Document derived classes.
if PY3:
for key, val in new_class.__dict__.items():
if isinstance(val, classmethod):
f = val.__get__(new_class)
if hasattr(f, '__func__') and not hasattr(f, 'im_func'):
f.__dict__.update({'im_func': getattr(f, '__func__')})
if hasattr(f, '__self__') and not hasattr(f, 'im_self'):
f.__dict__.update({'im_self': getattr(f, '__self__')})
# Handle delete rules
for field in new_class._fields.itervalues(): for field in new_class._fields.itervalues():
f = field f = field
f.owner_document = new_class f.owner_document = new_class
@ -167,33 +191,11 @@ class DocumentMetaclass(type):
field.name, delete_rule) field.name, delete_rule)
if (field.name and hasattr(Document, field.name) and if (field.name and hasattr(Document, field.name) and
EmbeddedDocument not in new_class.mro()): EmbeddedDocument not in new_class.mro()):
msg = ("%s is a document method and not a valid " msg = ("%s is a document method and not a valid "
"field name" % field.name) "field name" % field.name)
raise InvalidDocumentError(msg) raise InvalidDocumentError(msg)
if issubclass(new_class, Document):
new_class._collection = None
# Add class to the _document_registry
_document_registry[new_class._class_name] = new_class
# In Python 2, User-defined methods objects have special read-only
# attributes 'im_func' and 'im_self' which contain the function obj
# and class instance object respectively. With Python 3 these special
# attributes have been replaced by __func__ and __self__. The Blinker
# module continues to use im_func and im_self, so the code below
# copies __func__ into im_func and __self__ into im_self for
# classmethod objects in Document derived classes.
if PY3:
for key, val in new_class.__dict__.items():
if isinstance(val, classmethod):
f = val.__get__(new_class)
if hasattr(f, '__func__') and not hasattr(f, 'im_func'):
f.__dict__.update({'im_func': getattr(f, '__func__')})
if hasattr(f, '__self__') and not hasattr(f, 'im_self'):
f.__dict__.update({'im_self': getattr(f, '__self__')})
return new_class return new_class
def add_to_class(self, name, value): def add_to_class(self, name, value):
@ -357,12 +359,18 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class.id = field new_class.id = field
# Set primary key if not defined by the document # Set primary key if not defined by the document
new_class._auto_id_field = False
if not new_class._meta.get('id_field'): if not new_class._meta.get('id_field'):
new_class._auto_id_field = True
new_class._meta['id_field'] = 'id' new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id') new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class._fields['id'].name = 'id' new_class._fields['id'].name = 'id'
new_class.id = new_class._fields['id'] new_class.id = new_class._fields['id']
# Prepend id field to _fields_ordered
if 'id' in new_class._fields and 'id' not in new_class._fields_ordered:
new_class._fields_ordered = ('id', ) + new_class._fields_ordered
# Merge in exceptions with parent hierarchy # Merge in exceptions with parent hierarchy
exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned)
module = attrs.get('__module__') module = attrs.get('__module__')

View File

@ -2,7 +2,19 @@ _class_registry_cache = {}
def _import_class(cls_name): def _import_class(cls_name):
"""Cached mechanism for imports""" """Cache mechanism for imports.
Due to complications of circular imports mongoengine needs to do lots of
inline imports in functions. This is inefficient as classes are
imported repeated throughout the mongoengine code. This is
compounded by some recursive functions requiring inline imports.
:mod:`mongoengine.common` provides a single point to import all these
classes. Circular imports aren't an issue as it dynamically imports the
class when first needed. Subsequent calls to the
:func:`~mongoengine.common._import_class` can then directly retrieve the
class from the :data:`mongoengine.common._class_registry_cache`.
"""
if cls_name in _class_registry_cache: if cls_name in _class_registry_cache:
return _class_registry_cache.get(cls_name) return _class_registry_cache.get(cls_name)
@ -11,6 +23,7 @@ def _import_class(cls_name):
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'FileField', 'GenericReferenceField', 'FileField', 'GenericReferenceField',
'GenericEmbeddedDocumentField', 'GeoPointField', 'GenericEmbeddedDocumentField', 'GeoPointField',
'PointField', 'LineStringField', 'PolygonField',
'ReferenceField', 'StringField', 'ComplexBaseField') 'ReferenceField', 'StringField', 'ComplexBaseField')
queryset_classes = ('OperationError',) queryset_classes = ('OperationError',)
deref_classes = ('DeReference',) deref_classes = ('DeReference',)
@ -33,4 +46,4 @@ def _import_class(cls_name):
for cls in import_classes: for cls in import_classes:
_class_registry_cache[cls] = getattr(module, cls) _class_registry_cache[cls] = getattr(module, cls)
return _class_registry_cache.get(cls_name) return _class_registry_cache.get(cls_name)

View File

@ -137,11 +137,12 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
if alias not in _dbs: if alias not in _dbs:
conn = get_connection(alias) conn = get_connection(alias)
conn_settings = _connection_settings[alias] conn_settings = _connection_settings[alias]
_dbs[alias] = conn[conn_settings['name']] db = conn[conn_settings['name']]
# Authenticate if necessary # Authenticate if necessary
if conn_settings['username'] and conn_settings['password']: if conn_settings['username'] and conn_settings['password']:
_dbs[alias].authenticate(conn_settings['username'], db.authenticate(conn_settings['username'],
conn_settings['password']) conn_settings['password'])
_dbs[alias] = db
return _dbs[alias] return _dbs[alias]

View File

@ -1,8 +1,10 @@
from mongoengine.common import _import_class from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.queryset import OperationError, QuerySet from mongoengine.queryset import QuerySet
__all__ = ("switch_db", "switch_collection", "no_dereference", "query_counter")
__all__ = ("switch_db", "switch_collection", "no_dereference",
"no_sub_classes", "query_counter")
class switch_db(object): class switch_db(object):
@ -130,6 +132,36 @@ class no_dereference(object):
return self.cls return self.cls
class no_sub_classes(object):
""" no_sub_classes context manager.
Only returns instances of this class and no sub (inherited) classes::
with no_sub_classes(Group) as Group:
Group.objects.find()
"""
def __init__(self, cls):
""" Construct the no_sub_classes context manager.
:param cls: the class to turn querying sub classes on
"""
self.cls = cls
def __enter__(self):
""" change the objects default and _auto_dereference values"""
self.cls._all_subclasses = self.cls._subclasses
self.cls._subclasses = (self.cls,)
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the default and _auto_dereference values"""
self.cls._subclasses = self.cls._all_subclasses
delattr(self.cls, '_all_subclasses')
return self.cls
class QuerySetNoDeRef(QuerySet): class QuerySetNoDeRef(QuerySet):
"""Special no_dereference QuerySet""" """Special no_dereference QuerySet"""
def __dereference(items, max_depth=1, instance=None, name=None): def __dereference(items, max_depth=1, instance=None, name=None):
@ -157,7 +189,8 @@ class query_counter(object):
def __eq__(self, value): def __eq__(self, value):
""" == Compare querycounter. """ """ == Compare querycounter. """
return value == self._get_count() counter = self._get_count()
return value == counter
def __ne__(self, value): def __ne__(self, value):
""" != Compare querycounter. """ """ != Compare querycounter. """
@ -189,6 +222,7 @@ class query_counter(object):
def _get_count(self): def _get_count(self):
""" Get the number of queries. """ """ Get the number of queries. """
count = self.db.system.profile.find().count() - self.counter ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}}
count = self.db.system.profile.find(ignore_query).count() - self.counter
self.counter += 1 self.counter += 1
return count return count

View File

@ -1,9 +1,8 @@
from importlib import import_module
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import UserManager from django.contrib.auth.models import UserManager
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import models from django.db import models
from django.utils.importlib import import_module
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _

View File

@ -1,7 +1,10 @@
from django.conf import settings from django.conf import settings
from django.contrib.sessions.backends.base import SessionBase, CreateError from django.contrib.sessions.backends.base import SessionBase, CreateError
from django.core.exceptions import SuspiciousOperation from django.core.exceptions import SuspiciousOperation
from django.utils.encoding import force_unicode try:
from django.utils.encoding import force_unicode
except ImportError:
from django.utils.encoding import force_text as force_unicode
from mongoengine.document import Document from mongoengine.document import Document
from mongoengine import fields from mongoengine import fields
@ -39,7 +42,7 @@ class MongoSession(Document):
'indexes': [ 'indexes': [
{ {
'fields': ['expire_date'], 'fields': ['expire_date'],
'expireAfterSeconds': settings.SESSION_COOKIE_AGE 'expireAfterSeconds': 0
} }
] ]
} }

View File

@ -1,11 +1,14 @@
from __future__ import with_statement
import warnings import warnings
import hashlib
import pymongo import pymongo
import re import re
from pymongo.read_preferences import ReadPreference
from bson import ObjectId
from bson.dbref import DBRef from bson.dbref import DBRef
from mongoengine import signals from mongoengine import signals
from mongoengine.common import _import_class
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
BaseDocument, BaseDict, BaseList, BaseDocument, BaseDict, BaseList,
ALLOW_INHERITANCE, get_document) ALLOW_INHERITANCE, get_document)
@ -18,6 +21,19 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument')
def includes_cls(fields):
""" Helper function used for ensuring and comparing indexes
"""
first_field = None
if len(fields):
if isinstance(fields[0], basestring):
first_field = fields[0]
elif isinstance(fields[0], (list, tuple)) and len(fields[0]):
first_field = fields[0][0]
return first_field == '_cls'
class InvalidCollectionError(Exception): class InvalidCollectionError(Exception):
pass pass
@ -53,6 +69,9 @@ class EmbeddedDocument(BaseDocument):
return self._data == other._data return self._data == other._data
return False return False
def __ne__(self, other):
return not self.__eq__(other)
class Document(BaseDocument): class Document(BaseDocument):
"""The base class used for defining the structure and properties of """The base class used for defining the structure and properties of
@ -180,8 +199,8 @@ class Document(BaseDocument):
will force an fsync on the primary server. will force an fsync on the primary server.
:param cascade: Sets the flag for cascading saves. You can set a :param cascade: Sets the flag for cascading saves. You can set a
default by setting "cascade" in the document __meta__ default by setting "cascade" in the document __meta__
:param cascade_kwargs: optional kwargs dictionary to be passed throw :param cascade_kwargs: (optional) kwargs dictionary to be passed throw
to cascading saves to cascading saves. Implies ``cascade=True``.
:param _refs: A list of processed references used in cascading saves :param _refs: A list of processed references used in cascading saves
.. versionchanged:: 0.5 .. versionchanged:: 0.5
@ -190,24 +209,28 @@ class Document(BaseDocument):
:class:`~bson.dbref.DBRef` objects that have changes are :class:`~bson.dbref.DBRef` objects that have changes are
saved as well. saved as well.
.. versionchanged:: 0.6 .. versionchanged:: 0.6
Cascade saves are optional = defaults to True, if you want Added cascading saves
.. versionchanged:: 0.8
Cascade saves are optional and default to False. If you want
fine grain control then you can turn off using document fine grain control then you can turn off using document
meta['cascade'] = False Also you can pass different kwargs to meta['cascade'] = True. Also you can pass different kwargs to
the cascade save using cascade_kwargs which overwrites the the cascade save using cascade_kwargs which overwrites the
existing kwargs with custom values existing kwargs with custom values.
""" """
signals.pre_save.send(self.__class__, document=self) signals.pre_save.send(self.__class__, document=self)
if validate: if validate:
self.validate(clean=clean) self.validate(clean=clean)
if not write_concern: if write_concern is None:
write_concern = {} write_concern = {"w": 1}
doc = self.to_mongo() doc = self.to_mongo()
created = ('_id' not in doc or self._created or force_insert) created = ('_id' not in doc or self._created or force_insert)
signals.pre_save_post_validation.send(self.__class__, document=self, created=created)
try: try:
collection = self._get_collection() collection = self._get_collection()
if created: if created:
@ -232,7 +255,6 @@ class Document(BaseDocument):
return not updated return not updated
return created return created
upsert = self._created
update_query = {} update_query = {}
if updates: if updates:
@ -241,11 +263,12 @@ class Document(BaseDocument):
update_query["$unset"] = removals update_query["$unset"] = removals
if updates or removals: if updates or removals:
last_error = collection.update(select_dict, update_query, last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern) upsert=True, **write_concern)
created = is_new_object(last_error) created = is_new_object(last_error)
cascade = (self._meta.get('cascade', True) if cascade is None:
if cascade is None else cascade) cascade = self._meta.get('cascade', False) or cascade_kwargs is not None
if cascade: if cascade:
kwargs = { kwargs = {
"force_insert": force_insert, "force_insert": force_insert,
@ -278,15 +301,17 @@ class Document(BaseDocument):
def cascade_save(self, *args, **kwargs): def cascade_save(self, *args, **kwargs):
"""Recursively saves any references / """Recursively saves any references /
generic references on an objects""" generic references on an objects"""
import fields
_refs = kwargs.get('_refs', []) or [] _refs = kwargs.get('_refs', []) or []
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
for name, cls in self._fields.items(): for name, cls in self._fields.items():
if not isinstance(cls, (fields.ReferenceField, if not isinstance(cls, (ReferenceField,
fields.GenericReferenceField)): GenericReferenceField)):
continue continue
ref = getattr(self, name) ref = self._data.get(name)
if not ref or isinstance(ref, DBRef): if not ref or isinstance(ref, DBRef):
continue continue
@ -327,7 +352,13 @@ class Document(BaseDocument):
been saved. been saved.
""" """
if not self.pk: if not self.pk:
raise OperationError('attempt to update a document not yet saved') if kwargs.get('upsert', False):
query = self.to_mongo()
if "_cls" in query:
del(query["_cls"])
return self._qs.filter(**query).update_one(**kwargs)
else:
raise OperationError('attempt to update a document not yet saved')
# Need to add shard key to query, or you get an error # Need to add shard key to query, or you get an error
return self._qs.filter(**self._object_key).update_one(**kwargs) return self._qs.filter(**self._object_key).update_one(**kwargs)
@ -346,11 +377,10 @@ class Document(BaseDocument):
signals.pre_delete.send(self.__class__, document=self) signals.pre_delete.send(self.__class__, document=self)
try: try:
self._qs.filter(**self._object_key).delete(write_concern=write_concern) self._qs.filter(**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True)
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
message = u'Could not delete document (%s)' % err.message message = u'Could not delete document (%s)' % err.message
raise OperationError(message) raise OperationError(message)
signals.post_delete.send(self.__class__, document=self) signals.post_delete.send(self.__class__, document=self)
def switch_db(self, db_alias): def switch_db(self, db_alias):
@ -390,7 +420,7 @@ class Document(BaseDocument):
user.save() user.save()
If you need to read from another database see If you need to read from another database see
:class:`~mongoengine.context_managers.switch_collection` :class:`~mongoengine.context_managers.switch_db`
:param collection_name: The database alias to use for saving the :param collection_name: The database alias to use for saving the
document document
@ -410,8 +440,8 @@ class Document(BaseDocument):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
import dereference DeReference = _import_class('DeReference')
self._data = dereference.DeReference()(self._data, max_depth) DeReference()([self], max_depth + 1)
return self return self
def reload(self, max_depth=1): def reload(self, max_depth=1):
@ -420,19 +450,16 @@ class Document(BaseDocument):
.. versionadded:: 0.1.2 .. versionadded:: 0.1.2
.. versionchanged:: 0.6 Now chainable .. versionchanged:: 0.6 Now chainable
""" """
id_field = self._meta['id_field'] obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
obj = self._qs.filter(**{id_field: self[id_field]} **self._object_key).limit(1).select_related(max_depth=max_depth)
).limit(1).select_related(max_depth=max_depth)
if obj: if obj:
obj = obj[0] obj = obj[0]
else: else:
msg = "Reloaded document has been deleted" msg = "Reloaded document has been deleted"
raise OperationError(msg) raise OperationError(msg)
for field in self._fields: for field in self._fields_ordered:
setattr(self, field, self._reload(field, obj[field])) setattr(self, field, self._reload(field, obj[field]))
if self._dynamic:
for name in self._dynamic_fields.keys():
setattr(self, name, self._reload(name, obj._data[name]))
self._changed_fields = obj._changed_fields self._changed_fields = obj._changed_fields
self._created = False self._created = False
return obj return obj
@ -448,6 +475,7 @@ class Document(BaseDocument):
value = [self._reload(key, v) for v in value] value = [self._reload(key, v) for v in value]
value = BaseList(value, self, key) value = BaseList(value, self, key)
elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)):
value._instance = None
value._changed_fields = [] value._changed_fields = []
return value return value
@ -524,15 +552,6 @@ class Document(BaseDocument):
# index to service queries against _cls # index to service queries against _cls
cls_indexed = False cls_indexed = False
def includes_cls(fields):
first_field = None
if len(fields):
if isinstance(fields[0], basestring):
first_field = fields[0]
elif isinstance(fields[0], (list, tuple)) and len(fields[0]):
first_field = fields[0][0]
return first_field == '_cls'
# Ensure document-defined indexes are created # Ensure document-defined indexes are created
if cls._meta['index_specs']: if cls._meta['index_specs']:
index_spec = cls._meta['index_specs'] index_spec = cls._meta['index_specs']
@ -552,6 +571,90 @@ class Document(BaseDocument):
collection.ensure_index('_cls', background=background, collection.ensure_index('_cls', background=background,
**index_opts) **index_opts)
@classmethod
def list_indexes(cls, go_up=True, go_down=True):
""" Lists all of the indexes that should be created for given
collection. It includes all the indexes from super- and sub-classes.
"""
if cls._meta.get('abstract'):
return []
# get all the base classes, subclasses and sieblings
classes = []
def get_classes(cls):
if (cls not in classes and
isinstance(cls, TopLevelDocumentMetaclass)):
classes.append(cls)
for base_cls in cls.__bases__:
if (isinstance(base_cls, TopLevelDocumentMetaclass) and
base_cls != Document and
not base_cls._meta.get('abstract') and
base_cls._get_collection().full_name == cls._get_collection().full_name and
base_cls not in classes):
classes.append(base_cls)
get_classes(base_cls)
for subclass in cls.__subclasses__():
if (isinstance(base_cls, TopLevelDocumentMetaclass) and
subclass._get_collection().full_name == cls._get_collection().full_name and
subclass not in classes):
classes.append(subclass)
get_classes(subclass)
get_classes(cls)
# get the indexes spec for all of the gathered classes
def get_indexes_spec(cls):
indexes = []
if cls._meta['index_specs']:
index_spec = cls._meta['index_specs']
for spec in index_spec:
spec = spec.copy()
fields = spec.pop('fields')
indexes.append(fields)
return indexes
indexes = []
for cls in classes:
for index in get_indexes_spec(cls):
if index not in indexes:
indexes.append(index)
# finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed
if [(u'_id', 1)] not in indexes:
indexes.append([(u'_id', 1)])
if (cls._meta.get('index_cls', True) and
cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True):
indexes.append([(u'_cls', 1)])
return indexes
@classmethod
def compare_indexes(cls):
""" Compares the indexes defined in MongoEngine with the ones existing
in the database. Returns any missing/extra indexes.
"""
required = cls.list_indexes()
existing = [info['key'] for info in cls._get_collection().index_information().values()]
missing = [index for index in required if index not in existing]
extra = [index for index in existing if index not in required]
# if { _cls: 1 } is missing, make sure it's *really* necessary
if [(u'_cls', 1)] in missing:
cls_obsolete = False
for index in existing:
if includes_cls(index) and index not in extra:
cls_obsolete = True
break
if cls_obsolete:
missing.remove([(u'_cls', 1)])
return {'missing': missing, 'extra': extra}
class DynamicDocument(Document): class DynamicDocument(Document):
"""A Dynamic Document class allowing flexible, expandable and uncontrolled """A Dynamic Document class allowing flexible, expandable and uncontrolled

View File

@ -8,13 +8,21 @@ import uuid
import warnings import warnings
from operator import itemgetter from operator import itemgetter
try:
import dateutil
except ImportError:
dateutil = None
else:
import dateutil.parser
import pymongo
import gridfs import gridfs
from bson import Binary, DBRef, SON, ObjectId from bson import Binary, DBRef, SON, ObjectId
from mongoengine.errors import ValidationError from mongoengine.errors import ValidationError
from mongoengine.python_support import (PY3, bin_type, txt_type, from mongoengine.python_support import (PY3, bin_type, txt_type,
str_types, StringIO) str_types, StringIO)
from base import (BaseField, ComplexBaseField, ObjectIdField, from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField,
get_document, BaseDocument) get_document, BaseDocument)
from queryset import DO_NOTHING, QuerySet from queryset import DO_NOTHING, QuerySet
from document import Document, EmbeddedDocument from document import Document, EmbeddedDocument
@ -33,9 +41,8 @@ __all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
'SortedListField', 'DictField', 'MapField', 'ReferenceField', 'SortedListField', 'DictField', 'MapField', 'ReferenceField',
'GenericReferenceField', 'BinaryField', 'GridFSError', 'GenericReferenceField', 'BinaryField', 'GridFSError',
'GridFSProxy', 'FileField', 'ImageGridFsProxy', 'GridFSProxy', 'FileField', 'ImageGridFsProxy',
'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField',
'SequenceField', 'UUIDField'] 'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField']
RECURSIVE_REFERENCE_CONSTANT = 'self' RECURSIVE_REFERENCE_CONSTANT = 'self'
@ -107,11 +114,11 @@ class URLField(StringField):
""" """
_URL_REGEX = re.compile( _URL_REGEX = re.compile(
r'^(?:http|ftp)s?://' # http:// or https:// r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain... r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain...
r'localhost|' #localhost... r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE) r'(?:/?|[/?]\S+)$', re.IGNORECASE)
def __init__(self, verify_exists=False, url_regex=None, **kwargs): def __init__(self, verify_exists=False, url_regex=None, **kwargs):
@ -128,8 +135,7 @@ class URLField(StringField):
warnings.warn( warnings.warn(
"The URLField verify_exists argument has intractable security " "The URLField verify_exists argument has intractable security "
"and performance issues. Accordingly, it has been deprecated.", "and performance issues. Accordingly, it has been deprecated.",
DeprecationWarning DeprecationWarning)
)
try: try:
request = urllib2.Request(value) request = urllib2.Request(value)
urllib2.urlopen(request) urllib2.urlopen(request)
@ -273,14 +279,14 @@ class DecimalField(BaseField):
:param precision: Number of decimal places to store. :param precision: Number of decimal places to store.
:param rounding: The rounding rule from the python decimal libary: :param rounding: The rounding rule from the python decimal libary:
- decimial.ROUND_CEILING (towards Infinity) - decimal.ROUND_CEILING (towards Infinity)
- decimial.ROUND_DOWN (towards zero) - decimal.ROUND_DOWN (towards zero)
- decimial.ROUND_FLOOR (towards -Infinity) - decimal.ROUND_FLOOR (towards -Infinity)
- decimial.ROUND_HALF_DOWN (to nearest with ties going towards zero) - decimal.ROUND_HALF_DOWN (to nearest with ties going towards zero)
- decimial.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer) - decimal.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer)
- decimial.ROUND_HALF_UP (to nearest with ties going away from zero) - decimal.ROUND_HALF_UP (to nearest with ties going away from zero)
- decimial.ROUND_UP (away from zero) - decimal.ROUND_UP (away from zero)
- decimial.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero) - decimal.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero)
Defaults to: ``decimal.ROUND_HALF_UP`` Defaults to: ``decimal.ROUND_HALF_UP``
@ -348,6 +354,11 @@ class BooleanField(BaseField):
class DateTimeField(BaseField): class DateTimeField(BaseField):
"""A datetime field. """A datetime field.
Uses the python-dateutil library if available alternatively use time.strptime
to parse the dates. Note: python-dateutil's parser is fully featured and when
installed you can utilise it to convert varing types of date formats into valid
python datetime objects.
Note: Microseconds are rounded to the nearest millisecond. Note: Microseconds are rounded to the nearest millisecond.
Pre UTC microsecond support is effecively broken. Pre UTC microsecond support is effecively broken.
Use :class:`~mongoengine.fields.ComplexDateTimeField` if you Use :class:`~mongoengine.fields.ComplexDateTimeField` if you
@ -355,13 +366,11 @@ class DateTimeField(BaseField):
""" """
def validate(self, value): def validate(self, value):
if not isinstance(value, (datetime.datetime, datetime.date)): new_value = self.to_mongo(value)
if not isinstance(new_value, (datetime.datetime, datetime.date)):
self.error(u'cannot parse date "%s"' % value) self.error(u'cannot parse date "%s"' % value)
def to_mongo(self, value): def to_mongo(self, value):
return self.prepare_query_value(None, value)
def prepare_query_value(self, op, value):
if value is None: if value is None:
return value return value
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
@ -371,8 +380,16 @@ class DateTimeField(BaseField):
if callable(value): if callable(value):
return value() return value()
if not isinstance(value, basestring):
return None
# Attempt to parse a datetime: # Attempt to parse a datetime:
# value = smart_str(value) if dateutil:
try:
return dateutil.parser.parse(value)
except ValueError:
return None
# split usecs, because they are not recognized by strptime. # split usecs, because they are not recognized by strptime.
if '.' in value: if '.' in value:
try: try:
@ -397,6 +414,9 @@ class DateTimeField(BaseField):
except ValueError: except ValueError:
return None return None
def prepare_query_value(self, op, value):
return self.to_mongo(value)
class ComplexDateTimeField(StringField): class ComplexDateTimeField(StringField):
""" """
@ -469,7 +489,7 @@ class ComplexDateTimeField(StringField):
def __get__(self, instance, owner): def __get__(self, instance, owner):
data = super(ComplexDateTimeField, self).__get__(instance, owner) data = super(ComplexDateTimeField, self).__get__(instance, owner)
if data == None: if data is None:
return datetime.datetime.now() return datetime.datetime.now()
if isinstance(data, datetime.datetime): if isinstance(data, datetime.datetime):
return data return data
@ -658,15 +678,15 @@ class ListField(ComplexBaseField):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used.
""" """
if (not isinstance(value, (list, tuple, QuerySet)) or if (not isinstance(value, (list, tuple, QuerySet)) or
isinstance(value, basestring)): isinstance(value, basestring)):
self.error('Only lists and tuples may be used in a list field') self.error('Only lists and tuples may be used in a list field')
super(ListField, self).validate(value) super(ListField, self).validate(value)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if self.field: if self.field:
if op in ('set', 'unset') and (not isinstance(value, basestring) if op in ('set', 'unset') and (not isinstance(value, basestring)
and not isinstance(value, BaseDocument) and not isinstance(value, BaseDocument)
and hasattr(value, '__iter__')): and hasattr(value, '__iter__')):
return [self.field.prepare_query_value(op, v) for v in value] return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value) return self.field.prepare_query_value(op, value)
return super(ListField, self).prepare_query_value(op, value) return super(ListField, self).prepare_query_value(op, value)
@ -701,7 +721,7 @@ class SortedListField(ListField):
value = super(SortedListField, self).to_mongo(value) value = super(SortedListField, self).to_mongo(value)
if self._ordering is not None: if self._ordering is not None:
return sorted(value, key=itemgetter(self._ordering), return sorted(value, key=itemgetter(self._ordering),
reverse=self._order_reverse) reverse=self._order_reverse)
return sorted(value, reverse=self._order_reverse) return sorted(value, reverse=self._order_reverse)
@ -854,8 +874,6 @@ class ReferenceField(BaseField):
if not self.dbref: if not self.dbref:
return document.id return document.id
return document return document
elif not self.dbref and isinstance(document, basestring):
return document
id_field_name = self.document_type._meta['id_field'] id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name] id_field = self.document_type._fields[id_field_name]
@ -880,7 +898,7 @@ class ReferenceField(BaseField):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type.
""" """
if (not self.dbref and if (not self.dbref and
not isinstance(value, (DBRef, Document, EmbeddedDocument))): not isinstance(value, (DBRef, Document, EmbeddedDocument))):
collection = self.document_type._get_collection_name() collection = self.document_type._get_collection_name()
value = DBRef(collection, self.document_type.id.to_python(value)) value = DBRef(collection, self.document_type.id.to_python(value))
return value return value
@ -1001,7 +1019,7 @@ class BinaryField(BaseField):
if not isinstance(value, (bin_type, txt_type, Binary)): if not isinstance(value, (bin_type, txt_type, Binary)):
self.error("BinaryField only accepts instances of " self.error("BinaryField only accepts instances of "
"(%s, %s, Binary)" % ( "(%s, %s, Binary)" % (
bin_type.__name__, txt_type.__name__)) bin_type.__name__, txt_type.__name__))
if self.max_bytes is not None and len(value) > self.max_bytes: if self.max_bytes is not None and len(value) > self.max_bytes:
self.error('Binary value is too long') self.error('Binary value is too long')
@ -1172,9 +1190,7 @@ class FileField(BaseField):
# Check if a file already exists for this model # Check if a file already exists for this model
grid_file = instance._data.get(self.name) grid_file = instance._data.get(self.name)
if not isinstance(grid_file, self.proxy_class): if not isinstance(grid_file, self.proxy_class):
grid_file = self.proxy_class(key=self.name, instance=instance, grid_file = self.get_proxy_obj(key=self.name, instance=instance)
db_alias=self.db_alias,
collection_name=self.collection_name)
instance._data[self.name] = grid_file instance._data[self.name] = grid_file
if not grid_file.key: if not grid_file.key:
@ -1196,14 +1212,23 @@ class FileField(BaseField):
pass pass
# Create a new proxy object as we don't already have one # Create a new proxy object as we don't already have one
instance._data[key] = self.proxy_class(key=key, instance=instance, instance._data[key] = self.get_proxy_obj(key=key, instance=instance)
collection_name=self.collection_name)
instance._data[key].put(value) instance._data[key].put(value)
else: else:
instance._data[key] = value instance._data[key] = value
instance._mark_as_changed(key) instance._mark_as_changed(key)
def get_proxy_obj(self, key, instance, db_alias=None, collection_name=None):
if db_alias is None:
db_alias = self.db_alias
if collection_name is None:
collection_name = self.collection_name
return self.proxy_class(key=key, instance=instance,
db_alias=db_alias,
collection_name=collection_name)
def to_mongo(self, value): def to_mongo(self, value):
# Store the GridFS file id in MongoDB # Store the GridFS file id in MongoDB
if isinstance(value, self.proxy_class) and value.grid_id is not None: if isinstance(value, self.proxy_class) and value.grid_id is not None:
@ -1235,15 +1260,16 @@ class ImageGridFsProxy(GridFSProxy):
Insert a image in database Insert a image in database
applying field properties (size, thumbnail_size) applying field properties (size, thumbnail_size)
""" """
if not self.instance:
import ipdb; ipdb.set_trace();
field = self.instance._fields[self.key] field = self.instance._fields[self.key]
# Handle nested fields
if hasattr(field, 'field') and isinstance(field.field, FileField):
field = field.field
try: try:
img = Image.open(file_obj) img = Image.open(file_obj)
img_format = img.format img_format = img.format
except: except Exception, e:
raise ValidationError('Invalid image') raise ValidationError('Invalid image: %s' % e)
if (field.size and (img.size[0] > field.size['width'] or if (field.size and (img.size[0] > field.size['width'] or
img.size[1] > field.size['height'])): img.size[1] > field.size['height'])):
@ -1308,6 +1334,7 @@ class ImageGridFsProxy(GridFSProxy):
height=h, height=h,
format=format, format=format,
**kwargs) **kwargs)
@property @property
def size(self): def size(self):
""" """
@ -1386,28 +1413,6 @@ class ImageField(FileField):
**kwargs) **kwargs)
class GeoPointField(BaseField):
"""A list storing a latitude and longitude.
.. versionadded:: 0.4
"""
_geo_index = True
def validate(self, value):
"""Make sure that a geo-value is of type (x, y)
"""
if not isinstance(value, (list, tuple)):
self.error('GeoPointField can only accept tuples or lists '
'of (x, y)')
if not len(value) == 2:
self.error('Value must be a two-dimensional point')
if (not isinstance(value[0], (float, int)) and
not isinstance(value[1], (float, int))):
self.error('Both values in point must be float or int')
class SequenceField(BaseField): class SequenceField(BaseField):
"""Provides a sequental counter see: """Provides a sequental counter see:
http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers
@ -1466,6 +1471,22 @@ class SequenceField(BaseField):
upsert=True) upsert=True)
return self.value_decorator(counter['next']) return self.value_decorator(counter['next'])
def get_next_value(self):
"""Helper method to get the next value for previewing.
.. warning:: There is no guarantee this will be the next value
as it is only fixed on set.
"""
sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
data = collection.find_one({"_id": sequence_id})
if data:
return self.value_decorator(data['next']+1)
return self.value_decorator(1)
def get_sequence_name(self): def get_sequence_name(self):
if self.sequence_name: if self.sequence_name:
return self.sequence_name return self.sequence_name
@ -1548,3 +1569,83 @@ class UUIDField(BaseField):
value = uuid.UUID(value) value = uuid.UUID(value)
except Exception, exc: except Exception, exc:
self.error('Could not convert to UUID: %s' % exc) self.error('Could not convert to UUID: %s' % exc)
class GeoPointField(BaseField):
"""A list storing a latitude and longitude.
.. versionadded:: 0.4
"""
_geo_index = pymongo.GEO2D
def validate(self, value):
"""Make sure that a geo-value is of type (x, y)
"""
if not isinstance(value, (list, tuple)):
self.error('GeoPointField can only accept tuples or lists '
'of (x, y)')
if not len(value) == 2:
self.error("Value (%s) must be a two-dimensional point" % repr(value))
elif (not isinstance(value[0], (float, int)) or
not isinstance(value[1], (float, int))):
self.error("Both values (%s) in point must be float or int" % repr(value))
class PointField(GeoJsonBaseField):
"""A geo json field storing a latitude and longitude.
The data is represented as:
.. code-block:: js
{ "type" : "Point" ,
"coordinates" : [x, y]}
You can either pass a dict with the full information or a list
to set the value.
Requires mongodb >= 2.4
.. versionadded:: 0.8
"""
_type = "Point"
class LineStringField(GeoJsonBaseField):
"""A geo json field storing a line of latitude and longitude coordinates.
The data is represented as:
.. code-block:: js
{ "type" : "LineString" ,
"coordinates" : [[x1, y1], [x1, y1] ... [xn, yn]]}
You can either pass a dict with the full information or a list of points.
Requires mongodb >= 2.4
.. versionadded:: 0.8
"""
_type = "LineString"
class PolygonField(GeoJsonBaseField):
"""A geo json field storing a polygon of latitude and longitude coordinates.
The data is represented as:
.. code-block:: js
{ "type" : "Polygon" ,
"coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]],
[[x1, y1], [x1, y1] ... [xn, yn]]}
You can either pass a dict with the full information or a list
of LineStrings. The first LineString being the outside and the rest being
holes.
Requires mongodb >= 2.4
.. versionadded:: 0.8
"""
_type = "Polygon"

1479
mongoengine/queryset/base.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
from collections import defaultdict from collections import defaultdict
import pymongo
from bson import SON from bson import SON
from mongoengine.common import _import_class from mongoengine.common import _import_class
@ -12,7 +13,9 @@ COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not') 'all', 'size', 'exists', 'not')
GEO_OPERATORS = ('within_distance', 'within_spherical_distance', GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
'within_box', 'within_polygon', 'near', 'near_sphere', 'within_box', 'within_polygon', 'near', 'near_sphere',
'max_distance') 'max_distance', 'geo_within', 'geo_within_box',
'geo_within_polygon', 'geo_within_center',
'geo_within_sphere', 'geo_intersects')
STRING_OPERATORS = ('contains', 'icontains', 'startswith', STRING_OPERATORS = ('contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'istartswith', 'endswith', 'iendswith',
'exact', 'iexact') 'exact', 'iexact')
@ -21,7 +24,8 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
STRING_OPERATORS + CUSTOM_OPERATORS) STRING_OPERATORS + CUSTOM_OPERATORS)
UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push',
'push_all', 'pull', 'pull_all', 'add_to_set') 'push_all', 'pull', 'pull_all', 'add_to_set',
'set_on_insert')
def query(_doc_cls=None, _field_operation=False, **query): def query(_doc_cls=None, _field_operation=False, **query):
@ -81,32 +85,17 @@ def query(_doc_cls=None, _field_operation=False, **query):
value = field value = field
else: else:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op in ('in', 'nin', 'all', 'near'): elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
# if op and op not in COMPARISON_OPERATORS: # if op and op not in COMPARISON_OPERATORS:
if op: if op:
if op in GEO_OPERATORS: if op in GEO_OPERATORS:
if op == "within_distance": value = _geo_operator(field, op, value)
value = {'$within': {'$center': value}}
elif op == "within_spherical_distance":
value = {'$within': {'$centerSphere': value}}
elif op == "within_polygon":
value = {'$within': {'$polygon': value}}
elif op == "near":
value = {'$near': value}
elif op == "near_sphere":
value = {'$nearSphere': value}
elif op == 'within_box':
value = {'$within': {'$box': value}}
elif op == "max_distance":
value = {'$maxDistance': value}
else:
raise NotImplementedError("Geo method '%s' has not "
"been implemented" % op)
elif op in CUSTOM_OPERATORS: elif op in CUSTOM_OPERATORS:
if op == 'match': if op == 'match':
value = field.prepare_query_value(op, value)
value = {"$elemMatch": value} value = {"$elemMatch": value}
else: else:
NotImplementedError("Custom method '%s' has not " NotImplementedError("Custom method '%s' has not "
@ -176,7 +165,9 @@ def update(_doc_cls=None, **update):
if value > 0: if value > 0:
value = -value value = -value
elif op == 'add_to_set': elif op == 'add_to_set':
op = op.replace('_to_set', 'ToSet') op = 'addToSet'
elif op == 'set_on_insert':
op = "setOnInsert"
match = None match = None
if parts[-1] in COMPARISON_OPERATORS: if parts[-1] in COMPARISON_OPERATORS:
@ -250,3 +241,76 @@ def update(_doc_cls=None, **update):
mongo_update[key].update(value) mongo_update[key].update(value)
return mongo_update return mongo_update
def _geo_operator(field, op, value):
"""Helper to return the query for a given geo query"""
if field._geo_index == pymongo.GEO2D:
if op == "within_distance":
value = {'$within': {'$center': value}}
elif op == "within_spherical_distance":
value = {'$within': {'$centerSphere': value}}
elif op == "within_polygon":
value = {'$within': {'$polygon': value}}
elif op == "near":
value = {'$near': value}
elif op == "near_sphere":
value = {'$nearSphere': value}
elif op == 'within_box':
value = {'$within': {'$box': value}}
elif op == "max_distance":
value = {'$maxDistance': value}
else:
raise NotImplementedError("Geo method '%s' has not "
"been implemented for a GeoPointField" % op)
else:
if op == "geo_within":
value = {"$geoWithin": _infer_geometry(value)}
elif op == "geo_within_box":
value = {"$geoWithin": {"$box": value}}
elif op == "geo_within_polygon":
value = {"$geoWithin": {"$polygon": value}}
elif op == "geo_within_center":
value = {"$geoWithin": {"$center": value}}
elif op == "geo_within_sphere":
value = {"$geoWithin": {"$centerSphere": value}}
elif op == "geo_intersects":
value = {"$geoIntersects": _infer_geometry(value)}
elif op == "near":
value = {'$near': _infer_geometry(value)}
elif op == "max_distance":
value = {'$maxDistance': value}
else:
raise NotImplementedError("Geo method '%s' has not "
"been implemented for a %s " % (op, field._name))
return value
def _infer_geometry(value):
"""Helper method that tries to infer the $geometry shape for a given value"""
if isinstance(value, dict):
if "$geometry" in value:
return value
elif 'coordinates' in value and 'type' in value:
return {"$geometry": value}
raise InvalidQueryError("Invalid $geometry dictionary should have "
"type and coordinates keys")
elif isinstance(value, (list, set)):
try:
value[0][0][0]
return {"$geometry": {"type": "Polygon", "coordinates": value}}
except:
pass
try:
value[0][0]
return {"$geometry": {"type": "LineString", "coordinates": value}}
except:
pass
try:
value[0]
return {"$geometry": {"type": "Point", "coordinates": value}}
except:
pass
raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary "
"or (nested) lists of coordinate(s)")

View File

@ -23,6 +23,10 @@ class QNodeVisitor(object):
return query return query
class DuplicateQueryConditionsError(InvalidQueryError):
pass
class SimplificationVisitor(QNodeVisitor): class SimplificationVisitor(QNodeVisitor):
"""Simplifies query trees by combinging unnecessary 'and' connection nodes """Simplifies query trees by combinging unnecessary 'and' connection nodes
into a single Q-object. into a single Q-object.
@ -33,7 +37,11 @@ class SimplificationVisitor(QNodeVisitor):
# The simplification only applies to 'simple' queries # The simplification only applies to 'simple' queries
if all(isinstance(node, Q) for node in combination.children): if all(isinstance(node, Q) for node in combination.children):
queries = [n.query for n in combination.children] queries = [n.query for n in combination.children]
return Q(**self._query_conjunction(queries)) try:
return Q(**self._query_conjunction(queries))
except DuplicateQueryConditionsError:
# Cannot be simplified
pass
return combination return combination
def _query_conjunction(self, queries): def _query_conjunction(self, queries):
@ -47,8 +55,7 @@ class SimplificationVisitor(QNodeVisitor):
# to a single field # to a single field
intersection = ops.intersection(query_ops) intersection = ops.intersection(query_ops)
if intersection: if intersection:
msg = 'Duplicate query conditions: ' raise DuplicateQueryConditionsError()
raise InvalidQueryError(msg + ', '.join(intersection))
query_ops.update(ops) query_ops.update(ops)
combined_query.update(copy.deepcopy(query)) combined_query.update(copy.deepcopy(query))
@ -122,8 +129,7 @@ class QCombination(QNode):
# If the child is a combination of the same type, we can merge its # If the child is a combination of the same type, we can merge its
# children directly into this combinations children # children directly into this combinations children
if isinstance(node, QCombination) and node.operation == operation: if isinstance(node, QCombination) and node.operation == operation:
# self.children += node.children self.children += node.children
self.children.append(node)
else: else:
self.children.append(node) self.children.append(node)

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save', __all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation',
'pre_delete', 'post_delete'] 'post_save', 'pre_delete', 'post_delete']
signals_available = False signals_available = False
try: try:
@ -39,6 +39,7 @@ _signals = Namespace()
pre_init = _signals.signal('pre_init') pre_init = _signals.signal('pre_init')
post_init = _signals.signal('post_init') post_init = _signals.signal('post_init')
pre_save = _signals.signal('pre_save') pre_save = _signals.signal('pre_save')
pre_save_post_validation = _signals.signal('pre_save_post_validation')
post_save = _signals.signal('post_save') post_save = _signals.signal('post_save')
pre_delete = _signals.signal('pre_delete') pre_delete = _signals.signal('pre_delete')
post_delete = _signals.signal('post_delete') post_delete = _signals.signal('post_delete')

View File

@ -5,7 +5,7 @@
%define srcname mongoengine %define srcname mongoengine
Name: python-%{srcname} Name: python-%{srcname}
Version: 0.7.10 Version: 0.8.3
Release: 1%{?dist} Release: 1%{?dist}
Summary: A Python Document-Object Mapper for working with MongoDB Summary: A Python Document-Object Mapper for working with MongoDB

View File

@ -51,13 +51,13 @@ CLASSIFIERS = [
extra_opts = {} extra_opts = {}
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
extra_opts['use_2to3'] = True extra_opts['use_2to3'] = True
extra_opts['tests_require'] = ['nose', 'coverage', 'blinker'] extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'django>=1.5.1']
extra_opts['packages'] = find_packages(exclude=('tests',)) extra_opts['packages'] = find_packages(exclude=('tests',))
if "test" in sys.argv or "nosetests" in sys.argv: if "test" in sys.argv or "nosetests" in sys.argv:
extra_opts['packages'].append("tests") extra_opts['packages'].append("tests")
extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]}
else: else:
extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL'] extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2==2.6', 'python-dateutil']
extra_opts['packages'] = find_packages(exclude=('tests',)) extra_opts['packages'] = find_packages(exclude=('tests',))
setup(name='mongoengine', setup(name='mongoengine',
@ -74,7 +74,7 @@ setup(name='mongoengine',
long_description=LONG_DESCRIPTION, long_description=LONG_DESCRIPTION,
platforms=['any'], platforms=['any'],
classifiers=CLASSIFIERS, classifiers=CLASSIFIERS,
install_requires=['pymongo'], install_requires=['pymongo>=2.5'],
test_suite='nose.collector', test_suite='nose.collector',
**extra_opts **extra_opts
) )

View File

@ -1,12 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import unittest import unittest
from mongoengine import * from mongoengine import *
from mongoengine.queryset import NULLIFY from mongoengine.queryset import NULLIFY, PULL
from mongoengine.connection import get_db from mongoengine.connection import get_db
__all__ = ("ClassMethodsTest", ) __all__ = ("ClassMethodsTest", )
@ -86,6 +85,172 @@ class ClassMethodsTest(unittest.TestCase):
self.assertEqual(self.Person._meta['delete_rules'], self.assertEqual(self.Person._meta['delete_rules'],
{(Job, 'employee'): NULLIFY}) {(Job, 'employee'): NULLIFY})
def test_compare_indexes(self):
""" Ensure that the indexes are properly created and that
compare_indexes identifies the missing/extra indexes
"""
class BlogPost(Document):
author = StringField()
title = StringField()
description = StringField()
tags = StringField()
meta = {
'indexes': [('author', 'title')]
}
BlogPost.drop_collection()
BlogPost.ensure_indexes()
self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] })
BlogPost.ensure_index(['author', 'description'])
self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [[('author', 1), ('description', 1)]] })
BlogPost._get_collection().drop_index('author_1_description_1')
self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] })
BlogPost._get_collection().drop_index('author_1_title_1')
self.assertEqual(BlogPost.compare_indexes(), { 'missing': [[('author', 1), ('title', 1)]], 'extra': [] })
def test_compare_indexes_inheritance(self):
""" Ensure that the indexes are properly created and that
compare_indexes identifies the missing/extra indexes for subclassed
documents (_cls included)
"""
class BlogPost(Document):
author = StringField()
title = StringField()
description = StringField()
meta = {
'allow_inheritance': True
}
class BlogPostWithTags(BlogPost):
tags = StringField()
tag_list = ListField(StringField())
meta = {
'indexes': [('author', 'tags')]
}
BlogPost.drop_collection()
BlogPost.ensure_indexes()
BlogPostWithTags.ensure_indexes()
self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] })
BlogPostWithTags.ensure_index(['author', 'tag_list'])
self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [[('_cls', 1), ('author', 1), ('tag_list', 1)]] })
BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tag_list_1')
self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] })
BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tags_1')
self.assertEqual(BlogPost.compare_indexes(), { 'missing': [[('_cls', 1), ('author', 1), ('tags', 1)]], 'extra': [] })
def test_compare_indexes_multiple_subclasses(self):
""" Ensure that compare_indexes behaves correctly if called from a
class, which base class has multiple subclasses
"""
class BlogPost(Document):
author = StringField()
title = StringField()
description = StringField()
meta = {
'allow_inheritance': True
}
class BlogPostWithTags(BlogPost):
tags = StringField()
tag_list = ListField(StringField())
meta = {
'indexes': [('author', 'tags')]
}
class BlogPostWithCustomField(BlogPost):
custom = DictField()
meta = {
'indexes': [('author', 'custom')]
}
BlogPost.ensure_indexes()
BlogPostWithTags.ensure_indexes()
BlogPostWithCustomField.ensure_indexes()
self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] })
self.assertEqual(BlogPostWithTags.compare_indexes(), { 'missing': [], 'extra': [] })
self.assertEqual(BlogPostWithCustomField.compare_indexes(), { 'missing': [], 'extra': [] })
def test_list_indexes_inheritance(self):
""" ensure that all of the indexes are listed regardless of the super-
or sub-class that we call it from
"""
class BlogPost(Document):
author = StringField()
title = StringField()
description = StringField()
meta = {
'allow_inheritance': True
}
class BlogPostWithTags(BlogPost):
tags = StringField()
meta = {
'indexes': [('author', 'tags')]
}
class BlogPostWithTagsAndExtraText(BlogPostWithTags):
extra_text = StringField()
meta = {
'indexes': [('author', 'tags', 'extra_text')]
}
BlogPost.drop_collection()
BlogPost.ensure_indexes()
BlogPostWithTags.ensure_indexes()
BlogPostWithTagsAndExtraText.ensure_indexes()
self.assertEqual(BlogPost.list_indexes(),
BlogPostWithTags.list_indexes())
self.assertEqual(BlogPost.list_indexes(),
BlogPostWithTagsAndExtraText.list_indexes())
self.assertEqual(BlogPost.list_indexes(),
[[('_cls', 1), ('author', 1), ('tags', 1)],
[('_cls', 1), ('author', 1), ('tags', 1), ('extra_text', 1)],
[(u'_id', 1)], [('_cls', 1)]])
def test_register_delete_rule_inherited(self):
class Vaccine(Document):
name = StringField(required=True)
meta = {"indexes": ["name"]}
class Animal(Document):
family = StringField(required=True)
vaccine_made = ListField(ReferenceField("Vaccine", reverse_delete_rule=PULL))
meta = {"allow_inheritance": True, "indexes": ["family"]}
class Cat(Animal):
name = StringField(required=True)
self.assertEqual(Vaccine._meta['delete_rules'][(Animal, 'vaccine_made')], PULL)
self.assertEqual(Vaccine._meta['delete_rules'][(Cat, 'vaccine_made')], PULL)
def test_collection_naming(self): def test_collection_naming(self):
"""Ensure that a collection with a specified name may be used. """Ensure that a collection with a specified name may be used.
""" """

View File

@ -3,6 +3,7 @@ import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import unittest import unittest
from bson import SON
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
@ -613,13 +614,13 @@ class DeltaTest(unittest.TestCase):
Person.drop_collection() Person.drop_collection()
p = Person(name="James", age=34) p = Person(name="James", age=34)
self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', self.assertEqual(p._delta(), (
'_cls': 'Person'}, {})) SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {}))
p.doc = 123 p.doc = 123
del(p.doc) del(p.doc)
self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', self.assertEqual(p._delta(), (
'_cls': 'Person'}, {'doc': 1})) SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {}))
p = Person() p = Person()
p.name = "Dean" p.name = "Dean"
@ -631,14 +632,14 @@ class DeltaTest(unittest.TestCase):
self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._get_changed_fields(), ['age'])
self.assertEqual(p._delta(), ({'age': 24}, {})) self.assertEqual(p._delta(), ({'age': 24}, {}))
p = self.Person.objects(age=22).get() p = Person.objects(age=22).get()
p.age = 24 p.age = 24
self.assertEqual(p.age, 24) self.assertEqual(p.age, 24)
self.assertEqual(p._get_changed_fields(), ['age']) self.assertEqual(p._get_changed_fields(), ['age'])
self.assertEqual(p._delta(), ({'age': 24}, {})) self.assertEqual(p._delta(), ({'age': 24}, {}))
p.save() p.save()
self.assertEqual(1, self.Person.objects(age=24).count()) self.assertEqual(1, Person.objects(age=24).count())
def test_dynamic_delta(self): def test_dynamic_delta(self):

View File

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import with_statement
import unittest import unittest
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -217,7 +216,7 @@ class IndexesTest(unittest.TestCase):
} }
self.assertEqual([{'fields': [('location.point', '2d')]}], self.assertEqual([{'fields': [('location.point', '2d')]}],
Place._meta['index_specs']) Place._meta['index_specs'])
Place.ensure_indexes() Place.ensure_indexes()
info = Place._get_collection().index_information() info = Place._get_collection().index_information()
@ -231,8 +230,7 @@ class IndexesTest(unittest.TestCase):
location = DictField() location = DictField()
class Place(Document): class Place(Document):
current = DictField( current = DictField(field=EmbeddedDocumentField('EmbeddedLocation'))
field=EmbeddedDocumentField('EmbeddedLocation'))
meta = { meta = {
'allow_inheritance': True, 'allow_inheritance': True,
'indexes': [ 'indexes': [
@ -241,7 +239,7 @@ class IndexesTest(unittest.TestCase):
} }
self.assertEqual([{'fields': [('current.location.point', '2d')]}], self.assertEqual([{'fields': [('current.location.point', '2d')]}],
Place._meta['index_specs']) Place._meta['index_specs'])
Place.ensure_indexes() Place.ensure_indexes()
info = Place._get_collection().index_information() info = Place._get_collection().index_information()
@ -264,7 +262,7 @@ class IndexesTest(unittest.TestCase):
self.assertEqual([{'fields': [('addDate', -1)], 'unique': True, self.assertEqual([{'fields': [('addDate', -1)], 'unique': True,
'sparse': True}], 'sparse': True}],
BlogPost._meta['index_specs']) BlogPost._meta['index_specs'])
BlogPost.drop_collection() BlogPost.drop_collection()
@ -382,8 +380,7 @@ class IndexesTest(unittest.TestCase):
self.assertEqual(sorted(info.keys()), ['_id_', 'tags.tag_1']) self.assertEqual(sorted(info.keys()), ['_id_', 'tags.tag_1'])
post1 = BlogPost(title="Embedded Indexes tests in place", post1 = BlogPost(title="Embedded Indexes tests in place",
tags=[Tag(name="about"), Tag(name="time")] tags=[Tag(name="about"), Tag(name="time")])
)
post1.save() post1.save()
BlogPost.drop_collection() BlogPost.drop_collection()
@ -400,29 +397,6 @@ class IndexesTest(unittest.TestCase):
info = RecursiveDocument._get_collection().index_information() info = RecursiveDocument._get_collection().index_information()
self.assertEqual(sorted(info.keys()), ['_cls_1', '_id_']) self.assertEqual(sorted(info.keys()), ['_cls_1', '_id_'])
def test_geo_indexes_recursion(self):
class Location(Document):
name = StringField()
location = GeoPointField()
class Parent(Document):
name = StringField()
location = ReferenceField(Location, dbref=False)
Location.drop_collection()
Parent.drop_collection()
list(Parent.objects)
collection = Parent._get_collection()
info = collection.index_information()
self.assertFalse('location_2d' in info)
self.assertEqual(len(Parent._geo_indices()), 0)
self.assertEqual(len(Location._geo_indices()), 1)
def test_covered_index(self): def test_covered_index(self):
"""Ensure that covered indexes can be used """Ensure that covered indexes can be used
""" """
@ -433,7 +407,7 @@ class IndexesTest(unittest.TestCase):
meta = { meta = {
'indexes': ['a'], 'indexes': ['a'],
'allow_inheritance': False 'allow_inheritance': False
} }
Test.drop_collection() Test.drop_collection()
@ -633,7 +607,7 @@ class IndexesTest(unittest.TestCase):
list(Log.objects) list(Log.objects)
info = Log.objects._collection.index_information() info = Log.objects._collection.index_information()
self.assertEqual(3600, self.assertEqual(3600,
info['created_1']['expireAfterSeconds']) info['created_1']['expireAfterSeconds'])
def test_unique_and_indexes(self): def test_unique_and_indexes(self):
"""Ensure that 'unique' constraints aren't overridden by """Ensure that 'unique' constraints aren't overridden by

View File

@ -189,6 +189,41 @@ class InheritanceTest(unittest.TestCase):
self.assertEqual(Employee._get_collection_name(), self.assertEqual(Employee._get_collection_name(),
Person._get_collection_name()) Person._get_collection_name())
def test_indexes_and_multiple_inheritance(self):
""" Ensure that all of the indexes are created for a document with
multiple inheritance.
"""
class A(Document):
a = StringField()
meta = {
'allow_inheritance': True,
'indexes': ['a']
}
class B(Document):
b = StringField()
meta = {
'allow_inheritance': True,
'indexes': ['b']
}
class C(A, B):
pass
A.drop_collection()
B.drop_collection()
C.drop_collection()
C.ensure_indexes()
self.assertEqual(
sorted([idx['key'] for idx in C._get_collection().index_information().values()]),
sorted([[(u'_cls', 1), (u'b', 1)], [(u'_id', 1)], [(u'_cls', 1), (u'a', 1)]])
)
def test_polymorphic_queries(self): def test_polymorphic_queries(self):
"""Ensure that the correct subclasses are returned from a query """Ensure that the correct subclasses are returned from a query
""" """

View File

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -10,7 +9,9 @@ import unittest
import uuid import uuid
from datetime import datetime from datetime import datetime
from tests.fixtures import PickleEmbedded, PickleTest from bson import DBRef
from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
PickleDyanmicEmbedded, PickleDynamicTest)
from mongoengine import * from mongoengine import *
from mongoengine.errors import (NotRegistered, InvalidDocumentError, from mongoengine.errors import (NotRegistered, InvalidDocumentError,
@ -320,8 +321,8 @@ class InstanceTest(unittest.TestCase):
Location.drop_collection() Location.drop_collection()
self.assertEquals(Area, get_document("Area")) self.assertEqual(Area, get_document("Area"))
self.assertEquals(Area, get_document("Location.Area")) self.assertEqual(Area, get_document("Location.Area"))
def test_creation(self): def test_creation(self):
"""Ensure that document may be created using keyword arguments. """Ensure that document may be created using keyword arguments.
@ -443,6 +444,13 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(),
['_cls', 'name', 'age', 'salary']) ['_cls', 'name', 'age', 'salary'])
def test_embedded_document_to_mongo_id(self):
class SubDoc(EmbeddedDocument):
id = StringField(required=True)
sub_doc = SubDoc(id="abc")
self.assertEqual(sub_doc.to_mongo().keys(), ['id'])
def test_embedded_document(self): def test_embedded_document(self):
"""Ensure that embedded documents are set up correctly. """Ensure that embedded documents are set up correctly.
""" """
@ -509,12 +517,12 @@ class InstanceTest(unittest.TestCase):
t = TestDocument(status="published") t = TestDocument(status="published")
t.save(clean=False) t.save(clean=False)
self.assertEquals(t.pub_date, None) self.assertEqual(t.pub_date, None)
t = TestDocument(status="published") t = TestDocument(status="published")
t.save(clean=True) t.save(clean=True)
self.assertEquals(type(t.pub_date), datetime) self.assertEqual(type(t.pub_date), datetime)
def test_document_embedded_clean(self): def test_document_embedded_clean(self):
class TestEmbeddedDocument(EmbeddedDocument): class TestEmbeddedDocument(EmbeddedDocument):
@ -546,7 +554,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}}) self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}})
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save() t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save()
self.assertEquals(t.doc.z, 35) self.assertEqual(t.doc.z, 35)
# Asserts not raises # Asserts not raises
t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5)) t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5))
@ -665,7 +673,7 @@ class InstanceTest(unittest.TestCase):
p = Person.objects(name="Wilson Jr").get() p = Person.objects(name="Wilson Jr").get()
p.parent.name = "Daddy Wilson" p.parent.name = "Daddy Wilson"
p.save() p.save(cascade=True)
p1.reload() p1.reload()
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
@ -684,14 +692,12 @@ class InstanceTest(unittest.TestCase):
p2 = Person(name="Wilson Jr") p2 = Person(name="Wilson Jr")
p2.parent = p1 p2.parent = p1
p1.name = "Daddy Wilson"
p2.save(force_insert=True, cascade_kwargs={"force_insert": False}) p2.save(force_insert=True, cascade_kwargs={"force_insert": False})
p = Person.objects(name="Wilson Jr").get()
p.parent.name = "Daddy Wilson"
p.save()
p1.reload() p1.reload()
self.assertEqual(p1.name, p.parent.name) p2.reload()
self.assertEqual(p1.name, p2.parent.name)
def test_save_cascade_meta_false(self): def test_save_cascade_meta_false(self):
@ -766,6 +772,10 @@ class InstanceTest(unittest.TestCase):
p.parent.name = "Daddy Wilson" p.parent.name = "Daddy Wilson"
p.save() p.save()
p1.reload()
self.assertNotEqual(p1.name, p.parent.name)
p.save(cascade=True)
p1.reload() p1.reload()
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
@ -853,6 +863,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, None) self.assertEqual(person.name, None)
self.assertEqual(person.age, None) self.assertEqual(person.age, None)
def test_inserts_if_you_set_the_pk(self):
p1 = self.Person(name='p1', id=bson.ObjectId()).save()
p2 = self.Person(name='p2')
p2.id = bson.ObjectId()
p2.save()
self.assertEqual(2, self.Person.objects.count())
def test_can_save_if_not_included(self): def test_can_save_if_not_included(self):
class EmbeddedDoc(EmbeddedDocument): class EmbeddedDoc(EmbeddedDocument):
@ -1011,6 +1029,99 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.age, 21) self.assertEqual(person.age, 21)
self.assertEqual(person.active, False) self.assertEqual(person.active, False)
def test_query_count_when_saving(self):
"""Ensure references don't cause extra fetches when saving"""
class Organization(Document):
name = StringField()
class User(Document):
name = StringField()
orgs = ListField(ReferenceField('Organization'))
class Feed(Document):
name = StringField()
class UserSubscription(Document):
name = StringField()
user = ReferenceField(User)
feed = ReferenceField(Feed)
Organization.drop_collection()
User.drop_collection()
Feed.drop_collection()
UserSubscription.drop_collection()
o1 = Organization(name="o1").save()
o2 = Organization(name="o2").save()
u1 = User(name="Ross", orgs=[o1, o2]).save()
f1 = Feed(name="MongoEngine").save()
sub = UserSubscription(user=u1, feed=f1).save()
user = User.objects.first()
# Even if stored as ObjectId's internally mongoengine uses DBRefs
# As ObjectId's aren't automatically derefenced
self.assertTrue(isinstance(user._data['orgs'][0], DBRef))
self.assertTrue(isinstance(user.orgs[0], Organization))
self.assertTrue(isinstance(user._data['orgs'][0], Organization))
# Changing a value
with query_counter() as q:
self.assertEqual(q, 0)
sub = UserSubscription.objects.first()
self.assertEqual(q, 1)
sub.name = "Test Sub"
sub.save()
self.assertEqual(q, 2)
# Changing a value that will cascade
with query_counter() as q:
self.assertEqual(q, 0)
sub = UserSubscription.objects.first()
self.assertEqual(q, 1)
sub.user.name = "Test"
self.assertEqual(q, 2)
sub.save(cascade=True)
self.assertEqual(q, 3)
# Changing a value and one that will cascade
with query_counter() as q:
self.assertEqual(q, 0)
sub = UserSubscription.objects.first()
sub.name = "Test Sub 2"
self.assertEqual(q, 1)
sub.user.name = "Test 2"
self.assertEqual(q, 2)
sub.save(cascade=True)
self.assertEqual(q, 4) # One for the UserSub and one for the User
# Saving with just the refs
with query_counter() as q:
self.assertEqual(q, 0)
sub = UserSubscription(user=u1.pk, feed=f1.pk)
self.assertEqual(q, 0)
sub.save()
self.assertEqual(q, 1)
# Saving with just the refs on a ListField
with query_counter() as q:
self.assertEqual(q, 0)
User(name="Bob", orgs=[o1.pk, o2.pk]).save()
self.assertEqual(q, 1)
# Saving new objects
with query_counter() as q:
self.assertEqual(q, 0)
user = User.objects.first()
self.assertEqual(q, 1)
feed = Feed.objects.first()
self.assertEqual(q, 2)
sub = UserSubscription(user=user, feed=feed)
self.assertEqual(q, 2) # Check no change
sub.save()
self.assertEqual(q, 3)
def test_set_unset_one_operation(self): def test_set_unset_one_operation(self):
"""Ensure that $set and $unset actions are performed in the same """Ensure that $set and $unset actions are performed in the same
operation. operation.
@ -1702,6 +1813,7 @@ class InstanceTest(unittest.TestCase):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded() pickle_doc.embedded = PickleEmbedded()
pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
pickle_doc.save() pickle_doc.save()
pickled_doc = pickle.dumps(pickle_doc) pickled_doc = pickle.dumps(pickle_doc)
@ -1723,6 +1835,35 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(pickle_doc.string, "Two") self.assertEqual(pickle_doc.string, "Two")
self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) self.assertEqual(pickle_doc.lists, ["1", "2", "3"])
def test_dynamic_document_pickle(self):
pickle_doc = PickleDynamicTest(name="test", number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar")
pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
pickle_doc.save()
pickled_doc = pickle.dumps(pickle_doc)
resurrected = pickle.loads(pickled_doc)
self.assertEqual(resurrected, pickle_doc)
self.assertEqual(resurrected._fields_ordered,
pickle_doc._fields_ordered)
self.assertEqual(resurrected._dynamic_fields.keys(),
pickle_doc._dynamic_fields.keys())
self.assertEqual(resurrected.embedded, pickle_doc.embedded)
self.assertEqual(resurrected.embedded._fields_ordered,
pickle_doc.embedded._fields_ordered)
self.assertEqual(resurrected.embedded._dynamic_fields.keys(),
pickle_doc.embedded._dynamic_fields.keys())
def test_picklable_on_signals(self):
pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded()
pickle_doc.save()
pickle_doc.delete()
def test_throw_invalid_document_error(self): def test_throw_invalid_document_error(self):
# test handles people trying to upsert # test handles people trying to upsert
@ -1896,11 +2037,11 @@ class InstanceTest(unittest.TestCase):
A.objects.all() A.objects.all()
self.assertEquals('testdb-2', B._meta.get('db_alias')) self.assertEqual('testdb-2', B._meta.get('db_alias'))
self.assertEquals('mongoenginetest', self.assertEqual('mongoenginetest',
A._get_collection().database.name) A._get_collection().database.name)
self.assertEquals('mongoenginetest2', self.assertEqual('mongoenginetest2',
B._get_collection().database.name) B._get_collection().database.name)
def test_db_alias_propagates(self): def test_db_alias_propagates(self):
"""db_alias propagates? """db_alias propagates?
@ -2179,6 +2320,16 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_mixed_creation_dynamic(self):
"""Ensure that document may be created using mixed arguments.
"""
class Person(DynamicDocument):
name = StringField()
person = Person("Test User", age=42)
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42)
def test_bad_mixed_creation(self): def test_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments """Ensure that document gives correct error when duplicating arguments
""" """

View File

@ -1,2 +1,3 @@
from fields import * from fields import *
from file_tests import * from file_tests import *
from geo import *

View File

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -7,6 +6,11 @@ import datetime
import unittest import unittest
import uuid import uuid
try:
import dateutil
except ImportError:
dateutil = None
from decimal import Decimal from decimal import Decimal
from bson import Binary, DBRef, ObjectId from bson import Binary, DBRef, ObjectId
@ -30,20 +34,137 @@ class FieldTest(unittest.TestCase):
self.db.drop_collection('fs.files') self.db.drop_collection('fs.files')
self.db.drop_collection('fs.chunks') self.db.drop_collection('fs.chunks')
def test_default_values(self): def test_default_values_nothing_set(self):
"""Ensure that default field values are used when creating a document. """Ensure that default field values are used when creating a document.
""" """
class Person(Document): class Person(Document):
name = StringField() name = StringField()
age = IntField(default=30, help_text="Your real age") age = IntField(default=30, required=False)
userid = StringField(default=lambda: 'test', verbose_name="User Identity") userid = StringField(default=lambda: 'test', required=True)
created = DateTimeField(default=datetime.datetime.utcnow)
person = Person(name='Test Person') person = Person(name="Ross")
self.assertEqual(person._data['age'], 30)
self.assertEqual(person._data['userid'], 'test') # Confirm saving now would store values
self.assertEqual(person._fields['name'].help_text, None) data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(person._fields['age'].help_text, "Your real age") self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid'])
self.assertEqual(person._fields['userid'].verbose_name, "User Identity")
self.assertTrue(person.validate() is None)
self.assertEqual(person.name, person.name)
self.assertEqual(person.age, person.age)
self.assertEqual(person.userid, person.userid)
self.assertEqual(person.created, person.created)
self.assertEqual(person._data['name'], person.name)
self.assertEqual(person._data['age'], person.age)
self.assertEqual(person._data['userid'], person.userid)
self.assertEqual(person._data['created'], person.created)
# Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid'])
def test_default_values_set_to_None(self):
"""Ensure that default field values are used when creating a document.
"""
class Person(Document):
name = StringField()
age = IntField(default=30, required=False)
userid = StringField(default=lambda: 'test', required=True)
created = DateTimeField(default=datetime.datetime.utcnow)
# Trying setting values to None
person = Person(name=None, age=None, userid=None, created=None)
# Confirm saving now would store values
data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
self.assertTrue(person.validate() is None)
self.assertEqual(person.name, person.name)
self.assertEqual(person.age, person.age)
self.assertEqual(person.userid, person.userid)
self.assertEqual(person.created, person.created)
self.assertEqual(person._data['name'], person.name)
self.assertEqual(person._data['age'], person.age)
self.assertEqual(person._data['userid'], person.userid)
self.assertEqual(person._data['created'], person.created)
# Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
def test_default_values_when_setting_to_None(self):
"""Ensure that default field values are used when creating a document.
"""
class Person(Document):
name = StringField()
age = IntField(default=30, required=False)
userid = StringField(default=lambda: 'test', required=True)
created = DateTimeField(default=datetime.datetime.utcnow)
person = Person()
person.name = None
person.age = None
person.userid = None
person.created = None
# Confirm saving now would store values
data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
self.assertTrue(person.validate() is None)
self.assertEqual(person.name, person.name)
self.assertEqual(person.age, person.age)
self.assertEqual(person.userid, person.userid)
self.assertEqual(person.created, person.created)
self.assertEqual(person._data['name'], person.name)
self.assertEqual(person._data['age'], person.age)
self.assertEqual(person._data['userid'], person.userid)
self.assertEqual(person._data['created'], person.created)
# Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
def test_default_values_when_deleting_value(self):
"""Ensure that default field values are used when creating a document.
"""
class Person(Document):
name = StringField()
age = IntField(default=30, required=False)
userid = StringField(default=lambda: 'test', required=True)
created = DateTimeField(default=datetime.datetime.utcnow)
person = Person(name="Ross")
del person.name
del person.age
del person.userid
del person.created
data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
self.assertTrue(person.validate() is None)
self.assertEqual(person.name, person.name)
self.assertEqual(person.age, person.age)
self.assertEqual(person.userid, person.userid)
self.assertEqual(person.created, person.created)
self.assertEqual(person._data['name'], person.name)
self.assertEqual(person._data['age'], person.age)
self.assertEqual(person._data['userid'], person.userid)
self.assertEqual(person._data['created'], person.created)
# Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
def test_required_values(self): def test_required_values(self):
"""Ensure that required field constraints are enforced. """Ensure that required field constraints are enforced.
@ -404,11 +525,39 @@ class FieldTest(unittest.TestCase):
log.time = datetime.date.today() log.time = datetime.date.today()
log.validate() log.validate()
log.time = datetime.datetime.now().isoformat(' ')
log.validate()
if dateutil:
log.time = datetime.datetime.now().isoformat('T')
log.validate()
log.time = -1 log.time = -1
self.assertRaises(ValidationError, log.validate) self.assertRaises(ValidationError, log.validate)
log.time = '1pm' log.time = 'ABC'
self.assertRaises(ValidationError, log.validate) self.assertRaises(ValidationError, log.validate)
def test_datetime_tz_aware_mark_as_changed(self):
from mongoengine import connection
# Reset the connections
connection._connection_settings = {}
connection._connections = {}
connection._dbs = {}
connect(db='mongoenginetest', tz_aware=True)
class LogEntry(Document):
time = DateTimeField()
LogEntry.drop_collection()
LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save()
log = LogEntry.objects.first()
log.time = datetime.datetime(2013, 1, 1, 0, 0, 0)
self.assertEqual(['time'], log._changed_fields)
def test_datetime(self): def test_datetime(self):
"""Tests showing pymongo datetime fields handling of microseconds. """Tests showing pymongo datetime fields handling of microseconds.
Microseconds are rounded to the nearest millisecond and pre UTC Microseconds are rounded to the nearest millisecond and pre UTC
@ -462,6 +611,66 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection() LogEntry.drop_collection()
def test_datetime_usage(self):
"""Tests for regular datetime fields"""
class LogEntry(Document):
date = DateTimeField()
LogEntry.drop_collection()
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01)
log = LogEntry()
log.date = d1
log.validate()
log.save()
for query in (d1, d1.isoformat(' ')):
log1 = LogEntry.objects.get(date=query)
self.assertEqual(log, log1)
if dateutil:
log1 = LogEntry.objects.get(date=d1.isoformat('T'))
self.assertEqual(log, log1)
LogEntry.drop_collection()
# create 60 log entries
for i in xrange(1950, 2010):
d = datetime.datetime(i, 01, 01, 00, 00, 01)
LogEntry(date=d).save()
self.assertEqual(LogEntry.objects.count(), 60)
# Test ordering
logs = LogEntry.objects.order_by("date")
count = logs.count()
i = 0
while i == count - 1:
self.assertTrue(logs[i].date <= logs[i + 1].date)
i += 1
logs = LogEntry.objects.order_by("-date")
count = logs.count()
i = 0
while i == count - 1:
self.assertTrue(logs[i].date >= logs[i + 1].date)
i += 1
# Test searching
logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1))
self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1))
self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter(
date__lte=datetime.datetime(2011, 1, 1),
date__gte=datetime.datetime(2000, 1, 1),
)
self.assertEqual(logs.count(), 10)
LogEntry.drop_collection()
def test_complexdatetime_storage(self): def test_complexdatetime_storage(self):
"""Tests for complex datetime fields - which can handle microseconds """Tests for complex datetime fields - which can handle microseconds
without rounding. without rounding.
@ -788,6 +997,53 @@ class FieldTest(unittest.TestCase):
self.assertRaises(ValidationError, e.save) self.assertRaises(ValidationError, e.save)
def test_complex_field_same_value_not_changed(self):
"""
If a complex field is set to the same value, it should not be marked as
changed.
"""
class Simple(Document):
mapping = ListField()
Simple.drop_collection()
e = Simple().save()
e.mapping = []
self.assertEqual([], e._changed_fields)
class Simple(Document):
mapping = DictField()
Simple.drop_collection()
e = Simple().save()
e.mapping = {}
self.assertEqual([], e._changed_fields)
def test_slice_marks_field_as_changed(self):
class Simple(Document):
widgets = ListField()
simple = Simple(widgets=[1, 2, 3, 4]).save()
simple.widgets[:3] = []
self.assertEqual(['widgets'], simple._changed_fields)
simple.save()
simple = simple.reload()
self.assertEqual(simple.widgets, [4])
def test_del_slice_marks_field_as_changed(self):
class Simple(Document):
widgets = ListField()
simple = Simple(widgets=[1, 2, 3, 4]).save()
del simple.widgets[:3]
self.assertEqual(['widgets'], simple._changed_fields)
simple.save()
simple = simple.reload()
self.assertEqual(simple.widgets, [4])
def test_list_field_complex(self): def test_list_field_complex(self):
"""Ensure that the list fields can handle the complex types.""" """Ensure that the list fields can handle the complex types."""
@ -1841,45 +2097,6 @@ class FieldTest(unittest.TestCase):
Shirt.drop_collection() Shirt.drop_collection()
def test_geo_indexes(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""
class Event(Document):
title = StringField()
location = GeoPointField()
Event.drop_collection()
event = Event(title="Coltrane Motion @ Double Door",
location=[41.909889, -87.677137])
event.save()
info = Event.objects._collection.index_information()
self.assertTrue(u'location_2d' in info)
self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')])
Event.drop_collection()
def test_geo_embedded_indexes(self):
"""Ensure that indexes are created automatically for GeoPointFields on
embedded documents.
"""
class Venue(EmbeddedDocument):
location = GeoPointField()
name = StringField()
class Event(Document):
title = StringField()
venue = EmbeddedDocumentField(Venue)
Event.drop_collection()
venue = Venue(name="Double Door", location=[41.909889, -87.677137])
event = Event(title="Coltrane Motion", venue=venue)
event.save()
info = Event.objects._collection.index_information()
self.assertTrue(u'location_2d' in info)
self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')])
def test_ensure_unique_default_instances(self): def test_ensure_unique_default_instances(self):
"""Ensure that every field has it's own unique default instance.""" """Ensure that every field has it's own unique default instance."""
class D(Document): class D(Document):
@ -1917,6 +2134,38 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 1000) self.assertEqual(c['next'], 1000)
def test_sequence_field_get_next_value(self):
class Person(Document):
id = SequenceField(primary_key=True)
name = StringField()
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
Person(name="Person %s" % x).save()
self.assertEqual(Person.id.get_next_value(), 11)
self.db['mongoengine.counters'].drop()
self.assertEqual(Person.id.get_next_value(), 1)
class Person(Document):
id = SequenceField(primary_key=True, value_decorator=str)
name = StringField()
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
Person(name="Person %s" % x).save()
self.assertEqual(Person.id.get_next_value(), '11')
self.db['mongoengine.counters'].drop()
self.assertEqual(Person.id.get_next_value(), '1')
def test_sequence_field_sequence_name(self): def test_sequence_field_sequence_name(self):
class Person(Document): class Person(Document):
id = SequenceField(primary_key=True, sequence_name='jelly') id = SequenceField(primary_key=True, sequence_name='jelly')
@ -2225,6 +2474,37 @@ class FieldTest(unittest.TestCase):
user = User(email='me@example.com') user = User(email='me@example.com')
self.assertTrue(user.validate() is None) self.assertTrue(user.validate() is None)
def test_tuples_as_tuples(self):
"""
Ensure that tuples remain tuples when they are
inside a ComplexBaseField
"""
from mongoengine.base import BaseField
class EnumField(BaseField):
def __init__(self, **kwargs):
super(EnumField, self).__init__(**kwargs)
def to_mongo(self, value):
return value
def to_python(self, value):
return tuple(value)
class TestDoc(Document):
items = ListField(EnumField())
TestDoc.drop_collection()
tuples = [(100, 'Testing')]
doc = TestDoc()
doc.items = tuples
doc.save()
x = TestDoc.objects().get()
self.assertTrue(x is not None)
self.assertTrue(len(x.items) == 1)
self.assertTrue(tuple(x.items[0]) in tuples)
self.assertTrue(x.items[0] in tuples)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -15,6 +14,12 @@ from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.python_support import PY3, b, StringIO from mongoengine.python_support import PY3, b, StringIO
try:
from PIL import Image
HAS_PIL = True
except ImportError:
HAS_PIL = False
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png')
@ -256,14 +261,25 @@ class FileTest(unittest.TestCase):
self.assertFalse(test_file.the_file in [{"test": 1}]) self.assertFalse(test_file.the_file in [{"test": 1}])
def test_image_field(self): def test_image_field(self):
if PY3: if not HAS_PIL:
raise SkipTest('PIL does not have Python 3 support') raise SkipTest('PIL not installed')
class TestImage(Document): class TestImage(Document):
image = ImageField() image = ImageField()
TestImage.drop_collection() TestImage.drop_collection()
with tempfile.TemporaryFile() as f:
f.write(b("Hello World!"))
f.flush()
t = TestImage()
try:
t.image.put(f)
self.fail("Should have raised an invalidation error")
except ValidationError, e:
self.assertEquals("%s" % e, "Invalid image: cannot identify image file")
t = TestImage() t = TestImage()
t.image.put(open(TEST_IMAGE_PATH, 'rb')) t.image.put(open(TEST_IMAGE_PATH, 'rb'))
t.save() t.save()
@ -279,8 +295,8 @@ class FileTest(unittest.TestCase):
t.image.delete() t.image.delete()
def test_image_field_reassigning(self): def test_image_field_reassigning(self):
if PY3: if not HAS_PIL:
raise SkipTest('PIL does not have Python 3 support') raise SkipTest('PIL not installed')
class TestFile(Document): class TestFile(Document):
the_file = ImageField() the_file = ImageField()
@ -295,8 +311,8 @@ class FileTest(unittest.TestCase):
self.assertEqual(test_file.the_file.size, (45, 101)) self.assertEqual(test_file.the_file.size, (45, 101))
def test_image_field_resize(self): def test_image_field_resize(self):
if PY3: if not HAS_PIL:
raise SkipTest('PIL does not have Python 3 support') raise SkipTest('PIL not installed')
class TestImage(Document): class TestImage(Document):
image = ImageField(size=(185, 37)) image = ImageField(size=(185, 37))
@ -318,8 +334,8 @@ class FileTest(unittest.TestCase):
t.image.delete() t.image.delete()
def test_image_field_resize_force(self): def test_image_field_resize_force(self):
if PY3: if not HAS_PIL:
raise SkipTest('PIL does not have Python 3 support') raise SkipTest('PIL not installed')
class TestImage(Document): class TestImage(Document):
image = ImageField(size=(185, 37, True)) image = ImageField(size=(185, 37, True))
@ -341,8 +357,8 @@ class FileTest(unittest.TestCase):
t.image.delete() t.image.delete()
def test_image_field_thumbnail(self): def test_image_field_thumbnail(self):
if PY3: if not HAS_PIL:
raise SkipTest('PIL does not have Python 3 support') raise SkipTest('PIL not installed')
class TestImage(Document): class TestImage(Document):
image = ImageField(thumbnail_size=(92, 18)) image = ImageField(thumbnail_size=(92, 18))
@ -389,6 +405,14 @@ class FileTest(unittest.TestCase):
self.assertEqual(test_file.the_file.read(), self.assertEqual(test_file.the_file.read(),
b('Hello, World!')) b('Hello, World!'))
test_file = TestFile.objects.first()
test_file.the_file = b('HELLO, WORLD!')
test_file.save()
test_file = TestFile.objects.first()
self.assertEqual(test_file.the_file.read(),
b('HELLO, WORLD!'))
def test_copyable(self): def test_copyable(self):
class PutFile(Document): class PutFile(Document):
the_file = FileField() the_file = FileField()
@ -408,6 +432,54 @@ class FileTest(unittest.TestCase):
self.assertEqual(putfile, copy.copy(putfile)) self.assertEqual(putfile, copy.copy(putfile))
self.assertEqual(putfile, copy.deepcopy(putfile)) self.assertEqual(putfile, copy.deepcopy(putfile))
def test_get_image_by_grid_id(self):
if not HAS_PIL:
raise SkipTest('PIL not installed')
class TestImage(Document):
image1 = ImageField()
image2 = ImageField()
TestImage.drop_collection()
t = TestImage()
t.image1.put(open(TEST_IMAGE_PATH, 'rb'))
t.image2.put(open(TEST_IMAGE2_PATH, 'rb'))
t.save()
test = TestImage.objects.first()
grid_id = test.image1.grid_id
self.assertEqual(1, TestImage.objects(Q(image1=grid_id)
or Q(image2=grid_id)).count())
def test_complex_field_filefield(self):
"""Ensure you can add meta data to file"""
class Animal(Document):
genus = StringField()
family = StringField()
photos = ListField(FileField())
Animal.drop_collection()
marmot = Animal(genus='Marmota', family='Sciuridae')
marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk
photos_field = marmot._fields['photos'].field
new_proxy = photos_field.get_proxy_obj('photos', marmot)
new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar')
marmot_photo.close()
marmot.photos.append(new_proxy)
marmot.save()
marmot = Animal.objects.get()
self.assertEqual(marmot.photos[0].content_type, 'image/jpeg')
self.assertEqual(marmot.photos[0].foo, 'bar')
self.assertEqual(marmot.photos[0].get().length, 8313)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

274
tests/fields/geo.py Normal file
View File

@ -0,0 +1,274 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import *
from mongoengine.connection import get_db
__all__ = ("GeoFieldTest", )
class GeoFieldTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def _test_for_expected_error(self, Cls, loc, expected):
try:
Cls(loc=loc).validate()
self.fail()
except ValidationError, e:
self.assertEqual(expected, e.to_dict()['loc'])
def test_geopoint_validation(self):
class Location(Document):
loc = GeoPointField()
invalid_coords = [{"x": 1, "y": 2}, 5, "a"]
expected = 'GeoPointField can only accept tuples or lists of (x, y)'
for coord in invalid_coords:
self._test_for_expected_error(Location, coord, expected)
invalid_coords = [[], [1], [1, 2, 3]]
for coord in invalid_coords:
expected = "Value (%s) must be a two-dimensional point" % repr(coord)
self._test_for_expected_error(Location, coord, expected)
invalid_coords = [[{}, {}], ("a", "b")]
for coord in invalid_coords:
expected = "Both values (%s) in point must be float or int" % repr(coord)
self._test_for_expected_error(Location, coord, expected)
def test_point_validation(self):
class Location(Document):
loc = PointField()
invalid_coords = {"x": 1, "y": 2}
expected = 'PointField can only accept a valid GeoJson dictionary or lists of (x, y)'
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MadeUp", "coordinates": []}
expected = 'PointField type must be "Point"'
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "Point", "coordinates": [1, 2, 3]}
expected = "Value ([1, 2, 3]) must be a two-dimensional point"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [5, "a"]
expected = "PointField can only accept lists of [x, y]"
for coord in invalid_coords:
self._test_for_expected_error(Location, coord, expected)
invalid_coords = [[], [1], [1, 2, 3]]
for coord in invalid_coords:
expected = "Value (%s) must be a two-dimensional point" % repr(coord)
self._test_for_expected_error(Location, coord, expected)
invalid_coords = [[{}, {}], ("a", "b")]
for coord in invalid_coords:
expected = "Both values (%s) in point must be float or int" % repr(coord)
self._test_for_expected_error(Location, coord, expected)
Location(loc=[1, 2]).validate()
def test_linestring_validation(self):
class Location(Document):
loc = LineStringField()
invalid_coords = {"x": 1, "y": 2}
expected = 'LineStringField can only accept a valid GeoJson dictionary or lists of (x, y)'
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MadeUp", "coordinates": [[]]}
expected = 'LineStringField type must be "LineString"'
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "LineString", "coordinates": [[1, 2, 3]]}
expected = "Invalid LineString:\nValue ([1, 2, 3]) must be a two-dimensional point"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [5, "a"]
expected = "Invalid LineString must contain at least one valid point"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[1]]
expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0])
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[1, 2, 3]]
expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0])
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[{}, {}]], [("a", "b")]]
for coord in invalid_coords:
expected = "Invalid LineString:\nBoth values (%s) in point must be float or int" % repr(coord[0])
self._test_for_expected_error(Location, coord, expected)
Location(loc=[[1, 2], [3, 4], [5, 6], [1,2]]).validate()
def test_polygon_validation(self):
class Location(Document):
loc = PolygonField()
invalid_coords = {"x": 1, "y": 2}
expected = 'PolygonField can only accept a valid GeoJson dictionary or lists of (x, y)'
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MadeUp", "coordinates": [[]]}
expected = 'PolygonField type must be "Polygon"'
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "Polygon", "coordinates": [[[1, 2, 3]]]}
expected = "Invalid Polygon:\nValue ([1, 2, 3]) must be a two-dimensional point"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[5, "a"]]]
expected = "Invalid Polygon:\nBoth values ([5, 'a']) in point must be float or int"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[]]]
expected = "Invalid Polygon must contain at least one valid linestring"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[1, 2, 3]]]
expected = "Invalid Polygon:\nValue ([1, 2, 3]) must be a two-dimensional point"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[{}, {}]], [("a", "b")]]
expected = "Invalid Polygon:\nBoth values ([{}, {}]) in point must be float or int, Both values (('a', 'b')) in point must be float or int"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[1, 2], [3, 4]]]
expected = "Invalid Polygon:\nLineStrings must start and end at the same point"
self._test_for_expected_error(Location, invalid_coords, expected)
Location(loc=[[[1, 2], [3, 4], [5, 6], [1, 2]]]).validate()
def test_indexes_geopoint(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""
class Event(Document):
title = StringField()
location = GeoPointField()
geo_indicies = Event._geo_indices()
self.assertEqual(geo_indicies, [{'fields': [('location', '2d')]}])
def test_geopoint_embedded_indexes(self):
"""Ensure that indexes are created automatically for GeoPointFields on
embedded documents.
"""
class Venue(EmbeddedDocument):
location = GeoPointField()
name = StringField()
class Event(Document):
title = StringField()
venue = EmbeddedDocumentField(Venue)
geo_indicies = Event._geo_indices()
self.assertEqual(geo_indicies, [{'fields': [('venue.location', '2d')]}])
def test_indexes_2dsphere(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""
class Event(Document):
title = StringField()
point = PointField()
line = LineStringField()
polygon = PolygonField()
geo_indicies = Event._geo_indices()
self.assertTrue({'fields': [('line', '2dsphere')]} in geo_indicies)
self.assertTrue({'fields': [('polygon', '2dsphere')]} in geo_indicies)
self.assertTrue({'fields': [('point', '2dsphere')]} in geo_indicies)
def test_indexes_2dsphere_embedded(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""
class Venue(EmbeddedDocument):
name = StringField()
point = PointField()
line = LineStringField()
polygon = PolygonField()
class Event(Document):
title = StringField()
venue = EmbeddedDocumentField(Venue)
geo_indicies = Event._geo_indices()
self.assertTrue({'fields': [('venue.line', '2dsphere')]} in geo_indicies)
self.assertTrue({'fields': [('venue.polygon', '2dsphere')]} in geo_indicies)
self.assertTrue({'fields': [('venue.point', '2dsphere')]} in geo_indicies)
def test_geo_indexes_recursion(self):
class Location(Document):
name = StringField()
location = GeoPointField()
class Parent(Document):
name = StringField()
location = ReferenceField(Location)
Location.drop_collection()
Parent.drop_collection()
list(Parent.objects)
collection = Parent._get_collection()
info = collection.index_information()
self.assertFalse('location_2d' in info)
self.assertEqual(len(Parent._geo_indices()), 0)
self.assertEqual(len(Location._geo_indices()), 1)
def test_geo_indexes_auto_index(self):
# Test just listing the fields
class Log(Document):
location = PointField(auto_index=False)
datetime = DateTimeField()
meta = {
'indexes': [[("location", "2dsphere"), ("datetime", 1)]]
}
self.assertEqual([], Log._geo_indices())
Log.drop_collection()
Log.ensure_indexes()
info = Log._get_collection().index_information()
self.assertEqual(info["location_2dsphere_datetime_1"]["key"],
[('location', '2dsphere'), ('datetime', 1)])
# Test listing explicitly
class Log(Document):
location = PointField(auto_index=False)
datetime = DateTimeField()
meta = {
'indexes': [
{'fields': [("location", "2dsphere"), ("datetime", 1)]}
]
}
self.assertEqual([], Log._geo_indices())
Log.drop_collection()
Log.ensure_indexes()
info = Log._get_collection().index_information()
self.assertEqual(info["location_2dsphere_datetime_1"]["key"],
[('location', '2dsphere'), ('datetime', 1)])
if __name__ == '__main__':
unittest.main()

View File

@ -1,6 +1,8 @@
import pickle
from datetime import datetime from datetime import datetime
from mongoengine import * from mongoengine import *
from mongoengine import signals
class PickleEmbedded(EmbeddedDocument): class PickleEmbedded(EmbeddedDocument):
@ -15,6 +17,32 @@ class PickleTest(Document):
photo = FileField() photo = FileField()
class PickleDyanmicEmbedded(DynamicEmbeddedDocument):
date = DateTimeField(default=datetime.now)
class PickleDynamicTest(DynamicDocument):
number = IntField()
class PickleSignalsTest(Document):
number = IntField()
string = StringField(choices=(('One', '1'), ('Two', '2')))
embedded = EmbeddedDocumentField(PickleEmbedded)
lists = ListField(StringField())
@classmethod
def post_save(self, sender, document, created, **kwargs):
pickled = pickle.dumps(document)
@classmethod
def post_delete(self, sender, document, **kwargs):
pickled = pickle.dumps(document)
signals.post_save.connect(PickleSignalsTest.post_save, sender=PickleSignalsTest)
signals.post_delete.connect(PickleSignalsTest.post_delete, sender=PickleSignalsTest)
class Mixin(object): class Mixin(object):
name = StringField() name = StringField()

View File

@ -1,5 +1,5 @@
from transform import * from transform import *
from field_list import * from field_list import *
from queryset import * from queryset import *
from visitor import * from visitor import *
from geo import *

418
tests/queryset/geo.py Normal file
View File

@ -0,0 +1,418 @@
import sys
sys.path[0:0] = [""]
import unittest
from datetime import datetime, timedelta
from mongoengine import *
__all__ = ("GeoQueriesTest",)
class GeoQueriesTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
def test_geospatial_operators(self):
"""Ensure that geospatial queries are working.
"""
class Event(Document):
title = StringField()
date = DateTimeField()
location = GeoPointField()
def __unicode__(self):
return self.title
Event.drop_collection()
event1 = Event(title="Coltrane Motion @ Double Door",
date=datetime.now() - timedelta(days=1),
location=[-87.677137, 41.909889]).save()
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
date=datetime.now() - timedelta(days=10),
location=[-122.4194155, 37.7749295]).save()
event3 = Event(title="Coltrane Motion @ Empty Bottle",
date=datetime.now(),
location=[-87.686638, 41.900474]).save()
# find all events "near" pitchfork office, chicago.
# note that "near" will show the san francisco event, too,
# although it sorts to last.
events = Event.objects(location__near=[-87.67892, 41.9120459])
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event1, event3, event2])
# find events within 5 degrees of pitchfork office, chicago
point_and_distance = [[-87.67892, 41.9120459], 5]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 2)
events = list(events)
self.assertTrue(event2 not in events)
self.assertTrue(event1 in events)
self.assertTrue(event3 in events)
# ensure ordering is respected by "near"
events = Event.objects(location__near=[-87.67892, 41.9120459])
events = events.order_by("-date")
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event3, event1, event2])
# find events within 10 degrees of san francisco
point = [-122.415579, 37.7566023]
events = Event.objects(location__near=point, location__max_distance=10)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
# find events within 10 degrees of san francisco
point_and_distance = [[-122.415579, 37.7566023], 10]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
# find events within 1 degree of greenpoint, broolyn, nyc, ny
point_and_distance = [[-73.9509714, 40.7237134], 1]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 0)
# ensure ordering is respected by "within_distance"
point_and_distance = [[-87.67892, 41.9120459], 10]
events = Event.objects(location__within_distance=point_and_distance)
events = events.order_by("-date")
self.assertEqual(events.count(), 2)
self.assertEqual(events[0], event3)
# check that within_box works
box = [(-125.0, 35.0), (-100.0, 40.0)]
events = Event.objects(location__within_box=box)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id)
polygon = [
(-87.694445, 41.912114),
(-87.69084, 41.919395),
(-87.681742, 41.927186),
(-87.654276, 41.911731),
(-87.656164, 41.898061),
]
events = Event.objects(location__within_polygon=polygon)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event1.id)
polygon2 = [
(-1.742249, 54.033586),
(-1.225891, 52.792797),
(-4.40094, 53.389881)
]
events = Event.objects(location__within_polygon=polygon2)
self.assertEqual(events.count(), 0)
def test_geo_spatial_embedded(self):
class Venue(EmbeddedDocument):
location = GeoPointField()
name = StringField()
class Event(Document):
title = StringField()
venue = EmbeddedDocumentField(Venue)
Event.drop_collection()
venue1 = Venue(name="The Rock", location=[-87.677137, 41.909889])
venue2 = Venue(name="The Bridge", location=[-122.4194155, 37.7749295])
event1 = Event(title="Coltrane Motion @ Double Door",
venue=venue1).save()
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
venue=venue2).save()
event3 = Event(title="Coltrane Motion @ Empty Bottle",
venue=venue1).save()
# find all events "near" pitchfork office, chicago.
# note that "near" will show the san francisco event, too,
# although it sorts to last.
events = Event.objects(venue__location__near=[-87.67892, 41.9120459])
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event1, event3, event2])
def test_spherical_geospatial_operators(self):
"""Ensure that spherical geospatial queries are working
"""
class Point(Document):
location = GeoPointField()
Point.drop_collection()
# These points are one degree apart, which (according to Google Maps)
# is about 110 km apart at this place on the Earth.
north_point = Point(location=[-122, 38]).save() # Near Concord, CA
south_point = Point(location=[-122, 37]).save() # Near Santa Cruz, CA
earth_radius = 6378.009 # in km (needs to be a float for dividing by)
# Finds both points because they are within 60 km of the reference
# point equidistant between them.
points = Point.objects(location__near_sphere=[-122, 37.5])
self.assertEqual(points.count(), 2)
# Same behavior for _within_spherical_distance
points = Point.objects(
location__within_spherical_distance=[[-122, 37.5], 60/earth_radius]
)
self.assertEqual(points.count(), 2)
points = Point.objects(location__near_sphere=[-122, 37.5],
location__max_distance=60 / earth_radius)
self.assertEqual(points.count(), 2)
# Finds both points, but orders the north point first because it's
# closer to the reference point to the north.
points = Point.objects(location__near_sphere=[-122, 38.5])
self.assertEqual(points.count(), 2)
self.assertEqual(points[0].id, north_point.id)
self.assertEqual(points[1].id, south_point.id)
# Finds both points, but orders the south point first because it's
# closer to the reference point to the south.
points = Point.objects(location__near_sphere=[-122, 36.5])
self.assertEqual(points.count(), 2)
self.assertEqual(points[0].id, south_point.id)
self.assertEqual(points[1].id, north_point.id)
# Finds only one point because only the first point is within 60km of
# the reference point to the south.
points = Point.objects(
location__within_spherical_distance=[[-122, 36.5], 60/earth_radius])
self.assertEqual(points.count(), 1)
self.assertEqual(points[0].id, south_point.id)
def test_2dsphere_point(self):
class Event(Document):
title = StringField()
date = DateTimeField()
location = PointField()
def __unicode__(self):
return self.title
Event.drop_collection()
event1 = Event(title="Coltrane Motion @ Double Door",
date=datetime.now() - timedelta(days=1),
location=[-87.677137, 41.909889])
event1.save()
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
date=datetime.now() - timedelta(days=10),
location=[-122.4194155, 37.7749295]).save()
event3 = Event(title="Coltrane Motion @ Empty Bottle",
date=datetime.now(),
location=[-87.686638, 41.900474]).save()
# find all events "near" pitchfork office, chicago.
# note that "near" will show the san francisco event, too,
# although it sorts to last.
events = Event.objects(location__near=[-87.67892, 41.9120459])
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event1, event3, event2])
# find events within 5 degrees of pitchfork office, chicago
point_and_distance = [[-87.67892, 41.9120459], 2]
events = Event.objects(location__geo_within_center=point_and_distance)
self.assertEqual(events.count(), 2)
events = list(events)
self.assertTrue(event2 not in events)
self.assertTrue(event1 in events)
self.assertTrue(event3 in events)
# ensure ordering is respected by "near"
events = Event.objects(location__near=[-87.67892, 41.9120459])
events = events.order_by("-date")
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event3, event1, event2])
# find events within 10km of san francisco
point = [-122.415579, 37.7566023]
events = Event.objects(location__near=point, location__max_distance=10000)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
# find events within 1km of greenpoint, broolyn, nyc, ny
events = Event.objects(location__near=[-73.9509714, 40.7237134], location__max_distance=1000)
self.assertEqual(events.count(), 0)
# ensure ordering is respected by "near"
events = Event.objects(location__near=[-87.67892, 41.9120459],
location__max_distance=10000).order_by("-date")
self.assertEqual(events.count(), 2)
self.assertEqual(events[0], event3)
# check that within_box works
box = [(-125.0, 35.0), (-100.0, 40.0)]
events = Event.objects(location__geo_within_box=box)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id)
polygon = [
(-87.694445, 41.912114),
(-87.69084, 41.919395),
(-87.681742, 41.927186),
(-87.654276, 41.911731),
(-87.656164, 41.898061),
]
events = Event.objects(location__geo_within_polygon=polygon)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event1.id)
polygon2 = [
(-1.742249, 54.033586),
(-1.225891, 52.792797),
(-4.40094, 53.389881)
]
events = Event.objects(location__geo_within_polygon=polygon2)
self.assertEqual(events.count(), 0)
def test_2dsphere_point_embedded(self):
class Venue(EmbeddedDocument):
location = GeoPointField()
name = StringField()
class Event(Document):
title = StringField()
venue = EmbeddedDocumentField(Venue)
Event.drop_collection()
venue1 = Venue(name="The Rock", location=[-87.677137, 41.909889])
venue2 = Venue(name="The Bridge", location=[-122.4194155, 37.7749295])
event1 = Event(title="Coltrane Motion @ Double Door",
venue=venue1).save()
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
venue=venue2).save()
event3 = Event(title="Coltrane Motion @ Empty Bottle",
venue=venue1).save()
# find all events "near" pitchfork office, chicago.
# note that "near" will show the san francisco event, too,
# although it sorts to last.
events = Event.objects(venue__location__near=[-87.67892, 41.9120459])
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event1, event3, event2])
def test_linestring(self):
class Road(Document):
name = StringField()
line = LineStringField()
Road.drop_collection()
Road(name="66", line=[[40, 5], [41, 6]]).save()
# near
point = {"type": "Point", "coordinates": [40, 5]}
roads = Road.objects.filter(line__near=point["coordinates"]).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(line__near=point).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(line__near={"$geometry": point}).count()
self.assertEqual(1, roads)
# Within
polygon = {"type": "Polygon",
"coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
roads = Road.objects.filter(line__geo_within=polygon["coordinates"]).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(line__geo_within=polygon).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(line__geo_within={"$geometry": polygon}).count()
self.assertEqual(1, roads)
# Intersects
line = {"type": "LineString",
"coordinates": [[40, 5], [40, 6]]}
roads = Road.objects.filter(line__geo_intersects=line["coordinates"]).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(line__geo_intersects=line).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(line__geo_intersects={"$geometry": line}).count()
self.assertEqual(1, roads)
polygon = {"type": "Polygon",
"coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
roads = Road.objects.filter(line__geo_intersects=polygon["coordinates"]).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(line__geo_intersects=polygon).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(line__geo_intersects={"$geometry": polygon}).count()
self.assertEqual(1, roads)
def test_polygon(self):
class Road(Document):
name = StringField()
poly = PolygonField()
Road.drop_collection()
Road(name="66", poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]]).save()
# near
point = {"type": "Point", "coordinates": [40, 5]}
roads = Road.objects.filter(poly__near=point["coordinates"]).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(poly__near=point).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(poly__near={"$geometry": point}).count()
self.assertEqual(1, roads)
# Within
polygon = {"type": "Polygon",
"coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
roads = Road.objects.filter(poly__geo_within=polygon["coordinates"]).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(poly__geo_within=polygon).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(poly__geo_within={"$geometry": polygon}).count()
self.assertEqual(1, roads)
# Intersects
line = {"type": "LineString",
"coordinates": [[40, 5], [41, 6]]}
roads = Road.objects.filter(poly__geo_intersects=line["coordinates"]).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(poly__geo_intersects=line).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(poly__geo_intersects={"$geometry": line}).count()
self.assertEqual(1, roads)
polygon = {"type": "Polygon",
"coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
roads = Road.objects.filter(poly__geo_intersects=polygon["coordinates"]).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(poly__geo_intersects=polygon).count()
self.assertEqual(1, roads)
roads = Road.objects.filter(poly__geo_intersects={"$geometry": polygon}).count()
self.assertEqual(1, roads)
if __name__ == '__main__':
unittest.main()

View File

@ -1,4 +1,3 @@
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -31,12 +30,17 @@ class QuerySetTest(unittest.TestCase):
def setUp(self): def setUp(self):
connect(db='mongoenginetest') connect(db='mongoenginetest')
class PersonMeta(EmbeddedDocument):
weight = IntField()
class Person(Document): class Person(Document):
name = StringField() name = StringField()
age = IntField() age = IntField()
person_meta = EmbeddedDocumentField(PersonMeta)
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
Person.drop_collection() Person.drop_collection()
self.PersonMeta = PersonMeta
self.Person = Person self.Person = Person
def test_initialisation(self): def test_initialisation(self):
@ -116,6 +120,15 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(people), 1) self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B') self.assertEqual(people[0].name, 'User B')
# Test slice limit and skip cursor reset
qs = self.Person.objects[1:2]
# fetch then delete the cursor
qs._cursor
qs._cursor_obj = None
people = list(qs)
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
people = list(self.Person.objects[1:1]) people = list(self.Person.objects[1:1])
self.assertEqual(len(people), 0) self.assertEqual(len(people), 0)
@ -274,7 +287,7 @@ class QuerySetTest(unittest.TestCase):
a_objects = A.objects(s='test1') a_objects = A.objects(s='test1')
query = B.objects(ref__in=a_objects) query = B.objects(ref__in=a_objects)
query = query.filter(boolfield=True) query = query.filter(boolfield=True)
self.assertEquals(query.count(), 1) self.assertEqual(query.count(), 1)
def test_update_write_concern(self): def test_update_write_concern(self):
"""Test that passing write_concern works""" """Test that passing write_concern works"""
@ -287,15 +300,19 @@ class QuerySetTest(unittest.TestCase):
name='Test User', write_concern=write_concern) name='Test User', write_concern=write_concern)
author.save(write_concern=write_concern) author.save(write_concern=write_concern)
self.Person.objects.update(set__name='Ross', result = self.Person.objects.update(
write_concern=write_concern) set__name='Ross', write_concern={"w": 1})
self.assertEqual(result, 1)
result = self.Person.objects.update(
set__name='Ross', write_concern={"w": 0})
self.assertEqual(result, None)
author = self.Person.objects.first() result = self.Person.objects.update_one(
self.assertEqual(author.name, 'Ross') set__name='Test User', write_concern={"w": 1})
self.assertEqual(result, 1)
self.Person.objects.update_one(set__name='Test User', write_concern=write_concern) result = self.Person.objects.update_one(
author = self.Person.objects.first() set__name='Test User', write_concern={"w": 0})
self.assertEqual(author.name, 'Test User') self.assertEqual(result, None)
def test_update_update_has_a_value(self): def test_update_update_has_a_value(self):
"""Test to ensure that update is passed a value to update to""" """Test to ensure that update is passed a value to update to"""
@ -524,6 +541,50 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(club.members['John']['gender'], "F") self.assertEqual(club.members['John']['gender'], "F")
self.assertEqual(club.members['John']['age'], 14) self.assertEqual(club.members['John']['age'], 14)
def test_update_results(self):
self.Person.drop_collection()
result = self.Person(name="Bob", age=25).update(upsert=True, full_result=True)
self.assertTrue(isinstance(result, dict))
self.assertTrue("upserted" in result)
self.assertFalse(result["updatedExisting"])
bob = self.Person.objects.first()
result = bob.update(set__age=30, full_result=True)
self.assertTrue(isinstance(result, dict))
self.assertTrue(result["updatedExisting"])
self.Person(name="Bob", age=20).save()
result = self.Person.objects(name="Bob").update(set__name="bobby", multi=True)
self.assertEqual(result, 2)
def test_upsert(self):
self.Person.drop_collection()
self.Person.objects(pk=ObjectId(), name="Bob", age=30).update(upsert=True)
bob = self.Person.objects.first()
self.assertEqual("Bob", bob.name)
self.assertEqual(30, bob.age)
def test_upsert_one(self):
self.Person.drop_collection()
self.Person.objects(name="Bob", age=30).update_one(upsert=True)
bob = self.Person.objects.first()
self.assertEqual("Bob", bob.name)
self.assertEqual(30, bob.age)
def test_set_on_insert(self):
self.Person.drop_collection()
self.Person.objects(pk=ObjectId()).update(set__name='Bob', set_on_insert__age=30, upsert=True)
bob = self.Person.objects.first()
self.assertEqual("Bob", bob.name)
self.assertEqual(30, bob.age)
def test_get_or_create(self): def test_get_or_create(self):
"""Ensure that ``get_or_create`` returns one result or creates a new """Ensure that ``get_or_create`` returns one result or creates a new
document. document.
@ -592,14 +653,13 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(q, 1) # 1 for the insert self.assertEqual(q, 1) # 1 for the insert
Blog.drop_collection() Blog.drop_collection()
Blog.ensure_indexes()
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0) self.assertEqual(q, 0)
Blog.ensure_indexes()
self.assertEqual(q, 1)
Blog.objects.insert(blogs) Blog.objects.insert(blogs)
self.assertEqual(q, 3) # 1 for insert, and 1 for in bulk fetch (3 in total) self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch
Blog.drop_collection() Blog.drop_collection()
@ -763,7 +823,7 @@ class QuerySetTest(unittest.TestCase):
p = p.snapshot(True).slave_okay(True).timeout(True) p = p.snapshot(True).slave_okay(True).timeout(True)
self.assertEqual(p._cursor_args, self.assertEqual(p._cursor_args,
{'snapshot': True, 'slave_okay': True, 'timeout': True}) {'snapshot': True, 'slave_okay': True, 'timeout': True})
def test_repeated_iteration(self): def test_repeated_iteration(self):
"""Ensure that QuerySet rewinds itself one iteration finishes. """Ensure that QuerySet rewinds itself one iteration finishes.
@ -805,6 +865,7 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue("Doc: 0" in docs_string) self.assertTrue("Doc: 0" in docs_string)
self.assertEqual(docs.count(), 1000) self.assertEqual(docs.count(), 1000)
self.assertTrue('(remaining elements truncated)' in "%s" % docs)
# Limit and skip # Limit and skip
docs = docs[1:4] docs = docs[1:4]
@ -1233,7 +1294,7 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
authors = ListField(ReferenceField(self.Person, authors = ListField(ReferenceField(self.Person,
reverse_delete_rule=PULL)) reverse_delete_rule=PULL))
BlogPost.drop_collection() BlogPost.drop_collection()
self.Person.drop_collection() self.Person.drop_collection()
@ -1291,6 +1352,49 @@ class QuerySetTest(unittest.TestCase):
self.Person.objects()[:1].delete() self.Person.objects()[:1].delete()
self.assertEqual(1, BlogPost.objects.count()) self.assertEqual(1, BlogPost.objects.count())
def test_reference_field_find(self):
"""Ensure cascading deletion of referring documents from the database.
"""
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person)
BlogPost.drop_collection()
self.Person.drop_collection()
me = self.Person(name='Test User').save()
BlogPost(content="test 123", author=me).save()
self.assertEqual(1, BlogPost.objects(author=me).count())
self.assertEqual(1, BlogPost.objects(author=me.pk).count())
self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count())
self.assertEqual(1, BlogPost.objects(author__in=[me]).count())
self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count())
self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count())
def test_reference_field_find_dbref(self):
"""Ensure cascading deletion of referring documents from the database.
"""
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person, dbref=True)
BlogPost.drop_collection()
self.Person.drop_collection()
me = self.Person(name='Test User').save()
BlogPost(content="test 123", author=me).save()
self.assertEqual(1, BlogPost.objects(author=me).count())
self.assertEqual(1, BlogPost.objects(author=me.pk).count())
self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count())
self.assertEqual(1, BlogPost.objects(author__in=[me]).count())
self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count())
self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count())
def test_update(self): def test_update(self):
"""Ensure that atomic updates work properly. """Ensure that atomic updates work properly.
""" """
@ -1514,6 +1618,32 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(message.authors[1].name, "Ross") self.assertEqual(message.authors[1].name, "Ross")
self.assertEqual(message.authors[2].name, "Adam") self.assertEqual(message.authors[2].name, "Adam")
def test_reload_embedded_docs_instance(self):
class SubDoc(EmbeddedDocument):
val = IntField()
class Doc(Document):
embedded = EmbeddedDocumentField(SubDoc)
doc = Doc(embedded=SubDoc(val=0)).save()
doc.reload()
self.assertEqual(doc.pk, doc.embedded._instance.pk)
def test_reload_list_embedded_docs_instance(self):
class SubDoc(EmbeddedDocument):
val = IntField()
class Doc(Document):
embedded = ListField(EmbeddedDocumentField(SubDoc))
doc = Doc(embedded=[SubDoc(val=0)]).save()
doc.reload()
self.assertEqual(doc.pk, doc.embedded[0]._instance.pk)
def test_order_by(self): def test_order_by(self):
"""Ensure that QuerySets may be ordered. """Ensure that QuerySets may be ordered.
""" """
@ -2083,6 +2213,19 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='ageless person').save() self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.average('age')), avg) self.assertEqual(int(self.Person.objects.average('age')), avg)
# dot notation
self.Person(name='person meta', person_meta=self.PersonMeta(weight=0)).save()
self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), 0)
for i, weight in enumerate(ages):
self.Person(name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save()
self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), avg)
self.Person(name='test meta none').save()
self.assertEqual(int(self.Person.objects.average('person_meta.weight')), avg)
def test_sum(self): def test_sum(self):
"""Ensure that field can be summed over correctly. """Ensure that field can be summed over correctly.
""" """
@ -2095,6 +2238,153 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='ageless person').save() self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
for i, age in enumerate(ages):
self.Person(name='test meta%s' % i, person_meta=self.PersonMeta(weight=age)).save()
self.assertEqual(int(self.Person.objects.sum('person_meta.weight')), sum(ages))
self.Person(name='weightless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
def test_embedded_average(self):
class Pay(EmbeddedDocument):
value = DecimalField()
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(value=150)).save()
Doc(name=u"Isabella Luanna",
pay=Pay(value=530)).save()
Doc(name=u"Tayza mariana",
pay=Pay(value=165)).save()
Doc(name=u"Eliana Costa",
pay=Pay(value=115)).save()
self.assertEqual(
Doc.objects.average('pay.value'),
240)
def test_embedded_array_average(self):
class Pay(EmbeddedDocument):
values = ListField(DecimalField())
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(values=[150, 100])).save()
Doc(name=u"Isabella Luanna",
pay=Pay(values=[530, 100])).save()
Doc(name=u"Tayza mariana",
pay=Pay(values=[165, 100])).save()
Doc(name=u"Eliana Costa",
pay=Pay(values=[115, 100])).save()
self.assertEqual(
Doc.objects.average('pay.values'),
170)
def test_array_average(self):
class Doc(Document):
values = ListField(DecimalField())
Doc.drop_collection()
Doc(values=[150, 100]).save()
Doc(values=[530, 100]).save()
Doc(values=[165, 100]).save()
Doc(values=[115, 100]).save()
self.assertEqual(
Doc.objects.average('values'),
170)
def test_embedded_sum(self):
class Pay(EmbeddedDocument):
value = DecimalField()
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(value=150)).save()
Doc(name=u"Isabella Luanna",
pay=Pay(value=530)).save()
Doc(name=u"Tayza mariana",
pay=Pay(value=165)).save()
Doc(name=u"Eliana Costa",
pay=Pay(value=115)).save()
self.assertEqual(
Doc.objects.sum('pay.value'),
960)
def test_embedded_array_sum(self):
class Pay(EmbeddedDocument):
values = ListField(DecimalField())
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(values=[150, 100])).save()
Doc(name=u"Isabella Luanna",
pay=Pay(values=[530, 100])).save()
Doc(name=u"Tayza mariana",
pay=Pay(values=[165, 100])).save()
Doc(name=u"Eliana Costa",
pay=Pay(values=[115, 100])).save()
self.assertEqual(
Doc.objects.sum('pay.values'),
1360)
def test_array_sum(self):
class Doc(Document):
values = ListField(DecimalField())
Doc.drop_collection()
Doc(values=[150, 100]).save()
Doc(values=[530, 100]).save()
Doc(values=[165, 100]).save()
Doc(values=[115, 100]).save()
self.assertEqual(
Doc.objects.sum('values'),
1360)
def test_distinct(self): def test_distinct(self):
"""Ensure that the QuerySet.distinct method works. """Ensure that the QuerySet.distinct method works.
""" """
@ -2380,167 +2670,6 @@ class QuerySetTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.Person.drop_collection() self.Person.drop_collection()
def test_geospatial_operators(self):
"""Ensure that geospatial queries are working.
"""
class Event(Document):
title = StringField()
date = DateTimeField()
location = GeoPointField()
def __unicode__(self):
return self.title
Event.drop_collection()
event1 = Event(title="Coltrane Motion @ Double Door",
date=datetime.now() - timedelta(days=1),
location=[41.909889, -87.677137])
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
date=datetime.now() - timedelta(days=10),
location=[37.7749295, -122.4194155])
event3 = Event(title="Coltrane Motion @ Empty Bottle",
date=datetime.now(),
location=[41.900474, -87.686638])
event1.save()
event2.save()
event3.save()
# find all events "near" pitchfork office, chicago.
# note that "near" will show the san francisco event, too,
# although it sorts to last.
events = Event.objects(location__near=[41.9120459, -87.67892])
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event1, event3, event2])
# find events within 5 degrees of pitchfork office, chicago
point_and_distance = [[41.9120459, -87.67892], 5]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 2)
events = list(events)
self.assertTrue(event2 not in events)
self.assertTrue(event1 in events)
self.assertTrue(event3 in events)
# ensure ordering is respected by "near"
events = Event.objects(location__near=[41.9120459, -87.67892])
events = events.order_by("-date")
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event3, event1, event2])
# find events within 10 degrees of san francisco
point = [37.7566023, -122.415579]
events = Event.objects(location__near=point, location__max_distance=10)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
# find events within 10 degrees of san francisco
point_and_distance = [[37.7566023, -122.415579], 10]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
# find events within 1 degree of greenpoint, broolyn, nyc, ny
point_and_distance = [[40.7237134, -73.9509714], 1]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 0)
# ensure ordering is respected by "within_distance"
point_and_distance = [[41.9120459, -87.67892], 10]
events = Event.objects(location__within_distance=point_and_distance)
events = events.order_by("-date")
self.assertEqual(events.count(), 2)
self.assertEqual(events[0], event3)
# check that within_box works
box = [(35.0, -125.0), (40.0, -100.0)]
events = Event.objects(location__within_box=box)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id)
# check that polygon works for users who have a server >= 1.9
server_version = tuple(
get_connection().server_info()['version'].split('.')
)
required_version = tuple("1.9.0".split("."))
if server_version >= required_version:
polygon = [
(41.912114,-87.694445),
(41.919395,-87.69084),
(41.927186,-87.681742),
(41.911731,-87.654276),
(41.898061,-87.656164),
]
events = Event.objects(location__within_polygon=polygon)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event1.id)
polygon2 = [
(54.033586,-1.742249),
(52.792797,-1.225891),
(53.389881,-4.40094)
]
events = Event.objects(location__within_polygon=polygon2)
self.assertEqual(events.count(), 0)
Event.drop_collection()
def test_spherical_geospatial_operators(self):
"""Ensure that spherical geospatial queries are working
"""
class Point(Document):
location = GeoPointField()
Point.drop_collection()
# These points are one degree apart, which (according to Google Maps)
# is about 110 km apart at this place on the Earth.
north_point = Point(location=[-122, 38]) # Near Concord, CA
south_point = Point(location=[-122, 37]) # Near Santa Cruz, CA
north_point.save()
south_point.save()
earth_radius = 6378.009; # in km (needs to be a float for dividing by)
# Finds both points because they are within 60 km of the reference
# point equidistant between them.
points = Point.objects(location__near_sphere=[-122, 37.5])
self.assertEqual(points.count(), 2)
# Same behavior for _within_spherical_distance
points = Point.objects(
location__within_spherical_distance=[[-122, 37.5], 60/earth_radius]
);
self.assertEqual(points.count(), 2)
points = Point.objects(location__near_sphere=[-122, 37.5],
location__max_distance=60 / earth_radius);
self.assertEqual(points.count(), 2)
# Finds both points, but orders the north point first because it's
# closer to the reference point to the north.
points = Point.objects(location__near_sphere=[-122, 38.5])
self.assertEqual(points.count(), 2)
self.assertEqual(points[0].id, north_point.id)
self.assertEqual(points[1].id, south_point.id)
# Finds both points, but orders the south point first because it's
# closer to the reference point to the south.
points = Point.objects(location__near_sphere=[-122, 36.5])
self.assertEqual(points.count(), 2)
self.assertEqual(points[0].id, south_point.id)
self.assertEqual(points[1].id, north_point.id)
# Finds only one point because only the first point is within 60km of
# the reference point to the south.
points = Point.objects(
location__within_spherical_distance=[[-122, 36.5], 60/earth_radius])
self.assertEqual(points.count(), 1)
self.assertEqual(points[0].id, south_point.id)
Point.drop_collection()
def test_custom_querysets(self): def test_custom_querysets(self):
"""Ensure that custom QuerySet classes may be used. """Ensure that custom QuerySet classes may be used.
""" """
@ -3127,7 +3256,7 @@ class QuerySetTest(unittest.TestCase):
class Foo(EmbeddedDocument): class Foo(EmbeddedDocument):
shape = StringField() shape = StringField()
color = StringField() color = StringField()
trick = BooleanField() thick = BooleanField()
meta = {'allow_inheritance': False} meta = {'allow_inheritance': False}
class Bar(Document): class Bar(Document):
@ -3136,17 +3265,20 @@ class QuerySetTest(unittest.TestCase):
Bar.drop_collection() Bar.drop_collection()
b1 = Bar(foo=[Foo(shape= "square", color ="purple", thick = False), b1 = Bar(foo=[Foo(shape="square", color="purple", thick=False),
Foo(shape= "circle", color ="red", thick = True)]) Foo(shape="circle", color="red", thick=True)])
b1.save() b1.save()
b2 = Bar(foo=[Foo(shape= "square", color ="red", thick = True), b2 = Bar(foo=[Foo(shape="square", color="red", thick=True),
Foo(shape= "circle", color ="purple", thick = False)]) Foo(shape="circle", color="purple", thick=False)])
b2.save() b2.save()
ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"}))
self.assertEqual([b1], ak) self.assertEqual([b1], ak)
ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple")))
self.assertEqual([b1], ak)
def test_upsert_includes_cls(self): def test_upsert_includes_cls(self):
"""Upserts should include _cls information for inheritable classes """Upserts should include _cls information for inheritable classes
""" """
@ -3176,7 +3308,10 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual([], bars) self.assertEqual([], bars)
self.assertRaises(ConfigurationError, Bar.objects, self.assertRaises(ConfigurationError, Bar.objects,
read_preference='Primary') read_preference='Primary')
bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
def test_json_simple(self): def test_json_simple(self):
@ -3257,6 +3392,9 @@ class QuerySetTest(unittest.TestCase):
User(name="Bob Dole", age=89, price=Decimal('1.11')).save() User(name="Bob Dole", age=89, price=Decimal('1.11')).save()
User(name="Barack Obama", age=51, price=Decimal('2.22')).save() User(name="Barack Obama", age=51, price=Decimal('2.22')).save()
results = User.objects.only('id', 'name').as_pymongo()
self.assertEqual(sorted(results[0].keys()), sorted(['_id', 'name']))
users = User.objects.only('name', 'price').as_pymongo() users = User.objects.only('name', 'price').as_pymongo()
results = list(users) results = list(users)
self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[0], dict))
@ -3276,6 +3414,28 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(results[1]['name'], 'Barack Obama') self.assertEqual(results[1]['name'], 'Barack Obama')
self.assertEqual(results[1]['price'], Decimal('2.22')) self.assertEqual(results[1]['price'], Decimal('2.22'))
def test_as_pymongo_json_limit_fields(self):
class User(Document):
email = EmailField(unique=True, required=True)
password_hash = StringField(db_field='password_hash', required=True)
password_salt = StringField(db_field='password_salt', required=True)
User.drop_collection()
User(email="ross@example.com", password_salt="SomeSalt", password_hash="SomeHash").save()
serialized_user = User.objects.exclude('password_salt', 'password_hash').as_pymongo()[0]
self.assertEqual(set(['_id', 'email']), set(serialized_user.keys()))
serialized_user = User.objects.exclude('id', 'password_salt', 'password_hash').to_json()
self.assertEqual('[{"email": "ross@example.com"}]', serialized_user)
serialized_user = User.objects.exclude('password_salt').only('email').as_pymongo()[0]
self.assertEqual(set(['email']), set(serialized_user.keys()))
serialized_user = User.objects.exclude('password_salt').only('email').to_json()
self.assertEqual('[{"email": "ross@example.com"}]', serialized_user)
def test_no_dereference(self): def test_no_dereference(self):
class Organization(Document): class Organization(Document):
@ -3295,8 +3455,83 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(isinstance(qs.first().organization, Organization)) self.assertTrue(isinstance(qs.first().organization, Organization))
self.assertFalse(isinstance(qs.no_dereference().first().organization, self.assertFalse(isinstance(qs.no_dereference().first().organization,
Organization)) Organization))
self.assertFalse(isinstance(qs.no_dereference().get().organization,
Organization))
self.assertTrue(isinstance(qs.first().organization, Organization)) self.assertTrue(isinstance(qs.first().organization, Organization))
def test_cached_queryset(self):
class Person(Document):
name = StringField()
Person.drop_collection()
for i in xrange(100):
Person(name="No: %s" % i).save()
with query_counter() as q:
self.assertEqual(q, 0)
people = Person.objects
[x for x in people]
self.assertEqual(100, len(people._result_cache))
self.assertEqual(None, people._len)
self.assertEqual(q, 1)
list(people)
self.assertEqual(100, people._len) # Caused by list calling len
self.assertEqual(q, 1)
people.count() # count is cached
self.assertEqual(q, 1)
def test_cache_not_cloned(self):
class User(Document):
name = StringField()
def __unicode__(self):
return self.name
User.drop_collection()
User(name="Alice").save()
User(name="Bob").save()
users = User.objects.all().order_by('name')
self.assertEqual("%s" % users, "[<User: Alice>, <User: Bob>]")
self.assertEqual(2, len(users._result_cache))
users = users.filter(name="Bob")
self.assertEqual("%s" % users, "[<User: Bob>]")
self.assertEqual(1, len(users._result_cache))
def test_no_cache(self):
"""Ensure you can add meta data to file"""
class Noddy(Document):
fields = DictField()
Noddy.drop_collection()
for i in xrange(100):
noddy = Noddy()
for j in range(20):
noddy.fields["key"+str(j)] = "value "+str(j)
noddy.save()
docs = Noddy.objects.no_cache()
counter = len([1 for i in docs])
self.assertEquals(counter, 100)
self.assertEquals(len(list(docs)), 100)
self.assertRaises(TypeError, lambda: len(docs))
with query_counter() as q:
self.assertEqual(q, 0)
list(docs)
self.assertEqual(q, 1)
list(docs)
self.assertEqual(q, 2)
def test_nested_queryset_iterator(self): def test_nested_queryset_iterator(self):
# Try iterating the same queryset twice, nested. # Try iterating the same queryset twice, nested.
names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George'] names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George']
@ -3313,30 +3548,121 @@ class QuerySetTest(unittest.TestCase):
User(name=name).save() User(name=name).save()
users = User.objects.all().order_by('name') users = User.objects.all().order_by('name')
outer_count = 0 outer_count = 0
inner_count = 0 inner_count = 0
inner_total_count = 0 inner_total_count = 0
self.assertEqual(users.count(), 7) with query_counter() as q:
self.assertEqual(q, 0)
for i, outer_user in enumerate(users):
self.assertEqual(outer_user.name, names[i])
outer_count += 1
inner_count = 0
# Calling len might disrupt the inner loop if there are bugs
self.assertEqual(users.count(), 7) self.assertEqual(users.count(), 7)
for j, inner_user in enumerate(users): for i, outer_user in enumerate(users):
self.assertEqual(inner_user.name, names[j]) self.assertEqual(outer_user.name, names[i])
inner_count += 1 outer_count += 1
inner_total_count += 1 inner_count = 0
self.assertEqual(inner_count, 7) # inner loop should always be executed seven times # Calling len might disrupt the inner loop if there are bugs
self.assertEqual(users.count(), 7)
for j, inner_user in enumerate(users):
self.assertEqual(inner_user.name, names[j])
inner_count += 1
inner_total_count += 1
self.assertEqual(inner_count, 7) # inner loop should always be executed seven times
self.assertEqual(outer_count, 7) # outer loop should be executed seven times total
self.assertEqual(inner_total_count, 7 * 7) # inner loop should be executed fourtynine times total
self.assertEqual(q, 2)
def test_no_sub_classes(self):
class A(Document):
x = IntField()
y = IntField()
meta = {'allow_inheritance': True}
class B(A):
z = IntField()
class C(B):
zz = IntField()
A.drop_collection()
A(x=10, y=20).save()
A(x=15, y=30).save()
B(x=20, y=40).save()
B(x=30, y=50).save()
C(x=40, y=60).save()
self.assertEqual(A.objects.no_sub_classes().count(), 2)
self.assertEqual(A.objects.count(), 5)
self.assertEqual(B.objects.no_sub_classes().count(), 2)
self.assertEqual(B.objects.count(), 3)
self.assertEqual(C.objects.no_sub_classes().count(), 1)
self.assertEqual(C.objects.count(), 1)
for obj in A.objects.no_sub_classes():
self.assertEqual(obj.__class__, A)
for obj in B.objects.no_sub_classes():
self.assertEqual(obj.__class__, B)
for obj in C.objects.no_sub_classes():
self.assertEqual(obj.__class__, C)
def test_query_reference_to_custom_pk_doc(self):
class A(Document):
id = StringField(unique=True, primary_key=True)
class B(Document):
a = ReferenceField(A)
A.drop_collection()
B.drop_collection()
a = A.objects.create(id='custom_id')
b = B.objects.create(a=a)
self.assertEqual(B.objects.count(), 1)
self.assertEqual(B.objects.get(a=a).a, a)
self.assertEqual(B.objects.get(a=a.id).a, a)
def test_cls_query_in_subclassed_docs(self):
class Animal(Document):
name = StringField()
meta = {
'allow_inheritance': True
}
class Dog(Animal):
pass
class Cat(Animal):
pass
self.assertEqual(Animal.objects(name='Charlie')._query, {
'name': 'Charlie',
'_cls': { '$in': ('Animal', 'Animal.Dog', 'Animal.Cat') }
})
self.assertEqual(Dog.objects(name='Charlie')._query, {
'name': 'Charlie',
'_cls': 'Animal.Dog'
})
self.assertEqual(Cat.objects(name='Charlie')._query, {
'name': 'Charlie',
'_cls': 'Animal.Cat'
})
self.assertEqual(outer_count, 7) # outer loop should be executed seven times total
self.assertEqual(inner_total_count, 7 * 7) # inner loop should be executed fourtynine times total
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,4 +1,3 @@
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]

View File

@ -1,4 +1,3 @@
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -69,11 +68,11 @@ class QTest(unittest.TestCase):
x = IntField() x = IntField()
y = StringField() y = StringField()
# Check than an error is raised when conflicting queries are anded query = (Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc)
def invalid_combination(): self.assertEqual(query, {'$and': [{'x': {'$lt': 7}}, {'x': {'$lt': 3}}]})
query = Q(x__lt=7) & Q(x__lt=3)
query.to_query(TestDoc) query = (Q(y="a") & Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc)
self.assertRaises(InvalidQueryError, invalid_combination) self.assertEqual(query, {'$and': [{'y': "a"}, {'x': {'$lt': 7}}, {'x': {'$lt': 3}}]})
# Check normal cases work without an error # Check normal cases work without an error
query = Q(x__lt=7) & Q(x__gt=3) query = Q(x__lt=7) & Q(x__gt=3)
@ -326,10 +325,26 @@ class QTest(unittest.TestCase):
pk = ObjectId() pk = ObjectId()
User(email='example@example.com', pk=pk).save() User(email='example@example.com', pk=pk).save()
self.assertEqual(1, User.objects.filter( self.assertEqual(1, User.objects.filter(Q(email='example@example.com') |
Q(email='example@example.com') | Q(name='John Doe')).limit(2).filter(pk=pk).count())
Q(name='John Doe')
).limit(2).filter(pk=pk).count()) def test_chained_q_or_filtering(self):
class Post(EmbeddedDocument):
name = StringField(required=True)
class Item(Document):
postables = ListField(EmbeddedDocumentField(Post))
Item.drop_collection()
Item(postables=[Post(name="a"), Post(name="b")]).save()
Item(postables=[Post(name="a"), Post(name="c")]).save()
Item(postables=[Post(name="a"), Post(name="b"), Post(name="c")]).save()
self.assertEqual(Item.objects(Q(postables__name="a") & Q(postables__name="b")).count(), 2)
self.assertEqual(Item.objects.filter(postables__name="a").filter(postables__name="b").count(), 2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,4 +1,3 @@
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import unittest import unittest
@ -57,6 +56,9 @@ class ConnectionTest(unittest.TestCase):
self.assertTrue(isinstance(db, pymongo.database.Database)) self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'mongoenginetest') self.assertEqual(db.name, 'mongoenginetest')
c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({})
def test_register_connection(self): def test_register_connection(self):
"""Ensure that connections with different aliases may be registered. """Ensure that connections with different aliases may be registered.
""" """

View File

@ -1,4 +1,3 @@
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import unittest import unittest
@ -6,7 +5,8 @@ import unittest
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.context_managers import (switch_db, switch_collection, from mongoengine.context_managers import (switch_db, switch_collection,
no_dereference, query_counter) no_sub_classes, no_dereference,
query_counter)
class ContextManagersTest(unittest.TestCase): class ContextManagersTest(unittest.TestCase):
@ -139,6 +139,54 @@ class ContextManagersTest(unittest.TestCase):
self.assertTrue(isinstance(group.ref, User)) self.assertTrue(isinstance(group.ref, User))
self.assertTrue(isinstance(group.generic, User)) self.assertTrue(isinstance(group.generic, User))
def test_no_sub_classes(self):
class A(Document):
x = IntField()
y = IntField()
meta = {'allow_inheritance': True}
class B(A):
z = IntField()
class C(B):
zz = IntField()
A.drop_collection()
A(x=10, y=20).save()
A(x=15, y=30).save()
B(x=20, y=40).save()
B(x=30, y=50).save()
C(x=40, y=60).save()
self.assertEqual(A.objects.count(), 5)
self.assertEqual(B.objects.count(), 3)
self.assertEqual(C.objects.count(), 1)
with no_sub_classes(A) as A:
self.assertEqual(A.objects.count(), 2)
for obj in A.objects:
self.assertEqual(obj.__class__, A)
with no_sub_classes(B) as B:
self.assertEqual(B.objects.count(), 2)
for obj in B.objects:
self.assertEqual(obj.__class__, B)
with no_sub_classes(C) as C:
self.assertEqual(C.objects.count(), 1)
for obj in C.objects:
self.assertEqual(obj.__class__, C)
# Confirm context manager exit correctly
self.assertEqual(A.objects.count(), 5)
self.assertEqual(B.objects.count(), 3)
self.assertEqual(C.objects.count(), 1)
def test_query_counter(self): def test_query_counter(self):
connect('mongoenginetest') connect('mongoenginetest')
db = get_db() db = get_db()

View File

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import unittest import unittest
@ -1122,37 +1121,32 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2) self.assertEqual(q, 2)
def test_tuples_as_tuples(self): def test_objectid_reference_across_databases(self):
""" # mongoenginetest - Is default connection alias from setUp()
Ensure that tuples remain tuples when they are # Register Aliases
inside a ComplexBaseField register_connection('testdb-1', 'mongoenginetest2')
"""
from mongoengine.base import BaseField
class EnumField(BaseField): class User(Document):
name = StringField()
meta = {"db_alias": "testdb-1"}
def __init__(self, **kwargs): class Book(Document):
super(EnumField, self).__init__(**kwargs) name = StringField()
author = ReferenceField(User)
def to_mongo(self, value): # Drops
return value User.drop_collection()
Book.drop_collection()
def to_python(self, value): user = User(name="Ross").save()
return tuple(value) Book(name="MongoEngine for pros", author=user).save()
class TestDoc(Document): # Can't use query_counter across databases - so test the _data object
items = ListField(EnumField()) book = Book.objects.first()
self.assertFalse(isinstance(book._data['author'], User))
TestDoc.drop_collection() book.select_related()
tuples = [(100, 'Testing')] self.assertTrue(isinstance(book._data['author'], User))
doc = TestDoc()
doc.items = tuples
doc.save()
x = TestDoc.objects().get()
self.assertTrue(x is not None)
self.assertTrue(len(x.items) == 1)
self.assertTrue(tuple(x.items[0]) in tuples)
self.assertTrue(x.items[0] in tuples)
def test_non_ascii_pk(self): def test_non_ascii_pk(self):
""" """

View File

@ -1,50 +1,43 @@
from __future__ import with_statement
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import unittest import unittest
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from mongoengine.python_support import PY3
from mongoengine import * from mongoengine import *
from mongoengine.django.shortcuts import get_document_or_404
from django.http import Http404
from django.template import Context, Template
from django.conf import settings
from django.core.paginator import Paginator
settings.configure(
USE_TZ=True,
INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'),
AUTH_USER_MODEL=('mongo_auth.MongoUser'),
)
try: try:
from mongoengine.django.shortcuts import get_document_or_404 from django.contrib.auth import authenticate, get_user_model
from mongoengine.django.auth import User
from django.http import Http404 from mongoengine.django.mongo_auth.models import MongoUser, MongoUserManager
from django.template import Context, Template DJ15 = True
from django.conf import settings except Exception:
from django.core.paginator import Paginator DJ15 = False
from django.contrib.sessions.tests import SessionTestsMixin
settings.configure( from mongoengine.django.sessions import SessionStore, MongoSession
USE_TZ=True,
INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'),
AUTH_USER_MODEL=('mongo_auth.MongoUser'),
)
try:
from django.contrib.auth import authenticate, get_user_model
from mongoengine.django.auth import User
from mongoengine.django.mongo_auth.models import MongoUser, MongoUserManager
DJ15 = True
except Exception:
DJ15 = False
from django.contrib.sessions.tests import SessionTestsMixin
from mongoengine.django.sessions import SessionStore, MongoSession
except Exception, err:
if PY3:
SessionTestsMixin = type # dummy value so no error
SessionStore = None # dummy value so no error
else:
raise err
from datetime import tzinfo, timedelta from datetime import tzinfo, timedelta
ZERO = timedelta(0) ZERO = timedelta(0)
class FixedOffset(tzinfo): class FixedOffset(tzinfo):
"""Fixed offset in minutes east from UTC.""" """Fixed offset in minutes east from UTC."""
def __init__(self, offset, name): def __init__(self, offset, name):
self.__offset = timedelta(minutes = offset) self.__offset = timedelta(minutes=offset)
self.__name = name self.__name = name
def utcoffset(self, dt): def utcoffset(self, dt):
@ -71,8 +64,6 @@ def activate_timezone(tz):
class QuerySetTest(unittest.TestCase): class QuerySetTest(unittest.TestCase):
def setUp(self): def setUp(self):
if PY3:
raise SkipTest('django does not have Python 3 support')
connect(db='mongoenginetest') connect(db='mongoenginetest')
class Person(Document): class Person(Document):
@ -151,29 +142,79 @@ class QuerySetTest(unittest.TestCase):
# Try iterating the same queryset twice, nested, in a Django template. # Try iterating the same queryset twice, nested, in a Django template.
names = ['A', 'B', 'C', 'D'] names = ['A', 'B', 'C', 'D']
class User(Document): class CustomUser(Document):
name = StringField() name = StringField()
def __unicode__(self): def __unicode__(self):
return self.name return self.name
User.drop_collection() CustomUser.drop_collection()
for name in names: for name in names:
User(name=name).save() CustomUser(name=name).save()
users = User.objects.all().order_by('name') users = CustomUser.objects.all().order_by('name')
template = Template("{% for user in users %}{{ user.name }}{% ifequal forloop.counter 2 %} {% for inner_user in users %}{{ inner_user.name }}{% endfor %} {% endifequal %}{% endfor %}") template = Template("{% for user in users %}{{ user.name }}{% ifequal forloop.counter 2 %} {% for inner_user in users %}{{ inner_user.name }}{% endfor %} {% endifequal %}{% endfor %}")
rendered = template.render(Context({'users': users})) rendered = template.render(Context({'users': users}))
self.assertEqual(rendered, 'AB ABCD CD') self.assertEqual(rendered, 'AB ABCD CD')
def test_filter(self):
"""Ensure that a queryset and filters work as expected
"""
class Note(Document):
text = StringField()
for i in xrange(1, 101):
Note(name="Note: %s" % i).save()
# Check the count
self.assertEqual(Note.objects.count(), 100)
# Get the first 10 and confirm
notes = Note.objects[:10]
self.assertEqual(notes.count(), 10)
# Test djangos template filters
# self.assertEqual(length(notes), 10)
t = Template("{{ notes.count }}")
c = Context({"notes": notes})
self.assertEqual(t.render(c), "10")
# Test with skip
notes = Note.objects.skip(90)
self.assertEqual(notes.count(), 10)
# Test djangos template filters
self.assertEqual(notes.count(), 10)
t = Template("{{ notes.count }}")
c = Context({"notes": notes})
self.assertEqual(t.render(c), "10")
# Test with limit
notes = Note.objects.skip(90)
self.assertEqual(notes.count(), 10)
# Test djangos template filters
self.assertEqual(notes.count(), 10)
t = Template("{{ notes.count }}")
c = Context({"notes": notes})
self.assertEqual(t.render(c), "10")
# Test with skip and limit
notes = Note.objects.skip(10).limit(10)
# Test djangos template filters
self.assertEqual(notes.count(), 10)
t = Template("{{ notes.count }}")
c = Context({"notes": notes})
self.assertEqual(t.render(c), "10")
class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase): class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase):
backend = SessionStore backend = SessionStore
def setUp(self): def setUp(self):
if PY3:
raise SkipTest('django does not have Python 3 support')
connect(db='mongoenginetest') connect(db='mongoenginetest')
MongoSession.drop_collection() MongoSession.drop_collection()
super(MongoDBSessionTest, self).setUp() super(MongoDBSessionTest, self).setUp()
@ -211,8 +252,6 @@ class MongoAuthTest(unittest.TestCase):
} }
def setUp(self): def setUp(self):
if PY3:
raise SkipTest('django does not have Python 3 support')
if not DJ15: if not DJ15:
raise SkipTest('mongo_auth requires Django 1.5') raise SkipTest('mongo_auth requires Django 1.5')
connect(db='mongoenginetest') connect(db='mongoenginetest')
@ -224,7 +263,7 @@ class MongoAuthTest(unittest.TestCase):
def test_user_manager(self): def test_user_manager(self):
manager = get_user_model()._default_manager manager = get_user_model()._default_manager
self.assertIsInstance(manager, MongoUserManager) self.assertTrue(isinstance(manager, MongoUserManager))
def test_user_manager_exception(self): def test_user_manager_exception(self):
manager = get_user_model()._default_manager manager = get_user_model()._default_manager
@ -234,14 +273,14 @@ class MongoAuthTest(unittest.TestCase):
def test_create_user(self): def test_create_user(self):
manager = get_user_model()._default_manager manager = get_user_model()._default_manager
user = manager.create_user(**self.user_data) user = manager.create_user(**self.user_data)
self.assertIsInstance(user, User) self.assertTrue(isinstance(user, User))
db_user = User.objects.get(username='user') db_user = User.objects.get(username='user')
self.assertEqual(user.id, db_user.id) self.assertEqual(user.id, db_user.id)
def test_authenticate(self): def test_authenticate(self):
get_user_model()._default_manager.create_user(**self.user_data) get_user_model()._default_manager.create_user(**self.user_data)
user = authenticate(username='user', password='fail') user = authenticate(username='user', password='fail')
self.assertIsNone(user) self.assertEqual(None, user)
user = authenticate(username='user', password='test') user = authenticate(username='user', password='test')
db_user = User.objects.get(username='user') db_user = User.objects.get(username='user')
self.assertEqual(user.id, db_user.id) self.assertEqual(user.id, db_user.id)

47
tests/test_jinja.py Normal file
View File

@ -0,0 +1,47 @@
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import *
import jinja2
class TemplateFilterTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
def test_jinja2(self):
env = jinja2.Environment()
class TestData(Document):
title = StringField()
description = StringField()
TestData.drop_collection()
examples = [('A', '1'),
('B', '2'),
('C', '3')]
for title, description in examples:
TestData(title=title, description=description).save()
tmpl = """
{%- for record in content -%}
{%- if loop.first -%}{ {%- endif -%}
"{{ record.title }}": "{{ record.description }}"
{%- if loop.last -%} }{%- else -%},{% endif -%}
{%- endfor -%}
"""
ctx = {'content': TestData.objects}
template = env.from_string(tmpl)
rendered = template.render(**ctx)
self.assertEqual('{"A": "1","B": "2","C": "3"}', rendered)
if __name__ == '__main__':
unittest.main()

View File

@ -43,6 +43,15 @@ class SignalTests(unittest.TestCase):
def pre_save(cls, sender, document, **kwargs): def pre_save(cls, sender, document, **kwargs):
signal_output.append('pre_save signal, %s' % document) signal_output.append('pre_save signal, %s' % document)
@classmethod
def pre_save_post_validation(cls, sender, document, **kwargs):
signal_output.append('pre_save_post_validation signal, %s' % document)
if 'created' in kwargs:
if kwargs['created']:
signal_output.append('Is created')
else:
signal_output.append('Is updated')
@classmethod @classmethod
def post_save(cls, sender, document, **kwargs): def post_save(cls, sender, document, **kwargs):
signal_output.append('post_save signal, %s' % document) signal_output.append('post_save signal, %s' % document)
@ -75,40 +84,19 @@ class SignalTests(unittest.TestCase):
Author.drop_collection() Author.drop_collection()
class Another(Document): class Another(Document):
name = StringField() name = StringField()
def __unicode__(self): def __unicode__(self):
return self.name return self.name
@classmethod
def pre_init(cls, sender, document, **kwargs):
signal_output.append('pre_init Another signal, %s' % cls.__name__)
signal_output.append(str(kwargs['values']))
@classmethod
def post_init(cls, sender, document, **kwargs):
signal_output.append('post_init Another signal, %s' % document)
@classmethod
def pre_save(cls, sender, document, **kwargs):
signal_output.append('pre_save Another signal, %s' % document)
@classmethod
def post_save(cls, sender, document, **kwargs):
signal_output.append('post_save Another signal, %s' % document)
if 'created' in kwargs:
if kwargs['created']:
signal_output.append('Is created')
else:
signal_output.append('Is updated')
@classmethod @classmethod
def pre_delete(cls, sender, document, **kwargs): def pre_delete(cls, sender, document, **kwargs):
signal_output.append('pre_delete Another signal, %s' % document) signal_output.append('pre_delete signal, %s' % document)
@classmethod @classmethod
def post_delete(cls, sender, document, **kwargs): def post_delete(cls, sender, document, **kwargs):
signal_output.append('post_delete Another signal, %s' % document) signal_output.append('post_delete signal, %s' % document)
self.Another = Another self.Another = Another
Another.drop_collection() Another.drop_collection()
@ -133,6 +121,7 @@ class SignalTests(unittest.TestCase):
len(signals.pre_init.receivers), len(signals.pre_init.receivers),
len(signals.post_init.receivers), len(signals.post_init.receivers),
len(signals.pre_save.receivers), len(signals.pre_save.receivers),
len(signals.pre_save_post_validation.receivers),
len(signals.post_save.receivers), len(signals.post_save.receivers),
len(signals.pre_delete.receivers), len(signals.pre_delete.receivers),
len(signals.post_delete.receivers), len(signals.post_delete.receivers),
@ -143,16 +132,13 @@ class SignalTests(unittest.TestCase):
signals.pre_init.connect(Author.pre_init, sender=Author) signals.pre_init.connect(Author.pre_init, sender=Author)
signals.post_init.connect(Author.post_init, sender=Author) signals.post_init.connect(Author.post_init, sender=Author)
signals.pre_save.connect(Author.pre_save, sender=Author) signals.pre_save.connect(Author.pre_save, sender=Author)
signals.pre_save_post_validation.connect(Author.pre_save_post_validation, sender=Author)
signals.post_save.connect(Author.post_save, sender=Author) signals.post_save.connect(Author.post_save, sender=Author)
signals.pre_delete.connect(Author.pre_delete, sender=Author) signals.pre_delete.connect(Author.pre_delete, sender=Author)
signals.post_delete.connect(Author.post_delete, sender=Author) signals.post_delete.connect(Author.post_delete, sender=Author)
signals.pre_bulk_insert.connect(Author.pre_bulk_insert, sender=Author) signals.pre_bulk_insert.connect(Author.pre_bulk_insert, sender=Author)
signals.post_bulk_insert.connect(Author.post_bulk_insert, sender=Author) signals.post_bulk_insert.connect(Author.post_bulk_insert, sender=Author)
signals.pre_init.connect(Another.pre_init, sender=Another)
signals.post_init.connect(Another.post_init, sender=Another)
signals.pre_save.connect(Another.pre_save, sender=Another)
signals.post_save.connect(Another.post_save, sender=Another)
signals.pre_delete.connect(Another.pre_delete, sender=Another) signals.pre_delete.connect(Another.pre_delete, sender=Another)
signals.post_delete.connect(Another.post_delete, sender=Another) signals.post_delete.connect(Another.post_delete, sender=Another)
@ -164,16 +150,13 @@ class SignalTests(unittest.TestCase):
signals.post_delete.disconnect(self.Author.post_delete) signals.post_delete.disconnect(self.Author.post_delete)
signals.pre_delete.disconnect(self.Author.pre_delete) signals.pre_delete.disconnect(self.Author.pre_delete)
signals.post_save.disconnect(self.Author.post_save) signals.post_save.disconnect(self.Author.post_save)
signals.pre_save_post_validation.disconnect(self.Author.pre_save_post_validation)
signals.pre_save.disconnect(self.Author.pre_save) signals.pre_save.disconnect(self.Author.pre_save)
signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert)
signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert)
signals.pre_init.disconnect(self.Another.pre_init)
signals.post_init.disconnect(self.Another.post_init)
signals.post_delete.disconnect(self.Another.post_delete) signals.post_delete.disconnect(self.Another.post_delete)
signals.pre_delete.disconnect(self.Another.pre_delete) signals.pre_delete.disconnect(self.Another.pre_delete)
signals.post_save.disconnect(self.Another.post_save)
signals.pre_save.disconnect(self.Another.pre_save)
signals.post_save.disconnect(self.ExplicitId.post_save) signals.post_save.disconnect(self.ExplicitId.post_save)
@ -182,6 +165,7 @@ class SignalTests(unittest.TestCase):
len(signals.pre_init.receivers), len(signals.pre_init.receivers),
len(signals.post_init.receivers), len(signals.post_init.receivers),
len(signals.pre_save.receivers), len(signals.pre_save.receivers),
len(signals.pre_save_post_validation.receivers),
len(signals.post_save.receivers), len(signals.post_save.receivers),
len(signals.pre_delete.receivers), len(signals.pre_delete.receivers),
len(signals.post_delete.receivers), len(signals.post_delete.receivers),
@ -216,6 +200,8 @@ class SignalTests(unittest.TestCase):
a1 = self.Author(name='Bill Shakespeare') a1 = self.Author(name='Bill Shakespeare')
self.assertEqual(self.get_signal_output(a1.save), [ self.assertEqual(self.get_signal_output(a1.save), [
"pre_save signal, Bill Shakespeare", "pre_save signal, Bill Shakespeare",
"pre_save_post_validation signal, Bill Shakespeare",
"Is created",
"post_save signal, Bill Shakespeare", "post_save signal, Bill Shakespeare",
"Is created" "Is created"
]) ])
@ -224,6 +210,8 @@ class SignalTests(unittest.TestCase):
a1.name = 'William Shakespeare' a1.name = 'William Shakespeare'
self.assertEqual(self.get_signal_output(a1.save), [ self.assertEqual(self.get_signal_output(a1.save), [
"pre_save signal, William Shakespeare", "pre_save signal, William Shakespeare",
"pre_save_post_validation signal, William Shakespeare",
"Is updated",
"post_save signal, William Shakespeare", "post_save signal, William Shakespeare",
"Is updated" "Is updated"
]) ])
@ -252,7 +240,14 @@ class SignalTests(unittest.TestCase):
"Not loaded", "Not loaded",
]) ])
self.Author.objects.delete() def test_queryset_delete_signals(self):
""" Queryset delete should throw some signals. """
self.Another(name='Bill Shakespeare').save()
self.assertEqual(self.get_signal_output(self.Another.objects.delete), [
'pre_delete signal, Bill Shakespeare',
'post_delete signal, Bill Shakespeare',
])
def test_signals_with_explicit_doc_ids(self): def test_signals_with_explicit_doc_ids(self):
""" Model saves must have a created flag the first time.""" """ Model saves must have a created flag the first time."""