if (isLegacyOperator()) {
CAFFE_ENFORCE_WITH_CALLER(
options.device_opt() != c10::nullopt,
- "device must be provided in option.");
+ "device must be provided in options.");
return BlobGetMutableTensor(outputs_.at(idx), dims, options);
}
auto* ival = ivalue_outputs_[idx];
at::TensorOptions options,
const Tensor& src,
bool async = false) {
- Tensor* t = Output<Tensor>(idx, options.device().type());
- // TODO:
- // We plan to use the following:
- // Tensor* t = OutputTensor(idx, src.sizes(), src.options()+options);
- // that is overwrite options of src Tensor
- CAFFE_ENFORCE(
- !t->dtype_initialized() || t->dtype() == src.dtype(),
- "We don't allow a change of data type in OutputTensor");
+ CAFFE_ENFORCE_WITH_CALLER(
+ options.device_opt() != c10::nullopt,
+ "device must be provided in options.");
+ // Ouptut Tensor will always have the same data type as `src`
+ if (!options.has_dtype()) {
+ options = options.dtype(src.dtype());
+ }
+ CAFFE_ENFORCE_WITH_CALLER(
+ options.dtype() == src.dtype(),
+ "We don't allow change of src data type in OutputTensorCopyFrom");
+ Tensor* t = OutputTensor(idx, src.sizes(), options);
t->CopyFrom(src, async);
return t;
}
}
Tensor XOutput(int idx, at::IntList dims, at::TensorOptions options) {
+ // We'll default device to the device of the current Operator Context
if (options.device_opt() == c10::nullopt) {
return OperatorBase::XOutputTensor(
idx, dims, options.device(context_.device()));
}
Tensor* Output(int idx, at::IntList dims, at::TensorOptions options) {
+ // We'll default device to the device of the current Operator Context
if (options.device_opt() == c10::nullopt) {
return OperatorBase::OutputTensor(
idx, dims, options.device(context_.device()));
return OperatorBase::template Output<Tensor>(idx, type);
}
+ Tensor* OutputTensorCopyFrom(
+ int idx,
+ at::TensorOptions options,
+ const Tensor& src,
+ bool async = false) {
+ if (options.device_opt() == c10::nullopt) {
+ return OperatorBase::OutputTensorCopyFrom(
+ idx, options.device(context_.device()), src, async);
+ }
+ return OperatorBase::OutputTensorCopyFrom(idx, options, src, async);
+ }
+
void WaitEvent(const Event& ev, int stream_id = -1) final {
if (stream_id >= 0) {
context_.SwitchToDevice(stream_id);
/* using override */ using OperatorBase::OutputSize; \
/* using override */ using OperatorBase::IsInputOutputAlias
-#define USE_OPERATOR_FUNCTIONS(context) \
- USE_OPERATOR_BASE_FUNCTIONS; \
- /* using override */ using Operator<context>::context_; \
- /* using override */ using Operator<context>::Input; \
- /* using override */ using Operator<context>::InputBlob; \
- /* using override */ using Operator<context>::Output; \
- /* using override */ using Operator<context>::OutputBlob
+#define USE_OPERATOR_FUNCTIONS(context) \
+ USE_OPERATOR_BASE_FUNCTIONS; \
+ /* using override */ using Operator<context>::context_; \
+ /* using override */ using Operator<context>::Input; \
+ /* using override */ using Operator<context>::InputBlob; \
+ /* using override */ using Operator<context>::Output; \
+ /* using override */ using Operator<context>::OutputBlob; \
+ /* using override */ using Operator<context>::OutputTensorCopyFrom
#define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context)