[nnc] Add Resize Nearest Neighbor (#2315)
authorАндрей Шедько/AI Tools Lab /SRR/Assistant Engineer/삼성전자 <a.shedko@partner.samsung.com>
Wed, 28 Nov 2018 15:13:26 +0000 (18:13 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 28 Nov 2018 15:13:26 +0000 (18:13 +0300)
Added Resize Nearest Neighbor to tflite importer and interpreter.
Added shape inference tests some cases.

The op allows adding more resize types by just adding new values to the enum and backends.
The corresponding Op in ONNX is Upsample and it can be supported without modifications

Signed-off-by: Andrei Shedko <a.shedko@partner.samsung.com>
21 files changed:
contrib/nnc/core/modelIR/IrDotDumper.cpp
contrib/nnc/core/modelIR/Operation.cpp
contrib/nnc/core/modelIR/Shape.cpp
contrib/nnc/core/modelIR/ShapeInference.cpp
contrib/nnc/include/core/modelIR/IrDotDumper.h
contrib/nnc/include/core/modelIR/Shape.h
contrib/nnc/include/core/modelIR/ShapeInference.h
contrib/nnc/include/core/modelIR/operations/ResizeOp.h [new file with mode: 0644]
contrib/nnc/include/core/modelIR/operations/operations.lst.h
contrib/nnc/include/passes/acl_soft_backend/AclCppOpGenerator.h
contrib/nnc/include/passes/interpreter/Interpreter.h
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.h
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/SBSerializer.h
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.h
contrib/nnc/unittests/core/ShapeInference.cpp

index dc2ef09..c1f0986 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include <iostream>
+
 #include "core/modelIR/IrDotDumper.h"
 
 namespace nnc {
@@ -223,5 +224,15 @@ void mir::IrDotDumper::visit(ops::PadOp& op) {
   dotBuilder.updateWithOp(&op, node_info);
 }
 
+void IrDotDumper::visit(ops::ResizeOp& op) {
+  auto node_info = DotIrNodeInfo().withType("Resize", op.getName())
+    .withInShapes(getInputShapes(op))
+    .withOutShapes(getOutputShapes(op))
+    .withMisc("Mode", (int) op.getMode());
+  // scale and resShape are only needed in Shape Inference
+  
+  dotBuilder.updateWithOp(&op, node_info);
+}
+
 } // namespace mir
 } // namespace nnc
index c6e3e81..84a65b0 100644 (file)
@@ -24,6 +24,7 @@
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/VariableOp.h"
 #include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/ConcatOp.h"
 #include "core/modelIR/operations/BiasAddOp.h"
index 8867e6b..0ddb7f9 100644 (file)
@@ -24,6 +24,8 @@ namespace nnc
 namespace mir
 {
 
+constexpr int32_t mir::Shape::autoDim;
+
 Shape::Shape(std::initializer_list<int32_t> &&l) : _dims{l}
 {
   // DO NOTHING
@@ -57,7 +59,7 @@ int32_t Shape::numElements() const
 
   for (int32_t axis = 0; axis < rank(); ++axis)
   {
-    assert(dim(axis) != Shape::AUTO_DIM);
+    assert(dim(axis) != Shape::autoDim);
     res *= dim(axis);
   }
 
@@ -92,7 +94,7 @@ std::ostream &operator<<(std::ostream &s, const Shape &sh)
   {
     if (axis != 0)
       s << ", ";
-    if (sh.dim(axis) == Shape::AUTO_DIM)
+    if (sh.dim(axis) == Shape::autoDim)
       s << "AUTO";
     else
       s << sh.dim(axis);
index 3ef6069..98d98a1 100644 (file)
@@ -32,6 +32,7 @@
 #include "core/modelIR/operations/ConcatOp.h"
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/BatchNormOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/DropoutOp.h"
@@ -149,6 +150,32 @@ void ShapeInference::visit(ops::ReluOp& op) {
   op.setOutputShape(0, op.getInputShape(0));
 }
 
+void ShapeInference::visit(ops::ResizeOp& op) {
+  fillInputShapes(op);
+  const auto& in_s = op.getInputShape(0);
+  Shape out_s = in_s;
+  auto res_s = op.getResultShape();
+  const std::vector<float>& scales = op.getScales();
+
+  if (scales.size() > 0) {
+    assert(
+      in_s.rank() == static_cast<int32_t>(scales.size())
+      && "Scaling parameter incompatible with input shape");
+    for (int32_t i = 0; i < in_s.rank(); i++) {
+      out_s.dim(i) = (int32_t)lroundf(scales[i] * in_s.dim(i));
+    }
+  } else {
+    // Assume batch is cut off
+    assert(in_s.rank() == 3);
+    out_s.dim(0) = res_s.dim(0);
+    out_s.dim(1) = res_s.dim(1);
+    out_s.dim(2) = in_s.dim(2);
+    op.setScales({static_cast<float> (out_s.dim(0)) / in_s.dim(0),
+                  static_cast<float> (out_s.dim(1)) / in_s.dim(1), 1.0f});
+  }
+  op.setOutputShape(0, out_s);
+}
+
 void ShapeInference::visit(ops::SoftmaxOp& op) {
   fillInputShapes(op);
   op.setOutputShape(0, op.getInputShape(0));
@@ -232,14 +259,14 @@ void ShapeInference::visit(ops::ReshapeOp& op) {
   //can't use num_elements due to -1 in input shape and Shape using unsigned ints for dimensions
   for( int32_t d = 0; d < outShape.rank(); ++d ) {
     auto dim = outShape.dim(d);
-    if( dim != Shape::AUTO_DIM) {
+    if( dim != Shape::autoDim) {
       outElementsNum *= dim;
     }
   }
 
   for( int32_t d = 0; d < outShape.rank(); ++d ) {
     auto& dim = outShape.dim(d);
-    if( dim == Shape::AUTO_DIM ) {
+    if( dim == Shape::autoDim ) {
       dim = static_cast<int32_t>(inElementsNum / outElementsNum);
     }
   }
index f053843..567cc82 100644 (file)
@@ -31,6 +31,7 @@
 #include "core/modelIR/operations/ConcatOp.h"
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/BatchNormOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/DropoutOp.h"
@@ -63,6 +64,7 @@ public:
   void visit(ops::BiasAddOp& op) override;
   void visit(ops::VariableOp& op) override;
   void visit(ops::ReshapeOp& op) override;
+  void visit(ops::ResizeOp& op) override;
   void visit(ops::ScaleOp& op) override;
   void visit(ops::BatchNormOp& op) override;
   void visit(ops::DropoutOp& op) override;
index e3d5eab..6c245e1 100644 (file)
@@ -30,7 +30,7 @@ namespace mir
 class Shape
 {
 public:
-  static const auto AUTO_DIM = static_cast<int32_t>(-1);
+  static constexpr int32_t autoDim = -1;
 
   Shape() = default;
   Shape(std::initializer_list<int32_t> &&l);
index 3dcd9af..17e9f4a 100644 (file)
@@ -38,6 +38,7 @@ public:
   void visit(ops::CappedReluOp& op) override;
   void visit(ops::BiasAddOp& op) override;
   void visit(ops::ReshapeOp& op) override;
+  void visit(ops::ResizeOp& op) override;
   void visit(ops::VariableOp& op) override;
   void visit(ops::ScaleOp& op) override;
   void visit(ops::BatchNormOp& op) override;
diff --git a/contrib/nnc/include/core/modelIR/operations/ResizeOp.h b/contrib/nnc/include/core/modelIR/operations/ResizeOp.h
new file mode 100644 (file)
index 0000000..b18a7f0
--- /dev/null
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_IR_MODEL_RESIZEOP_H_
+#define _NNC_CORE_IR_MODEL_RESIZEOP_H_
+
+#include "core/modelIR/Operation.h"
+#include "core/modelIR/Shape.h"
+#include <vector>
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+class ResizeOp : public Operation {
+public:
+
+  enum class ResizeMethod {
+    nearestNeighbor, // TODO: BICUBIC and BILINEAR
+  };
+
+  explicit ResizeOp(const IODescriptor& arg, ResizeMethod mode, const std::vector<float>& scales) :
+    Operation(Type::resizeIm, {arg}), _mode(mode), _scales(scales),
+    _resultShape({}) {}
+
+  explicit ResizeOp(const IODescriptor& arg, ResizeMethod mode, const Shape& shape) :
+    Operation(Type::resizeIm, {arg}), _mode(mode),
+    _scales({}), _resultShape(shape) {}
+
+  /** @return The resize mode */
+  ResizeMethod getMode() const { return _mode; }
+
+  const Shape& getResultShape() const { return _resultShape; }
+
+  const std::vector<float>& getScales() const { return _scales; }
+
+  void setScales(const std::vector<float>& scales) { _scales = scales; }
+
+private:
+  std::vector<float> _scales;
+  Shape _resultShape;
+  ResizeMethod _mode;
+};
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
+
+
+#endif //_NNC_CORE_IR_MODEL_RESIZEOP_H_
index c04286a..c89e716 100644 (file)
@@ -29,6 +29,7 @@ HANDLE_OP(biasAdd, BiasAddOp)
 HANDLE_OP(variable, VariableOp)
 HANDLE_OP(ReLU, ReluOp)
 HANDLE_OP(reshape, ReshapeOp)
+HANDLE_OP(resizeIm, ResizeOp)
 HANDLE_OP(scale, ScaleOp)
 HANDLE_OP(batchNorm, BatchNormOp)
 HANDLE_OP(dropout, DropoutOp)
index 6c9af1e..c8c2e7e 100644 (file)
@@ -58,6 +58,7 @@ public:
   void visit(mir::ops::VariableOp& op) override;
   void visit(mir::ops::ReluOp& op) override;
   void visit(mir::ops::ReshapeOp& op) override;
+  void visit(mir::ops::ResizeOp& op) override;
   void visit(mir::ops::ScaleOp& op) override;
   void visit(mir::ops::BatchNormOp& op) override;
   void visit(mir::ops::DropoutOp& op) override;
index 1bc2af0..de0742a 100644 (file)
@@ -47,6 +47,7 @@ public:
   void visit(ops::BiasAddOp& op) override;
   void visit(ops::VariableOp& op) override;
   void visit(ops::ReshapeOp& op) override;
+  void visit(ops::ResizeOp& op) override;
   void visit(ops::ScaleOp& op) override;
   void visit(ops::BatchNormOp& op) override;
   void visit(ops::DropoutOp& op) override;
index 7b06637..d593b91 100644 (file)
@@ -15,6 +15,7 @@
 #include "core/modelIR/operations/CappedReluOp.h"
 #include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
 #include "core/modelIR/operations/ConcatOp.h"
@@ -681,5 +682,9 @@ void AclCppOpGenerator::visit(ops::SqueezeOp& op) {
   assert(false && "Unimplemented operation: Squeeze");
 }
 
+void AclCppOpGenerator::visit(mir::ops::ResizeOp& op) {
+  assert(false && "Unimplemented operation: Resize");
+}
+
 }
 // namespace nnc
index 877f945..c2e9611 100644 (file)
@@ -30,6 +30,7 @@
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/VariableOp.h"
 #include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/ConcatOp.h"
 #include "core/modelIR/operations/BiasAddOp.h"
@@ -282,4 +283,26 @@ void NNInterpreter::visit(ops::PadOp& op) {
   var(op.getId()) = Pad(input, op)();
 }
 
+void NNInterpreter::visit(ops::ResizeOp& op) {
+  mapByName(&op);
+  auto operand = op.getPrevNodes()[0];
+  Tensor<float> input(var(operand.op->getId())[operand.index]);
+  assert(input.getShape().rank() == 3 && "Must be rank 3 (for now)");
+  switch (op.getMode()) {
+    case ops::ResizeOp::ResizeMethod::nearestNeighbor: {
+      auto scales = op.getScales();
+      var(op.getId()) = Fill<float>(op.getOutputShape(0), [&scales, &input, &op](const Index& id) {
+        const Index in_idx = {static_cast<int> (lroundf(scales[0] * id.at(0))),
+                              static_cast<int> (lroundf(scales[1] * id.at(1))),
+                              static_cast<int> (lroundf(scales[2] * id.at(2)))};
+        return input.at(in_idx);
+      })();
+      break;
+    }
+    default:
+      assert(false && "Not supported Optype");
+  }
+
+}
+
 } // namespace nnc
index 9d93201..cf0c795 100644 (file)
@@ -225,6 +225,10 @@ void ModelAnalyzer::visit(ops::ReshapeOp& op) {
   addOpDescr(&op, "reshape");
 }
 
+void ModelAnalyzer::visit(mir::ops::ResizeOp& op) {
+  assert(false && "Not implemented");
+}
+
 void ModelAnalyzer::visit(ops::DropoutOp& op) {
   addOpDescr(&op, "dropout");
 }
index 77831e0..79cd275 100644 (file)
@@ -98,6 +98,7 @@ public:
   void visit(mir::ops::VariableOp& op) override;
   void visit(mir::ops::ReluOp& op) override;
   void visit(mir::ops::ReshapeOp& op) override;
+  void visit(mir::ops::ResizeOp& op) override;
   void visit(mir::ops::ScaleOp& op) override;
   void visit(mir::ops::BatchNormOp& op) override;
   void visit(mir::ops::DropoutOp& op) override;
index 7c09d23..7b51dd5 100644 (file)
@@ -30,6 +30,7 @@
 #include "core/modelIR/operations/CappedReluOp.h"
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/BatchNormOp.h"
@@ -320,4 +321,8 @@ void Serializer::visit(mir::ops::PadOp& op) {
   throw PassException("Not implemented yet");
 }
 
+void Serializer::visit(mir::ops::ResizeOp& op) {
+  throw PassException("Not implemented yet");
+}
+
 } // namespace nnc
index f5135a7..e1c3c62 100644 (file)
@@ -52,6 +52,7 @@ public:
   void visit(mir::ops::VariableOp& op) override;
   void visit(mir::ops::ReluOp& op) override;
   void visit(mir::ops::ReshapeOp& op) override;
+  void visit(mir::ops::ResizeOp& op) override;
   void visit(mir::ops::ScaleOp& op) override;
   void visit(mir::ops::BatchNormOp& op) override;
   void visit(mir::ops::DropoutOp& op) override;
index 28f1dc5..0269df9 100644 (file)
@@ -92,6 +92,7 @@ void TfliteImporter::processUnsupportedOp(const Operator* op) {
       break;
     case BuiltinOperator_SOFTMAX:
     case BuiltinOperator_RESHAPE:
+    case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
     case BuiltinOperator_SQUEEZE:
     case BuiltinOperator_PAD:
     case BuiltinOperator_ADD:
@@ -182,6 +183,10 @@ void TfliteImporter::walkOperator(const Operator* op) {
     case BuiltinOperator_RESHAPE:
       outputs = _opCreator->convertReshape(inputs, params, op->builtin_options_as<ReshapeOptions>());
       break;
+    case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
+      outputs = _opCreator->convertResizeNN(inputs, params,
+        op->builtin_options_as<ResizeNearestNeighborOptions>());
+      break;
     case BuiltinOperator_FULLY_CONNECTED:
       outputs = _opCreator->convertFullyConnected(inputs, params,
                                                   op->builtin_options_as<FullyConnectedOptions>());
index e61378c..566f72f 100644 (file)
@@ -22,6 +22,7 @@
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
 #include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/CappedReluOp.h"
 #include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/Tensor.h"
 #include "pass/PassException.h"
 
+#include "core/modelIR/Tensor.h"
+#include "core/modelIR/Shape.h"
+#include "core/modelIR/ShapeRange.h"
+
 using namespace nnc::mir;
 using namespace ::tflite;
 
@@ -133,6 +138,7 @@ std::vector<mir::Operation*> TFLiteOpCreator::convertReshape(InputOps inputs, In
   return outputs;
 }
 
+
 std::vector<mir::Operation*>
 TFLiteOpCreator::createTransposeConv(InputOps& inputs, InputParams& params,
                                      const ::tflite::TransposeConvOptions* opts) {
@@ -142,6 +148,23 @@ TFLiteOpCreator::createTransposeConv(InputOps& inputs, InputParams& params,
                                    paddingMap[opts->padding()]);
 }
 
+std::vector<mir::Operation*> TFLiteOpCreator::convertResizeNN(
+  InputOps& inputs, InputParams& params,
+  const ::tflite::ResizeNearestNeighborOptions* opts) {
+  // TODO support aligned corners
+  assert(!opts->align_corners() && "Aligned corners not currently supported");
+
+  mir::Tensor<int> out_shapes = mir::Tensor<int>(*params[0].get());
+  std::vector<int> res_shape;
+  for (const auto& i : mir::ShapeRange(out_shapes.getShape()))
+    res_shape.push_back(out_shapes.at(i));
+  res_shape.push_back(Shape::autoDim);
+  // assume no batch
+  return createOp<ops::ResizeOp>(ActivationFunctionType_NONE, inputs[0]->getOutput(0),
+                                 ops::ResizeOp::ResizeMethod::nearestNeighbor, Shape(res_shape));
+}
+
+
 std::vector<mir::Operation*>
 TFLiteOpCreator::createAdd(InputOps& inputs, InputParams&, const ::tflite::AddOptions* opts) {
   std::vector<IODescriptor> descriptors;
@@ -235,8 +258,8 @@ mir::Operation* TFLiteOpCreator::addFusedActivation(mir::Operation* input,
   }
 }
 
-std::vector<mir::Operation*> TFLiteOpCreator::createSqueeze(InputOps inputs, InputParams params,
-                                                            const ::tflite::SqueezeOptions* opts) {
+std::vector<mir::Operation*> TFLiteOpCreator::createSqueeze(
+  InputOps inputs, InputParams params, const ::tflite::SqueezeOptions* opts) {
 
   std::vector<int32_t> squeeze_dims{opts->squeeze_dims()->begin(), opts->squeeze_dims()->end()};
 
index 8574e15..5b8a14d 100644 (file)
@@ -69,6 +69,9 @@ public:
   std::vector<mir::Operation*> convertFullyConnected(InputOps, InputParams,
                                                      const ::tflite::FullyConnectedOptions*);
 
+  std::vector<mir::Operation*> convertResizeNN(InputOps, InputParams,
+                                               const ::tflite::ResizeNearestNeighborOptions*);
+
   std::vector<mir::Operation*> createSqueeze(InputOps& inputs, InputParams& params,
                                              const ::tflite::SqueezeOptions* opts);
 
index 9fefe88..924b1a0 100644 (file)
@@ -17,8 +17,8 @@
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/ShapeInference.h"
 #include "core/modelIR/operations/ReshapeOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
-#include "core/modelIR/Shape.h"
 
 #include "gtest/gtest.h"
 
@@ -31,8 +31,9 @@ TEST(ShapeInferenceTest, ReshapeAutoDimension) {
   Shape input_shape{10, 2, 5};
   Shape expected_shape{10, 1, 10};
 
+
   auto input = g.create<ops::VariableOp>("input", input_shape);
-  auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{10, 1, Shape::AUTO_DIM});
+  auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{10, 1, Shape::autoDim});
   op->setInputShape(0, input_shape);
 
   si.visit(*dynamic_cast<ops::ReshapeOp*>(op));
@@ -40,6 +41,42 @@ TEST(ShapeInferenceTest, ReshapeAutoDimension) {
   ASSERT_EQ(expected_shape, op->getOutputShape(0));
 }
 
+TEST(ShapeInferenceTest, ResizeWithShape) {
+  Graph g;
+  ShapeInference si;
+
+  Shape result_shape{10, 10, 3};
+
+  auto input = g.create<ops::VariableOp>("input", Shape{5, 5, 3});
+
+  auto op = g.create<ops::ResizeOp>(
+    "Resize", input->getOutput(0), ops::ResizeOp::ResizeMethod::nearestNeighbor,
+    Shape{10, 10, Shape::autoDim}
+  );
+
+  g.accept(&si);
+
+  ASSERT_EQ(result_shape, op->getOutputShape(0));
+}
+
+TEST(ShapeInferenceTest, ResizeWithScale) {
+  Graph g;
+  ShapeInference si;
+
+  Shape result_shape{30, 10, 3};
+
+  auto input = g.create<ops::VariableOp>("input", Shape{5, 5, 3});
+
+  auto op = g.create<ops::ResizeOp>(
+    "Resize", input->getOutput(0), ops::ResizeOp::ResizeMethod::nearestNeighbor,
+    std::vector<float>{6, 2, 1}
+  );
+
+  g.accept(&si);
+
+  ASSERT_EQ(result_shape, op->getOutputShape(0));
+}
+
 TEST(ShapeInferenceTest, ReshapeAutoDimensionShrink) {
   Graph g;
   ShapeInference si;
@@ -48,7 +85,7 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionShrink) {
   Shape result_shape_shrink{10, 20};
 
   auto input = g.create<ops::VariableOp>("input", input_shape);
-  auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{10, Shape::AUTO_DIM});
+  auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{10, Shape::autoDim});
   op->setInputShape(0, input_shape);
 
   si.visit(*dynamic_cast<ops::ReshapeOp*>(op));
@@ -63,7 +100,8 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionExpand) {
   Shape result_shape_expand{5, 10, 2, 2};
 
   auto input = g.create<ops::VariableOp>("input", input_shape);
-  auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{5, Shape::AUTO_DIM, 2, 2});
+  auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0),
+                                     Shape{5, Shape::autoDim, 2, 2});
   op->setInputShape(0, input_shape);
 
   si.visit(*dynamic_cast<ops::ReshapeOp*>(op));