From f05c415e32bb8c841f67b6e1396df939d7f2d9bc Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 8 Aug 2019 16:43:00 +0900 Subject: [PATCH] [moco-tf] Introduce ResolveConstantShape transform (#6379) * [moco-tf] Introduce ResolveConstantShape transform This commit introduces ResolveConstantShape transform, which is responsible to replace determined TFShape node with TFConst. Signed-off-by: Cheongyo Bahk * Separate out condition check stage --- .../src/Transforms/ResolveConstantShape.cpp | 126 +++++++++++++++++++++ .../moco-tf/src/Transforms/ResolveConstantShape.h | 44 +++++++ 2 files changed, 170 insertions(+) create mode 100644 compiler/moco-tf/src/Transforms/ResolveConstantShape.cpp create mode 100644 compiler/moco-tf/src/Transforms/ResolveConstantShape.h diff --git a/compiler/moco-tf/src/Transforms/ResolveConstantShape.cpp b/compiler/moco-tf/src/Transforms/ResolveConstantShape.cpp new file mode 100644 index 0000000..017aa66 --- /dev/null +++ b/compiler/moco-tf/src/Transforms/ResolveConstantShape.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ResolveConstantShape.h" + +#include "IR/TFShape.h" +#include "IR/TFConst.h" +#include "Annotations/ShapeInferenceData.h" + +#include + +#include + +namespace +{ + +/** + * WHEN: + * - TFShape's input shape is determined + * DO: + * - Replace TFShape into TFConst + * + * + * + * in ---- TFShape ---- out(s) + * + * + * in ---- TFShape + * + * TFConst ---- out(s) + */ +bool resolve_constant_shape(loco::Graph *graph, moco::tf::TFShape *shape_node) +{ + using moco::tf::ShapeInferenceData; + + auto input_shape = shape_node->input()->annot(); + + // Check condition + if (!input_shape) + { + // Cannot resolve without known input_shape + return false; + } + auto shape_rank = input_shape->rank(); + for (uint32_t axis = 0; axis < shape_rank; ++axis) + { + if (!input_shape->dim(axis).known()) + { + // Cannot resolve with unknown dimension + return false; + } + } + + auto input_tensor_shape = input_shape->tensor_shape(); + + // Make TFConst to replace TFShape + auto const_node = graph->nodes()->create(); + + // set dtype + auto dtype = shape_node->dtype(); + const_node->dtype(dtype); + + // set shape + const_node->rank(1); + const_node->dim(0) = shape_rank; + + // set data + if (dtype == loco::DataType::S32) + { + // TODO Better to make template for this when support new dtype + const_node->size(shape_rank); + for (uint32_t axis = 0; axis < shape_rank; ++axis) + { + int32_t dim = (int32_t)input_tensor_shape.dim(axis).value(); + assert(dim > 0); + const_node->at(axis) = dim; + } + } + else + { + throw std::runtime_error("ResolveConstantShape: Not supported output data type"); + } + + // replace + loco::replace(shape_node).with(const_node); + + return true; +} + +} // namespace + +namespace moco +{ +namespace tf +{ + +bool ResolveConstantShape::run(loco::Graph *graph) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(graph))) + { + if (auto shape_node = as(node)) + { + if (resolve_constant_shape(graph, shape_node)) + changed = true; + } + } + + return changed; +} + +} // namespace tf +} // namespace moco diff --git a/compiler/moco-tf/src/Transforms/ResolveConstantShape.h b/compiler/moco-tf/src/Transforms/ResolveConstantShape.h new file mode 100644 index 0000000..069418b --- /dev/null +++ b/compiler/moco-tf/src/Transforms/ResolveConstantShape.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MOCO_TF_RESOLVE_CONSTANT_SHAPE_H__ +#define __MOCO_TF_RESOLVE_CONSTANT_SHAPE_H__ + +#include "Transform.h" + +#include + +namespace moco +{ +namespace tf +{ + +/** + * @brief Replace fully determined TFShape node into TFConst + */ +class ResolveConstantShape : public Transform +{ +public: + const char *name(void) const final { return "ResolveConstantShape"; } + +public: + bool run(loco::Graph *graph) override; +}; + +} // namespace tf +} // namespace moco + +#endif // __MOCO_TF_RESOLVE_CONSTANT_SHAPE_H__ -- 2.7.4