support pre-convert filter format for mkldnn training mode and change 'OptimizeForIde...
authorCheng,Penghui <penghui.cheng@intel.com>
Sat, 30 Mar 2019 01:51:50 +0000 (18:51 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 02:00:48 +0000 (19:00 -0700)
Summary:
For MKL-DNN,the filter data will be reorderd to primitive format, it takes a lot of time.
So the patch provide a method to convert filter format before training.
And "OptimizeForIdeep" will be changed to "OptimizeForMkldnn" in this patch.
 This patch depends on https://github.com/pytorch/pytorch/pull/12866
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15171

Differential Revision: D14590741

Pulled By: yinghai

fbshipit-source-id: 07971c9977edac3c8eec08ca2c39cda639683492

12 files changed:
caffe2/ideep/ideep_utils.h
caffe2/opt/converter.cc
caffe2/opt/optimize_ideep.cc
caffe2/opt/optimize_ideep.h
caffe2/python/ideep/conv_op_test.py
caffe2/python/ideep/convfusion_op_test.py
caffe2/python/ideep/pre_convert_test.py [new file with mode: 0644]
caffe2/python/ideep/shape_op_test.py
caffe2/python/ideep/transform_ideep_net.py
caffe2/python/pybind_state.cc
caffe2/python/pybind_state_ideep.cc
caffe2/python/transformations.py

index a379d7c..db4195c 100644 (file)
@@ -9,6 +9,12 @@
 
 namespace caffe2 {
 
+enum ConvAlgorithm {
+  CONV_ALGORITHM_AUTO = 0,
+  CONV_ALGORITHM_WINOGRAD = 1,
+  CONV_ALGORITHM_MAX = CONV_ALGORITHM_WINOGRAD + 1
+};
+
 #define USE_IDEEP_DEF_ALIASES()                                                \
   using itensor = ideep::tensor;                                               \
   using iformat = ideep::format;                                               \
@@ -18,7 +24,4 @@ namespace caffe2 {
   using iattr = ideep::descriptor_group::attr_t;                               \
   using ibn_flag = ideep::batch_normalization_flag;
 
-const int CONV_ALGORITHM_AUTO = 0;
-const int CONV_ALGORITHM_WINOGRAD = 1;
-
 } // namespace caffe2
index 6f5cde5..44b8b0c 100644 (file)
@@ -140,8 +140,30 @@ class ConvConverter : public Converter {
   ~ConvConverter() override {}
 };
 
+class ConvTransposeConverter : public Converter {
+  std::unique_ptr<nom::repr::NeuralNetOperator> convertToNeuralNetOperator(
+      const OperatorDef& op) override {
+    std::unique_ptr<repr::NeuralNetOperator> nnOp;
+    auto argMap = getArgumentsFromOperator(op);
+    auto kernelShape = getKernelShape(argMap);
+    nnOp = util::make_unique<repr::ConvTranspose>(kernelShape);
+    auto c = dyn_cast<repr::ConvTranspose>(nnOp.get());
+
+    c->setStrides(getStrides(argMap));
+    c->setPads(getPads(argMap));
+    c->setGroup(getGroup(argMap));
+
+    return nnOp;
+  }
+  // Does not override default converter to OperatorDef
+
+  virtual ~ConvTransposeConverter() {}
+};
+
 REGISTER_CONVERTER(Conv, ConvConverter);
 
+REGISTER_CONVERTER(ConvTranspose, ConvTransposeConverter);
+
 TRIVIAL_CONVERTER(Relu);
 REGISTER_CONVERTER(Relu, ReluConverter);
 
index c657cd7..05bce30 100644 (file)
@@ -12,7 +12,7 @@ namespace opt {
 using namespace nom;
 
 #ifndef CAFFE2_USE_MKLDNN
-void OptimizeForIdeep(
+void OptimizeForMkldnn(
     repr::NNModule* nn,
     caffe2::Workspace* ws,
     bool training_mode) {
@@ -37,6 +37,15 @@ T* getTensor(Blob* blob) {
   return nullptr;
 }
 
+template <class T>
+T* getMutableTensor(Blob* blob) {
+  CAFFE_ENFORCE(blob, "Blob is invalid");
+  if (blob->template IsType<T>()) {
+    return blob->template GetMutable<T>();
+  }
+  return nullptr;
+}
+
 const caffe2::OperatorDef& getOpDef(const repr::NeuralNetOperator& nnOp) {
   auto annotation = nnOp.getAnnotation();
   if (annotation == nullptr) {
@@ -72,7 +81,7 @@ bool shouldFuseConv(const repr::Conv& conv) {
   return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false;
 }
 
-void removeStopGradientForInference(repr::NNModule *nn) {
+void removeStopGradientForInference(repr::NNModulenn) {
   auto isStopGradientNode = [](const repr::NNGraph::NodeRef& node) {
     if (!repr::nn::is<repr::NeuralNetOperator>(node)) {
       return false;
@@ -430,10 +439,10 @@ void enforceFusionInplaceForIdeep(repr::NNModule* nn) {
   }
 }
 
-void setPoolingInferenceMode(repr::NNModule *nn) {
+void setPoolingInferenceMode(repr::NNModulenn) {
   for (auto node_pair : repr::nn::dataIterator<repr::MaxPool>(nn->dataFlow)) {
     repr::NNGraph::NodeRef maxPoolNode;
-    repr::MaxPool *maxPool;
+    repr::MaxPoolmaxPool;
     std::tie(maxPool, maxPoolNode) = node_pair;
 
     if (!isOnIdeepDevice(*maxPool)) {
@@ -441,9 +450,9 @@ void setPoolingInferenceMode(repr::NNModule *nn) {
       continue;
     }
 
-    auto *op = getMutableOpDef(*maxPool);
+    autoop = getMutableOpDef(*maxPool);
     bool found_training_mode = false;
-    for (auto &arg : *op->mutable_arg()) {
+    for (autoarg : *op->mutable_arg()) {
       if (arg.name() == "training_mode") {
         arg.set_i(0);
         found_training_mode = true;
@@ -452,19 +461,149 @@ void setPoolingInferenceMode(repr::NNModule *nn) {
     }
 
     if (!found_training_mode) {
-      auto *arg = op->add_arg();
+      autoarg = op->add_arg();
       arg->set_name("training_mode");
       arg->set_i(0);
     }
   }
 }
 
-void OptimizeForIdeep(
-    repr::NNModule* nn,
-    caffe2::Workspace* ws,
-    bool training_mode) {
+// Pre-convert filters format to expected one here
+// in order to avoid boring conversions during computations
+void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
+  for (auto& node : nn->dataFlow.getMutableNodes()) {
+    if (!repr::nn::is<repr::ConvTranspose>(node) &&
+        !repr::nn::is<repr::Conv>(node) && !repr::nn::is<repr::FC>(node)) {
+      continue;
+    }
+
+    auto* nnOp = repr::nn::get<repr::NeuralNetOperator>(node);
+    if (!isOnIdeepDevice(*nnOp)) {
+      LOG(INFO) << "Not a IDEEP operator";
+      continue;
+    }
+
+    auto inputs = repr::nn::getInputs(node);
+    if (inputs.size() < 2) {
+      LOG(WARNING) << "Invalid input size";
+      continue;
+    }
+
+    auto* filterBlob = getBlob(inputs[1], ws);
+    auto* filter = getMutableTensor<itensor>(filterBlob);
+    if (filter == nullptr) {
+      continue;
+    }
+
+    itensor::descriptor expectedDesc;
+    if (repr::nn::is<repr::ConvTranspose>(node)) {
+      if (filter->get_public_format() == ideep::format::iohw)
+        continue;
+      auto convTranspose = repr::nn::get<repr::ConvTranspose>(node);
+      auto initValue = [](vector<int>& v, vector<int> i) {
+        if (v.empty())
+          v = i;
+      };
+      auto strides = convTranspose->getStrides();
+      initValue(strides, {1, 1});
+      auto pads = convTranspose->getPads();
+      initValue(pads, {0, 0, 0, 0});
+      auto* op = getMutableOpDef(*convTranspose);
+      auto aalgorithm = ialgo::deconvolution_direct;
+      auto dataType = filter->get_data_type();
+      ideep::tensor::dims filter_dims_mkldnn{filter->get_dim(1),
+                                             filter->get_dim(0),
+                                             filter->get_dim(2),
+                                             filter->get_dim(3)};
+      expectedDesc =
+          ideep::convolution_transpose_forward::expected_weights_descriptor(
+              filter_dims_mkldnn,
+              dataType,
+              strides,
+              {pads[0], pads[1]},
+              {pads[2], pads[3]});
+
+      if (filter->get_descriptor() != expectedDesc) {
+        filter->set_public_format(ideep::format::iohw);
+        itensor&& newFilter(expectedDesc);
+        ideep::reorder::compute(*filter, newFilter);
+        newFilter.set_public_format(ideep::format::iohw);
+        filterBlob->Reset<itensor>(new itensor(newFilter));
+      }
+    } else if (repr::nn::is<repr::Conv>(node)) {
+      auto conv = repr::nn::get<repr::Conv>(node);
+      auto initValue = [](vector<int>& v, vector<int> i) {
+        if (v.empty())
+          v = i;
+      };
+      auto strides = conv->getStrides();
+      initValue(strides, {1, 1});
+      auto pads = conv->getPads();
+      initValue(pads, {0, 0, 0, 0});
+      auto dilations = conv->getDilations();
+      initValue(dilations, {1, 1});
+
+      auto* op = getMutableOpDef(*conv);
+      auto aalgorithm = ialgo::convolution_direct;
+      for (auto& arg : *op->mutable_arg()) {
+        if ((arg.name() == "conv_algorithm") &&
+            (arg.i() == CONV_ALGORITHM_WINOGRAD)) {
+          aalgorithm = ialgo::convolution_winograd;
+        }
+      }
+      auto dataType = filter->get_data_type();
+
+      filter->make_group(conv->getGroup());
+      expectedDesc = ideep::convolution_forward::expected_weights_descriptor(
+          filter->get_dims(),
+          dataType,
+          strides,
+          {pads[0], pads[1]},
+          {pads[2], pads[3]},
+          dilations,
+          conv->getGroup(),
+          aalgorithm);
+
+      if (filter->get_descriptor() != expectedDesc) {
+        itensor&& newFilter(expectedDesc);
+        ideep::reorder::compute(*filter, newFilter);
+        filterBlob->Reset<itensor>(new itensor(newFilter));
+      }
+      // convert weights for FC
+    } else if (repr::nn::is<repr::FC>(node)) {
+      auto fc = repr::nn::get<repr::FC>(node);
+      auto axis_w = fc->getAxisW();
+      if (axis_w != 1) {
+        auto f_dims = filter->get_dims();
+        auto f_dim0 = std::accumulate(
+            f_dims.begin(),
+            f_dims.begin() + axis_w,
+            1,
+            std::multiplies<itensor::dim_t>());
+        auto f_dim1 = std::accumulate(
+            f_dims.begin() + axis_w,
+            f_dims.end(),
+            1,
+            std::multiplies<itensor::dim_t>());
+        filter->reshape({f_dim0, f_dim1});
+      }
+
+      expectedDesc = ideep::inner_product_forward::expected_weights_descriptor(
+          filter->get_dims());
+
+      if (filter->get_descriptor() != expectedDesc) {
+        itensor&& newFilter(expectedDesc);
+        ideep::reorder::compute(filter->as_weights(), newFilter);
+        filterBlob->Reset<itensor>(new itensor(newFilter));
+      }
+    }
+  }
+}
+
+void OptimizeForMkldnn(repr::NNModule *nn, caffe2::Workspace *ws,
+                      bool training_mode) {
   if (training_mode) {
-    // Only support inference so far
+    preConvertFiltersFormat(nn, ws);
     return;
   }
 
index d23fb54..85b86bf 100644 (file)
@@ -8,7 +8,7 @@
 namespace caffe2 {
 namespace opt {
 
-CAFFE2_API void OptimizeForIdeep(
+CAFFE2_API void OptimizeForMkldnn(
     nom::repr::NNModule* nn,
     caffe2::Workspace* ws,
     bool training_mode = false);
index 02ad31d..0377c83 100644 (file)
@@ -9,7 +9,7 @@ from hypothesis import given, settings
 import numpy as np
 from caffe2.proto import caffe2_pb2
 from caffe2.python import core, workspace
-from caffe2.python.transformations import optimizeForIDEEP
+from caffe2.python.transformations import optimizeForMKLDNN
 import caffe2.python.hypothesis_test_util as hu
 import caffe2.python.ideep_test_util as mu
 
@@ -133,7 +133,7 @@ class ConvTest(hu.HypothesisTestCase):
         old_net = caffe2_pb2.NetDef()
         old_net.op.extend([op1])
         net.Proto().CopyFrom(old_net)
-        optimizeForIDEEP(net)
+        optimizeForMKLDNN(net)
         workspace.RunOperatorOnce(net.Proto().op[0])
         Y1 = workspace.FetchBlob('Y')
 
index de66a4d..8c40be8 100644 (file)
@@ -10,7 +10,7 @@ import copy
 import numpy as np
 from caffe2.proto import caffe2_pb2
 from caffe2.python import core, workspace
-from caffe2.python.transformations import optimizeForIDEEP
+from caffe2.python.transformations import optimizeForMKLDNN
 import caffe2.python.hypothesis_test_util as hu
 import caffe2.python.ideep_test_util as mu
 
@@ -103,7 +103,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
         workspace.FeedBlob('b0', b, dc[1])
         net = core.Net("net")
         net.Proto().CopyFrom(old_net)
-        optimizeForIDEEP(net)
+        optimizeForMKLDNN(net)
         self.assertTrue(len(net.Proto().op) == 1)
         self.assertTrue(net.Proto().op[0].type == "ConvFusion")
         workspace.RunOperatorOnce(net.Proto().op[0])
@@ -243,7 +243,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
         workspace.FeedBlob('b0', b, dc[1])
         net = core.Net("net")
         net.Proto().CopyFrom(old_net)
-        optimizeForIDEEP(net)
+        optimizeForMKLDNN(net)
         self.assertTrue(len(net.Proto().op) == 2)
         self.assertTrue(net.Proto().op[1].type == "ConvFusion")
         workspace.RunNetOnce(net.Proto())
@@ -393,7 +393,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
         workspace.FeedBlob('b0', b, dc[1])
         net = core.Net("net")
         net.Proto().CopyFrom(old_net)
-        optimizeForIDEEP(net)
+        optimizeForMKLDNN(net)
         self.assertTrue(len(net.Proto().op) == 2)
         self.assertTrue(net.Proto().op[1].type == "ConvFusion")
         workspace.RunNetOnce(net.Proto())
@@ -481,7 +481,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
         workspace.FeedBlob('var', var, dc[1])
         net = core.Net("net")
         net.Proto().CopyFrom(old_net)
-        optimizeForIDEEP(net)
+        optimizeForMKLDNN(net)
         self.assertTrue(len(net.Proto().op) == 1)
         self.assertTrue(net.Proto().op[0].type == "Conv")
         workspace.RunOperatorOnce(net.Proto().op[0])
@@ -562,7 +562,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
         workspace.FeedBlob('bias', bias, dc[1])
         net = core.Net("net")
         net.Proto().CopyFrom(old_net)
-        optimizeForIDEEP(net)
+        optimizeForMKLDNN(net)
         self.assertTrue(len(net.Proto().op) == 1)
         self.assertTrue(net.Proto().op[0].type == "Conv")
         workspace.RunOperatorOnce(net.Proto().op[0])
diff --git a/caffe2/python/ideep/pre_convert_test.py b/caffe2/python/ideep/pre_convert_test.py
new file mode 100644 (file)
index 0000000..a32eedd
--- /dev/null
@@ -0,0 +1,97 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+import hypothesis.strategies as st
+from hypothesis import given
+import numpy as np
+from caffe2.proto import caffe2_pb2
+from caffe2.python import (
+    brew,
+    core,
+    model_helper,
+    workspace,
+)
+from caffe2.python.transformations import optimizeForMKLDNN
+import caffe2.python.hypothesis_test_util as hu
+
+
+@unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.")
+class PreConvertTest(hu.HypothesisTestCase):
+    @given(input_channels=st.integers(15, 16),
+           batch_size=st.integers(1, 3))
+    def test_preConvert(self, input_channels, batch_size):
+        def AddModel(model, data):
+            conv1 = brew.conv(model, data, 'conv1', dim_in=input_channels,
+                              dim_out=10, kernel=3, stride=1, pad=1, training_mode=1)
+            deconv1 = brew.conv_transpose(model, conv1, 'deconv1', dim_in=10, dim_out=10,
+                                          kernel=2, stride=2, pad=0, training_mode=1)
+            fc1 = brew.fc(model, deconv1, 'fc1', dim_in=10 * 56 * 56, dim_out=3)
+            softmax = brew.softmax(model, fc1, 'softmax')
+
+            return softmax
+
+        def AddTrainingOperators(model, softmax, label):
+            """Adds training operators to the model."""
+            # Compute cross entropy between softmax scores and labels
+            xent = model.LabelCrossEntropy([softmax, label], 'xent')
+            # Compute the expected loss
+            loss = model.AveragedLoss(xent, "loss")
+            # Use the average loss we just computed to add gradient operators to the model
+            model.AddGradientOperators([loss])
+
+        arg_scope = {"order": "NCHW", 'no_bias': False}
+        # Create the model helper for the train model
+        device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0)
+        with core.DeviceScope(device_opt):
+            train_model = model_helper.ModelHelper(name="test_train", arg_scope=arg_scope)
+            # Add the model definition (fc layers, conv layers, softmax, etc.)
+            softmax = AddModel(train_model, "X")
+            AddTrainingOperators(train_model, softmax, "label")
+
+            X = np.random.rand(
+                batch_size, input_channels, 28, 28).astype(np.float32) - 0.5
+            label = np.random.randint(3, size=batch_size).astype(np.int32)
+            blob_dict = {}
+            output_dict = {}
+            output_dict_cosim = {}
+            old_ws_name = workspace.CurrentWorkspace()
+            workspace.FeedBlob('X', X)
+            workspace.FeedBlob('label', label)
+            workspace.RunNetOnce(train_model.param_init_net)
+            for op in train_model.net.Proto().op:
+                if op.type == "Softmax":
+                    break
+                for j in range(1, len(op.input)):
+                    blob_dict[op.input[j]] = workspace.FetchBlob(op.input[j])
+
+            workspace.CreateNet(train_model.net, overwrite=True)
+            optimizeForMKLDNN(train_model.net, training_mode=True)
+            workspace.RunNet(train_model.net)
+            for op in train_model.net.Proto().op:
+                for blob in op.output:
+                    output_dict[blob] = workspace.FetchBlob(blob)
+
+            workspace.SwitchWorkspace("_device_check_", True)
+            workspace.FeedBlob('X', X)
+            workspace.FeedBlob('label', label)
+            for blob in blob_dict.keys():
+                workspace.FeedBlob(blob, blob_dict[blob])
+            workspace.CreateNet(train_model.net, overwrite=True)
+            workspace.RunNet(train_model.net)
+            for blob in output_dict.keys():
+                output_dict_cosim[blob] = workspace.FetchBlob(blob)
+
+            for blob in output_dict.keys():
+                if not np.allclose(output_dict[blob], output_dict_cosim[blob], atol=0.001, rtol=0.0001):
+                    print("blob {} error".format(blob))
+                    print(np.max(np.abs(output_dict[blob] - output_dict_cosim[blob])))
+                    self.assertTrue(False)
+
+            workspace.ResetWorkspace()
+            workspace.SwitchWorkspace(old_ws_name)
+
+if __name__ == "__main__":
+    unittest.main()
index d87ae54..a7defc9 100644 (file)
@@ -9,7 +9,6 @@ from hypothesis import given, settings
 import numpy as np
 from caffe2.proto import caffe2_pb2
 from caffe2.python import core, workspace
-from caffe2.python.transformations import optimizeForIDEEP
 import caffe2.python.hypothesis_test_util as hu
 import caffe2.python.ideep_test_util as mu
 
index d420a63..6345b76 100644 (file)
@@ -277,7 +277,7 @@ def fuse_conv_relu(net):
         op.device_option.CopyFrom(device_option)
 
     new_net = caffe2_pb2.NetDef()
-    new_net.ParseFromString(C.transform_optimizeForIDEEP(net.SerializeToString()))
+    new_net.ParseFromString(C.transform_optimizeForMKLDNN(net.SerializeToString()))
     return new_net
 
 
index 1bc772a..e4f3e6f 100644 (file)
@@ -1691,12 +1691,12 @@ void addGlobalMethods(py::module& m) {
   // into a python interface in transformations.py
   // Prefix the transformation with transform_ to avoid clobbering the
   // function namespace.
-  m.def("transform_optimizeForIDEEP", [](py::bytes def, bool training_mode) {
+  m.def("transform_optimizeForMKLDNN", [](py::bytes def, bool training_mode) {
     caffe2::NetDef proto;
     CAFFE_ENFORCE(ParseProtoFromLargeString(def.cast<std::string>(), &proto));
 
     auto nn = caffe2::convertToNNModule(proto);
-    opt::OptimizeForIdeep(&nn, gWorkspace, training_mode);
+    opt::OptimizeForMkldnn(&nn, gWorkspace, training_mode);
     auto new_proto = caffe2::convertToCaffe2Proto(nn, proto);
 
     std::string out;
index 2b20af2..ff4971e 100644 (file)
@@ -64,7 +64,8 @@ public:
         numpy_type != -1,
         "Unsupported ideep memory data type? This usually should not happen "
         "since ideep memory usually only do float and double.");
-    itensor::dims dims = atensor.get_dims();
+    itensor::dims dims = atensor.get_public_format_dims();
+
     std::vector<npy_intp> npy_dims(dims.begin(), dims.end());
 
     result.copied = force_copy || atensor.need_reorder();
index 2f9fc7a..ed0a327 100644 (file)
@@ -46,9 +46,9 @@ def fuseNNPACKConvRelu(net):
     )
 
 
-def optimizeForIDEEP(net, training_mode = False):
+def optimizeForMKLDNN(net, training_mode = False):
     net.Proto().ParseFromString(
-        C.transform_optimizeForIDEEP(net.Proto().SerializeToString(), training_mode)
+        C.transform_optimizeForMKLDNN(net.Proto().SerializeToString(), training_mode)
     )