#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/TypeSwitch.h"
#define DEBUG_TYPE "convert-async-to-llvm"
static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
+static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
return FunctionType::get(ctx, {value}, {i1});
}
+ static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
+ auto i1 = IntegerType::get(ctx, 1);
+ return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
+ }
+
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
+ addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
LogicalResult
matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
- rewriter.replaceOpWithNewOp<CallOp>(
- op, operandType.isa<TokenType>() ? kEmplaceToken : kEmplaceValue,
- TypeRange(), operands);
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kEmplaceToken; })
+ .Case<ValueType>([](Type) { return kEmplaceValue; });
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
+
return success();
}
};
LogicalResult
matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
- rewriter.replaceOpWithNewOp<CallOp>(
- op, operandType.isa<TokenType>() ? kSetTokenError : kSetValueError,
- TypeRange(), operands);
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kSetTokenError; })
+ .Case<ValueType>([](Type) { return kSetValueError; });
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
+
return success();
}
};
LogicalResult
matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
- rewriter.replaceOpWithNewOp<CallOp>(
- op, operandType.isa<TokenType>() ? kIsTokenError : kIsValueError,
- rewriter.getI1Type(), operands);
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kIsTokenError; })
+ .Case<GroupType>([](Type) { return kIsGroupError; })
+ .Case<ValueType>([](Type) { return kIsValueError; });
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
+ operands);
return success();
}
};
LogicalResult
matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
-
- StringRef apiFuncName;
- if (operandType.isa<TokenType>())
- apiFuncName = kAwaitToken;
- else if (operandType.isa<ValueType>())
- apiFuncName = kAwaitValue;
- else if (operandType.isa<GroupType>())
- apiFuncName = kAwaitGroup;
- else
- return rewriter.notifyMatchFailure(op, "unsupported async type");
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kAwaitToken; })
+ .Case<ValueType>([](Type) { return kAwaitValue; })
+ .Case<GroupType>([](Type) { return kAwaitGroup; });
rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
rewriter.eraseOp(op);
LogicalResult
matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
-
- StringRef apiFuncName;
- if (operandType.isa<TokenType>())
- apiFuncName = kAwaitTokenAndExecute;
- else if (operandType.isa<ValueType>())
- apiFuncName = kAwaitValueAndExecute;
- else if (operandType.isa<GroupType>())
- apiFuncName = kAwaitAllAndExecute;
- else
- return rewriter.notifyMatchFailure(op, "unsupported async type");
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
+ .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
+ .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
coro.cleanup);
- // TODO: Async groups do not yet support runtime errors.
- if (!std::is_same<AwaitAllOp, AwaitType>::value) {
- // Split the resume block into error checking and continuation.
- Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
-
- // Check if the awaited value is in the error state.
- builder.setInsertionPointToStart(resume);
- auto isError = builder.create<RuntimeIsErrorOp>(
- loc, rewriter.getI1Type(), operand);
- builder.create<CondBranchOp>(isError,
- /*trueDest=*/setupSetErrorBlock(coro),
- /*trueArgs=*/ArrayRef<Value>(),
- /*falseDest=*/continuation,
- /*falseArgs=*/ArrayRef<Value>());
-
- // Make sure that replacement value will be constructed in the
- // continuation block.
- rewriter.setInsertionPointToStart(continuation);
- }
+ // Split the resume block into error checking and continuation.
+ Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
+
+ // Check if the awaited value is in the error state.
+ builder.setInsertionPointToStart(resume);
+ auto isError =
+ builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand);
+ builder.create<CondBranchOp>(isError,
+ /*trueDest=*/setupSetErrorBlock(coro),
+ /*trueArgs=*/ArrayRef<Value>(),
+ /*falseDest=*/continuation,
+ /*falseArgs=*/ArrayRef<Value>());
+
+ // Make sure that replacement value will be constructed in the
+ // continuation block.
+ rewriter.setInsertionPointToStart(continuation);
}
// Erase or replace the await operation with the new value.
// tokens or values added to the group).
struct AsyncGroup : public RefCounted {
AsyncGroup(AsyncRuntime *runtime)
- : RefCounted(runtime), pendingTokens(0), rank(0) {}
+ : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {}
std::atomic<int> pendingTokens;
+ std::atomic<int> numErrors;
std::atomic<int> rank;
// Pending awaiters are guarded by a mutex.
int rank = group->rank.fetch_add(1);
group->pendingTokens.fetch_add(1);
- auto onTokenReady = [group]() {
+ auto onTokenReady = [group, token]() {
+ // Increment the number of errors in the group.
+ if (State(token->state).isError())
+ group->numErrors.fetch_add(1);
+
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
return State(value->state).isError();
}
+extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
+ return group->numErrors.load() > 0;
+}
+
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
if (!State(token->state).isAvailableOrError())
&mlir::runtime::mlirAsyncRuntimeIsTokenError);
exportSymbol("mlirAsyncRuntimeIsValueError",
&mlir::runtime::mlirAsyncRuntimeIsValueError);
+ exportSymbol("mlirAsyncRuntimeIsGroupError",
+ &mlir::runtime::mlirAsyncRuntimeIsGroupError);
exportSymbol("mlirAsyncRuntimeAwaitToken",
&mlir::runtime::mlirAsyncRuntimeAwaitToken);
exportSymbol("mlirAsyncRuntimeAwaitValue",