Fix for issue 14829 (#14908)
authorSoumyaroop Roy <soroy@tesla.com>
Mon, 17 Dec 2018 22:23:54 +0000 (14:23 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 17 Dec 2018 22:28:50 +0000 (14:28 -0800)
Summary:
* Modify the testcase as outlined in the issue
   * Issue url: https://github.com/pytorch/pytorch/issues/14829
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14908

Differential Revision: D13490360

Pulled By: ezyang

fbshipit-source-id: ff11a72e19b49223652182e82c2b4e65fe444ca7

test/cpp_extensions/complex_registration_extension.cpp
test/test_cpp_extensions.py

index cdb38e7..b71858e 100644 (file)
@@ -40,21 +40,9 @@ struct CPUComplexFloatType : public at::CPUTypeDefault {
   TypeID ID() const override;
 
   Tensor empty(IntList size, const TensorOptions & options) const override {
-    // TODO: Upstream this
-    int64_t numel = 1;
-    for (auto s : size) {
-      numel *= s;
-    }
-    Storage s{c10::make_intrusive<StorageImpl>(
-        scalarTypeToTypeMeta(ScalarType::ComplexFloat),
-        numel,
-        getCPUAllocator(),
-        /* resizable */ true)};
-    Tensor t{c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
-        std::move(s),
-        at::CPUTensorId(),
-        /* is_variable */ false)};
-    return t;
+    // Delegate to the appropriate cpu tensor factory
+    const DeviceGuard device_guard(options.device());
+    return at::native::empty_cpu(/* actuals */ size, options);
   }
 };
 
index a30483c..42dec14 100755 (executable)
@@ -310,7 +310,11 @@ class TestCppExtension(common.TestCase):
             verbose=True,
         )
 
-        torch.empty(2, 2, dtype=torch.complex64)
+        # Make sure that the empty tensor is of the desired shape and type
+        # Refer to https://github.com/pytorch/pytorch/issues/14829
+        t = torch.empty(2, 2, dtype=torch.complex64)
+        self.assertEqual(t.size(), torch.Size([2, 2]))
+        self.assertEqual(t.type(), 'torch.ComplexFloatTensor')
 
     @unittest.skipIf(not TEST_CUDA, "CUDA not found")
     def test_half_support(self):