4 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
10 * http://www.apache.org/licenses/LICENSE-2.0
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
19 #include "PermuteLayer.h"
32 PermuteLayer::PermuteLayer(const std::vector<ITensor *> &src_tensors,
33 const std::vector<ITensor *> &dst_tensors,
34 const std::vector<ITensor *> &input_deriv_tensors,
35 const std::vector<ITensor *> &output_deriv_tensors,
36 bool ignore_forward_in_training,
37 const std::shared_ptr<ExternalContext> &external_context)
38 : builtin::kernel::PermuteLayer{src_tensors, dst_tensors, external_context},
39 _input_deriv_tensors{input_deriv_tensors}, _output_deriv_tensors{output_deriv_tensors},
40 _ignore_forward_in_training{ignore_forward_in_training}
42 assert(input_deriv_tensors.size() == output_deriv_tensors.size());
43 assert(src_tensors.size() == dst_tensors.size());
46 void PermuteLayer::optimize()
48 builtin::kernel::PermuteLayer::optimize();
50 // TODO Calculate offsets of derivative tensors if necessary
53 void PermuteLayer::forward(bool training)
55 if (training && _ignore_forward_in_training)
58 builtin::kernel::PermuteLayer::run();
61 void PermuteLayer::backward()
63 for (uint32_t i = 0; i < _output_deriv_tensors.size(); ++i)
65 auto src_deriv = _output_deriv_tensors.at(i);
66 auto dst_deriv = _input_deriv_tensors.at(i);
68 // NOTE The derivative tensors corresponding to inputs/outputs of model are nullptr
69 // because permuting those tensors is meaningless
70 if (src_deriv && dst_deriv)
72 const auto rank = src_deriv->getShape().rank();
73 auto output_offsets = _dst_tensors_offsets.at(i);
74 auto input_offsets = _src_tensors_offsets.at(i);
76 exec::IPermuteFunction::permute(src_deriv, dst_deriv, rank, output_offsets, input_offsets);
83 } // namespace builtin
84 } // namespace backend