From 59058c441a9ba421b8f45cf1482544fd72ecb558 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 13 Apr 2022 18:38:11 +0000 Subject: [PATCH] [mlir][vector] Add operations used for Vector distribution Add vector op warp_execute_on_lane_0 that will be used to do incremental vector distribution in order to target warp level vector programming for architectures with GPU-like SIMT programming model. The idea behing the op is discussed further on discourse: https://discourse.llvm.org/t/vector-vector-distribution-large-vector-to-small-vector/1983/23 Differential Revision: https://reviews.llvm.org/D123703 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.h | 1 + mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 136 +++++++++++++++++ mlir/lib/Dialect/Vector/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 178 ++++++++++++++++++++++ mlir/test/Dialect/Vector/invalid.mlir | 75 +++++++++ mlir/test/Dialect/Vector/ops.mlir | 27 ++++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 2 + 7 files changed, 420 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index b5e9f25..39c6353 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -20,6 +20,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3e9ad30..76daf9e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -13,6 +13,7 @@ #ifndef VECTOR_OPS #define VECTOR_OPS +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" @@ -2539,4 +2540,139 @@ def Vector_ScanOp : let hasVerifier = 1; } +def Vector_YieldOp : Vector_Op<"yield", [ + NoSideEffect, ReturnLike, Terminator]> { + let summary = "Terminates and yields values from vector regions."; + let description = [{ + "vector.yield" yields an SSA value from the Vector dialect op region and + terminates the regions. The semantics of how the values are yielded is + defined by the parent operation. + If "vector.yield" has any operands, the operands must correspond to the + parent operation's results. + If the parent operation defines no value the vector.yield may be omitted + when printing the region. + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [ + OpBuilder<(ins), [{ /* nothing to do */ }]>, + ]; + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} + +def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0", + [DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"vector::YieldOp">, + RecursiveSideEffects]> { + let summary = "Executes operations in the associated region on lane #0 of a" + "GPU SIMT warp"; + let description = [{ + `warp_execute_on_lane_0` is an operation used to bridge the gap between + vector programming and GPU SIMT programming model. It allows to trivially + convert a region of vector code meant to run on a GPU warp into a valid SIMT + region and then allows incremental transformation to distribute vector + operations on the SIMT lane. + + Any code present in the region would only be executed on first lane + based on the `laneid` operand. The `laneid` operand is an integer ID between + [0, `warp_size`). The `warp_size` attribute indicates the number of lanes in + a warp. + + Operands are vector values distributed on all lanes that may be used by + the single lane execution. The matching region argument is a vector of all + the values of those lanes available to the single active lane. The + distributed dimension is implicit based on the shape of the operand and + argument. In the future this may be described by an affine map. + + Return values are distributed on all lanes using laneId as index. The + vector is distributed based on the shape ratio between the vector type of + the yield and the result type. + If the shapes are the same this means the value is broadcasted to all lanes. + In the future the distribution can be made more explicit using affine_maps + and will support having multiple Ids. + + Therefore the `warp_execute_on_lane_0` operations allow to implicitly copy + between lane0 and the lanes of the warp. When distributing a vector + from lane0 to all the lanes, the data are distributed in a block cyclic way. + + During lowering values passed as operands and return value need to be + visible to different lanes within the warp. This would usually be done by + going through memory. + + The region is *not* isolated from above. For values coming from the parent + region not going through operands only the lane 0 value will be accesible so + it generally only make sense for uniform values. + + Example: + ``` + vector.warp_execute_on_lane_0 (%laneid)[32] { + ... + } + ``` + + This may be lowered to an scf.if region as below: + ``` + %cnd = arith.cmpi eq, %laneid, %c0 : index + scf.if %cnd { + ... + } + ``` + + When the region has operands and/or return values: + ``` + %0 = vector.warp_execute_on_lane_0(%laneid)[32] + args(%v0 : vector<4xi32>) -> (vector<1xf32>) { + ^bb0(%arg0 : vector<128xi32>) : + ... + vector.yield %1 : vector<32xf32> + } + ``` + + values at the region boundary would go through memory: + ``` + %tmp0 = memreg.alloc() : memref<32xf32, 3> + %tmp1 = memreg.alloc() : memref<32xf32, 3> + %cnd = arith.cmpi eq, %laneid, %c0 : index + vector.store %v0, %tmp0[%laneid] : memref<32xf32>, vector<1xf32> + warp_sync + scf.if %cnd { + %arg0 = vector.load %tmp0[%c0] : memref<32xf32>, vector<32xf32> + ... + vector.store %1, %tmp1[%c0] : memref<32xf32>, vector<32xf32> + } + warp_sync + %0 = vector.load %tmp1[%laneid] : memref<32xf32>, vector<32xf32> + ``` + + }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + let arguments = (ins Index:$laneid, I64Attr:$warp_size, + Variadic:$args); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$warpRegion); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "Value":$laneid, "int64_t":$warpSize)>, + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid, + "int64_t":$warpSize)>, + // `blockArgTypes` are different than `args` types as they are they + // represent all the `args` instances visibile to lane 0. Therefore we need + // to explicit pass the type. + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid, + "int64_t":$warpSize, "ValueRange":$args, + "TypeRange":$blockArgTypes)> + ]; + + let extraClassDeclaration = [{ + bool isDefinedOutsideOfRegion(Value value) { + return !getRegion().isAncestor(value.getParentRegion()); + } + }]; +} + #endif // VECTOR_OPS diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt index 5964753..17380bd 100644 --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVector LINK_LIBS PUBLIC MLIRArithmetic + MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRDialectUtils MLIRIR diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 940f926..af17460 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4590,6 +4590,184 @@ OpFoldResult SplatOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// +// WarpExecuteOnLane0Op +//===----------------------------------------------------------------------===// + +void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) { + p << "(" << getLaneid() << ")"; + + SmallVector coreAttr = {getWarpSizeAttrName()}; + auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName()); + p << "[" << warpSizeAttr.cast().getInt() << "]"; + + if (!getArgs().empty()) + p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")"; + if (!getResults().empty()) + p << " -> (" << getResults().getTypes() << ')'; + p << " "; + p.printRegion(getRegion(), + /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/!getResults().empty()); + p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr); +} + +ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, + OperationState &result) { + // Create the region. + result.regions.reserve(1); + Region *warpRegion = result.addRegion(); + + auto &builder = parser.getBuilder(); + OpAsmParser::UnresolvedOperand laneId; + + // Parse predicate operand. + if (parser.parseLParen() || parser.parseRegionArgument(laneId) || + parser.parseRParen()) + return failure(); + + int64_t warpSize; + if (parser.parseLSquare() || parser.parseInteger(warpSize) || + parser.parseRSquare()) + return failure(); + result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(), + builder.getContext())), + builder.getI64IntegerAttr(warpSize)); + + if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands)) + return failure(); + + llvm::SMLoc inputsOperandsLoc; + SmallVector inputsOperands; + SmallVector inputTypes; + if (succeeded(parser.parseOptionalKeyword("args"))) { + if (parser.parseLParen()) + return failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands) || + parser.parseColonTypeList(inputTypes) || parser.parseRParen()) + return failure(); + } + if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, + result.operands)) + return failure(); + + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + // Parse the region. + if (parser.parseRegion(*warpRegion, /*arguments=*/{}, + /*argTypes=*/{})) + return failure(); + WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +void WarpExecuteOnLane0Op::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // The warp region is always executed + regions.push_back(RegionSuccessor(&getWarpRegion())); +} + +void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, Value laneId, + int64_t warpSize) { + build(builder, result, resultTypes, laneId, warpSize, + /*operands=*/llvm::None, /*argTypes=*/llvm::None); +} + +void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, Value laneId, + int64_t warpSize, ValueRange args, + TypeRange blockArgTypes) { + result.addOperands(laneId); + result.addAttribute(getAttributeNames()[0], + builder.getI64IntegerAttr(warpSize)); + result.addTypes(resultTypes); + result.addOperands(args); + assert(args.size() == blockArgTypes.size()); + OpBuilder::InsertionGuard guard(builder); + Region *warpRegion = result.addRegion(); + Block *block = builder.createBlock(warpRegion); + for (auto it : llvm::zip(blockArgTypes, args)) + block->addArgument(std::get<0>(it), std::get<1>(it).getLoc()); +} + +/// Helper check if the distributed vector type is consistent with the expanded +/// type and distributed size. +static LogicalResult verifyDistributedType(Type expanded, Type distributed, + int64_t warpSize, Operation *op) { + // If the types matches there is no distribution. + if (expanded == distributed) + return success(); + auto expandedVecType = expanded.dyn_cast(); + auto distributedVecType = distributed.dyn_cast(); + if (!expandedVecType || !distributedVecType) + return op->emitOpError("expected vector type for distributed operands."); + if (expandedVecType.getRank() != distributedVecType.getRank() || + expandedVecType.getElementType() != distributedVecType.getElementType()) + return op->emitOpError( + "expected distributed vectors to have same rank and element type."); + bool foundDistributedDim = false; + for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) { + if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i)) + continue; + if (expandedVecType.getDimSize(i) == + distributedVecType.getDimSize(i) * warpSize) { + if (foundDistributedDim) + return op->emitOpError() + << "expected only one dimension to be distributed from " + << expandedVecType << " to " << distributedVecType; + foundDistributedDim = true; + continue; + } + return op->emitOpError() << "incompatible distribution dimensions from " + << expandedVecType << " to " << distributedVecType; + } + return success(); +} + +LogicalResult WarpExecuteOnLane0Op::verify() { + if (getArgs().size() != getWarpRegion().getNumArguments()) + return emitOpError( + "expected same number op arguments and block arguments."); + auto yield = + cast(getWarpRegion().getBlocks().begin()->getTerminator()); + if (yield.getNumOperands() != getNumResults()) + return emitOpError( + "expected same number of yield operands and return values."); + int64_t warpSize = getWarpSize(); + for (auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) { + if (failed(verifyDistributedType(std::get<0>(it).getType(), + std::get<1>(it).getType(), warpSize, + getOperation()))) + return failure(); + } + for (auto it : llvm::zip(yield.getOperands(), getResults())) { + if (failed(verifyDistributedType(std::get<0>(it).getType(), + std::get<1>(it).getType(), warpSize, + getOperation()))) + return failure(); + } + return success(); +} + +bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) { + return succeeded( + verifyDistributedType(lhs, rhs, getWarpSize(), getOperation())); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index f60d2b1..e3e01b9 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1528,3 +1528,78 @@ func @invalid_splat(%v : f32) { vector.splat %v : memref<8xf32> return } + +// ----- + +func @warp_wrong_num_outputs(%laneid: index) { + // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected same number of yield operands and return values.}} + %2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<4xi32>) { + } + return +} + +// ----- + +func @warp_wrong_num_inputs(%laneid: index) { + // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected same number op arguments and block arguments.}} + vector.warp_execute_on_lane_0(%laneid)[64] { + ^bb0(%arg0 : vector<128xi32>) : + } + return +} + +// ----- + +func @warp_wrong_return_distribution(%laneid: index) { + // expected-error@+1 {{'vector.warp_execute_on_lane_0' op incompatible distribution dimensions from 'vector<128xi32>' to 'vector<4xi32>'}} + %2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<4xi32>) { + %0 = arith.constant dense<2>: vector<128xi32> + vector.yield %0 : vector<128xi32> + } + return +} + + +// ----- + +func @warp_wrong_arg_distribution(%laneid: index, %v0 : vector<4xi32>) { + // expected-error@+1 {{'vector.warp_execute_on_lane_0' op incompatible distribution dimensions from 'vector<128xi32>' to 'vector<4xi32>'}} + vector.warp_execute_on_lane_0(%laneid)[64] + args(%v0 : vector<4xi32>) { + ^bb0(%arg0 : vector<128xi32>) : + } + return +} + +// ----- + +func @warp_2_distributed_dims(%laneid: index) { + // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected only one dimension to be distributed from 'vector<128x128xi32>' to 'vector<4x4xi32>'}} + %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) { + %0 = arith.constant dense<2>: vector<128x128xi32> + vector.yield %0 : vector<128x128xi32> + } + return +} + +// ----- + +func @warp_mismatch_rank(%laneid: index) { + // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}} + %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) { + %0 = arith.constant dense<2>: vector<128xi32> + vector.yield %0 : vector<128xi32> + } + return +} + +// ----- + +func @warp_mismatch_rank(%laneid: index) { + // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}} + %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) { + %0 = arith.constant dense<2>: vector<128xi32> + vector.yield %0 : vector<128xi32> + } + return +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 43b38ef..3db28eb 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -745,3 +745,30 @@ func @vector_splat_0d(%a: f32) -> vector { %0 = vector.splat %a : vector return %0 : vector } + +// CHECK-LABEL: func @warp_execute_on_lane_0( +func @warp_execute_on_lane_0(%laneid: index) { +// CHECK-NEXT: vector.warp_execute_on_lane_0(%{{.*}})[32] { + vector.warp_execute_on_lane_0(%laneid)[32] { +// CHECK-NEXT: } + } +// CHECK-NEXT: return + return +} + +// CHECK-LABEL: func @warp_operand_result( +func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4xi32>) { +// CHECK-NEXT: %{{.*}} = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xi32>) -> (vector<4xi32>) { + %2 = vector.warp_execute_on_lane_0(%laneid)[32] + args(%v0 : vector<4xi32>) -> (vector<4xi32>) { + ^bb0(%arg0 : vector<128xi32>) : + %0 = arith.constant dense<2>: vector<128xi32> + %1 = arith.addi %arg0, %0 : vector<128xi32> +// CHECK: vector.yield %{{.*}} : vector<128xi32> + vector.yield %1 : vector<128xi32> +// CHECK-NEXT: } + } + return %2 : vector<4xi32> +} + + diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 656a089..65096cc 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2963,6 +2963,7 @@ cc_library( deps = [ ":ArithmeticDialect", ":ArithmeticUtils", + ":ControlFlowInterfaces", ":DialectUtils", ":IR", ":MemRefDialect", @@ -7275,6 +7276,7 @@ td_library( srcs = ["include/mlir/Dialect/Vector/IR/VectorOps.td"], includes = ["include"], deps = [ + ":ControlFlowInterfacesTdFiles", ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", -- 2.7.4