enum support

This commit is contained in:
Erik Friese 2023-08-30 19:27:59 +02:00
parent 24d694afe2
commit a413d08fc1
4 changed files with 81 additions and 35 deletions

View File

@ -5,7 +5,8 @@ use crate::{
use prost_reflect::{ use prost_reflect::{
prost_types::{ prost_types::{
field_descriptor_proto::{Label, Type}, field_descriptor_proto::{Label, Type},
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, OneofDescriptorProto, DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
FileDescriptorProto, OneofDescriptorProto,
}, },
DescriptorPool, MessageDescriptor, DescriptorPool, MessageDescriptor,
}; };
@ -19,11 +20,11 @@ pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
.lock() .lock()
.unwrap(); .unwrap();
create_cached_descriptor_in_pool(obj, &mut pool) create_cached_message_in_pool(obj, &mut pool)
} }
fn create_cached_descriptor_in_pool<'py>( fn create_cached_message_in_pool(
obj: &'py PyAny, obj: &PyAny,
pool: &mut DescriptorPool, pool: &mut DescriptorPool,
) -> Result<MessageDescriptor> { ) -> Result<MessageDescriptor> {
let name = obj.qualified_class_name()?; let name = obj.qualified_class_name()?;
@ -38,17 +39,12 @@ fn create_cached_descriptor_in_pool<'py>(
..Default::default() ..Default::default()
}; };
let mut file = FileDescriptorProto {
name: Some(name.clone()),
..Default::default()
};
for item in meta for item in meta
.getattr("meta_by_field_name")? .getattr("meta_by_field_name")?
.call_method0("items")? .call_method0("items")?
.iter()? .iter()?
{ {
let (field_name, field_meta) = item?.extract::<(&str, &'py PyAny)>()?; let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?;
message.field.push({ message.field.push({
let mut field = FieldDescriptorProto { let mut field = FieldDescriptorProto {
name: Some(field_name.to_string()), name: Some(field_name.to_string()),
@ -59,11 +55,20 @@ fn create_cached_descriptor_in_pool<'py>(
field_meta.getattr("proto_type")?.extract::<&str>()?, field_meta.getattr("proto_type")?.extract::<&str>()?,
)?); )?);
if field.r#type() == Type::Message { match field.r#type() {
let instance = meta.create_instance(field_name)?; Type::Message => {
let cls_name = instance.qualified_class_name()?; let instance = meta.create_instance(field_name)?;
field.type_name = Some(cls_name.to_string()); let cls_name = instance.qualified_class_name()?;
create_cached_descriptor_in_pool(instance, pool)?; field.type_name = Some(cls_name.to_string());
create_cached_message_in_pool(instance, pool)?;
}
Type::Enum => {
let cls = meta.get_class(field_name)?;
let cls_name = cls.qualified_name()?;
field.type_name = Some(cls_name.to_string());
create_cached_enum_in_pool(cls, pool)?;
}
_ => {}
} }
if meta.is_list_field(field_name)? { if meta.is_list_field(field_name)? {
@ -73,13 +78,7 @@ fn create_cached_descriptor_in_pool<'py>(
} }
if let Some(grp) = meta.oneof_group(field_name)? { if let Some(grp) = meta.oneof_group(field_name)? {
let oneof_index = message let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp);
.oneof_decl
.iter()
.enumerate()
.filter(|x| x.1.name() == grp)
.map(|x| x.0)
.last();
match oneof_index { match oneof_index {
Some(i) => field.oneof_index = Some(i as i32), Some(i) => field.oneof_index = Some(i as i32),
@ -97,13 +96,43 @@ fn create_cached_descriptor_in_pool<'py>(
}); });
} }
file.message_type.push(message); pool.add_file_descriptor_proto(FileDescriptorProto {
name: Some(name.clone()),
pool.add_file_descriptor_proto(file)?; message_type: vec![message],
..Default::default()
})?;
Ok(pool.get_message_by_name(&name).expect("Just registered...")) Ok(pool.get_message_by_name(&name).expect("Just registered..."))
} }
fn create_cached_enum_in_pool(cls: &PyAny, pool: &mut DescriptorPool) -> Result<()> {
let cls_name = cls.qualified_name()?;
if pool.get_enum_by_name(&cls_name).is_some() {
return Ok(());
}
let mut proto = EnumDescriptorProto {
name: Some(cls_name.clone()),
..Default::default()
};
for item in cls.iter()? {
let item = item?;
proto.value.push(EnumValueDescriptorProto {
number: Some(item.getattr("value")?.extract()?),
name: Some(item.getattr("name")?.extract()?),
..Default::default()
});
}
pool.add_file_descriptor_proto(FileDescriptorProto {
name: Some(cls_name),
enum_type: vec![proto],
..Default::default()
})?;
Ok(())
}
fn map_type(str: &str) -> Result<Type> { fn map_type(str: &str) -> Result<Type> {
match str { match str {
"enum" => Ok(Type::Enum), "enum" => Ok(Type::Enum),

View File

@ -52,6 +52,10 @@ fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) ->
.map(|x| map_field_value(field_name, x, proto_meta)) .map(|x| map_field_value(field_name, x, proto_meta))
.collect::<Result<Vec<PyObject>>>()? .collect::<Result<Vec<PyObject>>>()?
.to_object(py)), .to_object(py)),
Value::EnumNumber(x) => {
let cls = proto_meta.get_class(field_name)?;
Ok(cls.call1((x,))?.to_object(py))
}
value => Err(Error::UnsupportedType(value.to_string())), value => Err(Error::UnsupportedType(value.to_string())),
} }
} }

View File

@ -2,31 +2,37 @@ use crate::error::Result;
use pyo3::PyAny; use pyo3::PyAny;
pub trait PyAnyExtras { pub trait PyAnyExtras {
fn qualified_name(&self) -> Result<String>;
fn qualified_class_name(&self) -> Result<String>; fn qualified_class_name(&self) -> Result<String>;
fn get_proto_meta(&self) -> Result<&PyAny>; fn get_proto_meta(&self) -> Result<&PyAny>;
fn get_class(&self, field_name: &str) -> Result<&PyAny>;
fn create_instance(&self, field_name: &str) -> Result<&PyAny>; fn create_instance(&self, field_name: &str) -> Result<&PyAny>;
fn is_list_field(&self, field_name: &str) -> Result<bool>; fn is_list_field(&self, field_name: &str) -> Result<bool>;
fn oneof_group(&self, field_name: &str) -> Result<Option<String>>; fn oneof_group(&self, field_name: &str) -> Result<Option<String>>;
} }
impl PyAnyExtras for PyAny { impl PyAnyExtras for PyAny {
fn qualified_class_name(&self) -> Result<String> { fn qualified_name(&self) -> Result<String> {
let class = self.getattr("__class__")?; let module = self.getattr("__module__")?;
let module = class.getattr("__module__")?; let name = self.getattr("__name__")?;
let name = class.getattr("__name__")?;
Ok(format!("{module}.{name}")) Ok(format!("{module}.{name}"))
} }
fn qualified_class_name(&self) -> Result<String> {
self.getattr("__class__")?.qualified_name()
}
fn get_proto_meta(&self) -> Result<&PyAny> { fn get_proto_meta(&self) -> Result<&PyAny> {
Ok(self.getattr("_betterproto")?) Ok(self.getattr("_betterproto")?)
} }
fn get_class(&self, field_name: &str) -> Result<&PyAny> {
let cls = self.getattr("cls_by_field")?.get_item(field_name)?;
Ok(cls)
}
fn create_instance(&self, field_name: &str) -> Result<&PyAny> { fn create_instance(&self, field_name: &str) -> Result<&PyAny> {
let res = self Ok(self.get_class(field_name)?.call0()?)
.getattr("cls_by_field")?
.get_item(field_name)?
.call0()?;
Ok(res)
} }
fn is_list_field(&self, field_name: &str) -> Result<bool> { fn is_list_field(&self, field_name: &str) -> Result<bool> {

View File

@ -19,18 +19,25 @@ class Foo(betterproto.Message):
y: float = betterproto.double_field(2) y: float = betterproto.double_field(2)
z: List[Baz] = betterproto.message_field(3) z: List[Baz] = betterproto.message_field(3)
class Enm(betterproto.Enum):
A = 0
B = 1
C = 2
@dataclass(repr=False) @dataclass(repr=False)
class Bar(betterproto.Message): class Bar(betterproto.Message):
foo1: Foo = betterproto.message_field(1) foo1: Foo = betterproto.message_field(1)
foo2: Foo = betterproto.message_field(2) foo2: Foo = betterproto.message_field(2)
packed: List[int] = betterproto.int64_field(3) packed: List[int] = betterproto.int64_field(3)
enm: Enm = betterproto.enum_field(4)
# Serialization has not been changed yet. So nothing unusual here # Serialization has not been changed yet. So nothing unusual here
buffer = bytes( buffer = bytes(
Bar( Bar(
foo1=Foo(1, 2.34), foo1=Foo(1, 2.34),
foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]), foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]),
packed=[5, 3, 1] packed=[5, 3, 1],
enm=Enm.B
) )
) )