[Matrix] Use alignment info when lowering loads/stores.
authorFlorian Hahn <flo@fhahn.com>
Thu, 18 Jun 2020 12:06:00 +0000 (13:06 +0100)
committerFlorian Hahn <flo@fhahn.com>
Thu, 18 Jun 2020 12:19:31 +0000 (13:19 +0100)
This patch updates LowerMatrixIntrinsics to preserve the alignment
specified at the original load/stores and the align attribute for the
pointer argument of the column.major.load/store intrinsics.

We can always use the specified alignment for the load of the first
column. For subsequent columns, the alignment may need to be reduced.

For ConstantInt strides, compute the offset for the start of the column in
bytes and use commonAlignment to get the largest valid alignment.

For non-ConstantInt strides, we need to take the common alignment of the
initial alignment and the element size in bytes.

Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke, rjmccall

Reviewed By: rjmccall

Differential Revision: https://reviews.llvm.org/D81960

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll
llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll
llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll

index 97a40b8..c663714 100644 (file)
@@ -37,6 +37,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/Alignment.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -732,20 +733,6 @@ public:
     return Changed;
   }
 
-  LoadInst *createVectorLoad(Value *ColumnPtr, Type *EltType, bool IsVolatile,
-                             IRBuilder<> &Builder) {
-    return Builder.CreateAlignedLoad(ColumnPtr,
-                                     Align(DL.getABITypeAlignment(EltType)),
-                                     IsVolatile, "col.load");
-  }
-
-  StoreInst *createVectorStore(Value *ColumnValue, Value *ColumnPtr,
-                               Type *EltType, bool IsVolatile,
-                               IRBuilder<> &Builder) {
-    return Builder.CreateAlignedStore(ColumnValue, ColumnPtr,
-                                      DL.getABITypeAlign(EltType), IsVolatile);
-  }
-
   /// Turns \p BasePtr into an elementwise pointer to \p EltType.
   Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
     unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
@@ -777,10 +764,30 @@ public:
     return true;
   }
 
+  /// Compute the alignment for a column/row \p Idx with \p Stride between them.
+  /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
+  /// ConstantInt, reduce the initial alignment based on the byte offset. For
+  /// non-ConstantInt strides, return the common alignment of the initial
+  /// alignment and the element size in bytes.
+  Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
+                         MaybeAlign A) const {
+    Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
+    if (Idx == 0)
+      return InitialAlign;
+
+    TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
+    if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
+      uint64_t StrideInBytes =
+          ConstStride->getZExtValue() * ElementSizeInBits / 8;
+      return commonAlignment(InitialAlign, Idx * StrideInBytes);
+    }
+    return commonAlignment(InitialAlign, ElementSizeInBits / 8);
+  }
+
   /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
   /// vectors.
-  MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, bool IsVolatile,
-                      ShapeInfo Shape, IRBuilder<> &Builder) {
+  MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
+                      bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
     auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
     MatrixTy Result;
@@ -788,8 +795,10 @@ public:
       Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride,
                                      Shape.getStride(), VType->getElementType(),
                                      Builder);
-      Value *Vector =
-          createVectorLoad(GEP, VType->getElementType(), IsVolatile, Builder);
+      Value *Vector = Builder.CreateAlignedLoad(
+          GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign),
+          IsVolatile, "col.load");
+
       Result.addVector(Vector);
     }
     return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
@@ -798,8 +807,9 @@ public:
 
   /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
   /// starting at \p MatrixPtr[I][J].
-  MatrixTy loadMatrix(Value *MatrixPtr, bool IsVolatile, ShapeInfo MatrixShape,
-                      Value *I, Value *J, ShapeInfo ResultShape, Type *EltTy,
+  MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
+                      ShapeInfo MatrixShape, Value *I, Value *J,
+                      ShapeInfo ResultShape, Type *EltTy,
                       IRBuilder<> &Builder) {
 
     Value *Offset = Builder.CreateAdd(
@@ -815,19 +825,19 @@ public:
     Value *TilePtr =
         Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
 
-    return loadMatrix(TileTy, TilePtr,
+    return loadMatrix(TileTy, TilePtr, Align,
                       Builder.getInt64(MatrixShape.getStride()), IsVolatile,
                       ResultShape, Builder);
   }
 
   /// Lower a load instruction with shape information.
-  void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, bool IsVolatile,
-                 ShapeInfo Shape) {
+  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
+                 bool IsVolatile, ShapeInfo Shape) {
     IRBuilder<> Builder(Inst);
-    finalizeLowering(
-        Inst,
-        loadMatrix(Inst->getType(), Ptr, Stride, IsVolatile, Shape, Builder),
-        Builder);
+    finalizeLowering(Inst,
+                     loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
+                                Shape, Builder),
+                     Builder);
   }
 
   /// Lowers llvm.matrix.column.major.load.
@@ -838,16 +848,16 @@ public:
            "Intrinsic only supports column-major layout!");
     Value *Ptr = Inst->getArgOperand(0);
     Value *Stride = Inst->getArgOperand(1);
-    LowerLoad(Inst, Ptr, Stride,
+    LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
               cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
               {Inst->getArgOperand(3), Inst->getArgOperand(4)});
   }
 
   /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
   /// MatrixPtr[I][J].
-  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, bool IsVolatile,
-                   ShapeInfo MatrixShape, Value *I, Value *J, Type *EltTy,
-                   IRBuilder<> &Builder) {
+  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
+                   MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
+                   Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
     Value *Offset = Builder.CreateAdd(
         Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
 
@@ -861,34 +871,38 @@ public:
     Value *TilePtr =
         Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
 
-    storeMatrix(TileTy, StoreVal, TilePtr,
+    storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
                 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
   }
 
   /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
   /// vectors.
-  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride,
-                       bool IsVolatile, IRBuilder<> &Builder) {
+  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
+                       MaybeAlign MAlign, Value *Stride, bool IsVolatile,
+                       IRBuilder<> &Builder) {
     auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
     for (auto Vec : enumerate(StoreVal.vectors())) {
       Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()),
                                      Stride, StoreVal.getStride(),
                                      VType->getElementType(), Builder);
-      createVectorStore(Vec.value(), GEP, VType->getElementType(), IsVolatile,
-                        Builder);
+      Builder.CreateAlignedStore(Vec.value(), GEP,
+                                 getAlignForIndex(Vec.index(), Stride,
+                                                  VType->getElementType(),
+                                                  MAlign),
+                                 IsVolatile);
     }
     return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
                                    StoreVal.getNumVectors());
   }
 
   /// Lower a store instruction with shape information.
-  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
-                  bool IsVolatile, ShapeInfo Shape) {
+  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
+                  Value *Stride, bool IsVolatile, ShapeInfo Shape) {
     IRBuilder<> Builder(Inst);
     auto StoreVal = getMatrix(Matrix, Shape, Builder);
     finalizeLowering(Inst,
-                     storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride,
+                     storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
                                  IsVolatile, Builder),
                      Builder);
   }
@@ -902,7 +916,7 @@ public:
     Value *Matrix = Inst->getArgOperand(0);
     Value *Ptr = Inst->getArgOperand(1);
     Value *Stride = Inst->getArgOperand(2);
-    LowerStore(Inst, Matrix, Ptr, Stride,
+    LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
                cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
                {Inst->getArgOperand(4), Inst->getArgOperand(5)});
   }
@@ -1215,16 +1229,18 @@ public:
 
         for (unsigned K = 0; K < M; K += TileSize) {
           const unsigned TileM = std::min(M - K, unsigned(TileSize));
-          MatrixTy A = loadMatrix(APtr, LoadOp0->isVolatile(), LShape,
-                                  Builder.getInt64(I), Builder.getInt64(K),
-                                  {TileR, TileM}, EltType, Builder);
-          MatrixTy B = loadMatrix(BPtr, LoadOp1->isVolatile(), RShape,
-                                  Builder.getInt64(K), Builder.getInt64(J),
-                                  {TileM, TileC}, EltType, Builder);
+          MatrixTy A =
+              loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
+                         LShape, Builder.getInt64(I), Builder.getInt64(K),
+                         {TileR, TileM}, EltType, Builder);
+          MatrixTy B =
+              loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
+                         RShape, Builder.getInt64(K), Builder.getInt64(J),
+                         {TileM, TileC}, EltType, Builder);
           emitMatrixMultiply(Res, A, B, AllowContract, Builder, true);
         }
-        storeMatrix(Res, CPtr, Store->isVolatile(), {R, M}, Builder.getInt64(I),
-                    Builder.getInt64(J), EltType, Builder);
+        storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
+                    Builder.getInt64(I), Builder.getInt64(J), EltType, Builder);
       }
 
     // Mark eliminated instructions as fused and remove them.
@@ -1337,8 +1353,9 @@ public:
     if (I == ShapeMap.end())
       return false;
 
-    LowerLoad(Inst, Ptr, Builder.getInt64(I->second.getStride()),
-              Inst->isVolatile(), I->second);
+    LowerLoad(Inst, Ptr, Inst->getAlign(),
+              Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+              I->second);
     return true;
   }
 
@@ -1348,8 +1365,9 @@ public:
     if (I == ShapeMap.end())
       return false;
 
-    LowerStore(Inst, StoredVal, Ptr, Builder.getInt64(I->second.getStride()),
-               Inst->isVolatile(), I->second);
+    LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
+               Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+               I->second);
     return true;
   }
 
index 8caddb0..69bc882 100644 (file)
@@ -76,9 +76,9 @@ entry:
   %c.addr = alloca i32, align 4
   store i32 %r, i32* %r.addr, align 4
   store i32 %c, i32* %c.addr, align 4
-  %0 = load <4 x double>, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 0), align 16
+  %0 = load <4 x double>, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 0), align 8
   %mul = call <4 x double> @llvm.matrix.multiply(<4 x double> %0, <4 x double> %0, i32 2, i32 2, i32 2)
-  store <4 x double> %0, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 2), align 16
+  store <4 x double> %0, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 2), align 8
   ret void
 }
 
index 9da5c20..14b81a1 100644 (file)
@@ -51,7 +51,7 @@ define <9 x double> @strided_load_3x3_align32(<9 x double>* %in, i64 %stride) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[VEC_GEP]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST]], align 32
 ; CHECK-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START1]]
 ; CHECK-NEXT:    [[VEC_CAST3:%.*]] = bitcast double* [[VEC_GEP2]] to <3 x double>*
@@ -74,15 +74,15 @@ define <9 x double> @strided_load_3x3_align2(<9 x double>* %in, i64 %stride) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[VEC_GEP]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST]], align 2
 ; CHECK-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START1]]
 ; CHECK-NEXT:    [[VEC_CAST3:%.*]] = bitcast double* [[VEC_GEP2]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST3]], align 8
+; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST3]], align 2
 ; CHECK-NEXT:    [[VEC_START5:%.*]] = mul i64 2, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START5]]
 ; CHECK-NEXT:    [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST7]], align 2
 ; CHECK-NOT:     = load
 ;
 entry:
@@ -95,10 +95,10 @@ define <4 x double> @load_align2_multiply(<4 x double>* %in) {
 ; CHECK-LABEL: @load_align2_multiply(
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x double>* [[IN:%.*]] to double*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[TMP1]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST]], align 2
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP1]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST1]], align 2
 ; CHECK-NOT:     = load
 ;
   %in.m = load <4 x double>, <4 x double>* %in, align 2
@@ -111,13 +111,13 @@ define <6 x float> @strided_load_2x3_align16_stride2(<6 x float>* %in) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <6 x float>* [[IN:%.*]] to float*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast float* [[TMP0]] to <2 x float>*
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST]], align 16
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, float* [[TMP0]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast float* [[VEC_GEP]] to <2 x float>*
-; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST1]], align 4
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST1]], align 8
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, float* [[TMP0]], i64 4
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast float* [[VEC_GEP3]] to <2 x float>*
-; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST4]], align 4
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST4]], align 16
 ; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
 ; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x float> [[COL_LOAD5]], <2 x float> undef, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
 ; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <4 x float> [[TMP1]], <4 x float> [[TMP2]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
index 7b60d69..6688dad 100644 (file)
@@ -43,7 +43,7 @@ define void @strided_store_align32(<6 x i32> %in, i64 %stride, i32* %out) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT:%.*]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[VEC_GEP]] to <3 x i32>*
-; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 32
 ; CHECK-NEXT:    [[VEC_START2:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, i32* [[OUT]], i64 [[VEC_START2]]
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast i32* [[VEC_GEP3]] to <3 x i32>*
@@ -61,11 +61,11 @@ define void @strided_store_align2(<6 x i32> %in, i64 %stride, i32* %out) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT:%.*]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[VEC_GEP]] to <3 x i32>*
-; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 2
 ; CHECK-NEXT:    [[VEC_START2:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, i32* [[OUT]], i64 [[VEC_START2]]
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast i32* [[VEC_GEP3]] to <3 x i32>*
-; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 2
 ; CHECK-NEXT:    ret void
 ;
   call void @llvm.matrix.column.major.store(<6 x i32> %in, i32* align 2 %out, i64 %stride, i1 true, i32 3, i32 2)
@@ -76,10 +76,10 @@ define void @multiply_store_align16_stride8(<4 x i32> %in, <4 x i32>* %out) {
 ; CHECK-LABEL: @multiply_store_align16_stride8(
 ; CHECK:         [[TMP29:%.*]] = bitcast <4 x i32>* %out to i32*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[TMP29]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST]], align 16
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[TMP29]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST25:%.*]] = bitcast i32* [[VEC_GEP]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST25]], align 4
+; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST25]], align 8
 ; CHECK-NEXT:    ret void
 ;
   %res = call <4 x i32> @llvm.matrix.multiply(<4 x i32> %in, <4 x i32> %in, i32 2, i32 2, i32 2)
@@ -93,13 +93,13 @@ define void @strided_store_align8_stride12(<6 x i32> %in, i32* %out) {
 ; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <6 x i32> [[IN]], <6 x i32> undef, <2 x i32> <i32 2, i32 3>
 ; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <6 x i32> [[IN]], <6 x i32> undef, <2 x i32> <i32 4, i32 5>
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[OUT:%.*]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> [[SPLIT]], <2 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store <2 x i32> [[SPLIT]], <2 x i32>* [[VEC_CAST]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT]], i64 3
 ; CHECK-NEXT:    [[VEC_CAST3:%.*]] = bitcast i32* [[VEC_GEP]] to <2 x i32>*
 ; CHECK-NEXT:    store <2 x i32> [[SPLIT1]], <2 x i32>* [[VEC_CAST3]], align 4
 ; CHECK-NEXT:    [[VEC_GEP4:%.*]] = getelementptr i32, i32* [[OUT]], i64 6
 ; CHECK-NEXT:    [[VEC_CAST5:%.*]] = bitcast i32* [[VEC_GEP4]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> [[SPLIT2]], <2 x i32>* [[VEC_CAST5]], align 4
+; CHECK-NEXT:    store <2 x i32> [[SPLIT2]], <2 x i32>* [[VEC_CAST5]], align 8
 ; CHECK-NEXT:    ret void
 ;
   call void @llvm.matrix.column.major.store(<6 x i32> %in, i32* align 8 %out, i64 3, i1 false, i32 2, i32 3)