From a51e6692b4006822156f0ddd091b5eb4af393475 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 5 Aug 2019 13:42:13 +0900 Subject: [PATCH] [loco] Implement FeatureCodec shape inference (#6191) This commit extends CanonicalShaapeInferenRule to accept FeatureEncode and FeatureDecode nodes. Signed-off-by: Jonghyun Park --- .../src/Service/CanonicalShapeInferenceRule.cpp | 17 +++++- .../Service/CanonicalShapeInferenceRule.test.cpp | 25 +++++++++ compiler/loco/src/Service/GraphTestcase.h | 62 ++++++++++++++++++++++ 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index e18d4b2..d76868d 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -50,8 +50,21 @@ public: // TODO Support EltwiseMul // TODO Support Forward // TODO Support FeatureBiasAdd - // TODO Support FeatureDecode - // TODO Support FeatureEncode + + // CASE: FeatureDecode + loco::NodeShape visit(const loco::FeatureDecode *node) final + { + auto input_node_shape = loco::shape_get(node->input()); + return loco::NodeShape{node->decoder()->shape(input_node_shape.as())}; + } + + // CASE: FeatureEncode + loco::NodeShape visit(const loco::FeatureEncode *node) final + { + auto input_node_shape = loco::shape_get(node->input()); + return loco::NodeShape{node->encoder()->shape(input_node_shape.as())}; + } + // TODO Support FilterEncode // TODO Support FixedReshape // TODO Support MaxPool2D diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp index 537c26c..df0e33a 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -66,3 +66,28 @@ TEST(CanonicalShapeInferenceRuleTest, relu) ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(2), 3); ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(3), 4); } + +TEST(CanonicalShapeInferenceRuleTest, feature_codec) +{ + // Create a sample network + GraphTestcase testcase; + + testcase.pull_node->shape({1, 2, 3, 4}); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.encode_node)); + ASSERT_EQ(loco::shape_get(testcase.encode_node).domain(), loco::Domain::Feature); + + ASSERT_TRUE(loco::shape_known(testcase.decode_node)); + ASSERT_EQ(loco::shape_get(testcase.decode_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as().rank(), 4); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as().dim(1), 2); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as().dim(2), 3); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as().dim(3), 4); +} diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h index 4ae4101..a192323 100644 --- a/compiler/loco/src/Service/GraphTestcase.h +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -2,11 +2,15 @@ #define __GRAPH_TESTCASE_H__ #include "loco/IR/Graph.h" +#include "loco/IR/PermutingCodec.h" + +#include enum class GraphCode { Identity, Relu, + FeatureCodec, }; template class GraphTestcase; @@ -93,4 +97,62 @@ private: std::unique_ptr _graph; }; +template <> class GraphTestcase final +{ +public: + GraphTestcase() + { + using namespace loco; + + Permutation perm; + + perm[FeatureAxis::Count] = 0; + perm[FeatureAxis::Height] = 1; + perm[FeatureAxis::Width] = 2; + perm[FeatureAxis::Depth] = 3; + + // Create a sample network + _graph = make_graph(); + + // Create Nodes + pull_node = _graph->nodes()->create(); + + encode_node = _graph->nodes()->create(); + encode_node->input(pull_node); + encode_node->encoder(stdex::make_unique>(perm)); + + decode_node = _graph->nodes()->create(); + decode_node->input(encode_node); + decode_node->decoder(stdex::make_unique>(perm)); + + push_node = _graph->nodes()->create(); + push_node->from(decode_node); + + // Create Graph Input + auto graph_input = _graph->inputs()->create(); + + graph_input->name("input"); + graph_input->node(pull_node); + pull_node->index(0); + + // Create Graph Output + auto graph_output = _graph->outputs()->create(); + + graph_output->name("output"); + graph_output->node(push_node); + push_node->index(0); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FeatureEncode *encode_node = nullptr; + loco::FeatureDecode *decode_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr _graph; +}; + #endif // __GRAPH_TESTCASE_H__ -- 2.7.4