From dd14e5825209386129770296f9bc3a14ab0b4592 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 30 Sep 2020 12:58:24 -0700 Subject: [PATCH] [mlir][vector] First step of vector distribution transformation This is the first of several steps to support distributing large vectors. This adds instructions extract_map and insert_map that allow us to do incremental lowering. Right now the transformation only apply to simple pointwise operation with a vector size matching the multiplicity of the IDs used to distribute the vector. This can be used to distribute large vectors to loops or SPMD. Differential Revision: https://reviews.llvm.org/D88341 --- mlir/include/mlir/Dialect/Vector/VectorOps.td | 105 +++++++++++++++++++++ .../include/mlir/Dialect/Vector/VectorTransforms.h | 41 ++++++++ mlir/lib/Dialect/Vector/VectorOps.cpp | 47 +++++++++ mlir/lib/Dialect/Vector/VectorTransforms.cpp | 34 +++++++ mlir/test/Dialect/Vector/invalid.mlir | 28 ++++++ mlir/test/Dialect/Vector/ops.mlir | 11 +++ mlir/test/Dialect/Vector/vector-distribution.mlir | 13 +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp | 26 +++++ 8 files changed, 305 insertions(+) create mode 100644 mlir/test/Dialect/Vector/vector-distribution.mlir diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index f74c868..42e9470 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -454,6 +454,71 @@ def Vector_ExtractSlicesOp : }]; } +def Vector_ExtractMapOp : + Vector_Op<"extract_map", [NoSideEffect]>, + Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>, + Results<(outs AnyVector)> { + let summary = "vector extract map operation"; + let description = [{ + Takes an 1-D vector and extract a sub-part of the vector starting at id with + a size of `vector size / multiplicity`. This maps a given multiplicity of + the vector to a Value such as a loop induction variable or an SPMD id. + + Similarly to vector.tuple_get, this operation is used for progressive + lowering and should be folded away before converting to LLVM. + + + For instance, the following code: + ```mlir + %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32> + %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32> + %c = addf %a, %b: vector<32xf32> + vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32> + ``` + can be rewritten to: + ```mlir + %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32> + %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32> + %ea = vector.extract_map %a[%id : 32] : vector<32xf32> to vector<1xf32> + %eb = vector.extract_map %b[%id : 32] : vector<32xf32> to vector<1xf32> + %ec = addf %ea, %eb : vector<1xf32> + %c = vector.insert_map %ec, %id, 32 : vector<1xf32> to vector<32xf32> + vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32> + ``` + + Where %id can be an induction variable or an SPMD id going from 0 to 31. + + And then be rewritten to: + ```mlir + %a = vector.transfer_read %A[%id]: memref<32xf32>, vector<1xf32> + %b = vector.transfer_read %B[%id]: memref<32xf32>, vector<1xf32> + %c = addf %a, %b: vector<1xf32> + vector.transfer_write %c, %C[%id]: memref<32xf32>, vector<1xf32> + ``` + + Example: + + ```mlir + %ev = vector.extract_map %v[%id:32] : vector<32xf32> to vector<1xf32> + ``` + }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, " # + "Value vector, Value id, int64_t multiplicity">]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return vector().getType().cast(); + } + VectorType getResultType() { + return getResult().getType().cast(); + } + }]; + let assemblyFormat = [{ + $vector `[` $id `:` $multiplicity `]` attr-dict `:` type($vector) `to` + type(results) + }]; +} + def Vector_FMAOp : Op]>, @@ -626,6 +691,46 @@ def Vector_InsertSlicesOp : }]; } +def Vector_InsertMapOp : + Vector_Op<"insert_map", [NoSideEffect]>, + Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>, + Results<(outs AnyVector)> { + let summary = "vector insert map operation"; + let description = [{ + insert an 1-D vector and within a larger vector starting at id. The new + vector created will have a size of `vector size * multiplicity`. This + represents how a sub-part of the vector is written for a given Value such as + a loop induction variable or an SPMD id. + + Similarly to vector.tuple_get, this operation is used for progressive + lowering and should be folded away before converting to LLVM. + + This operations is meant to be used in combination with vector.extract_map. + See example in extract.map description. + + Example: + + ```mlir + %v = vector.insert_map %ev, %id, 32 : vector<1xf32> to vector<32xf32> + ``` + }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, " # + "Value vector, Value id, int64_t multiplicity">]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return vector().getType().cast(); + } + VectorType getResultType() { + return getResult().getType().cast(); + } + }]; + let assemblyFormat = [{ + $vector `,` $id `,` $multiplicity attr-dict `:` type($vector) `to` + type(results) + }]; +} + def Vector_InsertStridedSliceOp : Vector_Op<"insert_strided_slice", [NoSideEffect, PredOpTrait<"operand #0 and result have same element type", diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h index 9587c56..da9650c 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -172,6 +172,47 @@ private: FilterConstraintType filter; }; +struct DistributeOps { + ExtractMapOp extract; + InsertMapOp insert; +}; + +/// Distribute a 1D vector pointwise operation over a range of given IDs taking +/// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or +/// SPMD id). This transformation only inserts +/// vector.extract_map/vector.insert_map. It is meant to be used with +/// canonicalizations pattern to propagate and fold the vector +/// insert_map/extract_map operations. +/// Transforms: +// %v = addf %a, %b : vector<32xf32> +/// to: +/// %v = addf %a, %b : vector<32xf32> %ev = +/// vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> %nv = +/// vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32> +Optional distributPointwiseVectorOp(OpBuilder &builder, + Operation *op, Value id, + int64_t multiplicity); +/// Canonicalize an extra element using the result of a pointwise operation. +/// Transforms: +/// %v = addf %a, %b : vector32xf32> +/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> +/// to: +/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> +/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> +/// %dv = addf %da, %db : vector<1xf32> +struct PointwiseExtractPattern : public OpRewritePattern { + using FilterConstraintType = std::function; + PointwiseExtractPattern( + MLIRContext *context, FilterConstraintType constraint = + [](ExtractMapOp op) { return success(); }) + : OpRewritePattern(context), filter(constraint) {} + LogicalResult matchAndRewrite(ExtractMapOp extract, + PatternRewriter &rewriter) const override; + +private: + FilterConstraintType filter; +}; + } // namespace vector //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 348ccf8..1a83c55 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -901,6 +901,29 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl &results) { } //===----------------------------------------------------------------------===// +// ExtractMapOp +//===----------------------------------------------------------------------===// + +void ExtractMapOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value id, int64_t multiplicity) { + VectorType type = vector.getType().cast(); + VectorType resultType = VectorType::get(type.getNumElements() / multiplicity, + type.getElementType()); + ExtractMapOp::build(builder, result, resultType, vector, id, multiplicity); +} + +static LogicalResult verify(ExtractMapOp op) { + if (op.getSourceVectorType().getShape().size() != 1 || + op.getResultType().getShape().size() != 1) + return op.emitOpError("expects source and destination vectors of rank 1"); + if (op.getResultType().getNumElements() * (int64_t)op.multiplicity() != + op.getSourceVectorType().getNumElements()) + return op.emitOpError("vector sizes mismatch. Source size must be equal " + "to destination size * multiplicity"); + return success(); +} + +//===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -1123,6 +1146,30 @@ void InsertSlicesOp::getStrides(SmallVectorImpl &results) { } //===----------------------------------------------------------------------===// +// InsertMapOp +//===----------------------------------------------------------------------===// + +void InsertMapOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value id, int64_t multiplicity) { + VectorType type = vector.getType().cast(); + VectorType resultType = VectorType::get(type.getNumElements() * multiplicity, + type.getElementType()); + InsertMapOp::build(builder, result, resultType, vector, id, multiplicity); +} + +static LogicalResult verify(InsertMapOp op) { + if (op.getSourceVectorType().getShape().size() != 1 || + op.getResultType().getShape().size() != 1) + return op.emitOpError("expected source and destination vectors of rank 1"); + if ((int64_t)op.multiplicity() * op.getSourceVectorType().getNumElements() != + op.getResultType().getNumElements()) + return op.emitOpError( + "vector sizes mismatch. Destination size must be equal " + "to source size * multiplicity"); + return success(); +} + +//===----------------------------------------------------------------------===// // InsertStridedSliceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 5bf7857a..6a244a4 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2418,6 +2418,40 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite( return failure(); } +LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite( + ExtractMapOp extract, PatternRewriter &rewriter) const { + Operation *definedOp = extract.vector().getDefiningOp(); + if (!definedOp || definedOp->getNumResults() != 1) + return failure(); + // TODO: Create an interfaceOp for elementwise operations. + if (!isa(definedOp)) + return failure(); + Location loc = extract.getLoc(); + SmallVector extractOperands; + for (OpOperand &operand : definedOp->getOpOperands()) + extractOperands.push_back(rewriter.create( + loc, operand.get(), extract.id(), extract.multiplicity())); + Operation *newOp = cloneOpWithOperandsAndTypes( + rewriter, loc, definedOp, extractOperands, extract.getResult().getType()); + rewriter.replaceOp(extract, newOp->getResult(0)); + return success(); +} + +Optional +mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op, + Value id, int64_t multiplicity) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(op); + Location loc = op->getLoc(); + Value result = op->getResult(0); + DistributeOps ops; + ops.extract = + builder.create(loc, result, id, multiplicity); + ops.insert = + builder.create(loc, ops.extract, id, multiplicity); + return ops; +} + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 3a08123..25e002f 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1328,3 +1328,31 @@ func @compress_dim_mask_mismatch(%base: memref, %mask: vector<17xi1>, %va // expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}} vector.compressstore %base, %mask, %value : memref, vector<17xi1>, vector<16xf32> } + +// ----- + +func @extract_map_rank(%v: vector<2x32xf32>, %id : index) { + // expected-error@+1 {{'vector.extract_map' op expects source and destination vectors of rank 1}} + %0 = vector.extract_map %v[%id : 32] : vector<2x32xf32> to vector<2x1xf32> +} + +// ----- + +func @extract_map_size(%v: vector<63xf32>, %id : index) { + // expected-error@+1 {{'vector.extract_map' op vector sizes mismatch. Source size must be equal to destination size * multiplicity}} + %0 = vector.extract_map %v[%id : 32] : vector<63xf32> to vector<2xf32> +} + +// ----- + +func @insert_map_rank(%v: vector<2x1xf32>, %id : index) { + // expected-error@+1 {{'vector.insert_map' op expected source and destination vectors of rank 1}} + %0 = vector.insert_map %v, %id, 32 : vector<2x1xf32> to vector<2x32xf32> +} + +// ----- + +func @insert_map_size(%v: vector<1xf32>, %id : index) { + // expected-error@+1 {{'vector.insert_map' op vector sizes mismatch. Destination size must be equal to source size * multiplicity}} + %0 = vector.insert_map %v, %id, 32 : vector<1xf32> to vector<64xf32> +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 2a62be9..7315d21 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -432,3 +432,14 @@ func @expand_and_compress(%base: memref, %mask: vector<16xi1>, %passthru: vector.compressstore %base, %mask, %0 : memref, vector<16xi1>, vector<16xf32> return } + +// CHECK-LABEL: @extract_insert_map +func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> { + // CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}} : 16] : vector<32xf32> to vector<2xf32> + %vd = vector.extract_map %v[%id : 16] : vector<32xf32> to vector<2xf32> + // CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}, 16 : vector<2xf32> to vector<32xf32> + %r = vector.insert_map %vd, %id, 16 : vector<2xf32> to vector<32xf32> + // CHECK: return %[[R]] : vector<32xf32> + return %r : vector<32xf32> +} + diff --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir new file mode 100644 index 0000000..0216a01 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-distribution.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -test-vector-distribute-patterns | FileCheck %s + +// CHECK-LABEL: func @distribute_vector_add +// CHECK-SAME: (%[[ID:.*]]: index +// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32> +// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ID]], 32 : vector<1xf32> to vector<32xf32> +// CHECK-NEXT: return %[[INS]] : vector<32xf32> +func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> { + %0 = addf %A, %B : vector<32xf32> + return %0: vector<32xf32> +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index ab84603..2ffe10b 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -125,6 +125,28 @@ struct TestVectorUnrollingPatterns } }; +struct TestVectorDistributePatterns + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + MLIRContext *ctx = &getContext(); + OwningRewritePatternList patterns; + FuncOp func = getFunction(); + func.walk([&](AddFOp op) { + OpBuilder builder(op); + Optional ops = distributPointwiseVectorOp( + builder, op.getOperation(), func.getArgument(0), 32); + assert(ops.hasValue()); + SmallPtrSet extractOp({ops->extract}); + op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); + }); + patterns.insert(ctx); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + struct TestVectorTransferFullPartialSplitPatterns : public PassWrapper { @@ -178,5 +200,9 @@ void registerTestVectorConversions() { vectorTransformFullPartialPass("test-vector-transfer-full-partial-split", "Test conversion patterns to split " "transfer ops via scf.if + linalg ops"); + PassRegistration distributePass( + "test-vector-distribute-patterns", + "Test conversion patterns to distribute vector ops in the vector " + "dialect"); } } // namespace mlir -- 2.7.4