Native deserialization based on Rust and PyO3
Proof of concept Only capable of deserializing (nested) Messages with primitive fields No handling of lists, maps, enums, .. implemented yet See `example.py` for a working example
This commit is contained in:
99
betterproto-extras/src/descriptor_pool.rs
Normal file
99
betterproto-extras/src/descriptor_pool.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use prost_reflect::{
|
||||
prost_types::{
|
||||
field_descriptor_proto::Type, DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
|
||||
},
|
||||
DescriptorPool, MessageDescriptor,
|
||||
};
|
||||
use pyo3::{exceptions::PyRuntimeError, PyAny, PyResult};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use crate::py_any_extras::PyAnyExtras;
|
||||
|
||||
pub fn create_cached_descriptor(obj: &PyAny) -> PyResult<MessageDescriptor> {
|
||||
static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new();
|
||||
let mut pool = DESCRIPTOR_POOL
|
||||
.get_or_init(|| Mutex::new(DescriptorPool::new()))
|
||||
.lock()
|
||||
.unwrap();
|
||||
|
||||
create_cached_descriptor_in_pool(obj, &mut pool)
|
||||
}
|
||||
|
||||
fn create_cached_descriptor_in_pool<'py>(
|
||||
obj: &'py PyAny,
|
||||
pool: &mut DescriptorPool,
|
||||
) -> PyResult<MessageDescriptor> {
|
||||
let name = obj.qualified_class_name()?;
|
||||
if let Some(desc) = pool.get_message_by_name(&name) {
|
||||
return Ok(desc);
|
||||
}
|
||||
|
||||
let meta = obj.get_proto_meta()?;
|
||||
|
||||
let mut message = DescriptorProto {
|
||||
name: Some(name.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut file = FileDescriptorProto {
|
||||
name: Some(name.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
for item in meta
|
||||
.getattr("meta_by_field_name")?
|
||||
.call_method0("items")?
|
||||
.iter()?
|
||||
{
|
||||
let (field_name, field_meta) = item?.extract::<(&str, &'py PyAny)>()?;
|
||||
message.field.push({
|
||||
let mut field = FieldDescriptorProto {
|
||||
name: Some(field_name.to_string()),
|
||||
number: Some(field_meta.getattr("number")?.extract::<i32>()?),
|
||||
..Default::default()
|
||||
};
|
||||
field.set_type(map_type(
|
||||
field_meta.getattr("proto_type")?.extract::<&str>()?,
|
||||
)?);
|
||||
|
||||
if field.r#type() == Type::Message {
|
||||
let instance = meta.create_instance(field_name)?;
|
||||
let cls_name = instance.qualified_class_name()?;
|
||||
field.type_name = Some(cls_name.to_string());
|
||||
create_cached_descriptor_in_pool(instance, pool)?;
|
||||
}
|
||||
|
||||
field
|
||||
});
|
||||
}
|
||||
|
||||
file.message_type.push(message);
|
||||
|
||||
pool.add_file_descriptor_proto(file)
|
||||
.map_err(|_| PyRuntimeError::new_err("Error on proto registration"))?;
|
||||
|
||||
Ok(pool.get_message_by_name(&name).expect("Just registered..."))
|
||||
}
|
||||
|
||||
fn map_type(str: &str) -> PyResult<Type> {
|
||||
match str {
|
||||
"enum" => Ok(Type::Enum),
|
||||
"bool" => Ok(Type::Bool),
|
||||
"int32" => Ok(Type::Int32),
|
||||
"int64" => Ok(Type::Int64),
|
||||
"uint32" => Ok(Type::Uint32),
|
||||
"uint64" => Ok(Type::Uint64),
|
||||
"sint32" => Ok(Type::Sint32),
|
||||
"sint64" => Ok(Type::Sint64),
|
||||
"float" => Ok(Type::Float),
|
||||
"double" => Ok(Type::Double),
|
||||
"fixed32" => Ok(Type::Fixed32),
|
||||
"sfixed32" => Ok(Type::Sfixed32),
|
||||
"fixed64" => Ok(Type::Fixed64),
|
||||
"sfixed64" => Ok(Type::Sfixed64),
|
||||
"string" => Ok(Type::String),
|
||||
"bytes" => Ok(Type::Bytes),
|
||||
"message" => Ok(Type::Message),
|
||||
_ => Err(PyRuntimeError::new_err("Unsupported type")),
|
||||
}
|
||||
}
|
||||
23
betterproto-extras/src/lib.rs
Normal file
23
betterproto-extras/src/lib.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
mod descriptor_pool;
|
||||
mod merging;
|
||||
mod py_any_extras;
|
||||
|
||||
use descriptor_pool::create_cached_descriptor;
|
||||
use merging::merge_msg_into_pyobj;
|
||||
use prost_reflect::DynamicMessage;
|
||||
use pyo3::{exceptions::PyRuntimeError, prelude::*};
|
||||
|
||||
#[pyfunction]
|
||||
fn deserialize(obj: &PyAny, buf: &[u8]) -> PyResult<()> {
|
||||
let desc = create_cached_descriptor(obj)?;
|
||||
let msg = DynamicMessage::decode(desc, buf)
|
||||
.map_err(|_| PyRuntimeError::new_err("Error on deserializing."))?;
|
||||
merge_msg_into_pyobj(obj, &msg)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn betterproto_extras(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(deserialize, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
43
betterproto-extras/src/merging.rs
Normal file
43
betterproto-extras/src/merging.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use prost_reflect::{DynamicMessage, ReflectMessage, Value};
|
||||
use pyo3::{exceptions::PyRuntimeError, IntoPy, PyAny, PyResult};
|
||||
|
||||
use crate::py_any_extras::PyAnyExtras;
|
||||
|
||||
pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> PyResult<()> {
|
||||
for field in msg.descriptor().fields() {
|
||||
let field_name = field.name();
|
||||
let proto_meta = obj.get_proto_meta()?;
|
||||
if let Some(field_value) = msg.get_field_by_name(field_name) {
|
||||
obj.setattr(
|
||||
field_name,
|
||||
map_field_value(field_name, &field_value, proto_meta)?,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn map_field_value<'py>(
|
||||
field_name: &str,
|
||||
field_value: &Value,
|
||||
proto_meta: &'py PyAny,
|
||||
) -> PyResult<&'py PyAny> {
|
||||
let py = proto_meta.py();
|
||||
match field_value {
|
||||
Value::Bool(x) => Ok(x.into_py(py).into_ref(py)),
|
||||
Value::Bytes(x) => Ok(x.into_py(py).into_ref(py)),
|
||||
Value::F32(x) => Ok(x.into_py(py).into_ref(py)),
|
||||
Value::F64(x) => Ok(x.into_py(py).into_ref(py)),
|
||||
Value::I32(x) => Ok(x.into_py(py).into_ref(py)),
|
||||
Value::I64(x) => Ok(x.into_py(py).into_ref(py)),
|
||||
Value::String(x) => Ok(x.into_py(py).into_ref(py)),
|
||||
Value::U32(x) => Ok(x.into_py(py).into_ref(py)),
|
||||
Value::U64(x) => Ok(x.into_py(py).into_ref(py)),
|
||||
Value::Message(msg) => {
|
||||
let obj = proto_meta.create_instance(field_name)?;
|
||||
merge_msg_into_pyobj(obj, msg)?;
|
||||
Ok(obj)
|
||||
}
|
||||
_ => Err(PyRuntimeError::new_err("Unsupported type")),
|
||||
}
|
||||
}
|
||||
27
betterproto-extras/src/py_any_extras.rs
Normal file
27
betterproto-extras/src/py_any_extras.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use pyo3::{PyAny, PyResult};
|
||||
|
||||
pub trait PyAnyExtras<'py> {
|
||||
fn qualified_class_name(&self) -> PyResult<String>;
|
||||
fn get_proto_meta(&self) -> PyResult<&'py PyAny>;
|
||||
fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny>;
|
||||
}
|
||||
|
||||
impl<'py> PyAnyExtras<'py> for &'py PyAny {
|
||||
fn qualified_class_name(&self) -> PyResult<String> {
|
||||
let class = self.getattr("__class__")?;
|
||||
let module = class.getattr("__module__")?;
|
||||
let name = class.getattr("__name__")?;
|
||||
Ok(format!("{module}.{name}"))
|
||||
}
|
||||
|
||||
fn get_proto_meta(&self) -> PyResult<&'py PyAny> {
|
||||
self.getattr("_betterproto")
|
||||
}
|
||||
|
||||
fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny> {
|
||||
let res = self.getattr("cls_by_field")?
|
||||
.get_item(field_name)?
|
||||
.call0()?;
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user