[mlir][Vector] Introduce 'vector.load' and 'vector.store' ops
authorDiego Caballero <diego.caballero@intel.com>
Fri, 12 Feb 2021 17:41:46 +0000 (19:41 +0200)
committerDiego Caballero <diego.caballero@intel.com>
Fri, 12 Feb 2021 18:48:37 +0000 (20:48 +0200)
This patch adds the 'vector.load' and 'vector.store' ops to the Vector
dialect [1]. These operations model *contiguous* vector loads and stores
from/to memory. Their semantics are similar to the 'affine.vector_load' and
'affine.vector_store' counterparts but without the affine constraints. The
most relevant feature is that these new vector operations may perform a vector
load/store on memrefs with a non-vector element type, unlike 'std.load' and
'std.store' ops. This opens the representation to model more generic vector
load/store scenarios: unaligned vector loads/stores, perform scalar and vector
memory access on the same memref, decouple memory allocation constraints from
memory accesses, etc [1]. These operations will also facilitate the progressive
lowering of both Affine vector loads/stores and Vector transfer reads/writes
for those that read/write contiguous slices from/to memory.

In particular, this patch adds the 'vector.load' and 'vector.store' ops to the
Vector dialect, implements their lowering to the LLVM dialect, and changes the
lowering of 'affine.vector_load' and 'affine.vector_store' ops to the new vector
ops. The lowering of Vector transfer reads/writes will be implemented in the
future, probably as an independent pass. The API of 'vector.maskedload' and
'vector.maskedstore' has also been changed slightly to align it with the
transfer read/write ops and the vector new ops. This will improve reusability
among all these operations. For example, the lowering of 'vector.load',
'vector.store', 'vector.maskedload' and 'vector.maskedstore' to the LLVM dialect
is implemented with a single template conversion pattern.

[1] https://llvm.discourse.group/t/memref-type-and-data-layout/

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96185

mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir

index 1aeb92a..557446b 100644 (file)
@@ -1320,6 +1320,156 @@ def Vector_TransferWriteOp :
   let hasFolder = 1;
 }
 
+def Vector_LoadOp : Vector_Op<"load"> {
+  let summary = "reads an n-D slice of memory into an n-D vector";
+  let description = [{
+    The 'vector.load' operation reads an n-D slice of memory into an n-D
+    vector. It takes a 'base' memref, an index for each memref dimension and a
+    result vector type as arguments. It returns a value of the result vector
+    type. The 'base' memref and indices determine the start memory address from
+    which to read. Each index provides an offset for each memref dimension
+    based on the element type of the memref. The shape of the result vector
+    type determines the shape of the slice read from the start memory address.
+    The elements along each dimension of the slice are strided by the memref
+    strides. Only memref with default strides are allowed. These constraints
+    guarantee that elements read along the first dimension of the slice are
+    contiguous in memory.
+
+    The memref element type can be a scalar or a vector type. If the memref
+    element type is a scalar, it should match the element type of the result
+    vector. If the memref element type is vector, it should match the result
+    vector type.
+
+    Example 1: 1-D vector load on a scalar memref.
+    ```mlir
+    %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
+    ```
+
+    Example 2: 1-D vector load on a vector memref.
+    ```mlir
+    %result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+    ```
+
+    Example 3:  2-D vector load on a scalar memref.
+    ```mlir
+    %result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
+    ```
+
+    Example 4:  2-D vector load on a vector memref.
+    ```mlir
+    %result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+    ```
+
+    Representation-wise, the 'vector.load' operation permits out-of-bounds
+    reads. Support and implementation of out-of-bounds vector loads is
+    target-specific. No assumptions should be made on the value of elements
+    loaded out of bounds. Not all targets may support out-of-bounds vector
+    loads.
+
+    Example 5:  Potential out-of-bound vector load.
+    ```mlir
+    %result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
+    ```
+
+    Example 6:  Explicit out-of-bound vector load.
+    ```mlir
+    %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
+    ```
+  }];
+
+  let arguments = (ins Arg<AnyMemRef, "the reference to load from",
+      [MemRead]>:$base,
+      Variadic<Index>:$indices);
+  let results = (outs AnyVector:$result);
+
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+
+    VectorType getVectorType() {
+      return result().getType().cast<VectorType>();
+    }
+  }];
+
+  let assemblyFormat =
+      "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+}
+
+def Vector_StoreOp : Vector_Op<"store"> {
+  let summary = "writes an n-D vector to an n-D slice of memory";
+  let description = [{
+    The 'vector.store' operation writes an n-D vector to an n-D slice of memory.
+    It takes the vector value to be stored, a 'base' memref and an index for
+    each memref dimension. The 'base' memref and indices determine the start
+    memory address from which to write. Each index provides an offset for each
+    memref dimension based on the element type of the memref. The shape of the
+    vector value to store determines the shape of the slice written from the
+    start memory address. The elements along each dimension of the slice are
+    strided by the memref strides. Only memref with default strides are allowed.
+    These constraints guarantee that elements written along the first dimension
+    of the slice are contiguous in memory.
+
+    The memref element type can be a scalar or a vector type. If the memref
+    element type is a scalar, it should match the element type of the value
+    to store. If the memref element type is vector, it should match the type
+    of the value to store.
+
+    Example 1: 1-D vector store on a scalar memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+    ```
+
+    Example 2: 1-D vector store on a vector memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+    ```
+
+    Example 3:  2-D vector store on a scalar memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
+    ```
+
+    Example 4:  2-D vector store on a vector memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+    ```
+
+    Representation-wise, the 'vector.store' operation permits out-of-bounds
+    writes. Support and implementation of out-of-bounds vector stores are
+    target-specific. No assumptions should be made on the memory written out of
+    bounds. Not all targets may support out-of-bounds vector stores.
+
+    Example 5:  Potential out-of-bounds vector store.
+    ```mlir
+    vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
+    ```
+
+    Example 6:  Explicit out-of-bounds vector store.
+    ```mlir
+    vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
+    ```
+  }];
+
+  let arguments = (ins AnyVector:$valueToStore,
+      Arg<AnyMemRef, "the reference to store to",
+      [MemWrite]>:$base,
+      Variadic<Index>:$indices);
+
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+
+    VectorType getVectorType() {
+      return valueToStore().getType().cast<VectorType>();
+    }
+  }];
+
+  let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
+                       "`:` type($base) `,` type($valueToStore)";
+}
+
 def Vector_MaskedLoadOp :
   Vector_Op<"maskedload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
@@ -1363,7 +1513,7 @@ def Vector_MaskedLoadOp :
     VectorType getPassThruVectorType() {
       return pass_thru().getType().cast<VectorType>();
     }
-    VectorType getResultVectorType() {
+    VectorType getVectorType() {
       return result().getType().cast<VectorType>();
     }
   }];
@@ -1377,7 +1527,7 @@ def Vector_MaskedStoreOp :
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$value)> {
+               VectorOfRank<[1]>:$valueToStore)> {
 
   let summary = "stores elements from a vector into memory as defined by a mask vector";
 
@@ -1411,12 +1561,13 @@ def Vector_MaskedStoreOp :
     VectorType getMaskVectorType() {
       return mask().getType().cast<VectorType>();
     }
-    VectorType getValueVectorType() {
-      return value().getType().cast<VectorType>();
+    VectorType getVectorType() {
+      return valueToStore().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
-    "type($base) `,` type($mask) `,` type($value)";
+  let assemblyFormat =
+      "$base `[` $indices `]` `,` $mask `,` $valueToStore "
+      "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
 }
 
index 5a2fe91..ef04f68 100644 (file)
@@ -578,8 +578,9 @@ public:
     if (!resultOperands)
       return failure();
 
-    // Build std.load memref[expandedMap.results].
-    rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
+    // Build vector.load memref[expandedMap.results].
+    rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, op.getMemRef(),
+                                              *resultOperands);
     return success();
   }
 };
@@ -625,8 +626,8 @@ public:
       return failure();
 
     // Build std.store valueToStore, memref[expandedMap.results].
-    rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
-                                         op.getMemRef(), *maybeExpandedMap);
+    rewriter.replaceOpWithNewOp<mlir::StoreOp>(
+        op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
     return success();
   }
 };
@@ -695,8 +696,8 @@ public:
 };
 
 /// Apply the affine map from an 'affine.vector_load' operation to its operands,
-/// and feed the results to a newly created 'vector.transfer_read' operation
-/// (which replaces the original 'affine.vector_load').
+/// and feed the results to a newly created 'vector.load' operation (which
+/// replaces the original 'affine.vector_load').
 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
 public:
   using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
@@ -710,16 +711,16 @@ public:
     if (!resultOperands)
       return failure();
 
-    // Build vector.transfer_read memref[expandedMap.results].
-    rewriter.replaceOpWithNewOp<TransferReadOp>(
+    // Build vector.load memref[expandedMap.results].
+    rewriter.replaceOpWithNewOp<vector::LoadOp>(
         op, op.getVectorType(), op.getMemRef(), *resultOperands);
     return success();
   }
 };
 
 /// Apply the affine map from an 'affine.vector_store' operation to its
-/// operands, and feed the results to a newly created 'vector.transfer_write'
-/// operation (which replaces the original 'affine.vector_store').
+/// operands, and feed the results to a newly created 'vector.store' operation
+/// (which replaces the original 'affine.vector_store').
 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
 public:
   using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
@@ -733,7 +734,7 @@ public:
     if (!maybeExpandedMap)
       return failure();
 
-    rewriter.replaceOpWithNewOp<TransferWriteOp>(
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
         op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
     return success();
   }
index 54cdd9c..3393bb7 100644 (file)
@@ -357,64 +357,72 @@ public:
   }
 };
 
-/// Conversion pattern for a vector.maskedload.
-class VectorMaskedLoadOpConversion
-    : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
-public:
-  using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto loc = load->getLoc();
-    auto adaptor = vector::MaskedLoadOpAdaptor(operands);
-    MemRefType memRefType = load.getMemRefType();
+/// Overloaded utility that replaces a vector.load, vector.store,
+/// vector.maskedload and vector.maskedstore with their respective LLVM
+/// couterparts.
+static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
+                                 vector::LoadOpAdaptor adaptor,
+                                 VectorType vectorTy, Value ptr, unsigned align,
+                                 ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
+}
 
-    // Resolve alignment.
-    unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
-      return failure();
+static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
+                                 vector::MaskedLoadOpAdaptor adaptor,
+                                 VectorType vectorTy, Value ptr, unsigned align,
+                                 ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
+      loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
+}
 
-    // Resolve address.
-    auto vtype = typeConverter->convertType(load.getResultVectorType());
-    Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
-                                               adaptor.indices(), rewriter);
-    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
+static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
+                                 vector::StoreOpAdaptor adaptor,
+                                 VectorType vectorTy, Value ptr, unsigned align,
+                                 ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
+                                             ptr, align);
+}
 
-    rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
-        load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
-        rewriter.getI32IntegerAttr(align));
-    return success();
-  }
-};
+static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
+                                 vector::MaskedStoreOpAdaptor adaptor,
+                                 VectorType vectorTy, Value ptr, unsigned align,
+                                 ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
+      storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
+}
 
-/// Conversion pattern for a vector.maskedstore.
-class VectorMaskedStoreOpConversion
-    : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
+/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
+/// vector.maskedstore.
+template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
+class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 public:
-  using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
+  using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
+  matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = store->getLoc();
-    auto adaptor = vector::MaskedStoreOpAdaptor(operands);
-    MemRefType memRefType = store.getMemRefType();
+    // Only 1-D vectors can be lowered to LLVM.
+    VectorType vectorTy = loadOrStoreOp.getVectorType();
+    if (vectorTy.getRank() > 1)
+      return failure();
+
+    auto loc = loadOrStoreOp->getLoc();
+    auto adaptor = LoadOrStoreOpAdaptor(operands);
+    MemRefType memRefTy = loadOrStoreOp.getMemRefType();
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
+    if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
       return failure();
 
     // Resolve address.
-    auto vtype = typeConverter->convertType(store.getValueVectorType());
-    Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+    auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
+                     .template cast<VectorType>();
+    Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
                                                adaptor.indices(), rewriter);
-    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
+    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
 
-    rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
-        store, adaptor.value(), ptr, adaptor.mask(),
-        rewriter.getI32IntegerAttr(align));
+    replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
     return success();
   }
 };
@@ -1511,8 +1519,14 @@ void mlir::populateVectorToLLVMConversionPatterns(
               VectorInsertOpConversion,
               VectorPrintOpConversion,
               VectorTypeCastOpConversion,
-              VectorMaskedLoadOpConversion,
-              VectorMaskedStoreOpConversion,
+              VectorLoadStoreConversion<vector::LoadOp,
+                                        vector::LoadOpAdaptor>,
+              VectorLoadStoreConversion<vector::MaskedLoadOp,
+                                        vector::MaskedLoadOpAdaptor>,
+              VectorLoadStoreConversion<vector::StoreOp,
+                                        vector::StoreOpAdaptor>,
+              VectorLoadStoreConversion<vector::MaskedStoreOp,
+                                        vector::MaskedStoreOpAdaptor>,
               VectorGatherOpConversion,
               VectorScatterOpConversion,
               VectorExpandLoadOpConversion,
index 99b9788..a56b49a 100644 (file)
@@ -2374,13 +2374,74 @@ void TransferWriteOp::getEffects(
 }
 
 //===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+                                                 MemRefType memRefTy) {
+  auto affineMaps = memRefTy.getAffineMaps();
+  if (!affineMaps.empty())
+    return op->emitOpError("base memref should have a default identity layout");
+  return success();
+}
+
+static LogicalResult verify(vector::LoadOp op) {
+  VectorType resVecTy = op.getVectorType();
+  MemRefType memRefTy = op.getMemRefType();
+
+  if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
+    return failure();
+
+  // Checks for vector memrefs.
+  Type memElemTy = memRefTy.getElementType();
+  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
+    if (memVecTy != resVecTy)
+      return op.emitOpError("base memref and result vector types should match");
+    memElemTy = memVecTy.getElementType();
+  }
+
+  if (resVecTy.getElementType() != memElemTy)
+    return op.emitOpError("base and result element types should match");
+  if (llvm::size(op.indices()) != memRefTy.getRank())
+    return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(vector::StoreOp op) {
+  VectorType valueVecTy = op.getVectorType();
+  MemRefType memRefTy = op.getMemRefType();
+
+  if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
+    return failure();
+
+  // Checks for vector memrefs.
+  Type memElemTy = memRefTy.getElementType();
+  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
+    if (memVecTy != valueVecTy)
+      return op.emitOpError(
+          "base memref and valueToStore vector types should match");
+    memElemTy = memVecTy.getElementType();
+  }
+
+  if (valueVecTy.getElementType() != memElemTy)
+    return op.emitOpError("base and valueToStore element type should match");
+  if (llvm::size(op.indices()) != memRefTy.getRank())
+    return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // MaskedLoadOp
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(MaskedLoadOp op) {
   VectorType maskVType = op.getMaskVectorType();
   VectorType passVType = op.getPassThruVectorType();
-  VectorType resVType = op.getResultVectorType();
+  VectorType resVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (resVType.getElementType() != memType.getElementType())
@@ -2427,15 +2488,15 @@ void MaskedLoadOp::getCanonicalizationPatterns(
 
 static LogicalResult verify(MaskedStoreOp op) {
   VectorType maskVType = op.getMaskVectorType();
-  VectorType valueVType = op.getValueVectorType();
+  VectorType valueVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (valueVType.getElementType() != memType.getElementType())
-    return op.emitOpError("base and value element type should match");
+    return op.emitOpError("base and valueToStore element type should match");
   if (llvm::size(op.indices()) != memType.getRank())
     return op.emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
-    return op.emitOpError("expected value dim to match mask dim");
+    return op.emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }
 
@@ -2448,7 +2509,7 @@ public:
     switch (get1DMaskFormat(store.mask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-          store, store.value(), store.base(), store.indices(), false);
+          store, store.valueToStore(), store.base(), store.indices(), false);
       return success();
     case MaskFormat::AllFalse:
       rewriter.eraseOp(store);
index 7fba099..3df9bb3 100644 (file)
@@ -1,41 +1,5 @@
 // RUN: mlir-opt -lower-affine --split-input-file %s | FileCheck %s
 
-// CHECK-LABEL: func @affine_vector_load
-func @affine_vector_load(%arg0 : index) {
-  %0 = alloc() : memref<100xf32>
-  affine.for %i0 = 0 to 16 {
-    %1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32>
-  }
-// CHECK:       %[[buf:.*]] = alloc
-// CHECK:       %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
-// CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
-// CHECK-NEXT:  %[[b:.*]] = addi %[[a]], %[[c7]] : index
-// CHECK-NEXT:  %[[pad:.*]] = constant 0.0
-// CHECK-NEXT:  vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: func @affine_vector_store
-func @affine_vector_store(%arg0 : index) {
-  %0 = alloc() : memref<100xf32>
-  %1 = constant dense<11.0> : vector<4xf32>
-  affine.for %i0 = 0 to 16 {
-    affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32>
-}
-// CHECK:       %[[buf:.*]] = alloc
-// CHECK:       %[[val:.*]] = constant dense
-// CHECK:       %[[c_1:.*]] = constant -1 : index
-// CHECK-NEXT:  %[[a:.*]] = muli %arg0, %[[c_1]] : index
-// CHECK-NEXT:  %[[b:.*]] = addi %{{.*}}, %[[a]] : index
-// CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
-// CHECK-NEXT:  %[[c:.*]] = addi %[[b]], %[[c7]] : index
-// CHECK-NEXT:  vector.transfer_write  %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32>
-  return
-}
-
-// -----
 
 // CHECK-LABEL: func @affine_vector_load
 func @affine_vector_load(%arg0 : index) {
@@ -47,8 +11,7 @@ func @affine_vector_load(%arg0 : index) {
 // CHECK:       %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
 // CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
 // CHECK-NEXT:  %[[b:.*]] = addi %[[a]], %[[c7]] : index
-// CHECK-NEXT:  %[[pad:.*]] = constant 0.0
-// CHECK-NEXT:  vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32>
+// CHECK-NEXT:  vector.load %[[buf]][%[[b]]] : memref<100xf32>, vector<8xf32>
   return
 }
 
@@ -68,7 +31,7 @@ func @affine_vector_store(%arg0 : index) {
 // CHECK-NEXT:  %[[b:.*]] = addi %{{.*}}, %[[a]] : index
 // CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
 // CHECK-NEXT:  %[[c:.*]] = addi %[[b]], %[[c7]] : index
-// CHECK-NEXT:  vector.transfer_write  %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32>
+// CHECK-NEXT:  vector.store %[[val]], %[[buf]][%[[c]]] : memref<100xf32>, vector<4xf32>
   return
 }
 
@@ -83,8 +46,7 @@ func @vector_load_2d() {
 // CHECK:      %[[buf:.*]] = alloc
 // CHECK:      scf.for %[[i0:.*]] =
 // CHECK:        scf.for %[[i1:.*]] =
-// CHECK-NEXT:     %[[pad:.*]] = constant 0.0
-// CHECK-NEXT:     vector.transfer_read %[[buf]][%[[i0]], %[[i1]]], %[[pad]] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK-NEXT:     vector.load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
     }
   }
   return
@@ -103,9 +65,8 @@ func @vector_store_2d() {
 // CHECK:      %[[val:.*]] = constant dense
 // CHECK:      scf.for %[[i0:.*]] =
 // CHECK:        scf.for %[[i1:.*]] =
-// CHECK-NEXT:     vector.transfer_write  %[[val]], %[[buf]][%[[i0]], %[[i1]]] : vector<2x8xf32>, memref<100x100xf32>
+// CHECK-NEXT:     vector.store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
     }
   }
   return
 }
-
index facc91c..3d12943 100644 (file)
@@ -23,6 +23,7 @@ func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
 
 // -----
 
+
 func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
   %0 = vector.broadcast %arg0 : f32 to vector<2xf32>
   return %0 : vector<2xf32>
@@ -1242,6 +1243,33 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 // -----
 
+func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @vector_load_op
+// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
+// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
+// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<8xf32>>
+// CHECK: llvm.load %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<8xf32>>
+
+func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
+  %val = constant dense<11.0> : vector<4xf32>
+  vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: func @vector_store_op
+// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
+// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
+// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
+// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
+
 func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = constant 0: index
   %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
index 099dad7..ab58fdc 100644 (file)
@@ -1198,6 +1198,38 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
 
 // -----
 
+func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
+                               %i : index, %j : index, %value : vector<8xf32>) {
+  // expected-error@+1 {{'vector.store' op base memref should have a default identity layout}}
+  vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
+                                         vector<8xf32>
+}
+
+// -----
+
+func @vector_memref_mismatch(%memref : memref<200x100xvector<4xf32>>, %i : index,
+                             %j : index, %value : vector<8xf32>) {
+  // expected-error@+1 {{'vector.store' op base memref and valueToStore vector types should match}}
+  vector.store %value, %memref[%i, %j] : memref<200x100xvector<4xf32>>, vector<8xf32>
+}
+
+// -----
+
+func @store_base_type_mismatch(%base : memref<?xf64>, %value : vector<16xf32>) {
+  %c0 = constant 0 : index
+  // expected-error@+1 {{'vector.store' op base and valueToStore element type should match}}
+  vector.store %value, %base[%c0] : memref<?xf64>, vector<16xf32>
+}
+
+// -----
+
+func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16xf32>) {
+  // expected-error@+1 {{'vector.store' op requires 1 indices}}
+  vector.store %value, %base[] : memref<?xf32>, vector<16xf32>
+}
+
+// -----
+
 func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
   %c0 = constant 0 : index
   // expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
@@ -1231,7 +1263,7 @@ func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pa
 
 func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = constant 0 : index
-  // expected-error@+1 {{'vector.maskedstore' op base and value element type should match}}
+  // expected-error@+1 {{'vector.maskedstore' op base and valueToStore element type should match}}
   vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 
@@ -1239,7 +1271,7 @@ func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
 
 func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
   %c0 = constant 0 : index
-  // expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}}
+  // expected-error@+1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}}
   vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
 }
 
index 7284cab..11197f1 100644 (file)
@@ -450,6 +450,56 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
   return %0 : vector<16xi32>
 }
 
+// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
+func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
+                                             %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_1d_vector_memref
+func @vector_load_and_store_1d_vector_memref(%memref : memref<200x100xvector<8xf32>>,
+                                             %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_out_of_bounds
+func @vector_load_and_store_out_of_bounds(%memref : memref<7xf32>) {
+  %c0 = constant 0 : index
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32>
+  %0 = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32>
+  vector.store %0, %memref[%c0] : memref<7xf32>, vector<8xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_2d_scalar_memref
+func @vector_load_and_store_2d_scalar_memref(%memref : memref<200x100xf32>,
+                                             %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_2d_vector_memref
+func @vector_load_and_store_2d_vector_memref(%memref : memref<200x100xvector<4x8xf32>>,
+                                             %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+  return
+}
+
 // CHECK-LABEL: @masked_load_and_store
 func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
   %c0 = constant 0 : index