[nnc] Support for binary elementwise operations (#6408)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Thu, 8 Aug 2019 15:30:10 +0000 (18:30 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 8 Aug 2019 15:30:10 +0000 (18:30 +0300)
* Add implementations of binary elementwise operations to backends.
* Switch backend unittests to new ops.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
18 files changed:
compiler/nnc/include/passes/interpreter/Interpreter.h
compiler/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp
compiler/nnc/passes/acl_soft_backend/AclCppOpGenerator.h
compiler/nnc/passes/interpreter/Interpreter.cpp
compiler/nnc/passes/interpreter/ops/Add.h [new file with mode: 0644]
compiler/nnc/passes/interpreter/ops/Div.h [new file with mode: 0644]
compiler/nnc/passes/interpreter/ops/Max.h [new file with mode: 0644]
compiler/nnc/passes/interpreter/ops/Mul.h [new file with mode: 0644]
compiler/nnc/passes/interpreter/ops/Sub.h [new file with mode: 0644]
compiler/nnc/passes/soft_backend/ModelAnalyzer.cpp
compiler/nnc/passes/soft_backend/ModelAnalyzer.h
compiler/nnc/passes/soft_backend/SBSerializer.cpp
compiler/nnc/passes/soft_backend/SBSerializer.h
compiler/nnc/unittests/acl_backend/MIRToDOM.cpp
compiler/nnc/unittests/optimizations/CombineTransposes.cpp
compiler/nnc/unittests/optimizations/RemoveDeadEnds.cpp
compiler/nnc/unittests/optimizations/Util.h
compiler/nnc/unittests/soft_backend/CPPOperations.cpp

index f50cbb1..97532fa 100644 (file)
@@ -36,6 +36,7 @@ public:
 
   mir::TensorVariant getResult(const mir::Operation::Output* tensor);
 
+  void visit(mir::ops::AddOp& op) override;
   void visit(mir::ops::BatchNormOp& op) override;
   void visit(mir::ops::CappedReluOp& op) override;
   void visit(mir::ops::ConcatOp& op) override;
@@ -43,6 +44,7 @@ public:
   void visit(mir::ops::Conv2DOp& op) override;
   void visit(mir::ops::DeConv2DOp& op) override;
   void visit(mir::ops::DepthwiseConv2DOp& op) override;
+  void visit(mir::ops::DivOp& op) override;
   void visit(mir::ops::DropoutOp& op) override;
   void visit(mir::ops::ElementwiseOp& op) override;
   void visit(mir::ops::EluOp& op) override;
@@ -51,6 +53,8 @@ public:
   void visit(mir::ops::GemmOp& op) override;
   void visit(mir::ops::InputOp& op) override;
   void visit(mir::ops::LeakyReluOp& op) override;
+  void visit(mir::ops::MaxOp& op) override;
+  void visit(mir::ops::MulOp& op) override;
   void visit(mir::ops::OutputOp& op) override;
   void visit(mir::ops::PadOp& op) override;
   void visit(mir::ops::PoolOp& op) override;
@@ -63,6 +67,7 @@ public:
   void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqrtOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
+  void visit(mir::ops::SubOp& op) override;
   void visit(mir::ops::TanhOp& op) override;
   void visit(mir::ops::TransposeOp& op) override;
 
index 8900b49..2612604 100644 (file)
 #include "mir/Tensor.h"
 
 #include "mir/Operation.h"
-#include "mir/ops/BatchNormOp.h"
-#include "mir/ops/CappedReluOp.h"
-#include "mir/ops/ConcatOp.h"
-#include "mir/ops/ConstantOp.h"
-#include "mir/ops/Conv2DOp.h"
-#include "mir/ops/Deconv2DOp.h"
-#include "mir/ops/DepthwiseConv2DOp.h"
-#include "mir/ops/DropoutOp.h"
-#include "mir/ops/ElementwiseOp.h"
-#include "mir/ops/FullyConnectedOp.h"
-#include "mir/ops/GemmOp.h"
-#include "mir/ops/InputOp.h"
-#include "mir/ops/LeakyReluOp.h"
-#include "mir/ops/OutputOp.h"
-#include "mir/ops/PadOp.h"
-#include "mir/ops/PoolOp.h"
-#include "mir/ops/ReduceOp.h"
-#include "mir/ops/ReluOp.h"
-#include "mir/ops/ReshapeOp.h"
-#include "mir/ops/ResizeOp.h"
-#include "mir/ops/SigmoidOp.h"
-#include "mir/ops/SoftmaxOp.h"
-#include "mir/ops/SqrtOp.h"
-#include "mir/ops/TanhOp.h"
-#include "mir/ops/TransposeOp.h"
+#include "mir/OpDefs.h"
 
 #include <algorithm>
 
@@ -986,4 +962,46 @@ void AclCppOpGenerator::visit(mir::ops::OutputOp& /*op*/) {
   // No-op.
 }
 
-}  // namespace nnc
+void AclCppOpGenerator::visit(mir::ops::AddOp &op)
+{
+  assert(op.getNumInputs() == 2);
+  const auto *ir_lhs = op.getInput(0)->getProducer();
+  const auto *ir_rhs = op.getInput(1)->getProducer();
+  const auto *ir_output = op.getOutput(0);
+
+  // Create the output tensor in the DOM and obtain its identifier.
+  auto out = genTensor(ir_output);
+  addToPersistentTensors(out);
+
+  // Get the identifiers of the input tensors in the DOM.
+  auto lhs = AF::id(tensorName(ir_lhs));
+  auto rhs = AF::id(tensorName(ir_rhs));
+
+  genAddition(out->name() + "_" + "addition", 0, ir_rhs->getShape(), lhs, rhs, out);
+}
+
+void AclCppOpGenerator::visit(mir::ops::DivOp &) { throw AclCppException("NYI"); }
+
+void AclCppOpGenerator::visit(mir::ops::MaxOp &) { throw AclCppException("NYI"); }
+
+void AclCppOpGenerator::visit(mir::ops::MulOp &op)
+{
+  assert(op.getNumInputs() == 2);
+  const auto *ir_lhs = op.getInput(0)->getProducer();
+  const auto *ir_rhs = op.getInput(1)->getProducer();
+  const auto *ir_output = op.getOutput(0);
+
+  // Create the output tensor in the DOM and obtain its identifier.
+  auto out = genTensor(ir_output);
+  addToPersistentTensors(out);
+
+  // Get the identifiers of the input tensors in the DOM.
+  auto lhs = AF::id(tensorName(ir_lhs));
+  auto rhs = AF::id(tensorName(ir_rhs));
+
+  genMultiplication(out->name() + "_" + "multiplication", 0, ir_rhs->getShape(), lhs, rhs, out);
+}
+
+void AclCppOpGenerator::visit(mir::ops::SubOp &) { throw AclCppException("NYI"); }
+
+} // namespace nnc
index a0e55cc..8636fda 100644 (file)
@@ -47,6 +47,7 @@ public:
    * @brief Implementations of the IVisitor visitors.
    * @param op
    */
+  void visit(mir::ops::AddOp& op) override;
   void visit(mir::ops::BatchNormOp& op) override;
   void visit(mir::ops::CappedReluOp& op) override;
   void visit(mir::ops::ConcatOp& op) override;
@@ -54,6 +55,7 @@ public:
   void visit(mir::ops::Conv2DOp& op) override;
   void visit(mir::ops::DeConv2DOp& op) override;
   void visit(mir::ops::DepthwiseConv2DOp& op) override;
+  void visit(mir::ops::DivOp& op) override;
   void visit(mir::ops::DropoutOp& op) override;
   void visit(mir::ops::ElementwiseOp& op) override;
   void visit(mir::ops::EluOp& op) override;
@@ -62,6 +64,8 @@ public:
   void visit(mir::ops::GemmOp& op) override;
   void visit(mir::ops::InputOp& op) override;
   void visit(mir::ops::LeakyReluOp& op) override;
+  void visit(mir::ops::MaxOp& op) override;
+  void visit(mir::ops::MulOp& op) override;
   void visit(mir::ops::OutputOp& op) override;
   void visit(mir::ops::PadOp& op) override;
   void visit(mir::ops::PoolOp& op) override;
@@ -74,6 +78,7 @@ public:
   void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqrtOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
+  void visit(mir::ops::SubOp& op) override;
   void visit(mir::ops::TanhOp& op) override;
   void visit(mir::ops::TransposeOp& op) override;
 
index a93680c..862cdc8 100644 (file)
 
 #include "passes/interpreter/Interpreter.h"
 
-#include "mir/ops/BatchNormOp.h"
-#include "mir/ops/CappedReluOp.h"
-#include "mir/ops/ConcatOp.h"
-#include "mir/ops/ConstantOp.h"
-#include "mir/ops/Conv2DOp.h"
-#include "mir/ops/Deconv2DOp.h"
-#include "mir/ops/DepthwiseConv2DOp.h"
-#include "mir/ops/DropoutOp.h"
-#include "mir/ops/ElementwiseOp.h"
-#include "mir/ops/EluOp.h"
-#include "mir/ops/FullyConnectedOp.h"
-#include "mir/ops/GatherOp.h"
-#include "mir/ops/GemmOp.h"
-#include "mir/ops/InputOp.h"
-#include "mir/ops/OutputOp.h"
-#include "mir/ops/LeakyReluOp.h"
-#include "mir/ops/PadOp.h"
-#include "mir/ops/PoolOp.h"
-#include "mir/ops/ReduceOp.h"
-#include "mir/ops/ReluOp.h"
-#include "mir/ops/ResizeOp.h"
-#include "mir/ops/SigmoidOp.h"
-#include "mir/ops/SliceOp.h"
-#include "mir/ops/SoftmaxOp.h"
-#include "mir/ops/SqrtOp.h"
-#include "mir/ops/SqueezeOp.h"
-#include "mir/ops/TanhOp.h"
-#include "mir/ops/TransposeOp.h"
-
+#include "ops/Add.h"
 #include "ops/BatchNorm.h"
 #include "ops/Concat.h"
 #include "ops/Conv2D.h"
 #include "ops/DeConv2D.h"
 #include "ops/DepthwiseConv2D.h"
+#include "ops/Div.h"
 #include "ops/Dropout.h"
 #include "ops/Elementwise.h"
 #include "ops/FullyConnected.h"
 #include "ops/Gather.h"
 #include "ops/Gemm.h"
+#include "ops/Max.h"
+#include "ops/Mul.h"
 #include "ops/Pad.h"
 #include "ops/Pool.h"
 #include "ops/Reduce.h"
 #include "ops/Reshape.h"
 #include "ops/Softmax.h"
+#include "ops/Sub.h"
 #include "ops/Transpose.h"
 #include "ops/common.h"
 
+#include "mir/OpDefs.h"
+
 #include <cmath>
 #include <cassert>
 #include <iostream>
@@ -309,4 +287,39 @@ void NNInterpreter::visit(ops::OutputOp& /*op*/) {
   // No-op.
 }
 
+void NNInterpreter::visit(ops::AddOp &op)
+{
+  auto inputs = getInputTensors(op);
+  auto outputs = Add(op, inputs[0], inputs[1]);
+  setOutputTensors(op, std::move(outputs));
+}
+
+void NNInterpreter::visit(mir::ops::DivOp &op)
+{
+  auto inputs = getInputTensors(op);
+  auto outputs = Div(op, inputs[0], inputs[1]);
+  setOutputTensors(op, std::move(outputs));
+}
+
+void NNInterpreter::visit(mir::ops::MaxOp &op)
+{
+  auto inputs = getInputTensors(op);
+  auto outputs = Max(op, inputs[0], inputs[1]);
+  setOutputTensors(op, std::move(outputs));
+}
+
+void NNInterpreter::visit(mir::ops::MulOp &op)
+{
+  auto inputs = getInputTensors(op);
+  auto outputs = Mul(op, inputs[0], inputs[1]);
+  setOutputTensors(op, std::move(outputs));
+}
+
+void NNInterpreter::visit(mir::ops::SubOp &op)
+{
+  auto inputs = getInputTensors(op);
+  auto outputs = Sub(op, inputs[0], inputs[1]);
+  setOutputTensors(op, std::move(outputs));
+}
+
 } // namespace nnc
diff --git a/compiler/nnc/passes/interpreter/ops/Add.h b/compiler/nnc/passes/interpreter/ops/Add.h
new file mode 100644 (file)
index 0000000..64776ed
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_ADD_
+#define _NNC_CORE_BACKEND_INTERPRETER_ADD_
+
+#include "mir/ops/AddOp.h"
+#include "mir/Tensor.h"
+#include "mir/ShapeRange.h"
+
+namespace nnc
+{
+
+std::vector<mir::TensorVariant> Add(const mir::ops::AddOp &op, const mir::TensorVariant &lhs,
+                                    const mir::TensorVariant &rhs)
+{
+  mir::TensorVariant broadcasted_lhs(lhs, op.getOutputShape(0));
+  mir::TensorVariant broadcasted_rhs(rhs, op.getOutputShape(0));
+  mir::TensorVariant res(mir::DataType::FLOAT32, op.getOutputShape(0));
+  mir::Tensor<float> lhs_accessor(broadcasted_lhs);
+  mir::Tensor<float> rhs_accessor(broadcasted_rhs);
+  mir::Tensor<float> res_accessor(res);
+
+  for (const auto &index : mir::ShapeRange(op.getOutputShape(0)))
+  {
+    res_accessor.at(index) = lhs_accessor.at(index) + rhs_accessor.at(index);
+  }
+
+  return {res};
+}
+
+} // namespace nnc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_ADD_
diff --git a/compiler/nnc/passes/interpreter/ops/Div.h b/compiler/nnc/passes/interpreter/ops/Div.h
new file mode 100644 (file)
index 0000000..862147e
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_DIV_
+#define _NNC_CORE_BACKEND_INTERPRETER_DIV_
+
+#include "mir/ops/DivOp.h"
+#include "mir/Tensor.h"
+#include "mir/ShapeRange.h"
+
+namespace nnc
+{
+
+std::vector<mir::TensorVariant> Div(const mir::ops::DivOp &op, const mir::TensorVariant &lhs,
+                                    const mir::TensorVariant &rhs)
+{
+  mir::TensorVariant broadcasted_lhs(lhs, op.getOutputShape(0));
+  mir::TensorVariant broadcasted_rhs(rhs, op.getOutputShape(0));
+  mir::TensorVariant res(mir::DataType::FLOAT32, op.getOutputShape(0));
+  mir::Tensor<float> lhs_accessor(broadcasted_lhs);
+  mir::Tensor<float> rhs_accessor(broadcasted_rhs);
+  mir::Tensor<float> res_accessor(res);
+
+  for (const auto &index : mir::ShapeRange(op.getOutputShape(0)))
+  {
+    res_accessor.at(index) = lhs_accessor.at(index) / rhs_accessor.at(index);
+  }
+
+  return {res};
+}
+
+} // namespace nnc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_DIV_
diff --git a/compiler/nnc/passes/interpreter/ops/Max.h b/compiler/nnc/passes/interpreter/ops/Max.h
new file mode 100644 (file)
index 0000000..52cb387
--- /dev/null
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_MAX_
+#define _NNC_CORE_BACKEND_INTERPRETER_MAX_
+
+#include "mir/ops/MaxOp.h"
+#include "mir/Tensor.h"
+#include "mir/ShapeRange.h"
+
+#include <algorithm>
+
+namespace nnc
+{
+
+std::vector<mir::TensorVariant> Max(const mir::ops::MaxOp &op, const mir::TensorVariant &lhs,
+                                    const mir::TensorVariant &rhs)
+{
+  mir::TensorVariant broadcasted_lhs(lhs, op.getOutputShape(0));
+  mir::TensorVariant broadcasted_rhs(rhs, op.getOutputShape(0));
+  mir::TensorVariant res(mir::DataType::FLOAT32, op.getOutputShape(0));
+  mir::Tensor<float> lhs_accessor(broadcasted_lhs);
+  mir::Tensor<float> rhs_accessor(broadcasted_rhs);
+  mir::Tensor<float> res_accessor(res);
+
+  for (const auto &index : mir::ShapeRange(op.getOutputShape(0)))
+  {
+    res_accessor.at(index) = std::max(lhs_accessor.at(index), rhs_accessor.at(index));
+  }
+
+  return {res};
+}
+
+} // namespace nnc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_MAX_
diff --git a/compiler/nnc/passes/interpreter/ops/Mul.h b/compiler/nnc/passes/interpreter/ops/Mul.h
new file mode 100644 (file)
index 0000000..fb3ce3e
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_MUL_
+#define _NNC_CORE_BACKEND_INTERPRETER_MUL_
+
+#include "mir/ops/MulOp.h"
+#include "mir/Tensor.h"
+#include "mir/ShapeRange.h"
+
+namespace nnc
+{
+
+std::vector<mir::TensorVariant> Mul(const mir::ops::MulOp &op, const mir::TensorVariant &lhs,
+                                    const mir::TensorVariant &rhs)
+{
+  mir::TensorVariant broadcasted_lhs(lhs, op.getOutputShape(0));
+  mir::TensorVariant broadcasted_rhs(rhs, op.getOutputShape(0));
+  mir::TensorVariant res(mir::DataType::FLOAT32, op.getOutputShape(0));
+  mir::Tensor<float> lhs_accessor(broadcasted_lhs);
+  mir::Tensor<float> rhs_accessor(broadcasted_rhs);
+  mir::Tensor<float> res_accessor(res);
+
+  for (const auto &index : mir::ShapeRange(op.getOutputShape(0)))
+  {
+    res_accessor.at(index) = lhs_accessor.at(index) * rhs_accessor.at(index);
+  }
+
+  return {res};
+}
+
+} // namespace nnc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_MUL_
diff --git a/compiler/nnc/passes/interpreter/ops/Sub.h b/compiler/nnc/passes/interpreter/ops/Sub.h
new file mode 100644 (file)
index 0000000..90bce4d
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_SUB_
+#define _NNC_CORE_BACKEND_INTERPRETER_SUB_
+
+#include "mir/ops/SubOp.h"
+#include "mir/Tensor.h"
+#include "mir/ShapeRange.h"
+
+namespace nnc
+{
+
+std::vector<mir::TensorVariant> Sub(const mir::ops::SubOp &op, const mir::TensorVariant &lhs,
+                                    const mir::TensorVariant &rhs)
+{
+  mir::TensorVariant broadcasted_lhs(lhs, op.getOutputShape(0));
+  mir::TensorVariant broadcasted_rhs(rhs, op.getOutputShape(0));
+  mir::TensorVariant res(mir::DataType::FLOAT32, op.getOutputShape(0));
+  mir::Tensor<float> lhs_accessor(broadcasted_lhs);
+  mir::Tensor<float> rhs_accessor(broadcasted_rhs);
+  mir::Tensor<float> res_accessor(res);
+
+  for (const auto &index : mir::ShapeRange(op.getOutputShape(0)))
+  {
+    res_accessor.at(index) = lhs_accessor.at(index) - rhs_accessor.at(index);
+  }
+
+  return {res};
+}
+
+} // namespace nnc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_SUB_
index 7ea1f4a..5776298 100644 (file)
 
 #include "mir/Shape.h"
 #include "mir/Graph.h"
+#include "mir/OpDefs.h"
 
-#include "mir/ops/BatchNormOp.h"
-#include "mir/ops/CappedReluOp.h"
-#include "mir/ops/ConcatOp.h"
-#include "mir/ops/ConstantOp.h"
-#include "mir/ops/Conv2DOp.h"
-#include "mir/ops/Deconv2DOp.h"
-#include "mir/ops/DepthwiseConv2DOp.h"
-#include "mir/ops/DropoutOp.h"
-#include "mir/ops/ElementwiseOp.h"
-#include "mir/ops/EluOp.h"
-#include "mir/ops/FullyConnectedOp.h"
-#include "mir/ops/GatherOp.h"
-#include "mir/ops/GemmOp.h"
-#include "mir/ops/InputOp.h"
-#include "mir/ops/LeakyReluOp.h"
-#include "mir/ops/OutputOp.h"
-#include "mir/ops/PadOp.h"
-#include "mir/ops/PoolOp.h"
-#include "mir/ops/ReduceOp.h"
-#include "mir/ops/ReluOp.h"
-#include "mir/ops/ReshapeOp.h"
-#include "mir/ops/ResizeOp.h"
-#include "mir/ops/SigmoidOp.h"
-#include "mir/ops/SliceOp.h"
-#include "mir/ops/SoftmaxOp.h"
-#include "mir/ops/SqrtOp.h"
-#include "mir/ops/SqueezeOp.h"
-#include "mir/ops/TanhOp.h"
-#include "mir/ops/TransposeOp.h"
-
-#include <type_traits>
-#include <limits>
 #include <stack>
 #include <map>
 
@@ -480,4 +449,29 @@ void ModelAnalyzer::visit(mir::ops::OutputOp& op) {
   appendOperationToInference(&op, "out");
 }
 
+void ModelAnalyzer::visit(mir::ops::AddOp &op)
+{
+  appendOperationToInference(&op, "ElementWise<Add>");
+}
+
+void ModelAnalyzer::visit(mir::ops::DivOp &op)
+{
+  appendOperationToInference(&op, "ElementWise<Div>");
+}
+
+void ModelAnalyzer::visit(mir::ops::MaxOp &op)
+{
+  appendOperationToInference(&op, "ElementWise<Max>");
+}
+
+void ModelAnalyzer::visit(mir::ops::MulOp &op)
+{
+  appendOperationToInference(&op, "ElementWise<Mul>");
+}
+
+void ModelAnalyzer::visit(mir::ops::SubOp &op)
+{
+  appendOperationToInference(&op, "ElementWise<Sub>");
+}
+
 } // namespace nnc
index 0dfb537..73a0ad9 100644 (file)
@@ -46,6 +46,7 @@ public:
  */
   void analyze(const mir::Graph* g);
 
+  void visit(mir::ops::AddOp& op) override;
   void visit(mir::ops::BatchNormOp& op) override;
   void visit(mir::ops::CappedReluOp& op) override;
   void visit(mir::ops::ConcatOp& op) override;
@@ -53,6 +54,7 @@ public:
   void visit(mir::ops::Conv2DOp& op) override;
   void visit(mir::ops::DeConv2DOp& op) override;
   void visit(mir::ops::DepthwiseConv2DOp& op) override;
+  void visit(mir::ops::DivOp& op) override;
   void visit(mir::ops::DropoutOp& op) override;
   void visit(mir::ops::ElementwiseOp& op) override;
   void visit(mir::ops::EluOp& op) override;
@@ -61,6 +63,8 @@ public:
   void visit(mir::ops::GemmOp& op) override;
   void visit(mir::ops::InputOp& op) override;
   void visit(mir::ops::LeakyReluOp& op) override;
+  void visit(mir::ops::MaxOp& op) override;
+  void visit(mir::ops::MulOp& op) override;
   void visit(mir::ops::OutputOp& op) override;
   void visit(mir::ops::PadOp& op) override;
   void visit(mir::ops::PoolOp& op) override;
@@ -73,6 +77,7 @@ public:
   void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqrtOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
+  void visit(mir::ops::SubOp& op) override;
   void visit(mir::ops::TanhOp& op) override;
   void visit(mir::ops::TransposeOp& op) override;
 
index 07d1bf3..38c6122 100644 (file)
 
 #include "CommonData.def"
 
-#include "mir/ops/BatchNormOp.h"
-#include "mir/ops/CappedReluOp.h"
-#include "mir/ops/ConcatOp.h"
-#include "mir/ops/ConstantOp.h"
-#include "mir/ops/Conv2DOp.h"
-#include "mir/ops/Deconv2DOp.h"
-#include "mir/ops/DepthwiseConv2DOp.h"
-#include "mir/ops/DropoutOp.h"
-#include "mir/ops/ElementwiseOp.h"
-#include "mir/ops/EluOp.h"
-#include "mir/ops/FullyConnectedOp.h"
-#include "mir/ops/GatherOp.h"
-#include "mir/ops/GemmOp.h"
-#include "mir/ops/LeakyReluOp.h"
-#include "mir/ops/PadOp.h"
-#include "mir/ops/PoolOp.h"
-#include "mir/ops/ReduceOp.h"
-#include "mir/ops/ReluOp.h"
-#include "mir/ops/ReshapeOp.h"
-#include "mir/ops/ResizeOp.h"
-#include "mir/ops/SliceOp.h"
-#include "mir/ops/SoftmaxOp.h"
-#include "mir/ops/SqueezeOp.h"
-#include "mir/ops/SqrtOp.h"
-#include "mir/ops/TanhOp.h"
-#include "mir/ops/TransposeOp.h"
+#include "mir/OpDefs.h"
 
 #include "pass/PassException.h"
 #include <algorithm>
@@ -369,4 +344,44 @@ void Serializer::visit(mir::ops::OutputOp& /*op*/) {
   // no parameters to dump
 }
 
+void Serializer::visit(mir::ops::AddOp &op)
+{
+  _curOp->paramStartOffset = _buffer.size();
+  // Op type is known at codegen Time
+  serializeT(static_cast<int32_t>(op.getBroadcast()));
+  serializeShape(op.getOutputShape(0));
+}
+
+void Serializer::visit(mir::ops::DivOp &op)
+{
+  _curOp->paramStartOffset = _buffer.size();
+  // Op type is known at codegen Time
+  serializeT(static_cast<int32_t>(op.getBroadcast()));
+  serializeShape(op.getOutputShape(0));
+}
+
+void Serializer::visit(mir::ops::MaxOp &op)
+{
+  _curOp->paramStartOffset = _buffer.size();
+  // Op type is known at codegen Time
+  serializeT(static_cast<int32_t>(op.getBroadcast()));
+  serializeShape(op.getOutputShape(0));
+}
+
+void Serializer::visit(mir::ops::MulOp &op)
+{
+  _curOp->paramStartOffset = _buffer.size();
+  // Op type is known at codegen Time
+  serializeT(static_cast<int32_t>(op.getBroadcast()));
+  serializeShape(op.getOutputShape(0));
+}
+
+void Serializer::visit(mir::ops::SubOp &op)
+{
+  _curOp->paramStartOffset = _buffer.size();
+  // Op type is known at codegen Time
+  serializeT(static_cast<int32_t>(op.getBroadcast()));
+  serializeShape(op.getOutputShape(0));
+}
+
 } // namespace nnc
index 7c5c98e..f45877a 100644 (file)
@@ -39,7 +39,7 @@ namespace nnc {
  */
 class Serializer : public mir::IVisitor {
 public:
-
+  void visit(mir::ops::AddOp& op) override;
   void visit(mir::ops::BatchNormOp& op) override;
   void visit(mir::ops::CappedReluOp& op) override;
   void visit(mir::ops::ConcatOp& op) override;
@@ -47,6 +47,7 @@ public:
   void visit(mir::ops::Conv2DOp& op) override;
   void visit(mir::ops::DeConv2DOp& op) override;
   void visit(mir::ops::DepthwiseConv2DOp& op) override;
+  void visit(mir::ops::DivOp& op) override;
   void visit(mir::ops::DropoutOp& op) override;
   void visit(mir::ops::ElementwiseOp& op) override;
   void visit(mir::ops::EluOp& op) override;
@@ -55,6 +56,8 @@ public:
   void visit(mir::ops::GemmOp& op) override;
   void visit(mir::ops::InputOp& op) override;
   void visit(mir::ops::LeakyReluOp& op) override;
+  void visit(mir::ops::MaxOp& op) override;
+  void visit(mir::ops::MulOp& op) override;
   void visit(mir::ops::OutputOp& op) override;
   void visit(mir::ops::PadOp& op) override;
   void visit(mir::ops::PoolOp& op) override;
@@ -67,6 +70,7 @@ public:
   void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqrtOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
+  void visit(mir::ops::SubOp& op) override;
   void visit(mir::ops::TanhOp& op) override;
   void visit(mir::ops::TransposeOp& op) override;
 
index ea25f0a..46f329f 100644 (file)
@@ -34,7 +34,6 @@
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/Deconv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
-#include "mir/ops/ElementwiseOp.h"
 #include "mir/ops/EluOp.h"
 #include "mir/ops/FullyConnectedOp.h"
 #include "mir/ops/InputOp.h"
@@ -202,8 +201,6 @@ TensorVariant createTensorVariant(const Shape& shape) {
 
 }
 
-#include <fstream>
-
 // Actual tests
 
 TEST(acl_backend_mir_to_dom, constant) {
index 727465b..9343213 100644 (file)
 #include "passes/optimizations/CombineTransposes.h"
 #include "mir/ops/TransposeOp.h"
 #include "mir/ops/ReluOp.h"
-#include "mir/ops/ElementwiseOp.h"
-#include "mir/ops/ConstantOp.h"
-#include "mir/ops/TanhOp.h"
-#include "mir/ops/ConcatOp.h"
 #include "mir/ops/OutputOp.h"
-#include "mir/ops/PoolOp.h"
 #include "Util.h"
 #include <gtest/gtest.h>
 
@@ -99,7 +94,7 @@ TEST(OptPass, combineTransposesBush) {
    *        //       \\
    *[Transpose 2] [Transpose 3]
    *       \\       //
-   *    [Elementwise<add>]
+   *          [Add]
    */
   Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3, 2});
   Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0),
@@ -108,16 +103,13 @@ TEST(OptPass, combineTransposesBush) {
                                               vector<size_t>{1, 0, 2, 3});
   Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0),
                                               vector<size_t>{1, 0, 2, 3});
-  Operation* elw = g.create<ops::ElementwiseOp>("elewiseAdd",
-                                                vector<Operation::Output*>{tr2->getOutput(0),
-                                                                           tr3->getOutput(0)},
-                                                ops::ElementwiseOp::OpType::add);
+  Operation *elw = g.create<ops::AddOp>("elewiseAdd", tr2->getOutput(0), tr3->getOutput(0));
   std::stringstream ss;
   DumpVisitor d(ss);
   CombineTransposes pass;
   pass.run(&g);
   g.accept(&d);
-  ASSERT_EQ("i_input.e_elewiseAdd.", ss.str());
+  ASSERT_EQ("i_input.b_elewiseAdd.", ss.str());
   ASSERT_EQ(elw->getInput(0)->getProducer()->getNode()->getName(), "input");
   ASSERT_EQ(elw->getInput(1)->getProducer()->getNode()->getName(), "input");
 }
@@ -131,7 +123,7 @@ TEST(OptPass, combineTransposesOpOrder) {
    *      ||          ||
    * [Transpose 2] [Transpose 3]
    *       \\       //
-   *    [Elementwise<add>]
+   *          [Add]
    */
   Operation* in1 = g.create<ops::InputOp>("inp1", Shape{1, 2, 3});
   Operation* in2 = g.create<ops::InputOp>("inp2", Shape{1, 2, 3});
@@ -139,10 +131,7 @@ TEST(OptPass, combineTransposesOpOrder) {
   Operation* tr1 = g.create<ops::TransposeOp>("tr1", in2->getOutput(0), vector<size_t>{2, 1, 0});
   Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr0->getOutput(0), vector<size_t>{1, 0, 2});
   Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0), vector<size_t>{2, 1, 0});
-  Operation* elw = g.create<ops::ElementwiseOp>("elewiseAdd",
-                                                vector<Operation::Output*>{tr2->getOutput(0),
-                                                                           tr3->getOutput(0)},
-                                                ops::ElementwiseOp::OpType::add);
+  Operation *elw = g.create<ops::AddOp>("elewiseAdd", tr2->getOutput(0), tr3->getOutput(0));
   g.create<ops::OutputOp>("out", elw->getOutput(0));
   int n1 = elw->getInput(0)->getNode()->getInput(0)->getNode()->getInput(0)->getNode()->getId();
   int n2 = elw->getInput(1)->getNode()->getInput(0)->getNode()->getInput(0)->getNode()->getId();
index 9d0f440..9f19dbe 100644 (file)
@@ -16,9 +16,7 @@
 
 
 #include "passes/optimizations/RemoveDeadEnds.h"
-#include "mir/ops/TransposeOp.h"
 #include "mir/ops/ReluOp.h"
-#include "mir/ops/ElementwiseOp.h"
 #include "mir/ops/ConstantOp.h"
 
 #include <gtest/gtest.h>
index da5c9a6..a14bf73 100644 (file)
@@ -18,7 +18,6 @@
 #define NNCC_UTIL_H
 #include "mir/ops/TransposeOp.h"
 #include "mir/ops/ReluOp.h"
-#include "mir/ops/ElementwiseOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/TanhOp.h"
 #include "mir/ops/ConcatOp.h"
@@ -53,8 +52,6 @@ public:
 
   void visit(mir::ops::ConstantOp& op) override { _s << "const_" << op.getName() << "."; }
 
-  void visit(mir::ops::ElementwiseOp& op) override { _s << "e_" << op.getName() << "."; }
-
   std::ostream& _s;
 };
 
index e8c7224..a907332 100644 (file)
 #include "SBSerializer.h"
 
 // operations part
+#include "mir/ops/AddOp.h"
 #include "mir/ops/CappedReluOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/Deconv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
-#include "mir/ops/ElementwiseOp.h"
+#include "mir/ops/DivOp.h"
 #include "mir/ops/EluOp.h"
 #include "mir/ops/FullyConnectedOp.h"
 #include "mir/ops/InputOp.h"
 #include "mir/ops/LeakyReluOp.h"
+#include "mir/ops/MaxOp.h"
+#include "mir/ops/MulOp.h"
 #include "mir/ops/OutputOp.h"
 #include "mir/ops/PadOp.h"
 #include "mir/ops/PoolOp.h"
@@ -78,6 +81,7 @@
 #include "mir/ops/SliceOp.h"
 #include "mir/ops/SoftmaxOp.h"
 #include "mir/ops/SqrtOp.h"
+#include "mir/ops/SubOp.h"
 #include "mir/ops/TanhOp.h"
 #include "mir/ops/TransposeOp.h"
 
@@ -367,8 +371,10 @@ TEST(cpp_operations_test, concat) {
     }
 }
 
-TEST(cpp_operations_test, add2bc) {
-  for (int num_dims = 2; num_dims <= 4; ++num_dims) {
+TEST(cpp_operations_test, addbc)
+{
+  for (int num_dims = 2; num_dims <= 4; ++num_dims)
+  {
     // test prerequisites
     vector<int> shape_data1{3, 44, 5, 1};
     vector<int> shape_data2{3, 1, 5, 6};
@@ -378,70 +384,63 @@ TEST(cpp_operations_test, add2bc) {
     vector<unique_ptr<mir::TensorVariant>> input_ntensors(2);
     fillTensors(input_ntensors[0], input_atensors[0], shape_data1, 1.0f);
     fillTensors(input_ntensors[1], input_atensors[1], shape_data2, 2.0f);
-    auto op_generator = [](mir::Graph& g, const std::vector<mir::Operation::Output*>& inputs) {
-      return g.create<mir::ops::ElementwiseOp>("y", inputs, mir::ops::ElementwiseOp::OpType::add);
+    auto op_generator = [](mir::Graph &g, const std::vector<mir::Operation::Output *> &inputs) {
+      return g.create<mir::ops::AddOp>("y", inputs[0], inputs[1]);
     };
 
     createAndRunTestGraph(op_generator, ElementWise<Add, Tensor, Tensor>, input_ntensors,
-                          input_atensors[0],
-                          input_atensors[1]);
+                          input_atensors[0], input_atensors[1]);
   }
 }
 
-TEST(cpp_operations_test, mul3bc) {
-  for (int num_dims = 2; num_dims <= 4; ++num_dims) {
+TEST(cpp_operations_test, mulbc)
+{
+  for (int num_dims = 2; num_dims <= 4; ++num_dims)
+  {
     // test prerequisites
     vector<int> shape_data1{3, 22, 5, 1};
     vector<int> shape_data2{3, 1, 5, 6};
-    vector<int> shape_data3{1, 22, 1, 6};
     shape_data1.resize(num_dims);
     shape_data2.resize(num_dims);
-    shape_data3.resize(num_dims);
-    vector<Tensor> input_atensors(3);
-    vector<unique_ptr<mir::TensorVariant>> input_ntensors(3);
+    vector<Tensor> input_atensors(2);
+    vector<unique_ptr<mir::TensorVariant>> input_ntensors(2);
     fillTensors(input_ntensors[0], input_atensors[0], shape_data1, 1.0f);
     fillTensors(input_ntensors[1], input_atensors[1], shape_data2, 2.0f);
-    fillTensors(input_ntensors[2], input_atensors[2], shape_data3, 3.0f);
-    auto opGenerator = [](mir::Graph& g, const std::vector<mir::Operation::Output*>& inputs) {
-      return g.create<mir::ops::ElementwiseOp>("y", inputs, mir::ops::ElementwiseOp::OpType::mul);
+    auto opGenerator = [](mir::Graph &g, const std::vector<mir::Operation::Output *> &inputs) {
+      return g.create<mir::ops::MulOp>("y", inputs[0], inputs[1]);
     };
 
-    createAndRunTestGraph(opGenerator, ElementWise<Mul, Tensor, Tensor, Tensor>, input_ntensors,
-                          input_atensors[0],
-                          input_atensors[1], input_atensors[2]);
+    createAndRunTestGraph(opGenerator, ElementWise<Mul, Tensor, Tensor>, input_ntensors,
+                          input_atensors[0], input_atensors[1]);
   }
 }
 
-TEST(cpp_operations_test, div3bc) {
-  for (int num_dims = 2; num_dims <= 4; ++num_dims) {
+TEST(cpp_operations_test, divbc)
+{
+  for (int num_dims = 2; num_dims <= 4; ++num_dims)
+  {
     // test prerequisites
     vector<int> shape_data1{3, 22, 5, 1};
     vector<int> shape_data2{3, 1, 5, 6};
-    vector<int> shape_data3{1, 22, 1, 6};
     shape_data1.resize(num_dims);
     shape_data2.resize(num_dims);
-    shape_data3.resize(num_dims);
-    vector<Tensor> input_atensors(3);
-    vector<unique_ptr<mir::TensorVariant>> input_ntensors(3);
+    vector<Tensor> input_atensors(2);
+    vector<unique_ptr<mir::TensorVariant>> input_ntensors(2);
     fillTensors(input_ntensors[0], input_atensors[0], shape_data1, 5.0f);
     fillTensors(input_ntensors[1], input_atensors[1], shape_data2, 2.0f);
-    fillTensors(input_ntensors[2], input_atensors[2], shape_data3, 3.0f);
-    auto opGenerator = [](mir::Graph& g, const std::vector<mir::Operation::Output*>& inputs) {
-      return g.create<mir::ops::ElementwiseOp>("y", inputs, mir::ops::ElementwiseOp::OpType::div);
+    auto opGenerator = [](mir::Graph &g, const std::vector<mir::Operation::Output *> &inputs) {
+      return g.create<mir::ops::DivOp>("y", inputs[0], inputs[1]);
     };
 
-    createAndRunTestGraph(
-      opGenerator, ElementWise<Div, Tensor, Tensor, Tensor>,
-      input_ntensors,
-      input_atensors[0],
-      input_atensors[1],
-      input_atensors[2]
-    );
+    createAndRunTestGraph(opGenerator, ElementWise<Div, Tensor, Tensor>, input_ntensors,
+                          input_atensors[0], input_atensors[1]);
   }
 }
 
-TEST(cpp_operations_test, add2) {
-  for (int num_dims = 2; num_dims <= 4; ++num_dims) {
+TEST(cpp_operations_test, add)
+{
+  for (int num_dims = 2; num_dims <= 4; ++num_dims)
+  {
     // test prerequisites
     vector<int> shape_data{2, 3, 5, 7};
     shape_data.resize(num_dims);
@@ -449,85 +448,72 @@ TEST(cpp_operations_test, add2) {
     vector<unique_ptr<mir::TensorVariant>> input_ntensors(2);
     fillTensors(input_ntensors[0], input_atensors[0], shape_data, 1.0f);
     fillTensors(input_ntensors[1], input_atensors[1], shape_data, 2.0f);
-    auto op_generator = [](mir::Graph& g, const std::vector<mir::Operation::Output*>& inputs) {
-      return g.create<mir::ops::ElementwiseOp>("y", inputs, mir::ops::ElementwiseOp::OpType::add);
+    auto op_generator = [](mir::Graph &g, const std::vector<mir::Operation::Output *> &inputs) {
+      return g.create<mir::ops::AddOp>("y", inputs[0], inputs[1]);
     };
 
-    createAndRunTestGraph(op_generator,
-                          ElementWise<Add,Tensor,Tensor>,
-                          input_ntensors,
-                          input_atensors[0],
-                          input_atensors[1]);
+    createAndRunTestGraph(op_generator, ElementWise<Add, Tensor, Tensor>, input_ntensors,
+                          input_atensors[0], input_atensors[1]);
   }
 }
 
-TEST(cpp_operations_test, sub3) {
-  for (int num_dims = 2; num_dims <= 4; ++num_dims) {
+TEST(cpp_operations_test, sub)
+{
+  for (int num_dims = 2; num_dims <= 4; ++num_dims)
+  {
     // test prerequisites
     vector<int> shape_data{2, 3, 5, 7};
     shape_data.resize(num_dims);
-    vector<Tensor> input_atensors(3);
-    vector<unique_ptr<mir::TensorVariant>> input_n_tensors(3);
+    vector<Tensor> input_atensors(2);
+    vector<unique_ptr<mir::TensorVariant>> input_n_tensors(2);
     fillTensors(input_n_tensors[0], input_atensors[0], shape_data, 1.0f);
     fillTensors(input_n_tensors[1], input_atensors[1], shape_data, 2.0f);
-    fillTensors(input_n_tensors[2], input_atensors[2], shape_data, 3.0f);
-    auto opGenerator = [](mir::Graph& g, const std::vector<mir::Operation::Output*>& inputs) {
-      return g.create<mir::ops::ElementwiseOp>("y", inputs,
-                                               mir::ops::ElementwiseOp::OpType::sub);
+    auto opGenerator = [](mir::Graph &g, const std::vector<mir::Operation::Output *> &inputs) {
+      return g.create<mir::ops::SubOp>("y", inputs[0], inputs[1]);
     };
 
-    createAndRunTestGraph(opGenerator, ElementWise<Sub, Tensor, Tensor, Tensor>, input_n_tensors,
-                          input_atensors[0],
-                          input_atensors[1],
-                          input_atensors[2]);
+    createAndRunTestGraph(opGenerator, ElementWise<Sub, Tensor, Tensor>, input_n_tensors,
+                          input_atensors[0], input_atensors[1]);
   }
 }
 
-TEST(cpp_operations_test, mul3) {
-  for (int num_dims = 2; num_dims <= 4; ++num_dims) {
+TEST(cpp_operations_test, mul)
+{
+  for (int num_dims = 2; num_dims <= 4; ++num_dims)
+  {
     // test prerequisites
     vector<int> shape_data{2, 3, 5, 7};
     shape_data.resize(num_dims);
-    vector<Tensor> input_atensors(3);
-    vector<unique_ptr<mir::TensorVariant>> input_ntensors(3);
+    vector<Tensor> input_atensors(2);
+    vector<unique_ptr<mir::TensorVariant>> input_ntensors(2);
     fillTensors(input_ntensors[0], input_atensors[0], shape_data, 1.0f);
     fillTensors(input_ntensors[1], input_atensors[1], shape_data, 2.0f);
-    fillTensors(input_ntensors[2], input_atensors[2], shape_data, 3.0f);
-    auto op_generator = [](mir::Graph& g, const std::vector<mir::Operation::Output*>& inputs) {
-      return g.create<mir::ops::ElementwiseOp>("y", inputs, mir::ops::ElementwiseOp::OpType::mul);
+    auto op_generator = [](mir::Graph &g, const std::vector<mir::Operation::Output *> &inputs) {
+      return g.create<mir::ops::MulOp>("y", inputs[0], inputs[1]);
     };
 
-    createAndRunTestGraph(op_generator,
-                          ElementWise<Mul,Tensor,Tensor,Tensor>,
-                          input_ntensors,
-                          input_atensors[0],
-                          input_atensors[1],
-                          input_atensors[2]);
+    createAndRunTestGraph(op_generator, ElementWise<Mul, Tensor, Tensor>, input_ntensors,
+                          input_atensors[0], input_atensors[1]);
   }
 }
 
-TEST(cpp_operations_test, max4) {
-  for (int num_dims = 2; num_dims <= 4; ++num_dims) {
+TEST(cpp_operations_test, max)
+{
+  for (int num_dims = 2; num_dims <= 4; ++num_dims)
+  {
     // test prerequisites
     vector<int> shape_data{2, 3, 5, 7};
     shape_data.resize(num_dims);
-    vector<Tensor> input_atensors(4);
-    vector<unique_ptr<mir::TensorVariant>> input_ntensors(4);
+    vector<Tensor> input_atensors(2);
+    vector<unique_ptr<mir::TensorVariant>> input_ntensors(2);
     fillTensors(input_ntensors[0], input_atensors[0], shape_data, 1.0f);
     fillTensors(input_ntensors[1], input_atensors[1], shape_data, 2.0f);
-    fillTensors(input_ntensors[2], input_atensors[2], shape_data, 3.0f);
-    fillTensors(input_ntensors[3], input_atensors[3], shape_data, 3.0f);
-    auto op_generator = [](mir::Graph& g, const std::vector<mir::Operation::Output*>& inputs) {
-      return g.create<mir::ops::ElementwiseOp>("y", inputs, mir::ops::ElementwiseOp::OpType::max);
+    auto op_generator = [](mir::Graph &g, const std::vector<mir::Operation::Output *> &inputs) {
+      return g.create<mir::ops::MaxOp>("y", inputs[0], inputs[1]);
     };
 
-    createAndRunTestGraph(op_generator,
-                          ElementWise<Max,Tensor,Tensor,Tensor,Tensor>,
-                          input_ntensors,
-                          input_atensors[0],
-                          input_atensors[1],
-                          input_atensors[2],
-                          input_atensors[3]);
+    createAndRunTestGraph(op_generator, ElementWise<Max, Tensor, Tensor>, input_ntensors,
+                          input_atensors[0], input_atensors[1]);
   }
 }