diff --git a/.gitignore b/.gitignore index 42dcc6e6..51a9ca1d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,9 @@ *.pyc .*.swp +*.egg docs/.build docs/_build build/ dist/ -mongoengine.egg-info/ \ No newline at end of file +mongoengine.egg-info/ +env/ \ No newline at end of file diff --git a/AUTHORS b/AUTHORS index 93ecfa8d..93fe819e 100644 --- a/AUTHORS +++ b/AUTHORS @@ -2,3 +2,4 @@ Harry Marr Matt Dennewitz Deepak Thukral Florian Schlachter +Steve Challis diff --git a/docs/apireference.rst b/docs/apireference.rst index 267b22aa..34d4536d 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -64,3 +64,7 @@ Fields .. autoclass:: mongoengine.ReferenceField .. autoclass:: mongoengine.GenericReferenceField + +.. autoclass:: mongoengine.FileField + +.. autoclass:: mongoengine.GeoPointField diff --git a/docs/changelog.rst b/docs/changelog.rst index 479ea21c..d7c6fe85 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,32 @@ Changelog ========= +Changes in v0.4 +=============== +- Added ``GridFSStorage`` Django storage backend +- Added ``FileField`` for GridFS support +- New Q-object implementation, which is no longer based on Javascript +- Added ``SortedListField`` +- Added ``EmailField`` +- Added ``GeoPointField`` +- Added ``exact`` and ``iexact`` match operators to ``QuerySet`` +- Added ``get_document_or_404`` and ``get_list_or_404`` Django shortcuts +- Added new query operators for Geo queries +- Added ``not`` query operator +- Added new update operators: ``pop`` and ``add_to_set`` +- Added ``__raw__`` query parameter +- Added support for custom querysets +- Fixed document inheritance primary key issue +- Added support for querying by array element position +- Base class can now be defined for ``DictField`` +- Fixed MRO error that occured on document inheritance +- Added ``QuerySet.distinct``, ``QuerySet.create``, ``QuerySet.snapshot``, + ``QuerySet.timeout`` and ``QuerySet.all`` +- Subsequent calls to ``connect()`` now work +- Introduced ``min_length`` for ``StringField`` +- Fixed multi-process connection issue +- Other minor fixes + Changes in v0.3 =============== - Added MapReduce support diff --git a/docs/django.rst b/docs/django.rst index 92a8a52b..8a490571 100644 --- a/docs/django.rst +++ b/docs/django.rst @@ -19,7 +19,7 @@ MongoDB but still use many of the Django authentication infrastucture (such as the :func:`login_required` decorator and the :func:`authenticate` function). To enable the MongoEngine auth backend, add the following to you **settings.py** file:: - + AUTHENTICATION_BACKENDS = ( 'mongoengine.django.auth.MongoEngineBackend', ) @@ -44,3 +44,44 @@ into you settings module:: SESSION_ENGINE = 'mongoengine.django.sessions' .. versionadded:: 0.2.1 + +Storage +======= +With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`, +it is useful to have a Django file storage backend that wraps this. The new +storage module is called :class:`~mongoengine.django.GridFSStorage`. Using it +is very similar to using the default FileSystemStorage.:: + + fs = mongoengine.django.GridFSStorage() + + filename = fs.save('hello.txt', 'Hello, World!') + +All of the `Django Storage API methods +`_ have been +implemented except :func:`path`. If the filename provided already exists, an +underscore and a number (before # the file extension, if one exists) will be +appended to the filename until the generated filename doesn't exist. The +:func:`save` method will return the new filename.:: + + >>> fs.exists('hello.txt') + True + >>> fs.open('hello.txt').read() + 'Hello, World!' + >>> fs.size('hello.txt') + 13 + >>> fs.url('hello.txt') + 'http://your_media_url/hello.txt' + >>> fs.open('hello.txt').name + 'hello.txt' + >>> fs.listdir() + ([], [u'hello.txt']) + +All files will be saved and retrieved in GridFS via the :class::`FileDocument` +document, allowing easy access to the files without the GridFSStorage +backend.:: + + >>> from mongoengine.django.storage import FileDocument + >>> FileDocument.objects() + [] + +.. versionadded:: 0.4 diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 3c276869..106d4ec8 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -46,6 +46,12 @@ are as follows: * :class:`~mongoengine.EmbeddedDocumentField` * :class:`~mongoengine.ReferenceField` * :class:`~mongoengine.GenericReferenceField` +* :class:`~mongoengine.BooleanField` +* :class:`~mongoengine.FileField` +* :class:`~mongoengine.EmailField` +* :class:`~mongoengine.SortedListField` +* :class:`~mongoengine.BinaryField` +* :class:`~mongoengine.GeoPointField` Field arguments --------------- @@ -66,6 +72,25 @@ arguments can be set on all fields: :attr:`default` (Default: None) A value to use when no value is set for this field. + The definion of default parameters follow `the general rules on Python + `__, + which means that some care should be taken when dealing with default mutable objects + (like in :class:`~mongoengine.ListField` or :class:`~mongoengine.DictField`):: + + class ExampleFirst(Document): + # Default an empty list + values = ListField(IntField(), default=list) + + class ExampleSecond(Document): + # Default a set of values + values = ListField(IntField(), default=lambda: [1,2,3]) + + class ExampleDangerous(Document): + # This can make an .append call to add values to the default (and all the following objects), + # instead to just an object + values = ListField(IntField(), default=[1,2,3]) + + :attr:`unique` (Default: False) When True, no documents in the collection will have the same value for this field. @@ -214,6 +239,20 @@ either a single field name, or a list or tuple of field names:: first_name = StringField() last_name = StringField(unique_with='first_name') +Skipping Document validation on save +------------------------------------ +You can also skip the whole document validation process by setting +``validate=False`` when caling the :meth:`~mongoengine.document.Document.save` +method:: + + class Recipient(Document): + name = StringField() + email = EmailField() + + recipient = Recipient(name='admin', email='root@localhost') + recipient.save() # will raise a ValidationError while + recipient.save(validate=False) # won't + Document collections ==================== Document classes that inherit **directly** from :class:`~mongoengine.Document` @@ -259,6 +298,10 @@ or a **-** sign. Note that direction only matters on multi-field indexes. :: meta = { 'indexes': ['title', ('title', '-rating')] } + +.. note:: + Geospatial indexes will be automatically created for all + :class:`~mongoengine.GeoPointField`\ s Ordering ======== diff --git a/docs/guide/document-instances.rst b/docs/guide/document-instances.rst index b5a1f029..7b5d165b 100644 --- a/docs/guide/document-instances.rst +++ b/docs/guide/document-instances.rst @@ -59,6 +59,13 @@ you may still use :attr:`id` to access the primary key if you want:: >>> bob.id == bob.email == 'bob@example.com' True +You can also access the document's "primary key" using the :attr:`pk` field; in +is an alias to :attr:`id`:: + + >>> page = Page(title="Another Test Page") + >>> page.save() + >>> page.id == page.pk + .. note:: If you define your own primary key field, the field implicitly becomes required, so a :class:`ValidationError` will be thrown if you don't provide diff --git a/docs/guide/gridfs.rst b/docs/guide/gridfs.rst new file mode 100644 index 00000000..0cd06539 --- /dev/null +++ b/docs/guide/gridfs.rst @@ -0,0 +1,83 @@ +====== +GridFS +====== + +.. versionadded:: 0.4 + +Writing +------- + +GridFS support comes in the form of the :class:`~mongoengine.FileField` field +object. This field acts as a file-like object and provides a couple of +different ways of inserting and retrieving data. Arbitrary metadata such as +content type can also be stored alongside the files. In the following example, +a document is created to store details about animals, including a photo:: + + class Animal(Document): + genus = StringField() + family = StringField() + photo = FileField() + + marmot = Animal('Marmota', 'Sciuridae') + + marmot_photo = open('marmot.jpg', 'r') # Retrieve a photo from disk + marmot.photo = marmot_photo # Store photo in the document + marmot.photo.content_type = 'image/jpeg' # Store metadata + + marmot.save() + +Another way of writing to a :class:`~mongoengine.FileField` is to use the +:func:`put` method. This allows for metadata to be stored in the same call as +the file:: + + marmot.photo.put(marmot_photo, content_type='image/jpeg') + + marmot.save() + +Retrieval +--------- + +So using the :class:`~mongoengine.FileField` is just like using any other +field. The file can also be retrieved just as easily:: + + marmot = Animal.objects(genus='Marmota').first() + photo = marmot.photo.read() + content_type = marmot.photo.content_type + +Streaming +--------- + +Streaming data into a :class:`~mongoengine.FileField` is achieved in a +slightly different manner. First, a new file must be created by calling the +:func:`new_file` method. Data can then be written using :func:`write`:: + + marmot.photo.new_file() + marmot.photo.write('some_image_data') + marmot.photo.write('some_more_image_data') + marmot.photo.close() + + marmot.photo.save() + +Deletion +-------- + +Deleting stored files is achieved with the :func:`delete` method:: + + marmot.photo.delete() + +.. note:: + The FileField in a Document actually only stores the ID of a file in a + separate GridFS collection. This means that deleting a document + with a defined FileField does not actually delete the file. You must be + careful to delete any files in a Document as above before deleting the + Document itself. + + +Replacing files +--------------- + +Files can be replaced with the :func:`replace` method. This works just like +the :func:`put` method so even metadata can (and should) be replaced:: + + another_marmot = open('another_marmot.png', 'r') + marmot.photo.replace(another_marmot, content_type='image/png') diff --git a/docs/guide/index.rst b/docs/guide/index.rst index 7fdfe932..aac72469 100644 --- a/docs/guide/index.rst +++ b/docs/guide/index.rst @@ -10,3 +10,4 @@ User Guide defining-documents document-instances querying + gridfs diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 113ee431..832fed50 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -34,7 +34,7 @@ arguments. The keys in the keyword arguments correspond to fields on the Fields on embedded documents may also be referred to using field lookup syntax by using a double-underscore in place of the dot in object attribute access syntax:: - + # This will return a QuerySet that will only iterate over pages that have # been written by a user whose 'country' field is set to 'uk' uk_pages = Page.objects(author__country='uk') @@ -53,11 +53,21 @@ lists that contain that item will be matched:: # 'tags' list Page.objects(tags='coding') +Raw queries +----------- +It is possible to provide a raw PyMongo query as a query parameter, which will +be integrated directly into the query. This is done using the ``__raw__`` +keyword argument:: + + Page.objects(__raw__={'tags': 'coding'}) + +.. versionadded:: 0.4 + Query operators =============== Operators other than equality may also be used in queries; just attach the operator name to a key with a double-underscore:: - + # Only find users whose age is 18 or less young_users = Users.objects(age__lte=18) @@ -68,10 +78,12 @@ Available operators are as follows: * ``lte`` -- less than or equal to * ``gt`` -- greater than * ``gte`` -- greater than or equal to +* ``not`` -- negate a standard check, may be used before other operators (e.g. + ``Q(age__not__mod=5)``) * ``in`` -- value is in list (a list of values should be provided) * ``nin`` -- value is not in list (a list of values should be provided) * ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values -* ``all`` -- every item in array is in list of values provided +* ``all`` -- every item in list of values provided is in array * ``size`` -- the size of the array is * ``exists`` -- value for field exists @@ -89,6 +101,27 @@ expressions: .. versionadded:: 0.3 +There are a few special operators for performing geographical queries, that +may used with :class:`~mongoengine.GeoPointField`\ s: + +* ``within_distance`` -- provide a list containing a point and a maximum + distance (e.g. [(41.342, -87.653), 5]) +* ``within_box`` -- filter documents to those within a given bounding box (e.g. + [(35.0, -125.0), (40.0, -100.0)]) +* ``near`` -- order the documents by how close they are to a given point + +.. versionadded:: 0.4 + +Querying by position +==================== +It is possible to query by position in a list by using a numerical value as a +query operator. So if you wanted to find all pages whose first tag was ``db``, +you could use the following query:: + + BlogPost.objects(tags__0='db') + +.. versionadded:: 0.4 + Limiting and skipping results ============================= Just as with traditional ORMs, you may limit the number of results returned, or @@ -111,7 +144,7 @@ You may also index the query to retrieve a single result. If an item at that index does not exists, an :class:`IndexError` will be raised. A shortcut for retrieving the first result and returning :attr:`None` if no result exists is provided (:meth:`~mongoengine.queryset.QuerySet.first`):: - + >>> # Make sure there are no users >>> User.drop_collection() >>> User.objects[0] @@ -174,13 +207,29 @@ custom manager methods as you like:: @queryset_manager def live_posts(doc_cls, queryset): - return queryset(published=True).filter(published=True) + return queryset.filter(published=True) BlogPost(title='test1', published=False).save() BlogPost(title='test2', published=True).save() assert len(BlogPost.objects) == 2 assert len(BlogPost.live_posts) == 1 +Custom QuerySets +================ +Should you want to add custom methods for interacting with or filtering +documents, extending the :class:`~mongoengine.queryset.QuerySet` class may be +the way to go. To use a custom :class:`~mongoengine.queryset.QuerySet` class on +a document, set ``queryset_class`` to the custom class in a +:class:`~mongoengine.Document`\ s ``meta`` dictionary:: + + class AwesomerQuerySet(QuerySet): + pass + + class Page(Document): + meta = {'queryset_class': AwesomerQuerySet} + +.. versionadded:: 0.4 + Aggregation =========== MongoDB provides some aggregation methods out of the box, but there are not as @@ -399,14 +448,17 @@ that you may use with these methods: * ``unset`` -- delete a particular value (since MongoDB v1.3+) * ``inc`` -- increment a value by a given amount * ``dec`` -- decrement a value by a given amount +* ``pop`` -- remove the last item from a list * ``push`` -- append a value to a list * ``push_all`` -- append several values to a list +* ``pop`` -- remove the first or last element of a list * ``pull`` -- remove a value from a list * ``pull_all`` -- remove several values from a list +* ``add_to_set`` -- add value to a list only if its not in the list already The syntax for atomic updates is similar to the querying syntax, but the modifier comes before the field, not after it:: - + >>> post = BlogPost(title='Test', page_views=0, tags=['database']) >>> post.save() >>> BlogPost.objects(id=post.id).update_one(inc__page_views=1) diff --git a/docs/index.rst b/docs/index.rst index a28b344c..ccb7fbe2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ MongoDB. To install it, simply run .. code-block:: console - # easy_install -U mongoengine + # pip install -U mongoengine The source is available on `GitHub `_. diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index e01d31ae..6d18ffe7 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ + __author__ = 'Harry Marr' -VERSION = (0, 3, 0) +VERSION = (0, 4, 0) def get_version(): version = '%s.%s' % (VERSION[0], VERSION[1]) diff --git a/mongoengine/base.py b/mongoengine/base.py index 22347a2c..6b74cb07 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -23,10 +23,11 @@ class BaseField(object): # Fields may have _types inserted into indexes by default _index_with_types = True - + _geo_index = False + def __init__(self, db_field=None, name=None, required=False, default=None, - unique=False, unique_with=None, primary_key=False, validation=None, - choices=None): + unique=False, unique_with=None, primary_key=False, + validation=None, choices=None): self.db_field = (db_field or name) if not primary_key else '_id' if name: import warnings @@ -87,22 +88,24 @@ class BaseField(object): # check choices if self.choices is not None: if value not in self.choices: - raise ValidationError("Value must be one of %s."%unicode(self.choices)) - + raise ValidationError("Value must be one of %s." + % unicode(self.choices)) + # check validation argument if self.validation is not None: if callable(self.validation): if not self.validation(value): - raise ValidationError('Value does not match custom validation method.') + raise ValidationError('Value does not match custom' \ + 'validation method.') else: raise ValueError('validation argument must be a callable.') - + self.validate(value) class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ - + def to_python(self, value): return value # return unicode(value) @@ -148,7 +151,7 @@ class DocumentMetaclass(type): # Get superclasses from superclass superclasses[base._class_name] = base superclasses.update(base._superclasses) - + if hasattr(base, '_meta'): # Ensure that the Document class may be subclassed - # inheritance may be disabled to remove dependency on @@ -189,20 +192,23 @@ class DocumentMetaclass(type): field.owner_document = new_class module = attrs.get('__module__') - + base_excs = tuple(base.DoesNotExist for base in bases if hasattr(base, 'DoesNotExist')) or (DoesNotExist,) exc = subclass_exception('DoesNotExist', base_excs, module) new_class.add_to_class('DoesNotExist', exc) - + base_excs = tuple(base.MultipleObjectsReturned for base in bases if hasattr(base, 'MultipleObjectsReturned')) base_excs = base_excs or (MultipleObjectsReturned,) exc = subclass_exception('MultipleObjectsReturned', base_excs, module) new_class.add_to_class('MultipleObjectsReturned', exc) - + + global _document_registry + _document_registry[name] = new_class + return new_class - + def add_to_class(self, name, value): setattr(self, name, value) @@ -213,8 +219,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): """ def __new__(cls, name, bases, attrs): - global _document_registry - super_new = super(TopLevelDocumentMetaclass, cls).__new__ # Classes defined in this package are abstract and should not have # their own metadata with DB collection, etc. @@ -225,15 +229,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): return super_new(cls, name, bases, attrs) collection = name.lower() - + id_field = None base_indexes = [] + base_meta = {} # Subclassed documents inherit collection from superclass for base in bases: if hasattr(base, '_meta') and 'collection' in base._meta: collection = base._meta['collection'] + # Propagate index options. + for key in ('index_background', 'index_drop_dups', 'index_opts'): + if key in base._meta: + base_meta[key] = base._meta[key] + id_field = id_field or base._meta.get('id_field') base_indexes += base._meta.get('indexes', []) @@ -244,7 +254,12 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): 'ordering': [], # default ordering applied at runtime 'indexes': [], # indexes to be ensured at runtime 'id_field': id_field, + 'index_background': False, + 'index_drop_dups': False, + 'index_opts': {}, + 'queryset_class': QuerySet, } + meta.update(base_meta) # Apply document-defined meta options meta.update(attrs.get('meta', {})) @@ -253,18 +268,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Set up collection manager, needs the class to have fields so use # DocumentMetaclass before instantiating CollectionManager object new_class = super_new(cls, name, bases, attrs) - new_class.objects = QuerySetManager() + + # Provide a default queryset unless one has been manually provided + if not hasattr(new_class, 'objects'): + new_class.objects = QuerySetManager() user_indexes = [QuerySet._build_index_spec(new_class, spec) for spec in meta['indexes']] + base_indexes new_class._meta['indexes'] = user_indexes - + unique_indexes = [] for field_name, field in new_class._fields.items(): # Generate a list of indexes needed by uniqueness constraints if field.unique: field.required = True - unique_fields = [field_name] + unique_fields = [field.db_field] # Add any unique_with fields to the back of the index spec if field.unique_with: @@ -305,8 +323,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class._fields['id'] = ObjectIdField(db_field='_id') new_class.id = new_class._fields['id'] - _document_registry[name] = new_class - return new_class @@ -314,14 +330,17 @@ class BaseDocument(object): def __init__(self, **values): self._data = {} + # Assign default values to instance + for attr_name in self._fields.keys(): + # Use default value if present + value = getattr(self, attr_name, None) + setattr(self, attr_name, value) # Assign initial values to instance - for attr_name, attr_value in self._fields.items(): - if attr_name in values: + for attr_name in values.keys(): + try: setattr(self, attr_name, values.pop(attr_name)) - else: - # Use default value if present - value = getattr(self, attr_name, None) - setattr(self, attr_name, value) + except AttributeError: + pass def validate(self): """Ensure that all fields' values are valid and that required fields @@ -337,8 +356,8 @@ class BaseDocument(object): try: field._validate(value) except (ValueError, AttributeError, AssertionError), e: - raise ValidationError('Invalid value for field of type "' + - field.__class__.__name__ + '"') + raise ValidationError('Invalid value for field of type "%s": %s' + % (field.__class__.__name__, value)) elif field.required: raise ValidationError('Field "%s" is required' % field.name) @@ -357,6 +376,16 @@ class BaseDocument(object): all_subclasses.update(subclass._get_subclasses()) return all_subclasses + @apply + def pk(): + """Primary key alias + """ + def fget(self): + return getattr(self, self._meta['id_field']) + def fset(self, value): + return setattr(self, self._meta['id_field'], value) + return property(fget, fset) + def __iter__(self): return iter(self._fields) @@ -413,8 +442,10 @@ class BaseDocument(object): self._meta.get('allow_inheritance', True) == False): data['_cls'] = self._class_name data['_types'] = self._superclasses.keys() + [self._class_name] + if data.has_key('_id') and not data['_id']: + del data['_id'] return data - + @classmethod def _from_son(cls, son): """Create an instance of a Document (subclass) from a PyMongo SON. @@ -444,12 +475,14 @@ class BaseDocument(object): for field_name, field in cls._fields.items(): if field.db_field in data: - data[field_name] = field.to_python(data[field.db_field]) + value = data[field.db_field] + data[field_name] = (value if value is None + else field.to_python(value)) obj = cls(**data) obj._present_fields = present_fields return obj - + def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id'): if self.id == other.id: diff --git a/mongoengine/connection.py b/mongoengine/connection.py index ec3bf784..814fde13 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,62 +1,71 @@ from pymongo import Connection - +import multiprocessing __all__ = ['ConnectionError', 'connect'] -_connection_settings = { +_connection_defaults = { 'host': 'localhost', 'port': 27017, } -_connection = None +_connection = {} +_connection_settings = _connection_defaults.copy() _db_name = None _db_username = None _db_password = None -_db = None +_db = {} class ConnectionError(Exception): pass -def _get_connection(): +def _get_connection(reconnect=False): global _connection + identity = get_identity() # Connect to the database if not already connected - if _connection is None: + if _connection.get(identity) is None or reconnect: try: - _connection = Connection(**_connection_settings) + _connection[identity] = Connection(**_connection_settings) except: raise ConnectionError('Cannot connect to the database') - return _connection + return _connection[identity] -def _get_db(): +def _get_db(reconnect=False): global _db, _connection + identity = get_identity() # Connect if not already connected - if _connection is None: - _connection = _get_connection() + if _connection.get(identity) is None or reconnect: + _connection[identity] = _get_connection(reconnect=reconnect) - if _db is None: + if _db.get(identity) is None or reconnect: # _db_name will be None if the user hasn't called connect() if _db_name is None: raise ConnectionError('Not connected to the database') # Get DB from current connection and authenticate if necessary - _db = _connection[_db_name] + _db[identity] = _connection[identity][_db_name] if _db_username and _db_password: - _db.authenticate(_db_username, _db_password) + _db[identity].authenticate(_db_username, _db_password) - return _db + return _db[identity] +def get_identity(): + identity = multiprocessing.current_process()._identity + identity = 0 if not identity else identity[0] + return identity + def connect(db, username=None, password=None, **kwargs): """Connect to the database specified by the 'db' argument. Connection settings may be provided here as well if the database is not running on the default port on localhost. If authentication is needed, provide username and password arguments as well. """ - global _connection_settings, _db_name, _db_username, _db_password - _connection_settings.update(kwargs) + global _connection_settings, _db_name, _db_username, _db_password, _db + _connection_settings = dict(_connection_defaults, **kwargs) _db_name = db _db_username = username _db_password = password - return _get_db() \ No newline at end of file + return _get_db(reconnect=True) + diff --git a/mongoengine/django/auth.py b/mongoengine/django/auth.py index d4b0ff0b..595852ef 100644 --- a/mongoengine/django/auth.py +++ b/mongoengine/django/auth.py @@ -32,6 +32,9 @@ class User(Document): last_login = DateTimeField(default=datetime.datetime.now) date_joined = DateTimeField(default=datetime.datetime.now) + def __unicode__(self): + return self.username + def get_full_name(self): """Returns the users first and last names, separated by a space. """ @@ -72,10 +75,9 @@ class User(Document): email address. """ now = datetime.datetime.now() - + # Normalize the address by lowercasing the domain part of the email # address. - # Not sure why we'r allowing null email when its not allowed in django if email is not None: try: email_name, domain_part = email.strip().split('@', 1) @@ -83,12 +85,12 @@ class User(Document): pass else: email = '@'.join([email_name, domain_part.lower()]) - + user = User(username=username, email=email, date_joined=now) user.set_password(password) user.save() return user - + def get_and_delete_messages(self): return [] diff --git a/mongoengine/django/storage.py b/mongoengine/django/storage.py new file mode 100644 index 00000000..341455cd --- /dev/null +++ b/mongoengine/django/storage.py @@ -0,0 +1,112 @@ +import os +import itertools +import urlparse + +from mongoengine import * +from django.conf import settings +from django.core.files.storage import Storage +from django.core.exceptions import ImproperlyConfigured + + +class FileDocument(Document): + """A document used to store a single file in GridFS. + """ + file = FileField() + + +class GridFSStorage(Storage): + """A custom storage backend to store files in GridFS + """ + + def __init__(self, base_url=None): + + if base_url is None: + base_url = settings.MEDIA_URL + self.base_url = base_url + self.document = FileDocument + self.field = 'file' + + def delete(self, name): + """Deletes the specified file from the storage system. + """ + if self.exists(name): + doc = self.document.objects.first() + field = getattr(doc, self.field) + self._get_doc_with_name(name).delete() # Delete the FileField + field.delete() # Delete the FileDocument + + def exists(self, name): + """Returns True if a file referened by the given name already exists in the + storage system, or False if the name is available for a new file. + """ + doc = self._get_doc_with_name(name) + if doc: + field = getattr(doc, self.field) + return bool(field.name) + else: + return False + + def listdir(self, path=None): + """Lists the contents of the specified path, returning a 2-tuple of lists; + the first item being directories, the second item being files. + """ + def name(doc): + return getattr(doc, self.field).name + docs = self.document.objects + return [], [name(d) for d in docs if name(d)] + + def size(self, name): + """Returns the total size, in bytes, of the file specified by name. + """ + doc = self._get_doc_with_name(name) + if doc: + return getattr(doc, self.field).length + else: + raise ValueError("No such file or directory: '%s'" % name) + + def url(self, name): + """Returns an absolute URL where the file's contents can be accessed + directly by a web browser. + """ + if self.base_url is None: + raise ValueError("This file is not accessible via a URL.") + return urlparse.urljoin(self.base_url, name).replace('\\', '/') + + def _get_doc_with_name(self, name): + """Find the documents in the store with the given name + """ + docs = self.document.objects + doc = [d for d in docs if getattr(d, self.field).name == name] + if doc: + return doc[0] + else: + return None + + def _open(self, name, mode='rb'): + doc = self._get_doc_with_name(name) + if doc: + return getattr(doc, self.field) + else: + raise ValueError("No file found with the name '%s'." % name) + + def get_available_name(self, name): + """Returns a filename that's free on the target storage system, and + available for new content to be written to. + """ + file_root, file_ext = os.path.splitext(name) + # If the filename already exists, add an underscore and a number (before + # the file extension, if one exists) to the filename until the generated + # filename doesn't exist. + count = itertools.count(1) + while self.exists(name): + # file_ext includes the dot. + name = os.path.join("%s_%s%s" % (file_root, count.next(), file_ext)) + + return name + + def _save(self, name, content): + doc = self.document() + getattr(doc, self.field).put(content, filename=name) + doc.save() + + return name diff --git a/mongoengine/django/tests.py b/mongoengine/django/tests.py new file mode 100644 index 00000000..a8d7c7ff --- /dev/null +++ b/mongoengine/django/tests.py @@ -0,0 +1,21 @@ +#coding: utf-8 +from django.test import TestCase +from django.conf import settings + +from mongoengine import connect + +class MongoTestCase(TestCase): + """ + TestCase class that clear the collection between the tests + """ + db_name = 'test_%s' % settings.MONGO_DATABASE_NAME + def __init__(self, methodName='runtest'): + self.db = connect(self.db_name) + super(MongoTestCase, self).__init__(methodName) + + def _post_teardown(self): + super(MongoTestCase, self)._post_teardown() + for collection in self.db.collection_names(): + if collection == 'system.indexes': + continue + self.db.drop_collection(collection) diff --git a/mongoengine/document.py b/mongoengine/document.py index e5dec145..fef737db 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -15,7 +15,7 @@ class EmbeddedDocument(BaseDocument): fields on :class:`~mongoengine.Document`\ s through the :class:`~mongoengine.EmbeddedDocumentField` field type. """ - + __metaclass__ = DocumentMetaclass @@ -56,7 +56,7 @@ class Document(BaseDocument): __metaclass__ = TopLevelDocumentMetaclass - def save(self, safe=True, force_insert=False): + def save(self, safe=True, force_insert=False, validate=True): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -67,8 +67,10 @@ class Document(BaseDocument): :param safe: check if the operation succeeded before returning :param force_insert: only try to create a new document, don't allow updates of existing documents + :param validate: validates the document; set to ``False`` for skiping """ - self.validate() + if validate: + self.validate() doc = self.to_mongo() try: collection = self.__class__.objects._collection @@ -119,23 +121,23 @@ class Document(BaseDocument): class MapReduceDocument(object): """A document returned from a map/reduce query. - + :param collection: An instance of :class:`~pymongo.Collection` :param key: Document/result key, often an instance of :class:`~pymongo.objectid.ObjectId`. If supplied as an ``ObjectId`` found in the given ``collection``, the object can be accessed via the ``object`` property. :param value: The result(s) for this key. - + .. versionadded:: 0.3 """ - + def __init__(self, document, collection, key, value): self._document = document self._collection = collection self.key = key self.value = value - + @property def object(self): """Lazy-load the object referenced by ``self.key``. ``self.key`` @@ -143,7 +145,7 @@ class MapReduceDocument(object): """ id_field = self._document()._meta['id_field'] id_field_type = type(id_field) - + if not isinstance(self.key, id_field_type): try: self.key = id_field_type(self.key) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 7883f78a..e95fd65e 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -10,13 +10,16 @@ import pymongo.son import pymongo.binary import datetime import decimal +import gridfs +import warnings +import types __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'ObjectIdField', 'ReferenceField', 'ValidationError', - 'DecimalField', 'URLField', 'GenericReferenceField', - 'BinaryField', 'SortedListField', 'EmailField', 'GeoLocationField'] + 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', + 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -39,7 +42,7 @@ class StringField(BaseField): if self.max_length is not None and len(value) > self.max_length: raise ValidationError('String value is too long') - + if self.min_length is not None and len(value) < self.min_length: raise ValidationError('String value is too short') @@ -67,6 +70,9 @@ class StringField(BaseField): regex = r'%s$' elif op == 'exact': regex = r'^%s$' + + # escape unsafe characters which could lead to a re.error + value = re.escape(value) value = re.compile(regex % value, flags) return value @@ -103,8 +109,11 @@ class URLField(StringField): message = 'This URL appears to be a broken link: %s' % e raise ValidationError(message) + class EmailField(StringField): """A field that validates input as an E-Mail-Address. + + .. versionadded:: 0.4 """ EMAIL_REGEX = re.compile( @@ -112,11 +121,12 @@ class EmailField(StringField): r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) - + def validate(self, value): if not EmailField.EMAIL_REGEX.match(value): raise ValidationError('Invalid Mail-address: %s' % value) + class IntField(BaseField): """An integer field. """ @@ -140,6 +150,7 @@ class IntField(BaseField): if self.max_value is not None and value > self.max_value: raise ValidationError('Integer value is too large') + class FloatField(BaseField): """An floating point number field. """ @@ -176,7 +187,7 @@ class DecimalField(BaseField): if not isinstance(value, basestring): value = unicode(value) return decimal.Decimal(value) - + def to_mongo(self, value): return unicode(value) @@ -195,6 +206,7 @@ class DecimalField(BaseField): if self.max_value is not None and value > self.max_value: raise ValidationError('Decimal value is too large') + class BooleanField(BaseField): """A boolean field type. @@ -207,6 +219,7 @@ class BooleanField(BaseField): def validate(self, value): assert isinstance(value, bool) + class DateTimeField(BaseField): """A datetime field. """ @@ -214,38 +227,49 @@ class DateTimeField(BaseField): def validate(self, value): assert isinstance(value, datetime.datetime) + class EmbeddedDocumentField(BaseField): """An embedded document field. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. """ - def __init__(self, document, **kwargs): - if not issubclass(document, EmbeddedDocument): - raise ValidationError('Invalid embedded document class provided ' - 'to an EmbeddedDocumentField') - self.document = document + def __init__(self, document_type, **kwargs): + if not isinstance(document_type, basestring): + if not issubclass(document_type, EmbeddedDocument): + raise ValidationError('Invalid embedded document class ' + 'provided to an EmbeddedDocumentField') + self.document_type_obj = document_type super(EmbeddedDocumentField, self).__init__(**kwargs) + @property + def document_type(self): + if isinstance(self.document_type_obj, basestring): + if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: + self.document_type_obj = self.owner_document + else: + self.document_type_obj = get_document(self.document_type_obj) + return self.document_type_obj + def to_python(self, value): - if not isinstance(value, self.document): - return self.document._from_son(value) + if not isinstance(value, self.document_type): + return self.document_type._from_son(value) return value def to_mongo(self, value): - return self.document.to_mongo(value) + return self.document_type.to_mongo(value) def validate(self, value): """Make sure that the document instance is an instance of the EmbeddedDocument subclass provided when the document was defined. """ # Using isinstance also works for subclasses of self.document - if not isinstance(value, self.document): + if not isinstance(value, self.document_type): raise ValidationError('Invalid embedded document instance ' 'provided to an EmbeddedDocumentField') - self.document.validate(value) + self.document_type.validate(value) def lookup_member(self, member_name): - return self.document._fields.get(member_name) + return self.document_type._fields.get(member_name) def prepare_query_value(self, op, value): return self.to_mongo(value) @@ -264,6 +288,7 @@ class ListField(BaseField): raise ValidationError('Argument to ListField constructor must be ' 'a valid field') self.field = field + kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) def __get__(self, instance, owner): @@ -318,20 +343,32 @@ class ListField(BaseField): try: [self.field.validate(item) for item in value] except Exception, err: - raise ValidationError('Invalid ListField item (%s)' % str(err)) + raise ValidationError('Invalid ListField item (%s)' % str(item)) def prepare_query_value(self, op, value): if op in ('set', 'unset'): - return [self.field.to_mongo(v) for v in value] - return self.field.to_mongo(value) + return [self.field.prepare_query_value(op, v) for v in value] + return self.field.prepare_query_value(op, value) def lookup_member(self, member_name): return self.field.lookup_member(member_name) + def _set_owner_document(self, owner_document): + self.field.owner_document = owner_document + self._owner_document = owner_document + + def _get_owner_document(self, owner_document): + self._owner_document = owner_document + + owner_document = property(_get_owner_document, _set_owner_document) + + class SortedListField(ListField): """A ListField that sorts the contents of its list before writing to the database in order to ensure that a sorted list is always retrieved. + + .. versionadded:: 0.4 """ _ordering = None @@ -343,9 +380,11 @@ class SortedListField(ListField): def to_mongo(self, value): if self._ordering is not None: - return sorted([self.field.to_mongo(item) for item in value], key=itemgetter(self._ordering)) + return sorted([self.field.to_mongo(item) for item in value], + key=itemgetter(self._ordering)) return sorted([self.field.to_mongo(item) for item in value]) + class DictField(BaseField): """A dictionary field that wraps a standard Python dictionary. This is similar to an embedded document, but the structure is not defined. @@ -356,6 +395,7 @@ class DictField(BaseField): def __init__(self, basecls=None, *args, **kwargs): self.basecls = basecls or BaseField assert issubclass(self.basecls, BaseField) + kwargs.setdefault('default', lambda: {}) super(DictField, self).__init__(*args, **kwargs) def validate(self, value): @@ -372,24 +412,6 @@ class DictField(BaseField): def lookup_member(self, member_name): return self.basecls(db_field=member_name) -class GeoLocationField(DictField): - """Supports geobased fields""" - - def validate(self, value): - """Make sure that a geo-value is of type (x, y) - """ - if not isinstance(value, tuple) and not isinstance(value, list): - raise ValidationError('GeoLocationField can only hold tuples or lists of (x, y)') - - if len(value) <> 2: - raise ValidationError('GeoLocationField must have exactly two elements (x, y)') - - def to_mongo(self, value): - return {'x': value[0], 'y': value[1]} - - def to_python(self, value): - return value.keys() - class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on access (lazily). @@ -401,7 +423,6 @@ class ReferenceField(BaseField): raise ValidationError('Argument to ReferenceField constructor ' 'must be a document class or a string') self.document_type_obj = document_type - self.document_obj = None super(ReferenceField, self).__init__(**kwargs) @property @@ -501,7 +522,8 @@ class GenericReferenceField(BaseField): return {'_cls': document.__class__.__name__, '_ref': ref} def prepare_query_value(self, op, value): - return self.to_mongo(value)['_ref'] + return self.to_mongo(value) + class BinaryField(BaseField): """A binary data field. @@ -523,3 +545,161 @@ class BinaryField(BaseField): if self.max_bytes is not None and len(value) > self.max_bytes: raise ValidationError('Binary value is too long') + + +class GridFSError(Exception): + pass + + +class GridFSProxy(object): + """Proxy object to handle writing and reading of files to and from GridFS + + .. versionadded:: 0.4 + """ + + def __init__(self, grid_id=None): + self.fs = gridfs.GridFS(_get_db()) # Filesystem instance + self.newfile = None # Used for partial writes + self.grid_id = grid_id # Store GridFS id for file + + def __getattr__(self, name): + obj = self.get() + if name in dir(obj): + return getattr(obj, name) + raise AttributeError + + def __get__(self, instance, value): + return self + + def get(self, id=None): + if id: + self.grid_id = id + try: + return self.fs.get(id or self.grid_id) + except: + # File has been deleted + return None + + def new_file(self, **kwargs): + self.newfile = self.fs.new_file(**kwargs) + self.grid_id = self.newfile._id + + def put(self, file, **kwargs): + if self.grid_id: + raise GridFSError('This document already has a file. Either delete ' + 'it or call replace to overwrite it') + self.grid_id = self.fs.put(file, **kwargs) + + def write(self, string): + if self.grid_id: + if not self.newfile: + raise GridFSError('This document already has a file. Either ' + 'delete it or call replace to overwrite it') + else: + self.new_file() + self.newfile.write(string) + + def writelines(self, lines): + if not self.newfile: + self.new_file() + self.grid_id = self.newfile._id + self.newfile.writelines(lines) + + def read(self): + try: + return self.get().read() + except: + return None + + def delete(self): + # Delete file from GridFS, FileField still remains + self.fs.delete(self.grid_id) + self.grid_id = None + + def replace(self, file, **kwargs): + self.delete() + self.put(file, **kwargs) + + def close(self): + if self.newfile: + self.newfile.close() + else: + msg = "The close() method is only necessary after calling write()" + warnings.warn(msg) + + +class FileField(BaseField): + """A GridFS storage field. + + .. versionadded:: 0.4 + """ + + def __init__(self, **kwargs): + super(FileField, self).__init__(**kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self + + # Check if a file already exists for this model + grid_file = instance._data.get(self.name) + self.grid_file = grid_file + if self.grid_file: + return self.grid_file + return GridFSProxy() + + def __set__(self, instance, value): + if isinstance(value, file) or isinstance(value, str): + # using "FileField() = file/string" notation + grid_file = instance._data.get(self.name) + # If a file already exists, delete it + if grid_file: + try: + grid_file.delete() + except: + pass + # Create a new file with the new data + grid_file.put(value) + else: + # Create a new proxy object as we don't already have one + instance._data[self.name] = GridFSProxy() + instance._data[self.name].put(value) + else: + instance._data[self.name] = value + + def to_mongo(self, value): + # Store the GridFS file id in MongoDB + if isinstance(value, GridFSProxy) and value.grid_id is not None: + return value.grid_id + return None + + def to_python(self, value): + if value is not None: + return GridFSProxy(value) + + def validate(self, value): + if value.grid_id is not None: + assert isinstance(value, GridFSProxy) + assert isinstance(value.grid_id, pymongo.objectid.ObjectId) + + +class GeoPointField(BaseField): + """A list storing a latitude and longitude. + + .. versionadded:: 0.4 + """ + + _geo_index = True + + def validate(self, value): + """Make sure that a geo-value is of type (x, y) + """ + if not isinstance(value, (list, tuple)): + raise ValidationError('GeoPointField can only accept tuples or ' + 'lists of (x, y)') + + if not len(value) == 2: + raise ValidationError('Value must be a two-dimensional point.') + if (not isinstance(value[0], (float, int)) and + not isinstance(value[1], (float, int))): + raise ValidationError('Both values in point must be float or int.') diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index faf0cf44..519dda03 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -1,11 +1,13 @@ from connection import _get_db +import pprint import pymongo import pymongo.code import pymongo.dbref import pymongo.objectid import re import copy +import itertools __all__ = ['queryset_manager', 'Q', 'InvalidQueryError', 'InvalidCollectionError'] @@ -17,6 +19,7 @@ REPR_OUTPUT_SIZE = 20 class DoesNotExist(Exception): pass + class MultipleObjectsReturned(Exception): pass @@ -28,50 +31,192 @@ class InvalidQueryError(Exception): class OperationError(Exception): pass + class InvalidCollectionError(Exception): pass + RE_TYPE = type(re.compile('')) -class Q(object): +class QNodeVisitor(object): + """Base visitor class for visiting Q-object nodes in a query tree. + """ - OR = '||' - AND = '&&' - OPERATORS = { - 'eq': ('((this.%(field)s instanceof Array) && ' - ' this.%(field)s.indexOf(%(value)s) != -1) ||' - ' this.%(field)s == %(value)s'), - 'ne': 'this.%(field)s != %(value)s', - 'gt': 'this.%(field)s > %(value)s', - 'gte': 'this.%(field)s >= %(value)s', - 'lt': 'this.%(field)s < %(value)s', - 'lte': 'this.%(field)s <= %(value)s', - 'lte': 'this.%(field)s <= %(value)s', - 'in': '%(value)s.indexOf(this.%(field)s) != -1', - 'nin': '%(value)s.indexOf(this.%(field)s) == -1', - 'mod': '%(field)s %% %(value)s', - 'all': ('%(value)s.every(function(a){' - 'return this.%(field)s.indexOf(a) != -1 })'), - 'size': 'this.%(field)s.length == %(value)s', - 'exists': 'this.%(field)s != null', - 'regex_eq': '%(value)s.test(this.%(field)s)', - 'regex_ne': '!%(value)s.test(this.%(field)s)', - } + def visit_combination(self, combination): + """Called by QCombination objects. + """ + return combination - def __init__(self, **query): - self.query = [query] + def visit_query(self, query): + """Called by (New)Q objects. + """ + return query - def _combine(self, other, op): - obj = Q() - if not other.query[0]: + +class SimplificationVisitor(QNodeVisitor): + """Simplifies query trees by combinging unnecessary 'and' connection nodes + into a single Q-object. + """ + + def visit_combination(self, combination): + if combination.operation == combination.AND: + # The simplification only applies to 'simple' queries + if all(isinstance(node, Q) for node in combination.children): + queries = [node.query for node in combination.children] + return Q(**self._query_conjunction(queries)) + return combination + + def _query_conjunction(self, queries): + """Merges query dicts - effectively &ing them together. + """ + query_ops = set() + combined_query = {} + for query in queries: + ops = set(query.keys()) + # Make sure that the same operation isn't applied more than once + # to a single field + intersection = ops.intersection(query_ops) + if intersection: + msg = 'Duplicate query contitions: ' + raise InvalidQueryError(msg + ', '.join(intersection)) + + query_ops.update(ops) + combined_query.update(copy.deepcopy(query)) + return combined_query + + +class QueryTreeTransformerVisitor(QNodeVisitor): + """Transforms the query tree in to a form that may be used with MongoDB. + """ + + def visit_combination(self, combination): + if combination.operation == combination.AND: + # MongoDB doesn't allow us to have too many $or operations in our + # queries, so the aim is to move the ORs up the tree to one + # 'master' $or. Firstly, we must find all the necessary parts (part + # of an AND combination or just standard Q object), and store them + # separately from the OR parts. + or_groups = [] + and_parts = [] + for node in combination.children: + if isinstance(node, QCombination): + if node.operation == node.OR: + # Any of the children in an $or component may cause + # the query to succeed + or_groups.append(node.children) + elif node.operation == node.AND: + and_parts.append(node) + elif isinstance(node, Q): + and_parts.append(node) + + # Now we combine the parts into a usable query. AND together all of + # the necessary parts. Then for each $or part, create a new query + # that ANDs the necessary part with the $or part. + clauses = [] + for or_group in itertools.product(*or_groups): + q_object = reduce(lambda a, b: a & b, and_parts, Q()) + q_object = reduce(lambda a, b: a & b, or_group, q_object) + clauses.append(q_object) + + # Finally, $or the generated clauses in to one query. Each of the + # clauses is sufficient for the query to succeed. + return reduce(lambda a, b: a | b, clauses, Q()) + + if combination.operation == combination.OR: + children = [] + # Crush any nested ORs in to this combination as MongoDB doesn't + # support nested $or operations + for node in combination.children: + if (isinstance(node, QCombination) and + node.operation == combination.OR): + children += node.children + else: + children.append(node) + combination.children = children + + return combination + + +class QueryCompilerVisitor(QNodeVisitor): + """Compiles the nodes in a query tree to a PyMongo-compatible query + dictionary. + """ + + def __init__(self, document): + self.document = document + + def visit_combination(self, combination): + if combination.operation == combination.OR: + return {'$or': combination.children} + elif combination.operation == combination.AND: + return self._mongo_query_conjunction(combination.children) + return combination + + def visit_query(self, query): + return QuerySet._transform_query(self.document, **query.query) + + def _mongo_query_conjunction(self, queries): + """Merges Mongo query dicts - effectively &ing them together. + """ + combined_query = {} + for query in queries: + for field, ops in query.items(): + if field not in combined_query: + combined_query[field] = ops + else: + # The field is already present in the query the only way + # we can merge is if both the existing value and the new + # value are operation dicts, reject anything else + if (not isinstance(combined_query[field], dict) or + not isinstance(ops, dict)): + message = 'Conflicting values for ' + field + raise InvalidQueryError(message) + + current_ops = set(combined_query[field].keys()) + new_ops = set(ops.keys()) + # Make sure that the same operation isn't applied more than + # once to a single field + intersection = current_ops.intersection(new_ops) + if intersection: + msg = 'Duplicate query contitions: ' + raise InvalidQueryError(msg + ', '.join(intersection)) + + # Right! We've got two non-overlapping dicts of operations! + combined_query[field].update(copy.deepcopy(ops)) + return combined_query + + +class QNode(object): + """Base class for nodes in query trees. + """ + + AND = 0 + OR = 1 + + def to_query(self, document): + query = self.accept(SimplificationVisitor()) + query = query.accept(QueryTreeTransformerVisitor()) + query = query.accept(QueryCompilerVisitor(document)) + return query + + def accept(self, visitor): + raise NotImplementedError + + def _combine(self, other, operation): + """Combine this node with another node into a QCombination object. + """ + if other.empty: return self - if self.query[0]: - obj.query = (['('] + copy.deepcopy(self.query) + [op] + - copy.deepcopy(other.query) + [')']) - else: - obj.query = copy.deepcopy(other.query) - return obj + + if self.empty: + return other + + return QCombination(operation, [self, other]) + + @property + def empty(self): + return False def __or__(self, other): return self._combine(other, self.OR) @@ -79,70 +224,49 @@ class Q(object): def __and__(self, other): return self._combine(other, self.AND) - def as_js(self, document): - js = [] - js_scope = {} - for i, item in enumerate(self.query): - if isinstance(item, dict): - item_query = QuerySet._transform_query(document, **item) - # item_query will values will either be a value or a dict - js.append(self._item_query_as_js(item_query, js_scope, i)) + +class QCombination(QNode): + """Represents the combination of several conditions by a given logical + operator. + """ + + def __init__(self, operation, children): + self.operation = operation + self.children = [] + for node in children: + # If the child is a combination of the same type, we can merge its + # children directly into this combinations children + if isinstance(node, QCombination) and node.operation == operation: + self.children += node.children else: - js.append(item) - return pymongo.code.Code(' '.join(js), js_scope) + self.children.append(node) - def _item_query_as_js(self, item_query, js_scope, item_num): - # item_query will be in one of the following forms - # {'age': 25, 'name': 'Test'} - # {'age': {'$lt': 25}, 'name': {'$in': ['Test', 'Example']} - # {'age': {'$lt': 25, '$gt': 18}} - js = [] - for i, (key, value) in enumerate(item_query.items()): - op = 'eq' - # Construct a variable name for the value in the JS - value_name = 'i%sf%s' % (item_num, i) - if isinstance(value, dict): - # Multiple operators for this field - for j, (op, value) in enumerate(value.items()): - # Create a custom variable name for this operator - op_value_name = '%so%s' % (value_name, j) - # Construct the JS that uses this op - value, operation_js = self._build_op_js(op, key, value, - op_value_name) - # Update the js scope with the value for this op - js_scope[op_value_name] = value - js.append(operation_js) - else: - # Construct the JS for this field - value, field_js = self._build_op_js(op, key, value, value_name) - js_scope[value_name] = value - js.append(field_js) - print ' && '.join(js) - return ' && '.join(js) + def accept(self, visitor): + for i in range(len(self.children)): + self.children[i] = self.children[i].accept(visitor) - def _build_op_js(self, op, key, value, value_name): - """Substitute the values in to the correct chunk of Javascript. - """ - print op, key, value, value_name - if isinstance(value, RE_TYPE): - # Regexes are handled specially - if op.strip('$') == 'ne': - op_js = Q.OPERATORS['regex_ne'] - else: - op_js = Q.OPERATORS['regex_eq'] - else: - op_js = Q.OPERATORS[op.strip('$')] + return visitor.visit_combination(self) - # Comparing two ObjectIds in Javascript doesn't work.. - if isinstance(value, pymongo.objectid.ObjectId): - value = unicode(value) + @property + def empty(self): + return not bool(self.children) + + +class Q(QNode): + """A simple query object, used in a query tree to build up more complex + query structures. + """ + + def __init__(self, **query): + self.query = query + + def accept(self, visitor): + return visitor.visit_query(self) + + @property + def empty(self): + return not bool(self.query) - # Perform the substitution - operation_js = op_js % { - 'field': key, - 'value': value_name - } - return value, operation_js class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, @@ -153,20 +277,32 @@ class QuerySet(object): self._document = document self._collection_obj = collection self._accessed_collection = False - self._query = {} + self._mongo_query = None + self._query_obj = Q() + self._initial_query = {} self._where_clause = None self._loaded_fields = [] self._ordering = [] - + self._snapshot = False + self._timeout = True + # If inheritance is allowed, only return instances and instances of # subclasses of the class being used if document._meta.get('allow_inheritance'): - self._query = {'_types': self._document._class_name} + self._initial_query = {'_types': self._document._class_name} self._cursor_obj = None self._limit = None self._skip = None - def ensure_index(self, key_or_list): + @property + def _query(self): + if self._mongo_query is None: + self._mongo_query = self._query_obj.to_query(self._document) + self._mongo_query.update(self._initial_query) + return self._mongo_query + + def ensure_index(self, key_or_list, drop_dups=False, background=False, + **kwargs): """Ensure that the given indexes are in place. :param key_or_list: a single index key or a list of index keys (to @@ -174,7 +310,8 @@ class QuerySet(object): or a **-** to determine the index ordering """ index_list = QuerySet._build_index_spec(self._document, key_or_list) - self._collection.ensure_index(index_list) + self._collection.ensure_index(index_list, drop_dups=drop_dups, + background=background) return self @classmethod @@ -222,10 +359,14 @@ class QuerySet(object): objects, only the last one will be used :param query: Django-style query keyword arguments """ + #if q_obj: + #self._where_clause = q_obj.as_js(self._document) + query = Q(**query) if q_obj: - self._where_clause = q_obj.as_js(self._document) - query = QuerySet._transform_query(_doc_cls=self._document, **query) - self._query.update(query) + query &= q_obj + self._query_obj &= query + self._mongo_query = None + self._cursor_obj = None return self def filter(self, *q_objs, **query): @@ -233,6 +374,10 @@ class QuerySet(object): """ return self.__call__(*q_objs, **query) + def all(self): + """Returns all documents.""" + return self.__call__() + @property def _collection(self): """Property that returns the collection object. This allows us to @@ -240,33 +385,45 @@ class QuerySet(object): """ if not self._accessed_collection: self._accessed_collection = True - + + background = self._document._meta.get('index_background', False) + drop_dups = self._document._meta.get('index_drop_dups', False) + index_opts = self._document._meta.get('index_options', {}) + # Ensure document-defined indexes are created if self._document._meta['indexes']: for key_or_list in self._document._meta['indexes']: - #self.ensure_index(key_or_list) - self._collection.ensure_index(key_or_list) + self._collection.ensure_index(key_or_list, + background=background, **index_opts) # Ensure indexes created by uniqueness constraints for index in self._document._meta['unique_indexes']: - self._collection.ensure_index(index, unique=True) + self._collection.ensure_index(index, unique=True, + background=background, drop_dups=drop_dups, **index_opts) # If _types is being used (for polymorphism), it needs an index if '_types' in self._query: - self._collection.ensure_index('_types') - + self._collection.ensure_index('_types', + background=background, **index_opts) + # Ensure all needed field indexes are created - for field_name, field_instance in self._document._fields.iteritems(): - if field_instance.__class__.__name__ == 'GeoLocationField': - self._collection.ensure_index([(field_name, pymongo.GEO2D),]) + for field in self._document._fields.values(): + if field.__class__._geo_index: + index_spec = [(field.db_field, pymongo.GEO2D)] + self._collection.ensure_index(index_spec, + background=background, **index_opts) + return self._collection_obj @property def _cursor(self): if self._cursor_obj is None: - cursor_args = {} + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout, + } if self._loaded_fields: - cursor_args = {'fields': self._loaded_fields} + cursor_args['fields'] = self._loaded_fields self._cursor_obj = self._collection.find(self._query, **cursor_args) # Apply where clauses to cursor @@ -291,6 +448,9 @@ class QuerySet(object): for field_name in parts: if field is None: # Look up first field from the document + if field_name == 'pk': + # Deal with "primary key" alias + field_name = document._meta['id_field'] field = document._fields[field_name] else: # Look up subfield on the previous field @@ -314,19 +474,31 @@ class QuerySet(object): """Transform a query from Django-style format to Mongo format. """ operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists', 'near'] - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', + 'all', 'size', 'exists', 'not'] + geo_operators = ['within_distance', 'within_box', 'near'] + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] mongo_query = {} for key, value in query.items(): + if key == "__raw__": + mongo_query.update(value) + continue + parts = key.split('__') + indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] + parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is op = None - if parts[-1] in operators + match_operators: + if parts[-1] in operators + match_operators + geo_operators: op = parts.pop() + negate = False + if parts[-1] == 'not': + parts.pop() + negate = True + if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) @@ -334,20 +506,34 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] - singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte'] + singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops += match_operators if op in singular_ops: value = field.prepare_query_value(op, value) - elif op in ('in', 'nin', 'all'): + elif op in ('in', 'nin', 'all', 'near'): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] - if field.__class__.__name__ == 'GenericReferenceField': - parts.append('_ref') + # if op and op not in match_operators: + if op: + if op in geo_operators: + if op == "within_distance": + value = {'$within': {'$center': value}} + elif op == "near": + value = {'$near': value} + elif op == 'within_box': + value = {'$within': {'$box': value}} + else: + raise NotImplementedError("Geo method '%s' has not " + "been implemented" % op) + elif op not in match_operators: + value = {'$' + op: value} - if op and op not in match_operators: - value = {'$' + op: value} + if negate: + value = {'$not': value} + for i, part in indices: + parts.insert(i, part) key = '.'.join(parts) if op is None or key not in mongo_query: mongo_query[key] = value @@ -405,6 +591,15 @@ class QuerySet(object): message = u'%d items returned, instead of 1' % count raise self._document.MultipleObjectsReturned(message) + def create(self, **kwargs): + """Create new object. Returns the saved object instance. + + .. versionadded:: 0.4 + """ + doc = self._document(**kwargs) + doc.save() + return doc + def first(self): """Retrieve the first object matching the query. """ @@ -429,7 +624,7 @@ class QuerySet(object): def in_bulk(self, object_ids): """Retrieve a set of documents by their ids. - + :param object_ids: a list or tuple of ``ObjectId``\ s :rtype: dict of ObjectIds as keys and collection-specific Document subclasses as values. @@ -441,7 +636,7 @@ class QuerySet(object): docs = self._collection.find({'_id': {'$in': object_ids}}) for doc in docs: doc_map[doc['_id']] = self._document._from_son(doc) - + return doc_map def next(self): @@ -595,12 +790,22 @@ class QuerySet(object): # Integer index provided elif isinstance(key, int): return self._document._from_son(self._cursor[key]) + raise AttributeError + + def distinct(self, field): + """Return a list of distinct values for a given field. + + :param field: the field to select distinct values from + + .. versionadded:: 0.4 + """ + return self._cursor.distinct(field) def only(self, *fields): """Load only a subset of this document's fields. :: - + post = BlogPost.objects(...).only("title") - + :param fields: fields to include .. versionadded:: 0.3 @@ -629,11 +834,13 @@ class QuerySet(object): """ key_list = [] for key in keys: + if not key: continue direction = pymongo.ASCENDING if key[0] == '-': direction = pymongo.DESCENDING if key[0] in ('-', '+'): key = key[1:] + key = key.replace('__', '.') key_list.append((key, direction)) self._ordering = key_list @@ -649,10 +856,23 @@ class QuerySet(object): plan = self._cursor.explain() if format: - import pprint plan = pprint.pformat(plan) return plan + def snapshot(self, enabled): + """Enable or disable snapshot mode when querying. + + :param enabled: whether or not snapshot mode is enabled + """ + self._snapshot = enabled + + def timeout(self, enabled): + """Enable or disable the default mongod timeout when querying. + + :param enabled: whether or not the timeout is used + """ + self._timeout = enabled + def delete(self, safe=False): """Delete the documents matched by the query. @@ -664,8 +884,8 @@ class QuerySet(object): def _transform_update(cls, _doc_cls=None, **update): """Transform an update spec from Django-style format to Mongo format. """ - operators = ['set', 'unset', 'inc', 'dec', 'push', 'push_all', 'pull', - 'pull_all'] + operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all', + 'pull', 'pull_all', 'add_to_set'] mongo_update = {} for key, value in update.items(): @@ -683,6 +903,8 @@ class QuerySet(object): op = 'inc' if value > 0: value = -value + elif op == 'add_to_set': + op = op.replace('_to_set', 'ToSet') if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] @@ -691,7 +913,8 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] - if op in (None, 'set', 'unset', 'push', 'pull'): + if op in (None, 'set', 'unset', 'pop', 'push', 'pull', + 'addToSet'): value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] @@ -710,7 +933,8 @@ class QuerySet(object): return mongo_update def update(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on the fields matched by the query. + """Perform an atomic update on the fields matched by the query. When + ``safe_update`` is used, the number of affected documents is returned. :param safe: check if the operation succeeded before returning :param update: Django-style update keyword arguments @@ -722,8 +946,10 @@ class QuerySet(object): update = QuerySet._transform_update(self._document, **update) try: - self._collection.update(self._query, update, safe=safe_update, - upsert=upsert, multi=True) + ret = self._collection.update(self._query, update, multi=True, + upsert=upsert, safe=safe_update) + if ret is not None and 'n' in ret: + return ret['n'] except pymongo.errors.OperationFailure, err: if unicode(err) == u'multi not coded yet': message = u'update() method requires MongoDB 1.1.3+' @@ -731,7 +957,8 @@ class QuerySet(object): raise OperationError(u'Update failed (%s)' % unicode(err)) def update_one(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on first field matched by the query. + """Perform an atomic update on first field matched by the query. When + ``safe_update`` is used, the number of affected documents is returned. :param safe: check if the operation succeeded before returning :param update: Django-style update keyword arguments @@ -743,11 +970,14 @@ class QuerySet(object): # Explicitly provide 'multi=False' to newer versions of PyMongo # as the default may change to 'True' if pymongo.version >= '1.1.1': - self._collection.update(self._query, update, safe=safe_update, - upsert=upsert, multi=False) + ret = self._collection.update(self._query, update, multi=False, + upsert=upsert, safe=safe_update) else: # Older versions of PyMongo don't support 'multi' - self._collection.update(self._query, update, safe=safe_update) + ret = self._collection.update(self._query, update, + safe=safe_update) + if ret is not None and 'n' in ret: + return ret['n'] except pymongo.errors.OperationFailure, e: raise OperationError(u'Update failed [%s]' % unicode(e)) @@ -840,7 +1070,7 @@ class QuerySet(object): var total = 0.0; var num = 0; db[collection].find(query).forEach(function(doc) { - if (doc[averageField]) { + if (doc[averageField] !== undefined) { total += doc[averageField]; num += 1; } @@ -850,20 +1080,27 @@ class QuerySet(object): """ return self.exec_js(average_func, field) - def item_frequencies(self, list_field, normalize=False): - """Returns a dictionary of all items present in a list field across + def item_frequencies(self, field, normalize=False): + """Returns a dictionary of all items present in a field across the whole queried set of documents, and their corresponding frequency. This is useful for generating tag clouds, or searching documents. - :param list_field: the list field to use + If the field is a :class:`~mongoengine.ListField`, the items within + each list will be counted individually. + + :param field: the field to use :param normalize: normalize the results so they add to 1.0 """ freq_func = """ - function(listField) { + function(field) { if (options.normalize) { var total = 0.0; db[collection].find(query).forEach(function(doc) { - total += doc[listField].length; + if (doc[field].constructor == Array) { + total += doc[field].length; + } else { + total++; + } }); } @@ -873,14 +1110,19 @@ class QuerySet(object): inc /= total; } db[collection].find(query).forEach(function(doc) { - doc[listField].forEach(function(item) { + if (doc[field].constructor == Array) { + doc[field].forEach(function(item) { + frequencies[item] = inc + (frequencies[item] || 0); + }); + } else { + var item = doc[field]; frequencies[item] = inc + (frequencies[item] || 0); - }); + } }); return frequencies; } """ - return self.exec_js(freq_func, list_field, normalize=normalize) + return self.exec_js(freq_func, field, normalize=normalize) def __repr__(self): limit = REPR_OUTPUT_SIZE + 1 @@ -896,7 +1138,7 @@ class QuerySetManager(object): def __init__(self, manager_func=None): self._manager_func = manager_func - self._collection = None + self._collections = {} def __get__(self, instance, owner): """Descriptor for instantiating a new QuerySet object when @@ -906,10 +1148,9 @@ class QuerySetManager(object): # Document class being used rather than a document object return self - if self._collection is None: - db = _get_db() - collection = owner._meta['collection'] - + db = _get_db() + collection = owner._meta['collection'] + if (db, collection) not in self._collections: # Create collection as a capped collection if specified if owner._meta['max_size'] or owner._meta['max_documents']: # Get max document limit and max byte size from meta @@ -917,10 +1158,10 @@ class QuerySetManager(object): max_documents = owner._meta['max_documents'] if collection in db.collection_names(): - self._collection = db[collection] + self._collections[(db, collection)] = db[collection] # The collection already exists, check if its capped # options match the specified capped options - options = self._collection.options() + options = self._collections[(db, collection)].options() if options.get('max') != max_documents or \ options.get('size') != max_size: msg = ('Cannot create collection "%s" as a capped ' @@ -931,12 +1172,15 @@ class QuerySetManager(object): opts = {'capped': True, 'size': max_size} if max_documents: opts['max'] = max_documents - self._collection = db.create_collection(collection, **opts) + self._collections[(db, collection)] = db.create_collection( + collection, **opts + ) else: - self._collection = db[collection] + self._collections[(db, collection)] = db[collection] # owner is the document that contains the QuerySetManager - queryset = QuerySet(owner, self._collection) + queryset_class = owner._meta['queryset_class'] or QuerySet + queryset = queryset_class(owner, self._collections[(db, collection)]) if self._manager_func: if self._manager_func.func_code.co_argcount == 1: queryset = self._manager_func(queryset) diff --git a/tests/document.py b/tests/document.py index 8bc907c5..c0567632 100644 --- a/tests/document.py +++ b/tests/document.py @@ -200,6 +200,37 @@ class DocumentTest(unittest.TestCase): Person.drop_collection() self.assertFalse(collection in self.db.collection_names()) + def test_inherited_collections(self): + """Ensure that subclassed documents don't override parents' collections. + """ + class Drink(Document): + name = StringField() + + class AlcoholicDrink(Drink): + meta = {'collection': 'booze'} + + class Drinker(Document): + drink = GenericReferenceField() + + Drink.drop_collection() + AlcoholicDrink.drop_collection() + Drinker.drop_collection() + + red_bull = Drink(name='Red Bull') + red_bull.save() + + programmer = Drinker(drink=red_bull) + programmer.save() + + beer = AlcoholicDrink(name='Beer') + beer.save() + + real_person = Drinker(drink=beer) + real_person.save() + + self.assertEqual(Drinker.objects[0].drink.name, red_bull.name) + self.assertEqual(Drinker.objects[1].drink.name, beer.name) + def test_capped_collection(self): """Ensure that capped collections work properly. """ @@ -264,11 +295,12 @@ class DocumentTest(unittest.TestCase): # Indexes are lazy so use list() to perform query list(BlogPost.objects) info = BlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info.values()) - self.assertTrue([('_types', 1), ('addDate', -1)] in info.values()) + in info) + self.assertTrue([('_types', 1), ('addDate', -1)] in info) # tags is a list field so it shouldn't have _types in the index - self.assertTrue([('tags', 1)] in info.values()) + self.assertTrue([('tags', 1)] in info) class ExtendedBlogPost(BlogPost): title = StringField() @@ -278,10 +310,11 @@ class DocumentTest(unittest.TestCase): list(ExtendedBlogPost.objects) info = ExtendedBlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info.values()) - self.assertTrue([('_types', 1), ('addDate', -1)] in info.values()) - self.assertTrue([('_types', 1), ('title', 1)] in info.values()) + in info) + self.assertTrue([('_types', 1), ('addDate', -1)] in info) + self.assertTrue([('_types', 1), ('title', 1)] in info) BlogPost.drop_collection() @@ -353,12 +386,26 @@ class DocumentTest(unittest.TestCase): user_obj = User.objects.first() self.assertEqual(user_obj.id, 'test') + self.assertEqual(user_obj.pk, 'test') user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'test') self.assertTrue('username' not in user_son['_id']) User.drop_collection() + + user = User(pk='mongo', name='mongo user') + user.save() + + user_obj = User.objects.first() + self.assertEqual(user_obj.id, 'mongo') + self.assertEqual(user_obj.pk, 'mongo') + + user_son = User.objects._collection.find_one() + self.assertEqual(user_son['_id'], 'mongo') + self.assertTrue('username' not in user_son['_id']) + + User.drop_collection() def test_creation(self): """Ensure that document may be created using keyword arguments. @@ -446,6 +493,16 @@ class DocumentTest(unittest.TestCase): self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['age'], 30) self.assertEqual(person_obj['_id'], person.id) + # Test skipping validation on save + class Recipient(Document): + email = EmailField(required=True) + + recipient = Recipient(email='root@localhost') + self.assertRaises(ValidationError, recipient.save) + try: + recipient.save(validate=False) + except ValidationError: + fail() def test_delete(self): """Ensure that document may be deleted using the delete method. @@ -467,6 +524,18 @@ class DocumentTest(unittest.TestCase): collection = self.db[self.Person._meta['collection']] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') + + def test_save_custom_pk(self): + """Ensure that a document may be saved with a custom _id using pk alias. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30, + pk='497ce96f395f2f052a494fd4') + person.save() + # Ensure that the object is in the database with the correct _id + collection = self.db[self.Person._meta['collection']] + person_obj = collection.find_one({'name': 'Test User'}) + self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') def test_save_list(self): """Ensure that a list field may be properly saved. diff --git a/tests/fields.py b/tests/fields.py index 4050e264..5602cdec 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -3,6 +3,7 @@ import datetime from decimal import Decimal import pymongo +import gridfs from mongoengine import * from mongoengine.connection import _get_db @@ -188,6 +189,9 @@ class FieldTest(unittest.TestCase): def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements. """ + class User(Document): + pass + class Comment(EmbeddedDocument): content = StringField() @@ -195,6 +199,7 @@ class FieldTest(unittest.TestCase): content = StringField() comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField()) + authors = ListField(ReferenceField(User)) post = BlogPost(content='Went for a walk today...') post.validate() @@ -209,15 +214,21 @@ class FieldTest(unittest.TestCase): post.tags = ('fun', 'leisure') post.validate() - comments = [Comment(content='Good for you'), Comment(content='Yay.')] - post.comments = comments - post.validate() - post.comments = ['a'] self.assertRaises(ValidationError, post.validate) post.comments = 'yay' self.assertRaises(ValidationError, post.validate) + comments = [Comment(content='Good for you'), Comment(content='Yay.')] + post.comments = comments + post.validate() + + post.authors = [Comment()] + self.assertRaises(ValidationError, post.validate) + + post.authors = [User()] + post.validate() + def test_sorted_list_sorting(self): """Ensure that a sorted list field properly sorts values. """ @@ -227,7 +238,8 @@ class FieldTest(unittest.TestCase): class BlogPost(Document): content = StringField() - comments = SortedListField(EmbeddedDocumentField(Comment), ordering='order') + comments = SortedListField(EmbeddedDocumentField(Comment), + ordering='order') tags = SortedListField(StringField()) post = BlogPost(content='Went for a walk today...') @@ -393,14 +405,54 @@ class FieldTest(unittest.TestCase): class Employee(Document): name = StringField() boss = ReferenceField('self') + friends = ListField(ReferenceField('self')) bill = Employee(name='Bill Lumbergh') bill.save() - peter = Employee(name='Peter Gibbons', boss=bill) + + michael = Employee(name='Michael Bolton') + michael.save() + + samir = Employee(name='Samir Nagheenanajar') + samir.save() + + friends = [michael, samir] + peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) peter.save() peter = Employee.objects.with_id(peter.id) self.assertEqual(peter.boss, bill) + self.assertEqual(peter.friends, friends) + + def test_recursive_embedding(self): + """Ensure that EmbeddedDocumentFields can contain their own documents. + """ + class Tree(Document): + name = StringField() + children = ListField(EmbeddedDocumentField('TreeNode')) + + class TreeNode(EmbeddedDocument): + name = StringField() + children = ListField(EmbeddedDocumentField('self')) + + tree = Tree(name="Tree") + + first_child = TreeNode(name="Child 1") + tree.children.append(first_child) + + second_child = TreeNode(name="Child 2") + first_child.children.append(second_child) + + third_child = TreeNode(name="Child 3") + first_child.children.append(third_child) + + tree.save() + + tree_obj = Tree.objects.first() + self.assertEqual(len(tree.children), 1) + self.assertEqual(tree.children[0].name, first_child.name) + self.assertEqual(tree.children[0].children[0].name, second_child.name) + self.assertEqual(tree.children[0].children[1].name, third_child.name) def test_undefined_reference(self): """Ensure that ReferenceFields may reference undefined Documents. @@ -607,7 +659,130 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + def test_file_fields(self): + """Ensure that file fields can be written to and their data retrieved + """ + class PutFile(Document): + file = FileField() + class StreamFile(Document): + file = FileField() + + class SetFile(Document): + file = FileField() + + text = 'Hello, World!' + more_text = 'Foo Bar' + content_type = 'text/plain' + + PutFile.drop_collection() + StreamFile.drop_collection() + SetFile.drop_collection() + + putfile = PutFile() + putfile.file.put(text, content_type=content_type) + putfile.save() + putfile.validate() + result = PutFile.objects.first() + self.assertTrue(putfile == result) + self.assertEquals(result.file.read(), text) + self.assertEquals(result.file.content_type, content_type) + result.file.delete() # Remove file from GridFS + + streamfile = StreamFile() + streamfile.file.new_file(content_type=content_type) + streamfile.file.write(text) + streamfile.file.write(more_text) + streamfile.file.close() + streamfile.save() + streamfile.validate() + result = StreamFile.objects.first() + self.assertTrue(streamfile == result) + self.assertEquals(result.file.read(), text + more_text) + self.assertEquals(result.file.content_type, content_type) + result.file.delete() + + # Ensure deleted file returns None + self.assertTrue(result.file.read() == None) + + setfile = SetFile() + setfile.file = text + setfile.save() + setfile.validate() + result = SetFile.objects.first() + self.assertTrue(setfile == result) + self.assertEquals(result.file.read(), text) + + # Try replacing file with new one + result.file.replace(more_text) + result.save() + result.validate() + result = SetFile.objects.first() + self.assertTrue(setfile == result) + self.assertEquals(result.file.read(), more_text) + result.file.delete() + + PutFile.drop_collection() + StreamFile.drop_collection() + SetFile.drop_collection() + + # Make sure FileField is optional and not required + class DemoFile(Document): + file = FileField() + d = DemoFile.objects.create() + + def test_file_uniqueness(self): + """Ensure that each instance of a FileField is unique + """ + class TestFile(Document): + name = StringField() + file = FileField() + + # First instance + testfile = TestFile() + testfile.name = "Hello, World!" + testfile.file.put('Hello, World!') + testfile.save() + + # Second instance + testfiledupe = TestFile() + data = testfiledupe.file.read() # Should be None + + self.assertTrue(testfile.name != testfiledupe.name) + self.assertTrue(testfile.file.read() != data) + + TestFile.drop_collection() + + def test_geo_indexes(self): + """Ensure that indexes are created automatically for GeoPointFields. + """ + class Event(Document): + title = StringField() + location = GeoPointField() + + Event.drop_collection() + event = Event(title="Coltrane Motion @ Double Door", + location=[41.909889, -87.677137]) + event.save() + + info = Event.objects._collection.index_information() + self.assertTrue(u'location_2d' in info) + self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')]) + + Event.drop_collection() + + def test_ensure_unique_default_instances(self): + """Ensure that every field has it's own unique default instance.""" + class D(Document): + data = DictField() + data2 = DictField(default=lambda: {}) + + d1 = D() + d1.data['foo'] = 'bar' + d1.data2['foo'] = 'bar' + d2 = D() + self.assertEqual(d2.data, {}) + self.assertEqual(d2.data2, {}) if __name__ == '__main__': unittest.main() diff --git a/tests/queryset.py b/tests/queryset.py index 51f92993..6ca4174d 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -53,9 +53,6 @@ class QuerySetTest(unittest.TestCase): person2 = self.Person(name="User B", age=30) person2.save() - q1 = Q(name='test') - q2 = Q(age__gte=18) - # Find all people in the collection people = self.Person.objects self.assertEqual(len(people), 2) @@ -156,7 +153,8 @@ class QuerySetTest(unittest.TestCase): # Retrieve the first person from the database self.assertRaises(MultipleObjectsReturned, self.Person.objects.get) - self.assertRaises(self.Person.MultipleObjectsReturned, self.Person.objects.get) + self.assertRaises(self.Person.MultipleObjectsReturned, + self.Person.objects.get) # Use a query to filter the people found to just person2 person = self.Person.objects.get(age=30) @@ -165,8 +163,49 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age__lt=30) self.assertEqual(person.name, "User A") + def test_find_array_position(self): + """Ensure that query by array position works. + """ + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() - + Blog.objects.create(tags=['a', 'b']) + self.assertEqual(len(Blog.objects(tags__0='a')), 1) + self.assertEqual(len(Blog.objects(tags__0='b')), 0) + self.assertEqual(len(Blog.objects(tags__1='a')), 0) + self.assertEqual(len(Blog.objects(tags__1='b')), 1) + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + blog = Blog.objects(posts__0__comments__0__name='testa').get() + self.assertEqual(blog, blog1) + + query = Blog.objects(posts__1__comments__1__name='testb') + self.assertEqual(len(query), 2) + + query = Blog.objects(posts__1__comments__1__name='testa') + self.assertEqual(len(query), 0) + + query = Blog.objects(posts__0__comments__1__name='testa') + self.assertEqual(len(query), 0) + + Blog.drop_collection() def test_get_or_create(self): """Ensure that ``get_or_create`` returns one result or creates a new @@ -193,7 +232,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(created, False) # Try retrieving when no objects exists - new doc should be created - person, created = self.Person.objects.get_or_create(age=50, defaults={'name': 'User C'}) + kwargs = dict(age=50, defaults={'name': 'User C'}) + person, created = self.Person.objects.get_or_create(**kwargs) self.assertEqual(created, True) person = self.Person.objects.get(age=50) @@ -288,6 +328,25 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj, person) obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() self.assertEqual(obj, None) + + # Test unsafe expressions + person = self.Person(name='Guido van Rossum [.\'Geek\']') + person.save() + + obj = self.Person.objects(Q(name__icontains='[.\'Geek')).first() + self.assertEqual(obj, person) + + def test_not(self): + """Ensure that the __not operator works as expected. + """ + alice = self.Person(name='Alice', age=25) + alice.save() + + obj = self.Person.objects(name__iexact='alice').first() + self.assertEqual(obj, alice) + + obj = self.Person.objects(name__not__iexact='alice').first() + self.assertEqual(obj, None) def test_filter_chaining(self): """Ensure filters can be chained together. @@ -498,9 +557,10 @@ class QuerySetTest(unittest.TestCase): obj = self.Person.objects(Q(name=re.compile('^gui', re.I))).first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__ne=re.compile('^bob'))).first() + obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__ne=re.compile('^Gui'))).first() + + obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() self.assertEqual(obj, None) def test_q_lists(self): @@ -664,28 +724,32 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertTrue('db' in post.tags and 'nosql' in post.tags) + tags = post.tags[:-1] + BlogPost.objects.update(pop__tags=1) + post.reload() + self.assertEqual(post.tags, tags) + + BlogPost.objects.update_one(add_to_set__tags='unique') + BlogPost.objects.update_one(add_to_set__tags='unique') + post.reload() + self.assertEqual(post.tags.count('unique'), 1) + BlogPost.drop_collection() def test_update_pull(self): """Ensure that the 'pull' update operation works correctly. """ - class Comment(EmbeddedDocument): - content = StringField() - class BlogPost(Document): slug = StringField() - comments = ListField(EmbeddedDocumentField(Comment)) + tags = ListField(StringField()) - comment1 = Comment(content="test1") - comment2 = Comment(content="test2") - - post = BlogPost(slug="test", comments=[comment1, comment2]) + post = BlogPost(slug="test", tags=['code', 'mongodb', 'code']) post.save() - self.assertTrue(comment2 in post.comments) - BlogPost.objects(slug="test").update(pull__comments__content="test2") + BlogPost.objects(slug="test").update(pull__tags="code") post.reload() - self.assertTrue(comment2 not in post.comments) + self.assertTrue('code' not in post.tags) + self.assertEqual(len(post.tags), 1) def test_order_by(self): """Ensure that QuerySets may be ordered. @@ -921,7 +985,7 @@ class QuerySetTest(unittest.TestCase): BlogPost(hits=1, tags=['music', 'film', 'actors']).save() BlogPost(hits=2, tags=['music']).save() - BlogPost(hits=3, tags=['music', 'actors']).save() + BlogPost(hits=2, tags=['music', 'actors']).save() f = BlogPost.objects.item_frequencies('tags') f = dict((key, int(val)) for key, val in f.items()) @@ -943,16 +1007,26 @@ class QuerySetTest(unittest.TestCase): self.assertAlmostEqual(f['actors'], 2.0/6.0) self.assertAlmostEqual(f['film'], 1.0/6.0) + # Check item_frequencies works for non-list fields + f = BlogPost.objects.item_frequencies('hits') + f = dict((key, int(val)) for key, val in f.items()) + self.assertEqual(set(['1', '2']), set(f.keys())) + self.assertEqual(f['1'], 1) + self.assertEqual(f['2'], 2) + BlogPost.drop_collection() def test_average(self): """Ensure that field can be averaged correctly. """ + self.Person(name='person', age=0).save() + self.assertEqual(int(self.Person.objects.average('age')), 0) + ages = [23, 54, 12, 94, 27] for i, age in enumerate(ages): self.Person(name='test%s' % i, age=age).save() - avg = float(sum(ages)) / len(ages) + avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0 self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) self.Person(name='ageless person').save() @@ -970,15 +1044,34 @@ class QuerySetTest(unittest.TestCase): self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + def test_distinct(self): + """Ensure that the QuerySet.distinct method works. + """ + self.Person(name='Mr Orange', age=20).save() + self.Person(name='Mr White', age=20).save() + self.Person(name='Mr Orange', age=30).save() + self.Person(name='Mr Pink', age=30).save() + self.assertEqual(set(self.Person.objects.distinct('name')), + set(['Mr Orange', 'Mr White', 'Mr Pink'])) + self.assertEqual(set(self.Person.objects.distinct('age')), + set([20, 30])) + self.assertEqual(set(self.Person.objects(age=30).distinct('name')), + set(['Mr Orange', 'Mr Pink'])) + def test_custom_manager(self): """Ensure that custom QuerySetManager instances work as expected. """ class BlogPost(Document): tags = ListField(StringField()) + deleted = BooleanField(default=False) + + @queryset_manager + def objects(doc_cls, queryset): + return queryset(deleted=False) @queryset_manager def music_posts(doc_cls, queryset): - return queryset(tags='music') + return queryset(tags='music', deleted=False) BlogPost.drop_collection() @@ -988,6 +1081,8 @@ class QuerySetTest(unittest.TestCase): post2.save() post3 = BlogPost(tags=['film', 'actors']) post3.save() + post4 = BlogPost(tags=['film', 'actors'], deleted=True) + post4.save() self.assertEqual([p.id for p in BlogPost.objects], [post1.id, post2.id, post3.id]) @@ -1011,7 +1106,8 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() data = {'title': 'Post 1', 'comments': [Comment(content='test')]} - BlogPost(**data).save() + post = BlogPost(**data) + post.save() self.assertTrue('postTitle' in BlogPost.objects(title=data['title'])._query) @@ -1019,12 +1115,33 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects(title=data['title'])._query) self.assertEqual(len(BlogPost.objects(title=data['title'])), 1) + self.assertTrue('_id' in BlogPost.objects(pk=post.id)._query) + self.assertEqual(len(BlogPost.objects(pk=post.id)), 1) + self.assertTrue('postComments.commentContent' in BlogPost.objects(comments__content='test')._query) self.assertEqual(len(BlogPost.objects(comments__content='test')), 1) BlogPost.drop_collection() + def test_query_pk_field_name(self): + """Ensure that the correct "primary key" field name is used when querying + """ + class BlogPost(Document): + title = StringField(primary_key=True, db_field='postTitle') + + BlogPost.drop_collection() + + data = { 'title':'Post 1' } + post = BlogPost(**data) + post.save() + + self.assertTrue('_id' in BlogPost.objects(pk=data['title'])._query) + self.assertTrue('_id' in BlogPost.objects(title=data['title'])._query) + self.assertEqual(len(BlogPost.objects(pk=data['title'])), 1) + + BlogPost.drop_collection() + def test_query_value_conversion(self): """Ensure that query values are properly converted when necessary. """ @@ -1087,8 +1204,9 @@ class QuerySetTest(unittest.TestCase): # Indexes are lazy so use list() to perform query list(BlogPost.objects) info = BlogPost.objects._collection.index_information() - self.assertTrue([('_types', 1)] in info.values()) - self.assertTrue([('_types', 1), ('date', -1)] in info.values()) + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('_types', 1)] in info) + self.assertTrue([('_types', 1), ('date', -1)] in info) BlogPost.drop_collection() @@ -1164,46 +1282,104 @@ class QuerySetTest(unittest.TestCase): def tearDown(self): self.Person.drop_collection() + def test_geospatial_operators(self): + """Ensure that geospatial queries are working. + """ + class Event(Document): + title = StringField() + date = DateTimeField() + location = GeoPointField() + + def __unicode__(self): + return self.title + + Event.drop_collection() + + event1 = Event(title="Coltrane Motion @ Double Door", + date=datetime.now() - timedelta(days=1), + location=[41.909889, -87.677137]) + event2 = Event(title="Coltrane Motion @ Bottom of the Hill", + date=datetime.now() - timedelta(days=10), + location=[37.7749295, -122.4194155]) + event3 = Event(title="Coltrane Motion @ Empty Bottle", + date=datetime.now(), + location=[41.900474, -87.686638]) + + event1.save() + event2.save() + event3.save() + + # find all events "near" pitchfork office, chicago. + # note that "near" will show the san francisco event, too, + # although it sorts to last. + events = Event.objects(location__near=[41.9120459, -87.67892]) + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event1, event3, event2]) + + # find events within 5 miles of pitchfork office, chicago + point_and_distance = [[41.9120459, -87.67892], 5] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 2) + events = list(events) + self.assertTrue(event2 not in events) + self.assertTrue(event1 in events) + self.assertTrue(event3 in events) + + # ensure ordering is respected by "near" + events = Event.objects(location__near=[41.9120459, -87.67892]) + events = events.order_by("-date") + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event3, event1, event2]) + + # find events around san francisco + point_and_distance = [[37.7566023, -122.415579], 10] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0], event2) + + # find events within 1 mile of greenpoint, broolyn, nyc, ny + point_and_distance = [[40.7237134, -73.9509714], 1] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 0) + + # ensure ordering is respected by "within_distance" + point_and_distance = [[41.9120459, -87.67892], 10] + events = Event.objects(location__within_distance=point_and_distance) + events = events.order_by("-date") + self.assertEqual(events.count(), 2) + self.assertEqual(events[0], event3) + + # check that within_box works + box = [(35.0, -125.0), (40.0, -100.0)] + events = Event.objects(location__within_box=box) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0].id, event2.id) + + Event.drop_collection() + + def test_custom_querysets(self): + """Ensure that custom QuerySet classes may be used. + """ + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class Post(Document): + meta = {'queryset_class': CustomQuerySet} + + Post.drop_collection() + + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + class QTest(unittest.TestCase): - def test_or_and(self): - """Ensure that Q objects may be combined correctly. - """ - q1 = Q(name='test') - q2 = Q(age__gte=18) - - query = ['(', {'name': 'test'}, '||', {'age__gte': 18}, ')'] - self.assertEqual((q1 | q2).query, query) - - query = ['(', {'name': 'test'}, '&&', {'age__gte': 18}, ')'] - self.assertEqual((q1 & q2).query, query) - - query = ['(', '(', {'name': 'test'}, '&&', {'age__gte': 18}, ')', '||', - {'name': 'example'}, ')'] - self.assertEqual((q1 & q2 | Q(name='example')).query, query) - - def test_item_query_as_js(self): - """Ensure that the _item_query_as_js utilitiy method works properly. - """ - q = Q() - examples = [ - - ({'name': 'test'}, ('((this.name instanceof Array) && ' - 'this.name.indexOf(i0f0) != -1) || this.name == i0f0'), - {'i0f0': 'test'}), - ({'age': {'$gt': 18}}, 'this.age > i0f0o0', {'i0f0o0': 18}), - ({'name': 'test', 'age': {'$gt': 18, '$lte': 65}}, - ('this.age <= i0f0o0 && this.age > i0f0o1 && ' - '((this.name instanceof Array) && ' - 'this.name.indexOf(i0f1) != -1) || this.name == i0f1'), - {'i0f0o0': 65, 'i0f0o1': 18, 'i0f1': 'test'}), - ] - for item, js, scope in examples: - test_scope = {} - self.assertEqual(q._item_query_as_js(item, test_scope, 0), js) - self.assertEqual(scope, test_scope) - def test_empty_q(self): """Ensure that empty Q objects won't hurt. """ @@ -1213,11 +1389,131 @@ class QTest(unittest.TestCase): q4 = Q(name='test') q5 = Q() - query = ['(', {'age__gte': 18}, '||', {'name': 'test'}, ')'] - self.assertEqual((q1 | q2 | q3 | q4 | q5).query, query) + class Person(Document): + name = StringField() + age = IntField() + + query = {'$or': [{'age': {'$gte': 18}}, {'name': 'test'}]} + self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query) + + query = {'age': {'$gte': 18}, 'name': 'test'} + self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) + + def test_q_with_dbref(self): + """Ensure Q objects handle DBRefs correctly""" + connect(db='mongoenginetest') + + class User(Document): + pass + + class Post(Document): + created_user = ReferenceField(User) + + user = User.objects.create() + Post.objects.create(created_user=user) + + self.assertEqual(Post.objects.filter(created_user=user).count(), 1) + self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1) + + def test_and_combination(self): + """Ensure that Q-objects correctly AND together. + """ + class TestDoc(Document): + x = IntField() + y = StringField() + + # Check than an error is raised when conflicting queries are anded + def invalid_combination(): + query = Q(x__lt=7) & Q(x__lt=3) + query.to_query(TestDoc) + self.assertRaises(InvalidQueryError, invalid_combination) + + # Check normal cases work without an error + query = Q(x__lt=7) & Q(x__gt=3) + + q1 = Q(x__lt=7) + q2 = Q(x__gt=3) + query = (q1 & q2).to_query(TestDoc) + self.assertEqual(query, {'x': {'$lt': 7, '$gt': 3}}) + + # More complex nested example + query = Q(x__lt=100) & Q(y__ne='NotMyString') + query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) + mongo_query = { + 'x': {'$lt': 100, '$gt': -100}, + 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, + } + self.assertEqual(query.to_query(TestDoc), mongo_query) + + def test_or_combination(self): + """Ensure that Q-objects correctly OR together. + """ + class TestDoc(Document): + x = IntField() + + q1 = Q(x__lt=3) + q2 = Q(x__gt=7) + query = (q1 | q2).to_query(TestDoc) + self.assertEqual(query, { + '$or': [ + {'x': {'$lt': 3}}, + {'x': {'$gt': 7}}, + ] + }) + + def test_and_or_combination(self): + """Ensure that Q-objects handle ANDing ORed components. + """ + class TestDoc(Document): + x = IntField() + y = BooleanField() + + query = (Q(x__gt=0) | Q(x__exists=False)) + query &= Q(x__lt=100) + self.assertEqual(query.to_query(TestDoc), { + '$or': [ + {'x': {'$lt': 100, '$gt': 0}}, + {'x': {'$lt': 100, '$exists': False}}, + ] + }) + + q1 = (Q(x__gt=0) | Q(x__exists=False)) + q2 = (Q(x__lt=100) | Q(y=True)) + query = (q1 & q2).to_query(TestDoc) + + self.assertEqual(['$or'], query.keys()) + conditions = [ + {'x': {'$lt': 100, '$gt': 0}}, + {'x': {'$lt': 100, '$exists': False}}, + {'x': {'$gt': 0}, 'y': True}, + {'x': {'$exists': False}, 'y': True}, + ] + self.assertEqual(len(conditions), len(query['$or'])) + for condition in conditions: + self.assertTrue(condition in query['$or']) + + def test_or_and_or_combination(self): + """Ensure that Q-objects handle ORing ANDed ORed components. :) + """ + class TestDoc(Document): + x = IntField() + y = BooleanField() + + q1 = (Q(x__gt=0) & (Q(y=True) | Q(y__exists=False))) + q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False))) + query = (q1 | q2).to_query(TestDoc) + + self.assertEqual(['$or'], query.keys()) + conditions = [ + {'x': {'$gt': 0}, 'y': True}, + {'x': {'$gt': 0}, 'y': {'$exists': False}}, + {'x': {'$lt': 100}, 'y':False}, + {'x': {'$lt': 100}, 'y': {'$exists': False}}, + ] + self.assertEqual(len(conditions), len(query['$or'])) + for condition in conditions: + self.assertTrue(condition in query['$or']) - query = ['(', {'age__gte': 18}, '&&', {'name': 'test'}, ')'] - self.assertEqual((q1 & q2 & q3 & q4 & q5).query, query) if __name__ == '__main__': unittest.main()