From 521e255ad9656a213971b30ba1beeec395b2e27e Mon Sep 17 00:00:00 2001 From: Mathias Svensson Date: Mon, 28 Oct 2019 19:19:47 +0100 Subject: [PATCH] Rust: Add idiomatic iterator for Vector type (#5579) * Rust: Add idiomatic iterator for Vector type * Add comments explaining some implementation details --- rust/flatbuffers/src/vector.rs | 132 +++++++++++++++++++++++- tests/rust_usage_test/tests/integration_test.rs | 75 ++++++++++++++ 2 files changed, 204 insertions(+), 3 deletions(-) diff --git a/rust/flatbuffers/src/vector.rs b/rust/flatbuffers/src/vector.rs index ef1986a..354aec7 100644 --- a/rust/flatbuffers/src/vector.rs +++ b/rust/flatbuffers/src/vector.rs @@ -14,14 +14,15 @@ * limitations under the License. */ +use std::iter::{DoubleEndedIterator, ExactSizeIterator, FusedIterator}; use std::marker::PhantomData; use std::mem::size_of; use std::slice::from_raw_parts; use std::str::from_utf8_unchecked; +use endian_scalar::read_scalar_at; #[cfg(target_endian = "little")] use endian_scalar::EndianScalar; -use endian_scalar::{read_scalar, read_scalar_at}; use follow::Follow; use primitives::*; @@ -50,7 +51,7 @@ impl<'a, T: 'a> Vector<'a, T> { #[inline(always)] pub fn len(&self) -> usize { - read_scalar::(&self.0[self.1 as usize..]) as usize + read_scalar_at::(&self.0, self.1) as usize } #[inline(always)] pub fn is_empty(&self) -> bool { @@ -61,11 +62,16 @@ impl<'a, T: 'a> Vector<'a, T> { impl<'a, T: Follow<'a> + 'a> Vector<'a, T> { #[inline(always)] pub fn get(&self, idx: usize) -> T::Inner { - debug_assert!(idx < read_scalar::(&self.0[self.1 as usize..]) as usize); + debug_assert!(idx < read_scalar_at::(&self.0, self.1) as usize); let sz = size_of::(); debug_assert!(sz > 0); T::follow(self.0, self.1 as usize + SIZE_UOFFSET + sz * idx) } + + #[inline(always)] + pub fn iter(&self) -> VectorIter<'a, T> { + VectorIter::new(*self) + } } pub trait SafeSliceAccess {} @@ -147,3 +153,123 @@ impl<'a, T: Follow<'a> + 'a> Follow<'a> for Vector<'a, T> { Vector::new(buf, loc) } } + +#[derive(Debug)] +pub struct VectorIter<'a, T: 'a> { + buf: &'a [u8], + loc: usize, + remaining: usize, + phantom: PhantomData, +} + +impl<'a, T: 'a> VectorIter<'a, T> { + #[inline] + pub fn new(inner: Vector<'a, T>) -> Self { + VectorIter { + buf: inner.0, + // inner.1 is the location of the data for the vector. + // The first SIZE_UOFFSET bytes is the length. We skip + // that to get to the actual vector content. + loc: inner.1 + SIZE_UOFFSET, + remaining: inner.len(), + phantom: PhantomData, + } + } +} + +impl<'a, T: Follow<'a> + 'a> Clone for VectorIter<'a, T> { + #[inline] + fn clone(&self) -> Self { + VectorIter { + buf: self.buf, + loc: self.loc, + remaining: self.remaining, + phantom: self.phantom, + } + } +} + +impl<'a, T: Follow<'a> + 'a> Iterator for VectorIter<'a, T> { + type Item = T::Inner; + + #[inline] + fn next(&mut self) -> Option { + let sz = size_of::(); + debug_assert!(sz > 0); + + if self.remaining == 0 { + None + } else { + let result = T::follow(self.buf, self.loc); + self.loc += sz; + self.remaining -= 1; + Some(result) + } + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let sz = size_of::(); + debug_assert!(sz > 0); + + self.remaining = self.remaining.saturating_sub(n); + + // Note that this might overflow, but that is okay because + // in that case self.remaining will have been set to zero. + self.loc = self.loc.wrapping_add(sz * n); + + self.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +impl<'a, T: Follow<'a> + 'a> DoubleEndedIterator for VectorIter<'a, T> { + #[inline] + fn next_back(&mut self) -> Option { + let sz = size_of::(); + debug_assert!(sz > 0); + + if self.remaining == 0 { + None + } else { + self.remaining -= 1; + Some(T::follow(self.buf, self.loc + sz * self.remaining)) + } + } + + #[inline] + fn nth_back(&mut self, n: usize) -> Option { + self.remaining = self.remaining.saturating_sub(n); + self.next_back() + } +} + +impl<'a, T: 'a + Follow<'a>> ExactSizeIterator for VectorIter<'a, T> { + #[inline] + fn len(&self) -> usize { + self.remaining + } +} + +impl<'a, T: 'a + Follow<'a>> FusedIterator for VectorIter<'a, T> {} + +impl<'a, T: Follow<'a> + 'a> IntoIterator for Vector<'a, T> { + type Item = T::Inner; + type IntoIter = VectorIter<'a, T>; + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, 'b, T: Follow<'a> + 'a> IntoIterator for &'b Vector<'a, T> { + type Item = T::Inner; + type IntoIter = VectorIter<'a, T>; + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/tests/rust_usage_test/tests/integration_test.rs b/tests/rust_usage_test/tests/integration_test.rs index c2af738..c81f71a 100644 --- a/tests/rust_usage_test/tests/integration_test.rs +++ b/tests/rust_usage_test/tests/integration_test.rs @@ -184,6 +184,7 @@ fn serialized_example_is_accessible_and_correct(bytes: &[u8], identifier_require let inv = m.inventory().unwrap(); check_eq!(inv.len(), 5)?; check_eq!(inv.iter().sum::(), 10u8)?; + check_eq!(inv.iter().rev().sum::(), 10u8)?; check_is_some!(m.test4())?; let test4 = m.test4().unwrap(); @@ -509,6 +510,18 @@ mod roundtrip_generated_code { assert_eq!(m.testarrayofstring().unwrap().len(), 2); assert_eq!(m.testarrayofstring().unwrap().get(0), "foobar"); assert_eq!(m.testarrayofstring().unwrap().get(1), "baz"); + + let rust_vec_inst = m.testarrayofstring().unwrap(); + let rust_vec_iter_collect = rust_vec_inst.iter().collect::>(); + assert_eq!(rust_vec_iter_collect.len(), 2); + assert_eq!(rust_vec_iter_collect[0], "foobar"); + assert_eq!(rust_vec_iter_collect[1], "baz"); + + let rust_vec_iter_rev_collect = rust_vec_inst.iter().rev().collect::>(); + assert_eq!(rust_vec_iter_rev_collect.len(), 2); + assert_eq!(rust_vec_iter_rev_collect[1], "foobar"); + assert_eq!(rust_vec_iter_rev_collect[0], "baz"); + } #[test] fn vector_of_string_store_manual_build() { @@ -523,6 +536,17 @@ mod roundtrip_generated_code { assert_eq!(m.testarrayofstring().unwrap().len(), 2); assert_eq!(m.testarrayofstring().unwrap().get(0), "foobar"); assert_eq!(m.testarrayofstring().unwrap().get(1), "baz"); + + let rust_vec_inst = m.testarrayofstring().unwrap(); + let rust_vec_iter_collect = rust_vec_inst.iter().collect::>(); + assert_eq!(rust_vec_iter_collect.len(), 2); + assert_eq!(rust_vec_iter_collect[0], "foobar"); + assert_eq!(rust_vec_iter_collect[1], "baz"); + + let rust_vec_iter_rev_collect = rust_vec_inst.iter().rev().collect::>(); + assert_eq!(rust_vec_iter_rev_collect.len(), 2); + assert_eq!(rust_vec_iter_rev_collect[0], "baz"); + assert_eq!(rust_vec_iter_rev_collect[1], "foobar"); } #[test] fn vector_of_ubyte_store() { @@ -543,6 +567,13 @@ mod roundtrip_generated_code { name: Some(name), testarrayofbools: Some(v), ..Default::default()}); assert_eq!(m.testarrayofbools().unwrap(), &[false, true, false, true][..]); + + let rust_vec_inst = m.testarrayofbools().unwrap(); + let rust_vec_iter_collect = rust_vec_inst.iter().collect::>(); + assert_eq!(rust_vec_iter_collect, &[&false, &true, &false, &true][..]); + + let rust_vec_iter_rev_collect = rust_vec_inst.iter().rev().collect::>(); + assert_eq!(rust_vec_iter_rev_collect, &[&true, &false, &true, &false][..]); } #[test] fn vector_of_f64_store() { @@ -554,6 +585,15 @@ mod roundtrip_generated_code { vector_of_doubles: Some(v), ..Default::default()}); assert_eq!(m.vector_of_doubles().unwrap().len(), 1); assert_eq!(m.vector_of_doubles().unwrap().get(0), 3.14159265359f64); + + let rust_vec_inst = m.vector_of_doubles().unwrap(); + let rust_vec_iter_collect = rust_vec_inst.iter().collect::>(); + assert_eq!(rust_vec_iter_collect.len(), 1); + assert_eq!(rust_vec_iter_collect[0], 3.14159265359f64); + + let rust_vec_iter_rev_collect = rust_vec_inst.iter().rev().collect::>(); + assert_eq!(rust_vec_iter_rev_collect.len(), 1); + assert_eq!(rust_vec_iter_rev_collect[0], 3.14159265359f64); } #[test] fn vector_of_struct_store() { @@ -564,6 +604,13 @@ mod roundtrip_generated_code { name: Some(name), test4: Some(v), ..Default::default()}); assert_eq!(m.test4().unwrap(), &[my_game::example::Test::new(127, -128), my_game::example::Test::new(3, 123)][..]); + + let rust_vec_inst = m.test4().unwrap(); + let rust_vec_iter_collect = rust_vec_inst.iter().collect::>(); + assert_eq!(rust_vec_iter_collect, &[&my_game::example::Test::new(127, -128), &my_game::example::Test::new(3, 123)][..]); + + let rust_vec_iter_rev_collect = rust_vec_inst.iter().rev().collect::>(); + assert_eq!(rust_vec_iter_rev_collect, &[&my_game::example::Test::new(3, 123), &my_game::example::Test::new(127, -128)][..]); } #[test] fn vector_of_struct_store_with_type_inference() { @@ -613,6 +660,21 @@ mod roundtrip_generated_code { assert_eq!(m.testarrayoftables().unwrap().get(0).name(), "foo"); assert_eq!(m.testarrayoftables().unwrap().get(1).hp(), 100); assert_eq!(m.testarrayoftables().unwrap().get(1).name(), "bar"); + + let rust_vec_inst = m.testarrayoftables().unwrap(); + let rust_vec_iter_collect = rust_vec_inst.iter().collect::>(); + assert_eq!(rust_vec_iter_collect.len(), 2); + assert_eq!(rust_vec_iter_collect[0].hp(), 55); + assert_eq!(rust_vec_iter_collect[0].name(), "foo"); + assert_eq!(rust_vec_iter_collect[1].hp(), 100); + assert_eq!(rust_vec_iter_collect[1].name(), "bar"); + + let rust_vec_iter_rev_collect = rust_vec_inst.iter().rev().collect::>(); + assert_eq!(rust_vec_iter_rev_collect.len(), 2); + assert_eq!(rust_vec_iter_rev_collect[0].hp(), 100); + assert_eq!(rust_vec_iter_rev_collect[0].name(), "bar"); + assert_eq!(rust_vec_iter_rev_collect[1].hp(), 55); + assert_eq!(rust_vec_iter_rev_collect[1].name(), "foo"); } } @@ -721,6 +783,12 @@ mod generated_code_alignment_and_padding { let aln = ::std::mem::align_of::(); assert_eq!((a_ptr - start_ptr) % aln, 0); } + for a in abilities.iter().rev() { + let a_ptr = a as *const my_game::example::Ability as usize; + assert!(a_ptr > start_ptr); + let aln = ::std::mem::align_of::(); + assert_eq!((a_ptr - start_ptr) % aln, 0); + } } } @@ -806,6 +874,13 @@ mod roundtrip_vectors { result_vec.push(got.get(i)); } assert_eq!(result_vec, xs); + + let rust_vec_iter = got.iter().collect::>(); + assert_eq!(rust_vec_iter, xs); + + let mut rust_vec_rev_iter = got.iter().rev().collect::>(); + rust_vec_rev_iter.reverse(); + assert_eq!(rust_vec_rev_iter, xs); } #[test] -- 2.7.4