diff --git a/betterproto-extras/Cargo.lock b/betterproto-extras/Cargo.lock index 911b4ee..6d6ce22 100644 --- a/betterproto-extras/Cargo.lock +++ b/betterproto-extras/Cargo.lock @@ -18,6 +18,7 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" name = "betterproto-extras" version = "0.1.0" dependencies = [ + "indoc 2.0.3", "prost-reflect", "pyo3", "thiserror", @@ -53,6 +54,12 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" +[[package]] +name = "indoc" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c785eefb63ebd0e33416dfcb8d6da0bf27ce752843a45632a67bf10d4d4b5c4" + [[package]] name = "itertools" version = "0.10.5" @@ -175,7 +182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" dependencies = [ "cfg-if", - "indoc", + "indoc 1.0.9", "libc", "memoffset", "parking_lot", diff --git a/betterproto-extras/Cargo.toml b/betterproto-extras/Cargo.toml index a3af32f..205d1af 100644 --- a/betterproto-extras/Cargo.toml +++ b/betterproto-extras/Cargo.toml @@ -8,6 +8,7 @@ name = "betterproto_extras" crate-type = ["cdylib"] [dependencies] +indoc = "2.0.3" prost-reflect = "0.11.5" -pyo3 = { version = "0.19.0", features = ["abi3-py37", "extension-module"] } +pyo3 = { version = "0.19.2", features = ["abi3-py37", "extension-module"] } thiserror = "1.0.47" diff --git a/betterproto-extras/src/descriptor_pool.rs b/betterproto-extras/src/descriptor_pool.rs index 90fc203..2f02cb2 100644 --- a/betterproto-extras/src/descriptor_pool.rs +++ b/betterproto-extras/src/descriptor_pool.rs @@ -118,44 +118,62 @@ fn add_message_to_file( .transpose()? { Some(Type::Bool) => { - field.type_name = Some(".google.protobuf.BoolValue".to_string()); + field.type_name = Some("google.protobuf.BoolValue".to_string()); } Some(Type::Double) => { - field.type_name = Some(".google.protobuf.DoubleValue".to_string()); + field.type_name = Some("google.protobuf.DoubleValue".to_string()); } Some(Type::Float) => { - field.type_name = Some(".google.protobuf.FloatValue".to_string()); + field.type_name = Some("google.protobuf.FloatValue".to_string()); } Some(Type::Int64) => { - field.type_name = Some(".google.protobuf.Int64Value".to_string()); + field.type_name = Some("google.protobuf.Int64Value".to_string()); } Some(Type::Uint64) => { - field.type_name = Some(".google.protobuf.UInt64Value".to_string()); + field.type_name = Some("google.protobuf.UInt64Value".to_string()); } Some(Type::Int32) => { - field.type_name = Some(".google.protobuf.Int32Value".to_string()); + field.type_name = Some("google.protobuf.Int32Value".to_string()); } Some(Type::Uint32) => { - field.type_name = Some(".google.protobuf.UInt32Value".to_string()); + field.type_name = Some("google.protobuf.UInt32Value".to_string()); } Some(Type::String) => { - field.type_name = Some(".google.protobuf.StringValue".to_string()); + field.type_name = Some("google.protobuf.StringValue".to_string()); } Some(Type::Bytes) => { - field.type_name = Some(".google.protobuf.BytesValue".to_string()); + field.type_name = Some("google.protobuf.BytesValue".to_string()); } Some(t) => return Err(Error::UnsupportedWrapperType(t)), None => { let cls = meta.get_class(field_name)?; let cls_name = cls.qualified_name()?; - field.type_name = Some(cls_name.clone()); - if message_name != cls_name - && pool.get_message_by_name(&cls_name).is_none() - && !file.message_type.iter().any(|item| item.name() == cls_name) - && !messages_to_add.iter().any(|item| item.0 == cls_name) - { - messages_to_add.push((cls_name, cls.call0()?)); + match cls_name.as_str() { + "datetime.datetime" => { + field.type_name = + Some("google.protobuf.Timestamp".to_string()); + } + "datetime.timedelta" => { + field.type_name = + Some("google.protobuf.Duration".to_string()); + } + _ => { + field.type_name = Some(cls_name.clone()); + + if message_name != cls_name + && pool.get_message_by_name(&cls_name).is_none() + && !file + .message_type + .iter() + .any(|item| item.name() == cls_name) + && !messages_to_add + .iter() + .any(|item| item.0 == cls_name) + { + messages_to_add.push((cls_name, cls.call0()?)); + } + } } } }, diff --git a/betterproto-extras/src/merging.rs b/betterproto-extras/src/merging.rs index a36caa0..e7bb593 100644 --- a/betterproto-extras/src/merging.rs +++ b/betterproto-extras/src/merging.rs @@ -1,8 +1,13 @@ use crate::{error::Result, py_any_extras::PyAnyExtras}; -use prost_reflect::{DynamicMessage, MapKey, ReflectMessage, Value}; +use indoc::indoc; +use prost_reflect::{ + prost_types::{Duration, Timestamp}, + DynamicMessage, MapKey, ReflectMessage, Value, +}; use pyo3::{ - types::{IntoPyDict, PyBytes}, - PyAny, PyObject, Python, ToPyObject, + sync::GILOnceCell, + types::{IntoPyDict, PyBytes, PyModule}, + Py, PyAny, PyObject, Python, ToPyObject, }; pub fn merge_msg_into_pyobj(obj: &PyAny, mut msg: DynamicMessage) -> Result<()> { @@ -79,6 +84,14 @@ fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) -> .get_field_by_number(1) .and_then(|val| val.as_bytes().map(|b| PyBytes::new(py, b))) .to_object(py)), + "google.protobuf.Timestamp" => { + let msg = msg.transcode_to::()?; + Ok(create_py_datetime(&msg, py)) + } + "google.protobuf.Duration" => { + let msg = msg.transcode_to::()?; + Ok(create_py_timedelta(&msg, py)) + } _ => { let obj = proto_meta.create_instance(field_name)?; merge_msg_into_pyobj(obj, msg)?; @@ -118,3 +131,52 @@ fn map_key(key: MapKey, py: Python) -> PyObject { MapKey::String(x) => x.to_object(py), } } + +fn create_py_datetime(ts: &Timestamp, py: Python) -> PyObject { + static CONSTRUCTOR_CACHE: GILOnceCell> = GILOnceCell::new(); + let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || { + let constructor = PyModule::from_code( + py, + indoc! {" + from datetime import datetime, timezone + + def constructor(ts): + return datetime.fromtimestamp(ts, tz=timezone.utc) + "}, + "", + "", + ) + .expect("This is a valid Python module") + .getattr("constructor") + .expect("Attribute exists"); + Py::from(constructor) + }); + let ts = (ts.seconds as f64) + (ts.nanos as f64) / 1e9; + constructor + .call1(py, (ts,)) + .expect("static function will not fail") +} + +fn create_py_timedelta(duration: &Duration, py: Python) -> PyObject { + static CONSTRUCTOR_CACHE: GILOnceCell> = GILOnceCell::new(); + let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || { + let constructor = PyModule::from_code( + py, + indoc! {" + from datetime import timedelta + + def constructor(s, ms): + return timedelta(seconds=s, microseconds=ms) + "}, + "", + "", + ) + .expect("This is a valid Python module") + .getattr("constructor") + .expect("Attribute exists"); + Py::from(constructor) + }); + constructor + .call1(py, (duration.seconds as f64, (duration.nanos as f64) / 1e3)) + .expect("static function will not fail") +}