default options for OutputTensorCopyFrom (#15248)
authorJerry Zhang <jerryzh@fb.com>
Thu, 20 Dec 2018 02:10:36 +0000 (18:10 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 02:14:47 +0000 (18:14 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15248

OutputTensorCopyFrom takes four arguments: index, a source Tensor, TensorOptions and whether we want to perform an async call.
We want to provide some default option for TensorOptions, (1). default device to context_.device() (2). default dtype to input.dtype(). User can also explicitly provide these options to override default values.

next diff will change the order of TensorOptions parameter so that user don't need to write down tensor options unless they want to override.

Reviewed By: dzhulgakov

Differential Revision: D13453824

fbshipit-source-id: 87401f81c7c3f9fd3d8936c710e6c2e04a59b689

caffe2/core/operator.h

index 3e56e87..5e1a9ab 100644 (file)
@@ -186,7 +186,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
     if (isLegacyOperator()) {
       CAFFE_ENFORCE_WITH_CALLER(
           options.device_opt() != c10::nullopt,
-          "device must be provided in option.");
+          "device must be provided in options.");
       return BlobGetMutableTensor(outputs_.at(idx), dims, options);
     }
     auto* ival = ivalue_outputs_[idx];
@@ -208,14 +208,17 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
       at::TensorOptions options,
       const Tensor& src,
       bool async = false) {
-    Tensor* t = Output<Tensor>(idx, options.device().type());
-    // TODO:
-    // We plan to use the following:
-    // Tensor* t = OutputTensor(idx, src.sizes(), src.options()+options);
-    // that is overwrite options of src Tensor
-    CAFFE_ENFORCE(
-        !t->dtype_initialized() || t->dtype() == src.dtype(),
-        "We don't allow a change of data type in OutputTensor");
+    CAFFE_ENFORCE_WITH_CALLER(
+        options.device_opt() != c10::nullopt,
+        "device must be provided in options.");
+    // Ouptut Tensor will always have the same data type as `src`
+    if (!options.has_dtype()) {
+      options = options.dtype(src.dtype());
+    }
+    CAFFE_ENFORCE_WITH_CALLER(
+        options.dtype() == src.dtype(),
+        "We don't allow change of src data type in OutputTensorCopyFrom");
+    Tensor* t = OutputTensor(idx, src.sizes(), options);
     t->CopyFrom(src, async);
     return t;
   }
@@ -587,6 +590,7 @@ class Operator : public OperatorBase {
   }
 
   Tensor XOutput(int idx, at::IntList dims, at::TensorOptions options) {
+    // We'll default device to the device of the current Operator Context
     if (options.device_opt() == c10::nullopt) {
       return OperatorBase::XOutputTensor(
           idx, dims, options.device(context_.device()));
@@ -595,6 +599,7 @@ class Operator : public OperatorBase {
   }
 
   Tensor* Output(int idx, at::IntList dims, at::TensorOptions options) {
+    // We'll default device to the device of the current Operator Context
     if (options.device_opt() == c10::nullopt) {
       return OperatorBase::OutputTensor(
           idx, dims, options.device(context_.device()));
@@ -606,6 +611,18 @@ class Operator : public OperatorBase {
     return OperatorBase::template Output<Tensor>(idx, type);
   }
 
+  Tensor* OutputTensorCopyFrom(
+      int idx,
+      at::TensorOptions options,
+      const Tensor& src,
+      bool async = false) {
+    if (options.device_opt() == c10::nullopt) {
+      return OperatorBase::OutputTensorCopyFrom(
+          idx, options.device(context_.device()), src, async);
+    }
+    return OperatorBase::OutputTensorCopyFrom(idx, options, src, async);
+  }
+
   void WaitEvent(const Event& ev, int stream_id = -1) final {
     if (stream_id >= 0) {
       context_.SwitchToDevice(stream_id);
@@ -773,13 +790,14 @@ class Operator : public OperatorBase {
   /* using override */ using OperatorBase::OutputSize;              \
   /* using override */ using OperatorBase::IsInputOutputAlias
 
-#define USE_OPERATOR_FUNCTIONS(context)                    \
-  USE_OPERATOR_BASE_FUNCTIONS;                             \
-  /* using override */ using Operator<context>::context_;  \
-  /* using override */ using Operator<context>::Input;     \
-  /* using override */ using Operator<context>::InputBlob; \
-  /* using override */ using Operator<context>::Output;    \
-  /* using override */ using Operator<context>::OutputBlob
+#define USE_OPERATOR_FUNCTIONS(context)                     \
+  USE_OPERATOR_BASE_FUNCTIONS;                              \
+  /* using override */ using Operator<context>::context_;   \
+  /* using override */ using Operator<context>::Input;      \
+  /* using override */ using Operator<context>::InputBlob;  \
+  /* using override */ using Operator<context>::Output;     \
+  /* using override */ using Operator<context>::OutputBlob; \
+  /* using override */ using Operator<context>::OutputTensorCopyFrom
 
 #define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context)