2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * @file Tensor3DSource.h
19 * @ingroup COM_AI_RUNTIME
20 * @brief This file defines Tensor3DSource class
22 #ifndef __TENSOR3D_SOURCE_H__
23 #define __TENSOR3D_SOURCE_H__
25 #include "internal/Source.h"
28 // This is memcpy() version of generic TensorSource for 3D tensor
30 #include <arm_compute/core/ITensor.h>
31 #include <arm_compute/core/Window.h>
32 #include <arm_compute/core/Helpers.h>
35 * @brief Class to push tensor data to arm compute tensor
37 template <typename T> class Tensor3DSource final : public Source
41 * @brief Construct a new Tensor3DSource object
42 * @param[in] shape Shape of tensor
43 * @param[in] base Pointer of tensor data to push
44 * @param[in] size Size of tensor
46 Tensor3DSource(const nnfw::misc::tensor::Shape &shape, const T *base, const size_t size)
47 : _shape{shape}, _base{base}, _size{size}
54 * @brief Push tensor data to arm compute tensor
55 * @param[out] tensor Tensor object of arm compute to push tensor data
58 void push(::arm_compute::ITensor &tensor) const override
60 using ::arm_compute::Coordinates;
61 using ::arm_compute::execute_window_loop;
62 using ::arm_compute::Iterator;
63 using ::arm_compute::Window;
67 window.use_tensor_dimensions(tensor.info()->tensor_shape(), ::arm_compute::Window::DimY);
68 int32_t height_width = _shape.dim(1) * _shape.dim(2);
69 int32_t width = _shape.dim(2);
71 Iterator it(&tensor, window);
72 execute_window_loop(window,
73 [&](const ::arm_compute::Coordinates &id) {
74 const auto z = id.z();
75 const auto y = id.y();
76 memcpy(it.ptr(), _base + z * height_width + y * width, width * sizeof(T));
82 const nnfw::misc::tensor::Shape _shape;
89 #endif // __TENSOR3D_SOURCE_H__