use crate::{ error::{Error, Result}, py_any_extras::PyAnyExtras, }; use prost_reflect::{ prost_types::{ field_descriptor_proto::{Label, Type}, DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto, FileDescriptorProto, MessageOptions, OneofDescriptorProto, }, DescriptorPool, MessageDescriptor, }; use pyo3::PyAny; use std::sync::{Mutex, OnceLock}; pub fn create_cached_descriptor(obj: &PyAny) -> Result { static DESCRIPTOR_POOL: OnceLock> = OnceLock::new(); let mut pool = DESCRIPTOR_POOL .get_or_init(|| Mutex::new(DescriptorPool::global())) .lock() .unwrap(); let cls = obj.getattr("__class__")?; let name = format!("{}_{}", cls.qualified_name()?, cls.py_identifier()); if let Some(desc) = pool.get_message_by_name(&name) { return Ok(desc); } let mut file = FileDescriptorProto { name: Some(name.clone()), ..Default::default() }; add_message_to_file(name.clone(), obj, &pool, &mut file)?; pool.add_file_descriptor_proto(file)?; Ok(pool.get_message_by_name(&name).expect("Just registered...")) } fn add_message_to_file( message_name: String, obj: &PyAny, pool: &DescriptorPool, file: &mut FileDescriptorProto, ) -> Result<()> { let mut messages_to_add = vec![(message_name, obj)]; while let Some((message_name, obj)) = messages_to_add.pop() { let meta = obj.get_proto_meta()?; let mut message = DescriptorProto { name: Some(message_name.to_string()), ..Default::default() }; for item in meta .getattr("meta_by_field_name")? .call_method0("items")? .iter()? { let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?; message.field.push({ let mut field = FieldDescriptorProto { name: Some(field_name.to_string()), number: Some(field_meta.getattr("number")?.extract::()?), ..Default::default() }; let proto_type = field_meta.getattr("proto_type")?.extract::<&str>()?; if proto_type == "map" { field.set_type(Type::Message); let (key, val) = field_meta.getattr("map_types")?.extract::<(&str, &str)>()?; let key = map_type(key)?; let val = map_type(val)?; if matches!( key, Type::Float | Type::Double | Type::Bytes | Type::Message | Type::Enum ) { return Err(Error::UnsupportedMapKeyType(key)); } let map_entry_name = format!("{field_name}Entry"); field.type_name = Some(format!("{message_name}.{map_entry_name}")); field.set_label(Label::Repeated); message.nested_type.push(DescriptorProto { name: Some(map_entry_name), field: vec![ { let mut proto = FieldDescriptorProto { name: Some("key".to_string()), number: Some(1), ..Default::default() }; proto.set_type(key); proto }, { let mut proto = FieldDescriptorProto { name: Some("value".to_string()), number: Some(2), ..Default::default() }; proto.set_type(val); if val == Type::Message { set_type_name( &message_name, meta.get_class(&format!("{field_name}.value"))?, &mut proto, file, &mut messages_to_add, pool, )?; } proto }, ], options: Some(MessageOptions { map_entry: Some(true), ..Default::default() }), ..Default::default() }) } else { field.set_type(map_type(proto_type)?); match field.r#type() { Type::Message => match field_meta .getattr("wraps")? .extract::>()? .map(map_type) .transpose()? { Some(Type::Bool) => { field.type_name = Some("google.protobuf.BoolValue".to_string()); } Some(Type::Double) => { field.type_name = Some("google.protobuf.DoubleValue".to_string()); } Some(Type::Float) => { field.type_name = Some("google.protobuf.FloatValue".to_string()); } Some(Type::Int64) => { field.type_name = Some("google.protobuf.Int64Value".to_string()); } Some(Type::Uint64) => { field.type_name = Some("google.protobuf.UInt64Value".to_string()); } Some(Type::Int32) => { field.type_name = Some("google.protobuf.Int32Value".to_string()); } Some(Type::Uint32) => { field.type_name = Some("google.protobuf.UInt32Value".to_string()); } Some(Type::String) => { field.type_name = Some("google.protobuf.StringValue".to_string()); } Some(Type::Bytes) => { field.type_name = Some("google.protobuf.BytesValue".to_string()); } Some(t) => return Err(Error::UnsupportedWrapperType(t)), None => { set_type_name( &message_name, meta.get_class(field_name)?, &mut field, file, &mut messages_to_add, pool, )?; } }, Type::Enum => { let cls = meta.get_class(field_name)?; let cls_name = format!("{}_{}", cls.qualified_name()?, cls.py_identifier()); field.type_name = Some(cls_name.to_string()); if pool.get_enum_by_name(&cls_name).is_none() && !file.enum_type.iter().any(|item| item.name() == cls_name) { let mut proto = EnumDescriptorProto { name: Some(cls_name.clone()), ..Default::default() }; for item in cls.iter()? { let item = item?; proto.value.push(EnumValueDescriptorProto { number: Some(item.getattr("value")?.extract()?), name: Some(format!( "{}_{}", cls_name, item.getattr("name")?.extract::<&str>()? )), ..Default::default() }); } file.enum_type.push(proto); } } _ => {} } if meta.is_list_field(field_name)? { field.set_label(Label::Repeated); } else if field_meta.getattr("optional")?.extract::()? { field.proto3_optional = Some(true); } } if let Some(grp) = meta.oneof_group(field_name)? { let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp); match oneof_index { Some(i) => field.oneof_index = Some(i as i32), None => { message.oneof_decl.push(OneofDescriptorProto { name: Some(grp), ..Default::default() }); field.oneof_index = Some((message.oneof_decl.len() - 1) as i32) } } } field }); } file.message_type.push(message); } Ok(()) } fn map_type(str: &str) -> Result { match str { "enum" => Ok(Type::Enum), "bool" => Ok(Type::Bool), "int32" => Ok(Type::Int32), "int64" => Ok(Type::Int64), "uint32" => Ok(Type::Uint32), "uint64" => Ok(Type::Uint64), "sint32" => Ok(Type::Sint32), "sint64" => Ok(Type::Sint64), "float" => Ok(Type::Float), "double" => Ok(Type::Double), "fixed32" => Ok(Type::Fixed32), "sfixed32" => Ok(Type::Sfixed32), "fixed64" => Ok(Type::Fixed64), "sfixed64" => Ok(Type::Sfixed64), "string" => Ok(Type::String), "bytes" => Ok(Type::Bytes), "message" => Ok(Type::Message), _ => Err(Error::UnsupportedType(str.to_string())), } } fn set_type_name<'py>( message_name: &str, field_cls: &'py PyAny, field: &mut FieldDescriptorProto, file: &FileDescriptorProto, messages_to_add: &mut Vec<(String, &'py PyAny)>, pool: &DescriptorPool, ) -> Result<()> { let cls_name = field_cls.qualified_name()?; match cls_name.as_str() { "datetime.datetime" => { field.type_name = Some("google.protobuf.Timestamp".to_string()); } "datetime.timedelta" => { field.type_name = Some("google.protobuf.Duration".to_string()); } _ => { let cls_name = format!("{}_{}", cls_name, field_cls.py_identifier()); field.type_name = Some(cls_name.clone()); if message_name != cls_name && pool.get_message_by_name(&cls_name).is_none() && !file.message_type.iter().any(|item| item.name() == cls_name) && !messages_to_add.iter().any(|item| item.0 == cls_name) { messages_to_add.push((cls_name, field_cls.call0()?)); } } } Ok(()) }