[luci] Shape inf reduce for virtual outputs (#3938)
authorSaeHie Park <saehie.park@gmail.com>
Mon, 24 Aug 2020 04:48:30 +0000 (13:48 +0900)
committerGitHub <noreply@github.com>
Mon, 24 Aug 2020 04:48:30 +0000 (13:48 +0900)
This will relocate implementation of shape inference to reduce class LOC

ONE-DCO-1.0-Signed-off-by: SaeHie Park <saehie.park@gmail.com>

compiler/luci/service/src/CircleShapeInferenceRule.cpp

index 7b2bd99..779c206 100644 (file)
@@ -849,6 +849,315 @@ loco::TensorShape infer_reducer(const loco::Node *input, const loco::Node *indic
   return output_shape;
 }
 
+// Virtual
+loco::NodeShape infer_input(const luci::CircleInput *node)
+{
+  loco::TensorShape shape;
+
+  shape.rank(node->rank());
+  for (uint32_t axis = 0; axis < node->rank(); axis++)
+    shape.dim(axis) = node->dim(axis);
+
+  return loco::NodeShape{shape};
+}
+
+loco::NodeShape infer_output(const luci::CircleOutput *node)
+{
+  auto graph_outputs = node->graph()->outputs();
+  auto graph_output = graph_outputs->at(node->index());
+  auto output_shape = graph_output->shape();
+
+  return loco::NodeShape{*output_shape};
+}
+
+loco::NodeShape infer_if_out(const luci::CircleIfOut *node)
+{
+  /**
+   * @note  IF operator type and shape are that of the "then" and "else"
+   *        Graph Outputs.
+   */
+  auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
+  if (circle_if == nullptr)
+  {
+    INTERNAL_EXN("CircleIf IR is not configured correctly");
+  }
+
+  auto index = node->index();
+  auto then_graph = circle_if->then_graph();
+  auto else_graph = circle_if->else_graph();
+  assert(then_graph != nullptr);
+  assert(else_graph != nullptr);
+
+  // shape and type are assumed to be same
+  // these are checked at post_import_graph() in Import
+  auto then_outputs = loco::output_nodes(then_graph);
+  auto else_outputs = loco::output_nodes(else_graph);
+  assert(then_outputs.size() == else_outputs.size());
+  assert(index < static_cast<int32_t>(then_outputs.size()));
+
+  auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
+  auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
+
+  auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
+  auto else_graph_outputs = else_graph->outputs();
+  assert(then_graph_outputs->size() == else_graph_outputs->size());
+
+  auto then_graph_output = then_graph_outputs->at(then_out->index());
+  auto else_graph_output = else_graph_outputs->at(else_out->index());
+  (void)else_graph_output; // make compiler happy for unused variable warnings
+  assert(*then_graph_output->shape() == *else_graph_output->shape());
+
+  return loco::NodeShape{*then_graph_output->shape()};
+}
+
+loco::NodeShape infer_non_max_suppression_v4_out(const luci::CircleNonMaxSuppressionV4Out *node)
+{
+  const loco::DataType S32 = loco::DataType::S32;
+
+  auto nmsv4 = dynamic_cast<const luci::CircleNonMaxSuppressionV4 *>(node->input());
+  if (nmsv4 == nullptr)
+    INTERNAL_EXN("CircleNonMaxSuppressionV4 IR is not configured correctly");
+
+  auto index = node->index();
+  if (index == 1)
+    return loco::TensorShape({0});
+
+  assert(index == 0);
+
+  auto unknown = loco::TensorShape{loco::Dimension()};
+  auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv4->max_output_size());
+  if (max_output_size == nullptr)
+    return unknown; // we need CircleConst for max output size
+
+  LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
+
+  if (max_output_size->size<S32>() < 1)
+    return unknown;
+
+  auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
+  return loco::TensorShape{max_output_size_value};
+}
+
+loco::NodeShape infer_non_max_suppression_v5_out(const luci::CircleNonMaxSuppressionV5Out *node)
+{
+  const loco::DataType S32 = loco::DataType::S32;
+
+  auto nmsv5 = dynamic_cast<const luci::CircleNonMaxSuppressionV5 *>(node->input());
+  if (nmsv5 == nullptr)
+    INTERNAL_EXN("CircleNonMaxSuppressionV5 IR is not configured correctly");
+
+  auto index = node->index();
+  if (index == 2)
+    return loco::TensorShape({0});
+
+  assert(index == 0 || index == 1);
+
+  auto unknown = loco::TensorShape{loco::Dimension()};
+  auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv5->max_output_size());
+  if (max_output_size == nullptr)
+    return unknown; // we need CircleConst for max output size
+
+  LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
+
+  if (max_output_size->size<S32>() < 1)
+    return unknown;
+
+  auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
+  return loco::TensorShape{max_output_size_value};
+}
+
+loco::NodeShape infer_split_out(const luci::CircleSplitOut *node)
+{
+  const loco::DataType S32 = loco::DataType::S32;
+
+  auto split = dynamic_cast<const luci::CircleSplit *>(node->input());
+  if (split == nullptr)
+    INTERNAL_EXN("CircleSplit IR is not configured correctly");
+
+  loco::NodeShape unknown;
+
+  auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
+
+  auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
+  if (split_dim == nullptr)
+    return unknown; // we need CircleConst for split_dim
+  LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
+
+  assert(split_dim->size<S32>() == 1);
+  auto split_dim_axis = split_dim->at<S32>(0);
+  if (split_dim_axis < 0)
+    split_dim_axis += split_shape.rank();
+
+  auto split_dim_value = split_shape.dim(split_dim_axis).value();
+  assert(split_dim_value % split->num_split() == 0);
+  const int split_depth = split_dim_value / split->num_split();
+
+  loco::TensorShape output_shape = split_shape;
+
+  // All shapes are equally same
+  output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
+
+  return loco::NodeShape{output_shape};
+}
+
+loco::NodeShape infer_split_v_out(const luci::CircleSplitVOut *node)
+{
+  const loco::DataType S32 = loco::DataType::S32;
+
+  auto split = dynamic_cast<const luci::CircleSplitV *>(node->input());
+  if (split == nullptr)
+    INTERNAL_EXN("CircleSplit IR is not configured correctly");
+
+  loco::NodeShape unknown;
+
+  auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
+
+  auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits());
+  if (size_splits == nullptr)
+    return unknown; // we need CircleConst for size_splits
+  LUCI_ASSERT(size_splits->dtype() == S32, "Only support int32 for size_splits");
+
+  auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
+  if (split_dim == nullptr)
+    return unknown; // we need CircleConst for split_dim
+  LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
+
+  // fetch axis
+  assert(split_dim->size<S32>() == 1);
+  auto split_dim_axis = split_dim->at<S32>(0);
+  if (split_dim_axis < 0)
+    split_dim_axis += split_shape.rank();
+
+  // interpret size_splits values
+  int32_t size_splits_count = static_cast<int32_t>(size_splits->size<S32>());
+  assert(size_splits_count == split->num_split());
+
+  int64_t minus_one_count = 0, size_splits_sum = 0;
+  for (int32_t idx = 0; idx < size_splits_count; ++idx)
+  {
+    auto size = size_splits->at<S32>(idx);
+    assert(size >= -1);
+    if (size == -1)
+      ++minus_one_count;
+    else
+      size_splits_sum += size;
+  }
+  if (minus_one_count > 1)
+    INTERNAL_EXN("CircleSplitV size_splits has more than two -1 values");
+
+  // calcuate this SplitVOut shape
+  auto input_size = split_shape.dim(split_dim_axis).value();
+  assert(size_splits_sum <= input_size);
+
+  auto index_this = node->index();
+  assert(0 <= index_this && index_this < split->num_split());
+  auto split_depth = size_splits->at<S32>(index_this);
+  if (split_depth == -1)
+    split_depth = input_size - size_splits_sum;
+
+  loco::TensorShape output_shape = split_shape;
+
+  output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
+
+  return loco::NodeShape{output_shape};
+}
+
+loco::NodeShape infer_top_k_v2_out(const luci::CircleTopKV2Out *node)
+{
+  const loco::DataType S32 = loco::DataType::S32;
+
+  auto topkv2 = dynamic_cast<const luci::CircleTopKV2 *>(node->input());
+  if (topkv2 == nullptr)
+    INTERNAL_EXN("CircleSplit IR is not configured correctly");
+
+  // shape of topkv2 is same as topkv2->input()
+  auto input_shape = loco::shape_get(topkv2).as<loco::TensorShape>();
+
+  auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k());
+  LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32");
+  assert(node_k->size<S32>() == 1);
+
+  loco::TensorShape output_shape;
+
+  output_shape.rank(input_shape.rank());
+  for (uint32_t idx = 0; idx < input_shape.rank() - 1; ++idx)
+  {
+    output_shape.dim(idx) = input_shape.dim(idx);
+  }
+  output_shape.dim(input_shape.rank() - 1) = node_k->at<S32>(0);
+
+  return loco::NodeShape{output_shape};
+}
+
+loco::NodeShape infer_unique_out(const luci::CircleUniqueOut *node)
+{
+  if (node->index() == 0)
+  {
+    auto unique_shape = own_shape(node);
+    return loco::NodeShape{unique_shape};
+  }
+  assert(node->index() == 1);
+  auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
+  auto unique_shape = loco::shape_get(unique->input()).as<loco::TensorShape>();
+
+  assert(unique_shape.rank() == 1);
+
+  loco::TensorShape shape_output;
+  shape_output.rank(1);
+  shape_output.dim(0) = unique_shape.dim(0);
+  return loco::NodeShape{shape_output};
+}
+
+loco::NodeShape infer_unpack_out(const luci::CircleUnpackOut *node)
+{
+  auto unpack = dynamic_cast<const luci::CircleUnpack *>(node->input());
+  if (unpack == nullptr)
+  {
+    INTERNAL_EXN("CircleUnpack IR is not configured correctly");
+  }
+
+  auto unpack_shape = loco::shape_get(unpack).as<loco::TensorShape>();
+
+  return loco::NodeShape{unpack_shape};
+}
+
+loco::NodeShape infer_while_out(const luci::CircleWhileOut *node)
+{
+  /**
+   * @note  WHILE operator's shape is the same with the "cond"
+   *        Graph input.
+   */
+  auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
+  if (circle_while == nullptr)
+  {
+    INTERNAL_EXN("CircleWhile IR is not configured correctly");
+  }
+
+  auto index = node->index();
+  auto cond_graph = circle_while->cond_graph();
+  assert(cond_graph != nullptr);
+
+  // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
+  // loco::input_nodes
+  auto cond_inputs = loco::input_nodes(cond_graph);
+  auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
+
+  auto cond_graph_inputs = cond_graph->inputs();
+  auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
+
+  auto cond_graph_input_shape = *cond_graph_input->shape();
+  auto this_shape = own_shape(node);
+
+  if (!(this_shape == cond_graph_input_shape))
+  {
+    LOGGER(l);
+    WARN(l) << "Warning: CircleWhileOut '" << node->name() << "' shape mispatch " << this_shape
+            << " vs " << cond_graph_input_shape;
+  }
+
+  return loco::NodeShape{this_shape};
+}
+
 /**
  * @brief Class to infer the shape of CircleNode
  *
@@ -2063,25 +2372,9 @@ public:
   }
 
   // Virtual
-  loco::NodeShape visit(const luci::CircleInput *node) final
-  {
-    loco::TensorShape shape;
-
-    shape.rank(node->rank());
-    for (uint32_t axis = 0; axis < node->rank(); axis++)
-      shape.dim(axis) = node->dim(axis);
-
-    return loco::NodeShape{shape};
-  }
-
-  loco::NodeShape visit(const luci::CircleOutput *node) final
-  {
-    auto graph_outputs = node->graph()->outputs();
-    auto graph_output = graph_outputs->at(node->index());
-    auto output_shape = graph_output->shape();
+  loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); }
 
-    return loco::NodeShape{*output_shape};
-  }
+  loco::NodeShape visit(const luci::CircleOutput *node) final { return infer_output(node); }
 
   loco::NodeShape visit(const luci::CircleOutputDummy *node) final { return use_own(node); }
 
@@ -2089,293 +2382,32 @@ public:
 
   loco::NodeShape visit(const luci::CircleCustomOut *node) final { return use_own(node); }
 
-  loco::NodeShape visit(const luci::CircleIfOut *node) final
-  {
-    /**
-     * @note  IF operator type and shape are that of the "then" and "else"
-     *        Graph Outputs.
-     */
-    auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
-    if (circle_if == nullptr)
-    {
-      INTERNAL_EXN("CircleIf IR is not configured correctly");
-    }
-
-    auto index = node->index();
-    auto then_graph = circle_if->then_graph();
-    auto else_graph = circle_if->else_graph();
-    assert(then_graph != nullptr);
-    assert(else_graph != nullptr);
-
-    // shape and type are assumed to be same
-    // these are checked at post_import_graph() in Import
-    auto then_outputs = loco::output_nodes(then_graph);
-    auto else_outputs = loco::output_nodes(else_graph);
-    assert(then_outputs.size() == else_outputs.size());
-    assert(index < static_cast<int32_t>(then_outputs.size()));
-
-    auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
-    auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
-
-    auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
-    auto else_graph_outputs = else_graph->outputs();
-    assert(then_graph_outputs->size() == else_graph_outputs->size());
-
-    auto then_graph_output = then_graph_outputs->at(then_out->index());
-    auto else_graph_output = else_graph_outputs->at(else_out->index());
-    (void)else_graph_output; // make compiler happy for unused variable warnings
-    assert(*then_graph_output->shape() == *else_graph_output->shape());
-
-    return loco::NodeShape{*then_graph_output->shape()};
-  }
+  loco::NodeShape visit(const luci::CircleIfOut *node) final { return infer_if_out(node); }
 
   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final
   {
-    const loco::DataType S32 = loco::DataType::S32;
-
-    auto nmsv4 = dynamic_cast<const luci::CircleNonMaxSuppressionV4 *>(node->input());
-    if (nmsv4 == nullptr)
-      INTERNAL_EXN("CircleNonMaxSuppressionV4 IR is not configured correctly");
-
-    auto index = node->index();
-    if (index == 1)
-      return loco::TensorShape({0});
-
-    assert(index == 0);
-
-    auto unknown = loco::TensorShape{loco::Dimension()};
-    auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv4->max_output_size());
-    if (max_output_size == nullptr)
-      return unknown; // we need CircleConst for max output size
-
-    LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
-
-    if (max_output_size->size<S32>() < 1)
-      return unknown;
-
-    auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
-    return loco::TensorShape{max_output_size_value};
+    return infer_non_max_suppression_v4_out(node);
   }
 
   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final
   {
-    const loco::DataType S32 = loco::DataType::S32;
-
-    auto nmsv5 = dynamic_cast<const luci::CircleNonMaxSuppressionV5 *>(node->input());
-    if (nmsv5 == nullptr)
-      INTERNAL_EXN("CircleNonMaxSuppressionV5 IR is not configured correctly");
-
-    auto index = node->index();
-    if (index == 2)
-      return loco::TensorShape({0});
-
-    assert(index == 0 || index == 1);
-
-    auto unknown = loco::TensorShape{loco::Dimension()};
-    auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv5->max_output_size());
-    if (max_output_size == nullptr)
-      return unknown; // we need CircleConst for max output size
-
-    LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
-
-    if (max_output_size->size<S32>() < 1)
-      return unknown;
-
-    auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
-    return loco::TensorShape{max_output_size_value};
+    return infer_non_max_suppression_v5_out(node);
   }
 
-  loco::NodeShape visit(const luci::CircleSplitOut *node) final
-  {
-    const loco::DataType S32 = loco::DataType::S32;
-
-    auto split = dynamic_cast<const luci::CircleSplit *>(node->input());
-    if (split == nullptr)
-      INTERNAL_EXN("CircleSplit IR is not configured correctly");
-
-    loco::NodeShape unknown;
-
-    auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
-
-    auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
-    if (split_dim == nullptr)
-      return unknown; // we need CircleConst for split_dim
-    LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
-
-    assert(split_dim->size<S32>() == 1);
-    auto split_dim_axis = split_dim->at<S32>(0);
-    if (split_dim_axis < 0)
-      split_dim_axis += split_shape.rank();
+  loco::NodeShape visit(const luci::CircleSplitOut *node) final { return infer_split_out(node); }
 
-    auto split_dim_value = split_shape.dim(split_dim_axis).value();
-    assert(split_dim_value % split->num_split() == 0);
-    const int split_depth = split_dim_value / split->num_split();
-
-    loco::TensorShape output_shape = split_shape;
-
-    // All shapes are equally same
-    output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
-
-    return loco::NodeShape{output_shape};
-  }
-
-  loco::NodeShape visit(const luci::CircleSplitVOut *node) final
-  {
-    const loco::DataType S32 = loco::DataType::S32;
-
-    auto split = dynamic_cast<const luci::CircleSplitV *>(node->input());
-    if (split == nullptr)
-      INTERNAL_EXN("CircleSplit IR is not configured correctly");
-
-    loco::NodeShape unknown;
-
-    auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
-
-    auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits());
-    if (size_splits == nullptr)
-      return unknown; // we need CircleConst for size_splits
-    LUCI_ASSERT(size_splits->dtype() == S32, "Only support int32 for size_splits");
-
-    auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
-    if (split_dim == nullptr)
-      return unknown; // we need CircleConst for split_dim
-    LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
-
-    // fetch axis
-    assert(split_dim->size<S32>() == 1);
-    auto split_dim_axis = split_dim->at<S32>(0);
-    if (split_dim_axis < 0)
-      split_dim_axis += split_shape.rank();
-
-    // interpret size_splits values
-    int32_t size_splits_count = static_cast<int32_t>(size_splits->size<S32>());
-    assert(size_splits_count == split->num_split());
-
-    int64_t minus_one_count = 0, size_splits_sum = 0;
-    for (int32_t idx = 0; idx < size_splits_count; ++idx)
-    {
-      auto size = size_splits->at<S32>(idx);
-      assert(size >= -1);
-      if (size == -1)
-        ++minus_one_count;
-      else
-        size_splits_sum += size;
-    }
-    if (minus_one_count > 1)
-      INTERNAL_EXN("CircleSplitV size_splits has more than two -1 values");
-
-    // calcuate this SplitVOut shape
-    auto input_size = split_shape.dim(split_dim_axis).value();
-    assert(size_splits_sum <= input_size);
-
-    auto index_this = node->index();
-    assert(0 <= index_this && index_this < split->num_split());
-    auto split_depth = size_splits->at<S32>(index_this);
-    if (split_depth == -1)
-      split_depth = input_size - size_splits_sum;
-
-    loco::TensorShape output_shape = split_shape;
-
-    output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
-
-    return loco::NodeShape{output_shape};
-  }
+  loco::NodeShape visit(const luci::CircleSplitVOut *node) final { return infer_split_v_out(node); }
 
   loco::NodeShape visit(const luci::CircleTopKV2Out *node) final
   {
-    const loco::DataType S32 = loco::DataType::S32;
-
-    auto topkv2 = dynamic_cast<const luci::CircleTopKV2 *>(node->input());
-    if (topkv2 == nullptr)
-      INTERNAL_EXN("CircleSplit IR is not configured correctly");
-
-    // shape of topkv2 is same as topkv2->input()
-    auto input_shape = loco::shape_get(topkv2).as<loco::TensorShape>();
-
-    auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k());
-    LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32");
-    assert(node_k->size<S32>() == 1);
-
-    loco::TensorShape output_shape;
-
-    output_shape.rank(input_shape.rank());
-    for (uint32_t idx = 0; idx < input_shape.rank() - 1; ++idx)
-    {
-      output_shape.dim(idx) = input_shape.dim(idx);
-    }
-    output_shape.dim(input_shape.rank() - 1) = node_k->at<S32>(0);
-
-    return loco::NodeShape{output_shape};
-  }
-
-  loco::NodeShape visit(const luci::CircleUniqueOut *node) final
-  {
-    if (node->index() == 0)
-    {
-      auto unique_shape = own_shape(node);
-      return loco::NodeShape{unique_shape};
-    }
-    assert(node->index() == 1);
-    auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
-    auto unique_shape = loco::shape_get(unique->input()).as<loco::TensorShape>();
-
-    assert(unique_shape.rank() == 1);
-
-    loco::TensorShape shape_output;
-    shape_output.rank(1);
-    shape_output.dim(0) = unique_shape.dim(0);
-    return loco::NodeShape{shape_output};
+    return infer_top_k_v2_out(node);
   }
 
-  loco::NodeShape visit(const luci::CircleUnpackOut *node) final
-  {
-    auto unpack = dynamic_cast<const luci::CircleUnpack *>(node->input());
-    if (unpack == nullptr)
-    {
-      INTERNAL_EXN("CircleUnpack IR is not configured correctly");
-    }
+  loco::NodeShape visit(const luci::CircleUniqueOut *node) final { return infer_unique_out(node); }
 
-    auto unpack_shape = loco::shape_get(unpack).as<loco::TensorShape>();
+  loco::NodeShape visit(const luci::CircleUnpackOut *node) final { return infer_unpack_out(node); }
 
-    return loco::NodeShape{unpack_shape};
-  }
-
-  loco::NodeShape visit(const luci::CircleWhileOut *node) final
-  {
-    /**
-     * @note  WHILE operator's shape is the same with the "cond"
-     *        Graph input.
-     */
-    auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
-    if (circle_while == nullptr)
-    {
-      INTERNAL_EXN("CircleWhile IR is not configured correctly");
-    }
-
-    auto index = node->index();
-    auto cond_graph = circle_while->cond_graph();
-    assert(cond_graph != nullptr);
-
-    // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
-    // loco::input_nodes
-    auto cond_inputs = loco::input_nodes(cond_graph);
-    auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
-
-    auto cond_graph_inputs = cond_graph->inputs();
-    auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
-
-    auto cond_graph_input_shape = *cond_graph_input->shape();
-    auto this_shape = own_shape(node);
-
-    if (!(this_shape == cond_graph_input_shape))
-    {
-      LOGGER(l);
-      WARN(l) << "Warning: CircleWhileOut '" << node->name() << "' shape mispatch " << this_shape
-              << " vs " << cond_graph_input_shape;
-    }
-
-    return loco::NodeShape{this_shape};
-  }
+  loco::NodeShape visit(const luci::CircleWhileOut *node) final { return infer_while_out(node); }
 };
 
 } // namespace