From 64f2d050a271a8b23b5efa14c1bd5815ea0b94c2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 19 Aug 2019 19:07:46 +0900 Subject: [PATCH] [moco-tf] Introduce ShapeInference (#6679) This will introduce ShapeInference for loco graph and query for loco node that will do shape inference for all dialects used in moco Signed-off-by: SaeHie Park --- compiler/moco-tf/src/Dialect/ShapeInference.cpp | 56 +++++++++++++++++++++++++ compiler/moco-tf/src/Dialect/ShapeInference.h | 48 +++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 compiler/moco-tf/src/Dialect/ShapeInference.cpp create mode 100644 compiler/moco-tf/src/Dialect/ShapeInference.h diff --git a/compiler/moco-tf/src/Dialect/ShapeInference.cpp b/compiler/moco-tf/src/Dialect/ShapeInference.cpp new file mode 100644 index 0000000..a747a69 --- /dev/null +++ b/compiler/moco-tf/src/Dialect/ShapeInference.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ShapeInference.h" + +#include "TFShapeInferenceRule.h" +#include "TFDialect.h" + +#include + +#include +#include +#include +#include + +#include + +namespace moco +{ +namespace tf +{ + +void ShapeInference::run(loco::Graph *g) +{ + loco::CanonicalShapeInferenceRule canonical_rule; + TFShapeInferenceRule tf_rule; + + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule).bind(TFDialect::get(), &tf_rule); + // TODO: add CustomOp shape inference + + loco::apply(&rules).to(g); +} + +loco::NodeShape ShapeInference::get(loco::Node *node) +{ + assert(loco::shape_known(node)); + return loco::shape_get(node); +} + +} // namespace tf +} // namespace moco diff --git a/compiler/moco-tf/src/Dialect/ShapeInference.h b/compiler/moco-tf/src/Dialect/ShapeInference.h new file mode 100644 index 0000000..fc58c1b --- /dev/null +++ b/compiler/moco-tf/src/Dialect/ShapeInference.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MOCO_TF_SHAPE_INFERENCE_H__ +#define __MOCO_TF_SHAPE_INFERENCE_H__ + +#include +#include + +namespace moco +{ +namespace tf +{ + +/** + * @brief Class to prepare shape inferecne for all dialects used in moco and + * query shape for the node + * + * HOW TO USE + * + * ShapeInference::run(g); + * + * ShapeInference::get(g->nodes()->at(..)); + */ +struct ShapeInference +{ + static void run(loco::Graph *g); + + static loco::NodeShape get(loco::Node *node); +}; + +} // namespace tf +} // namespace moco + +#endif // __MOCO_TF_SHAPE_INFERENCE_H__ -- 2.7.4