oneof support
This commit is contained in:
parent
d79a9eee14
commit
a12c9d24de
@ -4,7 +4,8 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use prost_reflect::{
|
use prost_reflect::{
|
||||||
prost_types::{
|
prost_types::{
|
||||||
field_descriptor_proto::{Type, Label}, DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
|
field_descriptor_proto::{Label, Type},
|
||||||
|
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, OneofDescriptorProto,
|
||||||
},
|
},
|
||||||
DescriptorPool, MessageDescriptor,
|
DescriptorPool, MessageDescriptor,
|
||||||
};
|
};
|
||||||
@ -69,6 +70,27 @@ fn create_cached_descriptor_in_pool<'py>(
|
|||||||
field.set_label(Label::Repeated);
|
field.set_label(Label::Repeated);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(grp) = meta.oneof_group(field_name)? {
|
||||||
|
let oneof_index = message
|
||||||
|
.oneof_decl
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter(|x| x.1.name() == grp)
|
||||||
|
.map(|x| x.0)
|
||||||
|
.last();
|
||||||
|
|
||||||
|
match oneof_index {
|
||||||
|
Some(i) => field.oneof_index = Some(i as i32),
|
||||||
|
None => {
|
||||||
|
message.oneof_decl.push(OneofDescriptorProto {
|
||||||
|
name: Some(grp),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
field.oneof_index = Some((message.oneof_decl.len() - 1) as i32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
field
|
field
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> {
|
|||||||
for field in msg.descriptor().fields() {
|
for field in msg.descriptor().fields() {
|
||||||
let field_name = field.name();
|
let field_name = field.name();
|
||||||
let proto_meta = obj.get_proto_meta()?;
|
let proto_meta = obj.get_proto_meta()?;
|
||||||
|
if msg.has_field_by_name(field_name) {
|
||||||
if let Some(field_value) = msg.get_field_by_name(field_name) {
|
if let Some(field_value) = msg.get_field_by_name(field_name) {
|
||||||
obj.setattr(
|
obj.setattr(
|
||||||
field_name,
|
field_name,
|
||||||
@ -16,6 +17,7 @@ pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> {
|
|||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use pyo3::PyAny;
|
use pyo3::PyAny;
|
||||||
|
|
||||||
pub trait PyAnyExtras<'py> {
|
pub trait PyAnyExtras {
|
||||||
fn qualified_class_name(&self) -> Result<String>;
|
fn qualified_class_name(&self) -> Result<String>;
|
||||||
fn get_proto_meta(&self) -> Result<&'py PyAny>;
|
fn get_proto_meta(&self) -> Result<&PyAny>;
|
||||||
fn create_instance(&self, field_name: &str) -> Result<&'py PyAny>;
|
fn create_instance(&self, field_name: &str) -> Result<&PyAny>;
|
||||||
fn is_list_field(&self, field_name: &str) -> Result<bool>;
|
fn is_list_field(&self, field_name: &str) -> Result<bool>;
|
||||||
|
fn oneof_group(&self, field_name: &str) -> Result<Option<String>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'py> PyAnyExtras<'py> for &'py PyAny {
|
impl PyAnyExtras for PyAny {
|
||||||
fn qualified_class_name(&self) -> Result<String> {
|
fn qualified_class_name(&self) -> Result<String> {
|
||||||
let class = self.getattr("__class__")?;
|
let class = self.getattr("__class__")?;
|
||||||
let module = class.getattr("__module__")?;
|
let module = class.getattr("__module__")?;
|
||||||
@ -16,11 +17,11 @@ impl<'py> PyAnyExtras<'py> for &'py PyAny {
|
|||||||
Ok(format!("{module}.{name}"))
|
Ok(format!("{module}.{name}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_proto_meta(&self) -> Result<&'py PyAny> {
|
fn get_proto_meta(&self) -> Result<&PyAny> {
|
||||||
Ok(self.getattr("_betterproto")?)
|
Ok(self.getattr("_betterproto")?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_instance(&self, field_name: &str) -> Result<&'py PyAny> {
|
fn create_instance(&self, field_name: &str) -> Result<&PyAny> {
|
||||||
let res = self
|
let res = self
|
||||||
.getattr("cls_by_field")?
|
.getattr("cls_by_field")?
|
||||||
.get_item(field_name)?
|
.get_item(field_name)?
|
||||||
@ -34,4 +35,12 @@ impl<'py> PyAnyExtras<'py> for &'py PyAny {
|
|||||||
let name = cls.getattr("__name__")?;
|
let name = cls.getattr("__name__")?;
|
||||||
Ok(module.to_string() == "builtins" && name.to_string() == "list")
|
Ok(module.to_string() == "builtins" && name.to_string() == "list")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn oneof_group(&self, field_name: &str) -> Result<Option<String>> {
|
||||||
|
let opt = self
|
||||||
|
.getattr("oneof_group_by_field")?
|
||||||
|
.call_method1("get", (field_name,))?
|
||||||
|
.extract()?;
|
||||||
|
Ok(opt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
13
example.py
13
example.py
@ -5,17 +5,20 @@ import betterproto
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@dataclass
|
@dataclass(repr=False)
|
||||||
class Baz(betterproto.Message):
|
class Baz(betterproto.Message):
|
||||||
a: str = betterproto.string_field(1)
|
a: float = betterproto.float_field(1, group = "x")
|
||||||
|
b: int = betterproto.int64_field(2, group = "x")
|
||||||
|
c: float = betterproto.float_field(3, group = "y")
|
||||||
|
d: int = betterproto.int64_field(4, group = "y")
|
||||||
|
|
||||||
@dataclass
|
@dataclass(repr=False)
|
||||||
class Foo(betterproto.Message):
|
class Foo(betterproto.Message):
|
||||||
x: int = betterproto.int32_field(1)
|
x: int = betterproto.int32_field(1)
|
||||||
y: float = betterproto.double_field(2)
|
y: float = betterproto.double_field(2)
|
||||||
z: List[Baz] = betterproto.message_field(3)
|
z: List[Baz] = betterproto.message_field(3)
|
||||||
|
|
||||||
@dataclass
|
@dataclass(repr=False)
|
||||||
class Bar(betterproto.Message):
|
class Bar(betterproto.Message):
|
||||||
foo1: Foo = betterproto.message_field(1)
|
foo1: Foo = betterproto.message_field(1)
|
||||||
foo2: Foo = betterproto.message_field(2)
|
foo2: Foo = betterproto.message_field(2)
|
||||||
@ -25,7 +28,7 @@ class Bar(betterproto.Message):
|
|||||||
buffer = bytes(
|
buffer = bytes(
|
||||||
Bar(
|
Bar(
|
||||||
foo1=Foo(1, 2.34),
|
foo1=Foo(1, 2.34),
|
||||||
foo2=Foo(3, 4.56, [Baz("Hi"), Baz("There")]),
|
foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5), Baz(b = 2, d = 3)]),
|
||||||
packed=[5, 3, 1]
|
packed=[5, 3, 1]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user