From 29497ef910ca8af3f40b7566a550d2af5a3a5a69 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9E=A5=EC=A7=80=EC=84=AD/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Thu, 14 Jun 2018 20:10:27 +0900 Subject: [PATCH] Support other types of operand for SimpleArithmeticAdditionLayer (#1668) This commit supports other types of operand for SimpleArithmeticAdditionLayer. Signed-off-by: jiseob.jang --- .../layers/SimpleArithmeticAdditionLayer.h | 42 ++++++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/runtimes/pure_arm_compute/src/internal/layers/SimpleArithmeticAdditionLayer.h b/runtimes/pure_arm_compute/src/internal/layers/SimpleArithmeticAdditionLayer.h index 7fd5b00..58877f3 100644 --- a/runtimes/pure_arm_compute/src/internal/layers/SimpleArithmeticAdditionLayer.h +++ b/runtimes/pure_arm_compute/src/internal/layers/SimpleArithmeticAdditionLayer.h @@ -27,10 +27,46 @@ public: window.use_tensor_dimensions(_out->info()->tensor_shape()); execute_window_loop(window, [this](const arm_compute::Coordinates &id) { - const auto lhs_value = *reinterpret_cast(_lhs->ptr_to_element(id)); - const auto rhs_value = *reinterpret_cast(_rhs->ptr_to_element(id)); + // NOTE Must be two input tensors of identical type + // Must be output tensor of the same type as input0. + assert(_lhs->info()->data_type() == _rhs->info()->data_type()); + assert(_lhs->info()->data_type() == _out->info()->data_type()); - *reinterpret_cast(_out->ptr_to_element(id)) = lhs_value + rhs_value; + switch (_lhs->info()->data_type()) + { + case ::arm_compute::DataType::F32: + { + const auto lhs_value = *reinterpret_cast(_lhs->ptr_to_element(id)); + const auto rhs_value = *reinterpret_cast(_rhs->ptr_to_element(id)); + *reinterpret_cast(_out->ptr_to_element(id)) = lhs_value + rhs_value; + break; + } + case ::arm_compute::DataType::S32: + { + const auto lhs_value = *reinterpret_cast(_lhs->ptr_to_element(id)); + const auto rhs_value = *reinterpret_cast(_rhs->ptr_to_element(id)); + *reinterpret_cast(_out->ptr_to_element(id)) = lhs_value + rhs_value; + break; + } + case ::arm_compute::DataType::U32: + { + const auto lhs_value = *reinterpret_cast(_lhs->ptr_to_element(id)); + const auto rhs_value = *reinterpret_cast(_rhs->ptr_to_element(id)); + *reinterpret_cast(_out->ptr_to_element(id)) = lhs_value + rhs_value; + break; + } + case ::arm_compute::DataType::QASYMM8: + { + const auto lhs_value = *reinterpret_cast(_lhs->ptr_to_element(id)); + const auto rhs_value = *reinterpret_cast(_rhs->ptr_to_element(id)); + // How to handle with overflow? + *reinterpret_cast(_out->ptr_to_element(id)) = lhs_value + rhs_value; + break; + } + default: + throw std::runtime_error("Not supported, yet"); + break; + } }); _out->unmap(q); -- 2.7.4