Add Pool operation description (#286)
authorVladimir Plazun/AI Tools Lab/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Fri, 1 Jun 2018 14:12:50 +0000 (18:12 +0400)
committerSergey Vostokov/AI Tools Lab/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Fri, 1 Jun 2018 14:12:50 +0000 (17:12 +0300)
Add Pool operation desription class

This class is used to represent pool operation in computation graph

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/libs/core/include/nnc/core/IR/model/operations/pool_op.h [new file with mode: 0644]

diff --git a/contrib/nnc/libs/core/include/nnc/core/IR/model/operations/pool_op.h b/contrib/nnc/libs/core/include/nnc/core/IR/model/operations/pool_op.h
new file mode 100644 (file)
index 0000000..75bcdc1
--- /dev/null
@@ -0,0 +1,71 @@
+#ifndef _NNC_CORE_IR_MODEL_POOL_H_
+#define _NNC_CORE_IR_MODEL_POOL_H_
+
+#include <vector>
+
+#include "nnc/core/IR/model/operations/operation.h"
+#include "nnc/core/IR/model/operations/common.h"
+
+#include "nncc/core/ADT/tensor/Shape.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace IR
+{
+namespace model
+{
+namespace ops
+{
+
+using nncc::core::ADT::tensor::Shape;
+
+class PoolOp : public OpDescription
+{
+public:
+  enum class PoolingType
+  {
+    MAX,
+    AVG,
+    MIN
+  };
+
+  explicit PoolOp(const Shape &windowShape, const Shape &strides, PoolingType poolType,
+                  PaddingType padding)
+      : OpDescription(1, 1), _padding(padding), _poolingType(poolType), _strides(strides),
+        _windowShape(windowShape)
+  {
+    _pads.resize(_windowShape.rank());
+  }
+
+  PaddingType getPaddingType() const { return _padding; }
+
+  PoolingType getPoolingType() const { return _poolingType; }
+
+  const Shape &getWindowShape() const { return _windowShape; }
+
+  const Shape &getStrides() const { return _strides; }
+
+  const int getPadding(int dim) const { return _pads[dim]; }
+
+  void setPadding(int dim, int pad) { _pads[dim] = pad; }
+
+private:
+  PaddingType _padding;
+  PoolingType _poolingType;
+  Shape _windowShape;
+  Shape _strides;
+  std::vector<int> _pads;
+};
+
+} // namespace ops
+} // namespace model
+} // namespace IR
+} // namespace core
+} // namespace contrib
+} // namespace nncc
+
+#endif //_NNC_CORE_IR_MODEL_POOL_H_