[mlir] Add "mask" operand to vector.transfer_read/write.
authorMatthias Springer <springerm@google.com>
Wed, 7 Apr 2021 12:11:55 +0000 (21:11 +0900)
committerMatthias Springer <springerm@google.com>
Wed, 7 Apr 2021 12:33:13 +0000 (21:33 +0900)
Also factors out out-of-bounds mask generation from vector.transfer_read/write into a new MaterializeTransferMask pattern.

Differential Revision: https://reviews.llvm.org/D100001

13 files changed:
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir

index efd26ff..0ee3fd5 100644 (file)
@@ -68,7 +68,7 @@ void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions = false, bool enableIndexOptimizations = true);
+    bool reassociateFPReductions = false);
 
 /// Create a pass to convert vector operations to the LLVMIR dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(
index a2ec152..c11e811 100644 (file)
@@ -88,6 +88,10 @@ void populateVectorSlicesLoweringPatterns(RewritePatternSet &patterns);
 /// `vector.store` and `vector.broadcast`.
 void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
 
+/// These patterns materialize masks for various vector ops such as transfers.
+void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
+                                               bool enableIndexOptimizations);
+
 /// An attribute that specifies the combining function for `vector.contract`,
 /// and `vector.reduction`.
 class CombiningKindAttr
index 5ff118b..14afe95 100644 (file)
@@ -1135,10 +1135,12 @@ def Vector_TransferReadOp :
   Vector_Op<"transfer_read", [
       DeclareOpInterfaceMethods<VectorTransferOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
-      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
+      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      AttrSizedOperandSegments
     ]>,
     Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
                AffineMapAttr:$permutation_map, AnyType:$padding,
+               Optional<VectorOf<[I1]>>:$mask,
                OptionalAttr<BoolArrayAttr>:$in_bounds)>,
     Results<(outs AnyVector:$vector)> {
 
@@ -1167,13 +1169,19 @@ def Vector_TransferReadOp :
     return type.
 
     An SSA value `padding` of the same elemental type as the MemRef/Tensor is
-    provided to specify a fallback value in the case of out-of-bounds accesses.
+    provided to specify a fallback value in the case of out-of-bounds accesses
+    and/or masking.
+
+    An optional SSA value `mask` of the same shape as the vector type may be
+    specified to mask out elements. Such elements will be replaces with
+    `padding`. Elements whose corresponding mask element is `0` are masked out.
 
     An optional boolean array attribute is provided to specify which dimensions
     of the transfer are guaranteed to be within bounds. The absence of this
     `in_bounds` attribute signifies that any dimension of the transfer may be
     out-of-bounds. A `vector.transfer_read` can be lowered to a simple load if
-    all dimensions are specified to be within bounds.
+    all dimensions are specified to be within bounds and no `mask` was
+    specified.
 
     This operation is called 'read' by opposition to 'load' because the
     super-vector granularity is generally not representable with a single
@@ -1299,6 +1307,14 @@ def Vector_TransferReadOp :
     // 'getMinorIdentityMap' (resp. zero).
     OpBuilder<(ins "VectorType":$vector, "Value":$source,
       "ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
+    // Builder that does not set mask.
+    OpBuilder<(ins "Type":$vector, "Value":$source,
+      "ValueRange":$indices, "AffineMapAttr":$permutationMap, "Value":$padding,
+      "ArrayAttr":$inBounds)>,
+    // Builder that does not set mask.
+    OpBuilder<(ins "Type":$vector, "Value":$source,
+      "ValueRange":$indices, "AffineMap":$permutationMap, "Value":$padding,
+      "ArrayAttr":$inBounds)>
   ];
 
   let hasFolder = 1;
@@ -1308,11 +1324,13 @@ def Vector_TransferWriteOp :
   Vector_Op<"transfer_write", [
       DeclareOpInterfaceMethods<VectorTransferOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
-      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
+      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      AttrSizedOperandSegments
   ]>,
     Arguments<(ins AnyVector:$vector, AnyShaped:$source,
                Variadic<Index>:$indices,
                AffineMapAttr:$permutation_map,
+               Optional<VectorOf<[I1]>>:$mask,
                OptionalAttr<BoolArrayAttr>:$in_bounds)>,
     Results<(outs Optional<AnyRankedTensor>:$result)> {
 
@@ -1341,11 +1359,16 @@ def Vector_TransferWriteOp :
 
     The size of the slice is specified by the size of the vector.
 
+    An optional SSA value `mask` of the same shape as the vector type may be
+    specified to mask out elements. Elements whose corresponding mask element
+    is `0` are masked out.
+
     An optional boolean array attribute is provided to specify which dimensions
     of the transfer are guaranteed to be within bounds. The absence of this
     `in_bounds` attribute signifies that any dimension of the transfer may be
     out-of-bounds. A `vector.transfer_write` can be lowered to a simple store
-    if all dimensions are specified to be within bounds.
+    if all dimensions are specified to be within bounds and no `mask` was
+    specified.
 
     This operation is called 'write' by opposition to 'store' because the
     super-vector granularity is generally not representable with a single
@@ -1392,6 +1415,8 @@ def Vector_TransferWriteOp :
     OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
       "AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
     OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
+      "AffineMap":$permutationMap, "Value":$mask, "ArrayAttr":$inBounds)>,
+    OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
       "AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
   ];
 
index 82e4bc2..0c752c3 100644 (file)
@@ -104,66 +104,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
   return res;
 }
 
-static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
-                                   Location loc, Type targetType, Value value) {
-  if (targetType == value.getType())
-    return value;
-
-  bool targetIsIndex = targetType.isIndex();
-  bool valueIsIndex = value.getType().isIndex();
-  if (targetIsIndex ^ valueIsIndex)
-    return rewriter.create<IndexCastOp>(loc, targetType, value);
-
-  auto targetIntegerType = targetType.dyn_cast<IntegerType>();
-  auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
-  assert(targetIntegerType && valueIntegerType &&
-         "unexpected cast between types other than integers and index");
-  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
-
-  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
-    return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
-  return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
-}
-
-// Helper that returns a vector comparison that constructs a mask:
-//     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
-//
-// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
-//       much more compact, IR for this operation, but LLVM eventually
-//       generates more elaborate instructions for this intrinsic since it
-//       is very conservative on the boundary conditions.
-static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
-                                   Operation *op, bool enableIndexOptimizations,
-                                   int64_t dim, Value b, Value *off = nullptr) {
-  auto loc = op->getLoc();
-  // If we can assume all indices fit in 32-bit, we perform the vector
-  // comparison in 32-bit to get a higher degree of SIMD parallelism.
-  // Otherwise we perform the vector comparison using 64-bit indices.
-  Value indices;
-  Type idxType;
-  if (enableIndexOptimizations) {
-    indices = rewriter.create<ConstantOp>(
-        loc, rewriter.getI32VectorAttr(
-                 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
-    idxType = rewriter.getI32Type();
-  } else {
-    indices = rewriter.create<ConstantOp>(
-        loc, rewriter.getI64VectorAttr(
-                 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
-    idxType = rewriter.getI64Type();
-  }
-  // Add in an offset if requested.
-  if (off) {
-    Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
-    Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
-    indices = rewriter.create<AddIOp>(loc, ov, indices);
-  }
-  // Construct the vector comparison.
-  Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
-  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
-  return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
-}
-
 // Helper that returns data layout alignment of a memref.
 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
                                  MemRefType memrefType, unsigned &align) {
@@ -250,7 +190,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
   if (failed(getMemRefAlignment(
           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
     return failure();
-  auto adaptor = TransferWriteOpAdaptor(operands);
+  auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
                                              align);
   return success();
@@ -266,7 +206,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
     return failure();
 
-  auto adaptor = TransferWriteOpAdaptor(operands);
+  auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
       xferOp, adaptor.vector(), dataPtr, mask,
       rewriter.getI32IntegerAttr(align));
@@ -275,12 +215,12 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
 
 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
                                                   ArrayRef<Value> operands) {
-  return TransferReadOpAdaptor(operands);
+  return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
 }
 
 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
                                                    ArrayRef<Value> operands) {
-  return TransferWriteOpAdaptor(operands);
+  return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
 }
 
 namespace {
@@ -618,33 +558,6 @@ private:
   const bool reassociateFPReductions;
 };
 
-/// Conversion pattern for a vector.create_mask (1-D only).
-class VectorCreateMaskOpConversion
-    : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
-public:
-  explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
-                                        bool enableIndexOpt)
-      : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
-        enableIndexOptimizations(enableIndexOpt) {}
-
-  LogicalResult
-  matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto dstType = op.getType();
-    int64_t rank = dstType.getRank();
-    if (rank == 1) {
-      rewriter.replaceOp(
-          op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
-                                    dstType.getDimSize(0), operands[0]));
-      return success();
-    }
-    return failure();
-  }
-
-private:
-  const bool enableIndexOptimizations;
-};
-
 class VectorShuffleOpConversion
     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
 public:
@@ -1177,20 +1090,12 @@ public:
   }
 };
 
-/// Conversion pattern that converts a 1-D vector transfer read/write op in a
-/// sequence of:
-/// 1. Get the source/dst address as an LLVM vector pointer.
-/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-/// 4. Create a mask where offsetVector is compared against memref upper bound.
-/// 5. Rewrite op as a masked read or write.
+/// Conversion pattern that converts a 1-D vector transfer read/write op into a
+/// a masked or unmasked read/write.
 template <typename ConcreteOp>
 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
 public:
-  explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
-                                    bool enableIndexOpt)
-      : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
-        enableIndexOptimizations(enableIndexOpt) {}
+  using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
@@ -1212,6 +1117,9 @@ public:
     auto strides = computeContiguousStrides(memRefType);
     if (!strides)
       return failure();
+    // Out-of-bounds dims are handled by MaterializeTransferMask.
+    if (xferOp.hasOutOfBoundsDim())
+      return failure();
 
     auto toLLVMTy = [&](Type t) {
       return this->getTypeConverter()->convertType(t);
@@ -1241,40 +1149,24 @@ public:
 #endif // ifndef NDEBUG
     }
 
-    // 1. Get the source/dst address as an LLVM vector pointer.
+    // Get the source/dst address as an LLVM vector pointer.
     VectorType vtp = xferOp.getVectorType();
     Value dataPtr = this->getStridedElementPtr(
         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
     Value vectorDataPtr =
         castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
 
-    if (xferOp.isDimInBounds(0))
+    // Rewrite as an unmasked masked read / write.
+    if (!xferOp.mask())
       return replaceTransferOpWithLoadOrStore(rewriter,
                                               *this->getTypeConverter(), loc,
                                               xferOp, operands, vectorDataPtr);
 
-    // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-    // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-    // 4. Let dim the memref dimension, compute the vector comparison mask
-    //    (in-bounds mask):
-    //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
-    //
-    // TODO: when the leaf transfer rank is k > 1, we need the last `k`
-    //       dimensions here.
-    unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
-    unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
-    Value off = xferOp.indices()[lastIndex];
-    Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
-    Value mask = buildVectorComparison(
-        rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
-
-    // 5. Rewrite as a masked read / write.
+    // Rewrite as a masked read / write.
     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
-                                       xferOp, operands, vectorDataPtr, mask);
+                                       xferOp, operands, vectorDataPtr,
+                                       xferOp.mask());
   }
-
-private:
-  const bool enableIndexOptimizations;
 };
 
 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
@@ -1484,17 +1376,13 @@ public:
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions, bool enableIndexOptimizations) {
+    bool reassociateFPReductions) {
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorFMAOpNDRewritePattern,
                VectorInsertStridedSliceOpDifferentRankRewritePattern,
                VectorInsertStridedSliceOpSameRankRewritePattern,
                VectorExtractStridedSliceOpConversion>(ctx);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
-  patterns.add<VectorCreateMaskOpConversion,
-               VectorTransferConversion<TransferReadOp>,
-               VectorTransferConversion<TransferWriteOp>>(
-      converter, enableIndexOptimizations);
   patterns
       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
            VectorExtractElementOpConversion, VectorExtractOpConversion,
@@ -1508,8 +1396,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
            VectorLoadStoreConversion<vector::MaskedStoreOp,
                                      vector::MaskedStoreOpAdaptor>,
            VectorGatherOpConversion, VectorScatterOpConversion,
-           VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
-          converter);
+           VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
+           VectorTransferConversion<TransferReadOp>,
+           VectorTransferConversion<TransferWriteOp>>(converter);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
index abddcd7..49ee670 100644 (file)
@@ -71,9 +71,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
   // Convert to the LLVM IR dialect.
   LLVMTypeConverter converter(&getContext());
   RewritePatternSet patterns(&getContext());
+  populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
-  populateVectorToLLVMConversionPatterns(
-      converter, patterns, reassociateFPReductions, enableIndexOptimizations);
+  populateVectorToLLVMConversionPatterns(converter, patterns,
+                                         reassociateFPReductions);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
 
   // Architecture specific augmentations.
index b55c8bc..2f033b1 100644 (file)
@@ -42,7 +42,7 @@ static LogicalResult replaceTransferOpWithMubuf(
     LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
     Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
     Value &glc, Value &slc) {
-  auto adaptor = TransferWriteOpAdaptor(operands);
+  auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
   rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
                                                    dwordConfig, vindex,
                                                    offsetSizeInBytes, glc, slc);
@@ -62,7 +62,7 @@ public:
   LogicalResult
   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    typename ConcreteOp::Adaptor adaptor(operands);
+    typename ConcreteOp::Adaptor adaptor(operands, xferOp->getAttrDictionary());
 
     if (xferOp.getVectorType().getRank() > 1 ||
         llvm::size(xferOp.indices()) == 0)
index 6e963ae..72d32d0 100644 (file)
@@ -538,6 +538,8 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
   using namespace mlir::edsc::op;
 
   TransferReadOp transfer = cast<TransferReadOp>(op);
+  if (transfer.mask())
+    return failure();
   auto memRefType = transfer.getShapedType().dyn_cast<MemRefType>();
   if (!memRefType)
     return failure();
@@ -624,6 +626,8 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
   using namespace edsc::op;
 
   TransferWriteOp transfer = cast<TransferWriteOp>(op);
+  if (transfer.mask())
+    return failure();
   auto memRefType = transfer.getShapedType().template dyn_cast<MemRefType>();
   if (!memRefType)
     return failure();
index 7e4233d..cff5fcb 100644 (file)
@@ -2295,8 +2295,27 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, vectorType, source, indices, permMap, inBounds);
 }
 
+/// Builder that does not provide a mask.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           Type vectorType, Value source, ValueRange indices,
+                           AffineMap permutationMap, Value padding,
+                           ArrayAttr inBounds) {
+  build(builder, result, vectorType, source, indices, permutationMap, padding,
+        /*mask=*/Value(), inBounds);
+}
+
+/// Builder that does not provide a mask.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           Type vectorType, Value source, ValueRange indices,
+                           AffineMapAttr permutationMap, Value padding,
+                           ArrayAttr inBounds) {
+  build(builder, result, vectorType, source, indices, permutationMap, padding,
+        /*mask=*/Value(), inBounds);
+}
+
 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
-  SmallVector<StringRef, 2> elidedAttrs;
+  SmallVector<StringRef, 3> elidedAttrs;
+  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
   if (op.permutation_map().isMinorIdentity())
     elidedAttrs.push_back(op.getPermutationMapAttrName());
   bool elideInBounds = true;
@@ -2316,27 +2335,36 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
 static void print(OpAsmPrinter &p, TransferReadOp op) {
   p << op.getOperationName() << " " << op.source() << "[" << op.indices()
     << "], " << op.padding();
+  if (op.mask())
+    p << ", " << op.mask();
   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
   p << " : " << op.getShapedType() << ", " << op.getVectorType();
 }
 
 static ParseResult parseTransferReadOp(OpAsmParser &parser,
                                        OperationState &result) {
+  auto &builder = parser.getBuilder();
   llvm::SMLoc typesLoc;
   OpAsmParser::OperandType sourceInfo;
   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
   OpAsmParser::OperandType paddingInfo;
   SmallVector<Type, 2> types;
+  OpAsmParser::OperandType maskInfo;
   // Parsing with support for paddingValue.
   if (parser.parseOperand(sourceInfo) ||
       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseComma() || parser.parseOperand(paddingInfo) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseComma() || parser.parseOperand(paddingInfo))
+    return failure();
+  ParseResult hasMask = parser.parseOptionalComma();
+  if (hasMask.succeeded()) {
+    parser.parseOperand(maskInfo);
+  }
+  if (parser.parseOptionalAttrDict(result.attributes) ||
       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
     return failure();
   if (types.size() != 2)
     return parser.emitError(typesLoc, "requires two types");
-  auto indexType = parser.getBuilder().getIndexType();
+  auto indexType = builder.getIndexType();
   auto shapedType = types[0].dyn_cast<ShapedType>();
   if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
@@ -2349,12 +2377,21 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
   }
-  return failure(
-      parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
+  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
       parser.resolveOperands(indexInfo, indexType, result.operands) ||
       parser.resolveOperand(paddingInfo, shapedType.getElementType(),
-                            result.operands) ||
-      parser.addTypeToList(vectorType, result.types));
+                            result.operands))
+    return failure();
+  if (hasMask.succeeded()) {
+    auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
+    if (parser.resolveOperand(maskInfo, maskType, result.operands))
+      return failure();
+  }
+  result.addAttribute(
+      TransferReadOp::getOperandSegmentSizeAttr(),
+      builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
+                                static_cast<int32_t>(hasMask.succeeded())}));
+  return parser.addTypeToList(vectorType, result.types);
 }
 
 static LogicalResult verify(TransferReadOp op) {
@@ -2525,7 +2562,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                             /*optional*/ ArrayAttr inBounds) {
   Type resultType = source.getType().dyn_cast<RankedTensorType>();
   build(builder, result, resultType, vector, source, indices, permutationMap,
-        inBounds);
+        /*mask=*/Value(), inBounds);
 }
 
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
@@ -2534,24 +2571,39 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                             /*optional*/ ArrayAttr inBounds) {
   Type resultType = source.getType().dyn_cast<RankedTensorType>();
   build(builder, result, resultType, vector, source, indices, permutationMap,
-        inBounds);
+        /*mask=*/Value(), inBounds);
+}
+
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+                            Value vector, Value source, ValueRange indices,
+                            AffineMap permutationMap, /*optional*/ Value mask,
+                            /*optional*/ ArrayAttr inBounds) {
+  Type resultType = source.getType().dyn_cast<RankedTensorType>();
+  build(builder, result, resultType, vector, source, indices, permutationMap,
+        mask, inBounds);
 }
 
 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
                                         OperationState &result) {
+  auto &builder = parser.getBuilder();
   llvm::SMLoc typesLoc;
   OpAsmParser::OperandType vectorInfo, sourceInfo;
   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
   SmallVector<Type, 2> types;
+  OpAsmParser::OperandType maskInfo;
   if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
       parser.parseOperand(sourceInfo) ||
-      parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
+    return failure();
+  ParseResult hasMask = parser.parseOptionalComma();
+  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
+    return failure();
+  if (parser.parseOptionalAttrDict(result.attributes) ||
       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
     return failure();
   if (types.size() != 2)
     return parser.emitError(typesLoc, "requires two types");
-  auto indexType = parser.getBuilder().getIndexType();
+  auto indexType = builder.getIndexType();
   VectorType vectorType = types[0].dyn_cast<VectorType>();
   if (!vectorType)
     return parser.emitError(typesLoc, "requires vector type");
@@ -2564,17 +2616,28 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
   }
-  return failure(
-      parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
+  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
       parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
-      parser.resolveOperands(indexInfo, indexType, result.operands) ||
-      (shapedType.isa<RankedTensorType>() &&
-       parser.addTypeToList(shapedType, result.types)));
+      parser.resolveOperands(indexInfo, indexType, result.operands))
+    return failure();
+  if (hasMask.succeeded()) {
+    auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
+    if (parser.resolveOperand(maskInfo, maskType, result.operands))
+      return failure();
+  }
+  result.addAttribute(
+      TransferWriteOp::getOperandSegmentSizeAttr(),
+      builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
+                                static_cast<int32_t>(hasMask.succeeded())}));
+  return failure(shapedType.isa<RankedTensorType>() &&
+                 parser.addTypeToList(shapedType, result.types));
 }
 
 static void print(OpAsmPrinter &p, TransferWriteOp op) {
   p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
     << op.indices() << "]";
+  if (op.mask())
+    p << ", " << op.mask();
   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
   p << " : " << op.getVectorType() << ", " << op.getShapedType();
 }
index b48c8ac..ba8ca26 100644 (file)
@@ -596,6 +596,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
                                   OpBuilder &builder) {
   if (!isIdentitySuffix(readOp.permutation_map()))
     return nullptr;
+  if (readOp.mask())
+    return nullptr;
   auto sourceVectorType = readOp.getVectorType();
   SmallVector<int64_t, 4> strides(targetShape.size(), 1);
 
@@ -641,6 +643,8 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
   auto writeOp = cast<vector::TransferWriteOp>(op);
   if (!isIdentitySuffix(writeOp.permutation_map()))
     return failure();
+  if (writeOp.mask())
+    return failure();
   VectorType sourceVectorType = writeOp.getVectorType();
   SmallVector<int64_t, 4> strides(targetShape.size(), 1);
   TupleType tupleType = generateExtractSlicesOpResultType(
@@ -722,6 +726,9 @@ public:
     if (ignoreFilter && ignoreFilter(readOp))
       return failure();
 
+    if (readOp.mask())
+      return failure();
+
     // TODO: Support splitting TransferReadOp with non-identity permutation
     // maps. Repurpose code from MaterializeVectors transformation.
     if (!isIdentitySuffix(readOp.permutation_map()))
@@ -768,6 +775,9 @@ public:
     if (ignoreFilter && ignoreFilter(writeOp))
       return failure();
 
+    if (writeOp.mask())
+      return failure();
+
     // TODO: Support splitting TransferWriteOp with non-identity permutation
     // maps. Repurpose code from MaterializeVectors transformation.
     if (!isIdentitySuffix(writeOp.permutation_map()))
@@ -2546,6 +2556,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
          "Expected splitFullAndPartialTransferPrecondition to hold");
   auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
 
+  if (xferReadOp.mask())
+    return failure();
+
   // TODO: add support for write case.
   if (!xferReadOp)
     return failure();
@@ -2677,6 +2690,8 @@ struct TransferReadExtractPattern
         dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
     if (!extract)
       return failure();
+    if (read.mask())
+      return failure();
     edsc::ScopedContext scope(rewriter, read.getLoc());
     using mlir::edsc::op::operator+;
     using mlir::edsc::op::operator*;
@@ -2712,6 +2727,8 @@ struct TransferWriteInsertPattern
     auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
     if (!insert)
       return failure();
+    if (write.mask())
+      return failure();
     edsc::ScopedContext scope(rewriter, write.getLoc());
     using mlir::edsc::op::operator+;
     using mlir::edsc::op::operator*;
@@ -2742,6 +2759,7 @@ struct TransferWriteInsertPattern
 /// - If the memref's element type is a vector type then it coincides with the
 ///   result type.
 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
+/// - The op has no mask.
 struct TransferReadToVectorLoadLowering
     : public OpRewritePattern<vector::TransferReadOp> {
   TransferReadToVectorLoadLowering(MLIRContext *context)
@@ -2780,7 +2798,8 @@ struct TransferReadToVectorLoadLowering
     //       MaskedLoadOp.
     if (read.hasOutOfBoundsDim())
       return failure();
-
+    if (read.mask())
+      return failure();
     Operation *loadOp;
     if (!broadcastedDims.empty() &&
         unbroadcastedVectorType.getNumElements() == 1) {
@@ -2815,6 +2834,7 @@ struct TransferReadToVectorLoadLowering
 ///   type of the written value.
 /// - The permutation map is the minor identity map (neither permutation nor
 ///   broadcasting is allowed).
+/// - The op has no mask.
 struct TransferWriteToVectorStoreLowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   TransferWriteToVectorStoreLowering(MLIRContext *context)
@@ -2840,6 +2860,8 @@ struct TransferWriteToVectorStoreLowering
     //       MaskedStoreOp.
     if (write.hasOutOfBoundsDim())
       return failure();
+    if (write.mask())
+      return failure();
     rewriter.replaceOpWithNewOp<vector::StoreOp>(
         write, write.vector(), write.source(), write.indices());
     return success();
@@ -2880,6 +2902,8 @@ struct TransferReadPermutationLowering
         map.getPermutationMap(permutation, op.getContext());
     if (permutationMap.isIdentity())
       return failure();
+    if (op.mask())
+      return failure();
     // Caluclate the map of the new read by applying the inverse permutation.
     permutationMap = inversePermutation(permutationMap);
     AffineMap newMap = permutationMap.compose(map);
@@ -2914,6 +2938,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
 
   LogicalResult matchAndRewrite(vector::TransferReadOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.mask())
+      return failure();
     AffineMap map = op.permutation_map();
     unsigned numLeadingBroadcast = 0;
     for (auto expr : map.getResults()) {
@@ -3062,6 +3088,9 @@ struct CastAwayTransferReadLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
+    if (read.mask())
+      return failure();
+
     auto shapedType = read.source().getType().cast<ShapedType>();
     if (shapedType.getElementType() != read.getVectorType().getElementType())
       return failure();
@@ -3102,6 +3131,9 @@ struct CastAwayTransferWriteLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
+    if (write.mask())
+      return failure();
+
     auto shapedType = write.source().getType().dyn_cast<ShapedType>();
     if (shapedType.getElementType() != write.getVectorType().getElementType())
       return failure();
@@ -3371,6 +3403,151 @@ struct BubbleUpBitCastForStridedSliceInsert
   }
 };
 
+static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
+                                   Type targetType, Value value) {
+  if (targetType == value.getType())
+    return value;
+
+  bool targetIsIndex = targetType.isIndex();
+  bool valueIsIndex = value.getType().isIndex();
+  if (targetIsIndex ^ valueIsIndex)
+    return rewriter.create<IndexCastOp>(loc, targetType, value);
+
+  auto targetIntegerType = targetType.dyn_cast<IntegerType>();
+  auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
+  assert(targetIntegerType && valueIntegerType &&
+         "unexpected cast between types other than integers and index");
+  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
+
+  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
+    return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
+  return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
+}
+
+// Helper that returns a vector comparison that constructs a mask:
+//     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
+//
+// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
+//       much more compact, IR for this operation, but LLVM eventually
+//       generates more elaborate instructions for this intrinsic since it
+//       is very conservative on the boundary conditions.
+static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
+                                   bool enableIndexOptimizations, int64_t dim,
+                                   Value b, Value *off = nullptr) {
+  auto loc = op->getLoc();
+  // If we can assume all indices fit in 32-bit, we perform the vector
+  // comparison in 32-bit to get a higher degree of SIMD parallelism.
+  // Otherwise we perform the vector comparison using 64-bit indices.
+  Value indices;
+  Type idxType;
+  if (enableIndexOptimizations) {
+    indices = rewriter.create<ConstantOp>(
+        loc, rewriter.getI32VectorAttr(
+                 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
+    idxType = rewriter.getI32Type();
+  } else {
+    indices = rewriter.create<ConstantOp>(
+        loc, rewriter.getI64VectorAttr(
+                 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
+    idxType = rewriter.getI64Type();
+  }
+  // Add in an offset if requested.
+  if (off) {
+    Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
+    Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
+    indices = rewriter.create<AddIOp>(loc, ov, indices);
+  }
+  // Construct the vector comparison.
+  Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
+  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+  return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
+}
+
+template <typename ConcreteOp>
+struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
+public:
+  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
+      : mlir::OpRewritePattern<ConcreteOp>(context),
+        enableIndexOptimizations(enableIndexOpt) {}
+
+  LogicalResult matchAndRewrite(ConcreteOp xferOp,
+                                PatternRewriter &rewriter) const override {
+    if (!xferOp.hasOutOfBoundsDim())
+      return failure();
+
+    if (xferOp.getVectorType().getRank() > 1 ||
+        llvm::size(xferOp.indices()) == 0)
+      return failure();
+
+    Location loc = xferOp->getLoc();
+    VectorType vtp = xferOp.getVectorType();
+
+    // * Create a vector with linear indices [ 0 .. vector_length - 1 ].
+    // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+    // * Let dim the memref dimension, compute the vector comparison mask
+    //   (in-bounds mask):
+    //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+    //
+    // TODO: when the leaf transfer rank is k > 1, we need the last `k`
+    //       dimensions here.
+    unsigned vecWidth = vtp.getNumElements();
+    unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
+    Value off = xferOp.indices()[lastIndex];
+    Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
+    Value mask = buildVectorComparison(
+        rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
+
+    if (xferOp.mask()) {
+      // Intersect the in-bounds with the mask specified as an op parameter.
+      mask = rewriter.create<AndOp>(loc, mask, xferOp.mask());
+    }
+
+    rewriter.updateRootInPlace(xferOp, [&]() {
+      xferOp.maskMutable().assign(mask);
+      xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
+    });
+
+    return success();
+  }
+
+private:
+  const bool enableIndexOptimizations;
+};
+
+/// Conversion pattern for a vector.create_mask (1-D only).
+class VectorCreateMaskOpConversion
+    : public OpRewritePattern<vector::CreateMaskOp> {
+public:
+  explicit VectorCreateMaskOpConversion(MLIRContext *context,
+                                        bool enableIndexOpt)
+      : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
+        enableIndexOptimizations(enableIndexOpt) {}
+
+  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
+                                PatternRewriter &rewriter) const override {
+    auto dstType = op.getType();
+    int64_t rank = dstType.getRank();
+    if (rank == 1) {
+      rewriter.replaceOp(
+          op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
+                                    dstType.getDimSize(0), op.getOperand(0)));
+      return success();
+    }
+    return failure();
+  }
+
+private:
+  const bool enableIndexOptimizations;
+};
+
+void mlir::vector::populateVectorMaskMaterializationPatterns(
+    RewritePatternSet &patterns, bool enableIndexOptimizations) {
+  patterns.add<VectorCreateMaskOpConversion,
+               MaterializeTransferMask<vector::TransferReadOp>,
+               MaterializeTransferMask<vector::TransferWriteOp>>(
+      patterns.getContext(), enableIndexOptimizations);
+}
+
 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
index 249f8c0..c09b4ac 100644 (file)
@@ -3,20 +3,19 @@
 
 // CMP32-LABEL: @genbool_var_1d(
 // CMP32-SAME: %[[ARG:.*]]: index)
-// CMP32: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64
 // CMP32: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>
-// CMP32: %[[T1:.*]] = trunci %[[A]] : i64 to i32
+// CMP32: %[[T1:.*]] = index_cast %[[ARG]] : index to i32
 // CMP32: %[[T2:.*]] = splat %[[T1]] : vector<11xi32>
 // CMP32: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi32>
 // CMP32: return %[[T3]] : vector<11xi1>
 
 // CMP64-LABEL: @genbool_var_1d(
 // CMP64-SAME: %[[ARG:.*]]: index)
-// CMP64: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64
 // CMP64: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>
-// CMP64: %[[T1:.*]] = splat %[[A]] : vector<11xi64>
-// CMP64: %[[T2:.*]] = cmpi slt, %[[T0]], %[[T1]] : vector<11xi64>
-// CMP64: return %[[T2]] : vector<11xi1>
+// CMP64: %[[T1:.*]] = index_cast %[[ARG]] : index to i64
+// CMP64: %[[T2:.*]] = splat %[[T1]] : vector<11xi64>
+// CMP64: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi64>
+// CMP64: return %[[T3]] : vector<11xi1>
 
 func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
   %0 = vector.create_mask %arg0 : vector<11xi1>
index a5161b6..9faf7ca 100644 (file)
@@ -1049,31 +1049,31 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 // CHECK-LABEL: func @transfer_read_1d
 //  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
 //       CHECK: %[[c7:.*]] = constant 7.0
-//
-// 1. Bitcast to vector form.
-//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-//       CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
-//  CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
 //       CHECK: %[[C0:.*]] = constant 0 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
 //
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
 //       CHECK: %[[linearIndex:.*]] = constant dense
 //  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
 //  CHECK-SAME: vector<17xi32>
 //
-// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
 //       CHECK: %[[otrunc:.*]] = index_cast %[[BASE]] : index to i32
 //       CHECK: %[[offsetVec:.*]] = splat %[[otrunc]] : vector<17xi32>
 //       CHECK: %[[offsetVec2:.*]] = addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32>
 //
-// 4. Let dim the memref dimension, compute the vector comparison mask:
+// 3. Let dim the memref dimension, compute the vector comparison mask:
 //    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
 //       CHECK: %[[dtrunc:.*]] = index_cast %[[DIM]] : index to i32
 //       CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32>
 //       CHECK: %[[mask:.*]] = cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32>
 //
+// 4. Bitcast to vector form.
+//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+//  CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+//       CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
+//  CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
+//
 // 5. Rewrite as a masked read.
 //       CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
 //       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
@@ -1081,26 +1081,26 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //  CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
 
 //
-// 1. Bitcast to vector form.
-//       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-//       CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
-//  CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
-//
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
 //       CHECK: %[[linearIndex_b:.*]] = constant dense
 //  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
 //  CHECK-SAME: vector<17xi32>
 //
-// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
 //       CHECK: splat %{{.*}} : vector<17xi32>
 //       CHECK: addi
 //
-// 4. Let dim the memref dimension, compute the vector comparison mask:
+// 3. Let dim the memref dimension, compute the vector comparison mask:
 //    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
 //       CHECK: splat %{{.*}} : vector<17xi32>
 //       CHECK: %[[mask_b:.*]] = cmpi slt, {{.*}} : vector<17xi32>
 //
+// 4. Bitcast to vector form.
+//       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
+//  CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+//       CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
+//  CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
+//
 // 5. Rewrite as a masked write.
 //       CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
 //  CHECK-SAME: {alignment = 4 : i32} :
@@ -1182,6 +1182,21 @@ func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector<17xf
 
 // -----
 
+// CHECK-LABEL: func @transfer_read_1d_mask
+// CHECK: %[[mask1:.*]] = constant dense<[false, false, true, false, true]>
+// CHECK: %[[cmpi:.*]] = cmpi slt
+// CHECK: %[[mask2:.*]] = and %[[cmpi]], %[[mask1]]
+// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
+// CHECK: return %[[r]]
+func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32> {
+  %m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
+  %f7 = constant 7.0: f32
+  %f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<5xf32>
+  return %f: vector<5xf32>
+}
+
+// -----
+
 func @transfer_read_1d_cast(%A : memref<?xi32>, %base: index) -> vector<12xi8> {
   %c0 = constant 0: i32
   %v = vector.transfer_read %A[%base], %c0 {in_bounds = [true]} :
index 1e6f95a..43bef97 100644 (file)
@@ -11,6 +11,7 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
   %c0 = constant 0 : i32
   %vf0 = splat %f0 : vector<4x3xf32>
   %v0 = splat %c0 : vector<4x3xi32>
+  %m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
 
   //
   // CHECK: vector.transfer_read
@@ -27,7 +28,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
   %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {in_bounds = [false, true]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
   // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
   %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
-
+  // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<5xf32>
+  %7 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref<?x?xf32>, vector<5xf32>
 
   // CHECK: vector.transfer_write
   vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
@@ -39,7 +41,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
   vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
   // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
   vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
-
+  // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : vector<5xf32>, memref<?x?xf32>
+  vector.transfer_write %7, %arg0[%c3, %c3], %m : vector<5xf32>, memref<?x?xf32>
   return
 }
 
index 5cd7d09..bed94f0 100644 (file)
@@ -12,6 +12,14 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) {
   return
 }
 
+func @transfer_read_mask_1d(%A : memref<?xf32>, %base: index) {
+  %fm42 = constant -42.0: f32
+  %m = constant dense<[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]> : vector<13xi1>
+  %f = vector.transfer_read %A[%base], %fm42, %m : memref<?xf32>, vector<13xf32>
+  vector.print %f: vector<13xf32>
+  return
+}
+
 func @transfer_read_inbounds_4(%A : memref<?xf32>, %base: index) {
   %fm42 = constant -42.0: f32
   %f = vector.transfer_read %A[%base], %fm42
@@ -21,6 +29,15 @@ func @transfer_read_inbounds_4(%A : memref<?xf32>, %base: index) {
   return
 }
 
+func @transfer_read_mask_inbounds_4(%A : memref<?xf32>, %base: index) {
+  %fm42 = constant -42.0: f32
+  %m = constant dense<[0, 1, 0, 1]> : vector<4xi1>
+  %f = vector.transfer_read %A[%base], %fm42, %m {in_bounds = [true]}
+      : memref<?xf32>, vector<4xf32>
+  vector.print %f: vector<4xf32>
+  return
+}
+
 func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
   %f0 = constant 0.0 : f32
   %vf0 = splat %f0 : vector<4xf32>
@@ -47,6 +64,8 @@ func @entry() {
   // Read shifted by 2 and pad with -42:
   //   ( 2, 3, 4, -42, ..., -42)
   call @transfer_read_1d(%A, %c2) : (memref<?xf32>, index) -> ()
+  // Read with mask and out-of-bounds access.
+  call @transfer_read_mask_1d(%A, %c2) : (memref<?xf32>, index) -> ()
   // Write into memory shifted by 3
   //   memory contains [[ 0, 1, 2, 0, 0, xxx garbage xxx ]]
   call @transfer_write_1d(%A, %c3) : (memref<?xf32>, index) -> ()
@@ -56,9 +75,13 @@ func @entry() {
   // Read in-bounds 4 @ 1, guaranteed to not overflow.
   // Exercises proper alignment.
   call @transfer_read_inbounds_4(%A, %c1) : (memref<?xf32>, index) -> ()
+  // Read in-bounds with mask.
+  call @transfer_read_mask_inbounds_4(%A, %c1) : (memref<?xf32>, index) -> ()
   return
 }
 
 // CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
+// CHECK: ( -42, -42, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
 // CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 )
 // CHECK: ( 1, 2, 0, 0 )
+// CHECK: ( -42, 2, -42, 0 )