}
}
+TEST_F(Kernel, CatAndInlineWithAConstantDim) {
+ const auto graph_string = R"IR(
+ graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu),
+ %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)):
+ %2 : bool = prim::Constant[value=0]()
+ %3 : int = prim::Constant[value=1]()
+ %4 : Tensor[] = prim::ListConstruct(%0, %1)
+ %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3)
+ %6 : Tensor[] = prim::ListConstruct(%5)
+ %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3)
+ %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2)
+ return (%8, %7))IR";
+
+ auto graph = std::make_shared<Graph>();
+ parseIR(graph_string, &*graph);
+ TensorExprKernel k(graph);
+
+ auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto ref = at::_cast_Float(at::cat({a, b}, 1), 0);
+
+ std::vector<at::Tensor> inputs = {a, b};
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ k.run(stack);
+ auto o = stack[0].toTensor();
+ ASSERT_EQ(o.sizes(), ref.sizes());
+ ASSERT_EQ(o.dtype(), ref.dtype());
+ ASSERT_TRUE(at::allclose(o, ref));
+}
+
TEST_F(Kernel, CatWoConditionals) {
getCatWoConditionals() = true;
const auto graph_string = R"IR(
ref = foo(*values)
np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
+ def test_cat_with_constant_dim(self):
+ for device in self.devices:
+ def foo(*args):
+ v1 = torch.cat(args, dim=1)
+ v2 = torch.cat([v1], dim=1)
+ return v2 * v2
+
+ empty = torch.tensor([], device=device, dtype=torch.float32)
+ inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)]
+ traced = torch.jit.trace(foo, inputs)
+
+ x = warmup_and_run_forward(traced, *inputs)
+ self.assertLastGraphAllFused()
+ ref = foo(*inputs)
+ np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
+
def test_scalar(self):
@torch.jit.script
def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor: