state). All dependencies must be made explicit with async execute arguments
(`async.token` or `async.value`).
- Example:
-
```mlir
- %0 = async.execute {
- "compute0"(...)
- async.yield
- } : !async.token
+ %done, %values = async.execute {
+ %0 = "compute0"(...) : !some.type
+ async.yield %1 : f32
+ } : !async.token, !async.value<!some.type>
- %1 = "compute1"(...)
+ %1 = "compute1"(...) : !some.type
```
}];
// TODO: Take async.tokens/async.values as arguments.
let arguments = (ins );
- let results = (outs Async_TokenType:$done);
+ let results = (outs Async_TokenType:$done,
+ Variadic<Async_AnyValueType>:$values);
let regions = (region SizedRegion<1>:$body);
- let assemblyFormat = "$body attr-dict `:` type($done)";
+ let printer = [{ return ::mlir::async::print(p, *this); }];
+ let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }];
}
def Async_YieldOp :
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+
+ let verifier = [{ return ::mlir::async::verify(*this); }];
}
#endif // ASYNC_OPS
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
-using namespace mlir;
-using namespace mlir::async;
+namespace mlir {
+namespace async {
void AsyncDialect::initialize() {
addOperations<
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
>();
addTypes<TokenType>();
+ addTypes<ValueType>();
}
/// Parse a type registered to this dialect.
if (keyword == "token")
return TokenType::get(getContext());
+ if (keyword == "value") {
+ Type ty;
+ if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
+ parser.emitError(parser.getNameLoc(), "failed to parse async value type");
+ return Type();
+ }
+ return ValueType::get(ty);
+ }
+
parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
return Type();
}
/// Print a type registered to this dialect.
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
- .Case<TokenType>([&](Type) { os << "token"; })
+ .Case<TokenType>([&](TokenType) { os << "token"; })
+ .Case<ValueType>([&](ValueType valueTy) {
+ os << "value<";
+ os.printType(valueTy.getValueType());
+ os << '>';
+ })
.Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
}
+//===----------------------------------------------------------------------===//
+/// ValueType
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+
+// Storage for `async.value<T>` type, the only member is the wrapped type.
+struct ValueTypeStorage : public TypeStorage {
+ ValueTypeStorage(Type valueType) : valueType(valueType) {}
+
+ /// The hash key used for uniquing.
+ using KeyTy = Type;
+ bool operator==(const KeyTy &key) const { return key == valueType; }
+
+ /// Construction.
+ static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
+ Type valueType) {
+ return new (allocator.allocate<ValueTypeStorage>())
+ ValueTypeStorage(valueType);
+ }
+
+ Type valueType;
+};
+
+} // namespace detail
+
+ValueType ValueType::get(Type valueType) {
+ return Base::get(valueType.getContext(), valueType);
+}
+
+Type ValueType::getValueType() { return getImpl()->valueType; }
+
+//===----------------------------------------------------------------------===//
+// YieldOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(YieldOp op) {
+ // Get the underlying value types from async values returned from the
+ // parent `async.execute` operation.
+ auto executeOp = op.getParentOfType<ExecuteOp>();
+ auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) {
+ return result.getType().cast<ValueType>().getValueType();
+ });
+
+ if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin()))
+ return op.emitOpError("Operand types do not match the types returned from "
+ "the parent ExecuteOp");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+/// ExecuteOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, ExecuteOp op) {
+ p << "async.execute ";
+ p.printRegion(op.body());
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : ";
+ p.printType(op.done().getType());
+ if (!op.values().empty())
+ p << ", ";
+ llvm::interleaveComma(op.values(), p, [&](const OpResult &result) {
+ p.printType(result.getType());
+ });
+}
+
+static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
+ MLIRContext *ctx = result.getContext();
+
+ // Parse asynchronous region.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{},
+ /*enableNameShadowing=*/false))
+ return failure();
+
+ // Parse operation attributes.
+ NamedAttrList attrs;
+ if (parser.parseOptionalAttrDict(attrs))
+ return failure();
+ result.addAttributes(attrs);
+
+ // Parse result types.
+ SmallVector<Type, 4> resultTypes;
+ if (parser.parseColonTypeList(resultTypes))
+ return failure();
+
+ // First result type must be an async token type.
+ if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx))
+ return failure();
+ parser.addTypesToList(resultTypes, result.types);
+
+ return success();
+}
+
+} // namespace async
+} // namespace mlir
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
// RUN: mlir-opt %s | FileCheck %s
-// CHECK-LABEL: @identity
-func @identity(%arg0 : !async.token) -> !async.token {
+// CHECK-LABEL: @identity_token
+func @identity_token(%arg0 : !async.token) -> !async.token {
// CHECK: return %arg0 : !async.token
return %arg0 : !async.token
}
+// CHECK-LABEL: @identity_value
+func @identity_value(%arg0 : !async.value<f32>) -> !async.value<f32> {
+ // CHECK: return %arg0 : !async.value<f32>
+ return %arg0 : !async.value<f32>
+}
+
// CHECK-LABEL: @empty_async_execute
func @empty_async_execute() -> !async.token {
- %0 = async.execute {
+ %done = async.execute {
async.yield
} : !async.token
- return %0 : !async.token
+ // CHECK: return %done : !async.token
+ return %done : !async.token
+}
+
+// CHECK-LABEL: @return_async_value
+func @return_async_value() -> !async.value<f32> {
+ %done, %values = async.execute {
+ %cst = constant 1.000000e+00 : f32
+ async.yield %cst : f32
+ } : !async.token, !async.value<f32>
+
+ // CHECK: return %values : !async.value<f32>
+ return %values : !async.value<f32>
+}
+
+// CHECK-LABEL: @return_async_values
+func @return_async_values() -> (!async.value<f32>, !async.value<f32>) {
+ %done, %values:2 = async.execute {
+ %cst1 = constant 1.000000e+00 : f32
+ %cst2 = constant 2.000000e+00 : f32
+ async.yield %cst1, %cst2 : f32, f32
+ } : !async.token, !async.value<f32>, !async.value<f32>
+
+ // CHECK: return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
+ return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
}