diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 6c8572c6..ff4978d2 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -14,10 +14,14 @@ REPR_OUTPUT_SIZE = 20 class InvalidQueryError(Exception): pass - class OperationError(Exception): pass +class MultipleObjectsReturned(Exception): + pass + +class DoesNotExist(Exception): + pass class Q(object): @@ -292,6 +296,21 @@ class QuerySet(object): return mongo_query + def get(self, *q_objs, **query): + """Retrieve the the matching object raising + 'MultipleObjectsReturned' or 'DoesNotExist' exceptions + if multiple or no results are found. + """ + self.__call__(*q_objs, **query) + count = self.count() + if count == 1: + return self[0] + elif count > 1: + raise MultipleObjectsReturned + else: + raise DoesNotExist + + def first(self): """Retrieve the first object matching the query. """ diff --git a/tests/queryset.py b/tests/queryset.py index 1e06615a..3e73da86 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2,7 +2,7 @@ import unittest import pymongo from datetime import datetime -from mongoengine.queryset import QuerySet +from mongoengine.queryset import QuerySet, MultipleObjectsReturned, DoesNotExist from mongoengine import * @@ -135,6 +135,27 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.with_id(person1.id) self.assertEqual(person.name, "User A") + def test_find_only_one(self): + """Ensure that a query using find_one returns a valid result. + """ + # Try retrieving when no objects exists + self.assertRaises(DoesNotExist, self.Person.objects.get) + + person1 = self.Person(name="User A", age=20) + person1.save() + person2 = self.Person(name="User B", age=30) + person2.save() + + # Retrieve the first person from the database + self.assertRaises(MultipleObjectsReturned, self.Person.objects.get) + + # Use a query to filter the people found to just person2 + person = self.Person.objects.get(age=30) + self.assertEqual(person.name, "User B") + + person = self.Person.objects.get(age__lt=30) + self.assertEqual(person.name, "User A") + def test_filter_chaining(self): """Ensure filters can be chained together. """