Native deserialization based on Rust and PyO3

Proof of concept
Only capable of deserializing (nested) Messages with primitive fields
No handling of lists, maps, enums, .. implemented yet
See `example.py` for a working example
This commit is contained in:
Erik Friese
2023-08-25 19:41:22 +02:00
parent 4cdf1bb9e0
commit 421aa78014
12 changed files with 1203 additions and 486 deletions

View File

@@ -0,0 +1,99 @@
use prost_reflect::{
prost_types::{
field_descriptor_proto::Type, DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
},
DescriptorPool, MessageDescriptor,
};
use pyo3::{exceptions::PyRuntimeError, PyAny, PyResult};
use std::sync::{Mutex, OnceLock};
use crate::py_any_extras::PyAnyExtras;
pub fn create_cached_descriptor(obj: &PyAny) -> PyResult<MessageDescriptor> {
static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new();
let mut pool = DESCRIPTOR_POOL
.get_or_init(|| Mutex::new(DescriptorPool::new()))
.lock()
.unwrap();
create_cached_descriptor_in_pool(obj, &mut pool)
}
fn create_cached_descriptor_in_pool<'py>(
obj: &'py PyAny,
pool: &mut DescriptorPool,
) -> PyResult<MessageDescriptor> {
let name = obj.qualified_class_name()?;
if let Some(desc) = pool.get_message_by_name(&name) {
return Ok(desc);
}
let meta = obj.get_proto_meta()?;
let mut message = DescriptorProto {
name: Some(name.clone()),
..Default::default()
};
let mut file = FileDescriptorProto {
name: Some(name.clone()),
..Default::default()
};
for item in meta
.getattr("meta_by_field_name")?
.call_method0("items")?
.iter()?
{
let (field_name, field_meta) = item?.extract::<(&str, &'py PyAny)>()?;
message.field.push({
let mut field = FieldDescriptorProto {
name: Some(field_name.to_string()),
number: Some(field_meta.getattr("number")?.extract::<i32>()?),
..Default::default()
};
field.set_type(map_type(
field_meta.getattr("proto_type")?.extract::<&str>()?,
)?);
if field.r#type() == Type::Message {
let instance = meta.create_instance(field_name)?;
let cls_name = instance.qualified_class_name()?;
field.type_name = Some(cls_name.to_string());
create_cached_descriptor_in_pool(instance, pool)?;
}
field
});
}
file.message_type.push(message);
pool.add_file_descriptor_proto(file)
.map_err(|_| PyRuntimeError::new_err("Error on proto registration"))?;
Ok(pool.get_message_by_name(&name).expect("Just registered..."))
}
fn map_type(str: &str) -> PyResult<Type> {
match str {
"enum" => Ok(Type::Enum),
"bool" => Ok(Type::Bool),
"int32" => Ok(Type::Int32),
"int64" => Ok(Type::Int64),
"uint32" => Ok(Type::Uint32),
"uint64" => Ok(Type::Uint64),
"sint32" => Ok(Type::Sint32),
"sint64" => Ok(Type::Sint64),
"float" => Ok(Type::Float),
"double" => Ok(Type::Double),
"fixed32" => Ok(Type::Fixed32),
"sfixed32" => Ok(Type::Sfixed32),
"fixed64" => Ok(Type::Fixed64),
"sfixed64" => Ok(Type::Sfixed64),
"string" => Ok(Type::String),
"bytes" => Ok(Type::Bytes),
"message" => Ok(Type::Message),
_ => Err(PyRuntimeError::new_err("Unsupported type")),
}
}

View File

@@ -0,0 +1,23 @@
mod descriptor_pool;
mod merging;
mod py_any_extras;
use descriptor_pool::create_cached_descriptor;
use merging::merge_msg_into_pyobj;
use prost_reflect::DynamicMessage;
use pyo3::{exceptions::PyRuntimeError, prelude::*};
#[pyfunction]
fn deserialize(obj: &PyAny, buf: &[u8]) -> PyResult<()> {
let desc = create_cached_descriptor(obj)?;
let msg = DynamicMessage::decode(desc, buf)
.map_err(|_| PyRuntimeError::new_err("Error on deserializing."))?;
merge_msg_into_pyobj(obj, &msg)?;
Ok(())
}
#[pymodule]
fn betterproto_extras(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(deserialize, m)?)?;
Ok(())
}

View File

@@ -0,0 +1,43 @@
use prost_reflect::{DynamicMessage, ReflectMessage, Value};
use pyo3::{exceptions::PyRuntimeError, IntoPy, PyAny, PyResult};
use crate::py_any_extras::PyAnyExtras;
pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> PyResult<()> {
for field in msg.descriptor().fields() {
let field_name = field.name();
let proto_meta = obj.get_proto_meta()?;
if let Some(field_value) = msg.get_field_by_name(field_name) {
obj.setattr(
field_name,
map_field_value(field_name, &field_value, proto_meta)?,
)?;
}
}
Ok(())
}
fn map_field_value<'py>(
field_name: &str,
field_value: &Value,
proto_meta: &'py PyAny,
) -> PyResult<&'py PyAny> {
let py = proto_meta.py();
match field_value {
Value::Bool(x) => Ok(x.into_py(py).into_ref(py)),
Value::Bytes(x) => Ok(x.into_py(py).into_ref(py)),
Value::F32(x) => Ok(x.into_py(py).into_ref(py)),
Value::F64(x) => Ok(x.into_py(py).into_ref(py)),
Value::I32(x) => Ok(x.into_py(py).into_ref(py)),
Value::I64(x) => Ok(x.into_py(py).into_ref(py)),
Value::String(x) => Ok(x.into_py(py).into_ref(py)),
Value::U32(x) => Ok(x.into_py(py).into_ref(py)),
Value::U64(x) => Ok(x.into_py(py).into_ref(py)),
Value::Message(msg) => {
let obj = proto_meta.create_instance(field_name)?;
merge_msg_into_pyobj(obj, msg)?;
Ok(obj)
}
_ => Err(PyRuntimeError::new_err("Unsupported type")),
}
}

View File

@@ -0,0 +1,27 @@
use pyo3::{PyAny, PyResult};
pub trait PyAnyExtras<'py> {
fn qualified_class_name(&self) -> PyResult<String>;
fn get_proto_meta(&self) -> PyResult<&'py PyAny>;
fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny>;
}
impl<'py> PyAnyExtras<'py> for &'py PyAny {
fn qualified_class_name(&self) -> PyResult<String> {
let class = self.getattr("__class__")?;
let module = class.getattr("__module__")?;
let name = class.getattr("__name__")?;
Ok(format!("{module}.{name}"))
}
fn get_proto_meta(&self) -> PyResult<&'py PyAny> {
self.getattr("_betterproto")
}
fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny> {
let res = self.getattr("cls_by_field")?
.get_item(field_name)?
.call0()?;
Ok(res)
}
}