From a413d08fc1ebf8d065634ef506290eab85160ed4 Mon Sep 17 00:00:00 2001 From: Erik Friese Date: Wed, 30 Aug 2023 19:27:59 +0200 Subject: [PATCH] enum support --- betterproto-extras/src/descriptor_pool.rs | 79 ++++++++++++++++------- betterproto-extras/src/merging.rs | 4 ++ betterproto-extras/src/py_any_extras.rs | 24 ++++--- example.py | 9 ++- 4 files changed, 81 insertions(+), 35 deletions(-) diff --git a/betterproto-extras/src/descriptor_pool.rs b/betterproto-extras/src/descriptor_pool.rs index 0d2bc24..a672366 100644 --- a/betterproto-extras/src/descriptor_pool.rs +++ b/betterproto-extras/src/descriptor_pool.rs @@ -5,7 +5,8 @@ use crate::{ use prost_reflect::{ prost_types::{ field_descriptor_proto::{Label, Type}, - DescriptorProto, FieldDescriptorProto, FileDescriptorProto, OneofDescriptorProto, + DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto, + FileDescriptorProto, OneofDescriptorProto, }, DescriptorPool, MessageDescriptor, }; @@ -19,11 +20,11 @@ pub fn create_cached_descriptor(obj: &PyAny) -> Result { .lock() .unwrap(); - create_cached_descriptor_in_pool(obj, &mut pool) + create_cached_message_in_pool(obj, &mut pool) } -fn create_cached_descriptor_in_pool<'py>( - obj: &'py PyAny, +fn create_cached_message_in_pool( + obj: &PyAny, pool: &mut DescriptorPool, ) -> Result { let name = obj.qualified_class_name()?; @@ -38,17 +39,12 @@ fn create_cached_descriptor_in_pool<'py>( ..Default::default() }; - let mut file = FileDescriptorProto { - name: Some(name.clone()), - ..Default::default() - }; - for item in meta .getattr("meta_by_field_name")? .call_method0("items")? .iter()? { - let (field_name, field_meta) = item?.extract::<(&str, &'py PyAny)>()?; + let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?; message.field.push({ let mut field = FieldDescriptorProto { name: Some(field_name.to_string()), @@ -59,11 +55,20 @@ fn create_cached_descriptor_in_pool<'py>( field_meta.getattr("proto_type")?.extract::<&str>()?, )?); - if field.r#type() == Type::Message { - let instance = meta.create_instance(field_name)?; - let cls_name = instance.qualified_class_name()?; - field.type_name = Some(cls_name.to_string()); - create_cached_descriptor_in_pool(instance, pool)?; + match field.r#type() { + Type::Message => { + let instance = meta.create_instance(field_name)?; + let cls_name = instance.qualified_class_name()?; + field.type_name = Some(cls_name.to_string()); + create_cached_message_in_pool(instance, pool)?; + } + Type::Enum => { + let cls = meta.get_class(field_name)?; + let cls_name = cls.qualified_name()?; + field.type_name = Some(cls_name.to_string()); + create_cached_enum_in_pool(cls, pool)?; + } + _ => {} } if meta.is_list_field(field_name)? { @@ -73,13 +78,7 @@ fn create_cached_descriptor_in_pool<'py>( } if let Some(grp) = meta.oneof_group(field_name)? { - let oneof_index = message - .oneof_decl - .iter() - .enumerate() - .filter(|x| x.1.name() == grp) - .map(|x| x.0) - .last(); + let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp); match oneof_index { Some(i) => field.oneof_index = Some(i as i32), @@ -97,13 +96,43 @@ fn create_cached_descriptor_in_pool<'py>( }); } - file.message_type.push(message); - - pool.add_file_descriptor_proto(file)?; + pool.add_file_descriptor_proto(FileDescriptorProto { + name: Some(name.clone()), + message_type: vec![message], + ..Default::default() + })?; Ok(pool.get_message_by_name(&name).expect("Just registered...")) } +fn create_cached_enum_in_pool(cls: &PyAny, pool: &mut DescriptorPool) -> Result<()> { + let cls_name = cls.qualified_name()?; + if pool.get_enum_by_name(&cls_name).is_some() { + return Ok(()); + } + + let mut proto = EnumDescriptorProto { + name: Some(cls_name.clone()), + ..Default::default() + }; + + for item in cls.iter()? { + let item = item?; + proto.value.push(EnumValueDescriptorProto { + number: Some(item.getattr("value")?.extract()?), + name: Some(item.getattr("name")?.extract()?), + ..Default::default() + }); + } + + pool.add_file_descriptor_proto(FileDescriptorProto { + name: Some(cls_name), + enum_type: vec![proto], + ..Default::default() + })?; + Ok(()) +} + fn map_type(str: &str) -> Result { match str { "enum" => Ok(Type::Enum), diff --git a/betterproto-extras/src/merging.rs b/betterproto-extras/src/merging.rs index 6f13084..683f5f0 100644 --- a/betterproto-extras/src/merging.rs +++ b/betterproto-extras/src/merging.rs @@ -52,6 +52,10 @@ fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) -> .map(|x| map_field_value(field_name, x, proto_meta)) .collect::>>()? .to_object(py)), + Value::EnumNumber(x) => { + let cls = proto_meta.get_class(field_name)?; + Ok(cls.call1((x,))?.to_object(py)) + } value => Err(Error::UnsupportedType(value.to_string())), } } diff --git a/betterproto-extras/src/py_any_extras.rs b/betterproto-extras/src/py_any_extras.rs index 1c7e9cb..1f4a881 100644 --- a/betterproto-extras/src/py_any_extras.rs +++ b/betterproto-extras/src/py_any_extras.rs @@ -2,31 +2,37 @@ use crate::error::Result; use pyo3::PyAny; pub trait PyAnyExtras { + fn qualified_name(&self) -> Result; fn qualified_class_name(&self) -> Result; fn get_proto_meta(&self) -> Result<&PyAny>; + fn get_class(&self, field_name: &str) -> Result<&PyAny>; fn create_instance(&self, field_name: &str) -> Result<&PyAny>; fn is_list_field(&self, field_name: &str) -> Result; fn oneof_group(&self, field_name: &str) -> Result>; } impl PyAnyExtras for PyAny { - fn qualified_class_name(&self) -> Result { - let class = self.getattr("__class__")?; - let module = class.getattr("__module__")?; - let name = class.getattr("__name__")?; + fn qualified_name(&self) -> Result { + let module = self.getattr("__module__")?; + let name = self.getattr("__name__")?; Ok(format!("{module}.{name}")) } + fn qualified_class_name(&self) -> Result { + self.getattr("__class__")?.qualified_name() + } + fn get_proto_meta(&self) -> Result<&PyAny> { Ok(self.getattr("_betterproto")?) } + fn get_class(&self, field_name: &str) -> Result<&PyAny> { + let cls = self.getattr("cls_by_field")?.get_item(field_name)?; + Ok(cls) + } + fn create_instance(&self, field_name: &str) -> Result<&PyAny> { - let res = self - .getattr("cls_by_field")? - .get_item(field_name)? - .call0()?; - Ok(res) + Ok(self.get_class(field_name)?.call0()?) } fn is_list_field(&self, field_name: &str) -> Result { diff --git a/example.py b/example.py index 627b359..d19f888 100644 --- a/example.py +++ b/example.py @@ -19,18 +19,25 @@ class Foo(betterproto.Message): y: float = betterproto.double_field(2) z: List[Baz] = betterproto.message_field(3) +class Enm(betterproto.Enum): + A = 0 + B = 1 + C = 2 + @dataclass(repr=False) class Bar(betterproto.Message): foo1: Foo = betterproto.message_field(1) foo2: Foo = betterproto.message_field(2) packed: List[int] = betterproto.int64_field(3) + enm: Enm = betterproto.enum_field(4) # Serialization has not been changed yet. So nothing unusual here buffer = bytes( Bar( foo1=Foo(1, 2.34), foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]), - packed=[5, 3, 1] + packed=[5, 3, 1], + enm=Enm.B ) )