From bdd3389b175c29a17e728fa50ae7c84ab2d9bf47 Mon Sep 17 00:00:00 2001 From: Erik Friese Date: Thu, 31 Aug 2023 17:17:28 +0200 Subject: [PATCH] bugfix in proto descriptor creation reference cycles in betterproto messages have led to infinite recursion --- betterproto-extras/src/descriptor_pool.rs | 195 +++++++++++----------- 1 file changed, 100 insertions(+), 95 deletions(-) diff --git a/betterproto-extras/src/descriptor_pool.rs b/betterproto-extras/src/descriptor_pool.rs index a672366..e88bdbe 100644 --- a/betterproto-extras/src/descriptor_pool.rs +++ b/betterproto-extras/src/descriptor_pool.rs @@ -20,116 +20,121 @@ pub fn create_cached_descriptor(obj: &PyAny) -> Result { .lock() .unwrap(); - create_cached_message_in_pool(obj, &mut pool) -} - -fn create_cached_message_in_pool( - obj: &PyAny, - pool: &mut DescriptorPool, -) -> Result { let name = obj.qualified_class_name()?; if let Some(desc) = pool.get_message_by_name(&name) { return Ok(desc); } - let meta = obj.get_proto_meta()?; - - let mut message = DescriptorProto { + let mut file = FileDescriptorProto { name: Some(name.clone()), ..Default::default() }; - for item in meta - .getattr("meta_by_field_name")? - .call_method0("items")? - .iter()? - { - let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?; - message.field.push({ - let mut field = FieldDescriptorProto { - name: Some(field_name.to_string()), - number: Some(field_meta.getattr("number")?.extract::()?), - ..Default::default() - }; - field.set_type(map_type( - field_meta.getattr("proto_type")?.extract::<&str>()?, - )?); - - match field.r#type() { - Type::Message => { - let instance = meta.create_instance(field_name)?; - let cls_name = instance.qualified_class_name()?; - field.type_name = Some(cls_name.to_string()); - create_cached_message_in_pool(instance, pool)?; - } - Type::Enum => { - let cls = meta.get_class(field_name)?; - let cls_name = cls.qualified_name()?; - field.type_name = Some(cls_name.to_string()); - create_cached_enum_in_pool(cls, pool)?; - } - _ => {} - } - - if meta.is_list_field(field_name)? { - field.set_label(Label::Repeated); - } else if field_meta.getattr("optional")?.extract::()? { - field.proto3_optional = Some(true); - } - - if let Some(grp) = meta.oneof_group(field_name)? { - let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp); - - match oneof_index { - Some(i) => field.oneof_index = Some(i as i32), - None => { - message.oneof_decl.push(OneofDescriptorProto { - name: Some(grp), - ..Default::default() - }); - field.oneof_index = Some((message.oneof_decl.len() - 1) as i32) - } - } - } - - field - }); - } - - pool.add_file_descriptor_proto(FileDescriptorProto { - name: Some(name.clone()), - message_type: vec![message], - ..Default::default() - })?; - + add_message_to_file(name.clone(), obj, &pool, &mut file)?; + pool.add_file_descriptor_proto(file)?; Ok(pool.get_message_by_name(&name).expect("Just registered...")) } -fn create_cached_enum_in_pool(cls: &PyAny, pool: &mut DescriptorPool) -> Result<()> { - let cls_name = cls.qualified_name()?; - if pool.get_enum_by_name(&cls_name).is_some() { - return Ok(()); - } +fn add_message_to_file( + name: String, + obj: &PyAny, + pool: &DescriptorPool, + file: &mut FileDescriptorProto, +) -> Result<()> { + let mut messages_to_add = vec![(name, obj)]; - let mut proto = EnumDescriptorProto { - name: Some(cls_name.clone()), - ..Default::default() - }; - - for item in cls.iter()? { - let item = item?; - proto.value.push(EnumValueDescriptorProto { - number: Some(item.getattr("value")?.extract()?), - name: Some(item.getattr("name")?.extract()?), + while let Some((name, obj)) = messages_to_add.pop() { + let meta = obj.get_proto_meta()?; + let mut message = DescriptorProto { + name: Some(name.to_string()), ..Default::default() - }); - } + }; - pool.add_file_descriptor_proto(FileDescriptorProto { - name: Some(cls_name), - enum_type: vec![proto], - ..Default::default() - })?; + for item in meta + .getattr("meta_by_field_name")? + .call_method0("items")? + .iter()? + { + let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?; + message.field.push({ + let mut field = FieldDescriptorProto { + name: Some(field_name.to_string()), + number: Some(field_meta.getattr("number")?.extract::()?), + ..Default::default() + }; + field.set_type(map_type( + field_meta.getattr("proto_type")?.extract::<&str>()?, + )?); + + match field.r#type() { + Type::Message => { + let cls = meta.get_class(field_name)?; + let cls_name = cls.qualified_name()?; + field.type_name = Some(cls_name.clone()); + + if name != cls_name + && pool.get_message_by_name(&cls_name).is_none() + && !file.message_type.iter().any(|item| item.name() == cls_name) + && !messages_to_add.iter().any(|item| item.0 == cls_name) + { + messages_to_add.push((cls_name, cls.call0()?)); + } + } + Type::Enum => { + let cls = meta.get_class(field_name)?; + let cls_name = cls.qualified_name()?; + field.type_name = Some(cls_name.to_string()); + + if pool.get_enum_by_name(&cls_name).is_none() + && !file.enum_type.iter().any(|item| item.name() == cls_name) + { + let mut proto = EnumDescriptorProto { + name: Some(cls_name), + ..Default::default() + }; + + for item in cls.iter()? { + let item = item?; + proto.value.push(EnumValueDescriptorProto { + number: Some(item.getattr("value")?.extract()?), + name: Some(item.getattr("name")?.extract()?), + ..Default::default() + }); + } + + file.enum_type.push(proto); + } + } + _ => {} + } + + if meta.is_list_field(field_name)? { + field.set_label(Label::Repeated); + } else if field_meta.getattr("optional")?.extract::()? { + field.proto3_optional = Some(true); + } + + if let Some(grp) = meta.oneof_group(field_name)? { + let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp); + + match oneof_index { + Some(i) => field.oneof_index = Some(i as i32), + None => { + message.oneof_decl.push(OneofDescriptorProto { + name: Some(grp), + ..Default::default() + }); + field.oneof_index = Some((message.oneof_decl.len() - 1) as i32) + } + } + } + + field + }); + } + + file.message_type.push(message); + } Ok(()) }