158 lines
5.0 KiB
Rust
158 lines
5.0 KiB
Rust
use crate::{
|
|
error::{Error, Result},
|
|
py_any_extras::PyAnyExtras,
|
|
};
|
|
use prost_reflect::{
|
|
prost_types::{
|
|
field_descriptor_proto::{Label, Type},
|
|
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
|
|
FileDescriptorProto, OneofDescriptorProto,
|
|
},
|
|
DescriptorPool, MessageDescriptor,
|
|
};
|
|
use pyo3::PyAny;
|
|
use std::sync::{Mutex, OnceLock};
|
|
|
|
pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
|
|
static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new();
|
|
let mut pool = DESCRIPTOR_POOL
|
|
.get_or_init(|| Mutex::new(DescriptorPool::new()))
|
|
.lock()
|
|
.unwrap();
|
|
|
|
create_cached_message_in_pool(obj, &mut pool)
|
|
}
|
|
|
|
fn create_cached_message_in_pool(
|
|
obj: &PyAny,
|
|
pool: &mut DescriptorPool,
|
|
) -> Result<MessageDescriptor> {
|
|
let name = obj.qualified_class_name()?;
|
|
if let Some(desc) = pool.get_message_by_name(&name) {
|
|
return Ok(desc);
|
|
}
|
|
|
|
let meta = obj.get_proto_meta()?;
|
|
|
|
let mut message = DescriptorProto {
|
|
name: Some(name.clone()),
|
|
..Default::default()
|
|
};
|
|
|
|
for item in meta
|
|
.getattr("meta_by_field_name")?
|
|
.call_method0("items")?
|
|
.iter()?
|
|
{
|
|
let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?;
|
|
message.field.push({
|
|
let mut field = FieldDescriptorProto {
|
|
name: Some(field_name.to_string()),
|
|
number: Some(field_meta.getattr("number")?.extract::<i32>()?),
|
|
..Default::default()
|
|
};
|
|
field.set_type(map_type(
|
|
field_meta.getattr("proto_type")?.extract::<&str>()?,
|
|
)?);
|
|
|
|
match field.r#type() {
|
|
Type::Message => {
|
|
let instance = meta.create_instance(field_name)?;
|
|
let cls_name = instance.qualified_class_name()?;
|
|
field.type_name = Some(cls_name.to_string());
|
|
create_cached_message_in_pool(instance, pool)?;
|
|
}
|
|
Type::Enum => {
|
|
let cls = meta.get_class(field_name)?;
|
|
let cls_name = cls.qualified_name()?;
|
|
field.type_name = Some(cls_name.to_string());
|
|
create_cached_enum_in_pool(cls, pool)?;
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
if meta.is_list_field(field_name)? {
|
|
field.set_label(Label::Repeated);
|
|
} else if field_meta.getattr("optional")?.extract::<bool>()? {
|
|
field.proto3_optional = Some(true);
|
|
}
|
|
|
|
if let Some(grp) = meta.oneof_group(field_name)? {
|
|
let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp);
|
|
|
|
match oneof_index {
|
|
Some(i) => field.oneof_index = Some(i as i32),
|
|
None => {
|
|
message.oneof_decl.push(OneofDescriptorProto {
|
|
name: Some(grp),
|
|
..Default::default()
|
|
});
|
|
field.oneof_index = Some((message.oneof_decl.len() - 1) as i32)
|
|
}
|
|
}
|
|
}
|
|
|
|
field
|
|
});
|
|
}
|
|
|
|
pool.add_file_descriptor_proto(FileDescriptorProto {
|
|
name: Some(name.clone()),
|
|
message_type: vec![message],
|
|
..Default::default()
|
|
})?;
|
|
|
|
Ok(pool.get_message_by_name(&name).expect("Just registered..."))
|
|
}
|
|
|
|
fn create_cached_enum_in_pool(cls: &PyAny, pool: &mut DescriptorPool) -> Result<()> {
|
|
let cls_name = cls.qualified_name()?;
|
|
if pool.get_enum_by_name(&cls_name).is_some() {
|
|
return Ok(());
|
|
}
|
|
|
|
let mut proto = EnumDescriptorProto {
|
|
name: Some(cls_name.clone()),
|
|
..Default::default()
|
|
};
|
|
|
|
for item in cls.iter()? {
|
|
let item = item?;
|
|
proto.value.push(EnumValueDescriptorProto {
|
|
number: Some(item.getattr("value")?.extract()?),
|
|
name: Some(item.getattr("name")?.extract()?),
|
|
..Default::default()
|
|
});
|
|
}
|
|
|
|
pool.add_file_descriptor_proto(FileDescriptorProto {
|
|
name: Some(cls_name),
|
|
enum_type: vec![proto],
|
|
..Default::default()
|
|
})?;
|
|
Ok(())
|
|
}
|
|
|
|
fn map_type(str: &str) -> Result<Type> {
|
|
match str {
|
|
"enum" => Ok(Type::Enum),
|
|
"bool" => Ok(Type::Bool),
|
|
"int32" => Ok(Type::Int32),
|
|
"int64" => Ok(Type::Int64),
|
|
"uint32" => Ok(Type::Uint32),
|
|
"uint64" => Ok(Type::Uint64),
|
|
"sint32" => Ok(Type::Sint32),
|
|
"sint64" => Ok(Type::Sint64),
|
|
"float" => Ok(Type::Float),
|
|
"double" => Ok(Type::Double),
|
|
"fixed32" => Ok(Type::Fixed32),
|
|
"sfixed32" => Ok(Type::Sfixed32),
|
|
"fixed64" => Ok(Type::Fixed64),
|
|
"sfixed64" => Ok(Type::Sfixed64),
|
|
"string" => Ok(Type::String),
|
|
"bytes" => Ok(Type::Bytes),
|
|
"message" => Ok(Type::Message),
|
|
_ => Err(Error::UnsupportedType(str.to_string())),
|
|
}
|
|
}
|