[mir2loco] Supported other types in ConstantOp transformer (#6516)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Mon, 12 Aug 2019 17:11:17 +0000 (20:11 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 12 Aug 2019 17:11:17 +0000 (20:11 +0300)
* Supported F32, F64, S32, S64 data types
* Temporary implementations for F64 and S64

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir2loco/src/mir2loco.cpp

index 4e0942a..1f60858 100644 (file)
@@ -24,6 +24,8 @@
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ReshapeOp.h"
 
+#include "mir/ShapeRange.h"
+
 #include <stdex/Memory.h>
 
 #include <cstring>
@@ -182,13 +184,63 @@ void Transformer::visit(mir::ops::ConstantOp &op)
   const auto &value = op.getValue();
   const_node->dtype(ConvertDataType(value.getDataType()));
   // TODO Support other data types
-  assert(const_node->dtype() == loco::DataType::FLOAT32);
-  const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
-  // TODO Change that when loco support other DataTypeImpl
-  float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
-  char *loco_ptr = reinterpret_cast<char *>(&const_float);
-  char *mir_ptr = value.at(mir::Index(out_shape.rank()));
-  std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(float));
+  switch (const_node->dtype())
+  {
+    case loco::DataType::FLOAT32:
+    {
+      const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
+      float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
+      char *loco_ptr = reinterpret_cast<char *>(&const_float);
+      char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+      std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(float));
+      break;
+    }
+    case loco::DataType::FLOAT64:
+    {
+      // TODO Change that when loco support other DataTypeImpl
+      const_node->dtype(loco::DataType::FLOAT32);
+      const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
+      float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
+      char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+      double *mir_double = reinterpret_cast<double *>(mir_ptr);
+      float *loco_float = &const_float;
+      for (const mir::Index &idx : mir::ShapeRange(out_shape))
+      {
+        *loco_float = static_cast<float>(*mir_double);
+        loco_float++;
+        mir_double++;
+      }
+      break;
+    }
+    case loco::DataType::S32:
+    {
+      const_node->size<loco::DataType::S32>(out_shape.numElements());
+      int32_t &const_int32 = const_node->at<loco::DataType::S32>(0);
+      char *loco_ptr = reinterpret_cast<char *>(&const_int32);
+      char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+      std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(int32_t));
+      break;
+    }
+    case loco::DataType::S64:
+    {
+      // TODO Change that when loco support other DataTypeImpl
+      const_node->dtype(loco::DataType::S32);
+      const_node->size<loco::DataType::S32>(out_shape.numElements());
+      int32_t &const_int32 = const_node->at<loco::DataType::S32>(0);
+      char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+      int64_t *mir_int64 = reinterpret_cast<int64_t *>(mir_ptr);
+      int32_t *loco_int32 = &const_int32;
+      for (const mir::Index &idx : mir::ShapeRange(out_shape))
+      {
+        *loco_int32 = static_cast<float>(*mir_int64);
+        loco_int32++;
+        mir_int64++;
+      }
+      break;
+    }
+    default:
+      std::runtime_error("Unsupported data type");
+  }
   // Add to map
   _mir2loco_map.emplace(&op, const_node);
 }