Fallback sum/add to CPU if needed (#15267)
authorGu, Jinghui <jinghui.gu@intel.com>
Wed, 6 Feb 2019 17:25:42 +0000 (09:25 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 6 Feb 2019 17:35:14 +0000 (09:35 -0800)
Summary:
Fallback sum/add to CPU if needed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15267

Differential Revision: D13935064

Pulled By: yinghai

fbshipit-source-id: eb228683d00a0462a1970f849d35365bc98340d6

caffe2/ideep/operators/elementwise_sum_op.cc
caffe2/python/ideep/elementwise_sum_op_test.py

index 373d880..0a55fe0 100644 (file)
@@ -1,4 +1,7 @@
 #include <caffe2/ideep/ideep_utils.h>
+#include <caffe2/ideep/operators/operator_fallback_ideep.h>
+#include "caffe2/operators/utility_ops.h"
+#include "caffe2/operators/elementwise_add_op.h"
 
 namespace caffe2 {
 
@@ -6,39 +9,66 @@ class IDEEPSumOp final : public IDEEPOperator {
  public:
   USE_IDEEP_DEF_ALIASES();
   USE_IDEEP_OPERATOR_FUNCTIONS();
+  using FALLBACK_SUM = IDEEPFallbackOp<SumOp<CPUContext>, SkipIndices<0>>;
+  using FALLBACK_ADD = IDEEPFallbackOp<BinaryElementwiseOp<
+    NumericTypes, CPUContext, AddFunctor<CPUContext>>, SkipIndices<0>>;
 
   IDEEPSumOp(const OperatorDef& operator_def, Workspace* ws)
-      : IDEEPOperator(operator_def, ws) {}
+      : IDEEPOperator(operator_def, ws),
+        fallback_sum_(operator_def, ws),
+        fallback_add_(operator_def, ws) {}
   virtual ~IDEEPSumOp() {}
 
   bool RunOnDevice() override {
-    const auto& X = Input(INPUT0);
-    auto* Y = Output(OUTPUT);
-
-    if (InputSize() == 1) {
-      ideep::direct_copy::compute(X, *Y);
-
-    } else {
-      vector<itensor> inputs;
-      const vector<float> scales(InputSize(), 1.0);
-      const auto dims = X.get_dims();
-      for (int i = 0; i < InputSize(); ++i) {
-        if (Input(i).get_dims() != dims) {
-          CAFFE_ENFORCE_EQ(
-              dims,
-              Input(i).get_dims(),
-              "Broadcast is not yet supported with IDEEP.");
+    itensor::dims input_dims;
+    bool fallback_to_cpu = false;
+    vector<itensor> inputs_itensor;
+
+    // We only support element-wise sum for ideep tensors here.
+    // If a CPU tensor is detected in input list, we have to fallback
+    // to corresponding CPU operator.
+    for (int i = 0; i < InputSize(); ++i) {
+      if (OperatorBase::InputBlob(i).template IsType<itensor>()) {
+        auto& tensor_ideep = Input(i);
+        if (input_dims.empty()) {
+          input_dims = tensor_ideep.get_dims();
+        } else if (input_dims != tensor_ideep.get_dims()) {
+          fallback_to_cpu = true;
+          break;
         }
-        inputs.emplace_back(Input(i));
+        inputs_itensor.emplace_back(tensor_ideep);
+      } else {
+        CAFFE_ENFORCE(
+            BlobIsTensorType(OperatorBase::InputBlob(i), CPU),
+            "Expect cpu tensor if not itensor");
+        fallback_to_cpu = true;
+        break;
+      }
+    }
+
+    if (!fallback_to_cpu) {
+      auto* Y = Output(OUTPUT);
+      if (InputSize() == 1) {
+        const auto& X = Input(INPUT0);
+        ideep::direct_copy::compute(X, *Y);
+      } else {
+        const vector<float> scales(InputSize(), 1.0);
+        ideep::sum::compute(scales, inputs_itensor, *Y);
       }
+      return true;
+    }
 
-      ideep::sum::compute(scales, inputs, *Y);
+    if (InputSize() == 2) {
+      return fallback_add_.Run(0);
     }
 
-    return true;
+    return fallback_sum_.Run(0);
   }
 
  private:
+  FALLBACK_SUM fallback_sum_;
+  FALLBACK_ADD fallback_add_;
+
   INPUT_TAGS(INPUT0);
   OUTPUT_TAGS(OUTPUT);
 };
index 5444502..9b08edb 100644 (file)
@@ -38,5 +38,45 @@ class ElementwiseSumTest(hu.HypothesisTestCase):
         self.assertDeviceChecks(dc, op, Xs, [0])
 
 
+    @given(size=st.integers(7, 9),
+           input_channels=st.integers(1, 3),
+           batch_size=st.integers(1, 3),
+           inputs=st.integers(2, 7),
+           inplace=st.booleans(),
+           **mu.gcs)
+    def test_elementwise_sum_fallback(self,
+                                      size,
+                                      input_channels,
+                                      batch_size,
+                                      inputs,
+                                      inplace,
+                                      gc,
+                                      dc):
+        op = core.CreateOperator(
+            "Sum",
+            ["X_{}".format(i) for i in range(inputs)],
+            ["X_0" if inplace else "Y"],
+            device_option=dc[1]
+        )
+        Xs = [np.random.rand(batch_size, input_channels, size, size).astype(
+            np.float32) for _ in range(inputs)]
+
+        sum_val = Xs[0]
+        workspace.FeedBlob("X_0", Xs[0], dc[0])
+        for i, x in enumerate(Xs):
+            if i == 0: continue
+            sum_val += x
+            workspace.FeedBlob("X_{}".format(i), x, dc[1])
+
+        workspace.RunOperatorOnce(op)
+        Y = workspace.FetchBlob("X_0" if inplace else "Y")
+
+        if not np.allclose(sum_val, Y, atol=0.01, rtol=0.01):
+            print(Y.flatten())
+            print(sum_val.flatten())
+            print(np.max(np.abs(Y - sum_val)))
+            self.assertTrue(False)
+
+
 if __name__ == "__main__":
     unittest.main()