From 652a8bf7d0d104c21964076083684ad0936fe05e Mon Sep 17 00:00:00 2001 From: Raghavan Raman Date: Thu, 9 Sep 2021 08:26:16 -0700 Subject: [PATCH] [nnc] Updated indices during broadcast to use int64_t (#64627) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64627 This fixes the root cause of S242719 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D30801686 Pulled By: navahgar fbshipit-source-id: b6d3ebdc7eb57116eaced53c2f35c7798bb17e80 --- test/cpp/tensorexpr/test_kernel.cpp | 30 ++++++++++++++++++++++++++++++ test/test_tensorexpr.py | 16 ++++++++++++++++ torch/csrc/jit/tensorexpr/kernel.cpp | 2 +- 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index d12f142..1a6d086 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -577,6 +577,36 @@ TEST_F(Kernel, CatInputTypesPromotion) { } } +TEST_F(Kernel, CatAndInlineWithAConstantDim) { + const auto graph_string = R"IR( + graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu), + %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)): + %2 : bool = prim::Constant[value=0]() + %3 : int = prim::Constant[value=1]() + %4 : Tensor[] = prim::ListConstruct(%0, %1) + %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3) + %6 : Tensor[] = prim::ListConstruct(%5) + %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3) + %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2) + return (%8, %7))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + + auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::_cast_Float(at::cat({a, b}, 1), 0); + + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref)); +} + TEST_F(Kernel, CatWoConditionals) { getCatWoConditionals() = true; const auto graph_string = R"IR( diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 366c262..0ecc4de 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -1108,6 +1108,22 @@ class TestTensorExprFuser(BaseTestClass): ref = foo(*values) np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) + def test_cat_with_constant_dim(self): + for device in self.devices: + def foo(*args): + v1 = torch.cat(args, dim=1) + v2 = torch.cat([v1], dim=1) + return v2 * v2 + + empty = torch.tensor([], device=device, dtype=torch.float32) + inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)] + traced = torch.jit.trace(foo, inputs) + + x = warmup_and_run_forward(traced, *inputs) + self.assertLastGraphAllFused() + ref = foo(*inputs) + np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) + def test_scalar(self): @torch.jit.script def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor: diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index f850d7d..17e4c96 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -466,7 +466,7 @@ std::vector computeIndicesToBroadcast( while (sizeIt != inputSizes.rend()) { auto const& size = intValue(*sizeIt); if (size && *size == 1) { - bcast.emplace_back(0); + bcast.emplace_back(LongImm::make(0)); } else { bcast.emplace_back(*axisIt); } -- 2.7.4