--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_TFLSHAPE_ANNOT__
+#define __LOCOEX_SERVICE_TFLSHAPE_ANNOT__
+
+#include <loco/IR/Node.h>
+#include <loco/IR/TensorShape.h>
+
+namespace locoex
+{
+
+/**
+ * @brief Class to annotate shape to a TFL node
+ */
+struct TFLShapeAnnot : public loco::NodeAnnotation
+{
+public:
+ TFLShapeAnnot(const loco::TensorShape &shape) : _shape(shape) {}
+
+public:
+ const loco::TensorShape shape(void) const { return _shape; }
+
+private:
+ const loco::TensorShape _shape;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_TFLSHAPE_ANNOT__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLShapeInferenceRule.h"
+#include "TFLShapeAnnot.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include "ShapeInference.h"
+
+#include <cassert>
+
+namespace
+{
+
+/**
+ * @brief Class to infer the shape of TFLNode
+ *
+ * @note All TFLNode's inputs and outouts are always loco::Domain::Tensor
+ */
+class ShapeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::NodeShape>
+{
+public:
+ loco::NodeShape visit(const locoex::TFLNode *node) final
+ {
+ if (loco::shape_known(node)) // if shape was already inferred by inference rule
+ {
+ assert(loco::shape_get(node).domain() == loco::Domain::Tensor);
+
+ return loco::shape_get(node);
+ }
+ else
+ { // getting Shape data that was annotated while converting canonical node to TFLNode
+ auto shape_hint = node->annot<locoex::TFLShapeAnnot>();
+ assert(shape_hint != nullptr);
+
+ return shape_hint->shape();
+ }
+ }
+};
+
+} // namespace
+
+namespace locoex
+{
+
+bool TFLShapeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ return TFLDialect::get() == d;
+}
+
+bool TFLShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
+{
+ assert(node->dialect() == TFLDialect::get());
+ assert(dynamic_cast<const TFLNode *>(node) != nullptr);
+
+ ShapeInferenceAlgorithm alg;
+ shape = dynamic_cast<const TFLNode *>(node)->accept(&alg);
+
+ return true;
+}
+
+} // namespace locoex
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__
+#define __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__
+
+#include <loco/Service/ShapeInference.h>
+
+namespace locoex
+{
+
+struct TFLShapeInferenceRule final : public loco::ShapeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+ bool infer(const loco::Node *, loco::NodeShape &) const final;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/Service/TFLShapeAnnot.h"
+#include "Dialect/Service/TFLShapeInferenceRule.h"
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
+{
+ // Create a simple network
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+
+ auto tfl_node = g->nodes()->create<locoex::TFLRelu>();
+ tfl_node->input(pull_node);
+
+ auto push_node = g->nodes()->create<loco::Push>();
+ push_node->from(tfl_node);
+
+ auto input = g->inputs()->create();
+ {
+ input->name("input");
+ loco::link(input, pull_node);
+ }
+ auto output = g->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push_node);
+ }
+
+ // pre-check
+ ASSERT_FALSE(loco::shape_known(tfl_node));
+
+ // scenario.
+ // step 1. add annotation and run shape inference.
+ // TFLShapeInference will get shape info from annotated data
+ // step 2. then, run shape inference again
+
+ // step 1.
+ loco::TensorShape ts;
+ {
+ ts.rank(2);
+ ts.dim(0) = 1;
+ ts.dim(1) = 3;
+ }
+ auto shape_annot = stdex::make_unique<locoex::TFLShapeAnnot>(ts);
+ tfl_node->annot<locoex::TFLShapeAnnot>(std::move(shape_annot));
+
+ locoex::TFLShapeInferenceRule tfl_rule;
+ loco::apply(&tfl_rule).to(g.get());
+
+ // Verify
+ auto check_shape = [](locoex::TFLRelu *tfl_node) {
+ ASSERT_TRUE(loco::shape_known(tfl_node));
+ ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+ auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+ ASSERT_EQ(shape.rank(), 2);
+ ASSERT_EQ(shape.dim(0), 1);
+ ASSERT_EQ(shape.dim(1), 3);
+ };
+
+ check_shape(tfl_node);
+
+ // step 2.
+ loco::apply(&tfl_rule).to(g.get());
+
+ check_shape(tfl_node);
+}