supporting datetime and timedelta

This commit is contained in:
Erik Friese 2023-09-05 11:27:04 +02:00
parent 950d2f6536
commit fd02cb6180
4 changed files with 109 additions and 21 deletions

View File

@ -18,6 +18,7 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
name = "betterproto-extras"
version = "0.1.0"
dependencies = [
"indoc 2.0.3",
"prost-reflect",
"pyo3",
"thiserror",
@ -53,6 +54,12 @@ version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306"
[[package]]
name = "indoc"
version = "2.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c785eefb63ebd0e33416dfcb8d6da0bf27ce752843a45632a67bf10d4d4b5c4"
[[package]]
name = "itertools"
version = "0.10.5"
@ -175,7 +182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38"
dependencies = [
"cfg-if",
"indoc",
"indoc 1.0.9",
"libc",
"memoffset",
"parking_lot",

View File

@ -8,6 +8,7 @@ name = "betterproto_extras"
crate-type = ["cdylib"]
[dependencies]
indoc = "2.0.3"
prost-reflect = "0.11.5"
pyo3 = { version = "0.19.0", features = ["abi3-py37", "extension-module"] }
pyo3 = { version = "0.19.2", features = ["abi3-py37", "extension-module"] }
thiserror = "1.0.47"

View File

@ -118,44 +118,62 @@ fn add_message_to_file(
.transpose()?
{
Some(Type::Bool) => {
field.type_name = Some(".google.protobuf.BoolValue".to_string());
field.type_name = Some("google.protobuf.BoolValue".to_string());
}
Some(Type::Double) => {
field.type_name = Some(".google.protobuf.DoubleValue".to_string());
field.type_name = Some("google.protobuf.DoubleValue".to_string());
}
Some(Type::Float) => {
field.type_name = Some(".google.protobuf.FloatValue".to_string());
field.type_name = Some("google.protobuf.FloatValue".to_string());
}
Some(Type::Int64) => {
field.type_name = Some(".google.protobuf.Int64Value".to_string());
field.type_name = Some("google.protobuf.Int64Value".to_string());
}
Some(Type::Uint64) => {
field.type_name = Some(".google.protobuf.UInt64Value".to_string());
field.type_name = Some("google.protobuf.UInt64Value".to_string());
}
Some(Type::Int32) => {
field.type_name = Some(".google.protobuf.Int32Value".to_string());
field.type_name = Some("google.protobuf.Int32Value".to_string());
}
Some(Type::Uint32) => {
field.type_name = Some(".google.protobuf.UInt32Value".to_string());
field.type_name = Some("google.protobuf.UInt32Value".to_string());
}
Some(Type::String) => {
field.type_name = Some(".google.protobuf.StringValue".to_string());
field.type_name = Some("google.protobuf.StringValue".to_string());
}
Some(Type::Bytes) => {
field.type_name = Some(".google.protobuf.BytesValue".to_string());
field.type_name = Some("google.protobuf.BytesValue".to_string());
}
Some(t) => return Err(Error::UnsupportedWrapperType(t)),
None => {
let cls = meta.get_class(field_name)?;
let cls_name = cls.qualified_name()?;
field.type_name = Some(cls_name.clone());
if message_name != cls_name
&& pool.get_message_by_name(&cls_name).is_none()
&& !file.message_type.iter().any(|item| item.name() == cls_name)
&& !messages_to_add.iter().any(|item| item.0 == cls_name)
{
messages_to_add.push((cls_name, cls.call0()?));
match cls_name.as_str() {
"datetime.datetime" => {
field.type_name =
Some("google.protobuf.Timestamp".to_string());
}
"datetime.timedelta" => {
field.type_name =
Some("google.protobuf.Duration".to_string());
}
_ => {
field.type_name = Some(cls_name.clone());
if message_name != cls_name
&& pool.get_message_by_name(&cls_name).is_none()
&& !file
.message_type
.iter()
.any(|item| item.name() == cls_name)
&& !messages_to_add
.iter()
.any(|item| item.0 == cls_name)
{
messages_to_add.push((cls_name, cls.call0()?));
}
}
}
}
},

View File

@ -1,8 +1,13 @@
use crate::{error::Result, py_any_extras::PyAnyExtras};
use prost_reflect::{DynamicMessage, MapKey, ReflectMessage, Value};
use indoc::indoc;
use prost_reflect::{
prost_types::{Duration, Timestamp},
DynamicMessage, MapKey, ReflectMessage, Value,
};
use pyo3::{
types::{IntoPyDict, PyBytes},
PyAny, PyObject, Python, ToPyObject,
sync::GILOnceCell,
types::{IntoPyDict, PyBytes, PyModule},
Py, PyAny, PyObject, Python, ToPyObject,
};
pub fn merge_msg_into_pyobj(obj: &PyAny, mut msg: DynamicMessage) -> Result<()> {
@ -79,6 +84,14 @@ fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) ->
.get_field_by_number(1)
.and_then(|val| val.as_bytes().map(|b| PyBytes::new(py, b)))
.to_object(py)),
"google.protobuf.Timestamp" => {
let msg = msg.transcode_to::<Timestamp>()?;
Ok(create_py_datetime(&msg, py))
}
"google.protobuf.Duration" => {
let msg = msg.transcode_to::<Duration>()?;
Ok(create_py_timedelta(&msg, py))
}
_ => {
let obj = proto_meta.create_instance(field_name)?;
merge_msg_into_pyobj(obj, msg)?;
@ -118,3 +131,52 @@ fn map_key(key: MapKey, py: Python) -> PyObject {
MapKey::String(x) => x.to_object(py),
}
}
fn create_py_datetime(ts: &Timestamp, py: Python) -> PyObject {
static CONSTRUCTOR_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || {
let constructor = PyModule::from_code(
py,
indoc! {"
from datetime import datetime, timezone
def constructor(ts):
return datetime.fromtimestamp(ts, tz=timezone.utc)
"},
"",
"",
)
.expect("This is a valid Python module")
.getattr("constructor")
.expect("Attribute exists");
Py::from(constructor)
});
let ts = (ts.seconds as f64) + (ts.nanos as f64) / 1e9;
constructor
.call1(py, (ts,))
.expect("static function will not fail")
}
fn create_py_timedelta(duration: &Duration, py: Python) -> PyObject {
static CONSTRUCTOR_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || {
let constructor = PyModule::from_code(
py,
indoc! {"
from datetime import timedelta
def constructor(s, ms):
return timedelta(seconds=s, microseconds=ms)
"},
"",
"",
)
.expect("This is a valid Python module")
.getattr("constructor")
.expect("Attribute exists");
Py::from(constructor)
});
constructor
.call1(py, (duration.seconds as f64, (duration.nanos as f64) / 1e3))
.expect("static function will not fail")
}