#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
//===----------------------------------------------------------------------===//
include "mlir/Dialect/Async/IR/AsyncDialect.td"
include "mlir/Dialect/Async/IR/AsyncTypes.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
}];
}
+def Async_RuntimeNumWorkerThreadsOp :
+ Async_Op<"runtime.num_worker_threads",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "gets the number of threads in the threadpool from the runtime";
+ let description = [{
+ The `async.runtime.num_worker_threads` operation gets the number of threads
+ in the threadpool from the runtime.
+ }];
+
+ let results = (outs Index:$result);
+ let assemblyFormat = "attr-dict `:` type($result)";
+}
+
#endif // ASYNC_OPS
Option<"numWorkerThreads", "num-workers",
"int32_t", /*default=*/"8",
- "The number of available workers to execute async operations.">,
+ "The number of available workers to execute async operations. If `-1` "
+ "the value will be retrieved from the runtime.">,
Option<"minTaskSize", "min-task-size",
"int32_t", /*default=*/"1000",
extern "C" void
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume);
+// Returns the current number of available worker threads in the threadpool.
+extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads();
+
//===----------------------------------------------------------------------===//
// Small async runtime support library for testing.
//===----------------------------------------------------------------------===//
"mlirAsyncRuntimeAwaitValueAndExecute";
static constexpr const char *kAwaitAllAndExecute =
"mlirAsyncRuntimeAwaitAllInGroupAndExecute";
+static constexpr const char *kGetNumWorkerThreads =
+ "mlirAsyncRuntimGetNumWorkerThreads";
namespace {
/// Async Runtime API function types.
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
}
+ static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
+ return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
+ }
+
// Auxiliary coroutine resume intrinsic wrapper.
static Type resumeFunctionType(MLIRContext *ctx) {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
addFuncDecl(kAwaitAllAndExecute,
AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
+ addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
}
//===----------------------------------------------------------------------===//
} // namespace
//===----------------------------------------------------------------------===//
+// Convert async.runtime.num_worker_threads to the corresponding runtime API
+// call.
+//===----------------------------------------------------------------------===//
+
+namespace {
+class RuntimeNumWorkerThreadsOpLowering
+ : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Replace with a runtime API function call.
+ rewriter.replaceOpWithNewOp<CallOp>(op, kGetNumWorkerThreads,
+ rewriter.getIndexType());
+
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
// Async reference counting ops lowering (`async.runtime.add_ref` and
// `async.runtime.drop_ref` to the corresponding API calls).
//===----------------------------------------------------------------------===//
patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
- RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
- RuntimeDropRefOpLowering>(converter, ctx);
+ RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
+ RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
+ ctx);
// Lower async.runtime operations that rely on LLVM type converter to convert
// from async value payload type to the LLVM type.
LINK_LIBS PUBLIC
MLIRDialect
+ MLIRInferTypeOpInterface
MLIRIR
)
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
numUnrollableLoops++;
}
+ Value numWorkerThreadsVal;
+ if (numWorkerThreads >= 0)
+ numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads);
+ else
+ numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>();
+
// With large number of threads the value of creating many compute blocks
- // is reduced because the problem typically becomes memory bound. For small
- // number of threads it helps with stragglers.
- float overshardingFactor = numWorkerThreads <= 4 ? 8.0
- : numWorkerThreads <= 8 ? 4.0
- : numWorkerThreads <= 16 ? 2.0
- : numWorkerThreads <= 32 ? 1.0
- : numWorkerThreads <= 64 ? 0.8
- : 0.6;
-
- // Do not overload worker threads with too many compute blocks.
- Value maxComputeBlocks = b.create<arith::ConstantIndexOp>(
- std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
+ // is reduced because the problem typically becomes memory bound. For this
+ // reason we scale the number of workers using an equivalent to the
+ // following logic:
+ // float overshardingFactor = numWorkerThreads <= 4 ? 8.0
+ // : numWorkerThreads <= 8 ? 4.0
+ // : numWorkerThreads <= 16 ? 2.0
+ // : numWorkerThreads <= 32 ? 1.0
+ // : numWorkerThreads <= 64 ? 0.8
+ // : 0.6;
+
+ // Pairs of non-inclusive lower end of the bracket and factor that the
+ // number of workers needs to be scaled with if it falls in that bucket.
+ const SmallVector<std::pair<int, float>> overshardingBrackets = {
+ {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
+ const float initialOvershardingFactor = 8.0f;
+
+ Value scalingFactor = b.create<arith::ConstantFloatOp>(
+ llvm::APFloat(initialOvershardingFactor), b.getF32Type());
+ for (const std::pair<int, float> &p : overshardingBrackets) {
+ Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
+ Value inBracket = b.create<arith::CmpIOp>(
+ arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
+ Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
+ llvm::APFloat(p.second), b.getF32Type());
+ scalingFactor =
+ b.create<SelectOp>(inBracket, bracketScalingFactor, scalingFactor);
+ }
+ Value numWorkersIndex =
+ b.create<arith::IndexCastOp>(numWorkerThreadsVal, b.getI32Type());
+ Value numWorkersFloat =
+ b.create<arith::SIToFPOp>(numWorkersIndex, b.getF32Type());
+ Value scaledNumWorkers =
+ b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
+ Value scaledNumInt =
+ b.create<arith::FPToSIOp>(scaledNumWorkers, b.getI32Type());
+ Value scaledWorkers =
+ b.create<arith::IndexCastOp>(scaledNumInt, b.getIndexType());
+
+ Value maxComputeBlocks = b.create<arith::MaxSIOp>(
+ b.create<arith::ConstantIndexOp>(1), scaledWorkers);
// Compute parallel block size from the parallel problem size:
// blockSize = min(tripCount,
}
}
+extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
+ return getDefaultAsyncRuntime()->getThreadPool().getThreadCount();
+}
+
//===----------------------------------------------------------------------===//
// Small async runtime support library for testing.
//===----------------------------------------------------------------------===//
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
+ exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
+ &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -async-parallel-for=num-workers=-1 \
+// RUN: | FileCheck %s --dump-input=always
+
+// CHECK-LABEL: @num_worker_threads(
+// CHECK: %[[MEMREF:.*]]: memref<?xf32>
+func @num_worker_threads(%arg0: memref<?xf32>) {
+
+ // CHECK: %[[scalingCstInit:.*]] = arith.constant 8.000000e+00 : f32
+ // CHECK: %[[bracketLowerBound4:.*]] = arith.constant 4 : index
+ // CHECK: %[[scalingCst4:.*]] = arith.constant 4.000000e+00 : f32
+ // CHECK: %[[bracketLowerBound8:.*]] = arith.constant 8 : index
+ // CHECK: %[[scalingCst8:.*]] = arith.constant 2.000000e+00 : f32
+ // CHECK: %[[bracketLowerBound16:.*]] = arith.constant 16 : index
+ // CHECK: %[[scalingCst16:.*]] = arith.constant 1.000000e+00 : f32
+ // CHECK: %[[bracketLowerBound32:.*]] = arith.constant 32 : index
+ // CHECK: %[[scalingCst32:.*]] = arith.constant 8.000000e-01 : f32
+ // CHECK: %[[bracketLowerBound64:.*]] = arith.constant 64 : index
+ // CHECK: %[[scalingCst64:.*]] = arith.constant 6.000000e-01 : f32
+ // CHECK: %[[workersIndex:.*]] = async.runtime.num_worker_threads : index
+ // CHECK: %[[inBracket4:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound4]] : index
+ // CHECK: %[[scalingFactor4:.*]] = select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32
+ // CHECK: %[[inBracket8:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound8]] : index
+ // CHECK: %[[scalingFactor8:.*]] = select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32
+ // CHECK: %[[inBracket16:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound16]] : index
+ // CHECK: %[[scalingFactor16:.*]] = select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32
+ // CHECK: %[[inBracket32:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound32]] : index
+ // CHECK: %[[scalingFactor32:.*]] = select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32
+ // CHECK: %[[inBracket64:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound64]] : index
+ // CHECK: %[[scalingFactor64:.*]] = select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32
+ // CHECK: %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32
+ // CHECK: %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32
+ // CHECK: %[[scaledFloat:.*]] = arith.mulf %[[scalingFactor64]], %[[workersFloat]] : f32
+ // CHECK: %[[scaledInt:.*]] = arith.fptosi %[[scaledFloat]] : f32 to i32
+ // CHECK: %[[scaledIndex:.*]] = arith.index_cast %[[scaledInt]] : i32 to index
+
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 100 : index
+ %st = arith.constant 1 : index
+ scf.parallel (%i) = (%lb) to (%ub) step (%st) {
+ %one = arith.constant 1.0 : f32
+ memref.store %one, %arg0[%i] : memref<?xf32>
+ }
+
+ return
+}
includes = ["include"],
deps = [
":ControlFlowInterfacesTdFiles",
+ ":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
],
":ControlFlowInterfaces",
":Dialect",
":IR",
+ ":InferTypeOpInterface",
":SideEffectInterfaces",
":StandardOps",
":Support",