diff --git a/src/idl_gen_rust.cpp b/src/idl_gen_rust.cpp index 393006591..1ca6254a9 100644 --- a/src/idl_gen_rust.cpp +++ b/src/idl_gen_rust.cpp @@ -165,6 +165,24 @@ FullType GetFullType(const Type &type) { return ftBool; } +// If the second parameter is false then wrap the first with Option<...> +std::string WrapInOptionIfNotRequired(std::string s, bool required) { + if (required) { + return s; + } else { + return "Option<" + s + ">"; + } +} + +// If the second parameter is false then add .unwrap() +std::string AddUnwrapIfRequired(std::string s, bool required) { + if (required) { + return s + ".unwrap()"; + } else { + return s; + } +} + namespace rust { class RustGenerator : public BaseGenerator { @@ -967,11 +985,11 @@ class RustGenerator : public BaseGenerator { } case ftStruct: { const auto typname = WrapInNameSpace(*type.struct_def); - return "Option<&" + lifetime + " " + typname + ">"; + return WrapInOptionIfNotRequired("&" + lifetime + " " + typname, field.required); } case ftTable: { const auto typname = WrapInNameSpace(*type.struct_def); - return "Option<" + typname + "<" + lifetime + ">>"; + return WrapInOptionIfNotRequired(typname + "<" + lifetime + ">", field.required); } case ftEnumKey: case ftUnionKey: { @@ -980,38 +998,38 @@ class RustGenerator : public BaseGenerator { } case ftUnionValue: { - return "Option>"; + return WrapInOptionIfNotRequired("flatbuffers::Table<" + lifetime + ">", field.required); } case ftString: { - return "Option<&" + lifetime + " str>"; + return WrapInOptionIfNotRequired("&" + lifetime + " str", field.required); } case ftVectorOfInteger: case ftVectorOfFloat: { const auto typname = GetTypeBasic(type.VectorType()); if (IsOneByte(type.VectorType().base_type)) { - return "Option<&" + lifetime + " [" + typname + "]>"; + return WrapInOptionIfNotRequired("&" + lifetime + " [" + typname + "]", field.required); } - return "Option>"; + return WrapInOptionIfNotRequired("flatbuffers::Vector<" + lifetime + ", " + typname + ">", field.required); } case ftVectorOfBool: { - return "Option<&" + lifetime + " [bool]>"; + return WrapInOptionIfNotRequired("&" + lifetime + " [bool]", field.required); } case ftVectorOfEnumKey: { const auto typname = WrapInNameSpace(*type.enum_def); - return "Option>"; + return WrapInOptionIfNotRequired("flatbuffers::Vector<" + lifetime + ", " + typname + ">", field.required); } case ftVectorOfStruct: { const auto typname = WrapInNameSpace(*type.struct_def); - return "Option<&" + lifetime + " [" + typname + "]>"; + return WrapInOptionIfNotRequired("&" + lifetime + " [" + typname + "]", field.required); } case ftVectorOfTable: { const auto typname = WrapInNameSpace(*type.struct_def); - return "Option>>>"; + return WrapInOptionIfNotRequired("flatbuffers::Vector>>", field.required); } case ftVectorOfString: { - return "Option>>"; + return WrapInOptionIfNotRequired("flatbuffers::Vector>", field.required); } case ftVectorOfUnionValue: { FLATBUFFERS_ASSERT(false && "vectors of unions are not yet supported"); @@ -1041,17 +1059,17 @@ class RustGenerator : public BaseGenerator { } case ftStruct: { const auto typname = WrapInNameSpace(*type.struct_def); - return "self._tab.get::<" + typname + ">(" + offset_name + ", None)"; + return AddUnwrapIfRequired("self._tab.get::<" + typname + ">(" + offset_name + ", None)", field.required); } case ftTable: { const auto typname = WrapInNameSpace(*type.struct_def); - return "self._tab.get::>>(" + offset_name + ", None)"; + return AddUnwrapIfRequired("self._tab.get::>>(" + offset_name + ", None)", field.required); } case ftUnionValue: { - return "self._tab.get::>>(" + offset_name + \ - ", None)"; + ", None)", field.required); } case ftUnionKey: case ftEnumKey: { @@ -1062,8 +1080,8 @@ class RustGenerator : public BaseGenerator { ", Some(" + default_value + ")).unwrap()"; } case ftString: { - return "self._tab.get::>(" + \ - offset_name + ", None)"; + return AddUnwrapIfRequired("self._tab.get::>(" + \ + offset_name + ", None)", field.required); } case ftVectorOfInteger: @@ -1076,35 +1094,35 @@ class RustGenerator : public BaseGenerator { if (IsOneByte(type.VectorType().base_type)) { s += ".map(|v| v.safe_slice())"; } - return s; + return AddUnwrapIfRequired(s, field.required); } case ftVectorOfBool: { - return "self._tab.get::>>(" + \ - offset_name + ", None).map(|v| v.safe_slice())"; + offset_name + ", None).map(|v| v.safe_slice())", field.required); } case ftVectorOfEnumKey: { const auto typname = WrapInNameSpace(*type.enum_def); - return "self._tab.get::>>(" + \ - offset_name + ", None)"; + offset_name + ", None)", field.required); } case ftVectorOfStruct: { const auto typname = WrapInNameSpace(*type.struct_def); - return "self._tab.get::>>(" + \ - offset_name + ", None).map(|v| v.safe_slice() )"; + offset_name + ", None).map(|v| v.safe_slice() )", field.required); } case ftVectorOfTable: { const auto typname = WrapInNameSpace(*type.struct_def); - return "self._tab.get::>>>>(" + offset_name + ", None)"; + "<" + lifetime + ">>>>>(" + offset_name + ", None)", field.required); } case ftVectorOfString: { - return "self._tab.get::>>>(" + offset_name + ", None)"; + lifetime + " str>>>>(" + offset_name + ", None)", field.required); } case ftVectorOfUnionValue: { FLATBUFFERS_ASSERT(false && "vectors of unions are not yet supported"); @@ -1449,19 +1467,8 @@ class RustGenerator : public BaseGenerator { // must only be called if the field key is defined. void GenKeyFieldMethods(const FieldDef &field) { FLATBUFFERS_ASSERT(field.key); - const bool is_string = (field.value.type.base_type == BASE_TYPE_STRING); - if (is_string) { - code_.SetValue("KEY_TYPE", "Option<&str>"); - } else { - FLATBUFFERS_ASSERT(IsScalar(field.value.type.base_type)); - auto type = GetTypeBasic(field.value.type); - if (parser_.opts.scoped_enums && field.value.type.enum_def && - IsScalar(field.value.type.base_type)) { - type = GetTypeGet(field.value.type); - } - code_.SetValue("KEY_TYPE", type); - } + code_.SetValue("KEY_TYPE", GenTableAccessorFuncReturnType(field, "")); code_ += " #[inline]"; code_ += " pub fn key_compare_less_than(&self, o: &{{STRUCT_NAME}}) -> " diff --git a/tests/monster_test_generated.rs b/tests/monster_test_generated.rs index d2316d745..020a543eb 100644 --- a/tests/monster_test_generated.rs +++ b/tests/monster_test_generated.rs @@ -912,8 +912,8 @@ impl<'a> Monster<'a> { self._tab.get::(Monster::VT_HP, Some(100)).unwrap() } #[inline] - pub fn name(&'a self) -> Option<&'a str> { - self._tab.get::>(Monster::VT_NAME, None) + pub fn name(&'a self) -> &'a str { + self._tab.get::>(Monster::VT_NAME, None).unwrap() } #[inline] pub fn key_compare_less_than(&self, o: &Monster) -> bool { @@ -921,7 +921,7 @@ impl<'a> Monster<'a> { } #[inline] - pub fn key_compare_with_value(&self, val: Option<&str>) -> ::std::cmp::Ordering { + pub fn key_compare_with_value(&self, val: & str) -> ::std::cmp::Ordering { let key = self.name(); key.cmp(&val) } diff --git a/tests/rust_usage_test/tests/integration_test.rs b/tests/rust_usage_test/tests/integration_test.rs index 4f3eca435..0e5f8fac9 100644 --- a/tests/rust_usage_test/tests/integration_test.rs +++ b/tests/rust_usage_test/tests/integration_test.rs @@ -159,8 +159,7 @@ fn serialized_example_is_accessible_and_correct(bytes: &[u8], identifier_require check_eq!(m.hp(), 80)?; check_eq!(m.mana(), 150)?; - check_eq!(m.name(), Some("MyMonster"))?; - check_is_some!(m.name())?; + check_eq!(m.name(), "MyMonster")?; let pos = m.pos().unwrap(); check_eq!(pos.x(), 1.0f32)?; @@ -178,7 +177,7 @@ fn serialized_example_is_accessible_and_correct(bytes: &[u8], identifier_require let table2 = m.test().unwrap(); let monster2 = my_game::example::Monster::init_from_table(table2); - check_eq!(monster2.name(), Some("Fred"))?; + check_eq!(monster2.name(), "Fred")?; check_is_some!(m.inventory())?; let inv = m.inventory().unwrap(); @@ -266,7 +265,7 @@ mod roundtrip_generated_code { let mut b = flatbuffers::FlatBufferBuilder::new(); let name = b.create_string("foobar"); let m = build_mon(&mut b, &my_game::example::MonsterArgs{name: Some(name), ..Default::default()}); - assert_eq!(m.name(), Some("foobar")); + assert_eq!(m.name(), "foobar"); } #[test] fn struct_store() { @@ -325,11 +324,11 @@ mod roundtrip_generated_code { } let mon = my_game::example::get_root_as_monster(b.finished_data()); - assert_eq!(mon.name(), Some("bar")); + assert_eq!(mon.name(), "bar"); assert_eq!(mon.test_type(), my_game::example::Any::Monster); assert_eq!(my_game::example::Monster::init_from_table(mon.test().unwrap()).name(), - Some("foo")); - assert_eq!(mon.test_as_monster().unwrap().name(), Some("foo")); + "foo"); + assert_eq!(mon.test_as_monster().unwrap().name(), "foo"); assert_eq!(mon.test_as_test_simple_table_with_enum(), None); assert_eq!(mon.test_as_my_game_example_2_monster(), None); } @@ -361,8 +360,8 @@ mod roundtrip_generated_code { } let mon = my_game::example::get_root_as_monster(b.finished_data()); - assert_eq!(mon.name(), Some("bar")); - assert_eq!(mon.enemy().unwrap().name(), Some("foo")); + assert_eq!(mon.name(), "bar"); + assert_eq!(mon.enemy().unwrap().name(), "foo"); } #[test] fn table_full_namespace_default() { @@ -391,7 +390,7 @@ mod roundtrip_generated_code { } let mon = my_game::example::get_root_as_monster(b.finished_data()); - assert_eq!(mon.name(), Some("bar")); + assert_eq!(mon.name(), "bar"); assert_eq!(mon.testempty().unwrap().id(), Some("foo")); } #[test] @@ -434,13 +433,13 @@ mod roundtrip_generated_code { let m2_a = my_game::example::get_root_as_monster(m.testnestedflatbuffer().unwrap()); assert_eq!(m2_a.hp(), 123); - assert_eq!(m2_a.name(), Some("foobar")); + assert_eq!(m2_a.name(), "foobar"); assert!(m.testnestedflatbuffer_nested_flatbuffer().is_some()); let m2_b = m.testnestedflatbuffer_nested_flatbuffer().unwrap(); assert_eq!(m2_b.hp(), 123); - assert_eq!(m2_b.name(), Some("foobar")); + assert_eq!(m2_b.name(), "foobar"); } #[test] fn nested_flatbuffer_default() { @@ -561,9 +560,9 @@ mod roundtrip_generated_code { testarrayoftables: Some(v), ..Default::default()}); assert_eq!(m.testarrayoftables().unwrap().len(), 2); assert_eq!(m.testarrayoftables().unwrap().get(0).hp(), 55); - assert_eq!(m.testarrayoftables().unwrap().get(0).name(), Some("foo")); + assert_eq!(m.testarrayoftables().unwrap().get(0).name(), "foo"); assert_eq!(m.testarrayoftables().unwrap().get(1).hp(), 100); - assert_eq!(m.testarrayoftables().unwrap().get(1).name(), Some("bar")); + assert_eq!(m.testarrayoftables().unwrap().get(1).name(), "bar"); } } @@ -947,7 +946,7 @@ mod framing_format { let m = flatbuffers::get_size_prefixed_root::(buf); assert_eq!(m.mana(), 200); assert_eq!(m.hp(), 300); - assert_eq!(m.name(), Some("bob")); + assert_eq!(m.name(), "bob"); } } @@ -1459,13 +1458,11 @@ mod generated_key_comparisons { let a = my_game::example::get_root_as_monster(buf); // preconditions - assert_eq!(a.name(), Some("MyMonster")); + assert_eq!(a.name(), "MyMonster"); - assert_eq!(a.key_compare_with_value(None), ::std::cmp::Ordering::Greater); - - assert_eq!(a.key_compare_with_value(Some("AAA")), ::std::cmp::Ordering::Greater); - assert_eq!(a.key_compare_with_value(Some("MyMonster")), ::std::cmp::Ordering::Equal); - assert_eq!(a.key_compare_with_value(Some("ZZZ")), ::std::cmp::Ordering::Less); + assert_eq!(a.key_compare_with_value("AAA"), ::std::cmp::Ordering::Greater); + assert_eq!(a.key_compare_with_value("MyMonster"), ::std::cmp::Ordering::Equal); + assert_eq!(a.key_compare_with_value("ZZZ"), ::std::cmp::Ordering::Less); } #[test] @@ -1478,8 +1475,8 @@ mod generated_key_comparisons { let b = a.test_as_monster().unwrap(); // preconditions - assert_eq!(a.name(), Some("MyMonster")); - assert_eq!(b.name(), Some("Fred")); + assert_eq!(a.name(), "MyMonster"); + assert_eq!(b.name(), "Fred"); assert_eq!(a.key_compare_less_than(&a), false); assert_eq!(a.key_compare_less_than(&b), false);