From: Andy Davis Date: Wed, 4 Dec 2019 14:53:07 +0000 (-0800) Subject: Adds support for unrolling single-result vector operations with iterator type lists... X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=34e1f4aa510ea62155b9d2ab4e810a55ad6f4c5b;p=platform%2Fupstream%2Fllvm.git Adds support for unrolling single-result vector operations with iterator type lists and indexing maps to a target vector size. Adds unit tests for unrolling the vector ContractionOp with different iteration orders. PiperOrigin-RevId: 283747503 --- diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index d34fa9a..36c26fe 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -157,6 +157,18 @@ def Vector_ContractionOp : static StringRef getParallelIteratorTypeName() { return "parallel"; } + + // Returns the bounds of each dimension in the iteration space spanned + // by the iterator types of this operation. + void getIterationBounds(SmallVectorImpl &iterationBounds); + + // Returns a list of index maps, where there is a list entry for each + // op indexing map attribute (i.e. one for each input and output, with + // the output listed last). Each index map, maps from this operations + // iteration space, to vector dimensions of the maps input/output. + void getIterationIndexMap( + std::vector> &iterationIndexMap); + std::vector> getContractingDimMap(); std::vector> getBatchDimMap(); }]; diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td b/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td index fe0940c..e7167962 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td @@ -40,4 +40,9 @@ def : Pat<(AddFOp:$op_results $a, $b), (UnrollVectorOp<[2, 2]> $op_results, $a, $b), [(Constraint> $a)]>; +// TODO(andydavis) Add Constraints on lhs/rhs shapes. +def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1), + (UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c), + [(Constraint> $c)]>; + #endif // VECTOR_TRANSFORMS diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 7f3be9d..ab457a6 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -271,6 +271,44 @@ getDimMap(ArrayRef indexingMaps, ArrayAttr iteratorTypes, return dimMap; } +void ContractionOp::getIterationBounds( + SmallVectorImpl &iterationBounds) { + auto lhsShape = getLhsType().getShape(); + auto resShape = getResultType().getShape(); + SmallVector indexingMaps(getIndexingMaps()); + SmallVector iterationShape; + for (auto it : llvm::enumerate(iterator_types())) { + // Search lhs/rhs map results for 'targetExpr'. + auto targetExpr = getAffineDimExpr(it.index(), getContext()); + auto iteratorTypeName = it.value().cast().getValue(); + if (iteratorTypeName == getReductionIteratorTypeName()) { + // Get reduction dim size from lhs shape (same size in rhsShape). + int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); + assert(lhsDimIndex >= 0); + iterationBounds.push_back(lhsShape[lhsDimIndex]); + continue; + } + // Get parallel dimension size from result shape. + int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); + assert(resDimIndex >= 0); + iterationBounds.push_back(resShape[resDimIndex]); + } +} + +void ContractionOp::getIterationIndexMap( + std::vector> &iterationIndexMap) { + unsigned numMaps = indexing_maps().getValue().size(); + iterationIndexMap.resize(numMaps); + for (auto it : llvm::enumerate(indexing_maps())) { + auto index = it.index(); + auto map = it.value().cast().getValue(); + for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { + auto dim = map.getResult(i).cast(); + iterationIndexMap[index][dim.getPosition()] = i; + } + } +} + std::vector> ContractionOp::getContractingDimMap() { SmallVector indexingMaps(getIndexingMaps()); return getDimMap(indexingMaps, iterator_types(), diff --git a/mlir/lib/Dialect/VectorOps/VectorToVector.cpp b/mlir/lib/Dialect/VectorOps/VectorToVector.cpp index 1e2e651..0952312 100644 --- a/mlir/lib/Dialect/VectorOps/VectorToVector.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorToVector.cpp @@ -77,6 +77,15 @@ static int64_t computeMaxLinearIndex(ArrayRef basis) { return res; } +/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'. +static int64_t linearize(ArrayRef offsets, ArrayRef basis) { + assert(offsets.size() == basis.size()); + int64_t linearIndex = 0; + for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) + linearIndex += offsets[idx] * basis[idx]; + return linearIndex; +} + /// Given a shape with sizes greater than 0 along all dimensions, returns the /// delinearized components of linearIndex along shape. static SmallVector delinearize(int64_t linearIndex, @@ -151,9 +160,9 @@ static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, Location loc, Operation *op, ArrayRef operands, ArrayRef resultTypes) { - OperationState *res = new OperationState(loc, op->getName().getStringRef(), - operands, resultTypes, {}); - return builder.createOperation(*res); + OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, + op->getAttrs()); + return builder.createOperation(res); } // Helper function for Tablegen. @@ -164,6 +173,223 @@ static bool hasShape(Value *v, ArrayRef shape) { return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin()); } +static Value *makeSplatZero(Location loc, PatternRewriter &rewriter, + VectorType vt) { + auto t = vt.getElementType(); + Value *f = nullptr; + if (t.isBF16() || t.isF16()) + f = rewriter.create(loc, t, rewriter.getF64FloatAttr(0.0f)); + else if (t.isF32()) + f = rewriter.create(loc, t, rewriter.getF32FloatAttr(0.0f)); + else if (t.isF64()) + f = rewriter.create(loc, t, rewriter.getF64FloatAttr(0.0f)); + if (f) + return rewriter.create(loc, vt, f); + llvm_unreachable("Unsupported type in `makeSplatZero`"); +} + +// Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]' +// for each index 'i' in inputElements with a valid mapping in 'indexMap'. +static void getMappedElements(const DenseMap &indexMap, + ArrayRef inputElements, + SmallVectorImpl &resultElements) { + assert(indexMap.size() == resultElements.size()); + assert(inputElements.size() >= resultElements.size()); + for (unsigned i = 0, e = inputElements.size(); i < e; ++i) { + auto it = indexMap.find(i); + if (it != indexMap.end()) + resultElements[it->second] = inputElements[i]; + } +} + +// UnrolledOperandState aggregates per-operand state required for op unrolling. +struct UnrolledOperandState { + Value *operand; + SmallVector unrolledShape; + SmallVector unrollFactors; + SmallVector basis; + int64_t numInstances; +}; + +// Populates 'state' with unrolled shape, unroll factors, basis and +// num unrolled instances for 'operand'. +static void getUnrolledOperandState(Value *operand, + const DenseMap &indexMap, + ArrayRef targetShape, + UnrolledOperandState &state) { + auto vectorType = operand->getType().cast(); + state.operand = operand; + // Compute unrolled shape of 'operand'. + state.unrolledShape.resize(vectorType.getRank()); + getMappedElements(indexMap, targetShape, state.unrolledShape); + // Compute unroll factors for unrolled shape. + auto maybeUnrollFactors = + shapeRatio(vectorType.getShape(), state.unrolledShape); + assert(maybeUnrollFactors.hasValue()); + state.unrollFactors = *maybeUnrollFactors; + // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'. + state.basis = computeStrides(state.unrollFactors); + state.numInstances = computeMaxLinearIndex(state.unrollFactors); +} + +// Computes and returns the linear index of the unrolled vector at +// 'vectorOffsets' within the vector operand represented by 'state'. +static int64_t +getUnrolledOperandLinearIndex(UnrolledOperandState &state, + ArrayRef vectorOffsets, + DenseMap &indexMap) { + // Compute operand offsets. + SmallVector sliceOffsets(state.unrolledShape.size()); + getMappedElements(indexMap, vectorOffsets, sliceOffsets); + // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'. + return linearize(sliceOffsets, state.basis); +} + +// Returns an unrolled vector at 'vectorOffsets' within the vector operand +// represented by 'state'. The value is created if not present in 'cache'. +static Value *getOrCreateUnrolledOperandSlice( + Location loc, UnrolledOperandState &state, ArrayRef vectorOffsets, + ArrayRef offsets, DenseMap &indexMap, + SmallVectorImpl &cache, PatternRewriter &builder) { + // Compute operand offsets. + SmallVector sliceOffsets(state.unrolledShape.size()); + getMappedElements(indexMap, offsets, sliceOffsets); + // TODO(b/144845578) Support non-1 strides. + SmallVector sliceStrides(state.unrolledShape.size(), 1); + // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'. + int64_t sliceLinearIndex = + getUnrolledOperandLinearIndex(state, vectorOffsets, indexMap); + assert(sliceLinearIndex < static_cast(cache.size())); + auto *operandSlice = cache[sliceLinearIndex]; + if (operandSlice == nullptr) { + // Initialize 'cache' with slice from 'state.operand'. + operandSlice = builder.create( + loc, state.operand, sliceOffsets, state.unrolledShape, sliceStrides); + // Store value back to 'cache'. + cache[sliceLinearIndex] = operandSlice; + } + return operandSlice; +} + +// +// unrollSingleResultStructuredOp +// +// Returns a value representing the result of structured operation 'op' +// with iteration bounds 'iterationBounds' unrolled to 'targetShape'. +// An iteration space index map argument 'iterationIndexMapList' must be +// specified, with a map for each structured op input and a single map for the +// single result. The last map in the list must be the single result map. +// Extra operands can be passed to unrolled instances of 'op' using the +// 'extraOperands' argument. +// +// Example: +// +// // Before unrolling +// +// operand0 operand1 operand2 +// \ | / +// -------------------- opA -------------------- +// +// // After unrolling by 2 +// +// operand0 operand1 operand2 +// / \ / \ / \ +// slice00 slice01 slice10 slice11 slice20 slice21 +// \ | | | / | +// -------------------- opA0 -------------------- | +// | | | | +// \ | | / +// -------------------- opA1 ------------------- +// | | +// \ / +// insertslice +// | + +// TODO(andydavis) Generalize this to support structured ops beyond +// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType' +static Value *unrollSingleResultStructuredOp( + Operation *op, ArrayRef iterationBounds, + std::vector> &iterationIndexMapList, + ArrayRef targetShape, ArrayRef extraOperands, + PatternRewriter &builder) { + auto shapedType = op->getResult(0)->getType().dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) + assert(false && "Expected a statically shaped result type"); + + // Compute unroll factors for 'iterationBounds' based on 'targetShape' + auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape); + if (!maybeUnrollFactors.hasValue()) + assert(false && "Failed to compute unroll factors for target shape"); + auto unrollFactors = *maybeUnrollFactors; + + // Compute unrolled operation state for each mapped operand. + unsigned numMaps = iterationIndexMapList.size(); + SmallVector unrolledOperandState(numMaps); + assert(op->getNumOperands() >= numMaps); + for (unsigned i = 0; i < numMaps; ++i) { + getUnrolledOperandState(op->getOperand(i), iterationIndexMapList[i], + targetShape, unrolledOperandState[i]); + } + // Compute number of total unrolled instances. + auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors); + auto basis = computeStrides(unrollFactors); + + auto &resultOperandState = unrolledOperandState[numMaps - 1]; + auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape, + shapedType.getElementType()); + + // Initialize caches for intermediate vector results. + std::vector> caches(numMaps); + for (unsigned i = 0; i < numMaps; ++i) { + caches[i].resize(unrolledOperandState[i].numInstances); + } + + // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'. + for (unsigned i = 0; i < numUnrolledInstances; ++i) { + // De-linearize w.r.t. 'basis'. + auto vectorOffsets = delinearize(i, basis); + // Convert from unrolled vector-space offsets to element-space offsets. + auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, + vectorOffsets, targetShape); + // Get cached slice (or create slice) for each operand at 'offsets'. + SmallVector operands; + operands.reserve(numMaps); + for (unsigned i = 0; i < numMaps; ++i) { + operands.push_back(getOrCreateUnrolledOperandSlice( + op->getLoc(), unrolledOperandState[i], vectorOffsets, offsets, + iterationIndexMapList[i], caches[i], builder)); + } + // Create op on sliced vector arguments. + operands.append(extraOperands.begin(), extraOperands.end()); + auto resultVector = + cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands, + unrolledResultType) + ->getResult(0); + + // Compute linear result index. + int64_t resultIndex = getUnrolledOperandLinearIndex( + resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]); + // Update result cache at 'resultIndex'. + caches[numMaps - 1][resultIndex] = resultVector; + } + + // Make zero splat into which we will insert results from 'cache[numMaps - 1]' + auto resultVectorType = op->getResult(0)->getType().cast(); + auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType); + SmallVector strides(resultOperandState.unrollFactors.size(), 1); + // Insert vector accumulators into output. + for (unsigned i = 0; i < resultOperandState.numInstances; ++i) { + auto vectorOffsets = delinearize(i, resultOperandState.basis); + // Convert from unrolled vector-space offsets to element-space offsets. + auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, + vectorOffsets, resultOperandState.unrolledShape); + res = builder.create( + op->getLoc(), caches[numMaps - 1][i], res, offsets, strides); + } + + return res; +} + // Entry point for unrolling declarative pattern rewrites. // `op` is unrolled to the `targetShape` as follows, for each of its operands: // 1. the unrolled type `unrolledVectorType` and number of unrolled instances @@ -200,6 +426,26 @@ static bool hasShape(Value *v, ArrayRef shape) { Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, ArrayRef targetShape) { + if (auto contractionOp = dyn_cast(op)) { + // Get contraction op iteration bounds. + SmallVector iterationBounds; + contractionOp.getIterationBounds(iterationBounds); + assert(iterationBounds.size() == targetShape.size()); + // Get map from iteration space index to lhs/rhs/result shape index. + std::vector> iterationIndexMapList; + contractionOp.getIterationIndexMap(iterationIndexMapList); + // TODO(andydavis) Support unrollable vector masks. + SmallVector masks(contractionOp.masks().begin(), + contractionOp.masks().end()); + // Unroll 'op' 'iterationBounds' to 'targetShape'. + return unrollSingleResultStructuredOp(op, iterationBounds, + iterationIndexMapList, targetShape, + masks, builder); + } + // TODO(andydavis) Create trivial iteration bounds and index map for + // elementwise operations and call 'unrollSingleResultStructuredOp'. Remove + // fakefork/join if possible. + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: unrollSingleResultOpMatchingType on func:\n"); LLVM_DEBUG(op->getParentOfType().print(dbgs())); @@ -365,24 +611,6 @@ struct ConvertFakeForkFromBlockArgsOp : public RewritePattern { } }; -static Value *makeSplatZero(Location loc, PatternRewriter &rewriter, - VectorType vt) { - auto t = vt.getElementType(); - Value *f = nullptr; - if (t.isBF16() || t.isF16()) - f = rewriter.create(loc, t, rewriter.getF16FloatAttr(0.0f)) - .getResult(); - else if (t.isF32()) - f = rewriter.create(loc, t, rewriter.getF32FloatAttr(0.0f)) - .getResult(); - else if (t.isF64()) - f = rewriter.create(loc, t, rewriter.getF64FloatAttr(0.0f)) - .getResult(); - if (f) - return rewriter.create(loc, vt, f).getResult(); - llvm_unreachable("Unsupported type in `makeSplatZero`"); -} - // Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple // vector.strided_slice ops. struct ConvertFakeJoinOp : public RewritePattern { diff --git a/mlir/test/Conversion/VectorConversions/vector-to-vector.mlir b/mlir/test/Conversion/VectorConversions/vector-to-vector.mlir index dcd611d..1d9331e 100644 --- a/mlir/test/Conversion/VectorConversions/vector-to-vector.mlir +++ b/mlir/test/Conversion/VectorConversions/vector-to-vector.mlir @@ -47,3 +47,131 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> { %3 = addf %1, %2: vector<4x4xf32> return %3: vector<4x4xf32> } + + +#contraction_accesses0 = [ + (i, j, k) -> (i, k), + (i, j, k) -> (k, j), + (i, j, k) -> (i, j) +] +#contraction_trait0 = { + indexing_maps = #contraction_accesses0, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @contraction4x4_ijk + +// Reducing output vector [0, 0] + +// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// Reducing output vector [0, 2] + +// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// Reducing output vector [2, 0] + +// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// Reducing output vector [2, 2] + +// CHECK-NEXT: %[[R0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S02]], %[[R0S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S22]], %[[R1S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S42]], %[[R2S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// Insert output vector slices into 4x4 vector result. +// CHECK-NEXT: %[[RES0:.*]] = vector.insert_strided_slice %[[R3S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[RES1:.*]] = vector.insert_strided_slice %[[R3S02]], %[[RES0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[RES2:.*]] = vector.insert_strided_slice %[[R3S20]], %[[RES1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[RES3:.*]] = vector.insert_strided_slice %[[R3S22]], %[[RES2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: return %[[RES3]] : vector<4x4xf32> + +func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, + %arg2 : vector<4x4xf32>, %arg3 : index) + -> (vector<4x4xf32>) { + + %lhsm = vector.make_index_tuple %arg3, %arg3 : tuple + %rhsm = vector.make_index_tuple %arg3, %arg3 : tuple + %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm + : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> + + return %0 : vector<4x4xf32> +} + + +#contraction_accesses1 = [ + (i, k, j) -> (i, k), + (i, k, j) -> (k, j), + (i, k, j) -> (i, j) +] +#contraction_trait1 = { + indexing_maps = #contraction_accesses1, + iterator_types = ["parallel", "reduction", "parallel"] +} + +// CHECK-LABEL: func @contraction4x4_ikj + +// Reducing output vector [0, 0] + +// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// Reducing output vector [0, 2] + +// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// Reducing output vector [2, 0] + +// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// Reducing output vector [2, 2] + +// CHECK-NEXT: %[[R0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S02]], %[[R0S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// Insert output vector slices into 4x4 vector result. +// CHECK-NEXT: %[[RES0:.*]] = vector.insert_strided_slice %[[R1S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[RES1:.*]] = vector.insert_strided_slice %[[R1S02]], %[[RES0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[RES2:.*]] = vector.insert_strided_slice %[[R1S20]], %[[RES1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[RES3:.*]] = vector.insert_strided_slice %[[R1S22]], %[[RES2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: return %[[RES3]] : vector<4x4xf32> + +func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, + %arg2 : vector<4x4xf32>, %arg3 : index) + -> (vector<4x4xf32>) { + + %lhsm = vector.make_index_tuple %arg3, %arg3 : tuple + %rhsm = vector.make_index_tuple %arg3, %arg3 : tuple + %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm + : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> + + return %0 : vector<4x4xf32> +}