Improve the health of this package (#1428)
This commit is contained in:
parent
3135b456be
commit
835d3c3d18
22
.landscape.yml
Normal file
22
.landscape.yml
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
pylint:
|
||||||
|
disable:
|
||||||
|
# We use this a lot (e.g. via document._meta)
|
||||||
|
- protected-access
|
||||||
|
|
||||||
|
options:
|
||||||
|
additional-builtins:
|
||||||
|
# add xrange and long as valid built-ins. In Python 3, xrange is
|
||||||
|
# translated into range and long is translated into int via 2to3 (see
|
||||||
|
# "use_2to3" in setup.py). This should be removed when we drop Python
|
||||||
|
# 2 support (which probably won't happen any time soon).
|
||||||
|
- xrange
|
||||||
|
- long
|
||||||
|
|
||||||
|
pyflakes:
|
||||||
|
disable:
|
||||||
|
# undefined variables are already covered by pylint (and exclude
|
||||||
|
# xrange & long)
|
||||||
|
- F821
|
||||||
|
|
||||||
|
ignore-paths:
|
||||||
|
- benchmark.py
|
@ -1,7 +1,6 @@
|
|||||||
language: python
|
language: python
|
||||||
|
|
||||||
python:
|
python:
|
||||||
- '2.6' # TODO remove in v0.11.0
|
|
||||||
- '2.7'
|
- '2.7'
|
||||||
- '3.3'
|
- '3.3'
|
||||||
- '3.4'
|
- '3.4'
|
||||||
@ -43,7 +42,11 @@ before_script:
|
|||||||
script:
|
script:
|
||||||
- tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage
|
- tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage
|
||||||
|
|
||||||
after_script: coveralls --verbose
|
# For now only submit coveralls for Python v2.7. Python v3.x currently shows
|
||||||
|
# 0% coverage. That's caused by 'use_2to3', which builds the py3-compatible
|
||||||
|
# code in a separate dir and runs tests on that.
|
||||||
|
after_script:
|
||||||
|
- if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then coveralls --verbose; fi
|
||||||
|
|
||||||
notifications:
|
notifications:
|
||||||
irc: irc.freenode.org#mongoengine
|
irc: irc.freenode.org#mongoengine
|
||||||
|
152
benchmark.py
152
benchmark.py
@ -1,118 +1,41 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
"""
|
||||||
|
Simple benchmark comparing PyMongo and MongoEngine.
|
||||||
|
|
||||||
|
Sample run on a mid 2015 MacBook Pro (commit b282511):
|
||||||
|
|
||||||
|
Benchmarking...
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
Creating 10000 dictionaries - Pymongo
|
||||||
|
2.58979988098
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
Creating 10000 dictionaries - Pymongo write_concern={"w": 0}
|
||||||
|
1.26657605171
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
Creating 10000 dictionaries - MongoEngine
|
||||||
|
8.4351580143
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
Creating 10000 dictionaries without continual assign - MongoEngine
|
||||||
|
7.20191693306
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True
|
||||||
|
6.31104588509
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True
|
||||||
|
6.07083487511
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False
|
||||||
|
5.97704291344
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False
|
||||||
|
5.9111430645
|
||||||
|
"""
|
||||||
|
|
||||||
import timeit
|
import timeit
|
||||||
|
|
||||||
|
|
||||||
def cprofile_main():
|
|
||||||
from pymongo import Connection
|
|
||||||
connection = Connection()
|
|
||||||
connection.drop_database('timeit_test')
|
|
||||||
connection.disconnect()
|
|
||||||
|
|
||||||
from mongoengine import Document, DictField, connect
|
|
||||||
connect("timeit_test")
|
|
||||||
|
|
||||||
class Noddy(Document):
|
|
||||||
fields = DictField()
|
|
||||||
|
|
||||||
for i in range(1):
|
|
||||||
noddy = Noddy()
|
|
||||||
for j in range(20):
|
|
||||||
noddy.fields["key" + str(j)] = "value " + str(j)
|
|
||||||
noddy.save()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
|
||||||
0.4 Performance Figures ...
|
|
||||||
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - Pymongo
|
|
||||||
3.86744189262
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine
|
|
||||||
6.23374891281
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, safe=False, validate=False
|
|
||||||
5.33027005196
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False
|
|
||||||
pass - No Cascade
|
|
||||||
|
|
||||||
0.5.X
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - Pymongo
|
|
||||||
3.89597702026
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine
|
|
||||||
21.7735359669
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, safe=False, validate=False
|
|
||||||
19.8670389652
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False
|
|
||||||
pass - No Cascade
|
|
||||||
|
|
||||||
0.6.X
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - Pymongo
|
|
||||||
3.81559205055
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine
|
|
||||||
10.0446798801
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, safe=False, validate=False
|
|
||||||
9.51354718208
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False
|
|
||||||
9.02567505836
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, force=True
|
|
||||||
8.44933390617
|
|
||||||
|
|
||||||
0.7.X
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - Pymongo
|
|
||||||
3.78801012039
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine
|
|
||||||
9.73050498962
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, safe=False, validate=False
|
|
||||||
8.33456707001
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False
|
|
||||||
8.37778115273
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, force=True
|
|
||||||
8.36906409264
|
|
||||||
0.8.X
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - Pymongo
|
|
||||||
3.69964408875
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - Pymongo write_concern={"w": 0}
|
|
||||||
3.5526599884
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine
|
|
||||||
7.00959801674
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries without continual assign - MongoEngine
|
|
||||||
5.60943293571
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade=True
|
|
||||||
6.715102911
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True
|
|
||||||
5.50644683838
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False
|
|
||||||
4.69851183891
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False
|
|
||||||
4.68946313858
|
|
||||||
----------------------------------------------------------------------------------------------------
|
|
||||||
"""
|
|
||||||
print("Benchmarking...")
|
print("Benchmarking...")
|
||||||
|
|
||||||
setup = """
|
setup = """
|
||||||
@ -131,7 +54,7 @@ noddy = db.noddy
|
|||||||
for i in range(10000):
|
for i in range(10000):
|
||||||
example = {'fields': {}}
|
example = {'fields': {}}
|
||||||
for j in range(20):
|
for j in range(20):
|
||||||
example['fields']["key"+str(j)] = "value "+str(j)
|
example['fields']['key' + str(j)] = 'value ' + str(j)
|
||||||
|
|
||||||
noddy.save(example)
|
noddy.save(example)
|
||||||
|
|
||||||
@ -146,9 +69,10 @@ myNoddys = noddy.find()
|
|||||||
|
|
||||||
stmt = """
|
stmt = """
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
from pymongo.write_concern import WriteConcern
|
||||||
connection = MongoClient()
|
connection = MongoClient()
|
||||||
|
|
||||||
db = connection.timeit_test
|
db = connection.get_database('timeit_test', write_concern=WriteConcern(w=0))
|
||||||
noddy = db.noddy
|
noddy = db.noddy
|
||||||
|
|
||||||
for i in range(10000):
|
for i in range(10000):
|
||||||
@ -156,7 +80,7 @@ for i in range(10000):
|
|||||||
for j in range(20):
|
for j in range(20):
|
||||||
example['fields']["key"+str(j)] = "value "+str(j)
|
example['fields']["key"+str(j)] = "value "+str(j)
|
||||||
|
|
||||||
noddy.save(example, write_concern={"w": 0})
|
noddy.save(example)
|
||||||
|
|
||||||
myNoddys = noddy.find()
|
myNoddys = noddy.find()
|
||||||
[n for n in myNoddys] # iterate
|
[n for n in myNoddys] # iterate
|
||||||
@ -171,10 +95,10 @@ myNoddys = noddy.find()
|
|||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
connection = MongoClient()
|
connection = MongoClient()
|
||||||
connection.drop_database('timeit_test')
|
connection.drop_database('timeit_test')
|
||||||
connection.disconnect()
|
connection.close()
|
||||||
|
|
||||||
from mongoengine import Document, DictField, connect
|
from mongoengine import Document, DictField, connect
|
||||||
connect("timeit_test")
|
connect('timeit_test')
|
||||||
|
|
||||||
class Noddy(Document):
|
class Noddy(Document):
|
||||||
fields = DictField()
|
fields = DictField()
|
||||||
|
@ -4,6 +4,13 @@ Changelog
|
|||||||
|
|
||||||
Development
|
Development
|
||||||
===========
|
===========
|
||||||
|
- (Fill this out as you fix issues and develop you features).
|
||||||
|
|
||||||
|
Changes in 0.11.0
|
||||||
|
=================
|
||||||
|
- BREAKING CHANGE: Renamed `ConnectionError` to `MongoEngineConnectionError` since the former is a built-in exception name in Python v3.x. #1428
|
||||||
|
- BREAKING CHANGE: Dropped Python 2.6 support. #1428
|
||||||
|
- BREAKING CHANGE: `from mongoengine.base import ErrorClass` won't work anymore for any error from `mongoengine.errors` (e.g. `ValidationError`). Use `from mongoengine.errors import ErrorClass instead`. #1428
|
||||||
- Fixed absent rounding for DecimalField when `force_string` is set. #1103
|
- Fixed absent rounding for DecimalField when `force_string` is set. #1103
|
||||||
|
|
||||||
Changes in 0.10.8
|
Changes in 0.10.8
|
||||||
|
@ -2,6 +2,32 @@
|
|||||||
Upgrading
|
Upgrading
|
||||||
#########
|
#########
|
||||||
|
|
||||||
|
0.11.0
|
||||||
|
******
|
||||||
|
This release includes a major rehaul of MongoEngine's code quality and
|
||||||
|
introduces a few breaking changes. It also touches many different parts of
|
||||||
|
the package and although all the changes have been tested and scrutinized,
|
||||||
|
you're encouraged to thorougly test the upgrade.
|
||||||
|
|
||||||
|
First breaking change involves renaming `ConnectionError` to `MongoEngineConnectionError`.
|
||||||
|
If you import or catch this exception, you'll need to rename it in your code.
|
||||||
|
|
||||||
|
Second breaking change drops Python v2.6 support. If you run MongoEngine on
|
||||||
|
that Python version, you'll need to upgrade it first.
|
||||||
|
|
||||||
|
Third breaking change drops an old backward compatibility measure where
|
||||||
|
`from mongoengine.base import ErrorClass` would work on top of
|
||||||
|
`from mongoengine.errors import ErrorClass` (where `ErrorClass` is e.g.
|
||||||
|
`ValidationError`). If you import any exceptions from `mongoengine.base`,
|
||||||
|
change it to `mongoengine.errors`.
|
||||||
|
|
||||||
|
0.10.8
|
||||||
|
******
|
||||||
|
This version fixed an issue where specifying a MongoDB URI host would override
|
||||||
|
more information than it should. These changes are minor, but they still
|
||||||
|
subtly modify the connection logic and thus you're encouraged to test your
|
||||||
|
MongoDB connection before shipping v0.10.8 in production.
|
||||||
|
|
||||||
0.10.7
|
0.10.7
|
||||||
******
|
******
|
||||||
|
|
||||||
|
@ -1,25 +1,35 @@
|
|||||||
import connection
|
# Import submodules so that we can expose their __all__
|
||||||
from connection import *
|
from mongoengine import connection
|
||||||
import document
|
from mongoengine import document
|
||||||
from document import *
|
from mongoengine import errors
|
||||||
import errors
|
from mongoengine import fields
|
||||||
from errors import *
|
from mongoengine import queryset
|
||||||
import fields
|
from mongoengine import signals
|
||||||
from fields import *
|
|
||||||
import queryset
|
# Import everything from each submodule so that it can be accessed via
|
||||||
from queryset import *
|
# mongoengine, e.g. instead of `from mongoengine.connection import connect`,
|
||||||
import signals
|
# users can simply use `from mongoengine import connect`, or even
|
||||||
from signals import *
|
# `from mongoengine import *` and then `connect('testdb')`.
|
||||||
|
from mongoengine.connection import *
|
||||||
|
from mongoengine.document import *
|
||||||
|
from mongoengine.errors import *
|
||||||
|
from mongoengine.fields import *
|
||||||
|
from mongoengine.queryset import *
|
||||||
|
from mongoengine.signals import *
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = (list(document.__all__) + list(fields.__all__) +
|
||||||
|
list(connection.__all__) + list(queryset.__all__) +
|
||||||
|
list(signals.__all__) + list(errors.__all__))
|
||||||
|
|
||||||
__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ +
|
|
||||||
list(queryset.__all__) + signals.__all__ + list(errors.__all__))
|
|
||||||
|
|
||||||
VERSION = (0, 10, 9)
|
VERSION = (0, 10, 9)
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
if isinstance(VERSION[-1], basestring):
|
"""Return the VERSION as a string, e.g. for VERSION == (0, 10, 7),
|
||||||
return '.'.join(map(str, VERSION[:-1])) + VERSION[-1]
|
return '0.10.7'.
|
||||||
|
"""
|
||||||
return '.'.join(map(str, VERSION))
|
return '.'.join(map(str, VERSION))
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,28 @@
|
|||||||
|
# Base module is split into several files for convenience. Files inside of
|
||||||
|
# this module should import from a specific submodule (e.g.
|
||||||
|
# `from mongoengine.base.document import BaseDocument`), but all of the
|
||||||
|
# other modules should import directly from the top-level module (e.g.
|
||||||
|
# `from mongoengine.base import BaseDocument`). This approach is cleaner and
|
||||||
|
# also helps with cyclical import errors.
|
||||||
from mongoengine.base.common import *
|
from mongoengine.base.common import *
|
||||||
from mongoengine.base.datastructures import *
|
from mongoengine.base.datastructures import *
|
||||||
from mongoengine.base.document import *
|
from mongoengine.base.document import *
|
||||||
from mongoengine.base.fields import *
|
from mongoengine.base.fields import *
|
||||||
from mongoengine.base.metaclasses import *
|
from mongoengine.base.metaclasses import *
|
||||||
|
|
||||||
# Help with backwards compatibility
|
__all__ = (
|
||||||
from mongoengine.errors import *
|
# common
|
||||||
|
'UPDATE_OPERATORS', '_document_registry', 'get_document',
|
||||||
|
|
||||||
|
# datastructures
|
||||||
|
'BaseDict', 'BaseList', 'EmbeddedDocumentList',
|
||||||
|
|
||||||
|
# document
|
||||||
|
'BaseDocument',
|
||||||
|
|
||||||
|
# fields
|
||||||
|
'BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField',
|
||||||
|
|
||||||
|
# metaclasses
|
||||||
|
'DocumentMetaclass', 'TopLevelDocumentMetaclass'
|
||||||
|
)
|
||||||
|
@ -1,13 +1,18 @@
|
|||||||
from mongoengine.errors import NotRegistered
|
from mongoengine.errors import NotRegistered
|
||||||
|
|
||||||
__all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry')
|
__all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry')
|
||||||
|
|
||||||
|
|
||||||
|
UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push',
|
||||||
|
'push_all', 'pull', 'pull_all', 'add_to_set',
|
||||||
|
'set_on_insert', 'min', 'max'])
|
||||||
|
|
||||||
ALLOW_INHERITANCE = False
|
|
||||||
|
|
||||||
_document_registry = {}
|
_document_registry = {}
|
||||||
|
|
||||||
|
|
||||||
def get_document(name):
|
def get_document(name):
|
||||||
|
"""Get a document class by name."""
|
||||||
doc = _document_registry.get(name, None)
|
doc = _document_registry.get(name, None)
|
||||||
if not doc:
|
if not doc:
|
||||||
# Possible old style name
|
# Possible old style name
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
|
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
|
||||||
|
|
||||||
__all__ = ("BaseDict", "BaseList", "EmbeddedDocumentList")
|
__all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList')
|
||||||
|
|
||||||
|
|
||||||
class BaseDict(dict):
|
class BaseDict(dict):
|
||||||
"""A special dict so we can watch any changes"""
|
"""A special dict so we can watch any changes."""
|
||||||
|
|
||||||
_dereferenced = False
|
_dereferenced = False
|
||||||
_instance = None
|
_instance = None
|
||||||
@ -93,8 +95,7 @@ class BaseDict(dict):
|
|||||||
|
|
||||||
|
|
||||||
class BaseList(list):
|
class BaseList(list):
|
||||||
"""A special list so we can watch any changes
|
"""A special list so we can watch any changes."""
|
||||||
"""
|
|
||||||
|
|
||||||
_dereferenced = False
|
_dereferenced = False
|
||||||
_instance = None
|
_instance = None
|
||||||
@ -209,17 +210,22 @@ class BaseList(list):
|
|||||||
class EmbeddedDocumentList(BaseList):
|
class EmbeddedDocumentList(BaseList):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __match_all(cls, i, kwargs):
|
def __match_all(cls, embedded_doc, kwargs):
|
||||||
items = kwargs.items()
|
"""Return True if a given embedded doc matches all the filter
|
||||||
return all([
|
kwargs. If it doesn't return False.
|
||||||
getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items
|
"""
|
||||||
])
|
for key, expected_value in kwargs.items():
|
||||||
|
doc_val = getattr(embedded_doc, key)
|
||||||
|
if doc_val != expected_value and six.text_type(doc_val) != expected_value:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __only_matches(cls, obj, kwargs):
|
def __only_matches(cls, embedded_docs, kwargs):
|
||||||
|
"""Return embedded docs that match the filter kwargs."""
|
||||||
if not kwargs:
|
if not kwargs:
|
||||||
return obj
|
return embedded_docs
|
||||||
return filter(lambda i: cls.__match_all(i, kwargs), obj)
|
return [doc for doc in embedded_docs if cls.__match_all(doc, kwargs)]
|
||||||
|
|
||||||
def __init__(self, list_items, instance, name):
|
def __init__(self, list_items, instance, name):
|
||||||
super(EmbeddedDocumentList, self).__init__(list_items, instance, name)
|
super(EmbeddedDocumentList, self).__init__(list_items, instance, name)
|
||||||
@ -285,18 +291,18 @@ class EmbeddedDocumentList(BaseList):
|
|||||||
values = self.__only_matches(self, kwargs)
|
values = self.__only_matches(self, kwargs)
|
||||||
if len(values) == 0:
|
if len(values) == 0:
|
||||||
raise DoesNotExist(
|
raise DoesNotExist(
|
||||||
"%s matching query does not exist." % self._name
|
'%s matching query does not exist.' % self._name
|
||||||
)
|
)
|
||||||
elif len(values) > 1:
|
elif len(values) > 1:
|
||||||
raise MultipleObjectsReturned(
|
raise MultipleObjectsReturned(
|
||||||
"%d items returned, instead of 1" % len(values)
|
'%d items returned, instead of 1' % len(values)
|
||||||
)
|
)
|
||||||
|
|
||||||
return values[0]
|
return values[0]
|
||||||
|
|
||||||
def first(self):
|
def first(self):
|
||||||
"""
|
"""Return the first embedded document in the list, or ``None``
|
||||||
Returns the first embedded document in the list, or ``None`` if empty.
|
if empty.
|
||||||
"""
|
"""
|
||||||
if len(self) > 0:
|
if len(self) > 0:
|
||||||
return self[0]
|
return self[0]
|
||||||
@ -438,7 +444,7 @@ class StrictDict(object):
|
|||||||
__slots__ = allowed_keys_tuple
|
__slots__ = allowed_keys_tuple
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
|
return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
|
||||||
|
|
||||||
cls._classes[allowed_keys] = SpecificStrictDict
|
cls._classes[allowed_keys] = SpecificStrictDict
|
||||||
return cls._classes[allowed_keys]
|
return cls._classes[allowed_keys]
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
import numbers
|
import numbers
|
||||||
import operator
|
|
||||||
from collections import Hashable
|
from collections import Hashable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -8,30 +7,27 @@ from bson import ObjectId, json_util
|
|||||||
from bson.dbref import DBRef
|
from bson.dbref import DBRef
|
||||||
from bson.son import SON
|
from bson.son import SON
|
||||||
import pymongo
|
import pymongo
|
||||||
|
import six
|
||||||
|
|
||||||
from mongoengine import signals
|
from mongoengine import signals
|
||||||
from mongoengine.base.common import ALLOW_INHERITANCE, get_document
|
from mongoengine.base.common import get_document
|
||||||
from mongoengine.base.datastructures import (
|
from mongoengine.base.datastructures import (BaseDict, BaseList,
|
||||||
BaseDict,
|
EmbeddedDocumentList,
|
||||||
BaseList,
|
SemiStrictDict, StrictDict)
|
||||||
EmbeddedDocumentList,
|
|
||||||
SemiStrictDict,
|
|
||||||
StrictDict
|
|
||||||
)
|
|
||||||
from mongoengine.base.fields import ComplexBaseField
|
from mongoengine.base.fields import ComplexBaseField
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError,
|
from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError,
|
||||||
LookUpError, ValidationError)
|
LookUpError, OperationError, ValidationError)
|
||||||
from mongoengine.python_support import PY3, txt_type
|
|
||||||
|
|
||||||
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
|
__all__ = ('BaseDocument',)
|
||||||
|
|
||||||
NON_FIELD_ERRORS = '__all__'
|
NON_FIELD_ERRORS = '__all__'
|
||||||
|
|
||||||
|
|
||||||
class BaseDocument(object):
|
class BaseDocument(object):
|
||||||
__slots__ = ('_changed_fields', '_initialised', '_created', '_data',
|
__slots__ = ('_changed_fields', '_initialised', '_created', '_data',
|
||||||
'_dynamic_fields', '_auto_id_field', '_db_field_map', '__weakref__')
|
'_dynamic_fields', '_auto_id_field', '_db_field_map',
|
||||||
|
'__weakref__')
|
||||||
|
|
||||||
_dynamic = False
|
_dynamic = False
|
||||||
_dynamic_lock = True
|
_dynamic_lock = True
|
||||||
@ -57,15 +53,15 @@ class BaseDocument(object):
|
|||||||
name = next(field)
|
name = next(field)
|
||||||
if name in values:
|
if name in values:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Multiple values for keyword argument '" + name + "'")
|
'Multiple values for keyword argument "%s"' % name)
|
||||||
values[name] = value
|
values[name] = value
|
||||||
|
|
||||||
__auto_convert = values.pop("__auto_convert", True)
|
__auto_convert = values.pop('__auto_convert', True)
|
||||||
|
|
||||||
# 399: set default values only to fields loaded from DB
|
# 399: set default values only to fields loaded from DB
|
||||||
__only_fields = set(values.pop("__only_fields", values))
|
__only_fields = set(values.pop('__only_fields', values))
|
||||||
|
|
||||||
_created = values.pop("_created", True)
|
_created = values.pop('_created', True)
|
||||||
|
|
||||||
signals.pre_init.send(self.__class__, document=self, values=values)
|
signals.pre_init.send(self.__class__, document=self, values=values)
|
||||||
|
|
||||||
@ -76,7 +72,7 @@ class BaseDocument(object):
|
|||||||
self._fields.keys() + ['id', 'pk', '_cls', '_text_score'])
|
self._fields.keys() + ['id', 'pk', '_cls', '_text_score'])
|
||||||
if _undefined_fields:
|
if _undefined_fields:
|
||||||
msg = (
|
msg = (
|
||||||
"The fields '{0}' do not exist on the document '{1}'"
|
'The fields "{0}" do not exist on the document "{1}"'
|
||||||
).format(_undefined_fields, self._class_name)
|
).format(_undefined_fields, self._class_name)
|
||||||
raise FieldDoesNotExist(msg)
|
raise FieldDoesNotExist(msg)
|
||||||
|
|
||||||
@ -95,7 +91,7 @@ class BaseDocument(object):
|
|||||||
value = getattr(self, key, None)
|
value = getattr(self, key, None)
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
if "_cls" not in values:
|
if '_cls' not in values:
|
||||||
self._cls = self._class_name
|
self._cls = self._class_name
|
||||||
|
|
||||||
# Set passed values after initialisation
|
# Set passed values after initialisation
|
||||||
@ -150,7 +146,7 @@ class BaseDocument(object):
|
|||||||
if self._dynamic and not self._dynamic_lock:
|
if self._dynamic and not self._dynamic_lock:
|
||||||
|
|
||||||
if not hasattr(self, name) and not name.startswith('_'):
|
if not hasattr(self, name) and not name.startswith('_'):
|
||||||
DynamicField = _import_class("DynamicField")
|
DynamicField = _import_class('DynamicField')
|
||||||
field = DynamicField(db_field=name)
|
field = DynamicField(db_field=name)
|
||||||
field.name = name
|
field.name = name
|
||||||
self._dynamic_fields[name] = field
|
self._dynamic_fields[name] = field
|
||||||
@ -169,11 +165,13 @@ class BaseDocument(object):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
self__created = True
|
self__created = True
|
||||||
|
|
||||||
if (self._is_document and not self__created and
|
if (
|
||||||
name in self._meta.get('shard_key', tuple()) and
|
self._is_document and
|
||||||
self._data.get(name) != value):
|
not self__created and
|
||||||
OperationError = _import_class('OperationError')
|
name in self._meta.get('shard_key', tuple()) and
|
||||||
msg = "Shard Keys are immutable. Tried to update %s" % name
|
self._data.get(name) != value
|
||||||
|
):
|
||||||
|
msg = 'Shard Keys are immutable. Tried to update %s' % name
|
||||||
raise OperationError(msg)
|
raise OperationError(msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -197,8 +195,8 @@ class BaseDocument(object):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def __setstate__(self, data):
|
def __setstate__(self, data):
|
||||||
if isinstance(data["_data"], SON):
|
if isinstance(data['_data'], SON):
|
||||||
data["_data"] = self.__class__._from_son(data["_data"])._data
|
data['_data'] = self.__class__._from_son(data['_data'])._data
|
||||||
for k in ('_changed_fields', '_initialised', '_created', '_data',
|
for k in ('_changed_fields', '_initialised', '_created', '_data',
|
||||||
'_dynamic_fields'):
|
'_dynamic_fields'):
|
||||||
if k in data:
|
if k in data:
|
||||||
@ -212,7 +210,7 @@ class BaseDocument(object):
|
|||||||
|
|
||||||
dynamic_fields = data.get('_dynamic_fields') or SON()
|
dynamic_fields = data.get('_dynamic_fields') or SON()
|
||||||
for k in dynamic_fields.keys():
|
for k in dynamic_fields.keys():
|
||||||
setattr(self, k, data["_data"].get(k))
|
setattr(self, k, data['_data'].get(k))
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self._fields_ordered)
|
return iter(self._fields_ordered)
|
||||||
@ -254,12 +252,13 @@ class BaseDocument(object):
|
|||||||
return repr_type('<%s: %s>' % (self.__class__.__name__, u))
|
return repr_type('<%s: %s>' % (self.__class__.__name__, u))
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
# TODO this could be simpler?
|
||||||
if hasattr(self, '__unicode__'):
|
if hasattr(self, '__unicode__'):
|
||||||
if PY3:
|
if six.PY3:
|
||||||
return self.__unicode__()
|
return self.__unicode__()
|
||||||
else:
|
else:
|
||||||
return unicode(self).encode('utf-8')
|
return six.text_type(self).encode('utf-8')
|
||||||
return txt_type('%s object' % self.__class__.__name__)
|
return six.text_type('%s object' % self.__class__.__name__)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None:
|
if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None:
|
||||||
@ -308,7 +307,7 @@ class BaseDocument(object):
|
|||||||
fields = []
|
fields = []
|
||||||
|
|
||||||
data = SON()
|
data = SON()
|
||||||
data["_id"] = None
|
data['_id'] = None
|
||||||
data['_cls'] = self._class_name
|
data['_cls'] = self._class_name
|
||||||
|
|
||||||
# only root fields ['test1.a', 'test2'] => ['test1', 'test2']
|
# only root fields ['test1.a', 'test2'] => ['test1', 'test2']
|
||||||
@ -351,18 +350,8 @@ class BaseDocument(object):
|
|||||||
else:
|
else:
|
||||||
data[field.name] = value
|
data[field.name] = value
|
||||||
|
|
||||||
# If "_id" has not been set, then try and set it
|
|
||||||
Document = _import_class("Document")
|
|
||||||
if isinstance(self, Document):
|
|
||||||
if data["_id"] is None:
|
|
||||||
data["_id"] = self._data.get("id", None)
|
|
||||||
|
|
||||||
if data['_id'] is None:
|
|
||||||
data.pop('_id')
|
|
||||||
|
|
||||||
# Only add _cls if allow_inheritance is True
|
# Only add _cls if allow_inheritance is True
|
||||||
if (not hasattr(self, '_meta') or
|
if not self._meta.get('allow_inheritance'):
|
||||||
not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)):
|
|
||||||
data.pop('_cls')
|
data.pop('_cls')
|
||||||
|
|
||||||
return data
|
return data
|
||||||
@ -376,16 +365,16 @@ class BaseDocument(object):
|
|||||||
if clean:
|
if clean:
|
||||||
try:
|
try:
|
||||||
self.clean()
|
self.clean()
|
||||||
except ValidationError, error:
|
except ValidationError as error:
|
||||||
errors[NON_FIELD_ERRORS] = error
|
errors[NON_FIELD_ERRORS] = error
|
||||||
|
|
||||||
# Get a list of tuples of field names and their current values
|
# Get a list of tuples of field names and their current values
|
||||||
fields = [(self._fields.get(name, self._dynamic_fields.get(name)),
|
fields = [(self._fields.get(name, self._dynamic_fields.get(name)),
|
||||||
self._data.get(name)) for name in self._fields_ordered]
|
self._data.get(name)) for name in self._fields_ordered]
|
||||||
|
|
||||||
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
|
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
|
||||||
GenericEmbeddedDocumentField = _import_class(
|
GenericEmbeddedDocumentField = _import_class(
|
||||||
"GenericEmbeddedDocumentField")
|
'GenericEmbeddedDocumentField')
|
||||||
|
|
||||||
for field, value in fields:
|
for field, value in fields:
|
||||||
if value is not None:
|
if value is not None:
|
||||||
@ -395,21 +384,21 @@ class BaseDocument(object):
|
|||||||
field._validate(value, clean=clean)
|
field._validate(value, clean=clean)
|
||||||
else:
|
else:
|
||||||
field._validate(value)
|
field._validate(value)
|
||||||
except ValidationError, error:
|
except ValidationError as error:
|
||||||
errors[field.name] = error.errors or error
|
errors[field.name] = error.errors or error
|
||||||
except (ValueError, AttributeError, AssertionError), error:
|
except (ValueError, AttributeError, AssertionError) as error:
|
||||||
errors[field.name] = error
|
errors[field.name] = error
|
||||||
elif field.required and not getattr(field, '_auto_gen', False):
|
elif field.required and not getattr(field, '_auto_gen', False):
|
||||||
errors[field.name] = ValidationError('Field is required',
|
errors[field.name] = ValidationError('Field is required',
|
||||||
field_name=field.name)
|
field_name=field.name)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
pk = "None"
|
pk = 'None'
|
||||||
if hasattr(self, 'pk'):
|
if hasattr(self, 'pk'):
|
||||||
pk = self.pk
|
pk = self.pk
|
||||||
elif self._instance and hasattr(self._instance, 'pk'):
|
elif self._instance and hasattr(self._instance, 'pk'):
|
||||||
pk = self._instance.pk
|
pk = self._instance.pk
|
||||||
message = "ValidationError (%s:%s) " % (self._class_name, pk)
|
message = 'ValidationError (%s:%s) ' % (self._class_name, pk)
|
||||||
raise ValidationError(message, errors=errors)
|
raise ValidationError(message, errors=errors)
|
||||||
|
|
||||||
def to_json(self, *args, **kwargs):
|
def to_json(self, *args, **kwargs):
|
||||||
@ -426,33 +415,26 @@ class BaseDocument(object):
|
|||||||
return cls._from_son(json_util.loads(json_data), created=created)
|
return cls._from_son(json_util.loads(json_data), created=created)
|
||||||
|
|
||||||
def __expand_dynamic_values(self, name, value):
|
def __expand_dynamic_values(self, name, value):
|
||||||
"""expand any dynamic values to their correct types / values"""
|
"""Expand any dynamic values to their correct types / values."""
|
||||||
if not isinstance(value, (dict, list, tuple)):
|
if not isinstance(value, (dict, list, tuple)):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
|
# If the value is a dict with '_cls' in it, turn it into a document
|
||||||
|
is_dict = isinstance(value, dict)
|
||||||
is_list = False
|
if is_dict and '_cls' in value:
|
||||||
if not hasattr(value, 'items'):
|
|
||||||
is_list = True
|
|
||||||
value = dict([(k, v) for k, v in enumerate(value)])
|
|
||||||
|
|
||||||
if not is_list and '_cls' in value:
|
|
||||||
cls = get_document(value['_cls'])
|
cls = get_document(value['_cls'])
|
||||||
return cls(**value)
|
return cls(**value)
|
||||||
|
|
||||||
data = {}
|
if is_dict:
|
||||||
for k, v in value.items():
|
value = {
|
||||||
key = name if is_list else k
|
k: self.__expand_dynamic_values(k, v)
|
||||||
data[k] = self.__expand_dynamic_values(key, v)
|
for k, v in value.items()
|
||||||
|
}
|
||||||
if is_list: # Convert back to a list
|
|
||||||
data_items = sorted(data.items(), key=operator.itemgetter(0))
|
|
||||||
value = [v for k, v in data_items]
|
|
||||||
else:
|
else:
|
||||||
value = data
|
value = [self.__expand_dynamic_values(name, v) for v in value]
|
||||||
|
|
||||||
# Convert lists / values so we can watch for any changes on them
|
# Convert lists / values so we can watch for any changes on them
|
||||||
|
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
|
||||||
if (isinstance(value, (list, tuple)) and
|
if (isinstance(value, (list, tuple)) and
|
||||||
not isinstance(value, BaseList)):
|
not isinstance(value, BaseList)):
|
||||||
if issubclass(type(self), EmbeddedDocumentListField):
|
if issubclass(type(self), EmbeddedDocumentListField):
|
||||||
@ -465,8 +447,7 @@ class BaseDocument(object):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def _mark_as_changed(self, key):
|
def _mark_as_changed(self, key):
|
||||||
"""Marks a key as explicitly changed by the user
|
"""Mark a key as explicitly changed by the user."""
|
||||||
"""
|
|
||||||
if not key:
|
if not key:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -496,10 +477,11 @@ class BaseDocument(object):
|
|||||||
remove(field)
|
remove(field)
|
||||||
|
|
||||||
def _clear_changed_fields(self):
|
def _clear_changed_fields(self):
|
||||||
"""Using get_changed_fields iterate and remove any fields that are
|
"""Using _get_changed_fields iterate and remove any fields that
|
||||||
marked as changed"""
|
are marked as changed.
|
||||||
|
"""
|
||||||
for changed in self._get_changed_fields():
|
for changed in self._get_changed_fields():
|
||||||
parts = changed.split(".")
|
parts = changed.split('.')
|
||||||
data = self
|
data = self
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
@ -511,10 +493,13 @@ class BaseDocument(object):
|
|||||||
data = data.get(part, None)
|
data = data.get(part, None)
|
||||||
else:
|
else:
|
||||||
data = getattr(data, part, None)
|
data = getattr(data, part, None)
|
||||||
if hasattr(data, "_changed_fields"):
|
|
||||||
if hasattr(data, "_is_document") and data._is_document:
|
if hasattr(data, '_changed_fields'):
|
||||||
|
if getattr(data, '_is_document', False):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data._changed_fields = []
|
data._changed_fields = []
|
||||||
|
|
||||||
self._changed_fields = []
|
self._changed_fields = []
|
||||||
|
|
||||||
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
|
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
|
||||||
@ -526,26 +511,27 @@ class BaseDocument(object):
|
|||||||
iterator = data.iteritems()
|
iterator = data.iteritems()
|
||||||
|
|
||||||
for index, value in iterator:
|
for index, value in iterator:
|
||||||
list_key = "%s%s." % (key, index)
|
list_key = '%s%s.' % (key, index)
|
||||||
# don't check anything lower if this key is already marked
|
# don't check anything lower if this key is already marked
|
||||||
# as changed.
|
# as changed.
|
||||||
if list_key[:-1] in changed_fields:
|
if list_key[:-1] in changed_fields:
|
||||||
continue
|
continue
|
||||||
if hasattr(value, '_get_changed_fields'):
|
if hasattr(value, '_get_changed_fields'):
|
||||||
changed = value._get_changed_fields(inspected)
|
changed = value._get_changed_fields(inspected)
|
||||||
changed_fields += ["%s%s" % (list_key, k)
|
changed_fields += ['%s%s' % (list_key, k)
|
||||||
for k in changed if k]
|
for k in changed if k]
|
||||||
elif isinstance(value, (list, tuple, dict)):
|
elif isinstance(value, (list, tuple, dict)):
|
||||||
self._nestable_types_changed_fields(
|
self._nestable_types_changed_fields(
|
||||||
changed_fields, list_key, value, inspected)
|
changed_fields, list_key, value, inspected)
|
||||||
|
|
||||||
def _get_changed_fields(self, inspected=None):
|
def _get_changed_fields(self, inspected=None):
|
||||||
"""Returns a list of all fields that have explicitly been changed.
|
"""Return a list of all fields that have explicitly been changed.
|
||||||
"""
|
"""
|
||||||
EmbeddedDocument = _import_class("EmbeddedDocument")
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||||
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
|
DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument')
|
||||||
ReferenceField = _import_class("ReferenceField")
|
ReferenceField = _import_class('ReferenceField')
|
||||||
SortedListField = _import_class("SortedListField")
|
SortedListField = _import_class('SortedListField')
|
||||||
|
|
||||||
changed_fields = []
|
changed_fields = []
|
||||||
changed_fields += getattr(self, '_changed_fields', [])
|
changed_fields += getattr(self, '_changed_fields', [])
|
||||||
|
|
||||||
@ -572,7 +558,7 @@ class BaseDocument(object):
|
|||||||
):
|
):
|
||||||
# Find all embedded fields that have been changed
|
# Find all embedded fields that have been changed
|
||||||
changed = data._get_changed_fields(inspected)
|
changed = data._get_changed_fields(inspected)
|
||||||
changed_fields += ["%s%s" % (key, k) for k in changed if k]
|
changed_fields += ['%s%s' % (key, k) for k in changed if k]
|
||||||
elif (isinstance(data, (list, tuple, dict)) and
|
elif (isinstance(data, (list, tuple, dict)) and
|
||||||
db_field_name not in changed_fields):
|
db_field_name not in changed_fields):
|
||||||
if (hasattr(field, 'field') and
|
if (hasattr(field, 'field') and
|
||||||
@ -676,21 +662,25 @@ class BaseDocument(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_collection_name(cls):
|
def _get_collection_name(cls):
|
||||||
"""Returns the collection name for this class. None for abstract class
|
"""Return the collection name for this class. None for abstract
|
||||||
|
class.
|
||||||
"""
|
"""
|
||||||
return cls._meta.get('collection', None)
|
return cls._meta.get('collection', None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False):
|
def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False):
|
||||||
"""Create an instance of a Document (subclass) from a PyMongo SON.
|
"""Create an instance of a Document (subclass) from a PyMongo
|
||||||
|
SON.
|
||||||
"""
|
"""
|
||||||
if not only_fields:
|
if not only_fields:
|
||||||
only_fields = []
|
only_fields = []
|
||||||
|
|
||||||
# get the class name from the document, falling back to the given
|
# Get the class name from the document, falling back to the given
|
||||||
# class if unavailable
|
# class if unavailable
|
||||||
class_name = son.get('_cls', cls._class_name)
|
class_name = son.get('_cls', cls._class_name)
|
||||||
data = dict(("%s" % key, value) for key, value in son.iteritems())
|
|
||||||
|
# Convert SON to a dict, making sure each key is a string
|
||||||
|
data = {str(key): value for key, value in son.iteritems()}
|
||||||
|
|
||||||
# Return correct subclass for document type
|
# Return correct subclass for document type
|
||||||
if class_name != cls._class_name:
|
if class_name != cls._class_name:
|
||||||
@ -712,19 +702,20 @@ class BaseDocument(object):
|
|||||||
else field.to_python(value))
|
else field.to_python(value))
|
||||||
if field_name != field.db_field:
|
if field_name != field.db_field:
|
||||||
del data[field.db_field]
|
del data[field.db_field]
|
||||||
except (AttributeError, ValueError), e:
|
except (AttributeError, ValueError) as e:
|
||||||
errors_dict[field_name] = e
|
errors_dict[field_name] = e
|
||||||
|
|
||||||
if errors_dict:
|
if errors_dict:
|
||||||
errors = "\n".join(["%s - %s" % (k, v)
|
errors = '\n'.join(['%s - %s' % (k, v)
|
||||||
for k, v in errors_dict.items()])
|
for k, v in errors_dict.items()])
|
||||||
msg = ("Invalid data to create a `%s` instance.\n%s"
|
msg = ('Invalid data to create a `%s` instance.\n%s'
|
||||||
% (cls._class_name, errors))
|
% (cls._class_name, errors))
|
||||||
raise InvalidDocumentError(msg)
|
raise InvalidDocumentError(msg)
|
||||||
|
|
||||||
|
# In STRICT documents, remove any keys that aren't in cls._fields
|
||||||
if cls.STRICT:
|
if cls.STRICT:
|
||||||
data = dict((k, v)
|
data = {k: v for k, v in data.iteritems() if k in cls._fields}
|
||||||
for k, v in data.iteritems() if k in cls._fields)
|
|
||||||
obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data)
|
obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data)
|
||||||
obj._changed_fields = changed_fields
|
obj._changed_fields = changed_fields
|
||||||
if not _auto_dereference:
|
if not _auto_dereference:
|
||||||
@ -734,37 +725,43 @@ class BaseDocument(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_index_specs(cls, meta_indexes):
|
def _build_index_specs(cls, meta_indexes):
|
||||||
"""Generate and merge the full index specs
|
"""Generate and merge the full index specs."""
|
||||||
"""
|
|
||||||
|
|
||||||
geo_indices = cls._geo_indices()
|
geo_indices = cls._geo_indices()
|
||||||
unique_indices = cls._unique_with_indexes()
|
unique_indices = cls._unique_with_indexes()
|
||||||
index_specs = [cls._build_index_spec(spec)
|
index_specs = [cls._build_index_spec(spec) for spec in meta_indexes]
|
||||||
for spec in meta_indexes]
|
|
||||||
|
|
||||||
def merge_index_specs(index_specs, indices):
|
def merge_index_specs(index_specs, indices):
|
||||||
|
"""Helper method for merging index specs."""
|
||||||
if not indices:
|
if not indices:
|
||||||
return index_specs
|
return index_specs
|
||||||
|
|
||||||
spec_fields = [v['fields']
|
# Create a map of index fields to index spec. We're converting
|
||||||
for k, v in enumerate(index_specs)]
|
# the fields from a list to a tuple so that it's hashable.
|
||||||
# Merge unique_indexes with existing specs
|
spec_fields = {
|
||||||
for k, v in enumerate(indices):
|
tuple(index['fields']): index for index in index_specs
|
||||||
if v['fields'] in spec_fields:
|
}
|
||||||
index_specs[spec_fields.index(v['fields'])].update(v)
|
|
||||||
|
# For each new index, if there's an existing index with the same
|
||||||
|
# fields list, update the existing spec with all data from the
|
||||||
|
# new spec.
|
||||||
|
for new_index in indices:
|
||||||
|
candidate = spec_fields.get(tuple(new_index['fields']))
|
||||||
|
if candidate is None:
|
||||||
|
index_specs.append(new_index)
|
||||||
else:
|
else:
|
||||||
index_specs.append(v)
|
candidate.update(new_index)
|
||||||
|
|
||||||
return index_specs
|
return index_specs
|
||||||
|
|
||||||
|
# Merge geo indexes and unique_with indexes into the meta index specs.
|
||||||
index_specs = merge_index_specs(index_specs, geo_indices)
|
index_specs = merge_index_specs(index_specs, geo_indices)
|
||||||
index_specs = merge_index_specs(index_specs, unique_indices)
|
index_specs = merge_index_specs(index_specs, unique_indices)
|
||||||
return index_specs
|
return index_specs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_index_spec(cls, spec):
|
def _build_index_spec(cls, spec):
|
||||||
"""Build a PyMongo index spec from a MongoEngine index spec.
|
"""Build a PyMongo index spec from a MongoEngine index spec."""
|
||||||
"""
|
if isinstance(spec, six.string_types):
|
||||||
if isinstance(spec, basestring):
|
|
||||||
spec = {'fields': [spec]}
|
spec = {'fields': [spec]}
|
||||||
elif isinstance(spec, (list, tuple)):
|
elif isinstance(spec, (list, tuple)):
|
||||||
spec = {'fields': list(spec)}
|
spec = {'fields': list(spec)}
|
||||||
@ -775,8 +772,7 @@ class BaseDocument(object):
|
|||||||
direction = None
|
direction = None
|
||||||
|
|
||||||
# Check to see if we need to include _cls
|
# Check to see if we need to include _cls
|
||||||
allow_inheritance = cls._meta.get('allow_inheritance',
|
allow_inheritance = cls._meta.get('allow_inheritance')
|
||||||
ALLOW_INHERITANCE)
|
|
||||||
include_cls = (
|
include_cls = (
|
||||||
allow_inheritance and
|
allow_inheritance and
|
||||||
not spec.get('sparse', False) and
|
not spec.get('sparse', False) and
|
||||||
@ -786,7 +782,7 @@ class BaseDocument(object):
|
|||||||
|
|
||||||
# 733: don't include cls if index_cls is False unless there is an explicit cls with the index
|
# 733: don't include cls if index_cls is False unless there is an explicit cls with the index
|
||||||
include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True))
|
include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True))
|
||||||
if "cls" in spec:
|
if 'cls' in spec:
|
||||||
spec.pop('cls')
|
spec.pop('cls')
|
||||||
for key in spec['fields']:
|
for key in spec['fields']:
|
||||||
# If inherited spec continue
|
# If inherited spec continue
|
||||||
@ -801,19 +797,19 @@ class BaseDocument(object):
|
|||||||
# GEOHAYSTACK from )
|
# GEOHAYSTACK from )
|
||||||
# GEO2D from *
|
# GEO2D from *
|
||||||
direction = pymongo.ASCENDING
|
direction = pymongo.ASCENDING
|
||||||
if key.startswith("-"):
|
if key.startswith('-'):
|
||||||
direction = pymongo.DESCENDING
|
direction = pymongo.DESCENDING
|
||||||
elif key.startswith("$"):
|
elif key.startswith('$'):
|
||||||
direction = pymongo.TEXT
|
direction = pymongo.TEXT
|
||||||
elif key.startswith("#"):
|
elif key.startswith('#'):
|
||||||
direction = pymongo.HASHED
|
direction = pymongo.HASHED
|
||||||
elif key.startswith("("):
|
elif key.startswith('('):
|
||||||
direction = pymongo.GEOSPHERE
|
direction = pymongo.GEOSPHERE
|
||||||
elif key.startswith(")"):
|
elif key.startswith(')'):
|
||||||
direction = pymongo.GEOHAYSTACK
|
direction = pymongo.GEOHAYSTACK
|
||||||
elif key.startswith("*"):
|
elif key.startswith('*'):
|
||||||
direction = pymongo.GEO2D
|
direction = pymongo.GEO2D
|
||||||
if key.startswith(("+", "-", "*", "$", "#", "(", ")")):
|
if key.startswith(('+', '-', '*', '$', '#', '(', ')')):
|
||||||
key = key[1:]
|
key = key[1:]
|
||||||
|
|
||||||
# Use real field name, do it manually because we need field
|
# Use real field name, do it manually because we need field
|
||||||
@ -826,7 +822,7 @@ class BaseDocument(object):
|
|||||||
parts = []
|
parts = []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
try:
|
try:
|
||||||
if field != "_id":
|
if field != '_id':
|
||||||
field = field.db_field
|
field = field.db_field
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
@ -845,49 +841,53 @@ class BaseDocument(object):
|
|||||||
return spec
|
return spec
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _unique_with_indexes(cls, namespace=""):
|
def _unique_with_indexes(cls, namespace=''):
|
||||||
"""
|
"""Find unique indexes in the document schema and return them."""
|
||||||
Find and set unique indexes
|
|
||||||
"""
|
|
||||||
unique_indexes = []
|
unique_indexes = []
|
||||||
for field_name, field in cls._fields.items():
|
for field_name, field in cls._fields.items():
|
||||||
sparse = field.sparse
|
sparse = field.sparse
|
||||||
|
|
||||||
# Generate a list of indexes needed by uniqueness constraints
|
# Generate a list of indexes needed by uniqueness constraints
|
||||||
if field.unique:
|
if field.unique:
|
||||||
unique_fields = [field.db_field]
|
unique_fields = [field.db_field]
|
||||||
|
|
||||||
# Add any unique_with fields to the back of the index spec
|
# Add any unique_with fields to the back of the index spec
|
||||||
if field.unique_with:
|
if field.unique_with:
|
||||||
if isinstance(field.unique_with, basestring):
|
if isinstance(field.unique_with, six.string_types):
|
||||||
field.unique_with = [field.unique_with]
|
field.unique_with = [field.unique_with]
|
||||||
|
|
||||||
# Convert unique_with field names to real field names
|
# Convert unique_with field names to real field names
|
||||||
unique_with = []
|
unique_with = []
|
||||||
for other_name in field.unique_with:
|
for other_name in field.unique_with:
|
||||||
parts = other_name.split('.')
|
parts = other_name.split('.')
|
||||||
|
|
||||||
# Lookup real name
|
# Lookup real name
|
||||||
parts = cls._lookup_field(parts)
|
parts = cls._lookup_field(parts)
|
||||||
name_parts = [part.db_field for part in parts]
|
name_parts = [part.db_field for part in parts]
|
||||||
unique_with.append('.'.join(name_parts))
|
unique_with.append('.'.join(name_parts))
|
||||||
|
|
||||||
# Unique field should be required
|
# Unique field should be required
|
||||||
parts[-1].required = True
|
parts[-1].required = True
|
||||||
sparse = (not sparse and
|
sparse = (not sparse and
|
||||||
parts[-1].name not in cls.__dict__)
|
parts[-1].name not in cls.__dict__)
|
||||||
|
|
||||||
unique_fields += unique_with
|
unique_fields += unique_with
|
||||||
|
|
||||||
# Add the new index to the list
|
# Add the new index to the list
|
||||||
fields = [("%s%s" % (namespace, f), pymongo.ASCENDING)
|
fields = [
|
||||||
for f in unique_fields]
|
('%s%s' % (namespace, f), pymongo.ASCENDING)
|
||||||
|
for f in unique_fields
|
||||||
|
]
|
||||||
index = {'fields': fields, 'unique': True, 'sparse': sparse}
|
index = {'fields': fields, 'unique': True, 'sparse': sparse}
|
||||||
unique_indexes.append(index)
|
unique_indexes.append(index)
|
||||||
|
|
||||||
if field.__class__.__name__ == "ListField":
|
if field.__class__.__name__ == 'ListField':
|
||||||
field = field.field
|
field = field.field
|
||||||
|
|
||||||
# Grab any embedded document field unique indexes
|
# Grab any embedded document field unique indexes
|
||||||
if (field.__class__.__name__ == "EmbeddedDocumentField" and
|
if (field.__class__.__name__ == 'EmbeddedDocumentField' and
|
||||||
field.document_type != cls):
|
field.document_type != cls):
|
||||||
field_namespace = "%s." % field_name
|
field_namespace = '%s.' % field_name
|
||||||
doc_cls = field.document_type
|
doc_cls = field.document_type
|
||||||
unique_indexes += doc_cls._unique_with_indexes(field_namespace)
|
unique_indexes += doc_cls._unique_with_indexes(field_namespace)
|
||||||
|
|
||||||
@ -899,8 +899,9 @@ class BaseDocument(object):
|
|||||||
geo_indices = []
|
geo_indices = []
|
||||||
inspected.append(cls)
|
inspected.append(cls)
|
||||||
|
|
||||||
geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField",
|
geo_field_type_names = ('EmbeddedDocumentField', 'GeoPointField',
|
||||||
"PointField", "LineStringField", "PolygonField"]
|
'PointField', 'LineStringField',
|
||||||
|
'PolygonField')
|
||||||
|
|
||||||
geo_field_types = tuple([_import_class(field)
|
geo_field_types = tuple([_import_class(field)
|
||||||
for field in geo_field_type_names])
|
for field in geo_field_type_names])
|
||||||
@ -908,32 +909,68 @@ class BaseDocument(object):
|
|||||||
for field in cls._fields.values():
|
for field in cls._fields.values():
|
||||||
if not isinstance(field, geo_field_types):
|
if not isinstance(field, geo_field_types):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if hasattr(field, 'document_type'):
|
if hasattr(field, 'document_type'):
|
||||||
field_cls = field.document_type
|
field_cls = field.document_type
|
||||||
if field_cls in inspected:
|
if field_cls in inspected:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if hasattr(field_cls, '_geo_indices'):
|
if hasattr(field_cls, '_geo_indices'):
|
||||||
geo_indices += field_cls._geo_indices(
|
geo_indices += field_cls._geo_indices(
|
||||||
inspected, parent_field=field.db_field)
|
inspected, parent_field=field.db_field)
|
||||||
elif field._geo_index:
|
elif field._geo_index:
|
||||||
field_name = field.db_field
|
field_name = field.db_field
|
||||||
if parent_field:
|
if parent_field:
|
||||||
field_name = "%s.%s" % (parent_field, field_name)
|
field_name = '%s.%s' % (parent_field, field_name)
|
||||||
geo_indices.append({'fields':
|
geo_indices.append({
|
||||||
[(field_name, field._geo_index)]})
|
'fields': [(field_name, field._geo_index)]
|
||||||
|
})
|
||||||
|
|
||||||
return geo_indices
|
return geo_indices
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _lookup_field(cls, parts):
|
def _lookup_field(cls, parts):
|
||||||
"""Lookup a field based on its attribute and return a list containing
|
"""Given the path to a given field, return a list containing
|
||||||
the field's parents and the field.
|
the Field object associated with that field and all of its parent
|
||||||
"""
|
Field objects.
|
||||||
|
|
||||||
ListField = _import_class("ListField")
|
Args:
|
||||||
|
parts (str, list, or tuple) - path to the field. Should be a
|
||||||
|
string for simple fields existing on this document or a list
|
||||||
|
of strings for a field that exists deeper in embedded documents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of Field instances for fields that were found or
|
||||||
|
strings for sub-fields that weren't.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> user._lookup_field('name')
|
||||||
|
[<mongoengine.fields.StringField at 0x1119bff50>]
|
||||||
|
|
||||||
|
>>> user._lookup_field('roles')
|
||||||
|
[<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>]
|
||||||
|
|
||||||
|
>>> user._lookup_field(['roles', 'role'])
|
||||||
|
[<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>,
|
||||||
|
<mongoengine.fields.StringField at 0x1119ec050>]
|
||||||
|
|
||||||
|
>>> user._lookup_field('doesnt_exist')
|
||||||
|
raises LookUpError
|
||||||
|
|
||||||
|
>>> user._lookup_field(['roles', 'doesnt_exist'])
|
||||||
|
[<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>,
|
||||||
|
'doesnt_exist']
|
||||||
|
|
||||||
|
"""
|
||||||
|
# TODO this method is WAY too complicated. Simplify it.
|
||||||
|
# TODO don't think returning a string for embedded non-existent fields is desired
|
||||||
|
|
||||||
|
ListField = _import_class('ListField')
|
||||||
DynamicField = _import_class('DynamicField')
|
DynamicField = _import_class('DynamicField')
|
||||||
|
|
||||||
if not isinstance(parts, (list, tuple)):
|
if not isinstance(parts, (list, tuple)):
|
||||||
parts = [parts]
|
parts = [parts]
|
||||||
|
|
||||||
fields = []
|
fields = []
|
||||||
field = None
|
field = None
|
||||||
|
|
||||||
@ -943,16 +980,17 @@ class BaseDocument(object):
|
|||||||
fields.append(field_name)
|
fields.append(field_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Look up first field from the document
|
||||||
if field is None:
|
if field is None:
|
||||||
# Look up first field from the document
|
|
||||||
if field_name == 'pk':
|
if field_name == 'pk':
|
||||||
# Deal with "primary key" alias
|
# Deal with "primary key" alias
|
||||||
field_name = cls._meta['id_field']
|
field_name = cls._meta['id_field']
|
||||||
|
|
||||||
if field_name in cls._fields:
|
if field_name in cls._fields:
|
||||||
field = cls._fields[field_name]
|
field = cls._fields[field_name]
|
||||||
elif cls._dynamic:
|
elif cls._dynamic:
|
||||||
field = DynamicField(db_field=field_name)
|
field = DynamicField(db_field=field_name)
|
||||||
elif cls._meta.get("allow_inheritance", False) or cls._meta.get("abstract", False):
|
elif cls._meta.get('allow_inheritance') or cls._meta.get('abstract', False):
|
||||||
# 744: in case the field is defined in a subclass
|
# 744: in case the field is defined in a subclass
|
||||||
for subcls in cls.__subclasses__():
|
for subcls in cls.__subclasses__():
|
||||||
try:
|
try:
|
||||||
@ -965,35 +1003,55 @@ class BaseDocument(object):
|
|||||||
else:
|
else:
|
||||||
raise LookUpError('Cannot resolve field "%s"' % field_name)
|
raise LookUpError('Cannot resolve field "%s"' % field_name)
|
||||||
else:
|
else:
|
||||||
raise LookUpError('Cannot resolve field "%s"'
|
raise LookUpError('Cannot resolve field "%s"' % field_name)
|
||||||
% field_name)
|
|
||||||
else:
|
else:
|
||||||
ReferenceField = _import_class('ReferenceField')
|
ReferenceField = _import_class('ReferenceField')
|
||||||
GenericReferenceField = _import_class('GenericReferenceField')
|
GenericReferenceField = _import_class('GenericReferenceField')
|
||||||
|
|
||||||
|
# If previous field was a reference, throw an error (we
|
||||||
|
# cannot look up fields that are on references).
|
||||||
if isinstance(field, (ReferenceField, GenericReferenceField)):
|
if isinstance(field, (ReferenceField, GenericReferenceField)):
|
||||||
raise LookUpError('Cannot perform join in mongoDB: %s' %
|
raise LookUpError('Cannot perform join in mongoDB: %s' %
|
||||||
'__'.join(parts))
|
'__'.join(parts))
|
||||||
|
|
||||||
|
# If the parent field has a "field" attribute which has a
|
||||||
|
# lookup_member method, call it to find the field
|
||||||
|
# corresponding to this iteration.
|
||||||
if hasattr(getattr(field, 'field', None), 'lookup_member'):
|
if hasattr(getattr(field, 'field', None), 'lookup_member'):
|
||||||
new_field = field.field.lookup_member(field_name)
|
new_field = field.field.lookup_member(field_name)
|
||||||
|
|
||||||
|
# If the parent field is a DynamicField or if it's part of
|
||||||
|
# a DynamicDocument, mark current field as a DynamicField
|
||||||
|
# with db_name equal to the field name.
|
||||||
elif cls._dynamic and (isinstance(field, DynamicField) or
|
elif cls._dynamic and (isinstance(field, DynamicField) or
|
||||||
getattr(getattr(field, 'document_type', None), '_dynamic', None)):
|
getattr(getattr(field, 'document_type', None), '_dynamic', None)):
|
||||||
new_field = DynamicField(db_field=field_name)
|
new_field = DynamicField(db_field=field_name)
|
||||||
|
|
||||||
|
# Else, try to use the parent field's lookup_member method
|
||||||
|
# to find the subfield.
|
||||||
|
elif hasattr(field, 'lookup_member'):
|
||||||
|
new_field = field.lookup_member(field_name)
|
||||||
|
|
||||||
|
# Raise a LookUpError if all the other conditions failed.
|
||||||
else:
|
else:
|
||||||
# Look up subfield on the previous field or raise
|
raise LookUpError(
|
||||||
try:
|
'Cannot resolve subfield or operator {} '
|
||||||
new_field = field.lookup_member(field_name)
|
'on the field {}'.format(field_name, field.name)
|
||||||
except AttributeError:
|
)
|
||||||
raise LookUpError('Cannot resolve subfield or operator {} '
|
|
||||||
'on the field {}'.format(
|
# If current field still wasn't found and the parent field
|
||||||
field_name, field.name))
|
# is a ComplexBaseField, add the name current field name and
|
||||||
|
# move on.
|
||||||
if not new_field and isinstance(field, ComplexBaseField):
|
if not new_field and isinstance(field, ComplexBaseField):
|
||||||
fields.append(field_name)
|
fields.append(field_name)
|
||||||
continue
|
continue
|
||||||
elif not new_field:
|
elif not new_field:
|
||||||
raise LookUpError('Cannot resolve field "%s"'
|
raise LookUpError('Cannot resolve field "%s"' % field_name)
|
||||||
% field_name)
|
|
||||||
field = new_field # update field to the new field type
|
field = new_field # update field to the new field type
|
||||||
|
|
||||||
fields.append(field)
|
fields.append(field)
|
||||||
|
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -4,21 +4,17 @@ import weakref
|
|||||||
|
|
||||||
from bson import DBRef, ObjectId, SON
|
from bson import DBRef, ObjectId, SON
|
||||||
import pymongo
|
import pymongo
|
||||||
|
import six
|
||||||
|
|
||||||
from mongoengine.base.common import ALLOW_INHERITANCE
|
from mongoengine.base.common import UPDATE_OPERATORS
|
||||||
from mongoengine.base.datastructures import (
|
from mongoengine.base.datastructures import (BaseDict, BaseList,
|
||||||
BaseDict, BaseList, EmbeddedDocumentList
|
EmbeddedDocumentList)
|
||||||
)
|
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.errors import ValidationError
|
from mongoengine.errors import ValidationError
|
||||||
|
|
||||||
__all__ = ("BaseField", "ComplexBaseField",
|
|
||||||
"ObjectIdField", "GeoJsonBaseField")
|
|
||||||
|
|
||||||
|
__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField',
|
||||||
UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push',
|
'GeoJsonBaseField')
|
||||||
'push_all', 'pull', 'pull_all', 'add_to_set',
|
|
||||||
'set_on_insert', 'min', 'max'])
|
|
||||||
|
|
||||||
|
|
||||||
class BaseField(object):
|
class BaseField(object):
|
||||||
@ -73,7 +69,7 @@ class BaseField(object):
|
|||||||
self.db_field = (db_field or name) if not primary_key else '_id'
|
self.db_field = (db_field or name) if not primary_key else '_id'
|
||||||
|
|
||||||
if name:
|
if name:
|
||||||
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
|
msg = 'Field\'s "name" attribute deprecated in favour of "db_field"'
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
self.required = required or primary_key
|
self.required = required or primary_key
|
||||||
self.default = default
|
self.default = default
|
||||||
@ -89,7 +85,7 @@ class BaseField(object):
|
|||||||
# Detect and report conflicts between metadata and base properties.
|
# Detect and report conflicts between metadata and base properties.
|
||||||
conflicts = set(dir(self)) & set(kwargs)
|
conflicts = set(dir(self)) & set(kwargs)
|
||||||
if conflicts:
|
if conflicts:
|
||||||
raise TypeError("%s already has attribute(s): %s" % (
|
raise TypeError('%s already has attribute(s): %s' % (
|
||||||
self.__class__.__name__, ', '.join(conflicts)))
|
self.__class__.__name__, ', '.join(conflicts)))
|
||||||
|
|
||||||
# Assign metadata to the instance
|
# Assign metadata to the instance
|
||||||
@ -147,25 +143,21 @@ class BaseField(object):
|
|||||||
v._instance = weakref.proxy(instance)
|
v._instance = weakref.proxy(instance)
|
||||||
instance._data[self.name] = value
|
instance._data[self.name] = value
|
||||||
|
|
||||||
def error(self, message="", errors=None, field_name=None):
|
def error(self, message='', errors=None, field_name=None):
|
||||||
"""Raises a ValidationError.
|
"""Raise a ValidationError."""
|
||||||
"""
|
|
||||||
field_name = field_name if field_name else self.name
|
field_name = field_name if field_name else self.name
|
||||||
raise ValidationError(message, errors=errors, field_name=field_name)
|
raise ValidationError(message, errors=errors, field_name=field_name)
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
"""Convert a MongoDB-compatible type to a Python type.
|
"""Convert a MongoDB-compatible type to a Python type."""
|
||||||
"""
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def to_mongo(self, value):
|
def to_mongo(self, value):
|
||||||
"""Convert a Python type to a MongoDB-compatible type.
|
"""Convert a Python type to a MongoDB-compatible type."""
|
||||||
"""
|
|
||||||
return self.to_python(value)
|
return self.to_python(value)
|
||||||
|
|
||||||
def _to_mongo_safe_call(self, value, use_db_field=True, fields=None):
|
def _to_mongo_safe_call(self, value, use_db_field=True, fields=None):
|
||||||
"""A helper method to call to_mongo with proper inputs
|
"""Helper method to call to_mongo with proper inputs."""
|
||||||
"""
|
|
||||||
f_inputs = self.to_mongo.__code__.co_varnames
|
f_inputs = self.to_mongo.__code__.co_varnames
|
||||||
ex_vars = {}
|
ex_vars = {}
|
||||||
if 'fields' in f_inputs:
|
if 'fields' in f_inputs:
|
||||||
@ -177,15 +169,13 @@ class BaseField(object):
|
|||||||
return self.to_mongo(value, **ex_vars)
|
return self.to_mongo(value, **ex_vars)
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
"""Prepare a value that is being used in a query for PyMongo.
|
"""Prepare a value that is being used in a query for PyMongo."""
|
||||||
"""
|
|
||||||
if op in UPDATE_OPERATORS:
|
if op in UPDATE_OPERATORS:
|
||||||
self.validate(value)
|
self.validate(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def validate(self, value, clean=True):
|
def validate(self, value, clean=True):
|
||||||
"""Perform validation on a value.
|
"""Perform validation on a value."""
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _validate_choices(self, value):
|
def _validate_choices(self, value):
|
||||||
@ -200,11 +190,13 @@ class BaseField(object):
|
|||||||
if isinstance(value, (Document, EmbeddedDocument)):
|
if isinstance(value, (Document, EmbeddedDocument)):
|
||||||
if not any(isinstance(value, c) for c in choice_list):
|
if not any(isinstance(value, c) for c in choice_list):
|
||||||
self.error(
|
self.error(
|
||||||
'Value must be instance of %s' % unicode(choice_list)
|
'Value must be an instance of %s' % (
|
||||||
|
six.text_type(choice_list)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# Choices which are types other than Documents
|
# Choices which are types other than Documents
|
||||||
elif value not in choice_list:
|
elif value not in choice_list:
|
||||||
self.error('Value must be one of %s' % unicode(choice_list))
|
self.error('Value must be one of %s' % six.text_type(choice_list))
|
||||||
|
|
||||||
def _validate(self, value, **kwargs):
|
def _validate(self, value, **kwargs):
|
||||||
# Check the Choices Constraint
|
# Check the Choices Constraint
|
||||||
@ -247,8 +239,7 @@ class ComplexBaseField(BaseField):
|
|||||||
field = None
|
field = None
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
"""Descriptor to automatically dereference references.
|
"""Descriptor to automatically dereference references."""
|
||||||
"""
|
|
||||||
if instance is None:
|
if instance is None:
|
||||||
# Document class being used rather than a document object
|
# Document class being used rather than a document object
|
||||||
return self
|
return self
|
||||||
@ -260,7 +251,7 @@ class ComplexBaseField(BaseField):
|
|||||||
(self.field is None or isinstance(self.field,
|
(self.field is None or isinstance(self.field,
|
||||||
(GenericReferenceField, ReferenceField))))
|
(GenericReferenceField, ReferenceField))))
|
||||||
|
|
||||||
_dereference = _import_class("DeReference")()
|
_dereference = _import_class('DeReference')()
|
||||||
|
|
||||||
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
||||||
if instance._initialised and dereference and instance._data.get(self.name):
|
if instance._initialised and dereference and instance._data.get(self.name):
|
||||||
@ -295,9 +286,8 @@ class ComplexBaseField(BaseField):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
"""Convert a MongoDB-compatible type to a Python type.
|
"""Convert a MongoDB-compatible type to a Python type."""
|
||||||
"""
|
if isinstance(value, six.string_types):
|
||||||
if isinstance(value, basestring):
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if hasattr(value, 'to_python'):
|
if hasattr(value, 'to_python'):
|
||||||
@ -307,14 +297,14 @@ class ComplexBaseField(BaseField):
|
|||||||
if not hasattr(value, 'items'):
|
if not hasattr(value, 'items'):
|
||||||
try:
|
try:
|
||||||
is_list = True
|
is_list = True
|
||||||
value = dict([(k, v) for k, v in enumerate(value)])
|
value = {k: v for k, v in enumerate(value)}
|
||||||
except TypeError: # Not iterable return the value
|
except TypeError: # Not iterable return the value
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if self.field:
|
if self.field:
|
||||||
self.field._auto_dereference = self._auto_dereference
|
self.field._auto_dereference = self._auto_dereference
|
||||||
value_dict = dict([(key, self.field.to_python(item))
|
value_dict = {key: self.field.to_python(item)
|
||||||
for key, item in value.items()])
|
for key, item in value.items()}
|
||||||
else:
|
else:
|
||||||
Document = _import_class('Document')
|
Document = _import_class('Document')
|
||||||
value_dict = {}
|
value_dict = {}
|
||||||
@ -337,13 +327,12 @@ class ComplexBaseField(BaseField):
|
|||||||
return value_dict
|
return value_dict
|
||||||
|
|
||||||
def to_mongo(self, value, use_db_field=True, fields=None):
|
def to_mongo(self, value, use_db_field=True, fields=None):
|
||||||
"""Convert a Python type to a MongoDB-compatible type.
|
"""Convert a Python type to a MongoDB-compatible type."""
|
||||||
"""
|
Document = _import_class('Document')
|
||||||
Document = _import_class("Document")
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||||
EmbeddedDocument = _import_class("EmbeddedDocument")
|
GenericReferenceField = _import_class('GenericReferenceField')
|
||||||
GenericReferenceField = _import_class("GenericReferenceField")
|
|
||||||
|
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, six.string_types):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if hasattr(value, 'to_mongo'):
|
if hasattr(value, 'to_mongo'):
|
||||||
@ -360,13 +349,15 @@ class ComplexBaseField(BaseField):
|
|||||||
if not hasattr(value, 'items'):
|
if not hasattr(value, 'items'):
|
||||||
try:
|
try:
|
||||||
is_list = True
|
is_list = True
|
||||||
value = dict([(k, v) for k, v in enumerate(value)])
|
value = {k: v for k, v in enumerate(value)}
|
||||||
except TypeError: # Not iterable return the value
|
except TypeError: # Not iterable return the value
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if self.field:
|
if self.field:
|
||||||
value_dict = dict([(key, self.field._to_mongo_safe_call(item, use_db_field, fields))
|
value_dict = {
|
||||||
for key, item in value.iteritems()])
|
key: self.field._to_mongo_safe_call(item, use_db_field, fields)
|
||||||
|
for key, item in value.iteritems()
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
value_dict = {}
|
value_dict = {}
|
||||||
for k, v in value.iteritems():
|
for k, v in value.iteritems():
|
||||||
@ -380,9 +371,7 @@ class ComplexBaseField(BaseField):
|
|||||||
# any _cls data so make it a generic reference allows
|
# any _cls data so make it a generic reference allows
|
||||||
# us to dereference
|
# us to dereference
|
||||||
meta = getattr(v, '_meta', {})
|
meta = getattr(v, '_meta', {})
|
||||||
allow_inheritance = (
|
allow_inheritance = meta.get('allow_inheritance')
|
||||||
meta.get('allow_inheritance', ALLOW_INHERITANCE)
|
|
||||||
is True)
|
|
||||||
if not allow_inheritance and not self.field:
|
if not allow_inheritance and not self.field:
|
||||||
value_dict[k] = GenericReferenceField().to_mongo(v)
|
value_dict[k] = GenericReferenceField().to_mongo(v)
|
||||||
else:
|
else:
|
||||||
@ -404,8 +393,7 @@ class ComplexBaseField(BaseField):
|
|||||||
return value_dict
|
return value_dict
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
"""If field is provided ensure the value is valid.
|
"""If field is provided ensure the value is valid."""
|
||||||
"""
|
|
||||||
errors = {}
|
errors = {}
|
||||||
if self.field:
|
if self.field:
|
||||||
if hasattr(value, 'iteritems') or hasattr(value, 'items'):
|
if hasattr(value, 'iteritems') or hasattr(value, 'items'):
|
||||||
@ -415,9 +403,9 @@ class ComplexBaseField(BaseField):
|
|||||||
for k, v in sequence:
|
for k, v in sequence:
|
||||||
try:
|
try:
|
||||||
self.field._validate(v)
|
self.field._validate(v)
|
||||||
except ValidationError, error:
|
except ValidationError as error:
|
||||||
errors[k] = error.errors or error
|
errors[k] = error.errors or error
|
||||||
except (ValueError, AssertionError), error:
|
except (ValueError, AssertionError) as error:
|
||||||
errors[k] = error
|
errors[k] = error
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
@ -443,8 +431,7 @@ class ComplexBaseField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class ObjectIdField(BaseField):
|
class ObjectIdField(BaseField):
|
||||||
"""A field wrapper around MongoDB's ObjectIds.
|
"""A field wrapper around MongoDB's ObjectIds."""
|
||||||
"""
|
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
try:
|
try:
|
||||||
@ -457,10 +444,10 @@ class ObjectIdField(BaseField):
|
|||||||
def to_mongo(self, value):
|
def to_mongo(self, value):
|
||||||
if not isinstance(value, ObjectId):
|
if not isinstance(value, ObjectId):
|
||||||
try:
|
try:
|
||||||
return ObjectId(unicode(value))
|
return ObjectId(six.text_type(value))
|
||||||
except Exception, e:
|
except Exception as e:
|
||||||
# e.message attribute has been deprecated since Python 2.6
|
# e.message attribute has been deprecated since Python 2.6
|
||||||
self.error(unicode(e))
|
self.error(six.text_type(e))
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
@ -468,7 +455,7 @@ class ObjectIdField(BaseField):
|
|||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
try:
|
try:
|
||||||
ObjectId(unicode(value))
|
ObjectId(six.text_type(value))
|
||||||
except Exception:
|
except Exception:
|
||||||
self.error('Invalid Object ID')
|
self.error('Invalid Object ID')
|
||||||
|
|
||||||
@ -480,21 +467,20 @@ class GeoJsonBaseField(BaseField):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_geo_index = pymongo.GEOSPHERE
|
_geo_index = pymongo.GEOSPHERE
|
||||||
_type = "GeoBase"
|
_type = 'GeoBase'
|
||||||
|
|
||||||
def __init__(self, auto_index=True, *args, **kwargs):
|
def __init__(self, auto_index=True, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
:param bool auto_index: Automatically create a "2dsphere" index.\
|
:param bool auto_index: Automatically create a '2dsphere' index.\
|
||||||
Defaults to `True`.
|
Defaults to `True`.
|
||||||
"""
|
"""
|
||||||
self._name = "%sField" % self._type
|
self._name = '%sField' % self._type
|
||||||
if not auto_index:
|
if not auto_index:
|
||||||
self._geo_index = False
|
self._geo_index = False
|
||||||
super(GeoJsonBaseField, self).__init__(*args, **kwargs)
|
super(GeoJsonBaseField, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
"""Validate the GeoJson object based on its type
|
"""Validate the GeoJson object based on its type."""
|
||||||
"""
|
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
if set(value.keys()) == set(['type', 'coordinates']):
|
if set(value.keys()) == set(['type', 'coordinates']):
|
||||||
if value['type'] != self._type:
|
if value['type'] != self._type:
|
||||||
@ -509,7 +495,7 @@ class GeoJsonBaseField(BaseField):
|
|||||||
self.error('%s can only accept lists of [x, y]' % self._name)
|
self.error('%s can only accept lists of [x, y]' % self._name)
|
||||||
return
|
return
|
||||||
|
|
||||||
validate = getattr(self, "_validate_%s" % self._type.lower())
|
validate = getattr(self, '_validate_%s' % self._type.lower())
|
||||||
error = validate(value)
|
error = validate(value)
|
||||||
if error:
|
if error:
|
||||||
self.error(error)
|
self.error(error)
|
||||||
@ -522,7 +508,7 @@ class GeoJsonBaseField(BaseField):
|
|||||||
try:
|
try:
|
||||||
value[0][0][0]
|
value[0][0][0]
|
||||||
except (TypeError, IndexError):
|
except (TypeError, IndexError):
|
||||||
return "Invalid Polygon must contain at least one valid linestring"
|
return 'Invalid Polygon must contain at least one valid linestring'
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
for val in value:
|
for val in value:
|
||||||
@ -533,12 +519,12 @@ class GeoJsonBaseField(BaseField):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
if errors:
|
if errors:
|
||||||
if top_level:
|
if top_level:
|
||||||
return "Invalid Polygon:\n%s" % ", ".join(errors)
|
return 'Invalid Polygon:\n%s' % ', '.join(errors)
|
||||||
else:
|
else:
|
||||||
return "%s" % ", ".join(errors)
|
return '%s' % ', '.join(errors)
|
||||||
|
|
||||||
def _validate_linestring(self, value, top_level=True):
|
def _validate_linestring(self, value, top_level=True):
|
||||||
"""Validates a linestring"""
|
"""Validate a linestring."""
|
||||||
if not isinstance(value, (list, tuple)):
|
if not isinstance(value, (list, tuple)):
|
||||||
return 'LineStrings must contain list of coordinate pairs'
|
return 'LineStrings must contain list of coordinate pairs'
|
||||||
|
|
||||||
@ -546,7 +532,7 @@ class GeoJsonBaseField(BaseField):
|
|||||||
try:
|
try:
|
||||||
value[0][0]
|
value[0][0]
|
||||||
except (TypeError, IndexError):
|
except (TypeError, IndexError):
|
||||||
return "Invalid LineString must contain at least one valid point"
|
return 'Invalid LineString must contain at least one valid point'
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
for val in value:
|
for val in value:
|
||||||
@ -555,19 +541,19 @@ class GeoJsonBaseField(BaseField):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
if errors:
|
if errors:
|
||||||
if top_level:
|
if top_level:
|
||||||
return "Invalid LineString:\n%s" % ", ".join(errors)
|
return 'Invalid LineString:\n%s' % ', '.join(errors)
|
||||||
else:
|
else:
|
||||||
return "%s" % ", ".join(errors)
|
return '%s' % ', '.join(errors)
|
||||||
|
|
||||||
def _validate_point(self, value):
|
def _validate_point(self, value):
|
||||||
"""Validate each set of coords"""
|
"""Validate each set of coords"""
|
||||||
if not isinstance(value, (list, tuple)):
|
if not isinstance(value, (list, tuple)):
|
||||||
return 'Points must be a list of coordinate pairs'
|
return 'Points must be a list of coordinate pairs'
|
||||||
elif not len(value) == 2:
|
elif not len(value) == 2:
|
||||||
return "Value (%s) must be a two-dimensional point" % repr(value)
|
return 'Value (%s) must be a two-dimensional point' % repr(value)
|
||||||
elif (not isinstance(value[0], (float, int)) or
|
elif (not isinstance(value[0], (float, int)) or
|
||||||
not isinstance(value[1], (float, int))):
|
not isinstance(value[1], (float, int))):
|
||||||
return "Both values (%s) in point must be float or int" % repr(value)
|
return 'Both values (%s) in point must be float or int' % repr(value)
|
||||||
|
|
||||||
def _validate_multipoint(self, value):
|
def _validate_multipoint(self, value):
|
||||||
if not isinstance(value, (list, tuple)):
|
if not isinstance(value, (list, tuple)):
|
||||||
@ -577,7 +563,7 @@ class GeoJsonBaseField(BaseField):
|
|||||||
try:
|
try:
|
||||||
value[0][0]
|
value[0][0]
|
||||||
except (TypeError, IndexError):
|
except (TypeError, IndexError):
|
||||||
return "Invalid MultiPoint must contain at least one valid point"
|
return 'Invalid MultiPoint must contain at least one valid point'
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
for point in value:
|
for point in value:
|
||||||
@ -586,7 +572,7 @@ class GeoJsonBaseField(BaseField):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return "%s" % ", ".join(errors)
|
return '%s' % ', '.join(errors)
|
||||||
|
|
||||||
def _validate_multilinestring(self, value, top_level=True):
|
def _validate_multilinestring(self, value, top_level=True):
|
||||||
if not isinstance(value, (list, tuple)):
|
if not isinstance(value, (list, tuple)):
|
||||||
@ -596,7 +582,7 @@ class GeoJsonBaseField(BaseField):
|
|||||||
try:
|
try:
|
||||||
value[0][0][0]
|
value[0][0][0]
|
||||||
except (TypeError, IndexError):
|
except (TypeError, IndexError):
|
||||||
return "Invalid MultiLineString must contain at least one valid linestring"
|
return 'Invalid MultiLineString must contain at least one valid linestring'
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
for linestring in value:
|
for linestring in value:
|
||||||
@ -606,9 +592,9 @@ class GeoJsonBaseField(BaseField):
|
|||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
if top_level:
|
if top_level:
|
||||||
return "Invalid MultiLineString:\n%s" % ", ".join(errors)
|
return 'Invalid MultiLineString:\n%s' % ', '.join(errors)
|
||||||
else:
|
else:
|
||||||
return "%s" % ", ".join(errors)
|
return '%s' % ', '.join(errors)
|
||||||
|
|
||||||
def _validate_multipolygon(self, value):
|
def _validate_multipolygon(self, value):
|
||||||
if not isinstance(value, (list, tuple)):
|
if not isinstance(value, (list, tuple)):
|
||||||
@ -618,7 +604,7 @@ class GeoJsonBaseField(BaseField):
|
|||||||
try:
|
try:
|
||||||
value[0][0][0][0]
|
value[0][0][0][0]
|
||||||
except (TypeError, IndexError):
|
except (TypeError, IndexError):
|
||||||
return "Invalid MultiPolygon must contain at least one valid Polygon"
|
return 'Invalid MultiPolygon must contain at least one valid Polygon'
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
for polygon in value:
|
for polygon in value:
|
||||||
@ -627,9 +613,9 @@ class GeoJsonBaseField(BaseField):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return "Invalid MultiPolygon:\n%s" % ", ".join(errors)
|
return 'Invalid MultiPolygon:\n%s' % ', '.join(errors)
|
||||||
|
|
||||||
def to_mongo(self, value):
|
def to_mongo(self, value):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
return value
|
return value
|
||||||
return SON([("type", self._type), ("coordinates", value)])
|
return SON([('type', self._type), ('coordinates', value)])
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from mongoengine.base.common import ALLOW_INHERITANCE, _document_registry
|
import six
|
||||||
|
|
||||||
|
from mongoengine.base.common import _document_registry
|
||||||
from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField
|
from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.errors import InvalidDocumentError
|
from mongoengine.errors import InvalidDocumentError
|
||||||
from mongoengine.python_support import PY3
|
|
||||||
from mongoengine.queryset import (DO_NOTHING, DoesNotExist,
|
from mongoengine.queryset import (DO_NOTHING, DoesNotExist,
|
||||||
MultipleObjectsReturned,
|
MultipleObjectsReturned,
|
||||||
QuerySetManager)
|
QuerySetManager)
|
||||||
@ -45,7 +46,8 @@ class DocumentMetaclass(type):
|
|||||||
attrs['_meta'] = meta
|
attrs['_meta'] = meta
|
||||||
attrs['_meta']['abstract'] = False # 789: EmbeddedDocument shouldn't inherit abstract
|
attrs['_meta']['abstract'] = False # 789: EmbeddedDocument shouldn't inherit abstract
|
||||||
|
|
||||||
if attrs['_meta'].get('allow_inheritance', ALLOW_INHERITANCE):
|
# If allow_inheritance is True, add a "_cls" string field to the attrs
|
||||||
|
if attrs['_meta'].get('allow_inheritance'):
|
||||||
StringField = _import_class('StringField')
|
StringField = _import_class('StringField')
|
||||||
attrs['_cls'] = StringField()
|
attrs['_cls'] = StringField()
|
||||||
|
|
||||||
@ -87,16 +89,17 @@ class DocumentMetaclass(type):
|
|||||||
# Ensure no duplicate db_fields
|
# Ensure no duplicate db_fields
|
||||||
duplicate_db_fields = [k for k, v in field_names.items() if v > 1]
|
duplicate_db_fields = [k for k, v in field_names.items() if v > 1]
|
||||||
if duplicate_db_fields:
|
if duplicate_db_fields:
|
||||||
msg = ("Multiple db_fields defined for: %s " %
|
msg = ('Multiple db_fields defined for: %s ' %
|
||||||
", ".join(duplicate_db_fields))
|
', '.join(duplicate_db_fields))
|
||||||
raise InvalidDocumentError(msg)
|
raise InvalidDocumentError(msg)
|
||||||
|
|
||||||
# Set _fields and db_field maps
|
# Set _fields and db_field maps
|
||||||
attrs['_fields'] = doc_fields
|
attrs['_fields'] = doc_fields
|
||||||
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k))
|
attrs['_db_field_map'] = {k: getattr(v, 'db_field', k)
|
||||||
for k, v in doc_fields.iteritems()])
|
for k, v in doc_fields.items()}
|
||||||
attrs['_reverse_db_field_map'] = dict(
|
attrs['_reverse_db_field_map'] = {
|
||||||
(v, k) for k, v in attrs['_db_field_map'].iteritems())
|
v: k for k, v in attrs['_db_field_map'].items()
|
||||||
|
}
|
||||||
|
|
||||||
attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
|
attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
|
||||||
(v.creation_counter, v.name)
|
(v.creation_counter, v.name)
|
||||||
@ -116,10 +119,8 @@ class DocumentMetaclass(type):
|
|||||||
if hasattr(base, '_meta'):
|
if hasattr(base, '_meta'):
|
||||||
# Warn if allow_inheritance isn't set and prevent
|
# Warn if allow_inheritance isn't set and prevent
|
||||||
# inheritance of classes where inheritance is set to False
|
# inheritance of classes where inheritance is set to False
|
||||||
allow_inheritance = base._meta.get('allow_inheritance',
|
allow_inheritance = base._meta.get('allow_inheritance')
|
||||||
ALLOW_INHERITANCE)
|
if not allow_inheritance and not base._meta.get('abstract'):
|
||||||
if (allow_inheritance is not True and
|
|
||||||
not base._meta.get('abstract')):
|
|
||||||
raise ValueError('Document %s may not be subclassed' %
|
raise ValueError('Document %s may not be subclassed' %
|
||||||
base.__name__)
|
base.__name__)
|
||||||
|
|
||||||
@ -161,7 +162,7 @@ class DocumentMetaclass(type):
|
|||||||
# module continues to use im_func and im_self, so the code below
|
# module continues to use im_func and im_self, so the code below
|
||||||
# copies __func__ into im_func and __self__ into im_self for
|
# copies __func__ into im_func and __self__ into im_self for
|
||||||
# classmethod objects in Document derived classes.
|
# classmethod objects in Document derived classes.
|
||||||
if PY3:
|
if six.PY3:
|
||||||
for val in new_class.__dict__.values():
|
for val in new_class.__dict__.values():
|
||||||
if isinstance(val, classmethod):
|
if isinstance(val, classmethod):
|
||||||
f = val.__get__(new_class)
|
f = val.__get__(new_class)
|
||||||
@ -179,11 +180,11 @@ class DocumentMetaclass(type):
|
|||||||
if isinstance(f, CachedReferenceField):
|
if isinstance(f, CachedReferenceField):
|
||||||
|
|
||||||
if issubclass(new_class, EmbeddedDocument):
|
if issubclass(new_class, EmbeddedDocument):
|
||||||
raise InvalidDocumentError(
|
raise InvalidDocumentError('CachedReferenceFields is not '
|
||||||
"CachedReferenceFields is not allowed in EmbeddedDocuments")
|
'allowed in EmbeddedDocuments')
|
||||||
if not f.document_type:
|
if not f.document_type:
|
||||||
raise InvalidDocumentError(
|
raise InvalidDocumentError(
|
||||||
"Document is not available to sync")
|
'Document is not available to sync')
|
||||||
|
|
||||||
if f.auto_sync:
|
if f.auto_sync:
|
||||||
f.start_listener()
|
f.start_listener()
|
||||||
@ -195,8 +196,8 @@ class DocumentMetaclass(type):
|
|||||||
'reverse_delete_rule',
|
'reverse_delete_rule',
|
||||||
DO_NOTHING)
|
DO_NOTHING)
|
||||||
if isinstance(f, DictField) and delete_rule != DO_NOTHING:
|
if isinstance(f, DictField) and delete_rule != DO_NOTHING:
|
||||||
msg = ("Reverse delete rules are not supported "
|
msg = ('Reverse delete rules are not supported '
|
||||||
"for %s (field: %s)" %
|
'for %s (field: %s)' %
|
||||||
(field.__class__.__name__, field.name))
|
(field.__class__.__name__, field.name))
|
||||||
raise InvalidDocumentError(msg)
|
raise InvalidDocumentError(msg)
|
||||||
|
|
||||||
@ -204,16 +205,16 @@ class DocumentMetaclass(type):
|
|||||||
|
|
||||||
if delete_rule != DO_NOTHING:
|
if delete_rule != DO_NOTHING:
|
||||||
if issubclass(new_class, EmbeddedDocument):
|
if issubclass(new_class, EmbeddedDocument):
|
||||||
msg = ("Reverse delete rules are not supported for "
|
msg = ('Reverse delete rules are not supported for '
|
||||||
"EmbeddedDocuments (field: %s)" % field.name)
|
'EmbeddedDocuments (field: %s)' % field.name)
|
||||||
raise InvalidDocumentError(msg)
|
raise InvalidDocumentError(msg)
|
||||||
f.document_type.register_delete_rule(new_class,
|
f.document_type.register_delete_rule(new_class,
|
||||||
field.name, delete_rule)
|
field.name, delete_rule)
|
||||||
|
|
||||||
if (field.name and hasattr(Document, field.name) and
|
if (field.name and hasattr(Document, field.name) and
|
||||||
EmbeddedDocument not in new_class.mro()):
|
EmbeddedDocument not in new_class.mro()):
|
||||||
msg = ("%s is a document method and not a valid "
|
msg = ('%s is a document method and not a valid '
|
||||||
"field name" % field.name)
|
'field name' % field.name)
|
||||||
raise InvalidDocumentError(msg)
|
raise InvalidDocumentError(msg)
|
||||||
|
|
||||||
return new_class
|
return new_class
|
||||||
@ -271,6 +272,11 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
'index_drop_dups': False,
|
'index_drop_dups': False,
|
||||||
'index_opts': None,
|
'index_opts': None,
|
||||||
'delete_rules': None,
|
'delete_rules': None,
|
||||||
|
|
||||||
|
# allow_inheritance can be True, False, and None. True means
|
||||||
|
# "allow inheritance", False means "don't allow inheritance",
|
||||||
|
# None means "do whatever your parent does, or don't allow
|
||||||
|
# inheritance if you're a top-level class".
|
||||||
'allow_inheritance': None,
|
'allow_inheritance': None,
|
||||||
}
|
}
|
||||||
attrs['_is_base_cls'] = True
|
attrs['_is_base_cls'] = True
|
||||||
@ -303,7 +309,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
# If parent wasn't an abstract class
|
# If parent wasn't an abstract class
|
||||||
if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and
|
if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and
|
||||||
not parent_doc_cls._meta.get('abstract', True)):
|
not parent_doc_cls._meta.get('abstract', True)):
|
||||||
msg = "Trying to set a collection on a subclass (%s)" % name
|
msg = 'Trying to set a collection on a subclass (%s)' % name
|
||||||
warnings.warn(msg, SyntaxWarning)
|
warnings.warn(msg, SyntaxWarning)
|
||||||
del attrs['_meta']['collection']
|
del attrs['_meta']['collection']
|
||||||
|
|
||||||
@ -311,7 +317,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'):
|
if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'):
|
||||||
if (parent_doc_cls and
|
if (parent_doc_cls and
|
||||||
not parent_doc_cls._meta.get('abstract', False)):
|
not parent_doc_cls._meta.get('abstract', False)):
|
||||||
msg = "Abstract document cannot have non-abstract base"
|
msg = 'Abstract document cannot have non-abstract base'
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return super_new(cls, name, bases, attrs)
|
return super_new(cls, name, bases, attrs)
|
||||||
|
|
||||||
@ -334,12 +340,16 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
|
|
||||||
meta.merge(attrs.get('_meta', {})) # Top level meta
|
meta.merge(attrs.get('_meta', {})) # Top level meta
|
||||||
|
|
||||||
# Only simple classes (direct subclasses of Document)
|
# Only simple classes (i.e. direct subclasses of Document) may set
|
||||||
# may set allow_inheritance to False
|
# allow_inheritance to False. If the base Document allows inheritance,
|
||||||
|
# none of its subclasses can override allow_inheritance to False.
|
||||||
simple_class = all([b._meta.get('abstract')
|
simple_class = all([b._meta.get('abstract')
|
||||||
for b in flattened_bases if hasattr(b, '_meta')])
|
for b in flattened_bases if hasattr(b, '_meta')])
|
||||||
if (not simple_class and meta['allow_inheritance'] is False and
|
if (
|
||||||
not meta['abstract']):
|
not simple_class and
|
||||||
|
meta['allow_inheritance'] is False and
|
||||||
|
not meta['abstract']
|
||||||
|
):
|
||||||
raise ValueError('Only direct subclasses of Document may set '
|
raise ValueError('Only direct subclasses of Document may set '
|
||||||
'"allow_inheritance" to False')
|
'"allow_inheritance" to False')
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from pymongo import MongoClient, ReadPreference, uri_parser
|
from pymongo import MongoClient, ReadPreference, uri_parser
|
||||||
from mongoengine.python_support import (IS_PYMONGO_3, str_types)
|
import six
|
||||||
|
|
||||||
__all__ = ['ConnectionError', 'connect', 'register_connection',
|
from mongoengine.python_support import IS_PYMONGO_3
|
||||||
|
|
||||||
|
__all__ = ['MongoEngineConnectionError', 'connect', 'register_connection',
|
||||||
'DEFAULT_CONNECTION_NAME']
|
'DEFAULT_CONNECTION_NAME']
|
||||||
|
|
||||||
|
|
||||||
@ -14,7 +16,10 @@ else:
|
|||||||
READ_PREFERENCE = False
|
READ_PREFERENCE = False
|
||||||
|
|
||||||
|
|
||||||
class ConnectionError(Exception):
|
class MongoEngineConnectionError(Exception):
|
||||||
|
"""Error raised when the database connection can't be established or
|
||||||
|
when a connection with a requested alias can't be retrieved.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -50,8 +55,6 @@ def register_connection(alias, name=None, host=None, port=None,
|
|||||||
|
|
||||||
.. versionchanged:: 0.10.6 - added mongomock support
|
.. versionchanged:: 0.10.6 - added mongomock support
|
||||||
"""
|
"""
|
||||||
global _connection_settings
|
|
||||||
|
|
||||||
conn_settings = {
|
conn_settings = {
|
||||||
'name': name or 'test',
|
'name': name or 'test',
|
||||||
'host': host or 'localhost',
|
'host': host or 'localhost',
|
||||||
@ -66,7 +69,7 @@ def register_connection(alias, name=None, host=None, port=None,
|
|||||||
# Handle uri style connections
|
# Handle uri style connections
|
||||||
conn_host = conn_settings['host']
|
conn_host = conn_settings['host']
|
||||||
# host can be a list or a string, so if string, force to a list
|
# host can be a list or a string, so if string, force to a list
|
||||||
if isinstance(conn_host, str_types):
|
if isinstance(conn_host, six.string_types):
|
||||||
conn_host = [conn_host]
|
conn_host = [conn_host]
|
||||||
|
|
||||||
resolved_hosts = []
|
resolved_hosts = []
|
||||||
@ -111,9 +114,7 @@ def register_connection(alias, name=None, host=None, port=None,
|
|||||||
|
|
||||||
|
|
||||||
def disconnect(alias=DEFAULT_CONNECTION_NAME):
|
def disconnect(alias=DEFAULT_CONNECTION_NAME):
|
||||||
global _connections
|
"""Close the connection with a given alias."""
|
||||||
global _dbs
|
|
||||||
|
|
||||||
if alias in _connections:
|
if alias in _connections:
|
||||||
get_connection(alias=alias).close()
|
get_connection(alias=alias).close()
|
||||||
del _connections[alias]
|
del _connections[alias]
|
||||||
@ -122,71 +123,100 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME):
|
|||||||
|
|
||||||
|
|
||||||
def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||||
global _connections
|
"""Return a connection with a given alias."""
|
||||||
|
|
||||||
# Connect to the database if not already connected
|
# Connect to the database if not already connected
|
||||||
if reconnect:
|
if reconnect:
|
||||||
disconnect(alias)
|
disconnect(alias)
|
||||||
|
|
||||||
if alias not in _connections:
|
# If the requested alias already exists in the _connections list, return
|
||||||
if alias not in _connection_settings:
|
# it immediately.
|
||||||
msg = 'Connection with alias "%s" has not been defined' % alias
|
if alias in _connections:
|
||||||
if alias == DEFAULT_CONNECTION_NAME:
|
return _connections[alias]
|
||||||
msg = 'You have not defined a default connection'
|
|
||||||
raise ConnectionError(msg)
|
|
||||||
conn_settings = _connection_settings[alias].copy()
|
|
||||||
|
|
||||||
conn_settings.pop('name', None)
|
# Validate that the requested alias exists in the _connection_settings.
|
||||||
conn_settings.pop('username', None)
|
# Raise MongoEngineConnectionError if it doesn't.
|
||||||
conn_settings.pop('password', None)
|
if alias not in _connection_settings:
|
||||||
conn_settings.pop('authentication_source', None)
|
if alias == DEFAULT_CONNECTION_NAME:
|
||||||
conn_settings.pop('authentication_mechanism', None)
|
msg = 'You have not defined a default connection'
|
||||||
|
|
||||||
is_mock = conn_settings.pop('is_mock', None)
|
|
||||||
if is_mock:
|
|
||||||
# Use MongoClient from mongomock
|
|
||||||
try:
|
|
||||||
import mongomock
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError('You need mongomock installed '
|
|
||||||
'to mock MongoEngine.')
|
|
||||||
connection_class = mongomock.MongoClient
|
|
||||||
else:
|
else:
|
||||||
# Use MongoClient from pymongo
|
msg = 'Connection with alias "%s" has not been defined' % alias
|
||||||
connection_class = MongoClient
|
raise MongoEngineConnectionError(msg)
|
||||||
|
|
||||||
|
def _clean_settings(settings_dict):
|
||||||
|
irrelevant_fields = set([
|
||||||
|
'name', 'username', 'password', 'authentication_source',
|
||||||
|
'authentication_mechanism'
|
||||||
|
])
|
||||||
|
return {
|
||||||
|
k: v for k, v in settings_dict.items()
|
||||||
|
if k not in irrelevant_fields
|
||||||
|
}
|
||||||
|
|
||||||
|
# Retrieve a copy of the connection settings associated with the requested
|
||||||
|
# alias and remove the database name and authentication info (we don't
|
||||||
|
# care about them at this point).
|
||||||
|
conn_settings = _clean_settings(_connection_settings[alias].copy())
|
||||||
|
|
||||||
|
# Determine if we should use PyMongo's or mongomock's MongoClient.
|
||||||
|
is_mock = conn_settings.pop('is_mock', False)
|
||||||
|
if is_mock:
|
||||||
|
try:
|
||||||
|
import mongomock
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError('You need mongomock installed to mock '
|
||||||
|
'MongoEngine.')
|
||||||
|
connection_class = mongomock.MongoClient
|
||||||
|
else:
|
||||||
|
connection_class = MongoClient
|
||||||
|
|
||||||
|
# Handle replica set connections
|
||||||
if 'replicaSet' in conn_settings:
|
if 'replicaSet' in conn_settings:
|
||||||
|
|
||||||
# Discard port since it can't be used on MongoReplicaSetClient
|
# Discard port since it can't be used on MongoReplicaSetClient
|
||||||
conn_settings.pop('port', None)
|
conn_settings.pop('port', None)
|
||||||
# Discard replicaSet if not base string
|
|
||||||
if not isinstance(conn_settings['replicaSet'], basestring):
|
# Discard replicaSet if it's not a string
|
||||||
conn_settings.pop('replicaSet', None)
|
if not isinstance(conn_settings['replicaSet'], six.string_types):
|
||||||
|
del conn_settings['replicaSet']
|
||||||
|
|
||||||
|
# For replica set connections with PyMongo 2.x, use
|
||||||
|
# MongoReplicaSetClient.
|
||||||
|
# TODO remove this once we stop supporting PyMongo 2.x.
|
||||||
if not IS_PYMONGO_3:
|
if not IS_PYMONGO_3:
|
||||||
connection_class = MongoReplicaSetClient
|
connection_class = MongoReplicaSetClient
|
||||||
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
|
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
|
||||||
|
|
||||||
try:
|
# Iterate over all of the connection settings and if a connection with
|
||||||
connection = None
|
# the same parameters is already established, use it instead of creating
|
||||||
# check for shared connections
|
# a new one.
|
||||||
connection_settings_iterator = (
|
existing_connection = None
|
||||||
(db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems())
|
connection_settings_iterator = (
|
||||||
for db_alias, connection_settings in connection_settings_iterator:
|
(db_alias, settings.copy())
|
||||||
connection_settings.pop('name', None)
|
for db_alias, settings in _connection_settings.items()
|
||||||
connection_settings.pop('username', None)
|
)
|
||||||
connection_settings.pop('password', None)
|
for db_alias, connection_settings in connection_settings_iterator:
|
||||||
connection_settings.pop('authentication_source', None)
|
connection_settings = _clean_settings(connection_settings)
|
||||||
connection_settings.pop('authentication_mechanism', None)
|
if conn_settings == connection_settings and _connections.get(db_alias):
|
||||||
if conn_settings == connection_settings and _connections.get(db_alias, None):
|
existing_connection = _connections[db_alias]
|
||||||
connection = _connections[db_alias]
|
break
|
||||||
break
|
|
||||||
|
# If an existing connection was found, assign it to the new alias
|
||||||
|
if existing_connection:
|
||||||
|
_connections[alias] = existing_connection
|
||||||
|
else:
|
||||||
|
# Otherwise, create the new connection for this alias. Raise
|
||||||
|
# MongoEngineConnectionError if it can't be established.
|
||||||
|
try:
|
||||||
|
_connections[alias] = connection_class(**conn_settings)
|
||||||
|
except Exception as e:
|
||||||
|
raise MongoEngineConnectionError(
|
||||||
|
'Cannot connect to database %s :\n%s' % (alias, e))
|
||||||
|
|
||||||
_connections[alias] = connection if connection else connection_class(**conn_settings)
|
|
||||||
except Exception, e:
|
|
||||||
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
|
|
||||||
return _connections[alias]
|
return _connections[alias]
|
||||||
|
|
||||||
|
|
||||||
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||||
global _dbs
|
|
||||||
if reconnect:
|
if reconnect:
|
||||||
disconnect(alias)
|
disconnect(alias)
|
||||||
|
|
||||||
@ -217,7 +247,6 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
|
|||||||
|
|
||||||
.. versionchanged:: 0.6 - added multiple database support.
|
.. versionchanged:: 0.6 - added multiple database support.
|
||||||
"""
|
"""
|
||||||
global _connections
|
|
||||||
if alias not in _connections:
|
if alias not in _connections:
|
||||||
register_connection(alias, db, **kwargs)
|
register_connection(alias, db, **kwargs)
|
||||||
|
|
||||||
|
@ -2,12 +2,12 @@ from mongoengine.common import _import_class
|
|||||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||||
|
|
||||||
|
|
||||||
__all__ = ("switch_db", "switch_collection", "no_dereference",
|
__all__ = ('switch_db', 'switch_collection', 'no_dereference',
|
||||||
"no_sub_classes", "query_counter")
|
'no_sub_classes', 'query_counter')
|
||||||
|
|
||||||
|
|
||||||
class switch_db(object):
|
class switch_db(object):
|
||||||
""" switch_db alias context manager.
|
"""switch_db alias context manager.
|
||||||
|
|
||||||
Example ::
|
Example ::
|
||||||
|
|
||||||
@ -18,15 +18,14 @@ class switch_db(object):
|
|||||||
class Group(Document):
|
class Group(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
|
|
||||||
Group(name="test").save() # Saves in the default db
|
Group(name='test').save() # Saves in the default db
|
||||||
|
|
||||||
with switch_db(Group, 'testdb-1') as Group:
|
with switch_db(Group, 'testdb-1') as Group:
|
||||||
Group(name="hello testdb!").save() # Saves in testdb-1
|
Group(name='hello testdb!').save() # Saves in testdb-1
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cls, db_alias):
|
def __init__(self, cls, db_alias):
|
||||||
""" Construct the switch_db context manager
|
"""Construct the switch_db context manager
|
||||||
|
|
||||||
:param cls: the class to change the registered db
|
:param cls: the class to change the registered db
|
||||||
:param db_alias: the name of the specific database to use
|
:param db_alias: the name of the specific database to use
|
||||||
@ -34,37 +33,36 @@ class switch_db(object):
|
|||||||
self.cls = cls
|
self.cls = cls
|
||||||
self.collection = cls._get_collection()
|
self.collection = cls._get_collection()
|
||||||
self.db_alias = db_alias
|
self.db_alias = db_alias
|
||||||
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
|
self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
""" change the db_alias and clear the cached collection """
|
"""Change the db_alias and clear the cached collection."""
|
||||||
self.cls._meta["db_alias"] = self.db_alias
|
self.cls._meta['db_alias'] = self.db_alias
|
||||||
self.cls._collection = None
|
self.cls._collection = None
|
||||||
return self.cls
|
return self.cls
|
||||||
|
|
||||||
def __exit__(self, t, value, traceback):
|
def __exit__(self, t, value, traceback):
|
||||||
""" Reset the db_alias and collection """
|
"""Reset the db_alias and collection."""
|
||||||
self.cls._meta["db_alias"] = self.ori_db_alias
|
self.cls._meta['db_alias'] = self.ori_db_alias
|
||||||
self.cls._collection = self.collection
|
self.cls._collection = self.collection
|
||||||
|
|
||||||
|
|
||||||
class switch_collection(object):
|
class switch_collection(object):
|
||||||
""" switch_collection alias context manager.
|
"""switch_collection alias context manager.
|
||||||
|
|
||||||
Example ::
|
Example ::
|
||||||
|
|
||||||
class Group(Document):
|
class Group(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
|
|
||||||
Group(name="test").save() # Saves in the default db
|
Group(name='test').save() # Saves in the default db
|
||||||
|
|
||||||
with switch_collection(Group, 'group1') as Group:
|
with switch_collection(Group, 'group1') as Group:
|
||||||
Group(name="hello testdb!").save() # Saves in group1 collection
|
Group(name='hello testdb!').save() # Saves in group1 collection
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cls, collection_name):
|
def __init__(self, cls, collection_name):
|
||||||
""" Construct the switch_collection context manager
|
"""Construct the switch_collection context manager.
|
||||||
|
|
||||||
:param cls: the class to change the registered db
|
:param cls: the class to change the registered db
|
||||||
:param collection_name: the name of the collection to use
|
:param collection_name: the name of the collection to use
|
||||||
@ -75,7 +73,7 @@ class switch_collection(object):
|
|||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
""" change the _get_collection_name and clear the cached collection """
|
"""Change the _get_collection_name and clear the cached collection."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_collection_name(cls):
|
def _get_collection_name(cls):
|
||||||
@ -86,24 +84,23 @@ class switch_collection(object):
|
|||||||
return self.cls
|
return self.cls
|
||||||
|
|
||||||
def __exit__(self, t, value, traceback):
|
def __exit__(self, t, value, traceback):
|
||||||
""" Reset the collection """
|
"""Reset the collection."""
|
||||||
self.cls._collection = self.ori_collection
|
self.cls._collection = self.ori_collection
|
||||||
self.cls._get_collection_name = self.ori_get_collection_name
|
self.cls._get_collection_name = self.ori_get_collection_name
|
||||||
|
|
||||||
|
|
||||||
class no_dereference(object):
|
class no_dereference(object):
|
||||||
""" no_dereference context manager.
|
"""no_dereference context manager.
|
||||||
|
|
||||||
Turns off all dereferencing in Documents for the duration of the context
|
Turns off all dereferencing in Documents for the duration of the context
|
||||||
manager::
|
manager::
|
||||||
|
|
||||||
with no_dereference(Group) as Group:
|
with no_dereference(Group) as Group:
|
||||||
Group.objects.find()
|
Group.objects.find()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cls):
|
def __init__(self, cls):
|
||||||
""" Construct the no_dereference context manager.
|
"""Construct the no_dereference context manager.
|
||||||
|
|
||||||
:param cls: the class to turn dereferencing off on
|
:param cls: the class to turn dereferencing off on
|
||||||
"""
|
"""
|
||||||
@ -119,103 +116,102 @@ class no_dereference(object):
|
|||||||
ComplexBaseField))]
|
ComplexBaseField))]
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
""" change the objects default and _auto_dereference values"""
|
"""Change the objects default and _auto_dereference values."""
|
||||||
for field in self.deref_fields:
|
for field in self.deref_fields:
|
||||||
self.cls._fields[field]._auto_dereference = False
|
self.cls._fields[field]._auto_dereference = False
|
||||||
return self.cls
|
return self.cls
|
||||||
|
|
||||||
def __exit__(self, t, value, traceback):
|
def __exit__(self, t, value, traceback):
|
||||||
""" Reset the default and _auto_dereference values"""
|
"""Reset the default and _auto_dereference values."""
|
||||||
for field in self.deref_fields:
|
for field in self.deref_fields:
|
||||||
self.cls._fields[field]._auto_dereference = True
|
self.cls._fields[field]._auto_dereference = True
|
||||||
return self.cls
|
return self.cls
|
||||||
|
|
||||||
|
|
||||||
class no_sub_classes(object):
|
class no_sub_classes(object):
|
||||||
""" no_sub_classes context manager.
|
"""no_sub_classes context manager.
|
||||||
|
|
||||||
Only returns instances of this class and no sub (inherited) classes::
|
Only returns instances of this class and no sub (inherited) classes::
|
||||||
|
|
||||||
with no_sub_classes(Group) as Group:
|
with no_sub_classes(Group) as Group:
|
||||||
Group.objects.find()
|
Group.objects.find()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cls):
|
def __init__(self, cls):
|
||||||
""" Construct the no_sub_classes context manager.
|
"""Construct the no_sub_classes context manager.
|
||||||
|
|
||||||
:param cls: the class to turn querying sub classes on
|
:param cls: the class to turn querying sub classes on
|
||||||
"""
|
"""
|
||||||
self.cls = cls
|
self.cls = cls
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
""" change the objects default and _auto_dereference values"""
|
"""Change the objects default and _auto_dereference values."""
|
||||||
self.cls._all_subclasses = self.cls._subclasses
|
self.cls._all_subclasses = self.cls._subclasses
|
||||||
self.cls._subclasses = (self.cls,)
|
self.cls._subclasses = (self.cls,)
|
||||||
return self.cls
|
return self.cls
|
||||||
|
|
||||||
def __exit__(self, t, value, traceback):
|
def __exit__(self, t, value, traceback):
|
||||||
""" Reset the default and _auto_dereference values"""
|
"""Reset the default and _auto_dereference values."""
|
||||||
self.cls._subclasses = self.cls._all_subclasses
|
self.cls._subclasses = self.cls._all_subclasses
|
||||||
delattr(self.cls, '_all_subclasses')
|
delattr(self.cls, '_all_subclasses')
|
||||||
return self.cls
|
return self.cls
|
||||||
|
|
||||||
|
|
||||||
class query_counter(object):
|
class query_counter(object):
|
||||||
""" Query_counter context manager to get the number of queries. """
|
"""Query_counter context manager to get the number of queries."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
""" Construct the query_counter. """
|
"""Construct the query_counter."""
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
self.db = get_db()
|
self.db = get_db()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
""" On every with block we need to drop the profile collection. """
|
"""On every with block we need to drop the profile collection."""
|
||||||
self.db.set_profiling_level(0)
|
self.db.set_profiling_level(0)
|
||||||
self.db.system.profile.drop()
|
self.db.system.profile.drop()
|
||||||
self.db.set_profiling_level(2)
|
self.db.set_profiling_level(2)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, t, value, traceback):
|
def __exit__(self, t, value, traceback):
|
||||||
""" Reset the profiling level. """
|
"""Reset the profiling level."""
|
||||||
self.db.set_profiling_level(0)
|
self.db.set_profiling_level(0)
|
||||||
|
|
||||||
def __eq__(self, value):
|
def __eq__(self, value):
|
||||||
""" == Compare querycounter. """
|
"""== Compare querycounter."""
|
||||||
counter = self._get_count()
|
counter = self._get_count()
|
||||||
return value == counter
|
return value == counter
|
||||||
|
|
||||||
def __ne__(self, value):
|
def __ne__(self, value):
|
||||||
""" != Compare querycounter. """
|
"""!= Compare querycounter."""
|
||||||
return not self.__eq__(value)
|
return not self.__eq__(value)
|
||||||
|
|
||||||
def __lt__(self, value):
|
def __lt__(self, value):
|
||||||
""" < Compare querycounter. """
|
"""< Compare querycounter."""
|
||||||
return self._get_count() < value
|
return self._get_count() < value
|
||||||
|
|
||||||
def __le__(self, value):
|
def __le__(self, value):
|
||||||
""" <= Compare querycounter. """
|
"""<= Compare querycounter."""
|
||||||
return self._get_count() <= value
|
return self._get_count() <= value
|
||||||
|
|
||||||
def __gt__(self, value):
|
def __gt__(self, value):
|
||||||
""" > Compare querycounter. """
|
"""> Compare querycounter."""
|
||||||
return self._get_count() > value
|
return self._get_count() > value
|
||||||
|
|
||||||
def __ge__(self, value):
|
def __ge__(self, value):
|
||||||
""" >= Compare querycounter. """
|
""">= Compare querycounter."""
|
||||||
return self._get_count() >= value
|
return self._get_count() >= value
|
||||||
|
|
||||||
def __int__(self):
|
def __int__(self):
|
||||||
""" int representation. """
|
"""int representation."""
|
||||||
return self._get_count()
|
return self._get_count()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
""" repr query_counter as the number of queries. """
|
"""repr query_counter as the number of queries."""
|
||||||
return u"%s" % self._get_count()
|
return u"%s" % self._get_count()
|
||||||
|
|
||||||
def _get_count(self):
|
def _get_count(self):
|
||||||
""" Get the number of queries. """
|
"""Get the number of queries."""
|
||||||
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}}
|
ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}}
|
||||||
count = self.db.system.profile.find(ignore_query).count() - self.counter
|
count = self.db.system.profile.find(ignore_query).count() - self.counter
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
return count
|
return count
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
from bson import DBRef, SON
|
from bson import DBRef, SON
|
||||||
|
import six
|
||||||
|
|
||||||
from .base import (
|
from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList,
|
||||||
BaseDict, BaseList, EmbeddedDocumentList,
|
TopLevelDocumentMetaclass, get_document)
|
||||||
TopLevelDocumentMetaclass, get_document
|
from mongoengine.connection import get_db
|
||||||
)
|
from mongoengine.document import Document, EmbeddedDocument
|
||||||
from .connection import get_db
|
from mongoengine.fields import DictField, ListField, MapField, ReferenceField
|
||||||
from .document import Document, EmbeddedDocument
|
from mongoengine.queryset import QuerySet
|
||||||
from .fields import DictField, ListField, MapField, ReferenceField
|
|
||||||
from .python_support import txt_type
|
|
||||||
from .queryset import QuerySet
|
|
||||||
|
|
||||||
|
|
||||||
class DeReference(object):
|
class DeReference(object):
|
||||||
@ -25,7 +23,7 @@ class DeReference(object):
|
|||||||
:class:`~mongoengine.base.ComplexBaseField`
|
:class:`~mongoengine.base.ComplexBaseField`
|
||||||
:param get: A boolean determining if being called by __get__
|
:param get: A boolean determining if being called by __get__
|
||||||
"""
|
"""
|
||||||
if items is None or isinstance(items, basestring):
|
if items is None or isinstance(items, six.string_types):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
# cheapest way to convert a queryset to a list
|
# cheapest way to convert a queryset to a list
|
||||||
@ -68,11 +66,11 @@ class DeReference(object):
|
|||||||
|
|
||||||
items = _get_items(items)
|
items = _get_items(items)
|
||||||
else:
|
else:
|
||||||
items = dict([
|
items = {
|
||||||
(k, field.to_python(v))
|
k: (v if isinstance(v, (DBRef, Document))
|
||||||
if not isinstance(v, (DBRef, Document)) else (k, v)
|
else field.to_python(v))
|
||||||
for k, v in items.iteritems()]
|
for k, v in items.iteritems()
|
||||||
)
|
}
|
||||||
|
|
||||||
self.reference_map = self._find_references(items)
|
self.reference_map = self._find_references(items)
|
||||||
self.object_map = self._fetch_objects(doc_type=doc_type)
|
self.object_map = self._fetch_objects(doc_type=doc_type)
|
||||||
@ -90,14 +88,14 @@ class DeReference(object):
|
|||||||
return reference_map
|
return reference_map
|
||||||
|
|
||||||
# Determine the iterator to use
|
# Determine the iterator to use
|
||||||
if not hasattr(items, 'items'):
|
if isinstance(items, dict):
|
||||||
iterator = enumerate(items)
|
iterator = items.values()
|
||||||
else:
|
else:
|
||||||
iterator = items.iteritems()
|
iterator = items
|
||||||
|
|
||||||
# Recursively find dbreferences
|
# Recursively find dbreferences
|
||||||
depth += 1
|
depth += 1
|
||||||
for k, item in iterator:
|
for item in iterator:
|
||||||
if isinstance(item, (Document, EmbeddedDocument)):
|
if isinstance(item, (Document, EmbeddedDocument)):
|
||||||
for field_name, field in item._fields.iteritems():
|
for field_name, field in item._fields.iteritems():
|
||||||
v = item._data.get(field_name, None)
|
v = item._data.get(field_name, None)
|
||||||
@ -151,7 +149,7 @@ class DeReference(object):
|
|||||||
references = get_db()[collection].find({'_id': {'$in': refs}})
|
references = get_db()[collection].find({'_id': {'$in': refs}})
|
||||||
for ref in references:
|
for ref in references:
|
||||||
if '_cls' in ref:
|
if '_cls' in ref:
|
||||||
doc = get_document(ref["_cls"])._from_son(ref)
|
doc = get_document(ref['_cls'])._from_son(ref)
|
||||||
elif doc_type is None:
|
elif doc_type is None:
|
||||||
doc = get_document(
|
doc = get_document(
|
||||||
''.join(x.capitalize()
|
''.join(x.capitalize()
|
||||||
@ -218,7 +216,7 @@ class DeReference(object):
|
|||||||
if k in self.object_map and not is_list:
|
if k in self.object_map and not is_list:
|
||||||
data[k] = self.object_map[k]
|
data[k] = self.object_map[k]
|
||||||
elif isinstance(v, (Document, EmbeddedDocument)):
|
elif isinstance(v, (Document, EmbeddedDocument)):
|
||||||
for field_name, field in v._fields.iteritems():
|
for field_name in v._fields:
|
||||||
v = data[k]._data.get(field_name, None)
|
v = data[k]._data.get(field_name, None)
|
||||||
if isinstance(v, DBRef):
|
if isinstance(v, DBRef):
|
||||||
data[k]._data[field_name] = self.object_map.get(
|
data[k]._data[field_name] = self.object_map.get(
|
||||||
@ -227,7 +225,7 @@ class DeReference(object):
|
|||||||
data[k]._data[field_name] = self.object_map.get(
|
data[k]._data[field_name] = self.object_map.get(
|
||||||
(v['_ref'].collection, v['_ref'].id), v)
|
(v['_ref'].collection, v['_ref'].id), v)
|
||||||
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
|
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
|
||||||
item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name)
|
item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name)
|
||||||
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
|
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
|
||||||
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
|
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
|
||||||
item_name = '%s.%s' % (name, k) if name else name
|
item_name = '%s.%s' % (name, k) if name else name
|
||||||
|
@ -4,18 +4,12 @@ import warnings
|
|||||||
from bson.dbref import DBRef
|
from bson.dbref import DBRef
|
||||||
import pymongo
|
import pymongo
|
||||||
from pymongo.read_preferences import ReadPreference
|
from pymongo.read_preferences import ReadPreference
|
||||||
|
import six
|
||||||
|
|
||||||
from mongoengine import signals
|
from mongoengine import signals
|
||||||
from mongoengine.base import (
|
from mongoengine.base import (BaseDict, BaseDocument, BaseList,
|
||||||
ALLOW_INHERITANCE,
|
DocumentMetaclass, EmbeddedDocumentList,
|
||||||
BaseDict,
|
TopLevelDocumentMetaclass, get_document)
|
||||||
BaseDocument,
|
|
||||||
BaseList,
|
|
||||||
DocumentMetaclass,
|
|
||||||
EmbeddedDocumentList,
|
|
||||||
TopLevelDocumentMetaclass,
|
|
||||||
get_document
|
|
||||||
)
|
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||||
from mongoengine.context_managers import switch_collection, switch_db
|
from mongoengine.context_managers import switch_collection, switch_db
|
||||||
@ -31,12 +25,10 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
|
|||||||
|
|
||||||
|
|
||||||
def includes_cls(fields):
|
def includes_cls(fields):
|
||||||
""" Helper function used for ensuring and comparing indexes
|
"""Helper function used for ensuring and comparing indexes."""
|
||||||
"""
|
|
||||||
|
|
||||||
first_field = None
|
first_field = None
|
||||||
if len(fields):
|
if len(fields):
|
||||||
if isinstance(fields[0], basestring):
|
if isinstance(fields[0], six.string_types):
|
||||||
first_field = fields[0]
|
first_field = fields[0]
|
||||||
elif isinstance(fields[0], (list, tuple)) and len(fields[0]):
|
elif isinstance(fields[0], (list, tuple)) and len(fields[0]):
|
||||||
first_field = fields[0][0]
|
first_field = fields[0][0]
|
||||||
@ -57,9 +49,8 @@ class EmbeddedDocument(BaseDocument):
|
|||||||
to create a specialised version of the embedded document that will be
|
to create a specialised version of the embedded document that will be
|
||||||
stored in the same collection. To facilitate this behaviour a `_cls`
|
stored in the same collection. To facilitate this behaviour a `_cls`
|
||||||
field is added to documents (hidden though the MongoEngine interface).
|
field is added to documents (hidden though the MongoEngine interface).
|
||||||
To disable this behaviour and remove the dependence on the presence of
|
To enable this behaviour set :attr:`allow_inheritance` to ``True`` in the
|
||||||
`_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta`
|
:attr:`meta` dictionary.
|
||||||
dictionary.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ('_instance', )
|
__slots__ = ('_instance', )
|
||||||
@ -82,6 +73,15 @@ class EmbeddedDocument(BaseDocument):
|
|||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def to_mongo(self, *args, **kwargs):
|
||||||
|
data = super(EmbeddedDocument, self).to_mongo(*args, **kwargs)
|
||||||
|
|
||||||
|
# remove _id from the SON if it's in it and it's None
|
||||||
|
if '_id' in data and data['_id'] is None:
|
||||||
|
del data['_id']
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
self._instance.save(*args, **kwargs)
|
self._instance.save(*args, **kwargs)
|
||||||
|
|
||||||
@ -106,9 +106,8 @@ class Document(BaseDocument):
|
|||||||
create a specialised version of the document that will be stored in the
|
create a specialised version of the document that will be stored in the
|
||||||
same collection. To facilitate this behaviour a `_cls`
|
same collection. To facilitate this behaviour a `_cls`
|
||||||
field is added to documents (hidden though the MongoEngine interface).
|
field is added to documents (hidden though the MongoEngine interface).
|
||||||
To disable this behaviour and remove the dependence on the presence of
|
To enable this behaviourset :attr:`allow_inheritance` to ``True`` in the
|
||||||
`_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta`
|
:attr:`meta` dictionary.
|
||||||
dictionary.
|
|
||||||
|
|
||||||
A :class:`~mongoengine.Document` may use a **Capped Collection** by
|
A :class:`~mongoengine.Document` may use a **Capped Collection** by
|
||||||
specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta`
|
specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta`
|
||||||
@ -149,26 +148,22 @@ class Document(BaseDocument):
|
|||||||
|
|
||||||
__slots__ = ('__objects',)
|
__slots__ = ('__objects',)
|
||||||
|
|
||||||
def pk():
|
@property
|
||||||
"""Primary key alias
|
def pk(self):
|
||||||
"""
|
"""Get the primary key."""
|
||||||
|
if 'id_field' not in self._meta:
|
||||||
|
return None
|
||||||
|
return getattr(self, self._meta['id_field'])
|
||||||
|
|
||||||
def fget(self):
|
@pk.setter
|
||||||
if 'id_field' not in self._meta:
|
def pk(self, value):
|
||||||
return None
|
"""Set the primary key."""
|
||||||
return getattr(self, self._meta['id_field'])
|
return setattr(self, self._meta['id_field'], value)
|
||||||
|
|
||||||
def fset(self, value):
|
|
||||||
return setattr(self, self._meta['id_field'], value)
|
|
||||||
|
|
||||||
return property(fget, fset)
|
|
||||||
|
|
||||||
pk = pk()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_db(cls):
|
def _get_db(cls):
|
||||||
"""Some Model using other db_alias"""
|
"""Some Model using other db_alias"""
|
||||||
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
|
return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_collection(cls):
|
def _get_collection(cls):
|
||||||
@ -211,7 +206,20 @@ class Document(BaseDocument):
|
|||||||
cls.ensure_indexes()
|
cls.ensure_indexes()
|
||||||
return cls._collection
|
return cls._collection
|
||||||
|
|
||||||
def modify(self, query={}, **update):
|
def to_mongo(self, *args, **kwargs):
|
||||||
|
data = super(Document, self).to_mongo(*args, **kwargs)
|
||||||
|
|
||||||
|
# If '_id' is None, try and set it from self._data. If that
|
||||||
|
# doesn't exist either, remote '_id' from the SON completely.
|
||||||
|
if data['_id'] is None:
|
||||||
|
if self._data.get('id') is None:
|
||||||
|
del data['_id']
|
||||||
|
else:
|
||||||
|
data['_id'] = self._data['id']
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def modify(self, query=None, **update):
|
||||||
"""Perform an atomic update of the document in the database and reload
|
"""Perform an atomic update of the document in the database and reload
|
||||||
the document object using updated version.
|
the document object using updated version.
|
||||||
|
|
||||||
@ -225,17 +233,19 @@ class Document(BaseDocument):
|
|||||||
database matches the query
|
database matches the query
|
||||||
:param update: Django-style update keyword arguments
|
:param update: Django-style update keyword arguments
|
||||||
"""
|
"""
|
||||||
|
if query is None:
|
||||||
|
query = {}
|
||||||
|
|
||||||
if self.pk is None:
|
if self.pk is None:
|
||||||
raise InvalidDocumentError("The document does not have a primary key.")
|
raise InvalidDocumentError('The document does not have a primary key.')
|
||||||
|
|
||||||
id_field = self._meta["id_field"]
|
id_field = self._meta['id_field']
|
||||||
query = query.copy() if isinstance(query, dict) else query.to_query(self)
|
query = query.copy() if isinstance(query, dict) else query.to_query(self)
|
||||||
|
|
||||||
if id_field not in query:
|
if id_field not in query:
|
||||||
query[id_field] = self.pk
|
query[id_field] = self.pk
|
||||||
elif query[id_field] != self.pk:
|
elif query[id_field] != self.pk:
|
||||||
raise InvalidQueryError("Invalid document modify query: it must modify only this document.")
|
raise InvalidQueryError('Invalid document modify query: it must modify only this document.')
|
||||||
|
|
||||||
updated = self._qs(**query).modify(new=True, **update)
|
updated = self._qs(**query).modify(new=True, **update)
|
||||||
if updated is None:
|
if updated is None:
|
||||||
@ -310,7 +320,7 @@ class Document(BaseDocument):
|
|||||||
self.validate(clean=clean)
|
self.validate(clean=clean)
|
||||||
|
|
||||||
if write_concern is None:
|
if write_concern is None:
|
||||||
write_concern = {"w": 1}
|
write_concern = {'w': 1}
|
||||||
|
|
||||||
doc = self.to_mongo()
|
doc = self.to_mongo()
|
||||||
|
|
||||||
@ -347,7 +357,7 @@ class Document(BaseDocument):
|
|||||||
else:
|
else:
|
||||||
select_dict = {}
|
select_dict = {}
|
||||||
select_dict['_id'] = object_id
|
select_dict['_id'] = object_id
|
||||||
shard_key = self.__class__._meta.get('shard_key', tuple())
|
shard_key = self._meta.get('shard_key', tuple())
|
||||||
for k in shard_key:
|
for k in shard_key:
|
||||||
path = self._lookup_field(k.split('.'))
|
path = self._lookup_field(k.split('.'))
|
||||||
actual_key = [p.db_field for p in path]
|
actual_key = [p.db_field for p in path]
|
||||||
@ -358,7 +368,7 @@ class Document(BaseDocument):
|
|||||||
|
|
||||||
def is_new_object(last_error):
|
def is_new_object(last_error):
|
||||||
if last_error is not None:
|
if last_error is not None:
|
||||||
updated = last_error.get("updatedExisting")
|
updated = last_error.get('updatedExisting')
|
||||||
if updated is not None:
|
if updated is not None:
|
||||||
return not updated
|
return not updated
|
||||||
return created
|
return created
|
||||||
@ -366,14 +376,14 @@ class Document(BaseDocument):
|
|||||||
update_query = {}
|
update_query = {}
|
||||||
|
|
||||||
if updates:
|
if updates:
|
||||||
update_query["$set"] = updates
|
update_query['$set'] = updates
|
||||||
if removals:
|
if removals:
|
||||||
update_query["$unset"] = removals
|
update_query['$unset'] = removals
|
||||||
if updates or removals:
|
if updates or removals:
|
||||||
upsert = save_condition is None
|
upsert = save_condition is None
|
||||||
last_error = collection.update(select_dict, update_query,
|
last_error = collection.update(select_dict, update_query,
|
||||||
upsert=upsert, **write_concern)
|
upsert=upsert, **write_concern)
|
||||||
if not upsert and last_error["n"] == 0:
|
if not upsert and last_error['n'] == 0:
|
||||||
raise SaveConditionError('Race condition preventing'
|
raise SaveConditionError('Race condition preventing'
|
||||||
' document update detected')
|
' document update detected')
|
||||||
created = is_new_object(last_error)
|
created = is_new_object(last_error)
|
||||||
@ -384,26 +394,27 @@ class Document(BaseDocument):
|
|||||||
|
|
||||||
if cascade:
|
if cascade:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"force_insert": force_insert,
|
'force_insert': force_insert,
|
||||||
"validate": validate,
|
'validate': validate,
|
||||||
"write_concern": write_concern,
|
'write_concern': write_concern,
|
||||||
"cascade": cascade
|
'cascade': cascade
|
||||||
}
|
}
|
||||||
if cascade_kwargs: # Allow granular control over cascades
|
if cascade_kwargs: # Allow granular control over cascades
|
||||||
kwargs.update(cascade_kwargs)
|
kwargs.update(cascade_kwargs)
|
||||||
kwargs['_refs'] = _refs
|
kwargs['_refs'] = _refs
|
||||||
self.cascade_save(**kwargs)
|
self.cascade_save(**kwargs)
|
||||||
except pymongo.errors.DuplicateKeyError, err:
|
except pymongo.errors.DuplicateKeyError as err:
|
||||||
message = u'Tried to save duplicate unique keys (%s)'
|
message = u'Tried to save duplicate unique keys (%s)'
|
||||||
raise NotUniqueError(message % unicode(err))
|
raise NotUniqueError(message % six.text_type(err))
|
||||||
except pymongo.errors.OperationFailure, err:
|
except pymongo.errors.OperationFailure as err:
|
||||||
message = 'Could not save document (%s)'
|
message = 'Could not save document (%s)'
|
||||||
if re.match('^E1100[01] duplicate key', unicode(err)):
|
if re.match('^E1100[01] duplicate key', six.text_type(err)):
|
||||||
# E11000 - duplicate key error index
|
# E11000 - duplicate key error index
|
||||||
# E11001 - duplicate key on update
|
# E11001 - duplicate key on update
|
||||||
message = u'Tried to save duplicate unique keys (%s)'
|
message = u'Tried to save duplicate unique keys (%s)'
|
||||||
raise NotUniqueError(message % unicode(err))
|
raise NotUniqueError(message % six.text_type(err))
|
||||||
raise OperationError(message % unicode(err))
|
raise OperationError(message % six.text_type(err))
|
||||||
|
|
||||||
id_field = self._meta['id_field']
|
id_field = self._meta['id_field']
|
||||||
if created or id_field not in self._meta.get('shard_key', []):
|
if created or id_field not in self._meta.get('shard_key', []):
|
||||||
self[id_field] = self._fields[id_field].to_python(object_id)
|
self[id_field] = self._fields[id_field].to_python(object_id)
|
||||||
@ -414,10 +425,11 @@ class Document(BaseDocument):
|
|||||||
self._created = False
|
self._created = False
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def cascade_save(self, *args, **kwargs):
|
def cascade_save(self, **kwargs):
|
||||||
"""Recursively saves any references /
|
"""Recursively save any references and generic references on the
|
||||||
generic references on the document"""
|
document.
|
||||||
_refs = kwargs.get('_refs', []) or []
|
"""
|
||||||
|
_refs = kwargs.get('_refs') or []
|
||||||
|
|
||||||
ReferenceField = _import_class('ReferenceField')
|
ReferenceField = _import_class('ReferenceField')
|
||||||
GenericReferenceField = _import_class('GenericReferenceField')
|
GenericReferenceField = _import_class('GenericReferenceField')
|
||||||
@ -443,16 +455,17 @@ class Document(BaseDocument):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _qs(self):
|
def _qs(self):
|
||||||
"""
|
"""Return the queryset to use for updating / reloading / deletions."""
|
||||||
Returns the queryset to use for updating / reloading / deletions
|
|
||||||
"""
|
|
||||||
if not hasattr(self, '__objects'):
|
if not hasattr(self, '__objects'):
|
||||||
self.__objects = QuerySet(self, self._get_collection())
|
self.__objects = QuerySet(self, self._get_collection())
|
||||||
return self.__objects
|
return self.__objects
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _object_key(self):
|
def _object_key(self):
|
||||||
"""Dict to identify object in collection
|
"""Get the query dict that can be used to fetch this object from
|
||||||
|
the database. Most of the time it's a simple PK lookup, but in
|
||||||
|
case of a sharded collection with a compound shard key, it can
|
||||||
|
contain a more complex query.
|
||||||
"""
|
"""
|
||||||
select_dict = {'pk': self.pk}
|
select_dict = {'pk': self.pk}
|
||||||
shard_key = self.__class__._meta.get('shard_key', tuple())
|
shard_key = self.__class__._meta.get('shard_key', tuple())
|
||||||
@ -475,8 +488,8 @@ class Document(BaseDocument):
|
|||||||
if self.pk is None:
|
if self.pk is None:
|
||||||
if kwargs.get('upsert', False):
|
if kwargs.get('upsert', False):
|
||||||
query = self.to_mongo()
|
query = self.to_mongo()
|
||||||
if "_cls" in query:
|
if '_cls' in query:
|
||||||
del query["_cls"]
|
del query['_cls']
|
||||||
return self._qs.filter(**query).update_one(**kwargs)
|
return self._qs.filter(**query).update_one(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise OperationError(
|
raise OperationError(
|
||||||
@ -513,7 +526,7 @@ class Document(BaseDocument):
|
|||||||
try:
|
try:
|
||||||
self._qs.filter(
|
self._qs.filter(
|
||||||
**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True)
|
**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True)
|
||||||
except pymongo.errors.OperationFailure, err:
|
except pymongo.errors.OperationFailure as err:
|
||||||
message = u'Could not delete document (%s)' % err.message
|
message = u'Could not delete document (%s)' % err.message
|
||||||
raise OperationError(message)
|
raise OperationError(message)
|
||||||
signals.post_delete.send(self.__class__, document=self, **signal_kwargs)
|
signals.post_delete.send(self.__class__, document=self, **signal_kwargs)
|
||||||
@ -601,11 +614,12 @@ class Document(BaseDocument):
|
|||||||
if fields and isinstance(fields[0], int):
|
if fields and isinstance(fields[0], int):
|
||||||
max_depth = fields[0]
|
max_depth = fields[0]
|
||||||
fields = fields[1:]
|
fields = fields[1:]
|
||||||
elif "max_depth" in kwargs:
|
elif 'max_depth' in kwargs:
|
||||||
max_depth = kwargs["max_depth"]
|
max_depth = kwargs['max_depth']
|
||||||
|
|
||||||
if self.pk is None:
|
if self.pk is None:
|
||||||
raise self.DoesNotExist("Document does not exist")
|
raise self.DoesNotExist('Document does not exist')
|
||||||
|
|
||||||
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
|
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
|
||||||
**self._object_key).only(*fields).limit(
|
**self._object_key).only(*fields).limit(
|
||||||
1).select_related(max_depth=max_depth)
|
1).select_related(max_depth=max_depth)
|
||||||
@ -613,7 +627,7 @@ class Document(BaseDocument):
|
|||||||
if obj:
|
if obj:
|
||||||
obj = obj[0]
|
obj = obj[0]
|
||||||
else:
|
else:
|
||||||
raise self.DoesNotExist("Document does not exist")
|
raise self.DoesNotExist('Document does not exist')
|
||||||
|
|
||||||
for field in obj._data:
|
for field in obj._data:
|
||||||
if not fields or field in fields:
|
if not fields or field in fields:
|
||||||
@ -656,7 +670,7 @@ class Document(BaseDocument):
|
|||||||
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in
|
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in
|
||||||
`__raw__` queries."""
|
`__raw__` queries."""
|
||||||
if self.pk is None:
|
if self.pk is None:
|
||||||
msg = "Only saved documents can have a valid dbref"
|
msg = 'Only saved documents can have a valid dbref'
|
||||||
raise OperationError(msg)
|
raise OperationError(msg)
|
||||||
return DBRef(self.__class__._get_collection_name(), self.pk)
|
return DBRef(self.__class__._get_collection_name(), self.pk)
|
||||||
|
|
||||||
@ -711,7 +725,7 @@ class Document(BaseDocument):
|
|||||||
fields = index_spec.pop('fields')
|
fields = index_spec.pop('fields')
|
||||||
drop_dups = kwargs.get('drop_dups', False)
|
drop_dups = kwargs.get('drop_dups', False)
|
||||||
if IS_PYMONGO_3 and drop_dups:
|
if IS_PYMONGO_3 and drop_dups:
|
||||||
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
|
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
elif not IS_PYMONGO_3:
|
elif not IS_PYMONGO_3:
|
||||||
index_spec['drop_dups'] = drop_dups
|
index_spec['drop_dups'] = drop_dups
|
||||||
@ -737,7 +751,7 @@ class Document(BaseDocument):
|
|||||||
will be removed if PyMongo3+ is used
|
will be removed if PyMongo3+ is used
|
||||||
"""
|
"""
|
||||||
if IS_PYMONGO_3 and drop_dups:
|
if IS_PYMONGO_3 and drop_dups:
|
||||||
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
|
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
elif not IS_PYMONGO_3:
|
elif not IS_PYMONGO_3:
|
||||||
kwargs.update({'drop_dups': drop_dups})
|
kwargs.update({'drop_dups': drop_dups})
|
||||||
@ -757,7 +771,7 @@ class Document(BaseDocument):
|
|||||||
index_opts = cls._meta.get('index_opts') or {}
|
index_opts = cls._meta.get('index_opts') or {}
|
||||||
index_cls = cls._meta.get('index_cls', True)
|
index_cls = cls._meta.get('index_cls', True)
|
||||||
if IS_PYMONGO_3 and drop_dups:
|
if IS_PYMONGO_3 and drop_dups:
|
||||||
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
|
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
|
|
||||||
collection = cls._get_collection()
|
collection = cls._get_collection()
|
||||||
@ -795,8 +809,7 @@ class Document(BaseDocument):
|
|||||||
|
|
||||||
# If _cls is being used (for polymorphism), it needs an index,
|
# If _cls is being used (for polymorphism), it needs an index,
|
||||||
# only if another index doesn't begin with _cls
|
# only if another index doesn't begin with _cls
|
||||||
if (index_cls and not cls_indexed and
|
if index_cls and not cls_indexed and cls._meta.get('allow_inheritance'):
|
||||||
cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True):
|
|
||||||
|
|
||||||
# we shouldn't pass 'cls' to the collection.ensureIndex options
|
# we shouldn't pass 'cls' to the collection.ensureIndex options
|
||||||
# because of https://jira.mongodb.org/browse/SERVER-769
|
# because of https://jira.mongodb.org/browse/SERVER-769
|
||||||
@ -866,16 +879,15 @@ class Document(BaseDocument):
|
|||||||
# finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed
|
# finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed
|
||||||
if [(u'_id', 1)] not in indexes:
|
if [(u'_id', 1)] not in indexes:
|
||||||
indexes.append([(u'_id', 1)])
|
indexes.append([(u'_id', 1)])
|
||||||
if (cls._meta.get('index_cls', True) and
|
if cls._meta.get('index_cls', True) and cls._meta.get('allow_inheritance'):
|
||||||
cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True):
|
|
||||||
indexes.append([(u'_cls', 1)])
|
indexes.append([(u'_cls', 1)])
|
||||||
|
|
||||||
return indexes
|
return indexes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compare_indexes(cls):
|
def compare_indexes(cls):
|
||||||
""" Compares the indexes defined in MongoEngine with the ones existing
|
""" Compares the indexes defined in MongoEngine with the ones
|
||||||
in the database. Returns any missing/extra indexes.
|
existing in the database. Returns any missing/extra indexes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
required = cls.list_indexes()
|
required = cls.list_indexes()
|
||||||
@ -919,8 +931,9 @@ class DynamicDocument(Document):
|
|||||||
_dynamic = True
|
_dynamic = True
|
||||||
|
|
||||||
def __delattr__(self, *args, **kwargs):
|
def __delattr__(self, *args, **kwargs):
|
||||||
"""Deletes the attribute by setting to None and allowing _delta to unset
|
"""Delete the attribute by setting to None and allowing _delta
|
||||||
it"""
|
to unset it.
|
||||||
|
"""
|
||||||
field_name = args[0]
|
field_name = args[0]
|
||||||
if field_name in self._dynamic_fields:
|
if field_name in self._dynamic_fields:
|
||||||
setattr(self, field_name, None)
|
setattr(self, field_name, None)
|
||||||
@ -942,8 +955,9 @@ class DynamicEmbeddedDocument(EmbeddedDocument):
|
|||||||
_dynamic = True
|
_dynamic = True
|
||||||
|
|
||||||
def __delattr__(self, *args, **kwargs):
|
def __delattr__(self, *args, **kwargs):
|
||||||
"""Deletes the attribute by setting to None and allowing _delta to unset
|
"""Delete the attribute by setting to None and allowing _delta
|
||||||
it"""
|
to unset it.
|
||||||
|
"""
|
||||||
field_name = args[0]
|
field_name = args[0]
|
||||||
if field_name in self._fields:
|
if field_name in self._fields:
|
||||||
default = self._fields[field_name].default
|
default = self._fields[field_name].default
|
||||||
@ -985,10 +999,10 @@ class MapReduceDocument(object):
|
|||||||
try:
|
try:
|
||||||
self.key = id_field_type(self.key)
|
self.key = id_field_type(self.key)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception("Could not cast key as %s" %
|
raise Exception('Could not cast key as %s' %
|
||||||
id_field_type.__name__)
|
id_field_type.__name__)
|
||||||
|
|
||||||
if not hasattr(self, "_key_object"):
|
if not hasattr(self, '_key_object'):
|
||||||
self._key_object = self._document.objects.with_id(self.key)
|
self._key_object = self._document.objects.with_id(self.key)
|
||||||
return self._key_object
|
return self._key_object
|
||||||
return self._key_object
|
return self._key_object
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from mongoengine.python_support import txt_type
|
import six
|
||||||
|
|
||||||
|
|
||||||
__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
|
__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
|
||||||
'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
|
'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
|
||||||
@ -71,13 +70,13 @@ class ValidationError(AssertionError):
|
|||||||
field_name = None
|
field_name = None
|
||||||
_message = None
|
_message = None
|
||||||
|
|
||||||
def __init__(self, message="", **kwargs):
|
def __init__(self, message='', **kwargs):
|
||||||
self.errors = kwargs.get('errors', {})
|
self.errors = kwargs.get('errors', {})
|
||||||
self.field_name = kwargs.get('field_name')
|
self.field_name = kwargs.get('field_name')
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return txt_type(self.message)
|
return six.text_type(self.message)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '%s(%s,)' % (self.__class__.__name__, self.message)
|
return '%s(%s,)' % (self.__class__.__name__, self.message)
|
||||||
@ -111,17 +110,20 @@ class ValidationError(AssertionError):
|
|||||||
errors_dict = {}
|
errors_dict = {}
|
||||||
if not source:
|
if not source:
|
||||||
return errors_dict
|
return errors_dict
|
||||||
|
|
||||||
if isinstance(source, dict):
|
if isinstance(source, dict):
|
||||||
for field_name, error in source.iteritems():
|
for field_name, error in source.iteritems():
|
||||||
errors_dict[field_name] = build_dict(error)
|
errors_dict[field_name] = build_dict(error)
|
||||||
elif isinstance(source, ValidationError) and source.errors:
|
elif isinstance(source, ValidationError) and source.errors:
|
||||||
return build_dict(source.errors)
|
return build_dict(source.errors)
|
||||||
else:
|
else:
|
||||||
return unicode(source)
|
return six.text_type(source)
|
||||||
|
|
||||||
return errors_dict
|
return errors_dict
|
||||||
|
|
||||||
if not self.errors:
|
if not self.errors:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
return build_dict(self.errors)
|
return build_dict(self.errors)
|
||||||
|
|
||||||
def _format_errors(self):
|
def _format_errors(self):
|
||||||
@ -134,10 +136,10 @@ class ValidationError(AssertionError):
|
|||||||
value = ' '.join(
|
value = ' '.join(
|
||||||
[generate_key(v, k) for k, v in value.iteritems()])
|
[generate_key(v, k) for k, v in value.iteritems()])
|
||||||
|
|
||||||
results = "%s.%s" % (prefix, value) if prefix else value
|
results = '%s.%s' % (prefix, value) if prefix else value
|
||||||
return results
|
return results
|
||||||
|
|
||||||
error_dict = defaultdict(list)
|
error_dict = defaultdict(list)
|
||||||
for k, v in self.to_dict().iteritems():
|
for k, v in self.to_dict().iteritems():
|
||||||
error_dict[generate_key(v)].append(k)
|
error_dict[generate_key(v)].append(k)
|
||||||
return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()])
|
return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()])
|
||||||
|
@ -3,7 +3,6 @@ import decimal
|
|||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import urllib2
|
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -25,13 +24,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
Int64 = long
|
Int64 = long
|
||||||
|
|
||||||
from .base import (BaseDocument, BaseField, ComplexBaseField, GeoJsonBaseField,
|
from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField,
|
||||||
ObjectIdField, get_document)
|
GeoJsonBaseField, ObjectIdField, get_document)
|
||||||
from .connection import DEFAULT_CONNECTION_NAME, get_db
|
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||||
from .document import Document, EmbeddedDocument
|
from mongoengine.document import Document, EmbeddedDocument
|
||||||
from .errors import DoesNotExist, ValidationError
|
from mongoengine.errors import DoesNotExist, ValidationError
|
||||||
from .python_support import PY3, StringIO, bin_type, str_types, txt_type
|
from mongoengine.python_support import StringIO
|
||||||
from .queryset import DO_NOTHING, QuerySet
|
from mongoengine.queryset import DO_NOTHING, QuerySet
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
@ -39,7 +38,7 @@ except ImportError:
|
|||||||
Image = None
|
Image = None
|
||||||
ImageOps = None
|
ImageOps = None
|
||||||
|
|
||||||
__all__ = [
|
__all__ = (
|
||||||
'StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
|
'StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
|
||||||
'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField',
|
'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField',
|
||||||
'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField',
|
'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField',
|
||||||
@ -50,14 +49,14 @@ __all__ = [
|
|||||||
'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField',
|
'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField',
|
||||||
'GeoPointField', 'PointField', 'LineStringField', 'PolygonField',
|
'GeoPointField', 'PointField', 'LineStringField', 'PolygonField',
|
||||||
'SequenceField', 'UUIDField', 'MultiPointField', 'MultiLineStringField',
|
'SequenceField', 'UUIDField', 'MultiPointField', 'MultiLineStringField',
|
||||||
'MultiPolygonField', 'GeoJsonBaseField']
|
'MultiPolygonField', 'GeoJsonBaseField'
|
||||||
|
)
|
||||||
|
|
||||||
RECURSIVE_REFERENCE_CONSTANT = 'self'
|
RECURSIVE_REFERENCE_CONSTANT = 'self'
|
||||||
|
|
||||||
|
|
||||||
class StringField(BaseField):
|
class StringField(BaseField):
|
||||||
"""A unicode string field.
|
"""A unicode string field."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, regex=None, max_length=None, min_length=None, **kwargs):
|
def __init__(self, regex=None, max_length=None, min_length=None, **kwargs):
|
||||||
self.regex = re.compile(regex) if regex else None
|
self.regex = re.compile(regex) if regex else None
|
||||||
@ -66,7 +65,7 @@ class StringField(BaseField):
|
|||||||
super(StringField, self).__init__(**kwargs)
|
super(StringField, self).__init__(**kwargs)
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
if isinstance(value, unicode):
|
if isinstance(value, six.text_type):
|
||||||
return value
|
return value
|
||||||
try:
|
try:
|
||||||
value = value.decode('utf-8')
|
value = value.decode('utf-8')
|
||||||
@ -75,7 +74,7 @@ class StringField(BaseField):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
if not isinstance(value, basestring):
|
if not isinstance(value, six.string_types):
|
||||||
self.error('StringField only accepts string values')
|
self.error('StringField only accepts string values')
|
||||||
|
|
||||||
if self.max_length is not None and len(value) > self.max_length:
|
if self.max_length is not None and len(value) > self.max_length:
|
||||||
@ -91,7 +90,7 @@ class StringField(BaseField):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
if not isinstance(op, basestring):
|
if not isinstance(op, six.string_types):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'):
|
if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'):
|
||||||
@ -148,17 +147,6 @@ class URLField(StringField):
|
|||||||
self.error('Invalid URL: {}'.format(value))
|
self.error('Invalid URL: {}'.format(value))
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.verify_exists:
|
|
||||||
warnings.warn(
|
|
||||||
"The URLField verify_exists argument has intractable security "
|
|
||||||
"and performance issues. Accordingly, it has been deprecated.",
|
|
||||||
DeprecationWarning)
|
|
||||||
try:
|
|
||||||
request = urllib2.Request(value)
|
|
||||||
urllib2.urlopen(request)
|
|
||||||
except Exception, e:
|
|
||||||
self.error('This URL appears to be a broken link: %s' % e)
|
|
||||||
|
|
||||||
|
|
||||||
class EmailField(StringField):
|
class EmailField(StringField):
|
||||||
"""A field that validates input as an email address.
|
"""A field that validates input as an email address.
|
||||||
@ -182,8 +170,7 @@ class EmailField(StringField):
|
|||||||
|
|
||||||
|
|
||||||
class IntField(BaseField):
|
class IntField(BaseField):
|
||||||
"""An 32-bit integer field.
|
"""32-bit integer field."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, min_value=None, max_value=None, **kwargs):
|
def __init__(self, min_value=None, max_value=None, **kwargs):
|
||||||
self.min_value, self.max_value = min_value, max_value
|
self.min_value, self.max_value = min_value, max_value
|
||||||
@ -216,8 +203,7 @@ class IntField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class LongField(BaseField):
|
class LongField(BaseField):
|
||||||
"""An 64-bit integer field.
|
"""64-bit integer field."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, min_value=None, max_value=None, **kwargs):
|
def __init__(self, min_value=None, max_value=None, **kwargs):
|
||||||
self.min_value, self.max_value = min_value, max_value
|
self.min_value, self.max_value = min_value, max_value
|
||||||
@ -253,8 +239,7 @@ class LongField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class FloatField(BaseField):
|
class FloatField(BaseField):
|
||||||
"""An floating point number field.
|
"""Floating point number field."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, min_value=None, max_value=None, **kwargs):
|
def __init__(self, min_value=None, max_value=None, **kwargs):
|
||||||
self.min_value, self.max_value = min_value, max_value
|
self.min_value, self.max_value = min_value, max_value
|
||||||
@ -291,7 +276,7 @@ class FloatField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class DecimalField(BaseField):
|
class DecimalField(BaseField):
|
||||||
"""A fixed-point decimal number field.
|
"""Fixed-point decimal number field.
|
||||||
|
|
||||||
.. versionchanged:: 0.8
|
.. versionchanged:: 0.8
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
@ -332,25 +317,25 @@ class DecimalField(BaseField):
|
|||||||
|
|
||||||
# Convert to string for python 2.6 before casting to Decimal
|
# Convert to string for python 2.6 before casting to Decimal
|
||||||
try:
|
try:
|
||||||
value = decimal.Decimal("%s" % value)
|
value = decimal.Decimal('%s' % value)
|
||||||
except decimal.InvalidOperation:
|
except decimal.InvalidOperation:
|
||||||
return value
|
return value
|
||||||
return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding)
|
return value.quantize(decimal.Decimal('.%s' % ('0' * self.precision)), rounding=self.rounding)
|
||||||
|
|
||||||
def to_mongo(self, value):
|
def to_mongo(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
if self.force_string:
|
if self.force_string:
|
||||||
return unicode(self.to_python(value))
|
return six.text_type(self.to_python(value))
|
||||||
return float(self.to_python(value))
|
return float(self.to_python(value))
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
if not isinstance(value, decimal.Decimal):
|
if not isinstance(value, decimal.Decimal):
|
||||||
if not isinstance(value, basestring):
|
if not isinstance(value, six.string_types):
|
||||||
value = unicode(value)
|
value = six.text_type(value)
|
||||||
try:
|
try:
|
||||||
value = decimal.Decimal(value)
|
value = decimal.Decimal(value)
|
||||||
except Exception, exc:
|
except Exception as exc:
|
||||||
self.error('Could not convert value to decimal: %s' % exc)
|
self.error('Could not convert value to decimal: %s' % exc)
|
||||||
|
|
||||||
if self.min_value is not None and value < self.min_value:
|
if self.min_value is not None and value < self.min_value:
|
||||||
@ -364,7 +349,7 @@ class DecimalField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class BooleanField(BaseField):
|
class BooleanField(BaseField):
|
||||||
"""A boolean field type.
|
"""Boolean field type.
|
||||||
|
|
||||||
.. versionadded:: 0.1.2
|
.. versionadded:: 0.1.2
|
||||||
"""
|
"""
|
||||||
@ -382,7 +367,7 @@ class BooleanField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class DateTimeField(BaseField):
|
class DateTimeField(BaseField):
|
||||||
"""A datetime field.
|
"""Datetime field.
|
||||||
|
|
||||||
Uses the python-dateutil library if available alternatively use time.strptime
|
Uses the python-dateutil library if available alternatively use time.strptime
|
||||||
to parse the dates. Note: python-dateutil's parser is fully featured and when
|
to parse the dates. Note: python-dateutil's parser is fully featured and when
|
||||||
@ -410,7 +395,7 @@ class DateTimeField(BaseField):
|
|||||||
if callable(value):
|
if callable(value):
|
||||||
return value()
|
return value()
|
||||||
|
|
||||||
if not isinstance(value, basestring):
|
if not isinstance(value, six.string_types):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Attempt to parse a datetime:
|
# Attempt to parse a datetime:
|
||||||
@ -537,16 +522,19 @@ class EmbeddedDocumentField(BaseField):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, document_type, **kwargs):
|
def __init__(self, document_type, **kwargs):
|
||||||
if not isinstance(document_type, basestring):
|
if (
|
||||||
if not issubclass(document_type, EmbeddedDocument):
|
not isinstance(document_type, six.string_types) and
|
||||||
self.error('Invalid embedded document class provided to an '
|
not issubclass(document_type, EmbeddedDocument)
|
||||||
'EmbeddedDocumentField')
|
):
|
||||||
|
self.error('Invalid embedded document class provided to an '
|
||||||
|
'EmbeddedDocumentField')
|
||||||
|
|
||||||
self.document_type_obj = document_type
|
self.document_type_obj = document_type
|
||||||
super(EmbeddedDocumentField, self).__init__(**kwargs)
|
super(EmbeddedDocumentField, self).__init__(**kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def document_type(self):
|
def document_type(self):
|
||||||
if isinstance(self.document_type_obj, basestring):
|
if isinstance(self.document_type_obj, six.string_types):
|
||||||
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
||||||
self.document_type_obj = self.owner_document
|
self.document_type_obj = self.owner_document
|
||||||
else:
|
else:
|
||||||
@ -631,7 +619,7 @@ class DynamicField(BaseField):
|
|||||||
"""Convert a Python type to a MongoDB compatible type.
|
"""Convert a Python type to a MongoDB compatible type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, six.string_types):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if hasattr(value, 'to_mongo'):
|
if hasattr(value, 'to_mongo'):
|
||||||
@ -639,7 +627,7 @@ class DynamicField(BaseField):
|
|||||||
val = value.to_mongo(use_db_field, fields)
|
val = value.to_mongo(use_db_field, fields)
|
||||||
# If we its a document thats not inherited add _cls
|
# If we its a document thats not inherited add _cls
|
||||||
if isinstance(value, Document):
|
if isinstance(value, Document):
|
||||||
val = {"_ref": value.to_dbref(), "_cls": cls.__name__}
|
val = {'_ref': value.to_dbref(), '_cls': cls.__name__}
|
||||||
if isinstance(value, EmbeddedDocument):
|
if isinstance(value, EmbeddedDocument):
|
||||||
val['_cls'] = cls.__name__
|
val['_cls'] = cls.__name__
|
||||||
return val
|
return val
|
||||||
@ -650,7 +638,7 @@ class DynamicField(BaseField):
|
|||||||
is_list = False
|
is_list = False
|
||||||
if not hasattr(value, 'items'):
|
if not hasattr(value, 'items'):
|
||||||
is_list = True
|
is_list = True
|
||||||
value = dict([(k, v) for k, v in enumerate(value)])
|
value = {k: v for k, v in enumerate(value)}
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
for k, v in value.iteritems():
|
for k, v in value.iteritems():
|
||||||
@ -674,12 +662,12 @@ class DynamicField(BaseField):
|
|||||||
return member_name
|
return member_name
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, six.string_types):
|
||||||
return StringField().prepare_query_value(op, value)
|
return StringField().prepare_query_value(op, value)
|
||||||
return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value))
|
return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value))
|
||||||
|
|
||||||
def validate(self, value, clean=True):
|
def validate(self, value, clean=True):
|
||||||
if hasattr(value, "validate"):
|
if hasattr(value, 'validate'):
|
||||||
value.validate(clean=clean)
|
value.validate(clean=clean)
|
||||||
|
|
||||||
|
|
||||||
@ -699,21 +687,27 @@ class ListField(ComplexBaseField):
|
|||||||
super(ListField, self).__init__(**kwargs)
|
super(ListField, self).__init__(**kwargs)
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
"""Make sure that a list of valid fields is being used.
|
"""Make sure that a list of valid fields is being used."""
|
||||||
"""
|
|
||||||
if (not isinstance(value, (list, tuple, QuerySet)) or
|
if (not isinstance(value, (list, tuple, QuerySet)) or
|
||||||
isinstance(value, basestring)):
|
isinstance(value, six.string_types)):
|
||||||
self.error('Only lists and tuples may be used in a list field')
|
self.error('Only lists and tuples may be used in a list field')
|
||||||
super(ListField, self).validate(value)
|
super(ListField, self).validate(value)
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
if self.field:
|
if self.field:
|
||||||
if op in ('set', 'unset', None) and (
|
|
||||||
not isinstance(value, basestring) and
|
# If the value is iterable and it's not a string nor a
|
||||||
not isinstance(value, BaseDocument) and
|
# BaseDocument, call prepare_query_value for each of its items.
|
||||||
hasattr(value, '__iter__')):
|
if (
|
||||||
|
op in ('set', 'unset', None) and
|
||||||
|
hasattr(value, '__iter__') and
|
||||||
|
not isinstance(value, six.string_types) and
|
||||||
|
not isinstance(value, BaseDocument)
|
||||||
|
):
|
||||||
return [self.field.prepare_query_value(op, v) for v in value]
|
return [self.field.prepare_query_value(op, v) for v in value]
|
||||||
|
|
||||||
return self.field.prepare_query_value(op, value)
|
return self.field.prepare_query_value(op, value)
|
||||||
|
|
||||||
return super(ListField, self).prepare_query_value(op, value)
|
return super(ListField, self).prepare_query_value(op, value)
|
||||||
|
|
||||||
|
|
||||||
@ -726,7 +720,6 @@ class EmbeddedDocumentListField(ListField):
|
|||||||
:class:`~mongoengine.EmbeddedDocument`.
|
:class:`~mongoengine.EmbeddedDocument`.
|
||||||
|
|
||||||
.. versionadded:: 0.9
|
.. versionadded:: 0.9
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, document_type, **kwargs):
|
def __init__(self, document_type, **kwargs):
|
||||||
@ -775,17 +768,17 @@ class SortedListField(ListField):
|
|||||||
|
|
||||||
|
|
||||||
def key_not_string(d):
|
def key_not_string(d):
|
||||||
""" Helper function to recursively determine if any key in a dictionary is
|
"""Helper function to recursively determine if any key in a
|
||||||
not a string.
|
dictionary is not a string.
|
||||||
"""
|
"""
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)):
|
if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def key_has_dot_or_dollar(d):
|
def key_has_dot_or_dollar(d):
|
||||||
""" Helper function to recursively determine if any key in a dictionary
|
"""Helper function to recursively determine if any key in a
|
||||||
contains a dot or a dollar sign.
|
dictionary contains a dot or a dollar sign.
|
||||||
"""
|
"""
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
|
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
|
||||||
@ -813,14 +806,13 @@ class DictField(ComplexBaseField):
|
|||||||
super(DictField, self).__init__(*args, **kwargs)
|
super(DictField, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
"""Make sure that a list of valid fields is being used.
|
"""Make sure that a list of valid fields is being used."""
|
||||||
"""
|
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
self.error('Only dictionaries may be used in a DictField')
|
self.error('Only dictionaries may be used in a DictField')
|
||||||
|
|
||||||
if key_not_string(value):
|
if key_not_string(value):
|
||||||
msg = ("Invalid dictionary key - documents must "
|
msg = ('Invalid dictionary key - documents must '
|
||||||
"have only string keys")
|
'have only string keys')
|
||||||
self.error(msg)
|
self.error(msg)
|
||||||
if key_has_dot_or_dollar(value):
|
if key_has_dot_or_dollar(value):
|
||||||
self.error('Invalid dictionary key name - keys may not contain "."'
|
self.error('Invalid dictionary key name - keys may not contain "."'
|
||||||
@ -835,14 +827,15 @@ class DictField(ComplexBaseField):
|
|||||||
'istartswith', 'endswith', 'iendswith',
|
'istartswith', 'endswith', 'iendswith',
|
||||||
'exact', 'iexact']
|
'exact', 'iexact']
|
||||||
|
|
||||||
if op in match_operators and isinstance(value, basestring):
|
if op in match_operators and isinstance(value, six.string_types):
|
||||||
return StringField().prepare_query_value(op, value)
|
return StringField().prepare_query_value(op, value)
|
||||||
|
|
||||||
if hasattr(self.field, 'field'):
|
if hasattr(self.field, 'field'):
|
||||||
if op in ('set', 'unset') and isinstance(value, dict):
|
if op in ('set', 'unset') and isinstance(value, dict):
|
||||||
return dict(
|
return {
|
||||||
(k, self.field.prepare_query_value(op, v))
|
k: self.field.prepare_query_value(op, v)
|
||||||
for k, v in value.items())
|
for k, v in value.items()
|
||||||
|
}
|
||||||
return self.field.prepare_query_value(op, value)
|
return self.field.prepare_query_value(op, value)
|
||||||
|
|
||||||
return super(DictField, self).prepare_query_value(op, value)
|
return super(DictField, self).prepare_query_value(op, value)
|
||||||
@ -911,10 +904,12 @@ class ReferenceField(BaseField):
|
|||||||
A reference to an abstract document type is always stored as a
|
A reference to an abstract document type is always stored as a
|
||||||
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
|
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
|
||||||
"""
|
"""
|
||||||
if not isinstance(document_type, basestring):
|
if (
|
||||||
if not issubclass(document_type, (Document, basestring)):
|
not isinstance(document_type, six.string_types) and
|
||||||
self.error('Argument to ReferenceField constructor must be a '
|
not issubclass(document_type, Document)
|
||||||
'document class or a string')
|
):
|
||||||
|
self.error('Argument to ReferenceField constructor must be a '
|
||||||
|
'document class or a string')
|
||||||
|
|
||||||
self.dbref = dbref
|
self.dbref = dbref
|
||||||
self.document_type_obj = document_type
|
self.document_type_obj = document_type
|
||||||
@ -923,7 +918,7 @@ class ReferenceField(BaseField):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def document_type(self):
|
def document_type(self):
|
||||||
if isinstance(self.document_type_obj, basestring):
|
if isinstance(self.document_type_obj, six.string_types):
|
||||||
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
||||||
self.document_type_obj = self.owner_document
|
self.document_type_obj = self.owner_document
|
||||||
else:
|
else:
|
||||||
@ -931,8 +926,7 @@ class ReferenceField(BaseField):
|
|||||||
return self.document_type_obj
|
return self.document_type_obj
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
"""Descriptor to allow lazy dereferencing.
|
"""Descriptor to allow lazy dereferencing."""
|
||||||
"""
|
|
||||||
if instance is None:
|
if instance is None:
|
||||||
# Document class being used rather than a document object
|
# Document class being used rather than a document object
|
||||||
return self
|
return self
|
||||||
@ -989,8 +983,7 @@ class ReferenceField(BaseField):
|
|||||||
return id_
|
return id_
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
"""Convert a MongoDB-compatible type to a Python type.
|
"""Convert a MongoDB-compatible type to a Python type."""
|
||||||
"""
|
|
||||||
if (not self.dbref and
|
if (not self.dbref and
|
||||||
not isinstance(value, (DBRef, Document, EmbeddedDocument))):
|
not isinstance(value, (DBRef, Document, EmbeddedDocument))):
|
||||||
collection = self.document_type._get_collection_name()
|
collection = self.document_type._get_collection_name()
|
||||||
@ -1006,7 +999,7 @@ class ReferenceField(BaseField):
|
|||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
|
|
||||||
if not isinstance(value, (self.document_type, DBRef)):
|
if not isinstance(value, (self.document_type, DBRef)):
|
||||||
self.error("A ReferenceField only accepts DBRef or documents")
|
self.error('A ReferenceField only accepts DBRef or documents')
|
||||||
|
|
||||||
if isinstance(value, Document) and value.id is None:
|
if isinstance(value, Document) and value.id is None:
|
||||||
self.error('You can only reference documents once they have been '
|
self.error('You can only reference documents once they have been '
|
||||||
@ -1030,14 +1023,19 @@ class CachedReferenceField(BaseField):
|
|||||||
.. versionadded:: 0.9
|
.. versionadded:: 0.9
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, document_type, fields=[], auto_sync=True, **kwargs):
|
def __init__(self, document_type, fields=None, auto_sync=True, **kwargs):
|
||||||
"""Initialises the Cached Reference Field.
|
"""Initialises the Cached Reference Field.
|
||||||
|
|
||||||
:param fields: A list of fields to be cached in document
|
:param fields: A list of fields to be cached in document
|
||||||
:param auto_sync: if True documents are auto updated.
|
:param auto_sync: if True documents are auto updated.
|
||||||
"""
|
"""
|
||||||
if not isinstance(document_type, basestring) and \
|
if fields is None:
|
||||||
not issubclass(document_type, (Document, basestring)):
|
fields = []
|
||||||
|
|
||||||
|
if (
|
||||||
|
not isinstance(document_type, six.string_types) and
|
||||||
|
not issubclass(document_type, Document)
|
||||||
|
):
|
||||||
self.error('Argument to CachedReferenceField constructor must be a'
|
self.error('Argument to CachedReferenceField constructor must be a'
|
||||||
' document class or a string')
|
' document class or a string')
|
||||||
|
|
||||||
@ -1053,18 +1051,20 @@ class CachedReferenceField(BaseField):
|
|||||||
sender=self.document_type)
|
sender=self.document_type)
|
||||||
|
|
||||||
def on_document_pre_save(self, sender, document, created, **kwargs):
|
def on_document_pre_save(self, sender, document, created, **kwargs):
|
||||||
if not created:
|
if created:
|
||||||
update_kwargs = dict(
|
return None
|
||||||
('set__%s__%s' % (self.name, k), v)
|
|
||||||
for k, v in document._delta()[0].items()
|
|
||||||
if k in self.fields)
|
|
||||||
|
|
||||||
if update_kwargs:
|
update_kwargs = {
|
||||||
filter_kwargs = {}
|
'set__%s__%s' % (self.name, key): val
|
||||||
filter_kwargs[self.name] = document
|
for key, val in document._delta()[0].items()
|
||||||
|
if key in self.fields
|
||||||
|
}
|
||||||
|
if update_kwargs:
|
||||||
|
filter_kwargs = {}
|
||||||
|
filter_kwargs[self.name] = document
|
||||||
|
|
||||||
self.owner_document.objects(
|
self.owner_document.objects(
|
||||||
**filter_kwargs).update(**update_kwargs)
|
**filter_kwargs).update(**update_kwargs)
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
@ -1077,7 +1077,7 @@ class CachedReferenceField(BaseField):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def document_type(self):
|
def document_type(self):
|
||||||
if isinstance(self.document_type_obj, basestring):
|
if isinstance(self.document_type_obj, six.string_types):
|
||||||
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
||||||
self.document_type_obj = self.owner_document
|
self.document_type_obj = self.owner_document
|
||||||
else:
|
else:
|
||||||
@ -1117,7 +1117,7 @@ class CachedReferenceField(BaseField):
|
|||||||
# TODO: should raise here or will fail next statement
|
# TODO: should raise here or will fail next statement
|
||||||
|
|
||||||
value = SON((
|
value = SON((
|
||||||
("_id", id_field.to_mongo(id_)),
|
('_id', id_field.to_mongo(id_)),
|
||||||
))
|
))
|
||||||
|
|
||||||
if fields:
|
if fields:
|
||||||
@ -1143,7 +1143,7 @@ class CachedReferenceField(BaseField):
|
|||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
|
|
||||||
if not isinstance(value, self.document_type):
|
if not isinstance(value, self.document_type):
|
||||||
self.error("A CachedReferenceField only accepts documents")
|
self.error('A CachedReferenceField only accepts documents')
|
||||||
|
|
||||||
if isinstance(value, Document) and value.id is None:
|
if isinstance(value, Document) and value.id is None:
|
||||||
self.error('You can only reference documents once they have been '
|
self.error('You can only reference documents once they have been '
|
||||||
@ -1191,13 +1191,13 @@ class GenericReferenceField(BaseField):
|
|||||||
# Keep the choices as a list of allowed Document class names
|
# Keep the choices as a list of allowed Document class names
|
||||||
if choices:
|
if choices:
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
if isinstance(choice, basestring):
|
if isinstance(choice, six.string_types):
|
||||||
self.choices.append(choice)
|
self.choices.append(choice)
|
||||||
elif isinstance(choice, type) and issubclass(choice, Document):
|
elif isinstance(choice, type) and issubclass(choice, Document):
|
||||||
self.choices.append(choice._class_name)
|
self.choices.append(choice._class_name)
|
||||||
else:
|
else:
|
||||||
self.error('Invalid choices provided: must be a list of'
|
self.error('Invalid choices provided: must be a list of'
|
||||||
'Document subclasses and/or basestrings')
|
'Document subclasses and/or six.string_typess')
|
||||||
|
|
||||||
def _validate_choices(self, value):
|
def _validate_choices(self, value):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
@ -1280,8 +1280,7 @@ class GenericReferenceField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class BinaryField(BaseField):
|
class BinaryField(BaseField):
|
||||||
"""A binary data field.
|
"""A binary data field."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, max_bytes=None, **kwargs):
|
def __init__(self, max_bytes=None, **kwargs):
|
||||||
self.max_bytes = max_bytes
|
self.max_bytes = max_bytes
|
||||||
@ -1289,18 +1288,18 @@ class BinaryField(BaseField):
|
|||||||
|
|
||||||
def __set__(self, instance, value):
|
def __set__(self, instance, value):
|
||||||
"""Handle bytearrays in python 3.1"""
|
"""Handle bytearrays in python 3.1"""
|
||||||
if PY3 and isinstance(value, bytearray):
|
if six.PY3 and isinstance(value, bytearray):
|
||||||
value = bin_type(value)
|
value = six.binary_type(value)
|
||||||
return super(BinaryField, self).__set__(instance, value)
|
return super(BinaryField, self).__set__(instance, value)
|
||||||
|
|
||||||
def to_mongo(self, value):
|
def to_mongo(self, value):
|
||||||
return Binary(value)
|
return Binary(value)
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
if not isinstance(value, (bin_type, txt_type, Binary)):
|
if not isinstance(value, (six.binary_type, six.text_type, Binary)):
|
||||||
self.error("BinaryField only accepts instances of "
|
self.error('BinaryField only accepts instances of '
|
||||||
"(%s, %s, Binary)" % (
|
'(%s, %s, Binary)' % (
|
||||||
bin_type.__name__, txt_type.__name__))
|
six.binary_type.__name__, six.text_type.__name__))
|
||||||
|
|
||||||
if self.max_bytes is not None and len(value) > self.max_bytes:
|
if self.max_bytes is not None and len(value) > self.max_bytes:
|
||||||
self.error('Binary value is too long')
|
self.error('Binary value is too long')
|
||||||
@ -1384,11 +1383,13 @@ class GridFSProxy(object):
|
|||||||
get_db(self.db_alias), self.collection_name)
|
get_db(self.db_alias), self.collection_name)
|
||||||
return self._fs
|
return self._fs
|
||||||
|
|
||||||
def get(self, id=None):
|
def get(self, grid_id=None):
|
||||||
if id:
|
if grid_id:
|
||||||
self.grid_id = id
|
self.grid_id = grid_id
|
||||||
|
|
||||||
if self.grid_id is None:
|
if self.grid_id is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.gridout is None:
|
if self.gridout is None:
|
||||||
self.gridout = self.fs.get(self.grid_id)
|
self.gridout = self.fs.get(self.grid_id)
|
||||||
@ -1432,7 +1433,7 @@ class GridFSProxy(object):
|
|||||||
try:
|
try:
|
||||||
return gridout.read(size)
|
return gridout.read(size)
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ''
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
# Delete file from GridFS, FileField still remains
|
# Delete file from GridFS, FileField still remains
|
||||||
@ -1464,9 +1465,8 @@ class FileField(BaseField):
|
|||||||
"""
|
"""
|
||||||
proxy_class = GridFSProxy
|
proxy_class = GridFSProxy
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, db_alias=DEFAULT_CONNECTION_NAME, collection_name='fs',
|
||||||
db_alias=DEFAULT_CONNECTION_NAME,
|
**kwargs):
|
||||||
collection_name="fs", **kwargs):
|
|
||||||
super(FileField, self).__init__(**kwargs)
|
super(FileField, self).__init__(**kwargs)
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.db_alias = db_alias
|
self.db_alias = db_alias
|
||||||
@ -1488,8 +1488,10 @@ class FileField(BaseField):
|
|||||||
|
|
||||||
def __set__(self, instance, value):
|
def __set__(self, instance, value):
|
||||||
key = self.name
|
key = self.name
|
||||||
if ((hasattr(value, 'read') and not
|
if (
|
||||||
isinstance(value, GridFSProxy)) or isinstance(value, str_types)):
|
(hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or
|
||||||
|
isinstance(value, (six.binary_type, six.string_types))
|
||||||
|
):
|
||||||
# using "FileField() = file/string" notation
|
# using "FileField() = file/string" notation
|
||||||
grid_file = instance._data.get(self.name)
|
grid_file = instance._data.get(self.name)
|
||||||
# If a file already exists, delete it
|
# If a file already exists, delete it
|
||||||
@ -1558,7 +1560,7 @@ class ImageGridFsProxy(GridFSProxy):
|
|||||||
try:
|
try:
|
||||||
img = Image.open(file_obj)
|
img = Image.open(file_obj)
|
||||||
img_format = img.format
|
img_format = img.format
|
||||||
except Exception, e:
|
except Exception as e:
|
||||||
raise ValidationError('Invalid image: %s' % e)
|
raise ValidationError('Invalid image: %s' % e)
|
||||||
|
|
||||||
# Progressive JPEG
|
# Progressive JPEG
|
||||||
@ -1667,10 +1669,10 @@ class ImageGridFsProxy(GridFSProxy):
|
|||||||
return self.fs.get(out.thumbnail_id)
|
return self.fs.get(out.thumbnail_id)
|
||||||
|
|
||||||
def write(self, *args, **kwargs):
|
def write(self, *args, **kwargs):
|
||||||
raise RuntimeError("Please use \"put\" method instead")
|
raise RuntimeError('Please use "put" method instead')
|
||||||
|
|
||||||
def writelines(self, *args, **kwargs):
|
def writelines(self, *args, **kwargs):
|
||||||
raise RuntimeError("Please use \"put\" method instead")
|
raise RuntimeError('Please use "put" method instead')
|
||||||
|
|
||||||
|
|
||||||
class ImproperlyConfigured(Exception):
|
class ImproperlyConfigured(Exception):
|
||||||
@ -1695,14 +1697,17 @@ class ImageField(FileField):
|
|||||||
def __init__(self, size=None, thumbnail_size=None,
|
def __init__(self, size=None, thumbnail_size=None,
|
||||||
collection_name='images', **kwargs):
|
collection_name='images', **kwargs):
|
||||||
if not Image:
|
if not Image:
|
||||||
raise ImproperlyConfigured("PIL library was not found")
|
raise ImproperlyConfigured('PIL library was not found')
|
||||||
|
|
||||||
params_size = ('width', 'height', 'force')
|
params_size = ('width', 'height', 'force')
|
||||||
extra_args = dict(size=size, thumbnail_size=thumbnail_size)
|
extra_args = {
|
||||||
|
'size': size,
|
||||||
|
'thumbnail_size': thumbnail_size
|
||||||
|
}
|
||||||
for att_name, att in extra_args.items():
|
for att_name, att in extra_args.items():
|
||||||
value = None
|
value = None
|
||||||
if isinstance(att, (tuple, list)):
|
if isinstance(att, (tuple, list)):
|
||||||
if PY3:
|
if six.PY3:
|
||||||
value = dict(itertools.zip_longest(params_size, att,
|
value = dict(itertools.zip_longest(params_size, att,
|
||||||
fillvalue=None))
|
fillvalue=None))
|
||||||
else:
|
else:
|
||||||
@ -1763,10 +1768,10 @@ class SequenceField(BaseField):
|
|||||||
Generate and Increment the counter
|
Generate and Increment the counter
|
||||||
"""
|
"""
|
||||||
sequence_name = self.get_sequence_name()
|
sequence_name = self.get_sequence_name()
|
||||||
sequence_id = "%s.%s" % (sequence_name, self.name)
|
sequence_id = '%s.%s' % (sequence_name, self.name)
|
||||||
collection = get_db(alias=self.db_alias)[self.collection_name]
|
collection = get_db(alias=self.db_alias)[self.collection_name]
|
||||||
counter = collection.find_and_modify(query={"_id": sequence_id},
|
counter = collection.find_and_modify(query={'_id': sequence_id},
|
||||||
update={"$inc": {"next": 1}},
|
update={'$inc': {'next': 1}},
|
||||||
new=True,
|
new=True,
|
||||||
upsert=True)
|
upsert=True)
|
||||||
return self.value_decorator(counter['next'])
|
return self.value_decorator(counter['next'])
|
||||||
@ -1789,9 +1794,9 @@ class SequenceField(BaseField):
|
|||||||
as it is only fixed on set.
|
as it is only fixed on set.
|
||||||
"""
|
"""
|
||||||
sequence_name = self.get_sequence_name()
|
sequence_name = self.get_sequence_name()
|
||||||
sequence_id = "%s.%s" % (sequence_name, self.name)
|
sequence_id = '%s.%s' % (sequence_name, self.name)
|
||||||
collection = get_db(alias=self.db_alias)[self.collection_name]
|
collection = get_db(alias=self.db_alias)[self.collection_name]
|
||||||
data = collection.find_one({"_id": sequence_id})
|
data = collection.find_one({'_id': sequence_id})
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
return self.value_decorator(data['next'] + 1)
|
return self.value_decorator(data['next'] + 1)
|
||||||
@ -1861,8 +1866,8 @@ class UUIDField(BaseField):
|
|||||||
if not self._binary:
|
if not self._binary:
|
||||||
original_value = value
|
original_value = value
|
||||||
try:
|
try:
|
||||||
if not isinstance(value, basestring):
|
if not isinstance(value, six.string_types):
|
||||||
value = unicode(value)
|
value = six.text_type(value)
|
||||||
return uuid.UUID(value)
|
return uuid.UUID(value)
|
||||||
except Exception:
|
except Exception:
|
||||||
return original_value
|
return original_value
|
||||||
@ -1870,8 +1875,8 @@ class UUIDField(BaseField):
|
|||||||
|
|
||||||
def to_mongo(self, value):
|
def to_mongo(self, value):
|
||||||
if not self._binary:
|
if not self._binary:
|
||||||
return unicode(value)
|
return six.text_type(value)
|
||||||
elif isinstance(value, basestring):
|
elif isinstance(value, six.string_types):
|
||||||
return uuid.UUID(value)
|
return uuid.UUID(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -1882,11 +1887,11 @@ class UUIDField(BaseField):
|
|||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
if not isinstance(value, uuid.UUID):
|
if not isinstance(value, uuid.UUID):
|
||||||
if not isinstance(value, basestring):
|
if not isinstance(value, six.string_types):
|
||||||
value = str(value)
|
value = str(value)
|
||||||
try:
|
try:
|
||||||
uuid.UUID(value)
|
uuid.UUID(value)
|
||||||
except Exception, exc:
|
except Exception as exc:
|
||||||
self.error('Could not convert to UUID: %s' % exc)
|
self.error('Could not convert to UUID: %s' % exc)
|
||||||
|
|
||||||
|
|
||||||
@ -1904,19 +1909,18 @@ class GeoPointField(BaseField):
|
|||||||
_geo_index = pymongo.GEO2D
|
_geo_index = pymongo.GEO2D
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
"""Make sure that a geo-value is of type (x, y)
|
"""Make sure that a geo-value is of type (x, y)"""
|
||||||
"""
|
|
||||||
if not isinstance(value, (list, tuple)):
|
if not isinstance(value, (list, tuple)):
|
||||||
self.error('GeoPointField can only accept tuples or lists '
|
self.error('GeoPointField can only accept tuples or lists '
|
||||||
'of (x, y)')
|
'of (x, y)')
|
||||||
|
|
||||||
if not len(value) == 2:
|
if not len(value) == 2:
|
||||||
self.error("Value (%s) must be a two-dimensional point" %
|
self.error('Value (%s) must be a two-dimensional point' %
|
||||||
repr(value))
|
repr(value))
|
||||||
elif (not isinstance(value[0], (float, int)) or
|
elif (not isinstance(value[0], (float, int)) or
|
||||||
not isinstance(value[1], (float, int))):
|
not isinstance(value[1], (float, int))):
|
||||||
self.error(
|
self.error(
|
||||||
"Both values (%s) in point must be float or int" % repr(value))
|
'Both values (%s) in point must be float or int' % repr(value))
|
||||||
|
|
||||||
|
|
||||||
class PointField(GeoJsonBaseField):
|
class PointField(GeoJsonBaseField):
|
||||||
@ -1926,8 +1930,8 @@ class PointField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. code-block:: js
|
.. code-block:: js
|
||||||
|
|
||||||
{ "type" : "Point" ,
|
{'type' : 'Point' ,
|
||||||
"coordinates" : [x, y]}
|
'coordinates' : [x, y]}
|
||||||
|
|
||||||
You can either pass a dict with the full information or a list
|
You can either pass a dict with the full information or a list
|
||||||
to set the value.
|
to set the value.
|
||||||
@ -1936,7 +1940,7 @@ class PointField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. versionadded:: 0.8
|
.. versionadded:: 0.8
|
||||||
"""
|
"""
|
||||||
_type = "Point"
|
_type = 'Point'
|
||||||
|
|
||||||
|
|
||||||
class LineStringField(GeoJsonBaseField):
|
class LineStringField(GeoJsonBaseField):
|
||||||
@ -1946,8 +1950,8 @@ class LineStringField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. code-block:: js
|
.. code-block:: js
|
||||||
|
|
||||||
{ "type" : "LineString" ,
|
{'type' : 'LineString' ,
|
||||||
"coordinates" : [[x1, y1], [x1, y1] ... [xn, yn]]}
|
'coordinates' : [[x1, y1], [x1, y1] ... [xn, yn]]}
|
||||||
|
|
||||||
You can either pass a dict with the full information or a list of points.
|
You can either pass a dict with the full information or a list of points.
|
||||||
|
|
||||||
@ -1955,7 +1959,7 @@ class LineStringField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. versionadded:: 0.8
|
.. versionadded:: 0.8
|
||||||
"""
|
"""
|
||||||
_type = "LineString"
|
_type = 'LineString'
|
||||||
|
|
||||||
|
|
||||||
class PolygonField(GeoJsonBaseField):
|
class PolygonField(GeoJsonBaseField):
|
||||||
@ -1965,9 +1969,9 @@ class PolygonField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. code-block:: js
|
.. code-block:: js
|
||||||
|
|
||||||
{ "type" : "Polygon" ,
|
{'type' : 'Polygon' ,
|
||||||
"coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]],
|
'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]],
|
||||||
[[x1, y1], [x1, y1] ... [xn, yn]]}
|
[[x1, y1], [x1, y1] ... [xn, yn]]}
|
||||||
|
|
||||||
You can either pass a dict with the full information or a list
|
You can either pass a dict with the full information or a list
|
||||||
of LineStrings. The first LineString being the outside and the rest being
|
of LineStrings. The first LineString being the outside and the rest being
|
||||||
@ -1977,7 +1981,7 @@ class PolygonField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. versionadded:: 0.8
|
.. versionadded:: 0.8
|
||||||
"""
|
"""
|
||||||
_type = "Polygon"
|
_type = 'Polygon'
|
||||||
|
|
||||||
|
|
||||||
class MultiPointField(GeoJsonBaseField):
|
class MultiPointField(GeoJsonBaseField):
|
||||||
@ -1987,8 +1991,8 @@ class MultiPointField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. code-block:: js
|
.. code-block:: js
|
||||||
|
|
||||||
{ "type" : "MultiPoint" ,
|
{'type' : 'MultiPoint' ,
|
||||||
"coordinates" : [[x1, y1], [x2, y2]]}
|
'coordinates' : [[x1, y1], [x2, y2]]}
|
||||||
|
|
||||||
You can either pass a dict with the full information or a list
|
You can either pass a dict with the full information or a list
|
||||||
to set the value.
|
to set the value.
|
||||||
@ -1997,7 +2001,7 @@ class MultiPointField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. versionadded:: 0.9
|
.. versionadded:: 0.9
|
||||||
"""
|
"""
|
||||||
_type = "MultiPoint"
|
_type = 'MultiPoint'
|
||||||
|
|
||||||
|
|
||||||
class MultiLineStringField(GeoJsonBaseField):
|
class MultiLineStringField(GeoJsonBaseField):
|
||||||
@ -2007,9 +2011,9 @@ class MultiLineStringField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. code-block:: js
|
.. code-block:: js
|
||||||
|
|
||||||
{ "type" : "MultiLineString" ,
|
{'type' : 'MultiLineString' ,
|
||||||
"coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]],
|
'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]],
|
||||||
[[x1, y1], [x1, y1] ... [xn, yn]]]}
|
[[x1, y1], [x1, y1] ... [xn, yn]]]}
|
||||||
|
|
||||||
You can either pass a dict with the full information or a list of points.
|
You can either pass a dict with the full information or a list of points.
|
||||||
|
|
||||||
@ -2017,7 +2021,7 @@ class MultiLineStringField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. versionadded:: 0.9
|
.. versionadded:: 0.9
|
||||||
"""
|
"""
|
||||||
_type = "MultiLineString"
|
_type = 'MultiLineString'
|
||||||
|
|
||||||
|
|
||||||
class MultiPolygonField(GeoJsonBaseField):
|
class MultiPolygonField(GeoJsonBaseField):
|
||||||
@ -2027,14 +2031,14 @@ class MultiPolygonField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. code-block:: js
|
.. code-block:: js
|
||||||
|
|
||||||
{ "type" : "MultiPolygon" ,
|
{'type' : 'MultiPolygon' ,
|
||||||
"coordinates" : [[
|
'coordinates' : [[
|
||||||
[[x1, y1], [x1, y1] ... [xn, yn]],
|
[[x1, y1], [x1, y1] ... [xn, yn]],
|
||||||
[[x1, y1], [x1, y1] ... [xn, yn]]
|
[[x1, y1], [x1, y1] ... [xn, yn]]
|
||||||
], [
|
], [
|
||||||
[[x1, y1], [x1, y1] ... [xn, yn]],
|
[[x1, y1], [x1, y1] ... [xn, yn]],
|
||||||
[[x1, y1], [x1, y1] ... [xn, yn]]
|
[[x1, y1], [x1, y1] ... [xn, yn]]
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
You can either pass a dict with the full information or a list
|
You can either pass a dict with the full information or a list
|
||||||
@ -2044,4 +2048,4 @@ class MultiPolygonField(GeoJsonBaseField):
|
|||||||
|
|
||||||
.. versionadded:: 0.9
|
.. versionadded:: 0.9
|
||||||
"""
|
"""
|
||||||
_type = "MultiPolygon"
|
_type = 'MultiPolygon'
|
||||||
|
@ -1,50 +1,25 @@
|
|||||||
"""Helper functions and types to aid with Python 2.6 - 3 support."""
|
"""
|
||||||
|
Helper functions, constants, and types to aid with Python v2.7 - v3.x and
|
||||||
import sys
|
PyMongo v2.7 - v3.x support.
|
||||||
import warnings
|
"""
|
||||||
|
|
||||||
import pymongo
|
import pymongo
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
# Show a deprecation warning for people using Python v2.6
|
|
||||||
# TODO remove in mongoengine v0.11.0
|
|
||||||
if sys.version_info[0] == 2 and sys.version_info[1] == 6:
|
|
||||||
warnings.warn(
|
|
||||||
'Python v2.6 support is deprecated and is going to be dropped '
|
|
||||||
'entirely in the upcoming v0.11.0 release. Update your Python '
|
|
||||||
'version if you want to have access to the latest features and '
|
|
||||||
'bug fixes in MongoEngine.',
|
|
||||||
DeprecationWarning
|
|
||||||
)
|
|
||||||
|
|
||||||
if pymongo.version_tuple[0] < 3:
|
if pymongo.version_tuple[0] < 3:
|
||||||
IS_PYMONGO_3 = False
|
IS_PYMONGO_3 = False
|
||||||
else:
|
else:
|
||||||
IS_PYMONGO_3 = True
|
IS_PYMONGO_3 = True
|
||||||
|
|
||||||
PY3 = sys.version_info[0] == 3
|
|
||||||
|
|
||||||
if PY3:
|
# six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3.
|
||||||
import codecs
|
StringIO = six.BytesIO
|
||||||
from io import BytesIO as StringIO
|
|
||||||
|
|
||||||
# return s converted to binary. b('test') should be equivalent to b'test'
|
# Additionally for Py2, try to use the faster cStringIO, if available
|
||||||
def b(s):
|
if not six.PY3:
|
||||||
return codecs.latin_1_encode(s)[0]
|
|
||||||
|
|
||||||
bin_type = bytes
|
|
||||||
txt_type = str
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
from cStringIO import StringIO
|
import cStringIO
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from StringIO import StringIO
|
pass
|
||||||
|
else:
|
||||||
# Conversion to binary only necessary in Python 3
|
StringIO = cStringIO.StringIO
|
||||||
def b(s):
|
|
||||||
return s
|
|
||||||
|
|
||||||
bin_type = str
|
|
||||||
txt_type = unicode
|
|
||||||
|
|
||||||
str_types = (bin_type, txt_type)
|
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
from mongoengine.errors import (DoesNotExist, InvalidQueryError,
|
from mongoengine.errors import *
|
||||||
MultipleObjectsReturned, NotUniqueError,
|
|
||||||
OperationError)
|
|
||||||
from mongoengine.queryset.field_list import *
|
from mongoengine.queryset.field_list import *
|
||||||
from mongoengine.queryset.manager import *
|
from mongoengine.queryset.manager import *
|
||||||
from mongoengine.queryset.queryset import *
|
from mongoengine.queryset.queryset import *
|
||||||
from mongoengine.queryset.transform import *
|
from mongoengine.queryset.transform import *
|
||||||
from mongoengine.queryset.visitor import *
|
from mongoengine.queryset.visitor import *
|
||||||
|
|
||||||
__all__ = (field_list.__all__ + manager.__all__ + queryset.__all__ +
|
# Expose just the public subset of all imported objects and constants.
|
||||||
transform.__all__ + visitor.__all__)
|
__all__ = (
|
||||||
|
'QuerySet', 'QuerySetNoCache', 'Q', 'queryset_manager', 'QuerySetManager',
|
||||||
|
'QueryFieldList', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL',
|
||||||
|
|
||||||
|
# Errors that might be related to a queryset, mostly here for backward
|
||||||
|
# compatibility
|
||||||
|
'DoesNotExist', 'InvalidQueryError', 'MultipleObjectsReturned',
|
||||||
|
'NotUniqueError', 'OperationError',
|
||||||
|
)
|
||||||
|
@ -12,9 +12,10 @@ from bson.code import Code
|
|||||||
import pymongo
|
import pymongo
|
||||||
import pymongo.errors
|
import pymongo.errors
|
||||||
from pymongo.common import validate_read_preference
|
from pymongo.common import validate_read_preference
|
||||||
|
import six
|
||||||
|
|
||||||
from mongoengine import signals
|
from mongoengine import signals
|
||||||
from mongoengine.base.common import get_document
|
from mongoengine.base import get_document
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.connection import get_db
|
from mongoengine.connection import get_db
|
||||||
from mongoengine.context_managers import switch_db
|
from mongoengine.context_managers import switch_db
|
||||||
@ -73,10 +74,10 @@ class BaseQuerySet(object):
|
|||||||
# subclasses of the class being used
|
# subclasses of the class being used
|
||||||
if document._meta.get('allow_inheritance') is True:
|
if document._meta.get('allow_inheritance') is True:
|
||||||
if len(self._document._subclasses) == 1:
|
if len(self._document._subclasses) == 1:
|
||||||
self._initial_query = {"_cls": self._document._subclasses[0]}
|
self._initial_query = {'_cls': self._document._subclasses[0]}
|
||||||
else:
|
else:
|
||||||
self._initial_query = {
|
self._initial_query = {
|
||||||
"_cls": {"$in": self._document._subclasses}}
|
'_cls': {'$in': self._document._subclasses}}
|
||||||
self._loaded_fields = QueryFieldList(always_include=['_cls'])
|
self._loaded_fields = QueryFieldList(always_include=['_cls'])
|
||||||
self._cursor_obj = None
|
self._cursor_obj = None
|
||||||
self._limit = None
|
self._limit = None
|
||||||
@ -105,8 +106,8 @@ class BaseQuerySet(object):
|
|||||||
if q_obj:
|
if q_obj:
|
||||||
# make sure proper query object is passed
|
# make sure proper query object is passed
|
||||||
if not isinstance(q_obj, QNode):
|
if not isinstance(q_obj, QNode):
|
||||||
msg = ("Not a query object: %s. "
|
msg = ('Not a query object: %s. '
|
||||||
"Did you intend to use key=value?" % q_obj)
|
'Did you intend to use key=value?' % q_obj)
|
||||||
raise InvalidQueryError(msg)
|
raise InvalidQueryError(msg)
|
||||||
query &= q_obj
|
query &= q_obj
|
||||||
|
|
||||||
@ -133,10 +134,10 @@ class BaseQuerySet(object):
|
|||||||
obj_dict = self.__dict__.copy()
|
obj_dict = self.__dict__.copy()
|
||||||
|
|
||||||
# don't picke collection, instead pickle collection params
|
# don't picke collection, instead pickle collection params
|
||||||
obj_dict.pop("_collection_obj")
|
obj_dict.pop('_collection_obj')
|
||||||
|
|
||||||
# don't pickle cursor
|
# don't pickle cursor
|
||||||
obj_dict["_cursor_obj"] = None
|
obj_dict['_cursor_obj'] = None
|
||||||
|
|
||||||
return obj_dict
|
return obj_dict
|
||||||
|
|
||||||
@ -147,7 +148,7 @@ class BaseQuerySet(object):
|
|||||||
See https://github.com/MongoEngine/mongoengine/issues/442
|
See https://github.com/MongoEngine/mongoengine/issues/442
|
||||||
"""
|
"""
|
||||||
|
|
||||||
obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection()
|
obj_dict['_collection_obj'] = obj_dict['_document']._get_collection()
|
||||||
|
|
||||||
# update attributes
|
# update attributes
|
||||||
self.__dict__.update(obj_dict)
|
self.__dict__.update(obj_dict)
|
||||||
@ -166,7 +167,7 @@ class BaseQuerySet(object):
|
|||||||
queryset._skip, queryset._limit = key.start, key.stop
|
queryset._skip, queryset._limit = key.start, key.stop
|
||||||
if key.start and key.stop:
|
if key.start and key.stop:
|
||||||
queryset._limit = key.stop - key.start
|
queryset._limit = key.stop - key.start
|
||||||
except IndexError, err:
|
except IndexError as err:
|
||||||
# PyMongo raises an error if key.start == key.stop, catch it,
|
# PyMongo raises an error if key.start == key.stop, catch it,
|
||||||
# bin it, kill it.
|
# bin it, kill it.
|
||||||
start = key.start or 0
|
start = key.start or 0
|
||||||
@ -199,19 +200,16 @@ class BaseQuerySet(object):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _has_data(self):
|
def _has_data(self):
|
||||||
""" Retrieves whether cursor has any data. """
|
"""Return True if cursor has any data."""
|
||||||
|
|
||||||
queryset = self.order_by()
|
queryset = self.order_by()
|
||||||
return False if queryset.first() is None else True
|
return False if queryset.first() is None else True
|
||||||
|
|
||||||
def __nonzero__(self):
|
def __nonzero__(self):
|
||||||
""" Avoid to open all records in an if stmt in Py2. """
|
"""Avoid to open all records in an if stmt in Py2."""
|
||||||
|
|
||||||
return self._has_data()
|
return self._has_data()
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
""" Avoid to open all records in an if stmt in Py3. """
|
"""Avoid to open all records in an if stmt in Py3."""
|
||||||
|
|
||||||
return self._has_data()
|
return self._has_data()
|
||||||
|
|
||||||
# Core functions
|
# Core functions
|
||||||
@ -239,7 +237,7 @@ class BaseQuerySet(object):
|
|||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
if queryset._search_text:
|
if queryset._search_text:
|
||||||
raise OperationError(
|
raise OperationError(
|
||||||
"It is not possible to use search_text two times.")
|
'It is not possible to use search_text two times.')
|
||||||
|
|
||||||
query_kwargs = SON({'$search': text})
|
query_kwargs = SON({'$search': text})
|
||||||
if language:
|
if language:
|
||||||
@ -268,7 +266,7 @@ class BaseQuerySet(object):
|
|||||||
try:
|
try:
|
||||||
result = queryset.next()
|
result = queryset.next()
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
msg = ("%s matching query does not exist."
|
msg = ('%s matching query does not exist.'
|
||||||
% queryset._document._class_name)
|
% queryset._document._class_name)
|
||||||
raise queryset._document.DoesNotExist(msg)
|
raise queryset._document.DoesNotExist(msg)
|
||||||
try:
|
try:
|
||||||
@ -290,8 +288,7 @@ class BaseQuerySet(object):
|
|||||||
return self._document(**kwargs).save()
|
return self._document(**kwargs).save()
|
||||||
|
|
||||||
def first(self):
|
def first(self):
|
||||||
"""Retrieve the first object matching the query.
|
"""Retrieve the first object matching the query."""
|
||||||
"""
|
|
||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
try:
|
try:
|
||||||
result = queryset[0]
|
result = queryset[0]
|
||||||
@ -340,7 +337,7 @@ class BaseQuerySet(object):
|
|||||||
% str(self._document))
|
% str(self._document))
|
||||||
raise OperationError(msg)
|
raise OperationError(msg)
|
||||||
if doc.pk and not doc._created:
|
if doc.pk and not doc._created:
|
||||||
msg = "Some documents have ObjectIds use doc.update() instead"
|
msg = 'Some documents have ObjectIds use doc.update() instead'
|
||||||
raise OperationError(msg)
|
raise OperationError(msg)
|
||||||
|
|
||||||
signal_kwargs = signal_kwargs or {}
|
signal_kwargs = signal_kwargs or {}
|
||||||
@ -350,17 +347,17 @@ class BaseQuerySet(object):
|
|||||||
raw = [doc.to_mongo() for doc in docs]
|
raw = [doc.to_mongo() for doc in docs]
|
||||||
try:
|
try:
|
||||||
ids = self._collection.insert(raw, **write_concern)
|
ids = self._collection.insert(raw, **write_concern)
|
||||||
except pymongo.errors.DuplicateKeyError, err:
|
except pymongo.errors.DuplicateKeyError as err:
|
||||||
message = 'Could not save document (%s)'
|
message = 'Could not save document (%s)'
|
||||||
raise NotUniqueError(message % unicode(err))
|
raise NotUniqueError(message % six.text_type(err))
|
||||||
except pymongo.errors.OperationFailure, err:
|
except pymongo.errors.OperationFailure as err:
|
||||||
message = 'Could not save document (%s)'
|
message = 'Could not save document (%s)'
|
||||||
if re.match('^E1100[01] duplicate key', unicode(err)):
|
if re.match('^E1100[01] duplicate key', six.text_type(err)):
|
||||||
# E11000 - duplicate key error index
|
# E11000 - duplicate key error index
|
||||||
# E11001 - duplicate key on update
|
# E11001 - duplicate key on update
|
||||||
message = u'Tried to save duplicate unique keys (%s)'
|
message = u'Tried to save duplicate unique keys (%s)'
|
||||||
raise NotUniqueError(message % unicode(err))
|
raise NotUniqueError(message % six.text_type(err))
|
||||||
raise OperationError(message % unicode(err))
|
raise OperationError(message % six.text_type(err))
|
||||||
|
|
||||||
if not load_bulk:
|
if not load_bulk:
|
||||||
signals.post_bulk_insert.send(
|
signals.post_bulk_insert.send(
|
||||||
@ -386,7 +383,8 @@ class BaseQuerySet(object):
|
|||||||
return 0
|
return 0
|
||||||
return self._cursor.count(with_limit_and_skip=with_limit_and_skip)
|
return self._cursor.count(with_limit_and_skip=with_limit_and_skip)
|
||||||
|
|
||||||
def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None):
|
def delete(self, write_concern=None, _from_doc_delete=False,
|
||||||
|
cascade_refs=None):
|
||||||
"""Delete the documents matched by the query.
|
"""Delete the documents matched by the query.
|
||||||
|
|
||||||
:param write_concern: Extra keyword arguments are passed down which
|
:param write_concern: Extra keyword arguments are passed down which
|
||||||
@ -409,8 +407,9 @@ class BaseQuerySet(object):
|
|||||||
# Handle deletes where skips or limits have been applied or
|
# Handle deletes where skips or limits have been applied or
|
||||||
# there is an untriggered delete signal
|
# there is an untriggered delete signal
|
||||||
has_delete_signal = signals.signals_available and (
|
has_delete_signal = signals.signals_available and (
|
||||||
signals.pre_delete.has_receivers_for(self._document) or
|
signals.pre_delete.has_receivers_for(doc) or
|
||||||
signals.post_delete.has_receivers_for(self._document))
|
signals.post_delete.has_receivers_for(doc)
|
||||||
|
)
|
||||||
|
|
||||||
call_document_delete = (queryset._skip or queryset._limit or
|
call_document_delete = (queryset._skip or queryset._limit or
|
||||||
has_delete_signal) and not _from_doc_delete
|
has_delete_signal) and not _from_doc_delete
|
||||||
@ -423,37 +422,44 @@ class BaseQuerySet(object):
|
|||||||
return cnt
|
return cnt
|
||||||
|
|
||||||
delete_rules = doc._meta.get('delete_rules') or {}
|
delete_rules = doc._meta.get('delete_rules') or {}
|
||||||
|
delete_rules = list(delete_rules.items())
|
||||||
|
|
||||||
# Check for DENY rules before actually deleting/nullifying any other
|
# Check for DENY rules before actually deleting/nullifying any other
|
||||||
# references
|
# references
|
||||||
for rule_entry in delete_rules:
|
for rule_entry, rule in delete_rules:
|
||||||
document_cls, field_name = rule_entry
|
document_cls, field_name = rule_entry
|
||||||
if document_cls._meta.get('abstract'):
|
if document_cls._meta.get('abstract'):
|
||||||
continue
|
continue
|
||||||
rule = doc._meta['delete_rules'][rule_entry]
|
|
||||||
if rule == DENY and document_cls.objects(
|
|
||||||
**{field_name + '__in': self}).count() > 0:
|
|
||||||
msg = ("Could not delete document (%s.%s refers to it)"
|
|
||||||
% (document_cls.__name__, field_name))
|
|
||||||
raise OperationError(msg)
|
|
||||||
|
|
||||||
for rule_entry in delete_rules:
|
if rule == DENY:
|
||||||
|
refs = document_cls.objects(**{field_name + '__in': self})
|
||||||
|
if refs.limit(1).count() > 0:
|
||||||
|
raise OperationError(
|
||||||
|
'Could not delete document (%s.%s refers to it)'
|
||||||
|
% (document_cls.__name__, field_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check all the other rules
|
||||||
|
for rule_entry, rule in delete_rules:
|
||||||
document_cls, field_name = rule_entry
|
document_cls, field_name = rule_entry
|
||||||
if document_cls._meta.get('abstract'):
|
if document_cls._meta.get('abstract'):
|
||||||
continue
|
continue
|
||||||
rule = doc._meta['delete_rules'][rule_entry]
|
|
||||||
if rule == CASCADE:
|
if rule == CASCADE:
|
||||||
cascade_refs = set() if cascade_refs is None else cascade_refs
|
cascade_refs = set() if cascade_refs is None else cascade_refs
|
||||||
# Handle recursive reference
|
# Handle recursive reference
|
||||||
if doc._collection == document_cls._collection:
|
if doc._collection == document_cls._collection:
|
||||||
for ref in queryset:
|
for ref in queryset:
|
||||||
cascade_refs.add(ref.id)
|
cascade_refs.add(ref.id)
|
||||||
ref_q = document_cls.objects(**{field_name + '__in': self, 'pk__nin': cascade_refs})
|
refs = document_cls.objects(**{field_name + '__in': self,
|
||||||
ref_q_count = ref_q.count()
|
'pk__nin': cascade_refs})
|
||||||
if ref_q_count > 0:
|
if refs.count() > 0:
|
||||||
ref_q.delete(write_concern=write_concern, cascade_refs=cascade_refs)
|
refs.delete(write_concern=write_concern,
|
||||||
|
cascade_refs=cascade_refs)
|
||||||
elif rule == NULLIFY:
|
elif rule == NULLIFY:
|
||||||
document_cls.objects(**{field_name + '__in': self}).update(
|
document_cls.objects(**{field_name + '__in': self}).update(
|
||||||
write_concern=write_concern, **{'unset__%s' % field_name: 1})
|
write_concern=write_concern,
|
||||||
|
**{'unset__%s' % field_name: 1})
|
||||||
elif rule == PULL:
|
elif rule == PULL:
|
||||||
document_cls.objects(**{field_name + '__in': self}).update(
|
document_cls.objects(**{field_name + '__in': self}).update(
|
||||||
write_concern=write_concern,
|
write_concern=write_concern,
|
||||||
@ -461,7 +467,7 @@ class BaseQuerySet(object):
|
|||||||
|
|
||||||
result = queryset._collection.remove(queryset._query, **write_concern)
|
result = queryset._collection.remove(queryset._query, **write_concern)
|
||||||
if result:
|
if result:
|
||||||
return result.get("n")
|
return result.get('n')
|
||||||
|
|
||||||
def update(self, upsert=False, multi=True, write_concern=None,
|
def update(self, upsert=False, multi=True, write_concern=None,
|
||||||
full_result=False, **update):
|
full_result=False, **update):
|
||||||
@ -482,7 +488,7 @@ class BaseQuerySet(object):
|
|||||||
.. versionadded:: 0.2
|
.. versionadded:: 0.2
|
||||||
"""
|
"""
|
||||||
if not update and not upsert:
|
if not update and not upsert:
|
||||||
raise OperationError("No update parameters, would remove data")
|
raise OperationError('No update parameters, would remove data')
|
||||||
|
|
||||||
if write_concern is None:
|
if write_concern is None:
|
||||||
write_concern = {}
|
write_concern = {}
|
||||||
@ -495,9 +501,9 @@ class BaseQuerySet(object):
|
|||||||
# then ensure we add _cls to the update operation
|
# then ensure we add _cls to the update operation
|
||||||
if upsert and '_cls' in query:
|
if upsert and '_cls' in query:
|
||||||
if '$set' in update:
|
if '$set' in update:
|
||||||
update["$set"]["_cls"] = queryset._document._class_name
|
update['$set']['_cls'] = queryset._document._class_name
|
||||||
else:
|
else:
|
||||||
update["$set"] = {"_cls": queryset._document._class_name}
|
update['$set'] = {'_cls': queryset._document._class_name}
|
||||||
try:
|
try:
|
||||||
result = queryset._collection.update(query, update, multi=multi,
|
result = queryset._collection.update(query, update, multi=multi,
|
||||||
upsert=upsert, **write_concern)
|
upsert=upsert, **write_concern)
|
||||||
@ -505,13 +511,13 @@ class BaseQuerySet(object):
|
|||||||
return result
|
return result
|
||||||
elif result:
|
elif result:
|
||||||
return result['n']
|
return result['n']
|
||||||
except pymongo.errors.DuplicateKeyError, err:
|
except pymongo.errors.DuplicateKeyError as err:
|
||||||
raise NotUniqueError(u'Update failed (%s)' % unicode(err))
|
raise NotUniqueError(u'Update failed (%s)' % six.text_type(err))
|
||||||
except pymongo.errors.OperationFailure, err:
|
except pymongo.errors.OperationFailure as err:
|
||||||
if unicode(err) == u'multi not coded yet':
|
if six.text_type(err) == u'multi not coded yet':
|
||||||
message = u'update() method requires MongoDB 1.1.3+'
|
message = u'update() method requires MongoDB 1.1.3+'
|
||||||
raise OperationError(message)
|
raise OperationError(message)
|
||||||
raise OperationError(u'Update failed (%s)' % unicode(err))
|
raise OperationError(u'Update failed (%s)' % six.text_type(err))
|
||||||
|
|
||||||
def upsert_one(self, write_concern=None, **update):
|
def upsert_one(self, write_concern=None, **update):
|
||||||
"""Overwrite or add the first document matched by the query.
|
"""Overwrite or add the first document matched by the query.
|
||||||
@ -582,11 +588,11 @@ class BaseQuerySet(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if remove and new:
|
if remove and new:
|
||||||
raise OperationError("Conflicting parameters: remove and new")
|
raise OperationError('Conflicting parameters: remove and new')
|
||||||
|
|
||||||
if not update and not upsert and not remove:
|
if not update and not upsert and not remove:
|
||||||
raise OperationError(
|
raise OperationError(
|
||||||
"No update parameters, must either update or remove")
|
'No update parameters, must either update or remove')
|
||||||
|
|
||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
query = queryset._query
|
query = queryset._query
|
||||||
@ -597,7 +603,7 @@ class BaseQuerySet(object):
|
|||||||
try:
|
try:
|
||||||
if IS_PYMONGO_3:
|
if IS_PYMONGO_3:
|
||||||
if full_response:
|
if full_response:
|
||||||
msg = "With PyMongo 3+, it is not possible anymore to get the full response."
|
msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
if remove:
|
if remove:
|
||||||
result = queryset._collection.find_one_and_delete(
|
result = queryset._collection.find_one_and_delete(
|
||||||
@ -615,14 +621,14 @@ class BaseQuerySet(object):
|
|||||||
result = queryset._collection.find_and_modify(
|
result = queryset._collection.find_and_modify(
|
||||||
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
|
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
|
||||||
full_response=full_response, **self._cursor_args)
|
full_response=full_response, **self._cursor_args)
|
||||||
except pymongo.errors.DuplicateKeyError, err:
|
except pymongo.errors.DuplicateKeyError as err:
|
||||||
raise NotUniqueError(u"Update failed (%s)" % err)
|
raise NotUniqueError(u'Update failed (%s)' % err)
|
||||||
except pymongo.errors.OperationFailure, err:
|
except pymongo.errors.OperationFailure as err:
|
||||||
raise OperationError(u"Update failed (%s)" % err)
|
raise OperationError(u'Update failed (%s)' % err)
|
||||||
|
|
||||||
if full_response:
|
if full_response:
|
||||||
if result["value"] is not None:
|
if result['value'] is not None:
|
||||||
result["value"] = self._document._from_son(result["value"], only_fields=self.only_fields)
|
result['value'] = self._document._from_son(result['value'], only_fields=self.only_fields)
|
||||||
else:
|
else:
|
||||||
if result is not None:
|
if result is not None:
|
||||||
result = self._document._from_son(result, only_fields=self.only_fields)
|
result = self._document._from_son(result, only_fields=self.only_fields)
|
||||||
@ -640,7 +646,7 @@ class BaseQuerySet(object):
|
|||||||
"""
|
"""
|
||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
if not queryset._query_obj.empty:
|
if not queryset._query_obj.empty:
|
||||||
msg = "Cannot use a filter whilst using `with_id`"
|
msg = 'Cannot use a filter whilst using `with_id`'
|
||||||
raise InvalidQueryError(msg)
|
raise InvalidQueryError(msg)
|
||||||
return queryset.filter(pk=object_id).first()
|
return queryset.filter(pk=object_id).first()
|
||||||
|
|
||||||
@ -684,7 +690,7 @@ class BaseQuerySet(object):
|
|||||||
Only return instances of this document and not any inherited documents
|
Only return instances of this document and not any inherited documents
|
||||||
"""
|
"""
|
||||||
if self._document._meta.get('allow_inheritance') is True:
|
if self._document._meta.get('allow_inheritance') is True:
|
||||||
self._initial_query = {"_cls": self._document._class_name}
|
self._initial_query = {'_cls': self._document._class_name}
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -810,49 +816,56 @@ class BaseQuerySet(object):
|
|||||||
.. versionchanged:: 0.6 - Improved db_field refrence handling
|
.. versionchanged:: 0.6 - Improved db_field refrence handling
|
||||||
"""
|
"""
|
||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
field = self._fields_to_dbfields([field]).pop()
|
field = self._fields_to_dbfields([field]).pop()
|
||||||
finally:
|
except LookUpError:
|
||||||
distinct = self._dereference(queryset._cursor.distinct(field), 1,
|
pass
|
||||||
name=field, instance=self._document)
|
|
||||||
|
|
||||||
doc_field = self._document._fields.get(field.split('.', 1)[0])
|
distinct = self._dereference(queryset._cursor.distinct(field), 1,
|
||||||
instance = False
|
name=field, instance=self._document)
|
||||||
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
|
|
||||||
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
|
doc_field = self._document._fields.get(field.split('.', 1)[0])
|
||||||
ListField = _import_class('ListField')
|
instance = None
|
||||||
GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField')
|
|
||||||
if isinstance(doc_field, ListField):
|
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
|
||||||
doc_field = getattr(doc_field, "field", doc_field)
|
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
|
||||||
if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
|
ListField = _import_class('ListField')
|
||||||
instance = getattr(doc_field, "document_type", False)
|
GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField')
|
||||||
# handle distinct on subdocuments
|
if isinstance(doc_field, ListField):
|
||||||
if '.' in field:
|
doc_field = getattr(doc_field, 'field', doc_field)
|
||||||
for field_part in field.split('.')[1:]:
|
if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
|
||||||
# if looping on embedded document, get the document type instance
|
instance = getattr(doc_field, 'document_type', None)
|
||||||
if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
|
|
||||||
doc_field = instance
|
# handle distinct on subdocuments
|
||||||
# now get the subdocument
|
if '.' in field:
|
||||||
doc_field = getattr(doc_field, field_part, doc_field)
|
for field_part in field.split('.')[1:]:
|
||||||
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
|
# if looping on embedded document, get the document type instance
|
||||||
if isinstance(doc_field, ListField):
|
if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
|
||||||
doc_field = getattr(doc_field, "field", doc_field)
|
doc_field = instance
|
||||||
if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
|
# now get the subdocument
|
||||||
instance = getattr(doc_field, "document_type", False)
|
doc_field = getattr(doc_field, field_part, doc_field)
|
||||||
if instance and isinstance(doc_field, (EmbeddedDocumentField,
|
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
|
||||||
GenericEmbeddedDocumentField)):
|
if isinstance(doc_field, ListField):
|
||||||
distinct = [instance(**doc) for doc in distinct]
|
doc_field = getattr(doc_field, 'field', doc_field)
|
||||||
return distinct
|
if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
|
||||||
|
instance = getattr(doc_field, 'document_type', None)
|
||||||
|
|
||||||
|
if instance and isinstance(doc_field, (EmbeddedDocumentField,
|
||||||
|
GenericEmbeddedDocumentField)):
|
||||||
|
distinct = [instance(**doc) for doc in distinct]
|
||||||
|
|
||||||
|
return distinct
|
||||||
|
|
||||||
def only(self, *fields):
|
def only(self, *fields):
|
||||||
"""Load only a subset of this document's fields. ::
|
"""Load only a subset of this document's fields. ::
|
||||||
|
|
||||||
post = BlogPost.objects(...).only("title", "author.name")
|
post = BlogPost.objects(...).only('title', 'author.name')
|
||||||
|
|
||||||
.. note :: `only()` is chainable and will perform a union ::
|
.. note :: `only()` is chainable and will perform a union ::
|
||||||
So with the following it will fetch both: `title` and `author.name`::
|
So with the following it will fetch both: `title` and `author.name`::
|
||||||
|
|
||||||
post = BlogPost.objects.only("title").only("author.name")
|
post = BlogPost.objects.only('title').only('author.name')
|
||||||
|
|
||||||
:func:`~mongoengine.queryset.QuerySet.all_fields` will reset any
|
:func:`~mongoengine.queryset.QuerySet.all_fields` will reset any
|
||||||
field filters.
|
field filters.
|
||||||
@ -862,19 +875,19 @@ class BaseQuerySet(object):
|
|||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
.. versionchanged:: 0.5 - Added subfield support
|
.. versionchanged:: 0.5 - Added subfield support
|
||||||
"""
|
"""
|
||||||
fields = dict([(f, QueryFieldList.ONLY) for f in fields])
|
fields = {f: QueryFieldList.ONLY for f in fields}
|
||||||
self.only_fields = fields.keys()
|
self.only_fields = fields.keys()
|
||||||
return self.fields(True, **fields)
|
return self.fields(True, **fields)
|
||||||
|
|
||||||
def exclude(self, *fields):
|
def exclude(self, *fields):
|
||||||
"""Opposite to .only(), exclude some document's fields. ::
|
"""Opposite to .only(), exclude some document's fields. ::
|
||||||
|
|
||||||
post = BlogPost.objects(...).exclude("comments")
|
post = BlogPost.objects(...).exclude('comments')
|
||||||
|
|
||||||
.. note :: `exclude()` is chainable and will perform a union ::
|
.. note :: `exclude()` is chainable and will perform a union ::
|
||||||
So with the following it will exclude both: `title` and `author.name`::
|
So with the following it will exclude both: `title` and `author.name`::
|
||||||
|
|
||||||
post = BlogPost.objects.exclude("title").exclude("author.name")
|
post = BlogPost.objects.exclude('title').exclude('author.name')
|
||||||
|
|
||||||
:func:`~mongoengine.queryset.QuerySet.all_fields` will reset any
|
:func:`~mongoengine.queryset.QuerySet.all_fields` will reset any
|
||||||
field filters.
|
field filters.
|
||||||
@ -883,7 +896,7 @@ class BaseQuerySet(object):
|
|||||||
|
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
"""
|
"""
|
||||||
fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields])
|
fields = {f: QueryFieldList.EXCLUDE for f in fields}
|
||||||
return self.fields(**fields)
|
return self.fields(**fields)
|
||||||
|
|
||||||
def fields(self, _only_called=False, **kwargs):
|
def fields(self, _only_called=False, **kwargs):
|
||||||
@ -904,7 +917,7 @@ class BaseQuerySet(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Check for an operator and transform to mongo-style if there is
|
# Check for an operator and transform to mongo-style if there is
|
||||||
operators = ["slice"]
|
operators = ['slice']
|
||||||
cleaned_fields = []
|
cleaned_fields = []
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
parts = key.split('__')
|
parts = key.split('__')
|
||||||
@ -928,7 +941,7 @@ class BaseQuerySet(object):
|
|||||||
"""Include all fields. Reset all previously calls of .only() or
|
"""Include all fields. Reset all previously calls of .only() or
|
||||||
.exclude(). ::
|
.exclude(). ::
|
||||||
|
|
||||||
post = BlogPost.objects.exclude("comments").all_fields()
|
post = BlogPost.objects.exclude('comments').all_fields()
|
||||||
|
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
"""
|
"""
|
||||||
@ -955,7 +968,7 @@ class BaseQuerySet(object):
|
|||||||
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
|
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
|
||||||
for details.
|
for details.
|
||||||
"""
|
"""
|
||||||
return self._chainable_method("comment", text)
|
return self._chainable_method('comment', text)
|
||||||
|
|
||||||
def explain(self, format=False):
|
def explain(self, format=False):
|
||||||
"""Return an explain plan record for the
|
"""Return an explain plan record for the
|
||||||
@ -964,8 +977,15 @@ class BaseQuerySet(object):
|
|||||||
:param format: format the plan before returning it
|
:param format: format the plan before returning it
|
||||||
"""
|
"""
|
||||||
plan = self._cursor.explain()
|
plan = self._cursor.explain()
|
||||||
|
|
||||||
|
# TODO remove this option completely - it's useless. If somebody
|
||||||
|
# wants to pretty-print the output, they easily can.
|
||||||
if format:
|
if format:
|
||||||
|
msg = ('"format" param of BaseQuerySet.explain has been '
|
||||||
|
'deprecated and will be removed in future versions.')
|
||||||
|
warnings.warn(msg, DeprecationWarning)
|
||||||
plan = pprint.pformat(plan)
|
plan = pprint.pformat(plan)
|
||||||
|
|
||||||
return plan
|
return plan
|
||||||
|
|
||||||
# DEPRECATED. Has no more impact on PyMongo 3+
|
# DEPRECATED. Has no more impact on PyMongo 3+
|
||||||
@ -978,7 +998,7 @@ class BaseQuerySet(object):
|
|||||||
.. deprecated:: Ignored with PyMongo 3+
|
.. deprecated:: Ignored with PyMongo 3+
|
||||||
"""
|
"""
|
||||||
if IS_PYMONGO_3:
|
if IS_PYMONGO_3:
|
||||||
msg = "snapshot is deprecated as it has no impact when using PyMongo 3+."
|
msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
queryset._snapshot = enabled
|
queryset._snapshot = enabled
|
||||||
@ -1004,7 +1024,7 @@ class BaseQuerySet(object):
|
|||||||
.. deprecated:: Ignored with PyMongo 3+
|
.. deprecated:: Ignored with PyMongo 3+
|
||||||
"""
|
"""
|
||||||
if IS_PYMONGO_3:
|
if IS_PYMONGO_3:
|
||||||
msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+."
|
msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
queryset._slave_okay = enabled
|
queryset._slave_okay = enabled
|
||||||
@ -1066,7 +1086,7 @@ class BaseQuerySet(object):
|
|||||||
|
|
||||||
:param ms: the number of milliseconds before killing the query on the server
|
:param ms: the number of milliseconds before killing the query on the server
|
||||||
"""
|
"""
|
||||||
return self._chainable_method("max_time_ms", ms)
|
return self._chainable_method('max_time_ms', ms)
|
||||||
|
|
||||||
# JSON Helpers
|
# JSON Helpers
|
||||||
|
|
||||||
@ -1149,19 +1169,19 @@ class BaseQuerySet(object):
|
|||||||
|
|
||||||
MapReduceDocument = _import_class('MapReduceDocument')
|
MapReduceDocument = _import_class('MapReduceDocument')
|
||||||
|
|
||||||
if not hasattr(self._collection, "map_reduce"):
|
if not hasattr(self._collection, 'map_reduce'):
|
||||||
raise NotImplementedError("Requires MongoDB >= 1.7.1")
|
raise NotImplementedError('Requires MongoDB >= 1.7.1')
|
||||||
|
|
||||||
map_f_scope = {}
|
map_f_scope = {}
|
||||||
if isinstance(map_f, Code):
|
if isinstance(map_f, Code):
|
||||||
map_f_scope = map_f.scope
|
map_f_scope = map_f.scope
|
||||||
map_f = unicode(map_f)
|
map_f = six.text_type(map_f)
|
||||||
map_f = Code(queryset._sub_js_fields(map_f), map_f_scope)
|
map_f = Code(queryset._sub_js_fields(map_f), map_f_scope)
|
||||||
|
|
||||||
reduce_f_scope = {}
|
reduce_f_scope = {}
|
||||||
if isinstance(reduce_f, Code):
|
if isinstance(reduce_f, Code):
|
||||||
reduce_f_scope = reduce_f.scope
|
reduce_f_scope = reduce_f.scope
|
||||||
reduce_f = unicode(reduce_f)
|
reduce_f = six.text_type(reduce_f)
|
||||||
reduce_f_code = queryset._sub_js_fields(reduce_f)
|
reduce_f_code = queryset._sub_js_fields(reduce_f)
|
||||||
reduce_f = Code(reduce_f_code, reduce_f_scope)
|
reduce_f = Code(reduce_f_code, reduce_f_scope)
|
||||||
|
|
||||||
@ -1171,7 +1191,7 @@ class BaseQuerySet(object):
|
|||||||
finalize_f_scope = {}
|
finalize_f_scope = {}
|
||||||
if isinstance(finalize_f, Code):
|
if isinstance(finalize_f, Code):
|
||||||
finalize_f_scope = finalize_f.scope
|
finalize_f_scope = finalize_f.scope
|
||||||
finalize_f = unicode(finalize_f)
|
finalize_f = six.text_type(finalize_f)
|
||||||
finalize_f_code = queryset._sub_js_fields(finalize_f)
|
finalize_f_code = queryset._sub_js_fields(finalize_f)
|
||||||
finalize_f = Code(finalize_f_code, finalize_f_scope)
|
finalize_f = Code(finalize_f_code, finalize_f_scope)
|
||||||
mr_args['finalize'] = finalize_f
|
mr_args['finalize'] = finalize_f
|
||||||
@ -1187,7 +1207,7 @@ class BaseQuerySet(object):
|
|||||||
else:
|
else:
|
||||||
map_reduce_function = 'map_reduce'
|
map_reduce_function = 'map_reduce'
|
||||||
|
|
||||||
if isinstance(output, basestring):
|
if isinstance(output, six.string_types):
|
||||||
mr_args['out'] = output
|
mr_args['out'] = output
|
||||||
|
|
||||||
elif isinstance(output, dict):
|
elif isinstance(output, dict):
|
||||||
@ -1200,7 +1220,7 @@ class BaseQuerySet(object):
|
|||||||
break
|
break
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise OperationError("actionData not specified for output")
|
raise OperationError('actionData not specified for output')
|
||||||
|
|
||||||
db_alias = output.get('db_alias')
|
db_alias = output.get('db_alias')
|
||||||
remaing_args = ['db', 'sharded', 'nonAtomic']
|
remaing_args = ['db', 'sharded', 'nonAtomic']
|
||||||
@ -1430,7 +1450,7 @@ class BaseQuerySet(object):
|
|||||||
# snapshot is not handled at all by PyMongo 3+
|
# snapshot is not handled at all by PyMongo 3+
|
||||||
# TODO: evaluate similar possibilities using modifiers
|
# TODO: evaluate similar possibilities using modifiers
|
||||||
if self._snapshot:
|
if self._snapshot:
|
||||||
msg = "The snapshot option is not anymore available with PyMongo 3+"
|
msg = 'The snapshot option is not anymore available with PyMongo 3+'
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
cursor_args = {
|
cursor_args = {
|
||||||
'no_cursor_timeout': not self._timeout
|
'no_cursor_timeout': not self._timeout
|
||||||
@ -1442,7 +1462,7 @@ class BaseQuerySet(object):
|
|||||||
if fields_name not in cursor_args:
|
if fields_name not in cursor_args:
|
||||||
cursor_args[fields_name] = {}
|
cursor_args[fields_name] = {}
|
||||||
|
|
||||||
cursor_args[fields_name]['_text_score'] = {'$meta': "textScore"}
|
cursor_args[fields_name]['_text_score'] = {'$meta': 'textScore'}
|
||||||
|
|
||||||
return cursor_args
|
return cursor_args
|
||||||
|
|
||||||
@ -1497,8 +1517,8 @@ class BaseQuerySet(object):
|
|||||||
if self._mongo_query is None:
|
if self._mongo_query is None:
|
||||||
self._mongo_query = self._query_obj.to_query(self._document)
|
self._mongo_query = self._query_obj.to_query(self._document)
|
||||||
if self._class_check and self._initial_query:
|
if self._class_check and self._initial_query:
|
||||||
if "_cls" in self._mongo_query:
|
if '_cls' in self._mongo_query:
|
||||||
self._mongo_query = {"$and": [self._initial_query, self._mongo_query]}
|
self._mongo_query = {'$and': [self._initial_query, self._mongo_query]}
|
||||||
else:
|
else:
|
||||||
self._mongo_query.update(self._initial_query)
|
self._mongo_query.update(self._initial_query)
|
||||||
return self._mongo_query
|
return self._mongo_query
|
||||||
@ -1510,8 +1530,7 @@ class BaseQuerySet(object):
|
|||||||
return self.__dereference
|
return self.__dereference
|
||||||
|
|
||||||
def no_dereference(self):
|
def no_dereference(self):
|
||||||
"""Turn off any dereferencing for the results of this queryset.
|
"""Turn off any dereferencing for the results of this queryset."""
|
||||||
"""
|
|
||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
queryset._auto_dereference = False
|
queryset._auto_dereference = False
|
||||||
return queryset
|
return queryset
|
||||||
@ -1540,7 +1559,7 @@ class BaseQuerySet(object):
|
|||||||
emit(null, 1);
|
emit(null, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
""" % dict(field=field)
|
""" % {'field': field}
|
||||||
reduce_func = """
|
reduce_func = """
|
||||||
function(key, values) {
|
function(key, values) {
|
||||||
var total = 0;
|
var total = 0;
|
||||||
@ -1562,8 +1581,8 @@ class BaseQuerySet(object):
|
|||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
count = sum(frequencies.values())
|
count = sum(frequencies.values())
|
||||||
frequencies = dict([(k, float(v) / count)
|
frequencies = {k: float(v) / count
|
||||||
for k, v in frequencies.items()])
|
for k, v in frequencies.items()}
|
||||||
|
|
||||||
return frequencies
|
return frequencies
|
||||||
|
|
||||||
@ -1615,10 +1634,10 @@ class BaseQuerySet(object):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
total, data, types = self.exec_js(freq_func, field)
|
total, data, types = self.exec_js(freq_func, field)
|
||||||
values = dict([(types.get(k), int(v)) for k, v in data.iteritems()])
|
values = {types.get(k): int(v) for k, v in data.iteritems()}
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
values = dict([(k, float(v) / total) for k, v in values.items()])
|
values = {k: float(v) / total for k, v in values.items()}
|
||||||
|
|
||||||
frequencies = {}
|
frequencies = {}
|
||||||
for k, v in values.iteritems():
|
for k, v in values.iteritems():
|
||||||
@ -1640,14 +1659,14 @@ class BaseQuerySet(object):
|
|||||||
for x in document._subclasses][1:]
|
for x in document._subclasses][1:]
|
||||||
for field in fields:
|
for field in fields:
|
||||||
try:
|
try:
|
||||||
field = ".".join(f.db_field for f in
|
field = '.'.join(f.db_field for f in
|
||||||
document._lookup_field(field.split('.')))
|
document._lookup_field(field.split('.')))
|
||||||
ret.append(field)
|
ret.append(field)
|
||||||
except LookUpError, err:
|
except LookUpError as err:
|
||||||
found = False
|
found = False
|
||||||
for subdoc in subclasses:
|
for subdoc in subclasses:
|
||||||
try:
|
try:
|
||||||
subfield = ".".join(f.db_field for f in
|
subfield = '.'.join(f.db_field for f in
|
||||||
subdoc._lookup_field(field.split('.')))
|
subdoc._lookup_field(field.split('.')))
|
||||||
ret.append(subfield)
|
ret.append(subfield)
|
||||||
found = True
|
found = True
|
||||||
@ -1660,15 +1679,14 @@ class BaseQuerySet(object):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _get_order_by(self, keys):
|
def _get_order_by(self, keys):
|
||||||
"""Creates a list of order by fields
|
"""Creates a list of order by fields"""
|
||||||
"""
|
|
||||||
key_list = []
|
key_list = []
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if not key:
|
if not key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if key == '$text_score':
|
if key == '$text_score':
|
||||||
key_list.append(('_text_score', {'$meta': "textScore"}))
|
key_list.append(('_text_score', {'$meta': 'textScore'}))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
direction = pymongo.ASCENDING
|
direction = pymongo.ASCENDING
|
||||||
@ -1740,7 +1758,7 @@ class BaseQuerySet(object):
|
|||||||
# If we need to coerce types, we need to determine the
|
# If we need to coerce types, we need to determine the
|
||||||
# type of this field and use the corresponding
|
# type of this field and use the corresponding
|
||||||
# .to_python(...)
|
# .to_python(...)
|
||||||
from mongoengine.fields import EmbeddedDocumentField
|
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
|
||||||
|
|
||||||
obj = self._document
|
obj = self._document
|
||||||
for chunk in path.split('.'):
|
for chunk in path.split('.'):
|
||||||
@ -1774,7 +1792,7 @@ class BaseQuerySet(object):
|
|||||||
field_name = match.group(1).split('.')
|
field_name = match.group(1).split('.')
|
||||||
fields = self._document._lookup_field(field_name)
|
fields = self._document._lookup_field(field_name)
|
||||||
# Substitute the correct name for the field into the javascript
|
# Substitute the correct name for the field into the javascript
|
||||||
return ".".join([f.db_field for f in fields])
|
return '.'.join([f.db_field for f in fields])
|
||||||
|
|
||||||
code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code)
|
code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code)
|
||||||
code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub,
|
code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub,
|
||||||
@ -1785,21 +1803,21 @@ class BaseQuerySet(object):
|
|||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
method = getattr(queryset._cursor, method_name)
|
method = getattr(queryset._cursor, method_name)
|
||||||
method(val)
|
method(val)
|
||||||
setattr(queryset, "_" + method_name, val)
|
setattr(queryset, '_' + method_name, val)
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
# Deprecated
|
# Deprecated
|
||||||
def ensure_index(self, **kwargs):
|
def ensure_index(self, **kwargs):
|
||||||
"""Deprecated use :func:`Document.ensure_index`"""
|
"""Deprecated use :func:`Document.ensure_index`"""
|
||||||
msg = ("Doc.objects()._ensure_index() is deprecated. "
|
msg = ('Doc.objects()._ensure_index() is deprecated. '
|
||||||
"Use Doc.ensure_index() instead.")
|
'Use Doc.ensure_index() instead.')
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
self._document.__class__.ensure_index(**kwargs)
|
self._document.__class__.ensure_index(**kwargs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _ensure_indexes(self):
|
def _ensure_indexes(self):
|
||||||
"""Deprecated use :func:`~Document.ensure_indexes`"""
|
"""Deprecated use :func:`~Document.ensure_indexes`"""
|
||||||
msg = ("Doc.objects()._ensure_indexes() is deprecated. "
|
msg = ('Doc.objects()._ensure_indexes() is deprecated. '
|
||||||
"Use Doc.ensure_indexes() instead.")
|
'Use Doc.ensure_indexes() instead.')
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
self._document.__class__.ensure_indexes()
|
self._document.__class__.ensure_indexes()
|
||||||
|
@ -67,7 +67,7 @@ class QueryFieldList(object):
|
|||||||
return bool(self.fields)
|
return bool(self.fields)
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
field_list = dict((field, self.value) for field in self.fields)
|
field_list = {field: self.value for field in self.fields}
|
||||||
if self.slice:
|
if self.slice:
|
||||||
field_list.update(self.slice)
|
field_list.update(self.slice)
|
||||||
if self._id is not None:
|
if self._id is not None:
|
||||||
|
@ -53,15 +53,14 @@ class QuerySet(BaseQuerySet):
|
|||||||
return self._len
|
return self._len
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Provides the string representation of the QuerySet
|
"""Provide a string representation of the QuerySet"""
|
||||||
"""
|
|
||||||
if self._iter:
|
if self._iter:
|
||||||
return '.. queryset mid-iteration ..'
|
return '.. queryset mid-iteration ..'
|
||||||
|
|
||||||
self._populate_cache()
|
self._populate_cache()
|
||||||
data = self._result_cache[:REPR_OUTPUT_SIZE + 1]
|
data = self._result_cache[:REPR_OUTPUT_SIZE + 1]
|
||||||
if len(data) > REPR_OUTPUT_SIZE:
|
if len(data) > REPR_OUTPUT_SIZE:
|
||||||
data[-1] = "...(remaining elements truncated)..."
|
data[-1] = '...(remaining elements truncated)...'
|
||||||
return repr(data)
|
return repr(data)
|
||||||
|
|
||||||
def _iter_results(self):
|
def _iter_results(self):
|
||||||
@ -113,7 +112,7 @@ class QuerySet(BaseQuerySet):
|
|||||||
# Pull in ITER_CHUNK_SIZE docs from the database and store them in
|
# Pull in ITER_CHUNK_SIZE docs from the database and store them in
|
||||||
# the result cache.
|
# the result cache.
|
||||||
try:
|
try:
|
||||||
for i in xrange(ITER_CHUNK_SIZE):
|
for _ in xrange(ITER_CHUNK_SIZE):
|
||||||
self._result_cache.append(self.next())
|
self._result_cache.append(self.next())
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# Getting this exception means there are no more docs in the
|
# Getting this exception means there are no more docs in the
|
||||||
@ -142,7 +141,7 @@ class QuerySet(BaseQuerySet):
|
|||||||
.. versionadded:: 0.8.3 Convert to non caching queryset
|
.. versionadded:: 0.8.3 Convert to non caching queryset
|
||||||
"""
|
"""
|
||||||
if self._result_cache is not None:
|
if self._result_cache is not None:
|
||||||
raise OperationError("QuerySet already cached")
|
raise OperationError('QuerySet already cached')
|
||||||
return self.clone_into(QuerySetNoCache(self._document, self._collection))
|
return self.clone_into(QuerySetNoCache(self._document, self._collection))
|
||||||
|
|
||||||
|
|
||||||
@ -165,13 +164,14 @@ class QuerySetNoCache(BaseQuerySet):
|
|||||||
return '.. queryset mid-iteration ..'
|
return '.. queryset mid-iteration ..'
|
||||||
|
|
||||||
data = []
|
data = []
|
||||||
for i in xrange(REPR_OUTPUT_SIZE + 1):
|
for _ in xrange(REPR_OUTPUT_SIZE + 1):
|
||||||
try:
|
try:
|
||||||
data.append(self.next())
|
data.append(self.next())
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
if len(data) > REPR_OUTPUT_SIZE:
|
if len(data) > REPR_OUTPUT_SIZE:
|
||||||
data[-1] = "...(remaining elements truncated)..."
|
data[-1] = '...(remaining elements truncated)...'
|
||||||
|
|
||||||
self.rewind()
|
self.rewind()
|
||||||
return repr(data)
|
return repr(data)
|
||||||
|
@ -3,8 +3,9 @@ from collections import defaultdict
|
|||||||
from bson import ObjectId, SON
|
from bson import ObjectId, SON
|
||||||
from bson.dbref import DBRef
|
from bson.dbref import DBRef
|
||||||
import pymongo
|
import pymongo
|
||||||
|
import six
|
||||||
|
|
||||||
from mongoengine.base.fields import UPDATE_OPERATORS
|
from mongoengine.base import UPDATE_OPERATORS
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.connection import get_connection
|
from mongoengine.connection import get_connection
|
||||||
from mongoengine.errors import InvalidQueryError
|
from mongoengine.errors import InvalidQueryError
|
||||||
@ -29,12 +30,11 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
|
|||||||
|
|
||||||
# TODO make this less complex
|
# TODO make this less complex
|
||||||
def query(_doc_cls=None, **kwargs):
|
def query(_doc_cls=None, **kwargs):
|
||||||
"""Transform a query from Django-style format to Mongo format.
|
"""Transform a query from Django-style format to Mongo format."""
|
||||||
"""
|
|
||||||
mongo_query = {}
|
mongo_query = {}
|
||||||
merge_query = defaultdict(list)
|
merge_query = defaultdict(list)
|
||||||
for key, value in sorted(kwargs.items()):
|
for key, value in sorted(kwargs.items()):
|
||||||
if key == "__raw__":
|
if key == '__raw__':
|
||||||
mongo_query.update(value)
|
mongo_query.update(value)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
op = parts.pop()
|
op = parts.pop()
|
||||||
|
|
||||||
# Allow to escape operator-like field name by __
|
# Allow to escape operator-like field name by __
|
||||||
if len(parts) > 1 and parts[-1] == "":
|
if len(parts) > 1 and parts[-1] == '':
|
||||||
parts.pop()
|
parts.pop()
|
||||||
|
|
||||||
negate = False
|
negate = False
|
||||||
@ -59,7 +59,7 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
# Switch field names to proper names [set in Field(name='foo')]
|
# Switch field names to proper names [set in Field(name='foo')]
|
||||||
try:
|
try:
|
||||||
fields = _doc_cls._lookup_field(parts)
|
fields = _doc_cls._lookup_field(parts)
|
||||||
except Exception, e:
|
except Exception as e:
|
||||||
raise InvalidQueryError(e)
|
raise InvalidQueryError(e)
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
cleaned_fields = []
|
cleaned_fields = []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
append_field = True
|
append_field = True
|
||||||
if isinstance(field, basestring):
|
if isinstance(field, six.string_types):
|
||||||
parts.append(field)
|
parts.append(field)
|
||||||
append_field = False
|
append_field = False
|
||||||
# is last and CachedReferenceField
|
# is last and CachedReferenceField
|
||||||
@ -87,9 +87,9 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
|
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
|
||||||
singular_ops += STRING_OPERATORS
|
singular_ops += STRING_OPERATORS
|
||||||
if op in singular_ops:
|
if op in singular_ops:
|
||||||
if isinstance(field, basestring):
|
if isinstance(field, six.string_types):
|
||||||
if (op in STRING_OPERATORS and
|
if (op in STRING_OPERATORS and
|
||||||
isinstance(value, basestring)):
|
isinstance(value, six.string_types)):
|
||||||
StringField = _import_class('StringField')
|
StringField = _import_class('StringField')
|
||||||
value = StringField.prepare_query_value(op, value)
|
value = StringField.prepare_query_value(op, value)
|
||||||
else:
|
else:
|
||||||
@ -129,10 +129,10 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
value = query(field.field.document_type, **value)
|
value = query(field.field.document_type, **value)
|
||||||
else:
|
else:
|
||||||
value = field.prepare_query_value(op, value)
|
value = field.prepare_query_value(op, value)
|
||||||
value = {"$elemMatch": value}
|
value = {'$elemMatch': value}
|
||||||
elif op in CUSTOM_OPERATORS:
|
elif op in CUSTOM_OPERATORS:
|
||||||
NotImplementedError("Custom method '%s' has not "
|
NotImplementedError('Custom method "%s" has not '
|
||||||
"been implemented" % op)
|
'been implemented' % op)
|
||||||
elif op not in STRING_OPERATORS:
|
elif op not in STRING_OPERATORS:
|
||||||
value = {'$' + op: value}
|
value = {'$' + op: value}
|
||||||
|
|
||||||
@ -197,15 +197,16 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def update(_doc_cls=None, **update):
|
def update(_doc_cls=None, **update):
|
||||||
"""Transform an update spec from Django-style format to Mongo format.
|
"""Transform an update spec from Django-style format to Mongo
|
||||||
|
format.
|
||||||
"""
|
"""
|
||||||
mongo_update = {}
|
mongo_update = {}
|
||||||
for key, value in update.items():
|
for key, value in update.items():
|
||||||
if key == "__raw__":
|
if key == '__raw__':
|
||||||
mongo_update.update(value)
|
mongo_update.update(value)
|
||||||
continue
|
continue
|
||||||
parts = key.split('__')
|
parts = key.split('__')
|
||||||
# if there is no operator, default to "set"
|
# if there is no operator, default to 'set'
|
||||||
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
|
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
|
||||||
parts.insert(0, 'set')
|
parts.insert(0, 'set')
|
||||||
# Check for an operator and transform to mongo-style if there is
|
# Check for an operator and transform to mongo-style if there is
|
||||||
@ -224,21 +225,21 @@ def update(_doc_cls=None, **update):
|
|||||||
elif op == 'add_to_set':
|
elif op == 'add_to_set':
|
||||||
op = 'addToSet'
|
op = 'addToSet'
|
||||||
elif op == 'set_on_insert':
|
elif op == 'set_on_insert':
|
||||||
op = "setOnInsert"
|
op = 'setOnInsert'
|
||||||
|
|
||||||
match = None
|
match = None
|
||||||
if parts[-1] in COMPARISON_OPERATORS:
|
if parts[-1] in COMPARISON_OPERATORS:
|
||||||
match = parts.pop()
|
match = parts.pop()
|
||||||
|
|
||||||
# Allow to escape operator-like field name by __
|
# Allow to escape operator-like field name by __
|
||||||
if len(parts) > 1 and parts[-1] == "":
|
if len(parts) > 1 and parts[-1] == '':
|
||||||
parts.pop()
|
parts.pop()
|
||||||
|
|
||||||
if _doc_cls:
|
if _doc_cls:
|
||||||
# Switch field names to proper names [set in Field(name='foo')]
|
# Switch field names to proper names [set in Field(name='foo')]
|
||||||
try:
|
try:
|
||||||
fields = _doc_cls._lookup_field(parts)
|
fields = _doc_cls._lookup_field(parts)
|
||||||
except Exception, e:
|
except Exception as e:
|
||||||
raise InvalidQueryError(e)
|
raise InvalidQueryError(e)
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
@ -246,7 +247,7 @@ def update(_doc_cls=None, **update):
|
|||||||
appended_sub_field = False
|
appended_sub_field = False
|
||||||
for field in fields:
|
for field in fields:
|
||||||
append_field = True
|
append_field = True
|
||||||
if isinstance(field, basestring):
|
if isinstance(field, six.string_types):
|
||||||
# Convert the S operator to $
|
# Convert the S operator to $
|
||||||
if field == 'S':
|
if field == 'S':
|
||||||
field = '$'
|
field = '$'
|
||||||
@ -267,7 +268,7 @@ def update(_doc_cls=None, **update):
|
|||||||
else:
|
else:
|
||||||
field = cleaned_fields[-1]
|
field = cleaned_fields[-1]
|
||||||
|
|
||||||
GeoJsonBaseField = _import_class("GeoJsonBaseField")
|
GeoJsonBaseField = _import_class('GeoJsonBaseField')
|
||||||
if isinstance(field, GeoJsonBaseField):
|
if isinstance(field, GeoJsonBaseField):
|
||||||
value = field.to_mongo(value)
|
value = field.to_mongo(value)
|
||||||
|
|
||||||
@ -281,7 +282,7 @@ def update(_doc_cls=None, **update):
|
|||||||
value = [field.prepare_query_value(op, v) for v in value]
|
value = [field.prepare_query_value(op, v) for v in value]
|
||||||
elif field.required or value is not None:
|
elif field.required or value is not None:
|
||||||
value = field.prepare_query_value(op, value)
|
value = field.prepare_query_value(op, value)
|
||||||
elif op == "unset":
|
elif op == 'unset':
|
||||||
value = 1
|
value = 1
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
@ -291,16 +292,16 @@ def update(_doc_cls=None, **update):
|
|||||||
key = '.'.join(parts)
|
key = '.'.join(parts)
|
||||||
|
|
||||||
if not op:
|
if not op:
|
||||||
raise InvalidQueryError("Updates must supply an operation "
|
raise InvalidQueryError('Updates must supply an operation '
|
||||||
"eg: set__FIELD=value")
|
'eg: set__FIELD=value')
|
||||||
|
|
||||||
if 'pull' in op and '.' in key:
|
if 'pull' in op and '.' in key:
|
||||||
# Dot operators don't work on pull operations
|
# Dot operators don't work on pull operations
|
||||||
# unless they point to a list field
|
# unless they point to a list field
|
||||||
# Otherwise it uses nested dict syntax
|
# Otherwise it uses nested dict syntax
|
||||||
if op == 'pullAll':
|
if op == 'pullAll':
|
||||||
raise InvalidQueryError("pullAll operations only support "
|
raise InvalidQueryError('pullAll operations only support '
|
||||||
"a single field depth")
|
'a single field depth')
|
||||||
|
|
||||||
# Look for the last list field and use dot notation until there
|
# Look for the last list field and use dot notation until there
|
||||||
field_classes = [c.__class__ for c in cleaned_fields]
|
field_classes = [c.__class__ for c in cleaned_fields]
|
||||||
@ -311,7 +312,7 @@ def update(_doc_cls=None, **update):
|
|||||||
# Then process as normal
|
# Then process as normal
|
||||||
last_listField = len(
|
last_listField = len(
|
||||||
cleaned_fields) - field_classes.index(ListField)
|
cleaned_fields) - field_classes.index(ListField)
|
||||||
key = ".".join(parts[:last_listField])
|
key = '.'.join(parts[:last_listField])
|
||||||
parts = parts[last_listField:]
|
parts = parts[last_listField:]
|
||||||
parts.insert(0, key)
|
parts.insert(0, key)
|
||||||
|
|
||||||
@ -319,7 +320,7 @@ def update(_doc_cls=None, **update):
|
|||||||
for key in parts:
|
for key in parts:
|
||||||
value = {key: value}
|
value = {key: value}
|
||||||
elif op == 'addToSet' and isinstance(value, list):
|
elif op == 'addToSet' and isinstance(value, list):
|
||||||
value = {key: {"$each": value}}
|
value = {key: {'$each': value}}
|
||||||
else:
|
else:
|
||||||
value = {key: value}
|
value = {key: value}
|
||||||
key = '$' + op
|
key = '$' + op
|
||||||
@ -333,78 +334,82 @@ def update(_doc_cls=None, **update):
|
|||||||
|
|
||||||
|
|
||||||
def _geo_operator(field, op, value):
|
def _geo_operator(field, op, value):
|
||||||
"""Helper to return the query for a given geo query"""
|
"""Helper to return the query for a given geo query."""
|
||||||
if op == "max_distance":
|
if op == 'max_distance':
|
||||||
value = {'$maxDistance': value}
|
value = {'$maxDistance': value}
|
||||||
elif op == "min_distance":
|
elif op == 'min_distance':
|
||||||
value = {'$minDistance': value}
|
value = {'$minDistance': value}
|
||||||
elif field._geo_index == pymongo.GEO2D:
|
elif field._geo_index == pymongo.GEO2D:
|
||||||
if op == "within_distance":
|
if op == 'within_distance':
|
||||||
value = {'$within': {'$center': value}}
|
value = {'$within': {'$center': value}}
|
||||||
elif op == "within_spherical_distance":
|
elif op == 'within_spherical_distance':
|
||||||
value = {'$within': {'$centerSphere': value}}
|
value = {'$within': {'$centerSphere': value}}
|
||||||
elif op == "within_polygon":
|
elif op == 'within_polygon':
|
||||||
value = {'$within': {'$polygon': value}}
|
value = {'$within': {'$polygon': value}}
|
||||||
elif op == "near":
|
elif op == 'near':
|
||||||
value = {'$near': value}
|
value = {'$near': value}
|
||||||
elif op == "near_sphere":
|
elif op == 'near_sphere':
|
||||||
value = {'$nearSphere': value}
|
value = {'$nearSphere': value}
|
||||||
elif op == 'within_box':
|
elif op == 'within_box':
|
||||||
value = {'$within': {'$box': value}}
|
value = {'$within': {'$box': value}}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Geo method '%s' has not "
|
raise NotImplementedError('Geo method "%s" has not been '
|
||||||
"been implemented for a GeoPointField" % op)
|
'implemented for a GeoPointField' % op)
|
||||||
else:
|
else:
|
||||||
if op == "geo_within":
|
if op == 'geo_within':
|
||||||
value = {"$geoWithin": _infer_geometry(value)}
|
value = {'$geoWithin': _infer_geometry(value)}
|
||||||
elif op == "geo_within_box":
|
elif op == 'geo_within_box':
|
||||||
value = {"$geoWithin": {"$box": value}}
|
value = {'$geoWithin': {'$box': value}}
|
||||||
elif op == "geo_within_polygon":
|
elif op == 'geo_within_polygon':
|
||||||
value = {"$geoWithin": {"$polygon": value}}
|
value = {'$geoWithin': {'$polygon': value}}
|
||||||
elif op == "geo_within_center":
|
elif op == 'geo_within_center':
|
||||||
value = {"$geoWithin": {"$center": value}}
|
value = {'$geoWithin': {'$center': value}}
|
||||||
elif op == "geo_within_sphere":
|
elif op == 'geo_within_sphere':
|
||||||
value = {"$geoWithin": {"$centerSphere": value}}
|
value = {'$geoWithin': {'$centerSphere': value}}
|
||||||
elif op == "geo_intersects":
|
elif op == 'geo_intersects':
|
||||||
value = {"$geoIntersects": _infer_geometry(value)}
|
value = {'$geoIntersects': _infer_geometry(value)}
|
||||||
elif op == "near":
|
elif op == 'near':
|
||||||
value = {'$near': _infer_geometry(value)}
|
value = {'$near': _infer_geometry(value)}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Geo method '%s' has not "
|
raise NotImplementedError(
|
||||||
"been implemented for a %s " % (op, field._name))
|
'Geo method "%s" has not been implemented for a %s '
|
||||||
|
% (op, field._name)
|
||||||
|
)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _infer_geometry(value):
|
def _infer_geometry(value):
|
||||||
"""Helper method that tries to infer the $geometry shape for a given value"""
|
"""Helper method that tries to infer the $geometry shape for a
|
||||||
|
given value.
|
||||||
|
"""
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
if "$geometry" in value:
|
if '$geometry' in value:
|
||||||
return value
|
return value
|
||||||
elif 'coordinates' in value and 'type' in value:
|
elif 'coordinates' in value and 'type' in value:
|
||||||
return {"$geometry": value}
|
return {'$geometry': value}
|
||||||
raise InvalidQueryError("Invalid $geometry dictionary should have "
|
raise InvalidQueryError('Invalid $geometry dictionary should have '
|
||||||
"type and coordinates keys")
|
'type and coordinates keys')
|
||||||
elif isinstance(value, (list, set)):
|
elif isinstance(value, (list, set)):
|
||||||
# TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
|
# TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
|
||||||
# TODO: should both TypeError and IndexError be alike interpreted?
|
# TODO: should both TypeError and IndexError be alike interpreted?
|
||||||
|
|
||||||
try:
|
try:
|
||||||
value[0][0][0]
|
value[0][0][0]
|
||||||
return {"$geometry": {"type": "Polygon", "coordinates": value}}
|
return {'$geometry': {'type': 'Polygon', 'coordinates': value}}
|
||||||
except (TypeError, IndexError):
|
except (TypeError, IndexError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
value[0][0]
|
value[0][0]
|
||||||
return {"$geometry": {"type": "LineString", "coordinates": value}}
|
return {'$geometry': {'type': 'LineString', 'coordinates': value}}
|
||||||
except (TypeError, IndexError):
|
except (TypeError, IndexError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
value[0]
|
value[0]
|
||||||
return {"$geometry": {"type": "Point", "coordinates": value}}
|
return {'$geometry': {'type': 'Point', 'coordinates': value}}
|
||||||
except (TypeError, IndexError):
|
except (TypeError, IndexError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary "
|
raise InvalidQueryError('Invalid $geometry data. Can be either a '
|
||||||
"or (nested) lists of coordinate(s)")
|
'dictionary or (nested) lists of coordinate(s)')
|
||||||
|
@ -69,9 +69,9 @@ class QueryCompilerVisitor(QNodeVisitor):
|
|||||||
self.document = document
|
self.document = document
|
||||||
|
|
||||||
def visit_combination(self, combination):
|
def visit_combination(self, combination):
|
||||||
operator = "$and"
|
operator = '$and'
|
||||||
if combination.operation == combination.OR:
|
if combination.operation == combination.OR:
|
||||||
operator = "$or"
|
operator = '$or'
|
||||||
return {operator: combination.children}
|
return {operator: combination.children}
|
||||||
|
|
||||||
def visit_query(self, query):
|
def visit_query(self, query):
|
||||||
@ -79,8 +79,7 @@ class QueryCompilerVisitor(QNodeVisitor):
|
|||||||
|
|
||||||
|
|
||||||
class QNode(object):
|
class QNode(object):
|
||||||
"""Base class for nodes in query trees.
|
"""Base class for nodes in query trees."""
|
||||||
"""
|
|
||||||
|
|
||||||
AND = 0
|
AND = 0
|
||||||
OR = 1
|
OR = 1
|
||||||
@ -94,7 +93,8 @@ class QNode(object):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _combine(self, other, operation):
|
def _combine(self, other, operation):
|
||||||
"""Combine this node with another node into a QCombination object.
|
"""Combine this node with another node into a QCombination
|
||||||
|
object.
|
||||||
"""
|
"""
|
||||||
if getattr(other, 'empty', True):
|
if getattr(other, 'empty', True):
|
||||||
return self
|
return self
|
||||||
@ -116,8 +116,8 @@ class QNode(object):
|
|||||||
|
|
||||||
|
|
||||||
class QCombination(QNode):
|
class QCombination(QNode):
|
||||||
"""Represents the combination of several conditions by a given logical
|
"""Represents the combination of several conditions by a given
|
||||||
operator.
|
logical operator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, operation, children):
|
def __init__(self, operation, children):
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# -*- coding: utf-8 -*-
|
__all__ = ('pre_init', 'post_init', 'pre_save', 'pre_save_post_validation',
|
||||||
|
'post_save', 'pre_delete', 'post_delete')
|
||||||
__all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation',
|
|
||||||
'post_save', 'pre_delete', 'post_delete']
|
|
||||||
|
|
||||||
signals_available = False
|
signals_available = False
|
||||||
try:
|
try:
|
||||||
@ -34,6 +32,7 @@ except ImportError:
|
|||||||
temporarily_connected_to = _fail
|
temporarily_connected_to = _fail
|
||||||
del _fail
|
del _fail
|
||||||
|
|
||||||
|
|
||||||
# the namespace for code signals. If you are not mongoengine code, do
|
# the namespace for code signals. If you are not mongoengine code, do
|
||||||
# not put signals in here. Create your own namespace instead.
|
# not put signals in here. Create your own namespace instead.
|
||||||
_signals = Namespace()
|
_signals = Namespace()
|
||||||
|
12
setup.cfg
12
setup.cfg
@ -1,13 +1,11 @@
|
|||||||
[nosetests]
|
[nosetests]
|
||||||
verbosity = 2
|
verbosity=2
|
||||||
detailed-errors = 1
|
detailed-errors=1
|
||||||
cover-erase = 1
|
tests=tests
|
||||||
cover-branches = 1
|
cover-package=mongoengine
|
||||||
cover-package = mongoengine
|
|
||||||
tests = tests
|
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
ignore=E501,F401,F403,F405,I201
|
ignore=E501,F401,F403,F405,I201
|
||||||
exclude=build,dist,docs,venv,.tox,.eggs,tests
|
exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests
|
||||||
max-complexity=45
|
max-complexity=45
|
||||||
application-import-names=mongoengine,tests
|
application-import-names=mongoengine,tests
|
||||||
|
25
setup.py
25
setup.py
@ -21,8 +21,9 @@ except Exception:
|
|||||||
|
|
||||||
|
|
||||||
def get_version(version_tuple):
|
def get_version(version_tuple):
|
||||||
if not isinstance(version_tuple[-1], int):
|
"""Return the version tuple as a string, e.g. for (0, 10, 7),
|
||||||
return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1]
|
return '0.10.7'.
|
||||||
|
"""
|
||||||
return '.'.join(map(str, version_tuple))
|
return '.'.join(map(str, version_tuple))
|
||||||
|
|
||||||
|
|
||||||
@ -41,31 +42,29 @@ CLASSIFIERS = [
|
|||||||
'Operating System :: OS Independent',
|
'Operating System :: OS Independent',
|
||||||
'Programming Language :: Python',
|
'Programming Language :: Python',
|
||||||
"Programming Language :: Python :: 2",
|
"Programming Language :: Python :: 2",
|
||||||
"Programming Language :: Python :: 2.6",
|
|
||||||
"Programming Language :: Python :: 2.7",
|
"Programming Language :: Python :: 2.7",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.2",
|
|
||||||
"Programming Language :: Python :: 3.3",
|
"Programming Language :: Python :: 3.3",
|
||||||
"Programming Language :: Python :: 3.4",
|
"Programming Language :: Python :: 3.4",
|
||||||
|
"Programming Language :: Python :: 3.5",
|
||||||
"Programming Language :: Python :: Implementation :: CPython",
|
"Programming Language :: Python :: Implementation :: CPython",
|
||||||
"Programming Language :: Python :: Implementation :: PyPy",
|
"Programming Language :: Python :: Implementation :: PyPy",
|
||||||
'Topic :: Database',
|
'Topic :: Database',
|
||||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||||
]
|
]
|
||||||
|
|
||||||
extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])}
|
extra_opts = {
|
||||||
|
'packages': find_packages(exclude=['tests', 'tests.*']),
|
||||||
|
'tests_require': ['nose', 'coverage==4.2', 'blinker', 'Pillow>=2.0.0']
|
||||||
|
}
|
||||||
if sys.version_info[0] == 3:
|
if sys.version_info[0] == 3:
|
||||||
extra_opts['use_2to3'] = True
|
extra_opts['use_2to3'] = True
|
||||||
extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0']
|
if 'test' in sys.argv or 'nosetests' in sys.argv:
|
||||||
if "test" in sys.argv or "nosetests" in sys.argv:
|
|
||||||
extra_opts['packages'] = find_packages()
|
extra_opts['packages'] = find_packages()
|
||||||
extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]}
|
extra_opts['package_data'] = {
|
||||||
|
'tests': ['fields/mongoengine.png', 'fields/mongodb_leaf.png']}
|
||||||
else:
|
else:
|
||||||
# coverage 4 does not support Python 3.2 anymore
|
extra_opts['tests_require'] += ['python-dateutil']
|
||||||
extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0', 'python-dateutil']
|
|
||||||
|
|
||||||
if sys.version_info[0] == 2 and sys.version_info[1] == 6:
|
|
||||||
extra_opts['tests_require'].append('unittest2')
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='mongoengine',
|
name='mongoengine',
|
||||||
|
@ -2,4 +2,3 @@ from all_warnings import AllWarnings
|
|||||||
from document import *
|
from document import *
|
||||||
from queryset import *
|
from queryset import *
|
||||||
from fields import *
|
from fields import *
|
||||||
from migration import *
|
|
||||||
|
@ -3,8 +3,6 @@ This test has been put into a module. This is because it tests warnings that
|
|||||||
only get triggered on first hit. This way we can ensure its imported into the
|
only get triggered on first hit. This way we can ensure its imported into the
|
||||||
top level and called first by the test suite.
|
top level and called first by the test suite.
|
||||||
"""
|
"""
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from class_methods import *
|
from class_methods import *
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from bson import SON
|
from bson import SON
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
from mongoengine.connection import get_db
|
from mongoengine.connection import get_db
|
||||||
@ -143,11 +141,9 @@ class DynamicTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_three_level_complex_data_lookups(self):
|
def test_three_level_complex_data_lookups(self):
|
||||||
"""Ensure you can query three level document dynamic fields"""
|
"""Ensure you can query three level document dynamic fields"""
|
||||||
p = self.Person()
|
p = self.Person.objects.create(
|
||||||
p.misc = {'hello': {'hello2': 'world'}}
|
misc={'hello': {'hello2': 'world'}}
|
||||||
p.save()
|
)
|
||||||
# from pprint import pprint as pp; import pdb; pdb.set_trace();
|
|
||||||
print self.Person.objects(misc__hello__hello2='world')
|
|
||||||
self.assertEqual(1, self.Person.objects(misc__hello__hello2='world').count())
|
self.assertEqual(1, self.Person.objects(misc__hello__hello2='world').count())
|
||||||
|
|
||||||
def test_complex_embedded_document_validation(self):
|
def test_complex_embedded_document_validation(self):
|
||||||
|
@ -556,8 +556,8 @@ class IndexesTest(unittest.TestCase):
|
|||||||
|
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
for i in xrange(0, 10):
|
for i in range(0, 10):
|
||||||
tags = [("tag %i" % n) for n in xrange(0, i % 2)]
|
tags = [("tag %i" % n) for n in range(0, i % 2)]
|
||||||
BlogPost(tags=tags).save()
|
BlogPost(tags=tags).save()
|
||||||
|
|
||||||
self.assertEqual(BlogPost.objects.count(), 10)
|
self.assertEqual(BlogPost.objects.count(), 10)
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -253,19 +251,17 @@ class InheritanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(classes, [Human])
|
self.assertEqual(classes, [Human])
|
||||||
|
|
||||||
def test_allow_inheritance(self):
|
def test_allow_inheritance(self):
|
||||||
"""Ensure that inheritance may be disabled on simple classes and that
|
"""Ensure that inheritance is disabled by default on simple
|
||||||
_cls and _subclasses will not be used.
|
classes and that _cls will not be used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Animal(Document):
|
class Animal(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
|
|
||||||
def create_dog_class():
|
# can't inherit because Animal didn't explicitly allow inheritance
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
class Dog(Animal):
|
class Dog(Animal):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.assertRaises(ValueError, create_dog_class)
|
|
||||||
|
|
||||||
# Check that _cls etc aren't present on simple documents
|
# Check that _cls etc aren't present on simple documents
|
||||||
dog = Animal(name='dog').save()
|
dog = Animal(name='dog').save()
|
||||||
self.assertEqual(dog.to_mongo().keys(), ['_id', 'name'])
|
self.assertEqual(dog.to_mongo().keys(), ['_id', 'name'])
|
||||||
@ -275,17 +271,15 @@ class InheritanceTest(unittest.TestCase):
|
|||||||
self.assertFalse('_cls' in obj)
|
self.assertFalse('_cls' in obj)
|
||||||
|
|
||||||
def test_cant_turn_off_inheritance_on_subclass(self):
|
def test_cant_turn_off_inheritance_on_subclass(self):
|
||||||
"""Ensure if inheritance is on in a subclass you cant turn it off
|
"""Ensure if inheritance is on in a subclass you cant turn it off.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Animal(Document):
|
class Animal(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
meta = {'allow_inheritance': True}
|
meta = {'allow_inheritance': True}
|
||||||
|
|
||||||
def create_mammal_class():
|
with self.assertRaises(ValueError):
|
||||||
class Mammal(Animal):
|
class Mammal(Animal):
|
||||||
meta = {'allow_inheritance': False}
|
meta = {'allow_inheritance': False}
|
||||||
self.assertRaises(ValueError, create_mammal_class)
|
|
||||||
|
|
||||||
def test_allow_inheritance_abstract_document(self):
|
def test_allow_inheritance_abstract_document(self):
|
||||||
"""Ensure that abstract documents can set inheritance rules and that
|
"""Ensure that abstract documents can set inheritance rules and that
|
||||||
@ -298,10 +292,9 @@ class InheritanceTest(unittest.TestCase):
|
|||||||
class Animal(FinalDocument):
|
class Animal(FinalDocument):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
|
|
||||||
def create_mammal_class():
|
with self.assertRaises(ValueError):
|
||||||
class Mammal(Animal):
|
class Mammal(Animal):
|
||||||
pass
|
pass
|
||||||
self.assertRaises(ValueError, create_mammal_class)
|
|
||||||
|
|
||||||
# Check that _cls isn't present in simple documents
|
# Check that _cls isn't present in simple documents
|
||||||
doc = Animal(name='dog')
|
doc = Animal(name='dog')
|
||||||
@ -360,29 +353,26 @@ class InheritanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(berlin.pk, berlin.auto_id_0)
|
self.assertEqual(berlin.pk, berlin.auto_id_0)
|
||||||
|
|
||||||
def test_abstract_document_creation_does_not_fail(self):
|
def test_abstract_document_creation_does_not_fail(self):
|
||||||
|
|
||||||
class City(Document):
|
class City(Document):
|
||||||
continent = StringField()
|
continent = StringField()
|
||||||
meta = {'abstract': True,
|
meta = {'abstract': True,
|
||||||
'allow_inheritance': False}
|
'allow_inheritance': False}
|
||||||
|
|
||||||
bkk = City(continent='asia')
|
bkk = City(continent='asia')
|
||||||
self.assertEqual(None, bkk.pk)
|
self.assertEqual(None, bkk.pk)
|
||||||
# TODO: expected error? Shouldn't we create a new error type?
|
# TODO: expected error? Shouldn't we create a new error type?
|
||||||
self.assertRaises(KeyError, lambda: setattr(bkk, 'pk', 1))
|
with self.assertRaises(KeyError):
|
||||||
|
setattr(bkk, 'pk', 1)
|
||||||
|
|
||||||
def test_allow_inheritance_embedded_document(self):
|
def test_allow_inheritance_embedded_document(self):
|
||||||
"""Ensure embedded documents respect inheritance
|
"""Ensure embedded documents respect inheritance."""
|
||||||
"""
|
|
||||||
|
|
||||||
class Comment(EmbeddedDocument):
|
class Comment(EmbeddedDocument):
|
||||||
content = StringField()
|
content = StringField()
|
||||||
|
|
||||||
def create_special_comment():
|
with self.assertRaises(ValueError):
|
||||||
class SpecialComment(Comment):
|
class SpecialComment(Comment):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.assertRaises(ValueError, create_special_comment)
|
|
||||||
|
|
||||||
doc = Comment(content='test')
|
doc = Comment(content='test')
|
||||||
self.assertFalse('_cls' in doc.to_mongo())
|
self.assertFalse('_cls' in doc.to_mongo())
|
||||||
|
|
||||||
@ -454,11 +444,11 @@ class InheritanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(Guppy._get_collection_name(), 'fish')
|
self.assertEqual(Guppy._get_collection_name(), 'fish')
|
||||||
self.assertEqual(Human._get_collection_name(), 'human')
|
self.assertEqual(Human._get_collection_name(), 'human')
|
||||||
|
|
||||||
def create_bad_abstract():
|
# ensure that a subclass of a non-abstract class can't be abstract
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
class EvilHuman(Human):
|
class EvilHuman(Human):
|
||||||
evil = BooleanField(default=True)
|
evil = BooleanField(default=True)
|
||||||
meta = {'abstract': True}
|
meta = {'abstract': True}
|
||||||
self.assertRaises(ValueError, create_bad_abstract)
|
|
||||||
|
|
||||||
def test_abstract_embedded_documents(self):
|
def test_abstract_embedded_documents(self):
|
||||||
# 789: EmbeddedDocument shouldn't inherit abstract
|
# 789: EmbeddedDocument shouldn't inherit abstract
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
import bson
|
import bson
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@ -16,12 +13,12 @@ from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
|
|||||||
PickleDynamicEmbedded, PickleDynamicTest)
|
PickleDynamicEmbedded, PickleDynamicTest)
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
|
from mongoengine.base import get_document, _document_registry
|
||||||
|
from mongoengine.connection import get_db
|
||||||
from mongoengine.errors import (NotRegistered, InvalidDocumentError,
|
from mongoengine.errors import (NotRegistered, InvalidDocumentError,
|
||||||
InvalidQueryError, NotUniqueError,
|
InvalidQueryError, NotUniqueError,
|
||||||
FieldDoesNotExist, SaveConditionError)
|
FieldDoesNotExist, SaveConditionError)
|
||||||
from mongoengine.queryset import NULLIFY, Q
|
from mongoengine.queryset import NULLIFY, Q
|
||||||
from mongoengine.connection import get_db
|
|
||||||
from mongoengine.base import get_document
|
|
||||||
from mongoengine.context_managers import switch_db, query_counter
|
from mongoengine.context_managers import switch_db, query_counter
|
||||||
from mongoengine import signals
|
from mongoengine import signals
|
||||||
|
|
||||||
@ -102,21 +99,18 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(options['size'], 4096)
|
self.assertEqual(options['size'], 4096)
|
||||||
|
|
||||||
# Check that the document cannot be redefined with different options
|
# Check that the document cannot be redefined with different options
|
||||||
def recreate_log_document():
|
class Log(Document):
|
||||||
class Log(Document):
|
date = DateTimeField(default=datetime.now)
|
||||||
date = DateTimeField(default=datetime.now)
|
meta = {
|
||||||
meta = {
|
'max_documents': 11,
|
||||||
'max_documents': 11,
|
}
|
||||||
}
|
|
||||||
# Create the collection by accessing Document.objects
|
|
||||||
Log.objects
|
|
||||||
self.assertRaises(InvalidCollectionError, recreate_log_document)
|
|
||||||
|
|
||||||
Log.drop_collection()
|
# Accessing Document.objects creates the collection
|
||||||
|
with self.assertRaises(InvalidCollectionError):
|
||||||
|
Log.objects
|
||||||
|
|
||||||
def test_capped_collection_default(self):
|
def test_capped_collection_default(self):
|
||||||
"""Ensure that capped collections defaults work properly.
|
"""Ensure that capped collections defaults work properly."""
|
||||||
"""
|
|
||||||
class Log(Document):
|
class Log(Document):
|
||||||
date = DateTimeField(default=datetime.now)
|
date = DateTimeField(default=datetime.now)
|
||||||
meta = {
|
meta = {
|
||||||
@ -134,16 +128,14 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(options['size'], 10 * 2**20)
|
self.assertEqual(options['size'], 10 * 2**20)
|
||||||
|
|
||||||
# Check that the document with default value can be recreated
|
# Check that the document with default value can be recreated
|
||||||
def recreate_log_document():
|
class Log(Document):
|
||||||
class Log(Document):
|
date = DateTimeField(default=datetime.now)
|
||||||
date = DateTimeField(default=datetime.now)
|
meta = {
|
||||||
meta = {
|
'max_documents': 10,
|
||||||
'max_documents': 10,
|
}
|
||||||
}
|
|
||||||
# Create the collection by accessing Document.objects
|
# Create the collection by accessing Document.objects
|
||||||
Log.objects
|
Log.objects
|
||||||
recreate_log_document()
|
|
||||||
Log.drop_collection()
|
|
||||||
|
|
||||||
def test_capped_collection_no_max_size_problems(self):
|
def test_capped_collection_no_max_size_problems(self):
|
||||||
"""Ensure that capped collections with odd max_size work properly.
|
"""Ensure that capped collections with odd max_size work properly.
|
||||||
@ -166,16 +158,14 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.assertTrue(options['size'] >= 10000)
|
self.assertTrue(options['size'] >= 10000)
|
||||||
|
|
||||||
# Check that the document with odd max_size value can be recreated
|
# Check that the document with odd max_size value can be recreated
|
||||||
def recreate_log_document():
|
class Log(Document):
|
||||||
class Log(Document):
|
date = DateTimeField(default=datetime.now)
|
||||||
date = DateTimeField(default=datetime.now)
|
meta = {
|
||||||
meta = {
|
'max_size': 10000,
|
||||||
'max_size': 10000,
|
}
|
||||||
}
|
|
||||||
# Create the collection by accessing Document.objects
|
# Create the collection by accessing Document.objects
|
||||||
Log.objects
|
Log.objects
|
||||||
recreate_log_document()
|
|
||||||
Log.drop_collection()
|
|
||||||
|
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
"""Ensure that unicode representation works
|
"""Ensure that unicode representation works
|
||||||
@ -286,7 +276,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
|
|
||||||
list_stats = []
|
list_stats = []
|
||||||
|
|
||||||
for i in xrange(10):
|
for i in range(10):
|
||||||
s = Stats()
|
s = Stats()
|
||||||
s.save()
|
s.save()
|
||||||
list_stats.append(s)
|
list_stats.append(s)
|
||||||
@ -356,14 +346,14 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(User._fields['username'].db_field, '_id')
|
self.assertEqual(User._fields['username'].db_field, '_id')
|
||||||
self.assertEqual(User._meta['id_field'], 'username')
|
self.assertEqual(User._meta['id_field'], 'username')
|
||||||
|
|
||||||
def create_invalid_user():
|
# test no primary key field
|
||||||
User(name='test').save() # no primary key field
|
self.assertRaises(ValidationError, User(name='test').save)
|
||||||
self.assertRaises(ValidationError, create_invalid_user)
|
|
||||||
|
|
||||||
def define_invalid_user():
|
# define a subclass with a different primary key field than the
|
||||||
|
# parent
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
class EmailUser(User):
|
class EmailUser(User):
|
||||||
email = StringField(primary_key=True)
|
email = StringField(primary_key=True)
|
||||||
self.assertRaises(ValueError, define_invalid_user)
|
|
||||||
|
|
||||||
class EmailUser(User):
|
class EmailUser(User):
|
||||||
email = StringField()
|
email = StringField()
|
||||||
@ -411,12 +401,10 @@ class InstanceTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Mimic Place and NicePlace definitions being in a different file
|
# Mimic Place and NicePlace definitions being in a different file
|
||||||
# and the NicePlace model not being imported in at query time.
|
# and the NicePlace model not being imported in at query time.
|
||||||
from mongoengine.base import _document_registry
|
|
||||||
del(_document_registry['Place.NicePlace'])
|
del(_document_registry['Place.NicePlace'])
|
||||||
|
|
||||||
def query_without_importing_nice_place():
|
with self.assertRaises(NotRegistered):
|
||||||
print Place.objects.all()
|
list(Place.objects.all())
|
||||||
self.assertRaises(NotRegistered, query_without_importing_nice_place)
|
|
||||||
|
|
||||||
def test_document_registry_regressions(self):
|
def test_document_registry_regressions(self):
|
||||||
|
|
||||||
@ -745,7 +733,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
t.save()
|
t.save()
|
||||||
except ValidationError, e:
|
except ValidationError as e:
|
||||||
expect_msg = "Draft entries may not have a publication date."
|
expect_msg = "Draft entries may not have a publication date."
|
||||||
self.assertTrue(expect_msg in e.message)
|
self.assertTrue(expect_msg in e.message)
|
||||||
self.assertEqual(e.to_dict(), {'__all__': expect_msg})
|
self.assertEqual(e.to_dict(), {'__all__': expect_msg})
|
||||||
@ -784,7 +772,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15))
|
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15))
|
||||||
try:
|
try:
|
||||||
t.save()
|
t.save()
|
||||||
except ValidationError, e:
|
except ValidationError as e:
|
||||||
expect_msg = "Value of z != x + y"
|
expect_msg = "Value of z != x + y"
|
||||||
self.assertTrue(expect_msg in e.message)
|
self.assertTrue(expect_msg in e.message)
|
||||||
self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}})
|
self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}})
|
||||||
@ -798,8 +786,10 @@ class InstanceTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_modify_empty(self):
|
def test_modify_empty(self):
|
||||||
doc = self.Person(name="bob", age=10).save()
|
doc = self.Person(name="bob", age=10).save()
|
||||||
self.assertRaises(
|
|
||||||
InvalidDocumentError, lambda: self.Person().modify(set__age=10))
|
with self.assertRaises(InvalidDocumentError):
|
||||||
|
self.Person().modify(set__age=10)
|
||||||
|
|
||||||
self.assertDbEqual([dict(doc.to_mongo())])
|
self.assertDbEqual([dict(doc.to_mongo())])
|
||||||
|
|
||||||
def test_modify_invalid_query(self):
|
def test_modify_invalid_query(self):
|
||||||
@ -807,9 +797,8 @@ class InstanceTest(unittest.TestCase):
|
|||||||
doc2 = self.Person(name="jim", age=20).save()
|
doc2 = self.Person(name="jim", age=20).save()
|
||||||
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
|
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
|
||||||
|
|
||||||
self.assertRaises(
|
with self.assertRaises(InvalidQueryError):
|
||||||
InvalidQueryError,
|
doc1.modify({'id': doc2.id}, set__value=20)
|
||||||
lambda: doc1.modify(dict(id=doc2.id), set__value=20))
|
|
||||||
|
|
||||||
self.assertDbEqual(docs)
|
self.assertDbEqual(docs)
|
||||||
|
|
||||||
@ -818,7 +807,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
doc2 = self.Person(name="jim", age=20).save()
|
doc2 = self.Person(name="jim", age=20).save()
|
||||||
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
|
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
|
||||||
|
|
||||||
assert not doc1.modify(dict(name=doc2.name), set__age=100)
|
assert not doc1.modify({'name': doc2.name}, set__age=100)
|
||||||
|
|
||||||
self.assertDbEqual(docs)
|
self.assertDbEqual(docs)
|
||||||
|
|
||||||
@ -827,7 +816,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
doc2 = self.Person(id=ObjectId(), name="jim", age=20)
|
doc2 = self.Person(id=ObjectId(), name="jim", age=20)
|
||||||
docs = [dict(doc1.to_mongo())]
|
docs = [dict(doc1.to_mongo())]
|
||||||
|
|
||||||
assert not doc2.modify(dict(name=doc2.name), set__age=100)
|
assert not doc2.modify({'name': doc2.name}, set__age=100)
|
||||||
|
|
||||||
self.assertDbEqual(docs)
|
self.assertDbEqual(docs)
|
||||||
|
|
||||||
@ -1293,12 +1282,11 @@ class InstanceTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_document_update(self):
|
def test_document_update(self):
|
||||||
|
|
||||||
def update_not_saved_raises():
|
# try updating a non-saved document
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
person = self.Person(name='dcrosta')
|
person = self.Person(name='dcrosta')
|
||||||
person.update(set__name='Dan Crosta')
|
person.update(set__name='Dan Crosta')
|
||||||
|
|
||||||
self.assertRaises(OperationError, update_not_saved_raises)
|
|
||||||
|
|
||||||
author = self.Person(name='dcrosta')
|
author = self.Person(name='dcrosta')
|
||||||
author.save()
|
author.save()
|
||||||
|
|
||||||
@ -1308,19 +1296,17 @@ class InstanceTest(unittest.TestCase):
|
|||||||
p1 = self.Person.objects.first()
|
p1 = self.Person.objects.first()
|
||||||
self.assertEqual(p1.name, author.name)
|
self.assertEqual(p1.name, author.name)
|
||||||
|
|
||||||
def update_no_value_raises():
|
# try sending an empty update
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
person = self.Person.objects.first()
|
person = self.Person.objects.first()
|
||||||
person.update()
|
person.update()
|
||||||
|
|
||||||
self.assertRaises(OperationError, update_no_value_raises)
|
# update that doesn't explicitly specify an operator should default
|
||||||
|
# to 'set__'
|
||||||
def update_no_op_should_default_to_set():
|
person = self.Person.objects.first()
|
||||||
person = self.Person.objects.first()
|
person.update(name="Dan")
|
||||||
person.update(name="Dan")
|
person.reload()
|
||||||
person.reload()
|
self.assertEqual("Dan", person.name)
|
||||||
return person.name
|
|
||||||
|
|
||||||
self.assertEqual("Dan", update_no_op_should_default_to_set())
|
|
||||||
|
|
||||||
def test_update_unique_field(self):
|
def test_update_unique_field(self):
|
||||||
class Doc(Document):
|
class Doc(Document):
|
||||||
@ -1329,8 +1315,8 @@ class InstanceTest(unittest.TestCase):
|
|||||||
doc1 = Doc(name="first").save()
|
doc1 = Doc(name="first").save()
|
||||||
doc2 = Doc(name="second").save()
|
doc2 = Doc(name="second").save()
|
||||||
|
|
||||||
self.assertRaises(NotUniqueError, lambda:
|
with self.assertRaises(NotUniqueError):
|
||||||
doc2.update(set__name=doc1.name))
|
doc2.update(set__name=doc1.name)
|
||||||
|
|
||||||
def test_embedded_update(self):
|
def test_embedded_update(self):
|
||||||
"""
|
"""
|
||||||
@ -1848,15 +1834,13 @@ class InstanceTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_duplicate_db_fields_raise_invalid_document_error(self):
|
def test_duplicate_db_fields_raise_invalid_document_error(self):
|
||||||
"""Ensure a InvalidDocumentError is thrown if duplicate fields
|
"""Ensure a InvalidDocumentError is thrown if duplicate fields
|
||||||
declare the same db_field"""
|
declare the same db_field.
|
||||||
|
"""
|
||||||
def throw_invalid_document_error():
|
with self.assertRaises(InvalidDocumentError):
|
||||||
class Foo(Document):
|
class Foo(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
name2 = StringField(db_field='name')
|
name2 = StringField(db_field='name')
|
||||||
|
|
||||||
self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
|
|
||||||
|
|
||||||
def test_invalid_son(self):
|
def test_invalid_son(self):
|
||||||
"""Raise an error if loading invalid data"""
|
"""Raise an error if loading invalid data"""
|
||||||
class Occurrence(EmbeddedDocument):
|
class Occurrence(EmbeddedDocument):
|
||||||
@ -1868,11 +1852,13 @@ class InstanceTest(unittest.TestCase):
|
|||||||
forms = ListField(StringField(), default=list)
|
forms = ListField(StringField(), default=list)
|
||||||
occurs = ListField(EmbeddedDocumentField(Occurrence), default=list)
|
occurs = ListField(EmbeddedDocumentField(Occurrence), default=list)
|
||||||
|
|
||||||
def raise_invalid_document():
|
with self.assertRaises(InvalidDocumentError):
|
||||||
Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one',
|
Word._from_son({
|
||||||
'occurs': {"hello": None}})
|
'stem': [1, 2, 3],
|
||||||
|
'forms': 1,
|
||||||
self.assertRaises(InvalidDocumentError, raise_invalid_document)
|
'count': 'one',
|
||||||
|
'occurs': {"hello": None}
|
||||||
|
})
|
||||||
|
|
||||||
def test_reverse_delete_rule_cascade_and_nullify(self):
|
def test_reverse_delete_rule_cascade_and_nullify(self):
|
||||||
"""Ensure that a referenced document is also deleted upon deletion.
|
"""Ensure that a referenced document is also deleted upon deletion.
|
||||||
@ -2103,8 +2089,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(Bar.objects.get().foo, None)
|
self.assertEqual(Bar.objects.get().foo, None)
|
||||||
|
|
||||||
def test_invalid_reverse_delete_rule_raise_errors(self):
|
def test_invalid_reverse_delete_rule_raise_errors(self):
|
||||||
|
with self.assertRaises(InvalidDocumentError):
|
||||||
def throw_invalid_document_error():
|
|
||||||
class Blog(Document):
|
class Blog(Document):
|
||||||
content = StringField()
|
content = StringField()
|
||||||
authors = MapField(ReferenceField(
|
authors = MapField(ReferenceField(
|
||||||
@ -2114,21 +2099,15 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.Person,
|
self.Person,
|
||||||
reverse_delete_rule=NULLIFY))
|
reverse_delete_rule=NULLIFY))
|
||||||
|
|
||||||
self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
|
with self.assertRaises(InvalidDocumentError):
|
||||||
|
|
||||||
def throw_invalid_document_error_embedded():
|
|
||||||
class Parents(EmbeddedDocument):
|
class Parents(EmbeddedDocument):
|
||||||
father = ReferenceField('Person', reverse_delete_rule=DENY)
|
father = ReferenceField('Person', reverse_delete_rule=DENY)
|
||||||
mother = ReferenceField('Person', reverse_delete_rule=DENY)
|
mother = ReferenceField('Person', reverse_delete_rule=DENY)
|
||||||
|
|
||||||
self.assertRaises(
|
|
||||||
InvalidDocumentError, throw_invalid_document_error_embedded)
|
|
||||||
|
|
||||||
def test_reverse_delete_rule_cascade_recurs(self):
|
def test_reverse_delete_rule_cascade_recurs(self):
|
||||||
"""Ensure that a chain of documents is also deleted upon cascaded
|
"""Ensure that a chain of documents is also deleted upon cascaded
|
||||||
deletion.
|
deletion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class BlogPost(Document):
|
class BlogPost(Document):
|
||||||
content = StringField()
|
content = StringField()
|
||||||
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
|
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
|
||||||
@ -2344,15 +2323,14 @@ class InstanceTest(unittest.TestCase):
|
|||||||
pickle_doc.save()
|
pickle_doc.save()
|
||||||
pickle_doc.delete()
|
pickle_doc.delete()
|
||||||
|
|
||||||
def test_throw_invalid_document_error(self):
|
def test_override_method_with_field(self):
|
||||||
|
"""Test creating a field with a field name that would override
|
||||||
# test handles people trying to upsert
|
the "validate" method.
|
||||||
def throw_invalid_document_error():
|
"""
|
||||||
|
with self.assertRaises(InvalidDocumentError):
|
||||||
class Blog(Document):
|
class Blog(Document):
|
||||||
validate = DictField()
|
validate = DictField()
|
||||||
|
|
||||||
self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
|
|
||||||
|
|
||||||
def test_mutating_documents(self):
|
def test_mutating_documents(self):
|
||||||
|
|
||||||
class B(EmbeddedDocument):
|
class B(EmbeddedDocument):
|
||||||
@ -2815,11 +2793,10 @@ class InstanceTest(unittest.TestCase):
|
|||||||
log.log = "Saving"
|
log.log = "Saving"
|
||||||
log.save()
|
log.save()
|
||||||
|
|
||||||
def change_shard_key():
|
# try to change the shard key
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
log.machine = "127.0.0.1"
|
log.machine = "127.0.0.1"
|
||||||
|
|
||||||
self.assertRaises(OperationError, change_shard_key)
|
|
||||||
|
|
||||||
def test_shard_key_in_embedded_document(self):
|
def test_shard_key_in_embedded_document(self):
|
||||||
class Foo(EmbeddedDocument):
|
class Foo(EmbeddedDocument):
|
||||||
foo = StringField()
|
foo = StringField()
|
||||||
@ -2840,12 +2817,11 @@ class InstanceTest(unittest.TestCase):
|
|||||||
bar_doc.bar = 'baz'
|
bar_doc.bar = 'baz'
|
||||||
bar_doc.save()
|
bar_doc.save()
|
||||||
|
|
||||||
def change_shard_key():
|
# try to change the shard key
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
bar_doc.foo.foo = 'something'
|
bar_doc.foo.foo = 'something'
|
||||||
bar_doc.save()
|
bar_doc.save()
|
||||||
|
|
||||||
self.assertRaises(OperationError, change_shard_key)
|
|
||||||
|
|
||||||
def test_shard_key_primary(self):
|
def test_shard_key_primary(self):
|
||||||
class LogEntry(Document):
|
class LogEntry(Document):
|
||||||
machine = StringField(primary_key=True)
|
machine = StringField(primary_key=True)
|
||||||
@ -2866,11 +2842,10 @@ class InstanceTest(unittest.TestCase):
|
|||||||
log.log = "Saving"
|
log.log = "Saving"
|
||||||
log.save()
|
log.save()
|
||||||
|
|
||||||
def change_shard_key():
|
# try to change the shard key
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
log.machine = "127.0.0.1"
|
log.machine = "127.0.0.1"
|
||||||
|
|
||||||
self.assertRaises(OperationError, change_shard_key)
|
|
||||||
|
|
||||||
def test_kwargs_simple(self):
|
def test_kwargs_simple(self):
|
||||||
|
|
||||||
class Embedded(EmbeddedDocument):
|
class Embedded(EmbeddedDocument):
|
||||||
@ -2955,11 +2930,9 @@ class InstanceTest(unittest.TestCase):
|
|||||||
def test_bad_mixed_creation(self):
|
def test_bad_mixed_creation(self):
|
||||||
"""Ensure that document gives correct error when duplicating arguments
|
"""Ensure that document gives correct error when duplicating arguments
|
||||||
"""
|
"""
|
||||||
def construct_bad_instance():
|
with self.assertRaises(TypeError):
|
||||||
return self.Person("Test User", 42, name="Bad User")
|
return self.Person("Test User", 42, name="Bad User")
|
||||||
|
|
||||||
self.assertRaises(TypeError, construct_bad_instance)
|
|
||||||
|
|
||||||
def test_data_contains_id_field(self):
|
def test_data_contains_id_field(self):
|
||||||
"""Ensure that asking for _data returns 'id'
|
"""Ensure that asking for _data returns 'id'
|
||||||
"""
|
"""
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@ -60,7 +57,7 @@ class ValidatorErrorTest(unittest.TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
User().validate()
|
User().validate()
|
||||||
except ValidationError, e:
|
except ValidationError as e:
|
||||||
self.assertTrue("User:None" in e.message)
|
self.assertTrue("User:None" in e.message)
|
||||||
self.assertEqual(e.to_dict(), {
|
self.assertEqual(e.to_dict(), {
|
||||||
'username': 'Field is required',
|
'username': 'Field is required',
|
||||||
@ -70,7 +67,7 @@ class ValidatorErrorTest(unittest.TestCase):
|
|||||||
user.name = None
|
user.name = None
|
||||||
try:
|
try:
|
||||||
user.save()
|
user.save()
|
||||||
except ValidationError, e:
|
except ValidationError as e:
|
||||||
self.assertTrue("User:RossC0" in e.message)
|
self.assertTrue("User:RossC0" in e.message)
|
||||||
self.assertEqual(e.to_dict(), {
|
self.assertEqual(e.to_dict(), {
|
||||||
'name': 'Field is required'})
|
'name': 'Field is required'})
|
||||||
@ -118,7 +115,7 @@ class ValidatorErrorTest(unittest.TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
Doc(id="bad").validate()
|
Doc(id="bad").validate()
|
||||||
except ValidationError, e:
|
except ValidationError as e:
|
||||||
self.assertTrue("SubDoc:None" in e.message)
|
self.assertTrue("SubDoc:None" in e.message)
|
||||||
self.assertEqual(e.to_dict(), {
|
self.assertEqual(e.to_dict(), {
|
||||||
"e": {'val': 'OK could not be converted to int'}})
|
"e": {'val': 'OK could not be converted to int'}})
|
||||||
@ -136,7 +133,7 @@ class ValidatorErrorTest(unittest.TestCase):
|
|||||||
doc.e.val = "OK"
|
doc.e.val = "OK"
|
||||||
try:
|
try:
|
||||||
doc.save()
|
doc.save()
|
||||||
except ValidationError, e:
|
except ValidationError as e:
|
||||||
self.assertTrue("Doc:test" in e.message)
|
self.assertTrue("Doc:test" in e.message)
|
||||||
self.assertEqual(e.to_dict(), {
|
self.assertEqual(e.to_dict(), {
|
||||||
"e": {'val': 'OK could not be converted to int'}})
|
"e": {'val': 'OK could not be converted to int'}})
|
||||||
@ -156,14 +153,14 @@ class ValidatorErrorTest(unittest.TestCase):
|
|||||||
|
|
||||||
s = SubDoc()
|
s = SubDoc()
|
||||||
|
|
||||||
self.assertRaises(ValidationError, lambda: s.validate())
|
self.assertRaises(ValidationError, s.validate)
|
||||||
|
|
||||||
d1.e = s
|
d1.e = s
|
||||||
d2.e = s
|
d2.e = s
|
||||||
|
|
||||||
del d1
|
del d1
|
||||||
|
|
||||||
self.assertRaises(ValidationError, lambda: d2.validate())
|
self.assertRaises(ValidationError, d2.validate)
|
||||||
|
|
||||||
def test_parent_reference_in_child_document(self):
|
def test_parent_reference_in_child_document(self):
|
||||||
"""
|
"""
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from nose.plugins.skip import SkipTest
|
from nose.plugins.skip import SkipTest
|
||||||
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
@ -29,10 +25,9 @@ except ImportError:
|
|||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
from mongoengine.connection import get_db
|
from mongoengine.connection import get_db
|
||||||
from mongoengine.base import _document_registry
|
from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList,
|
||||||
from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList
|
_document_registry)
|
||||||
from mongoengine.errors import NotRegistered, DoesNotExist
|
from mongoengine.errors import NotRegistered, DoesNotExist
|
||||||
from mongoengine.python_support import PY3, b, bin_type
|
|
||||||
|
|
||||||
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
|
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
|
||||||
|
|
||||||
@ -653,8 +648,8 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Post UTC - microseconds are rounded (down) nearest millisecond and
|
# Post UTC - microseconds are rounded (down) nearest millisecond and
|
||||||
# dropped
|
# dropped
|
||||||
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
|
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999)
|
||||||
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01)
|
d2 = datetime.datetime(1970, 1, 1, 0, 0, 1)
|
||||||
log = LogEntry()
|
log = LogEntry()
|
||||||
log.date = d1
|
log.date = d1
|
||||||
log.save()
|
log.save()
|
||||||
@ -663,15 +658,15 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.assertEqual(log.date, d2)
|
self.assertEqual(log.date, d2)
|
||||||
|
|
||||||
# Post UTC - microseconds are rounded (down) nearest millisecond
|
# Post UTC - microseconds are rounded (down) nearest millisecond
|
||||||
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
|
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999)
|
||||||
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000)
|
d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000)
|
||||||
log.date = d1
|
log.date = d1
|
||||||
log.save()
|
log.save()
|
||||||
log.reload()
|
log.reload()
|
||||||
self.assertNotEqual(log.date, d1)
|
self.assertNotEqual(log.date, d1)
|
||||||
self.assertEqual(log.date, d2)
|
self.assertEqual(log.date, d2)
|
||||||
|
|
||||||
if not PY3:
|
if not six.PY3:
|
||||||
# Pre UTC dates microseconds below 1000 are dropped
|
# Pre UTC dates microseconds below 1000 are dropped
|
||||||
# This does not seem to be true in PY3
|
# This does not seem to be true in PY3
|
||||||
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
|
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
|
||||||
@ -691,7 +686,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
LogEntry.drop_collection()
|
LogEntry.drop_collection()
|
||||||
|
|
||||||
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01)
|
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1)
|
||||||
log = LogEntry()
|
log = LogEntry()
|
||||||
log.date = d1
|
log.date = d1
|
||||||
log.validate()
|
log.validate()
|
||||||
@ -708,8 +703,8 @@ class FieldTest(unittest.TestCase):
|
|||||||
LogEntry.drop_collection()
|
LogEntry.drop_collection()
|
||||||
|
|
||||||
# create 60 log entries
|
# create 60 log entries
|
||||||
for i in xrange(1950, 2010):
|
for i in range(1950, 2010):
|
||||||
d = datetime.datetime(i, 01, 01, 00, 00, 01)
|
d = datetime.datetime(i, 1, 1, 0, 0, 1)
|
||||||
LogEntry(date=d).save()
|
LogEntry(date=d).save()
|
||||||
|
|
||||||
self.assertEqual(LogEntry.objects.count(), 60)
|
self.assertEqual(LogEntry.objects.count(), 60)
|
||||||
@ -756,7 +751,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Post UTC - microseconds are rounded (down) nearest millisecond and
|
# Post UTC - microseconds are rounded (down) nearest millisecond and
|
||||||
# dropped - with default datetimefields
|
# dropped - with default datetimefields
|
||||||
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
|
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999)
|
||||||
log = LogEntry()
|
log = LogEntry()
|
||||||
log.date = d1
|
log.date = d1
|
||||||
log.save()
|
log.save()
|
||||||
@ -765,7 +760,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Post UTC - microseconds are rounded (down) nearest millisecond - with
|
# Post UTC - microseconds are rounded (down) nearest millisecond - with
|
||||||
# default datetimefields
|
# default datetimefields
|
||||||
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
|
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999)
|
||||||
log.date = d1
|
log.date = d1
|
||||||
log.save()
|
log.save()
|
||||||
log.reload()
|
log.reload()
|
||||||
@ -782,7 +777,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
# Pre UTC microseconds above 1000 is wonky - with default datetimefields
|
# Pre UTC microseconds above 1000 is wonky - with default datetimefields
|
||||||
# log.date has an invalid microsecond value so I can't construct
|
# log.date has an invalid microsecond value so I can't construct
|
||||||
# a date to compare.
|
# a date to compare.
|
||||||
for i in xrange(1001, 3113, 33):
|
for i in range(1001, 3113, 33):
|
||||||
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
|
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
|
||||||
log.date = d1
|
log.date = d1
|
||||||
log.save()
|
log.save()
|
||||||
@ -792,7 +787,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.assertEqual(log, log1)
|
self.assertEqual(log, log1)
|
||||||
|
|
||||||
# Test string padding
|
# Test string padding
|
||||||
microsecond = map(int, [math.pow(10, x) for x in xrange(6)])
|
microsecond = map(int, [math.pow(10, x) for x in range(6)])
|
||||||
mm = dd = hh = ii = ss = [1, 10]
|
mm = dd = hh = ii = ss = [1, 10]
|
||||||
|
|
||||||
for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond):
|
for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond):
|
||||||
@ -814,7 +809,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
LogEntry.drop_collection()
|
LogEntry.drop_collection()
|
||||||
|
|
||||||
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
|
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999)
|
||||||
log = LogEntry()
|
log = LogEntry()
|
||||||
log.date = d1
|
log.date = d1
|
||||||
log.save()
|
log.save()
|
||||||
@ -825,8 +820,8 @@ class FieldTest(unittest.TestCase):
|
|||||||
LogEntry.drop_collection()
|
LogEntry.drop_collection()
|
||||||
|
|
||||||
# create 60 log entries
|
# create 60 log entries
|
||||||
for i in xrange(1950, 2010):
|
for i in range(1950, 2010):
|
||||||
d = datetime.datetime(i, 01, 01, 00, 00, 01, 999)
|
d = datetime.datetime(i, 1, 1, 0, 0, 1, 999)
|
||||||
LogEntry(date=d).save()
|
LogEntry(date=d).save()
|
||||||
|
|
||||||
self.assertEqual(LogEntry.objects.count(), 60)
|
self.assertEqual(LogEntry.objects.count(), 60)
|
||||||
@ -1134,12 +1129,11 @@ class FieldTest(unittest.TestCase):
|
|||||||
e.mapping = [1]
|
e.mapping = [1]
|
||||||
e.save()
|
e.save()
|
||||||
|
|
||||||
def create_invalid_mapping():
|
# try creating an invalid mapping
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
e.mapping = ["abc"]
|
e.mapping = ["abc"]
|
||||||
e.save()
|
e.save()
|
||||||
|
|
||||||
self.assertRaises(ValidationError, create_invalid_mapping)
|
|
||||||
|
|
||||||
Simple.drop_collection()
|
Simple.drop_collection()
|
||||||
|
|
||||||
def test_list_field_rejects_strings(self):
|
def test_list_field_rejects_strings(self):
|
||||||
@ -1406,12 +1400,11 @@ class FieldTest(unittest.TestCase):
|
|||||||
e.mapping['someint'] = 1
|
e.mapping['someint'] = 1
|
||||||
e.save()
|
e.save()
|
||||||
|
|
||||||
def create_invalid_mapping():
|
# try creating an invalid mapping
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
e.mapping['somestring'] = "abc"
|
e.mapping['somestring'] = "abc"
|
||||||
e.save()
|
e.save()
|
||||||
|
|
||||||
self.assertRaises(ValidationError, create_invalid_mapping)
|
|
||||||
|
|
||||||
Simple.drop_collection()
|
Simple.drop_collection()
|
||||||
|
|
||||||
def test_dictfield_complex(self):
|
def test_dictfield_complex(self):
|
||||||
@ -1484,11 +1477,10 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.assertEqual(BaseDict, type(e.mapping))
|
self.assertEqual(BaseDict, type(e.mapping))
|
||||||
self.assertEqual({"ints": [3, 4]}, e.mapping)
|
self.assertEqual({"ints": [3, 4]}, e.mapping)
|
||||||
|
|
||||||
def create_invalid_mapping():
|
# try creating an invalid mapping
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
e.update(set__mapping={"somestrings": ["foo", "bar", ]})
|
e.update(set__mapping={"somestrings": ["foo", "bar", ]})
|
||||||
|
|
||||||
self.assertRaises(ValueError, create_invalid_mapping)
|
|
||||||
|
|
||||||
Simple.drop_collection()
|
Simple.drop_collection()
|
||||||
|
|
||||||
def test_mapfield(self):
|
def test_mapfield(self):
|
||||||
@ -1503,18 +1495,14 @@ class FieldTest(unittest.TestCase):
|
|||||||
e.mapping['someint'] = 1
|
e.mapping['someint'] = 1
|
||||||
e.save()
|
e.save()
|
||||||
|
|
||||||
def create_invalid_mapping():
|
with self.assertRaises(ValidationError):
|
||||||
e.mapping['somestring'] = "abc"
|
e.mapping['somestring'] = "abc"
|
||||||
e.save()
|
e.save()
|
||||||
|
|
||||||
self.assertRaises(ValidationError, create_invalid_mapping)
|
with self.assertRaises(ValidationError):
|
||||||
|
|
||||||
def create_invalid_class():
|
|
||||||
class NoDeclaredType(Document):
|
class NoDeclaredType(Document):
|
||||||
mapping = MapField()
|
mapping = MapField()
|
||||||
|
|
||||||
self.assertRaises(ValidationError, create_invalid_class)
|
|
||||||
|
|
||||||
Simple.drop_collection()
|
Simple.drop_collection()
|
||||||
|
|
||||||
def test_complex_mapfield(self):
|
def test_complex_mapfield(self):
|
||||||
@ -1543,14 +1531,10 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
|
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
|
||||||
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
|
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
|
||||||
|
|
||||||
def create_invalid_mapping():
|
with self.assertRaises(ValidationError):
|
||||||
e.mapping['someint'] = 123
|
e.mapping['someint'] = 123
|
||||||
e.save()
|
e.save()
|
||||||
|
|
||||||
self.assertRaises(ValidationError, create_invalid_mapping)
|
|
||||||
|
|
||||||
Extensible.drop_collection()
|
|
||||||
|
|
||||||
def test_embedded_mapfield_db_field(self):
|
def test_embedded_mapfield_db_field(self):
|
||||||
|
|
||||||
class Embedded(EmbeddedDocument):
|
class Embedded(EmbeddedDocument):
|
||||||
@ -1760,8 +1744,8 @@ class FieldTest(unittest.TestCase):
|
|||||||
# Reference is no longer valid
|
# Reference is no longer valid
|
||||||
foo.delete()
|
foo.delete()
|
||||||
bar = Bar.objects.get()
|
bar = Bar.objects.get()
|
||||||
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref'))
|
self.assertRaises(DoesNotExist, getattr, bar, 'ref')
|
||||||
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref'))
|
self.assertRaises(DoesNotExist, getattr, bar, 'generic_ref')
|
||||||
|
|
||||||
# When auto_dereference is disabled, there is no trouble returning DBRef
|
# When auto_dereference is disabled, there is no trouble returning DBRef
|
||||||
bar = Bar.objects.get()
|
bar = Bar.objects.get()
|
||||||
@ -2036,7 +2020,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
})
|
})
|
||||||
|
|
||||||
def test_cached_reference_fields_on_embedded_documents(self):
|
def test_cached_reference_fields_on_embedded_documents(self):
|
||||||
def build():
|
with self.assertRaises(InvalidDocumentError):
|
||||||
class Test(Document):
|
class Test(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
|
|
||||||
@ -2045,8 +2029,6 @@ class FieldTest(unittest.TestCase):
|
|||||||
'test': CachedReferenceField(Test)
|
'test': CachedReferenceField(Test)
|
||||||
})
|
})
|
||||||
|
|
||||||
self.assertRaises(InvalidDocumentError, build)
|
|
||||||
|
|
||||||
def test_cached_reference_auto_sync(self):
|
def test_cached_reference_auto_sync(self):
|
||||||
class Person(Document):
|
class Person(Document):
|
||||||
TYPES = (
|
TYPES = (
|
||||||
@ -2863,7 +2845,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
content_type = StringField()
|
content_type = StringField()
|
||||||
blob = BinaryField()
|
blob = BinaryField()
|
||||||
|
|
||||||
BLOB = b('\xe6\x00\xc4\xff\x07')
|
BLOB = six.b('\xe6\x00\xc4\xff\x07')
|
||||||
MIME_TYPE = 'application/octet-stream'
|
MIME_TYPE = 'application/octet-stream'
|
||||||
|
|
||||||
Attachment.drop_collection()
|
Attachment.drop_collection()
|
||||||
@ -2873,7 +2855,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
attachment_1 = Attachment.objects().first()
|
attachment_1 = Attachment.objects().first()
|
||||||
self.assertEqual(MIME_TYPE, attachment_1.content_type)
|
self.assertEqual(MIME_TYPE, attachment_1.content_type)
|
||||||
self.assertEqual(BLOB, bin_type(attachment_1.blob))
|
self.assertEqual(BLOB, six.binary_type(attachment_1.blob))
|
||||||
|
|
||||||
Attachment.drop_collection()
|
Attachment.drop_collection()
|
||||||
|
|
||||||
@ -2900,13 +2882,13 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
attachment_required = AttachmentRequired()
|
attachment_required = AttachmentRequired()
|
||||||
self.assertRaises(ValidationError, attachment_required.validate)
|
self.assertRaises(ValidationError, attachment_required.validate)
|
||||||
attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07'))
|
attachment_required.blob = Binary(six.b('\xe6\x00\xc4\xff\x07'))
|
||||||
attachment_required.validate()
|
attachment_required.validate()
|
||||||
|
|
||||||
attachment_size_limit = AttachmentSizeLimit(
|
attachment_size_limit = AttachmentSizeLimit(
|
||||||
blob=b('\xe6\x00\xc4\xff\x07'))
|
blob=six.b('\xe6\x00\xc4\xff\x07'))
|
||||||
self.assertRaises(ValidationError, attachment_size_limit.validate)
|
self.assertRaises(ValidationError, attachment_size_limit.validate)
|
||||||
attachment_size_limit.blob = b('\xe6\x00\xc4\xff')
|
attachment_size_limit.blob = six.b('\xe6\x00\xc4\xff')
|
||||||
attachment_size_limit.validate()
|
attachment_size_limit.validate()
|
||||||
|
|
||||||
Attachment.drop_collection()
|
Attachment.drop_collection()
|
||||||
@ -3152,7 +3134,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
shirt.validate()
|
shirt.validate()
|
||||||
except ValidationError, error:
|
except ValidationError as error:
|
||||||
# get the validation rules
|
# get the validation rules
|
||||||
error_dict = error.to_dict()
|
error_dict = error.to_dict()
|
||||||
self.assertEqual(error_dict['size'], SIZE_MESSAGE)
|
self.assertEqual(error_dict['size'], SIZE_MESSAGE)
|
||||||
@ -3181,7 +3163,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.db['mongoengine.counters'].drop()
|
self.db['mongoengine.counters'].drop()
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
|
|
||||||
for x in xrange(10):
|
for x in range(10):
|
||||||
Person(name="Person %s" % x).save()
|
Person(name="Person %s" % x).save()
|
||||||
|
|
||||||
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
|
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
|
||||||
@ -3205,7 +3187,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.db['mongoengine.counters'].drop()
|
self.db['mongoengine.counters'].drop()
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
|
|
||||||
for x in xrange(10):
|
for x in range(10):
|
||||||
Person(name="Person %s" % x).save()
|
Person(name="Person %s" % x).save()
|
||||||
|
|
||||||
self.assertEqual(Person.id.get_next_value(), 11)
|
self.assertEqual(Person.id.get_next_value(), 11)
|
||||||
@ -3220,7 +3202,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.db['mongoengine.counters'].drop()
|
self.db['mongoengine.counters'].drop()
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
|
|
||||||
for x in xrange(10):
|
for x in range(10):
|
||||||
Person(name="Person %s" % x).save()
|
Person(name="Person %s" % x).save()
|
||||||
|
|
||||||
self.assertEqual(Person.id.get_next_value(), '11')
|
self.assertEqual(Person.id.get_next_value(), '11')
|
||||||
@ -3236,7 +3218,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.db['mongoengine.counters'].drop()
|
self.db['mongoengine.counters'].drop()
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
|
|
||||||
for x in xrange(10):
|
for x in range(10):
|
||||||
Person(name="Person %s" % x).save()
|
Person(name="Person %s" % x).save()
|
||||||
|
|
||||||
c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
|
c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
|
||||||
@ -3261,7 +3243,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.db['mongoengine.counters'].drop()
|
self.db['mongoengine.counters'].drop()
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
|
|
||||||
for x in xrange(10):
|
for x in range(10):
|
||||||
Person(name="Person %s" % x).save()
|
Person(name="Person %s" % x).save()
|
||||||
|
|
||||||
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
|
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
|
||||||
@ -3323,7 +3305,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
Animal.drop_collection()
|
Animal.drop_collection()
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
|
|
||||||
for x in xrange(10):
|
for x in range(10):
|
||||||
Animal(name="Animal %s" % x).save()
|
Animal(name="Animal %s" % x).save()
|
||||||
Person(name="Person %s" % x).save()
|
Person(name="Person %s" % x).save()
|
||||||
|
|
||||||
@ -3353,7 +3335,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.db['mongoengine.counters'].drop()
|
self.db['mongoengine.counters'].drop()
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
|
|
||||||
for x in xrange(10):
|
for x in range(10):
|
||||||
p = Person(name="Person %s" % x)
|
p = Person(name="Person %s" % x)
|
||||||
p.save()
|
p.save()
|
||||||
|
|
||||||
@ -3540,7 +3522,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.assertRaises(ValidationError, post.validate)
|
self.assertRaises(ValidationError, post.validate)
|
||||||
try:
|
try:
|
||||||
post.validate()
|
post.validate()
|
||||||
except ValidationError, error:
|
except ValidationError as error:
|
||||||
# ValidationError.errors property
|
# ValidationError.errors property
|
||||||
self.assertTrue(hasattr(error, 'errors'))
|
self.assertTrue(hasattr(error, 'errors'))
|
||||||
self.assertTrue(isinstance(error.errors, dict))
|
self.assertTrue(isinstance(error.errors, dict))
|
||||||
@ -3601,8 +3583,6 @@ class FieldTest(unittest.TestCase):
|
|||||||
Ensure that tuples remain tuples when they are
|
Ensure that tuples remain tuples when they are
|
||||||
inside a ComplexBaseField
|
inside a ComplexBaseField
|
||||||
"""
|
"""
|
||||||
from mongoengine.base import BaseField
|
|
||||||
|
|
||||||
class EnumField(BaseField):
|
class EnumField(BaseField):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -3836,9 +3816,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
filtered = self.post1.comments.filter()
|
filtered = self.post1.comments.filter()
|
||||||
|
|
||||||
# Ensure nothing was changed
|
# Ensure nothing was changed
|
||||||
# < 2.6 Incompatible >
|
self.assertListEqual(filtered, self.post1.comments)
|
||||||
# self.assertListEqual(filtered, self.post1.comments)
|
|
||||||
self.assertEqual(filtered, self.post1.comments)
|
|
||||||
|
|
||||||
def test_single_keyword_filter(self):
|
def test_single_keyword_filter(self):
|
||||||
"""
|
"""
|
||||||
@ -3889,10 +3867,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
Tests the filter method of a List of Embedded Documents
|
Tests the filter method of a List of Embedded Documents
|
||||||
when the keyword is not a known keyword.
|
when the keyword is not a known keyword.
|
||||||
"""
|
"""
|
||||||
# < 2.6 Incompatible >
|
with self.assertRaises(AttributeError):
|
||||||
# with self.assertRaises(AttributeError):
|
self.post2.comments.filter(year=2)
|
||||||
# self.post2.comments.filter(year=2)
|
|
||||||
self.assertRaises(AttributeError, self.post2.comments.filter, year=2)
|
|
||||||
|
|
||||||
def test_no_keyword_exclude(self):
|
def test_no_keyword_exclude(self):
|
||||||
"""
|
"""
|
||||||
@ -3902,9 +3878,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
filtered = self.post1.comments.exclude()
|
filtered = self.post1.comments.exclude()
|
||||||
|
|
||||||
# Ensure everything was removed
|
# Ensure everything was removed
|
||||||
# < 2.6 Incompatible >
|
self.assertListEqual(filtered, [])
|
||||||
# self.assertListEqual(filtered, [])
|
|
||||||
self.assertEqual(filtered, [])
|
|
||||||
|
|
||||||
def test_single_keyword_exclude(self):
|
def test_single_keyword_exclude(self):
|
||||||
"""
|
"""
|
||||||
@ -3950,10 +3924,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
Tests the exclude method of a List of Embedded Documents
|
Tests the exclude method of a List of Embedded Documents
|
||||||
when the keyword is not a known keyword.
|
when the keyword is not a known keyword.
|
||||||
"""
|
"""
|
||||||
# < 2.6 Incompatible >
|
with self.assertRaises(AttributeError):
|
||||||
# with self.assertRaises(AttributeError):
|
self.post2.comments.exclude(year=2)
|
||||||
# self.post2.comments.exclude(year=2)
|
|
||||||
self.assertRaises(AttributeError, self.post2.comments.exclude, year=2)
|
|
||||||
|
|
||||||
def test_chained_filter_exclude(self):
|
def test_chained_filter_exclude(self):
|
||||||
"""
|
"""
|
||||||
@ -3991,10 +3963,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
single keyword.
|
single keyword.
|
||||||
"""
|
"""
|
||||||
comment = self.post1.comments.get(author='user1')
|
comment = self.post1.comments.get(author='user1')
|
||||||
|
self.assertIsInstance(comment, self.Comments)
|
||||||
# < 2.6 Incompatible >
|
|
||||||
# self.assertIsInstance(comment, self.Comments)
|
|
||||||
self.assertTrue(isinstance(comment, self.Comments))
|
|
||||||
self.assertEqual(comment.author, 'user1')
|
self.assertEqual(comment.author, 'user1')
|
||||||
|
|
||||||
def test_multi_keyword_get(self):
|
def test_multi_keyword_get(self):
|
||||||
@ -4003,10 +3972,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
multiple keywords.
|
multiple keywords.
|
||||||
"""
|
"""
|
||||||
comment = self.post2.comments.get(author='user2', message='message2')
|
comment = self.post2.comments.get(author='user2', message='message2')
|
||||||
|
self.assertIsInstance(comment, self.Comments)
|
||||||
# < 2.6 Incompatible >
|
|
||||||
# self.assertIsInstance(comment, self.Comments)
|
|
||||||
self.assertTrue(isinstance(comment, self.Comments))
|
|
||||||
self.assertEqual(comment.author, 'user2')
|
self.assertEqual(comment.author, 'user2')
|
||||||
self.assertEqual(comment.message, 'message2')
|
self.assertEqual(comment.message, 'message2')
|
||||||
|
|
||||||
@ -4015,44 +3981,32 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
Tests the get method of a List of Embedded Documents without
|
Tests the get method of a List of Embedded Documents without
|
||||||
a keyword to return multiple documents.
|
a keyword to return multiple documents.
|
||||||
"""
|
"""
|
||||||
# < 2.6 Incompatible >
|
with self.assertRaises(MultipleObjectsReturned):
|
||||||
# with self.assertRaises(MultipleObjectsReturned):
|
self.post1.comments.get()
|
||||||
# self.post1.comments.get()
|
|
||||||
self.assertRaises(MultipleObjectsReturned, self.post1.comments.get)
|
|
||||||
|
|
||||||
def test_keyword_multiple_return_get(self):
|
def test_keyword_multiple_return_get(self):
|
||||||
"""
|
"""
|
||||||
Tests the get method of a List of Embedded Documents with a keyword
|
Tests the get method of a List of Embedded Documents with a keyword
|
||||||
to return multiple documents.
|
to return multiple documents.
|
||||||
"""
|
"""
|
||||||
# < 2.6 Incompatible >
|
with self.assertRaises(MultipleObjectsReturned):
|
||||||
# with self.assertRaises(MultipleObjectsReturned):
|
self.post2.comments.get(author='user2')
|
||||||
# self.post2.comments.get(author='user2')
|
|
||||||
self.assertRaises(
|
|
||||||
MultipleObjectsReturned, self.post2.comments.get, author='user2'
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_unknown_keyword_get(self):
|
def test_unknown_keyword_get(self):
|
||||||
"""
|
"""
|
||||||
Tests the get method of a List of Embedded Documents with an
|
Tests the get method of a List of Embedded Documents with an
|
||||||
unknown keyword.
|
unknown keyword.
|
||||||
"""
|
"""
|
||||||
# < 2.6 Incompatible >
|
with self.assertRaises(AttributeError):
|
||||||
# with self.assertRaises(AttributeError):
|
self.post2.comments.get(year=2020)
|
||||||
# self.post2.comments.get(year=2020)
|
|
||||||
self.assertRaises(AttributeError, self.post2.comments.get, year=2020)
|
|
||||||
|
|
||||||
def test_no_result_get(self):
|
def test_no_result_get(self):
|
||||||
"""
|
"""
|
||||||
Tests the get method of a List of Embedded Documents where get
|
Tests the get method of a List of Embedded Documents where get
|
||||||
returns no results.
|
returns no results.
|
||||||
"""
|
"""
|
||||||
# < 2.6 Incompatible >
|
with self.assertRaises(DoesNotExist):
|
||||||
# with self.assertRaises(DoesNotExist):
|
self.post1.comments.get(author='user3')
|
||||||
# self.post1.comments.get(author='user3')
|
|
||||||
self.assertRaises(
|
|
||||||
DoesNotExist, self.post1.comments.get, author='user3'
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_first(self):
|
def test_first(self):
|
||||||
"""
|
"""
|
||||||
@ -4062,9 +4016,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
comment = self.post1.comments.first()
|
comment = self.post1.comments.first()
|
||||||
|
|
||||||
# Ensure a Comment object was returned.
|
# Ensure a Comment object was returned.
|
||||||
# < 2.6 Incompatible >
|
self.assertIsInstance(comment, self.Comments)
|
||||||
# self.assertIsInstance(comment, self.Comments)
|
|
||||||
self.assertTrue(isinstance(comment, self.Comments))
|
|
||||||
self.assertEqual(comment, self.post1.comments[0])
|
self.assertEqual(comment, self.post1.comments[0])
|
||||||
|
|
||||||
def test_create(self):
|
def test_create(self):
|
||||||
@ -4077,22 +4029,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
self.post1.save()
|
self.post1.save()
|
||||||
|
|
||||||
# Ensure the returned value is the comment object.
|
# Ensure the returned value is the comment object.
|
||||||
# < 2.6 Incompatible >
|
self.assertIsInstance(comment, self.Comments)
|
||||||
# self.assertIsInstance(comment, self.Comments)
|
|
||||||
self.assertTrue(isinstance(comment, self.Comments))
|
|
||||||
self.assertEqual(comment.author, 'user4')
|
self.assertEqual(comment.author, 'user4')
|
||||||
self.assertEqual(comment.message, 'message1')
|
self.assertEqual(comment.message, 'message1')
|
||||||
|
|
||||||
# Ensure the new comment was actually saved to the database.
|
# Ensure the new comment was actually saved to the database.
|
||||||
# < 2.6 Incompatible >
|
self.assertIn(
|
||||||
# self.assertIn(
|
comment,
|
||||||
# comment,
|
self.BlogPost.objects(comments__author='user4')[0].comments
|
||||||
# self.BlogPost.objects(comments__author='user4')[0].comments
|
|
||||||
# )
|
|
||||||
self.assertTrue(
|
|
||||||
comment in self.BlogPost.objects(
|
|
||||||
comments__author='user4'
|
|
||||||
)[0].comments
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_filtered_create(self):
|
def test_filtered_create(self):
|
||||||
@ -4107,22 +4051,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
self.post1.save()
|
self.post1.save()
|
||||||
|
|
||||||
# Ensure the returned value is the comment object.
|
# Ensure the returned value is the comment object.
|
||||||
# < 2.6 Incompatible >
|
self.assertIsInstance(comment, self.Comments)
|
||||||
# self.assertIsInstance(comment, self.Comments)
|
|
||||||
self.assertTrue(isinstance(comment, self.Comments))
|
|
||||||
self.assertEqual(comment.author, 'user4')
|
self.assertEqual(comment.author, 'user4')
|
||||||
self.assertEqual(comment.message, 'message1')
|
self.assertEqual(comment.message, 'message1')
|
||||||
|
|
||||||
# Ensure the new comment was actually saved to the database.
|
# Ensure the new comment was actually saved to the database.
|
||||||
# < 2.6 Incompatible >
|
self.assertIn(
|
||||||
# self.assertIn(
|
comment,
|
||||||
# comment,
|
self.BlogPost.objects(comments__author='user4')[0].comments
|
||||||
# self.BlogPost.objects(comments__author='user4')[0].comments
|
|
||||||
# )
|
|
||||||
self.assertTrue(
|
|
||||||
comment in self.BlogPost.objects(
|
|
||||||
comments__author='user4'
|
|
||||||
)[0].comments
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_keyword_update(self):
|
def test_no_keyword_update(self):
|
||||||
@ -4135,22 +4071,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
self.post1.save()
|
self.post1.save()
|
||||||
|
|
||||||
# Ensure that nothing was altered.
|
# Ensure that nothing was altered.
|
||||||
# < 2.6 Incompatible >
|
self.assertIn(
|
||||||
# self.assertIn(
|
original[0],
|
||||||
# original[0],
|
self.BlogPost.objects(id=self.post1.id)[0].comments
|
||||||
# self.BlogPost.objects(id=self.post1.id)[0].comments
|
|
||||||
# )
|
|
||||||
self.assertTrue(
|
|
||||||
original[0] in self.BlogPost.objects(id=self.post1.id)[0].comments
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# < 2.6 Incompatible >
|
self.assertIn(
|
||||||
# self.assertIn(
|
original[1],
|
||||||
# original[1],
|
self.BlogPost.objects(id=self.post1.id)[0].comments
|
||||||
# self.BlogPost.objects(id=self.post1.id)[0].comments
|
|
||||||
# )
|
|
||||||
self.assertTrue(
|
|
||||||
original[1] in self.BlogPost.objects(id=self.post1.id)[0].comments
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure the method returned 0 as the number of entries
|
# Ensure the method returned 0 as the number of entries
|
||||||
@ -4196,13 +4124,9 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
comments.save()
|
comments.save()
|
||||||
|
|
||||||
# Ensure that the new comment has been added to the database.
|
# Ensure that the new comment has been added to the database.
|
||||||
# < 2.6 Incompatible >
|
self.assertIn(
|
||||||
# self.assertIn(
|
new_comment,
|
||||||
# new_comment,
|
self.BlogPost.objects(id=self.post1.id)[0].comments
|
||||||
# self.BlogPost.objects(id=self.post1.id)[0].comments
|
|
||||||
# )
|
|
||||||
self.assertTrue(
|
|
||||||
new_comment in self.BlogPost.objects(id=self.post1.id)[0].comments
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete(self):
|
def test_delete(self):
|
||||||
@ -4214,23 +4138,15 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
# Ensure that all the comments under post1 were deleted in the
|
# Ensure that all the comments under post1 were deleted in the
|
||||||
# database.
|
# database.
|
||||||
# < 2.6 Incompatible >
|
self.assertListEqual(
|
||||||
# self.assertListEqual(
|
|
||||||
# self.BlogPost.objects(id=self.post1.id)[0].comments, []
|
|
||||||
# )
|
|
||||||
self.assertEqual(
|
|
||||||
self.BlogPost.objects(id=self.post1.id)[0].comments, []
|
self.BlogPost.objects(id=self.post1.id)[0].comments, []
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure that post1 comments were deleted from the list.
|
# Ensure that post1 comments were deleted from the list.
|
||||||
# < 2.6 Incompatible >
|
self.assertListEqual(self.post1.comments, [])
|
||||||
# self.assertListEqual(self.post1.comments, [])
|
|
||||||
self.assertEqual(self.post1.comments, [])
|
|
||||||
|
|
||||||
# Ensure that comments still returned a EmbeddedDocumentList object.
|
# Ensure that comments still returned a EmbeddedDocumentList object.
|
||||||
# < 2.6 Incompatible >
|
self.assertIsInstance(self.post1.comments, EmbeddedDocumentList)
|
||||||
# self.assertIsInstance(self.post1.comments, EmbeddedDocumentList)
|
|
||||||
self.assertTrue(isinstance(self.post1.comments, EmbeddedDocumentList))
|
|
||||||
|
|
||||||
# Ensure that the delete method returned 2 as the number of entries
|
# Ensure that the delete method returned 2 as the number of entries
|
||||||
# deleted from the database
|
# deleted from the database
|
||||||
@ -4270,21 +4186,15 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
|||||||
self.post1.save()
|
self.post1.save()
|
||||||
|
|
||||||
# Ensure that only the user2 comment was deleted.
|
# Ensure that only the user2 comment was deleted.
|
||||||
# < 2.6 Incompatible >
|
self.assertNotIn(
|
||||||
# self.assertNotIn(
|
comment, self.BlogPost.objects(id=self.post1.id)[0].comments
|
||||||
# comment, self.BlogPost.objects(id=self.post1.id)[0].comments
|
|
||||||
# )
|
|
||||||
self.assertTrue(
|
|
||||||
comment not in self.BlogPost.objects(id=self.post1.id)[0].comments
|
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1
|
len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure that the user2 comment no longer exists in the list.
|
# Ensure that the user2 comment no longer exists in the list.
|
||||||
# < 2.6 Incompatible >
|
self.assertNotIn(comment, self.post1.comments)
|
||||||
# self.assertNotIn(comment, self.post1.comments)
|
|
||||||
self.assertTrue(comment not in self.post1.comments)
|
|
||||||
self.assertEqual(len(self.post1.comments), 1)
|
self.assertEqual(len(self.post1.comments), 1)
|
||||||
|
|
||||||
# Ensure that the delete method returned 1 as the number of entries
|
# Ensure that the delete method returned 1 as the number of entries
|
||||||
|
@ -1,18 +1,16 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import gridfs
|
import gridfs
|
||||||
|
import six
|
||||||
|
|
||||||
from nose.plugins.skip import SkipTest
|
from nose.plugins.skip import SkipTest
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
from mongoengine.connection import get_db
|
from mongoengine.connection import get_db
|
||||||
from mongoengine.python_support import b, StringIO
|
from mongoengine.python_support import StringIO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -49,7 +47,7 @@ class FileTest(unittest.TestCase):
|
|||||||
|
|
||||||
PutFile.drop_collection()
|
PutFile.drop_collection()
|
||||||
|
|
||||||
text = b('Hello, World!')
|
text = six.b('Hello, World!')
|
||||||
content_type = 'text/plain'
|
content_type = 'text/plain'
|
||||||
|
|
||||||
putfile = PutFile()
|
putfile = PutFile()
|
||||||
@ -88,8 +86,8 @@ class FileTest(unittest.TestCase):
|
|||||||
|
|
||||||
StreamFile.drop_collection()
|
StreamFile.drop_collection()
|
||||||
|
|
||||||
text = b('Hello, World!')
|
text = six.b('Hello, World!')
|
||||||
more_text = b('Foo Bar')
|
more_text = six.b('Foo Bar')
|
||||||
content_type = 'text/plain'
|
content_type = 'text/plain'
|
||||||
|
|
||||||
streamfile = StreamFile()
|
streamfile = StreamFile()
|
||||||
@ -123,8 +121,8 @@ class FileTest(unittest.TestCase):
|
|||||||
|
|
||||||
StreamFile.drop_collection()
|
StreamFile.drop_collection()
|
||||||
|
|
||||||
text = b('Hello, World!')
|
text = six.b('Hello, World!')
|
||||||
more_text = b('Foo Bar')
|
more_text = six.b('Foo Bar')
|
||||||
content_type = 'text/plain'
|
content_type = 'text/plain'
|
||||||
|
|
||||||
streamfile = StreamFile()
|
streamfile = StreamFile()
|
||||||
@ -155,8 +153,8 @@ class FileTest(unittest.TestCase):
|
|||||||
class SetFile(Document):
|
class SetFile(Document):
|
||||||
the_file = FileField()
|
the_file = FileField()
|
||||||
|
|
||||||
text = b('Hello, World!')
|
text = six.b('Hello, World!')
|
||||||
more_text = b('Foo Bar')
|
more_text = six.b('Foo Bar')
|
||||||
|
|
||||||
SetFile.drop_collection()
|
SetFile.drop_collection()
|
||||||
|
|
||||||
@ -185,7 +183,7 @@ class FileTest(unittest.TestCase):
|
|||||||
GridDocument.drop_collection()
|
GridDocument.drop_collection()
|
||||||
|
|
||||||
with tempfile.TemporaryFile() as f:
|
with tempfile.TemporaryFile() as f:
|
||||||
f.write(b("Hello World!"))
|
f.write(six.b("Hello World!"))
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
# Test without default
|
# Test without default
|
||||||
@ -202,7 +200,7 @@ class FileTest(unittest.TestCase):
|
|||||||
self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id)
|
self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id)
|
||||||
|
|
||||||
# Test with default
|
# Test with default
|
||||||
doc_d = GridDocument(the_file=b(''))
|
doc_d = GridDocument(the_file=six.b(''))
|
||||||
doc_d.save()
|
doc_d.save()
|
||||||
|
|
||||||
doc_e = GridDocument.objects.with_id(doc_d.id)
|
doc_e = GridDocument.objects.with_id(doc_d.id)
|
||||||
@ -228,7 +226,7 @@ class FileTest(unittest.TestCase):
|
|||||||
# First instance
|
# First instance
|
||||||
test_file = TestFile()
|
test_file = TestFile()
|
||||||
test_file.name = "Hello, World!"
|
test_file.name = "Hello, World!"
|
||||||
test_file.the_file.put(b('Hello, World!'))
|
test_file.the_file.put(six.b('Hello, World!'))
|
||||||
test_file.save()
|
test_file.save()
|
||||||
|
|
||||||
# Second instance
|
# Second instance
|
||||||
@ -282,7 +280,7 @@ class FileTest(unittest.TestCase):
|
|||||||
|
|
||||||
test_file = TestFile()
|
test_file = TestFile()
|
||||||
self.assertFalse(bool(test_file.the_file))
|
self.assertFalse(bool(test_file.the_file))
|
||||||
test_file.the_file.put(b('Hello, World!'), content_type='text/plain')
|
test_file.the_file.put(six.b('Hello, World!'), content_type='text/plain')
|
||||||
test_file.save()
|
test_file.save()
|
||||||
self.assertTrue(bool(test_file.the_file))
|
self.assertTrue(bool(test_file.the_file))
|
||||||
|
|
||||||
@ -302,7 +300,7 @@ class FileTest(unittest.TestCase):
|
|||||||
class TestFile(Document):
|
class TestFile(Document):
|
||||||
the_file = FileField()
|
the_file = FileField()
|
||||||
|
|
||||||
text = b('Hello, World!')
|
text = six.b('Hello, World!')
|
||||||
content_type = 'text/plain'
|
content_type = 'text/plain'
|
||||||
|
|
||||||
testfile = TestFile()
|
testfile = TestFile()
|
||||||
@ -346,7 +344,7 @@ class FileTest(unittest.TestCase):
|
|||||||
testfile.the_file.put(text, content_type=content_type, filename="hello")
|
testfile.the_file.put(text, content_type=content_type, filename="hello")
|
||||||
testfile.save()
|
testfile.save()
|
||||||
|
|
||||||
text = b('Bonjour, World!')
|
text = six.b('Bonjour, World!')
|
||||||
testfile.the_file.replace(text, content_type=content_type, filename="hello")
|
testfile.the_file.replace(text, content_type=content_type, filename="hello")
|
||||||
testfile.save()
|
testfile.save()
|
||||||
|
|
||||||
@ -372,14 +370,14 @@ class FileTest(unittest.TestCase):
|
|||||||
TestImage.drop_collection()
|
TestImage.drop_collection()
|
||||||
|
|
||||||
with tempfile.TemporaryFile() as f:
|
with tempfile.TemporaryFile() as f:
|
||||||
f.write(b("Hello World!"))
|
f.write(six.b("Hello World!"))
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
t = TestImage()
|
t = TestImage()
|
||||||
try:
|
try:
|
||||||
t.image.put(f)
|
t.image.put(f)
|
||||||
self.fail("Should have raised an invalidation error")
|
self.fail("Should have raised an invalidation error")
|
||||||
except ValidationError, e:
|
except ValidationError as e:
|
||||||
self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f)
|
self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f)
|
||||||
|
|
||||||
t = TestImage()
|
t = TestImage()
|
||||||
@ -496,7 +494,7 @@ class FileTest(unittest.TestCase):
|
|||||||
# First instance
|
# First instance
|
||||||
test_file = TestFile()
|
test_file = TestFile()
|
||||||
test_file.name = "Hello, World!"
|
test_file.name = "Hello, World!"
|
||||||
test_file.the_file.put(b('Hello, World!'),
|
test_file.the_file.put(six.b('Hello, World!'),
|
||||||
name="hello.txt")
|
name="hello.txt")
|
||||||
test_file.save()
|
test_file.save()
|
||||||
|
|
||||||
@ -504,16 +502,15 @@ class FileTest(unittest.TestCase):
|
|||||||
self.assertEqual(data.get('name'), 'hello.txt')
|
self.assertEqual(data.get('name'), 'hello.txt')
|
||||||
|
|
||||||
test_file = TestFile.objects.first()
|
test_file = TestFile.objects.first()
|
||||||
self.assertEqual(test_file.the_file.read(),
|
self.assertEqual(test_file.the_file.read(), six.b('Hello, World!'))
|
||||||
b('Hello, World!'))
|
|
||||||
|
|
||||||
test_file = TestFile.objects.first()
|
test_file = TestFile.objects.first()
|
||||||
test_file.the_file = b('HELLO, WORLD!')
|
test_file.the_file = six.b('HELLO, WORLD!')
|
||||||
test_file.save()
|
test_file.save()
|
||||||
|
|
||||||
test_file = TestFile.objects.first()
|
test_file = TestFile.objects.first()
|
||||||
self.assertEqual(test_file.the_file.read(),
|
self.assertEqual(test_file.the_file.read(),
|
||||||
b('HELLO, WORLD!'))
|
six.b('HELLO, WORLD!'))
|
||||||
|
|
||||||
def test_copyable(self):
|
def test_copyable(self):
|
||||||
class PutFile(Document):
|
class PutFile(Document):
|
||||||
@ -521,7 +518,7 @@ class FileTest(unittest.TestCase):
|
|||||||
|
|
||||||
PutFile.drop_collection()
|
PutFile.drop_collection()
|
||||||
|
|
||||||
text = b('Hello, World!')
|
text = six.b('Hello, World!')
|
||||||
content_type = 'text/plain'
|
content_type = 'text/plain'
|
||||||
|
|
||||||
putfile = PutFile()
|
putfile = PutFile()
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
import unittest
|
|
||||||
|
|
||||||
from convert_to_new_inheritance_model import *
|
|
||||||
from decimalfield_as_float import *
|
|
||||||
from referencefield_dbref_to_object_id import *
|
|
||||||
from turn_off_inheritance import *
|
|
||||||
from uuidfield_to_binary import *
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
@ -1,51 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from mongoengine import Document, connect
|
|
||||||
from mongoengine.connection import get_db
|
|
||||||
from mongoengine.fields import StringField
|
|
||||||
|
|
||||||
__all__ = ('ConvertToNewInheritanceModel', )
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertToNewInheritanceModel(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
connect(db='mongoenginetest')
|
|
||||||
self.db = get_db()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
for collection in self.db.collection_names():
|
|
||||||
if 'system.' in collection:
|
|
||||||
continue
|
|
||||||
self.db.drop_collection(collection)
|
|
||||||
|
|
||||||
def test_how_to_convert_to_the_new_inheritance_model(self):
|
|
||||||
"""Demonstrates migrating from 0.7 to 0.8
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. Declaration of the class
|
|
||||||
class Animal(Document):
|
|
||||||
name = StringField()
|
|
||||||
meta = {
|
|
||||||
'allow_inheritance': True,
|
|
||||||
'indexes': ['name']
|
|
||||||
}
|
|
||||||
|
|
||||||
# 2. Remove _types
|
|
||||||
collection = Animal._get_collection()
|
|
||||||
collection.update({}, {"$unset": {"_types": 1}}, multi=True)
|
|
||||||
|
|
||||||
# 3. Confirm extra data is removed
|
|
||||||
count = collection.find({'_types': {"$exists": True}}).count()
|
|
||||||
self.assertEqual(0, count)
|
|
||||||
|
|
||||||
# 4. Remove indexes
|
|
||||||
info = collection.index_information()
|
|
||||||
indexes_to_drop = [key for key, value in info.iteritems()
|
|
||||||
if '_types' in dict(value['key'])]
|
|
||||||
for index in indexes_to_drop:
|
|
||||||
collection.drop_index(index)
|
|
||||||
|
|
||||||
# 5. Recreate indexes
|
|
||||||
Animal.ensure_indexes()
|
|
@ -1,50 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import unittest
|
|
||||||
import decimal
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
from mongoengine import Document, connect
|
|
||||||
from mongoengine.connection import get_db
|
|
||||||
from mongoengine.fields import StringField, DecimalField, ListField
|
|
||||||
|
|
||||||
__all__ = ('ConvertDecimalField', )
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertDecimalField(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
connect(db='mongoenginetest')
|
|
||||||
self.db = get_db()
|
|
||||||
|
|
||||||
def test_how_to_convert_decimal_fields(self):
|
|
||||||
"""Demonstrates migrating from 0.7 to 0.8
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. Old definition - using dbrefs
|
|
||||||
class Person(Document):
|
|
||||||
name = StringField()
|
|
||||||
money = DecimalField(force_string=True)
|
|
||||||
monies = ListField(DecimalField(force_string=True))
|
|
||||||
|
|
||||||
Person.drop_collection()
|
|
||||||
Person(name="Wilson Jr", money=Decimal("2.50"),
|
|
||||||
monies=[Decimal("2.10"), Decimal("5.00")]).save()
|
|
||||||
|
|
||||||
# 2. Start the migration by changing the schema
|
|
||||||
# Change DecimalField - add precision and rounding settings
|
|
||||||
class Person(Document):
|
|
||||||
name = StringField()
|
|
||||||
money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP)
|
|
||||||
monies = ListField(DecimalField(precision=2,
|
|
||||||
rounding=decimal.ROUND_HALF_UP))
|
|
||||||
|
|
||||||
# 3. Loop all the objects and mark parent as changed
|
|
||||||
for p in Person.objects:
|
|
||||||
p._mark_as_changed('money')
|
|
||||||
p._mark_as_changed('monies')
|
|
||||||
p.save()
|
|
||||||
|
|
||||||
# 4. Confirmation of the fix!
|
|
||||||
wilson = Person.objects(name="Wilson Jr").as_pymongo()[0]
|
|
||||||
self.assertTrue(isinstance(wilson['money'], float))
|
|
||||||
self.assertTrue(all([isinstance(m, float) for m in wilson['monies']]))
|
|
@ -1,52 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from mongoengine import Document, connect
|
|
||||||
from mongoengine.connection import get_db
|
|
||||||
from mongoengine.fields import StringField, ReferenceField, ListField
|
|
||||||
|
|
||||||
__all__ = ('ConvertToObjectIdsModel', )
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertToObjectIdsModel(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
connect(db='mongoenginetest')
|
|
||||||
self.db = get_db()
|
|
||||||
|
|
||||||
def test_how_to_convert_to_object_id_reference_fields(self):
|
|
||||||
"""Demonstrates migrating from 0.7 to 0.8
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. Old definition - using dbrefs
|
|
||||||
class Person(Document):
|
|
||||||
name = StringField()
|
|
||||||
parent = ReferenceField('self', dbref=True)
|
|
||||||
friends = ListField(ReferenceField('self', dbref=True))
|
|
||||||
|
|
||||||
Person.drop_collection()
|
|
||||||
|
|
||||||
p1 = Person(name="Wilson", parent=None).save()
|
|
||||||
f1 = Person(name="John", parent=None).save()
|
|
||||||
f2 = Person(name="Paul", parent=None).save()
|
|
||||||
f3 = Person(name="George", parent=None).save()
|
|
||||||
f4 = Person(name="Ringo", parent=None).save()
|
|
||||||
Person(name="Wilson Jr", parent=p1, friends=[f1, f2, f3, f4]).save()
|
|
||||||
|
|
||||||
# 2. Start the migration by changing the schema
|
|
||||||
# Change ReferenceField as now dbref defaults to False
|
|
||||||
class Person(Document):
|
|
||||||
name = StringField()
|
|
||||||
parent = ReferenceField('self')
|
|
||||||
friends = ListField(ReferenceField('self'))
|
|
||||||
|
|
||||||
# 3. Loop all the objects and mark parent as changed
|
|
||||||
for p in Person.objects:
|
|
||||||
p._mark_as_changed('parent')
|
|
||||||
p._mark_as_changed('friends')
|
|
||||||
p.save()
|
|
||||||
|
|
||||||
# 4. Confirmation of the fix!
|
|
||||||
wilson = Person.objects(name="Wilson Jr").as_pymongo()[0]
|
|
||||||
self.assertEqual(p1.id, wilson['parent'])
|
|
||||||
self.assertEqual([f1.id, f2.id, f3.id, f4.id], wilson['friends'])
|
|
@ -1,62 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from mongoengine import Document, connect
|
|
||||||
from mongoengine.connection import get_db
|
|
||||||
from mongoengine.fields import StringField
|
|
||||||
|
|
||||||
__all__ = ('TurnOffInheritanceTest', )
|
|
||||||
|
|
||||||
|
|
||||||
class TurnOffInheritanceTest(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
connect(db='mongoenginetest')
|
|
||||||
self.db = get_db()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
for collection in self.db.collection_names():
|
|
||||||
if 'system.' in collection:
|
|
||||||
continue
|
|
||||||
self.db.drop_collection(collection)
|
|
||||||
|
|
||||||
def test_how_to_turn_off_inheritance(self):
|
|
||||||
"""Demonstrates migrating from allow_inheritance = True to False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. Old declaration of the class
|
|
||||||
|
|
||||||
class Animal(Document):
|
|
||||||
name = StringField()
|
|
||||||
meta = {
|
|
||||||
'allow_inheritance': True,
|
|
||||||
'indexes': ['name']
|
|
||||||
}
|
|
||||||
|
|
||||||
# 2. Turn off inheritance
|
|
||||||
class Animal(Document):
|
|
||||||
name = StringField()
|
|
||||||
meta = {
|
|
||||||
'allow_inheritance': False,
|
|
||||||
'indexes': ['name']
|
|
||||||
}
|
|
||||||
|
|
||||||
# 3. Remove _types and _cls
|
|
||||||
collection = Animal._get_collection()
|
|
||||||
collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True)
|
|
||||||
|
|
||||||
# 3. Confirm extra data is removed
|
|
||||||
count = collection.find({"$or": [{'_types': {"$exists": True}},
|
|
||||||
{'_cls': {"$exists": True}}]}).count()
|
|
||||||
assert count == 0
|
|
||||||
|
|
||||||
# 4. Remove indexes
|
|
||||||
info = collection.index_information()
|
|
||||||
indexes_to_drop = [key for key, value in info.iteritems()
|
|
||||||
if '_types' in dict(value['key'])
|
|
||||||
or '_cls' in dict(value['key'])]
|
|
||||||
for index in indexes_to_drop:
|
|
||||||
collection.drop_index(index)
|
|
||||||
|
|
||||||
# 5. Recreate indexes
|
|
||||||
Animal.ensure_indexes()
|
|
@ -1,48 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import unittest
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from mongoengine import Document, connect
|
|
||||||
from mongoengine.connection import get_db
|
|
||||||
from mongoengine.fields import StringField, UUIDField, ListField
|
|
||||||
|
|
||||||
__all__ = ('ConvertToBinaryUUID', )
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertToBinaryUUID(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
connect(db='mongoenginetest')
|
|
||||||
self.db = get_db()
|
|
||||||
|
|
||||||
def test_how_to_convert_to_binary_uuid_fields(self):
|
|
||||||
"""Demonstrates migrating from 0.7 to 0.8
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. Old definition - using dbrefs
|
|
||||||
class Person(Document):
|
|
||||||
name = StringField()
|
|
||||||
uuid = UUIDField(binary=False)
|
|
||||||
uuids = ListField(UUIDField(binary=False))
|
|
||||||
|
|
||||||
Person.drop_collection()
|
|
||||||
Person(name="Wilson Jr", uuid=uuid.uuid4(),
|
|
||||||
uuids=[uuid.uuid4(), uuid.uuid4()]).save()
|
|
||||||
|
|
||||||
# 2. Start the migration by changing the schema
|
|
||||||
# Change UUIDFIeld as now binary defaults to True
|
|
||||||
class Person(Document):
|
|
||||||
name = StringField()
|
|
||||||
uuid = UUIDField()
|
|
||||||
uuids = ListField(UUIDField())
|
|
||||||
|
|
||||||
# 3. Loop all the objects and mark parent as changed
|
|
||||||
for p in Person.objects:
|
|
||||||
p._mark_as_changed('uuid')
|
|
||||||
p._mark_as_changed('uuids')
|
|
||||||
p.save()
|
|
||||||
|
|
||||||
# 4. Confirmation of the fix!
|
|
||||||
wilson = Person.objects(name="Wilson Jr").as_pymongo()[0]
|
|
||||||
self.assertTrue(isinstance(wilson['uuid'], uuid.UUID))
|
|
||||||
self.assertTrue(all([isinstance(u, uuid.UUID) for u in wilson['uuids']]))
|
|
@ -1,6 +1,3 @@
|
|||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
@ -95,7 +92,7 @@ class OnlyExcludeAllTest(unittest.TestCase):
|
|||||||
exclude = ['d', 'e']
|
exclude = ['d', 'e']
|
||||||
only = ['b', 'c']
|
only = ['b', 'c']
|
||||||
|
|
||||||
qs = MyDoc.objects.fields(**dict(((i, 1) for i in include)))
|
qs = MyDoc.objects.fields(**{i: 1 for i in include})
|
||||||
self.assertEqual(qs._loaded_fields.as_dict(),
|
self.assertEqual(qs._loaded_fields.as_dict(),
|
||||||
{'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1})
|
{'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1})
|
||||||
qs = qs.only(*only)
|
qs = qs.only(*only)
|
||||||
@ -103,14 +100,14 @@ class OnlyExcludeAllTest(unittest.TestCase):
|
|||||||
qs = qs.exclude(*exclude)
|
qs = qs.exclude(*exclude)
|
||||||
self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
|
self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
|
||||||
|
|
||||||
qs = MyDoc.objects.fields(**dict(((i, 1) for i in include)))
|
qs = MyDoc.objects.fields(**{i: 1 for i in include})
|
||||||
qs = qs.exclude(*exclude)
|
qs = qs.exclude(*exclude)
|
||||||
self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1})
|
self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1})
|
||||||
qs = qs.only(*only)
|
qs = qs.only(*only)
|
||||||
self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
|
self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
|
||||||
|
|
||||||
qs = MyDoc.objects.exclude(*exclude)
|
qs = MyDoc.objects.exclude(*exclude)
|
||||||
qs = qs.fields(**dict(((i, 1) for i in include)))
|
qs = qs.fields(**{i: 1 for i in include})
|
||||||
self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1})
|
self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1})
|
||||||
qs = qs.only(*only)
|
qs = qs.only(*only)
|
||||||
self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
|
self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
|
||||||
@ -129,7 +126,7 @@ class OnlyExcludeAllTest(unittest.TestCase):
|
|||||||
exclude = ['d', 'e']
|
exclude = ['d', 'e']
|
||||||
only = ['b', 'c']
|
only = ['b', 'c']
|
||||||
|
|
||||||
qs = MyDoc.objects.fields(**dict(((i, 1) for i in include)))
|
qs = MyDoc.objects.fields(**{i: 1 for i in include})
|
||||||
qs = qs.exclude(*exclude)
|
qs = qs.exclude(*exclude)
|
||||||
qs = qs.only(*only)
|
qs = qs.only(*only)
|
||||||
qs = qs.fields(slice__b=5)
|
qs = qs.fields(slice__b=5)
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
import unittest
|
||||||
|
|
||||||
from pymongo.errors import OperationFailure
|
from pymongo.errors import OperationFailure
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from mongoengine import connect, Document, IntField
|
from mongoengine import connect, Document, IntField
|
||||||
|
@ -9,13 +9,13 @@ from nose.plugins.skip import SkipTest
|
|||||||
import pymongo
|
import pymongo
|
||||||
from pymongo.errors import ConfigurationError
|
from pymongo.errors import ConfigurationError
|
||||||
from pymongo.read_preferences import ReadPreference
|
from pymongo.read_preferences import ReadPreference
|
||||||
|
import six
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
from mongoengine.connection import get_connection, get_db
|
from mongoengine.connection import get_connection, get_db
|
||||||
from mongoengine.context_managers import query_counter, switch_db
|
from mongoengine.context_managers import query_counter, switch_db
|
||||||
from mongoengine.errors import InvalidQueryError
|
from mongoengine.errors import InvalidQueryError
|
||||||
from mongoengine.python_support import IS_PYMONGO_3, PY3
|
from mongoengine.python_support import IS_PYMONGO_3
|
||||||
from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned,
|
from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned,
|
||||||
QuerySet, QuerySetManager, queryset_manager)
|
QuerySet, QuerySetManager, queryset_manager)
|
||||||
|
|
||||||
@ -25,7 +25,10 @@ __all__ = ("QuerySetTest",)
|
|||||||
class db_ops_tracker(query_counter):
|
class db_ops_tracker(query_counter):
|
||||||
|
|
||||||
def get_ops(self):
|
def get_ops(self):
|
||||||
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}}
|
ignore_query = {
|
||||||
|
'ns': {'$ne': '%s.system.indexes' % self.db.name},
|
||||||
|
'command.count': {'$ne': 'system.profile'}
|
||||||
|
}
|
||||||
return list(self.db.system.profile.find(ignore_query))
|
return list(self.db.system.profile.find(ignore_query))
|
||||||
|
|
||||||
|
|
||||||
@ -94,12 +97,12 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
author = ReferenceField(self.Person)
|
author = ReferenceField(self.Person)
|
||||||
author2 = GenericReferenceField()
|
author2 = GenericReferenceField()
|
||||||
|
|
||||||
def test_reference():
|
# test addressing a field from a reference
|
||||||
|
with self.assertRaises(InvalidQueryError):
|
||||||
list(BlogPost.objects(author__name="test"))
|
list(BlogPost.objects(author__name="test"))
|
||||||
|
|
||||||
self.assertRaises(InvalidQueryError, test_reference)
|
# should fail for a generic reference as well
|
||||||
|
with self.assertRaises(InvalidQueryError):
|
||||||
def test_generic_reference():
|
|
||||||
list(BlogPost.objects(author2__name="test"))
|
list(BlogPost.objects(author2__name="test"))
|
||||||
|
|
||||||
def test_find(self):
|
def test_find(self):
|
||||||
@ -174,7 +177,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Test larger slice __repr__
|
# Test larger slice __repr__
|
||||||
self.Person.objects.delete()
|
self.Person.objects.delete()
|
||||||
for i in xrange(55):
|
for i in range(55):
|
||||||
self.Person(name='A%s' % i, age=i).save()
|
self.Person(name='A%s' % i, age=i).save()
|
||||||
|
|
||||||
self.assertEqual(self.Person.objects.count(), 55)
|
self.assertEqual(self.Person.objects.count(), 55)
|
||||||
@ -218,14 +221,15 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
person = self.Person.objects[1]
|
person = self.Person.objects[1]
|
||||||
self.assertEqual(person.name, "User B")
|
self.assertEqual(person.name, "User B")
|
||||||
|
|
||||||
self.assertRaises(IndexError, self.Person.objects.__getitem__, 2)
|
with self.assertRaises(IndexError):
|
||||||
|
self.Person.objects[2]
|
||||||
|
|
||||||
# Find a document using just the object id
|
# Find a document using just the object id
|
||||||
person = self.Person.objects.with_id(person1.id)
|
person = self.Person.objects.with_id(person1.id)
|
||||||
self.assertEqual(person.name, "User A")
|
self.assertEqual(person.name, "User A")
|
||||||
|
|
||||||
self.assertRaises(
|
with self.assertRaises(InvalidQueryError):
|
||||||
InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id)
|
self.Person.objects(name="User A").with_id(person1.id)
|
||||||
|
|
||||||
def test_find_only_one(self):
|
def test_find_only_one(self):
|
||||||
"""Ensure that a query using ``get`` returns at most one result.
|
"""Ensure that a query using ``get`` returns at most one result.
|
||||||
@ -363,7 +367,8 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
# test invalid batch size
|
# test invalid batch size
|
||||||
qs = A.objects.batch_size(-1)
|
qs = A.objects.batch_size(-1)
|
||||||
self.assertRaises(ValueError, lambda: list(qs))
|
with self.assertRaises(ValueError):
|
||||||
|
list(qs)
|
||||||
|
|
||||||
def test_update_write_concern(self):
|
def test_update_write_concern(self):
|
||||||
"""Test that passing write_concern works"""
|
"""Test that passing write_concern works"""
|
||||||
@ -392,18 +397,14 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
"""Test to ensure that update is passed a value to update to"""
|
"""Test to ensure that update is passed a value to update to"""
|
||||||
self.Person.drop_collection()
|
self.Person.drop_collection()
|
||||||
|
|
||||||
author = self.Person(name='Test User')
|
author = self.Person.objects.create(name='Test User')
|
||||||
author.save()
|
|
||||||
|
|
||||||
def update_raises():
|
with self.assertRaises(OperationError):
|
||||||
self.Person.objects(pk=author.pk).update({})
|
self.Person.objects(pk=author.pk).update({})
|
||||||
|
|
||||||
def update_one_raises():
|
with self.assertRaises(OperationError):
|
||||||
self.Person.objects(pk=author.pk).update_one({})
|
self.Person.objects(pk=author.pk).update_one({})
|
||||||
|
|
||||||
self.assertRaises(OperationError, update_raises)
|
|
||||||
self.assertRaises(OperationError, update_one_raises)
|
|
||||||
|
|
||||||
def test_update_array_position(self):
|
def test_update_array_position(self):
|
||||||
"""Ensure that updating by array position works.
|
"""Ensure that updating by array position works.
|
||||||
|
|
||||||
@ -431,8 +432,8 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
Blog.objects.create(posts=[post2, post1])
|
Blog.objects.create(posts=[post2, post1])
|
||||||
|
|
||||||
# Update all of the first comments of second posts of all blogs
|
# Update all of the first comments of second posts of all blogs
|
||||||
Blog.objects().update(set__posts__1__comments__0__name="testc")
|
Blog.objects().update(set__posts__1__comments__0__name='testc')
|
||||||
testc_blogs = Blog.objects(posts__1__comments__0__name="testc")
|
testc_blogs = Blog.objects(posts__1__comments__0__name='testc')
|
||||||
self.assertEqual(testc_blogs.count(), 2)
|
self.assertEqual(testc_blogs.count(), 2)
|
||||||
|
|
||||||
Blog.drop_collection()
|
Blog.drop_collection()
|
||||||
@ -441,14 +442,13 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Update only the first blog returned by the query
|
# Update only the first blog returned by the query
|
||||||
Blog.objects().update_one(
|
Blog.objects().update_one(
|
||||||
set__posts__1__comments__1__name="testc")
|
set__posts__1__comments__1__name='testc')
|
||||||
testc_blogs = Blog.objects(posts__1__comments__1__name="testc")
|
testc_blogs = Blog.objects(posts__1__comments__1__name='testc')
|
||||||
self.assertEqual(testc_blogs.count(), 1)
|
self.assertEqual(testc_blogs.count(), 1)
|
||||||
|
|
||||||
# Check that using this indexing syntax on a non-list fails
|
# Check that using this indexing syntax on a non-list fails
|
||||||
def non_list_indexing():
|
with self.assertRaises(InvalidQueryError):
|
||||||
Blog.objects().update(set__posts__1__comments__0__name__1="asdf")
|
Blog.objects().update(set__posts__1__comments__0__name__1='asdf')
|
||||||
self.assertRaises(InvalidQueryError, non_list_indexing)
|
|
||||||
|
|
||||||
Blog.drop_collection()
|
Blog.drop_collection()
|
||||||
|
|
||||||
@ -516,15 +516,12 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4])
|
self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4])
|
||||||
|
|
||||||
# Nested updates arent supported yet..
|
# Nested updates arent supported yet..
|
||||||
def update_nested():
|
with self.assertRaises(OperationError):
|
||||||
Simple.drop_collection()
|
Simple.drop_collection()
|
||||||
Simple(x=[{'test': [1, 2, 3, 4]}]).save()
|
Simple(x=[{'test': [1, 2, 3, 4]}]).save()
|
||||||
Simple.objects(x__test=2).update(set__x__S__test__S=3)
|
Simple.objects(x__test=2).update(set__x__S__test__S=3)
|
||||||
self.assertEqual(simple.x, [1, 2, 3, 4])
|
self.assertEqual(simple.x, [1, 2, 3, 4])
|
||||||
|
|
||||||
self.assertRaises(OperationError, update_nested)
|
|
||||||
Simple.drop_collection()
|
|
||||||
|
|
||||||
def test_update_using_positional_operator_embedded_document(self):
|
def test_update_using_positional_operator_embedded_document(self):
|
||||||
"""Ensure that the embedded documents can be updated using the positional
|
"""Ensure that the embedded documents can be updated using the positional
|
||||||
operator."""
|
operator."""
|
||||||
@ -617,11 +614,11 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
members = DictField()
|
members = DictField()
|
||||||
|
|
||||||
club = Club()
|
club = Club()
|
||||||
club.members['John'] = dict(gender="M", age=13)
|
club.members['John'] = {'gender': 'M', 'age': 13}
|
||||||
club.save()
|
club.save()
|
||||||
|
|
||||||
Club.objects().update(
|
Club.objects().update(
|
||||||
set__members={"John": dict(gender="F", age=14)})
|
set__members={"John": {'gender': 'F', 'age': 14}})
|
||||||
|
|
||||||
club = Club.objects().first()
|
club = Club.objects().first()
|
||||||
self.assertEqual(club.members['John']['gender'], "F")
|
self.assertEqual(club.members['John']['gender'], "F")
|
||||||
@ -802,7 +799,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
post2 = Post(comments=[comment2, comment2])
|
post2 = Post(comments=[comment2, comment2])
|
||||||
|
|
||||||
blogs = []
|
blogs = []
|
||||||
for i in xrange(1, 100):
|
for i in range(1, 100):
|
||||||
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
|
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
|
||||||
|
|
||||||
Blog.objects.insert(blogs, load_bulk=False)
|
Blog.objects.insert(blogs, load_bulk=False)
|
||||||
@ -839,30 +836,31 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(Blog.objects.count(), 2)
|
self.assertEqual(Blog.objects.count(), 2)
|
||||||
|
|
||||||
# test handles people trying to upsert
|
# test inserting an existing document (shouldn't be allowed)
|
||||||
def throw_operation_error():
|
with self.assertRaises(OperationError):
|
||||||
|
blog = Blog.objects.first()
|
||||||
|
Blog.objects.insert(blog)
|
||||||
|
|
||||||
|
# test inserting a query set
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
blogs = Blog.objects
|
blogs = Blog.objects
|
||||||
Blog.objects.insert(blogs)
|
Blog.objects.insert(blogs)
|
||||||
|
|
||||||
self.assertRaises(OperationError, throw_operation_error)
|
# insert a new doc
|
||||||
|
|
||||||
# Test can insert new doc
|
|
||||||
new_post = Blog(title="code123", id=ObjectId())
|
new_post = Blog(title="code123", id=ObjectId())
|
||||||
Blog.objects.insert(new_post)
|
Blog.objects.insert(new_post)
|
||||||
|
|
||||||
# test handles other classes being inserted
|
class Author(Document):
|
||||||
def throw_operation_error_wrong_doc():
|
pass
|
||||||
class Author(Document):
|
|
||||||
pass
|
# try inserting a different document class
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
Blog.objects.insert(Author())
|
Blog.objects.insert(Author())
|
||||||
|
|
||||||
self.assertRaises(OperationError, throw_operation_error_wrong_doc)
|
# try inserting a non-document
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
def throw_operation_error_not_a_document():
|
|
||||||
Blog.objects.insert("HELLO WORLD")
|
Blog.objects.insert("HELLO WORLD")
|
||||||
|
|
||||||
self.assertRaises(OperationError, throw_operation_error_not_a_document)
|
|
||||||
|
|
||||||
Blog.drop_collection()
|
Blog.drop_collection()
|
||||||
|
|
||||||
blog1 = Blog(title="code", posts=[post1, post2])
|
blog1 = Blog(title="code", posts=[post1, post2])
|
||||||
@ -882,14 +880,13 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
blog3 = Blog(title="baz", posts=[post1, post2])
|
blog3 = Blog(title="baz", posts=[post1, post2])
|
||||||
Blog.objects.insert([blog1, blog2])
|
Blog.objects.insert([blog1, blog2])
|
||||||
|
|
||||||
def throw_operation_error_not_unique():
|
with self.assertRaises(NotUniqueError):
|
||||||
Blog.objects.insert([blog2, blog3])
|
Blog.objects.insert([blog2, blog3])
|
||||||
|
|
||||||
self.assertRaises(NotUniqueError, throw_operation_error_not_unique)
|
|
||||||
self.assertEqual(Blog.objects.count(), 2)
|
self.assertEqual(Blog.objects.count(), 2)
|
||||||
|
|
||||||
Blog.objects.insert([blog2, blog3], write_concern={"w": 0,
|
Blog.objects.insert([blog2, blog3],
|
||||||
'continue_on_error': True})
|
write_concern={"w": 0, 'continue_on_error': True})
|
||||||
self.assertEqual(Blog.objects.count(), 3)
|
self.assertEqual(Blog.objects.count(), 3)
|
||||||
|
|
||||||
def test_get_changed_fields_query_count(self):
|
def test_get_changed_fields_query_count(self):
|
||||||
@ -1022,7 +1019,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
Doc.drop_collection()
|
Doc.drop_collection()
|
||||||
|
|
||||||
for i in xrange(1000):
|
for i in range(1000):
|
||||||
Doc(number=i).save()
|
Doc(number=i).save()
|
||||||
|
|
||||||
docs = Doc.objects.order_by('number')
|
docs = Doc.objects.order_by('number')
|
||||||
@ -1176,7 +1173,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
qs = list(qs)
|
qs = list(qs)
|
||||||
expected = list(expected)
|
expected = list(expected)
|
||||||
self.assertEqual(len(qs), len(expected))
|
self.assertEqual(len(qs), len(expected))
|
||||||
for i in xrange(len(qs)):
|
for i in range(len(qs)):
|
||||||
self.assertEqual(qs[i], expected[i])
|
self.assertEqual(qs[i], expected[i])
|
||||||
|
|
||||||
def test_ordering(self):
|
def test_ordering(self):
|
||||||
@ -1216,7 +1213,8 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.assertSequence(qs, expected)
|
self.assertSequence(qs, expected)
|
||||||
|
|
||||||
def test_clear_ordering(self):
|
def test_clear_ordering(self):
|
||||||
""" Ensure that the default ordering can be cleared by calling order_by().
|
"""Ensure that the default ordering can be cleared by calling
|
||||||
|
order_by() w/o any arguments.
|
||||||
"""
|
"""
|
||||||
class BlogPost(Document):
|
class BlogPost(Document):
|
||||||
title = StringField()
|
title = StringField()
|
||||||
@ -1232,12 +1230,13 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
BlogPost.objects.filter(title='whatever').first()
|
BlogPost.objects.filter(title='whatever').first()
|
||||||
self.assertEqual(len(q.get_ops()), 1)
|
self.assertEqual(len(q.get_ops()), 1)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
q.get_ops()[0]['query']['$orderby'], {u'published_date': -1})
|
q.get_ops()[0]['query']['$orderby'],
|
||||||
|
{'published_date': -1}
|
||||||
|
)
|
||||||
|
|
||||||
with db_ops_tracker() as q:
|
with db_ops_tracker() as q:
|
||||||
BlogPost.objects.filter(title='whatever').order_by().first()
|
BlogPost.objects.filter(title='whatever').order_by().first()
|
||||||
self.assertEqual(len(q.get_ops()), 1)
|
self.assertEqual(len(q.get_ops()), 1)
|
||||||
print q.get_ops()[0]['query']
|
|
||||||
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
|
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
|
||||||
|
|
||||||
def test_no_ordering_for_get(self):
|
def test_no_ordering_for_get(self):
|
||||||
@ -1710,7 +1709,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
Log.drop_collection()
|
Log.drop_collection()
|
||||||
|
|
||||||
for i in xrange(10):
|
for i in range(10):
|
||||||
Log().save()
|
Log().save()
|
||||||
|
|
||||||
Log.objects()[3:5].delete()
|
Log.objects()[3:5].delete()
|
||||||
@ -1910,12 +1909,10 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban')
|
Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban')
|
||||||
self.assertEqual(Site.objects.first().collaborators, [])
|
self.assertEqual(Site.objects.first().collaborators, [])
|
||||||
|
|
||||||
def pull_all():
|
with self.assertRaises(InvalidQueryError):
|
||||||
Site.objects(id=s.id).update_one(
|
Site.objects(id=s.id).update_one(
|
||||||
pull_all__collaborators__user=['Ross'])
|
pull_all__collaborators__user=['Ross'])
|
||||||
|
|
||||||
self.assertRaises(InvalidQueryError, pull_all)
|
|
||||||
|
|
||||||
def test_pull_from_nested_embedded(self):
|
def test_pull_from_nested_embedded(self):
|
||||||
|
|
||||||
class User(EmbeddedDocument):
|
class User(EmbeddedDocument):
|
||||||
@ -1946,12 +1943,10 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
pull__collaborators__unhelpful={'name': 'Frank'})
|
pull__collaborators__unhelpful={'name': 'Frank'})
|
||||||
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
|
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
|
||||||
|
|
||||||
def pull_all():
|
with self.assertRaises(InvalidQueryError):
|
||||||
Site.objects(id=s.id).update_one(
|
Site.objects(id=s.id).update_one(
|
||||||
pull_all__collaborators__helpful__name=['Ross'])
|
pull_all__collaborators__helpful__name=['Ross'])
|
||||||
|
|
||||||
self.assertRaises(InvalidQueryError, pull_all)
|
|
||||||
|
|
||||||
def test_pull_from_nested_mapfield(self):
|
def test_pull_from_nested_mapfield(self):
|
||||||
|
|
||||||
class Collaborator(EmbeddedDocument):
|
class Collaborator(EmbeddedDocument):
|
||||||
@ -1980,12 +1975,10 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
pull__collaborators__unhelpful={'user': 'Frank'})
|
pull__collaborators__unhelpful={'user': 'Frank'})
|
||||||
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
|
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
|
||||||
|
|
||||||
def pull_all():
|
with self.assertRaises(InvalidQueryError):
|
||||||
Site.objects(id=s.id).update_one(
|
Site.objects(id=s.id).update_one(
|
||||||
pull_all__collaborators__helpful__user=['Ross'])
|
pull_all__collaborators__helpful__user=['Ross'])
|
||||||
|
|
||||||
self.assertRaises(InvalidQueryError, pull_all)
|
|
||||||
|
|
||||||
def test_update_one_pop_generic_reference(self):
|
def test_update_one_pop_generic_reference(self):
|
||||||
|
|
||||||
class BlogTag(Document):
|
class BlogTag(Document):
|
||||||
@ -2610,7 +2603,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
BlogPost(hits=2, tags=['music', 'actors']).save()
|
BlogPost(hits=2, tags=['music', 'actors']).save()
|
||||||
|
|
||||||
def test_assertions(f):
|
def test_assertions(f):
|
||||||
f = dict((key, int(val)) for key, val in f.items())
|
f = {key: int(val) for key, val in f.items()}
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(['music', 'film', 'actors', 'watch']), set(f.keys()))
|
set(['music', 'film', 'actors', 'watch']), set(f.keys()))
|
||||||
self.assertEqual(f['music'], 3)
|
self.assertEqual(f['music'], 3)
|
||||||
@ -2625,7 +2618,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Ensure query is taken into account
|
# Ensure query is taken into account
|
||||||
def test_assertions(f):
|
def test_assertions(f):
|
||||||
f = dict((key, int(val)) for key, val in f.items())
|
f = {key: int(val) for key, val in f.items()}
|
||||||
self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys()))
|
self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys()))
|
||||||
self.assertEqual(f['music'], 2)
|
self.assertEqual(f['music'], 2)
|
||||||
self.assertEqual(f['actors'], 1)
|
self.assertEqual(f['actors'], 1)
|
||||||
@ -2689,7 +2682,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
doc.save()
|
doc.save()
|
||||||
|
|
||||||
def test_assertions(f):
|
def test_assertions(f):
|
||||||
f = dict((key, int(val)) for key, val in f.items())
|
f = {key: int(val) for key, val in f.items()}
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(['62-3331-1656', '62-3332-1656']), set(f.keys()))
|
set(['62-3331-1656', '62-3332-1656']), set(f.keys()))
|
||||||
self.assertEqual(f['62-3331-1656'], 2)
|
self.assertEqual(f['62-3331-1656'], 2)
|
||||||
@ -2703,7 +2696,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Ensure query is taken into account
|
# Ensure query is taken into account
|
||||||
def test_assertions(f):
|
def test_assertions(f):
|
||||||
f = dict((key, int(val)) for key, val in f.items())
|
f = {key: int(val) for key, val in f.items()}
|
||||||
self.assertEqual(set(['62-3331-1656']), set(f.keys()))
|
self.assertEqual(set(['62-3331-1656']), set(f.keys()))
|
||||||
self.assertEqual(f['62-3331-1656'], 2)
|
self.assertEqual(f['62-3331-1656'], 2)
|
||||||
|
|
||||||
@ -2810,10 +2803,10 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
Test.drop_collection()
|
Test.drop_collection()
|
||||||
|
|
||||||
for i in xrange(50):
|
for i in range(50):
|
||||||
Test(val=1).save()
|
Test(val=1).save()
|
||||||
|
|
||||||
for i in xrange(20):
|
for i in range(20):
|
||||||
Test(val=2).save()
|
Test(val=2).save()
|
||||||
|
|
||||||
freqs = Test.objects.item_frequencies(
|
freqs = Test.objects.item_frequencies(
|
||||||
@ -3603,7 +3596,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
Post.drop_collection()
|
Post.drop_collection()
|
||||||
|
|
||||||
for i in xrange(10):
|
for i in range(10):
|
||||||
Post(title="Post %s" % i).save()
|
Post(title="Post %s" % i).save()
|
||||||
|
|
||||||
self.assertEqual(5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True))
|
self.assertEqual(5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True))
|
||||||
@ -3618,7 +3611,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
MyDoc.drop_collection()
|
MyDoc.drop_collection()
|
||||||
for i in xrange(0, 10):
|
for i in range(0, 10):
|
||||||
MyDoc().save()
|
MyDoc().save()
|
||||||
|
|
||||||
self.assertEqual(MyDoc.objects.count(), 10)
|
self.assertEqual(MyDoc.objects.count(), 10)
|
||||||
@ -3674,7 +3667,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
Number.drop_collection()
|
Number.drop_collection()
|
||||||
|
|
||||||
for i in xrange(1, 101):
|
for i in range(1, 101):
|
||||||
t = Number(n=i)
|
t = Number(n=i)
|
||||||
t.save()
|
t.save()
|
||||||
|
|
||||||
@ -3821,11 +3814,9 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.assertTrue(a in results)
|
self.assertTrue(a in results)
|
||||||
self.assertTrue(c in results)
|
self.assertTrue(c in results)
|
||||||
|
|
||||||
def invalid_where():
|
with self.assertRaises(TypeError):
|
||||||
list(IntPair.objects.where(fielda__gte=3))
|
list(IntPair.objects.where(fielda__gte=3))
|
||||||
|
|
||||||
self.assertRaises(TypeError, invalid_where)
|
|
||||||
|
|
||||||
def test_scalar(self):
|
def test_scalar(self):
|
||||||
|
|
||||||
class Organization(Document):
|
class Organization(Document):
|
||||||
@ -4081,7 +4072,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Test larger slice __repr__
|
# Test larger slice __repr__
|
||||||
self.Person.objects.delete()
|
self.Person.objects.delete()
|
||||||
for i in xrange(55):
|
for i in range(55):
|
||||||
self.Person(name='A%s' % i, age=i).save()
|
self.Person(name='A%s' % i, age=i).save()
|
||||||
|
|
||||||
self.assertEqual(self.Person.objects.scalar('name').count(), 55)
|
self.assertEqual(self.Person.objects.scalar('name').count(), 55)
|
||||||
@ -4089,7 +4080,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
"A0", "%s" % self.Person.objects.order_by('name').scalar('name').first())
|
"A0", "%s" % self.Person.objects.order_by('name').scalar('name').first())
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0])
|
"A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0])
|
||||||
if PY3:
|
if six.PY3:
|
||||||
self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by(
|
self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by(
|
||||||
'age').scalar('name')[1:3])
|
'age').scalar('name')[1:3])
|
||||||
self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by(
|
self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by(
|
||||||
@ -4107,7 +4098,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
pks = self.Person.objects.order_by('age').scalar('pk')[1:3]
|
pks = self.Person.objects.order_by('age').scalar('pk')[1:3]
|
||||||
names = self.Person.objects.scalar('name').in_bulk(list(pks)).values()
|
names = self.Person.objects.scalar('name').in_bulk(list(pks)).values()
|
||||||
if PY3:
|
if six.PY3:
|
||||||
expected = "['A1', 'A2']"
|
expected = "['A1', 'A2']"
|
||||||
else:
|
else:
|
||||||
expected = "[u'A1', u'A2']"
|
expected = "[u'A1', u'A2']"
|
||||||
@ -4463,7 +4454,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
name = StringField()
|
name = StringField()
|
||||||
|
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
for i in xrange(100):
|
for i in range(100):
|
||||||
Person(name="No: %s" % i).save()
|
Person(name="No: %s" % i).save()
|
||||||
|
|
||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
@ -4494,7 +4485,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
name = StringField()
|
name = StringField()
|
||||||
|
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
for i in xrange(100):
|
for i in range(100):
|
||||||
Person(name="No: %s" % i).save()
|
Person(name="No: %s" % i).save()
|
||||||
|
|
||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
@ -4538,7 +4529,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
fields = DictField()
|
fields = DictField()
|
||||||
|
|
||||||
Noddy.drop_collection()
|
Noddy.drop_collection()
|
||||||
for i in xrange(100):
|
for i in range(100):
|
||||||
noddy = Noddy()
|
noddy = Noddy()
|
||||||
for j in range(20):
|
for j in range(20):
|
||||||
noddy.fields["key" + str(j)] = "value " + str(j)
|
noddy.fields["key" + str(j)] = "value " + str(j)
|
||||||
@ -4550,7 +4541,9 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.assertEqual(counter, 100)
|
self.assertEqual(counter, 100)
|
||||||
|
|
||||||
self.assertEqual(len(list(docs)), 100)
|
self.assertEqual(len(list(docs)), 100)
|
||||||
self.assertRaises(TypeError, lambda: len(docs))
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
len(docs)
|
||||||
|
|
||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
self.assertEqual(q, 0)
|
self.assertEqual(q, 0)
|
||||||
@ -4739,7 +4732,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
name = StringField()
|
name = StringField()
|
||||||
|
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
for i in xrange(100):
|
for i in range(100):
|
||||||
Person(name="No: %s" % i).save()
|
Person(name="No: %s" % i).save()
|
||||||
|
|
||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
@ -4863,10 +4856,10 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
def test_delete_count(self):
|
def test_delete_count(self):
|
||||||
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)]
|
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)]
|
||||||
self.assertEqual(self.Person.objects().delete(), 3) # test ordinary QuerySey delete count
|
self.assertEqual(self.Person.objects().delete(), 3) # test ordinary QuerySey delete count
|
||||||
|
|
||||||
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)]
|
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)]
|
||||||
|
|
||||||
self.assertEqual(self.Person.objects().skip(1).delete(), 2) # test Document delete with existing documents
|
self.assertEqual(self.Person.objects().skip(1).delete(), 2) # test Document delete with existing documents
|
||||||
|
|
||||||
@ -4875,12 +4868,14 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_max_time_ms(self):
|
def test_max_time_ms(self):
|
||||||
# 778: max_time_ms can get only int or None as input
|
# 778: max_time_ms can get only int or None as input
|
||||||
self.assertRaises(TypeError, self.Person.objects(name="name").max_time_ms, "not a number")
|
self.assertRaises(TypeError,
|
||||||
|
self.Person.objects(name="name").max_time_ms,
|
||||||
|
'not a number')
|
||||||
|
|
||||||
def test_subclass_field_query(self):
|
def test_subclass_field_query(self):
|
||||||
class Animal(Document):
|
class Animal(Document):
|
||||||
is_mamal = BooleanField()
|
is_mamal = BooleanField()
|
||||||
meta = dict(allow_inheritance=True)
|
meta = {'allow_inheritance': True}
|
||||||
|
|
||||||
class Cat(Animal):
|
class Cat(Animal):
|
||||||
whiskers_length = FloatField()
|
whiskers_length = FloatField()
|
||||||
@ -4925,7 +4920,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
class Data(Document):
|
class Data(Document):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for i in xrange(300):
|
for i in range(300):
|
||||||
Data().save()
|
Data().save()
|
||||||
|
|
||||||
records = Data.objects.limit(250)
|
records = Data.objects.limit(250)
|
||||||
@ -4957,7 +4952,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
class Data(Document):
|
class Data(Document):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for i in xrange(300):
|
for i in range(300):
|
||||||
Data().save()
|
Data().save()
|
||||||
|
|
||||||
qs = Data.objects.limit(250)
|
qs = Data.objects.limit(250)
|
||||||
|
@ -238,7 +238,8 @@ class TransformTest(unittest.TestCase):
|
|||||||
box = [(35.0, -125.0), (40.0, -100.0)]
|
box = [(35.0, -125.0), (40.0, -100.0)]
|
||||||
# I *meant* to execute location__within_box=box
|
# I *meant* to execute location__within_box=box
|
||||||
events = Event.objects(location__within=box)
|
events = Event.objects(location__within=box)
|
||||||
self.assertRaises(InvalidQueryError, lambda: events.count())
|
with self.assertRaises(InvalidQueryError):
|
||||||
|
events.count()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -185,7 +185,7 @@ class QTest(unittest.TestCase):
|
|||||||
x = IntField()
|
x = IntField()
|
||||||
|
|
||||||
TestDoc.drop_collection()
|
TestDoc.drop_collection()
|
||||||
for i in xrange(1, 101):
|
for i in range(1, 101):
|
||||||
t = TestDoc(x=i)
|
t = TestDoc(x=i)
|
||||||
t.save()
|
t.save()
|
||||||
|
|
||||||
@ -268,14 +268,13 @@ class QTest(unittest.TestCase):
|
|||||||
self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3)
|
self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3)
|
||||||
|
|
||||||
# Test invalid query objs
|
# Test invalid query objs
|
||||||
def wrong_query_objs():
|
with self.assertRaises(InvalidQueryError):
|
||||||
self.Person.objects('user1')
|
self.Person.objects('user1')
|
||||||
|
|
||||||
def wrong_query_objs_filter():
|
# filter should fail, too
|
||||||
self.Person.objects('user1')
|
with self.assertRaises(InvalidQueryError):
|
||||||
|
self.Person.objects.filter('user1')
|
||||||
|
|
||||||
self.assertRaises(InvalidQueryError, wrong_query_objs)
|
|
||||||
self.assertRaises(InvalidQueryError, wrong_query_objs_filter)
|
|
||||||
|
|
||||||
def test_q_regex(self):
|
def test_q_regex(self):
|
||||||
"""Ensure that Q objects can be queried using regexes.
|
"""Ensure that Q objects can be queried using regexes.
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
import sys
|
|
||||||
import datetime
|
import datetime
|
||||||
from pymongo.errors import OperationFailure
|
from pymongo.errors import OperationFailure
|
||||||
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import unittest2 as unittest
|
import unittest2 as unittest
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -19,7 +16,8 @@ from mongoengine import (
|
|||||||
)
|
)
|
||||||
from mongoengine.python_support import IS_PYMONGO_3
|
from mongoengine.python_support import IS_PYMONGO_3
|
||||||
import mongoengine.connection
|
import mongoengine.connection
|
||||||
from mongoengine.connection import get_db, get_connection, ConnectionError
|
from mongoengine.connection import (MongoEngineConnectionError, get_db,
|
||||||
|
get_connection)
|
||||||
|
|
||||||
|
|
||||||
def get_tz_awareness(connection):
|
def get_tz_awareness(connection):
|
||||||
@ -159,7 +157,10 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
c.mongoenginetest.add_user("username", "password")
|
c.mongoenginetest.add_user("username", "password")
|
||||||
|
|
||||||
if not IS_PYMONGO_3:
|
if not IS_PYMONGO_3:
|
||||||
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
|
self.assertRaises(
|
||||||
|
MongoEngineConnectionError, connect, 'testdb_uri_bad',
|
||||||
|
host='mongodb://test:password@localhost'
|
||||||
|
)
|
||||||
|
|
||||||
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
|
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
|
||||||
|
|
||||||
@ -229,10 +230,11 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
self.assertRaises(OperationFailure, test_conn.server_info)
|
self.assertRaises(OperationFailure, test_conn.server_info)
|
||||||
else:
|
else:
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ConnectionError, connect, 'mongoenginetest', alias='test1',
|
MongoEngineConnectionError, connect, 'mongoenginetest',
|
||||||
|
alias='test1',
|
||||||
host='mongodb://username2:password@localhost/mongoenginetest'
|
host='mongodb://username2:password@localhost/mongoenginetest'
|
||||||
)
|
)
|
||||||
self.assertRaises(ConnectionError, get_db, 'test1')
|
self.assertRaises(MongoEngineConnectionError, get_db, 'test1')
|
||||||
|
|
||||||
# Authentication succeeds with "authSource"
|
# Authentication succeeds with "authSource"
|
||||||
connect(
|
connect(
|
||||||
@ -253,7 +255,7 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
register_connection('testdb', 'mongoenginetest2')
|
register_connection('testdb', 'mongoenginetest2')
|
||||||
|
|
||||||
self.assertRaises(ConnectionError, get_connection)
|
self.assertRaises(MongoEngineConnectionError, get_connection)
|
||||||
conn = get_connection('testdb')
|
conn = get_connection('testdb')
|
||||||
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
|
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
@ -79,7 +77,7 @@ class ContextManagersTest(unittest.TestCase):
|
|||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
User(name='user %s' % i).save()
|
User(name='user %s' % i).save()
|
||||||
|
|
||||||
user = User.objects.first()
|
user = User.objects.first()
|
||||||
@ -117,7 +115,7 @@ class ContextManagersTest(unittest.TestCase):
|
|||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
User(name='user %s' % i).save()
|
User(name='user %s' % i).save()
|
||||||
|
|
||||||
user = User.objects.first()
|
user = User.objects.first()
|
||||||
@ -195,7 +193,7 @@ class ContextManagersTest(unittest.TestCase):
|
|||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
self.assertEqual(0, q)
|
self.assertEqual(0, q)
|
||||||
|
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
db.test.find({}).count()
|
db.test.find({}).count()
|
||||||
|
|
||||||
self.assertEqual(50, q)
|
self.assertEqual(50, q)
|
||||||
|
@ -23,7 +23,8 @@ class TestStrictDict(unittest.TestCase):
|
|||||||
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
|
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
|
||||||
|
|
||||||
def test_init_fails_on_nonexisting_attrs(self):
|
def test_init_fails_on_nonexisting_attrs(self):
|
||||||
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))
|
with self.assertRaises(AttributeError):
|
||||||
|
self.dtype(a=1, b=2, d=3)
|
||||||
|
|
||||||
def test_eq(self):
|
def test_eq(self):
|
||||||
d = self.dtype(a=1, b=1, c=1)
|
d = self.dtype(a=1, b=1, c=1)
|
||||||
@ -46,14 +47,12 @@ class TestStrictDict(unittest.TestCase):
|
|||||||
d = self.dtype()
|
d = self.dtype()
|
||||||
d.a = 1
|
d.a = 1
|
||||||
self.assertEqual(d.a, 1)
|
self.assertEqual(d.a, 1)
|
||||||
self.assertRaises(AttributeError, lambda: d.b)
|
self.assertRaises(AttributeError, getattr, d, 'b')
|
||||||
|
|
||||||
def test_setattr_raises_on_nonexisting_attr(self):
|
def test_setattr_raises_on_nonexisting_attr(self):
|
||||||
d = self.dtype()
|
d = self.dtype()
|
||||||
|
with self.assertRaises(AttributeError):
|
||||||
def _f():
|
|
||||||
d.x = 1
|
d.x = 1
|
||||||
self.assertRaises(AttributeError, _f)
|
|
||||||
|
|
||||||
def test_setattr_getattr_special(self):
|
def test_setattr_getattr_special(self):
|
||||||
d = self.strict_dict_class(["items"])
|
d = self.strict_dict_class(["items"])
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from bson import DBRef, ObjectId
|
from bson import DBRef, ObjectId
|
||||||
@ -32,7 +30,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
user = User(name='user %s' % i)
|
user = User(name='user %s' % i)
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
@ -90,7 +88,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
user = User(name='user %s' % i)
|
user = User(name='user %s' % i)
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
@ -162,7 +160,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
for i in xrange(1, 26):
|
for i in range(1, 26):
|
||||||
user = User(name='user %s' % i)
|
user = User(name='user %s' % i)
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
@ -440,7 +438,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
members = []
|
members = []
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
a = UserA(name='User A %s' % i)
|
a = UserA(name='User A %s' % i)
|
||||||
a.save()
|
a.save()
|
||||||
|
|
||||||
@ -531,7 +529,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
members = []
|
members = []
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
a = UserA(name='User A %s' % i)
|
a = UserA(name='User A %s' % i)
|
||||||
a.save()
|
a.save()
|
||||||
|
|
||||||
@ -614,15 +612,15 @@ class FieldTest(unittest.TestCase):
|
|||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
members = []
|
members = []
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
user = User(name='user %s' % i)
|
user = User(name='user %s' % i)
|
||||||
user.save()
|
user.save()
|
||||||
members.append(user)
|
members.append(user)
|
||||||
|
|
||||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
group = Group(members={str(u.id): u for u in members})
|
||||||
group.save()
|
group.save()
|
||||||
|
|
||||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
group = Group(members={str(u.id): u for u in members})
|
||||||
group.save()
|
group.save()
|
||||||
|
|
||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
@ -687,7 +685,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
members = []
|
members = []
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
a = UserA(name='User A %s' % i)
|
a = UserA(name='User A %s' % i)
|
||||||
a.save()
|
a.save()
|
||||||
|
|
||||||
@ -699,9 +697,9 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
members += [a, b, c]
|
members += [a, b, c]
|
||||||
|
|
||||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
group = Group(members={str(u.id): u for u in members})
|
||||||
group.save()
|
group.save()
|
||||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
group = Group(members={str(u.id): u for u in members})
|
||||||
group.save()
|
group.save()
|
||||||
|
|
||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
@ -783,16 +781,16 @@ class FieldTest(unittest.TestCase):
|
|||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
members = []
|
members = []
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
a = UserA(name='User A %s' % i)
|
a = UserA(name='User A %s' % i)
|
||||||
a.save()
|
a.save()
|
||||||
|
|
||||||
members += [a]
|
members += [a]
|
||||||
|
|
||||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
group = Group(members={str(u.id): u for u in members})
|
||||||
group.save()
|
group.save()
|
||||||
|
|
||||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
group = Group(members={str(u.id): u for u in members})
|
||||||
group.save()
|
group.save()
|
||||||
|
|
||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
@ -866,7 +864,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
members = []
|
members = []
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
a = UserA(name='User A %s' % i)
|
a = UserA(name='User A %s' % i)
|
||||||
a.save()
|
a.save()
|
||||||
|
|
||||||
@ -878,9 +876,9 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
members += [a, b, c]
|
members += [a, b, c]
|
||||||
|
|
||||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
group = Group(members={str(u.id): u for u in members})
|
||||||
group.save()
|
group.save()
|
||||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
group = Group(members={str(u.id): u for u in members})
|
||||||
group.save()
|
group.save()
|
||||||
|
|
||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
@ -1103,7 +1101,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
User(name='user %s' % i).save()
|
User(name='user %s' % i).save()
|
||||||
|
|
||||||
Group(name="Test", members=User.objects).save()
|
Group(name="Test", members=User.objects).save()
|
||||||
@ -1132,7 +1130,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
User(name='user %s' % i).save()
|
User(name='user %s' % i).save()
|
||||||
|
|
||||||
Group(name="Test", members=User.objects).save()
|
Group(name="Test", members=User.objects).save()
|
||||||
@ -1169,7 +1167,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
Group.drop_collection()
|
Group.drop_collection()
|
||||||
|
|
||||||
members = []
|
members = []
|
||||||
for i in xrange(1, 51):
|
for i in range(1, 51):
|
||||||
a = UserA(name='User A %s' % i).save()
|
a = UserA(name='User A %s' % i).save()
|
||||||
b = UserB(name='User B %s' % i).save()
|
b = UserB(name='User B %s' % i).save()
|
||||||
c = UserC(name='User C %s' % i).save()
|
c = UserC(name='User C %s' % i).save()
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from pymongo import ReadPreference
|
from pymongo import ReadPreference
|
||||||
@ -18,7 +15,7 @@ else:
|
|||||||
|
|
||||||
import mongoengine
|
import mongoengine
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
from mongoengine.connection import ConnectionError
|
from mongoengine.connection import MongoEngineConnectionError
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTest(unittest.TestCase):
|
class ConnectionTest(unittest.TestCase):
|
||||||
@ -41,7 +38,7 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
conn = connect(db='mongoenginetest',
|
conn = connect(db='mongoenginetest',
|
||||||
host="mongodb://localhost/mongoenginetest?replicaSet=rs",
|
host="mongodb://localhost/mongoenginetest?replicaSet=rs",
|
||||||
read_preference=READ_PREF)
|
read_preference=READ_PREF)
|
||||||
except ConnectionError, e:
|
except MongoEngineConnectionError as e:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not isinstance(conn, CONN_CLASS):
|
if not isinstance(conn, CONN_CLASS):
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
sys.path[0:0] = [""]
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
|
Loading…
x
Reference in New Issue
Block a user