arm_compute v18.05
[platform/upstream/armcl.git] / tests / validation / reference / FlattenLayer.cpp
index 611701d..44f4d93 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,12 +34,8 @@ namespace validation
 namespace reference
 {
 template <typename T>
-SimpleTensor<T> flatten_layer(const SimpleTensor<T> &src)
+SimpleTensor<T> flatten_layer(const SimpleTensor<T> &src, const TensorShape &shape_flatten)
 {
-    TensorShape shape_flatten(src.shape());
-    shape_flatten.set(0, src.shape()[0] * src.shape()[1] * src.shape()[2]);
-    shape_flatten.remove_dimension(1);
-    shape_flatten.remove_dimension(1);
     SimpleTensor<T> dst(shape_flatten, src.data_type(), 1, src.fixed_point_position());
 
     // Note: Since the reference implementation does not use padding bytes, we can copy directly the content of the source tensor
@@ -48,10 +44,10 @@ SimpleTensor<T> flatten_layer(const SimpleTensor<T> &src)
     return dst;
 }
 
-template SimpleTensor<float> flatten_layer(const SimpleTensor<float> &src);
-template SimpleTensor<half> flatten_layer(const SimpleTensor<half> &src);
-template SimpleTensor<qint8_t> flatten_layer(const SimpleTensor<qint8_t> &src);
-template SimpleTensor<qint16_t> flatten_layer(const SimpleTensor<qint16_t> &src);
+template SimpleTensor<float> flatten_layer(const SimpleTensor<float> &src, const TensorShape &shape_flatten);
+template SimpleTensor<half> flatten_layer(const SimpleTensor<half> &src, const TensorShape &shape_flatten);
+template SimpleTensor<qint8_t> flatten_layer(const SimpleTensor<qint8_t> &src, const TensorShape &shape_flatten);
+template SimpleTensor<qint16_t> flatten_layer(const SimpleTensor<qint16_t> &src, const TensorShape &shape_flatten);
 } // namespace reference
 } // namespace validation
 } // namespace test