From: Ivan Vagin/AI Tools Lab /SRR/Engineer/삼성전자 Date: Fri, 7 Dec 2018 11:23:02 +0000 (+0300) Subject: [nnc] Initial implementation of caffe2_op_creator (#2333) X-Git-Tag: nncc_backup~1153 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7b84a6938a383f0dc0905decec8ede6c7c591af3;p=platform%2Fcore%2Fml%2Fnnfw.git [nnc] Initial implementation of caffe2_op_creator (#2333) Initial implementation of caffe2_op_creator. `mobilenet` supported, to support `inception` model - need to support custom paddings in pooling ops and test not tested operation conversations. Implemented ops: - Add - AveragePool - Conv - Concat - Dropout - FC - GivenTensorFill - MaxPool - Mul - Relu - Softmax - SpatialBN - Sum Not tested ops: - Add - Concat - Mul - SpatialBN Signed-off-by: Ivan Vagin --- diff --git a/contrib/nnc/driver/Options.cpp b/contrib/nnc/driver/Options.cpp index 317dc48..028bd67 100644 --- a/contrib/nnc/driver/Options.cpp +++ b/contrib/nnc/driver/Options.cpp @@ -92,7 +92,7 @@ Option initNet(optname("--init-net"), std::string(), optional(false), optvalues(""), - nullptr, + checkInFile, separators(""), #ifdef NNC_FRONTEND_CAFFE2_ENABLED showopt(true), diff --git a/contrib/nnc/include/passes/caffe2_frontend/caffe2_importer.h b/contrib/nnc/include/passes/caffe2_frontend/caffe2_importer.h index b9fa140..db637ce 100644 --- a/contrib/nnc/include/passes/caffe2_frontend/caffe2_importer.h +++ b/contrib/nnc/include/passes/caffe2_frontend/caffe2_importer.h @@ -33,7 +33,7 @@ class OperatorDef; class NetDef; } namespace nnc { -// class Caffe2OpCreator; +class Caffe2OpCreator; enum class SupportedCaffe2OpType : uint8_t; } @@ -68,7 +68,7 @@ private: std::string _initNet; mir::Graph* _graph; std::unique_ptr<::caffe2::NetDef> _net; - // std::unique_ptr _opCreator; + std::unique_ptr _opCreator; std::vector _inputShapes; static const std::map _operatorTypes; @@ -77,7 +77,7 @@ private: // This map maps caffe2 operators names to MIR operators // that correspond to previous caffe2 operators std::map _blobNameToIODescriptor; - mir::Operation* _lastNode; + mir::Operation* _lastMIROp; std::map> _MIRTensors; @@ -85,47 +85,42 @@ private: * @brief Pass through caffe2 graph and collect ops unsupported by NNC * @throw PassException with message, containing detected problems */ - // void collectUnsupportedOps(); + void collectUnsupportedOps(); /** * @brief Collecting unsupported parts of caffe2 operator */ - // void collectUnsupportedOp(const ::caffe2::OperatorDef&); + void collectUnsupportedOp(const ::caffe2::OperatorDef&); /** * @brief Creating MIR node from single caffe2 operator */ - // void createMIRNodesFromOp(const ::caffe2::OperatorDef&); + void createMIRNodesFromOp(const ::caffe2::OperatorDef&); /** * @brief Since caffe2 tensor values stored separately (in init_net) - preload them in _MIRTensors */ - // void preloadAllTensors(); + void preloadAllTensors(); /** * @brief Creates MIR tensor from caffe2 givenTensorFill op */ - // std::shared_ptr createTensor(const ::caffe2::OperatorDef&); + std::shared_ptr createTensor(const ::caffe2::OperatorDef&); /** * @brief Returns MIR ops, under given caffe2 op */ - // std::vector getInputMIROps(const ::caffe2::OperatorDef&); - - /** - * @brief create MIR inputs with given names and shapes - */ - // void createGraphInputs(const std::vector&, const std::vector&); + std::vector getInputMIROps(const ::caffe2::OperatorDef&); /** * @brief Mark output MIR nodes */ - // void setGraphOutputs(); + void setGraphOutputs(); /** * @brief Set MIR node names */ - // void setIrNodeNames(); + void setIrNodeNames(); }; } // namespace nnc diff --git a/contrib/nnc/include/passes/common_frontend/op_creator_helper.h b/contrib/nnc/include/passes/common_frontend/op_creator_helper.h new file mode 100644 index 0000000..11309b9 --- /dev/null +++ b/contrib/nnc/include/passes/common_frontend/op_creator_helper.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FRONTEND_COMMON_OP_CREATOR_HELPER_H_ +#define FRONTEND_COMMON_OP_CREATOR_HELPER_H_ + +#include +#include + +#include "core/modelIR/Shape.h" +#include "core/modelIR/TensorVariant.h" + +namespace nnc { + +/** Convert kernel for grouped 2d convolution in kernel for ordinary 2d convolution + * + * Grouped convolution breaks input and kernel channels into selected number of groups and applies convolution in every group of channels independently. + * This technique allows to save kernel size(channels from different groups are not merged, no need to store redundant 0 weights). + * This is not supported by compiler for now, so this function unfolds compact kernel into classic flavored "every input layer affects every output layer", + * by inserting zero coefficients where needed + * + * @param groups number of groups in grouped convolution + * @param foldedKernel original grouped kernel + * @return unfolded kernel, compatible with ordinary conv2D operation + */ +std::shared_ptr +fixGroupedKernel(int groups, std::shared_ptr folded_kernel); + +} // namespace nnc + +#endif // FRONTEND_COMMON_OP_CREATOR_HELPER_H_ diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp b/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp index 69bc7d8..d9673ae 100644 --- a/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp +++ b/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp @@ -26,7 +26,7 @@ #include "caffe2/proto/caffe2.pb.h" #include "caffe2_op_types.h" -// #include "caffe2_op_creator.h" +#include "caffe2_op_creator.h" #include "core/modelIR/Shape.h" #include "core/modelIR/operations/VariableOp.h" @@ -44,13 +44,13 @@ Caffe2Importer::Caffe2Importer(std::string predictNet, std::string initNet, std::vector> shapes) : _predictNet(std::move(predictNet)), _initNet(std::move(initNet)), - _graph(new mir::Graph())/*, - _opCreator(new Caffe2OpCreator(_graph))*/ { - for(auto& shape : shapes) + _graph(new mir::Graph()), + _opCreator(new Caffe2OpCreator(_graph)) { + for (auto& shape : shapes) _inputShapes.emplace_back(shape); } -Caffe2Importer::~Caffe2Importer()=default; +Caffe2Importer::~Caffe2Importer() = default; PassData Caffe2Importer::run(PassData) { import(); @@ -66,41 +66,240 @@ void Caffe2Importer::import() { _net.reset(new NetDef()); if (!readProtoFromBinaryFile<::caffe2::NetDef>(_predictNet.c_str(), _net.get())) - throw PassException("Could not load model: " + _predictNet+ "\n"); + throw PassException("Could not load model: " + _predictNet + "\n"); std::unique_ptr net2; net2.reset(new NetDef()); if (!readProtoFromBinaryFile<::caffe2::NetDef>(_initNet.c_str(), net2.get())) - throw PassException("Could not load model: " + _initNet+ "\n"); + throw PassException("Could not load model: " + _initNet + "\n"); _net->MergeFrom(*net2); - // collectUnsupportedOps(); + preloadAllTensors(); - // preloadAllTensors(); + collectUnsupportedOps(); } mir::Graph* Caffe2Importer::createIR() { - throw PassException("Caffe2: NYI"); - /* for (auto& op : _net->op()) createMIRNodesFromOp(op); setIrNodeNames(); setGraphOutputs(); - */ return _graph; } +void Caffe2Importer::collectUnsupportedOps() { + for (auto& op : _net->op()) + collectUnsupportedOp(op); + + if (!_problemsOpSet.empty()) { + std::string msg("Detected problems:\n"); + for (const auto& problemStr : _problemsOpSet) + msg.append(problemStr + "\n"); + throw PassException(msg); + } +} + +void Caffe2Importer::collectUnsupportedOp(const OperatorDef& op) { + if (_operatorTypes.find(op.type()) == _operatorTypes.end()) { + _problemsOpSet.insert(op.type() + ": unknown layer"); + return; + } + + SupportedCaffe2OpType opType = _operatorTypes.at(op.type()); + switch (opType) { + case SupportedCaffe2OpType::FC: + _opCreator->checkFC(op, _problemsOpSet); + break; + case SupportedCaffe2OpType::spatialBN: + _opCreator->checkSpatialBN(op, _problemsOpSet); + break; + case SupportedCaffe2OpType::add: + case SupportedCaffe2OpType::averagePool: + case SupportedCaffe2OpType::concat: + case SupportedCaffe2OpType::constantFill: + case SupportedCaffe2OpType::conv: + case SupportedCaffe2OpType::dropout: + case SupportedCaffe2OpType::givenTensorFill: + case SupportedCaffe2OpType::maxPool: + case SupportedCaffe2OpType::mul: + case SupportedCaffe2OpType::relu: + case SupportedCaffe2OpType::softmax: + case SupportedCaffe2OpType::sum: + _opCreator->commonCheck(op, _problemsOpSet); + break; + default: + _problemsOpSet.insert(op.type() + ": unsupported layer"); + break; + } +} + +void Caffe2Importer::preloadAllTensors() { + for (auto& op : _net->op()) { + // All tensor values are stored in 'GivenTensorFill' and 'ConstantFill' operators, so skip rest + auto opType = _operatorTypes.at(op.type()); + if ((opType == SupportedCaffe2OpType::givenTensorFill + || opType == SupportedCaffe2OpType::constantFill) + && hasArgument(op.arg(), "values")) { + _MIRTensors.insert( + std::pair>(op.output(0), createTensor(op))); + } + } +} + +void Caffe2Importer::createMIRNodesFromOp(const OperatorDef& op) { + std::vector outputs; + + // If op input not met yet - consider it as model input + if (op.input_size() > 0 + && _blobNameToIODescriptor.find(op.input(0)) == _blobNameToIODescriptor.end()) { + + outputs = _opCreator->createInput(op.input(0), _inputShapes.front()); + _blobNameToIODescriptor[op.input(0)] = outputs.at(0); + + _inputShapes.erase(_inputShapes.begin(), _inputShapes.begin() + 1); + } + + auto inputs = getInputMIROps(op); + + SupportedCaffe2OpType opType = _operatorTypes.at(op.type()); + switch (opType) { + case SupportedCaffe2OpType::constantFill: + case SupportedCaffe2OpType::givenTensorFill: + return; + case SupportedCaffe2OpType::add: + outputs = _opCreator->convertAdd(inputs, op, _MIRTensors); + break; + case SupportedCaffe2OpType::averagePool: + outputs = _opCreator->convertAveragePool(inputs, op); + break; + case SupportedCaffe2OpType::conv: + outputs = _opCreator->convertConv(inputs, op, _MIRTensors); + break; + case SupportedCaffe2OpType::concat: + outputs = _opCreator->convertConcat(inputs, op); + break; + case SupportedCaffe2OpType::dropout: + outputs = _opCreator->convertDropout(inputs, op); + break; + case SupportedCaffe2OpType::FC: + outputs = _opCreator->convertFullyConnected(inputs, op, _MIRTensors); + break; + case SupportedCaffe2OpType::maxPool: + outputs = _opCreator->convertMaxPool(inputs, op); + break; + case SupportedCaffe2OpType::mul: + outputs = _opCreator->convertMul(inputs, op, _MIRTensors); + break; + case SupportedCaffe2OpType::relu: + outputs = _opCreator->convertRelu(inputs); + break; + case SupportedCaffe2OpType::softmax: + outputs = _opCreator->convertSoftmax(inputs, op); + break; + case SupportedCaffe2OpType::spatialBN: + outputs = _opCreator->convertSpatialBN(inputs, op, _MIRTensors); + break; + case SupportedCaffe2OpType::sum: + outputs = _opCreator->convertSum(inputs); + break; + default: + assert(false && "All unsupported types should have been found before this pass."); + } + + for (int i = 0; i < outputs.size(); ++i) { + // caffe2 input blob name could be same as output blob name, and next line will overwrite + // '_blobNameToIODescriptor' element, but in all networks that I saw it was not a problem + _blobNameToIODescriptor[op.output(i)] = outputs.at(i); + } + + _lastMIROp = outputs.at(0).op; +} + +std::shared_ptr Caffe2Importer::createTensor(const OperatorDef& op) { + assert(hasArgument(op.arg(), "shape") && hasArgument(op.arg(), "values")); + + auto shape = findArgumentByName(op.arg(), "shape"); + auto values = findArgumentByName(op.arg(), "values"); + + // Create untyped tensor. Note, tensor contents will be *copied* here. + auto type = mir::DTYPE::FLOAT32; + size_t elementSize = sizeof(float); + size_t bufferSize = values.floats().size() * elementSize; + const char* srcData = reinterpret_cast(values.floats().data()); + std::shared_ptr tensorBufferCopy(new char[bufferSize], + std::default_delete()); + char* dstData = tensorBufferCopy.get(); + memcpy(dstData, srcData, bufferSize); + + Shape tensor_shape = ShapeHelper::createShape( + shape.ints(), static_cast(shape.ints().size())); + + auto tensor = std::make_shared(tensor_shape, tensorBufferCopy, type, elementSize); + + return tensor; +} + +std::vector Caffe2Importer::getInputMIROps(const OperatorDef& op) { + // caffe2 operation inputs not same as MIR inputs (ex: in caffe2 conv kernel and bias also inputs) + // so choose caffe2 inputs, which are 'real' inputs + std::vector inputs; + SupportedCaffe2OpType opType = _operatorTypes.at(op.type()); + switch (opType) { + case SupportedCaffe2OpType::givenTensorFill: + case SupportedCaffe2OpType::constantFill: + break; + case SupportedCaffe2OpType::add: + case SupportedCaffe2OpType::averagePool: + case SupportedCaffe2OpType::conv: + case SupportedCaffe2OpType::dropout: + case SupportedCaffe2OpType::FC: + case SupportedCaffe2OpType::maxPool: + case SupportedCaffe2OpType::mul: + case SupportedCaffe2OpType::relu: + case SupportedCaffe2OpType::softmax: + case SupportedCaffe2OpType::spatialBN: + inputs.push_back(_blobNameToIODescriptor[op.input(0)]); + break; + case SupportedCaffe2OpType::sum: + case SupportedCaffe2OpType::concat: + for (auto& i : op.input()) + inputs.push_back(_blobNameToIODescriptor[i]); + break; + default: + assert(false && "All unsupported types should have been found before this pass."); + } + + return inputs; +} + +void Caffe2Importer::setGraphOutputs() { + // For now, we assume that: + // - there is exactly one output; + // - the output is from the last layer. + _graph->markOutput(_lastMIROp); +} + +void Caffe2Importer::setIrNodeNames() { + for (auto& item : _blobNameToIODescriptor) + item.second.op->setName(item.first); +} + const std::map Caffe2Importer::_operatorTypes = { + {"Add", SupportedCaffe2OpType::add}, {"AveragePool", SupportedCaffe2OpType::averagePool}, {"Conv", SupportedCaffe2OpType::conv}, + {"Concat", SupportedCaffe2OpType::concat}, + {"ConstantFill", SupportedCaffe2OpType::constantFill}, {"Dropout", SupportedCaffe2OpType::dropout}, {"FC", SupportedCaffe2OpType::FC}, {"GivenTensorFill", SupportedCaffe2OpType::givenTensorFill}, {"MaxPool", SupportedCaffe2OpType::maxPool}, + {"Mul", SupportedCaffe2OpType::mul}, {"Relu", SupportedCaffe2OpType::relu}, {"Softmax", SupportedCaffe2OpType::softmax}, + {"SpatialBN", SupportedCaffe2OpType::spatialBN}, {"Sum", SupportedCaffe2OpType::sum} }; diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp new file mode 100644 index 0000000..57b4652 --- /dev/null +++ b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/modelIR/operations/BatchNormOp.h" +#include "core/modelIR/operations/BiasAddOp.h" +#include "core/modelIR/operations/CappedReluOp.h" +#include "core/modelIR/operations/ConcatOp.h" +#include "core/modelIR/operations/Conv2DOp.h" +#include "core/modelIR/operations/DepthwiseConv2DOp.h" +#include "core/modelIR/operations/DropoutOp.h" +#include "core/modelIR/operations/ElementwiseOp.h" +#include "core/modelIR/operations/FullyConnectedOp.h" +#include "core/modelIR/operations/PoolOp.h" +#include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/ReshapeOp.h" +#include "core/modelIR/operations/ScaleOp.h" +#include "core/modelIR/operations/SoftmaxOp.h" +#include "core/modelIR/operations/TransposeOp.h" +#include "core/modelIR/operations/VariableOp.h" + +#include "core/modelIR/Index.h" +#include "core/modelIR/Shape.h" +#include "core/modelIR/ShapeRange.h" +#include "core/modelIR/Tensor.h" +#include "core/modelIR/TensorUtil.h" + +#include "passes/common_frontend/op_creator_helper.h" +#include "passes/common_frontend/shape_helper.h" +#include "pass/PassException.h" +#include "caffe2_op_creator.h" +#include "caffe2_proto_helper.h" + +#include +#include +#include +#include "option/Options.h" + + +namespace nnc { + +using namespace ::caffe2; +using namespace mir; +using nnc::mir::transposeTensor; + +// +// Helper functions +// + +mir::IODescriptor Caffe2OpCreator::convertCaffeToMIR(const mir::IODescriptor& arg) { + if (cli::debugTranspose) { + // NCHW -> NHWC + auto transpose = createOp(arg, std::vector{0, 2, 3, 1}); + return transpose->getOutput(0); + } else { + return arg; + } +} + +mir::IODescriptor Caffe2OpCreator::convertMIRToCaffe(const mir::IODescriptor& arg) { + if (cli::debugTranspose) { + // NHWC -> NCHW + auto transpose = createOp(arg, std::vector{0, 3, 1, 2}); + return transpose->getOutput(0); + } else { + return arg; + } +} + +// +// Check functions +// + +void Caffe2OpCreator::commonCheck(const ::caffe2::OperatorDef& op, + std::set& problemsOpSet) { + if (getSingleArgument(op, "order", "NCHW") != "NCHW") + problemsOpSet.insert("Only 'NCHW' oreder is supported"); +} + +void Caffe2OpCreator::checkFC(const ::caffe2::OperatorDef& op, + std::set& problemsOpSet) { + commonCheck(op, problemsOpSet); + for (auto& s : {"axis", "axis_w", "float16_compute"}) + if (hasArgument(op.arg(), s)) + problemsOpSet.insert(std::string("FC: only default '") + s + "' value is supported"); +} + +void Caffe2OpCreator::checkSpatialBN(const ::caffe2::OperatorDef& op, + std::set& problemsOpSet) { + commonCheck(op, problemsOpSet); + if (op.input_size() != 5) + problemsOpSet.insert( + "SpatialBN must have exactly 5 inputs ('sums' and 'sumsq' are not supported yet)"); +} + +// +// Convert functions +// + +std::vector +Caffe2OpCreator::convertAdd(const std::vector& inputs, + const ::caffe2::OperatorDef& op, + const MIRTensors& mirTensors) { + // TODO: not tested + throw PassException("Caffe2 Add op not tested yet"); + auto& addend = mirTensors.at(op.input(1)); + auto add = createOp(inputs[0], *addend); + return {add->getOutput(0)}; +} + +std::vector +Caffe2OpCreator::convertAveragePool(const std::vector& inputs, + const OperatorDef& op) { + // TODO: implement custom paddings + bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r") + || hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b"); + if (has_custom_pad) + throw PassException("Custom one-side padding not supported yet"); + + int kernel_size = static_cast(findArgumentByName(op.arg(), "kernel").i()); + Shape window_shape = Shape({kernel_size, kernel_size}); + + int stride = static_cast(findArgumentByName(op.arg(), "stride").i()); + Shape strides = Shape({stride, stride}); + + ops::PoolOp::PoolingType pool_type = ops::PoolOp::PoolingType::AVG; + ops::PoolOp::BorderType border_type = ops::PoolOp::BorderType::ZEROFILLED; + + int pad = getSingleArgument(op, "pad", 0); + std::vector padding{pad, pad}; + + auto pooling = createOp(inputs[0], pool_type, window_shape, strides, padding, + padding, border_type, ops::PoolOp::RoundMode::ceil); + + return {pooling->getOutput(0)}; +} + +std::vector Caffe2OpCreator::convertConv(const std::vector& inputs, + const ::caffe2::OperatorDef& op, + const MIRTensors& mirTensors) { + int stride = getSingleArgument(op, "stride", 1); + Shape stride_shape = Shape({stride, stride}); + + int pad = getSingleArgument(op, "pad", 0); + std::vector padding{pad, pad}; + + auto kernel_tensor = transposeTensor<2, 3, 1, 0>(mirTensors.at(op.input(1))); + auto in_group_size = kernel_tensor->getShape().dim(2); + auto out_channels = kernel_tensor->getShape().dim(3); + int num_groups = getSingleArgument(op, "group", 1); + bool is_depthwise = (num_groups != 1) && (in_group_size == 1) && (out_channels == num_groups); + + mir::Operation* conv2d; + if (is_depthwise) { + // This is depthwise convolution + // TODO handle properly kernel with layer multiplier + std::shared_ptr transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(kernel_tensor); + conv2d = createOp(convertCaffeToMIR(inputs[0]), *transposed_tensor, + stride_shape, padding, padding); + } else { + // first we need to convert kernel of grouped convolution to appropriate ordinary kernel + if (num_groups != 1) + kernel_tensor = fixGroupedKernel(num_groups, kernel_tensor); + + conv2d = createOp(convertCaffeToMIR(inputs[0]), *kernel_tensor, + stride_shape, padding, padding); + } + + if (op.input_size() > 2) { // Bias is optional + auto bias_add = createOp(conv2d->getOutput(0), *mirTensors.at(op.input(2))); + return {convertMIRToCaffe(bias_add->getOutput(0))}; + } + return {convertMIRToCaffe(conv2d->getOutput(0))}; +} + +std::vector Caffe2OpCreator::convertConcat(const std::vector& inputs, + const ::caffe2::OperatorDef& op) { + // TODO: not tested + throw PassException("Caffe2 Concat op not tested yet"); + int axis = getSingleArgument(op, "axis", -1); + auto result = createOp(inputs, axis); + return {result->getOutput(0)}; +} + +std::vector Caffe2OpCreator::convertDropout(const std::vector& inputs, + const ::caffe2::OperatorDef& op) { + // TODO: not tested + throw PassException("Caffe2 Dropout op not tested yet"); + int is_test = getSingleArgument(op, "is_test", 0); + if (is_test) + return {inputs[0]}; + + float dropot_ratio = getSingleArgument(op, "ratio", 0.5f); + auto dropout = createOp(inputs[0], dropot_ratio); + return {dropout->getOutput(0)}; +} + +// TODO: describe caffe2 FC interface +std::vector +Caffe2OpCreator::convertFullyConnected(const std::vector& inputs, + const ::caffe2::OperatorDef& op, + const MIRTensors& mirTensors) { + auto weightsTensor = mirTensors.at(op.input(1)); + weightsTensor = transposeTensor<1, 0>(weightsTensor); + int32_t fc_input_size = weightsTensor->getShape().dim(0); + + // Add Reshape operation to make sure the input for FC operation has shape [1, fcInputSize] + // It is needed because Caffe2 FC layer takes NCHW input and flattens the CHW part. + auto reshape = createOp(inputs[0], Shape({1, fc_input_size})); + + auto fully_connected = createOp(reshape->getOutput(0), *weightsTensor); + + auto bias = createOp(fully_connected->getOutput(0), *mirTensors.at(op.input(2))); + return {bias->getOutput(0)}; +} + +std::vector +Caffe2OpCreator::createInput(const std::string& input_name, const mir::Shape& input_shape) { + // TODO For now we only support convolutional networks with one element per batch. + assert(input_shape.rank() == 4 && input_shape.dim(0) == 1); + + // TODO Do not transpose data on input and remove transpose. + auto transposed_shape = mir::Shape{input_shape.dim(0), input_shape.dim(2), + input_shape.dim(3), input_shape.dim(1)}; + auto variable = _graph->create(input_name, transposed_shape); + return {convertMIRToCaffe(variable->getOutput(0))}; +} + +std::vector Caffe2OpCreator::convertMaxPool(const std::vector& inputs, + const OperatorDef& op) { + // TODO: implement custom paddings + bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r") + || hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b"); + if (has_custom_pad) + throw PassException("Custom one-side padding not supported yet"); + + int window_length = static_cast(findArgumentByName(op.arg(), "kernel").i()); + Shape window_shape = Shape({window_length, window_length}); + + int stride = static_cast(findArgumentByName(op.arg(), "stride").i()); + Shape strides = Shape({stride, stride}); + + ops::PoolOp::PoolingType pool_type = ops::PoolOp::PoolingType::MAX; + ops::PoolOp::BorderType border_type = ops::PoolOp::BorderType::EMPTY; + + int pad = getSingleArgument(op, "pad", 0); + std::vector padding{pad, pad}; + + auto pooling = createOp(convertCaffeToMIR(inputs[0]), pool_type, window_shape, + strides, padding, padding, border_type, + ops::PoolOp::RoundMode::ceil); + + return {convertMIRToCaffe(pooling->getOutput(0))}; +} + +std::vector +Caffe2OpCreator::convertMul(const std::vector& inputs, + const ::caffe2::OperatorDef& op, + const MIRTensors& mirTensors) { + // TODO: not tested + throw PassException("Caffe Mul op not tested yet"); + auto& multiplier = mirTensors.at(op.input(1)); + auto mul = createOp(inputs[0], *multiplier); + return {mul->getOutput(0)}; +} + +std::vector Caffe2OpCreator::convertRelu(const std::vector& inputs) { + auto relu = createOp(inputs[0]); + return {relu->getOutput(0)}; +} + +std::vector Caffe2OpCreator::convertSoftmax(const std::vector& inputs, + const ::caffe2::OperatorDef& op) { + int axis = getSingleArgument(op, "axis", 1); + auto softmax = createOp(inputs[0], axis); + return {softmax->getOutput(0)}; +} + +std::vector +Caffe2OpCreator::convertSpatialBN(const std::vector& inputs, + const ::caffe2::OperatorDef& op, + const MIRTensors& mirTensors) { + // TODO: not tested + throw PassException("Caffe2 SpatialBN op not tested yet"); + // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias + + auto& scale = mirTensors.at(op.input(1)); + auto& bias = mirTensors.at(op.input(2)); + auto& mean = mirTensors.at(op.input(3)); + auto& var = mirTensors.at(op.input(4)); + float eps = getSingleArgument(op, "epsilon", 1e-5f); + + // res1 = X - mean + Tensor bias_data(*mean); + for (Index idx: ShapeRange(bias_data.getShape())) + bias_data.at(idx) *= -1; + auto bias_add_1 = createOp(convertCaffeToMIR(inputs[0]), *mean); + + // res2 = res1 * scale / (var + epsilon) + Tensor multiplier(*scale); + for (Index idx: ShapeRange(scale->getShape())) + multiplier.at(idx) = 1.0f / std::sqrt(*(float*) var->at(idx) + eps); + auto scale_op = createOp(bias_add_1->getOutput(0), *scale); + + // overall_res = res2 + bias + auto bias_add_2 = createOp(scale_op->getOutput(0), *bias); + + return {convertMIRToCaffe(bias_add_2->getOutput(0))}; +} + +std::vector Caffe2OpCreator::convertSum(const std::vector& inputs) { + auto op = createOp(inputs, ops::ElementwiseOp::OpType::add); + return {op->getOutput(0)}; +} + +} // namespace nnc diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h new file mode 100644 index 0000000..09b2d0e --- /dev/null +++ b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNCC_CAFFE2_OP_CREATOR_H +#define NNCC_CAFFE2_OP_CREATOR_H + +#include +#include +#include +#include + +#include "core/modelIR/Graph.h" +#include "core/modelIR/Operation.h" +#include "core/modelIR/TensorVariant.h" +#include "core/modelIR/operations/CommonProps.h" +#include "core/modelIR/Shape.h" + +#include "caffe2/proto/caffe2.pb.h" + +namespace nnc { + +using nnc::mir::Graph; +using nnc::mir::Operation; +using IrTensor = nnc::mir::TensorVariant; +using nnc::mir::Shape; +using MIRTensors = const std::map>; + +class Caffe2OpCreator { +public: + explicit Caffe2OpCreator(Graph* g) : _graph(g) {}; + + void commonCheck(const ::caffe2::OperatorDef&, std::set&); + + void checkFC(const ::caffe2::OperatorDef&, std::set&); + + void checkSpatialBN(const ::caffe2::OperatorDef&, std::set&); + + std::vector convertAdd(const std::vector&, + const ::caffe2::OperatorDef&, const MIRTensors&); + + std::vector convertAveragePool(const std::vector&, + const ::caffe2::OperatorDef&); + + std::vector convertConv(const std::vector&, + const ::caffe2::OperatorDef&, const MIRTensors&); + + std::vector convertConcat(const std::vector&, + const ::caffe2::OperatorDef&); + + std::vector convertDropout(const std::vector&, + const ::caffe2::OperatorDef&); + + std::vector convertFullyConnected(const std::vector&, + const ::caffe2::OperatorDef&, + const MIRTensors&); + + std::vector createInput(const std::string&, const mir::Shape&); + + std::vector convertMaxPool(const std::vector&, + const ::caffe2::OperatorDef&); + + std::vector convertMul(const std::vector&, + const ::caffe2::OperatorDef&, const MIRTensors&); + + std::vector convertRelu(const std::vector&); + + std::vector convertSoftmax(const std::vector&, + const ::caffe2::OperatorDef&); + + std::vector convertSpatialBN(const std::vector&, + const ::caffe2::OperatorDef&, const MIRTensors&); + + std::vector convertSum(const std::vector&); + +private: + Graph* _graph = nullptr; + + mir::IODescriptor convertCaffeToMIR(const mir::IODescriptor& arg); + + mir::IODescriptor convertMIRToCaffe(const mir::IODescriptor& arg); + + template + mir::Operation* createOp(Types&& ... args); +}; + +template +mir::Operation* Caffe2OpCreator::createOp(Types&& ... args) { + // TODO: set operation names + return _graph->create("", std::forward(args)...); +} + +} // namespace nnc + +#endif //NNCC_CAFFE2_OP_CREATOR_H diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_op_types.h b/contrib/nnc/passes/caffe2_frontend/caffe2_op_types.h index 6aa56e7..8ac7260 100644 --- a/contrib/nnc/passes/caffe2_frontend/caffe2_op_types.h +++ b/contrib/nnc/passes/caffe2_frontend/caffe2_op_types.h @@ -20,14 +20,19 @@ namespace nnc { enum class SupportedCaffe2OpType : uint8_t { + add, averagePool, + concat, conv, + constantFill, dropout, FC, givenTensorFill, maxPool, + mul, relu, softmax, + spatialBN, sum }; diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.cpp b/contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.cpp index b13436c..8e79e89 100644 --- a/contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.cpp +++ b/contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.cpp @@ -36,4 +36,25 @@ const bool hasArgument(RepArgument args, std::string name) { return false; } +int getSingleArgument(const ::caffe2::OperatorDef& op, const std::string& argument_name, + const int default_value) { + if (hasArgument(op.arg(), argument_name)) + return static_cast(findArgumentByName(op.arg(), argument_name).i()); + return default_value; +} + +float getSingleArgument(const ::caffe2::OperatorDef& op, const std::string& argument_name, + const float default_value) { + if (hasArgument(op.arg(), argument_name)) + return findArgumentByName(op.arg(), argument_name).f(); + return default_value; +} + +std::string getSingleArgument(const ::caffe2::OperatorDef& op, const std::string& argument_name, + const std::string& default_value) { + if (hasArgument(op.arg(), argument_name)) + return findArgumentByName(op.arg(), argument_name).s(); + return default_value; +} + } // namespace nnc diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.h b/contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.h index 750f396..5b62e23 100644 --- a/contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.h +++ b/contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.h @@ -22,8 +22,13 @@ namespace nnc { using RepArgument = const ::google::protobuf::RepeatedPtrField<::caffe2::Argument>&; const ::caffe2::Argument& findArgumentByName(RepArgument args, std::string name); + const bool hasArgument(RepArgument args, std::string name); +int getSingleArgument(const ::caffe2::OperatorDef&, const std::string&, const int); +float getSingleArgument(const ::caffe2::OperatorDef&, const std::string&, const float); +std::string getSingleArgument(const ::caffe2::OperatorDef&, const std::string&, const std::string&); + } // namespace nnc #endif // NNCC_CAFFE2_PROTO_HELPER_H diff --git a/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp index c0cb96a..caff041 100644 --- a/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp +++ b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp @@ -39,6 +39,7 @@ #include "core/modelIR/TensorUtil.h" #include "passes/common_frontend/shape_helper.h" +#include "passes/common_frontend/op_creator_helper.h" #include "pass/PassException.h" #include "caffe_op_creator.h" @@ -52,65 +53,6 @@ namespace nnc { using namespace mir; using namespace ::caffe; -/** Convert kernel for grouped 2d convolution in kernel for ordinary 2d convolution - * - * Grouped convolution breaks input and kernel channels into selected number of groups and applies convolution in every group of channels independently. - * This technique allows to save kernel size(channels from different groups are not merged, no need to store redundant 0 weights). - * This is not supported by compiler for now, so this function unfolds compact kernel into classic flavored "every input layer affects every output layer", - * by inserting zero coefficients where needed - * - * @param groups number of groups in grouped convolution - * @param foldedKernel original grouped kernel - * @return unfolded kernel, compatible with ordinary conv2D operation - */ -static std::shared_ptr -fixGroupedKernel(int groups, std::shared_ptr folded_kernel) { - const int kernel_in_chan_num = 2; - const int kernel_out_chan_num = 3; - - const Shape& kernel_shape = folded_kernel->getShape(); - auto kernel_in_channels = kernel_shape.dim(kernel_in_chan_num); - auto kernel_out_channels = kernel_shape.dim(kernel_out_chan_num); - auto in_channels = kernel_in_channels * groups; - - // Original kernel has shape [H, W, inputChannels/groups, outputChannels] - // here creates unfolded kernel with shape [H, W, inputChannels, outputChannels] - Shape unfold_kernel_shape(kernel_shape); - unfold_kernel_shape.dim(kernel_in_chan_num) = in_channels; - auto buffer_size = unfold_kernel_shape.numElements() * folded_kernel->getElementSize(); - std::shared_ptr buffer(new char[buffer_size], std::default_delete()); - size_t data_size = folded_kernel->getElementSize(); - std::shared_ptr unfold_kernel = - std::make_shared(unfold_kernel_shape, buffer, folded_kernel->getDataType(), - data_size); - - int in_group_size = kernel_in_channels; - int out_group_size = kernel_out_channels / groups; - assert(kernel_out_channels % groups == 0); - - // Iterate over "unfolded" kernel Shape and insert appropriate values into result kernel - for (const mir::Index& idx: mir::ShapeRange(unfold_kernel_shape)) { - auto in_group_no = idx.at(kernel_in_chan_num) / in_group_size; - auto out_group_no = idx.at(kernel_out_chan_num) / out_group_size; - // check that input channel group fits output channel group - if (in_group_no == out_group_no) { - // compute index in original kernel that corresponds output index - mir::Index folded_idx(idx); - folded_idx.at(kernel_in_chan_num) %= in_group_size; - - std::copy(folded_kernel->at(folded_idx), folded_kernel->at(folded_idx) + data_size, - unfold_kernel->at(idx)); - } else { - // fill element of output kernel with zero element - assert(folded_kernel->getDataType() == DTYPE::FLOAT32 && - "unsupported data type, add appropriate zero element creation"); - float* elem = reinterpret_cast(unfold_kernel->at(idx)); - *elem = 0.0f; - } - } - return unfold_kernel; -} - mir::IODescriptor CaffeOpCreator::convertCaffeToMIR(const mir::IODescriptor& arg) { if (cli::debugTranspose) { // NCHW -> NHWC diff --git a/contrib/nnc/passes/common_frontend/CMakeLists.txt b/contrib/nnc/passes/common_frontend/CMakeLists.txt index 84df465..cb3cc2e 100644 --- a/contrib/nnc/passes/common_frontend/CMakeLists.txt +++ b/contrib/nnc/passes/common_frontend/CMakeLists.txt @@ -2,8 +2,7 @@ # Common for every importer code library # ########################################## -set(COMMON_SOURCES - model_allocation.cpp) +set(COMMON_SOURCES model_allocation.cpp op_creator_helper.cpp) add_library(nn_import_common STATIC ${COMMON_SOURCES}) set_target_properties(nn_import_common PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/contrib/nnc/passes/common_frontend/op_creator_helper.cpp b/contrib/nnc/passes/common_frontend/op_creator_helper.cpp new file mode 100644 index 0000000..7bc0315 --- /dev/null +++ b/contrib/nnc/passes/common_frontend/op_creator_helper.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "passes/common_frontend/op_creator_helper.h" + +#include "core/modelIR/Shape.h" +#include "core/modelIR/ShapeRange.h" +#include "core/modelIR/TensorVariant.h" + +namespace nnc { + +using namespace mir; + +std::shared_ptr +fixGroupedKernel(int groups, std::shared_ptr folded_kernel) { + const int kernel_in_chan_num = 2; + const int kernel_out_chan_num = 3; + + const Shape& kernel_shape = folded_kernel->getShape(); + auto kernel_in_channels = kernel_shape.dim(kernel_in_chan_num); + auto kernel_out_channels = kernel_shape.dim(kernel_out_chan_num); + auto in_channels = kernel_in_channels * groups; + + // Original kernel has shape [H, W, inputChannels/groups, outputChannels] + // here creates unfolded kernel with shape [H, W, inputChannels, outputChannels] + Shape unfold_kernel_shape(kernel_shape); + unfold_kernel_shape.dim(kernel_in_chan_num) = in_channels; + auto buffer_size = unfold_kernel_shape.numElements() * folded_kernel->getElementSize(); + std::shared_ptr buffer(new char[buffer_size], std::default_delete()); + size_t data_size = folded_kernel->getElementSize(); + std::shared_ptr unfold_kernel = + std::make_shared(unfold_kernel_shape, buffer, folded_kernel->getDataType(), + data_size); + + int in_group_size = kernel_in_channels; + int out_group_size = kernel_out_channels / groups; + assert(kernel_out_channels % groups == 0); + + // Iterate over "unfolded" kernel Shape and insert appropriate values into result kernel + for (const mir::Index& idx: mir::ShapeRange(unfold_kernel_shape)) { + auto in_group_no = idx.at(kernel_in_chan_num) / in_group_size; + auto out_group_no = idx.at(kernel_out_chan_num) / out_group_size; + // check that input channel group fits output channel group + if (in_group_no == out_group_no) { + // compute index in original kernel that corresponds output index + mir::Index folded_idx(idx); + folded_idx.at(kernel_in_chan_num) %= in_group_size; + + std::copy(folded_kernel->at(folded_idx), folded_kernel->at(folded_idx) + data_size, + unfold_kernel->at(idx)); + } else { + // fill element of output kernel with zero element + assert(folded_kernel->getDataType() == DTYPE::FLOAT32 && + "unsupported data type, add appropriate zero element creation"); + float* elem = reinterpret_cast(unfold_kernel->at(idx)); + *elem = 0.0f; + } + } + return unfold_kernel; +} + +} // namespace nnc \ No newline at end of file