Implement the plan of Cast operation (#1574)
author장지섭/동작제어Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Mon, 11 Jun 2018 05:28:15 +0000 (14:28 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 11 Jun 2018 05:28:15 +0000 (14:28 +0900)
This commit implements the plan of Cast operation.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc
runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h
runtimes/pure_arm_compute/src/internal/layers/SimpleCastLayer.h [new file with mode: 0644]

index ac79353..1957bcf 100644 (file)
@@ -28,6 +28,7 @@
 #include "internal/arm_compute/feature/View.h"
 #include "internal/layers/GenericReshapeLayer.h"
 #include "internal/layers/SimpleArithmeticAdditionLayer.h"
+#include "internal/layers/SimpleCastLayer.h"
 
 #include "util/kernel/IndexIterator.h"
 #include "util/feature/IndexIterator.h"
@@ -1448,7 +1449,75 @@ void Planner::visit(const ::internal::tflite::op::ReduceMax::Node &node)
 
 void Planner::visit(const ::internal::tflite::op::Cast::Node &node)
 {
-  // TODO Implement the plan of Cast
+  const ::internal::tflite::operand::Index output_index{node.param().output_index};
+  const ::internal::tflite::operand::Index input_index{node.param().input_index};
+
+  const auto output_shape = _ctx.at(output_index).shape();
+  const auto input_shape = _ctx.at(input_index).shape();
+  assert(output_shape.rank() == input_shape.rank());
+  for (uint32_t n = 0; n < input_shape.rank(); ++n)
+  {
+    assert(output_shape.dim(n) == input_shape.dim(n));
+  }
+
+  // TODO Should move to the place where the operand is handled, if it is possible.
+  // Set Shape Constraints and TensorInfo
+  switch (input_shape.rank())
+  {
+    case 0: // scalar
+    {
+      _builder.addShapeConstr(output_index, asTensorInfo(1, _ctx.at(output_index).type()));
+      _builder.addShapeConstr(input_index, asTensorInfo(1, _ctx.at(input_index).type()));
+      break;
+    }
+    case 1: // vector
+    {
+      _builder.addShapeConstr(output_index,
+                              asTensorInfo(input_shape.asVector(), _ctx.at(output_index).type()));
+      _builder.addShapeConstr(input_index,
+                              asTensorInfo(output_shape.asVector(), _ctx.at(input_index).type()));
+      break;
+    }
+    case 4: // feature
+    {
+      _builder.addShapeConstr(output_index,
+                              asTensorInfo(input_shape.asFeature(), _ctx.at(output_index).type()));
+      _builder.addShapeConstr(input_index,
+                              asTensorInfo(output_shape.asFeature(), _ctx.at(input_index).type()));
+      break;
+    }
+    default:
+      throw std::runtime_error("Not supported, yet");
+      break;
+  }
+
+  // Construct operation parameters
+  struct Param
+  {
+    int input_index;
+    int output_index;
+  };
+
+  Param param;
+
+  param.output_index = output_index.asInt();
+  param.input_index = input_index.asInt();
+
+  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
+    auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index});
+    auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index});
+
+    std::unique_ptr<::arm_compute::IFunction> fn;
+
+    auto l = make_layer<SimpleCastLayer>();
+
+    l->configure(input_alloc, output_alloc);
+    fn = std::move(l);
+
+    builder.append(std::move(fn));
+  };
+
+  _builder.addStage(stage);
 }
 
 void Planner::visit(const ::internal::tflite::op::TopKV2::Node &node)
index 2a22253..c8f386f 100644 (file)
@@ -49,4 +49,30 @@ inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand:
   return ::arm_compute::TensorInfo(asTensorShape(shape), 1, asDataType(type));
 }
 
+template <typename FromT>
+void copyCast(const FromT value, ::arm_compute::ICLTensor *to, const ::arm_compute::Coordinates &id)
+{
+  switch (to->info()->data_type())
+  {
+    case ::arm_compute::DataType::F32:
+    {
+      *reinterpret_cast<float *>(to->ptr_to_element(id)) = static_cast<float>(value);
+      break;
+    }
+    case ::arm_compute::DataType::S32:
+    {
+      *reinterpret_cast<int32_t *>(to->ptr_to_element(id)) = static_cast<int32_t>(value);
+      break;
+    }
+    case ::arm_compute::DataType::U32:
+    {
+      *reinterpret_cast<uint32_t *>(to->ptr_to_element(id)) = static_cast<uint32_t>(value);
+      break;
+    }
+    default:
+      throw std::runtime_error("Not supported, yet");
+      break;
+  }
+}
+
 #endif // __ARM_COMPUTE_CAST_H__
diff --git a/runtimes/pure_arm_compute/src/internal/layers/SimpleCastLayer.h b/runtimes/pure_arm_compute/src/internal/layers/SimpleCastLayer.h
new file mode 100644 (file)
index 0000000..634dfa4
--- /dev/null
@@ -0,0 +1,66 @@
+#ifndef __SIMPLE_CAST_LAYER_H__
+#define __SIMPLE_CAST_LAYER_H__
+
+#include <arm_compute/core/CL/ICLTensor.h>
+
+#include "internal/op/Cast.h"
+
+class SimpleCastLayer : public ::arm_compute::IFunction
+{
+public:
+  void configure(::arm_compute::ICLTensor *in, ::arm_compute::ICLTensor *out)
+  {
+    _in = in;
+    _out = out;
+  }
+
+public:
+  void run(void) override
+  {
+    auto &q = ::arm_compute::CLScheduler::get().queue();
+
+    _in->map(q);
+    _out->map(q);
+
+    arm_compute::Window window;
+    window.use_tensor_dimensions(_out->info()->tensor_shape());
+
+    execute_window_loop(window,
+                        [this](const arm_compute::Coordinates &id) { castData(_in, _out, id); });
+
+    _out->unmap(q);
+    _in->unmap(q);
+  }
+
+  void castData(::arm_compute::ICLTensor *in, ::arm_compute::ICLTensor *out,
+                const arm_compute::Coordinates &id)
+  {
+    switch (in->info()->data_type())
+    {
+      case ::arm_compute::DataType::F32:
+      {
+        copyCast(*reinterpret_cast<float *>(in->ptr_to_element(id)), out, id);
+        break;
+      }
+      case ::arm_compute::DataType::S32:
+      {
+        copyCast(*reinterpret_cast<int32_t *>(in->ptr_to_element(id)), out, id);
+        break;
+      }
+      case ::arm_compute::DataType::U32:
+      {
+        copyCast(*reinterpret_cast<uint32_t *>(in->ptr_to_element(id)), out, id);
+        break;
+      }
+      default:
+        throw std::runtime_error("Not supported, yet");
+        break;
+    }
+  }
+
+private:
+  ::arm_compute::ICLTensor *_in;
+  ::arm_compute::ICLTensor *_out;
+};
+
+#endif // __SIMPLE_CAST_LAYER_H__