--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFTypeInferenceRule.h"
+
+#include "TFDialect.h"
+#include "TFNodeVisitor.h"
+#include "TFNodes.h"
+
+#include "TFNodeImpl.h"
+
+#include <cassert>
+
+namespace
+{
+
+using namespace moco::tf;
+
+struct TypeForwardAlgorithm final : public moco::tf::TFNodeVisitor<loco::DataType>
+{
+ loco::DataType visit(const TFAdd *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFAvgPool *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFBiasAdd *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->lhs()); }
+
+ loco::DataType visit(const TFConst *node) { return node->dtype(); }
+
+ loco::DataType visit(const TFConv2D *node) { return dtype_get(node->ifm()); }
+ loco::DataType visit(const TFDepthwiseConv2dNative *node) { return dtype_get(node->ifm()); }
+ loco::DataType visit(const TFFusedBatchNorm *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFIdentity *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFMaxPool *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFMul *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFRealDiv *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFRelu *node) { return dtype_get(node->features()); }
+ loco::DataType visit(const TFRelu6 *node) { return dtype_get(node->features()); }
+ loco::DataType visit(const TFReshape *node) { return dtype_get(node->tensor()); }
+ loco::DataType visit(const TFRsqrt *node) { return dtype_get(node->x()); }
+
+ loco::DataType visit(const TFShape *node) { return node->dtype(); }
+
+ loco::DataType visit(const TFSqrt *node) { return dtype_get(node->x()); }
+ // TODO handle TFSquaredDifference
+ loco::DataType visit(const TFSqueeze *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFSub *node) { return dtype_get(node->x()); }
+};
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool TFTypeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ // This rule recognizes only "TFDialect" dialect!
+ return TFDialect::get() == d;
+}
+
+bool TFTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
+{
+ assert(node->dialect() == TFDialect::get());
+
+ TypeForwardAlgorithm alg;
+
+// clang-format off
+#define TENSORFLOW_NODE(OPCODE,CLASS) \
+ if (dynamic_cast<const moco::tf::CLASS *>(node)) \
+ { \
+ auto tfnode = dynamic_cast<const moco::tf::CLASS *>(node); \
+ dtype = tfnode->accept(&alg); \
+ assert(dtype != loco::DataType::Unknown); \
+ return true; \
+ }
+#include "Dialect/TFNodes.lst"
+#undef TENSORFLOW_NODE
+ // clang-format on
+
+ return false;
+}
+
+} // namespace tf
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_TYPE_INFERENCE_RULE_H__
+#define __MOCO_TF_TYPE_INFERENCE_RULE_H__
+
+#include <loco/Service/TypeInference.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Type Inference Rule for TFDialect
+ */
+struct TFTypeInferenceRule final : public loco::TypeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+
+ bool infer(const loco::Node *, loco::DataType &) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_TYPE_INFERENCE_RULE_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TypeInference.h"
+#include "TFTypeInferenceRule.h"
+#include "TFDialect.h"
+
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
+
+#include <stdex/Memory.h>
+
+#include <type_traits>
+
+#include <cassert>
+
+namespace moco
+{
+namespace tf
+{
+
+void TypeInference::run(loco::Graph *g)
+{
+ loco::CanonicalTypeInferenceRule canonical_rule;
+ TFTypeInferenceRule tf_rule; // rule for TF dialect
+
+ loco::MultiDialectTypeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule).bind(TFDialect::get(), &tf_rule);
+
+ loco::apply(&rules).to(g);
+}
+
+loco::DataType TypeInference::get(const loco::Node *node)
+{
+ assert(loco::dtype_known(node));
+ return loco::dtype_get(node);
+}
+
+} // namespace tf
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_TYPE_INFERENCE_H__
+#define __MOCO_TF_TYPE_INFERENCE_H__
+
+#include <loco/IR/Nodes.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Class to prepare type inferecne for all dialects used in moco and query type for the node
+ *
+ * HOW TO USE
+ *
+ * TypeInference::run(g);
+ *
+ * TypeInference::get(g->nodes()->at(0));
+ * TypeInference::get(g->nodes()->at(...));
+ */
+struct TypeInference
+{
+ static void run(loco::Graph *g);
+
+ static loco::DataType get(const loco::Node *node);
+};
+
+inline loco::DataType dtype_get(const loco::Node *node) { return TypeInference::get(node); }
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_TYPE_INFERENCE_H__