diff --git a/AUTHORS b/AUTHORS index 9371cd76..084c0129 100644 --- a/AUTHORS +++ b/AUTHORS @@ -129,6 +129,16 @@ that much better: * Peter Teichman * Jakub Kot * Jorge Bastida + * Aleksandr Sorokoumov + * Yohan Graterol + * bool-dev + * Russ Weeks + * Paul Swartz + * Sundar Raman + * Benoit Louy + * lraucy + * hellysmile + * Jaepil Jeong * Stefan Wójcik * Pete Campton * Martyn Smith @@ -145,4 +155,3 @@ that much better: * Jared Forsyth * Kenneth Falck * Lukasz Balcerzak - * Aleksandr Sorokoumov diff --git a/docs/changelog.rst b/docs/changelog.rst index ecfe941d..a3b4eaaa 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -51,6 +51,19 @@ Changes in 0.8.X Changes in 0.7.10 ================= +- Allow construction using positional parameters (#268) +- Updated EmailField length to support long domains (#243) +- Added 64-bit integer support (#251) +- Added Django sessions TTL support (#224) +- Fixed issue with numerical keys in MapField(EmbeddedDocumentField()) (#240) +- Fixed clearing _changed_fields for complex nested embedded documents (#237, #239, #242) +- Added "_id" to _data dictionary (#255) +- Only mark a field as changed if the value has changed (#258) +- Explicitly check for Document instances when dereferencing (#261) +- Fixed order_by chaining issue (#265) +- Added dereference support for tuples (#250) +- Resolve field name to db field name when using distinct(#260, #264, #269) +- Added kwargs to doc.save to help interop with django (#223, #270) - Fixed cloning querysets in PY3 - Int fields no longer unset in save when changed to 0 (#272) - Fixed ReferenceField query chaining bug fixed (#254) diff --git a/docs/django.rst b/docs/django.rst index ba934324..6f27b902 100644 --- a/docs/django.rst +++ b/docs/django.rst @@ -10,9 +10,15 @@ In your **settings.py** file, ignore the standard database settings (unless you also plan to use the ORM in your project), and instead call :func:`~mongoengine.connect` somewhere in the settings module. -.. note:: If getting an ``ImproperlyConfigured: settings.DATABASES is - improperly configured`` error you may need to remove - ``django.contrib.sites`` from ``INSTALLED_APPS`` in settings.py. +.. note :: + If you are not using another Database backend you may need to add a dummy + database backend to ``settings.py`` eg:: + + DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.dummy' + } + } Authentication ============== @@ -49,6 +55,9 @@ into you settings module:: SESSION_ENGINE = 'mongoengine.django.sessions' +Django provides session cookie, which expires after ```SESSION_COOKIE_AGE``` seconds, but doesnt delete cookie at sessions backend, so ``'mongoengine.django.sessions'`` supports `mongodb TTL +`_. + .. versionadded:: 0.2.1 Storage diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 2d3b090b..ebb34101 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -30,13 +30,22 @@ class BaseDocument(object): _dynamic_lock = True _initialised = False - def __init__(self, __auto_convert=True, **values): + def __init__(self, __auto_convert=True, *args, **values): """ Initialise a document or embedded document :param __auto_convert: Try and will cast python objects to Object types :param values: A dictionary of values for the document """ + if args: + # Combine positional arguments with named arguments. + # We only want named arguments. + field = iter(self._fields_ordered) + for value in args: + name = next(field) + if name in values: + raise TypeError("Multiple values for keyword argument '" + name + "'") + values[name] = value signals.pre_init.send(self.__class__, document=self, values=values) @@ -117,15 +126,15 @@ class BaseDocument(object): self._mark_as_changed(name) if (self._is_document and not self._created and - name in self._meta.get('shard_key', tuple()) and - self._data.get(name) != value): + name in self._meta.get('shard_key', tuple()) and + self._data.get(name) != value): OperationError = _import_class('OperationError') msg = "Shard Keys are immutable. Tried to update %s" % name raise OperationError(msg) # Check if the user has created a new instance of a class if (self._is_document and self._initialised - and self._created and name == self._meta['id_field']): + and self._created and name == self._meta['id_field']): super(BaseDocument, self).__setattr__('_created', False) super(BaseDocument, self).__setattr__(name, value) @@ -143,7 +152,10 @@ class BaseDocument(object): self.__set_field_display() def __iter__(self): - return iter(self._fields) + if 'id' in self._fields and 'id' not in self._fields_ordered: + return iter(('id', ) + self._fields_ordered) + + return iter(self._fields_ordered) def __getitem__(self, name): """Dictionary-style field access, return a field's value if present. @@ -264,7 +276,7 @@ class BaseDocument(object): for name, field in self._fields.items()] if self._dynamic: fields += [(field, self._data.get(name)) - for name, field in self._dynamic_fields.items()] + for name, field in self._dynamic_fields.items()] EmbeddedDocumentField = _import_class("EmbeddedDocumentField") GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") @@ -273,7 +285,7 @@ class BaseDocument(object): if value is not None: try: if isinstance(field, (EmbeddedDocumentField, - GenericEmbeddedDocumentField)): + GenericEmbeddedDocumentField)): field._validate(value, clean=clean) else: field._validate(value) @@ -330,7 +342,7 @@ class BaseDocument(object): # Convert lists / values so we can watch for any changes on them if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): + not isinstance(value, BaseList)): value = BaseList(value, self, name) elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, self, name) @@ -344,9 +356,25 @@ class BaseDocument(object): return key = self._db_field_map.get(key, key) if (hasattr(self, '_changed_fields') and - key not in self._changed_fields): + key not in self._changed_fields): self._changed_fields.append(key) + def _clear_changed_fields(self): + self._changed_fields = [] + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + for field_name, field in self._fields.iteritems(): + if (isinstance(field, ComplexBaseField) and + isinstance(field.field, EmbeddedDocumentField)): + field_value = getattr(self, field_name, None) + if field_value: + for idx in (field_value if isinstance(field_value, dict) + else xrange(len(field_value))): + field_value[idx]._clear_changed_fields() + elif isinstance(field, EmbeddedDocumentField): + field_value = getattr(self, field_name, None) + if field_value: + field_value._clear_changed_fields() + def _get_changed_fields(self, key='', inspected=None): """Returns a list of all fields that have explicitly been changed. """ @@ -418,7 +446,7 @@ class BaseDocument(object): for p in parts: if isinstance(d, DBRef): break - elif p.isdigit(): + elif isinstance(d, list) and p.isdigit(): d = d[int(p)] elif hasattr(d, 'get'): d = d.get(p) @@ -449,7 +477,7 @@ class BaseDocument(object): parts = path.split('.') db_field_name = parts.pop() for p in parts: - if p.isdigit(): + if isinstance(d, list) and p.isdigit(): d = d[int(p)] elif (hasattr(d, '__getattribute__') and not isinstance(d, dict)): @@ -514,7 +542,7 @@ class BaseDocument(object): value = data[field.db_field] try: data[field_name] = (value if value is None - else field.to_python(value)) + else field.to_python(value)) if field_name != field.db_field: del data[field.db_field] except (AttributeError, ValueError), e: @@ -548,14 +576,14 @@ class BaseDocument(object): geo_indices = cls._geo_indices() unique_indices = cls._unique_with_indexes() index_specs = [cls._build_index_spec(spec) - for spec in meta_indexes] + for spec in meta_indexes] def merge_index_specs(index_specs, indices): if not indices: return index_specs spec_fields = [v['fields'] - for k, v in enumerate(index_specs)] + for k, v in enumerate(index_specs)] # Merge unqiue_indexes with existing specs for k, v in enumerate(indices): if v['fields'] in spec_fields: @@ -727,7 +755,7 @@ class BaseDocument(object): field = DynamicField(db_field=field_name) else: raise LookUpError('Cannot resolve field "%s"' - % field_name) + % field_name) else: ReferenceField = _import_class('ReferenceField') GenericReferenceField = _import_class('GenericReferenceField') @@ -744,7 +772,7 @@ class BaseDocument(object): continue elif not new_field: raise LookUpError('Cannot resolve field "%s"' - % field_name) + % field_name) field = new_field # update field to the new field type fields.append(field) return fields diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 8d2ee876..6ebba362 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -81,8 +81,12 @@ class BaseField(object): def __set__(self, instance, value): """Descriptor for assigning a value to a field in a document. """ - instance._data[self.name] = value - if instance._initialised: + changed = False + if (self.name not in instance._data or + instance._data[self.name] != value): + changed = True + instance._data[self.name] = value + if changed and instance._initialised: instance._mark_as_changed(self.name) def error(self, message="", errors=None, field_name=None): diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 2b63bfa8..a53744db 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -78,7 +78,7 @@ class DocumentMetaclass(type): # Count names to ensure no db_field redefinitions field_names[attr_value.db_field] = field_names.get( - attr_value.db_field, 0) + 1 + attr_value.db_field, 0) + 1 # Ensure no duplicate db_fields duplicate_db_fields = [k for k, v in field_names.items() if v > 1] @@ -90,9 +90,12 @@ class DocumentMetaclass(type): # Set _fields and db_field maps attrs['_fields'] = doc_fields attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) - for k, v in doc_fields.iteritems()]) + for k, v in doc_fields.iteritems()]) + attrs['_fields_ordered'] = tuple(i[1] for i in sorted( + (v.creation_counter, v.name) + for v in doc_fields.itervalues())) attrs['_reverse_db_field_map'] = dict( - (v, k) for k, v in attrs['_db_field_map'].iteritems()) + (v, k) for k, v in attrs['_db_field_map'].iteritems()) # # Set document hierarchy @@ -101,7 +104,7 @@ class DocumentMetaclass(type): class_name = [name] for base in flattened_bases: if (not getattr(base, '_is_base_cls', True) and - not getattr(base, '_meta', {}).get('abstract', True)): + not getattr(base, '_meta', {}).get('abstract', True)): # Collate heirarchy for _cls and _subclasses class_name.append(base.__name__) @@ -109,11 +112,11 @@ class DocumentMetaclass(type): # Warn if allow_inheritance isn't set and prevent # inheritance of classes where inheritance is set to False allow_inheritance = base._meta.get('allow_inheritance', - ALLOW_INHERITANCE) - if (allow_inheritance != True and - not base._meta.get('abstract')): + ALLOW_INHERITANCE) + if (allow_inheritance is not True and + not base._meta.get('abstract')): raise ValueError('Document %s may not be subclassed' % - base.__name__) + base.__name__) # Get superclasses from last base superclass document_bases = [b for b in flattened_bases diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 25d46b46..1e220d47 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -33,7 +33,7 @@ class DeReference(object): self.max_depth = max_depth doc_type = None - if instance and instance._fields: + if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)): doc_type = instance._fields.get(name) if hasattr(doc_type, 'field'): doc_type = doc_type.field @@ -84,7 +84,7 @@ class DeReference(object): # Recursively find dbreferences depth += 1 for k, item in iterator: - if hasattr(item, '_fields'): + if isinstance(item, Document): for field_name, field in item._fields.iteritems(): v = item._data.get(field_name, None) if isinstance(v, (DBRef)): @@ -174,6 +174,7 @@ class DeReference(object): if not hasattr(items, 'items'): is_list = True + as_tuple = isinstance(items, tuple) iterator = enumerate(items) data = [] else: @@ -190,7 +191,7 @@ class DeReference(object): if k in self.object_map and not is_list: data[k] = self.object_map[k] - elif hasattr(v, '_fields'): + elif isinstance(v, Document): for field_name, field in v._fields.iteritems(): v = data[k]._data.get(field_name, None) if isinstance(v, (DBRef)): @@ -208,7 +209,7 @@ class DeReference(object): if instance and name: if is_list: - return BaseList(data, instance, name) + return tuple(data) if as_tuple else BaseList(data, instance, name) return BaseDict(data, instance, name) depth += 1 return data diff --git a/mongoengine/django/sessions.py b/mongoengine/django/sessions.py index 1c9288ed..0d199a6c 100644 --- a/mongoengine/django/sessions.py +++ b/mongoengine/django/sessions.py @@ -32,9 +32,17 @@ class MongoSession(Document): else fields.DictField() expire_date = fields.DateTimeField() - meta = {'collection': MONGOENGINE_SESSION_COLLECTION, - 'db_alias': MONGOENGINE_SESSION_DB_ALIAS, - 'allow_inheritance': False} + meta = { + 'collection': MONGOENGINE_SESSION_COLLECTION, + 'db_alias': MONGOENGINE_SESSION_DB_ALIAS, + 'allow_inheritance': False, + 'indexes': [ + { + 'fields': ['expire_date'], + 'expireAfterSeconds': settings.SESSION_COOKIE_AGE + } + ] + } def get_decoded(self): return SessionStore().decode(self.session_data) diff --git a/mongoengine/document.py b/mongoengine/document.py index 525d9644..9057075e 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -160,7 +160,7 @@ class Document(BaseDocument): def save(self, safe=True, force_insert=False, validate=True, clean=True, write_options=None, cascade=None, cascade_kwargs=None, - _refs=None): + _refs=None, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -278,7 +278,7 @@ class Document(BaseDocument): if id_field not in self._meta.get('shard_key', []): self[id_field] = self._fields[id_field].to_python(object_id) - self._changed_fields = [] + self._clear_changed_fields() self._created = False signals.post_save.send(self.__class__, document=self, created=created) return self diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 11e9d3f3..690e7ace 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -27,7 +27,7 @@ except ImportError: Image = None ImageOps = None -__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', +__all__ = ['StringField', 'IntField', 'LongField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', 'DecimalField', 'ComplexDateTimeField', 'URLField', 'DynamicField', @@ -143,7 +143,7 @@ class EmailField(StringField): EMAIL_REGEX = re.compile( r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) def validate(self, value): @@ -153,7 +153,7 @@ class EmailField(StringField): class IntField(BaseField): - """An integer field. + """An 32-bit integer field. """ def __init__(self, min_value=None, max_value=None, **kwargs): @@ -186,6 +186,40 @@ class IntField(BaseField): return int(value) +class LongField(BaseField): + """An 64-bit integer field. + """ + + def __init__(self, min_value=None, max_value=None, **kwargs): + self.min_value, self.max_value = min_value, max_value + super(LongField, self).__init__(**kwargs) + + def to_python(self, value): + try: + value = long(value) + except ValueError: + pass + return value + + def validate(self, value): + try: + value = long(value) + except: + self.error('%s could not be converted to long' % value) + + if self.min_value is not None and value < self.min_value: + self.error('Long value is too small') + + if self.max_value is not None and value > self.max_value: + self.error('Long value is too large') + + def prepare_query_value(self, op, value): + if value is None: + return value + + return long(value) + + class FloatField(BaseField): """An floating point number field. """ diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index ba6134fc..c299190f 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -609,8 +609,11 @@ class QuerySet(object): .. versionchanged:: 0.6 - Improved db_field refrence handling """ queryset = self.clone() - return queryset._dereference(queryset._cursor.distinct(field), 1, - name=field, instance=queryset._document) + try: + field = self._fields_to_dbfields([field]).pop() + finally: + return self._dereference(queryset._cursor.distinct(field), 1, + name=field, instance=self._document) def only(self, *fields): """Load only a subset of this document's fields. :: @@ -696,7 +699,7 @@ class QuerySet(object): prefixed with **+** or **-** to determine the ordering direction """ queryset = self.clone() - queryset._ordering = self._get_order_by(keys) + queryset._ordering = queryset._get_order_by(keys) return queryset def explain(self, format=False): diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 8932a54f..95d11e8f 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -32,7 +32,7 @@ class SimplificationVisitor(QNodeVisitor): if combination.operation == combination.AND: # The simplification only applies to 'simple' queries if all(isinstance(node, Q) for node in combination.children): - queries = [node.query for node in combination.children] + queries = [n.query for n in combination.children] return Q(**self._query_conjunction(queries)) return combination diff --git a/tests/document/instance.py b/tests/document/instance.py index 9a4149ff..07991c1b 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -827,20 +827,20 @@ class InstanceTest(unittest.TestCase): float_field = FloatField(default=1.1) boolean_field = BooleanField(default=True) datetime_field = DateTimeField(default=datetime.now) - embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, - default=lambda: EmbeddedDoc()) + embedded_document_field = EmbeddedDocumentField( + EmbeddedDoc, default=lambda: EmbeddedDoc()) list_field = ListField(default=lambda: [1, 2, 3]) dict_field = DictField(default=lambda: {"hello": "world"}) objectid_field = ObjectIdField(default=bson.ObjectId) reference_field = ReferenceField(Simple, default=lambda: - Simple().save()) + Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) generic_reference_field = GenericReferenceField( - default=lambda: Simple().save()) + default=lambda: Simple().save()) sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3]) email_field = EmailField(default="ross@example.com") @@ -848,7 +848,7 @@ class InstanceTest(unittest.TestCase): sequence_field = SequenceField() uuid_field = UUIDField(default=uuid.uuid4) generic_embedded_document_field = GenericEmbeddedDocumentField( - default=lambda: EmbeddedDoc()) + default=lambda: EmbeddedDoc()) Simple.drop_collection() Doc.drop_collection() @@ -1127,20 +1127,20 @@ class InstanceTest(unittest.TestCase): u3 = User(username="hmarr") u3.save() - p1 = Page(comments = [Comment(user=u1, comment="Its very good"), - Comment(user=u2, comment="Hello world"), - Comment(user=u3, comment="Ping Pong"), - Comment(user=u1, comment="I like a beer")]) + p1 = Page(comments=[Comment(user=u1, comment="Its very good"), + Comment(user=u2, comment="Hello world"), + Comment(user=u3, comment="Ping Pong"), + Comment(user=u1, comment="I like a beer")]) p1.save() - p2 = Page(comments = [Comment(user=u1, comment="Its very good"), - Comment(user=u2, comment="Hello world")]) + p2 = Page(comments=[Comment(user=u1, comment="Its very good"), + Comment(user=u2, comment="Hello world")]) p2.save() - p3 = Page(comments = [Comment(user=u3, comment="Its very good")]) + p3 = Page(comments=[Comment(user=u3, comment="Its very good")]) p3.save() - p4 = Page(comments = [Comment(user=u2, comment="Heavy Metal song")]) + p4 = Page(comments=[Comment(user=u2, comment="Heavy Metal song")]) p4.save() self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) @@ -1183,7 +1183,6 @@ class InstanceTest(unittest.TestCase): class Site(Document): page = EmbeddedDocumentField(Page) - Site.drop_collection() site = Site(page=Page(log_message="Warning: Dummy message")) site.save() @@ -1328,7 +1327,8 @@ class InstanceTest(unittest.TestCase): occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) def raise_invalid_document(): - Word._from_son({'stem': [1,2,3], 'forms': 1, 'count': 'one', 'occurs': {"hello": None}}) + Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one', + 'occurs': {"hello": None}}) self.assertRaises(InvalidDocumentError, raise_invalid_document) @@ -1350,7 +1350,7 @@ class InstanceTest(unittest.TestCase): reviewer = self.Person(name='Re Viewer') reviewer.save() - post = BlogPost(content = 'Watched some TV') + post = BlogPost(content='Watched some TV') post.author = author post.reviewer = reviewer post.save() @@ -1432,7 +1432,6 @@ class InstanceTest(unittest.TestCase): author.delete() self.assertEqual(len(BlogPost.objects), 0) - def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): ''' ensure the pre_delete signal is triggered upon a cascading deletion setup a blog post with content, an author and editor @@ -1627,7 +1626,7 @@ class InstanceTest(unittest.TestCase): u1 = User.objects.create() u2 = User.objects.create() u3 = User.objects.create() - u4 = User() # New object + u4 = User() # New object b1 = BlogPost.objects.create() b2 = BlogPost.objects.create() @@ -1638,9 +1637,9 @@ class InstanceTest(unittest.TestCase): self.assertTrue(u1 in all_user_list) self.assertTrue(u2 in all_user_list) self.assertTrue(u3 in all_user_list) - self.assertFalse(u4 in all_user_list) # New object - self.assertFalse(b1 in all_user_list) # Other object - self.assertFalse(b2 in all_user_list) # Other object + self.assertFalse(u4 in all_user_list) # New object + self.assertFalse(b1 in all_user_list) # Other object + self.assertFalse(b2 in all_user_list) # Other object # in Dict all_user_dic = {} @@ -1650,9 +1649,9 @@ class InstanceTest(unittest.TestCase): self.assertEqual(all_user_dic.get(u1, False), "OK") self.assertEqual(all_user_dic.get(u2, False), "OK") self.assertEqual(all_user_dic.get(u3, False), "OK") - self.assertEqual(all_user_dic.get(u4, False), False) # New object - self.assertEqual(all_user_dic.get(b1, False), False) # Other object - self.assertEqual(all_user_dic.get(b2, False), False) # Other object + self.assertEqual(all_user_dic.get(u4, False), False) # New object + self.assertEqual(all_user_dic.get(b1, False), False) # Other object + self.assertEqual(all_user_dic.get(b2, False), False) # Other object # in Set all_user_set = set(User.objects.all()) @@ -1730,7 +1729,6 @@ class InstanceTest(unittest.TestCase): self.assertEqual(Doc.objects(archived=False).count(), 1) - def test_can_save_false_values_dynamic(self): """Ensures you can save False values on dynamic docs""" class Doc(DynamicDocument): @@ -1852,9 +1850,9 @@ class InstanceTest(unittest.TestCase): self.assertEquals('testdb-2', B._meta.get('db_alias')) self.assertEquals('mongoenginetest', - A._get_collection().database.name) + A._get_collection().database.name) self.assertEquals('mongoenginetest2', - B._get_collection().database.name) + B._get_collection().database.name) def test_db_alias_propagates(self): """db_alias propagates? @@ -1920,21 +1918,21 @@ class InstanceTest(unittest.TestCase): # Checks self.assertEqual(",".join([str(b) for b in Book.objects.all()]), - "1,2,3,4,5,6,7,8,9") + "1,2,3,4,5,6,7,8,9") # bob related books self.assertEqual(",".join([str(b) for b in Book.objects.filter( - Q(extra__a=bob) | - Q(author=bob) | - Q(extra__b=bob))]), - "1,2,3,4") + Q(extra__a=bob) | + Q(author=bob) | + Q(extra__b=bob))]), + "1,2,3,4") # Susan & Karl related books self.assertEqual(",".join([str(b) for b in Book.objects.filter( - Q(extra__a__all=[karl, susan]) | - Q(author__all=[karl, susan]) | - Q(extra__b__all=[ - karl.to_dbref(), susan.to_dbref()])) - ]), "1") + Q(extra__a__all=[karl, susan]) | + Q(author__all=[karl, susan]) | + Q(extra__b__all=[ + karl.to_dbref(), susan.to_dbref()])) + ]), "1") # $Where self.assertEqual(u",".join([str(b) for b in Book.objects.filter( @@ -1943,8 +1941,8 @@ class InstanceTest(unittest.TestCase): function(){ return this.name == '1' || this.name == '2';}""" - } - )]), "1,2") + })]), + "1,2") def test_switch_db_instance(self): register_connection('testdb-1', 'mongoenginetest2') @@ -2020,7 +2018,6 @@ class InstanceTest(unittest.TestCase): self.assertEqual("Bar", user._data["foo"]) self.assertEqual([1, 2, 3], user._data["data"]) - def test_spaces_in_keys(self): class Embedded(DynamicEmbeddedDocument): @@ -2109,8 +2106,8 @@ class InstanceTest(unittest.TestCase): docs = ListField(EmbeddedDocumentField(Embedded)) classic_doc = Doc(doc_name="my doc", docs=[ - Embedded(name="embedded doc1"), - Embedded(name="embedded doc2")]) + Embedded(name="embedded doc1"), + Embedded(name="embedded doc2")]) dict_doc = Doc(**{"doc_name": "my doc", "docs": [{"name": "embedded doc1"}, {"name": "embedded doc2"}]}) @@ -2118,5 +2115,82 @@ class InstanceTest(unittest.TestCase): self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc._data, dict_doc._data) + def test_positional_creation(self): + """Ensure that document may be created using positional arguments. + """ + person = self.Person("Test User", 42) + self.assertEqual(person.name, "Test User") + self.assertEqual(person.age, 42) + + def test_mixed_creation(self): + """Ensure that document may be created using mixed arguments. + """ + person = self.Person("Test User", age=42) + self.assertEqual(person.name, "Test User") + self.assertEqual(person.age, 42) + + def test_bad_mixed_creation(self): + """Ensure that document gives correct error when duplicating arguments + """ + def construct_bad_instance(): + return self.Person("Test User", 42, name="Bad User") + + self.assertRaises(TypeError, construct_bad_instance) + + def test_data_contains_id_field(self): + """Ensure that asking for _data returns 'id' + """ + class Person(Document): + name = StringField() + + Person.drop_collection() + Person(name="Harry Potter").save() + + person = Person.objects.first() + self.assertTrue('id' in person._data.keys()) + self.assertEqual(person._data.get('id'), person.id) + + def test_complex_nesting_document_and_embedded_document(self): + + class Macro(EmbeddedDocument): + value = DynamicField(default="UNDEFINED") + + class Parameter(EmbeddedDocument): + macros = MapField(EmbeddedDocumentField(Macro)) + + def expand(self): + self.macros["test"] = Macro() + + class Node(Document): + parameters = MapField(EmbeddedDocumentField(Parameter)) + + def expand(self): + self.flattened_parameter = {} + for parameter_name, parameter in self.parameters.iteritems(): + parameter.expand() + + class System(Document): + name = StringField(required=True) + nodes = MapField(ReferenceField(Node, dbref=False)) + + def save(self, *args, **kwargs): + for node_name, node in self.nodes.iteritems(): + node.expand() + node.save(*args, **kwargs) + super(System, self).save(*args, **kwargs) + + System.drop_collection() + Node.drop_collection() + + system = System(name="system") + system.nodes["node"] = Node() + system.save() + system.nodes["node"].parameters["param"] = Parameter() + system.save() + + system = System.objects.first() + self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value) + + if __name__ == '__main__': unittest.main() diff --git a/tests/document/validation.py b/tests/document/validation.py index 24ffed65..d3f3fd70 100644 --- a/tests/document/validation.py +++ b/tests/document/validation.py @@ -130,8 +130,8 @@ class ValidatorErrorTest(unittest.TestCase): doc = Doc.objects.first() keys = doc._data.keys() self.assertEqual(2, len(keys)) - self.assertTrue('id' in keys) self.assertTrue('e' in keys) + self.assertTrue('id' in keys) doc.e.val = "OK" try: diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 124c9538..9a7b82f7 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -145,6 +145,17 @@ class FieldTest(unittest.TestCase): self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) + def test_long_ne_operator(self): + class TestDocument(Document): + long_fld = LongField() + + TestDocument.drop_collection() + + TestDocument(long_fld=None).save() + TestDocument(long_fld=1).save() + + self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) + def test_object_id_validation(self): """Ensure that invalid values cannot be assigned to string fields. """ @@ -218,6 +229,23 @@ class FieldTest(unittest.TestCase): person.age = 'ten' self.assertRaises(ValidationError, person.validate) + def test_long_validation(self): + """Ensure that invalid values cannot be assigned to long fields. + """ + class TestDocument(Document): + value = LongField(min_value=0, max_value=110) + + doc = TestDocument() + doc.value = 50 + doc.validate() + + doc.value = -1 + self.assertRaises(ValidationError, doc.validate) + doc.age = 120 + self.assertRaises(ValidationError, doc.validate) + doc.age = 'ten' + self.assertRaises(ValidationError, doc.validate) + def test_float_validation(self): """Ensure that invalid values cannot be assigned to float fields. """ @@ -971,6 +999,24 @@ class FieldTest(unittest.TestCase): doc = self.db.test.find_one() self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) + def test_mapfield_numerical_index(self): + """Ensure that MapField accept numeric strings as indexes.""" + class Embedded(EmbeddedDocument): + name = StringField() + + class Test(Document): + my_map = MapField(EmbeddedDocumentField(Embedded)) + + Test.drop_collection() + + test = Test() + test.my_map['1'] = Embedded(name='test') + test.save() + test.my_map['1'].name = 'test updated' + test.save() + + Test.drop_collection() + def test_map_field_lookup(self): """Ensure MapField lookups succeed on Fields without a lookup method""" @@ -2399,11 +2445,26 @@ class FieldTest(unittest.TestCase): self.assertTrue(1 in error_dict['comments']) self.assertTrue('content' in error_dict['comments'][1]) self.assertEqual(error_dict['comments'][1]['content'], - u'Field is required') + u'Field is required') post.comments[1].content = 'here we go' post.validate() + def test_email_field(self): + class User(Document): + email = EmailField() + + user = User(email="ross@example.com") + self.assertTrue(user.validate() is None) + + user = User(email=("Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S" + "ucictfqpdkK9iS1zeFw8sg7s7cwAF7suIfUfeyueLpfosjn3" + "aJIazqqWkm7.net")) + self.assertTrue(user.validate() is None) + + user = User(email='me@localhost') + self.assertRaises(ValidationError, user.validate) + def test_email_field_honors_regex(self): class User(Document): email = EmailField(regex=r'\w+@example.com') diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 3e53c456..da0e89ab 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -931,6 +931,11 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() Blog.drop_collection() + def assertSequence(self, qs, expected): + self.assertEqual(len(qs), len(expected)) + for i in range(len(qs)): + self.assertEqual(qs[i], expected[i]) + def test_ordering(self): """Ensure default ordering is applied and can be overridden. """ @@ -957,14 +962,13 @@ class QuerySetTest(unittest.TestCase): # get the "first" BlogPost using default ordering # from BlogPost.meta.ordering - latest_post = BlogPost.objects.first() - self.assertEqual(latest_post.title, "Blog Post #3") + expected = [blog_post_3, blog_post_2, blog_post_1] + self.assertSequence(BlogPost.objects.all(), expected) # override default ordering, order BlogPosts by "published_date" - first_post = BlogPost.objects.order_by("+published_date").first() - self.assertEqual(first_post.title, "Blog Post #1") - - BlogPost.drop_collection() + qs = BlogPost.objects.order_by("+published_date") + expected = [blog_post_1, blog_post_2, blog_post_3] + self.assertSequence(qs, expected) def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. @@ -1505,8 +1509,8 @@ class QuerySetTest(unittest.TestCase): def test_order_by(self): """Ensure that QuerySets may be ordered. """ - self.Person(name="User A", age=20).save() self.Person(name="User B", age=40).save() + self.Person(name="User A", age=20).save() self.Person(name="User C", age=30).save() names = [p.name for p in self.Person.objects.order_by('-age')] @@ -1521,11 +1525,67 @@ class QuerySetTest(unittest.TestCase): ages = [p.age for p in self.Person.objects.order_by('-name')] self.assertEqual(ages, [30, 40, 20]) + def test_order_by_optional(self): + class BlogPost(Document): + title = StringField() + published_date = DateTimeField(required=False) + + BlogPost.drop_collection() + + blog_post_3 = BlogPost(title="Blog Post #3", + published_date=datetime(2010, 1, 6, 0, 0 ,0)) + blog_post_2 = BlogPost(title="Blog Post #2", + published_date=datetime(2010, 1, 5, 0, 0 ,0)) + blog_post_4 = BlogPost(title="Blog Post #4", + published_date=datetime(2010, 1, 7, 0, 0 ,0)) + blog_post_1 = BlogPost(title="Blog Post #1", published_date=None) + + blog_post_3.save() + blog_post_1.save() + blog_post_4.save() + blog_post_2.save() + + expected = [blog_post_1, blog_post_2, blog_post_3, blog_post_4] + self.assertSequence(BlogPost.objects.order_by('published_date'), + expected) + self.assertSequence(BlogPost.objects.order_by('+published_date'), + expected) + + expected.reverse() + self.assertSequence(BlogPost.objects.order_by('-published_date'), + expected) + + def test_order_by_list(self): + class BlogPost(Document): + title = StringField() + published_date = DateTimeField(required=False) + + BlogPost.drop_collection() + + blog_post_1 = BlogPost(title="A", + published_date=datetime(2010, 1, 6, 0, 0 ,0)) + blog_post_2 = BlogPost(title="B", + published_date=datetime(2010, 1, 6, 0, 0 ,0)) + blog_post_3 = BlogPost(title="C", + published_date=datetime(2010, 1, 7, 0, 0 ,0)) + + blog_post_2.save() + blog_post_3.save() + blog_post_1.save() + + qs = BlogPost.objects.order_by('published_date', 'title') + expected = [blog_post_1, blog_post_2, blog_post_3] + self.assertSequence(qs, expected) + + qs = BlogPost.objects.order_by('-published_date', '-title') + expected.reverse() + self.assertSequence(qs, expected) + def test_order_by_chaining(self): """Ensure that an order_by query chains properly and allows .only() """ - self.Person(name="User A", age=20).save() self.Person(name="User B", age=40).save() + self.Person(name="User A", age=20).save() self.Person(name="User C", age=30).save() only_age = self.Person.objects.order_by('-age').only('age') @@ -1537,6 +1597,21 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(names, [None, None, None]) self.assertEqual(ages, [40, 30, 20]) + qs = self.Person.objects.all().order_by('-age') + qs = qs.limit(10) + ages = [p.age for p in qs] + self.assertEqual(ages, [40, 30, 20]) + + qs = self.Person.objects.all().limit(10) + qs = qs.order_by('-age') + ages = [p.age for p in qs] + self.assertEqual(ages, [40, 30, 20]) + + qs = self.Person.objects.all().skip(0) + qs = qs.order_by('-age') + ages = [p.age for p in qs] + self.assertEqual(ages, [40, 30, 20]) + def test_confirm_order_by_reference_wont_work(self): """Ordering by reference is not possible. Use map / reduce.. or denormalise""" @@ -2065,6 +2140,25 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(Foo.objects.distinct("bar"), [bar]) + def test_distinct_handles_db_field(self): + """Ensure that distinct resolves field name to db_field as expected. + """ + class Product(Document): + product_id = IntField(db_field='pid') + + Product.drop_collection() + + Product(product_id=1).save() + Product(product_id=2).save() + Product(product_id=1).save() + + self.assertEqual(set(Product.objects.distinct('product_id')), + set([1, 2])) + self.assertEqual(set(Product.objects.distinct('pid')), + set([1, 2])) + + Product.drop_collection() + def test_custom_manager(self): """Ensure that custom QuerySetManager instances work as expected. """ diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 4198f3c4..f8925238 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -7,7 +7,7 @@ from bson import DBRef, ObjectId from mongoengine import * from mongoengine.connection import get_db -from mongoengine.context_managers import query_counter, no_dereference +from mongoengine.context_managers import query_counter class FieldTest(unittest.TestCase): @@ -212,8 +212,9 @@ class FieldTest(unittest.TestCase): # Migrate the data for g in Group.objects(): - g.author = g.author - g.members = g.members + # Explicitly mark as changed so resets + g._mark_as_changed('author') + g._mark_as_changed('members') g.save() group = Group.objects.first() @@ -1120,6 +1121,37 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) + def test_tuples_as_tuples(self): + """ + Ensure that tuples remain tuples when they are + inside a ComplexBaseField + """ + from mongoengine.base import BaseField + + class EnumField(BaseField): + + def __init__(self, **kwargs): + super(EnumField, self).__init__(**kwargs) + + def to_mongo(self, value): + return value + + def to_python(self, value): + return tuple(value) + + class TestDoc(Document): + items = ListField(EnumField()) + + TestDoc.drop_collection() + tuples = [(100, 'Testing')] + doc = TestDoc() + doc.items = tuples + doc.save() + x = TestDoc.objects().get() + self.assertTrue(x is not None) + self.assertTrue(len(x.items) == 1) + self.assertTrue(tuple(x.items[0]) in tuples) + self.assertTrue(x.items[0] in tuples) if __name__ == '__main__': unittest.main()