diff --git a/betterproto-extras/src/descriptor_pool.rs b/betterproto-extras/src/descriptor_pool.rs index 9ce590b..90fc203 100644 --- a/betterproto-extras/src/descriptor_pool.rs +++ b/betterproto-extras/src/descriptor_pool.rs @@ -16,7 +16,7 @@ use std::sync::{Mutex, OnceLock}; pub fn create_cached_descriptor(obj: &PyAny) -> Result { static DESCRIPTOR_POOL: OnceLock> = OnceLock::new(); let mut pool = DESCRIPTOR_POOL - .get_or_init(|| Mutex::new(DescriptorPool::new())) + .get_or_init(|| Mutex::new(DescriptorPool::global())) .lock() .unwrap(); @@ -111,19 +111,54 @@ fn add_message_to_file( } else { field.set_type(map_type(proto_type)?); match field.r#type() { - Type::Message => { - let cls = meta.get_class(field_name)?; - let cls_name = cls.qualified_name()?; - field.type_name = Some(cls_name.clone()); - - if message_name != cls_name - && pool.get_message_by_name(&cls_name).is_none() - && !file.message_type.iter().any(|item| item.name() == cls_name) - && !messages_to_add.iter().any(|item| item.0 == cls_name) - { - messages_to_add.push((cls_name, cls.call0()?)); + Type::Message => match field_meta + .getattr("wraps")? + .extract::>()? + .map(map_type) + .transpose()? + { + Some(Type::Bool) => { + field.type_name = Some(".google.protobuf.BoolValue".to_string()); } - } + Some(Type::Double) => { + field.type_name = Some(".google.protobuf.DoubleValue".to_string()); + } + Some(Type::Float) => { + field.type_name = Some(".google.protobuf.FloatValue".to_string()); + } + Some(Type::Int64) => { + field.type_name = Some(".google.protobuf.Int64Value".to_string()); + } + Some(Type::Uint64) => { + field.type_name = Some(".google.protobuf.UInt64Value".to_string()); + } + Some(Type::Int32) => { + field.type_name = Some(".google.protobuf.Int32Value".to_string()); + } + Some(Type::Uint32) => { + field.type_name = Some(".google.protobuf.UInt32Value".to_string()); + } + Some(Type::String) => { + field.type_name = Some(".google.protobuf.StringValue".to_string()); + } + Some(Type::Bytes) => { + field.type_name = Some(".google.protobuf.BytesValue".to_string()); + } + Some(t) => return Err(Error::UnsupportedWrapperType(t)), + None => { + let cls = meta.get_class(field_name)?; + let cls_name = cls.qualified_name()?; + field.type_name = Some(cls_name.clone()); + + if message_name != cls_name + && pool.get_message_by_name(&cls_name).is_none() + && !file.message_type.iter().any(|item| item.name() == cls_name) + && !messages_to_add.iter().any(|item| item.0 == cls_name) + { + messages_to_add.push((cls_name, cls.call0()?)); + } + } + }, Type::Enum => { let cls = meta.get_class(field_name)?; let cls_name = cls.qualified_name()?; diff --git a/betterproto-extras/src/error.rs b/betterproto-extras/src/error.rs index 0279041..6cc0885 100644 --- a/betterproto-extras/src/error.rs +++ b/betterproto-extras/src/error.rs @@ -12,6 +12,8 @@ pub enum Error { UnsupportedType(String), #[error("Unsupported map key type `{0:?}`.")] UnsupportedMapKeyType(Type), + #[error("Unsupported wrapper type `{0:?}`.")] + UnsupportedWrapperType(Type), #[error("Error on proto registration")] FailedToRegisterDescriptor(#[from] DescriptorError), #[error("The given binary data does not match the protobuf schema.")] diff --git a/betterproto-extras/src/merging.rs b/betterproto-extras/src/merging.rs index d8ce976..a36caa0 100644 --- a/betterproto-extras/src/merging.rs +++ b/betterproto-extras/src/merging.rs @@ -1,5 +1,5 @@ use crate::{error::Result, py_any_extras::PyAnyExtras}; -use prost_reflect::{DynamicMessage, MapKey, Value}; +use prost_reflect::{DynamicMessage, MapKey, ReflectMessage, Value}; use pyo3::{ types::{IntoPyDict, PyBytes}, PyAny, PyObject, Python, ToPyObject, @@ -42,11 +42,49 @@ fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) -> Value::String(x) => Ok(x.to_object(py)), Value::U32(x) => Ok(x.to_object(py)), Value::U64(x) => Ok(x.to_object(py)), - Value::Message(msg) => { - let obj = proto_meta.create_instance(field_name)?; - merge_msg_into_pyobj(obj, msg)?; - Ok(obj.to_object(py)) - } + Value::Message(msg) => match msg.descriptor().full_name() { + "google.protobuf.BoolValue" => Ok(msg + .get_field_by_number(1) + .and_then(|val| val.as_bool()) + .to_object(py)), + "google.protobuf.DoubleValue" => Ok(msg + .get_field_by_number(1) + .and_then(|val| val.as_f64()) + .to_object(py)), + "google.protobuf.FloatValue" => Ok(msg + .get_field_by_number(1) + .and_then(|val| val.as_f32()) + .to_object(py)), + "google.protobuf.Int64Value" => Ok(msg + .get_field_by_number(1) + .and_then(|val| val.as_i64()) + .to_object(py)), + "google.protobuf.UInt64Value" => Ok(msg + .get_field_by_number(1) + .and_then(|val| val.as_u64()) + .to_object(py)), + "google.protobuf.Int32Value" => Ok(msg + .get_field_by_number(1) + .and_then(|val| val.as_i32()) + .to_object(py)), + "google.protobuf.UInt32Value" => Ok(msg + .get_field_by_number(1) + .and_then(|val| val.as_u32()) + .to_object(py)), + "google.protobuf.StringValue" => Ok(msg + .get_field_by_number(1) + .and_then(|val| val.as_str().map(|s| s.to_string())) + .to_object(py)), + "google.protobuf.BytesValue" => Ok(msg + .get_field_by_number(1) + .and_then(|val| val.as_bytes().map(|b| PyBytes::new(py, b))) + .to_object(py)), + _ => { + let obj = proto_meta.create_instance(field_name)?; + merge_msg_into_pyobj(obj, msg)?; + Ok(obj.to_object(py)) + } + }, Value::List(ls) => Ok(ls .into_iter() .map(|x| map_field_value(field_name, x, proto_meta)) diff --git a/example.py b/example.py index f72c8b8..75a97fd 100644 --- a/example.py +++ b/example.py @@ -31,6 +31,7 @@ class Bar(betterproto.Message): packed: List[int] = betterproto.int64_field(3) enm: Enm = betterproto.enum_field(4) map: Dict[int, bool] = betterproto.map_field(5, betterproto.TYPE_INT64, betterproto.TYPE_BOOL) + maybe: Optional[bool] = betterproto.message_field(6, wraps=betterproto.TYPE_BOOL) # Serialization has not been changed yet. So nothing unusual here buffer = bytes( @@ -42,7 +43,8 @@ buffer = bytes( map={ 1: True, 42: False - } + }, + maybe=True ) )