[mlir][Vector] Remove lhs and rhs masks from vector.contract
authorDiego Caballero <diegocaballero@google.com>
Wed, 29 Mar 2023 19:16:12 +0000 (19:16 +0000)
committerDiego Caballero <diegocaballero@google.com>
Wed, 29 Mar 2023 19:53:29 +0000 (19:53 +0000)
This patch removes the historical lhs and rhs masks in vector.contract,
now that vector.mask supports vector.contract and the lhs and rhs masks
are barely supported by all the vector.contract lowerings and
transformations.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D144430

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-transforms.mlir

index 539486f..0b844e1 100644 (file)
@@ -95,7 +95,6 @@ def Vector_ContractionOp :
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
-               Variadic<VectorOf<[I1]>>:$masks,
                ArrayAttr:$indexing_maps,
                Vector_IteratorTypeArrayAttr:$iterator_types,
                DefaultValuedAttr<Vector_CombiningKindAttr,
@@ -115,10 +114,6 @@ def Vector_ContractionOp :
     contraction. For integer types, only signless integer types are supported,
     and the promotion happens via sign extension.
 
-    Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp)
-    specify the dynamic dimension sizes of valid data within the lhs/rhs vector
-    arguments.
-
     An iterator type attribute list must be specified, where each element of
     the list represents an iterator with one of the following types:
 
@@ -195,17 +190,9 @@ def Vector_ContractionOp :
     %4 = vector.contract #contraction_trait %0, %1, %2
         : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
 
-    // 4D vector contraction with two contracting dimensions and optional
-    // vector mask arguments.
-    %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
-    %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
-
-    %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
-       : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
-
     // Vector contraction with mixed typed. lhs/rhs have different element
     // types than accumulator/result.
-    %6 = vector.contract #contraction_trait %0, %1, %2
+    %5 = vector.contract #contraction_trait %0, %1, %2
       : vector<10xf16>, vector<10xf16> into f32
 
     // Contract with max (K = 0).
@@ -219,7 +206,7 @@ def Vector_ContractionOp :
       iterator_types = ["reduction"],
       kind = #vector.kind<max>
     }
-    %7 = vector.contract #contraction_trait %0, %1, %2
+    %6 = vector.contract #contraction_trait %0, %1, %2
       : vector<10xf32>, vector<10xf32> into f32
     ```
   }];
@@ -241,14 +228,6 @@ def Vector_ContractionOp :
       return getRhs().getType().cast<VectorType>();
     }
     Type getAccType() { return getAcc().getType(); }
-    VectorType getLHSVectorMaskType() {
-      if (llvm::size(getMasks()) != 2) return VectorType();
-      return getOperand(3).getType().cast<VectorType>();
-    }
-    VectorType getRHSVectorMaskType() {
-      if (llvm::size(getMasks()) != 2) return VectorType();
-      return getOperand(4).getType().cast<VectorType>();
-    }
     Type getResultType() { return getResult().getType(); }
     SmallVector<StringRef> getTraitAttrNames();
     static unsigned getAccOperandIndex() { return 2; }
index cc98136..19860a1 100644 (file)
@@ -76,9 +76,6 @@ static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
 // Return true if the contract op can be convert to MMA matmul.
 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
                                           bool useNvGpu) {
-  if (!contract.getMasks().empty())
-    return false;
-
   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
   AffineExpr m, n, k;
index ce7d184..abd8962 100644 (file)
@@ -729,8 +729,6 @@ void ContractionOp::print(OpAsmPrinter &p) {
   auto dictAttr = DictionaryAttr::get(getContext(), attrs);
   p << " " << dictAttr << " " << getLhs() << ", ";
   p << getRhs() << ", " << getAcc();
-  if (getMasks().size() == 2)
-    p << ", " << getMasks();
 
   p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
   p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
@@ -898,18 +896,6 @@ LogicalResult ContractionOp::verify() {
                                contractingDimMap, batchDimMap)))
     return failure();
 
-  // Verify that either two vector masks are set or none are set.
-  auto lhsMaskType = getLHSVectorMaskType();
-  auto rhsMaskType = getRHSVectorMaskType();
-  if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
-    return emitOpError("invalid number of vector masks specified");
-  if (lhsMaskType && rhsMaskType) {
-    // Verify mask rank == argument rank.
-    if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
-        rhsMaskType.getShape().size() != rhsType.getShape().size())
-      return emitOpError("invalid vector mask rank");
-  }
-
   // Verify supported combining kind.
   auto vectorType = resType.dyn_cast<VectorType>();
   auto elementType = vectorType ? vectorType.getElementType() : resType;
index f12884b..16751f8 100644 (file)
@@ -563,10 +563,6 @@ private:
 /// otherwise supports any layout permutation of the matrix-multiply.
 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
     vector::ContractionOp op, PatternRewriter &rewriter) const {
-  // TODO: Remove native masks from contraction op?
-  if (!op.getMasks().empty())
-    return failure();
-
   if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::OuterProduct)
     return failure();
@@ -613,10 +609,6 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
   if (maskableOp.isMasked())
     return failure();
 
-  // TODO: Remove native masks from contraction op?
-  if (!op.getMasks().empty())
-    return failure();
-
   if (failed(filter(op)))
     return failure();
 
@@ -749,10 +741,6 @@ struct ContractOpToElementwise
     if (maskableOp.isMasked())
       return failure();
 
-    // TODO: Remove native masks from contraction op?
-    if (!contractOp.getMasks().empty())
-      return failure();
-
     if (failed(filter(contractOp)))
       return failure();
 
@@ -872,10 +860,6 @@ private:
 LogicalResult
 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
                                        PatternRewriter &rewriter) const {
-  // TODO: Remove native masks from contraction op?
-  if (!op.getMasks().empty())
-    return failure();
-
   if (failed(filter(op)))
     return failure();
 
@@ -1230,9 +1214,6 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
   if (maskableOp.isMasked())
     return failure();
 
-  // TODO: Remove native masks from contraction op?
-  if (!op.getMasks().empty())
-    return failure();
   if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::Matmul)
     return failure();
index 722db6a..3715cf0 100644 (file)
@@ -292,9 +292,6 @@ struct CastAwayContractionLeadingOneDim
       return failure();
     if (oldAccType.getRank() < 2)
       return failure();
-    // TODO: implement masks.
-    if (!contractOp.getMasks().empty())
-      return failure();
     if (oldAccType.getShape()[0] != 1)
       return failure();
     // currently we support only dropping one dim but the pattern can be applied
index 20fc59e..ea83242 100644 (file)
@@ -1044,10 +1044,6 @@ struct CanonicalizeContractMatmulToMMT final
 
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
-    // TODO: Remove native masks from contraction op?
-    if (!op.getMasks().empty())
-      return failure();
-
     if (failed(filter(op)))
       return failure();
 
index 688c0a2..98577e3 100644 (file)
@@ -351,20 +351,12 @@ struct UnrollContractionPattern
       SmallVector<int64_t> lhsOffets =
           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
       extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
-      // If there is a mask associated to lhs, extract it as well.
-      if (slicesOperands.size() > 3)
-        extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
-                       lhsOffets);
 
       // Extract the new rhs operand.
       AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
       SmallVector<int64_t> rhsOffets =
           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
       extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
-      // If there is a mask associated to rhs, extract it as well.
-      if (slicesOperands.size() > 4)
-        extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
-                       rhsOffets);
 
       AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
       SmallVector<int64_t> accOffets =
index 4013d5d..aadbc14 100644 (file)
@@ -365,13 +365,6 @@ func.func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf3
   // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
       : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
-  // Test contraction with optional vector mask arguments.
-  %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
-  %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
-  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
-  %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
-                                           %rhs_mask
-      : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   // Test contraction with mixed type.
   // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
   %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
index 3f3e178..cb28279 100644 (file)
@@ -62,186 +62,6 @@ func.func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
   return %3: vector<4x4xf32>
 }
 
-#contraction_accesses0 = [
-  affine_map<(i, j, k) -> (i, k)>,
-  affine_map<(i, j, k) -> (k, j)>,
-  affine_map<(i, j, k) -> (i, j)>
-]
-#contraction_trait0 = {
-  indexing_maps = #contraction_accesses0,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-// CHECK-LABEL: func @contraction4x4_ijk
-
-// Reducing output vector [0, 0]
-
-//      CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-
-// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S1]], %[[S2]], %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-
-// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S6]], %[[S7]], %[[R1S00]], %[[S8]], %[[S9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[S10:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S12:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S13:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S10]], %[[S11]], %[[R2S00]], %[[S12]], %[[S13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// Reducing output vector [0, 2]
-// CHECK-NEXT: %[[S14:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S17:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S15:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S18:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S16:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S14]], %[[S15]], %[[S16]], %[[S17]], %[[S18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[S19:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S21:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S20:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S22:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S19]], %[[S20]], %[[R1S02]], %[[S21]], %[[S22]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[S23:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S25:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S24:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S26:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S23]], %[[S24]], %[[R2S02]], %[[S25]], %[[S26]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// Reducing output vector [2, 0]
-
-// CHECK-NEXT: %[[S27:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S30:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S28:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S31:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S29:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S27]], %[[S28]], %[[S29]], %[[S30]], %[[S31]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[S32:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S34:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S33:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S35:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT:  %[[R2S20:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S32]], %[[S33]], %[[R1S20]], %[[S34]], %[[S35]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[S36:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S38:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S37:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S39:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT:  %[[R3S20:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S36]], %[[S37]], %[[R2S20]], %[[S38]], %[[S39]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// Reducing output vector [2, 2]
-
-// CHECK-NEXT: %[[S40:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S43:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S41:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S44:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S42:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S40]], %[[S41]], %[[S42]], %[[S43]], %[[S44]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[S45:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S47:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S46:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S48:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S45]], %[[S46]], %[[R1S22]], %[[S47]], %[[S48]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[S49:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S51:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S50:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S52:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S49]], %[[S50]], %[[R2S22]], %[[S51]], %[[S52]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[R3S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[R3S02]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[R3S20]], %[[VEC2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[R3S22]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-
-// CHECK-NEXT:  return %[[VEC4]] : vector<4x4xf32>
-
-func.func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
-                         %arg2 : vector<4x4xf32>, %arg3 : index)
-                         -> (vector<4x4xf32>) {
-  %lhsm = vector.constant_mask [4, 6] : vector<4x6xi1>
-  %rhsm = vector.constant_mask [6, 4] : vector<6x4xi1>
-  %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
-      : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
-
-  return %0 : vector<4x4xf32>
-}
-
-#contraction_accesses1 = [
-  affine_map<(i, k, j) -> (i, k)>,
-  affine_map<(i, k, j) -> (k, j)>,
-  affine_map<(i, k, j) -> (i, j)>
-]
-#contraction_trait1 = {
-  indexing_maps = #contraction_accesses1,
-  iterator_types = ["parallel", "reduction", "parallel"]
-}
-
-// CHECK-LABEL: func @contraction4x4_ikj
-
-// Reducing output vector [0, 0]
-
-//      CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[S1]], %[[S2]], %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// Reducing output vector [0, 2]
-
-// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S10:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[S6]], %[[S7]], %[[S8]], %[[S9]], %[[S10]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// Reducing output vector [2, 0]
-
-// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S14:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S12:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S15:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S13:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[S11]], %[[S12]], %[[S13]], %[[S14]], %[[S15]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// Reducing output vector [2, 2]
-
-// CHECK-NEXT: %[[S16:.*]] = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S19:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S17:.*]] = vector.extract_strided_slice %arg1 {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[S20:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
-// CHECK-NEXT: %[[S18:.*]] = vector.extract_strided_slice %arg2 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[S16]], %[[S17]], %[[S18]], %[[S19]], %[[S20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[R1S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[R1S02]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[R1S20]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[R1S22]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-// CHECK-NEXT:  return %[[VEC3]] : vector<4x4xf32>
-
-func.func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
-                         %arg2 : vector<4x4xf32>, %arg3 : index)
-                         -> (vector<4x4xf32>) {
-  %lhsm = vector.constant_mask [4, 2] : vector<4x2xi1>
-  %rhsm = vector.constant_mask [2, 4] : vector<2x4xi1>
-  %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
-      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
-
-  return %0 : vector<4x4xf32>
-}
-
 // CHECK-LABEL: func @contraction4x4_ikj_xfer_read
 
 // CHECK-DAG:      %[[C2:.*]] = arith.constant 2 : index
@@ -260,10 +80,10 @@ func.func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
 // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
 // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
 
-// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
 // CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
@@ -271,6 +91,16 @@ func.func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
 // CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
 // CHECK-NEXT: return
 
+#contraction_accesses1 = [
+  affine_map<(i, k, j) -> (i, k)>,
+  affine_map<(i, k, j) -> (k, j)>,
+  affine_map<(i, k, j) -> (i, j)>
+]
+#contraction_trait1 = {
+  indexing_maps = #contraction_accesses1,
+  iterator_types = ["parallel", "reduction", "parallel"]
+}
+
 func.func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
                                    %arg1 : memref<2x4xf32>,
                                    %arg2 : memref<4x4xf32>) {
@@ -389,10 +219,10 @@ func.func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>)
 // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
 // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
 
-// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>
 // CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>