loco::DataType visit(const loco::Push *node) { return loco::dtype_get(node->from()); }
loco::DataType visit(const loco::Pull *node) { return node->dtype(); }
loco::DataType visit(const loco::ReLU *node) { return loco::dtype_get(node->input()); }
- // TODO Support ReLU6
+ loco::DataType visit(const loco::ReLU6 *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); }
loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); }
};
ASSERT_TRUE(loco::dtype_known(push_node));
ASSERT_EQ(loco::dtype_get(push_node), loco::DataType::U8);
}
+
+TEST(CanonicalTypeInferenceRuleTest, relu6)
+{
+ // Create a simple Relu6 network
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+
+ pull_node->dtype(loco::DataType::FLOAT32);
+
+ auto relu6_node = g->nodes()->create<loco::ReLU6>();
+
+ relu6_node->input(pull_node);
+
+ auto push_node = g->nodes()->create<loco::Push>();
+
+ push_node->from(relu6_node);
+
+ auto graph_input = g->inputs()->create();
+
+ graph_input->name("input");
+ graph_input->node(pull_node);
+
+ auto graph_output = g->outputs()->create();
+
+ graph_output->name("output");
+ graph_output->node(push_node);
+
+ // Run Type Inference
+ loco::CanonicalTypeInferenceRule rule;
+
+ loco::apply(&rule).to(g.get());
+
+ // Verify!
+ ASSERT_TRUE(loco::dtype_known(relu6_node));
+ ASSERT_EQ(loco::dtype_get(relu6_node), loco::DataType::FLOAT32);
+}