#include "mir/ops/ReluOp.h"
#include "mir/ops/ReshapeOp.h"
+#include "mir/ShapeRange.h"
+
#include <stdex/Memory.h>
#include <cstring>
const auto &value = op.getValue();
const_node->dtype(ConvertDataType(value.getDataType()));
// TODO Support other data types
- assert(const_node->dtype() == loco::DataType::FLOAT32);
- const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
- // TODO Change that when loco support other DataTypeImpl
- float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
- char *loco_ptr = reinterpret_cast<char *>(&const_float);
- char *mir_ptr = value.at(mir::Index(out_shape.rank()));
- std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(float));
+ switch (const_node->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ {
+ const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
+ float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
+ char *loco_ptr = reinterpret_cast<char *>(&const_float);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(float));
+ break;
+ }
+ case loco::DataType::FLOAT64:
+ {
+ // TODO Change that when loco support other DataTypeImpl
+ const_node->dtype(loco::DataType::FLOAT32);
+ const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
+ float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ double *mir_double = reinterpret_cast<double *>(mir_ptr);
+ float *loco_float = &const_float;
+ for (const mir::Index &idx : mir::ShapeRange(out_shape))
+ {
+ *loco_float = static_cast<float>(*mir_double);
+ loco_float++;
+ mir_double++;
+ }
+ break;
+ }
+ case loco::DataType::S32:
+ {
+ const_node->size<loco::DataType::S32>(out_shape.numElements());
+ int32_t &const_int32 = const_node->at<loco::DataType::S32>(0);
+ char *loco_ptr = reinterpret_cast<char *>(&const_int32);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(int32_t));
+ break;
+ }
+ case loco::DataType::S64:
+ {
+ // TODO Change that when loco support other DataTypeImpl
+ const_node->dtype(loco::DataType::S32);
+ const_node->size<loco::DataType::S32>(out_shape.numElements());
+ int32_t &const_int32 = const_node->at<loco::DataType::S32>(0);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ int64_t *mir_int64 = reinterpret_cast<int64_t *>(mir_ptr);
+ int32_t *loco_int32 = &const_int32;
+ for (const mir::Index &idx : mir::ShapeRange(out_shape))
+ {
+ *loco_int32 = static_cast<float>(*mir_int64);
+ loco_int32++;
+ mir_int64++;
+ }
+ break;
+ }
+ default:
+ std::runtime_error("Unsupported data type");
+ }
// Add to map
_mir2loco_map.emplace(&op, const_node);
}