[moco-tf] Reshape canonicalize support ConstGen shape input (#6115)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Fri, 2 Aug 2019 04:59:48 +0000 (13:59 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 2 Aug 2019 04:59:48 +0000 (13:59 +0900)
* [moco-tf] Reshape canonicalize support ConstGen shape input

In canonicalization of TFReshape, shape input of Reshape supports not
only TF-dialect TFConst, but also canonical ConstGen. This is to
eliminate dependency on canonicalization order.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
* Update comments

* Fix typo

* Remove ambiguity on shape()

compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp

index cdee8bd..343518f 100644 (file)
@@ -67,47 +67,73 @@ bool canonicalize_reshape(loco::Graph *graph, moco::tf::TFReshape *node)
 
   using FixedReshape = loco::Reshape<loco::ReshapeType::Fixed>;
 
-  // TODO Support not only TFConst, but also loco::ConstGen as shape input
-  // TODO Support other cases like dynamic reshape
-
   /**
-   * @note This will replace TFReshape + TFConst(as shape input) node pair into
-   *       Canonical Reshape<ReshapeType::Fixed>, or 'FixedReshape'. TFConst
-   *       should not have -1 as its entry to be converted to FixedReshape.
+   * This rule canonicalizes TFReshape only when its output shape is known at
+   * compile time, i.e. fixed reshape case.
+   * TODO Support other cases like dynamic reshape
+   *
+   * This will replace TFReshape + TFConst or Canonical ConstGen(as shape input)
+   * node pair into Canonical Reshape<ReshapeType::Fixed>, or 'FixedReshape'.
+   * Shape input (TFConst or Canonical ConstGen) should not have wildcard
+   * dimension to be converted to FixedReshape.
    *
-   *       Before
-   *                         (shape)
-   *                 TFConst ----
-   *                             \
-   *                 In --------- TFReshape ------- Out(s)
-   *                    (tensor)
+   * Before
+   *           TFConst   (shape)
+   *              or   ---
+   *           ConstGen   \
+   *                       \
+   *           In --------- TFReshape ------- Out(s)
+   *              (tensor)
    *
-   *       After
-   *                 TFConst ----
-   *                             \
-   *                   ---------- TFReshape
-   *                  /
-   *                 In -------- FixedReshape ----- Out(s)
+   * After
+   *           TFConst
+   *              or   ---
+   *           ConstGen   \
+   *                       \
+   *             ---------- TFReshape
+   *            /
+   *           In -------- FixedReshape ----- Out(s)
    */
 
-  // create loco nodes
+  // create loco node to replace
   auto fixed_reshape = graph->nodes()->create<FixedReshape>();
 
-  // Only supports fixed reshape
-  auto tfconst_shape_input = dynamic_cast<moco::tf::TFConst *>(node->shape());
-  assert(is_fixed_shape_input(tfconst_shape_input));
+  // Supports 2 cases for Reshape's shape input:
+  //     TF-dialect TFConst or Canonical ConstGen
+  loco::Node *shape_input = node->shape();
+  auto tfconst_shape_input = dynamic_cast<moco::tf::TFConst *>(shape_input);
+  auto constgen_shape_input = dynamic_cast<loco::ConstGen *>(shape_input);
 
+  if (tfconst_shape_input)
   {
-    // set attribute
-    auto shape = dynamic_cast<moco::tf::TFConst *>(node->shape());
-    auto rank = shape->dim(0).value();
+    // Only support fixed reshape
+    // TODO support dynamic reshape
+    assert(is_fixed_shape_input(tfconst_shape_input));
 
+    auto rank = tfconst_shape_input->dim(0).value();
     fixed_reshape->rank(rank);
     for (uint32_t axis = 0; axis < rank; ++axis)
     {
-      fixed_reshape->dim(axis) = shape->at<loco::DataType::S32>(axis);
+      fixed_reshape->dim(axis) = tfconst_shape_input->at<loco::DataType::S32>(axis);
     }
   }
+  else if (constgen_shape_input)
+  {
+    // ditto
+    assert(is_fixed_shape_input(constgen_shape_input));
+
+    auto rank = constgen_shape_input->dim(0).value();
+    fixed_reshape->rank(rank);
+    for (uint32_t axis = 0; axis < rank; ++axis)
+    {
+      fixed_reshape->dim(axis) = constgen_shape_input->at<loco::DataType::S32>(axis);
+    }
+  }
+  else
+  {
+    // TODO support dynamic reshape from not const node
+    throw std::runtime_error("ReshapeCanonicalizer: only support const node as input shape");
+  }
 
   // replace
   auto in = node->tensor();