[moco-tf] TypeInference (#6593)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 14 Aug 2019 08:29:12 +0000 (17:29 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 14 Aug 2019 08:29:12 +0000 (17:29 +0900)
* [moco-tf] TypeInference

This will introduce TF dialect TypeInference and rule class

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* explicit use name space

* update commentf for TypeInference

* add todo

* update class name and comment

* comment for not supported yet node

compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp [new file with mode: 0644]
compiler/moco-tf/src/Dialect/TFTypeInferenceRule.h [new file with mode: 0644]
compiler/moco-tf/src/Dialect/TypeInference.cpp [new file with mode: 0644]
compiler/moco-tf/src/Dialect/TypeInference.h [new file with mode: 0644]

diff --git a/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp b/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp
new file mode 100644 (file)
index 0000000..2d823da
--- /dev/null
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFTypeInferenceRule.h"
+
+#include "TFDialect.h"
+#include "TFNodeVisitor.h"
+#include "TFNodes.h"
+
+#include "TFNodeImpl.h"
+
+#include <cassert>
+
+namespace
+{
+
+using namespace moco::tf;
+
+struct TypeForwardAlgorithm final : public moco::tf::TFNodeVisitor<loco::DataType>
+{
+  loco::DataType visit(const TFAdd *node) { return dtype_get(node->x()); }
+  loco::DataType visit(const TFAvgPool *node) { return dtype_get(node->value()); }
+  loco::DataType visit(const TFBiasAdd *node) { return dtype_get(node->value()); }
+  loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->lhs()); }
+
+  loco::DataType visit(const TFConst *node) { return node->dtype(); }
+
+  loco::DataType visit(const TFConv2D *node) { return dtype_get(node->ifm()); }
+  loco::DataType visit(const TFDepthwiseConv2dNative *node) { return dtype_get(node->ifm()); }
+  loco::DataType visit(const TFFusedBatchNorm *node) { return dtype_get(node->input()); }
+  loco::DataType visit(const TFIdentity *node) { return dtype_get(node->input()); }
+  loco::DataType visit(const TFMaxPool *node) { return dtype_get(node->value()); }
+  loco::DataType visit(const TFMul *node) { return dtype_get(node->x()); }
+  loco::DataType visit(const TFRealDiv *node) { return dtype_get(node->x()); }
+  loco::DataType visit(const TFRelu *node) { return dtype_get(node->features()); }
+  loco::DataType visit(const TFRelu6 *node) { return dtype_get(node->features()); }
+  loco::DataType visit(const TFReshape *node) { return dtype_get(node->tensor()); }
+  loco::DataType visit(const TFRsqrt *node) { return dtype_get(node->x()); }
+
+  loco::DataType visit(const TFShape *node) { return node->dtype(); }
+
+  loco::DataType visit(const TFSqrt *node) { return dtype_get(node->x()); }
+  // TODO handle TFSquaredDifference
+  loco::DataType visit(const TFSqueeze *node) { return dtype_get(node->input()); }
+  loco::DataType visit(const TFSub *node) { return dtype_get(node->x()); }
+};
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool TFTypeInferenceRule::recognize(const loco::Dialect *d) const
+{
+  // This rule recognizes only "TFDialect" dialect!
+  return TFDialect::get() == d;
+}
+
+bool TFTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
+{
+  assert(node->dialect() == TFDialect::get());
+
+  TypeForwardAlgorithm alg;
+
+// clang-format off
+#define TENSORFLOW_NODE(OPCODE,CLASS)                          \
+  if (dynamic_cast<const moco::tf::CLASS *>(node))             \
+  {                                                            \
+    auto tfnode = dynamic_cast<const moco::tf::CLASS *>(node); \
+    dtype = tfnode->accept(&alg);                              \
+    assert(dtype != loco::DataType::Unknown);                  \
+    return true;                                               \
+  }
+#include "Dialect/TFNodes.lst"
+#undef TENSORFLOW_NODE
+  // clang-format on
+
+  return false;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.h b/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.h
new file mode 100644 (file)
index 0000000..3e6a647
--- /dev/null
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_TYPE_INFERENCE_RULE_H__
+#define __MOCO_TF_TYPE_INFERENCE_RULE_H__
+
+#include <loco/Service/TypeInference.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Type Inference Rule for TFDialect
+ */
+struct TFTypeInferenceRule final : public loco::TypeInferenceRule
+{
+  bool recognize(const loco::Dialect *) const final;
+
+  bool infer(const loco::Node *, loco::DataType &) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_TYPE_INFERENCE_RULE_H__
diff --git a/compiler/moco-tf/src/Dialect/TypeInference.cpp b/compiler/moco-tf/src/Dialect/TypeInference.cpp
new file mode 100644 (file)
index 0000000..0d95f26
--- /dev/null
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TypeInference.h"
+#include "TFTypeInferenceRule.h"
+#include "TFDialect.h"
+
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
+
+#include <stdex/Memory.h>
+
+#include <type_traits>
+
+#include <cassert>
+
+namespace moco
+{
+namespace tf
+{
+
+void TypeInference::run(loco::Graph *g)
+{
+  loco::CanonicalTypeInferenceRule canonical_rule;
+  TFTypeInferenceRule tf_rule; // rule for TF dialect
+
+  loco::MultiDialectTypeInferenceRule rules;
+
+  rules.bind(loco::CanonicalDialect::get(), &canonical_rule).bind(TFDialect::get(), &tf_rule);
+
+  loco::apply(&rules).to(g);
+}
+
+loco::DataType TypeInference::get(const loco::Node *node)
+{
+  assert(loco::dtype_known(node));
+  return loco::dtype_get(node);
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Dialect/TypeInference.h b/compiler/moco-tf/src/Dialect/TypeInference.h
new file mode 100644 (file)
index 0000000..d3aafb0
--- /dev/null
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_TYPE_INFERENCE_H__
+#define __MOCO_TF_TYPE_INFERENCE_H__
+
+#include <loco/IR/Nodes.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Class to prepare type inferecne for all dialects used in moco and query type for the node
+ *
+ * HOW TO USE
+ *
+ *   TypeInference::run(g);
+ *
+ *   TypeInference::get(g->nodes()->at(0));
+ *   TypeInference::get(g->nodes()->at(...));
+ */
+struct TypeInference
+{
+  static void run(loco::Graph *g);
+
+  static loco::DataType get(const loco::Node *node);
+};
+
+inline loco::DataType dtype_get(const loco::Node *node) { return TypeInference::get(node); }
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_TYPE_INFERENCE_H__