From 4e825c59be48b602a4790c91df0801138f3cbb6e Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Tue, 17 Dec 2019 06:26:31 -0800 Subject: [PATCH] Update vector op unrolling transformation to generate ExtractSlicesOp and InsertSlicesOp (instead of less structured chain of StridedSliceOps and InsertStridedSliceOps). PiperOrigin-RevId: 285968051 --- mlir/lib/Dialect/VectorOps/VectorOps.cpp | 4 +- mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 115 ++++++-- mlir/test/Dialect/VectorOps/vector-transforms.mlir | 295 ++++++++++----------- 3 files changed, 236 insertions(+), 178 deletions(-) diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 48fc0d4..1f6a4bc 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -543,8 +543,8 @@ isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType, SmallVector vectorOffsets(rank); int64_t linearIndex = i; for (unsigned j = 0; j < rank; ++j) { - vectorOffsets.push_back(linearIndex / sliceStrides[i]); - linearIndex %= sliceStrides[i]; + vectorOffsets[j] = linearIndex / sliceStrides[j]; + linearIndex %= sliceStrides[j]; } // Convert from unrolled vector-space offsets to element-space offsets. auto offsets = mlir::functional::zipMap( diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 6825709..8d70f4a 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -142,6 +142,47 @@ static void getMappedElements(const DenseMap &indexMap, } } +// Returns a tuple type with vector element types for each resulting slice +// of 'vectorType' unrolled by 'sizes' and 'strides'. +// TODO(andydavis) Move this to a utility function and share it with +// Extract/InsertSlicesOp verification. +static TupleType generateExtractSlicesOpResultType(VectorType vectorType, + ArrayRef sizes, + ArrayRef strides, + PatternRewriter &builder) { + assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); + unsigned rank = vectorType.getRank(); + assert(sizes.size() == rank); + assert(strides.size() == rank); + + // Compute shape ratio of 'shape' and 'sizes'. + auto shape = vectorType.getShape(); + auto maybeDimSliceCounts = shapeRatio(shape, sizes); + assert(maybeDimSliceCounts.hasValue()); + auto sliceDimCounts = *maybeDimSliceCounts; + + // Compute strides w.r.t number of slices in each dimension. + auto basis = computeStrides(sliceDimCounts); + int64_t sliceCount = computeMaxLinearIndex(sliceDimCounts); + SmallVector vectorTypes(sliceCount); + for (unsigned i = 0; i < sliceCount; ++i) { + // De-linearize w.r.t. 'basis'. + auto vectorOffsets = delinearize(i, basis); + // Convert from unrolled vector-space offsets to element-space offsets. + auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, + vectorOffsets, sizes); + // Initialize 'sliceSizes' to target 'sizes' + SmallVector sliceSizes(sizes.begin(), sizes.end()); + for (unsigned j = 0; j < rank; ++j) { + // Based on 'offsets' and 'shape' clip some dim sizes for partial tiles. + sliceSizes[j] = std::min(sliceSizes[j], shape[j] - offsets[j]); + } + // Create Vector type and add to 'vectorTypes[i]'. + vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType()); + } + return TupleType::get(vectorTypes, builder.getContext()); +} + // UnrolledVectorState aggregates per-operand/result vector state required for // unrolling. struct UnrolledVectorState { @@ -149,14 +190,16 @@ struct UnrolledVectorState { SmallVector unrollFactors; SmallVector basis; int64_t numInstances; + Value *slicesTuple; }; // Populates 'state' with unrolled shape, unroll factors, basis and // num unrolled instances for 'vectorType'. -static void initUnrolledVectorState(VectorType vectorType, +static void initUnrolledVectorState(VectorType vectorType, Value *initValue, const DenseMap &indexMap, ArrayRef targetShape, - UnrolledVectorState &state) { + UnrolledVectorState &state, + PatternRewriter &builder) { // Compute unrolled shape of 'vectorType'. state.unrolledShape.resize(vectorType.getRank()); getMappedElements(indexMap, targetShape, state.unrolledShape); @@ -168,6 +211,16 @@ static void initUnrolledVectorState(VectorType vectorType, // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'. state.basis = computeStrides(state.unrollFactors); state.numInstances = computeMaxLinearIndex(state.unrollFactors); + state.slicesTuple = nullptr; + if (initValue != nullptr) { + // Create ExtractSlicesOp. + SmallVector sizes(state.unrolledShape); + SmallVector strides(state.unrollFactors.size(), 1); + auto tupleType = + generateExtractSlicesOpResultType(vectorType, sizes, strides, builder); + state.slicesTuple = builder.create( + initValue->getLoc(), tupleType, initValue, sizes, strides); + } } // Computes and returns the linear index of the unrolled vector at @@ -202,10 +255,14 @@ static Value *getOrCreateUnrolledVectorSlice( assert(sliceLinearIndex < static_cast(cache.size())); auto *valueSlice = cache[sliceLinearIndex]; if (valueSlice == nullptr) { - assert(initValue != nullptr); - // Initialize 'cache' with slice from 'state.value'. - valueSlice = builder.create( - loc, initValue, sliceOffsets, state.unrolledShape, sliceStrides); + // Return tuple element at 'sliceLinearIndex'. + auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex); + auto initValueType = initValue->getType().cast(); + auto vectorType = + VectorType::get(state.unrolledShape, initValueType.getElementType()); + // Initialize 'cache' with slice from 'initValue'. + valueSlice = builder.create( + loc, vectorType, state.slicesTuple, tupleIndex); // Store value back to 'cache'. cache[sliceLinearIndex] = valueSlice; } @@ -293,8 +350,10 @@ static Value *unrollSingleResultStructuredOp(Operation *op, unsigned numVectors = vectors.size(); SmallVector unrolledVectorState(numVectors); for (unsigned i = 0; i < numVectors; ++i) { - initUnrolledVectorState(vectors[i].type, vectors[i].indexMap, targetShape, - unrolledVectorState[i]); + int64_t operandIndex = vectors[i].operandIndex; + auto *operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr; + initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap, + targetShape, unrolledVectorState[i], builder); } // Compute number of total unrolled instances. auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors); @@ -341,21 +400,26 @@ static Value *unrollSingleResultStructuredOp(Operation *op, caches[resultIndex][linearIndex] = resultVector; } - // Make zero splat into which we will insert results from - // 'cache[resultIndex]' - auto resultVectorType = op->getResult(0)->getType().cast(); - auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType); - SmallVector strides(resultValueState.unrollFactors.size(), 1); - // Insert vector accumulators into output. + // Create TupleOp of unrolled result vectors. + SmallVector vectorTupleTypes(resultValueState.numInstances); + SmallVector vectorTupleValues(resultValueState.numInstances); for (unsigned i = 0; i < resultValueState.numInstances; ++i) { - auto vectorOffsets = delinearize(i, resultValueState.basis); - // Convert from unrolled vector-space offsets to element-space offsets. - auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, - vectorOffsets, resultValueState.unrolledShape); - res = builder.create( - op->getLoc(), caches[resultIndex][i], res, offsets, strides); + vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast(); + vectorTupleValues[i] = caches[resultIndex][i]; } - return res; + TupleType tupleType = builder.getTupleType(vectorTupleTypes); + Value *tupleOp = builder.create(op->getLoc(), tupleType, + vectorTupleValues); + + // Create InsertSlicesOp(Tuple(result_vectors)). + auto resultVectorType = op->getResult(0)->getType().cast(); + SmallVector sizes(resultValueState.unrolledShape); + SmallVector strides(resultValueState.unrollFactors.size(), 1); + + Value *insertSlicesOp = builder.create( + op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes), + builder.getI64ArrayAttr(strides)); + return insertSlicesOp; } static void getVectorContractionOpUnrollState( @@ -381,10 +445,10 @@ static void getVectorContractionOpUnrollState( if (llvm::size(contractionOp.masks()) == 2) { // Add vectors for lhs/rhs vector mask arguments. Masks have the // same vector shape lhs/rhs args, so copy their index maps. - vectors.push_back( - {vectors[0].type, vectors[0].indexMap, accOperandIndex + 1, false}); - vectors.push_back( - {vectors[1].type, vectors[1].indexMap, accOperandIndex + 2, false}); + vectors.push_back({contractionOp.getLHSVectorMaskType(), + vectors[0].indexMap, accOperandIndex + 1, false}); + vectors.push_back({contractionOp.getRHSVectorMaskType(), + vectors[1].indexMap, accOperandIndex + 2, false}); } // Unroll 'op' 'iterationBounds' to 'targetShape'. // TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition @@ -509,6 +573,7 @@ struct SplitTransferReadOp : public OpRewritePattern { } }; +// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO(andydavis) Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir index c8d92ee..783f542 100644 --- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir @@ -3,64 +3,68 @@ // CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0, d1) -> (d0, d1) // CHECK-LABEL: func @add4x2 -// CHECK: %[[V1:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V2:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A1:.*]] = addf %[[V1]], %[[V2]] : vector<2x2xf32> -// CHECK-NEXT: %[[V3:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V4:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A2:.*]] = addf %[[V3]], %[[V4]] : vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.insert_strided_slice %[[A1]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.insert_strided_slice %[[A2]], %[[R1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> +// CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32> +// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.tuple %[[A1]], %[[A2]] : vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.insert_slices %[[R1]], [2, 2], [1, 1] : tuple, vector<2x2xf32>> into vector<4x2xf32> // CHECK-NEXT: return %[[R2:.*]] : vector<4x2xf32> + func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> { %1 = addf %0, %0: vector<4x2xf32> return %1: vector<4x2xf32> } // CHECK-LABEL: func @add4x4 +// CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> + +// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32> -// CHECK: %[[V1:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V2:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A1:.*]] = addf %[[V1]], %[[V2]] : vector<2x2xf32> +// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32> -// CHECK-NEXT: %[[V3:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V4:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A2:.*]] = addf %[[V3]], %[[V4]] : vector<2x2xf32> +// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES1]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[A3:.*]] = addf %[[TG5]], %[[TG6]] : vector<2x2xf32> -// CHECK-NEXT: %[[V5:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V6:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A3:.*]] = addf %[[V5]], %[[V6]] : vector<2x2xf32> +// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES1]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES2]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[A4:.*]] = addf %[[TG7]], %[[TG8]] : vector<2x2xf32> -// CHECK-NEXT: %[[V7:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V8:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A4:.*]] = addf %[[V7]], %[[V8]] : vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.tuple %[[A1]], %[[A2]], %[[A3]], %[[A4]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.insert_slices %[[R1]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.insert_strided_slice %[[A1]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.insert_strided_slice %[[A2]], %[[R1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.insert_strided_slice %[[A3]], %[[R2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[R4:.*]] = vector.insert_strided_slice %[[A4]], %[[R3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %[[R2]], [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[V9:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V10:.*]] = vector.strided_slice %[[R4]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A5:.*]] = addf %[[V9]], %[[V10]] : vector<2x2xf32> +// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[A5:.*]] = addf %[[TG9]], %[[TG10]] : vector<2x2xf32> -// CHECK-NEXT: %[[V11:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V12:.*]] = vector.strided_slice %[[R4]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A6:.*]] = addf %[[V11]], %[[V12]] : vector<2x2xf32> +// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[A6:.*]] = addf %[[TG11]], %[[TG12]] : vector<2x2xf32> -// CHECK-NEXT: %[[V13:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V14:.*]] = vector.strided_slice %[[R4]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A7:.*]] = addf %[[V13]], %[[V14]] : vector<2x2xf32> +// CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES4]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[A7:.*]] = addf %[[TG13]], %[[TG14]] : vector<2x2xf32> -// CHECK-NEXT: %[[V15:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[V16:.*]] = vector.strided_slice %[[R4]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A8:.*]] = addf %[[V15]], %[[V16]] : vector<2x2xf32> +// CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES4]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[A8:.*]] = addf %[[TG15]], %[[TG16]] : vector<2x2xf32> -// CHECK-NEXT: %[[R5:.*]] = vector.insert_strided_slice %[[A5]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[R6:.*]] = vector.insert_strided_slice %[[A6]], %[[R5]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[R7:.*]] = vector.insert_strided_slice %[[A7]], %[[R6]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[R8:.*]] = vector.insert_strided_slice %[[A8]], %[[R7]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: return %[[R8]] : vector<4x4xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.tuple %[[A5]], %[[A6]], %[[A7]], %[[A8]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[R4:.*]] = vector.insert_slices %[[R3]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> +// CHECK-NEXT: return %[[R4]] : vector<4x4xf32> func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> { %2 = addf %0, %1: vector<4x4xf32> @@ -80,64 +84,76 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> { // CHECK-LABEL: func @contraction4x4_ijk +// CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 6] : vector<4x6xi1> +// CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [6, 4] : vector<6x4xi1> + // Reducing output vector [0, 0] -// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %[[LMASK1]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[LMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[RMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %[[LMASK2]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x6xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<6x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %[[LMASK]], [2, 2], [1, 1] : vector<4x6xi1> into tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %[[RMASK]], [2, 2], [1, 1] : vector<6x4xi1> into tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> + +// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] -// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[RMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[RMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %[[LMASK1]], %[[RMASK4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[RMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %[[LMASK2]], %[[RMASK5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] -// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[LMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK3]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[LMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %[[LMASK4]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[LMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %[[LMASK5]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> +// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] -// CHECK-NEXT: %[[R0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S02]], %[[R0S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S22]], %[[R1S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S42]], %[[R2S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// Insert output vector slices into 4x4 vector result. -// CHECK-NEXT: %[[RES0:.*]] = vector.insert_strided_slice %[[R3S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[RES1:.*]] = vector.insert_strided_slice %[[R3S02]], %[[RES0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[RES2:.*]] = vector.insert_strided_slice %[[R3S20]], %[[RES1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[RES3:.*]] = vector.insert_strided_slice %[[R3S22]], %[[RES2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: return %[[RES3]] : vector<4x4xf32> +// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> +// CHECK-NEXT: return %[[RES1]] : vector<4x4xf32> func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, %arg2 : vector<4x4xf32>, %arg3 : index) @@ -162,40 +178,47 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, // CHECK-LABEL: func @contraction4x4_ikj + +// CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 2] : vector<4x2xi1> +// CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [2, 4] : vector<2x4xi1> + // Reducing output vector [0, 0] -// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xf32> into tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xi1> into tuple, vector<2x2xi1>> +// CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xi1> into tuple, vector<2x2xi1>> + +// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>> +// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] -// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] -// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK1]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>> +// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] -// CHECK-NEXT: %[[R0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S02]], %[[R0S22]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// Insert output vector slices into 4x4 vector result. -// CHECK-NEXT: %[[RES0:.*]] = vector.insert_strided_slice %[[R1S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[RES1:.*]] = vector.insert_strided_slice %[[R1S02]], %[[RES0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[RES2:.*]] = vector.insert_strided_slice %[[R1S20]], %[[RES1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[RES3:.*]] = vector.insert_strided_slice %[[R1S22]], %[[RES2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: return %[[RES3]] : vector<4x4xf32> +// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> +// CHECK-NEXT: return %[[RES1]] : vector<4x4xf32> func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, %arg2 : vector<4x4xf32>, %arg3 : index) @@ -209,41 +232,16 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, } // CHECK-LABEL: func @contraction4x4_ikj_xfer_read +// TODO(andydavis) Add VTR splitting back into this test in follow up CL. -// Capture constants used to re-index vector transfer reads. -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[C0:.*]] = constant 0 : index // Check LHS vector.transfer read is split for each user. -// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[ISS0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> -// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[ISS1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[ISS0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> - -// Check RHS vector.transfer read is split for each user. -// CHECK: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[ISS2:.*]] = vector.insert_strided_slice %[[VTR2]], %{{.*}} {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32> -// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[ISS3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[ISS2]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32> - -// Check ACC vector.transfer read is split for each user (should be 4). -// CHECK: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[ISS4:.*]] = vector.insert_strided_slice %[[VTR4]], %{{.*}} {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[ISS5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[ISS4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[ISS6:.*]] = vector.insert_strided_slice %[[VTR6]], %[[ISS5]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[ISS7:.*]] = vector.insert_strided_slice %[[VTR7]], %[[ISS6]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> - -// Check LHS slice uses splat of split tranfer read results. -// CHECK: vector.strided_slice %[[ISS1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> - -// Check RHS slice uses splat of split tranfer read results. -// CHECK: vector.strided_slice %[[ISS3]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> - -// Check ACC slice uses splat of split tranfer read results. -// CHECK: vector.strided_slice %[[ISS7]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> + +// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<4x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x4xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<4x4xf32> + func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, %arg1 : memref<2x4xf32>, @@ -270,17 +268,12 @@ func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, return %3 : vector<4x4xf32> } +// TODO(andydavis) Update test with VTR split transform. // CHECK-LABEL: func @vector_transfers -// CHECK-COUNT-8: vector.transfer_read -// CHECK-COUNT-2: vector.strided_slice -// CHECK-COUNT-1: addf -// CHECK-COUNT-2: vector.strided_slice -// CHECK-COUNT-1: addf -// CHECK-COUNT-2: vector.strided_slice -// CHECK-COUNT-1: addf -// CHECK-COUNT-2: vector.strided_slice -// CHECK-COUNT-1: addf -// CHECK-COUNT-4: vector.insert_strided_slice +// CHECK-COUNT-2: vector.transfer_read +// CHECK-COUNT-2: vector.extract_slices +// CHECK-COUNT-4: addf +// CHECK-COUNT-1: vector.insert_slices // CHECK: vector.transfer_write func @vector_transfers(%arg0: index, %arg1: index) { -- 2.7.4