proper error handling

This commit is contained in:
Erik Friese 2023-08-26 21:29:41 +02:00
parent 604dcb104f
commit 26da86d2cd
7 changed files with 93 additions and 31 deletions

View File

@ -20,6 +20,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"prost-reflect", "prost-reflect",
"pyo3", "pyo3",
"thiserror",
] ]
[[package]] [[package]]
@ -144,7 +145,7 @@ dependencies = [
"itertools", "itertools",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 1.0.109",
] ]
[[package]] [[package]]
@ -213,7 +214,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn", "syn 1.0.109",
] ]
[[package]] [[package]]
@ -224,7 +225,7 @@ checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 1.0.109",
] ]
[[package]] [[package]]
@ -268,12 +269,43 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "syn"
version = "2.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]] [[package]]
name = "target-lexicon" name = "target-lexicon"
version = "0.12.11" version = "0.12.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a"
[[package]]
name = "thiserror"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.29",
]
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.11" version = "1.0.11"

View File

@ -10,3 +10,4 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
prost-reflect = "0.11.4" prost-reflect = "0.11.4"
pyo3 = { version = "0.19.0", features = ["abi3-py37", "extension-module"] } pyo3 = { version = "0.19.0", features = ["abi3-py37", "extension-module"] }
thiserror = "1.0.47"

View File

@ -1,15 +1,17 @@
use crate::{
error::{Error, Result},
py_any_extras::PyAnyExtras,
};
use prost_reflect::{ use prost_reflect::{
prost_types::{ prost_types::{
field_descriptor_proto::Type, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, field_descriptor_proto::Type, DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
}, },
DescriptorPool, MessageDescriptor, DescriptorPool, MessageDescriptor,
}; };
use pyo3::{exceptions::PyRuntimeError, PyAny, PyResult}; use pyo3::PyAny;
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
use crate::py_any_extras::PyAnyExtras; pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
pub fn create_cached_descriptor(obj: &PyAny) -> PyResult<MessageDescriptor> {
static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new(); static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new();
let mut pool = DESCRIPTOR_POOL let mut pool = DESCRIPTOR_POOL
.get_or_init(|| Mutex::new(DescriptorPool::new())) .get_or_init(|| Mutex::new(DescriptorPool::new()))
@ -22,7 +24,7 @@ pub fn create_cached_descriptor(obj: &PyAny) -> PyResult<MessageDescriptor> {
fn create_cached_descriptor_in_pool<'py>( fn create_cached_descriptor_in_pool<'py>(
obj: &'py PyAny, obj: &'py PyAny,
pool: &mut DescriptorPool, pool: &mut DescriptorPool,
) -> PyResult<MessageDescriptor> { ) -> Result<MessageDescriptor> {
let name = obj.qualified_class_name()?; let name = obj.qualified_class_name()?;
if let Some(desc) = pool.get_message_by_name(&name) { if let Some(desc) = pool.get_message_by_name(&name) {
return Ok(desc); return Ok(desc);
@ -69,13 +71,12 @@ fn create_cached_descriptor_in_pool<'py>(
file.message_type.push(message); file.message_type.push(message);
pool.add_file_descriptor_proto(file) pool.add_file_descriptor_proto(file)?;
.map_err(|_| PyRuntimeError::new_err("Error on proto registration"))?;
Ok(pool.get_message_by_name(&name).expect("Just registered...")) Ok(pool.get_message_by_name(&name).expect("Just registered..."))
} }
fn map_type(str: &str) -> PyResult<Type> { fn map_type(str: &str) -> Result<Type> {
match str { match str {
"enum" => Ok(Type::Enum), "enum" => Ok(Type::Enum),
"bool" => Ok(Type::Bool), "bool" => Ok(Type::Bool),
@ -94,6 +95,6 @@ fn map_type(str: &str) -> PyResult<Type> {
"string" => Ok(Type::String), "string" => Ok(Type::String),
"bytes" => Ok(Type::Bytes), "bytes" => Ok(Type::Bytes),
"message" => Ok(Type::Message), "message" => Ok(Type::Message),
_ => Err(PyRuntimeError::new_err("Unsupported type")), _ => Err(Error::UnsupportedType(str.to_string())),
} }
} }

View File

@ -0,0 +1,23 @@
use prost_reflect::{prost::DecodeError, DescriptorError};
use pyo3::{exceptions::PyRuntimeError, PyErr};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Given object is not a valid betterproto message.")]
NoBetterprotoMessage(#[from] PyErr),
#[error("Unsupported type `{0}`.")]
UnsupportedType(String),
#[error("Error on proto registration")]
FailedToRegisterDescriptor(#[from] DescriptorError),
#[error("The given binary data does not match the protobuf schema.")]
FailedToDecode(#[from] DecodeError),
}
pub type Result<T> = core::result::Result<T, Error>;
impl From<Error> for PyErr {
fn from(value: Error) -> Self {
PyRuntimeError::new_err(value.to_string())
}
}

View File

@ -1,17 +1,18 @@
mod descriptor_pool; mod descriptor_pool;
mod error;
mod merging; mod merging;
mod py_any_extras; mod py_any_extras;
use descriptor_pool::create_cached_descriptor; use descriptor_pool::create_cached_descriptor;
use error::Result;
use merging::merge_msg_into_pyobj; use merging::merge_msg_into_pyobj;
use prost_reflect::DynamicMessage; use prost_reflect::DynamicMessage;
use pyo3::{exceptions::PyRuntimeError, prelude::*}; use pyo3::prelude::*;
#[pyfunction] #[pyfunction]
fn deserialize(obj: &PyAny, buf: &[u8]) -> PyResult<()> { fn deserialize(obj: &PyAny, buf: &[u8]) -> Result<()> {
let desc = create_cached_descriptor(obj)?; let desc = create_cached_descriptor(obj)?;
let msg = DynamicMessage::decode(desc, buf) let msg = DynamicMessage::decode(desc, buf)?;
.map_err(|_| PyRuntimeError::new_err("Error on deserializing."))?;
merge_msg_into_pyobj(obj, &msg)?; merge_msg_into_pyobj(obj, &msg)?;
Ok(()) Ok(())
} }

View File

@ -1,9 +1,11 @@
use crate::{
error::{Error, Result},
py_any_extras::PyAnyExtras,
};
use prost_reflect::{DynamicMessage, ReflectMessage, Value}; use prost_reflect::{DynamicMessage, ReflectMessage, Value};
use pyo3::{exceptions::PyRuntimeError, IntoPy, PyAny, PyResult}; use pyo3::{IntoPy, PyAny};
use crate::py_any_extras::PyAnyExtras; pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> {
pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> PyResult<()> {
for field in msg.descriptor().fields() { for field in msg.descriptor().fields() {
let field_name = field.name(); let field_name = field.name();
let proto_meta = obj.get_proto_meta()?; let proto_meta = obj.get_proto_meta()?;
@ -21,7 +23,7 @@ fn map_field_value<'py>(
field_name: &str, field_name: &str,
field_value: &Value, field_value: &Value,
proto_meta: &'py PyAny, proto_meta: &'py PyAny,
) -> PyResult<&'py PyAny> { ) -> Result<&'py PyAny> {
let py = proto_meta.py(); let py = proto_meta.py();
match field_value { match field_value {
Value::Bool(x) => Ok(x.into_py(py).into_ref(py)), Value::Bool(x) => Ok(x.into_py(py).into_ref(py)),
@ -38,6 +40,6 @@ fn map_field_value<'py>(
merge_msg_into_pyobj(obj, msg)?; merge_msg_into_pyobj(obj, msg)?;
Ok(obj) Ok(obj)
} }
_ => Err(PyRuntimeError::new_err("Unsupported type")), value => Err(Error::UnsupportedType(value.to_string())),
} }
} }

View File

@ -1,25 +1,27 @@
use pyo3::{PyAny, PyResult}; use crate::error::Result;
use pyo3::PyAny;
pub trait PyAnyExtras<'py> { pub trait PyAnyExtras<'py> {
fn qualified_class_name(&self) -> PyResult<String>; fn qualified_class_name(&self) -> Result<String>;
fn get_proto_meta(&self) -> PyResult<&'py PyAny>; fn get_proto_meta(&self) -> Result<&'py PyAny>;
fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny>; fn create_instance(&self, field_name: &str) -> Result<&'py PyAny>;
} }
impl<'py> PyAnyExtras<'py> for &'py PyAny { impl<'py> PyAnyExtras<'py> for &'py PyAny {
fn qualified_class_name(&self) -> PyResult<String> { fn qualified_class_name(&self) -> Result<String> {
let class = self.getattr("__class__")?; let class = self.getattr("__class__")?;
let module = class.getattr("__module__")?; let module = class.getattr("__module__")?;
let name = class.getattr("__name__")?; let name = class.getattr("__name__")?;
Ok(format!("{module}.{name}")) Ok(format!("{module}.{name}"))
} }
fn get_proto_meta(&self) -> PyResult<&'py PyAny> { fn get_proto_meta(&self) -> Result<&'py PyAny> {
self.getattr("_betterproto") Ok(self.getattr("_betterproto")?)
} }
fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny> { fn create_instance(&self, field_name: &str) -> Result<&'py PyAny> {
let res = self.getattr("cls_by_field")? let res = self
.getattr("cls_by_field")?
.get_item(field_name)? .get_item(field_name)?
.call0()?; .call0()?;
Ok(res) Ok(res)