using FixedReshape = loco::Reshape<loco::ReshapeType::Fixed>;
- // TODO Support not only TFConst, but also loco::ConstGen as shape input
- // TODO Support other cases like dynamic reshape
-
/**
- * @note This will replace TFReshape + TFConst(as shape input) node pair into
- * Canonical Reshape<ReshapeType::Fixed>, or 'FixedReshape'. TFConst
- * should not have -1 as its entry to be converted to FixedReshape.
+ * This rule canonicalizes TFReshape only when its output shape is known at
+ * compile time, i.e. fixed reshape case.
+ * TODO Support other cases like dynamic reshape
+ *
+ * This will replace TFReshape + TFConst or Canonical ConstGen(as shape input)
+ * node pair into Canonical Reshape<ReshapeType::Fixed>, or 'FixedReshape'.
+ * Shape input (TFConst or Canonical ConstGen) should not have wildcard
+ * dimension to be converted to FixedReshape.
*
- * Before
- * (shape)
- * TFConst ----
- * \
- * In --------- TFReshape ------- Out(s)
- * (tensor)
+ * Before
+ * TFConst (shape)
+ * or ---
+ * ConstGen \
+ * \
+ * In --------- TFReshape ------- Out(s)
+ * (tensor)
*
- * After
- * TFConst ----
- * \
- * ---------- TFReshape
- * /
- * In -------- FixedReshape ----- Out(s)
+ * After
+ * TFConst
+ * or ---
+ * ConstGen \
+ * \
+ * ---------- TFReshape
+ * /
+ * In -------- FixedReshape ----- Out(s)
*/
- // create loco nodes
+ // create loco node to replace
auto fixed_reshape = graph->nodes()->create<FixedReshape>();
- // Only supports fixed reshape
- auto tfconst_shape_input = dynamic_cast<moco::tf::TFConst *>(node->shape());
- assert(is_fixed_shape_input(tfconst_shape_input));
+ // Supports 2 cases for Reshape's shape input:
+ // TF-dialect TFConst or Canonical ConstGen
+ loco::Node *shape_input = node->shape();
+ auto tfconst_shape_input = dynamic_cast<moco::tf::TFConst *>(shape_input);
+ auto constgen_shape_input = dynamic_cast<loco::ConstGen *>(shape_input);
+ if (tfconst_shape_input)
{
- // set attribute
- auto shape = dynamic_cast<moco::tf::TFConst *>(node->shape());
- auto rank = shape->dim(0).value();
+ // Only support fixed reshape
+ // TODO support dynamic reshape
+ assert(is_fixed_shape_input(tfconst_shape_input));
+ auto rank = tfconst_shape_input->dim(0).value();
fixed_reshape->rank(rank);
for (uint32_t axis = 0; axis < rank; ++axis)
{
- fixed_reshape->dim(axis) = shape->at<loco::DataType::S32>(axis);
+ fixed_reshape->dim(axis) = tfconst_shape_input->at<loco::DataType::S32>(axis);
}
}
+ else if (constgen_shape_input)
+ {
+ // ditto
+ assert(is_fixed_shape_input(constgen_shape_input));
+
+ auto rank = constgen_shape_input->dim(0).value();
+ fixed_reshape->rank(rank);
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ {
+ fixed_reshape->dim(axis) = constgen_shape_input->at<loco::DataType::S32>(axis);
+ }
+ }
+ else
+ {
+ // TODO support dynamic reshape from not const node
+ throw std::runtime_error("ReshapeCanonicalizer: only support const node as input shape");
+ }
// replace
auto in = node->tensor();