// TODO Support EltwiseMul
// TODO Support Forward
// TODO Support FeatureBiasAdd
- // TODO Support FeatureDecode
- // TODO Support FeatureEncode
+
+ // CASE: FeatureDecode
+ loco::NodeShape visit(const loco::FeatureDecode *node) final
+ {
+ auto input_node_shape = loco::shape_get(node->input());
+ return loco::NodeShape{node->decoder()->shape(input_node_shape.as<loco::FeatureShape>())};
+ }
+
+ // CASE: FeatureEncode
+ loco::NodeShape visit(const loco::FeatureEncode *node) final
+ {
+ auto input_node_shape = loco::shape_get(node->input());
+ return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())};
+ }
+
// TODO Support FilterEncode
// TODO Support FixedReshape
// TODO Support MaxPool2D
ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
}
+
+TEST(CanonicalShapeInferenceRuleTest, feature_codec)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::FeatureCodec> testcase;
+
+ testcase.pull_node->shape({1, 2, 3, 4});
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.encode_node));
+ ASSERT_EQ(loco::shape_get(testcase.encode_node).domain(), loco::Domain::Feature);
+
+ ASSERT_TRUE(loco::shape_known(testcase.decode_node));
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().rank(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(0), 1);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(1), 2);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(2), 3);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(3), 4);
+}
#define __GRAPH_TESTCASE_H__
#include "loco/IR/Graph.h"
+#include "loco/IR/PermutingCodec.h"
+
+#include <stdex/Memory.h>
enum class GraphCode
{
Identity,
Relu,
+ FeatureCodec,
};
template <GraphCode Code> class GraphTestcase;
std::unique_ptr<loco::Graph> _graph;
};
+template <> class GraphTestcase<GraphCode::FeatureCodec> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ Permutation<Domain::Feature> perm;
+
+ perm[FeatureAxis::Count] = 0;
+ perm[FeatureAxis::Height] = 1;
+ perm[FeatureAxis::Width] = 2;
+ perm[FeatureAxis::Depth] = 3;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ // Create Nodes
+ pull_node = _graph->nodes()->create<Pull>();
+
+ encode_node = _graph->nodes()->create<FeatureEncode>();
+ encode_node->input(pull_node);
+ encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm));
+
+ decode_node = _graph->nodes()->create<FeatureDecode>();
+ decode_node->input(encode_node);
+ decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm));
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->from(decode_node);
+
+ // Create Graph Input
+ auto graph_input = _graph->inputs()->create();
+
+ graph_input->name("input");
+ graph_input->node(pull_node);
+ pull_node->index(0);
+
+ // Create Graph Output
+ auto graph_output = _graph->outputs()->create();
+
+ graph_output->name("output");
+ graph_output->node(push_node);
+ push_node->index(0);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FeatureEncode *encode_node = nullptr;
+ loco::FeatureDecode *decode_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
#endif // __GRAPH_TESTCASE_H__