This adds TFLTypeInference to infer type for TFL nodes.
Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_TFLTYPE_ANNOT__
+#define __LOCOEX_SERVICE_TFLTYPE_ANNOT__
+
+#include <loco/IR/Node.h>
+#include <loco/IR/DataType.h>
+
+namespace locoex
+{
+
+/**
+ * @brief Class to annotate type to a TFL node
+ */
+struct TFLTypeAnnot : public loco::NodeAnnotation
+{
+public:
+ TFLTypeAnnot(const loco::DataType type) : _type(type) {}
+
+public:
+ loco::DataType type(void) const { return _type; }
+
+private:
+ const loco::DataType _type;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_TFLTYPE_ANNOT__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/Service/TFLTypeAnnot.h"
+#include "Dialect/Service/TFLTypeInferenceRule.h"
+
+#include <loco.h>
+
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+TEST(TFLTypeInferenceRuleTest, minimal_with_TFLRelu)
+{
+ // Create a simple network
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+
+ auto tfl_node = g->nodes()->create<locoex::TFLRelu>();
+ tfl_node->input(pull_node);
+
+ auto push_node = g->nodes()->create<loco::Push>();
+ push_node->from(tfl_node);
+
+ auto input = g->inputs()->create();
+ {
+ input->name("input");
+ loco::link(input, pull_node);
+ }
+ auto output = g->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push_node);
+ }
+
+ // pre-check
+ ASSERT_FALSE(loco::dtype_known(tfl_node));
+
+ // scenario.
+ // step 1. add annotation and run type inference.
+ // TFLTypeInference will get type info from annotated data
+ // step 2. then, run type inference again
+
+ // step 1.
+ loco::DataType dtype = loco::DataType::S64;
+
+ auto type_annot = stdex::make_unique<locoex::TFLTypeAnnot>(dtype);
+ tfl_node->annot<locoex::TFLTypeAnnot>(std::move(type_annot));
+
+ locoex::TFLTypeInferenceRule tfl_rule;
+ loco::apply(&tfl_rule).to(g.get());
+
+ // Verify
+ auto check_type = [](locoex::TFLRelu *tfl_node) {
+ ASSERT_TRUE(loco::dtype_known(tfl_node));
+ auto type = loco::dtype_get(tfl_node);
+ ASSERT_EQ(type, loco::DataType::S64);
+ };
+
+ check_type(tfl_node);
+
+ // step 2.
+ loco::apply(&tfl_rule).to(g.get());
+
+ check_type(tfl_node);
+}
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLTypeInferenceRule.h"
+#include "TFLTypeAnnot.h"
+
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+#include "Dialect/IR/TFLNodes.h"
+
+#include <cassert>
+
+namespace
+{
+
+struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::DataType>
+{
+ loco::DataType visit(const locoex::TFLNode *tfl_node) final
+ {
+ if (loco::dtype_known(tfl_node)) // if type was already found by inference rule
+ {
+ return loco::dtype_get(tfl_node);
+ }
+ else
+ {
+ auto dtype_hint = tfl_node->annot<locoex::TFLTypeAnnot>();
+
+ return dtype_hint->type(); // normally dtype_hint is the type of TFLNode
+ }
+ }
+};
+
+} // namespace
+
+namespace locoex
+{
+
+bool TFLTypeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ return TFLDialect::get() == d;
+}
+
+bool TFLTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
+{
+ assert(node->dialect() == TFLDialect::get());
+
+ TypeInferenceAlgorithm alg;
+
+ dtype = dynamic_cast<const TFLNode *>(node)->accept(&alg);
+ assert(dtype != loco::DataType::Unknown);
+
+ return true;
+}
+
+} // namespace locoex
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_TFLTYPE_INFERENCE_RULE_H__
+#define __LOCOEX_SERVICE_TFLTYPE_INFERENCE_RULE_H__
+
+#include <loco/Service/TypeInference.h>
+
+namespace locoex
+{
+
+/**
+ * @brief Type Inference Rule for TFLDialect
+ */
+struct TFLTypeInferenceRule final : public loco::TypeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+
+ bool infer(const loco::Node *, loco::DataType &) const final;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_TFLTYPE_INFERENCE_RULE_H__