From: 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Tue, 20 Aug 2019 07:46:48 +0000 (+0900) Subject: [moco-tf] Canonicalizer for TFRsqrt (#6723) X-Git-Tag: accepted/tizen/unified/20190903.052428~294 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=16e768d141709b4e988cf070da193a2fdad833e5;p=platform%2Fcore%2Fml%2Fnnfw.git [moco-tf] Canonicalizer for TFRsqrt (#6723) This will introduce Canonicalizer for TFRsqrt node that converts to "1/Sqrt(x)" Signed-off-by: SaeHie Park --- diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp new file mode 100644 index 0000000..b4fbcac --- /dev/null +++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "RsqrtCanonicalizer.h" + +#include "Annotations/ShapeInferenceData.h" + +#include "Dialect/TFDialect.h" +#include "Dialect/TFNodes.h" +#include "Dialect/TFNodeVisitor.h" +#include "Dialect/TFNodeImpl.h" + +#include + +#include + +#include + +namespace +{ + +template +void prepare_const_gen(loco::ConstGen *const_node, const moco::tf::ShapeInferenceData *shapedata, + T value); + +template <> +void prepare_const_gen(loco::ConstGen *const_node, + const moco::tf::ShapeInferenceData *shapedata, float value) +{ + LOGGER(l); + + uint32_t const_num_elements = 1; + + auto dtype = loco::DataType::FLOAT32; + const_node->dtype(dtype); + + auto rank = shapedata->rank(); + const_node->rank(rank); + for (uint32_t r = 0; r < rank; ++r) + { + if (shapedata->dim(r).known()) + const_node->dim(r) = shapedata->dim(r); + else + throw std::runtime_error("Cannot handle unknown shape"); + + assert(shapedata->dim(r).value() > 0); + + const_num_elements *= shapedata->dim(r).value(); + } + + INFO(l) << "prepare_const_gen : Elements = " << const_num_elements; + + const_node->size(const_num_elements); + for (uint32_t i = 0; i < const_num_elements; ++i) + { + const_node->at(i) = value; + } +} + +bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node) +{ + /** + * @note This will replace TFRsqrt node with Canonical EltwiseSqrt + EltwiseRealDiv + * + * Before + * A --- TFRsqrt -- C + * After + * +- TFRsqrt -- + * | + * | ConstGen --+ + * | \ + * A -+- EltwiseSqrt -- EltwiseDiv -- C + * + * Where + * A : features of TFRsqrt + * C : a node that uses TFSqrt as an input + * TFRsqrt is disconnected from C + * TFRsqrt is converted to 1 / EltwiseSqrt + */ + + auto rsqrt_shapedata = node->annot(); + if (rsqrt_shapedata == nullptr) + { + // We need this shape information + assert(false); // this shouldn't happen, let's add an alarm + return false; + } + + if (!loco::dtype_known(node)) + { + // We need type of this node + return false; + } + + auto sqrt_node = graph->nodes()->create(); + auto eltdiv_node = graph->nodes()->create(); + auto const_node = graph->nodes()->create(); + + auto dtype = loco::dtype_get(node); + + switch (dtype) + { + case loco::DataType::FLOAT32: + prepare_const_gen(const_node, rsqrt_shapedata, 1.0f); + break; + + default: + throw std::runtime_error("NYI for this DataType"); + } + + auto node_A = node->x(); + + // update connections + sqrt_node->input(node_A); + eltdiv_node->lhs(const_node); + eltdiv_node->rhs(sqrt_node); + + // replace node + replace(node).with(eltdiv_node); + + return true; +} + +} // namespace + +namespace moco +{ +namespace tf +{ + +bool RsqrtCanonicalizer::run(loco::Graph *graph) +{ + auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); + bool changed = false; + + for (auto node : active_nodes) + { + if (node->dialect() == TFDialect::get()) + { + auto tf_node = dynamic_cast(node); + if (tf_node != nullptr) + { + if (canonicalize_rsqrt(graph, tf_node)) + changed = true; + } + } + } + + return changed; +} + +} // namespace tf +} // namespace moco diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h new file mode 100644 index 0000000..a58c0ad --- /dev/null +++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MOCO_TF_RSQRT_CANONICALIZER_H__ +#define __MOCO_TF_RSQRT_CANONICALIZER_H__ + +#include "Transform.h" + +#include + +namespace moco +{ +namespace tf +{ + +/** + * @brief Convert TFRsqrt to Canonical EltwiseDiv + EltwiseSqrt + */ +class RsqrtCanonicalizer : public Transform +{ +public: + const char *name(void) const final { return "RsqrtCanonicalizer"; } + +public: + bool run(loco::Graph *graph) override; +}; + +} // namespace tf +} // namespace moco + +#endif // __MOCO_TF_RSQRT_CANONICALIZER_H__