Compare commits

..

27 Commits

Author SHA1 Message Date
Stefan Wojcik
2579e0b840 make EmbeddedDocument not hashable by default 2017-04-16 21:27:44 -04:00
Stefan Wojcik
824ec42005 bump version to v0.13.0 and fill in the changelog and the upgrade docs 2017-04-16 14:08:46 -04:00
Stefan Wójcik
466935e9a3 Unicode support in EmailField (#1527) 2017-04-16 13:58:58 -04:00
Stefan Wojcik
b52d3e3a7b added one more item to the v0.12.0 changelog 2017-04-07 10:34:04 -04:00
Stefan Wojcik
888a6da4a5 update the changelog and bump the version to v0.12.0 2017-04-07 10:18:39 -04:00
Omer Katz
972ac73dd9 Merge pull request #1497 from userlocalhost/feature/order_guarantee
added a feature to save object data in order
2017-04-07 10:49:39 +03:00
Hiroyasu OHYAMA
d8b238d5f1 Refactored the implementation of DynamicField extension for storing data in order 2017-04-06 00:42:11 +00:00
Omer Katz
63206c3da2 Merge pull request #1520 from ZoetropeLabs/fix/allow-reference-fields-take-object-ids
Allow ReferenceFields to take ObjectIds
2017-04-02 13:57:58 +03:00
Richard Fortescue-Webb
5713de8966 Use the objectid in the test 2017-03-29 11:34:57 +01:00
Richard Fortescue-Webb
58f293fef3 Allow ReferenceFields to take ObjectIds 2017-03-29 10:34:50 +01:00
Hiroyasu OHYAMA
ffbb2c9689 This is Additional tests for the container_class parameter of DynamicField
This tests DynamicField dereference with ordering guarantee.
2017-03-08 14:46:04 +00:00
Hiroyasu OHYAMA
9cd3dcdebf Added a test for the change of the condition in DeReference processing
This checks DBRef conversion using DynamicField with the ordering
guarantee.
2017-03-08 14:45:43 +00:00
Hiroyasu OHYAMA
f2fe58c3c5 Added a condition to store data to ObjectDict when the items type is it
Previous dereference implementation re-contains data as `dict` except
for the predicted type.
But the OrderedDict is not predicted, so the its data would be converted
`dict` implicitly.
As the result, the order of stored data get wrong. And this patch
prevents it.
2017-03-08 14:35:50 +00:00
Stefan Wojcik
b78010aa94 remove test_last_field_name_like_operator (it's a dupe of the same test in tests/queryset/transform.py) 2017-03-05 21:24:46 -05:00
Stefan Wójcik
49035543b9 cleanup BaseQuerySet.__getitem__ (#1502) 2017-03-05 21:17:53 -05:00
Stefan Wójcik
f9ccf635ca Respect db fields in multiple layers of embedded docs (#1501) 2017-03-05 18:20:09 -05:00
Stefan Wojcik
e8ea294964 test negative indexes (closes #1119) 2017-03-05 18:12:01 -05:00
Stefan Wojcik
19ef2be88b fix #937 2017-03-05 00:05:33 -05:00
Stefan Wojcik
30e8b8186f clean up document instance tests 2017-03-02 00:25:56 -05:00
Stefan Wójcik
741643af5f clean up field unit tests (#1498) 2017-03-02 00:05:10 -05:00
Hiroyasu OHYAMA
6aaf9ba470 removed a checking of dict order because this order is not cared (some implementation might be in ordered, but other one is not) 2017-03-01 09:32:28 +00:00
Hiroyasu OHYAMA
5957dc72eb To achive storing object data in order with minimum implementation, I
changed followings.

- added optional parameter `container_class` which enables to choose
  intermediate class at encoding Python data, instead of additional
  field class.
- removed OrderedDocument class because the equivalent feature could
  be implemented by the outside of Mongoengine.
2017-03-01 09:20:57 +00:00
Hiroyasu OHYAMA
e32a9777d7 added test for OrderedDynamicField and OrderedDocument 2017-02-28 03:35:53 +00:00
Hiroyasu OHYAMA
84a8f1eb2b added OrderedDocument class to decode BSON data to OrderedDict for retrieving data in order 2017-02-28 03:35:39 +00:00
Hiroyasu OHYAMA
6810953014 added OrderedDynamicField class to store data in the defined order because of #203 2017-02-28 03:34:42 +00:00
Ephraim Berkovitch
398964945a Document.objects.create should raise NotUniqueError upon saving duplicate primary key (#1485) 2017-02-27 09:42:44 -05:00
Stefan Wójcik
5f43c032f2 revamp the "connecting" user guide and test more ways of connecting to a replica set (#1490) 2017-02-26 21:29:06 -05:00
15 changed files with 1224 additions and 937 deletions

1
.gitignore vendored
View File

@@ -17,3 +17,4 @@ tests/test_bugfix.py
htmlcov/ htmlcov/
venv venv
venv3 venv3
scratchpad

View File

@@ -5,11 +5,27 @@ Changelog
Development Development
=========== ===========
- (Fill this out as you fix issues and develop your features). - (Fill this out as you fix issues and develop your features).
- Fixed using sets in field choices #1481
Changes in 0.13.0
=================
- POTENTIAL BREAKING CHANGE: Added Unicode support to the `EmailField`, see
docs/upgrade.rst for details.
Changes in 0.12.0
=================
- POTENTIAL BREAKING CHANGE: Fixed limit/skip/hint/batch_size chaining #1476 - POTENTIAL BREAKING CHANGE: Fixed limit/skip/hint/batch_size chaining #1476
- POTENTIAL BREAKING CHANGE: Changed a public `QuerySet.clone_into` method to a private `QuerySet._clone_into` #1476 - POTENTIAL BREAKING CHANGE: Changed a public `QuerySet.clone_into` method to a private `QuerySet._clone_into` #1476
- Fixed the way `Document.objects.create` works with duplicate IDs #1485
- Fixed connecting to a replica set with PyMongo 2.x #1436 - Fixed connecting to a replica set with PyMongo 2.x #1436
- Fixed using sets in field choices #1481
- Fixed deleting items from a `ListField` #1318
- Fixed an obscure error message when filtering by `field__in=non_iterable`. #1237 - Fixed an obscure error message when filtering by `field__in=non_iterable`. #1237
- Fixed behavior of a `dec` update operator #1450
- Added a `rename` update operator #1454
- Added validation for the `db_field` parameter #1448
- Fixed the error message displayed when querying an `EmbeddedDocumentField` by an invalid value #1440
- Fixed the error message displayed when validating unicode URLs #1486
- Raise an error when trying to save an abstract document #1449
Changes in 0.11.0 Changes in 0.11.0
================= =================

View File

@@ -340,14 +340,19 @@ Javascript code that is executed on the database server.
Counting results Counting results
---------------- ----------------
Just as with limiting and skipping results, there is a method on Just as with limiting and skipping results, there is a method on a
:class:`~mongoengine.queryset.QuerySet` objects -- :class:`~mongoengine.queryset.QuerySet` object --
:meth:`~mongoengine.queryset.QuerySet.count`, but there is also a more Pythonic :meth:`~mongoengine.queryset.QuerySet.count`::
way of achieving this::
num_users = len(User.objects) num_users = User.objects.count()
Even if len() is the Pythonic way of counting results, keep in mind that if you concerned about performance, :meth:`~mongoengine.queryset.QuerySet.count` is the way to go since it only execute a server side count query, while len() retrieves the results, places them in cache, and finally counts them. If we compare the performance of the two operations, len() is much slower than :meth:`~mongoengine.queryset.QuerySet.count`. You could technically use ``len(User.objects)`` to get the same result, but it
would be significantly slower than :meth:`~mongoengine.queryset.QuerySet.count`.
When you execute a server-side count query, you let MongoDB do the heavy
lifting and you receive a single integer over the wire. Meanwhile, len()
retrieves all the results, places them in a local cache, and finally counts
them. If we compare the performance of the two operations, len() is much slower
than :meth:`~mongoengine.queryset.QuerySet.count`.
Further aggregation Further aggregation
------------------- -------------------

View File

@@ -6,6 +6,20 @@ Development
*********** ***********
(Fill this out whenever you introduce breaking changes to MongoEngine) (Fill this out whenever you introduce breaking changes to MongoEngine)
0.13.0
******
This release adds Unicode support to the `EmailField` and changes its
structure significantly. Previously, email addresses containing Unicode
characters didn't work at all. Starting with v0.13.0, domains with Unicode
characters are supported out of the box, meaning some emails that previously
didn't pass validation now do. Make sure the rest of your application can
accept such email addresses. Additionally, if you subclassed the `EmailField`
in your application and overrode `EmailField.EMAIL_REGEX`, you will have to
adjust your code to override `EmailField.USER_REGEX`, `EmailField.DOMAIN_REGEX`,
and potentially `EmailField.UTF8_USER_REGEX`.
0.12.0
******
This release includes various fixes for the `BaseQuerySet` methods and how they This release includes various fixes for the `BaseQuerySet` methods and how they
are chained together. Since version 0.10.1 applying limit/skip/hint/batch_size are chained together. Since version 0.10.1 applying limit/skip/hint/batch_size
to an already-existing queryset wouldn't modify the underlying PyMongo cursor. to an already-existing queryset wouldn't modify the underlying PyMongo cursor.

View File

@@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) +
list(signals.__all__) + list(errors.__all__)) list(signals.__all__) + list(errors.__all__))
VERSION = (0, 11, 0) VERSION = (0, 13, 0)
def get_version(): def get_version():

View File

@@ -272,13 +272,6 @@ class BaseDocument(object):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self):
if getattr(self, 'pk', None) is None:
# For new object
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)
def clean(self): def clean(self):
""" """
Hook for doing document level data cleaning before validation is run. Hook for doing document level data cleaning before validation is run.
@@ -684,8 +677,13 @@ class BaseDocument(object):
# class if unavailable # class if unavailable
class_name = son.get('_cls', cls._class_name) class_name = son.get('_cls', cls._class_name)
# Convert SON to a dict, making sure each key is a string # Convert SON to a data dict, making sure each key is a string and
data = {str(key): value for key, value in son.iteritems()} # corresponds to the right db field.
data = {}
for key, value in son.iteritems():
key = str(key)
key = cls._db_field_map.get(key, key)
data[key] = value
# Return correct subclass for document type # Return correct subclass for document type
if class_name != cls._class_name: if class_name != cls._class_name:

View File

@@ -1,3 +1,4 @@
from collections import OrderedDict
from bson import DBRef, SON from bson import DBRef, SON
import six import six
@@ -201,6 +202,10 @@ class DeReference(object):
as_tuple = isinstance(items, tuple) as_tuple = isinstance(items, tuple)
iterator = enumerate(items) iterator = enumerate(items)
data = [] data = []
elif isinstance(items, OrderedDict):
is_list = False
iterator = items.iteritems()
data = OrderedDict()
else: else:
is_list = False is_list = False
iterator = items.iteritems() iterator = items.iteritems()

View File

@@ -60,6 +60,12 @@ class EmbeddedDocument(BaseDocument):
my_metaclass = DocumentMetaclass my_metaclass = DocumentMetaclass
__metaclass__ = DocumentMetaclass __metaclass__ = DocumentMetaclass
# A generic embedded document doesn't have any immutable properties
# that describe it uniquely, hence it shouldn't be hashable. You can
# define your own __hash__ method on a subclass if you need your
# embedded documents to be hashable.
__hash__ = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(EmbeddedDocument, self).__init__(*args, **kwargs) super(EmbeddedDocument, self).__init__(*args, **kwargs)
self._instance = None self._instance = None
@@ -160,6 +166,15 @@ class Document(BaseDocument):
"""Set the primary key.""" """Set the primary key."""
return setattr(self, self._meta['id_field'], value) return setattr(self, self._meta['id_field'], value)
def __hash__(self):
"""Return the hash based on the PK of this document. If it's new
and doesn't have a PK yet, return the default object hash instead.
"""
if self.pk is None:
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)
@classmethod @classmethod
def _get_db(cls): def _get_db(cls):
"""Some Model using other db_alias""" """Some Model using other db_alias"""

View File

@@ -2,9 +2,11 @@ import datetime
import decimal import decimal
import itertools import itertools
import re import re
import socket
import time import time
import uuid import uuid
import warnings import warnings
from collections import Mapping
from operator import itemgetter from operator import itemgetter
from bson import Binary, DBRef, ObjectId, SON from bson import Binary, DBRef, ObjectId, SON
@@ -153,21 +155,105 @@ class EmailField(StringField):
.. versionadded:: 0.4 .. versionadded:: 0.4
""" """
USER_REGEX = re.compile(
EMAIL_REGEX = re.compile( # `dot-atom` defined in RFC 5322 Section 3.2.3.
# dot-atom r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # `quoted-string` defined in RFC 5322 Section 3.2.4.
# quoted-string r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)',
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' re.IGNORECASE
# domain (max length of an ICAAN TLD is 22 characters)
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}|[A-Z0-9-]{2,}(?<!-))$', re.IGNORECASE
) )
UTF8_USER_REGEX = re.compile(
six.u(
# RFC 6531 Section 3.3 extends `atext` (used by dot-atom) to
# include `UTF8-non-ascii`.
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z\u0080-\U0010FFFF]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z\u0080-\U0010FFFF]+)*\Z"
# `quoted-string`
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)'
), re.IGNORECASE | re.UNICODE
)
DOMAIN_REGEX = re.compile(
r'((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z',
re.IGNORECASE
)
error_msg = u'Invalid email address: %s'
def __init__(self, domain_whitelist=None, allow_utf8_user=False,
allow_ip_domain=False, *args, **kwargs):
"""Initialize the EmailField.
Args:
domain_whitelist (list) - list of otherwise invalid domain
names which you'd like to support.
allow_utf8_user (bool) - if True, the user part of the email
address can contain UTF8 characters.
False by default.
allow_ip_domain (bool) - if True, the domain part of the email
can be a valid IPv4 or IPv6 address.
"""
self.domain_whitelist = domain_whitelist or []
self.allow_utf8_user = allow_utf8_user
self.allow_ip_domain = allow_ip_domain
super(EmailField, self).__init__(*args, **kwargs)
def validate_user_part(self, user_part):
"""Validate the user part of the email address. Return True if
valid and False otherwise.
"""
if self.allow_utf8_user:
return self.UTF8_USER_REGEX.match(user_part)
return self.USER_REGEX.match(user_part)
def validate_domain_part(self, domain_part):
"""Validate the domain part of the email address. Return True if
valid and False otherwise.
"""
# Skip domain validation if it's in the whitelist.
if domain_part in self.domain_whitelist:
return True
if self.DOMAIN_REGEX.match(domain_part):
return True
# Validate IPv4/IPv6, e.g. user@[192.168.0.1]
if (
self.allow_ip_domain and
domain_part[0] == '[' and
domain_part[-1] == ']'
):
for addr_family in (socket.AF_INET, socket.AF_INET6):
try:
socket.inet_pton(addr_family, domain_part[1:-1])
return True
except (socket.error, UnicodeEncodeError):
pass
return False
def validate(self, value): def validate(self, value):
if not EmailField.EMAIL_REGEX.match(value):
self.error('Invalid email address: %s' % value)
super(EmailField, self).validate(value) super(EmailField, self).validate(value)
if '@' not in value:
self.error(self.error_msg % value)
user_part, domain_part = value.rsplit('@', 1)
# Validate the user part.
if not self.validate_user_part(user_part):
self.error(self.error_msg % value)
# Validate the domain and, if invalid, see if it's IDN-encoded.
if not self.validate_domain_part(domain_part):
try:
domain_part = domain_part.encode('idna').decode('ascii')
except UnicodeError:
self.error(self.error_msg % value)
else:
if not self.validate_domain_part(domain_part):
self.error(self.error_msg % value)
class IntField(BaseField): class IntField(BaseField):
"""32-bit integer field.""" """32-bit integer field."""
@@ -619,6 +705,14 @@ class DynamicField(BaseField):
Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data"""
def __init__(self, container_class=dict, *args, **kwargs):
self._container_cls = container_class
if not issubclass(self._container_cls, Mapping):
self.error('The class that is specified in `container_class` parameter '
'must be a subclass of `dict`.')
super(DynamicField, self).__init__(*args, **kwargs)
def to_mongo(self, value, use_db_field=True, fields=None): def to_mongo(self, value, use_db_field=True, fields=None):
"""Convert a Python type to a MongoDB compatible type. """Convert a Python type to a MongoDB compatible type.
""" """
@@ -644,7 +738,7 @@ class DynamicField(BaseField):
is_list = True is_list = True
value = {k: v for k, v in enumerate(value)} value = {k: v for k, v in enumerate(value)}
data = {} data = self._container_cls()
for k, v in value.iteritems(): for k, v in value.iteritems():
data[k] = self.to_mongo(v, use_db_field, fields) data[k] = self.to_mongo(v, use_db_field, fields)
@@ -998,8 +1092,8 @@ class ReferenceField(BaseField):
def validate(self, value): def validate(self, value):
if not isinstance(value, (self.document_type, DBRef)): if not isinstance(value, (self.document_type, DBRef, ObjectId)):
self.error('A ReferenceField only accepts DBRef or documents') self.error('A ReferenceField only accepts DBRef, ObjectId or documents')
if isinstance(value, Document) and value.id is None: if isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been ' self.error('You can only reference documents once they have been '

View File

@@ -158,44 +158,49 @@ class BaseQuerySet(object):
# self._cursor # self._cursor
def __getitem__(self, key): def __getitem__(self, key):
"""Support skip and limit using getitem and slicing syntax.""" """Return a document instance corresponding to a given index if
the key is an integer. If the key is a slice, translate its
bounds into a skip and a limit, and return a cloned queryset
with that skip/limit applied. For example:
>>> User.objects[0]
<User: User object>
>>> User.objects[1:3]
[<User: User object>, <User: User object>]
"""
queryset = self.clone() queryset = self.clone()
# Slice provided # Handle a slice
if isinstance(key, slice): if isinstance(key, slice):
try: queryset._cursor_obj = queryset._cursor[key]
queryset._cursor_obj = queryset._cursor[key] queryset._skip, queryset._limit = key.start, key.stop
queryset._skip, queryset._limit = key.start, key.stop if key.start and key.stop:
if key.start and key.stop: queryset._limit = key.stop - key.start
queryset._limit = key.stop - key.start
except IndexError as err:
# PyMongo raises an error if key.start == key.stop, catch it,
# bin it, kill it.
start = key.start or 0
if start >= 0 and key.stop >= 0 and key.step is None:
if start == key.stop:
queryset.limit(0)
queryset._skip = key.start
queryset._limit = key.stop - start
return queryset
raise err
# Allow further QuerySet modifications to be performed # Allow further QuerySet modifications to be performed
return queryset return queryset
# Integer index provided
# Handle an index
elif isinstance(key, int): elif isinstance(key, int):
if queryset._scalar: if queryset._scalar:
return queryset._get_scalar( return queryset._get_scalar(
queryset._document._from_son(queryset._cursor[key], queryset._document._from_son(
_auto_dereference=self._auto_dereference, queryset._cursor[key],
only_fields=self.only_fields)) _auto_dereference=self._auto_dereference,
only_fields=self.only_fields
)
)
if queryset._as_pymongo: if queryset._as_pymongo:
return queryset._get_as_pymongo(queryset._cursor[key]) return queryset._get_as_pymongo(queryset._cursor[key])
return queryset._document._from_son(queryset._cursor[key],
_auto_dereference=self._auto_dereference,
only_fields=self.only_fields)
raise AttributeError return queryset._document._from_son(
queryset._cursor[key],
_auto_dereference=self._auto_dereference,
only_fields=self.only_fields
)
raise AttributeError('Provide a slice or an integer index')
def __iter__(self): def __iter__(self):
raise NotImplementedError raise NotImplementedError
@@ -286,7 +291,7 @@ class BaseQuerySet(object):
.. versionadded:: 0.4 .. versionadded:: 0.4
""" """
return self._document(**kwargs).save() return self._document(**kwargs).save(force_insert=True)
def first(self): def first(self):
"""Retrieve the first object matching the query.""" """Retrieve the first object matching the query."""

View File

@@ -412,7 +412,6 @@ class IndexesTest(unittest.TestCase):
User.ensure_indexes() User.ensure_indexes()
info = User.objects._collection.index_information() info = User.objects._collection.index_information()
self.assertEqual(sorted(info.keys()), ['_cls_1_user_guid_1', '_id_']) self.assertEqual(sorted(info.keys()), ['_cls_1_user_guid_1', '_id_'])
User.drop_collection()
def test_embedded_document_index(self): def test_embedded_document_index(self):
"""Tests settings an index on an embedded document """Tests settings an index on an embedded document
@@ -434,7 +433,6 @@ class IndexesTest(unittest.TestCase):
info = BlogPost.objects._collection.index_information() info = BlogPost.objects._collection.index_information()
self.assertEqual(sorted(info.keys()), ['_id_', 'date.yr_-1']) self.assertEqual(sorted(info.keys()), ['_id_', 'date.yr_-1'])
BlogPost.drop_collection()
def test_list_embedded_document_index(self): def test_list_embedded_document_index(self):
"""Ensure list embedded documents can be indexed """Ensure list embedded documents can be indexed
@@ -461,7 +459,6 @@ class IndexesTest(unittest.TestCase):
post1 = BlogPost(title="Embedded Indexes tests in place", post1 = BlogPost(title="Embedded Indexes tests in place",
tags=[Tag(name="about"), Tag(name="time")]) tags=[Tag(name="about"), Tag(name="time")])
post1.save() post1.save()
BlogPost.drop_collection()
def test_recursive_embedded_objects_dont_break_indexes(self): def test_recursive_embedded_objects_dont_break_indexes(self):
@@ -622,8 +619,6 @@ class IndexesTest(unittest.TestCase):
post3 = BlogPost(title='test3', date=Date(year=2010), slug='test') post3 = BlogPost(title='test3', date=Date(year=2010), slug='test')
self.assertRaises(OperationError, post3.save) self.assertRaises(OperationError, post3.save)
BlogPost.drop_collection()
def test_unique_embedded_document(self): def test_unique_embedded_document(self):
"""Ensure that uniqueness constraints are applied to fields on embedded documents. """Ensure that uniqueness constraints are applied to fields on embedded documents.
""" """
@@ -651,8 +646,6 @@ class IndexesTest(unittest.TestCase):
sub=SubDocument(year=2010, slug='test')) sub=SubDocument(year=2010, slug='test'))
self.assertRaises(NotUniqueError, post3.save) self.assertRaises(NotUniqueError, post3.save)
BlogPost.drop_collection()
def test_unique_embedded_document_in_list(self): def test_unique_embedded_document_in_list(self):
""" """
Ensure that the uniqueness constraints are applied to fields in Ensure that the uniqueness constraints are applied to fields in
@@ -683,8 +676,6 @@ class IndexesTest(unittest.TestCase):
self.assertRaises(NotUniqueError, post2.save) self.assertRaises(NotUniqueError, post2.save)
BlogPost.drop_collection()
def test_unique_with_embedded_document_and_embedded_unique(self): def test_unique_with_embedded_document_and_embedded_unique(self):
"""Ensure that uniqueness constraints are applied to fields on """Ensure that uniqueness constraints are applied to fields on
embedded documents. And work with unique_with as well. embedded documents. And work with unique_with as well.
@@ -718,8 +709,6 @@ class IndexesTest(unittest.TestCase):
sub=SubDocument(year=2009, slug='test-1')) sub=SubDocument(year=2009, slug='test-1'))
self.assertRaises(NotUniqueError, post3.save) self.assertRaises(NotUniqueError, post3.save)
BlogPost.drop_collection()
def test_ttl_indexes(self): def test_ttl_indexes(self):
class Log(Document): class Log(Document):
@@ -759,13 +748,11 @@ class IndexesTest(unittest.TestCase):
raise AssertionError("We saved a dupe!") raise AssertionError("We saved a dupe!")
except NotUniqueError: except NotUniqueError:
pass pass
Customer.drop_collection()
def test_unique_and_primary(self): def test_unique_and_primary(self):
"""If you set a field as primary, then unexpected behaviour can occur. """If you set a field as primary, then unexpected behaviour can occur.
You won't create a duplicate but you will update an existing document. You won't create a duplicate but you will update an existing document.
""" """
class User(Document): class User(Document):
name = StringField(primary_key=True, unique=True) name = StringField(primary_key=True, unique=True)
password = StringField() password = StringField()
@@ -781,8 +768,23 @@ class IndexesTest(unittest.TestCase):
self.assertEqual(User.objects.count(), 1) self.assertEqual(User.objects.count(), 1)
self.assertEqual(User.objects.get().password, 'secret2') self.assertEqual(User.objects.get().password, 'secret2')
def test_unique_and_primary_create(self):
"""Create a new record with a duplicate primary key
throws an exception
"""
class User(Document):
name = StringField(primary_key=True)
password = StringField()
User.drop_collection() User.drop_collection()
User.objects.create(name='huangz', password='secret')
with self.assertRaises(NotUniqueError):
User.objects.create(name='huangz', password='secret2')
self.assertEqual(User.objects.count(), 1)
self.assertEqual(User.objects.get().password, 'secret')
def test_index_with_pk(self): def test_index_with_pk(self):
"""Ensure you can use `pk` as part of a query""" """Ensure you can use `pk` as part of a query"""

View File

@@ -28,8 +28,6 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
__all__ = ("InstanceTest",) __all__ = ("InstanceTest",)
class InstanceTest(unittest.TestCase): class InstanceTest(unittest.TestCase):
def setUp(self): def setUp(self):
@@ -72,8 +70,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(field._instance, instance) self.assertEqual(field._instance, instance)
def test_capped_collection(self): def test_capped_collection(self):
"""Ensure that capped collections work properly. """Ensure that capped collections work properly."""
"""
class Log(Document): class Log(Document):
date = DateTimeField(default=datetime.now) date = DateTimeField(default=datetime.now)
meta = { meta = {
@@ -181,8 +178,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual('<Article: привет мир>', repr(doc)) self.assertEqual('<Article: привет мир>', repr(doc))
def test_repr_none(self): def test_repr_none(self):
"""Ensure None values handled correctly """Ensure None values are handled correctly."""
"""
class Article(Document): class Article(Document):
title = StringField() title = StringField()
@@ -190,25 +186,23 @@ class InstanceTest(unittest.TestCase):
return None return None
doc = Article(title=u'привет мир') doc = Article(title=u'привет мир')
self.assertEqual('<Article: None>', repr(doc)) self.assertEqual('<Article: None>', repr(doc))
def test_queryset_resurrects_dropped_collection(self): def test_queryset_resurrects_dropped_collection(self):
self.Person.drop_collection() self.Person.drop_collection()
self.assertEqual([], list(self.Person.objects())) self.assertEqual([], list(self.Person.objects()))
# Ensure works correctly with inhertited classes
class Actor(self.Person): class Actor(self.Person):
pass pass
# Ensure works correctly with inhertited classes
Actor.objects() Actor.objects()
self.Person.drop_collection() self.Person.drop_collection()
self.assertEqual([], list(Actor.objects())) self.assertEqual([], list(Actor.objects()))
def test_polymorphic_references(self): def test_polymorphic_references(self):
"""Ensure that the correct subclasses are returned from a query when """Ensure that the correct subclasses are returned from a query
using references / generic references when using references / generic references
""" """
class Animal(Document): class Animal(Document):
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
@@ -258,9 +252,6 @@ class InstanceTest(unittest.TestCase):
classes = [a.__class__ for a in Zoo.objects.first().animals] classes = [a.__class__ for a in Zoo.objects.first().animals]
self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human])
Zoo.drop_collection()
Animal.drop_collection()
def test_reference_inheritance(self): def test_reference_inheritance(self):
class Stats(Document): class Stats(Document):
created = DateTimeField(default=datetime.now) created = DateTimeField(default=datetime.now)
@@ -287,8 +278,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(list_stats, CompareStats.objects.first().stats) self.assertEqual(list_stats, CompareStats.objects.first().stats)
def test_db_field_load(self): def test_db_field_load(self):
"""Ensure we load data correctly """Ensure we load data correctly from the right db field."""
"""
class Person(Document): class Person(Document):
name = StringField(required=True) name = StringField(required=True)
_rank = StringField(required=False, db_field="rank") _rank = StringField(required=False, db_field="rank")
@@ -307,8 +297,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Person.objects.get(name="Fred").rank, "Private") self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
def test_db_embedded_doc_field_load(self): def test_db_embedded_doc_field_load(self):
"""Ensure we load embedded document data correctly """Ensure we load embedded document data correctly."""
"""
class Rank(EmbeddedDocument): class Rank(EmbeddedDocument):
title = StringField(required=True) title = StringField(required=True)
@@ -333,8 +322,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Person.objects.get(name="Fred").rank, "Private") self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
def test_custom_id_field(self): def test_custom_id_field(self):
"""Ensure that documents may be created with custom primary keys. """Ensure that documents may be created with custom primary keys."""
"""
class User(Document): class User(Document):
username = StringField(primary_key=True) username = StringField(primary_key=True)
name = StringField() name = StringField()
@@ -382,10 +370,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(user_son['_id'], 'mongo') self.assertEqual(user_son['_id'], 'mongo')
self.assertTrue('username' not in user_son['_id']) self.assertTrue('username' not in user_son['_id'])
User.drop_collection()
def test_document_not_registered(self): def test_document_not_registered(self):
class Place(Document): class Place(Document):
name = StringField() name = StringField()
@@ -407,7 +392,6 @@ class InstanceTest(unittest.TestCase):
list(Place.objects.all()) list(Place.objects.all())
def test_document_registry_regressions(self): def test_document_registry_regressions(self):
class Location(Document): class Location(Document):
name = StringField() name = StringField()
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
@@ -421,18 +405,16 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Area, get_document("Location.Area")) self.assertEqual(Area, get_document("Location.Area"))
def test_creation(self): def test_creation(self):
"""Ensure that document may be created using keyword arguments. """Ensure that document may be created using keyword arguments."""
"""
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 30) self.assertEqual(person.age, 30)
def test_to_dbref(self): def test_to_dbref(self):
"""Ensure that you can get a dbref of a document""" """Ensure that you can get a dbref of a document."""
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
self.assertRaises(OperationError, person.to_dbref) self.assertRaises(OperationError, person.to_dbref)
person.save() person.save()
person.to_dbref() person.to_dbref()
def test_save_abstract_document(self): def test_save_abstract_document(self):
@@ -445,8 +427,7 @@ class InstanceTest(unittest.TestCase):
Doc(name='aaa').save() Doc(name='aaa').save()
def test_reload(self): def test_reload(self):
"""Ensure that attributes may be reloaded. """Ensure that attributes may be reloaded."""
"""
person = self.Person(name="Test User", age=20) person = self.Person(name="Test User", age=20)
person.save() person.save()
@@ -479,7 +460,6 @@ class InstanceTest(unittest.TestCase):
doc = Animal(superphylum='Deuterostomia') doc = Animal(superphylum='Deuterostomia')
doc.save() doc.save()
doc.reload() doc.reload()
Animal.drop_collection()
def test_reload_sharded_nested(self): def test_reload_sharded_nested(self):
class SuperPhylum(EmbeddedDocument): class SuperPhylum(EmbeddedDocument):
@@ -493,11 +473,9 @@ class InstanceTest(unittest.TestCase):
doc = Animal(superphylum=SuperPhylum(name='Deuterostomia')) doc = Animal(superphylum=SuperPhylum(name='Deuterostomia'))
doc.save() doc.save()
doc.reload() doc.reload()
Animal.drop_collection()
def test_reload_referencing(self): def test_reload_referencing(self):
"""Ensures reloading updates weakrefs correctly """Ensures reloading updates weakrefs correctly."""
"""
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
dict_field = DictField() dict_field = DictField()
list_field = ListField() list_field = ListField()
@@ -569,8 +547,7 @@ class InstanceTest(unittest.TestCase):
self.assertFalse("Threw wrong exception") self.assertFalse("Threw wrong exception")
def test_reload_of_non_strict_with_special_field_name(self): def test_reload_of_non_strict_with_special_field_name(self):
"""Ensures reloading works for documents with meta strict == False """Ensures reloading works for documents with meta strict == False."""
"""
class Post(Document): class Post(Document):
meta = { meta = {
'strict': False 'strict': False
@@ -591,8 +568,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(post.items, ["more lorem", "even more ipsum"]) self.assertEqual(post.items, ["more lorem", "even more ipsum"])
def test_dictionary_access(self): def test_dictionary_access(self):
"""Ensure that dictionary-style field access works properly. """Ensure that dictionary-style field access works properly."""
"""
person = self.Person(name='Test User', age=30, job=self.Job()) person = self.Person(name='Test User', age=30, job=self.Job())
self.assertEqual(person['name'], 'Test User') self.assertEqual(person['name'], 'Test User')
@@ -634,8 +610,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(sub_doc.to_mongo().keys(), ['id']) self.assertEqual(sub_doc.to_mongo().keys(), ['id'])
def test_embedded_document(self): def test_embedded_document(self):
"""Ensure that embedded documents are set up correctly. """Ensure that embedded documents are set up correctly."""
"""
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
@@ -643,8 +618,7 @@ class InstanceTest(unittest.TestCase):
self.assertFalse('id' in Comment._fields) self.assertFalse('id' in Comment._fields)
def test_embedded_document_instance(self): def test_embedded_document_instance(self):
"""Ensure that embedded documents can reference parent instance """Ensure that embedded documents can reference parent instance."""
"""
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
string = StringField() string = StringField()
@@ -652,6 +626,7 @@ class InstanceTest(unittest.TestCase):
embedded_field = EmbeddedDocumentField(Embedded) embedded_field = EmbeddedDocumentField(Embedded)
Doc.drop_collection() Doc.drop_collection()
doc = Doc(embedded_field=Embedded(string="Hi")) doc = Doc(embedded_field=Embedded(string="Hi"))
self.assertHasInstance(doc.embedded_field, doc) self.assertHasInstance(doc.embedded_field, doc)
@@ -661,7 +636,8 @@ class InstanceTest(unittest.TestCase):
def test_embedded_document_complex_instance(self): def test_embedded_document_complex_instance(self):
"""Ensure that embedded documents in complex fields can reference """Ensure that embedded documents in complex fields can reference
parent instance""" parent instance.
"""
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
string = StringField() string = StringField()
@@ -677,8 +653,7 @@ class InstanceTest(unittest.TestCase):
self.assertHasInstance(doc.embedded_field[0], doc) self.assertHasInstance(doc.embedded_field[0], doc)
def test_embedded_document_complex_instance_no_use_db_field(self): def test_embedded_document_complex_instance_no_use_db_field(self):
"""Ensure that use_db_field is propagated to list of Emb Docs """Ensure that use_db_field is propagated to list of Emb Docs."""
"""
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
string = StringField(db_field='s') string = StringField(db_field='s')
@@ -690,7 +665,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(d['embedded_field'], [{'string': 'Hi'}]) self.assertEqual(d['embedded_field'], [{'string': 'Hi'}])
def test_instance_is_set_on_setattr(self): def test_instance_is_set_on_setattr(self):
class Email(EmbeddedDocument): class Email(EmbeddedDocument):
email = EmailField() email = EmailField()
@@ -698,6 +672,7 @@ class InstanceTest(unittest.TestCase):
email = EmbeddedDocumentField(Email) email = EmbeddedDocumentField(Email)
Account.drop_collection() Account.drop_collection()
acc = Account() acc = Account()
acc.email = Email(email='test@example.com') acc.email = Email(email='test@example.com')
self.assertHasInstance(acc._data["email"], acc) self.assertHasInstance(acc._data["email"], acc)
@@ -707,7 +682,6 @@ class InstanceTest(unittest.TestCase):
self.assertHasInstance(acc1._data["email"], acc1) self.assertHasInstance(acc1._data["email"], acc1)
def test_instance_is_set_on_setattr_on_embedded_document_list(self): def test_instance_is_set_on_setattr_on_embedded_document_list(self):
class Email(EmbeddedDocument): class Email(EmbeddedDocument):
email = EmailField() email = EmailField()
@@ -853,32 +827,28 @@ class InstanceTest(unittest.TestCase):
self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())])
def test_save(self): def test_save(self):
"""Ensure that a document may be saved in the database. """Ensure that a document may be saved in the database."""
"""
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30) person = self.Person(name='Test User', age=30)
person.save() person.save()
# Ensure that the object is in the database # Ensure that the object is in the database
collection = self.db[self.Person._get_collection_name()] collection = self.db[self.Person._get_collection_name()]
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['name'], 'Test User')
self.assertEqual(person_obj['age'], 30) self.assertEqual(person_obj['age'], 30)
self.assertEqual(person_obj['_id'], person.id) self.assertEqual(person_obj['_id'], person.id)
# Test skipping validation on save
# Test skipping validation on save
class Recipient(Document): class Recipient(Document):
email = EmailField(required=True) email = EmailField(required=True)
recipient = Recipient(email='root@localhost') recipient = Recipient(email='not-an-email')
self.assertRaises(ValidationError, recipient.save) self.assertRaises(ValidationError, recipient.save)
recipient.save(validate=False)
try:
recipient.save(validate=False)
except ValidationError:
self.fail()
def test_save_to_a_value_that_equates_to_false(self): def test_save_to_a_value_that_equates_to_false(self):
class Thing(EmbeddedDocument): class Thing(EmbeddedDocument):
count = IntField() count = IntField()
@@ -898,7 +868,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(user.thing.count, 0) self.assertEqual(user.thing.count, 0)
def test_save_max_recursion_not_hit(self): def test_save_max_recursion_not_hit(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@@ -924,7 +893,6 @@ class InstanceTest(unittest.TestCase):
p0.save() p0.save()
def test_save_max_recursion_not_hit_with_file_field(self): def test_save_max_recursion_not_hit_with_file_field(self):
class Foo(Document): class Foo(Document):
name = StringField() name = StringField()
picture = FileField() picture = FileField()
@@ -948,7 +916,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture)
def test_save_cascades(self): def test_save_cascades(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@@ -971,7 +938,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_save_cascade_kwargs(self): def test_save_cascade_kwargs(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@@ -992,7 +958,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p2.parent.name) self.assertEqual(p1.name, p2.parent.name)
def test_save_cascade_meta_false(self): def test_save_cascade_meta_false(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@@ -1021,7 +986,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_save_cascade_meta_true(self): def test_save_cascade_meta_true(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@@ -1046,7 +1010,6 @@ class InstanceTest(unittest.TestCase):
self.assertNotEqual(p1.name, p.parent.name) self.assertNotEqual(p1.name, p.parent.name)
def test_save_cascades_generically(self): def test_save_cascades_generically(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = GenericReferenceField() parent = GenericReferenceField()
@@ -1072,7 +1035,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_save_atomicity_condition(self): def test_save_atomicity_condition(self):
class Widget(Document): class Widget(Document):
toggle = BooleanField(default=False) toggle = BooleanField(default=False)
count = IntField(default=0) count = IntField(default=0)
@@ -1150,7 +1112,8 @@ class InstanceTest(unittest.TestCase):
def test_update(self): def test_update(self):
"""Ensure that an existing document is updated instead of be """Ensure that an existing document is updated instead of be
overwritten.""" overwritten.
"""
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30) person = self.Person(name='Test User', age=30)
person.save() person.save()
@@ -1254,7 +1217,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(2, self.Person.objects.count()) self.assertEqual(2, self.Person.objects.count())
def test_can_save_if_not_included(self): def test_can_save_if_not_included(self):
class EmbeddedDoc(EmbeddedDocument): class EmbeddedDoc(EmbeddedDocument):
pass pass
@@ -1341,10 +1303,7 @@ class InstanceTest(unittest.TestCase):
doc2.update(set__name=doc1.name) doc2.update(set__name=doc1.name)
def test_embedded_update(self): def test_embedded_update(self):
""" """Test update on `EmbeddedDocumentField` fields."""
Test update on `EmbeddedDocumentField` fields
"""
class Page(EmbeddedDocument): class Page(EmbeddedDocument):
log_message = StringField(verbose_name="Log message", log_message = StringField(verbose_name="Log message",
required=True) required=True)
@@ -1365,11 +1324,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(site.page.log_message, "Error: Dummy message") self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_embedded_update_db_field(self): def test_embedded_update_db_field(self):
"""Test update on `EmbeddedDocumentField` fields when db_field
is other than default.
""" """
Test update on `EmbeddedDocumentField` fields when db_field is other
than default.
"""
class Page(EmbeddedDocument): class Page(EmbeddedDocument):
log_message = StringField(verbose_name="Log message", log_message = StringField(verbose_name="Log message",
db_field="page_log_message", db_field="page_log_message",
@@ -1392,9 +1349,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(site.page.log_message, "Error: Dummy message") self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_save_only_changed_fields(self): def test_save_only_changed_fields(self):
"""Ensure save only sets / unsets changed fields """Ensure save only sets / unsets changed fields."""
"""
class User(self.Person): class User(self.Person):
active = BooleanField(default=True) active = BooleanField(default=True)
@@ -1514,8 +1469,8 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(q, 3) self.assertEqual(q, 3)
def test_set_unset_one_operation(self): def test_set_unset_one_operation(self):
"""Ensure that $set and $unset actions are performed in the same """Ensure that $set and $unset actions are performed in the
operation. same operation.
""" """
class FooBar(Document): class FooBar(Document):
foo = StringField(default=None) foo = StringField(default=None)
@@ -1536,9 +1491,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(1, q) self.assertEqual(1, q)
def test_save_only_changed_fields_recursive(self): def test_save_only_changed_fields_recursive(self):
"""Ensure save only sets / unsets changed fields """Ensure save only sets / unsets changed fields."""
"""
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
published = BooleanField(default=True) published = BooleanField(default=True)
@@ -1578,8 +1531,7 @@ class InstanceTest(unittest.TestCase):
self.assertFalse(person.comments_dict['first_post'].published) self.assertFalse(person.comments_dict['first_post'].published)
def test_delete(self): def test_delete(self):
"""Ensure that document may be deleted using the delete method. """Ensure that document may be deleted using the delete method."""
"""
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
person.save() person.save()
self.assertEqual(self.Person.objects.count(), 1) self.assertEqual(self.Person.objects.count(), 1)
@@ -1587,33 +1539,34 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(self.Person.objects.count(), 0) self.assertEqual(self.Person.objects.count(), 0)
def test_save_custom_id(self): def test_save_custom_id(self):
"""Ensure that a document may be saved with a custom _id. """Ensure that a document may be saved with a custom _id."""
"""
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30, person = self.Person(name='Test User', age=30,
id='497ce96f395f2f052a494fd4') id='497ce96f395f2f052a494fd4')
person.save() person.save()
# Ensure that the object is in the database with the correct _id # Ensure that the object is in the database with the correct _id
collection = self.db[self.Person._get_collection_name()] collection = self.db[self.Person._get_collection_name()]
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_custom_pk(self): def test_save_custom_pk(self):
""" """Ensure that a document may be saved with a custom _id using
Ensure that a document may be saved with a custom _id using pk alias. pk alias.
""" """
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30, person = self.Person(name='Test User', age=30,
pk='497ce96f395f2f052a494fd4') pk='497ce96f395f2f052a494fd4')
person.save() person.save()
# Ensure that the object is in the database with the correct _id # Ensure that the object is in the database with the correct _id
collection = self.db[self.Person._get_collection_name()] collection = self.db[self.Person._get_collection_name()]
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_list(self): def test_save_list(self):
"""Ensure that a list field may be properly saved. """Ensure that a list field may be properly saved."""
"""
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
@@ -1636,8 +1589,6 @@ class InstanceTest(unittest.TestCase):
for comment_obj, comment in zip(post_obj['comments'], comments): for comment_obj, comment in zip(post_obj['comments'], comments):
self.assertEqual(comment_obj['content'], comment['content']) self.assertEqual(comment_obj['content'], comment['content'])
BlogPost.drop_collection()
def test_list_search_by_embedded(self): def test_list_search_by_embedded(self):
class User(Document): class User(Document):
username = StringField(required=True) username = StringField(required=True)
@@ -1697,8 +1648,8 @@ class InstanceTest(unittest.TestCase):
list(Page.objects.filter(comments__user=u3))) list(Page.objects.filter(comments__user=u3)))
def test_save_embedded_document(self): def test_save_embedded_document(self):
"""Ensure that a document with an embedded document field may be """Ensure that a document with an embedded document field may
saved in the database. be saved in the database.
""" """
class EmployeeDetails(EmbeddedDocument): class EmployeeDetails(EmbeddedDocument):
position = StringField() position = StringField()
@@ -1717,13 +1668,13 @@ class InstanceTest(unittest.TestCase):
employee_obj = collection.find_one({'name': 'Test Employee'}) employee_obj = collection.find_one({'name': 'Test Employee'})
self.assertEqual(employee_obj['name'], 'Test Employee') self.assertEqual(employee_obj['name'], 'Test Employee')
self.assertEqual(employee_obj['age'], 50) self.assertEqual(employee_obj['age'], 50)
# Ensure that the 'details' embedded object saved correctly # Ensure that the 'details' embedded object saved correctly
self.assertEqual(employee_obj['details']['position'], 'Developer') self.assertEqual(employee_obj['details']['position'], 'Developer')
def test_embedded_update_after_save(self): def test_embedded_update_after_save(self):
""" """Test update of `EmbeddedDocumentField` attached to a newly
Test update of `EmbeddedDocumentField` attached to a newly saved saved document.
document.
""" """
class Page(EmbeddedDocument): class Page(EmbeddedDocument):
log_message = StringField(verbose_name="Log message", log_message = StringField(verbose_name="Log message",
@@ -1744,8 +1695,8 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(site.page.log_message, "Error: Dummy message") self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_updating_an_embedded_document(self): def test_updating_an_embedded_document(self):
"""Ensure that a document with an embedded document field may be """Ensure that a document with an embedded document field may
saved in the database. be saved in the database.
""" """
class EmployeeDetails(EmbeddedDocument): class EmployeeDetails(EmbeddedDocument):
position = StringField() position = StringField()
@@ -1780,7 +1731,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(promoted_employee.details, None) self.assertEqual(promoted_employee.details, None)
def test_object_mixins(self): def test_object_mixins(self):
class NameMixin(object): class NameMixin(object):
name = StringField() name = StringField()
@@ -1819,9 +1769,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(t.count, 12) self.assertEqual(t.count, 12)
def test_save_reference(self): def test_save_reference(self):
"""Ensure that a document reference field may be saved in the database. """Ensure that a document reference field may be saved in the
database.
""" """
class BlogPost(Document): class BlogPost(Document):
meta = {'collection': 'blogpost_1'} meta = {'collection': 'blogpost_1'}
content = StringField() content = StringField()
@@ -1852,8 +1802,6 @@ class InstanceTest(unittest.TestCase):
author = list(self.Person.objects(name='Test User'))[-1] author = list(self.Person.objects(name='Test User'))[-1]
self.assertEqual(author.age, 25) self.assertEqual(author.age, 25)
BlogPost.drop_collection()
def test_duplicate_db_fields_raise_invalid_document_error(self): def test_duplicate_db_fields_raise_invalid_document_error(self):
"""Ensure a InvalidDocumentError is thrown if duplicate fields """Ensure a InvalidDocumentError is thrown if duplicate fields
declare the same db_field. declare the same db_field.
@@ -1864,7 +1812,7 @@ class InstanceTest(unittest.TestCase):
name2 = StringField(db_field='name') name2 = StringField(db_field='name')
def test_invalid_son(self): def test_invalid_son(self):
"""Raise an error if loading invalid data""" """Raise an error if loading invalid data."""
class Occurrence(EmbeddedDocument): class Occurrence(EmbeddedDocument):
number = IntField() number = IntField()
@@ -1887,9 +1835,9 @@ class InstanceTest(unittest.TestCase):
Word._from_son('this is not a valid SON dict') Word._from_son('this is not a valid SON dict')
def test_reverse_delete_rule_cascade_and_nullify(self): def test_reverse_delete_rule_cascade_and_nullify(self):
"""Ensure that a referenced document is also deleted upon deletion. """Ensure that a referenced document is also deleted upon
deletion.
""" """
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
@@ -1944,7 +1892,8 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Book.objects.count(), 0) self.assertEqual(Book.objects.count(), 0)
def test_reverse_delete_rule_with_shared_id_among_collections(self): def test_reverse_delete_rule_with_shared_id_among_collections(self):
"""Ensure that cascade delete rule doesn't mix id among collections. """Ensure that cascade delete rule doesn't mix id among
collections.
""" """
class User(Document): class User(Document):
id = IntField(primary_key=True) id = IntField(primary_key=True)
@@ -1975,10 +1924,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Book.objects.get(), book_2) self.assertEqual(Book.objects.get(), book_2)
def test_reverse_delete_rule_with_document_inheritance(self): def test_reverse_delete_rule_with_document_inheritance(self):
"""Ensure that a referenced document is also deleted upon deletion """Ensure that a referenced document is also deleted upon
of a child document. deletion of a child document.
""" """
class Writer(self.Person): class Writer(self.Person):
pass pass
@@ -2010,10 +1958,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(BlogPost.objects.count(), 0) self.assertEqual(BlogPost.objects.count(), 0)
def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): def test_reverse_delete_rule_cascade_and_nullify_complex_field(self):
"""Ensure that a referenced document is also deleted upon deletion for """Ensure that a referenced document is also deleted upon
complex fields. deletion for complex fields.
""" """
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
authors = ListField(ReferenceField( authors = ListField(ReferenceField(
@@ -2022,7 +1969,6 @@ class InstanceTest(unittest.TestCase):
self.Person, reverse_delete_rule=NULLIFY)) self.Person, reverse_delete_rule=NULLIFY))
self.Person.drop_collection() self.Person.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
author = self.Person(name='Test User') author = self.Person(name='Test User')
@@ -2046,10 +1992,10 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(BlogPost.objects.count(), 0) self.assertEqual(BlogPost.objects.count(), 0)
def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self):
""" ensure the pre_delete signal is triggered upon a cascading deletion """Ensure the pre_delete signal is triggered upon a cascading
setup a blog post with content, an author and editor deletion setup a blog post with content, an author and editor
delete the author which triggers deletion of blogpost via cascade delete the author which triggers deletion of blogpost via
blog post's pre_delete signal alters an editor attribute cascade blog post's pre_delete signal alters an editor attribute.
""" """
class Editor(self.Person): class Editor(self.Person):
review_queue = IntField(default=0) review_queue = IntField(default=0)
@@ -2077,6 +2023,7 @@ class InstanceTest(unittest.TestCase):
# delete the author, the post is also deleted due to the CASCADE rule # delete the author, the post is also deleted due to the CASCADE rule
author.delete() author.delete()
# the pre-delete signal should have decremented the editor's queue # the pre-delete signal should have decremented the editor's queue
editor = Editor.objects(name='Max P.').get() editor = Editor.objects(name='Max P.').get()
self.assertEqual(editor.review_queue, 0) self.assertEqual(editor.review_queue, 0)
@@ -2085,7 +2032,6 @@ class InstanceTest(unittest.TestCase):
"""Ensure that Bi-Directional relationships work with """Ensure that Bi-Directional relationships work with
reverse_delete_rule reverse_delete_rule
""" """
class Bar(Document): class Bar(Document):
content = StringField() content = StringField()
foo = ReferenceField('Foo') foo = ReferenceField('Foo')
@@ -2131,8 +2077,8 @@ class InstanceTest(unittest.TestCase):
mother = ReferenceField('Person', reverse_delete_rule=DENY) mother = ReferenceField('Person', reverse_delete_rule=DENY)
def test_reverse_delete_rule_cascade_recurs(self): def test_reverse_delete_rule_cascade_recurs(self):
"""Ensure that a chain of documents is also deleted upon cascaded """Ensure that a chain of documents is also deleted upon
deletion. cascaded deletion.
""" """
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
@@ -2162,15 +2108,10 @@ class InstanceTest(unittest.TestCase):
author.delete() author.delete()
self.assertEqual(Comment.objects.count(), 0) self.assertEqual(Comment.objects.count(), 0)
self.Person.drop_collection()
BlogPost.drop_collection()
Comment.drop_collection()
def test_reverse_delete_rule_deny(self): def test_reverse_delete_rule_deny(self):
"""Ensure that a document cannot be referenced if there are still """Ensure that a document cannot be referenced if there are
documents referring to it. still documents referring to it.
""" """
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=DENY) author = ReferenceField(self.Person, reverse_delete_rule=DENY)
@@ -2198,11 +2139,7 @@ class InstanceTest(unittest.TestCase):
author.delete() author.delete()
self.assertEqual(self.Person.objects.count(), 1) self.assertEqual(self.Person.objects.count(), 1)
self.Person.drop_collection()
BlogPost.drop_collection()
def subclasses_and_unique_keys_works(self): def subclasses_and_unique_keys_works(self):
class A(Document): class A(Document):
pass pass
@@ -2218,19 +2155,16 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(A.objects.count(), 2) self.assertEqual(A.objects.count(), 2)
self.assertEqual(B.objects.count(), 1) self.assertEqual(B.objects.count(), 1)
A.drop_collection()
B.drop_collection()
def test_document_hash(self): def test_document_hash(self):
"""Test document in list, dict, set """Test document in list, dict, set."""
"""
class User(Document): class User(Document):
pass pass
class BlogPost(Document): class BlogPost(Document):
pass pass
# Clear old datas # Clear old data
User.drop_collection() User.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
@@ -2242,17 +2176,18 @@ class InstanceTest(unittest.TestCase):
b1 = BlogPost.objects.create() b1 = BlogPost.objects.create()
b2 = BlogPost.objects.create() b2 = BlogPost.objects.create()
# in List # Make sure docs are properly identified in a list (__eq__ is used
# for the comparison).
all_user_list = list(User.objects.all()) all_user_list = list(User.objects.all())
self.assertTrue(u1 in all_user_list) self.assertTrue(u1 in all_user_list)
self.assertTrue(u2 in all_user_list) self.assertTrue(u2 in all_user_list)
self.assertTrue(u3 in all_user_list) self.assertTrue(u3 in all_user_list)
self.assertFalse(u4 in all_user_list) # New object self.assertTrue(u4 not in all_user_list) # New object
self.assertFalse(b1 in all_user_list) # Other object self.assertTrue(b1 not in all_user_list) # Other object
self.assertFalse(b2 in all_user_list) # Other object self.assertTrue(b2 not in all_user_list) # Other object
# in Dict # Make sure docs can be used as keys in a dict (__hash__ is used
# for hashing the docs).
all_user_dic = {} all_user_dic = {}
for u in User.objects.all(): for u in User.objects.all():
all_user_dic[u] = "OK" all_user_dic[u] = "OK"
@@ -2264,13 +2199,22 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(all_user_dic.get(b1, False), False) # Other object self.assertEqual(all_user_dic.get(b1, False), False) # Other object
self.assertEqual(all_user_dic.get(b2, False), False) # Other object self.assertEqual(all_user_dic.get(b2, False), False) # Other object
# in Set # Make sure docs are properly identified in a set (__hash__ is used
# for hashing the docs).
all_user_set = set(User.objects.all()) all_user_set = set(User.objects.all())
self.assertTrue(u1 in all_user_set) self.assertTrue(u1 in all_user_set)
self.assertTrue(u4 not in all_user_set)
self.assertTrue(b1 not in all_user_list)
self.assertTrue(b2 not in all_user_list)
# Make sure duplicate docs aren't accepted in the set
self.assertEqual(len(all_user_set), 3)
all_user_set.add(u1)
all_user_set.add(u2)
all_user_set.add(u3)
self.assertEqual(len(all_user_set), 3)
def test_picklable(self): def test_picklable(self):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded() pickle_doc.embedded = PickleEmbedded()
pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
@@ -2296,7 +2240,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) self.assertEqual(pickle_doc.lists, ["1", "2", "3"])
def test_regular_document_pickle(self): def test_regular_document_pickle(self):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
pickle_doc.save() pickle_doc.save()
@@ -2319,7 +2262,6 @@ class InstanceTest(unittest.TestCase):
fixtures.PickleTest = PickleTest fixtures.PickleTest = PickleTest
def test_dynamic_document_pickle(self): def test_dynamic_document_pickle(self):
pickle_doc = PickleDynamicTest( pickle_doc = PickleDynamicTest(
name="test", number=1, string="One", lists=['1', '2']) name="test", number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar") pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar")
@@ -2358,7 +2300,6 @@ class InstanceTest(unittest.TestCase):
validate = DictField() validate = DictField()
def test_mutating_documents(self): def test_mutating_documents(self):
class B(EmbeddedDocument): class B(EmbeddedDocument):
field1 = StringField(default='field1') field1 = StringField(default='field1')
@@ -2366,6 +2307,7 @@ class InstanceTest(unittest.TestCase):
b = EmbeddedDocumentField(B, default=lambda: B()) b = EmbeddedDocumentField(B, default=lambda: B())
A.drop_collection() A.drop_collection()
a = A() a = A()
a.save() a.save()
a.reload() a.reload()
@@ -2389,12 +2331,13 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(a.b.field2.c_field, 'new value') self.assertEqual(a.b.field2.c_field, 'new value')
def test_can_save_false_values(self): def test_can_save_false_values(self):
"""Ensures you can save False values on save""" """Ensures you can save False values on save."""
class Doc(Document): class Doc(Document):
foo = StringField() foo = StringField()
archived = BooleanField(default=False, required=True) archived = BooleanField(default=False, required=True)
Doc.drop_collection() Doc.drop_collection()
d = Doc() d = Doc()
d.save() d.save()
d.archived = False d.archived = False
@@ -2403,11 +2346,12 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Doc.objects(archived=False).count(), 1) self.assertEqual(Doc.objects(archived=False).count(), 1)
def test_can_save_false_values_dynamic(self): def test_can_save_false_values_dynamic(self):
"""Ensures you can save False values on dynamic docs""" """Ensures you can save False values on dynamic docs."""
class Doc(DynamicDocument): class Doc(DynamicDocument):
foo = StringField() foo = StringField()
Doc.drop_collection() Doc.drop_collection()
d = Doc() d = Doc()
d.save() d.save()
d.archived = False d.archived = False
@@ -2447,7 +2391,7 @@ class InstanceTest(unittest.TestCase):
Collection.update = orig_update Collection.update = orig_update
def test_db_alias_tests(self): def test_db_alias_tests(self):
""" DB Alias tests """ """DB Alias tests."""
# mongoenginetest - Is default connection alias from setUp() # mongoenginetest - Is default connection alias from setUp()
# Register Aliases # Register Aliases
register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-1', 'mongoenginetest2')
@@ -2509,8 +2453,7 @@ class InstanceTest(unittest.TestCase):
get_db("testdb-3")[AuthorBooks._get_collection_name()]) get_db("testdb-3")[AuthorBooks._get_collection_name()])
def test_db_alias_overrides(self): def test_db_alias_overrides(self):
"""db_alias can be overriden """Test db_alias can be overriden."""
"""
# Register a connection with db_alias testdb-2 # Register a connection with db_alias testdb-2
register_connection('testdb-2', 'mongoenginetest2') register_connection('testdb-2', 'mongoenginetest2')
@@ -2534,8 +2477,7 @@ class InstanceTest(unittest.TestCase):
B._get_collection().database.name) B._get_collection().database.name)
def test_db_alias_propagates(self): def test_db_alias_propagates(self):
"""db_alias propagates? """db_alias propagates?"""
"""
register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-1', 'mongoenginetest2')
class A(Document): class A(Document):
@@ -2548,8 +2490,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual('testdb-1', B._meta.get('db_alias')) self.assertEqual('testdb-1', B._meta.get('db_alias'))
def test_db_ref_usage(self): def test_db_ref_usage(self):
""" DB Ref usage in dict_fields""" """DB Ref usage in dict_fields."""
class User(Document): class User(Document):
name = StringField() name = StringField()
@@ -2784,7 +2725,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(user.thing._data['data'], [1, 2, 3]) self.assertEqual(user.thing._data['data'], [1, 2, 3])
def test_spaces_in_keys(self): def test_spaces_in_keys(self):
class Embedded(DynamicEmbeddedDocument): class Embedded(DynamicEmbeddedDocument):
pass pass
@@ -2873,7 +2813,6 @@ class InstanceTest(unittest.TestCase):
log.machine = "127.0.0.1" log.machine = "127.0.0.1"
def test_kwargs_simple(self): def test_kwargs_simple(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
name = StringField() name = StringField()
@@ -2893,7 +2832,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(classic_doc._data, dict_doc._data) self.assertEqual(classic_doc._data, dict_doc._data)
def test_kwargs_complex(self): def test_kwargs_complex(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
name = StringField() name = StringField()
@@ -2916,36 +2854,35 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(classic_doc._data, dict_doc._data) self.assertEqual(classic_doc._data, dict_doc._data)
def test_positional_creation(self): def test_positional_creation(self):
"""Ensure that document may be created using positional arguments. """Ensure that document may be created using positional arguments."""
"""
person = self.Person("Test User", 42) person = self.Person("Test User", 42)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_mixed_creation(self): def test_mixed_creation(self):
"""Ensure that document may be created using mixed arguments. """Ensure that document may be created using mixed arguments."""
"""
person = self.Person("Test User", age=42) person = self.Person("Test User", age=42)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_positional_creation_embedded(self): def test_positional_creation_embedded(self):
"""Ensure that embedded document may be created using positional arguments. """Ensure that embedded document may be created using positional
arguments.
""" """
job = self.Job("Test Job", 4) job = self.Job("Test Job", 4)
self.assertEqual(job.name, "Test Job") self.assertEqual(job.name, "Test Job")
self.assertEqual(job.years, 4) self.assertEqual(job.years, 4)
def test_mixed_creation_embedded(self): def test_mixed_creation_embedded(self):
"""Ensure that embedded document may be created using mixed arguments. """Ensure that embedded document may be created using mixed
arguments.
""" """
job = self.Job("Test Job", years=4) job = self.Job("Test Job", years=4)
self.assertEqual(job.name, "Test Job") self.assertEqual(job.name, "Test Job")
self.assertEqual(job.years, 4) self.assertEqual(job.years, 4)
def test_mixed_creation_dynamic(self): def test_mixed_creation_dynamic(self):
"""Ensure that document may be created using mixed arguments. """Ensure that document may be created using mixed arguments."""
"""
class Person(DynamicDocument): class Person(DynamicDocument):
name = StringField() name = StringField()
@@ -2954,14 +2891,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_bad_mixed_creation(self): def test_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments """Ensure that document gives correct error when duplicating
arguments.
""" """
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
return self.Person("Test User", 42, name="Bad User") return self.Person("Test User", 42, name="Bad User")
def test_data_contains_id_field(self): def test_data_contains_id_field(self):
"""Ensure that asking for _data returns 'id' """Ensure that asking for _data returns 'id'."""
"""
class Person(Document): class Person(Document):
name = StringField() name = StringField()
@@ -2973,7 +2910,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person._data.get('id'), person.id) self.assertEqual(person._data.get('id'), person.id)
def test_complex_nesting_document_and_embedded_document(self): def test_complex_nesting_document_and_embedded_document(self):
class Macro(EmbeddedDocument): class Macro(EmbeddedDocument):
value = DynamicField(default="UNDEFINED") value = DynamicField(default="UNDEFINED")
@@ -3016,7 +2952,6 @@ class InstanceTest(unittest.TestCase):
system.nodes["node"].parameters["param"].macros["test"].value) system.nodes["node"].parameters["param"].macros["test"].value)
def test_embedded_document_equality(self): def test_embedded_document_equality(self):
class Test(Document): class Test(Document):
field = StringField(required=True) field = StringField(required=True)
@@ -3202,8 +3137,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(idx, 2) self.assertEqual(idx, 2)
def test_falsey_pk(self): def test_falsey_pk(self):
"""Ensure that we can create and update a document with Falsey PK. """Ensure that we can create and update a document with Falsey PK."""
"""
class Person(Document): class Person(Document):
age = IntField(primary_key=True) age = IntField(primary_key=True)
height = FloatField() height = FloatField()

File diff suppressed because it is too large Load Diff

View File

@@ -4962,20 +4962,6 @@ class QuerySetTest(unittest.TestCase):
for p in Person.objects(): for p in Person.objects():
self.assertEqual(p.name, 'a') self.assertEqual(p.name, 'a')
def test_last_field_name_like_operator(self):
class EmbeddedItem(EmbeddedDocument):
type = StringField()
class Doc(Document):
item = EmbeddedDocumentField(EmbeddedItem)
Doc.drop_collection()
doc = Doc(item=EmbeddedItem(type="axe"))
doc.save()
self.assertEqual(1, Doc.objects(item__type__="axe").count())
def test_len_during_iteration(self): def test_len_during_iteration(self):
"""Tests that calling len on a queyset during iteration doesn't """Tests that calling len on a queyset during iteration doesn't
stop paging. stop paging.

View File

@@ -2,10 +2,15 @@
import unittest import unittest
from bson import DBRef, ObjectId from bson import DBRef, ObjectId
from collections import OrderedDict
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.context_managers import query_counter from mongoengine.context_managers import query_counter
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.base import TopLevelDocumentMetaclass
if IS_PYMONGO_3:
from bson import CodecOptions
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@@ -1287,5 +1292,70 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2) self.assertEqual(q, 2)
def test_dynamic_field_dereference(self):
class Merchandise(Document):
name = StringField()
price = IntField()
class Store(Document):
merchandises = DynamicField()
Merchandise.drop_collection()
Store.drop_collection()
merchandises = {
'#1': Merchandise(name='foo', price=100).save(),
'#2': Merchandise(name='bar', price=120).save(),
'#3': Merchandise(name='baz', price=110).save(),
}
Store(merchandises=merchandises).save()
store = Store.objects().first()
for obj in store.merchandises.values():
self.assertFalse(isinstance(obj, Merchandise))
store.select_related()
for obj in store.merchandises.values():
self.assertTrue(isinstance(obj, Merchandise))
def test_dynamic_field_dereference_with_ordering_guarantee_on_pymongo3(self):
# This is because 'codec_options' is supported on pymongo3 or later
if IS_PYMONGO_3:
class OrderedDocument(Document):
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
@classmethod
def _get_collection(cls):
collection = super(OrderedDocument, cls)._get_collection()
opts = CodecOptions(document_class=OrderedDict)
return collection.with_options(codec_options=opts)
class Merchandise(Document):
name = StringField()
price = IntField()
class Store(OrderedDocument):
merchandises = DynamicField(container_class=OrderedDict)
Merchandise.drop_collection()
Store.drop_collection()
merchandises = OrderedDict()
merchandises['#1'] = Merchandise(name='foo', price=100).save()
merchandises['#2'] = Merchandise(name='bar', price=120).save()
merchandises['#3'] = Merchandise(name='baz', price=110).save()
Store(merchandises=merchandises).save()
store = Store.objects().first()
store.select_related()
# confirms that the load data order is same with the one at storing
self.assertTrue(type(store.merchandises), OrderedDict)
self.assertEqual(','.join(store.merchandises.keys()), '#1,#2,#3')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()