From 552ef9fc094f7c0b0d7bf7a9b4d5da1319a2a67a Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 15 Jun 2020 15:30:21 -0700 Subject: [PATCH] [mlir][DialectConversion] Add overload of addDynamicallyLegalDialect to support lambdas This allows for passing a lambda to addDynamicallyLegalDialect without needing to explicit wrap with Optional. Differential Revision: https://reviews.llvm.org/D81680 --- mlir/include/mlir/Transforms/DialectConversion.h | 6 ++++++ mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 7 +++---- mlir/test/lib/Transforms/TestBufferPlacement.cpp | 4 +--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f9d6671..3e7d503 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -607,6 +607,12 @@ public: if (callback) setLegalityCallback(dialectNames, *callback); } + template + void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) { + SmallVector dialectNames({Args::getDialectNamespace()...}); + setDialectAction(dialectNames, LegalizationAction::Dynamic); + setLegalityCallback(dialectNames, callback); + } /// Register unknown operations as dynamically legal. For operations(and /// dialects) that do not have a set legalization action, treat them as diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 490e670..7df2be9 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -657,10 +657,9 @@ spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { new SPIRVConversionTarget(targetAttr)); SPIRVConversionTarget *targetPtr = target.get(); target->addDynamicallyLegalDialect( - Optional( - // We need to capture the raw pointer here because it is stable: - // target will be destroyed once this function is returned. - [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); })); + // We need to capture the raw pointer here because it is stable: + // target will be destroyed once this function is returned. + [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); return target; } diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp index cbccb7d..0976f71 100644 --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -127,9 +127,7 @@ struct TestBufferPlacementPreparationPass auto isLegalOperation = [&](Operation *op) { return converter.isLegal(op); }; - target.addDynamicallyLegalDialect( - Optional( - isLegalOperation)); + target.addDynamicallyLegalDialect(isLegalOperation); // Mark Standard Return operations illegal as long as one operand is tensor. target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { -- 2.7.4