From d79a9eee14d7453d3dd1805e2d17366bced2c23d Mon Sep 17 00:00:00 2001 From: Erik Friese Date: Sun, 27 Aug 2023 13:14:45 +0200 Subject: [PATCH] deserializing lists --- betterproto-extras/src/descriptor_pool.rs | 6 +++++- betterproto-extras/src/merging.rs | 13 +++++++------ betterproto-extras/src/py_any_extras.rs | 8 ++++++++ example.py | 17 +++++++++++++++-- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/betterproto-extras/src/descriptor_pool.rs b/betterproto-extras/src/descriptor_pool.rs index f301bcb..7ffee0e 100644 --- a/betterproto-extras/src/descriptor_pool.rs +++ b/betterproto-extras/src/descriptor_pool.rs @@ -4,7 +4,7 @@ use crate::{ }; use prost_reflect::{ prost_types::{ - field_descriptor_proto::Type, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, + field_descriptor_proto::{Type, Label}, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, }, DescriptorPool, MessageDescriptor, }; @@ -65,6 +65,10 @@ fn create_cached_descriptor_in_pool<'py>( create_cached_descriptor_in_pool(instance, pool)?; } + if meta.is_list_field(field_name)? { + field.set_label(Label::Repeated); + } + field }); } diff --git a/betterproto-extras/src/merging.rs b/betterproto-extras/src/merging.rs index 1fbe712..6fe118f 100644 --- a/betterproto-extras/src/merging.rs +++ b/betterproto-extras/src/merging.rs @@ -3,7 +3,7 @@ use crate::{ py_any_extras::PyAnyExtras, }; use prost_reflect::{DynamicMessage, ReflectMessage, Value}; -use pyo3::{PyAny, ToPyObject, PyObject}; +use pyo3::{PyAny, PyObject, ToPyObject}; pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> { for field in msg.descriptor().fields() { @@ -19,11 +19,7 @@ pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> { Ok(()) } -fn map_field_value( - field_name: &str, - field_value: &Value, - proto_meta: &PyAny, -) -> Result { +fn map_field_value(field_name: &str, field_value: &Value, proto_meta: &PyAny) -> Result { let py = proto_meta.py(); match field_value { Value::Bool(x) => Ok(x.to_object(py)), @@ -40,6 +36,11 @@ fn map_field_value( merge_msg_into_pyobj(obj, msg)?; Ok(obj.to_object(py)) } + Value::List(ls) => Ok(ls + .iter() + .map(|x| map_field_value(field_name, x, proto_meta)) + .collect::>>()? + .to_object(py)), value => Err(Error::UnsupportedType(value.to_string())), } } diff --git a/betterproto-extras/src/py_any_extras.rs b/betterproto-extras/src/py_any_extras.rs index 31c7b52..f42611a 100644 --- a/betterproto-extras/src/py_any_extras.rs +++ b/betterproto-extras/src/py_any_extras.rs @@ -5,6 +5,7 @@ pub trait PyAnyExtras<'py> { fn qualified_class_name(&self) -> Result; fn get_proto_meta(&self) -> Result<&'py PyAny>; fn create_instance(&self, field_name: &str) -> Result<&'py PyAny>; + fn is_list_field(&self, field_name: &str) -> Result; } impl<'py> PyAnyExtras<'py> for &'py PyAny { @@ -26,4 +27,11 @@ impl<'py> PyAnyExtras<'py> for &'py PyAny { .call0()?; Ok(res) } + + fn is_list_field(&self, field_name: &str) -> Result { + let cls = self.getattr("default_gen")?.get_item(field_name)?; + let module = cls.getattr("__module__")?; + let name = cls.getattr("__name__")?; + Ok(module.to_string() == "builtins" && name.to_string() == "list") + } } diff --git a/example.py b/example.py index aea85b2..ca69ed8 100644 --- a/example.py +++ b/example.py @@ -3,20 +3,33 @@ import betterproto from dataclasses import dataclass +from typing import List + +@dataclass +class Baz(betterproto.Message): + a: str = betterproto.string_field(1) @dataclass class Foo(betterproto.Message): x: int = betterproto.int32_field(1) y: float = betterproto.double_field(2) + z: List[Baz] = betterproto.message_field(3) @dataclass class Bar(betterproto.Message): foo1: Foo = betterproto.message_field(1) foo2: Foo = betterproto.message_field(2) + packed: List[int] = betterproto.int64_field(3) # Serialization has not been changed yet. So nothing unusual here -buffer = bytes(Bar(foo1 = Foo(1, 2.34), foo2 = Foo(3, 4.56))) +buffer = bytes( + Bar( + foo1=Foo(1, 2.34), + foo2=Foo(3, 4.56, [Baz("Hi"), Baz("There")]), + packed=[5, 3, 1] + ) +) # Native deserialization happening here bar = Bar().parse(buffer) -print(bar) \ No newline at end of file +print(bar)