diff --git a/docs/changelog.rst b/docs/changelog.rst index c9535a07..92770afb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in 0.7.8 ================ +- Fix sequence fields in embedded documents (MongoEngine/mongoengine#166) - Fix query chaining with .order_by() (MongoEngine/mongoengine#176) - Added optional encoding and collection config for Django sessions (MongoEngine/mongoengine#180, MongoEngine/mongoengine#181, MongoEngine/mongoengine#183) - Fixed EmailField so can add extra validation (MongoEngine/mongoengine#173, MongoEngine/mongoengine#174, MongoEngine/mongoengine#187) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 9c0bedec..3f413b25 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1338,7 +1338,7 @@ class SequenceField(IntField): .. versionadded:: 0.5 """ - def __init__(self, collection_name=None, db_alias = None, sequence_name = None, *args, **kwargs): + def __init__(self, collection_name=None, db_alias=None, sequence_name=None, *args, **kwargs): self.collection_name = collection_name or 'mongoengine.counters' self.db_alias = db_alias or DEFAULT_CONNECTION_NAME self.sequence_name = sequence_name @@ -1348,7 +1348,7 @@ class SequenceField(IntField): """ Generate and Increment the counter """ - sequence_name = self.sequence_name or self.owner_document._get_collection_name() + sequence_name = self.get_sequence_name() sequence_id = "%s.%s" % (sequence_name, self.name) collection = get_db(alias=self.db_alias)[self.collection_name] counter = collection.find_and_modify(query={"_id": sequence_id}, @@ -1357,6 +1357,15 @@ class SequenceField(IntField): upsert=True) return counter['next'] + def get_sequence_name(self): + if self.sequence_name: + return self.sequence_name + owner = self.owner_document + if issubclass(owner, Document): + return owner._get_collection_name() + else: + return owner._class_name + def __get__(self, instance, owner): if instance is None: diff --git a/tests/test_fields.py b/tests/test_fields.py index 88e82f17..fdcc3080 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2175,6 +2175,28 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) self.assertEqual(c['next'], 10) + def test_embedded_sequence_field(self): + class Comment(EmbeddedDocument): + id = SequenceField() + content = StringField(required=True) + + class Post(Document): + title = StringField(required=True) + comments = ListField(EmbeddedDocumentField(Comment)) + + self.db['mongoengine.counters'].drop() + Post.drop_collection() + + Post(title="MongoEngine", + comments=[Comment(content="NoSQL Rocks"), + Comment(content="MongoEngine Rocks")]).save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'Comment.id'}) + self.assertEqual(c['next'], 2) + post = Post.objects.first() + self.assertEqual(1, post.comments[0].id) + self.assertEqual(2, post.comments[1].id) + def test_generic_embedded_document(self): class Car(EmbeddedDocument): name = StringField()