Adds support for unrolling single-result vector operations with iterator type lists...
authorAndy Davis <andydavis@google.com>
Wed, 4 Dec 2019 14:53:07 +0000 (06:53 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Dec 2019 14:53:37 +0000 (06:53 -0800)
Adds unit tests for unrolling the vector ContractionOp with different iteration orders.

PiperOrigin-RevId: 283747503

mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/lib/Dialect/VectorOps/VectorToVector.cpp
mlir/test/Conversion/VectorConversions/vector-to-vector.mlir

index d34fa9a..36c26fe 100644 (file)
@@ -157,6 +157,18 @@ def Vector_ContractionOp :
     static StringRef getParallelIteratorTypeName() {
       return "parallel";
     }
+
+    // Returns the bounds of each dimension in the iteration space spanned
+    // by the iterator types of this operation.
+    void getIterationBounds(SmallVectorImpl<int64_t> &iterationBounds);
+
+    // Returns a list of index maps, where there is a list entry for each
+    // op indexing map attribute (i.e. one for each input and output, with
+    // the output listed last). Each index map, maps from this operations
+    // iteration space, to vector dimensions of the maps input/output.
+    void getIterationIndexMap(
+      std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap);
+
     std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
     std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
   }];
index fe0940c..e716796 100644 (file)
@@ -40,4 +40,9 @@ def : Pat<(AddFOp:$op_results $a, $b),
           (UnrollVectorOp<[2, 2]> $op_results, $a, $b),
           [(Constraint<HasShape<[4, 4]>> $a)]>;
 
+// TODO(andydavis) Add Constraints on lhs/rhs shapes.
+def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
+          (UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
+          [(Constraint<HasShape<[4, 4]>> $c)]>;
+
 #endif // VECTOR_TRANSFORMS
index 7f3be9d..ab457a6 100644 (file)
@@ -271,6 +271,44 @@ getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
   return dimMap;
 }
 
+void ContractionOp::getIterationBounds(
+    SmallVectorImpl<int64_t> &iterationBounds) {
+  auto lhsShape = getLhsType().getShape();
+  auto resShape = getResultType().getShape();
+  SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
+  SmallVector<int64_t, 2> iterationShape;
+  for (auto it : llvm::enumerate(iterator_types())) {
+    // Search lhs/rhs map results for 'targetExpr'.
+    auto targetExpr = getAffineDimExpr(it.index(), getContext());
+    auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
+    if (iteratorTypeName == getReductionIteratorTypeName()) {
+      // Get reduction dim size from lhs shape (same size in rhsShape).
+      int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
+      assert(lhsDimIndex >= 0);
+      iterationBounds.push_back(lhsShape[lhsDimIndex]);
+      continue;
+    }
+    // Get parallel dimension size from result shape.
+    int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
+    assert(resDimIndex >= 0);
+    iterationBounds.push_back(resShape[resDimIndex]);
+  }
+}
+
+void ContractionOp::getIterationIndexMap(
+    std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
+  unsigned numMaps = indexing_maps().getValue().size();
+  iterationIndexMap.resize(numMaps);
+  for (auto it : llvm::enumerate(indexing_maps())) {
+    auto index = it.index();
+    auto map = it.value().cast<AffineMapAttr>().getValue();
+    for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+      auto dim = map.getResult(i).cast<AffineDimExpr>();
+      iterationIndexMap[index][dim.getPosition()] = i;
+    }
+  }
+}
+
 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
   return getDimMap(indexingMaps, iterator_types(),
index 1e2e651..0952312 100644 (file)
@@ -77,6 +77,15 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
   return res;
 }
 
+/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
+static int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
+  assert(offsets.size() == basis.size());
+  int64_t linearIndex = 0;
+  for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
+    linearIndex += offsets[idx] * basis[idx];
+  return linearIndex;
+}
+
 /// Given a shape with sizes greater than 0 along all dimensions, returns the
 /// delinearized components of linearIndex along shape.
 static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
@@ -151,9 +160,9 @@ static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
                                               Location loc, Operation *op,
                                               ArrayRef<Value *> operands,
                                               ArrayRef<Type> resultTypes) {
-  OperationState *res = new OperationState(loc, op->getName().getStringRef(),
-                                           operands, resultTypes, {});
-  return builder.createOperation(*res);
+  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
+                     op->getAttrs());
+  return builder.createOperation(res);
 }
 
 // Helper function for Tablegen.
@@ -164,6 +173,223 @@ static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
   return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
 }
 
+static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
+                            VectorType vt) {
+  auto t = vt.getElementType();
+  Value *f = nullptr;
+  if (t.isBF16() || t.isF16())
+    f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
+  else if (t.isF32())
+    f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f));
+  else if (t.isF64())
+    f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
+  if (f)
+    return rewriter.create<SplatOp>(loc, vt, f);
+  llvm_unreachable("Unsupported type in `makeSplatZero`");
+}
+
+// Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]'
+// for each index 'i' in inputElements with a valid mapping in 'indexMap'.
+static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
+                              ArrayRef<int64_t> inputElements,
+                              SmallVectorImpl<int64_t> &resultElements) {
+  assert(indexMap.size() == resultElements.size());
+  assert(inputElements.size() >= resultElements.size());
+  for (unsigned i = 0, e = inputElements.size(); i < e; ++i) {
+    auto it = indexMap.find(i);
+    if (it != indexMap.end())
+      resultElements[it->second] = inputElements[i];
+  }
+}
+
+// UnrolledOperandState aggregates per-operand state required for op unrolling.
+struct UnrolledOperandState {
+  Value *operand;
+  SmallVector<int64_t, 4> unrolledShape;
+  SmallVector<int64_t, 4> unrollFactors;
+  SmallVector<int64_t, 8> basis;
+  int64_t numInstances;
+};
+
+// Populates 'state' with unrolled shape, unroll factors, basis and
+// num unrolled instances for 'operand'.
+static void getUnrolledOperandState(Value *operand,
+                                    const DenseMap<int64_t, int64_t> &indexMap,
+                                    ArrayRef<int64_t> targetShape,
+                                    UnrolledOperandState &state) {
+  auto vectorType = operand->getType().cast<VectorType>();
+  state.operand = operand;
+  // Compute unrolled shape of 'operand'.
+  state.unrolledShape.resize(vectorType.getRank());
+  getMappedElements(indexMap, targetShape, state.unrolledShape);
+  // Compute unroll factors for unrolled shape.
+  auto maybeUnrollFactors =
+      shapeRatio(vectorType.getShape(), state.unrolledShape);
+  assert(maybeUnrollFactors.hasValue());
+  state.unrollFactors = *maybeUnrollFactors;
+  // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
+  state.basis = computeStrides(state.unrollFactors);
+  state.numInstances = computeMaxLinearIndex(state.unrollFactors);
+}
+
+// Computes and returns the linear index of the unrolled vector at
+// 'vectorOffsets' within the vector operand represented by 'state'.
+static int64_t
+getUnrolledOperandLinearIndex(UnrolledOperandState &state,
+                              ArrayRef<int64_t> vectorOffsets,
+                              DenseMap<int64_t, int64_t> &indexMap) {
+  // Compute operand offsets.
+  SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
+  getMappedElements(indexMap, vectorOffsets, sliceOffsets);
+  // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
+  return linearize(sliceOffsets, state.basis);
+}
+
+// Returns an unrolled vector at 'vectorOffsets' within the vector operand
+// represented by 'state'. The value is created if not present in 'cache'.
+static Value *getOrCreateUnrolledOperandSlice(
+    Location loc, UnrolledOperandState &state, ArrayRef<int64_t> vectorOffsets,
+    ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
+    SmallVectorImpl<Value *> &cache, PatternRewriter &builder) {
+  // Compute operand offsets.
+  SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
+  getMappedElements(indexMap, offsets, sliceOffsets);
+  // TODO(b/144845578) Support non-1 strides.
+  SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
+  // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
+  int64_t sliceLinearIndex =
+      getUnrolledOperandLinearIndex(state, vectorOffsets, indexMap);
+  assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
+  auto *operandSlice = cache[sliceLinearIndex];
+  if (operandSlice == nullptr) {
+    // Initialize 'cache' with slice from 'state.operand'.
+    operandSlice = builder.create<vector::StridedSliceOp>(
+        loc, state.operand, sliceOffsets, state.unrolledShape, sliceStrides);
+    // Store value back to 'cache'.
+    cache[sliceLinearIndex] = operandSlice;
+  }
+  return operandSlice;
+}
+
+//
+// unrollSingleResultStructuredOp
+//
+// Returns a value representing the result of structured operation 'op'
+// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
+// An iteration space index map argument 'iterationIndexMapList' must be
+// specified, with a map for each structured op input and a single map for the
+// single result. The last map in the list must be the single result map.
+// Extra operands can be passed to unrolled instances of 'op' using the
+// 'extraOperands' argument.
+//
+// Example:
+//
+//  // Before unrolling
+//
+//   operand0                operand1                operand2
+//       \                      |                      /
+//        -------------------- opA --------------------
+//
+//  // After unrolling by 2
+//
+//   operand0                operand1                operand2
+//   /      \                /      \                /      \
+// slice00  slice01       slice10  slice11        slice20  slice21
+//   \         |            |          |            /          |
+//    -------------------- opA0 --------------------           |
+//             |            |          |                       |
+//              \           |          |                      /
+//               -------------------- opA1 -------------------
+//                          |          |
+//                           \        /
+//                           insertslice
+//                                |
+
+// TODO(andydavis) Generalize this to support structured ops beyond
+// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
+static Value *unrollSingleResultStructuredOp(
+    Operation *op, ArrayRef<int64_t> iterationBounds,
+    std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
+    ArrayRef<int64_t> targetShape, ArrayRef<Value *> extraOperands,
+    PatternRewriter &builder) {
+  auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
+  if (!shapedType || !shapedType.hasStaticShape())
+    assert(false && "Expected a statically shaped result type");
+
+  // Compute unroll factors for 'iterationBounds' based on 'targetShape'
+  auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape);
+  if (!maybeUnrollFactors.hasValue())
+    assert(false && "Failed to compute unroll factors for target shape");
+  auto unrollFactors = *maybeUnrollFactors;
+
+  // Compute unrolled operation state for each mapped operand.
+  unsigned numMaps = iterationIndexMapList.size();
+  SmallVector<UnrolledOperandState, 3> unrolledOperandState(numMaps);
+  assert(op->getNumOperands() >= numMaps);
+  for (unsigned i = 0; i < numMaps; ++i) {
+    getUnrolledOperandState(op->getOperand(i), iterationIndexMapList[i],
+                            targetShape, unrolledOperandState[i]);
+  }
+  // Compute number of total unrolled instances.
+  auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
+  auto basis = computeStrides(unrollFactors);
+
+  auto &resultOperandState = unrolledOperandState[numMaps - 1];
+  auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
+                                            shapedType.getElementType());
+
+  // Initialize caches for intermediate vector results.
+  std::vector<SmallVector<Value *, 4>> caches(numMaps);
+  for (unsigned i = 0; i < numMaps; ++i) {
+    caches[i].resize(unrolledOperandState[i].numInstances);
+  }
+
+  // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
+  for (unsigned i = 0; i < numUnrolledInstances; ++i) {
+    // De-linearize w.r.t. 'basis'.
+    auto vectorOffsets = delinearize(i, basis);
+    // Convert from unrolled vector-space offsets to element-space offsets.
+    auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+                          vectorOffsets, targetShape);
+    // Get cached slice (or create slice) for each operand at 'offsets'.
+    SmallVector<Value *, 3> operands;
+    operands.reserve(numMaps);
+    for (unsigned i = 0; i < numMaps; ++i) {
+      operands.push_back(getOrCreateUnrolledOperandSlice(
+          op->getLoc(), unrolledOperandState[i], vectorOffsets, offsets,
+          iterationIndexMapList[i], caches[i], builder));
+    }
+    // Create op on sliced vector arguments.
+    operands.append(extraOperands.begin(), extraOperands.end());
+    auto resultVector =
+        cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
+                                    unrolledResultType)
+            ->getResult(0);
+
+    // Compute linear result index.
+    int64_t resultIndex = getUnrolledOperandLinearIndex(
+        resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]);
+    // Update result cache at 'resultIndex'.
+    caches[numMaps - 1][resultIndex] = resultVector;
+  }
+
+  // Make zero splat into which we will insert results from 'cache[numMaps - 1]'
+  auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
+  auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
+  SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
+  // Insert vector accumulators into output.
+  for (unsigned i = 0; i < resultOperandState.numInstances; ++i) {
+    auto vectorOffsets = delinearize(i, resultOperandState.basis);
+    // Convert from unrolled vector-space offsets to element-space offsets.
+    auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+                          vectorOffsets, resultOperandState.unrolledShape);
+    res = builder.create<vector::InsertStridedSliceOp>(
+        op->getLoc(), caches[numMaps - 1][i], res, offsets, strides);
+  }
+
+  return res;
+}
+
 // Entry point for unrolling declarative pattern rewrites.
 // `op` is unrolled to the `targetShape` as follows, for each of its operands:
 //   1. the unrolled type `unrolledVectorType` and number of unrolled instances
@@ -200,6 +426,26 @@ static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
 Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
                                                Operation *op,
                                                ArrayRef<int64_t> targetShape) {
+  if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
+    // Get contraction op iteration bounds.
+    SmallVector<int64_t, 6> iterationBounds;
+    contractionOp.getIterationBounds(iterationBounds);
+    assert(iterationBounds.size() == targetShape.size());
+    // Get map from iteration space index to lhs/rhs/result shape index.
+    std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
+    contractionOp.getIterationIndexMap(iterationIndexMapList);
+    // TODO(andydavis) Support unrollable vector masks.
+    SmallVector<Value *, 2> masks(contractionOp.masks().begin(),
+                                  contractionOp.masks().end());
+    // Unroll 'op' 'iterationBounds' to 'targetShape'.
+    return unrollSingleResultStructuredOp(op, iterationBounds,
+                                          iterationIndexMapList, targetShape,
+                                          masks, builder);
+  }
+  // TODO(andydavis) Create trivial iteration bounds and index map for
+  // elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
+  // fakefork/join if possible.
+
   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
                        "]: unrollSingleResultOpMatchingType on func:\n");
   LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
@@ -365,24 +611,6 @@ struct ConvertFakeForkFromBlockArgsOp : public RewritePattern {
   }
 };
 
-static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
-                            VectorType vt) {
-  auto t = vt.getElementType();
-  Value *f = nullptr;
-  if (t.isBF16() || t.isF16())
-    f = rewriter.create<ConstantOp>(loc, t, rewriter.getF16FloatAttr(0.0f))
-            .getResult();
-  else if (t.isF32())
-    f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f))
-            .getResult();
-  else if (t.isF64())
-    f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f))
-            .getResult();
-  if (f)
-    return rewriter.create<SplatOp>(loc, vt, f).getResult();
-  llvm_unreachable("Unsupported type in `makeSplatZero`");
-}
-
 // Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple
 // vector.strided_slice ops.
 struct ConvertFakeJoinOp : public RewritePattern {
index dcd611d..1d9331e 100644 (file)
@@ -47,3 +47,131 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
   %3 = addf %1, %2: vector<4x4xf32>
   return %3: vector<4x4xf32>
 }
+
+
+#contraction_accesses0 = [
+  (i, j, k) -> (i, k),
+  (i, j, k) -> (k, j),
+  (i, j, k) -> (i, j)
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @contraction4x4_ijk
+
+// Reducing output vector [0, 0]
+
+// CHECK:       %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [0, 2]
+
+// CHECK-NEXT:  %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [2, 0]
+
+// CHECK-NEXT:  %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [2, 2]
+
+// CHECK-NEXT:  %[[R0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S02]], %[[R0S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S22]], %[[R1S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S42]], %[[R2S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Insert output vector slices into 4x4 vector result.
+// CHECK-NEXT:  %[[RES0:.*]] = vector.insert_strided_slice %[[R3S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  %[[RES1:.*]] = vector.insert_strided_slice %[[R3S02]], %[[RES0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  %[[RES2:.*]] = vector.insert_strided_slice %[[R3S20]], %[[RES1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  %[[RES3:.*]] = vector.insert_strided_slice %[[R3S22]], %[[RES2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  return %[[RES3]] : vector<4x4xf32>
+
+func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
+                         %arg2 : vector<4x4xf32>, %arg3 : index)
+                         -> (vector<4x4xf32>) {
+
+  %lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
+  %rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
+  %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
+      : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
+
+  return %0 : vector<4x4xf32>
+}
+
+
+#contraction_accesses1 = [
+  (i, k, j) -> (i, k),
+  (i, k, j) -> (k, j),
+  (i, k, j) -> (i, j)
+]
+#contraction_trait1 = {
+  indexing_maps = #contraction_accesses1,
+  iterator_types = ["parallel", "reduction", "parallel"]
+}
+
+// CHECK-LABEL: func @contraction4x4_ikj
+
+// Reducing output vector [0, 0]
+
+// CHECK:       %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [0, 2]
+
+// CHECK-NEXT:  %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [2, 0]
+
+// CHECK-NEXT:  %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [2, 2]
+
+// CHECK-NEXT:  %[[R0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S02]], %[[R0S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Insert output vector slices into 4x4 vector result.
+// CHECK-NEXT:  %[[RES0:.*]] = vector.insert_strided_slice %[[R1S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  %[[RES1:.*]] = vector.insert_strided_slice %[[R1S02]], %[[RES0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  %[[RES2:.*]] = vector.insert_strided_slice %[[R1S20]], %[[RES1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  %[[RES3:.*]] = vector.insert_strided_slice %[[R1S22]], %[[RES2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  return %[[RES3]] : vector<4x4xf32>
+
+func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
+                         %arg2 : vector<4x4xf32>, %arg3 : index)
+                         -> (vector<4x4xf32>) {
+
+  %lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
+  %rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
+  %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
+      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
+
+  return %0 : vector<4x4xf32>
+}