From: Jerry Zhang Date: Thu, 20 Dec 2018 02:10:36 +0000 (-0800) Subject: default options for OutputTensorCopyFrom (#15248) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2161 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5dd5ef3214c77b8683361dbda087cded855bbce3;p=platform%2Fupstream%2Fpytorch.git default options for OutputTensorCopyFrom (#15248) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15248 OutputTensorCopyFrom takes four arguments: index, a source Tensor, TensorOptions and whether we want to perform an async call. We want to provide some default option for TensorOptions, (1). default device to context_.device() (2). default dtype to input.dtype(). User can also explicitly provide these options to override default values. next diff will change the order of TensorOptions parameter so that user don't need to write down tensor options unless they want to override. Reviewed By: dzhulgakov Differential Revision: D13453824 fbshipit-source-id: 87401f81c7c3f9fd3d8936c710e6c2e04a59b689 --- diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 3e56e87..5e1a9ab 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -186,7 +186,7 @@ class CAFFE2_API OperatorBase : public Observable { if (isLegacyOperator()) { CAFFE_ENFORCE_WITH_CALLER( options.device_opt() != c10::nullopt, - "device must be provided in option."); + "device must be provided in options."); return BlobGetMutableTensor(outputs_.at(idx), dims, options); } auto* ival = ivalue_outputs_[idx]; @@ -208,14 +208,17 @@ class CAFFE2_API OperatorBase : public Observable { at::TensorOptions options, const Tensor& src, bool async = false) { - Tensor* t = Output(idx, options.device().type()); - // TODO: - // We plan to use the following: - // Tensor* t = OutputTensor(idx, src.sizes(), src.options()+options); - // that is overwrite options of src Tensor - CAFFE_ENFORCE( - !t->dtype_initialized() || t->dtype() == src.dtype(), - "We don't allow a change of data type in OutputTensor"); + CAFFE_ENFORCE_WITH_CALLER( + options.device_opt() != c10::nullopt, + "device must be provided in options."); + // Ouptut Tensor will always have the same data type as `src` + if (!options.has_dtype()) { + options = options.dtype(src.dtype()); + } + CAFFE_ENFORCE_WITH_CALLER( + options.dtype() == src.dtype(), + "We don't allow change of src data type in OutputTensorCopyFrom"); + Tensor* t = OutputTensor(idx, src.sizes(), options); t->CopyFrom(src, async); return t; } @@ -587,6 +590,7 @@ class Operator : public OperatorBase { } Tensor XOutput(int idx, at::IntList dims, at::TensorOptions options) { + // We'll default device to the device of the current Operator Context if (options.device_opt() == c10::nullopt) { return OperatorBase::XOutputTensor( idx, dims, options.device(context_.device())); @@ -595,6 +599,7 @@ class Operator : public OperatorBase { } Tensor* Output(int idx, at::IntList dims, at::TensorOptions options) { + // We'll default device to the device of the current Operator Context if (options.device_opt() == c10::nullopt) { return OperatorBase::OutputTensor( idx, dims, options.device(context_.device())); @@ -606,6 +611,18 @@ class Operator : public OperatorBase { return OperatorBase::template Output(idx, type); } + Tensor* OutputTensorCopyFrom( + int idx, + at::TensorOptions options, + const Tensor& src, + bool async = false) { + if (options.device_opt() == c10::nullopt) { + return OperatorBase::OutputTensorCopyFrom( + idx, options.device(context_.device()), src, async); + } + return OperatorBase::OutputTensorCopyFrom(idx, options, src, async); + } + void WaitEvent(const Event& ev, int stream_id = -1) final { if (stream_id >= 0) { context_.SwitchToDevice(stream_id); @@ -773,13 +790,14 @@ class Operator : public OperatorBase { /* using override */ using OperatorBase::OutputSize; \ /* using override */ using OperatorBase::IsInputOutputAlias -#define USE_OPERATOR_FUNCTIONS(context) \ - USE_OPERATOR_BASE_FUNCTIONS; \ - /* using override */ using Operator::context_; \ - /* using override */ using Operator::Input; \ - /* using override */ using Operator::InputBlob; \ - /* using override */ using Operator::Output; \ - /* using override */ using Operator::OutputBlob +#define USE_OPERATOR_FUNCTIONS(context) \ + USE_OPERATOR_BASE_FUNCTIONS; \ + /* using override */ using Operator::context_; \ + /* using override */ using Operator::Input; \ + /* using override */ using Operator::InputBlob; \ + /* using override */ using Operator::Output; \ + /* using override */ using Operator::OutputBlob; \ + /* using override */ using Operator::OutputTensorCopyFrom #define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context)