From 29f12ea88d2574516f6f605a007c6905b3195496 Mon Sep 17 00:00:00 2001 From: Erik Friese Date: Mon, 4 Sep 2023 12:38:18 +0200 Subject: [PATCH] map support --- betterproto-extras/src/descriptor_pool.rs | 145 ++++++++++++++-------- betterproto-extras/src/error.rs | 6 +- betterproto-extras/src/merging.rs | 33 ++++- example.py | 9 +- 4 files changed, 134 insertions(+), 59 deletions(-) diff --git a/betterproto-extras/src/descriptor_pool.rs b/betterproto-extras/src/descriptor_pool.rs index e88bdbe..9ce590b 100644 --- a/betterproto-extras/src/descriptor_pool.rs +++ b/betterproto-extras/src/descriptor_pool.rs @@ -6,7 +6,7 @@ use prost_reflect::{ prost_types::{ field_descriptor_proto::{Label, Type}, DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto, - FileDescriptorProto, OneofDescriptorProto, + FileDescriptorProto, MessageOptions, OneofDescriptorProto, }, DescriptorPool, MessageDescriptor, }; @@ -36,17 +36,17 @@ pub fn create_cached_descriptor(obj: &PyAny) -> Result { } fn add_message_to_file( - name: String, + message_name: String, obj: &PyAny, pool: &DescriptorPool, file: &mut FileDescriptorProto, ) -> Result<()> { - let mut messages_to_add = vec![(name, obj)]; + let mut messages_to_add = vec![(message_name, obj)]; - while let Some((name, obj)) = messages_to_add.pop() { + while let Some((message_name, obj)) = messages_to_add.pop() { let meta = obj.get_proto_meta()?; let mut message = DescriptorProto { - name: Some(name.to_string()), + name: Some(message_name.to_string()), ..Default::default() }; @@ -62,56 +62,101 @@ fn add_message_to_file( number: Some(field_meta.getattr("number")?.extract::()?), ..Default::default() }; - field.set_type(map_type( - field_meta.getattr("proto_type")?.extract::<&str>()?, - )?); + let proto_type = field_meta.getattr("proto_type")?.extract::<&str>()?; - match field.r#type() { - Type::Message => { - let cls = meta.get_class(field_name)?; - let cls_name = cls.qualified_name()?; - field.type_name = Some(cls_name.clone()); + if proto_type == "map" { + field.set_type(Type::Message); + let (key, val) = field_meta.getattr("map_types")?.extract::<(&str, &str)>()?; + let key = map_type(key)?; + let val = map_type(val)?; - if name != cls_name - && pool.get_message_by_name(&cls_name).is_none() - && !file.message_type.iter().any(|item| item.name() == cls_name) - && !messages_to_add.iter().any(|item| item.0 == cls_name) - { - messages_to_add.push((cls_name, cls.call0()?)); - } + if matches!( + key, + Type::Float | Type::Double | Type::Bytes | Type::Message | Type::Enum + ) { + return Err(Error::UnsupportedMapKeyType(key)); } - Type::Enum => { - let cls = meta.get_class(field_name)?; - let cls_name = cls.qualified_name()?; - field.type_name = Some(cls_name.to_string()); - if pool.get_enum_by_name(&cls_name).is_none() - && !file.enum_type.iter().any(|item| item.name() == cls_name) - { - let mut proto = EnumDescriptorProto { - name: Some(cls_name), - ..Default::default() - }; - - for item in cls.iter()? { - let item = item?; - proto.value.push(EnumValueDescriptorProto { - number: Some(item.getattr("value")?.extract()?), - name: Some(item.getattr("name")?.extract()?), - ..Default::default() - }); - } - - file.enum_type.push(proto); - } - } - _ => {} - } - - if meta.is_list_field(field_name)? { + let map_entry_name = format!("{field_name}Entry"); + field.type_name = Some(format!("{message_name}.{map_entry_name}")); field.set_label(Label::Repeated); - } else if field_meta.getattr("optional")?.extract::()? { - field.proto3_optional = Some(true); + message.nested_type.push(DescriptorProto { + name: Some(map_entry_name), + field: vec![ + { + let mut proto = FieldDescriptorProto { + name: Some("key".to_string()), + number: Some(1), + ..Default::default() + }; + proto.set_type(key); + proto + }, + { + let mut proto = FieldDescriptorProto { + name: Some("value".to_string()), + number: Some(2), + ..Default::default() + }; + proto.set_type(val); + proto + }, + ], + options: Some(MessageOptions { + map_entry: Some(true), + ..Default::default() + }), + ..Default::default() + }) + } else { + field.set_type(map_type(proto_type)?); + match field.r#type() { + Type::Message => { + let cls = meta.get_class(field_name)?; + let cls_name = cls.qualified_name()?; + field.type_name = Some(cls_name.clone()); + + if message_name != cls_name + && pool.get_message_by_name(&cls_name).is_none() + && !file.message_type.iter().any(|item| item.name() == cls_name) + && !messages_to_add.iter().any(|item| item.0 == cls_name) + { + messages_to_add.push((cls_name, cls.call0()?)); + } + } + Type::Enum => { + let cls = meta.get_class(field_name)?; + let cls_name = cls.qualified_name()?; + field.type_name = Some(cls_name.to_string()); + + if pool.get_enum_by_name(&cls_name).is_none() + && !file.enum_type.iter().any(|item| item.name() == cls_name) + { + let mut proto = EnumDescriptorProto { + name: Some(cls_name), + ..Default::default() + }; + + for item in cls.iter()? { + let item = item?; + proto.value.push(EnumValueDescriptorProto { + number: Some(item.getattr("value")?.extract()?), + name: Some(item.getattr("name")?.extract()?), + ..Default::default() + }); + } + + file.enum_type.push(proto); + } + } + _ => {} + } + + if meta.is_list_field(field_name)? { + field.set_label(Label::Repeated); + } else if field_meta.getattr("optional")?.extract::()? { + field.proto3_optional = Some(true); + } } if let Some(grp) = meta.oneof_group(field_name)? { diff --git a/betterproto-extras/src/error.rs b/betterproto-extras/src/error.rs index edb6a3f..0279041 100644 --- a/betterproto-extras/src/error.rs +++ b/betterproto-extras/src/error.rs @@ -1,4 +1,6 @@ -use prost_reflect::{prost::DecodeError, DescriptorError}; +use prost_reflect::{ + prost::DecodeError, prost_types::field_descriptor_proto::Type, DescriptorError, +}; use pyo3::{exceptions::PyRuntimeError, PyErr}; use thiserror::Error; @@ -8,6 +10,8 @@ pub enum Error { NoBetterprotoMessage(#[from] PyErr), #[error("Unsupported type `{0}`.")] UnsupportedType(String), + #[error("Unsupported map key type `{0:?}`.")] + UnsupportedMapKeyType(Type), #[error("Error on proto registration")] FailedToRegisterDescriptor(#[from] DescriptorError), #[error("The given binary data does not match the protobuf schema.")] diff --git a/betterproto-extras/src/merging.rs b/betterproto-extras/src/merging.rs index 5738bb1..d8ce976 100644 --- a/betterproto-extras/src/merging.rs +++ b/betterproto-extras/src/merging.rs @@ -1,9 +1,9 @@ -use crate::{ - error::{Error, Result}, - py_any_extras::PyAnyExtras, +use crate::{error::Result, py_any_extras::PyAnyExtras}; +use prost_reflect::{DynamicMessage, MapKey, Value}; +use pyo3::{ + types::{IntoPyDict, PyBytes}, + PyAny, PyObject, Python, ToPyObject, }; -use prost_reflect::{DynamicMessage, Value}; -use pyo3::{PyAny, PyObject, ToPyObject, types::PyBytes}; pub fn merge_msg_into_pyobj(obj: &PyAny, mut msg: DynamicMessage) -> Result<()> { for field in msg.take_fields() { @@ -56,6 +56,27 @@ fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) -> let cls = proto_meta.get_class(field_name)?; Ok(cls.call1((x,))?.to_object(py)) } - value => Err(Error::UnsupportedType(value.to_string())), + Value::Map(map) => { + let res: Result> = map + .into_iter() + .map(|(k, v)| { + let key = map_key(k, py); + let val = map_field_value(field_name, v, proto_meta)?; + Ok((key, val)) + }) + .collect(); + Ok(res?.into_py_dict(py).to_object(py)) + } + } +} + +fn map_key(key: MapKey, py: Python) -> PyObject { + match key { + MapKey::Bool(x) => x.to_object(py), + MapKey::I32(x) => x.to_object(py), + MapKey::I64(x) => x.to_object(py), + MapKey::U32(x) => x.to_object(py), + MapKey::U64(x) => x.to_object(py), + MapKey::String(x) => x.to_object(py), } } diff --git a/example.py b/example.py index d19f888..f72c8b8 100644 --- a/example.py +++ b/example.py @@ -3,7 +3,7 @@ import betterproto from dataclasses import dataclass -from typing import List, Optional +from typing import Dict, List, Optional @dataclass(repr=False) class Baz(betterproto.Message): @@ -30,6 +30,7 @@ class Bar(betterproto.Message): foo2: Foo = betterproto.message_field(2) packed: List[int] = betterproto.int64_field(3) enm: Enm = betterproto.enum_field(4) + map: Dict[int, bool] = betterproto.map_field(5, betterproto.TYPE_INT64, betterproto.TYPE_BOOL) # Serialization has not been changed yet. So nothing unusual here buffer = bytes( @@ -37,7 +38,11 @@ buffer = bytes( foo1=Foo(1, 2.34), foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]), packed=[5, 3, 1], - enm=Enm.B + enm=Enm.B, + map={ + 1: True, + 42: False + } ) )