use crate::{ error::{Error, Result}, py_any_extras::PyAnyExtras, }; use prost_reflect::{ prost_types::{ field_descriptor_proto::{Label, Type}, DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto, FileDescriptorProto, OneofDescriptorProto, }, DescriptorPool, MessageDescriptor, }; use pyo3::PyAny; use std::sync::{Mutex, OnceLock}; pub fn create_cached_descriptor(obj: &PyAny) -> Result { static DESCRIPTOR_POOL: OnceLock> = OnceLock::new(); let mut pool = DESCRIPTOR_POOL .get_or_init(|| Mutex::new(DescriptorPool::new())) .lock() .unwrap(); create_cached_message_in_pool(obj, &mut pool) } fn create_cached_message_in_pool( obj: &PyAny, pool: &mut DescriptorPool, ) -> Result { let name = obj.qualified_class_name()?; if let Some(desc) = pool.get_message_by_name(&name) { return Ok(desc); } let meta = obj.get_proto_meta()?; let mut message = DescriptorProto { name: Some(name.clone()), ..Default::default() }; for item in meta .getattr("meta_by_field_name")? .call_method0("items")? .iter()? { let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?; message.field.push({ let mut field = FieldDescriptorProto { name: Some(field_name.to_string()), number: Some(field_meta.getattr("number")?.extract::()?), ..Default::default() }; field.set_type(map_type( field_meta.getattr("proto_type")?.extract::<&str>()?, )?); match field.r#type() { Type::Message => { let instance = meta.create_instance(field_name)?; let cls_name = instance.qualified_class_name()?; field.type_name = Some(cls_name.to_string()); create_cached_message_in_pool(instance, pool)?; } Type::Enum => { let cls = meta.get_class(field_name)?; let cls_name = cls.qualified_name()?; field.type_name = Some(cls_name.to_string()); create_cached_enum_in_pool(cls, pool)?; } _ => {} } if meta.is_list_field(field_name)? { field.set_label(Label::Repeated); } else if field_meta.getattr("optional")?.extract::()? { field.proto3_optional = Some(true); } if let Some(grp) = meta.oneof_group(field_name)? { let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp); match oneof_index { Some(i) => field.oneof_index = Some(i as i32), None => { message.oneof_decl.push(OneofDescriptorProto { name: Some(grp), ..Default::default() }); field.oneof_index = Some((message.oneof_decl.len() - 1) as i32) } } } field }); } pool.add_file_descriptor_proto(FileDescriptorProto { name: Some(name.clone()), message_type: vec![message], ..Default::default() })?; Ok(pool.get_message_by_name(&name).expect("Just registered...")) } fn create_cached_enum_in_pool(cls: &PyAny, pool: &mut DescriptorPool) -> Result<()> { let cls_name = cls.qualified_name()?; if pool.get_enum_by_name(&cls_name).is_some() { return Ok(()); } let mut proto = EnumDescriptorProto { name: Some(cls_name.clone()), ..Default::default() }; for item in cls.iter()? { let item = item?; proto.value.push(EnumValueDescriptorProto { number: Some(item.getattr("value")?.extract()?), name: Some(item.getattr("name")?.extract()?), ..Default::default() }); } pool.add_file_descriptor_proto(FileDescriptorProto { name: Some(cls_name), enum_type: vec![proto], ..Default::default() })?; Ok(()) } fn map_type(str: &str) -> Result { match str { "enum" => Ok(Type::Enum), "bool" => Ok(Type::Bool), "int32" => Ok(Type::Int32), "int64" => Ok(Type::Int64), "uint32" => Ok(Type::Uint32), "uint64" => Ok(Type::Uint64), "sint32" => Ok(Type::Sint32), "sint64" => Ok(Type::Sint64), "float" => Ok(Type::Float), "double" => Ok(Type::Double), "fixed32" => Ok(Type::Fixed32), "sfixed32" => Ok(Type::Sfixed32), "fixed64" => Ok(Type::Fixed64), "sfixed64" => Ok(Type::Sfixed64), "string" => Ok(Type::String), "bytes" => Ok(Type::Bytes), "message" => Ok(Type::Message), _ => Err(Error::UnsupportedType(str.to_string())), } }