Store ScalarType and Backend instead of Type in TensorIterator
authorRoy Li <royboy@fb.com>
Thu, 4 Apr 2019 09:21:09 +0000 (02:21 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 09:24:16 +0000 (02:24 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17601

Reviewed By: ezyang

Differential Revision: D14274754

fbshipit-source-id: b08880ae586b6ae57d4c0bbeb203796d087926c4

aten/src/ATen/native/Indexing.cpp
aten/src/ATen/native/TensorIterator.cpp
aten/src/ATen/native/TensorIterator.h
test/test_cuda.py

index 062d267..f236f99 100644 (file)
@@ -407,7 +407,7 @@ static AdvancedIndex make_info(Tensor self, TensorList orig) {
 static std::unique_ptr<TensorIterator> make_index_iterator(const AdvancedIndex& info) {
   auto builder = TensorIterator::Builder();
   builder.dont_compute_common_dtype();
-  builder.add_output(Tensor(), &info.src.dispatch_type());
+  builder.add_output(Tensor(), info.src.type().backend(), info.src.scalar_type());
   builder.add_input(info.src);
   for (auto& index : info.indices) {
     builder.add_input(index);
@@ -424,7 +424,7 @@ static std::unique_ptr<TensorIterator> make_index_put_iterator(const AdvancedInd
   builder.dont_compute_common_dtype();
   builder.dont_resize_outputs();
   builder.add_output(info.src);
-  builder.add_input(value, &info.src.dispatch_type());
+  builder.add_input(value, info.src.type().backend(), info.src.scalar_type());
   for (auto& index : info.indices) {
     builder.add_input(index);
   }
index b47caca..e86444a 100644 (file)
@@ -87,53 +87,53 @@ compute_result_type(at::ArrayRef<OperandInfo> operands, const F& predicate) {
 void TensorIterator::compute_types() {
   bool missing_dtypes = false;
   for (auto& op : operands_) {
-    if (!op.tensor.defined() && !op.type) {
+    if (!op.tensor.defined() && !op.is_type_defined()) {
       missing_dtypes = true;
     }
   }
 
   if (missing_dtypes || compute_common_dtype_) {
-    auto& type = compute_common_type();
+    ScalarType common_dtype;
+    Backend common_backend;
+    std::tie(common_backend, common_dtype) = compute_common_type();
     for (auto& op : operands_) {
-      auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.scalar_type());
-      if (!op.type) {
-        op.type = &type;
-      } else if (compute_common_dtype_ && op.type != &type) {
+      if (!op.is_type_defined()) {
+        op.set_type(common_backend, common_dtype);
+      } else if (compute_common_dtype_ && !op.is_type_equal(common_backend, common_dtype)) {
         if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 &&
-            type.device_type() == kCUDA && op_tensor_type.device_type() == kCPU) {
+            common_backend == Backend::CUDA && op.tensor.type().backend() == Backend::CPU) {
           // don't cast CPU scalars in CUDA ops that directly support them
-          op.type = &op_tensor_type;
+          op.set_type(op.tensor.type().backend(), op.tensor.scalar_type());
         } else if (promote_gpu_output_dtypes_ && op.tensor.defined() &&
-            !op.is_output && op_tensor_type.scalarType() == kHalf &&
-            type.scalarType() == kFloat && type.device_type() == kCUDA &&
-            op_tensor_type.device_type() == kCUDA) {
+            !op.is_output && op.tensor.scalar_type() == kHalf &&
+            common_dtype == kFloat && common_backend == Backend::CUDA &&
+            op.tensor.type().backend() == Backend::CUDA) {
           // allow input tensor type upcasting for fp16 to fp32 in fused kernel
           // on GPU
-          op.type = &op_tensor_type;
+          op.set_type(op.tensor.type().backend(), op.tensor.scalar_type());
         } else {
-          op.type = &type;
+          op.set_type(common_backend, common_dtype);
         }
       }
     }
   }
 
   for (auto& op : operands_) {
-    auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.scalar_type());
-    if (op.tensor.defined() && op_tensor_type != *op.type) {
+    if (op.tensor.defined() && !op.is_type_equal(op.tensor.type().backend(), op.tensor.scalar_type())) {
       if (op.is_output) {
-        AT_ERROR("output with type ", op_tensor_type.toString(),
-                 " doesn't match the desired type ", op.type->toString());
+        AT_ERROR("output with backend ", toString(op.tensor.type().backend()), " and dtype ", toString(op.tensor.scalar_type()),
+                 " doesn't match the desired backend ", toString(op.backend), " and dtype ", toString(op.dtype));
       } else if (op.tensor.dim() == 0) {
-        op.tensor = op.tensor.to(*op.type);
+        op.tensor = op.tensor.to(op.options());
       } else {
-        AT_ERROR("expected type ", op.type->toString(), " but got ",
-            op_tensor_type.toString());
+        AT_ERROR("expected backend ", toString(op.backend), " and dtype ", toString(op.dtype),
+                 " but got backend ", toString(op.tensor.type().backend()), " and dtype ", toString(op.tensor.scalar_type()));
       }
     }
   }
 }
 
-Type& TensorIterator::compute_common_type() {
+std::pair<Backend, ScalarType> TensorIterator::compute_common_type() {
   // See [Result type computation] in TensorIterator.h
   auto result_type = ScalarType::Undefined;
   auto backend = Backend::Undefined;
@@ -154,7 +154,7 @@ Type& TensorIterator::compute_common_type() {
   AT_ASSERT(result_type != ScalarType::Undefined);
   AT_ASSERT(backend != Backend::Undefined);
 
-  return at::globalContext().getNonVariableType(backend, result_type);
+  return std::make_pair(backend, result_type);
 }
 
 DimVector TensorIterator::compatible_stride(int element_size) const {
@@ -182,8 +182,8 @@ void TensorIterator::allocate_outputs() {
   for (int i = 0; i < num_outputs_; i++) {
     auto& op = operands_[i];
     if (!op.tensor.defined()) {
-      AT_ASSERTM(op.type, "no type for operand", i);
-      int element_size = op.type->typeMeta().itemsize();
+      AT_ASSERTM(op.is_type_defined(), "no type for operand", i);
+      int element_size = elementSize(op.dtype);
       op.stride_bytes = compatible_stride(element_size);
 
       auto tensor_shape = invert_perm(shape_);
@@ -191,7 +191,7 @@ void TensorIterator::allocate_outputs() {
       for (int dim = 0; dim < ndim(); dim++) {
         tensor_stride[dim] /= element_size;
       }
-      op.tensor = at::empty_strided(tensor_shape, tensor_stride, op.type->options());
+      op.tensor = at::empty_strided(tensor_shape, tensor_stride, op.options());
     }
   }
 }
@@ -420,7 +420,7 @@ bool TensorIterator::is_scalar(int arg) const {
 }
 
 bool TensorIterator::is_cpu_scalar(int arg) const {
-  return is_scalar(arg) && operands_[arg].tensor.type().device_type() == kCPU;
+  return is_scalar(arg) && device_type(arg) == kCPU;
 }
 
 void* TensorIterator::data_ptr(int arg) const {
index affcade..6a9ca8c 100644 (file)
@@ -66,10 +66,11 @@ struct DimCounter {
 };
 struct CAFFE2_API OperandInfo {
   OperandInfo() {}
-  OperandInfo(const Tensor& t, const Type* type=nullptr)
-    : tensor(t), type(const_cast<Type*>(type)) {
-      if (t.defined() && !type) {
-        this->type = &t.dispatch_type();
+  explicit OperandInfo(const Tensor& t, const Backend backend=Backend::Undefined, const ScalarType dtype=ScalarType::Undefined)
+    : tensor(t), backend(backend), dtype(dtype) {
+      if (t.defined() && (backend == Backend::Undefined || dtype == ScalarType::Undefined)) {
+        this->backend = t.type().backend();
+        this->dtype = t.scalar_type();
       }
   }
 
@@ -85,7 +86,25 @@ struct CAFFE2_API OperandInfo {
   /// input should be converted to this type if necessary. For outputs, this
   /// specifies which type to allocate. Note that there is very limited support
   /// for type conversions currently: they are only allowed for zero-dim tensors.
-  Type* type = nullptr;
+  Backend backend = Backend::Undefined;
+  ScalarType dtype = ScalarType::Undefined;
+
+  bool is_type_defined() {
+    return dtype != ScalarType::Undefined && backend != Backend::Undefined;
+  }
+
+  bool is_type_equal(Backend b, ScalarType s) {
+    return dtype == s && backend == b;
+  }
+
+  void set_type(Backend b, ScalarType s) {
+    dtype = s;
+    backend = b;
+  }
+
+  TensorOptions options() {
+    return TensorOptions(backendToDeviceType(backend)).dtype(dtype);
+  }
 
   /// The data pointer. This may be different from tensor.data_ptr() if the
   /// iterator is split.
@@ -148,13 +167,9 @@ struct CAFFE2_API TensorIterator {
   /// Accessors for each operand
   IntArrayRef strides(int arg) const { return operands_[arg].stride_bytes; }
   void* data_ptr(int arg) const;
-  const Type& type(int arg=0) const {
-    AT_ASSERT(operands_[arg].type);
-    return *operands_[arg].type;
-  }
-  ScalarType dtype(int arg=0) const { return type(arg).scalarType(); }
-  DeviceType device_type(int arg=0) const { return type(arg).device_type(); }
-  int64_t element_size(int arg) const { return type(arg).typeMeta().itemsize(); }
+  ScalarType dtype(int arg=0) const { return operands_[arg].dtype; }
+  DeviceType device_type(int arg=0) const { return backendToDeviceType(operands_[arg].backend); }
+  int64_t element_size(int arg) const { return elementSize(dtype(arg)); }
   bool is_scalar(int arg) const;
   bool is_cpu_scalar(int arg) const;
 
@@ -237,7 +252,7 @@ protected:
   void reorder_dimensions();
   void permute_dimensions(IntArrayRef perm);
   void compute_types();
-  Type& compute_common_type();
+  std::pair<Backend, ScalarType> compute_common_type();
   void allocate_outputs();
   void coalesce_dimensions();
 
@@ -261,13 +276,13 @@ struct TensorIterator::Builder {
 
   Builder() : iter_(new TensorIterator()) {};
 
-  void add_output(const Tensor& output, const Type* type=nullptr) {
-    iter_->operands_.emplace_back(output, type);
+  void add_output(const Tensor& output, const Backend backend=Backend::Undefined, const ScalarType dtype=ScalarType::Undefined) {
+    iter_->operands_.emplace_back(output, backend, dtype);
     iter_->num_outputs_++;
   }
 
-  void add_input(const Tensor& input, const Type* type=nullptr) {
-    iter_->operands_.emplace_back(input, type);
+  void add_input(const Tensor& input, const Backend backend=Backend::Undefined, const ScalarType dtype=ScalarType::Undefined) {
+    iter_->operands_.emplace_back(input, backend, dtype);
   }
 
   void dont_compute_common_dtype() {
index e7a7f01..d7ab446 100644 (file)
@@ -1025,7 +1025,7 @@ class TestCuda(TestCase):
 
             self.assertEqual(x * y, 4.5)
             self.assertEqual(y * x, 4.5)
-            with self.assertRaisesRegex(RuntimeError, "doesn't match the desired type"):
+            with self.assertRaisesRegex(RuntimeError, "doesn't match the desired"):
                 y *= x
             x *= y
             self.assertEqual(x, 4.5)
@@ -2059,15 +2059,13 @@ class TestCuda(TestCase):
     def test_sum_cpu_gpu_mismatch(self):
         x = torch.randn(20, dtype=torch.float32, device='cuda')
         y = torch.randn(1, dtype=torch.float32)
-        with self.assertRaisesRegex(RuntimeError, 'expected type'
-                                    ' torch.FloatTensor but got'
-                                    ' torch.cuda.FloatTensor'):
+        with self.assertRaisesRegex(RuntimeError,
+                                    'expected backend CPU and dtype Float but got backend CUDA and dtype Float'):
             torch.sum(x, dim=[0], dtype=torch.float32, out=y)
         # makeing sure half to float promotion is also properly working.
         x = x.half()
-        with self.assertRaisesRegex(RuntimeError, 'expected type'
-                                    ' torch.FloatTensor but got'
-                                    ' torch.cuda.HalfTensor'):
+        with self.assertRaisesRegex(RuntimeError,
+                                    'expected backend CPU and dtype Float but got backend CUDA and dtype Half'):
             torch.sum(x, dim=[0], dtype=torch.float32, out=y)
 
     @skipIfRocm