diff --git a/docs/guide/document-instances.rst b/docs/guide/document-instances.rst index b5a1f029..7b5d165b 100644 --- a/docs/guide/document-instances.rst +++ b/docs/guide/document-instances.rst @@ -59,6 +59,13 @@ you may still use :attr:`id` to access the primary key if you want:: >>> bob.id == bob.email == 'bob@example.com' True +You can also access the document's "primary key" using the :attr:`pk` field; in +is an alias to :attr:`id`:: + + >>> page = Page(title="Another Test Page") + >>> page.save() + >>> page.id == page.pk + .. note:: If you define your own primary key field, the field implicitly becomes required, so a :class:`ValidationError` will be thrown if you don't provide diff --git a/mongoengine/base.py b/mongoengine/base.py index 0cbd707d..addcd6bf 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -330,14 +330,17 @@ class BaseDocument(object): def __init__(self, **values): self._data = {} + # Assign default values to instance + for attr_name in self._fields.keys(): + # Use default value if present + value = getattr(self, attr_name, None) + setattr(self, attr_name, value) # Assign initial values to instance - for attr_name, attr_value in self._fields.items(): - if attr_name in values: + for attr_name in values.keys(): + try: setattr(self, attr_name, values.pop(attr_name)) - else: - # Use default value if present - value = getattr(self, attr_name, None) - setattr(self, attr_name, value) + except AttributeError: + pass def validate(self): """Ensure that all fields' values are valid and that required fields @@ -373,6 +376,16 @@ class BaseDocument(object): all_subclasses.update(subclass._get_subclasses()) return all_subclasses + @apply + def pk(): + """Primary key alias + """ + def fget(self): + return getattr(self, self._meta['id_field']) + def fset(self, value): + return setattr(self, self._meta['id_field'], value) + return property(fget, fset) + def __iter__(self): return iter(self._fields) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 48936e68..69a110fe 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -312,6 +312,9 @@ class QuerySet(object): for field_name in parts: if field is None: # Look up first field from the document + if field_name == 'pk': + # Deal with "primary key" alias + field_name = document._meta['id_field'] field = document._fields[field_name] else: # Look up subfield on the previous field diff --git a/tests/document.py b/tests/document.py index c2e0d6f9..81f492c5 100644 --- a/tests/document.py +++ b/tests/document.py @@ -355,12 +355,26 @@ class DocumentTest(unittest.TestCase): user_obj = User.objects.first() self.assertEqual(user_obj.id, 'test') + self.assertEqual(user_obj.pk, 'test') user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'test') self.assertTrue('username' not in user_son['_id']) User.drop_collection() + + user = User(pk='mongo', name='mongo user') + user.save() + + user_obj = User.objects.first() + self.assertEqual(user_obj.id, 'mongo') + self.assertEqual(user_obj.pk, 'mongo') + + user_son = User.objects._collection.find_one() + self.assertEqual(user_son['_id'], 'mongo') + self.assertTrue('username' not in user_son['_id']) + + User.drop_collection() def test_creation(self): """Ensure that document may be created using keyword arguments. @@ -479,6 +493,18 @@ class DocumentTest(unittest.TestCase): collection = self.db[self.Person._meta['collection']] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') + + def test_save_custom_pk(self): + """Ensure that a document may be saved with a custom _id using pk alias. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30, + pk='497ce96f395f2f052a494fd4') + person.save() + # Ensure that the object is in the database with the correct _id + collection = self.db[self.Person._meta['collection']] + person_obj = collection.find_one({'name': 'Test User'}) + self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') def test_save_list(self): """Ensure that a list field may be properly saved. diff --git a/tests/queryset.py b/tests/queryset.py index 2271c366..92d3e4c5 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1094,7 +1094,8 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() data = {'title': 'Post 1', 'comments': [Comment(content='test')]} - BlogPost(**data).save() + post = BlogPost(**data) + post.save() self.assertTrue('postTitle' in BlogPost.objects(title=data['title'])._query) @@ -1102,12 +1103,33 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects(title=data['title'])._query) self.assertEqual(len(BlogPost.objects(title=data['title'])), 1) + self.assertTrue('_id' in BlogPost.objects(pk=post.id)._query) + self.assertEqual(len(BlogPost.objects(pk=post.id)), 1) + self.assertTrue('postComments.commentContent' in BlogPost.objects(comments__content='test')._query) self.assertEqual(len(BlogPost.objects(comments__content='test')), 1) BlogPost.drop_collection() + def test_query_pk_field_name(self): + """Ensure that the correct "primary key" field name is used when querying + """ + class BlogPost(Document): + title = StringField(primary_key=True, db_field='postTitle') + + BlogPost.drop_collection() + + data = { 'title':'Post 1' } + post = BlogPost(**data) + post.save() + + self.assertTrue('_id' in BlogPost.objects(pk=data['title'])._query) + self.assertTrue('_id' in BlogPost.objects(title=data['title'])._query) + self.assertEqual(len(BlogPost.objects(pk=data['title'])), 1) + + BlogPost.drop_collection() + def test_query_value_conversion(self): """Ensure that query values are properly converted when necessary. """