Refine [mir] operations constructors in order to simplify interface. (#8044)
authorGusev Dmitry/Engineer/AI Tools Lab /SRR/Samsung Electronics <d.gusev@partner.samsung.com>
Mon, 14 Oct 2019 18:07:38 +0000 (21:07 +0300)
committerAlexander Efimov/./AI Tools Lab/Samsung Electronics <a.efimov@samsung.com>
Mon, 14 Oct 2019 18:07:38 +0000 (21:07 +0300)
Operations attributes are united into structures.

Signed-off-by: Dmitry Gusev <d.gusev@partner.samsung.com>
13 files changed:
compiler/mir/include/mir/Attributes.h [new file with mode: 0644]
compiler/mir/include/mir/ops/AvgPool2DOp.h
compiler/mir/include/mir/ops/Conv2DOp.h
compiler/mir/include/mir/ops/Deconv2DOp.h
compiler/mir/include/mir/ops/DepthwiseConv2DOp.h
compiler/mir/include/mir/ops/MaxPool2DOp.h
compiler/mir/include/mir/ops/PadOp.h
compiler/mir/src/ops/AvgPool2DOp.cpp
compiler/mir/src/ops/Conv2DOp.cpp
compiler/mir/src/ops/DeConv2DOp.cpp
compiler/mir/src/ops/DepthwiseConv2DOp.cpp
compiler/mir/src/ops/MaxPool2DOp.cpp
compiler/mir/src/ops/PadOp.cpp

diff --git a/compiler/mir/include/mir/Attributes.h b/compiler/mir/include/mir/Attributes.h
new file mode 100644 (file)
index 0000000..c065ad9
--- /dev/null
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef OP_ATTRIBUTES_H
+#define OP_ATTRIBUTES_H
+
+#include <vector>
+#include "mir/DataFormat.h"
+#include "mir/ops/PaddingType.h"
+
+namespace mir
+{
+
+struct Conv2DOpAttributes
+{
+  Conv2DOpAttributes() = default;
+
+  std::vector<std::int32_t> strides{1, 1};
+  std::vector<std::int32_t> padding_before{0, 0};
+  std::vector<std::int32_t> padding_after{0, 0};
+  DataFormat data_format{DataFormat::NHWC};
+};
+
+struct AvgPool2DOpAttributes
+{
+  AvgPool2DOpAttributes() = default;
+
+  std::vector<std::int32_t> window{1, 1};
+  std::vector<std::int32_t> strides{1, 1};
+  std::vector<std::int32_t> padding_before{0, 0};
+  std::vector<std::int32_t> padding_after{0, 0};
+  DataFormat data_format{DataFormat::NHWC};
+  bool include_pad{true};
+};
+
+struct MaxPool2DOpAttributes
+{
+  MaxPool2DOpAttributes() = default;
+
+  std::vector<std::int32_t> window{1, 1};
+  std::vector<std::int32_t> strides{1, 1};
+  std::vector<std::int32_t> padding_before{0, 0};
+  std::vector<std::int32_t> padding_after{0, 0};
+  DataFormat data_format{DataFormat::NHWC};
+};
+
+struct Deconv2DOpAttributes
+{
+  Deconv2DOpAttributes() = default;
+
+  std::vector<std::int32_t> strides{1, 1};
+  std::vector<std::int32_t> padding_before{0, 0};
+  std::vector<std::int32_t> padding_after{0, 0};
+  DataFormat data_format{DataFormat::NHWC};
+  ops::PaddingType padding_type{ops::PaddingType::Explicit};
+};
+
+struct PadOpAttributes
+{
+  PadOpAttributes() : padding_value(0.0) {}
+  PadOpAttributes(unsigned dims) : padding_before(dims), padding_after(dims), padding_value(0.0) {}
+
+  std::vector<std::int32_t> padding_before;
+  std::vector<std::int32_t> padding_after;
+  float padding_value;
+};
+}
+
+#endif
\ No newline at end of file
index c53cd79..d8b637a 100644 (file)
@@ -18,7 +18,7 @@
 #define _MIR_OPS_AVG_POOL_OP_H_
 
 #include "mir/Operation.h"
-#include "mir/DataFormat.h"
+#include "mir/Attributes.h"
 
 #include <cstdint>
 #include <vector>
@@ -36,40 +36,47 @@ public:
               const std::vector<std::int32_t> &padding_before,
               const std::vector<std::int32_t> &padding_after, bool include_pad,
               DataFormat data_format)
-      : Operation(Type::avgPool2D, {arg}), _window_size(window_size), _strides(strides),
-        _padding_before(padding_before), _padding_after(padding_after), _include_pad(include_pad),
-        _data_format(data_format)
+      : Operation(Type::avgPool2D, {arg})
+  {
+    _attributes.window = window_size;
+    _attributes.strides = strides;
+    _attributes.padding_before = padding_before;
+    _attributes.padding_after = padding_after;
+    _attributes.include_pad = include_pad;
+    _attributes.data_format = data_format;
+
+    inferOutputShapes();
+  }
+
+  AvgPool2DOp(Output *arg, const AvgPool2DOpAttributes &attributes)
+      : Operation(Type::avgPool2D, {arg}), _attributes(attributes)
   {
     inferOutputShapes();
   }
 
   Operation *copyWithInputs(const std::vector<Output *> &inputs) override
   {
-    return new AvgPool2DOp(inputs[0], _window_size, _strides, _padding_before, _padding_after,
-                           _include_pad, _data_format);
+    return new AvgPool2DOp(inputs[0], _attributes);
   };
 
-  const std::vector<std::int32_t> &getWindowSize() const { return _window_size; }
+  const std::vector<std::int32_t> &getWindowSize() const { return _attributes.window; }
+
+  const std::vector<std::int32_t> &getStrides() const { return _attributes.strides; }
 
-  const std::vector<std::int32_t> &getStrides() const { return _strides; }
+  const std::vector<std::int32_t> &getPaddingBefore() const { return _attributes.padding_before; }
 
-  const std::vector<std::int32_t> &getPaddingBefore() const { return _padding_before; }
+  const std::vector<std::int32_t> &getPaddingAfter() const { return _attributes.padding_after; }
 
-  const std::vector<std::int32_t> &getPaddingAfter() const { return _padding_after; }
+  bool getIncludePad() const { return _attributes.include_pad; }
 
-  bool getIncludePad() const { return _include_pad; }
+  DataFormat getDataFormat() const { return _attributes.data_format; }
 
-  DataFormat getDataFormat() const { return _data_format; }
+  const AvgPool2DOpAttributes &getAttributes() const { return _attributes; }
 
 private:
   void inferOutputShapes();
 
-  std::vector<std::int32_t> _window_size;
-  std::vector<std::int32_t> _strides;
-  std::vector<std::int32_t> _padding_before;
-  std::vector<std::int32_t> _padding_after;
-  bool _include_pad;
-  DataFormat _data_format;
+  AvgPool2DOpAttributes _attributes;
 };
 
 } // namespace ops
index 602905b..dd74e13 100644 (file)
@@ -18,7 +18,7 @@
 #define _MIR_OPS_CONV_2D_OP_H_
 
 #include "mir/Operation.h"
-#include "mir/DataFormat.h"
+#include "mir/Attributes.h"
 #include <vector>
 
 namespace mir
@@ -32,33 +32,41 @@ public:
   Conv2DOp(Output *input, Output *kernel, const std::vector<std::int32_t> &strides,
            const std::vector<std::int32_t> &padding_before,
            const std::vector<std::int32_t> &padding_after, DataFormat data_format)
-      : Operation(Type::conv2D, {input, kernel}), _strides(strides),
-        _padding_before(padding_before), _padding_after(padding_after), _data_format(data_format)
+      : Operation(Type::conv2D, {input, kernel})
+  {
+    _attributes.strides = strides;
+    _attributes.padding_before = padding_before;
+    _attributes.padding_after = padding_after;
+    _attributes.data_format = data_format;
+
+    inferOutputShapes();
+  }
+
+  Conv2DOp(Output *input, Output *kernel, const Conv2DOpAttributes &attributes)
+      : Operation(Type::conv2D, {input, kernel}), _attributes(attributes)
   {
     inferOutputShapes();
   }
 
   Operation *copyWithInputs(const std::vector<Output *> &inputs) override
   {
-    return new Conv2DOp(inputs[0], inputs[1], _strides, _padding_before, _padding_after,
-                        _data_format);
+    return new Conv2DOp(inputs[0], inputs[1], _attributes);
   };
 
-  const std::vector<std::int32_t> &getStrides() const { return _strides; }
+  const std::vector<std::int32_t> &getStrides() const { return _attributes.strides; }
+
+  const std::vector<std::int32_t> &getPaddingBefore() const { return _attributes.padding_before; }
 
-  const std::vector<std::int32_t> &getPaddingBefore() const { return _padding_before; }
+  const std::vector<std::int32_t> &getPaddingAfter() const { return _attributes.padding_after; }
 
-  const std::vector<std::int32_t> &getPaddingAfter() const { return _padding_after; }
+  const Conv2DOpAttributes &getAttributes() const { return _attributes; }
 
-  DataFormat getDataFormat() const { return _data_format; }
+  DataFormat getDataFormat() const { return _attributes.data_format; }
 
 private:
   void inferOutputShapes();
 
-  std::vector<std::int32_t> _strides;
-  std::vector<std::int32_t> _padding_before;
-  std::vector<std::int32_t> _padding_after;
-  DataFormat _data_format;
+  Conv2DOpAttributes _attributes;
 };
 
 } // namespace ops
index 68b1797..0748740 100644 (file)
@@ -18,7 +18,7 @@
 #define _MIR_OPS_DECONV_2D_OP_H_
 
 #include "mir/Operation.h"
-#include "mir/DataFormat.h"
+#include "mir/Attributes.h"
 #include "mir/ops/PaddingType.h"
 
 #include <cstdint>
@@ -35,17 +35,38 @@ public:
   DeConv2DOp(Output *input, Output *kernel, const std::vector<std::int32_t> &strides,
              const std::vector<std::int32_t> &padding_before,
              const std::vector<std::int32_t> &padding_after, DataFormat data_format)
-      : Operation(Type::deConv2D, {input, kernel}), _strides(strides),
-        _padding_type(PaddingType::Explicit), _padding_before(padding_before),
-        _padding_after(padding_after), _data_format(data_format)
+      : Operation(Type::deConv2D, {input, kernel})
   {
+    _attributes.strides = strides;
+    _attributes.padding_type = PaddingType::Explicit;
+    _attributes.padding_before = padding_before;
+    _attributes.padding_after = padding_after;
+    _attributes.data_format = data_format;
+
     inferOutputShapes();
   }
 
   DeConv2DOp(Output *input, Output *kernel, const std::vector<std::int32_t> &strides,
              PaddingType padding_type, const Shape &output_shape, DataFormat data_format)
-      : Operation(Type::deConv2D, {input, kernel}), _strides(strides), _padding_type(padding_type),
-        _padding_before(2), _padding_after(2), _data_format(data_format)
+      : Operation(Type::deConv2D, {input, kernel})
+  {
+    _attributes.strides = strides;
+    _attributes.padding_type = padding_type;
+    _attributes.data_format = data_format;
+
+    setOutputShape(0, output_shape);
+    inferPaddings();
+  }
+
+  DeConv2DOp(Output *input, Output *kernel, const Deconv2DOpAttributes &attributes)
+      : Operation(Type::deConv2D, {input, kernel}), _attributes(attributes)
+  {
+    inferOutputShapes();
+  }
+
+  DeConv2DOp(Output *input, Output *kernel, const Deconv2DOpAttributes &attributes,
+             const Shape &output_shape)
+      : Operation(Type::deConv2D, {input, kernel}), _attributes(attributes)
   {
     setOutputShape(0, output_shape);
     inferPaddings();
@@ -54,22 +75,22 @@ public:
   Operation *copyWithInputs(const std::vector<Output *> &inputs) override
   {
     if (getPaddingType() == PaddingType::Explicit)
-      return new DeConv2DOp(inputs[0], inputs[1], getStrides(), getPaddingBefore(),
-                            getPaddingAfter(), getDataFormat());
+      return new DeConv2DOp(inputs[0], inputs[1], _attributes);
     else
-      return new DeConv2DOp(inputs[0], inputs[1], getStrides(), getPaddingType(), getOutputShape(0),
-                            getDataFormat());
+      return new DeConv2DOp(inputs[0], inputs[1], _attributes, getOutputShape(0));
   }
 
-  const std::vector<std::int32_t> &getStrides() const { return _strides; }
+  const std::vector<std::int32_t> &getStrides() const { return _attributes.strides; }
+
+  PaddingType getPaddingType() const { return _attributes.padding_type; }
 
-  PaddingType getPaddingType() const { return _padding_type; }
+  const std::vector<std::int32_t> &getPaddingBefore() const { return _attributes.padding_before; }
 
-  const std::vector<std::int32_t> &getPaddingBefore() const { return _padding_before; }
+  const std::vector<std::int32_t> &getPaddingAfter() const { return _attributes.padding_after; }
 
-  const std::vector<std::int32_t> &getPaddingAfter() const { return _padding_after; }
+  DataFormat getDataFormat() const { return _attributes.data_format; }
 
-  DataFormat getDataFormat() const { return _data_format; }
+  const Deconv2DOpAttributes &getAttributes() const { return _attributes; }
 
 private:
   void inferOutputShapes();
@@ -79,11 +100,7 @@ private:
    */
   void inferPaddings();
 
-  std::vector<std::int32_t> _strides;
-  PaddingType _padding_type;
-  std::vector<std::int32_t> _padding_before;
-  std::vector<std::int32_t> _padding_after;
-  DataFormat _data_format;
+  Deconv2DOpAttributes _attributes;
 };
 
 } // namespace ops
index 3d31f69..b9686a4 100644 (file)
@@ -18,7 +18,7 @@
 #define _MIR_OPS_DEPTHWISE_CONV_2D_OP_H_
 
 #include "mir/Operation.h"
-#include "mir/DataFormat.h"
+#include "mir/Attributes.h"
 #include <vector>
 
 namespace mir
@@ -32,33 +32,41 @@ public:
   DepthwiseConv2DOp(Output *input, Output *kernel, const std::vector<std::int32_t> &strides,
                     const std::vector<std::int32_t> &padding_before,
                     const std::vector<std::int32_t> &padding_after, DataFormat data_format)
-      : Operation(Type::depthwiseConv, {input, kernel}), _strides(strides),
-        _padding_before(padding_before), _padding_after(padding_after), _data_format(data_format)
+      : Operation(Type::depthwiseConv, {input, kernel})
+  {
+    _attributes.strides = strides;
+    _attributes.padding_before = padding_before;
+    _attributes.padding_after = padding_after;
+    _attributes.data_format = data_format;
+
+    inferOutputShapes();
+  }
+
+  DepthwiseConv2DOp(Output *input, Output *kernel, const Conv2DOpAttributes &attributes)
+      : Operation(Type::depthwiseConv, {input, kernel}), _attributes(attributes)
   {
     inferOutputShapes();
   }
 
   Operation *copyWithInputs(const std::vector<Output *> &inputs) override
   {
-    return new DepthwiseConv2DOp(inputs[0], inputs[1], _strides, _padding_before, _padding_after,
-                                 _data_format);
+    return new DepthwiseConv2DOp(inputs[0], inputs[1], _attributes);
   }
 
-  const std::vector<std::int32_t> &getStrides() const { return _strides; }
+  const std::vector<std::int32_t> &getStrides() const { return _attributes.strides; }
+
+  const std::vector<std::int32_t> &getPaddingBefore() const { return _attributes.padding_before; }
 
-  const std::vector<std::int32_t> &getPaddingBefore() const { return _padding_before; }
+  const std::vector<std::int32_t> &getPaddingAfter() const { return _attributes.padding_after; }
 
-  const std::vector<std::int32_t> &getPaddingAfter() const { return _padding_after; }
+  DataFormat getDataFormat() const { return _attributes.data_format; }
 
-  DataFormat getDataFormat() const { return _data_format; }
+  const Conv2DOpAttributes &getAttributes() const { return _attributes; }
 
 private:
   void inferOutputShapes();
 
-  std::vector<std::int32_t> _strides;
-  std::vector<std::int32_t> _padding_before;
-  std::vector<std::int32_t> _padding_after;
-  DataFormat _data_format;
+  mir::Conv2DOpAttributes _attributes;
 };
 
 } // namespace ops
index 2a9d55c..c774f34 100644 (file)
@@ -18,7 +18,7 @@
 #define _MIR_OPS_MAX_POOL_OP_H_
 
 #include "mir/Operation.h"
-#include "mir/DataFormat.h"
+#include "mir/Attributes.h"
 
 #include <cstdint>
 #include <vector>
@@ -35,36 +35,44 @@ public:
               const std::vector<std::int32_t> &strides,
               const std::vector<std::int32_t> &padding_before,
               const std::vector<std::int32_t> &padding_after, DataFormat data_format)
-      : Operation(Type::maxPool2D, {arg}), _window_size(window_size), _strides(strides),
-        _padding_before(padding_before), _padding_after(padding_after), _data_format(data_format)
+      : Operation(Type::maxPool2D, {arg})
+  {
+    _attributes.window = window_size;
+    _attributes.strides = strides;
+    _attributes.padding_before = padding_before;
+    _attributes.padding_after = padding_after;
+    _attributes.data_format = data_format;
+
+    inferOutputShapes();
+  }
+
+  MaxPool2DOp(Output *arg, const MaxPool2DOpAttributes &attributes)
+      : Operation(Type::maxPool2D, {arg}), _attributes(attributes)
   {
     inferOutputShapes();
   }
 
   Operation *copyWithInputs(const std::vector<Output *> &inputs) override
   {
-    return new MaxPool2DOp(inputs[0], _window_size, _strides, _padding_before, _padding_after,
-                           _data_format);
+    return new MaxPool2DOp(inputs[0], _attributes);
   };
 
-  const std::vector<std::int32_t> &getWindowSize() const { return _window_size; }
+  const std::vector<std::int32_t> &getWindowSize() const { return _attributes.window; }
+
+  const std::vector<std::int32_t> &getStrides() const { return _attributes.strides; }
 
-  const std::vector<std::int32_t> &getStrides() const { return _strides; }
+  const std::vector<std::int32_t> &getPaddingBefore() const { return _attributes.padding_before; }
 
-  const std::vector<std::int32_t> &getPaddingBefore() const { return _padding_before; }
+  const std::vector<std::int32_t> &getPaddingAfter() const { return _attributes.padding_after; }
 
-  const std::vector<std::int32_t> &getPaddingAfter() const { return _padding_after; }
+  DataFormat getDataFormat() const { return _attributes.data_format; }
 
-  DataFormat getDataFormat() const { return _data_format; }
+  const MaxPool2DOpAttributes &getAttributes() const { return _attributes; }
 
 private:
   void inferOutputShapes();
 
-  std::vector<std::int32_t> _window_size;
-  std::vector<std::int32_t> _strides;
-  std::vector<std::int32_t> _padding_before;
-  std::vector<std::int32_t> _padding_after;
-  DataFormat _data_format;
+  MaxPool2DOpAttributes _attributes;
 };
 
 } // namespace ops
index 1b109e5..ca21574 100644 (file)
@@ -18,6 +18,7 @@
 #define _MIR_OPS_PAD_OP_H_
 
 #include "mir/Operation.h"
+#include "mir/Attributes.h"
 
 namespace mir
 {
@@ -33,30 +34,38 @@ public:
   /// @param padding_value The value to be used for padding.
   PadOp(Output *arg, const std::vector<std::int32_t> &padding_before,
         const std::vector<std::int32_t> &padding_after, float padding_value)
-      : Operation(Type::pad, {arg}), _padding_before(padding_before), _padding_after(padding_after),
-        _padding_value(padding_value)
+      : Operation(Type::pad, {arg})
   {
-    assert(_padding_before.size() == _padding_after.size());
+    _attributes.padding_before = padding_before;
+    _attributes.padding_after = padding_after;
+    _attributes.padding_value = padding_value;
+
+    assert(_attributes.padding_before.size() == _attributes.padding_after.size());
+    inferOutputShapes();
+  }
+
+  PadOp(Output *arg, const PadOpAttributes &attributes)
+      : Operation(Type::pad, {arg}), _attributes(attributes)
+  {
+    assert(_attributes.padding_before.size() == _attributes.padding_after.size());
     inferOutputShapes();
   }
 
   Operation *copyWithInputs(const std::vector<Output *> &inputs) override
   {
-    return new PadOp(inputs[0], _padding_before, _padding_after, _padding_value);
+    return new PadOp(inputs[0], _attributes);
   }
 
-  const std::vector<std::int32_t> &getPaddingBefore() const { return _padding_before; }
+  const std::vector<std::int32_t> &getPaddingBefore() const { return _attributes.padding_before; }
 
-  const std::vector<std::int32_t> &getPaddingAfter() const { return _padding_after; }
+  const std::vector<std::int32_t> &getPaddingAfter() const { return _attributes.padding_after; }
 
-  float getPaddingValue() const { return _padding_value; }
+  float getPaddingValue() const { return _attributes.padding_value; }
 
 private:
   void inferOutputShapes();
 
-  std::vector<std::int32_t> _padding_before;
-  std::vector<std::int32_t> _padding_after;
-  float _padding_value;
+  PadOpAttributes _attributes;
 };
 
 } // namespace ops
index 5a16e3a..2da6613 100644 (file)
@@ -24,16 +24,16 @@ namespace ops
 void AvgPool2DOp::inferOutputShapes()
 {
   const auto &input_shape = getInputShape(0);
-  const int batch_dim_index = getDataBatchDimIndex(_data_format);
-  const int channel_dim_index = getDataChannelDimIndex(_data_format);
+  const int batch_dim_index = getDataBatchDimIndex(_attributes.data_format);
+  const int channel_dim_index = getDataChannelDimIndex(_attributes.data_format);
 
   constexpr int num_spatial_dims = 2;
 
   assert(input_shape.rank() == 4);
-  assert(_window_size.size() == num_spatial_dims);
-  assert(_strides.size() == num_spatial_dims);
-  assert(_padding_before.size() == num_spatial_dims);
-  assert(_padding_after.size() == num_spatial_dims);
+  assert(_attributes.window.size() == num_spatial_dims);
+  assert(_attributes.strides.size() == num_spatial_dims);
+  assert(_attributes.padding_before.size() == num_spatial_dims);
+  assert(_attributes.padding_after.size() == num_spatial_dims);
 
   Shape output_shape(4);
 
@@ -42,13 +42,15 @@ void AvgPool2DOp::inferOutputShapes()
 
   for (int i = 0; i < num_spatial_dims; i++)
   {
-    const int spatial_dim_index = getDataSpatialDimIndex(_data_format, i);
-    const std::int32_t padded_input =
-        input_shape.dim(spatial_dim_index) + _padding_before.at(i) + _padding_after.at(i);
+    const int spatial_dim_index = getDataSpatialDimIndex(_attributes.data_format, i);
+    const std::int32_t padded_input = input_shape.dim(spatial_dim_index) +
+                                      _attributes.padding_before.at(i) +
+                                      _attributes.padding_after.at(i);
     // out_size = ceil((in_size - window_size + 1) / stride) =
     //   (in_size - window_size + 1 + stride - 1) / stride =
     //   (in_size - window_size) / stride + 1
-    output_shape.dim(spatial_dim_index) = (padded_input - _window_size[i]) / _strides[i] + 1;
+    output_shape.dim(spatial_dim_index) =
+        (padded_input - _attributes.window[i]) / _attributes.strides[i] + 1;
   }
 
   setOutputShape(0, output_shape);
index 9967aa5..065acb5 100644 (file)
@@ -26,15 +26,15 @@ void Conv2DOp::inferOutputShapes()
   // Kernel shape: [Co, Hk, Wk, Ci].
   const auto &input_shape = getInputShape(0);
   const auto &kernel_shape = getInputShape(1);
-  const int batch_dim_index = getDataBatchDimIndex(_data_format);
-  const int channel_dim_index = getDataChannelDimIndex(_data_format);
+  const int batch_dim_index = getDataBatchDimIndex(_attributes.data_format);
+  const int channel_dim_index = getDataChannelDimIndex(_attributes.data_format);
 
   assert(input_shape.rank() == 4);
   assert(kernel_shape.rank() == 4);
   assert(kernel_shape.dim(3) == input_shape.dim(channel_dim_index));
-  assert(_strides.size() == 2);
-  assert(_padding_before.size() == 2);
-  assert(_padding_after.size() == 2);
+  assert(_attributes.strides.size() == 2);
+  assert(_attributes.padding_before.size() == 2);
+  assert(_attributes.padding_after.size() == 2);
 
   Shape output_shape(4);
 
@@ -43,14 +43,14 @@ void Conv2DOp::inferOutputShapes()
 
   for (int i = 0; i < 2; i++)
   {
-    const int spatial_dim_index = getDataSpatialDimIndex(_data_format, i);
-    const std::int32_t padded_input =
-        input_shape.dim(spatial_dim_index) + _padding_before[i] + _padding_after[i];
+    const int spatial_dim_index = getDataSpatialDimIndex(_attributes.data_format, i);
+    const std::int32_t padded_input = input_shape.dim(spatial_dim_index) +
+                                      _attributes.padding_before[i] + _attributes.padding_after[i];
     // out_size = ceil((in_size - kernel_size + 1) / stride) =
     //   (in_size - kernel_size + 1 + stride - 1) / stride =
     //   (in_size - kernel_size) / stride + 1
     output_shape.dim(spatial_dim_index) =
-        (padded_input - kernel_shape.dim(1 + i)) / _strides[i] + 1;
+        (padded_input - kernel_shape.dim(1 + i)) / _attributes.strides[i] + 1;
   }
 
   setOutputShape(0, output_shape);
index 75a4a24..3b3def4 100644 (file)
@@ -24,7 +24,7 @@ namespace ops
 // See the formulas at https://github.com/onnx/onnx/blob/master/docs/Operators.md#convtranspose.
 void DeConv2DOp::inferPaddings()
 {
-  assert(_padding_type != PaddingType::Explicit);
+  assert(_attributes.padding_type != PaddingType::Explicit);
 
   const auto &input_shape = getInputShape(0);
   const auto &kernel_shape = getInputShape(1);
@@ -34,23 +34,24 @@ void DeConv2DOp::inferPaddings()
 
   for (int i = 0; i < num_spatial_dims; ++i)
   {
-    const int spatial_dim_index = getDataSpatialDimIndex(_data_format, i);
-    const std::int32_t total_padding = (input_shape.dim(spatial_dim_index) - 1) * _strides[i] +
-                                       kernel_shape.dim(i) - output_shape.dim(spatial_dim_index);
+    const int spatial_dim_index = getDataSpatialDimIndex(_attributes.data_format, i);
+    const std::int32_t total_padding =
+        (input_shape.dim(spatial_dim_index) - 1) * _attributes.strides[i] + kernel_shape.dim(i) -
+        output_shape.dim(spatial_dim_index);
 
-    switch (_padding_type)
+    switch (_attributes.padding_type)
     {
       case PaddingType::Valid:
         // TODO Figure out what to do.
         assert(false);
         break;
       case PaddingType::SameLower:
-        _padding_after[i] = total_padding / 2;
-        _padding_before[i] = total_padding - _padding_after[i];
+        _attributes.padding_after[i] = total_padding / 2;
+        _attributes.padding_before[i] = total_padding - _attributes.padding_after[i];
         break;
       case PaddingType::SameUpper:
-        _padding_before[i] = total_padding / 2;
-        _padding_after[i] = total_padding - _padding_before[i];
+        _attributes.padding_before[i] = total_padding / 2;
+        _attributes.padding_after[i] = total_padding - _attributes.padding_before[i];
         break;
       default:
         assert(false);
@@ -61,13 +62,13 @@ void DeConv2DOp::inferPaddings()
 // See the formulas at https://github.com/onnx/onnx/blob/master/docs/Operators.md#convtranspose.
 void DeConv2DOp::inferOutputShapes()
 {
-  assert(_padding_type == PaddingType::Explicit);
+  assert(_attributes.padding_type == PaddingType::Explicit);
 
   // Kernel shape: [Hk, Wk, Co, Ci]
   const auto &input_shape = getInputShape(0);
   const auto &kernel_shape = getInputShape(1);
-  const int batch_dim_index = getDataBatchDimIndex(_data_format);
-  const int channel_dim_index = getDataChannelDimIndex(_data_format);
+  const int batch_dim_index = getDataBatchDimIndex(_attributes.data_format);
+  const int channel_dim_index = getDataChannelDimIndex(_attributes.data_format);
 
   assert(input_shape.rank() == 4);
   assert(kernel_shape.rank() == 4);
@@ -82,10 +83,10 @@ void DeConv2DOp::inferOutputShapes()
 
   for (int i = 0; i < num_spatial_dims; i++)
   {
-    const int spatial_dim_index = getDataSpatialDimIndex(_data_format, i);
-    output_shape.dim(spatial_dim_index) = (input_shape.dim(spatial_dim_index) - 1) * _strides[i] +
-                                          kernel_shape.dim(i) -
-                                          (_padding_before.at(i) + _padding_after.at(i));
+    const int spatial_dim_index = getDataSpatialDimIndex(_attributes.data_format, i);
+    output_shape.dim(spatial_dim_index) =
+        (input_shape.dim(spatial_dim_index) - 1) * _attributes.strides[i] + kernel_shape.dim(i) -
+        (_attributes.padding_before.at(i) + _attributes.padding_after.at(i));
   }
 
   setOutputShape(0, output_shape);
index ff128b3..7b6e2f7 100644 (file)
@@ -26,15 +26,15 @@ void DepthwiseConv2DOp::inferOutputShapes()
   // Kernel shape: [Hk, Wk, Ci, M].
   const auto &input_shape = getInputShape(0);
   const auto &kernel_shape = getInputShape(1);
-  const int batch_dim_index = getDataBatchDimIndex(_data_format);
-  const int channel_dim_index = getDataChannelDimIndex(_data_format);
+  const int batch_dim_index = getDataBatchDimIndex(_attributes.data_format);
+  const int channel_dim_index = getDataChannelDimIndex(_attributes.data_format);
 
   assert(input_shape.rank() == 4);
   assert(kernel_shape.rank() == 4);
   assert(input_shape.dim(channel_dim_index) == kernel_shape.dim(2));
-  assert(_strides.size() == 2);
-  assert(_padding_before.size() == 2);
-  assert(_padding_after.size() == 2);
+  assert(_attributes.strides.size() == 2);
+  assert(_attributes.padding_before.size() == 2);
+  assert(_attributes.padding_after.size() == 2);
 
   Shape output_shape(4);
 
@@ -43,13 +43,14 @@ void DepthwiseConv2DOp::inferOutputShapes()
 
   for (int i = 0; i < 2; i++)
   {
-    const int spatial_dim_index = getDataSpatialDimIndex(_data_format, i);
-    const std::int32_t padded_input =
-        input_shape.dim(spatial_dim_index) + _padding_before[i] + _padding_after[i];
+    const int spatial_dim_index = getDataSpatialDimIndex(_attributes.data_format, i);
+    const std::int32_t padded_input = input_shape.dim(spatial_dim_index) +
+                                      _attributes.padding_before[i] + _attributes.padding_after[i];
     // out_size = ceil((in_size - kernel_size + 1) / stride) =
     //   (in_size - kernel_size + 1 + stride - 1) / stride =
     //   (in_size - kernel_size) / stride + 1
-    output_shape.dim(spatial_dim_index) = (padded_input - kernel_shape.dim(i)) / _strides[i] + 1;
+    output_shape.dim(spatial_dim_index) =
+        (padded_input - kernel_shape.dim(i)) / _attributes.strides[i] + 1;
   }
 
   setOutputShape(0, output_shape);
index df7e538..7480719 100644 (file)
@@ -24,16 +24,16 @@ namespace ops
 void MaxPool2DOp::inferOutputShapes()
 {
   const auto &input_shape = getInputShape(0);
-  const int batch_dim_index = getDataBatchDimIndex(_data_format);
-  const int channel_dim_index = getDataChannelDimIndex(_data_format);
+  const int batch_dim_index = getDataBatchDimIndex(_attributes.data_format);
+  const int channel_dim_index = getDataChannelDimIndex(_attributes.data_format);
 
   constexpr int num_spatial_dims = 2;
 
   assert(input_shape.rank() == 4);
-  assert(_window_size.size() == num_spatial_dims);
-  assert(_strides.size() == num_spatial_dims);
-  assert(_padding_before.size() == num_spatial_dims);
-  assert(_padding_after.size() == num_spatial_dims);
+  assert(_attributes.window.size() == num_spatial_dims);
+  assert(_attributes.strides.size() == num_spatial_dims);
+  assert(_attributes.padding_before.size() == num_spatial_dims);
+  assert(_attributes.padding_after.size() == num_spatial_dims);
 
   Shape output_shape(4);
 
@@ -42,13 +42,15 @@ void MaxPool2DOp::inferOutputShapes()
 
   for (int i = 0; i < num_spatial_dims; i++)
   {
-    const int spatial_dim_index = getDataSpatialDimIndex(_data_format, i);
-    const std::int32_t padded_input =
-        input_shape.dim(spatial_dim_index) + _padding_before.at(i) + _padding_after.at(i);
+    const int spatial_dim_index = getDataSpatialDimIndex(_attributes.data_format, i);
+    const std::int32_t padded_input = input_shape.dim(spatial_dim_index) +
+                                      _attributes.padding_before.at(i) +
+                                      _attributes.padding_after.at(i);
     // out_size = ceil((in_size - window_size + 1) / stride) =
     //   (in_size - window_size + 1 + stride - 1) / stride =
     //   (in_size - window_size) / stride + 1
-    output_shape.dim(spatial_dim_index) = (padded_input - _window_size[i]) / _strides[i] + 1;
+    output_shape.dim(spatial_dim_index) =
+        (padded_input - _attributes.window[i]) / _attributes.strides[i] + 1;
   }
 
   setOutputShape(0, output_shape);
index 75d8662..890daf1 100644 (file)
@@ -29,7 +29,8 @@ void PadOp::inferOutputShapes()
   Shape out_shape(num_dims);
   for (int32_t dim = 0; dim < num_dims; ++dim)
   {
-    out_shape.dim(dim) = _padding_before[dim] + input_shape.dim(dim) + _padding_after[dim];
+    out_shape.dim(dim) =
+        _attributes.padding_before[dim] + input_shape.dim(dim) + _attributes.padding_after[dim];
   }
 
   setOutputShape(0, out_shape);