From fa39789bac7e2e76280f17832e517a8cd378f48d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Mon, 18 Jul 2011 12:44:28 -0300 Subject: [PATCH] added SequenceField --- mongoengine/fields.py | 33 ++++++++++++++++++++++++++++++++- tests/fields.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 88040115..a89ec3e4 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -20,7 +20,8 @@ __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', 'DecimalField', 'ComplexDateTimeField', 'URLField', 'GenericReferenceField', 'FileField', 'BinaryField', - 'SortedListField', 'EmailField', 'GeoPointField'] + 'SortedListField', 'EmailField', 'GeoPointField', + 'SequenceField'] RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -876,3 +877,33 @@ class GeoPointField(BaseField): if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): raise ValidationError('Both values in point must be float or int.') + + +class SequenceField(IntField): + def generate_new_value(self): + """ + Generate and Increment counter + """ + sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(), + self.name) + collection = _get_db()['mongoengine.counters'] + counter = collection.find_and_modify(query={"_id": sequence_id}, + update={"$inc" : {"next": 1}}, + new=True, + upsert=True) + return counter['next'] + + def __get__(self, instance, owner): + if not instance._data: + return + + if instance is None: + return self + + value = instance._data.get(self.name) + + if not value: + value = self.generate_new_value() + instance._data[self.name] = value + + return value diff --git a/tests/fields.py b/tests/fields.py index 7a752998..2ceda7df 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1380,5 +1380,42 @@ class FieldTest(unittest.TestCase): self.assertEqual(d2.data2, {}) + def test_sequence_field(self): + class Person(Document): + id = SequenceField(primary_key=True) + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + p = Person() + p.save() + + p = Person.objects.first() + self.assertEqual(p.id, 1) + + def test_multiple_sequence_field(self): + class Person(Document): + id = SequenceField(primary_key=True) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in xrange(10): + p = Person(name="Person %s" % x) + p.save() + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + for x in xrange(10): + p = Person(name="Person %s" % x) + p.save() + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 21)) + + counter = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(counter['next'], 20) + if __name__ == '__main__': unittest.main()