google wrapper types

This commit is contained in:
Erik Friese 2023-09-04 17:18:50 +02:00
parent 29f12ea88d
commit 950d2f6536
4 changed files with 97 additions and 20 deletions

View File

@ -16,7 +16,7 @@ use std::sync::{Mutex, OnceLock};
pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> { pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new(); static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new();
let mut pool = DESCRIPTOR_POOL let mut pool = DESCRIPTOR_POOL
.get_or_init(|| Mutex::new(DescriptorPool::new())) .get_or_init(|| Mutex::new(DescriptorPool::global()))
.lock() .lock()
.unwrap(); .unwrap();
@ -111,7 +111,41 @@ fn add_message_to_file(
} else { } else {
field.set_type(map_type(proto_type)?); field.set_type(map_type(proto_type)?);
match field.r#type() { match field.r#type() {
Type::Message => { Type::Message => match field_meta
.getattr("wraps")?
.extract::<Option<&str>>()?
.map(map_type)
.transpose()?
{
Some(Type::Bool) => {
field.type_name = Some(".google.protobuf.BoolValue".to_string());
}
Some(Type::Double) => {
field.type_name = Some(".google.protobuf.DoubleValue".to_string());
}
Some(Type::Float) => {
field.type_name = Some(".google.protobuf.FloatValue".to_string());
}
Some(Type::Int64) => {
field.type_name = Some(".google.protobuf.Int64Value".to_string());
}
Some(Type::Uint64) => {
field.type_name = Some(".google.protobuf.UInt64Value".to_string());
}
Some(Type::Int32) => {
field.type_name = Some(".google.protobuf.Int32Value".to_string());
}
Some(Type::Uint32) => {
field.type_name = Some(".google.protobuf.UInt32Value".to_string());
}
Some(Type::String) => {
field.type_name = Some(".google.protobuf.StringValue".to_string());
}
Some(Type::Bytes) => {
field.type_name = Some(".google.protobuf.BytesValue".to_string());
}
Some(t) => return Err(Error::UnsupportedWrapperType(t)),
None => {
let cls = meta.get_class(field_name)?; let cls = meta.get_class(field_name)?;
let cls_name = cls.qualified_name()?; let cls_name = cls.qualified_name()?;
field.type_name = Some(cls_name.clone()); field.type_name = Some(cls_name.clone());
@ -124,6 +158,7 @@ fn add_message_to_file(
messages_to_add.push((cls_name, cls.call0()?)); messages_to_add.push((cls_name, cls.call0()?));
} }
} }
},
Type::Enum => { Type::Enum => {
let cls = meta.get_class(field_name)?; let cls = meta.get_class(field_name)?;
let cls_name = cls.qualified_name()?; let cls_name = cls.qualified_name()?;

View File

@ -12,6 +12,8 @@ pub enum Error {
UnsupportedType(String), UnsupportedType(String),
#[error("Unsupported map key type `{0:?}`.")] #[error("Unsupported map key type `{0:?}`.")]
UnsupportedMapKeyType(Type), UnsupportedMapKeyType(Type),
#[error("Unsupported wrapper type `{0:?}`.")]
UnsupportedWrapperType(Type),
#[error("Error on proto registration")] #[error("Error on proto registration")]
FailedToRegisterDescriptor(#[from] DescriptorError), FailedToRegisterDescriptor(#[from] DescriptorError),
#[error("The given binary data does not match the protobuf schema.")] #[error("The given binary data does not match the protobuf schema.")]

View File

@ -1,5 +1,5 @@
use crate::{error::Result, py_any_extras::PyAnyExtras}; use crate::{error::Result, py_any_extras::PyAnyExtras};
use prost_reflect::{DynamicMessage, MapKey, Value}; use prost_reflect::{DynamicMessage, MapKey, ReflectMessage, Value};
use pyo3::{ use pyo3::{
types::{IntoPyDict, PyBytes}, types::{IntoPyDict, PyBytes},
PyAny, PyObject, Python, ToPyObject, PyAny, PyObject, Python, ToPyObject,
@ -42,11 +42,49 @@ fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) ->
Value::String(x) => Ok(x.to_object(py)), Value::String(x) => Ok(x.to_object(py)),
Value::U32(x) => Ok(x.to_object(py)), Value::U32(x) => Ok(x.to_object(py)),
Value::U64(x) => Ok(x.to_object(py)), Value::U64(x) => Ok(x.to_object(py)),
Value::Message(msg) => { Value::Message(msg) => match msg.descriptor().full_name() {
"google.protobuf.BoolValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_bool())
.to_object(py)),
"google.protobuf.DoubleValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_f64())
.to_object(py)),
"google.protobuf.FloatValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_f32())
.to_object(py)),
"google.protobuf.Int64Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_i64())
.to_object(py)),
"google.protobuf.UInt64Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_u64())
.to_object(py)),
"google.protobuf.Int32Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_i32())
.to_object(py)),
"google.protobuf.UInt32Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_u32())
.to_object(py)),
"google.protobuf.StringValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_str().map(|s| s.to_string()))
.to_object(py)),
"google.protobuf.BytesValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_bytes().map(|b| PyBytes::new(py, b)))
.to_object(py)),
_ => {
let obj = proto_meta.create_instance(field_name)?; let obj = proto_meta.create_instance(field_name)?;
merge_msg_into_pyobj(obj, msg)?; merge_msg_into_pyobj(obj, msg)?;
Ok(obj.to_object(py)) Ok(obj.to_object(py))
} }
},
Value::List(ls) => Ok(ls Value::List(ls) => Ok(ls
.into_iter() .into_iter()
.map(|x| map_field_value(field_name, x, proto_meta)) .map(|x| map_field_value(field_name, x, proto_meta))

View File

@ -31,6 +31,7 @@ class Bar(betterproto.Message):
packed: List[int] = betterproto.int64_field(3) packed: List[int] = betterproto.int64_field(3)
enm: Enm = betterproto.enum_field(4) enm: Enm = betterproto.enum_field(4)
map: Dict[int, bool] = betterproto.map_field(5, betterproto.TYPE_INT64, betterproto.TYPE_BOOL) map: Dict[int, bool] = betterproto.map_field(5, betterproto.TYPE_INT64, betterproto.TYPE_BOOL)
maybe: Optional[bool] = betterproto.message_field(6, wraps=betterproto.TYPE_BOOL)
# Serialization has not been changed yet. So nothing unusual here # Serialization has not been changed yet. So nothing unusual here
buffer = bytes( buffer = bytes(
@ -42,7 +43,8 @@ buffer = bytes(
map={ map={
1: True, 1: True,
42: False 42: False
} },
maybe=True
) )
) )