SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
ArrayRef<int64_t> sizes);
+/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
+int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
+
/// Given the slice strides together with a linear index in the dimension
/// space, returns the vector-space offsets in each dimension for a
/// de-linearized index.
return res;
}
-/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
-static int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
- assert(offsets.size() == basis.size());
- int64_t linearIndex = 0;
- for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
- linearIndex += offsets[idx] * basis[idx];
- return linearIndex;
-}
-
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
}
};
+/// Returns the producer Value of the same type as 'consumerValue', by tracking
+/// the tuple index and offsets of the consumer vector value through the
+/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp)
+/// from consumer to producer. Each operation in the chain is structured, and
+/// so the tuple index and offsets can be mapped from result to input, while
+/// visiting each operation in the chain.
+/// Returns nullptr on failure.
+static Value getProducerValue(Value consumerValue) {
+ auto consumerVectorType = consumerValue.getType().cast<VectorType>();
+ // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type.
+ int64_t tupleIndex = -1;
+ SmallVector<int64_t, 4> offsets(consumerVectorType.getRank(), 0);
+ auto *op = consumerValue.getDefiningOp();
+ while (op != nullptr) {
+ if (auto tupleGetOp = dyn_cast<vector::TupleGetOp>(op)) {
+ assert(tupleIndex == -1 && "TupleGetOp must have vector result type");
+
+ // Update 'tupleIndex' and next defining 'op' to visit.
+ tupleIndex = tupleGetOp.getIndex();
+ op = tupleGetOp.vectors().getDefiningOp();
+ } else if (auto extractSlicesOp = dyn_cast<vector::ExtractSlicesOp>(op)) {
+ assert(tupleIndex >= 0);
+
+ // Compute slice strides for 'extractSlicesOp'.
+ SmallVector<int64_t, 4> sizes;
+ extractSlicesOp.getSizes(sizes);
+ auto sliceStrides = computeStrides(
+ extractSlicesOp.getSourceVectorType().getShape(), sizes);
+
+ // Compute 'elementOffsets' into 'extractSlicesOp' input vector type,
+ // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'.
+ auto vectorOffsets = delinearize(sliceStrides, tupleIndex);
+ auto elementOffsets =
+ computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
+
+ // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative
+ // to the 'extractSlicesOp' input vector type.
+ assert(offsets.size() == elementOffsets.size());
+ for (unsigned i = 0, e = offsets.size(); i < e; ++i)
+ offsets[i] += elementOffsets[i];
+
+ // Clear 'tupleIndex' and update next defining 'op' to visit.
+ tupleIndex = -1;
+ op = extractSlicesOp.vector().getDefiningOp();
+ } else if (auto insertSlicesOp = dyn_cast<vector::InsertSlicesOp>(op)) {
+ assert(tupleIndex == -1);
+
+ // Compute slice strides for 'insertSlicesOp'.
+ SmallVector<int64_t, 4> sizes;
+ insertSlicesOp.getSizes(sizes);
+ auto sliceStrides = computeStrides(
+ insertSlicesOp.getResultVectorType().getShape(), sizes);
+
+ // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice,
+ // of 'insertSlicesOp' result vector type at 'offsets'.
+ SmallVector<int64_t, 4> vectorOffsets(offsets.size());
+ assert(offsets.size() == sizes.size());
+ for (unsigned i = 0, e = offsets.size(); i < e; ++i)
+ vectorOffsets[i] = offsets[i] / sizes[i];
+
+ // Compute the source tuple element index.
+ tupleIndex = linearize(vectorOffsets, sliceStrides);
+
+ // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now
+ // relative to input tuple element vector type at 'tupleIndex'.
+ auto elementOffsets =
+ computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
+ assert(offsets.size() == elementOffsets.size());
+ for (unsigned i = 0, e = offsets.size(); i < e; ++i) {
+ offsets[i] -= elementOffsets[i];
+ assert(offsets[i] >= 0);
+ }
+
+ // Update next defining 'op' to visit.
+ op = insertSlicesOp.vectors().getDefiningOp();
+ } else if (auto tupleOp = dyn_cast<vector::TupleOp>(op)) {
+ assert(tupleIndex >= 0);
+
+ // Return tuple element 'value' at 'tupleIndex' if it matches type.
+ auto value = tupleOp.getOperand(tupleIndex);
+ if (value.getType() == consumerVectorType)
+ return value;
+
+ // Update 'tupleIndex' and next defining 'op' to visit.
+ tupleIndex = -1;
+ op = value.getDefiningOp();
+ } else {
+ break;
+ }
+ }
+ return nullptr;
+}
+
/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
//
// Example:
LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
PatternRewriter &rewriter) const override {
- // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
- auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
- tupleGetOp.vectors().getDefiningOp());
- if (!extractSlicesOp)
- return failure();
-
- // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
- auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
- extractSlicesOp.vector().getDefiningOp());
- if (!insertSlicesOp)
- return failure();
-
- // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
- auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
- insertSlicesOp.vectors().getDefiningOp());
- if (!tupleOp)
- return failure();
-
- // Forward Value from 'tupleOp' at 'tupleGetOp.index'.
- Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
- rewriter.replaceOp(tupleGetOp, tupleValue);
- return success();
+ if (auto producer = getProducerValue(tupleGetOp.getResult())) {
+ rewriter.replaceOp(tupleGetOp, producer);
+ return success();
+ }
+ return failure();
}
};
using llvm::SetVector;
-namespace mlir {
+using namespace mlir;
-SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
- ArrayRef<int64_t> sizes) {
+SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> sizes) {
int64_t rank = shape.size();
// Compute the count for each dimension.
SmallVector<int64_t, 4> sliceDimCounts(rank);
return sliceStrides;
}
-SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> sliceStrides,
- int64_t index) {
+int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
+ assert(offsets.size() == basis.size());
+ int64_t linearIndex = 0;
+ for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
+ linearIndex += offsets[idx] * basis[idx];
+ return linearIndex;
+}
+
+SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
+ int64_t index) {
int64_t rank = sliceStrides.size();
SmallVector<int64_t, 4> vectorOffsets(rank);
for (int64_t r = 0; r < rank; ++r) {
return vectorOffsets;
}
-SmallVector<int64_t, 4>
-computeElementOffsetsFromVectorSliceOffsets(ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> vectorOffsets) {
+SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
+ ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
return functional::zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, sizes);
}
-SmallVector<int64_t, 4> computeSliceSizes(ArrayRef<int64_t> shape,
- ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> elementOffsets) {
+SmallVector<int64_t, 4>
+mlir::computeSliceSizes(ArrayRef<int64_t> shape, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> elementOffsets) {
int64_t rank = shape.size();
SmallVector<int64_t, 4> sliceSizes(rank);
for (unsigned r = 0; r < rank; ++r)
return sliceSizes;
}
-Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
- ArrayRef<int64_t> subShape) {
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
+ ArrayRef<int64_t> subShape) {
if (superShape.size() < subShape.size()) {
return Optional<SmallVector<int64_t, 4>>();
}
return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
}
-Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
- VectorType subVectorType) {
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
+ VectorType subVectorType) {
assert(superVectorType.getElementType() == subVectorType.getElementType() &&
"vector types must be of the same elemental type");
return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
return getParentsOfType<AffineForOp>(op);
}
-AffineMap
-makePermutationMap(Operation *op, ArrayRef<Value> indices,
- const DenseMap<Operation *, unsigned> &loopToVectorDim) {
+AffineMap mlir::makePermutationMap(
+ Operation *op, ArrayRef<Value> indices,
+ const DenseMap<Operation *, unsigned> &loopToVectorDim) {
DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
auto enclosingLoops = getEnclosingforOps(op);
for (auto *forInst : enclosingLoops) {
enclosingLoopToVectorDim.insert(*it);
}
}
- return makePermutationMap(indices, enclosingLoopToVectorDim);
+ return ::makePermutationMap(indices, enclosingLoopToVectorDim);
}
bool matcher::operatesOnSuperVectorsOf(Operation &op,
return true;
}
-} // namespace mlir
return %1 : vector<8xf32>
}
+// CHECK-LABEL: func @tuple_get_producer_consumer
+// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
+// CHECK: return %[[A7]] : vector<2x4xf32>
+
+func @tuple_get_producer_consumer(
+ %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
+ %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
+ %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
+ %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
+ %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
+ : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+ vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
+ // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
+ %1 = vector.insert_slices %0, [2, 4], [1, 1]
+ : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+ vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+ into vector<4x16xf32>
+ // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
+ %2 = vector.extract_slices %1, [4, 8], [1, 1]
+ : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
+ // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
+ %3 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+ // %arg7 == %3 at tupleIndex = -1, offsets = [2, 4]
+ %4 = vector.extract_slices %3, [2, 4], [1, 1]
+ : vector<4x8xf32> into
+ tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+ // %arg7 == %4 at tupleIndex = 3, offsets = [0, 0]
+ %5 = vector.tuple_get %4, 3
+ : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+ // %arg7 == %5
+ return %5 : vector<2x4xf32>
+}
+
+// CHECK-LABEL: func @tuple_get_producer_consumer_swizzle
+// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
+// CHECK: return %[[A7]] : vector<2x4xf32>
+
+func @tuple_get_producer_consumer_swizzle(
+ %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
+ %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
+ %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
+ %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
+ %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
+ : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+ vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
+ // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
+ %1 = vector.insert_slices %0, [2, 4], [1, 1]
+ : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+ vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+ into vector<4x16xf32>
+ // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
+ %2 = vector.extract_slices %1, [4, 8], [1, 1]
+ : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
+ // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
+
+ // Extract tuple elements.
+ %3 = vector.tuple_get %2, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+ %4 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+ // %arg7 == %4 at tupleIndex = -1, offsets = [2, 4]
+
+ // Swizzle tuple elements.
+ %5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32>
+ // %arg7 == %5 at tupleIndex = 0, offsets = [2, 4]
+ %6 = vector.tuple_get %5, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+ // %arg7 == %6 at tupleIndex = -1, offsets = [2, 4]
+ %7 = vector.extract_slices %6, [2, 4], [1, 1]
+ : vector<4x8xf32> into
+ tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+ // %arg7 == %7 at tupleIndex = 3, offsets = [0, 0]
+ %8 = vector.tuple_get %7, 3
+ : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+ // %arg7 == %8
+ return %8 : vector<2x4xf32>
+}
+
// CHECK-LABEL: func @vector_transfers_vector_element_type
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index