deserializing lists

This commit is contained in:
Erik Friese 2023-08-27 13:14:45 +02:00
parent d848d05710
commit d79a9eee14
4 changed files with 35 additions and 9 deletions

View File

@ -4,7 +4,7 @@ use crate::{
};
use prost_reflect::{
prost_types::{
field_descriptor_proto::Type, DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
field_descriptor_proto::{Type, Label}, DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
},
DescriptorPool, MessageDescriptor,
};
@ -65,6 +65,10 @@ fn create_cached_descriptor_in_pool<'py>(
create_cached_descriptor_in_pool(instance, pool)?;
}
if meta.is_list_field(field_name)? {
field.set_label(Label::Repeated);
}
field
});
}

View File

@ -3,7 +3,7 @@ use crate::{
py_any_extras::PyAnyExtras,
};
use prost_reflect::{DynamicMessage, ReflectMessage, Value};
use pyo3::{PyAny, ToPyObject, PyObject};
use pyo3::{PyAny, PyObject, ToPyObject};
pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> {
for field in msg.descriptor().fields() {
@ -19,11 +19,7 @@ pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> {
Ok(())
}
fn map_field_value(
field_name: &str,
field_value: &Value,
proto_meta: &PyAny,
) -> Result<PyObject> {
fn map_field_value(field_name: &str, field_value: &Value, proto_meta: &PyAny) -> Result<PyObject> {
let py = proto_meta.py();
match field_value {
Value::Bool(x) => Ok(x.to_object(py)),
@ -40,6 +36,11 @@ fn map_field_value(
merge_msg_into_pyobj(obj, msg)?;
Ok(obj.to_object(py))
}
Value::List(ls) => Ok(ls
.iter()
.map(|x| map_field_value(field_name, x, proto_meta))
.collect::<Result<Vec<PyObject>>>()?
.to_object(py)),
value => Err(Error::UnsupportedType(value.to_string())),
}
}

View File

@ -5,6 +5,7 @@ pub trait PyAnyExtras<'py> {
fn qualified_class_name(&self) -> Result<String>;
fn get_proto_meta(&self) -> Result<&'py PyAny>;
fn create_instance(&self, field_name: &str) -> Result<&'py PyAny>;
fn is_list_field(&self, field_name: &str) -> Result<bool>;
}
impl<'py> PyAnyExtras<'py> for &'py PyAny {
@ -26,4 +27,11 @@ impl<'py> PyAnyExtras<'py> for &'py PyAny {
.call0()?;
Ok(res)
}
fn is_list_field(&self, field_name: &str) -> Result<bool> {
let cls = self.getattr("default_gen")?.get_item(field_name)?;
let module = cls.getattr("__module__")?;
let name = cls.getattr("__name__")?;
Ok(module.to_string() == "builtins" && name.to_string() == "list")
}
}

View File

@ -3,20 +3,33 @@
import betterproto
from dataclasses import dataclass
from typing import List
@dataclass
class Baz(betterproto.Message):
a: str = betterproto.string_field(1)
@dataclass
class Foo(betterproto.Message):
x: int = betterproto.int32_field(1)
y: float = betterproto.double_field(2)
z: List[Baz] = betterproto.message_field(3)
@dataclass
class Bar(betterproto.Message):
foo1: Foo = betterproto.message_field(1)
foo2: Foo = betterproto.message_field(2)
packed: List[int] = betterproto.int64_field(3)
# Serialization has not been changed yet. So nothing unusual here
buffer = bytes(Bar(foo1 = Foo(1, 2.34), foo2 = Foo(3, 4.56)))
buffer = bytes(
Bar(
foo1=Foo(1, 2.34),
foo2=Foo(3, 4.56, [Baz("Hi"), Baz("There")]),
packed=[5, 3, 1]
)
)
# Native deserialization happening here
bar = Bar().parse(buffer)
print(bar)
print(bar)