From 93db1c0501d57b8c49f16456ddefe99c474216c8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9E=A5=EC=A7=80=EC=84=AD/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Mon, 11 Jun 2018 14:28:15 +0900 Subject: [PATCH] Implement the plan of Cast operation (#1574) This commit implements the plan of Cast operation. Signed-off-by: jiseob.jang --- runtimes/pure_arm_compute/src/compilation.cc | 71 +++++++++++++++++++++- .../src/internal/arm_compute/Cast.h | 26 ++++++++ .../src/internal/layers/SimpleCastLayer.h | 66 ++++++++++++++++++++ 3 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 runtimes/pure_arm_compute/src/internal/layers/SimpleCastLayer.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index ac79353..1957bcf 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -28,6 +28,7 @@ #include "internal/arm_compute/feature/View.h" #include "internal/layers/GenericReshapeLayer.h" #include "internal/layers/SimpleArithmeticAdditionLayer.h" +#include "internal/layers/SimpleCastLayer.h" #include "util/kernel/IndexIterator.h" #include "util/feature/IndexIterator.h" @@ -1448,7 +1449,75 @@ void Planner::visit(const ::internal::tflite::op::ReduceMax::Node &node) void Planner::visit(const ::internal::tflite::op::Cast::Node &node) { - // TODO Implement the plan of Cast + const ::internal::tflite::operand::Index output_index{node.param().output_index}; + const ::internal::tflite::operand::Index input_index{node.param().input_index}; + + const auto output_shape = _ctx.at(output_index).shape(); + const auto input_shape = _ctx.at(input_index).shape(); + assert(output_shape.rank() == input_shape.rank()); + for (uint32_t n = 0; n < input_shape.rank(); ++n) + { + assert(output_shape.dim(n) == input_shape.dim(n)); + } + + // TODO Should move to the place where the operand is handled, if it is possible. + // Set Shape Constraints and TensorInfo + switch (input_shape.rank()) + { + case 0: // scalar + { + _builder.addShapeConstr(output_index, asTensorInfo(1, _ctx.at(output_index).type())); + _builder.addShapeConstr(input_index, asTensorInfo(1, _ctx.at(input_index).type())); + break; + } + case 1: // vector + { + _builder.addShapeConstr(output_index, + asTensorInfo(input_shape.asVector(), _ctx.at(output_index).type())); + _builder.addShapeConstr(input_index, + asTensorInfo(output_shape.asVector(), _ctx.at(input_index).type())); + break; + } + case 4: // feature + { + _builder.addShapeConstr(output_index, + asTensorInfo(input_shape.asFeature(), _ctx.at(output_index).type())); + _builder.addShapeConstr(input_index, + asTensorInfo(output_shape.asFeature(), _ctx.at(input_index).type())); + break; + } + default: + throw std::runtime_error("Not supported, yet"); + break; + } + + // Construct operation parameters + struct Param + { + int input_index; + int output_index; + }; + + Param param; + + param.output_index = output_index.asInt(); + param.input_index = input_index.asInt(); + + auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) { + auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index}); + auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index}); + + std::unique_ptr<::arm_compute::IFunction> fn; + + auto l = make_layer(); + + l->configure(input_alloc, output_alloc); + fn = std::move(l); + + builder.append(std::move(fn)); + }; + + _builder.addStage(stage); } void Planner::visit(const ::internal::tflite::op::TopKV2::Node &node) diff --git a/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h b/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h index 2a22253..c8f386f 100644 --- a/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h +++ b/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h @@ -49,4 +49,30 @@ inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand: return ::arm_compute::TensorInfo(asTensorShape(shape), 1, asDataType(type)); } +template +void copyCast(const FromT value, ::arm_compute::ICLTensor *to, const ::arm_compute::Coordinates &id) +{ + switch (to->info()->data_type()) + { + case ::arm_compute::DataType::F32: + { + *reinterpret_cast(to->ptr_to_element(id)) = static_cast(value); + break; + } + case ::arm_compute::DataType::S32: + { + *reinterpret_cast(to->ptr_to_element(id)) = static_cast(value); + break; + } + case ::arm_compute::DataType::U32: + { + *reinterpret_cast(to->ptr_to_element(id)) = static_cast(value); + break; + } + default: + throw std::runtime_error("Not supported, yet"); + break; + } +} + #endif // __ARM_COMPUTE_CAST_H__ diff --git a/runtimes/pure_arm_compute/src/internal/layers/SimpleCastLayer.h b/runtimes/pure_arm_compute/src/internal/layers/SimpleCastLayer.h new file mode 100644 index 0000000..634dfa4 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/layers/SimpleCastLayer.h @@ -0,0 +1,66 @@ +#ifndef __SIMPLE_CAST_LAYER_H__ +#define __SIMPLE_CAST_LAYER_H__ + +#include + +#include "internal/op/Cast.h" + +class SimpleCastLayer : public ::arm_compute::IFunction +{ +public: + void configure(::arm_compute::ICLTensor *in, ::arm_compute::ICLTensor *out) + { + _in = in; + _out = out; + } + +public: + void run(void) override + { + auto &q = ::arm_compute::CLScheduler::get().queue(); + + _in->map(q); + _out->map(q); + + arm_compute::Window window; + window.use_tensor_dimensions(_out->info()->tensor_shape()); + + execute_window_loop(window, + [this](const arm_compute::Coordinates &id) { castData(_in, _out, id); }); + + _out->unmap(q); + _in->unmap(q); + } + + void castData(::arm_compute::ICLTensor *in, ::arm_compute::ICLTensor *out, + const arm_compute::Coordinates &id) + { + switch (in->info()->data_type()) + { + case ::arm_compute::DataType::F32: + { + copyCast(*reinterpret_cast(in->ptr_to_element(id)), out, id); + break; + } + case ::arm_compute::DataType::S32: + { + copyCast(*reinterpret_cast(in->ptr_to_element(id)), out, id); + break; + } + case ::arm_compute::DataType::U32: + { + copyCast(*reinterpret_cast(in->ptr_to_element(id)), out, id); + break; + } + default: + throw std::runtime_error("Not supported, yet"); + break; + } + } + +private: + ::arm_compute::ICLTensor *_in; + ::arm_compute::ICLTensor *_out; +}; + +#endif // __SIMPLE_CAST_LAYER_H__ -- 2.7.4