From b9715156ff909fb38725893afb1d18709cb7f1bd Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Tue, 20 Apr 2021 11:26:44 +0000 Subject: [PATCH] [mlir][linalg] lower index operations during linalg to vector lowering. The patch extends the vectorization pass to lower linalg index operations to vector code. It allocates constant 1d vectors that enumerate the indexes along the iteration dimensions and broadcasts/transposes these 1d vectors to the iteration space. Differential Revision: https://reviews.llvm.org/D100373 --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 10 ++++ mlir/include/mlir/IR/Builders.h | 1 + mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 23 +++++++++ mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 3 +- .../Dialect/Linalg/Transforms/Vectorization.cpp | 58 ++++++++++++++++++--- mlir/lib/IR/Builders.cpp | 6 +++ mlir/test/Dialect/Linalg/vectorization.mlir | 60 ++++++++++++++++------ 7 files changed, 134 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 6c3e86c..0512b35 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1242,11 +1242,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /// appear in the operands. SmallVector createFlatListOfOperandDims(OpBuilder &, Location); + /// Return the flat list of all operands' static dimension sizes in the + /// order they appear in the operands. All operand dimension sizes have to + /// be statically known. + SmallVector createFlatListOfOperandStaticDims(); + /// Create the loop ranges to materialize the computation over the current /// operands. This is done by applying `getShapesToLoopsMap` to /// `createFlatListOfOperandDims`. SmallVector createLoopRanges(OpBuilder &b, Location loc); + /// Compute the static loop sizes necessary to vectorize the computation. + /// This is done by applying `getShapesToLoopsMap` to + /// `createFlatListOfOperandStaticDims`. + SmallVector computeStaticLoopSizes(); + /// Returns all the operands past the inputs, output_buffers and /// init_tensors operands. Asserts that these operands are value types to /// allow transformations like tiling to just use the values when cloning diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index f8b119c..1e0863c 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -124,6 +124,7 @@ public: DenseIntElementsAttr getBoolVectorAttr(ArrayRef values); DenseIntElementsAttr getI32VectorAttr(ArrayRef values); DenseIntElementsAttr getI64VectorAttr(ArrayRef values); + DenseIntElementsAttr getIndexVectorAttr(ArrayRef values); /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty. /// These are generally preferable for representing general lists of integers diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f1bf22c..1c45467 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -193,6 +193,16 @@ SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, return res; } +SmallVector LinalgOp::createFlatListOfOperandStaticDims() { + SmallVector res; + for (Value v : getShapedOperands()) { + ShapedType t = v.getType().template cast(); + assert(t.hasStaticShape() && "expected operands to have static shapes"); + llvm::append_range(res, t.getShape()); + } + return res; +} + SmallVector LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { AffineMap map = getLoopsToShapesMap(); unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); @@ -211,6 +221,19 @@ SmallVector LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { return res; } +SmallVector LinalgOp::computeStaticLoopSizes() { + AffineMap map = getLoopsToShapesMap(); + unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); + SmallVector allShapeSizes = createFlatListOfOperandStaticDims(); + SmallVector res(numDims, 0); + for (unsigned idx = 0; idx < numRes; ++idx) { + auto result = map.getResult(idx); + if (auto d = result.dyn_cast()) + res[d.getPosition()] = allShapeSizes[idx]; + } + return res; +} + /// Visitor to check if any of the given set of positions from AffineDimExprs /// are used within an AffineExpr. struct HasAffineDimExprVisitor diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 55402a7..2e8b158 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -462,8 +462,7 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); - // TODO: remove hasIndexSemantics check once index ops are supported. - if (!linalgOp || linalgOp.hasIndexSemantics()) + if (!linalgOp) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index c14a3b3..14ef418 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -166,6 +166,42 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op, return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; } +/// Helper function to vectorize the index operations of a `linalgOp`. Return +/// VectorizationStatus::NewOp to signal the vectorization algorithm that it +/// should map the produced operations. This function is meant to be used as a +/// CustomVectorizationHook. +static VectorizationResult +vectorizeLinalgIndex(OpBuilder &builder, Operation *op, LinalgOp linalgOp) { + IndexOp indexOp = dyn_cast(op); + if (!indexOp) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + auto loc = indexOp.getLoc(); + // Compute the static loop sizes of the index op. + auto targetShape = linalgOp.computeStaticLoopSizes(); + // Compute a one-dimensional index vector for the index op dimension. + SmallVector constantSeq( + llvm::seq(0, targetShape[indexOp.dim()])); + ConstantOp constantOp = + builder.create(loc, builder.getIndexVectorAttr(constantSeq)); + // Return the one-dimensional index vector if it lives in the trailing + // dimension of the iteration space since the vectorization algorithm in this + // case can handle the broadcast. + if (indexOp.dim() == targetShape.size() - 1) + return VectorizationResult{VectorizationStatus::NewOp, constantOp}; + // Otherwise permute the targetShape to move the index dimension last, + // broadcast the one-dimensional index vector to the permuted shape, and + // finally transpose the broadcasted index vector to undo the permutation. + std::swap(targetShape[indexOp.dim()], targetShape.back()); + auto broadCastOp = builder.create( + loc, VectorType::get(targetShape, builder.getIndexType()), constantOp); + SmallVector transposition( + llvm::seq(0, linalgOp.getNumLoops())); + std::swap(transposition.back(), transposition[indexOp.dim()]); + auto transposeOp = + builder.create(loc, broadCastOp, transposition); + return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; +} + /// Generic vectorization for a single operation `op`, given already vectorized /// operands carried by `bvm`. Vectorization occurs as follows: /// 1. Try to apply any of the `customVectorizationHooks` and return its @@ -245,7 +281,7 @@ static bool hasOnlyScalarElementwiseOp(Region &r) { if (!llvm::hasSingleElement(r)) return false; for (Operation &op : r.front()) { - if (!(isa(op) || + if (!(isa(op) || OpTrait::hasElementwiseMappableTraits(&op)) || llvm::any_of(op.getResultTypes(), [](Type type) { return !type.isIntOrIndexOrFloat(); })) @@ -293,7 +329,9 @@ static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) { /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d /// load). /// TODO: Reuse opportunities for RAR dependencies. -/// 4. Register CustomVectorizationHook for YieldOp to capture the results. +/// 4a. Register CustomVectorizationHook for YieldOp to capture the results. +/// 4b. Register CustomVectorizationHook for IndexOp to access the iteration +/// indices. /// 5. Iteratively call vectorizeOneOp on the region operations. LogicalResult vectorizeAsLinalgGeneric( OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl &newResults, @@ -333,16 +371,23 @@ LogicalResult vectorizeAsLinalgGeneric( bvm.map(vectorArg, vectorRead); } - // 4. Register CustomVectorizationHook for yieldOp. + auto hooks = llvm::to_vector<4>(customVectorizationHooks); + // 4a. Register CustomVectorizationHook for yieldOp. CustomVectorizationHook vectorizeYield = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults); }; - // Append the vectorizeYield hook. - auto hooks = llvm::to_vector<4>(customVectorizationHooks); hooks.push_back(vectorizeYield); + // 4b. Register CustomVectorizationHook for indexOp. + CustomVectorizationHook vectorizeIndex = + [&](Operation *op, + const BlockAndValueMapping &bvm) -> VectorizationResult { + return vectorizeLinalgIndex(builder, op, linalgOp); + }; + hooks.push_back(vectorizeIndex); + // 5. Iteratively call `vectorizeOneOp` to each op in the slice. for (Operation &op : block.getOperations()) { VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks); @@ -401,9 +446,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); - // TODO: remove once index ops are supported. - if (linalgOp.hasIndexSemantics()) - return failure(); if (isElementwise(op)) return success(); return success(isaContractionOpInterface(linalgOp)); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index d1ab379..4f8aa9e 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -120,6 +120,12 @@ DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef values) { values); } +DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef values) { + return DenseIntElementsAttr::get( + VectorType::get(static_cast(values.size()), getIndexType()), + values); +} + DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef values) { return DenseIntElementsAttr::get( RankedTensorType::get(static_cast(values.size()), diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index faaadcf9..c18bf5b 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -174,6 +174,49 @@ func @test_vectorize_copy_scalar(%A : memref, %B : memref) { // ----- +// CHECK-LABEL: func @test_vectorize_trailing_index + // CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>) +func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) { + // CHECK-DAG: %[[CST0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + outs(%arg0: memref<1x2x4x8xindex>) { + ^bb0(%arg1: index): + // CHECK: %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<8xindex> to vector<1x2x4x8xindex> + // CHECK: vector.transfer_write %[[BCST]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex> + %0 = linalg.index 3 : index + linalg.yield %0 : index + } + return +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_inner_index + // CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>) +func @test_vectorize_inner_index(%arg0: memref<1x2x4x8xindex>) { + // CHECK-DAG: %[[CST0:.*]] = constant dense<[0, 1]> : vector<2xindex> + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + outs(%arg0: memref<1x2x4x8xindex>) { + ^bb0(%arg1: index): + // CHECK: %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<2xindex> to vector<1x8x4x2xindex> + // CHECK: %[[TRAN:.*]] = vector.transpose %[[BCST]], [0, 3, 2, 1] : vector<1x8x4x2xindex> to vector<1x2x4x8xindex> + // CHECK: vector.transfer_write %[[TRAN]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex> + %0 = linalg.index 1 : index + linalg.yield %0 : index + } + return +} + +// ----- + // CHECK-LABEL: func @generic_vectorize // CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>, // CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32) @@ -252,7 +295,6 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, return } - // ----- // CHECK-LABEL: func @generic_vectorize_tensor @@ -469,19 +511,3 @@ func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> return %0 : tensor<6x?x?x?xf32> } - -// ----- - -// CHECK-LABEL: @index_op -// CHECK: linalg.generic -func @index_op(%arg0: memref<4x8xindex>) { - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - outs(%arg0 : memref<4x8xindex>) { - ^bb0(%arg1: index): // no predecessors - %0 = linalg.index 1 : index - linalg.yield %0 : index - } - return -} -- 2.7.4