oneof support

This commit is contained in:
Erik Friese 2023-08-27 14:03:22 +02:00
parent d79a9eee14
commit a12c9d24de
4 changed files with 53 additions and 17 deletions

View File

@ -4,7 +4,8 @@ use crate::{
}; };
use prost_reflect::{ use prost_reflect::{
prost_types::{ prost_types::{
field_descriptor_proto::{Type, Label}, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, field_descriptor_proto::{Label, Type},
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, OneofDescriptorProto,
}, },
DescriptorPool, MessageDescriptor, DescriptorPool, MessageDescriptor,
}; };
@ -69,6 +70,27 @@ fn create_cached_descriptor_in_pool<'py>(
field.set_label(Label::Repeated); field.set_label(Label::Repeated);
} }
if let Some(grp) = meta.oneof_group(field_name)? {
let oneof_index = message
.oneof_decl
.iter()
.enumerate()
.filter(|x| x.1.name() == grp)
.map(|x| x.0)
.last();
match oneof_index {
Some(i) => field.oneof_index = Some(i as i32),
None => {
message.oneof_decl.push(OneofDescriptorProto {
name: Some(grp),
..Default::default()
});
field.oneof_index = Some((message.oneof_decl.len() - 1) as i32)
}
}
}
field field
}); });
} }

View File

@ -9,11 +9,13 @@ pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> {
for field in msg.descriptor().fields() { for field in msg.descriptor().fields() {
let field_name = field.name(); let field_name = field.name();
let proto_meta = obj.get_proto_meta()?; let proto_meta = obj.get_proto_meta()?;
if let Some(field_value) = msg.get_field_by_name(field_name) { if msg.has_field_by_name(field_name) {
obj.setattr( if let Some(field_value) = msg.get_field_by_name(field_name) {
field_name, obj.setattr(
map_field_value(field_name, &field_value, proto_meta)?, field_name,
)?; map_field_value(field_name, &field_value, proto_meta)?,
)?;
}
} }
} }
Ok(()) Ok(())

View File

@ -1,14 +1,15 @@
use crate::error::Result; use crate::error::Result;
use pyo3::PyAny; use pyo3::PyAny;
pub trait PyAnyExtras<'py> { pub trait PyAnyExtras {
fn qualified_class_name(&self) -> Result<String>; fn qualified_class_name(&self) -> Result<String>;
fn get_proto_meta(&self) -> Result<&'py PyAny>; fn get_proto_meta(&self) -> Result<&PyAny>;
fn create_instance(&self, field_name: &str) -> Result<&'py PyAny>; fn create_instance(&self, field_name: &str) -> Result<&PyAny>;
fn is_list_field(&self, field_name: &str) -> Result<bool>; fn is_list_field(&self, field_name: &str) -> Result<bool>;
fn oneof_group(&self, field_name: &str) -> Result<Option<String>>;
} }
impl<'py> PyAnyExtras<'py> for &'py PyAny { impl PyAnyExtras for PyAny {
fn qualified_class_name(&self) -> Result<String> { fn qualified_class_name(&self) -> Result<String> {
let class = self.getattr("__class__")?; let class = self.getattr("__class__")?;
let module = class.getattr("__module__")?; let module = class.getattr("__module__")?;
@ -16,11 +17,11 @@ impl<'py> PyAnyExtras<'py> for &'py PyAny {
Ok(format!("{module}.{name}")) Ok(format!("{module}.{name}"))
} }
fn get_proto_meta(&self) -> Result<&'py PyAny> { fn get_proto_meta(&self) -> Result<&PyAny> {
Ok(self.getattr("_betterproto")?) Ok(self.getattr("_betterproto")?)
} }
fn create_instance(&self, field_name: &str) -> Result<&'py PyAny> { fn create_instance(&self, field_name: &str) -> Result<&PyAny> {
let res = self let res = self
.getattr("cls_by_field")? .getattr("cls_by_field")?
.get_item(field_name)? .get_item(field_name)?
@ -34,4 +35,12 @@ impl<'py> PyAnyExtras<'py> for &'py PyAny {
let name = cls.getattr("__name__")?; let name = cls.getattr("__name__")?;
Ok(module.to_string() == "builtins" && name.to_string() == "list") Ok(module.to_string() == "builtins" && name.to_string() == "list")
} }
fn oneof_group(&self, field_name: &str) -> Result<Option<String>> {
let opt = self
.getattr("oneof_group_by_field")?
.call_method1("get", (field_name,))?
.extract()?;
Ok(opt)
}
} }

View File

@ -5,17 +5,20 @@ import betterproto
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
@dataclass @dataclass(repr=False)
class Baz(betterproto.Message): class Baz(betterproto.Message):
a: str = betterproto.string_field(1) a: float = betterproto.float_field(1, group = "x")
b: int = betterproto.int64_field(2, group = "x")
c: float = betterproto.float_field(3, group = "y")
d: int = betterproto.int64_field(4, group = "y")
@dataclass @dataclass(repr=False)
class Foo(betterproto.Message): class Foo(betterproto.Message):
x: int = betterproto.int32_field(1) x: int = betterproto.int32_field(1)
y: float = betterproto.double_field(2) y: float = betterproto.double_field(2)
z: List[Baz] = betterproto.message_field(3) z: List[Baz] = betterproto.message_field(3)
@dataclass @dataclass(repr=False)
class Bar(betterproto.Message): class Bar(betterproto.Message):
foo1: Foo = betterproto.message_field(1) foo1: Foo = betterproto.message_field(1)
foo2: Foo = betterproto.message_field(2) foo2: Foo = betterproto.message_field(2)
@ -25,7 +28,7 @@ class Bar(betterproto.Message):
buffer = bytes( buffer = bytes(
Bar( Bar(
foo1=Foo(1, 2.34), foo1=Foo(1, 2.34),
foo2=Foo(3, 4.56, [Baz("Hi"), Baz("There")]), foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5), Baz(b = 2, d = 3)]),
packed=[5, 3, 1] packed=[5, 3, 1]
) )
) )