From 74278dd01e5713920a35f1c3e0731e535667c19a Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 17 Dec 2019 14:57:07 -0800 Subject: [PATCH] NFC: Use TypeSwitch to simplify existing code. PiperOrigin-RevId: 286066371 --- mlir/examples/toy/Ch1/parser/AST.cpp | 24 ++++---- mlir/examples/toy/Ch2/parser/AST.cpp | 24 ++++---- mlir/examples/toy/Ch3/parser/AST.cpp | 24 ++++---- mlir/examples/toy/Ch4/parser/AST.cpp | 24 ++++---- mlir/examples/toy/Ch5/parser/AST.cpp | 24 ++++---- mlir/examples/toy/Ch6/parser/AST.cpp | 24 ++++---- mlir/examples/toy/Ch7/parser/AST.cpp | 25 ++++---- mlir/lib/Analysis/MemRefBoundCheck.cpp | 9 ++- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 33 +++++----- .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 70 ++++++++-------------- mlir/lib/Transforms/Utils/Utils.cpp | 12 ++-- 11 files changed, 117 insertions(+), 176 deletions(-) diff --git a/mlir/examples/toy/Ch1/parser/AST.cpp b/mlir/examples/toy/Ch1/parser/AST.cpp index 0c7735ec..32221d2 100644 --- a/mlir/examples/toy/Ch1/parser/AST.cpp +++ b/mlir/examples/toy/Ch1/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -84,20 +85,15 @@ template static std::string loc(T *node) { /// Dispatch to a generic expressions to the appropriate subclass using RTTI void ASTDumper::dump(ExprAST *expr) { -#define dispatch(CLASS) \ - if (CLASS *node = llvm::dyn_cast(expr)) \ - return dump(node); - dispatch(VarDeclExprAST); - dispatch(LiteralExprAST); - dispatch(NumberExprAST); - dispatch(VariableExprAST); - dispatch(ReturnExprAST); - dispatch(BinaryExprAST); - dispatch(CallExprAST); - dispatch(PrintExprAST); - // No match, fallback to a generic message - INDENT(); - llvm::errs() << "getKind() << ">\n"; + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); } /// A variable declaration is printing the variable name, the type, and then diff --git a/mlir/examples/toy/Ch2/parser/AST.cpp b/mlir/examples/toy/Ch2/parser/AST.cpp index 0c7735ec..32221d2 100644 --- a/mlir/examples/toy/Ch2/parser/AST.cpp +++ b/mlir/examples/toy/Ch2/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -84,20 +85,15 @@ template static std::string loc(T *node) { /// Dispatch to a generic expressions to the appropriate subclass using RTTI void ASTDumper::dump(ExprAST *expr) { -#define dispatch(CLASS) \ - if (CLASS *node = llvm::dyn_cast(expr)) \ - return dump(node); - dispatch(VarDeclExprAST); - dispatch(LiteralExprAST); - dispatch(NumberExprAST); - dispatch(VariableExprAST); - dispatch(ReturnExprAST); - dispatch(BinaryExprAST); - dispatch(CallExprAST); - dispatch(PrintExprAST); - // No match, fallback to a generic message - INDENT(); - llvm::errs() << "getKind() << ">\n"; + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); } /// A variable declaration is printing the variable name, the type, and then diff --git a/mlir/examples/toy/Ch3/parser/AST.cpp b/mlir/examples/toy/Ch3/parser/AST.cpp index 0c7735ec..32221d2 100644 --- a/mlir/examples/toy/Ch3/parser/AST.cpp +++ b/mlir/examples/toy/Ch3/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -84,20 +85,15 @@ template static std::string loc(T *node) { /// Dispatch to a generic expressions to the appropriate subclass using RTTI void ASTDumper::dump(ExprAST *expr) { -#define dispatch(CLASS) \ - if (CLASS *node = llvm::dyn_cast(expr)) \ - return dump(node); - dispatch(VarDeclExprAST); - dispatch(LiteralExprAST); - dispatch(NumberExprAST); - dispatch(VariableExprAST); - dispatch(ReturnExprAST); - dispatch(BinaryExprAST); - dispatch(CallExprAST); - dispatch(PrintExprAST); - // No match, fallback to a generic message - INDENT(); - llvm::errs() << "getKind() << ">\n"; + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); } /// A variable declaration is printing the variable name, the type, and then diff --git a/mlir/examples/toy/Ch4/parser/AST.cpp b/mlir/examples/toy/Ch4/parser/AST.cpp index 0c7735ec..32221d2 100644 --- a/mlir/examples/toy/Ch4/parser/AST.cpp +++ b/mlir/examples/toy/Ch4/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -84,20 +85,15 @@ template static std::string loc(T *node) { /// Dispatch to a generic expressions to the appropriate subclass using RTTI void ASTDumper::dump(ExprAST *expr) { -#define dispatch(CLASS) \ - if (CLASS *node = llvm::dyn_cast(expr)) \ - return dump(node); - dispatch(VarDeclExprAST); - dispatch(LiteralExprAST); - dispatch(NumberExprAST); - dispatch(VariableExprAST); - dispatch(ReturnExprAST); - dispatch(BinaryExprAST); - dispatch(CallExprAST); - dispatch(PrintExprAST); - // No match, fallback to a generic message - INDENT(); - llvm::errs() << "getKind() << ">\n"; + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); } /// A variable declaration is printing the variable name, the type, and then diff --git a/mlir/examples/toy/Ch5/parser/AST.cpp b/mlir/examples/toy/Ch5/parser/AST.cpp index 0c7735ec..32221d2 100644 --- a/mlir/examples/toy/Ch5/parser/AST.cpp +++ b/mlir/examples/toy/Ch5/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -84,20 +85,15 @@ template static std::string loc(T *node) { /// Dispatch to a generic expressions to the appropriate subclass using RTTI void ASTDumper::dump(ExprAST *expr) { -#define dispatch(CLASS) \ - if (CLASS *node = llvm::dyn_cast(expr)) \ - return dump(node); - dispatch(VarDeclExprAST); - dispatch(LiteralExprAST); - dispatch(NumberExprAST); - dispatch(VariableExprAST); - dispatch(ReturnExprAST); - dispatch(BinaryExprAST); - dispatch(CallExprAST); - dispatch(PrintExprAST); - // No match, fallback to a generic message - INDENT(); - llvm::errs() << "getKind() << ">\n"; + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); } /// A variable declaration is printing the variable name, the type, and then diff --git a/mlir/examples/toy/Ch6/parser/AST.cpp b/mlir/examples/toy/Ch6/parser/AST.cpp index 0c7735ec..32221d2 100644 --- a/mlir/examples/toy/Ch6/parser/AST.cpp +++ b/mlir/examples/toy/Ch6/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -84,20 +85,15 @@ template static std::string loc(T *node) { /// Dispatch to a generic expressions to the appropriate subclass using RTTI void ASTDumper::dump(ExprAST *expr) { -#define dispatch(CLASS) \ - if (CLASS *node = llvm::dyn_cast(expr)) \ - return dump(node); - dispatch(VarDeclExprAST); - dispatch(LiteralExprAST); - dispatch(NumberExprAST); - dispatch(VariableExprAST); - dispatch(ReturnExprAST); - dispatch(BinaryExprAST); - dispatch(CallExprAST); - dispatch(PrintExprAST); - // No match, fallback to a generic message - INDENT(); - llvm::errs() << "getKind() << ">\n"; + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); } /// A variable declaration is printing the variable name, the type, and then diff --git a/mlir/examples/toy/Ch7/parser/AST.cpp b/mlir/examples/toy/Ch7/parser/AST.cpp index 7405629..6ade418 100644 --- a/mlir/examples/toy/Ch7/parser/AST.cpp +++ b/mlir/examples/toy/Ch7/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -86,21 +87,15 @@ template static std::string loc(T *node) { /// Dispatch to a generic expressions to the appropriate subclass using RTTI void ASTDumper::dump(ExprAST *expr) { -#define dispatch(CLASS) \ - if (CLASS *node = llvm::dyn_cast(expr)) \ - return dump(node); - dispatch(VarDeclExprAST); - dispatch(LiteralExprAST); - dispatch(StructLiteralExprAST); - dispatch(NumberExprAST); - dispatch(VariableExprAST); - dispatch(ReturnExprAST); - dispatch(BinaryExprAST); - dispatch(CallExprAST); - dispatch(PrintExprAST); - // No match, fallback to a generic message - INDENT(); - llvm::errs() << "getKind() << ">\n"; + mlir::TypeSwitch(expr) + .Case([&](auto *node) { dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); } /// A variable declaration is printing the variable name, the type, and then diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 52379c0..4696ce6 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" @@ -49,11 +50,9 @@ std::unique_ptr> mlir::createMemRefBoundCheckPass() { void MemRefBoundCheck::runOnFunction() { getFunction().walk([](Operation *opInst) { - if (auto loadOp = dyn_cast(opInst)) { - boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = dyn_cast(opInst)) { - boundCheckLoadOrStoreOp(storeOp); - } + TypeSwitch(opInst).Case( + [](auto op) { boundCheckLoadOrStoreOp(op); }); + // TODO(bondhugula): do this for DMA ops as well. }); } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 51cdd72..5d6a92f 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -21,6 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -232,25 +233,19 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) { } // Dispatch based on the actual type. Return null type on error. -Type LLVMTypeConverter::convertStandardType(Type type) { - if (auto funcType = type.dyn_cast()) - return convertFunctionType(funcType); - if (auto intType = type.dyn_cast()) - return convertIntegerType(intType); - if (auto floatType = type.dyn_cast()) - return convertFloatType(floatType); - if (auto indexType = type.dyn_cast()) - return convertIndexType(indexType); - if (auto memRefType = type.dyn_cast()) - return convertMemRefType(memRefType); - if (auto memRefType = type.dyn_cast()) - return convertUnrankedMemRefType(memRefType); - if (auto vectorType = type.dyn_cast()) - return convertVectorType(vectorType); - if (auto llvmType = type.dyn_cast()) - return llvmType; - - return {}; +Type LLVMTypeConverter::convertStandardType(Type t) { + return TypeSwitch(t) + .Case([&](FloatType type) { return convertFloatType(type); }) + .Case([&](FunctionType type) { return convertFunctionType(type); }) + .Case([&](IndexType type) { return convertIndexType(type); }) + .Case([&](IntegerType type) { return convertIntegerType(type); }) + .Case([&](MemRefType type) { return convertMemRefType(type); }) + .Case([&](UnrankedMemRefType type) { + return convertUnrankedMemRefType(type); + }) + .Case([&](VectorType type) { return convertVectorType(type); }) + .Case([](LLVM::LLVMType type) { return type; }) + .Default([](Type) { return Type(); }); } LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 2cb75de..f7591bf 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/SPIRV/Serialization.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" @@ -1634,54 +1635,33 @@ Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { return success(); } -LogicalResult Serializer::processOperation(Operation *op) { - LLVM_DEBUG(llvm::dbgs() << "[op] '" << op->getName() << "'\n"); +LogicalResult Serializer::processOperation(Operation *opInst) { + LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); // First dispatch the ops that do not directly mirror an instruction from // the SPIR-V spec. - if (auto addressOfOp = dyn_cast(op)) { - return processAddressOfOp(addressOfOp); - } - if (auto branchOp = dyn_cast(op)) { - return processBranchOp(branchOp); - } - if (auto condBranchOp = dyn_cast(op)) { - return processBranchConditionalOp(condBranchOp); - } - if (auto constOp = dyn_cast(op)) { - return processConstantOp(constOp); - } - if (auto fnOp = dyn_cast(op)) { - return processFuncOp(fnOp); - } - if (auto varOp = dyn_cast(op)) { - return processVariableOp(varOp); - } - if (auto varOp = dyn_cast(op)) { - return processGlobalVariableOp(varOp); - } - if (auto selectionOp = dyn_cast(op)) { - return processSelectionOp(selectionOp); - } - if (auto loopOp = dyn_cast(op)) { - return processLoopOp(loopOp); - } - if (isa(op)) { - return success(); - } - if (auto refOpOp = dyn_cast(op)) { - return processReferenceOfOp(refOpOp); - } - if (auto specConstOp = dyn_cast(op)) { - return processSpecConstantOp(specConstOp); - } - if (auto undefOp = dyn_cast(op)) { - return processUndefOp(undefOp); - } - - // Then handle all the ops that directly mirror SPIR-V instructions with - // auto-generated methods. - return dispatchToAutogenSerialization(op); + return TypeSwitch(opInst) + .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) + .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) + .Case([&](spirv::BranchConditionalOp op) { + return processBranchConditionalOp(op); + }) + .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) + .Case([&](FuncOp op) { return processFuncOp(op); }) + .Case([&](spirv::GlobalVariableOp op) { + return processGlobalVariableOp(op); + }) + .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) + .Case([&](spirv::ModuleEndOp) { return success(); }) + .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) + .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) + .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) + .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) + .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) + + // Then handle all the ops that directly mirror SPIR-V instructions with + // auto-generated methods. + .Default([&](auto *op) { return dispatchToAutogenSerialization(op); }); } namespace { diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 190b6c3..79a6d7a 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/Utils.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" @@ -47,14 +48,9 @@ static bool isMemRefDereferencingOp(Operation &op) { /// Return the AffineMapAttr associated with memory 'op' on 'memref'. static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) { - if (auto loadOp = dyn_cast(op)) - return loadOp.getAffineMapAttrForMemRef(memref); - else if (auto storeOp = dyn_cast(op)) - return storeOp.getAffineMapAttrForMemRef(memref); - else if (auto dmaStart = dyn_cast(op)) - return dmaStart.getAffineMapAttrForMemRef(memref); - assert(isa(op)); - return cast(op).getAffineMapAttrForMemRef(memref); + return TypeSwitch(op) + .Case( + [=](auto op) { return op.getAffineMapAttrForMemRef(memref); }); } // Perform the replacement in `op`. -- 2.7.4