Fix fallback issues to handle inplace case (#15726)
authorGu, Jinghui <jinghui.gu@intel.com>
Fri, 11 Jan 2019 03:44:29 +0000 (19:44 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 03:47:09 +0000 (19:47 -0800)
Summary:
Fix fallback issues to handle inplace case
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15726

Differential Revision: D13591243

Pulled By: yinghai

fbshipit-source-id: 6897f1daacb36beabcdfc22c39242bbdfdd0e534

caffe2/ideep/operators/operator_fallback_ideep.h

index 118f92d..77001b9 100644 (file)
@@ -99,12 +99,14 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
         }
       } else {
         VLOG(1) << "Input " << i << " is not ideep::tensor. Skipping copy.";
-        // Note(jiayq): This removes a const but conceptually
-        // local_input_blobs will only be used as const blob input for the
-        // base op so we are still fine.
-        local_input_blobs_[i]->ShareExternal(
-            const_cast<void *>(OperatorBase::Inputs()[i]->GetRaw()),
-            OperatorBase::Inputs()[i]->meta());
+        if (OperatorBase::Inputs()[i]->GetRaw() != local_input_blobs_[i]->GetRaw()) {
+          // Note(jiayq): This removes a const but conceptually
+          // local_input_blobs will only be used as const blob input for the
+          // base op so we are still fine.
+          local_input_blobs_[i]->ShareExternal(
+              const_cast<void *>(OperatorBase::Inputs()[i]->GetRaw()),
+              OperatorBase::Inputs()[i]->meta());
+        }
         input_share_[i] = true;
       }
     }
@@ -150,8 +152,13 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
       } else {
         VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor";
         Blob* dst = OperatorBase::OutputBlob(i);
-        dst->Reset(new Tensor(CPU));
-        BlobSetTensor(dst, src.Alias());
+        if (output_inplace_[i]) {
+          auto dtensor = BlobGetMutableTensor(dst, CPU);
+          dtensor->CopyFrom(src);
+        } else {
+          dst->Reset(new Tensor(CPU));
+          BlobSetTensor(dst, src.Alias());
+        }
       }
     }
     return true;