2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
20 #include "ir/DataType.h"
26 ::arm_compute::DataLayout asDataLayout(onert::ir::Layout layout)
30 case onert::ir::Layout::NHWC:
31 return ::arm_compute::DataLayout::NHWC;
32 case onert::ir::Layout::NCHW:
33 return ::arm_compute::DataLayout::NCHW;
35 return ::arm_compute::DataLayout::UNKNOWN;
48 ::arm_compute::TensorShape asTensorShape(const ir::Shape &shape, ir::Layout frontend_layout,
49 ir::Layout backend_layout, bool apply_dim_correction)
51 // If shape's rank is 0, the tensor is a scalar
52 // Sometimes, some ACL kernel can use a scalar as tensor. But ACL does not allocate buffer for
53 // tensor having rank as 0.
54 const auto tensor_shape = shape.rank() == 0 ? ir::Shape{1} : shape;
56 const uint32_t rank = tensor_shape.rank();
58 ::arm_compute::TensorShape res{};
60 res.set_num_dimensions(rank);
62 for (uint32_t axis = 0; axis < rank; ++axis)
64 // NOTE In some cases, in incorrect dimensions is required.
65 // For example, intput_size is 1 in LSTM. The input-to-input weights([num_units, input_size]) of
66 // LSTM is used as the weight of the FullyConnected.
67 // The FullyConnected's weight must be greater or equal than 2-dimensions.
68 // However, if the dimension correction is applied to input_to_input_weights with input_size
69 // equal to 1, it will be changed to 1-D.
70 // So input_to_input_weights is not used by the weight of FullyConnected.
71 res.set(ToARMComputeAxis(rank, axis, frontend_layout, backend_layout).value(),
72 tensor_shape.dim(axis), apply_dim_correction);
78 ::arm_compute::Coordinates asTensorCoordinate(const ir::Coordinates &coord,
79 ir::Layout frontend_layout, ir::Layout backend_layout)
81 const uint32_t rank = coord.size();
83 ::arm_compute::Coordinates res{};
85 res.set_num_dimensions(rank);
87 for (uint32_t axis = 0; axis < rank; ++axis)
89 res.set(ToARMComputeAxis(rank, axis, frontend_layout, backend_layout).value(), coord[axis]);
95 ::arm_compute::DataType asDataType(const ir::DataType type)
99 case ir::DataType::FLOAT32:
100 return ::arm_compute::DataType::F32;
101 case ir::DataType::INT32:
102 return ::arm_compute::DataType::S32;
103 case ir::DataType::UINT32:
104 return ::arm_compute::DataType::U32;
105 case ir::DataType::QUANT_UINT8_ASYMM:
106 return ::arm_compute::DataType::QASYMM8;
107 case ir::DataType::BOOL8:
108 case ir::DataType::UINT8:
109 return ::arm_compute::DataType::U8;
110 case ir::DataType::QUANT_INT8_SYMM:
111 return ::arm_compute::DataType::S8;
112 case ir::DataType::FLOAT16:
113 return ::arm_compute::DataType::F16;
115 throw std::runtime_error("Not supported, yet");
120 ::arm_compute::QuantizationInfo asQuantizationInfo(const float scale, const int32_t offset)
122 return ::arm_compute::QuantizationInfo(scale, offset);
125 ::arm_compute::TensorInfo asTensorInfo(const ir::Shape &shape, const ir::TypeInfo &typeInfo,
126 ir::Layout frontend_layout, ir::Layout backend_layout,
127 bool apply_dim_correction)
129 ::arm_compute::TensorInfo info(
130 asTensorShape(shape, frontend_layout, backend_layout, apply_dim_correction), 1,
131 asDataType(typeInfo.type()), asQuantizationInfo(typeInfo.scale(), typeInfo.offset()));
132 info.set_data_layout(asDataLayout(backend_layout));
136 ::arm_compute::PadStrideInfo asPadStrideInfo(const ir::ExplicitPadding &padding,
137 const ir::Stride &stride)
139 return ::arm_compute::PadStrideInfo{stride.horizontal,
145 ::arm_compute::DimensionRoundingType::FLOOR};
148 ::arm_compute::ActivationLayerInfo asActivationLayerInfo(const ir::Activation act_code)
152 case ir::Activation::NONE:
153 return ::arm_compute::ActivationLayerInfo{};
154 case ir::Activation::RELU:
155 return ::arm_compute::ActivationLayerInfo{
156 ::arm_compute::ActivationLayerInfo::ActivationFunction::RELU};
157 case ir::Activation::RELU1:
158 return ::arm_compute::ActivationLayerInfo{
159 ::arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 1.0f, -1.0f};
160 case ir::Activation::RELU6:
161 return ::arm_compute::ActivationLayerInfo{
162 ::arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.0f, 0.0f};
163 // Cases for activation of LSTM.
164 case ir::Activation::TANH:
165 return ::arm_compute::ActivationLayerInfo{
166 ::arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f};
167 case ir::Activation::SIGMOID:
168 // NOTE The sigmoid function is a special case of the Logistic function when L=1, k=1, x0=0.
169 // TODO In ACL and nnapi sepc, currently, Logistic's L always is 1, k always is 1, x0 always
170 // 0(always sigmoid) regardless of values of the parameter.
171 // If ACL support non-sigmoid logistic, should fix param values.
172 return ::arm_compute::ActivationLayerInfo{
173 ::arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC, 0.0f, 0.0f};
175 throw std::runtime_error{"Not supported, yet"};
180 arm_compute::Coordinates asCoordinates(const ir::Operand &operand, int32_t rank,
181 ir::Layout frontend_layout, ir::Layout backend_layout)
183 std::set<uint32_t> axes = asSet(operand, rank, frontend_layout, backend_layout);
185 arm_compute::Coordinates reduce_axes;
186 for (const int32_t axis : axes)
188 reduce_axes.set(reduce_axes.num_dimensions(), axis);
194 std::set<uint32_t> asSet(const ir::Operand &operand, int32_t rank, ir::Layout frontend_layout,
195 ir::Layout backend_layout)
197 std::set<std::uint32_t> axes;
199 for (size_t i = 0; i < operand.shape().num_elements(); ++i)
202 switch (operand.typeInfo().type())
204 case ir::DataType::INT32:
205 axis = reinterpret_cast<const int32_t *>(operand.data()->base())[i];
207 case ir::DataType::INT64:
208 axis = reinterpret_cast<const int64_t *>(operand.data()->base())[i];
211 throw std::runtime_error("acl_common::asSet: Not supported data type");
215 axes.insert(ToARMComputeAxis(rank, axis, frontend_layout, backend_layout).value());
221 std::unique_ptr<AclFunction> asAclFunction(std::unique_ptr<::arm_compute::IFunction> &&layer)
223 return std::make_unique<AclFunction>(std::move(layer));
226 std::unique_ptr<AclClFunction> asAclClFunction(std::unique_ptr<::arm_compute::IFunction> &&layer)
228 return std::make_unique<AclClFunction>(std::move(layer));
231 ir::Layout asRuntimeLayout(::arm_compute::DataLayout data_layout)
235 case ::arm_compute::DataLayout::NHWC:
236 return ir::Layout::NHWC;
237 case ::arm_compute::DataLayout::NCHW:
238 return ir::Layout::NCHW;
240 return ir::Layout::UNKNOWN;
244 ir::DataType asRuntimeDataType(::arm_compute::DataType data_type)
248 case ::arm_compute::DataType::F32:
249 return ir::DataType::FLOAT32;
250 case ::arm_compute::DataType::S32:
251 return ir::DataType::INT32;
252 case ::arm_compute::DataType::U32:
253 return ir::DataType::UINT32;
254 case ::arm_compute::DataType::QASYMM8:
255 return ir::DataType::QUANT_UINT8_ASYMM;
256 case ::arm_compute::DataType::U8:
257 return ir::DataType::UINT8;
258 case ::arm_compute::DataType::QSYMM8:
259 return ir::DataType::QUANT_INT8_SYMM;
260 case ::arm_compute::DataType::F16:
261 return ir::DataType::FLOAT16;
263 throw std::runtime_error{"Not supported, yet"};
268 arm_compute::ReduceOperation convertReduceType(ir::operation::Reduce::ReduceType reduce_type_ir)
270 switch (reduce_type_ir)
272 case ir::operation::Reduce::ReduceType::MAX:
273 return arm_compute::ReduceOperation::MAX;
274 case ir::operation::Reduce::ReduceType::MIN:
275 return arm_compute::ReduceOperation::MIN;
276 case ir::operation::Reduce::ReduceType::SUM:
277 return arm_compute::ReduceOperation::SUM;
279 throw std::runtime_error("convertReduceType: Not supported operation yet");
283 } // namespace acl_common
284 } // namespace backend