From 0298f2cfb1df80741a08fb7cd1eec9da70ed3441 Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Wed, 30 Jun 2021 00:00:11 -0700 Subject: [PATCH] [mlir] Fix wrong type in WmmaConstantOpToNVVMLowering InsertElement takes a scalar integer attribute not an array of integer. Differential Revision: https://reviews.llvm.org/D105174 --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 2 +- mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index d46a185..d955673 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -371,7 +371,7 @@ struct WmmaConstantOpToNVVMLowering for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { Value idx = rewriter.create( loc, typeConverter->convertType(rewriter.getIntegerType(32)), - rewriter.getI32ArrayAttr(vecEl)); + rewriter.getI32IntegerAttr(vecEl)); vecCst = rewriter.create(loc, vecType, vecCst, cst, idx); } diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir index f692dff..6eb641b 100644 --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -160,9 +160,9 @@ gpu.module @test_module { // CHECK-LABEL: func @gpu_wmma_constant_op // CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16 // CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<2xf16> -// CHECK: %[[C0:.+]] = llvm.mlir.constant([0 : i32]) : i32 +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[V1:.+]] = llvm.insertelement %[[CST]], %[[V0]][%[[C0]] : i32] : vector<2xf16> -// CHECK: %[[C1:.+]] = llvm.mlir.constant([1 : i32]) : i32 +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[V2:.+]] = llvm.insertelement %[[CST]], %[[V1]][%[[C1]] : i32] : vector<2xf16> // CHECK: %[[M0:.+]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[M1:.+]] = llvm.insertvalue %[[V2]], %[[M0]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -- 2.7.4