*/
#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
#include "Dialect/Service/TFLTypeInferenceRule.h"
#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
#include <stdex/Memory.h>
auto g = loco::make_graph();
auto pull_node = g->nodes()->create<loco::Pull>();
+ {
+ pull_node->dtype(loco::DataType::S32);
+ }
auto tfl_node = g->nodes()->create<locoex::TFLRelu>();
tfl_node->features(pull_node);
// pre-check
ASSERT_FALSE(loco::dtype_known(tfl_node));
- // TODO Rewrite code below using new TFLShapeInferenceRule
+ // type inference
+ locoex::TFLTypeInferenceRule tfl_rule;
+ loco::CanonicalTypeInferenceRule canon_rule;
+ loco::MultiDialectTypeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canon_rule);
+ rules.bind(locoex::TFLDialect::get(), &tfl_rule);
+
+ loco::apply(&rules).to(g.get());
+
+ // Verify
+ ASSERT_TRUE(loco::dtype_known(tfl_node));
+ auto type = loco::dtype_get(tfl_node);
+ ASSERT_EQ(type, loco::DataType::S32);
}