From: Cheng,Penghui Date: Sat, 30 Mar 2019 01:51:50 +0000 (-0700) Subject: support pre-convert filter format for mkldnn training mode and change 'OptimizeForIde... X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~538 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e13101e0691b0eabc1900f482a615ea7f14e7a72;p=platform%2Fupstream%2Fpytorch.git support pre-convert filter format for mkldnn training mode and change 'OptimizeForIdeep' to 'OptimizeForMkldnn' (#15171) Summary: For MKL-DNN,the filter data will be reorderd to primitive format, it takes a lot of time. So the patch provide a method to convert filter format before training. And "OptimizeForIdeep" will be changed to "OptimizeForMkldnn" in this patch. This patch depends on https://github.com/pytorch/pytorch/pull/12866 Pull Request resolved: https://github.com/pytorch/pytorch/pull/15171 Differential Revision: D14590741 Pulled By: yinghai fbshipit-source-id: 07971c9977edac3c8eec08ca2c39cda639683492 --- diff --git a/caffe2/ideep/ideep_utils.h b/caffe2/ideep/ideep_utils.h index a379d7c..db4195c 100644 --- a/caffe2/ideep/ideep_utils.h +++ b/caffe2/ideep/ideep_utils.h @@ -9,6 +9,12 @@ namespace caffe2 { +enum ConvAlgorithm { + CONV_ALGORITHM_AUTO = 0, + CONV_ALGORITHM_WINOGRAD = 1, + CONV_ALGORITHM_MAX = CONV_ALGORITHM_WINOGRAD + 1 +}; + #define USE_IDEEP_DEF_ALIASES() \ using itensor = ideep::tensor; \ using iformat = ideep::format; \ @@ -18,7 +24,4 @@ namespace caffe2 { using iattr = ideep::descriptor_group::attr_t; \ using ibn_flag = ideep::batch_normalization_flag; -const int CONV_ALGORITHM_AUTO = 0; -const int CONV_ALGORITHM_WINOGRAD = 1; - } // namespace caffe2 diff --git a/caffe2/opt/converter.cc b/caffe2/opt/converter.cc index 6f5cde5..44b8b0c 100644 --- a/caffe2/opt/converter.cc +++ b/caffe2/opt/converter.cc @@ -140,8 +140,30 @@ class ConvConverter : public Converter { ~ConvConverter() override {} }; +class ConvTransposeConverter : public Converter { + std::unique_ptr convertToNeuralNetOperator( + const OperatorDef& op) override { + std::unique_ptr nnOp; + auto argMap = getArgumentsFromOperator(op); + auto kernelShape = getKernelShape(argMap); + nnOp = util::make_unique(kernelShape); + auto c = dyn_cast(nnOp.get()); + + c->setStrides(getStrides(argMap)); + c->setPads(getPads(argMap)); + c->setGroup(getGroup(argMap)); + + return nnOp; + } + // Does not override default converter to OperatorDef + + virtual ~ConvTransposeConverter() {} +}; + REGISTER_CONVERTER(Conv, ConvConverter); +REGISTER_CONVERTER(ConvTranspose, ConvTransposeConverter); + TRIVIAL_CONVERTER(Relu); REGISTER_CONVERTER(Relu, ReluConverter); diff --git a/caffe2/opt/optimize_ideep.cc b/caffe2/opt/optimize_ideep.cc index c657cd7..05bce30 100644 --- a/caffe2/opt/optimize_ideep.cc +++ b/caffe2/opt/optimize_ideep.cc @@ -12,7 +12,7 @@ namespace opt { using namespace nom; #ifndef CAFFE2_USE_MKLDNN -void OptimizeForIdeep( +void OptimizeForMkldnn( repr::NNModule* nn, caffe2::Workspace* ws, bool training_mode) { @@ -37,6 +37,15 @@ T* getTensor(Blob* blob) { return nullptr; } +template +T* getMutableTensor(Blob* blob) { + CAFFE_ENFORCE(blob, "Blob is invalid"); + if (blob->template IsType()) { + return blob->template GetMutable(); + } + return nullptr; +} + const caffe2::OperatorDef& getOpDef(const repr::NeuralNetOperator& nnOp) { auto annotation = nnOp.getAnnotation(); if (annotation == nullptr) { @@ -72,7 +81,7 @@ bool shouldFuseConv(const repr::Conv& conv) { return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false; } -void removeStopGradientForInference(repr::NNModule *nn) { +void removeStopGradientForInference(repr::NNModule* nn) { auto isStopGradientNode = [](const repr::NNGraph::NodeRef& node) { if (!repr::nn::is(node)) { return false; @@ -430,10 +439,10 @@ void enforceFusionInplaceForIdeep(repr::NNModule* nn) { } } -void setPoolingInferenceMode(repr::NNModule *nn) { +void setPoolingInferenceMode(repr::NNModule* nn) { for (auto node_pair : repr::nn::dataIterator(nn->dataFlow)) { repr::NNGraph::NodeRef maxPoolNode; - repr::MaxPool *maxPool; + repr::MaxPool* maxPool; std::tie(maxPool, maxPoolNode) = node_pair; if (!isOnIdeepDevice(*maxPool)) { @@ -441,9 +450,9 @@ void setPoolingInferenceMode(repr::NNModule *nn) { continue; } - auto *op = getMutableOpDef(*maxPool); + auto* op = getMutableOpDef(*maxPool); bool found_training_mode = false; - for (auto &arg : *op->mutable_arg()) { + for (auto& arg : *op->mutable_arg()) { if (arg.name() == "training_mode") { arg.set_i(0); found_training_mode = true; @@ -452,19 +461,149 @@ void setPoolingInferenceMode(repr::NNModule *nn) { } if (!found_training_mode) { - auto *arg = op->add_arg(); + auto* arg = op->add_arg(); arg->set_name("training_mode"); arg->set_i(0); } } } -void OptimizeForIdeep( - repr::NNModule* nn, - caffe2::Workspace* ws, - bool training_mode) { +// Pre-convert filters format to expected one here +// in order to avoid boring conversions during computations +void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) { + for (auto& node : nn->dataFlow.getMutableNodes()) { + if (!repr::nn::is(node) && + !repr::nn::is(node) && !repr::nn::is(node)) { + continue; + } + + auto* nnOp = repr::nn::get(node); + if (!isOnIdeepDevice(*nnOp)) { + LOG(INFO) << "Not a IDEEP operator"; + continue; + } + + auto inputs = repr::nn::getInputs(node); + if (inputs.size() < 2) { + LOG(WARNING) << "Invalid input size"; + continue; + } + + auto* filterBlob = getBlob(inputs[1], ws); + auto* filter = getMutableTensor(filterBlob); + if (filter == nullptr) { + continue; + } + + itensor::descriptor expectedDesc; + if (repr::nn::is(node)) { + if (filter->get_public_format() == ideep::format::iohw) + continue; + auto convTranspose = repr::nn::get(node); + auto initValue = [](vector& v, vector i) { + if (v.empty()) + v = i; + }; + auto strides = convTranspose->getStrides(); + initValue(strides, {1, 1}); + auto pads = convTranspose->getPads(); + initValue(pads, {0, 0, 0, 0}); + auto* op = getMutableOpDef(*convTranspose); + auto aalgorithm = ialgo::deconvolution_direct; + auto dataType = filter->get_data_type(); + ideep::tensor::dims filter_dims_mkldnn{filter->get_dim(1), + filter->get_dim(0), + filter->get_dim(2), + filter->get_dim(3)}; + expectedDesc = + ideep::convolution_transpose_forward::expected_weights_descriptor( + filter_dims_mkldnn, + dataType, + strides, + {pads[0], pads[1]}, + {pads[2], pads[3]}); + + if (filter->get_descriptor() != expectedDesc) { + filter->set_public_format(ideep::format::iohw); + itensor&& newFilter(expectedDesc); + ideep::reorder::compute(*filter, newFilter); + newFilter.set_public_format(ideep::format::iohw); + filterBlob->Reset(new itensor(newFilter)); + } + } else if (repr::nn::is(node)) { + auto conv = repr::nn::get(node); + auto initValue = [](vector& v, vector i) { + if (v.empty()) + v = i; + }; + auto strides = conv->getStrides(); + initValue(strides, {1, 1}); + auto pads = conv->getPads(); + initValue(pads, {0, 0, 0, 0}); + auto dilations = conv->getDilations(); + initValue(dilations, {1, 1}); + + auto* op = getMutableOpDef(*conv); + auto aalgorithm = ialgo::convolution_direct; + for (auto& arg : *op->mutable_arg()) { + if ((arg.name() == "conv_algorithm") && + (arg.i() == CONV_ALGORITHM_WINOGRAD)) { + aalgorithm = ialgo::convolution_winograd; + } + } + auto dataType = filter->get_data_type(); + + filter->make_group(conv->getGroup()); + expectedDesc = ideep::convolution_forward::expected_weights_descriptor( + filter->get_dims(), + dataType, + strides, + {pads[0], pads[1]}, + {pads[2], pads[3]}, + dilations, + conv->getGroup(), + aalgorithm); + + if (filter->get_descriptor() != expectedDesc) { + itensor&& newFilter(expectedDesc); + ideep::reorder::compute(*filter, newFilter); + filterBlob->Reset(new itensor(newFilter)); + } + // convert weights for FC + } else if (repr::nn::is(node)) { + auto fc = repr::nn::get(node); + auto axis_w = fc->getAxisW(); + if (axis_w != 1) { + auto f_dims = filter->get_dims(); + auto f_dim0 = std::accumulate( + f_dims.begin(), + f_dims.begin() + axis_w, + 1, + std::multiplies()); + auto f_dim1 = std::accumulate( + f_dims.begin() + axis_w, + f_dims.end(), + 1, + std::multiplies()); + filter->reshape({f_dim0, f_dim1}); + } + + expectedDesc = ideep::inner_product_forward::expected_weights_descriptor( + filter->get_dims()); + + if (filter->get_descriptor() != expectedDesc) { + itensor&& newFilter(expectedDesc); + ideep::reorder::compute(filter->as_weights(), newFilter); + filterBlob->Reset(new itensor(newFilter)); + } + } + } +} + +void OptimizeForMkldnn(repr::NNModule *nn, caffe2::Workspace *ws, + bool training_mode) { if (training_mode) { - // Only support inference so far + preConvertFiltersFormat(nn, ws); return; } diff --git a/caffe2/opt/optimize_ideep.h b/caffe2/opt/optimize_ideep.h index d23fb54..85b86bf 100644 --- a/caffe2/opt/optimize_ideep.h +++ b/caffe2/opt/optimize_ideep.h @@ -8,7 +8,7 @@ namespace caffe2 { namespace opt { -CAFFE2_API void OptimizeForIdeep( +CAFFE2_API void OptimizeForMkldnn( nom::repr::NNModule* nn, caffe2::Workspace* ws, bool training_mode = false); diff --git a/caffe2/python/ideep/conv_op_test.py b/caffe2/python/ideep/conv_op_test.py index 02ad31d..0377c83 100644 --- a/caffe2/python/ideep/conv_op_test.py +++ b/caffe2/python/ideep/conv_op_test.py @@ -9,7 +9,7 @@ from hypothesis import given, settings import numpy as np from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace -from caffe2.python.transformations import optimizeForIDEEP +from caffe2.python.transformations import optimizeForMKLDNN import caffe2.python.hypothesis_test_util as hu import caffe2.python.ideep_test_util as mu @@ -133,7 +133,7 @@ class ConvTest(hu.HypothesisTestCase): old_net = caffe2_pb2.NetDef() old_net.op.extend([op1]) net.Proto().CopyFrom(old_net) - optimizeForIDEEP(net) + optimizeForMKLDNN(net) workspace.RunOperatorOnce(net.Proto().op[0]) Y1 = workspace.FetchBlob('Y') diff --git a/caffe2/python/ideep/convfusion_op_test.py b/caffe2/python/ideep/convfusion_op_test.py index de66a4d..8c40be8 100644 --- a/caffe2/python/ideep/convfusion_op_test.py +++ b/caffe2/python/ideep/convfusion_op_test.py @@ -10,7 +10,7 @@ import copy import numpy as np from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace -from caffe2.python.transformations import optimizeForIDEEP +from caffe2.python.transformations import optimizeForMKLDNN import caffe2.python.hypothesis_test_util as hu import caffe2.python.ideep_test_util as mu @@ -103,7 +103,7 @@ class ConvFusionTest(hu.HypothesisTestCase): workspace.FeedBlob('b0', b, dc[1]) net = core.Net("net") net.Proto().CopyFrom(old_net) - optimizeForIDEEP(net) + optimizeForMKLDNN(net) self.assertTrue(len(net.Proto().op) == 1) self.assertTrue(net.Proto().op[0].type == "ConvFusion") workspace.RunOperatorOnce(net.Proto().op[0]) @@ -243,7 +243,7 @@ class ConvFusionTest(hu.HypothesisTestCase): workspace.FeedBlob('b0', b, dc[1]) net = core.Net("net") net.Proto().CopyFrom(old_net) - optimizeForIDEEP(net) + optimizeForMKLDNN(net) self.assertTrue(len(net.Proto().op) == 2) self.assertTrue(net.Proto().op[1].type == "ConvFusion") workspace.RunNetOnce(net.Proto()) @@ -393,7 +393,7 @@ class ConvFusionTest(hu.HypothesisTestCase): workspace.FeedBlob('b0', b, dc[1]) net = core.Net("net") net.Proto().CopyFrom(old_net) - optimizeForIDEEP(net) + optimizeForMKLDNN(net) self.assertTrue(len(net.Proto().op) == 2) self.assertTrue(net.Proto().op[1].type == "ConvFusion") workspace.RunNetOnce(net.Proto()) @@ -481,7 +481,7 @@ class ConvFusionTest(hu.HypothesisTestCase): workspace.FeedBlob('var', var, dc[1]) net = core.Net("net") net.Proto().CopyFrom(old_net) - optimizeForIDEEP(net) + optimizeForMKLDNN(net) self.assertTrue(len(net.Proto().op) == 1) self.assertTrue(net.Proto().op[0].type == "Conv") workspace.RunOperatorOnce(net.Proto().op[0]) @@ -562,7 +562,7 @@ class ConvFusionTest(hu.HypothesisTestCase): workspace.FeedBlob('bias', bias, dc[1]) net = core.Net("net") net.Proto().CopyFrom(old_net) - optimizeForIDEEP(net) + optimizeForMKLDNN(net) self.assertTrue(len(net.Proto().op) == 1) self.assertTrue(net.Proto().op[0].type == "Conv") workspace.RunOperatorOnce(net.Proto().op[0]) diff --git a/caffe2/python/ideep/pre_convert_test.py b/caffe2/python/ideep/pre_convert_test.py new file mode 100644 index 0000000..a32eedd --- /dev/null +++ b/caffe2/python/ideep/pre_convert_test.py @@ -0,0 +1,97 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +import hypothesis.strategies as st +from hypothesis import given +import numpy as np +from caffe2.proto import caffe2_pb2 +from caffe2.python import ( + brew, + core, + model_helper, + workspace, +) +from caffe2.python.transformations import optimizeForMKLDNN +import caffe2.python.hypothesis_test_util as hu + + +@unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") +class PreConvertTest(hu.HypothesisTestCase): + @given(input_channels=st.integers(15, 16), + batch_size=st.integers(1, 3)) + def test_preConvert(self, input_channels, batch_size): + def AddModel(model, data): + conv1 = brew.conv(model, data, 'conv1', dim_in=input_channels, + dim_out=10, kernel=3, stride=1, pad=1, training_mode=1) + deconv1 = brew.conv_transpose(model, conv1, 'deconv1', dim_in=10, dim_out=10, + kernel=2, stride=2, pad=0, training_mode=1) + fc1 = brew.fc(model, deconv1, 'fc1', dim_in=10 * 56 * 56, dim_out=3) + softmax = brew.softmax(model, fc1, 'softmax') + + return softmax + + def AddTrainingOperators(model, softmax, label): + """Adds training operators to the model.""" + # Compute cross entropy between softmax scores and labels + xent = model.LabelCrossEntropy([softmax, label], 'xent') + # Compute the expected loss + loss = model.AveragedLoss(xent, "loss") + # Use the average loss we just computed to add gradient operators to the model + model.AddGradientOperators([loss]) + + arg_scope = {"order": "NCHW", 'no_bias': False} + # Create the model helper for the train model + device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0) + with core.DeviceScope(device_opt): + train_model = model_helper.ModelHelper(name="test_train", arg_scope=arg_scope) + # Add the model definition (fc layers, conv layers, softmax, etc.) + softmax = AddModel(train_model, "X") + AddTrainingOperators(train_model, softmax, "label") + + X = np.random.rand( + batch_size, input_channels, 28, 28).astype(np.float32) - 0.5 + label = np.random.randint(3, size=batch_size).astype(np.int32) + blob_dict = {} + output_dict = {} + output_dict_cosim = {} + old_ws_name = workspace.CurrentWorkspace() + workspace.FeedBlob('X', X) + workspace.FeedBlob('label', label) + workspace.RunNetOnce(train_model.param_init_net) + for op in train_model.net.Proto().op: + if op.type == "Softmax": + break + for j in range(1, len(op.input)): + blob_dict[op.input[j]] = workspace.FetchBlob(op.input[j]) + + workspace.CreateNet(train_model.net, overwrite=True) + optimizeForMKLDNN(train_model.net, training_mode=True) + workspace.RunNet(train_model.net) + for op in train_model.net.Proto().op: + for blob in op.output: + output_dict[blob] = workspace.FetchBlob(blob) + + workspace.SwitchWorkspace("_device_check_", True) + workspace.FeedBlob('X', X) + workspace.FeedBlob('label', label) + for blob in blob_dict.keys(): + workspace.FeedBlob(blob, blob_dict[blob]) + workspace.CreateNet(train_model.net, overwrite=True) + workspace.RunNet(train_model.net) + for blob in output_dict.keys(): + output_dict_cosim[blob] = workspace.FetchBlob(blob) + + for blob in output_dict.keys(): + if not np.allclose(output_dict[blob], output_dict_cosim[blob], atol=0.001, rtol=0.0001): + print("blob {} error".format(blob)) + print(np.max(np.abs(output_dict[blob] - output_dict_cosim[blob]))) + self.assertTrue(False) + + workspace.ResetWorkspace() + workspace.SwitchWorkspace(old_ws_name) + +if __name__ == "__main__": + unittest.main() diff --git a/caffe2/python/ideep/shape_op_test.py b/caffe2/python/ideep/shape_op_test.py index d87ae54..a7defc9 100644 --- a/caffe2/python/ideep/shape_op_test.py +++ b/caffe2/python/ideep/shape_op_test.py @@ -9,7 +9,6 @@ from hypothesis import given, settings import numpy as np from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace -from caffe2.python.transformations import optimizeForIDEEP import caffe2.python.hypothesis_test_util as hu import caffe2.python.ideep_test_util as mu diff --git a/caffe2/python/ideep/transform_ideep_net.py b/caffe2/python/ideep/transform_ideep_net.py index d420a63..6345b76 100644 --- a/caffe2/python/ideep/transform_ideep_net.py +++ b/caffe2/python/ideep/transform_ideep_net.py @@ -277,7 +277,7 @@ def fuse_conv_relu(net): op.device_option.CopyFrom(device_option) new_net = caffe2_pb2.NetDef() - new_net.ParseFromString(C.transform_optimizeForIDEEP(net.SerializeToString())) + new_net.ParseFromString(C.transform_optimizeForMKLDNN(net.SerializeToString())) return new_net diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 1bc772a..e4f3e6f 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -1691,12 +1691,12 @@ void addGlobalMethods(py::module& m) { // into a python interface in transformations.py // Prefix the transformation with transform_ to avoid clobbering the // function namespace. - m.def("transform_optimizeForIDEEP", [](py::bytes def, bool training_mode) { + m.def("transform_optimizeForMKLDNN", [](py::bytes def, bool training_mode) { caffe2::NetDef proto; CAFFE_ENFORCE(ParseProtoFromLargeString(def.cast(), &proto)); auto nn = caffe2::convertToNNModule(proto); - opt::OptimizeForIdeep(&nn, gWorkspace, training_mode); + opt::OptimizeForMkldnn(&nn, gWorkspace, training_mode); auto new_proto = caffe2::convertToCaffe2Proto(nn, proto); std::string out; diff --git a/caffe2/python/pybind_state_ideep.cc b/caffe2/python/pybind_state_ideep.cc index 2b20af2..ff4971e 100644 --- a/caffe2/python/pybind_state_ideep.cc +++ b/caffe2/python/pybind_state_ideep.cc @@ -64,7 +64,8 @@ public: numpy_type != -1, "Unsupported ideep memory data type? This usually should not happen " "since ideep memory usually only do float and double."); - itensor::dims dims = atensor.get_dims(); + itensor::dims dims = atensor.get_public_format_dims(); + std::vector npy_dims(dims.begin(), dims.end()); result.copied = force_copy || atensor.need_reorder(); diff --git a/caffe2/python/transformations.py b/caffe2/python/transformations.py index 2f9fc7a..ed0a327 100644 --- a/caffe2/python/transformations.py +++ b/caffe2/python/transformations.py @@ -46,9 +46,9 @@ def fuseNNPACKConvRelu(net): ) -def optimizeForIDEEP(net, training_mode = False): +def optimizeForMKLDNN(net, training_mode = False): net.Proto().ParseFromString( - C.transform_optimizeForIDEEP(net.Proto().SerializeToString(), training_mode) + C.transform_optimizeForMKLDNN(net.Proto().SerializeToString(), training_mode) )