static StringRef getParallelIteratorTypeName() {
return "parallel";
}
+
+ // Returns the bounds of each dimension in the iteration space spanned
+ // by the iterator types of this operation.
+ void getIterationBounds(SmallVectorImpl<int64_t> &iterationBounds);
+
+ // Returns a list of index maps, where there is a list entry for each
+ // op indexing map attribute (i.e. one for each input and output, with
+ // the output listed last). Each index map, maps from this operations
+ // iteration space, to vector dimensions of the maps input/output.
+ void getIterationIndexMap(
+ std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap);
+
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
}];
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
[(Constraint<HasShape<[4, 4]>> $a)]>;
+// TODO(andydavis) Add Constraints on lhs/rhs shapes.
+def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
+ (UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
+ [(Constraint<HasShape<[4, 4]>> $c)]>;
+
#endif // VECTOR_TRANSFORMS
return dimMap;
}
+void ContractionOp::getIterationBounds(
+ SmallVectorImpl<int64_t> &iterationBounds) {
+ auto lhsShape = getLhsType().getShape();
+ auto resShape = getResultType().getShape();
+ SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
+ SmallVector<int64_t, 2> iterationShape;
+ for (auto it : llvm::enumerate(iterator_types())) {
+ // Search lhs/rhs map results for 'targetExpr'.
+ auto targetExpr = getAffineDimExpr(it.index(), getContext());
+ auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
+ if (iteratorTypeName == getReductionIteratorTypeName()) {
+ // Get reduction dim size from lhs shape (same size in rhsShape).
+ int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
+ assert(lhsDimIndex >= 0);
+ iterationBounds.push_back(lhsShape[lhsDimIndex]);
+ continue;
+ }
+ // Get parallel dimension size from result shape.
+ int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
+ assert(resDimIndex >= 0);
+ iterationBounds.push_back(resShape[resDimIndex]);
+ }
+}
+
+void ContractionOp::getIterationIndexMap(
+ std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
+ unsigned numMaps = indexing_maps().getValue().size();
+ iterationIndexMap.resize(numMaps);
+ for (auto it : llvm::enumerate(indexing_maps())) {
+ auto index = it.index();
+ auto map = it.value().cast<AffineMapAttr>().getValue();
+ for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+ auto dim = map.getResult(i).cast<AffineDimExpr>();
+ iterationIndexMap[index][dim.getPosition()] = i;
+ }
+ }
+}
+
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
return getDimMap(indexingMaps, iterator_types(),
return res;
}
+/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
+static int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
+ assert(offsets.size() == basis.size());
+ int64_t linearIndex = 0;
+ for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
+ linearIndex += offsets[idx] * basis[idx];
+ return linearIndex;
+}
+
/// Given a shape with sizes greater than 0 along all dimensions, returns the
/// delinearized components of linearIndex along shape.
static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
Location loc, Operation *op,
ArrayRef<Value *> operands,
ArrayRef<Type> resultTypes) {
- OperationState *res = new OperationState(loc, op->getName().getStringRef(),
- operands, resultTypes, {});
- return builder.createOperation(*res);
+ OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
+ op->getAttrs());
+ return builder.createOperation(res);
}
// Helper function for Tablegen.
return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
}
+static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
+ VectorType vt) {
+ auto t = vt.getElementType();
+ Value *f = nullptr;
+ if (t.isBF16() || t.isF16())
+ f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
+ else if (t.isF32())
+ f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f));
+ else if (t.isF64())
+ f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
+ if (f)
+ return rewriter.create<SplatOp>(loc, vt, f);
+ llvm_unreachable("Unsupported type in `makeSplatZero`");
+}
+
+// Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]'
+// for each index 'i' in inputElements with a valid mapping in 'indexMap'.
+static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
+ ArrayRef<int64_t> inputElements,
+ SmallVectorImpl<int64_t> &resultElements) {
+ assert(indexMap.size() == resultElements.size());
+ assert(inputElements.size() >= resultElements.size());
+ for (unsigned i = 0, e = inputElements.size(); i < e; ++i) {
+ auto it = indexMap.find(i);
+ if (it != indexMap.end())
+ resultElements[it->second] = inputElements[i];
+ }
+}
+
+// UnrolledOperandState aggregates per-operand state required for op unrolling.
+struct UnrolledOperandState {
+ Value *operand;
+ SmallVector<int64_t, 4> unrolledShape;
+ SmallVector<int64_t, 4> unrollFactors;
+ SmallVector<int64_t, 8> basis;
+ int64_t numInstances;
+};
+
+// Populates 'state' with unrolled shape, unroll factors, basis and
+// num unrolled instances for 'operand'.
+static void getUnrolledOperandState(Value *operand,
+ const DenseMap<int64_t, int64_t> &indexMap,
+ ArrayRef<int64_t> targetShape,
+ UnrolledOperandState &state) {
+ auto vectorType = operand->getType().cast<VectorType>();
+ state.operand = operand;
+ // Compute unrolled shape of 'operand'.
+ state.unrolledShape.resize(vectorType.getRank());
+ getMappedElements(indexMap, targetShape, state.unrolledShape);
+ // Compute unroll factors for unrolled shape.
+ auto maybeUnrollFactors =
+ shapeRatio(vectorType.getShape(), state.unrolledShape);
+ assert(maybeUnrollFactors.hasValue());
+ state.unrollFactors = *maybeUnrollFactors;
+ // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
+ state.basis = computeStrides(state.unrollFactors);
+ state.numInstances = computeMaxLinearIndex(state.unrollFactors);
+}
+
+// Computes and returns the linear index of the unrolled vector at
+// 'vectorOffsets' within the vector operand represented by 'state'.
+static int64_t
+getUnrolledOperandLinearIndex(UnrolledOperandState &state,
+ ArrayRef<int64_t> vectorOffsets,
+ DenseMap<int64_t, int64_t> &indexMap) {
+ // Compute operand offsets.
+ SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
+ getMappedElements(indexMap, vectorOffsets, sliceOffsets);
+ // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
+ return linearize(sliceOffsets, state.basis);
+}
+
+// Returns an unrolled vector at 'vectorOffsets' within the vector operand
+// represented by 'state'. The value is created if not present in 'cache'.
+static Value *getOrCreateUnrolledOperandSlice(
+ Location loc, UnrolledOperandState &state, ArrayRef<int64_t> vectorOffsets,
+ ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
+ SmallVectorImpl<Value *> &cache, PatternRewriter &builder) {
+ // Compute operand offsets.
+ SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
+ getMappedElements(indexMap, offsets, sliceOffsets);
+ // TODO(b/144845578) Support non-1 strides.
+ SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
+ // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
+ int64_t sliceLinearIndex =
+ getUnrolledOperandLinearIndex(state, vectorOffsets, indexMap);
+ assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
+ auto *operandSlice = cache[sliceLinearIndex];
+ if (operandSlice == nullptr) {
+ // Initialize 'cache' with slice from 'state.operand'.
+ operandSlice = builder.create<vector::StridedSliceOp>(
+ loc, state.operand, sliceOffsets, state.unrolledShape, sliceStrides);
+ // Store value back to 'cache'.
+ cache[sliceLinearIndex] = operandSlice;
+ }
+ return operandSlice;
+}
+
+//
+// unrollSingleResultStructuredOp
+//
+// Returns a value representing the result of structured operation 'op'
+// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
+// An iteration space index map argument 'iterationIndexMapList' must be
+// specified, with a map for each structured op input and a single map for the
+// single result. The last map in the list must be the single result map.
+// Extra operands can be passed to unrolled instances of 'op' using the
+// 'extraOperands' argument.
+//
+// Example:
+//
+// // Before unrolling
+//
+// operand0 operand1 operand2
+// \ | /
+// -------------------- opA --------------------
+//
+// // After unrolling by 2
+//
+// operand0 operand1 operand2
+// / \ / \ / \
+// slice00 slice01 slice10 slice11 slice20 slice21
+// \ | | | / |
+// -------------------- opA0 -------------------- |
+// | | | |
+// \ | | /
+// -------------------- opA1 -------------------
+// | |
+// \ /
+// insertslice
+// |
+
+// TODO(andydavis) Generalize this to support structured ops beyond
+// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
+static Value *unrollSingleResultStructuredOp(
+ Operation *op, ArrayRef<int64_t> iterationBounds,
+ std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
+ ArrayRef<int64_t> targetShape, ArrayRef<Value *> extraOperands,
+ PatternRewriter &builder) {
+ auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
+ if (!shapedType || !shapedType.hasStaticShape())
+ assert(false && "Expected a statically shaped result type");
+
+ // Compute unroll factors for 'iterationBounds' based on 'targetShape'
+ auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape);
+ if (!maybeUnrollFactors.hasValue())
+ assert(false && "Failed to compute unroll factors for target shape");
+ auto unrollFactors = *maybeUnrollFactors;
+
+ // Compute unrolled operation state for each mapped operand.
+ unsigned numMaps = iterationIndexMapList.size();
+ SmallVector<UnrolledOperandState, 3> unrolledOperandState(numMaps);
+ assert(op->getNumOperands() >= numMaps);
+ for (unsigned i = 0; i < numMaps; ++i) {
+ getUnrolledOperandState(op->getOperand(i), iterationIndexMapList[i],
+ targetShape, unrolledOperandState[i]);
+ }
+ // Compute number of total unrolled instances.
+ auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
+ auto basis = computeStrides(unrollFactors);
+
+ auto &resultOperandState = unrolledOperandState[numMaps - 1];
+ auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
+ shapedType.getElementType());
+
+ // Initialize caches for intermediate vector results.
+ std::vector<SmallVector<Value *, 4>> caches(numMaps);
+ for (unsigned i = 0; i < numMaps; ++i) {
+ caches[i].resize(unrolledOperandState[i].numInstances);
+ }
+
+ // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
+ for (unsigned i = 0; i < numUnrolledInstances; ++i) {
+ // De-linearize w.r.t. 'basis'.
+ auto vectorOffsets = delinearize(i, basis);
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+ vectorOffsets, targetShape);
+ // Get cached slice (or create slice) for each operand at 'offsets'.
+ SmallVector<Value *, 3> operands;
+ operands.reserve(numMaps);
+ for (unsigned i = 0; i < numMaps; ++i) {
+ operands.push_back(getOrCreateUnrolledOperandSlice(
+ op->getLoc(), unrolledOperandState[i], vectorOffsets, offsets,
+ iterationIndexMapList[i], caches[i], builder));
+ }
+ // Create op on sliced vector arguments.
+ operands.append(extraOperands.begin(), extraOperands.end());
+ auto resultVector =
+ cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
+ unrolledResultType)
+ ->getResult(0);
+
+ // Compute linear result index.
+ int64_t resultIndex = getUnrolledOperandLinearIndex(
+ resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]);
+ // Update result cache at 'resultIndex'.
+ caches[numMaps - 1][resultIndex] = resultVector;
+ }
+
+ // Make zero splat into which we will insert results from 'cache[numMaps - 1]'
+ auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
+ auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
+ SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
+ // Insert vector accumulators into output.
+ for (unsigned i = 0; i < resultOperandState.numInstances; ++i) {
+ auto vectorOffsets = delinearize(i, resultOperandState.basis);
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+ vectorOffsets, resultOperandState.unrolledShape);
+ res = builder.create<vector::InsertStridedSliceOp>(
+ op->getLoc(), caches[numMaps - 1][i], res, offsets, strides);
+ }
+
+ return res;
+}
+
// Entry point for unrolling declarative pattern rewrites.
// `op` is unrolled to the `targetShape` as follows, for each of its operands:
// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
Operation *op,
ArrayRef<int64_t> targetShape) {
+ if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
+ // Get contraction op iteration bounds.
+ SmallVector<int64_t, 6> iterationBounds;
+ contractionOp.getIterationBounds(iterationBounds);
+ assert(iterationBounds.size() == targetShape.size());
+ // Get map from iteration space index to lhs/rhs/result shape index.
+ std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
+ contractionOp.getIterationIndexMap(iterationIndexMapList);
+ // TODO(andydavis) Support unrollable vector masks.
+ SmallVector<Value *, 2> masks(contractionOp.masks().begin(),
+ contractionOp.masks().end());
+ // Unroll 'op' 'iterationBounds' to 'targetShape'.
+ return unrollSingleResultStructuredOp(op, iterationBounds,
+ iterationIndexMapList, targetShape,
+ masks, builder);
+ }
+ // TODO(andydavis) Create trivial iteration bounds and index map for
+ // elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
+ // fakefork/join if possible.
+
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: unrollSingleResultOpMatchingType on func:\n");
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
}
};
-static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
- VectorType vt) {
- auto t = vt.getElementType();
- Value *f = nullptr;
- if (t.isBF16() || t.isF16())
- f = rewriter.create<ConstantOp>(loc, t, rewriter.getF16FloatAttr(0.0f))
- .getResult();
- else if (t.isF32())
- f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f))
- .getResult();
- else if (t.isF64())
- f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f))
- .getResult();
- if (f)
- return rewriter.create<SplatOp>(loc, vt, f).getResult();
- llvm_unreachable("Unsupported type in `makeSplatZero`");
-}
-
// Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple
// vector.strided_slice ops.
struct ConvertFakeJoinOp : public RewritePattern {
%3 = addf %1, %2: vector<4x4xf32>
return %3: vector<4x4xf32>
}
+
+
+#contraction_accesses0 = [
+ (i, j, k) -> (i, k),
+ (i, j, k) -> (k, j),
+ (i, j, k) -> (i, j)
+]
+#contraction_trait0 = {
+ indexing_maps = #contraction_accesses0,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @contraction4x4_ijk
+
+// Reducing output vector [0, 0]
+
+// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [0, 2]
+
+// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [2, 0]
+
+// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [2, 2]
+
+// CHECK-NEXT: %[[R0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S02]], %[[R0S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S22]], %[[R1S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S42]], %[[R2S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Insert output vector slices into 4x4 vector result.
+// CHECK-NEXT: %[[RES0:.*]] = vector.insert_strided_slice %[[R3S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: %[[RES1:.*]] = vector.insert_strided_slice %[[R3S02]], %[[RES0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: %[[RES2:.*]] = vector.insert_strided_slice %[[R3S20]], %[[RES1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: %[[RES3:.*]] = vector.insert_strided_slice %[[R3S22]], %[[RES2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: return %[[RES3]] : vector<4x4xf32>
+
+func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
+ %arg2 : vector<4x4xf32>, %arg3 : index)
+ -> (vector<4x4xf32>) {
+
+ %lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
+ %rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
+ %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
+ : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
+
+ return %0 : vector<4x4xf32>
+}
+
+
+#contraction_accesses1 = [
+ (i, k, j) -> (i, k),
+ (i, k, j) -> (k, j),
+ (i, k, j) -> (i, j)
+]
+#contraction_trait1 = {
+ indexing_maps = #contraction_accesses1,
+ iterator_types = ["parallel", "reduction", "parallel"]
+}
+
+// CHECK-LABEL: func @contraction4x4_ikj
+
+// Reducing output vector [0, 0]
+
+// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [0, 2]
+
+// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [2, 0]
+
+// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Reducing output vector [2, 2]
+
+// CHECK-NEXT: %[[R0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S02]], %[[R0S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// Insert output vector slices into 4x4 vector result.
+// CHECK-NEXT: %[[RES0:.*]] = vector.insert_strided_slice %[[R1S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: %[[RES1:.*]] = vector.insert_strided_slice %[[R1S02]], %[[RES0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: %[[RES2:.*]] = vector.insert_strided_slice %[[R1S20]], %[[RES1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: %[[RES3:.*]] = vector.insert_strided_slice %[[R1S22]], %[[RES2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: return %[[RES3]] : vector<4x4xf32>
+
+func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
+ %arg2 : vector<4x4xf32>, %arg3 : index)
+ -> (vector<4x4xf32>) {
+
+ %lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
+ %rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
+ %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
+ : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
+
+ return %0 : vector<4x4xf32>
+}