enum support
This commit is contained in:
parent
24d694afe2
commit
a413d08fc1
@ -5,7 +5,8 @@ use crate::{
|
|||||||
use prost_reflect::{
|
use prost_reflect::{
|
||||||
prost_types::{
|
prost_types::{
|
||||||
field_descriptor_proto::{Label, Type},
|
field_descriptor_proto::{Label, Type},
|
||||||
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, OneofDescriptorProto,
|
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
|
||||||
|
FileDescriptorProto, OneofDescriptorProto,
|
||||||
},
|
},
|
||||||
DescriptorPool, MessageDescriptor,
|
DescriptorPool, MessageDescriptor,
|
||||||
};
|
};
|
||||||
@ -19,11 +20,11 @@ pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
|
|||||||
.lock()
|
.lock()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
create_cached_descriptor_in_pool(obj, &mut pool)
|
create_cached_message_in_pool(obj, &mut pool)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_cached_descriptor_in_pool<'py>(
|
fn create_cached_message_in_pool(
|
||||||
obj: &'py PyAny,
|
obj: &PyAny,
|
||||||
pool: &mut DescriptorPool,
|
pool: &mut DescriptorPool,
|
||||||
) -> Result<MessageDescriptor> {
|
) -> Result<MessageDescriptor> {
|
||||||
let name = obj.qualified_class_name()?;
|
let name = obj.qualified_class_name()?;
|
||||||
@ -38,17 +39,12 @@ fn create_cached_descriptor_in_pool<'py>(
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut file = FileDescriptorProto {
|
|
||||||
name: Some(name.clone()),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
for item in meta
|
for item in meta
|
||||||
.getattr("meta_by_field_name")?
|
.getattr("meta_by_field_name")?
|
||||||
.call_method0("items")?
|
.call_method0("items")?
|
||||||
.iter()?
|
.iter()?
|
||||||
{
|
{
|
||||||
let (field_name, field_meta) = item?.extract::<(&str, &'py PyAny)>()?;
|
let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?;
|
||||||
message.field.push({
|
message.field.push({
|
||||||
let mut field = FieldDescriptorProto {
|
let mut field = FieldDescriptorProto {
|
||||||
name: Some(field_name.to_string()),
|
name: Some(field_name.to_string()),
|
||||||
@ -59,11 +55,20 @@ fn create_cached_descriptor_in_pool<'py>(
|
|||||||
field_meta.getattr("proto_type")?.extract::<&str>()?,
|
field_meta.getattr("proto_type")?.extract::<&str>()?,
|
||||||
)?);
|
)?);
|
||||||
|
|
||||||
if field.r#type() == Type::Message {
|
match field.r#type() {
|
||||||
let instance = meta.create_instance(field_name)?;
|
Type::Message => {
|
||||||
let cls_name = instance.qualified_class_name()?;
|
let instance = meta.create_instance(field_name)?;
|
||||||
field.type_name = Some(cls_name.to_string());
|
let cls_name = instance.qualified_class_name()?;
|
||||||
create_cached_descriptor_in_pool(instance, pool)?;
|
field.type_name = Some(cls_name.to_string());
|
||||||
|
create_cached_message_in_pool(instance, pool)?;
|
||||||
|
}
|
||||||
|
Type::Enum => {
|
||||||
|
let cls = meta.get_class(field_name)?;
|
||||||
|
let cls_name = cls.qualified_name()?;
|
||||||
|
field.type_name = Some(cls_name.to_string());
|
||||||
|
create_cached_enum_in_pool(cls, pool)?;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
if meta.is_list_field(field_name)? {
|
if meta.is_list_field(field_name)? {
|
||||||
@ -73,13 +78,7 @@ fn create_cached_descriptor_in_pool<'py>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(grp) = meta.oneof_group(field_name)? {
|
if let Some(grp) = meta.oneof_group(field_name)? {
|
||||||
let oneof_index = message
|
let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp);
|
||||||
.oneof_decl
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.filter(|x| x.1.name() == grp)
|
|
||||||
.map(|x| x.0)
|
|
||||||
.last();
|
|
||||||
|
|
||||||
match oneof_index {
|
match oneof_index {
|
||||||
Some(i) => field.oneof_index = Some(i as i32),
|
Some(i) => field.oneof_index = Some(i as i32),
|
||||||
@ -97,13 +96,43 @@ fn create_cached_descriptor_in_pool<'py>(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
file.message_type.push(message);
|
pool.add_file_descriptor_proto(FileDescriptorProto {
|
||||||
|
name: Some(name.clone()),
|
||||||
pool.add_file_descriptor_proto(file)?;
|
message_type: vec![message],
|
||||||
|
..Default::default()
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(pool.get_message_by_name(&name).expect("Just registered..."))
|
Ok(pool.get_message_by_name(&name).expect("Just registered..."))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_cached_enum_in_pool(cls: &PyAny, pool: &mut DescriptorPool) -> Result<()> {
|
||||||
|
let cls_name = cls.qualified_name()?;
|
||||||
|
if pool.get_enum_by_name(&cls_name).is_some() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut proto = EnumDescriptorProto {
|
||||||
|
name: Some(cls_name.clone()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
for item in cls.iter()? {
|
||||||
|
let item = item?;
|
||||||
|
proto.value.push(EnumValueDescriptorProto {
|
||||||
|
number: Some(item.getattr("value")?.extract()?),
|
||||||
|
name: Some(item.getattr("name")?.extract()?),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.add_file_descriptor_proto(FileDescriptorProto {
|
||||||
|
name: Some(cls_name),
|
||||||
|
enum_type: vec![proto],
|
||||||
|
..Default::default()
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn map_type(str: &str) -> Result<Type> {
|
fn map_type(str: &str) -> Result<Type> {
|
||||||
match str {
|
match str {
|
||||||
"enum" => Ok(Type::Enum),
|
"enum" => Ok(Type::Enum),
|
||||||
|
@ -52,6 +52,10 @@ fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) ->
|
|||||||
.map(|x| map_field_value(field_name, x, proto_meta))
|
.map(|x| map_field_value(field_name, x, proto_meta))
|
||||||
.collect::<Result<Vec<PyObject>>>()?
|
.collect::<Result<Vec<PyObject>>>()?
|
||||||
.to_object(py)),
|
.to_object(py)),
|
||||||
|
Value::EnumNumber(x) => {
|
||||||
|
let cls = proto_meta.get_class(field_name)?;
|
||||||
|
Ok(cls.call1((x,))?.to_object(py))
|
||||||
|
}
|
||||||
value => Err(Error::UnsupportedType(value.to_string())),
|
value => Err(Error::UnsupportedType(value.to_string())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,31 +2,37 @@ use crate::error::Result;
|
|||||||
use pyo3::PyAny;
|
use pyo3::PyAny;
|
||||||
|
|
||||||
pub trait PyAnyExtras {
|
pub trait PyAnyExtras {
|
||||||
|
fn qualified_name(&self) -> Result<String>;
|
||||||
fn qualified_class_name(&self) -> Result<String>;
|
fn qualified_class_name(&self) -> Result<String>;
|
||||||
fn get_proto_meta(&self) -> Result<&PyAny>;
|
fn get_proto_meta(&self) -> Result<&PyAny>;
|
||||||
|
fn get_class(&self, field_name: &str) -> Result<&PyAny>;
|
||||||
fn create_instance(&self, field_name: &str) -> Result<&PyAny>;
|
fn create_instance(&self, field_name: &str) -> Result<&PyAny>;
|
||||||
fn is_list_field(&self, field_name: &str) -> Result<bool>;
|
fn is_list_field(&self, field_name: &str) -> Result<bool>;
|
||||||
fn oneof_group(&self, field_name: &str) -> Result<Option<String>>;
|
fn oneof_group(&self, field_name: &str) -> Result<Option<String>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PyAnyExtras for PyAny {
|
impl PyAnyExtras for PyAny {
|
||||||
fn qualified_class_name(&self) -> Result<String> {
|
fn qualified_name(&self) -> Result<String> {
|
||||||
let class = self.getattr("__class__")?;
|
let module = self.getattr("__module__")?;
|
||||||
let module = class.getattr("__module__")?;
|
let name = self.getattr("__name__")?;
|
||||||
let name = class.getattr("__name__")?;
|
|
||||||
Ok(format!("{module}.{name}"))
|
Ok(format!("{module}.{name}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn qualified_class_name(&self) -> Result<String> {
|
||||||
|
self.getattr("__class__")?.qualified_name()
|
||||||
|
}
|
||||||
|
|
||||||
fn get_proto_meta(&self) -> Result<&PyAny> {
|
fn get_proto_meta(&self) -> Result<&PyAny> {
|
||||||
Ok(self.getattr("_betterproto")?)
|
Ok(self.getattr("_betterproto")?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_class(&self, field_name: &str) -> Result<&PyAny> {
|
||||||
|
let cls = self.getattr("cls_by_field")?.get_item(field_name)?;
|
||||||
|
Ok(cls)
|
||||||
|
}
|
||||||
|
|
||||||
fn create_instance(&self, field_name: &str) -> Result<&PyAny> {
|
fn create_instance(&self, field_name: &str) -> Result<&PyAny> {
|
||||||
let res = self
|
Ok(self.get_class(field_name)?.call0()?)
|
||||||
.getattr("cls_by_field")?
|
|
||||||
.get_item(field_name)?
|
|
||||||
.call0()?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_list_field(&self, field_name: &str) -> Result<bool> {
|
fn is_list_field(&self, field_name: &str) -> Result<bool> {
|
||||||
|
@ -19,18 +19,25 @@ class Foo(betterproto.Message):
|
|||||||
y: float = betterproto.double_field(2)
|
y: float = betterproto.double_field(2)
|
||||||
z: List[Baz] = betterproto.message_field(3)
|
z: List[Baz] = betterproto.message_field(3)
|
||||||
|
|
||||||
|
class Enm(betterproto.Enum):
|
||||||
|
A = 0
|
||||||
|
B = 1
|
||||||
|
C = 2
|
||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False)
|
||||||
class Bar(betterproto.Message):
|
class Bar(betterproto.Message):
|
||||||
foo1: Foo = betterproto.message_field(1)
|
foo1: Foo = betterproto.message_field(1)
|
||||||
foo2: Foo = betterproto.message_field(2)
|
foo2: Foo = betterproto.message_field(2)
|
||||||
packed: List[int] = betterproto.int64_field(3)
|
packed: List[int] = betterproto.int64_field(3)
|
||||||
|
enm: Enm = betterproto.enum_field(4)
|
||||||
|
|
||||||
# Serialization has not been changed yet. So nothing unusual here
|
# Serialization has not been changed yet. So nothing unusual here
|
||||||
buffer = bytes(
|
buffer = bytes(
|
||||||
Bar(
|
Bar(
|
||||||
foo1=Foo(1, 2.34),
|
foo1=Foo(1, 2.34),
|
||||||
foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]),
|
foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]),
|
||||||
packed=[5, 3, 1]
|
packed=[5, 3, 1],
|
||||||
|
enm=Enm.B
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user