[loco] Implement FeatureCodec shape inference (#6191)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 5 Aug 2019 04:42:13 +0000 (13:42 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 5 Aug 2019 04:42:13 +0000 (13:42 +0900)
This commit extends CanonicalShaapeInferenRule to accept
FeatureEncode and FeatureDecode nodes.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
compiler/loco/src/Service/GraphTestcase.h

index e18d4b2..d76868d 100644 (file)
@@ -50,8 +50,21 @@ public:
   // TODO Support EltwiseMul
   // TODO Support Forward
   // TODO Support FeatureBiasAdd
-  // TODO Support FeatureDecode
-  // TODO Support FeatureEncode
+
+  // CASE: FeatureDecode
+  loco::NodeShape visit(const loco::FeatureDecode *node) final
+  {
+    auto input_node_shape = loco::shape_get(node->input());
+    return loco::NodeShape{node->decoder()->shape(input_node_shape.as<loco::FeatureShape>())};
+  }
+
+  // CASE: FeatureEncode
+  loco::NodeShape visit(const loco::FeatureEncode *node) final
+  {
+    auto input_node_shape = loco::shape_get(node->input());
+    return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())};
+  }
+
   // TODO Support FilterEncode
   // TODO Support FixedReshape
   // TODO Support MaxPool2D
index 537c26c..df0e33a 100644 (file)
@@ -66,3 +66,28 @@ TEST(CanonicalShapeInferenceRuleTest, relu)
   ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
   ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
 }
+
+TEST(CanonicalShapeInferenceRuleTest, feature_codec)
+{
+  // Create a sample network
+  GraphTestcase<GraphCode::FeatureCodec> testcase;
+
+  testcase.pull_node->shape({1, 2, 3, 4});
+
+  // Run Inference
+  loco::CanonicalShapeInferenceRule rule;
+
+  loco::apply(&rule).to(testcase.graph());
+
+  // Verify!
+  ASSERT_TRUE(loco::shape_known(testcase.encode_node));
+  ASSERT_EQ(loco::shape_get(testcase.encode_node).domain(), loco::Domain::Feature);
+
+  ASSERT_TRUE(loco::shape_known(testcase.decode_node));
+  ASSERT_EQ(loco::shape_get(testcase.decode_node).domain(), loco::Domain::Tensor);
+  ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().rank(), 4);
+  ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(0), 1);
+  ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(1), 2);
+  ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(2), 3);
+  ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(3), 4);
+}
index 4ae4101..a192323 100644 (file)
@@ -2,11 +2,15 @@
 #define __GRAPH_TESTCASE_H__
 
 #include "loco/IR/Graph.h"
+#include "loco/IR/PermutingCodec.h"
+
+#include <stdex/Memory.h>
 
 enum class GraphCode
 {
   Identity,
   Relu,
+  FeatureCodec,
 };
 
 template <GraphCode Code> class GraphTestcase;
@@ -93,4 +97,62 @@ private:
   std::unique_ptr<loco::Graph> _graph;
 };
 
+template <> class GraphTestcase<GraphCode::FeatureCodec> final
+{
+public:
+  GraphTestcase()
+  {
+    using namespace loco;
+
+    Permutation<Domain::Feature> perm;
+
+    perm[FeatureAxis::Count] = 0;
+    perm[FeatureAxis::Height] = 1;
+    perm[FeatureAxis::Width] = 2;
+    perm[FeatureAxis::Depth] = 3;
+
+    // Create a sample network
+    _graph = make_graph();
+
+    // Create Nodes
+    pull_node = _graph->nodes()->create<Pull>();
+
+    encode_node = _graph->nodes()->create<FeatureEncode>();
+    encode_node->input(pull_node);
+    encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm));
+
+    decode_node = _graph->nodes()->create<FeatureDecode>();
+    decode_node->input(encode_node);
+    decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm));
+
+    push_node = _graph->nodes()->create<loco::Push>();
+    push_node->from(decode_node);
+
+    // Create Graph Input
+    auto graph_input = _graph->inputs()->create();
+
+    graph_input->name("input");
+    graph_input->node(pull_node);
+    pull_node->index(0);
+
+    // Create Graph Output
+    auto graph_output = _graph->outputs()->create();
+
+    graph_output->name("output");
+    graph_output->node(push_node);
+    push_node->index(0);
+  }
+
+public:
+  loco::Graph *graph() { return _graph.get(); }
+
+  loco::Pull *pull_node = nullptr;
+  loco::FeatureEncode *encode_node = nullptr;
+  loco::FeatureDecode *decode_node = nullptr;
+  loco::Push *push_node = nullptr;
+
+private:
+  std::unique_ptr<loco::Graph> _graph;
+};
+
 #endif // __GRAPH_TESTCASE_H__