[Verifier] Verify sizes of matrix.multiply operands and specified shape.
authorFlorian Hahn <flo@fhahn.com>
Tue, 4 Apr 2023 19:51:30 +0000 (20:51 +0100)
committerFlorian Hahn <flo@fhahn.com>
Tue, 4 Apr 2023 19:51:43 +0000 (20:51 +0100)
Extend the verifier to check if the size of the matrix operands of
matrix.multiply match the sizes specified by the numeric arguments.

Reviewed By: thegameg

Differential Revision: https://reviews.llvm.org/D147466

llvm/lib/IR/Verifier.cpp
llvm/test/Verifier/matrix-intrinsics.ll

index 5cc90e0..90dd6ef 100644 (file)
@@ -5638,15 +5638,28 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
     Type *Op0ElemTy = nullptr;
     Type *Op1ElemTy = nullptr;
     switch (ID) {
-    case Intrinsic::matrix_multiply:
+    case Intrinsic::matrix_multiply: {
       NumRows = cast<ConstantInt>(Call.getArgOperand(2));
+      ConstantInt *N = cast<ConstantInt>(Call.getArgOperand(3));
       NumColumns = cast<ConstantInt>(Call.getArgOperand(4));
+      Check(cast<FixedVectorType>(Call.getArgOperand(0)->getType())
+                    ->getNumElements() ==
+                NumRows->getZExtValue() * N->getZExtValue(),
+            "First argument of a matrix operation does not match specified "
+            "shape!");
+      Check(cast<FixedVectorType>(Call.getArgOperand(1)->getType())
+                    ->getNumElements() ==
+                N->getZExtValue() * NumColumns->getZExtValue(),
+            "Second argument of a matrix operation does not match specified "
+            "shape!");
+
       ResultTy = cast<VectorType>(Call.getType());
       Op0ElemTy =
           cast<VectorType>(Call.getArgOperand(0)->getType())->getElementType();
       Op1ElemTy =
           cast<VectorType>(Call.getArgOperand(1)->getType())->getElementType();
       break;
+    }
     case Intrinsic::matrix_transpose:
       NumRows = cast<ConstantInt>(Call.getArgOperand(1));
       NumColumns = cast<ConstantInt>(Call.getArgOperand(2));
index 60b9518..b6d5ad9 100644 (file)
@@ -1,4 +1,4 @@
-; RUN: not llvm-as -opaque-pointers < %s -o /dev/null 2>&1 | FileCheck %s
+; RUN: not llvm-as < %s -o /dev/null 2>&1 | FileCheck %s
 
 define <4 x float> @transpose(<4 x float> %m, i32 %arg) {
 ; CHECK: assembly parsed, but does not verify as correct!
@@ -20,17 +20,19 @@ define <4 x float> @transpose(<4 x float> %m, i32 %arg) {
 }
 
 define <4 x float> @multiply(<4 x float> %m, i32 %arg) {
-; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
-; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
+; CHECK-NEXT: First argument of a matrix operation does not match specified shape!
+; CHECK-NEXT: First argument of a matrix operation does not match specified shape!
+; CHECK-NEXT: Second argument of a matrix operation does not match specified shape!
 ; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
 ; CHECK-NEXT: immarg operand has non-immediate parameter
 ; CHECK-NEXT: i32 %arg
-; CHECK-NEXT:   %result.3 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1)
+; CHECK-NEXT:   %result.4 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1)
   %result.0 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %m, <4 x float> %m, i32 0, i32 0, i32 0)
   %result.1 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.0, <4 x float> %m, i32 3, i32 2, i32 2)
   %result.2 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.1, <4 x float> %m, i32 2, i32 2, i32 1)
-  %result.3 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1)
-  ret <4 x float> %result.3
+  %result.3 = call <3 x float> @llvm.matrix.multiply.v3f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 2, i32 2, i32 2)
+  %result.4 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1)
+  ret <4 x float> %result.4
 }
 
 define <4 x float> @column.major_load(ptr %m, ptr %n, i32 %arg) {
@@ -136,3 +138,4 @@ declare <4 x float> @llvm.matrix.multiply.v4f32.v4i32.v4f32(<4 x i32>, <4 x floa
 declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4i32(<4 x float>, <4 x i32>, i32, i32, i32)
 declare <4 x float> @llvm.matrix.multiply.v4f32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32, i32, i32)
 declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32)
+declare <3 x float> @llvm.matrix.multiply.v3f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32)