proper error handling

This commit is contained in:
Erik Friese 2023-08-26 21:29:41 +02:00
parent 604dcb104f
commit 26da86d2cd
7 changed files with 93 additions and 31 deletions

View File

@ -20,6 +20,7 @@ version = "0.1.0"
dependencies = [
"prost-reflect",
"pyo3",
"thiserror",
]
[[package]]
@ -144,7 +145,7 @@ dependencies = [
"itertools",
"proc-macro2",
"quote",
"syn",
"syn 1.0.109",
]
[[package]]
@ -213,7 +214,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn",
"syn 1.0.109",
]
[[package]]
@ -224,7 +225,7 @@ checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.109",
]
[[package]]
@ -268,12 +269,43 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "target-lexicon"
version = "0.12.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a"
[[package]]
name = "thiserror"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.29",
]
[[package]]
name = "unicode-ident"
version = "1.0.11"

View File

@ -10,3 +10,4 @@ crate-type = ["cdylib"]
[dependencies]
prost-reflect = "0.11.4"
pyo3 = { version = "0.19.0", features = ["abi3-py37", "extension-module"] }
thiserror = "1.0.47"

View File

@ -1,15 +1,17 @@
use crate::{
error::{Error, Result},
py_any_extras::PyAnyExtras,
};
use prost_reflect::{
prost_types::{
field_descriptor_proto::Type, DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
},
DescriptorPool, MessageDescriptor,
};
use pyo3::{exceptions::PyRuntimeError, PyAny, PyResult};
use pyo3::PyAny;
use std::sync::{Mutex, OnceLock};
use crate::py_any_extras::PyAnyExtras;
pub fn create_cached_descriptor(obj: &PyAny) -> PyResult<MessageDescriptor> {
pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new();
let mut pool = DESCRIPTOR_POOL
.get_or_init(|| Mutex::new(DescriptorPool::new()))
@ -22,7 +24,7 @@ pub fn create_cached_descriptor(obj: &PyAny) -> PyResult<MessageDescriptor> {
fn create_cached_descriptor_in_pool<'py>(
obj: &'py PyAny,
pool: &mut DescriptorPool,
) -> PyResult<MessageDescriptor> {
) -> Result<MessageDescriptor> {
let name = obj.qualified_class_name()?;
if let Some(desc) = pool.get_message_by_name(&name) {
return Ok(desc);
@ -69,13 +71,12 @@ fn create_cached_descriptor_in_pool<'py>(
file.message_type.push(message);
pool.add_file_descriptor_proto(file)
.map_err(|_| PyRuntimeError::new_err("Error on proto registration"))?;
pool.add_file_descriptor_proto(file)?;
Ok(pool.get_message_by_name(&name).expect("Just registered..."))
}
fn map_type(str: &str) -> PyResult<Type> {
fn map_type(str: &str) -> Result<Type> {
match str {
"enum" => Ok(Type::Enum),
"bool" => Ok(Type::Bool),
@ -94,6 +95,6 @@ fn map_type(str: &str) -> PyResult<Type> {
"string" => Ok(Type::String),
"bytes" => Ok(Type::Bytes),
"message" => Ok(Type::Message),
_ => Err(PyRuntimeError::new_err("Unsupported type")),
_ => Err(Error::UnsupportedType(str.to_string())),
}
}

View File

@ -0,0 +1,23 @@
use prost_reflect::{prost::DecodeError, DescriptorError};
use pyo3::{exceptions::PyRuntimeError, PyErr};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Given object is not a valid betterproto message.")]
NoBetterprotoMessage(#[from] PyErr),
#[error("Unsupported type `{0}`.")]
UnsupportedType(String),
#[error("Error on proto registration")]
FailedToRegisterDescriptor(#[from] DescriptorError),
#[error("The given binary data does not match the protobuf schema.")]
FailedToDecode(#[from] DecodeError),
}
pub type Result<T> = core::result::Result<T, Error>;
impl From<Error> for PyErr {
fn from(value: Error) -> Self {
PyRuntimeError::new_err(value.to_string())
}
}

View File

@ -1,17 +1,18 @@
mod descriptor_pool;
mod error;
mod merging;
mod py_any_extras;
use descriptor_pool::create_cached_descriptor;
use error::Result;
use merging::merge_msg_into_pyobj;
use prost_reflect::DynamicMessage;
use pyo3::{exceptions::PyRuntimeError, prelude::*};
use pyo3::prelude::*;
#[pyfunction]
fn deserialize(obj: &PyAny, buf: &[u8]) -> PyResult<()> {
fn deserialize(obj: &PyAny, buf: &[u8]) -> Result<()> {
let desc = create_cached_descriptor(obj)?;
let msg = DynamicMessage::decode(desc, buf)
.map_err(|_| PyRuntimeError::new_err("Error on deserializing."))?;
let msg = DynamicMessage::decode(desc, buf)?;
merge_msg_into_pyobj(obj, &msg)?;
Ok(())
}

View File

@ -1,9 +1,11 @@
use crate::{
error::{Error, Result},
py_any_extras::PyAnyExtras,
};
use prost_reflect::{DynamicMessage, ReflectMessage, Value};
use pyo3::{exceptions::PyRuntimeError, IntoPy, PyAny, PyResult};
use pyo3::{IntoPy, PyAny};
use crate::py_any_extras::PyAnyExtras;
pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> PyResult<()> {
pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> {
for field in msg.descriptor().fields() {
let field_name = field.name();
let proto_meta = obj.get_proto_meta()?;
@ -21,7 +23,7 @@ fn map_field_value<'py>(
field_name: &str,
field_value: &Value,
proto_meta: &'py PyAny,
) -> PyResult<&'py PyAny> {
) -> Result<&'py PyAny> {
let py = proto_meta.py();
match field_value {
Value::Bool(x) => Ok(x.into_py(py).into_ref(py)),
@ -38,6 +40,6 @@ fn map_field_value<'py>(
merge_msg_into_pyobj(obj, msg)?;
Ok(obj)
}
_ => Err(PyRuntimeError::new_err("Unsupported type")),
value => Err(Error::UnsupportedType(value.to_string())),
}
}

View File

@ -1,25 +1,27 @@
use pyo3::{PyAny, PyResult};
use crate::error::Result;
use pyo3::PyAny;
pub trait PyAnyExtras<'py> {
fn qualified_class_name(&self) -> PyResult<String>;
fn get_proto_meta(&self) -> PyResult<&'py PyAny>;
fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny>;
fn qualified_class_name(&self) -> Result<String>;
fn get_proto_meta(&self) -> Result<&'py PyAny>;
fn create_instance(&self, field_name: &str) -> Result<&'py PyAny>;
}
impl<'py> PyAnyExtras<'py> for &'py PyAny {
fn qualified_class_name(&self) -> PyResult<String> {
fn qualified_class_name(&self) -> Result<String> {
let class = self.getattr("__class__")?;
let module = class.getattr("__module__")?;
let name = class.getattr("__name__")?;
Ok(format!("{module}.{name}"))
}
fn get_proto_meta(&self) -> PyResult<&'py PyAny> {
self.getattr("_betterproto")
fn get_proto_meta(&self) -> Result<&'py PyAny> {
Ok(self.getattr("_betterproto")?)
}
fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny> {
let res = self.getattr("cls_by_field")?
fn create_instance(&self, field_name: &str) -> Result<&'py PyAny> {
let res = self
.getattr("cls_by_field")?
.get_item(field_name)?
.call0()?;
Ok(res)