diff --git a/betterproto-extras/src/descriptor_pool.rs b/betterproto-extras/src/descriptor_pool.rs index 7ffee0e..90b0e98 100644 --- a/betterproto-extras/src/descriptor_pool.rs +++ b/betterproto-extras/src/descriptor_pool.rs @@ -4,7 +4,8 @@ use crate::{ }; use prost_reflect::{ prost_types::{ - field_descriptor_proto::{Type, Label}, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, + field_descriptor_proto::{Label, Type}, + DescriptorProto, FieldDescriptorProto, FileDescriptorProto, OneofDescriptorProto, }, DescriptorPool, MessageDescriptor, }; @@ -69,6 +70,27 @@ fn create_cached_descriptor_in_pool<'py>( field.set_label(Label::Repeated); } + if let Some(grp) = meta.oneof_group(field_name)? { + let oneof_index = message + .oneof_decl + .iter() + .enumerate() + .filter(|x| x.1.name() == grp) + .map(|x| x.0) + .last(); + + match oneof_index { + Some(i) => field.oneof_index = Some(i as i32), + None => { + message.oneof_decl.push(OneofDescriptorProto { + name: Some(grp), + ..Default::default() + }); + field.oneof_index = Some((message.oneof_decl.len() - 1) as i32) + } + } + } + field }); } diff --git a/betterproto-extras/src/merging.rs b/betterproto-extras/src/merging.rs index 6fe118f..88d12eb 100644 --- a/betterproto-extras/src/merging.rs +++ b/betterproto-extras/src/merging.rs @@ -9,11 +9,13 @@ pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> { for field in msg.descriptor().fields() { let field_name = field.name(); let proto_meta = obj.get_proto_meta()?; - if let Some(field_value) = msg.get_field_by_name(field_name) { - obj.setattr( - field_name, - map_field_value(field_name, &field_value, proto_meta)?, - )?; + if msg.has_field_by_name(field_name) { + if let Some(field_value) = msg.get_field_by_name(field_name) { + obj.setattr( + field_name, + map_field_value(field_name, &field_value, proto_meta)?, + )?; + } } } Ok(()) diff --git a/betterproto-extras/src/py_any_extras.rs b/betterproto-extras/src/py_any_extras.rs index f42611a..1c7e9cb 100644 --- a/betterproto-extras/src/py_any_extras.rs +++ b/betterproto-extras/src/py_any_extras.rs @@ -1,14 +1,15 @@ use crate::error::Result; use pyo3::PyAny; -pub trait PyAnyExtras<'py> { +pub trait PyAnyExtras { fn qualified_class_name(&self) -> Result; - fn get_proto_meta(&self) -> Result<&'py PyAny>; - fn create_instance(&self, field_name: &str) -> Result<&'py PyAny>; + fn get_proto_meta(&self) -> Result<&PyAny>; + fn create_instance(&self, field_name: &str) -> Result<&PyAny>; fn is_list_field(&self, field_name: &str) -> Result; + fn oneof_group(&self, field_name: &str) -> Result>; } -impl<'py> PyAnyExtras<'py> for &'py PyAny { +impl PyAnyExtras for PyAny { fn qualified_class_name(&self) -> Result { let class = self.getattr("__class__")?; let module = class.getattr("__module__")?; @@ -16,11 +17,11 @@ impl<'py> PyAnyExtras<'py> for &'py PyAny { Ok(format!("{module}.{name}")) } - fn get_proto_meta(&self) -> Result<&'py PyAny> { + fn get_proto_meta(&self) -> Result<&PyAny> { Ok(self.getattr("_betterproto")?) } - fn create_instance(&self, field_name: &str) -> Result<&'py PyAny> { + fn create_instance(&self, field_name: &str) -> Result<&PyAny> { let res = self .getattr("cls_by_field")? .get_item(field_name)? @@ -34,4 +35,12 @@ impl<'py> PyAnyExtras<'py> for &'py PyAny { let name = cls.getattr("__name__")?; Ok(module.to_string() == "builtins" && name.to_string() == "list") } + + fn oneof_group(&self, field_name: &str) -> Result> { + let opt = self + .getattr("oneof_group_by_field")? + .call_method1("get", (field_name,))? + .extract()?; + Ok(opt) + } } diff --git a/example.py b/example.py index ca69ed8..2f1048a 100644 --- a/example.py +++ b/example.py @@ -5,17 +5,20 @@ import betterproto from dataclasses import dataclass from typing import List -@dataclass +@dataclass(repr=False) class Baz(betterproto.Message): - a: str = betterproto.string_field(1) + a: float = betterproto.float_field(1, group = "x") + b: int = betterproto.int64_field(2, group = "x") + c: float = betterproto.float_field(3, group = "y") + d: int = betterproto.int64_field(4, group = "y") -@dataclass +@dataclass(repr=False) class Foo(betterproto.Message): x: int = betterproto.int32_field(1) y: float = betterproto.double_field(2) z: List[Baz] = betterproto.message_field(3) -@dataclass +@dataclass(repr=False) class Bar(betterproto.Message): foo1: Foo = betterproto.message_field(1) foo2: Foo = betterproto.message_field(2) @@ -25,7 +28,7 @@ class Bar(betterproto.Message): buffer = bytes( Bar( foo1=Foo(1, 2.34), - foo2=Foo(3, 4.56, [Baz("Hi"), Baz("There")]), + foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5), Baz(b = 2, d = 3)]), packed=[5, 3, 1] ) )