2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "kernels/Fill.h"
18 #include "kernels/Utils.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 namespace luci_interpreter
26 Fill::Fill(const Tensor *dims, const Tensor *value, Tensor *output)
27 : Kernel({dims, value}, {output})
31 template <typename T> void Fill::configureShape()
33 const auto dims_data = getTensorData<T>(dims());
34 Shape output_shape(dims()->shape().dim(0));
36 for (int i = 0; i < output_shape.num_dims(); ++i)
38 T data = dims_data[i];
40 throw std::runtime_error("Fill dimensions must be >= 0");
42 output_shape.dim(i) = data;
45 output()->resize(output_shape);
48 void Fill::configure()
50 const auto dims_shape = dims()->shape();
51 const auto value_shape = value()->shape();
53 // Make sure the 1st input tensor is 1-D
54 LUCI_INTERPRETER_CHECK(dims_shape.num_dims() == 1);
56 // Make sure the 1st input tensor is int32 or int64
57 LUCI_INTERPRETER_CHECK(dims()->element_type() == DataType::S32 or
58 dims()->element_type() == DataType::S64);
60 // Make sure the 2nd input tensor is a scalar
61 LUCI_INTERPRETER_CHECK(value_shape.num_dims() == 0)
63 // Check zero point and scale for S16 and S8
64 if (value()->element_type() == loco::DataType::S16 or
65 value()->element_type() == loco::DataType::S8)
67 LUCI_INTERPRETER_CHECK(value()->scale() == output()->scale());
68 LUCI_INTERPRETER_CHECK(value()->zero_point() == output()->zero_point());
70 if (value()->element_type() == loco::DataType::S16)
71 LUCI_INTERPRETER_CHECK(value()->zero_point() == 0);
74 switch (dims()->element_type())
77 configureShape<int32_t>();
80 configureShape<int64_t>();
83 throw std::runtime_error("Unsupported type.");
87 void Fill::execute() const
89 switch (output()->element_type())
92 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int8_t>(value()),
93 getTensorShape(output()), getTensorData<int8_t>(output()));
96 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int16_t>(value()),
97 getTensorShape(output()), getTensorData<int16_t>(output()));
100 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int32_t>(value()),
101 getTensorShape(output()), getTensorData<int32_t>(output()));
104 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int64_t>(value()),
105 getTensorShape(output()), getTensorData<int64_t>(output()));
107 case DataType::FLOAT32:
108 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<float>(value()),
109 getTensorShape(output()), getTensorData<float>(output()));
112 throw std::runtime_error("Unsupported type.");
116 } // namespace kernels
117 } // namespace luci_interpreter