[mlir][Vector] Add support for masked vector.contract
authorDiego Caballero <diegocaballero@google.com>
Wed, 15 Feb 2023 05:46:15 +0000 (05:46 +0000)
committerDiego Caballero <diegocaballero@google.com>
Wed, 15 Feb 2023 06:10:22 +0000 (06:10 +0000)
This patch adds support for masking vector.contract ops with the
vector.mask approach. This also includes the lowering of vector.contract
through the vector.outerproduct path to LLVM. For now, this only adds
support for one of the many potential flavors of
vector.contract/vector.outerproduct but unsupported cases will fail
gratefully.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D143965

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir

index b366cf8..6f6d80c 100644 (file)
@@ -91,6 +91,7 @@ def Vector_ContractionOp :
       PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
       PredOpTrait<"third operand acc and result have same element type",
                   TCresVTEtIsSameAsOpBase<0, 2>>,
+      DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
@@ -632,6 +633,7 @@ def Vector_ExtractOp :
 def Vector_FMAOp :
   Op<Vector_Dialect, "fma", [
        Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
+       DeclareOpInterfaceMethods<MaskableOpInterface>,
        DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
      ] # ElementwiseMappable.traits>,
     Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
@@ -923,7 +925,8 @@ def Vector_OuterProductOp :
     PredOpTrait<"lhs operand and result have same element type",
                 TCresVTEtIsSameAsOpBase<0, 0>>,
     PredOpTrait<"rhs operand and result have same element type",
-                TCresVTEtIsSameAsOpBase<0, 1>>]>,
+                TCresVTEtIsSameAsOpBase<0, 1>>,
+    DeclareOpInterfaceMethods<MaskableOpInterface>]>,
     Arguments<(ins AnyVector:$lhs, AnyType:$rhs,
                Variadic<AnyVector>:$acc,
                DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,
index 68865d3..a1c1c3d 100644 (file)
@@ -1107,12 +1107,48 @@ public:
     VectorType vType = fmaOp.getVectorType();
     if (vType.getRank() > 1)
       return failure();
+
+    // Masked fmas are lowered separately.
+    auto maskableOp = cast<MaskableOpInterface>(fmaOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
     return success();
   }
 };
 
+/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their
+/// LLVM counterpart representation. Non side effecting VP intrinsics are not
+/// fully supported by some backends, including x86, and they don't support
+/// pass-through values either. For these reasons, we generate an unmasked
+/// fma followed by a select instrution to emulate the masking behavior.
+/// This pattern is peepholed by some backends with support for masked fma
+/// instructions. This pattern does not match vectors of n >= 2 rank.
+class MaskedFMAOp1DConversion
+    : public VectorMaskOpConversionBase<vector::FMAOp> {
+public:
+  using VectorMaskOpConversionBase<vector::FMAOp>::VectorMaskOpConversionBase;
+
+  MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr)
+      : VectorMaskOpConversionBase<vector::FMAOp>(converter) {}
+
+  virtual LogicalResult matchAndRewriteMaskableOp(
+      vector::MaskOp maskOp, MaskableOpInterface maskableOp,
+      ConversionPatternRewriter &rewriter) const override {
+    auto fmaOp = cast<FMAOp>(maskableOp.getOperation());
+    Type llvmType = typeConverter->convertType(fmaOp.getVectorType());
+
+    Value fmulAddOp = rewriter.create<LLVM::FMulAddOp>(
+        fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(),
+        fmaOp.getAcc());
+    rewriter.replaceOpWithNewOp<LLVM::SelectOp>(
+        maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc());
+    return success();
+  }
+};
+
 class VectorInsertElementOpConversion
     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
 public:
@@ -1279,6 +1315,11 @@ public:
     if (vType.getRank() < 2)
       return failure();
 
+    // Masked fmas are lowered separately.
+    auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
     auto loc = op.getLoc();
     auto elemType = vType.getElementType();
     Value zero = rewriter.create<arith::ConstantOp>(
@@ -1707,9 +1748,10 @@ void mlir::populateVectorToLLVMConversionPatterns(
   patterns
       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
            VectorExtractElementOpConversion, VectorExtractOpConversion,
-           VectorFMAOp1DConversion, VectorInsertElementOpConversion,
-           VectorInsertOpConversion, VectorPrintOpConversion,
-           VectorTypeCastOpConversion, VectorScaleOpConversion,
+           VectorFMAOp1DConversion, MaskedFMAOp1DConversion,
+           VectorInsertElementOpConversion, VectorInsertOpConversion,
+           VectorPrintOpConversion, VectorTypeCastOpConversion,
+           VectorScaleOpConversion,
            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
            VectorLoadStoreConversion<vector::MaskedLoadOp,
                                      vector::MaskedLoadOpAdaptor>,
index 02cee8c..3e145f1 100644 (file)
@@ -889,6 +889,34 @@ LogicalResult ContractionOp::verify() {
   return success();
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type ContractionOp::getExpectedMaskType() {
+  auto indexingMaps = this->getIndexingMapsArray();
+  AffineMap lhsIdxMap = indexingMaps[0];
+  AffineMap rhsIdxMap = indexingMaps[1];
+  VectorType lhsType = this->getLhsType();
+  VectorType rhsType = this->getRhsType();
+
+  unsigned numVecDims = lhsIdxMap.getNumDims();
+  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
+
+  // Using the information in the indexing maps, extract the size of each
+  // dimension in the vector.contract operation from the two input operands.
+  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
+    maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
+    maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+
+  assert(!ShapedType::isDynamicShape(maskShape) &&
+         "Mask shape couldn't be computed");
+
+  return VectorType::get(maskShape,
+                         IntegerType::get(lhsType.getContext(), /*width=*/1));
+}
+
 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
   return SmallVector<StringRef>{getIndexingMapsAttrName(),
                                 getIteratorTypesAttrName(), getKindAttrName()};
@@ -1760,6 +1788,16 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type FMAOp::getExpectedMaskType() {
+  auto vecType = this->getVectorType();
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
@@ -2762,6 +2800,16 @@ LogicalResult OuterProductOp::verify() {
   return success();
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type OuterProductOp::getExpectedMaskType() {
+  auto vecType = this->getVectorType();
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//
index c7aee6d..9976cf7 100644 (file)
@@ -147,13 +147,13 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
 }
 
 /// Helper to create arithmetic operation associated with a kind of contraction.
-static std::optional<Value> createContractArithOp(Location loc, Value x,
-                                                  Value y, Value acc,
-                                                  vector::CombiningKind kind,
-                                                  PatternRewriter &rewriter,
-                                                  bool isInt) {
+static std::optional<Value>
+createContractArithOp(Location loc, Value x, Value y, Value acc,
+                      vector::CombiningKind kind, PatternRewriter &rewriter,
+                      bool isInt, Optional<Value> maybeMask = std::nullopt) {
   using vector::CombiningKind;
   Value mul;
+
   if (isInt) {
     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
       // Only valid for floating point types.
@@ -169,11 +169,17 @@ static std::optional<Value> createContractArithOp(Location loc, Value x,
       return std::nullopt;
     // Special case for fused multiply-add.
     if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
-      return std::optional<Value>(
-          rewriter.create<vector::FMAOp>(loc, x, y, acc));
+      Operation *fmaOp = rewriter.create<vector::FMAOp>(loc, x, y, acc);
+      if (maybeMask.has_value() && maybeMask.value())
+        fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value());
+      return fmaOp->getResult(0);
     }
     mul = rewriter.create<arith::MulFOp>(loc, x, y);
   }
+
+  assert((!maybeMask.has_value() || !maybeMask.value()) &&
+         "Unsupported masked case");
+
   if (!acc)
     return std::optional<Value>(mul);
   return makeArithReduction(rewriter, loc, kind, mul, acc);
@@ -550,14 +556,27 @@ public:
     Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
     vector::CombiningKind kind = op.getKind();
 
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+    Operation *rootOp;
+    Value mask;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = op;
+    }
+
     if (!rhsType) {
       // Special case: AXPY operation.
       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
       std::optional<Value> mult = createContractArithOp(
-          loc, op.getLhs(), b, acc, kind, rewriter, isInt);
+          loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
       if (!mult.has_value())
         return failure();
-      rewriter.replaceOp(op, *mult);
+      rewriter.replaceOp(rootOp, *mult);
       return success();
     }
 
@@ -571,13 +590,14 @@ public:
       Value r = nullptr;
       if (acc)
         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
-      std::optional<Value> m =
-          createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
+      std::optional<Value> m = createContractArithOp(
+          loc, a, op.getRhs(), r, kind, rewriter, isInt, mask);
       if (!m.has_value())
         return failure();
       result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
     }
-    rewriter.replaceOp(op, result);
+
+    rewriter.replaceOp(rootOp, result);
     return success();
   }
 };
@@ -601,7 +621,12 @@ struct ContractOpToElementwise
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
-    // TODO: implement masks
+    // TODO: Support vector.mask.
+    auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
+    // TODO: Remove native masks from contraction op?
     if (!contractOp.getMasks().empty())
       return failure();
 
@@ -1429,7 +1454,12 @@ namespace mlir {
 LogicalResult
 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
                                                  PatternRewriter &rew) const {
-  // TODO: implement masks
+  // TODO: Support vector.mask.
+  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+  if (maskableOp.isMasked())
+    return failure();
+
+  // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
   if (vectorTransformOptions.vectorContractLowering !=
@@ -1525,10 +1555,16 @@ struct UnrolledOuterProductGenerator
   UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
       : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
-        res(op.getAcc()), lhsType(op.getLhsType()) {}
+        res(op.getAcc()), lhsType(op.getLhsType()) {
+    auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+    if (maskableOp.isMasked())
+      mask = maskableOp.getMaskingOp().getMask();
+  }
 
   Value t(Value v) {
     static constexpr std::array<int64_t, 2> perm = {1, 0};
+    if (!v)
+      return v;
     return rewriter.create<vector::TransposeOp>(loc, v, perm);
   }
 
@@ -1547,16 +1583,27 @@ struct UnrolledOuterProductGenerator
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
   }
 
-  Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
+  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+                             Optional<Value> maybeMask = std::nullopt) {
     assert(reductionSize > 0);
+    // Incremental support for masking.
+    if (mask && !maybeMask.has_value())
+      return failure();
+
     Type resElementType = res.getType().cast<VectorType>().getElementType();
     for (int64_t k = 0; k < reductionSize; ++k) {
       Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
       Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
       extractA = promote(extractA, resElementType);
       extractB = promote(extractB, resElementType);
-      res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), extractA,
-                                             extractB, res, kind);
+      Value extractMask;
+      if (maybeMask.has_value() && maybeMask.value())
+        extractMask =
+            rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
+
+      Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
+          loc, res.getType(), extractA, extractB, res, kind);
+      res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
     }
     return res;
   }
@@ -1607,7 +1654,7 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
@@ -1646,7 +1693,7 @@ struct UnrolledOuterProductGenerator
 
 private:
   vector::CombiningKind kind;
-  Value lhs, rhs, res;
+  Value lhs, rhs, res, mask;
   VectorType lhsType;
 };
 } // namespace
@@ -1668,7 +1715,7 @@ private:
 /// otherwise supports any layout permutation of the matrix-multiply.
 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
     vector::ContractionOp op, PatternRewriter &rewriter) const {
-  // TODO: implement masks
+  // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
 
@@ -1679,20 +1726,31 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   if (failed(filter(op)))
     return failure();
 
+  // Vector mask setup.
+  OpBuilder::InsertionGuard guard(rewriter);
+  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+  Operation *rootOp;
+  if (maskableOp.isMasked()) {
+    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+    rootOp = maskableOp.getMaskingOp();
+  } else {
+    rootOp = op;
+  }
+
   UnrolledOuterProductGenerator e(rewriter, op);
   FailureOr<Value> matmatRes = e.matmat();
   if (succeeded(matmatRes)) {
-    rewriter.replaceOp(op, *matmatRes);
+    rewriter.replaceOp(rootOp, *matmatRes);
     return success();
   }
   FailureOr<Value> matvecRes = e.matvec();
   if (succeeded(matvecRes)) {
-    rewriter.replaceOp(op, *matvecRes);
+    rewriter.replaceOp(rootOp, *matvecRes);
     return success();
   }
   FailureOr<Value> tmatvecRes = e.tmatvec();
   if (succeeded(tmatvecRes)) {
-    rewriter.replaceOp(op, *tmatvecRes);
+    rewriter.replaceOp(rootOp, *tmatvecRes);
     return success();
   }
 
@@ -1702,7 +1760,12 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
 LogicalResult
 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
                                             PatternRewriter &rewriter) const {
-  // TODO: implement masks
+  // TODO: Support vector.mask.
+  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+  if (maskableOp.isMasked())
+    return failure();
+
+  // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
 
@@ -1834,7 +1897,12 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
 LogicalResult
 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
                                        PatternRewriter &rewriter) const {
-  // TODO: implement masks.
+  // TODO: Support vector.mask.
+  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+  if (maskableOp.isMasked())
+    return failure();
+
+  // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
 
index f38f799..a72ab15 100644 (file)
@@ -416,6 +416,18 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v
 // CHECK:       %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
 // CHECK:       return %[[T19]] : vector<2x3xf32>
 
+// -----
+
+func.func @masked_vector_contract(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// We can't check for the intermediate 'vector.mask { vector.fma }' state so we
+// just make sure the vector.fma is lowered.
+
+// CHECK:           llvm.intr.fmuladd
+// CHECK:           llvm.select
 
 // -----
 
@@ -2145,3 +2157,17 @@ func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> {
   %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
   return %0 : vector<8xf32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @masked_vector_fma(
+// CHECK-SAME:                                 %[[INPUT:.*]]: vector<8xf32>,
+// CHECK-SAME:                                 %[[MASK:.*]]: vector<8xi1>) -> vector<8xf32>
+// CHECK:           %[[FMA:.*]] = llvm.intr.fmuladd(%[[INPUT]], %[[INPUT]], %[[INPUT]])  : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+// CHECK:           llvm.select %[[MASK]], %[[FMA]], %[[INPUT]] : vector<8xi1>, vector<8xf32>
+
+func.func @masked_vector_fma(%a: vector<8xf32>, %m: vector<8xi1>) -> vector<8xf32> {
+  %0 = vector.mask %m { vector.fma %a, %a, %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
index d4dc35d..1617cb1 100644 (file)
@@ -1196,3 +1196,27 @@ func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vec
   %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32
   return %0 : f32
 }
+
+func.func @masked_vector_contract(%arg0: vector<2x3xf32>,
+                                  %arg1: vector<3xf32>,
+                                  %arg2: vector<2xf32>,
+                                  %m: vector<2x3xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// OUTERPRODUCT-LABEL:   func.func @masked_vector_contract(
+// OUTERPRODUCT-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
+// OUTERPRODUCT-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
+// OUTERPRODUCT-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
+// OUTERPRODUCT-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// OUTERPRODUCT:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// OUTERPRODUCT:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1>
+// OUTERPRODUCT:           vector.mask %[[MASK0]] { vector.outerproduct
+
+// OUTERPRODUCT:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1>
+// OUTERPRODUCT:           vector.mask %[[MASK1]] { vector.outerproduct
+
+// OUTERPRODUCT:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1>
+// OUTERPRODUCT:           vector.mask %[[MASK2]] { vector.outerproduct