Async execute operation can take async arguments as dependencies.
Change `async.execute` custom parser/printer format to use `%value as %unwrapped: !async.value<!type>` sytax.
Reviewed By: mehdi_amini, herhut
Differential Revision: https://reviews.llvm.org/D88601
#ifndef MLIR_DIALECT_ASYNC_IR_ASYNC_H
#define MLIR_DIALECT_ASYNC_IR_ASYNC_H
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
Type valueType = type;
}
-def Async_AnyValueType : Type<CPred<"$_self.isa<::mlir::async::ValueType>()">,
- "async value type">;
+def Async_AnyValueType : DialectType<AsyncDialect,
+ CPred<"$_self.isa<::mlir::async::ValueType>()">,
+ "async value type">;
+
+def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType,
+ Async_TokenType]>;
#endif // ASYNC_BASE_TD
class Async_Op<string mnemonic, list<OpTrait> traits = []> :
Op<AsyncDialect, mnemonic, traits>;
-def Async_ExecuteOp : Async_Op<"execute"> {
+def Async_ExecuteOp : Async_Op<"execute", [AttrSizedOperandSegments]> {
let summary = "Asynchronous execute operation";
let description = [{
The `body` region attached to the `async.execute` operation semantically
state). All dependencies must be made explicit with async execute arguments
(`async.token` or `async.value`).
+ `async.execute` operation takes `async.token` dependencies and `async.value`
+ operands separatly, and starts execution of the attached body region only
+ when all tokens and values become ready.
+
+ Example:
+
```mlir
- %done, %values = async.execute {
- %0 = "compute0"(...) : !some.type
- async.yield %1 : f32
- } : !async.token, !async.value<!some.type>
+ %dependency = ... : !async.token
+ %value = ... : !async.value<f32>
+
+ %token, %results =
+ async.execute [%dependency](%value as %unwrapped: !async.value<f32>)
+ -> !async.value<!some.type>
+ {
+ %0 = "compute0"(%unwrapped): (f32) -> !some.type
+ async.yield %0 : !some.type
+ }
%1 = "compute1"(...) : !some.type
```
+
+ In the example above asynchronous execution starts only after dependency
+ token and value argument become ready. Unwrapped value passed to the
+ attached body region as an %unwrapped value of f32 type.
}];
- // TODO: Take async.tokens/async.values as arguments.
- let arguments = (ins );
- let results = (outs Async_TokenType:$done,
- Variadic<Async_AnyValueType>:$values);
+ let arguments = (ins Variadic<Async_TokenType>:$dependencies,
+ Variadic<Async_AnyValueOrTokenType>:$operands);
+
+ let results = (outs Async_TokenType:$token,
+ Variadic<Async_AnyValueType>:$results);
let regions = (region SizedRegion<1>:$body);
- let printer = [{ return ::mlir::async::print(p, *this); }];
- let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }];
+ let printer = [{ return ::print(p, *this); }];
+ let parser = [{ return ::parse$cppClass(parser, result); }];
+
+ let verifier = [{ return ::verify(*this); }];
}
def Async_YieldOp :
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
- let verifier = [{ return ::mlir::async::verify(*this); }];
+ let verifier = [{ return ::verify(*this); }];
}
#endif // ASYNC_OPS
#include "mlir/Dialect/Async/IR/Async.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Traits.h"
-#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Transforms/InliningUtils.h"
-#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/raw_ostream.h"
-namespace mlir {
-namespace async {
+using namespace mlir;
+using namespace mlir::async;
void AsyncDialect::initialize() {
addOperations<
/// ValueType
//===----------------------------------------------------------------------===//
+namespace mlir {
+namespace async {
namespace detail {
// Storage for `async.value<T>` type, the only member is the wrapped type.
};
} // namespace detail
+} // namespace async
+} // namespace mlir
ValueType ValueType::get(Type valueType) {
return Base::get(valueType.getContext(), valueType);
// Get the underlying value types from async values returned from the
// parent `async.execute` operation.
auto executeOp = op.getParentOfType<ExecuteOp>();
- auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) {
+ auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
return result.getType().cast<ValueType>().getValueType();
});
/// ExecuteOp
//===----------------------------------------------------------------------===//
+constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
+
static void print(OpAsmPrinter &p, ExecuteOp op) {
- p << "async.execute ";
- p.printRegion(op.body());
- p.printOptionalAttrDict(op.getAttrs());
- p << " : ";
- p.printType(op.done().getType());
- if (!op.values().empty())
- p << ", ";
- llvm::interleaveComma(op.values(), p, [&](const OpResult &result) {
- p.printType(result.getType());
- });
+ p << op.getOperationName();
+
+ // [%tokens,...]
+ if (!op.dependencies().empty())
+ p << " [" << op.dependencies() << "]";
+
+ // (%value as %unwrapped: !async.value<!arg.type>, ...)
+ if (!op.operands().empty()) {
+ p << " (";
+ llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
+ p << operand << " as " << op.body().front().getArgument(n++) << ": "
+ << operand.getType();
+ });
+ p << ")";
+ }
+
+ // -> (!async.value<!return.type>, ...)
+ p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1));
+ p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr});
+ p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
}
static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
MLIRContext *ctx = result.getContext();
- // Parse asynchronous region.
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{},
- /*enableNameShadowing=*/false))
+ // Sizes of parsed variadic operands, will be updated below after parsing.
+ int32_t numDependencies = 0;
+ int32_t numOperands = 0;
+
+ auto tokenTy = TokenType::get(ctx);
+
+ // Parse dependency tokens.
+ if (succeeded(parser.parseOptionalLSquare())) {
+ SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
+ if (parser.parseOperandList(tokenArgs) ||
+ parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
+ parser.parseRSquare())
+ return failure();
+
+ numDependencies = tokenArgs.size();
+ }
+
+ // Parse async value operands (%value as %unwrapped : !async.value<!type>).
+ SmallVector<OpAsmParser::OperandType, 4> valueArgs;
+ SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
+ SmallVector<Type, 4> valueTypes;
+ SmallVector<Type, 4> unwrappedTypes;
+
+ if (succeeded(parser.parseOptionalLParen())) {
+ auto argsLoc = parser.getCurrentLocation();
+
+ // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
+ auto parseAsyncValueArg = [&]() -> ParseResult {
+ if (parser.parseOperand(valueArgs.emplace_back()) ||
+ parser.parseKeyword("as") ||
+ parser.parseOperand(unwrappedArgs.emplace_back()) ||
+ parser.parseColonType(valueTypes.emplace_back()))
+ return failure();
+
+ auto valueTy = valueTypes.back().dyn_cast<ValueType>();
+ unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
+
+ return success();
+ };
+
+ // If the next token is `)` skip async value arguments parsing.
+ if (failed(parser.parseOptionalRParen())) {
+ do {
+ if (parseAsyncValueArg())
+ return failure();
+ } while (succeeded(parser.parseOptionalComma()));
+
+ if (parser.parseRParen() ||
+ parser.resolveOperands(valueArgs, valueTypes, argsLoc,
+ result.operands))
+ return failure();
+ }
+
+ numOperands = valueArgs.size();
+ }
+
+ // Add derived `operand_segment_sizes` attribute based on parsed operands.
+ auto operandSegmentSizes = DenseIntElementsAttr::get(
+ VectorType::get({2}, parser.getBuilder().getI32Type()),
+ {numDependencies, numOperands});
+ result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
+
+ // Parse the types of results returned from the async execute op.
+ SmallVector<Type, 4> resultTypes;
+ if (parser.parseOptionalArrowTypeList(resultTypes))
return failure();
+ // Async execute first result is always a completion token.
+ parser.addTypeToList(tokenTy, result.types);
+ parser.addTypesToList(resultTypes, result.types);
+
// Parse operation attributes.
NamedAttrList attrs;
- if (parser.parseOptionalAttrDict(attrs))
+ if (parser.parseOptionalAttrDictWithKeyword(attrs))
return failure();
result.addAttributes(attrs);
- // Parse result types.
- SmallVector<Type, 4> resultTypes;
- if (parser.parseColonTypeList(resultTypes))
- return failure();
-
- // First result type must be an async token type.
- if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx))
+ // Parse asynchronous region.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
+ /*argTypes=*/{unwrappedTypes},
+ /*enableNameShadowing=*/false))
return failure();
- parser.addTypesToList(resultTypes, result.types);
return success();
}
-} // namespace async
-} // namespace mlir
+static LogicalResult verify(ExecuteOp op) {
+ // Unwrap async.execute value operands types.
+ auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
+ return operand.getType().cast<ValueType>().getValueType();
+ });
+
+ // Verify that unwrapped argument types matches the body region arguments.
+ if (llvm::size(unwrappedTypes) != llvm::size(op.body().getArgumentTypes()))
+ return op.emitOpError("the number of async body region arguments does not "
+ "match the number of execute operation arguments");
+
+ if (!std::equal(unwrappedTypes.begin(), unwrappedTypes.end(),
+ op.body().getArgumentTypes().begin()))
+ return op.emitOpError("async body region argument types do not match the "
+ "execute operation arguments types");
+
+ return success();
+}
#define GET_OP_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
-// RUN: mlir-opt %s | FileCheck %s
+// RUN: mlir-opt %s | FileCheck %s
// CHECK-LABEL: @identity_token
-func @identity_token(%arg0 : !async.token) -> !async.token {
+func @identity_token(%arg0: !async.token) -> !async.token {
// CHECK: return %arg0 : !async.token
return %arg0 : !async.token
}
// CHECK-LABEL: @empty_async_execute
func @empty_async_execute() -> !async.token {
- %done = async.execute {
+ // CHECK: async.execute
+ %token = async.execute {
async.yield
- } : !async.token
+ }
- // CHECK: return %done : !async.token
- return %done : !async.token
+ // CHECK: return %token : !async.token
+ return %token : !async.token
}
// CHECK-LABEL: @return_async_value
func @return_async_value() -> !async.value<f32> {
- %done, %values = async.execute {
+ // CHECK: async.execute -> !async.value<f32>
+ %token, %results = async.execute -> !async.value<f32> {
%cst = constant 1.000000e+00 : f32
async.yield %cst : f32
- } : !async.token, !async.value<f32>
+ }
- // CHECK: return %values : !async.value<f32>
- return %values : !async.value<f32>
+ // CHECK: return %results : !async.value<f32>
+ return %results : !async.value<f32>
+}
+
+// CHECK-LABEL: @return_captured_value
+func @return_captured_value() -> !async.token {
+ %cst = constant 1.000000e+00 : f32
+ // CHECK: async.execute -> !async.value<f32>
+ %token, %results = async.execute -> !async.value<f32> {
+ async.yield %cst : f32
+ }
+
+ // CHECK: return %token : !async.token
+ return %token : !async.token
}
// CHECK-LABEL: @return_async_values
func @return_async_values() -> (!async.value<f32>, !async.value<f32>) {
- %done, %values:2 = async.execute {
+ %token, %results:2 = async.execute -> (!async.value<f32>, !async.value<f32>) {
%cst1 = constant 1.000000e+00 : f32
%cst2 = constant 2.000000e+00 : f32
async.yield %cst1, %cst2 : f32, f32
- } : !async.token, !async.value<f32>, !async.value<f32>
+ }
+
+ // CHECK: return %results#0, %results#1 : !async.value<f32>, !async.value<f32>
+ return %results#0, %results#1 : !async.value<f32>, !async.value<f32>
+}
+
+// CHECK-LABEL: @async_token_dependencies
+func @async_token_dependencies(%arg0: !async.token) -> !async.token {
+ // CHECK: async.execute [%arg0]
+ %token = async.execute [%arg0] {
+ async.yield
+ }
+
+ // CHECK: return %token : !async.token
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @async_value_operands
+func @async_value_operands(%arg0: !async.value<f32>) -> !async.token {
+ // CHECK: async.execute (%arg0 as %arg1: !async.value<f32>) -> !async.value<f32>
+ %token, %results = async.execute (%arg0 as %arg1: !async.value<f32>) -> !async.value<f32> {
+ async.yield %arg1 : f32
+ }
+
+ // CHECK: return %token : !async.token
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @async_token_and_value_operands
+func @async_token_and_value_operands(%arg0: !async.token, %arg1: !async.value<f32>) -> !async.token {
+ // CHECK: async.execute [%arg0] (%arg1 as %arg2: !async.value<f32>) -> !async.value<f32>
+ %token, %results = async.execute [%arg0] (%arg1 as %arg2: !async.value<f32>) -> !async.value<f32> {
+ async.yield %arg2 : f32
+ }
+
+ // CHECK: return %token : !async.token
+ return %token : !async.token
+}
- // CHECK: return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
- return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
+// CHECK-LABEL: @empty_tokens_or_values_operands
+func @empty_tokens_or_values_operands() {
+ // CHECK: async.execute {
+ %token0 = async.execute [] () -> () { async.yield }
+ // CHECK: async.execute {
+ %token1 = async.execute () -> () { async.yield }
+ // CHECK: async.execute {
+ %token2 = async.execute -> () { async.yield }
+ // CHECK: async.execute {
+ %token3 = async.execute () { async.yield }
+ // CHECK: async.execute {
+ %token4 = async.execute [] { async.yield }
+ return
}