Fix sequence fields in embedded documents (MongoEngine/mongoengine#166)

This commit is contained in:
Ross Lawley 2012-12-10 09:11:31 +00:00
parent 90d22c2a28
commit 9236f365fa
3 changed files with 34 additions and 2 deletions

View File

@ -5,6 +5,7 @@ Changelog
Changes in 0.7.8 Changes in 0.7.8
================ ================
- Fix sequence fields in embedded documents (MongoEngine/mongoengine#166)
- Fix query chaining with .order_by() (MongoEngine/mongoengine#176) - Fix query chaining with .order_by() (MongoEngine/mongoengine#176)
- Added optional encoding and collection config for Django sessions (MongoEngine/mongoengine#180, MongoEngine/mongoengine#181, MongoEngine/mongoengine#183) - Added optional encoding and collection config for Django sessions (MongoEngine/mongoengine#180, MongoEngine/mongoengine#181, MongoEngine/mongoengine#183)
- Fixed EmailField so can add extra validation (MongoEngine/mongoengine#173, MongoEngine/mongoengine#174, MongoEngine/mongoengine#187) - Fixed EmailField so can add extra validation (MongoEngine/mongoengine#173, MongoEngine/mongoengine#174, MongoEngine/mongoengine#187)

View File

@ -1348,7 +1348,7 @@ class SequenceField(IntField):
""" """
Generate and Increment the counter Generate and Increment the counter
""" """
sequence_name = self.sequence_name or self.owner_document._get_collection_name() sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name) sequence_id = "%s.%s" % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name] collection = get_db(alias=self.db_alias)[self.collection_name]
counter = collection.find_and_modify(query={"_id": sequence_id}, counter = collection.find_and_modify(query={"_id": sequence_id},
@ -1357,6 +1357,15 @@ class SequenceField(IntField):
upsert=True) upsert=True)
return counter['next'] return counter['next']
def get_sequence_name(self):
if self.sequence_name:
return self.sequence_name
owner = self.owner_document
if issubclass(owner, Document):
return owner._get_collection_name()
else:
return owner._class_name
def __get__(self, instance, owner): def __get__(self, instance, owner):
if instance is None: if instance is None:

View File

@ -2175,6 +2175,28 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'})
self.assertEqual(c['next'], 10) self.assertEqual(c['next'], 10)
def test_embedded_sequence_field(self):
class Comment(EmbeddedDocument):
id = SequenceField()
content = StringField(required=True)
class Post(Document):
title = StringField(required=True)
comments = ListField(EmbeddedDocumentField(Comment))
self.db['mongoengine.counters'].drop()
Post.drop_collection()
Post(title="MongoEngine",
comments=[Comment(content="NoSQL Rocks"),
Comment(content="MongoEngine Rocks")]).save()
c = self.db['mongoengine.counters'].find_one({'_id': 'Comment.id'})
self.assertEqual(c['next'], 2)
post = Post.objects.first()
self.assertEqual(1, post.comments[0].id)
self.assertEqual(2, post.comments[1].id)
def test_generic_embedded_document(self): def test_generic_embedded_document(self):
class Car(EmbeddedDocument): class Car(EmbeddedDocument):
name = StringField() name = StringField()