#include "Phase.h"
#include "Transforms.h"
+#include "Transforms/ShapeInferencePass.h"
+#include "Transforms/TypeInferencePass.h"
+
#include "Canonicalization/AddCanonicalizer.h"
#include "Canonicalization/AvgPoolCanonicalizer.h"
#include "Canonicalization/BiasAddCanonicalizer.h"
moco::tf::Phase phase;
/* TRANSFORM DECLARATION BEGIN */
+ // Run shape and type inference at the top
+ phase.emplace_back(stdex::make_unique<ShapeInferencePass>());
+ phase.emplace_back(stdex::make_unique<TypeInferencePass>());
+
phase.emplace_back(stdex::make_unique<AddCanonicalizer>());
phase.emplace_back(stdex::make_unique<AvgPoolCanonicalizer>());
phase.emplace_back(stdex::make_unique<BiasAddCanonicalizer>());
phase.emplace_back(stdex::make_unique<SubCanonicalizer>());
/* TRANSFORM DECLARATION END */
- moco::tf::PhaseRunner<moco::tf::PhaseStrategy::Saturate> phase_runner{g};
+ moco::tf::PhaseRunner<moco::tf::PhaseStrategy::Restart> phase_runner{g};
phase_runner.run(phase);
// Assert if graph has TF dialect nodes