2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Fill.h"
19 #include "kernels/Utils.h"
22 namespace luci_interpreter
27 Fill::Fill(const Tensor *dims, const Tensor *value, Tensor *output)
28 : Kernel({dims, value}, {output})
32 template <typename T> void Fill::configureShape()
34 const auto dims_data = getTensorData<T>(dims());
35 Shape output_shape(dims()->shape().dim(0));
37 for (int i = 0; i < output_shape.num_dims(); ++i)
39 T data = dims_data[i];
41 assert(false && "Fill dimensions must be >= 0");
43 output_shape.dim(i) = data;
45 // TODO: enable it only if kernel with dynamic shapes
46 output()->resize(output_shape);
49 void Fill::configure()
51 const auto dims_shape = dims()->shape();
52 const auto value_shape = value()->shape();
54 // Make sure the 1st input tensor is 1-D
55 LUCI_INTERPRETER_CHECK(dims_shape.num_dims() == 1);
57 // Make sure the 1st input tensor is int32 or int64
58 LUCI_INTERPRETER_CHECK(dims()->element_type() == DataType::S32 or
59 dims()->element_type() == DataType::S64);
61 // Make sure the 2nd input tensor is a scalar
62 LUCI_INTERPRETER_CHECK(value_shape.num_dims() == 0)
64 // Check zero point and scale for S16 and S8
65 if (value()->element_type() == DataType::S16 or value()->element_type() == DataType::S8)
67 LUCI_INTERPRETER_CHECK(value()->scale() == output()->scale());
68 LUCI_INTERPRETER_CHECK(value()->zero_point() == output()->zero_point());
70 if (value()->element_type() == DataType::S16)
71 LUCI_INTERPRETER_CHECK(value()->zero_point() == 0);
74 switch (dims()->element_type())
77 configureShape<int32_t>();
80 configureShape<int64_t>();
83 assert(false && "Unsupported type.");
87 void Fill::execute() const
89 switch (output()->element_type())
92 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int8_t>(value()),
93 getTensorShape(output()), getTensorData<int8_t>(output()));
96 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int16_t>(value()),
97 getTensorShape(output()), getTensorData<int16_t>(output()));
100 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int32_t>(value()),
101 getTensorShape(output()), getTensorData<int32_t>(output()));
104 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int64_t>(value()),
105 getTensorShape(output()), getTensorData<int64_t>(output()));
107 case DataType::FLOAT32:
108 tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<float>(value()),
109 getTensorShape(output()), getTensorData<float>(output()));
112 assert(false && "Unsupported type.");
116 } // namespace kernels
117 } // namespace luci_interpreter