From: Nicolas Vasilache Date: Sun, 17 May 2020 14:15:58 +0000 (-0400) Subject: [mlir] NFC - VectorTransforms use OpBuilder where relevant X-Git-Tag: llvmorg-12-init~5843 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1d6eb09d2225310b1af54856c34fdcd45cd0f9ef;p=platform%2Fupstream%2Fllvm.git [mlir] NFC - VectorTransforms use OpBuilder where relevant Summary: This will allow using unrolling outside of only rewrite patterns. Differential Revision: https://reviews.llvm.org/D80083 --- diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h index a7325ce..337ac75 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -65,7 +65,7 @@ namespace vector { // This will be extended in the future to support more advanced use cases than // simple pointwise ops. SmallVector -unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, +unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op, ArrayRef targetShape); } // namespace vector diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 851b54b..af7e5ad 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -68,8 +68,8 @@ static int64_t computeMaxLinearIndex(ArrayRef basis) { // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. -static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, - Location loc, Operation *op, +static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, + Operation *op, ArrayRef operands, ArrayRef resultTypes) { OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, @@ -98,7 +98,7 @@ static void getMappedElements(const DenseMap &indexMap, static TupleType generateExtractSlicesOpResultType(VectorType vectorType, ArrayRef sizes, ArrayRef strides, - PatternRewriter &builder) { + OpBuilder &builder) { assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); assert(static_cast(sizes.size()) == vectorType.getRank()); assert(static_cast(strides.size()) == vectorType.getRank()); @@ -140,7 +140,7 @@ static void initUnrolledVectorState(VectorType vectorType, Value initValue, const DenseMap &indexMap, ArrayRef targetShape, UnrolledVectorState &state, - PatternRewriter &builder) { + OpBuilder &builder) { // Compute unrolled shape of 'vectorType'. state.unrolledShape.resize(vectorType.getRank()); getMappedElements(indexMap, targetShape, state.unrolledShape); @@ -183,7 +183,7 @@ getUnrolledVectorLinearIndex(UnrolledVectorState &state, static Value getOrCreateUnrolledVectorSlice( Location loc, UnrolledVectorState &state, ArrayRef vectorOffsets, ArrayRef offsets, DenseMap &indexMap, - Value initValue, SmallVectorImpl &cache, PatternRewriter &builder) { + Value initValue, SmallVectorImpl &cache, OpBuilder &builder) { // Compute slice offsets. SmallVector sliceOffsets(state.unrolledShape.size()); getMappedElements(indexMap, offsets, sliceOffsets); @@ -275,7 +275,7 @@ static Value unrollSingleResultStructuredOp(Operation *op, std::vector &vectors, unsigned resultIndex, ArrayRef targetShape, - PatternRewriter &builder) { + OpBuilder &builder) { auto shapedType = op->getResult(0).getType().dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) assert(false && "Expected a statically shaped result type"); @@ -426,7 +426,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef targetShape, // Entry point for unrolling declarative pattern rewrites. SmallVector mlir::vector::unrollSingleResultOpMatchingType( - PatternRewriter &builder, Operation *op, ArrayRef targetShape) { + OpBuilder &builder, Operation *op, ArrayRef targetShape) { assert(op->getNumResults() == 1 && "Expected single result operation"); // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'. @@ -451,12 +451,10 @@ SmallVector mlir::vector::unrollSingleResultOpMatchingType( /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and /// calls 'fn' with linear index and indices for each slice. -static void -generateTransferOpSlices(Type memrefElementType, VectorType vectorType, - TupleType tupleType, ArrayRef sizes, - ArrayRef strides, ArrayRef indices, - PatternRewriter &rewriter, - function_ref)> fn) { +static void generateTransferOpSlices( + Type memrefElementType, VectorType vectorType, TupleType tupleType, + ArrayRef sizes, ArrayRef strides, ArrayRef indices, + OpBuilder &builder, function_ref)> fn) { // Compute strides w.r.t. to slice counts in each dimension. auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); assert(maybeDimSliceCounts.hasValue()); @@ -484,7 +482,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType, } unsigned indexOffset = numSliceIndices - vectorRank; - auto *ctx = rewriter.getContext(); + auto *ctx = builder.getContext(); for (unsigned i = 0; i < numSlices; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); auto elementOffsets = @@ -498,7 +496,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType, auto expr = getAffineDimExpr(0, ctx) + getAffineConstantExpr(elementOffsets[j - indexOffset], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - sliceIndices[j] = rewriter.create( + sliceIndices[j] = builder.create( indices[j].getLoc(), map, ArrayRef(indices[j])); } } @@ -1683,8 +1681,13 @@ public: // TODO(andydavis) Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + // clang-format off + patterns.insert(context); + // clang-format on } void mlir::vector::populateVectorSlicesLoweringPatterns( @@ -1695,9 +1698,14 @@ void mlir::vector::populateVectorSlicesLoweringPatterns( void mlir::vector::populateVectorContractLoweringPatterns( OwningRewritePatternList &patterns, MLIRContext *context, VectorTransformsOptions parameters) { - patterns.insert(context); + // clang-format off + patterns.insert(context); + // clang-format on patterns.insert(parameters, context); }