[Matrix] Preserve volatile when loading loads/stores.
authorFlorian Hahn <flo@fhahn.com>
Thu, 18 Jun 2020 11:14:19 +0000 (12:14 +0100)
committerFlorian Hahn <flo@fhahn.com>
Thu, 18 Jun 2020 11:14:19 +0000 (12:14 +0100)
Currently the matrix lowering turns volatile loads/stores into
non-volatile ones. This patch updates the lowering to preserve the
volatile bit.

Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke, nicolasvasilache

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D81498

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll
llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-volatile.ll
llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll

index 445ab19..97a40b8 100644 (file)
@@ -718,9 +718,9 @@ public:
       if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
         Changed |= VisitBinaryOperator(BinOp);
       if (match(Inst, m_Load(m_Value(Op1))))
-        Changed |= VisitLoad(Inst, Op1, Builder);
+        Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
-        Changed |= VisitStore(Inst, Op1, Op2, Builder);
+        Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
     }
 
     RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func);
@@ -732,16 +732,18 @@ public:
     return Changed;
   }
 
-  LoadInst *createVectorLoad(Value *ColumnPtr, Type *EltType,
+  LoadInst *createVectorLoad(Value *ColumnPtr, Type *EltType, bool IsVolatile,
                              IRBuilder<> &Builder) {
-    return Builder.CreateAlignedLoad(
-        ColumnPtr, Align(DL.getABITypeAlignment(EltType)), "col.load");
+    return Builder.CreateAlignedLoad(ColumnPtr,
+                                     Align(DL.getABITypeAlignment(EltType)),
+                                     IsVolatile, "col.load");
   }
 
   StoreInst *createVectorStore(Value *ColumnValue, Value *ColumnPtr,
-                               Type *EltType, IRBuilder<> &Builder) {
+                               Type *EltType, bool IsVolatile,
+                               IRBuilder<> &Builder) {
     return Builder.CreateAlignedStore(ColumnValue, ColumnPtr,
-                                      DL.getABITypeAlign(EltType));
+                                      DL.getABITypeAlign(EltType), IsVolatile);
   }
 
   /// Turns \p BasePtr into an elementwise pointer to \p EltType.
@@ -777,8 +779,8 @@ public:
 
   /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
   /// vectors.
-  MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, ShapeInfo Shape,
-                      IRBuilder<> &Builder) {
+  MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, bool IsVolatile,
+                      ShapeInfo Shape, IRBuilder<> &Builder) {
     auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
     MatrixTy Result;
@@ -786,7 +788,8 @@ public:
       Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride,
                                      Shape.getStride(), VType->getElementType(),
                                      Builder);
-      Value *Vector = createVectorLoad(GEP, VType->getElementType(), Builder);
+      Value *Vector =
+          createVectorLoad(GEP, VType->getElementType(), IsVolatile, Builder);
       Result.addVector(Vector);
     }
     return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
@@ -795,8 +798,8 @@ public:
 
   /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
   /// starting at \p MatrixPtr[I][J].
-  MatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, Value *I,
-                      Value *J, ShapeInfo ResultShape, Type *EltTy,
+  MatrixTy loadMatrix(Value *MatrixPtr, bool IsVolatile, ShapeInfo MatrixShape,
+                      Value *I, Value *J, ShapeInfo ResultShape, Type *EltTy,
                       IRBuilder<> &Builder) {
 
     Value *Offset = Builder.CreateAdd(
@@ -813,17 +816,18 @@ public:
         Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
 
     return loadMatrix(TileTy, TilePtr,
-                      Builder.getInt64(MatrixShape.getStride()), ResultShape,
-                      Builder);
+                      Builder.getInt64(MatrixShape.getStride()), IsVolatile,
+                      ResultShape, Builder);
   }
 
   /// Lower a load instruction with shape information.
-  void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
+  void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, bool IsVolatile,
                  ShapeInfo Shape) {
     IRBuilder<> Builder(Inst);
-    finalizeLowering(Inst,
-                     loadMatrix(Inst->getType(), Ptr, Stride, Shape, Builder),
-                     Builder);
+    finalizeLowering(
+        Inst,
+        loadMatrix(Inst->getType(), Ptr, Stride, IsVolatile, Shape, Builder),
+        Builder);
   }
 
   /// Lowers llvm.matrix.column.major.load.
@@ -835,12 +839,13 @@ public:
     Value *Ptr = Inst->getArgOperand(0);
     Value *Stride = Inst->getArgOperand(1);
     LowerLoad(Inst, Ptr, Stride,
+              cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
               {Inst->getArgOperand(3), Inst->getArgOperand(4)});
   }
 
   /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
   /// MatrixPtr[I][J].
-  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
+  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, bool IsVolatile,
                    ShapeInfo MatrixShape, Value *I, Value *J, Type *EltTy,
                    IRBuilder<> &Builder) {
     Value *Offset = Builder.CreateAdd(
@@ -857,20 +862,21 @@ public:
         Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
 
     storeMatrix(TileTy, StoreVal, TilePtr,
-                Builder.getInt64(MatrixShape.getStride()), Builder);
+                Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
   }
 
   /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
   /// vectors.
   MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride,
-                       IRBuilder<> &Builder) {
+                       bool IsVolatile, IRBuilder<> &Builder) {
     auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
     for (auto Vec : enumerate(StoreVal.vectors())) {
       Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()),
                                      Stride, StoreVal.getStride(),
                                      VType->getElementType(), Builder);
-      createVectorStore(Vec.value(), GEP, VType->getElementType(), Builder);
+      createVectorStore(Vec.value(), GEP, VType->getElementType(), IsVolatile,
+                        Builder);
     }
     return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
                                    StoreVal.getNumVectors());
@@ -878,12 +884,13 @@ public:
 
   /// Lower a store instruction with shape information.
   void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
-                  ShapeInfo Shape) {
+                  bool IsVolatile, ShapeInfo Shape) {
     IRBuilder<> Builder(Inst);
     auto StoreVal = getMatrix(Matrix, Shape, Builder);
-    finalizeLowering(
-        Inst, storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, Builder),
-        Builder);
+    finalizeLowering(Inst,
+                     storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride,
+                                 IsVolatile, Builder),
+                     Builder);
   }
 
   /// Lowers llvm.matrix.column.major.store.
@@ -896,6 +903,7 @@ public:
     Value *Ptr = Inst->getArgOperand(1);
     Value *Stride = Inst->getArgOperand(2);
     LowerStore(Inst, Matrix, Ptr, Stride,
+               cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
                {Inst->getArgOperand(4), Inst->getArgOperand(5)});
   }
 
@@ -1207,16 +1215,16 @@ public:
 
         for (unsigned K = 0; K < M; K += TileSize) {
           const unsigned TileM = std::min(M - K, unsigned(TileSize));
-          MatrixTy A =
-              loadMatrix(APtr, LShape, Builder.getInt64(I), Builder.getInt64(K),
-                         {TileR, TileM}, EltType, Builder);
-          MatrixTy B =
-              loadMatrix(BPtr, RShape, Builder.getInt64(K), Builder.getInt64(J),
-                         {TileM, TileC}, EltType, Builder);
+          MatrixTy A = loadMatrix(APtr, LoadOp0->isVolatile(), LShape,
+                                  Builder.getInt64(I), Builder.getInt64(K),
+                                  {TileR, TileM}, EltType, Builder);
+          MatrixTy B = loadMatrix(BPtr, LoadOp1->isVolatile(), RShape,
+                                  Builder.getInt64(K), Builder.getInt64(J),
+                                  {TileM, TileC}, EltType, Builder);
           emitMatrixMultiply(Res, A, B, AllowContract, Builder, true);
         }
-        storeMatrix(Res, CPtr, {R, M}, Builder.getInt64(I), Builder.getInt64(J),
-                    EltType, Builder);
+        storeMatrix(Res, CPtr, Store->isVolatile(), {R, M}, Builder.getInt64(I),
+                    Builder.getInt64(J), EltType, Builder);
       }
 
     // Mark eliminated instructions as fused and remove them.
@@ -1324,23 +1332,24 @@ public:
   }
 
   /// Lower load instructions, if shape information is available.
-  bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) {
+  bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
     auto I = ShapeMap.find(Inst);
     if (I == ShapeMap.end())
       return false;
 
-    LowerLoad(Inst, Ptr, Builder.getInt64(I->second.getStride()), I->second);
+    LowerLoad(Inst, Ptr, Builder.getInt64(I->second.getStride()),
+              Inst->isVolatile(), I->second);
     return true;
   }
 
-  bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
+  bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
                   IRBuilder<> &Builder) {
     auto I = ShapeMap.find(StoredVal);
     if (I == ShapeMap.end())
       return false;
 
     LowerStore(Inst, StoredVal, Ptr, Builder.getInt64(I->second.getStride()),
-               I->second);
+               Inst->isVolatile(), I->second);
     return true;
   }
 
index 98fbdf3..9da5c20 100644 (file)
@@ -8,15 +8,15 @@ define <9 x double> @strided_load_3x3_volatile(<9 x double>* %in, i64 %stride) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[VEC_GEP]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load volatile <3 x double>, <3 x double>* [[VEC_CAST]], align 8
 ; CHECK-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START1]]
 ; CHECK-NEXT:    [[VEC_CAST3:%.*]] = bitcast double* [[VEC_GEP2]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST3]], align 8
+; CHECK-NEXT:    load volatile <3 x double>, <3 x double>* [[VEC_CAST3]], align 8
 ; CHECK-NEXT:    [[VEC_START5:%.*]] = mul i64 2, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START5]]
 ; CHECK-NEXT:    [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT:    load volatile <3 x double>, <3 x double>* [[VEC_CAST7]], align 8
 ; CHECK-NOT:     = load
 ;
 entry:
@@ -30,10 +30,10 @@ define <4 x double> @load_volatile_multiply(<4 x double>* %in) {
 ; CHECK-LABEL: @load_volatile_multiply(
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x double>* [[IN:%.*]] to double*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[TMP1]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP1]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
 ; CHECK-NOT:     = load
 ;
   %in.m = load volatile <4 x double>, <4 x double>* %in, align 8
index ea70512..2136520 100644 (file)
@@ -14,29 +14,29 @@ define void @multiply_all_volatile(<4 x double>* noalias %A, <4 x double>* noali
 ; CHECK-NEXT:    [[COL_CAST:%.*]] = bitcast double* [[TMP1]] to <4 x double>*
 ; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <4 x double>* [[COL_CAST]] to double*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[TMP2]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP2]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <4 x double>* [[B:%.*]] to double*
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr double, double* [[TMP3]], i64 0
 ; CHECK-NEXT:    [[COL_CAST3:%.*]] = bitcast double* [[TMP4]] to <4 x double>*
 ; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <4 x double>* [[COL_CAST3]] to double*
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast double* [[TMP5]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
 ; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr double, double* [[TMP5]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
 
 ; CHECK:         [[TMP18:%.*]] = bitcast <4 x double>* [[C:%.*]] to double*
 ; CHECK-NEXT:    [[TMP19:%.*]] = getelementptr double, double* [[TMP18]], i64 0
 ; CHECK-NEXT:    [[COL_CAST18:%.*]] = bitcast double* [[TMP19]] to <4 x double>*
 ; CHECK-NEXT:    [[TMP20:%.*]] = bitcast <4 x double>* [[COL_CAST18]] to double*
 ; CHECK-NEXT:    [[VEC_CAST19:%.*]] = bitcast double* [[TMP20]] to <2 x double>*
-; CHECK-NEXT:    store <2 x double> {{.*}}, <2 x double>* [[VEC_CAST19]], align 8
+; CHECK-NEXT:    store volatile <2 x double> {{.*}}, <2 x double>* [[VEC_CAST19]], align 8
 ; CHECK-NEXT:    [[VEC_GEP20:%.*]] = getelementptr double, double* [[TMP20]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST21:%.*]] = bitcast double* [[VEC_GEP20]] to <2 x double>*
-; CHECK-NEXT:    store <2 x double> {{.*}}, <2 x double>* [[VEC_CAST21]], align 8
+; CHECK-NEXT:    store volatile <2 x double> {{.*}}, <2 x double>* [[VEC_CAST21]], align 8
 ; CHECK-NEXT:    ret void
 ;
 
@@ -59,10 +59,10 @@ define void @multiply_load0_volatile(<4 x double>* noalias %A, <4 x double>* noa
 ; CHECK-NEXT:    [[COL_CAST:%.*]] = bitcast double* [[TMP1]] to <4 x double>*
 ; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <4 x double>* [[COL_CAST]] to double*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[TMP2]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP2]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <4 x double>* [[B:%.*]] to double*
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr double, double* [[TMP3]], i64 0
 ; CHECK-NEXT:    [[COL_CAST3:%.*]] = bitcast double* [[TMP4]] to <4 x double>*
@@ -112,10 +112,10 @@ define void @multiply_load1_volatile(<4 x double>* noalias %A, <4 x double>* noa
 ; CHECK-NEXT:    [[COL_CAST3:%.*]] = bitcast double* [[TMP4]] to <4 x double>*
 ; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <4 x double>* [[COL_CAST3]] to double*
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast double* [[TMP5]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
 ; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr double, double* [[TMP5]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT:    load volatile <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
 
 ; CHECK:         [[TMP18:%.*]] = bitcast <4 x double>* [[C:%.*]] to double*
 ; CHECK-NEXT:    [[TMP19:%.*]] = getelementptr double, double* [[TMP18]], i64 0
@@ -166,10 +166,10 @@ define void @multiply_store_volatile(<4 x double>* noalias %A, <4 x double>* noa
 ; CHECK-NEXT:    [[COL_CAST18:%.*]] = bitcast double* [[TMP19]] to <4 x double>*
 ; CHECK-NEXT:    [[TMP20:%.*]] = bitcast <4 x double>* [[COL_CAST18]] to double*
 ; CHECK-NEXT:    [[VEC_CAST19:%.*]] = bitcast double* [[TMP20]] to <2 x double>*
-; CHECK-NEXT:    store <2 x double> {{.*}}, <2 x double>* [[VEC_CAST19]], align 8
+; CHECK-NEXT:    store volatile <2 x double> {{.*}}, <2 x double>* [[VEC_CAST19]], align 8
 ; CHECK-NEXT:    [[VEC_GEP20:%.*]] = getelementptr double, double* [[TMP20]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST21:%.*]] = bitcast double* [[VEC_GEP20]] to <2 x double>*
-; CHECK-NEXT:    store <2 x double> {{.*}}, <2 x double>* [[VEC_CAST21]], align 8
+; CHECK-NEXT:    store volatile <2 x double> {{.*}}, <2 x double>* [[VEC_CAST21]], align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:
index 38b6ee3..7b60d69 100644 (file)
@@ -6,10 +6,10 @@ define void @strided_store_volatile(<6 x i32> %in, i32* %out) {
 ; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <6 x i32> [[IN:%.*]], <6 x i32> undef, <3 x i32> <i32 0, i32 1, i32 2>
 ; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <6 x i32> [[IN]], <6 x i32> undef, <3 x i32> <i32 3, i32 4, i32 5>
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[OUT:%.*]] to <3 x i32>*
-; CHECK-NEXT:    store <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT]], i64 5
 ; CHECK-NEXT:    [[VEC_CAST2:%.*]] = bitcast i32* [[VEC_GEP]] to <3 x i32>*
-; CHECK-NEXT:    store <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST2]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST2]], align 4
 ; CHECK-NEXT:    ret void
 ;
   call void @llvm.matrix.column.major.store(<6 x i32> %in, i32* %out, i64 5, i1 true, i32 3, i32 2)
@@ -23,10 +23,10 @@ define void @multiply_store_volatile(<4 x i32> %in, <4 x i32>* %out) {
 ; CHECK-LABEL: @multiply_store_volatile(
 ; CHECK:         [[TMP29:%.*]] = bitcast <4 x i32>* %out to i32*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[TMP29]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store volatile <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[TMP29]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST25:%.*]] = bitcast i32* [[VEC_GEP]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST25]], align 4
+; CHECK-NEXT:    store volatile <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST25]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %res = call <4 x i32> @llvm.matrix.multiply(<4 x i32> %in, <4 x i32> %in, i32 2, i32 2, i32 2)
@@ -43,11 +43,11 @@ define void @strided_store_align32(<6 x i32> %in, i64 %stride, i32* %out) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT:%.*]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[VEC_GEP]] to <3 x i32>*
-; CHECK-NEXT:    store <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
 ; CHECK-NEXT:    [[VEC_START2:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, i32* [[OUT]], i64 [[VEC_START2]]
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast i32* [[VEC_GEP3]] to <3 x i32>*
-; CHECK-NEXT:    store <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 4
 ; CHECK-NEXT:    ret void
 ;
   call void @llvm.matrix.column.major.store(<6 x i32> %in, i32* align 32 %out, i64 %stride, i1 true, i32 3, i32 2)
@@ -61,11 +61,11 @@ define void @strided_store_align2(<6 x i32> %in, i64 %stride, i32* %out) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT:%.*]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[VEC_GEP]] to <3 x i32>*
-; CHECK-NEXT:    store <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
 ; CHECK-NEXT:    [[VEC_START2:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, i32* [[OUT]], i64 [[VEC_START2]]
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast i32* [[VEC_GEP3]] to <3 x i32>*
-; CHECK-NEXT:    store <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 4
 ; CHECK-NEXT:    ret void
 ;
   call void @llvm.matrix.column.major.store(<6 x i32> %in, i32* align 2 %out, i64 %stride, i1 true, i32 3, i32 2)