[nnc] Updated indices during broadcast to use int64_t (#64627)
authorRaghavan Raman <raghavanr@fb.com>
Thu, 9 Sep 2021 15:26:16 +0000 (08:26 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 15:29:37 +0000 (08:29 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64627

This fixes the root cause of S242719

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30801686

Pulled By: navahgar

fbshipit-source-id: b6d3ebdc7eb57116eaced53c2f35c7798bb17e80

test/cpp/tensorexpr/test_kernel.cpp
test/test_tensorexpr.py
torch/csrc/jit/tensorexpr/kernel.cpp

index d12f142..1a6d086 100644 (file)
@@ -577,6 +577,36 @@ TEST_F(Kernel, CatInputTypesPromotion) {
   }
 }
 
+TEST_F(Kernel, CatAndInlineWithAConstantDim) {
+  const auto graph_string = R"IR(
+      graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu),
+            %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)):
+        %2 : bool = prim::Constant[value=0]()
+        %3 : int = prim::Constant[value=1]()
+        %4 : Tensor[] = prim::ListConstruct(%0, %1)
+        %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3)
+        %6 : Tensor[] = prim::ListConstruct(%5)
+        %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3)
+        %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2)
+        return (%8, %7))IR";
+
+  auto graph = std::make_shared<Graph>();
+  parseIR(graph_string, &*graph);
+  TensorExprKernel k(graph);
+
+  auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto ref = at::_cast_Float(at::cat({a, b}, 1), 0);
+
+  std::vector<at::Tensor> inputs = {a, b};
+  std::vector<IValue> stack = fmap<IValue>(inputs);
+  k.run(stack);
+  auto o = stack[0].toTensor();
+  ASSERT_EQ(o.sizes(), ref.sizes());
+  ASSERT_EQ(o.dtype(), ref.dtype());
+  ASSERT_TRUE(at::allclose(o, ref));
+}
+
 TEST_F(Kernel, CatWoConditionals) {
   getCatWoConditionals() = true;
   const auto graph_string = R"IR(
index 366c262..0ecc4de 100644 (file)
@@ -1108,6 +1108,22 @@ class TestTensorExprFuser(BaseTestClass):
             ref = foo(*values)
             np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
 
+    def test_cat_with_constant_dim(self):
+        for device in self.devices:
+            def foo(*args):
+                v1 = torch.cat(args, dim=1)
+                v2 = torch.cat([v1], dim=1)
+                return v2 * v2
+
+            empty = torch.tensor([], device=device, dtype=torch.float32)
+            inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)]
+            traced = torch.jit.trace(foo, inputs)
+
+            x = warmup_and_run_forward(traced, *inputs)
+            self.assertLastGraphAllFused()
+            ref = foo(*inputs)
+            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
+
     def test_scalar(self):
         @torch.jit.script
         def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor:
index f850d7d..17e4c96 100644 (file)
@@ -466,7 +466,7 @@ std::vector<ExprHandle> computeIndicesToBroadcast(
   while (sizeIt != inputSizes.rend()) {
     auto const& size = intValue(*sizeIt);
     if (size && *size == 1) {
-      bcast.emplace_back(0);
+      bcast.emplace_back(LongImm::make(0));
     } else {
       bcast.emplace_back(*axisIt);
     }