#include "internal/Swizzle.h"
#include "internal/Model.h"
-inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape)
+inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape,
+ bool apply_dim_correction = true)
{
const uint32_t rank = shape.rank();
for (uint32_t axis = 0; axis < rank; ++axis)
{
- // NOTE Do NOT update TensorShape with operator[] (in ::arm_compute::Dimensions)
- // TensorShape::set applies dimension correction after value update.
- // Various asserts in ARMCompute work only when this correction is applied.
- res.set(ToARMComputeAxis(rank, axis).value(), shape.dim(axis));
+ // NOTE In some cases, in incorrect dimensions is required.
+ // For example, intput_size is 1 in LSTM. The input-to-input weights([num_units, input_size]) of
+ // LSTM is used as the weight of the FullyConnected.
+ // The FullyConnected's weight must be greater or equal than 2-dimensions.
+ // However, if the dimension correction is applied to input_to_input_weights with input_size
+ // equal to 1, it will be changed to 1-D.
+ // So input_to_input_weights is not used by the weight of FullyConnected.
+ res.set(ToARMComputeAxis(rank, axis).value(), shape.dim(axis), apply_dim_correction);
}
return res;