From e02b6c00164ed04e0bd17558ca4bb3efe59b30b6 Mon Sep 17 00:00:00 2001
From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?=
=?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?=
=?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?=
Date: Wed, 16 Jan 2019 17:22:21 +0300
Subject: [PATCH] [nnc] Remove BiasAdd and ScaleOp from Add operation in Caffe2
frontend (#2857)
* Remove BiasAdd and ScaleOp from caffe2 convertAdd
* Fix BatchNorm assertion of caffe frontend
* Fix Elementwise with broadcast on caffe frontend
Signed-off-by: Pavel Iliutchenko
---
.../nnc/passes/caffe2_frontend/caffe2_importer.cpp | 8 +-
.../passes/caffe2_frontend/caffe2_op_creator.cpp | 106 +++++++++------------
.../nnc/passes/caffe2_frontend/caffe2_op_creator.h | 10 +-
.../nnc/passes/caffe_frontend/caffe_op_creator.cpp | 39 ++++++--
4 files changed, 87 insertions(+), 76 deletions(-)
diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp b/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp
index 246bafd..c075f99 100644
--- a/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp
+++ b/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp
@@ -101,15 +101,9 @@ void Caffe2Importer::collectUnsupportedOp(const OperatorDef& op) {
SupportedCaffe2OpType opType = _operatorTypes.at(op.type());
switch (opType) {
- case SupportedCaffe2OpType::add:
- _opCreator->checkAdd(op, _problemsOpSet);
- break;
case SupportedCaffe2OpType::FC:
_opCreator->checkFC(op, _problemsOpSet);
break;
- case SupportedCaffe2OpType::mul:
- _opCreator->checkMul(op, _problemsOpSet);
- break;
case SupportedCaffe2OpType::spatialBN:
_opCreator->checkSpatialBN(op, _problemsOpSet);
break;
@@ -118,11 +112,13 @@ void Caffe2Importer::collectUnsupportedOp(const OperatorDef& op) {
case SupportedCaffe2OpType::maxPool:
_opCreator->checkConvLikeOp(op, _problemsOpSet);
break;
+ case SupportedCaffe2OpType::add:
case SupportedCaffe2OpType::concat:
case SupportedCaffe2OpType::constantFill:
case SupportedCaffe2OpType::dropout:
case SupportedCaffe2OpType::givenTensorFill:
case SupportedCaffe2OpType::givenTensorInt64Fill:
+ case SupportedCaffe2OpType::mul:
case SupportedCaffe2OpType::relu:
case SupportedCaffe2OpType::sigmoid:
case SupportedCaffe2OpType::softmax:
diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp
index 585e649..a446c19 100644
--- a/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp
+++ b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp
@@ -158,13 +158,13 @@ static Shape getWindowShape(const ::caffe2::OperatorDef& op,
mir::IODescriptor Caffe2OpCreator::convertCaffeToMIR(const mir::IODescriptor& arg) {
// NCHW -> NHWC
- auto transpose = createOp(arg, std::vector{0, 2, 3, 1});
+ auto transpose = createOp("CaffeToMIR", arg, std::vector{0, 2, 3, 1});
return transpose->getOutput(0);
}
mir::IODescriptor Caffe2OpCreator::convertMIRToCaffe(const mir::IODescriptor& arg) {
// NHWC -> NCHW
- auto transpose = createOp(arg, std::vector{0, 3, 1, 2});
+ auto transpose = createOp("MIRToCaffe", arg, std::vector{0, 3, 1, 2});
return transpose->getOutput(0);
}
@@ -172,17 +172,6 @@ mir::IODescriptor Caffe2OpCreator::convertMIRToCaffe(const mir::IODescriptor& ar
// Check functions
//
-void Caffe2OpCreator::checkAdd(const ::caffe2::OperatorDef& op,
- std::set& problemsOpSet) {
- commonCheck(op, problemsOpSet);
-
- if (getSingleArgument(op, "axis", 1) != 1)
- problemsOpSet.insert("Add: only 'axis' = 1 is supported");
-
- if (getSingleArgument(op, "broadcast", 1) != 1)
- problemsOpSet.insert("Add: only enabled 'broadcast' is supported");
-}
-
void Caffe2OpCreator::checkConvLikeOp(const ::caffe2::OperatorDef& op,
std::set& problemsOpSet) {
commonCheck(op, problemsOpSet);
@@ -218,17 +207,6 @@ void Caffe2OpCreator::checkFC(const ::caffe2::OperatorDef& op,
problemsOpSet.insert(std::string("FC: only default '") + s + "' value is supported");
}
-void Caffe2OpCreator::checkMul(const ::caffe2::OperatorDef& op,
- std::set& problemsOpSet) {
- commonCheck(op, problemsOpSet);
-
- if (getSingleArgument(op, "axis", 1) != 1)
- problemsOpSet.insert("Mul: only 'axis' = 1 is supported");
-
- if (getSingleArgument(op, "broadcast", 1) != 1)
- problemsOpSet.insert("Mul: only enabled 'broadcast' is supported");
-}
-
void Caffe2OpCreator::checkSpatialBN(const ::caffe2::OperatorDef& op,
std::set& problemsOpSet) {
commonCheck(op, problemsOpSet);
@@ -254,8 +232,6 @@ std::vector
Caffe2OpCreator::convertAdd(const std::vector& inputs,
const ::caffe2::OperatorDef& op,
const MIRTensors& mir_tensors) {
- const auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
-
std::vector add_input;
for (const auto& i : inputs)
@@ -263,11 +239,11 @@ Caffe2OpCreator::convertAdd(const std::vector& inputs,
// check mir tensors contain operand
if (mir_tensors.find(op.input(1)) != mir_tensors.end()) {
- auto next_input = createOp(mir_tensors.at(op.input(1)));
- add_input.push_back(next_input[0].getOutput(0));
+ auto next_input = createOp("Constant", mir_tensors.at(op.input(1)));
+ add_input.push_back(next_input->getOutput(0));
}
- auto add = createOp(add_input, ops::ElementwiseOp::OpType::add);
+ auto add = createOp("Elementwise_Add", add_input, ops::ElementwiseOp::OpType::add);
return {convertMIRToCaffe(add->getOutput(0))};
}
@@ -285,7 +261,7 @@ Caffe2OpCreator::convertAveragePool(const std::vector& inputs,
std::vector pad_before, pad_after;
std::tie(pad_before, pad_after) = getPadding(op);
- auto pooling = createOp(convertCaffeToMIR(inputs[0]), pool_type, window_shape,
+ auto pooling = createOp("Average_Pool", convertCaffeToMIR(inputs[0]), pool_type, window_shape,
strides, pad_before, pad_after, border_type,
ops::PoolOp::RoundMode::floor);
@@ -311,20 +287,20 @@ std::vector Caffe2OpCreator::convertConv(const std::vector(kernel_tensor);
- result = createOp(convertCaffeToMIR(inputs[0]), transposed_tensor,
+ result = createOp("Depthwise_Conv2D", convertCaffeToMIR(inputs[0]), transposed_tensor,
stride_shape, pad_before, pad_after);
} else {
// first we need to convert kernel of grouped convolution to appropriate ordinary kernel
if (num_groups != 1)
kernel_tensor = fixGroupedKernel(num_groups, kernel_tensor);
- result = createOp(convertCaffeToMIR(inputs[0]), kernel_tensor,
+ result = createOp("Conv2D", convertCaffeToMIR(inputs[0]), kernel_tensor,
stride_shape, pad_before, pad_after);
}
if (op.input_size() > 2) { // Bias is optional
- auto bias = createOp(mir_tensors.at(op.input(2)))->getOutput(0);
- result = createOp(result->getOutput(0), bias);
+ auto bias = createOp("Constant", mir_tensors.at(op.input(2)))->getOutput(0);
+ result = createOp("Bias_Add", result->getOutput(0), bias);
}
return {convertMIRToCaffe(result->getOutput(0))};
@@ -333,7 +309,7 @@ std::vector Caffe2OpCreator::convertConv(const std::vector Caffe2OpCreator::convertConcat(const std::vector& inputs,
const ::caffe2::OperatorDef& op) {
int axis = getSingleArgument(op, "axis", 1);
- auto result = createOp(inputs, axis);
+ auto result = createOp("Concat", inputs, axis);
return {result->getOutput(0)};
}
@@ -344,7 +320,7 @@ std::vector Caffe2OpCreator::convertDropout(const std::vector(inputs[0], dropout_ratio);
+ auto dropout = createOp("Dropout", inputs[0], dropout_ratio);
return {dropout->getOutput(0)};
}
@@ -357,12 +333,15 @@ Caffe2OpCreator::convertFullyConnected(const std::vector& inputs,
auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
// Transform input into 2-D tensor by flattening axes
Shape shape{input_shape.dim(0), input_shape.numElements() / input_shape.dim(0)};
- auto reshape = createOp(inputs[0], shape);
- auto weights = createOp(weights_tensor)->getOutput(0);
- auto result = createOp(reshape->getOutput(0), weights);
- auto bias = createOp(mir_tensors.at(op.input(2)))->getOutput(0);
- result = createOp(result->getOutput(0), bias);
+
+ auto reshape = createOp("Reshape", inputs[0], shape);
+ auto weights = createOp("Constant", weights_tensor)->getOutput(0);
+ auto result = createOp("Fully_Connected", reshape->getOutput(0), weights);
+ auto bias = createOp("Constant", mir_tensors.at(op.input(2)))->getOutput(0);
+ result = createOp("Bias_Add", result->getOutput(0), bias);
+
return {result->getOutput(0)};
+
}
std::vector Caffe2OpCreator::convertMaxPool(const std::vector& inputs,
@@ -376,7 +355,7 @@ std::vector Caffe2OpCreator::convertMaxPool(const std::vector pad_before, pad_after;
std::tie(pad_before, pad_after) = getPadding(op);
- auto pooling = createOp(convertCaffeToMIR(inputs[0]), pool_type, window_shape,
+ auto pooling = createOp("Pool", convertCaffeToMIR(inputs[0]), pool_type, window_shape,
strides, pad_before, pad_after, border_type,
ops::PoolOp::RoundMode::floor);
@@ -387,30 +366,38 @@ std::vector
Caffe2OpCreator::convertMul(const std::vector& inputs,
const ::caffe2::OperatorDef& op,
const MIRTensors& mir_tensors) {
- const auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
- // TODO: replace with elementwise op, when broadcating will be added in elementwise op
- auto multiplier = createOp(mir_tensors.at(op.input(1)))->getOutput(0);
- auto result = createOp(convertCaffeToMIR(inputs[0]), multiplier);
- return {convertMIRToCaffe(result->getOutput(0))};
+ std::vector input_descriptors;
+ for (const auto& i: inputs)
+ input_descriptors.push_back(convertCaffeToMIR(i.op->getOutput(0)));
+
+ // TODO: replace ConstantOp on inputs
+ if (mir_tensors.find(op.input(1)) != mir_tensors.end()) {
+ auto const_tensor = createOp("Constant", mir_tensors.at(op.input(1)));
+ input_descriptors.push_back(const_tensor->getOutput(0));
+ }
+
+ auto mul = createOp("Elementwise_Mul", input_descriptors, ops::ElementwiseOp::OpType::mul);
+
+ return {convertMIRToCaffe(mul->getOutput(0))};
}
std::vector
Caffe2OpCreator::convertRelu(const std::vector& inputs) {
- auto relu = createOp(inputs[0]);
+ auto relu = createOp("Relu", inputs[0]);
return {relu->getOutput(0)};
}
std::vector
Caffe2OpCreator::convertSigmoid(const std::vector& inputs) {
- auto result = createOp(inputs[0]);
+ auto result = createOp("Sigmoid", inputs[0]);
return {result->getOutput(0)};
}
std::vector Caffe2OpCreator::convertSoftmax(const std::vector& inputs,
const ::caffe2::OperatorDef& op) {
int axis = getSingleArgument(op, "axis", 1);
- auto softmax = createOp(inputs[0], axis);
+ auto softmax = createOp("Softmax", inputs[0], axis);
return {softmax->getOutput(0)};
}
@@ -430,19 +417,20 @@ Caffe2OpCreator::convertSpatialBN(const std::vector& inputs,
Tensor bias_data(mean_tensor);
for (auto& idx: ShapeRange(bias_data.getShape()))
bias_data.at(idx) *= -1;
- auto mean = createOp(mean_tensor)->getOutput(0);
- auto result = createOp(convertCaffeToMIR(inputs[0]), mean);
+
+ auto mean = createOp("Constant", mean_tensor)->getOutput(0);
+ auto result = createOp("Bias_Add", convertCaffeToMIR(inputs[0]), mean);
// res2 = res1 * scale / (var + epsilon)
Tensor multiplier(scale_tensor);
for (auto& idx: ShapeRange(scale_tensor.getShape()))
multiplier.at(idx) /= std::sqrt(*(float*) var_tensor.at(idx) + eps);
- auto scale = createOp(scale_tensor)->getOutput(0);
- result = createOp(result->getOutput(0), scale);
+ auto scale = createOp("Constant", scale_tensor)->getOutput(0);
+ result = createOp("Scale", result->getOutput(0), scale);
// overall_res = res2 + bias
- auto bias = createOp(bias_tensor)->getOutput(0);
- result = createOp(result->getOutput(0), bias);
+ auto bias = createOp("Constant", bias_tensor)->getOutput(0);
+ result = createOp("Bias_Add", result->getOutput(0), bias);
return {convertMIRToCaffe(result->getOutput(0))};
}
@@ -452,7 +440,7 @@ std::vector Caffe2OpCreator::convertSum(const std::vectorgetOutputShape(inputs[0].index) && "All Sum inputs must have same shape");
- auto op = createOp(inputs, ops::ElementwiseOp::OpType::add);
+ auto op = createOp("Elementwise_Add", inputs, ops::ElementwiseOp::OpType::add);
return {op->getOutput(0)};
}
@@ -464,7 +452,7 @@ Caffe2OpCreator::convertClip(const std::vector& inputs,
float min = getSingleArgument(op, "min", float(0));
assert(max > 0.0 && min == 0.0 && "Support only if clip is CappedRelu");
- auto cap_relu = createOp(inputs[0], max);
+ auto cap_relu = createOp("Capped_Relu", inputs[0], max);
return {cap_relu->getOutput(0)};
}
@@ -487,7 +475,7 @@ Caffe2OpCreator::convertReshape(const std::vector& inputs,
}
Shape out_shape(shape_vec);
- auto reshape = createOp(inputs[0], out_shape);
+ auto reshape = createOp("Reshape", inputs[0], out_shape);
return {reshape->getOutput(0)};
}
diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h
index 5c6fe9f..99058cb 100644
--- a/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h
+++ b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h
@@ -107,13 +107,15 @@ private:
mir::IODescriptor convertMIRToCaffe(const mir::IODescriptor& arg);
template
- mir::Operation* createOp(Types&& ... args);
+ mir::Operation* createOp(const std::string& name, Types&& ... args);
};
template
-mir::Operation* Caffe2OpCreator::createOp(Types&& ... args) {
- // TODO: set operation names
- return _graph->create("", std::forward(args)...);
+mir::Operation* Caffe2OpCreator::createOp(const std::string& name, Types&& ... args) {
+ mir::Operation* new_op = _graph->create("", std::forward(args)...);
+ std::string op_name = name + "_" + std::to_string(new_op->getId());
+ new_op->setName(op_name);
+ return new_op;
}
} // namespace nnc
diff --git a/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp
index 72d0a33..2369b72 100644
--- a/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp
+++ b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp
@@ -506,9 +506,8 @@ void CaffeOpCreator::checkBatchNorm(const caffe::LayerParameter& layer,
const auto& scale_shape = layer.blobs(2).shape();
// Check that last blob(with scaleFactor) containing only one number
- // FIXME This should be an assertion.
- if (scale_shape.dim_size() != 1 || scale_shape.dim(0) != 1)
- problems_op_set.insert("Unexpected shape of scale parameter in batch norm");
+ assert(scale_shape.dim_size() == 1 && scale_shape.dim(0) == 1 &&
+ "Unexpected shape of scale parameter in batch norm");
}
std::vector
@@ -597,21 +596,47 @@ CaffeOpCreator::convertEltwise(const caffe::LayerParameter& layer,
const std::vector& inputs) {
auto& opts = layer.eltwise_param();
ops::ElementwiseOp::OpType optype;
+ std::vector input_tensors;
switch (opts.operation()){
case EltwiseParameter_EltwiseOp_PROD:
optype = ops::ElementwiseOp::OpType::mul;
+ for (auto& i: inputs)
+ input_tensors.push_back(i);
break;
case EltwiseParameter_EltwiseOp_SUM:
optype = ops::ElementwiseOp::OpType::add;
- // TODO TechDebt: When broadcast is implemented this should create Scale Ops before sum args
- for (float c: opts.coeff())
- assert(c == 1.0f && "Coeff != 1 is not supported");
+ if (opts.coeff().size() > 0) {
+ assert(opts.coeff().size() == inputs.size());
+ for (int i = 0; i < opts.coeff().size(); i++) {
+ if (opts.coeff().Get(i) != 1.0f) {
+ auto coeff = new char[sizeof(float)];
+ memcpy(coeff, &opts.coeff().Get(i), sizeof(float));
+ auto coeff_tensor = TensorVariant(Shape{1},
+ std::shared_ptr(reinterpret_cast(coeff), std::default_delete()),
+ DTYPE::FLOAT32, sizeof(float));
+ auto coeff_const = createOp(layer.name() + "_const", coeff_tensor);
+ std::vector mul_inputs;
+ mul_inputs.push_back(coeff_const->getOutput(0));
+ mul_inputs.push_back(inputs[i]);
+ auto mul = createOp(layer.name() + "_mul",
+ mul_inputs, ops::ElementwiseOp::OpType::mul);
+ input_tensors.push_back(mul->getOutput(0));
+ } else {
+ input_tensors.push_back(inputs[i]);
+ }
+ }
+ } else {
+ for (auto& i: inputs)
+ input_tensors.push_back(i);
+ }
break;
case EltwiseParameter_EltwiseOp_MAX:
optype = ops::ElementwiseOp::OpType::max;
+ for (auto& i: inputs)
+ input_tensors.push_back(i);
break;
}
- auto elementwise = createOp(layer.name(), inputs, optype);
+ auto elementwise = createOp(layer.name(), input_tensors, optype);
return {elementwise->getOutput(0)};
}
--
2.7.4