Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Reshape.cpp
index 61d3300..d3234e4 100644 (file)
@@ -17,6 +17,8 @@
 
 #include "kernels/Reshape.h"
 
+#include "kernels/Utils.h"
+
 #include <cassert>
 #include <cstring>
 
@@ -28,12 +30,26 @@ namespace kernels
 
 static Shape extractShapeFromTensor(const Tensor *tensor)
 {
-  assert(tensor->element_type() == DataType::S32);
   Shape shape(tensor->shape().num_elements());
-  const auto *shape_data = tensor->data<int32_t>();
-  for (int i = 0; i < tensor->shape().num_elements(); ++i)
+  if (tensor->element_type() == DataType::S32)
+  {
+    const auto *shape_data = tensor->data<int32_t>();
+    for (int i = 0; i < tensor->shape().num_elements(); ++i)
+    {
+      shape.dim(i) = shape_data[i];
+    }
+  }
+  else if (tensor->element_type() == DataType::S64)
+  {
+    const auto *shape_data = tensor->data<int64_t>();
+    for (int i = 0; i < tensor->shape().num_elements(); ++i)
+    {
+      shape.dim(i) = static_cast<int32_t>(shape_data[i]);
+    }
+  }
+  else
   {
-    shape.dim(i) = shape_data[i];
+    LUCI_INTERPRETER_CHECK(false);
   }
   return shape;
 }