#include "kernels/Reshape.h"
+#include "kernels/Utils.h"
+
#include <cassert>
#include <cstring>
static Shape extractShapeFromTensor(const Tensor *tensor)
{
- assert(tensor->element_type() == DataType::S32);
Shape shape(tensor->shape().num_elements());
- const auto *shape_data = tensor->data<int32_t>();
- for (int i = 0; i < tensor->shape().num_elements(); ++i)
+ if (tensor->element_type() == DataType::S32)
+ {
+ const auto *shape_data = tensor->data<int32_t>();
+ for (int i = 0; i < tensor->shape().num_elements(); ++i)
+ {
+ shape.dim(i) = shape_data[i];
+ }
+ }
+ else if (tensor->element_type() == DataType::S64)
+ {
+ const auto *shape_data = tensor->data<int64_t>();
+ for (int i = 0; i < tensor->shape().num_elements(); ++i)
+ {
+ shape.dim(i) = static_cast<int32_t>(shape_data[i]);
+ }
+ }
+ else
{
- shape.dim(i) = shape_data[i];
+ LUCI_INTERPRETER_CHECK(false);
}
return shape;
}