diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index a49f9968..7ffb9976 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -762,14 +762,29 @@ class BaseQuerySet(object): distinct = self._dereference(queryset._cursor.distinct(field), 1, name=field, instance=self._document) - # We may need to cast to the correct type eg. - # ListField(EmbeddedDocumentField) - doc_field = getattr( - self._document._fields.get(field), "field", None) - instance = getattr(doc_field, "document_type", False) + doc_field = self._document._fields.get(field.split('.', 1)[0]) + instance = False + # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) EmbeddedDocumentField = _import_class('EmbeddedDocumentField') - GenericEmbeddedDocumentField = _import_class( - 'GenericEmbeddedDocumentField') + ListField = _import_class('ListField') + GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') + if isinstance(doc_field, ListField): + doc_field = getattr(doc_field, "field", doc_field) + if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): + instance = getattr(doc_field, "document_type", False) + # handle distinct on subdocuments + if '.' in field: + for field_part in field.split('.')[1:]: + # if looping on embedded document, get the document type instance + if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): + doc_field = instance + # now get the subdocument + doc_field = getattr(doc_field, field_part, doc_field) + # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) + if isinstance(doc_field, ListField): + doc_field = getattr(doc_field, "field", doc_field) + if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): + instance = getattr(doc_field, "document_type", False) if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): distinct = [instance(**doc) for doc in distinct] diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 7f3d5c48..03d2bdd9 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -2981,6 +2981,46 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(authors, [mark_twain, john_tolkien]) + def test_distinct_ListField_EmbeddedDocumentField_EmbeddedDocumentField(self): + class Continent(EmbeddedDocument): + continent_name = StringField() + + class Country(EmbeddedDocument): + country_name = StringField() + continent = EmbeddedDocumentField(Continent) + + class Author(EmbeddedDocument): + name = StringField() + country = EmbeddedDocumentField(Country) + + class Book(Document): + title = StringField() + authors = ListField(EmbeddedDocumentField(Author)) + + Book.drop_collection() + + europe = Continent(continent_name='europe') + asia = Continent(continent_name='asia') + + scotland = Country(country_name="Scotland", continent=europe) + tibet = Country(country_name="Tibet", continent=asia) + + mark_twain = Author(name="Mark Twain", country=scotland) + john_tolkien = Author(name="John Ronald Reuel Tolkien", country=tibet) + + book = Book(title="Tom Sawyer", authors=[mark_twain]).save() + book = Book( + title="The Lord of the Rings", authors=[john_tolkien]).save() + book = Book( + title="The Stories", authors=[mark_twain, john_tolkien]).save() + country_list = Book.objects.distinct("authors.country") + + self.assertEqual(country_list, [scotland, tibet]) + + continent_list = Book.objects.distinct("authors.country.continent") + + self.assertEqual(continent_list, [europe, asia]) + def test_distinct_ListField_ReferenceField(self): class Bar(Document):