Enhance cpu support on gloo based multi-nodes mode. (#11330)
authorShane Li <shane.li@intel.com>
Tue, 15 Jan 2019 19:07:55 +0000 (11:07 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 15 Jan 2019 19:47:10 +0000 (11:47 -0800)
Summary:
1. Add some gloo communication operators into related fallback list;
2. Work around to avoid compiling errors while using fallback operator whose CPU operator inherits from 'OperatorBase' directly like PrefetchOperator;
3. Add new cpu context support for some python module files and resnet50 training example file.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11330

Reviewed By: yinghai

Differential Revision: D13624519

Pulled By: wesolwsk

fbshipit-source-id: ce39d57ddb8cd7786db2e873bfe954069d972f4f

caffe2/ideep/operators/operator_fallback_ideep.cc
caffe2/ideep/operators/operator_fallback_ideep.h
caffe2/image/image_input_op.cc
caffe2/python/data_parallel_model.py
caffe2/python/examples/resnet50_trainer.py
caffe2/python/predictor/predictor_exporter.py
cmake/Dependencies.cmake

index 6016923..bacb546 100644 (file)
 #include <caffe2/operators/transpose_op.h>
 #include <caffe2/operators/affine_channel_op.h>
 #include <caffe2/operators/stop_gradient.h>
+#include <caffe2/operators/order_switch_ops.h>
+#include <caffe2/operators/softmax_with_loss_op.h>
 #include <caffe2/sgd/iter_op.h>
 #include <caffe2/sgd/learning_rate_op.h>
 #include <caffe2/queue/queue_ops.h>
 #include <caffe2/operators/tensor_protos_db_input.h>
 
+#ifdef CAFFE2_USE_GLOO
+#include <caffe2/contrib/gloo/common_world_ops.h>
+#include <caffe2/contrib/gloo/broadcast_ops.h>
+#include <caffe2/contrib/gloo/allreduce_ops.h>
+#include <caffe2/contrib/gloo/allgather_ops.h>
+#include <caffe2/contrib/gloo/barrier_ops.h>
+#include <caffe2/contrib/gloo/reduce_scatter_ops.h>
+#endif
+
 // can add more non-IDEEP operators if needed
 namespace caffe2 {
 
@@ -170,5 +181,48 @@ REGISTER_IDEEP_OPERATOR(
         MulFunctor<CPUContext>>>);
 REGISTER_IDEEP_OPERATOR(TensorProtosDBInput, IDEEPFallbackOp<TensorProtosDBInput<CPUContext>>);
 REGISTER_IDEEP_OPERATOR(CloseBlobsQueue, IDEEPFallbackOp<CloseBlobsQueueOp<CPUContext>>);
+REGISTER_IDEEP_OPERATOR(
+    SoftmaxWithLoss,
+    IDEEPFallbackOp<SoftmaxWithLossOp<float, CPUContext>>);
+REGISTER_IDEEP_OPERATOR(
+    SoftmaxWithLossGradient,
+    IDEEPFallbackOp<SoftmaxWithLossGradientOp<float, CPUContext>>);
+REGISTER_IDEEP_OPERATOR(
+    NHWC2NCHW,
+    IDEEPFallbackOp<NHWC2NCHWOp<float, CPUContext>>);
+REGISTER_IDEEP_OPERATOR(
+    NCHW2NHWC,
+    IDEEPFallbackOp<NCHW2NHWCOp<float, CPUContext>>);
+
+#ifdef CAFFE2_USE_GLOO
+namespace gloo {
+// gloo operators
+REGISTER_IDEEP_OPERATOR(
+    CreateCommonWorld,
+    IDEEPFallbackOp<CreateCommonWorld<CPUContext>, SkipIndices<0>>);
+REGISTER_IDEEP_OPERATOR(
+    CloneCommonWorld,
+    IDEEPFallbackOp<CloneCommonWorld<CPUContext>, SkipIndices<0>>);
+REGISTER_IDEEP_OPERATOR(
+    DestroyCommonWorld,
+    IDEEPFallbackOp<DestroyCommonWorld>);
+REGISTER_IDEEP_OPERATOR(
+    Broadcast,
+    IDEEPFallbackOp<BroadcastOp<CPUContext>>);
+REGISTER_IDEEP_OPERATOR(
+    Allreduce,
+    IDEEPFallbackOp<AllreduceOp<CPUContext>>);
+REGISTER_IDEEP_OPERATOR(
+    Allgather,
+    IDEEPFallbackOp<AllgatherOp<CPUContext>>);
+REGISTER_IDEEP_OPERATOR(
+    Barrier,
+    IDEEPFallbackOp<BarrierOp<CPUContext>>);
+REGISTER_IDEEP_OPERATOR(
+    ReduceScatter,
+    IDEEPFallbackOp<ReduceScatterOp<CPUContext>>);
+
+} // namespace gloo
+#endif
 
 } // namespace caffe2
index e834ed9..26d07ea 100644 (file)
@@ -111,6 +111,8 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
       }
     }
 
+    // Some CPU ops inherited from OperatorBase directly might need this default
+    // input argument '0' like 'PrefetchOperator'.
     if (!base_op_->Run(0)) {
       LOG(ERROR) << "Base op run failed in IDEEPFallbackOp. Def: "
                  << ProtoDebugString(this->debug_def());
index 4af9328..a01994c 100644 (file)
@@ -1,5 +1,10 @@
 #include "caffe2/image/image_input_op.h"
 
+#ifdef CAFFE2_USE_MKLDNN
+#include <caffe2/ideep/operators/operator_fallback_ideep.h>
+#include <caffe2/ideep/utils/ideep_operator.h>
+#endif
+
 namespace caffe2 {
 
 template <>
@@ -112,4 +117,10 @@ The dimension of the output image will always be cropxcrop
 
 NO_GRADIENT(ImageInput);
 
+#ifdef CAFFE2_USE_MKLDNN
+REGISTER_IDEEP_OPERATOR(
+    ImageInput,
+    IDEEPFallbackOp<ImageInputOp<CPUContext>>);
+#endif
+
 }  // namespace caffe2
index 7a76545..cfe10ec 100644 (file)
@@ -38,6 +38,9 @@ def Parallelize_CPU(*args, **kwargs):
     kwargs['cpu_device'] = True
     Parallelize(*args, **kwargs)
 
+def Parallelize_iDeep(*args, **kwargs):
+    kwargs['ideep'] = True
+    Parallelize(*args, **kwargs)
 
 def Parallelize(
     model_helper_obj,
@@ -58,6 +61,7 @@ def Parallelize(
     use_nccl=False,
     max_concurrent_distributed_ops=16,
     cpu_device=False,
+    ideep=False,
     num_threads_per_device=4,
     shared_model=False,
     combine_spatial_bn=False,
@@ -119,6 +123,7 @@ def Parallelize(
       blobs_to_keep :   A list of blob names to keep and don't free during
                         dynamic memory optimization (for example loss blob).
       cpu_device        Use CPU instead of GPU.
+      ideep             Use ideep.
       combine_spatial_bn:
                         When set to True, applies batch normalization across
                         all devices within the node. If False, batch
@@ -135,12 +140,12 @@ def Parallelize(
         device scope was: {}".format(scope.CurrentDeviceScope())
 
     if devices is None:
-        if not cpu_device:
-            devices = list(range(0, workspace.NumGpuDevices()))
+        if not (cpu_device or ideep):
+            devices = list(range(0, workspace.NumCudaDevices()))
         else:
             devices = list(range(0, cpu_count()))
 
-    if not cpu_device:
+    if not (cpu_device or ideep):
         for gpu in devices:
             if gpu >= workspace.NumGpuDevices():
                 log.warning("** Only {} GPUs available, GPUs {} requested".format(
@@ -151,6 +156,13 @@ def Parallelize(
         model_helper_obj._shared_model = False
         device_name = "GPU"
         assert shared_model is False, "Shared model only supported on CPU"
+    elif ideep:
+        model_helper_obj._device_type = caffe2_pb2.IDEEP
+        model_helper_obj._device_prefix = "ideep"
+        device_name = "IDEEP"
+        model_helper_obj._shared_model = shared_model
+        if shared_model and rendezvous is not None:
+            assert "Shared model only supported on single-node currently"
     else:
         model_helper_obj._device_type = caffe2_pb2.CPU
         model_helper_obj._device_prefix = "cpu"
@@ -969,7 +981,7 @@ def GetLearningRateBlobNames(model):
     Returns a list of learning rates blob names used in the optimizer.
     '''
     if model._optimizer is not None:
-        if model._device_type == caffe2_pb2.CPU:
+        if model._device_type == caffe2_pb2.CPU or model._device_type == caffe2_pb2.IDEEP:
             return [model._optimizer.get_cpu_blob_name('lr')]
         elif core.IsGPUDeviceType(model._device_type):
             return [model._optimizer.get_gpu_blob_name('lr', gpu, '')
@@ -1160,6 +1172,7 @@ def _SyncAllParamsDistributed(
 
     gpu_device_opt = core.DeviceOption(model._device_type, devices[0])
     cpu_device_opt = core.DeviceOption(caffe2_pb2.CPU)
+    ideep_device_opt = core.DeviceOption(caffe2_pb2.IDEEP)
 
     if model._broadcast_context is None:
         model._broadcast_context = CollectivesConcurrencyControl(
@@ -1186,7 +1199,7 @@ def _SyncAllParamsDistributed(
 
         device_opt = gpu_device_opt if _IsGPUBlob(
             model, param_name
-        ) else cpu_device_opt
+        ) else ideep_device_opt if _IsIDEEPBlob(model, param_name) else cpu_device_opt
 
         if rendezvous['engine'] == 'GLOO':
             with core.DeviceScope(device_opt):
@@ -1587,6 +1600,17 @@ def _InferBlobDevice(model):
     map_ops(model.net.Proto())
     model._blob_to_device = mapping
 
+def _IsIDEEPBlob(model, blob_name):
+    if blob_name in model._blob_to_device:
+        return model._blob_to_device[blob_name].device_type == caffe2_pb2.IDEEP
+    else:
+        blob_name = "{}_{}/{}".format(
+            model._device_prefix, model._devices[0], blob_name
+        )
+        if blob_name not in model._blob_to_device:
+            return model._device_type == caffe2_pb2.IDEEP
+        return model._blob_to_device[blob_name].device_type == caffe2_pb2.IDEEP
+
 def _IsGPUBlob(model, blob_name):
     if blob_name in model._blob_to_device:
         return core.IsGPUDeviceType(model._blob_to_device[blob_name].device_type)
index 307d7b2..928c85b 100644 (file)
@@ -107,7 +107,7 @@ def AddNullInput(model, reader, batch_size, img_size, dtype):
     )
 
 
-def SaveModel(args, train_model, epoch):
+def SaveModel(args, train_model, epoch, use_ideep):
     prefix = "[]_{}".format(train_model._device_prefix, train_model._devices[0])
     predictor_export_meta = pred_exp.PredictorExportMeta(
         predict_net=train_model.net.Proto(),
@@ -134,10 +134,11 @@ def SaveModel(args, train_model, epoch):
         db_type="minidb",
         db_destination=model_path,
         predictor_export_meta=predictor_export_meta,
+        use_ideep = use_ideep
     )
 
 
-def LoadModel(path, model):
+def LoadModel(path, model, use_ideep):
     '''
     Load pretrained model from file
     '''
@@ -148,8 +149,14 @@ def LoadModel(path, model):
     predict_init_net = core.Net(pred_utils.GetNet(
         meta_net_def, predictor_constants.PREDICT_INIT_NET_TYPE))
 
-    predict_init_net.RunAllOnGPU()
-    init_net.RunAllOnGPU()
+    if use_ideep: 
+        predict_init_net.RunAllOnIDEEP()
+    else: 
+        predict_init_net.RunAllOnGPU()
+    if use_ideep: 
+        init_net.RunAllOnIDEEP()
+    else: 
+        init_net.RunAllOnGPU()
 
     assert workspace.RunNetOnce(predict_init_net)
     assert workspace.RunNetOnce(init_net)
@@ -288,12 +295,19 @@ def Train(args):
     log.info("Using epoch size: {}".format(args.epoch_size))
 
     # Create ModelHelper object
-    train_arg_scope = {
-        'order': 'NCHW',
-        'use_cudnn': True,
-        'cudnn_exhaustive_search': True,
-        'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024),
-    }
+    if args.use_ideep:
+        train_arg_scope = {
+            'use_cudnn': False,
+            'cudnn_exhaustive_search': False,
+            'training_mode': 1
+        }
+    else:
+        train_arg_scope = {
+            'order': 'NCHW',
+            'use_cudnn': True,
+            'cudnn_exhaustive_search': True,
+            'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024),
+        }
     train_model = model_helper.ModelHelper(
         name='resnext' + str(args.num_layers), arg_scope=train_arg_scope
     )
@@ -469,6 +483,7 @@ def Train(args):
         rendezvous=rendezvous,
         optimize_gradient_memory=False,
         cpu_device=args.use_cpu,
+        ideep=args.use_ideep,
         shared_model=args.use_cpu,
         combine_spatial_bn=args.use_cpu,
     )
@@ -482,11 +497,17 @@ def Train(args):
     test_model = None
     if (args.test_data is not None):
         log.info("----- Create test net ----")
-        test_arg_scope = {
-            'order': "NCHW",
-            'use_cudnn': True,
-            'cudnn_exhaustive_search': True,
-        }
+        if use_ideep:
+            test_arg_scope = {
+                'use_cudnn': False,
+                'cudnn_exhaustive_search': False,
+            }
+        else:
+            test_arg_scope = {
+                'order': "NCHW",
+                'use_cudnn': True,
+                'cudnn_exhaustive_search': True,
+            }
         test_model = model_helper.ModelHelper(
             name='resnext' + str(args.num_layers) + "_test",
             arg_scope=test_arg_scope,
@@ -526,7 +547,7 @@ def Train(args):
     epoch = 0
     # load the pre-trained model and reset epoch
     if args.load_model_path is not None:
-        LoadModel(args.load_model_path, train_model)
+        LoadModel(args.load_model_path, train_model, args.use_ideep)
 
         # Sync the model params
         data_parallel_model.FinalizeAfterCheckpoint(train_model)
@@ -564,7 +585,7 @@ def Train(args):
         )
 
         # Save the model for each epoch
-        SaveModel(args, train_model, epoch)
+        SaveModel(args, train_model, epoch, args.use_ideep)
 
         model_path = "%s/%s_" % (
             args.file_store_path,
@@ -638,6 +659,8 @@ def main():
                         help="Load previously saved model to continue training")
     parser.add_argument("--use_cpu", type=bool, default=False,
                         help="Use CPU instead of GPU")
+    parser.add_argument("--use_ideep", type=bool, default=False,
+                        help="Use ideep")
     parser.add_argument('--dtype', default='float',
                         choices=['float', 'float16'],
                         help='Data type used for training')
index 60b4ab7..b4a1635 100644 (file)
@@ -189,8 +189,9 @@ def set_model_info(meta_net_def, project_str, model_class_str, version):
     meta_net_def.modelInfo.version = version
 
 
-def save_to_db(db_type, db_destination, predictor_export_meta):
+def save_to_db(db_type, db_destination, predictor_export_meta, use_ideep = False):
     meta_net_def = get_meta_net_def(predictor_export_meta)
+    device_type = caffe2_pb2.IDEEP if use_ideep else caffe2_pb2.CPU
     with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
         workspace.FeedBlob(
             predictor_constants.META_NET_DEF,
@@ -202,6 +203,7 @@ def save_to_db(db_type, db_destination, predictor_export_meta):
     op = core.CreateOperator(
         "Save",
         blobs_to_save, [],
+        device_option = core.DeviceOption(device_type),
         absolute_path=True,
         db=db_destination, db_type=db_type)
 
index 84f73d4..a49e597 100644 (file)
@@ -823,6 +823,7 @@ if(USE_GLOO)
     if(USE_CUDA)
       list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS gloo_cuda)
     endif()
+    add_compile_options(-DCAFFE2_USE_GLOO)
   endif()
 endif()