Fallback to CPU concat op to handle TensorCPU inputs (#15263)
authorGu, Jinghui <jinghui.gu@intel.com>
Mon, 7 Jan 2019 19:07:51 +0000 (11:07 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 7 Jan 2019 19:13:23 +0000 (11:13 -0800)
Summary:
Fallback to CPU concat op to handle TensorCPU inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15263

Differential Revision: D13587030

Pulled By: yinghai

fbshipit-source-id: 010a8579d61c3beb8556eb92493a552b2ab0030c

caffe2/ideep/operators/concat_split_op.cc
caffe2/python/ideep/concat_split_op_test.py

index 1a0de0c..b105aad 100644 (file)
@@ -1,4 +1,6 @@
 #include <caffe2/ideep/ideep_utils.h>
+#include <caffe2/ideep/operators/operator_fallback_ideep.h>
+#include <caffe2/operators/concat_split_op.h>
 
 namespace caffe2 {
 
@@ -6,9 +8,11 @@ class IDEEPConcatOp final : public IDEEPOperator {
  public:
   USE_IDEEP_DEF_ALIASES();
   USE_IDEEP_OPERATOR_FUNCTIONS();
+  using FALLBACK_OP = IDEEPFallbackOp<ConcatOp<CPUContext>, SkipIndices<0>>;
 
   IDEEPConcatOp(const OperatorDef& operator_def, Workspace* ws)
-      : IDEEPOperator(operator_def, ws) {
+      : IDEEPOperator(operator_def, ws),
+        fallback_(operator_def, ws) {
     CAFFE_ENFORCE(
       !(OperatorBase::HasArgument("axis") && OperatorBase::HasArgument("order")),
         "You shouldn't specify both the dim to concat, and the order "
@@ -25,39 +29,47 @@ class IDEEPConcatOp final : public IDEEPOperator {
   virtual ~IDEEPConcatOp() {}
 
   bool RunOnDevice() override {
-    auto* output = Output(OUTPUT);
+    bool fallback_to_cpu = false;
+    vector<itensor> inputs_itensor;
 
-    vector<itensor> inputs;
     for (int i = 0; i < InputSize(); ++i) {
       if (OperatorBase::InputBlob(i).template IsType<itensor>()) {
-        inputs.emplace_back(Input(i));
+        auto& tensor_ideep = Input(i);
+        if (tensor_ideep.ndims() == 0 || tensor_ideep.get_nelems() == 0)
+          continue;
+        inputs_itensor.emplace_back(tensor_ideep);
       } else {
         CAFFE_ENFORCE(
             BlobIsTensorType(OperatorBase::InputBlob(i), CPU),
             "Expect cpu tensor if not itensor");
         auto& tensor_cpu = OperatorBase::Input<Tensor>(i, CPU);
-        CAFFE_ENFORCE(
-            tensor_cpu.sizes().size() == 0 || tensor_cpu.numel() == 0,
-            "Expect zero dim tensor");
+        if (tensor_cpu.sizes().size() == 0 || tensor_cpu.numel() == 0)
+          continue;
+        fallback_to_cpu = true;
+        break;
       }
     }
 
-    auto axis_vdata = ideep::concat::compute(inputs, axis_, add_axis_, *output);
-    Tensor* axis_info = OutputTensor(
-        AXIS_INFO,
-        vector<int64_t>(1, InputSize()),
-        at::dtype<int>().device(CPU));
-    int* axis_data = axis_info->template mutable_data<int>();
-    for (int i = 0; i < axis_vdata.size(); i++) {
-      axis_data[i] = axis_vdata[i];
+    if (!fallback_to_cpu) {
+      auto* output = Output(OUTPUT);
+      Tensor* axis_info = OutputTensor(AXIS_INFO,
+        vector<int64_t>(1, InputSize()), at::dtype<int>().device(CPU));
+      auto* axis_data = axis_info->template mutable_data<int>();
+      auto axis_vdata =
+        ideep::concat::compute(inputs_itensor, axis_, add_axis_, *output);
+      for (int i = 0; i < axis_vdata.size(); i++) {
+        axis_data[i] = axis_vdata[i];
+      }
+      return true;
     }
 
-    return true;
+    return fallback_.Run(0);
   }
 
  private:
   int axis_;
   int add_axis_;
+  FALLBACK_OP fallback_;
 
   INPUT_TAGS(INPUT0);
   OUTPUT_TAGS(OUTPUT, AXIS_INFO);
index 3a5763f..29e7750 100644 (file)
@@ -113,6 +113,49 @@ class TestConcatSplitOps(hu.HypothesisTestCase):
             self.assertGradientChecks(gc, op, splits, i, [0])
 
 
+    @given(tensor_splits=_tensor_splits(add_axis=True), **mu.gcs)
+    def test_concat_with_TensorCPU(self, tensor_splits, gc, dc):
+        axis, _, splits = tensor_splits
+        op0 = core.CreateOperator(
+            "Concat",
+            ['X_{}'.format(i) for i in range(len(splits))],
+            ['concat_result0', 'split_info0'],
+            axis=axis,
+            add_axis=1,
+            device_option=dc[0]
+        )
+        op1 = core.CreateOperator(
+            "Concat",
+            ['X_{}'.format(i) for i in range(len(splits))],
+            ['concat_result1', 'split_info1'],
+            axis=axis,
+            add_axis=1,
+            device_option=dc[1]
+        )
+
+        for i, X in enumerate(splits):
+            workspace.FeedBlob('X_{}'.format(i), X, dc[0])
+
+        workspace.RunOperatorOnce(op0)
+        res0 = workspace.FetchBlob('concat_result0')
+        inf0 = workspace.FetchBlob('split_info0')
+
+        workspace.RunOperatorOnce(op1)
+        res1 = workspace.FetchBlob('concat_result1')
+        inf1 = workspace.FetchBlob('split_info1')
+
+        if not np.allclose(res0, res1, atol=0.0, rtol=0.0):
+            print(res1.flatten())
+            print(res0.flatten())
+            print(np.max(np.abs(res1 - res0)))
+            self.assertTrue(False)
+
+        if not np.allclose(inf0, inf1, atol=0.0, rtol=0.0):
+            print(inf1.flatten())
+            print(inf0.flatten())
+            print(np.max(np.abs(inf1 - inf0)))
+            self.assertTrue(False)
+
 
 if __name__ == "__main__":
     unittest.main()