[mir] Add binary elementwise operations (#6351)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Wed, 7 Aug 2019 12:04:29 +0000 (15:04 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 7 Aug 2019 12:04:29 +0000 (15:04 +0300)
first step of future replacement of single `ElementwiseOp` class with set of individual operations.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir/CMakeLists.txt
compiler/mir/include/mir/OpDefs.h
compiler/mir/include/mir/ops/AddOp.h [new file with mode: 0644]
compiler/mir/include/mir/ops/BinaryElementwiseOp.h [new file with mode: 0644]
compiler/mir/include/mir/ops/DivOp.h [new file with mode: 0644]
compiler/mir/include/mir/ops/MaxOp.h [new file with mode: 0644]
compiler/mir/include/mir/ops/MulOp.h [new file with mode: 0644]
compiler/mir/include/mir/ops/SubOp.h [new file with mode: 0644]
compiler/mir/include/mir/ops/operations.lst.h
compiler/mir/src/Operation.cpp
compiler/mir/src/ops/BinaryElementwiseOp.cpp [new file with mode: 0644]

index 9fa8ef0..2e04cf3 100644 (file)
@@ -1,4 +1,5 @@
 set(MIR_SOURCES
+    src/ops/BinaryElementwiseOp.cpp
     src/ops/ConcatOp.cpp
     src/ops/Conv2DOp.cpp
     src/ops/DeConv2DOp.cpp
index 92b5b3b..3a5034d 100644 (file)
@@ -17,6 +17,7 @@
 #ifndef _MIR_OPDEFS_H_
 #define _MIR_OPDEFS_H_
 
+#include "mir/ops/AddOp.h"
 #include "mir/ops/BatchNormOp.h"
 #include "mir/ops/BiasAddOp.h"
 #include "mir/ops/CappedReluOp.h"
@@ -26,6 +27,7 @@
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/Deconv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/DivOp.h"
 #include "mir/ops/DropoutOp.h"
 #include "mir/ops/ElementwiseOp.h"
 #include "mir/ops/EluOp.h"
@@ -34,6 +36,8 @@
 #include "mir/ops/GemmOp.h"
 #include "mir/ops/InputOp.h"
 #include "mir/ops/LeakyReluOp.h"
+#include "mir/ops/MaxOp.h"
+#include "mir/ops/MulOp.h"
 #include "mir/ops/OutputOp.h"
 #include "mir/ops/PadOp.h"
 #include "mir/ops/PoolOp.h"
@@ -47,6 +51,7 @@
 #include "mir/ops/SoftmaxOp.h"
 #include "mir/ops/SqrtOp.h"
 #include "mir/ops/SqueezeOp.h"
+#include "mir/ops/SubOp.h"
 #include "mir/ops/TanhOp.h"
 #include "mir/ops/TransposeOp.h"
 
diff --git a/compiler/mir/include/mir/ops/AddOp.h b/compiler/mir/include/mir/ops/AddOp.h
new file mode 100644 (file)
index 0000000..962cd48
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _MIR_OPS_ADD_OP_H_
+#define _MIR_OPS_ADD_OP_H_
+
+#include "mir/ops/BinaryElementwiseOp.h"
+
+namespace mir
+{
+namespace ops
+{
+
+class AddOp : public BinaryElementwiseOp
+{
+public:
+  AddOp(Output *arg1, Output *arg2) : BinaryElementwiseOp(Type::add, arg1, arg2) {}
+
+  Operation *copyWithInputs(const std::vector<Output *> &inputs) override
+  {
+    return new AddOp(inputs[0], inputs[1]);
+  }
+};
+
+} // namespace ops
+} // namespace mir
+
+#endif //_MIR_OPS_ADD_OP_H_
diff --git a/compiler/mir/include/mir/ops/BinaryElementwiseOp.h b/compiler/mir/include/mir/ops/BinaryElementwiseOp.h
new file mode 100644 (file)
index 0000000..51f73b0
--- /dev/null
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _MIR_OPS_BINARY_ELEMENTWISE_OP_H_
+#define _MIR_OPS_BINARY_ELEMENTWISE_OP_H_
+
+#include "mir/Operation.h"
+
+namespace mir
+{
+namespace ops
+{
+
+class BinaryElementwiseOp : public Operation
+{
+protected:
+  BinaryElementwiseOp(Type type, Output *lhs, Output *rhs)
+      : Operation(type, {lhs, rhs}), _needs_broadcast(false)
+  {
+    inferOutputShapes();
+  };
+
+public:
+  bool getBroadcast() const { return _needs_broadcast; }
+
+private:
+  void inferOutputShapes();
+
+  bool _needs_broadcast;
+};
+
+} // namespace ops
+} // namespace mir
+
+#endif //_MIR_OPS_BINARY_ELEMENTWISE_OP_H_
diff --git a/compiler/mir/include/mir/ops/DivOp.h b/compiler/mir/include/mir/ops/DivOp.h
new file mode 100644 (file)
index 0000000..349e75b
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _MIR_OPS_DIV_OP_H_
+#define _MIR_OPS_DIV_OP_H_
+
+#include "mir/ops/BinaryElementwiseOp.h"
+
+namespace mir
+{
+namespace ops
+{
+
+class DivOp : public BinaryElementwiseOp
+{
+public:
+  DivOp(Output *arg1, Output *arg2) : BinaryElementwiseOp(Type::div, arg1, arg2) {}
+
+  Operation *copyWithInputs(const std::vector<Output *> &inputs) override
+  {
+    return new DivOp(inputs[0], inputs[1]);
+  }
+};
+
+} // namespace ops
+} // namespace mir
+
+#endif //_MIR_OPS_DIV_OP_H_
diff --git a/compiler/mir/include/mir/ops/MaxOp.h b/compiler/mir/include/mir/ops/MaxOp.h
new file mode 100644 (file)
index 0000000..ca2d91a
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _MIR_OPS_MAX_OP_H_
+#define _MIR_OPS_MAX_OP_H_
+
+#include "mir/ops/BinaryElementwiseOp.h"
+
+namespace mir
+{
+namespace ops
+{
+
+class MaxOp : public BinaryElementwiseOp
+{
+public:
+  MaxOp(Output *arg1, Output *arg2) : BinaryElementwiseOp(Type::max, arg1, arg2) {}
+
+  Operation *copyWithInputs(const std::vector<Output *> &inputs) override
+  {
+    return new MaxOp(inputs[0], inputs[1]);
+  }
+};
+
+} // namespace ops
+} // namespace mir
+
+#endif //_MIR_OPS_MAX_OP_H_
diff --git a/compiler/mir/include/mir/ops/MulOp.h b/compiler/mir/include/mir/ops/MulOp.h
new file mode 100644 (file)
index 0000000..c76e307
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _MIR_OPS_MUL_OP_H_
+#define _MIR_OPS_MUL_OP_H_
+
+#include "mir/ops/BinaryElementwiseOp.h"
+
+namespace mir
+{
+namespace ops
+{
+
+class MulOp : public BinaryElementwiseOp
+{
+public:
+  MulOp(Output *arg1, Output *arg2) : BinaryElementwiseOp(Type::mul, arg1, arg2) {}
+
+  Operation *copyWithInputs(const std::vector<Output *> &inputs) override
+  {
+    return new MulOp(inputs[0], inputs[1]);
+  }
+};
+
+} // namespace ops
+} // namespace mir
+
+#endif //_MIR_OPS_MUL_OP_H_
diff --git a/compiler/mir/include/mir/ops/SubOp.h b/compiler/mir/include/mir/ops/SubOp.h
new file mode 100644 (file)
index 0000000..519b238
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _MIR_OPS_SUB_OP_H_
+#define _MIR_OPS_SUB_OP_H_
+
+#include "mir/ops/BinaryElementwiseOp.h"
+
+namespace mir
+{
+namespace ops
+{
+
+class SubOp : public BinaryElementwiseOp
+{
+public:
+  SubOp(Output *arg1, Output *arg2) : BinaryElementwiseOp(Type::sub, arg1, arg2) {}
+
+  Operation *copyWithInputs(const std::vector<Output *> &inputs) override
+  {
+    return new SubOp(inputs[0], inputs[1]);
+  }
+};
+
+} // namespace ops
+} // namespace mir
+
+#endif //_MIR_OPS_SUB_OP_H_
index 7e853d5..791d807 100644 (file)
@@ -18,6 +18,7 @@
 #error "You should define HANDLE_OP before including this file"
 #endif // HANDLE_OP
 
+HANDLE_OP(add, AddOp)
 HANDLE_OP(batchNorm, BatchNormOp)
 HANDLE_OP(biasAdd, BiasAddOp)
 HANDLE_OP(cappedReLU, CappedReluOp)
@@ -26,6 +27,7 @@ HANDLE_OP(constant, ConstantOp)
 HANDLE_OP(conv2D, Conv2DOp)
 HANDLE_OP(deConv2D, DeConv2DOp)
 HANDLE_OP(depthwiseConv, DepthwiseConv2DOp)
+HANDLE_OP(div, DivOp)
 HANDLE_OP(dropout, DropoutOp)
 HANDLE_OP(elementwise, ElementwiseOp)
 HANDLE_OP(ELU, EluOp)
@@ -34,6 +36,8 @@ HANDLE_OP(gather, GatherOp)
 HANDLE_OP(gemmOp, GemmOp)
 HANDLE_OP(input, InputOp)
 HANDLE_OP(leakyReLU, LeakyReluOp)
+HANDLE_OP(max, MaxOp)
+HANDLE_OP(mul, MulOp)
 HANDLE_OP(output, OutputOp)
 HANDLE_OP(pad, PadOp)
 HANDLE_OP(pool, PoolOp)
@@ -47,5 +51,6 @@ HANDLE_OP(slice, SliceOp)
 HANDLE_OP(softmax, SoftmaxOp)
 HANDLE_OP(sqrt, SqrtOp)
 HANDLE_OP(squeeze, SqueezeOp)
+HANDLE_OP(sub, SubOp)
 HANDLE_OP(tanh, TanhOp)
 HANDLE_OP(transpose, TransposeOp)
index 9fac454..dbdd37a 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include "mir/Operation.h"
+#include "mir/ops/AddOp.h"
 #include "mir/ops/BatchNormOp.h"
 #include "mir/ops/BiasAddOp.h"
 #include "mir/ops/CappedReluOp.h"
@@ -23,6 +24,7 @@
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/Deconv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/DivOp.h"
 #include "mir/ops/DropoutOp.h"
 #include "mir/ops/ElementwiseOp.h"
 #include "mir/ops/EluOp.h"
@@ -31,6 +33,8 @@
 #include "mir/ops/GemmOp.h"
 #include "mir/ops/InputOp.h"
 #include "mir/ops/LeakyReluOp.h"
+#include "mir/ops/MaxOp.h"
+#include "mir/ops/MulOp.h"
 #include "mir/ops/OutputOp.h"
 #include "mir/ops/PadOp.h"
 #include "mir/ops/PoolOp.h"
@@ -44,6 +48,7 @@
 #include "mir/ops/SoftmaxOp.h"
 #include "mir/ops/SqueezeOp.h"
 #include "mir/ops/SqrtOp.h"
+#include "mir/ops/SubOp.h"
 #include "mir/ops/TanhOp.h"
 #include "mir/ops/TransposeOp.h"
 
diff --git a/compiler/mir/src/ops/BinaryElementwiseOp.cpp b/compiler/mir/src/ops/BinaryElementwiseOp.cpp
new file mode 100644 (file)
index 0000000..06df8cd
--- /dev/null
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mir/ops/BinaryElementwiseOp.h"
+
+#include <cassert>
+
+namespace mir
+{
+namespace ops
+{
+
+void BinaryElementwiseOp::inferOutputShapes()
+{
+  int max_rank = getInputShape(0).rank();
+  size_t max_ind = 0;
+  for (size_t i = 0; i < getNumInputs(); i++)
+  {
+    if (max_rank < getInputShape(i).rank())
+    {
+      max_rank = getInputShape(i).rank();
+      max_ind = i;
+    }
+  }
+  Shape max_shape = getInputShape(max_ind);
+  for (size_t i = 0; i < getNumInputs(); i++)
+  {
+    const auto &current_shape = getInputShape(i);
+    _needs_broadcast = _needs_broadcast || max_shape != current_shape; // check not equal
+    const int rank = current_shape.rank();
+    for (int axis = 0; axis < rank; axis++)
+    {
+      auto current_dim = current_shape.dim(rank - axis - 1);
+      // get max for all axes
+      if (max_shape.dim(max_rank - axis - 1) == 1 && current_dim != 1)
+      {
+        max_shape.dim(max_rank - axis - 1) = current_dim;
+      }
+      else
+      {
+        assert((current_dim == 1 || current_dim == max_shape.dim(max_rank - axis - 1)) &&
+               "Incompatible shapes in broadcast!");
+      }
+    }
+  }
+  setOutputShape(0, max_shape);
+}
+
+} // namespace ops
+} // namespace mir
\ No newline at end of file