From 1533dc88c484b2373f1f8ce27155277dbe09cdbc Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 2 Aug 2019 13:59:48 +0900 Subject: [PATCH] [moco-tf] Reshape canonicalize support ConstGen shape input (#6115) * [moco-tf] Reshape canonicalize support ConstGen shape input In canonicalization of TFReshape, shape input of Reshape supports not only TF-dialect TFConst, but also canonical ConstGen. This is to eliminate dependency on canonicalization order. Signed-off-by: Cheongyo Bahk * Update comments * Fix typo * Remove ambiguity on shape() --- .../src/Canonicalization/ReshapeCanonicalizer.cpp | 78 ++++++++++++++-------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp index cdee8bd..343518f 100644 --- a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp @@ -67,47 +67,73 @@ bool canonicalize_reshape(loco::Graph *graph, moco::tf::TFReshape *node) using FixedReshape = loco::Reshape; - // TODO Support not only TFConst, but also loco::ConstGen as shape input - // TODO Support other cases like dynamic reshape - /** - * @note This will replace TFReshape + TFConst(as shape input) node pair into - * Canonical Reshape, or 'FixedReshape'. TFConst - * should not have -1 as its entry to be converted to FixedReshape. + * This rule canonicalizes TFReshape only when its output shape is known at + * compile time, i.e. fixed reshape case. + * TODO Support other cases like dynamic reshape + * + * This will replace TFReshape + TFConst or Canonical ConstGen(as shape input) + * node pair into Canonical Reshape, or 'FixedReshape'. + * Shape input (TFConst or Canonical ConstGen) should not have wildcard + * dimension to be converted to FixedReshape. * - * Before - * (shape) - * TFConst ---- - * \ - * In --------- TFReshape ------- Out(s) - * (tensor) + * Before + * TFConst (shape) + * or --- + * ConstGen \ + * \ + * In --------- TFReshape ------- Out(s) + * (tensor) * - * After - * TFConst ---- - * \ - * ---------- TFReshape - * / - * In -------- FixedReshape ----- Out(s) + * After + * TFConst + * or --- + * ConstGen \ + * \ + * ---------- TFReshape + * / + * In -------- FixedReshape ----- Out(s) */ - // create loco nodes + // create loco node to replace auto fixed_reshape = graph->nodes()->create(); - // Only supports fixed reshape - auto tfconst_shape_input = dynamic_cast(node->shape()); - assert(is_fixed_shape_input(tfconst_shape_input)); + // Supports 2 cases for Reshape's shape input: + // TF-dialect TFConst or Canonical ConstGen + loco::Node *shape_input = node->shape(); + auto tfconst_shape_input = dynamic_cast(shape_input); + auto constgen_shape_input = dynamic_cast(shape_input); + if (tfconst_shape_input) { - // set attribute - auto shape = dynamic_cast(node->shape()); - auto rank = shape->dim(0).value(); + // Only support fixed reshape + // TODO support dynamic reshape + assert(is_fixed_shape_input(tfconst_shape_input)); + auto rank = tfconst_shape_input->dim(0).value(); fixed_reshape->rank(rank); for (uint32_t axis = 0; axis < rank; ++axis) { - fixed_reshape->dim(axis) = shape->at(axis); + fixed_reshape->dim(axis) = tfconst_shape_input->at(axis); } } + else if (constgen_shape_input) + { + // ditto + assert(is_fixed_shape_input(constgen_shape_input)); + + auto rank = constgen_shape_input->dim(0).value(); + fixed_reshape->rank(rank); + for (uint32_t axis = 0; axis < rank; ++axis) + { + fixed_reshape->dim(axis) = constgen_shape_input->at(axis); + } + } + else + { + // TODO support dynamic reshape from not const node + throw std::runtime_error("ReshapeCanonicalizer: only support const node as input shape"); + } // replace auto in = node->tensor(); -- 2.7.4