map support
This commit is contained in:
parent
219233b50e
commit
29f12ea88d
@ -6,7 +6,7 @@ use prost_reflect::{
|
|||||||
prost_types::{
|
prost_types::{
|
||||||
field_descriptor_proto::{Label, Type},
|
field_descriptor_proto::{Label, Type},
|
||||||
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
|
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
|
||||||
FileDescriptorProto, OneofDescriptorProto,
|
FileDescriptorProto, MessageOptions, OneofDescriptorProto,
|
||||||
},
|
},
|
||||||
DescriptorPool, MessageDescriptor,
|
DescriptorPool, MessageDescriptor,
|
||||||
};
|
};
|
||||||
@ -36,17 +36,17 @@ pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn add_message_to_file(
|
fn add_message_to_file(
|
||||||
name: String,
|
message_name: String,
|
||||||
obj: &PyAny,
|
obj: &PyAny,
|
||||||
pool: &DescriptorPool,
|
pool: &DescriptorPool,
|
||||||
file: &mut FileDescriptorProto,
|
file: &mut FileDescriptorProto,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut messages_to_add = vec![(name, obj)];
|
let mut messages_to_add = vec![(message_name, obj)];
|
||||||
|
|
||||||
while let Some((name, obj)) = messages_to_add.pop() {
|
while let Some((message_name, obj)) = messages_to_add.pop() {
|
||||||
let meta = obj.get_proto_meta()?;
|
let meta = obj.get_proto_meta()?;
|
||||||
let mut message = DescriptorProto {
|
let mut message = DescriptorProto {
|
||||||
name: Some(name.to_string()),
|
name: Some(message_name.to_string()),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -62,56 +62,101 @@ fn add_message_to_file(
|
|||||||
number: Some(field_meta.getattr("number")?.extract::<i32>()?),
|
number: Some(field_meta.getattr("number")?.extract::<i32>()?),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
field.set_type(map_type(
|
let proto_type = field_meta.getattr("proto_type")?.extract::<&str>()?;
|
||||||
field_meta.getattr("proto_type")?.extract::<&str>()?,
|
|
||||||
)?);
|
|
||||||
|
|
||||||
match field.r#type() {
|
if proto_type == "map" {
|
||||||
Type::Message => {
|
field.set_type(Type::Message);
|
||||||
let cls = meta.get_class(field_name)?;
|
let (key, val) = field_meta.getattr("map_types")?.extract::<(&str, &str)>()?;
|
||||||
let cls_name = cls.qualified_name()?;
|
let key = map_type(key)?;
|
||||||
field.type_name = Some(cls_name.clone());
|
let val = map_type(val)?;
|
||||||
|
|
||||||
if name != cls_name
|
if matches!(
|
||||||
&& pool.get_message_by_name(&cls_name).is_none()
|
key,
|
||||||
&& !file.message_type.iter().any(|item| item.name() == cls_name)
|
Type::Float | Type::Double | Type::Bytes | Type::Message | Type::Enum
|
||||||
&& !messages_to_add.iter().any(|item| item.0 == cls_name)
|
) {
|
||||||
{
|
return Err(Error::UnsupportedMapKeyType(key));
|
||||||
messages_to_add.push((cls_name, cls.call0()?));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Type::Enum => {
|
|
||||||
let cls = meta.get_class(field_name)?;
|
|
||||||
let cls_name = cls.qualified_name()?;
|
|
||||||
field.type_name = Some(cls_name.to_string());
|
|
||||||
|
|
||||||
if pool.get_enum_by_name(&cls_name).is_none()
|
let map_entry_name = format!("{field_name}Entry");
|
||||||
&& !file.enum_type.iter().any(|item| item.name() == cls_name)
|
field.type_name = Some(format!("{message_name}.{map_entry_name}"));
|
||||||
{
|
|
||||||
let mut proto = EnumDescriptorProto {
|
|
||||||
name: Some(cls_name),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
for item in cls.iter()? {
|
|
||||||
let item = item?;
|
|
||||||
proto.value.push(EnumValueDescriptorProto {
|
|
||||||
number: Some(item.getattr("value")?.extract()?),
|
|
||||||
name: Some(item.getattr("name")?.extract()?),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
file.enum_type.push(proto);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
if meta.is_list_field(field_name)? {
|
|
||||||
field.set_label(Label::Repeated);
|
field.set_label(Label::Repeated);
|
||||||
} else if field_meta.getattr("optional")?.extract::<bool>()? {
|
message.nested_type.push(DescriptorProto {
|
||||||
field.proto3_optional = Some(true);
|
name: Some(map_entry_name),
|
||||||
|
field: vec![
|
||||||
|
{
|
||||||
|
let mut proto = FieldDescriptorProto {
|
||||||
|
name: Some("key".to_string()),
|
||||||
|
number: Some(1),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
proto.set_type(key);
|
||||||
|
proto
|
||||||
|
},
|
||||||
|
{
|
||||||
|
let mut proto = FieldDescriptorProto {
|
||||||
|
name: Some("value".to_string()),
|
||||||
|
number: Some(2),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
proto.set_type(val);
|
||||||
|
proto
|
||||||
|
},
|
||||||
|
],
|
||||||
|
options: Some(MessageOptions {
|
||||||
|
map_entry: Some(true),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
field.set_type(map_type(proto_type)?);
|
||||||
|
match field.r#type() {
|
||||||
|
Type::Message => {
|
||||||
|
let cls = meta.get_class(field_name)?;
|
||||||
|
let cls_name = cls.qualified_name()?;
|
||||||
|
field.type_name = Some(cls_name.clone());
|
||||||
|
|
||||||
|
if message_name != cls_name
|
||||||
|
&& pool.get_message_by_name(&cls_name).is_none()
|
||||||
|
&& !file.message_type.iter().any(|item| item.name() == cls_name)
|
||||||
|
&& !messages_to_add.iter().any(|item| item.0 == cls_name)
|
||||||
|
{
|
||||||
|
messages_to_add.push((cls_name, cls.call0()?));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Type::Enum => {
|
||||||
|
let cls = meta.get_class(field_name)?;
|
||||||
|
let cls_name = cls.qualified_name()?;
|
||||||
|
field.type_name = Some(cls_name.to_string());
|
||||||
|
|
||||||
|
if pool.get_enum_by_name(&cls_name).is_none()
|
||||||
|
&& !file.enum_type.iter().any(|item| item.name() == cls_name)
|
||||||
|
{
|
||||||
|
let mut proto = EnumDescriptorProto {
|
||||||
|
name: Some(cls_name),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
for item in cls.iter()? {
|
||||||
|
let item = item?;
|
||||||
|
proto.value.push(EnumValueDescriptorProto {
|
||||||
|
number: Some(item.getattr("value")?.extract()?),
|
||||||
|
name: Some(item.getattr("name")?.extract()?),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
file.enum_type.push(proto);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if meta.is_list_field(field_name)? {
|
||||||
|
field.set_label(Label::Repeated);
|
||||||
|
} else if field_meta.getattr("optional")?.extract::<bool>()? {
|
||||||
|
field.proto3_optional = Some(true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(grp) = meta.oneof_group(field_name)? {
|
if let Some(grp) = meta.oneof_group(field_name)? {
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
use prost_reflect::{prost::DecodeError, DescriptorError};
|
use prost_reflect::{
|
||||||
|
prost::DecodeError, prost_types::field_descriptor_proto::Type, DescriptorError,
|
||||||
|
};
|
||||||
use pyo3::{exceptions::PyRuntimeError, PyErr};
|
use pyo3::{exceptions::PyRuntimeError, PyErr};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
@ -8,6 +10,8 @@ pub enum Error {
|
|||||||
NoBetterprotoMessage(#[from] PyErr),
|
NoBetterprotoMessage(#[from] PyErr),
|
||||||
#[error("Unsupported type `{0}`.")]
|
#[error("Unsupported type `{0}`.")]
|
||||||
UnsupportedType(String),
|
UnsupportedType(String),
|
||||||
|
#[error("Unsupported map key type `{0:?}`.")]
|
||||||
|
UnsupportedMapKeyType(Type),
|
||||||
#[error("Error on proto registration")]
|
#[error("Error on proto registration")]
|
||||||
FailedToRegisterDescriptor(#[from] DescriptorError),
|
FailedToRegisterDescriptor(#[from] DescriptorError),
|
||||||
#[error("The given binary data does not match the protobuf schema.")]
|
#[error("The given binary data does not match the protobuf schema.")]
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
use crate::{
|
use crate::{error::Result, py_any_extras::PyAnyExtras};
|
||||||
error::{Error, Result},
|
use prost_reflect::{DynamicMessage, MapKey, Value};
|
||||||
py_any_extras::PyAnyExtras,
|
use pyo3::{
|
||||||
|
types::{IntoPyDict, PyBytes},
|
||||||
|
PyAny, PyObject, Python, ToPyObject,
|
||||||
};
|
};
|
||||||
use prost_reflect::{DynamicMessage, Value};
|
|
||||||
use pyo3::{PyAny, PyObject, ToPyObject, types::PyBytes};
|
|
||||||
|
|
||||||
pub fn merge_msg_into_pyobj(obj: &PyAny, mut msg: DynamicMessage) -> Result<()> {
|
pub fn merge_msg_into_pyobj(obj: &PyAny, mut msg: DynamicMessage) -> Result<()> {
|
||||||
for field in msg.take_fields() {
|
for field in msg.take_fields() {
|
||||||
@ -56,6 +56,27 @@ fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) ->
|
|||||||
let cls = proto_meta.get_class(field_name)?;
|
let cls = proto_meta.get_class(field_name)?;
|
||||||
Ok(cls.call1((x,))?.to_object(py))
|
Ok(cls.call1((x,))?.to_object(py))
|
||||||
}
|
}
|
||||||
value => Err(Error::UnsupportedType(value.to_string())),
|
Value::Map(map) => {
|
||||||
|
let res: Result<Vec<_>> = map
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| {
|
||||||
|
let key = map_key(k, py);
|
||||||
|
let val = map_field_value(field_name, v, proto_meta)?;
|
||||||
|
Ok((key, val))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(res?.into_py_dict(py).to_object(py))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_key(key: MapKey, py: Python) -> PyObject {
|
||||||
|
match key {
|
||||||
|
MapKey::Bool(x) => x.to_object(py),
|
||||||
|
MapKey::I32(x) => x.to_object(py),
|
||||||
|
MapKey::I64(x) => x.to_object(py),
|
||||||
|
MapKey::U32(x) => x.to_object(py),
|
||||||
|
MapKey::U64(x) => x.to_object(py),
|
||||||
|
MapKey::String(x) => x.to_object(py),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False)
|
||||||
class Baz(betterproto.Message):
|
class Baz(betterproto.Message):
|
||||||
@ -30,6 +30,7 @@ class Bar(betterproto.Message):
|
|||||||
foo2: Foo = betterproto.message_field(2)
|
foo2: Foo = betterproto.message_field(2)
|
||||||
packed: List[int] = betterproto.int64_field(3)
|
packed: List[int] = betterproto.int64_field(3)
|
||||||
enm: Enm = betterproto.enum_field(4)
|
enm: Enm = betterproto.enum_field(4)
|
||||||
|
map: Dict[int, bool] = betterproto.map_field(5, betterproto.TYPE_INT64, betterproto.TYPE_BOOL)
|
||||||
|
|
||||||
# Serialization has not been changed yet. So nothing unusual here
|
# Serialization has not been changed yet. So nothing unusual here
|
||||||
buffer = bytes(
|
buffer = bytes(
|
||||||
@ -37,7 +38,11 @@ buffer = bytes(
|
|||||||
foo1=Foo(1, 2.34),
|
foo1=Foo(1, 2.34),
|
||||||
foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]),
|
foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]),
|
||||||
packed=[5, 3, 1],
|
packed=[5, 3, 1],
|
||||||
enm=Enm.B
|
enm=Enm.B,
|
||||||
|
map={
|
||||||
|
1: True,
|
||||||
|
42: False
|
||||||
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user