diff --git a/betterproto-extras/Cargo.lock b/betterproto-extras/Cargo.lock index 670818b..a61fc33 100644 --- a/betterproto-extras/Cargo.lock +++ b/betterproto-extras/Cargo.lock @@ -20,6 +20,7 @@ version = "0.1.0" dependencies = [ "prost-reflect", "pyo3", + "thiserror", ] [[package]] @@ -144,7 +145,7 @@ dependencies = [ "itertools", "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -213,7 +214,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -224,7 +225,7 @@ checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -268,12 +269,43 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "target-lexicon" version = "0.12.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" +[[package]] +name = "thiserror" +version = "1.0.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.29", +] + [[package]] name = "unicode-ident" version = "1.0.11" diff --git a/betterproto-extras/Cargo.toml b/betterproto-extras/Cargo.toml index 412adca..dc782ba 100644 --- a/betterproto-extras/Cargo.toml +++ b/betterproto-extras/Cargo.toml @@ -10,3 +10,4 @@ crate-type = ["cdylib"] [dependencies] prost-reflect = "0.11.4" pyo3 = { version = "0.19.0", features = ["abi3-py37", "extension-module"] } +thiserror = "1.0.47" diff --git a/betterproto-extras/src/descriptor_pool.rs b/betterproto-extras/src/descriptor_pool.rs index b83caaa..f301bcb 100644 --- a/betterproto-extras/src/descriptor_pool.rs +++ b/betterproto-extras/src/descriptor_pool.rs @@ -1,15 +1,17 @@ +use crate::{ + error::{Error, Result}, + py_any_extras::PyAnyExtras, +}; use prost_reflect::{ prost_types::{ field_descriptor_proto::Type, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, }, DescriptorPool, MessageDescriptor, }; -use pyo3::{exceptions::PyRuntimeError, PyAny, PyResult}; +use pyo3::PyAny; use std::sync::{Mutex, OnceLock}; -use crate::py_any_extras::PyAnyExtras; - -pub fn create_cached_descriptor(obj: &PyAny) -> PyResult { +pub fn create_cached_descriptor(obj: &PyAny) -> Result { static DESCRIPTOR_POOL: OnceLock> = OnceLock::new(); let mut pool = DESCRIPTOR_POOL .get_or_init(|| Mutex::new(DescriptorPool::new())) @@ -22,7 +24,7 @@ pub fn create_cached_descriptor(obj: &PyAny) -> PyResult { fn create_cached_descriptor_in_pool<'py>( obj: &'py PyAny, pool: &mut DescriptorPool, -) -> PyResult { +) -> Result { let name = obj.qualified_class_name()?; if let Some(desc) = pool.get_message_by_name(&name) { return Ok(desc); @@ -69,13 +71,12 @@ fn create_cached_descriptor_in_pool<'py>( file.message_type.push(message); - pool.add_file_descriptor_proto(file) - .map_err(|_| PyRuntimeError::new_err("Error on proto registration"))?; + pool.add_file_descriptor_proto(file)?; Ok(pool.get_message_by_name(&name).expect("Just registered...")) } -fn map_type(str: &str) -> PyResult { +fn map_type(str: &str) -> Result { match str { "enum" => Ok(Type::Enum), "bool" => Ok(Type::Bool), @@ -94,6 +95,6 @@ fn map_type(str: &str) -> PyResult { "string" => Ok(Type::String), "bytes" => Ok(Type::Bytes), "message" => Ok(Type::Message), - _ => Err(PyRuntimeError::new_err("Unsupported type")), + _ => Err(Error::UnsupportedType(str.to_string())), } } diff --git a/betterproto-extras/src/error.rs b/betterproto-extras/src/error.rs new file mode 100644 index 0000000..edb6a3f --- /dev/null +++ b/betterproto-extras/src/error.rs @@ -0,0 +1,23 @@ +use prost_reflect::{prost::DecodeError, DescriptorError}; +use pyo3::{exceptions::PyRuntimeError, PyErr}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Given object is not a valid betterproto message.")] + NoBetterprotoMessage(#[from] PyErr), + #[error("Unsupported type `{0}`.")] + UnsupportedType(String), + #[error("Error on proto registration")] + FailedToRegisterDescriptor(#[from] DescriptorError), + #[error("The given binary data does not match the protobuf schema.")] + FailedToDecode(#[from] DecodeError), +} + +pub type Result = core::result::Result; + +impl From for PyErr { + fn from(value: Error) -> Self { + PyRuntimeError::new_err(value.to_string()) + } +} diff --git a/betterproto-extras/src/lib.rs b/betterproto-extras/src/lib.rs index b78211e..416d5d1 100644 --- a/betterproto-extras/src/lib.rs +++ b/betterproto-extras/src/lib.rs @@ -1,17 +1,18 @@ mod descriptor_pool; +mod error; mod merging; mod py_any_extras; use descriptor_pool::create_cached_descriptor; +use error::Result; use merging::merge_msg_into_pyobj; use prost_reflect::DynamicMessage; -use pyo3::{exceptions::PyRuntimeError, prelude::*}; +use pyo3::prelude::*; #[pyfunction] -fn deserialize(obj: &PyAny, buf: &[u8]) -> PyResult<()> { +fn deserialize(obj: &PyAny, buf: &[u8]) -> Result<()> { let desc = create_cached_descriptor(obj)?; - let msg = DynamicMessage::decode(desc, buf) - .map_err(|_| PyRuntimeError::new_err("Error on deserializing."))?; + let msg = DynamicMessage::decode(desc, buf)?; merge_msg_into_pyobj(obj, &msg)?; Ok(()) } diff --git a/betterproto-extras/src/merging.rs b/betterproto-extras/src/merging.rs index ec8092f..09cc2fb 100644 --- a/betterproto-extras/src/merging.rs +++ b/betterproto-extras/src/merging.rs @@ -1,9 +1,11 @@ +use crate::{ + error::{Error, Result}, + py_any_extras::PyAnyExtras, +}; use prost_reflect::{DynamicMessage, ReflectMessage, Value}; -use pyo3::{exceptions::PyRuntimeError, IntoPy, PyAny, PyResult}; +use pyo3::{IntoPy, PyAny}; -use crate::py_any_extras::PyAnyExtras; - -pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> PyResult<()> { +pub fn merge_msg_into_pyobj(obj: &PyAny, msg: &DynamicMessage) -> Result<()> { for field in msg.descriptor().fields() { let field_name = field.name(); let proto_meta = obj.get_proto_meta()?; @@ -21,7 +23,7 @@ fn map_field_value<'py>( field_name: &str, field_value: &Value, proto_meta: &'py PyAny, -) -> PyResult<&'py PyAny> { +) -> Result<&'py PyAny> { let py = proto_meta.py(); match field_value { Value::Bool(x) => Ok(x.into_py(py).into_ref(py)), @@ -38,6 +40,6 @@ fn map_field_value<'py>( merge_msg_into_pyobj(obj, msg)?; Ok(obj) } - _ => Err(PyRuntimeError::new_err("Unsupported type")), + value => Err(Error::UnsupportedType(value.to_string())), } } diff --git a/betterproto-extras/src/py_any_extras.rs b/betterproto-extras/src/py_any_extras.rs index 86db978..31c7b52 100644 --- a/betterproto-extras/src/py_any_extras.rs +++ b/betterproto-extras/src/py_any_extras.rs @@ -1,25 +1,27 @@ -use pyo3::{PyAny, PyResult}; +use crate::error::Result; +use pyo3::PyAny; pub trait PyAnyExtras<'py> { - fn qualified_class_name(&self) -> PyResult; - fn get_proto_meta(&self) -> PyResult<&'py PyAny>; - fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny>; + fn qualified_class_name(&self) -> Result; + fn get_proto_meta(&self) -> Result<&'py PyAny>; + fn create_instance(&self, field_name: &str) -> Result<&'py PyAny>; } impl<'py> PyAnyExtras<'py> for &'py PyAny { - fn qualified_class_name(&self) -> PyResult { + fn qualified_class_name(&self) -> Result { let class = self.getattr("__class__")?; let module = class.getattr("__module__")?; let name = class.getattr("__name__")?; Ok(format!("{module}.{name}")) } - fn get_proto_meta(&self) -> PyResult<&'py PyAny> { - self.getattr("_betterproto") + fn get_proto_meta(&self) -> Result<&'py PyAny> { + Ok(self.getattr("_betterproto")?) } - fn create_instance(&self, field_name: &str) -> PyResult<&'py PyAny> { - let res = self.getattr("cls_by_field")? + fn create_instance(&self, field_name: &str) -> Result<&'py PyAny> { + let res = self + .getattr("cls_by_field")? .get_item(field_name)? .call0()?; Ok(res)