290 lines
12 KiB
Rust
290 lines
12 KiB
Rust
use crate::{
|
|
error::{Error, Result},
|
|
py_any_extras::PyAnyExtras,
|
|
};
|
|
use prost_reflect::{
|
|
prost_types::{
|
|
field_descriptor_proto::{Label, Type},
|
|
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
|
|
FileDescriptorProto, MessageOptions, OneofDescriptorProto,
|
|
},
|
|
DescriptorPool, MessageDescriptor,
|
|
};
|
|
use pyo3::PyAny;
|
|
use std::sync::{Mutex, OnceLock};
|
|
|
|
pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
|
|
static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new();
|
|
let mut pool = DESCRIPTOR_POOL
|
|
.get_or_init(|| Mutex::new(DescriptorPool::global()))
|
|
.lock()
|
|
.unwrap();
|
|
|
|
let cls = obj.getattr("__class__")?;
|
|
let name = format!("{}_{}", cls.qualified_name()?, cls.py_identifier());
|
|
if let Some(desc) = pool.get_message_by_name(&name) {
|
|
return Ok(desc);
|
|
}
|
|
|
|
let mut file = FileDescriptorProto {
|
|
name: Some(name.clone()),
|
|
..Default::default()
|
|
};
|
|
|
|
add_message_to_file(name.clone(), obj, &pool, &mut file)?;
|
|
pool.add_file_descriptor_proto(file)?;
|
|
Ok(pool.get_message_by_name(&name).expect("Just registered..."))
|
|
}
|
|
|
|
fn add_message_to_file(
|
|
message_name: String,
|
|
obj: &PyAny,
|
|
pool: &DescriptorPool,
|
|
file: &mut FileDescriptorProto,
|
|
) -> Result<()> {
|
|
let mut messages_to_add = vec![(message_name, obj)];
|
|
|
|
while let Some((message_name, obj)) = messages_to_add.pop() {
|
|
let meta = obj.get_proto_meta()?;
|
|
let mut message = DescriptorProto {
|
|
name: Some(message_name.to_string()),
|
|
..Default::default()
|
|
};
|
|
|
|
for item in meta
|
|
.getattr("meta_by_field_name")?
|
|
.call_method0("items")?
|
|
.iter()?
|
|
{
|
|
let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?;
|
|
message.field.push({
|
|
let mut field = FieldDescriptorProto {
|
|
name: Some(field_name.to_string()),
|
|
number: Some(field_meta.getattr("number")?.extract::<i32>()?),
|
|
..Default::default()
|
|
};
|
|
let proto_type = field_meta.getattr("proto_type")?.extract::<&str>()?;
|
|
|
|
if proto_type == "map" {
|
|
field.set_type(Type::Message);
|
|
let (key, val) = field_meta.getattr("map_types")?.extract::<(&str, &str)>()?;
|
|
let key = map_type(key)?;
|
|
let val = map_type(val)?;
|
|
|
|
if matches!(
|
|
key,
|
|
Type::Float | Type::Double | Type::Bytes | Type::Message | Type::Enum
|
|
) {
|
|
return Err(Error::UnsupportedMapKeyType(key));
|
|
}
|
|
|
|
let map_entry_name = format!("{field_name}Entry");
|
|
field.type_name = Some(format!("{message_name}.{map_entry_name}"));
|
|
field.set_label(Label::Repeated);
|
|
message.nested_type.push(DescriptorProto {
|
|
name: Some(map_entry_name),
|
|
field: vec![
|
|
{
|
|
let mut proto = FieldDescriptorProto {
|
|
name: Some("key".to_string()),
|
|
number: Some(1),
|
|
..Default::default()
|
|
};
|
|
proto.set_type(key);
|
|
proto
|
|
},
|
|
{
|
|
let mut proto = FieldDescriptorProto {
|
|
name: Some("value".to_string()),
|
|
number: Some(2),
|
|
..Default::default()
|
|
};
|
|
proto.set_type(val);
|
|
if val == Type::Message {
|
|
set_type_name(
|
|
&message_name,
|
|
meta.get_class(&format!("{field_name}.value"))?,
|
|
&mut proto,
|
|
file,
|
|
&mut messages_to_add,
|
|
pool,
|
|
)?;
|
|
}
|
|
proto
|
|
},
|
|
],
|
|
options: Some(MessageOptions {
|
|
map_entry: Some(true),
|
|
..Default::default()
|
|
}),
|
|
..Default::default()
|
|
})
|
|
} else {
|
|
field.set_type(map_type(proto_type)?);
|
|
match field.r#type() {
|
|
Type::Message => match field_meta
|
|
.getattr("wraps")?
|
|
.extract::<Option<&str>>()?
|
|
.map(map_type)
|
|
.transpose()?
|
|
{
|
|
Some(Type::Bool) => {
|
|
field.type_name = Some("google.protobuf.BoolValue".to_string());
|
|
}
|
|
Some(Type::Double) => {
|
|
field.type_name = Some("google.protobuf.DoubleValue".to_string());
|
|
}
|
|
Some(Type::Float) => {
|
|
field.type_name = Some("google.protobuf.FloatValue".to_string());
|
|
}
|
|
Some(Type::Int64) => {
|
|
field.type_name = Some("google.protobuf.Int64Value".to_string());
|
|
}
|
|
Some(Type::Uint64) => {
|
|
field.type_name = Some("google.protobuf.UInt64Value".to_string());
|
|
}
|
|
Some(Type::Int32) => {
|
|
field.type_name = Some("google.protobuf.Int32Value".to_string());
|
|
}
|
|
Some(Type::Uint32) => {
|
|
field.type_name = Some("google.protobuf.UInt32Value".to_string());
|
|
}
|
|
Some(Type::String) => {
|
|
field.type_name = Some("google.protobuf.StringValue".to_string());
|
|
}
|
|
Some(Type::Bytes) => {
|
|
field.type_name = Some("google.protobuf.BytesValue".to_string());
|
|
}
|
|
Some(t) => return Err(Error::UnsupportedWrapperType(t)),
|
|
None => {
|
|
set_type_name(
|
|
&message_name,
|
|
meta.get_class(field_name)?,
|
|
&mut field,
|
|
file,
|
|
&mut messages_to_add,
|
|
pool,
|
|
)?;
|
|
}
|
|
},
|
|
Type::Enum => {
|
|
let cls = meta.get_class(field_name)?;
|
|
let cls_name =
|
|
format!("{}_{}", cls.qualified_name()?, cls.py_identifier());
|
|
field.type_name = Some(cls_name.to_string());
|
|
|
|
if pool.get_enum_by_name(&cls_name).is_none()
|
|
&& !file.enum_type.iter().any(|item| item.name() == cls_name)
|
|
{
|
|
let mut proto = EnumDescriptorProto {
|
|
name: Some(cls_name.clone()),
|
|
..Default::default()
|
|
};
|
|
|
|
for item in cls.iter()? {
|
|
let item = item?;
|
|
proto.value.push(EnumValueDescriptorProto {
|
|
number: Some(item.getattr("value")?.extract()?),
|
|
name: Some(format!(
|
|
"{}_{}",
|
|
cls_name,
|
|
item.getattr("name")?.extract::<&str>()?
|
|
)),
|
|
..Default::default()
|
|
});
|
|
}
|
|
|
|
file.enum_type.push(proto);
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
if meta.is_list_field(field_name)? {
|
|
field.set_label(Label::Repeated);
|
|
} else if field_meta.getattr("optional")?.extract::<bool>()? {
|
|
field.proto3_optional = Some(true);
|
|
}
|
|
}
|
|
|
|
if let Some(grp) = meta.oneof_group(field_name)? {
|
|
let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp);
|
|
|
|
match oneof_index {
|
|
Some(i) => field.oneof_index = Some(i as i32),
|
|
None => {
|
|
message.oneof_decl.push(OneofDescriptorProto {
|
|
name: Some(grp),
|
|
..Default::default()
|
|
});
|
|
field.oneof_index = Some((message.oneof_decl.len() - 1) as i32)
|
|
}
|
|
}
|
|
}
|
|
|
|
field
|
|
});
|
|
}
|
|
|
|
file.message_type.push(message);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn map_type(str: &str) -> Result<Type> {
|
|
match str {
|
|
"enum" => Ok(Type::Enum),
|
|
"bool" => Ok(Type::Bool),
|
|
"int32" => Ok(Type::Int32),
|
|
"int64" => Ok(Type::Int64),
|
|
"uint32" => Ok(Type::Uint32),
|
|
"uint64" => Ok(Type::Uint64),
|
|
"sint32" => Ok(Type::Sint32),
|
|
"sint64" => Ok(Type::Sint64),
|
|
"float" => Ok(Type::Float),
|
|
"double" => Ok(Type::Double),
|
|
"fixed32" => Ok(Type::Fixed32),
|
|
"sfixed32" => Ok(Type::Sfixed32),
|
|
"fixed64" => Ok(Type::Fixed64),
|
|
"sfixed64" => Ok(Type::Sfixed64),
|
|
"string" => Ok(Type::String),
|
|
"bytes" => Ok(Type::Bytes),
|
|
"message" => Ok(Type::Message),
|
|
_ => Err(Error::UnsupportedType(str.to_string())),
|
|
}
|
|
}
|
|
|
|
fn set_type_name<'py>(
|
|
message_name: &str,
|
|
field_cls: &'py PyAny,
|
|
field: &mut FieldDescriptorProto,
|
|
file: &FileDescriptorProto,
|
|
messages_to_add: &mut Vec<(String, &'py PyAny)>,
|
|
pool: &DescriptorPool,
|
|
) -> Result<()> {
|
|
let cls_name = field_cls.qualified_name()?;
|
|
|
|
match cls_name.as_str() {
|
|
"datetime.datetime" => {
|
|
field.type_name = Some("google.protobuf.Timestamp".to_string());
|
|
}
|
|
"datetime.timedelta" => {
|
|
field.type_name = Some("google.protobuf.Duration".to_string());
|
|
}
|
|
_ => {
|
|
let cls_name = format!("{}_{}", cls_name, field_cls.py_identifier());
|
|
field.type_name = Some(cls_name.clone());
|
|
|
|
if message_name != cls_name
|
|
&& pool.get_message_by_name(&cls_name).is_none()
|
|
&& !file.message_type.iter().any(|item| item.name() == cls_name)
|
|
&& !messages_to_add.iter().any(|item| item.0 == cls_name)
|
|
{
|
|
messages_to_add.push((cls_name, field_cls.call0()?));
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|