From: Thomas Raoux Date: Thu, 18 Jun 2020 20:05:09 +0000 (-0700) Subject: [mlir][spirv] Allow mixed type cooperative matrix muladd X-Git-Tag: llvmorg-12-init~2621 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=25cbfa0788846c7ec06affb9c0e0d4a87b510c02;p=platform%2Fupstream%2Fllvm.git [mlir][spirv] Allow mixed type cooperative matrix muladd muladd can have differenti types for lhs/rhs and acc/destination. Change verifier and update the test to use supported example. Differential Revision: https://reviews.llvm.org/D82042 --- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index efe6858..87456f0 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2753,8 +2753,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) { typeR.getScope() != typeB.getScope() || typeR.getScope() != typeC.getScope()) return op.emitOpError("matrix scope must match"); - if (typeR.getElementType() != typeA.getElementType() || - typeR.getElementType() != typeB.getElementType() || + if (typeA.getElementType() != typeB.getElementType() || typeR.getElementType() != typeC.getElementType()) return op.emitOpError("matrix element type must match"); return success(); diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir index 51c7090..a2dafad 100644 --- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -37,9 +37,9 @@ spv.func @cooperative_matrix_length() -> i32 "None" { } // CHECK-LABEL: @cooperative_matrix_muladd -spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> - %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> +spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x32xi8, Subgroup>, %b : !spv.coopmatrix<32x8xi8, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> + %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> spv.Return }