[mir] Introduce copy with changed input method to Operation (#3346)
authorАндрей Шедько/AI Tools Lab /SRR/Engineer/삼성전자 <a.shedko@samsung.com>
Wed, 24 Apr 2019 17:55:25 +0000 (20:55 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 24 Apr 2019 17:55:25 +0000 (20:55 +0300)
Introduce "copy with changed input" method to Operation
to facilitate easier creation of optimizations

Signed-off-by: Andrei Shedko <a.shedko@samsung.com>
34 files changed:
contrib/mir/include/mir/Graph.h
contrib/mir/include/mir/Operation.h
contrib/mir/include/mir/ops/BatchNormOp.h
contrib/mir/include/mir/ops/BiasAddOp.h
contrib/mir/include/mir/ops/CappedReluOp.h
contrib/mir/include/mir/ops/ConcatOp.h
contrib/mir/include/mir/ops/ConstantOp.h
contrib/mir/include/mir/ops/Conv2DOp.h
contrib/mir/include/mir/ops/Deconv2DOp.h
contrib/mir/include/mir/ops/DepthwiseConv2DOp.h
contrib/mir/include/mir/ops/DropoutOp.h
contrib/mir/include/mir/ops/ElementwiseOp.h
contrib/mir/include/mir/ops/EluOp.h
contrib/mir/include/mir/ops/FullyConnectedOp.h
contrib/mir/include/mir/ops/GatherOp.h
contrib/mir/include/mir/ops/GemmOp.h
contrib/mir/include/mir/ops/InputOp.h
contrib/mir/include/mir/ops/LeakyReluOp.h
contrib/mir/include/mir/ops/OutputOp.h
contrib/mir/include/mir/ops/PadOp.h
contrib/mir/include/mir/ops/PoolOp.h
contrib/mir/include/mir/ops/ReduceFOp.h
contrib/mir/include/mir/ops/ReluOp.h
contrib/mir/include/mir/ops/ReshapeOp.h
contrib/mir/include/mir/ops/ResizeOp.h
contrib/mir/include/mir/ops/ScaleOp.h
contrib/mir/include/mir/ops/SigmoidOp.h
contrib/mir/include/mir/ops/SliceOp.h
contrib/mir/include/mir/ops/SoftmaxOp.h
contrib/mir/include/mir/ops/SqrtOp.h
contrib/mir/include/mir/ops/SqueezeOp.h
contrib/mir/include/mir/ops/TanhOp.h
contrib/mir/include/mir/ops/TransposeOp.h
contrib/mir/src/TensorVariant.cpp

index b333eb9..a61bd2c 100644 (file)
@@ -49,6 +49,18 @@ class Graph {
     return op;
   }
 
+  /**
+   * @brief Copies `old_op` with new inputs and registers it into graph.
+   */
+  Operation* copyOpWithInputs(Operation* old_op, const std::vector<Operation::Output*>& inputs) {
+    assert(inputs.size() == old_op->getNumInputs());
+    auto op = old_op->copyWithInputs(inputs);
+    op->setId(_lastNodeId++);
+    op->setName(old_op->getName());
+    registerOp(op);
+    return op;
+  }
+
   void accept(IVisitor* visitor);
 
   /**
index 89edf9a..96190de 100644 (file)
@@ -112,7 +112,6 @@ public:
     Output* _producer;
   };
 
-
   virtual ~Operation() = default;
 
   Type getType() const { return _type; }
@@ -162,6 +161,8 @@ public:
 
   void accept(IVisitor* v);
 
+  virtual Operation* copyWithInputs(const std::vector<Output*>& inputs) = 0;
+
 protected:
   Operation(Type type, const std::vector<Output*>& inputs, std::size_t num_outputs = 1);
 
index ebdb37a..4d63d72 100644 (file)
@@ -31,6 +31,10 @@ public:
     setOutputShape(0, getInputShape(0));
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new BatchNormOp(inputs[0], _movingAvgFraction, _eps, _spatial);
+  }
+
   /**
    * @return The epsilon value to use to avoid division by zero.
    */
index c4304d2..0e1369e 100644 (file)
@@ -29,6 +29,10 @@ public:
     // Infer output shape.
     setOutputShape(0, getInputShape(0));
   }
+
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new BiasAddOp(inputs[0], inputs[1]);
+  }
 };
 
 } // namespace ops
index 909e23a..a72b8ee 100644 (file)
@@ -29,6 +29,10 @@ public:
     setOutputShape(0, getInputShape(0));
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new CappedReluOp(inputs[0], _cap);
+  }
+
   float getCap() const { return _cap; }
 
 private:
index d60f612..28af3da 100644 (file)
@@ -32,6 +32,10 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new ConcatOp(inputs, _axis);
+  };
+
   int32_t getAxis() const {
     if (_axis < 0) {
       // Negative axis is used to index starting from the last element of the shape
index 7168e42..b3afed3 100644 (file)
@@ -30,6 +30,12 @@ public:
 
   const TensorVariant& getValue() const { return _value; }
 
+  Operation* copyWithInputs(const std::vector<mir::Operation::Output*>& input) override {
+    assert(false && "Copying constants is not allowed!");
+    (void) input;
+    return nullptr;
+  }
+
 private:
   TensorVariant _value;
 };
index cf02f79..6f9a71d 100644 (file)
@@ -37,6 +37,10 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new Conv2DOp(inputs[0], inputs[1], _strides, _paddingBefore, _paddingAfter);
+  };
+
   const Shape& getStrides() const { return _strides; }
 
   const std::vector<int32_t>& getPaddingBefore() const { return _paddingBefore; }
index 38c424d..7b746fb 100644 (file)
@@ -66,6 +66,10 @@ public:
     inferPaddings();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new DeConv2DOp(inputs[0], inputs[1], _strides, _paddingAfter);
+  };
+
   const Shape& getStrides() const { return _strides; }
 
   PaddingType getPaddingType() const { return _paddingType; }
index e1d6d6a..5cae15e 100644 (file)
@@ -37,6 +37,10 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new DepthwiseConv2DOp(inputs[0], inputs[1], _strides, _paddingBefore, _paddingAfter);
+  }
+
   const Shape& getStrides() const { return _strides; }
 
   const std::vector<int32_t>& getPaddingBefore() const { return _paddingBefore; }
index 4b3b494..9b4094c 100644 (file)
@@ -29,6 +29,10 @@ public:
     setOutputShape(0, getInputShape(0));
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new DropoutOp(inputs[0], _rate);
+  }
+
   /**
    * @return The ratio of random dropout
    */
index ba00cc6..a2c4cc3 100644 (file)
@@ -43,6 +43,10 @@ public:
     inferOutputShapes();
   };
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new ElementwiseOp(inputs, _opType);
+  }
+
   bool getBroadcast() const { return _needsBroadcast; }
 
 private:
index 53a9122..6c97a93 100644 (file)
@@ -28,6 +28,10 @@ public:
     setOutputShape(0, getInputShape(0));
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new EluOp(inputs[0], _alpha);
+  }
+
   float getAlpha() const { return _alpha; }
 
 private:
index 5a97fcf..4e47f20 100644 (file)
@@ -29,6 +29,10 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new FullyConnectedOp(inputs[0], inputs[1]);
+  }
+
 private:
   void inferOutputShapes();
 };
index 7302e59..773cb7c 100644 (file)
@@ -34,6 +34,10 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new GatherOp(inputs[0], inputs[1], _axis);
+  }
+
   int32_t getAxis() const { return _axis; }
 
 private:
index c9d8491..4dc5766 100644 (file)
@@ -29,6 +29,10 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new GemmOp(inputs[0], inputs[1], inputs[2]);
+  }
+
 private:
   void inferOutputShapes();
 };
index ad49f7f..936043f 100644 (file)
@@ -27,6 +27,12 @@ public:
   explicit InputOp(const Shape& shape) : Operation(Type::input, {}) {
     setOutputShape(0, shape);
   }
+
+  Operation* copyWithInputs(const std::vector<Output*>& input) override {
+    assert(false && "copying graph input is not allowed");
+    (void) input;
+    return nullptr;
+  }
 };
 
 } // namespace ops
index c83fc16..7fbd9e4 100644 (file)
@@ -30,6 +30,10 @@ public:
     setOutputShape(0, getInputShape(0));
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new LeakyReluOp(inputs[0], _alpha);
+  }
+
   float getAlpha() const {
     return _alpha;
   }
index 544e895..469e793 100644 (file)
@@ -25,6 +25,10 @@ namespace ops {
 class OutputOp : public Operation {
 public:
   explicit OutputOp(Output* input) : Operation(Type::output, {input}) {}
+
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new OutputOp(inputs[0]);
+  }
 };
 
 } // namespace ops
index 9eca214..88fed11 100644 (file)
@@ -43,6 +43,10 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new PadOp(inputs[0], _numDims, _paddings, _scalarValue);
+  }
+
   /**
    * @param dim Dimension number
    * @return Pair of paddings for dimension
index 3938c3e..46ac221 100644 (file)
@@ -55,6 +55,11 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new PoolOp(inputs[0],_poolingType, _windowShape,_strides,
+        _paddingBefore,_paddingAfter, _borderType);
+  };
+
   BorderType getBorderType() const { return _borderType; }
 
   PoolingType getPoolingType() const { return _poolingType; }
index 5fd1b67..f2ac010 100644 (file)
@@ -73,6 +73,10 @@ public:
     setOutputShape(0, output_shape);
   };
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new ReduceFOp(inputs[0], _reduceDims, _keepDims, _funcType);
+  }
+
   const std::vector<int32_t>& getReductionDims() { return _reduceDims; };
 
   bool getKeepDims() const { return _keepDims; };
index 1ac5a0c..0149188 100644 (file)
@@ -28,6 +28,10 @@ public:
     // Infer output shape.
     setOutputShape(0, getInputShape(0));
   }
+
+  Operation* copyWithInputs(const std::vector<Output*>& arg) override {
+    return new ReluOp(arg[0]);
+  }
 };
 
 } // namespace ops
index 791f05b..03586a0 100644 (file)
@@ -45,6 +45,10 @@ public:
 
     setOutputShape(0, output_shape);
   }
+
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new ReshapeOp(inputs[0], getOutputShape(0));
+  }
 };
 
 } // namespace ops
index 98f6a29..963d40a 100644 (file)
@@ -50,6 +50,10 @@ public:
     setOutputShape(0, output_shape);
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new ResizeOp(inputs[0], _mode, getOutputShape(0));
+  }
+
   ResizeOp(Output* arg, ResizeMethod mode, const Shape& output_shape)
       : Operation(Type::resizeIm, {arg}), _mode(mode) {
     // Calculate scales based on given shape.
index 7ba276c..a4c3490 100644 (file)
@@ -28,6 +28,10 @@ public:
     // Infer output shape.
     setOutputShape(0, getInputShape(0));
   }
+
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new ScaleOp(inputs[0], inputs[1]);
+  }
 };
 
 } // namespace ops
index 3eb743e..d7780e0 100644 (file)
@@ -28,6 +28,10 @@ public:
     // Infer output shape.
     setOutputShape(0, getInputShape(0));
   }
+
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new SigmoidOp(inputs[0]);
+  }
 };
 
 } // namespace ops
index b2d25c5..03da4ce 100644 (file)
@@ -29,6 +29,10 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new SliceOp(inputs[0], _starts, _sizes);
+  }
+
   const Shape& getStarts() { return _starts; }
 
   const Shape& getSizes() { return _sizes; }
index 2a370cc..d32c7a1 100644 (file)
@@ -31,6 +31,10 @@ public:
     setOutputShape(0, getInputShape(0));
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new SoftmaxOp(inputs[0], _axis);
+  }
+
   int32_t getAxis() const {
     if (_axis < 0) {
       // Negative axis is used to index starting from the last element of the shape
index e52bb55..1728d72 100644 (file)
@@ -28,6 +28,10 @@ public:
   explicit SqrtOp(Output* arg) : Operation(Type::sqrt, {arg}) {
     setOutputShape(0, getInputShape(0));
   };
+
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new SqrtOp(inputs[0]);
+  }
 };
 
 } // namespace ops
index 2aadc32..9017b4f 100644 (file)
@@ -31,6 +31,10 @@ public:
     inferOutputShapes();
   }
 
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new SqueezeOp(inputs[0], _dims_to_squeeze);
+  }
+
   void inferOutputShapes();
 
   int32_t getNumSqueezeDims() const { return static_cast<int32_t>(_dims_to_squeeze.size()); }
index 51ea345..67cd5e9 100644 (file)
@@ -28,6 +28,10 @@ public:
     // Infer output shape.
     setOutputShape(0, getInputShape(0));
   }
+
+  Operation* copyWithInputs(const std::vector<Output*>& inputs) override {
+    return new TanhOp(inputs[0]);
+  }
 };
 
 } // namespace ops
index f50786a..09950c7 100644 (file)
@@ -34,6 +34,10 @@ public:
 
   const std::vector<std::size_t>& getAxisOrder() const { return _axisOrder; }
 
+  Operation* copyWithInputs(const std::vector<Output*>& arg) override {
+    return new TransposeOp(arg[0], _axisOrder);
+  }
+
 private:
   void inferOutputShapes();
 
index 6caf1ca..3c76ecd 100644 (file)
@@ -21,7 +21,7 @@ namespace mir
 {
 
 TensorVariant::TensorVariant(DTYPE dtype, const Shape& shape)
-    : _dtype(dtype), _shape(shape), _strides(shape.rank()) {
+    : _dtype(dtype), _strides(shape.rank()), _shape(shape) {
   switch (dtype) {
     case DTYPE::FLOAT32:
       _elementSize = sizeof(float);