From 54cd6a7e97a226738e2c85b86559918dd9e3cd5d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 1 Jul 2019 10:29:09 -0700 Subject: [PATCH] NFC: Refactor Function to be value typed. Move the data members out of Function and into a new impl storage class 'FunctionStorage'. This allows for Function to become value typed, which will greatly simplify the transition of Function to FuncOp(given that FuncOp is also value typed). PiperOrigin-RevId: 255983022 --- mlir/bindings/python/pybind.cpp | 35 +-- .../Linalg/Linalg1/include/linalg1/Common.h | 24 +-- mlir/examples/Linalg/Linalg2/Example.cpp | 20 +- mlir/examples/Linalg/Linalg3/Conversion.cpp | 22 +- mlir/examples/Linalg/Linalg3/Example.cpp | 32 +-- mlir/examples/Linalg/Linalg3/Execution.cpp | 22 +- .../Linalg/Linalg3/include/linalg3/Transforms.h | 6 +- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 2 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 12 +- mlir/examples/Linalg/Linalg4/Example.cpp | 40 ++-- .../Linalg/Linalg4/include/linalg4/Transforms.h | 4 +- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 9 +- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 28 +-- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 28 +-- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 28 +-- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 54 ++--- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 22 +- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 28 +-- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 54 ++--- mlir/include/mlir/Analysis/Dominance.h | 4 +- mlir/include/mlir/Analysis/NestedMatcher.h | 4 +- mlir/include/mlir/ExecutionEngine/MemRefUtils.h | 2 +- mlir/include/mlir/GPU/GPUDialect.h | 6 +- mlir/include/mlir/IR/Attributes.h | 3 +- mlir/include/mlir/IR/Block.h | 2 +- mlir/include/mlir/IR/Builders.h | 2 +- mlir/include/mlir/IR/Dialect.h | 10 +- mlir/include/mlir/IR/Function.h | 234 ++++++++++++++------- mlir/include/mlir/IR/Module.h | 66 ++++-- mlir/include/mlir/IR/Operation.h | 2 +- mlir/include/mlir/IR/PatternMatch.h | 2 +- mlir/include/mlir/IR/Region.h | 11 +- mlir/include/mlir/IR/SymbolTable.h | 12 +- mlir/include/mlir/IR/Value.h | 4 +- mlir/include/mlir/LLVMIR/LLVMDialect.h | 2 +- mlir/include/mlir/Pass/AnalysisManager.h | 18 +- mlir/include/mlir/Pass/Pass.h | 17 +- mlir/include/mlir/Pass/PassInstrumentation.h | 10 +- mlir/include/mlir/StandardOps/Ops.td | 4 +- mlir/include/mlir/Transforms/DialectConversion.h | 4 +- mlir/include/mlir/Transforms/LowerAffine.h | 2 +- mlir/include/mlir/Transforms/ViewFunctionGraph.h | 4 +- mlir/lib/AffineOps/AffineOps.cpp | 2 +- mlir/lib/Analysis/Dominance.cpp | 9 +- mlir/lib/Analysis/OpStats.cpp | 2 +- mlir/lib/Analysis/TestParallelismDetection.cpp | 2 +- mlir/lib/Analysis/Verifier.cpp | 14 +- .../GPUToCUDA/ConvertKernelFuncToCubin.cpp | 6 +- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 106 +++++----- .../GPUToCUDA/GenerateCubinAccessors.cpp | 26 +-- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 18 +- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 4 +- .../Dialect/QuantOps/Transforms/ConvertConst.cpp | 2 +- .../QuantOps/Transforms/ConvertSimQuant.cpp | 2 +- mlir/lib/ExecutionEngine/MemRefUtils.cpp | 10 +- mlir/lib/GPU/IR/GPUDialect.cpp | 18 +- mlir/lib/GPU/Transforms/KernelOutlining.cpp | 28 +-- mlir/lib/IR/AsmPrinter.cpp | 50 ++--- mlir/lib/IR/Attributes.cpp | 5 - mlir/lib/IR/Block.cpp | 2 +- mlir/lib/IR/Builders.cpp | 4 +- mlir/lib/IR/Dialect.cpp | 15 ++ mlir/lib/IR/Function.cpp | 59 +++--- mlir/lib/IR/Operation.cpp | 9 +- mlir/lib/IR/Region.cpp | 10 +- mlir/lib/IR/SymbolTable.cpp | 22 +- mlir/lib/IR/Value.cpp | 8 +- mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 4 +- mlir/lib/Linalg/Transforms/Fusion.cpp | 2 +- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 69 +++--- mlir/lib/Linalg/Transforms/LowerToLoops.cpp | 3 +- mlir/lib/Linalg/Transforms/Tiling.cpp | 2 +- mlir/lib/Parser/Parser.cpp | 28 +-- mlir/lib/Pass/IRPrinting.cpp | 12 +- mlir/lib/Pass/Pass.cpp | 23 +- mlir/lib/Pass/PassDetail.h | 2 +- .../Transforms/AddDefaultStatsTestPass.cpp | 2 +- .../Transforms/InferQuantizedTypesPass.cpp | 2 +- .../Transforms/RemoveInstrumentationPass.cpp | 2 +- mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp | 6 +- mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp | 2 +- .../SPIRV/Transforms/StdOpsToSPIRVConversion.cpp | 2 +- mlir/lib/StandardOps/Ops.cpp | 12 +- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 31 +-- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 4 +- mlir/lib/Transforms/Canonicalizer.cpp | 2 +- mlir/lib/Transforms/DialectConversion.cpp | 54 ++--- mlir/lib/Transforms/DmaGeneration.cpp | 10 +- mlir/lib/Transforms/LoopFusion.cpp | 12 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnroll.cpp | 8 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 12 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 4 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 4 +- mlir/lib/Transforms/ViewFunctionGraph.cpp | 4 +- mlir/test/EDSC/builder-api-test.cpp | 150 +++++++------ .../test/lib/Transforms/TestVectorizationUtils.cpp | 16 +- mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp | 18 +- mlir/unittests/Pass/AnalysisManagerTest.cpp | 20 +- 103 files changed, 986 insertions(+), 874 deletions(-) diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 222ef52..cdf4a7f 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/IR/Function.h" #include "llvm/IR/Module.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" @@ -110,13 +111,14 @@ struct PythonValueHandle { struct PythonFunction { PythonFunction() : function{nullptr} {} PythonFunction(mlir_func_t f) : function{f} {} - PythonFunction(mlir::Function *f) : function{f} {} + PythonFunction(mlir::Function f) + : function(const_cast(f.getAsOpaquePointer())) {} operator mlir_func_t() { return function; } std::string str() { - mlir::Function *f = reinterpret_cast(function); + mlir::Function f = mlir::Function::getFromOpaquePointer(function); std::string res; llvm::raw_string_ostream os(res); - f->print(os); + f.print(os); return res; } @@ -124,18 +126,18 @@ struct PythonFunction { // declaration, add the entry block, transforming the declaration into a // definition. Return true if the block was added, false otherwise. bool define() { - auto *f = reinterpret_cast(function); - if (!f->getBlocks().empty()) + auto f = mlir::Function::getFromOpaquePointer(function); + if (!f.getBlocks().empty()) return false; - f->addEntryBlock(); + f.addEntryBlock(); return true; } PythonValueHandle arg(unsigned index) { - Function *f = static_cast(function); - assert(index < f->getNumArguments() && "argument index out of bounds"); - return PythonValueHandle(ValueHandle(f->getArgument(index))); + auto f = mlir::Function::getFromOpaquePointer(function); + assert(index < f.getNumArguments() && "argument index out of bounds"); + return PythonValueHandle(ValueHandle(f.getArgument(index))); } mlir_func_t function; @@ -250,10 +252,9 @@ struct PythonFunctionContext { PythonFunction enter() { assert(function.function && "function is not set up"); - auto *mlirFunc = static_cast(function.function); - contextBuilder.emplace(mlirFunc->getBody()); - context = - new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc->getLoc()); + auto mlirFunc = mlir::Function::getFromOpaquePointer(function.function); + contextBuilder.emplace(mlirFunc.getBody()); + context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc()); return function; } @@ -594,7 +595,7 @@ PythonMLIRModule::declareFunction(const std::string &name, } // Create the function itself. - auto *func = new mlir::Function( + auto func = mlir::Function::create( UnknownLoc::get(&mlirContext), name, mlir::Type::getFromOpaquePointer(funcType).cast(), attrs, inputAttrs); @@ -652,9 +653,9 @@ PYBIND11_MODULE(pybind, m) { return ValueHandle::create(value, floatType); }); m.def("constant_function", [](PythonFunction func) -> PythonValueHandle { - auto *function = reinterpret_cast(func.function); - auto attr = FunctionAttr::get(function); - return ValueHandle::create(function->getType(), attr); + auto function = Function::getFromOpaquePointer(func.function); + auto attr = FunctionAttr::get(function.getName(), function.getContext()); + return ValueHandle::create(function.getType(), attr); }); m.def("appendTo", [](const PythonBlockHandle &handle) { return PythonBlockAppender(handle); diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h index ddd6df9..1f129c6 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h @@ -57,15 +57,15 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context, } /// A basic function builder -inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name, - llvm::ArrayRef types, - llvm::ArrayRef resultTypes) { +inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name, + llvm::ArrayRef types, + llvm::ArrayRef resultTypes) { auto *context = module.getContext(); - auto *function = new mlir::Function( + auto function = mlir::Function::create( mlir::UnknownLoc::get(context), name, mlir::FunctionType::get({types}, resultTypes, context)); - function->addEntryBlock(); - module.getFunctions().push_back(function); + function.addEntryBlock(); + module.push_back(function); return function; } @@ -83,19 +83,19 @@ inline std::unique_ptr cleanupPassManager() { /// llvm::outs() for FileCheck'ing. /// If an error occurs, dump to llvm::errs() and do not print to llvm::outs() /// which will make the associated FileCheck test fail. -inline void cleanupAndPrintFunction(mlir::Function *f) { +inline void cleanupAndPrintFunction(mlir::Function f) { bool printToOuts = true; - auto check = [f, &printToOuts](mlir::LogicalResult result) { + auto check = [&f, &printToOuts](mlir::LogicalResult result) { if (failed(result)) { - f->emitError("Verification and cleanup passes failed"); + f.emitError("Verification and cleanup passes failed"); printToOuts = false; } }; auto pm = cleanupPassManager(); - check(f->getModule()->verify()); - check(pm->run(f->getModule())); + check(f.getModule()->verify()); + check(pm->run(f.getModule())); if (printToOuts) - f->print(llvm::outs()); + f.print(llvm::outs()); } /// Helper class to sugar building loop nests from indexings that appear in diff --git a/mlir/examples/Linalg/Linalg2/Example.cpp b/mlir/examples/Linalg/Linalg2/Example.cpp index a415dae..9534711 100644 --- a/mlir/examples/Linalg/Linalg2/Example.cpp +++ b/mlir/examples/Linalg/Linalg2/Example.cpp @@ -36,14 +36,14 @@ TEST_FUNC(linalg_ops) { MLIRContext context; Module module(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function *f = + mlir::Function f = makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)), + ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), @@ -75,14 +75,14 @@ TEST_FUNC(linalg_ops_folded_slices) { MLIRContext context; Module module(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices", - {indexType, indexType, indexType}, {}); + mlir::Function f = makeFunction(module, "linalg_ops_folded_slices", + {indexType, indexType, indexType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)), + ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), @@ -104,7 +104,7 @@ TEST_FUNC(linalg_ops_folded_slices) { // CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view // clang-format on - f->walk([](SliceOp slice) { + f.walk([](SliceOp slice) { auto *sliceResult = slice.getResult(); auto viewOp = emitAndReturnFullyComposedView(sliceResult); sliceResult->replaceAllUsesWith(viewOp.getResult()); diff --git a/mlir/examples/Linalg/Linalg3/Conversion.cpp b/mlir/examples/Linalg/Linalg3/Conversion.cpp index 37d1b51..23d1cfe 100644 --- a/mlir/examples/Linalg/Linalg3/Conversion.cpp +++ b/mlir/examples/Linalg/Linalg3/Conversion.cpp @@ -37,26 +37,26 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -67,7 +67,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(foo) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); convertLinalg3ToLLVM(module); diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp index f02aef9..8b04344 100644 --- a/mlir/examples/Linalg/Linalg3/Example.cpp +++ b/mlir/examples/Linalg/Linalg3/Example.cpp @@ -34,26 +34,26 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - mlir::OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + mlir::OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -64,7 +64,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(matmul_as_matvec) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); // clang-format off @@ -82,7 +82,7 @@ TEST_FUNC(matmul_as_matvec) { TEST_FUNC(matmul_as_dot) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); lowerToFinerGrainedTensorContraction(f); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); @@ -103,7 +103,7 @@ TEST_FUNC(matmul_as_dot) { TEST_FUNC(matmul_as_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); composeSliceOps(f); // clang-format off @@ -135,7 +135,7 @@ TEST_FUNC(matmul_as_loops) { TEST_FUNC(matmul_as_matvec_as_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops"); lowerToFinerGrainedTensorContraction(f); lowerToLoops(f); @@ -166,14 +166,14 @@ TEST_FUNC(matmul_as_matvec_as_loops) { TEST_FUNC(matmul_as_matvec_as_affine) { MLIRContext context; Module module(&context); - mlir::Function *f = + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine"); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); lowerToLoops(f); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); - if (succeeded(pm.run(f->getModule()))) + if (succeeded(pm.run(f.getModule()))) cleanupAndPrintFunction(f); // clang-format off diff --git a/mlir/examples/Linalg/Linalg3/Execution.cpp b/mlir/examples/Linalg/Linalg3/Execution.cpp index 00d571c..94b233a 100644 --- a/mlir/examples/Linalg/Linalg3/Execution.cpp +++ b/mlir/examples/Linalg/Linalg3/Execution.cpp @@ -37,26 +37,26 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - mlir::OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + mlir::OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -110,7 +110,7 @@ TEST_FUNC(execution) { // dialect through partial conversions. MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); convertLinalg3ToLLVM(module); diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h index 9af528e..6c0aec0 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h @@ -55,11 +55,11 @@ makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps, /// Traverses `f` and rewrites linalg.slice, and the operations it depends on, /// to only use linalg.view operations. -void composeSliceOps(mlir::Function *f); +void composeSliceOps(mlir::Function f); /// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec) /// as linalg.matvec(resp. linalg.dot). -void lowerToFinerGrainedTensorContraction(mlir::Function *f); +void lowerToFinerGrainedTensorContraction(mlir::Function f); /// Operation-wise writing of linalg operations to loop form. /// It is the caller's responsibility to erase the `op` if necessary. @@ -69,7 +69,7 @@ llvm::Optional> writeAsLoops(mlir::Operation *op); /// Traverses `f` and rewrites linalg operations in loop form. -void lowerToLoops(mlir::Function *f); +void lowerToLoops(mlir::Function f); /// Creates a pass that rewrites linalg.load and linalg.store to affine.load and /// affine.store operations. diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 7b559bf..96b0f37 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -148,7 +148,7 @@ static void populateLinalg3ToLLVMConversionPatterns( void linalg::convertLinalg3ToLLVM(Module &module) { // Remove affine constructs. - for (auto &func : module) { + for (auto func : module) { auto rr = lowerAffineConstructs(func); (void)rr; assert(succeeded(rr) && "affine loop lowering failed"); diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index d5c8641..7b9e5ff 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -35,8 +35,8 @@ using namespace mlir::edsc::intrinsics; using namespace linalg; using namespace linalg::intrinsics; -void linalg::composeSliceOps(mlir::Function *f) { - f->walk([](SliceOp sliceOp) { +void linalg::composeSliceOps(mlir::Function f) { + f.walk([](SliceOp sliceOp) { auto *sliceResult = sliceOp.getResult(); auto viewOp = emitAndReturnFullyComposedView(sliceResult); sliceResult->replaceAllUsesWith(viewOp.getResult()); @@ -44,8 +44,8 @@ void linalg::composeSliceOps(mlir::Function *f) { }); } -void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) { - f->walk([](Operation *op) { +void linalg::lowerToFinerGrainedTensorContraction(mlir::Function f) { + f.walk([](Operation *op) { if (auto matmulOp = dyn_cast(op)) { matmulOp.writeAsFinerGrainTensorContraction(); } else if (auto matvecOp = dyn_cast(op)) { @@ -211,8 +211,8 @@ linalg::writeAsLoops(Operation *op) { return llvm::None; } -void linalg::lowerToLoops(mlir::Function *f) { - f->walk([](Operation *op) { +void linalg::lowerToLoops(mlir::Function f) { + f.walk([](Operation *op) { if (writeAsLoops(op)) op->erase(); }); diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp index cdc05a1..873e57e 100644 --- a/mlir/examples/Linalg/Linalg4/Example.cpp +++ b/mlir/examples/Linalg/Linalg4/Example.cpp @@ -34,27 +34,27 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -65,11 +65,11 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(matmul_tiled_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops"); lowerToTiledLoops(f, {8, 9}); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); - if (succeeded(pm.run(f->getModule()))) + if (succeeded(pm.run(f.getModule()))) cleanupAndPrintFunction(f); // clang-format off @@ -96,10 +96,10 @@ TEST_FUNC(matmul_tiled_loops) { TEST_FUNC(matmul_tiled_views) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views"); - OpBuilder b(f->getBody()); - lowerToTiledViews(f, {b.create(f->getLoc(), 8), - b.create(f->getLoc(), 9)}); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views"); + OpBuilder b(f.getBody()); + lowerToTiledViews(f, {b.create(f.getLoc(), 8), + b.create(f.getLoc(), 9)}); composeSliceOps(f); // clang-format off @@ -125,11 +125,11 @@ TEST_FUNC(matmul_tiled_views) { TEST_FUNC(matmul_tiled_views_as_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops"); - OpBuilder b(f->getBody()); - lowerToTiledViews(f, {b.create(f->getLoc(), 8), - b.create(f->getLoc(), 9)}); + OpBuilder b(f.getBody()); + lowerToTiledViews(f, {b.create(f.getLoc(), 8), + b.create(f.getLoc(), 9)}); composeSliceOps(f); lowerToLoops(f); // This cannot lower below linalg.load and linalg.store due to lost diff --git a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h index 2165cab..ba7273e 100644 --- a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h +++ b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h @@ -34,12 +34,12 @@ writeAsTiledViews(mlir::Operation *op, llvm::ArrayRef tileSizes); /// Apply `writeAsTiledLoops` on all linalg ops. This is a convenience function /// and is not exposed as a pass because a fixed set of tile sizes for all ops /// in a function can generally not be specified. -void lowerToTiledLoops(mlir::Function *f, llvm::ArrayRef tileSizes); +void lowerToTiledLoops(mlir::Function f, llvm::ArrayRef tileSizes); /// Apply `writeAsTiledViews` on all linalg ops. This is a convenience function /// and is not exposed as a pass because a fixed set of tile sizes for all ops /// in a function can generally not be specified. -void lowerToTiledViews(mlir::Function *f, +void lowerToTiledViews(mlir::Function f, llvm::ArrayRef tileSizes); } // namespace linalg diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 1a308df..16b395d 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -43,9 +43,8 @@ linalg::writeAsTiledLoops(Operation *op, ArrayRef tileSizes) { return llvm::None; } -void linalg::lowerToTiledLoops(mlir::Function *f, - ArrayRef tileSizes) { - f->walk([tileSizes](Operation *op) { +void linalg::lowerToTiledLoops(mlir::Function f, ArrayRef tileSizes) { + f.walk([tileSizes](Operation *op) { if (writeAsTiledLoops(op, tileSizes).hasValue()) op->erase(); }); @@ -185,8 +184,8 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { return llvm::None; } -void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef tileSizes) { - f->walk([tileSizes](Operation *op) { +void linalg::lowerToTiledViews(mlir::Function f, ArrayRef tileSizes) { + f.walk([tileSizes](Operation *op) { if (auto matmulOp = dyn_cast(op)) { writeAsTiledViews(matmulOp, tileSizes); } else if (auto matvecOp = dyn_cast(op)) { diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 842c7a1..73789fa 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -75,7 +75,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -129,40 +129,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -172,16 +172,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function->getBody()); + builder = llvm::make_unique(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emitted. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index e365f37..23cb853 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -76,7 +76,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -130,40 +130,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -173,16 +173,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function->getBody()); + builder = llvm::make_unique(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emitted. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index 032766a..f2132c2 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -76,7 +76,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -130,40 +130,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -173,16 +173,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function->getBody()); + builder = llvm::make_unique(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emited. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 688c736..f237fd9 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -113,14 +113,14 @@ public: // function to process, the mangled name for this specialization, and the // types of the arguments on which to specialize. struct FunctionToSpecialize { - mlir::Function *function; + mlir::Function function; std::string mangledName; SmallVector argumentsType; }; void runOnModule() override { auto &module = getModule(); - auto *main = module.getNamedFunction("main"); + auto main = module.getNamedFunction("main"); if (!main) { emitError(mlir::UnknownLoc::get(module.getContext()), "Shape inference failed: can't find a main function\n"); @@ -139,7 +139,7 @@ public: // Delete any generic function left // FIXME: we may want this as a separate pass. - for (mlir::Function &function : llvm::make_early_inc_range(module)) { + for (mlir::Function function : llvm::make_early_inc_range(module)) { if (auto genericAttr = function.getAttrOfType("toy.generic")) { if (genericAttr.getValue()) @@ -153,7 +153,7 @@ public: mlir::LogicalResult specialize(SmallVectorImpl &funcWorklist) { FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function *f = functionToSpecialize.function; + mlir::Function f = functionToSpecialize.function; // Check if cloning for specialization is needed (usually anything but main) // We will create a new function with the concrete types for the parameters @@ -169,36 +169,36 @@ public: auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, {ToyArrayType::get(&getContext())}, &getContext()); - auto *newFunction = new mlir::Function( - f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); - getModule().getFunctions().push_back(newFunction); + auto newFunction = mlir::Function::create( + f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs()); + getModule().push_back(newFunction); // Clone the function body mlir::BlockAndValueMapping mapper; - f->cloneInto(newFunction, mapper); + f.cloneInto(newFunction, mapper); LLVM_DEBUG({ llvm::dbgs() << "====== Cloned : \n"; - f->dump(); + f.dump(); llvm::dbgs() << "====== Into : \n"; - newFunction->dump(); + newFunction.dump(); }); f = newFunction; - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // Remap the entry-block arguments // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f->getBlocks().front(); + auto &entryBlock = f.getBlocks().front(); int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == static_cast(f->getType().getInputs().size())); - entryBlock.addArguments(f->getType().getInputs()); + assert(blockArgSize == static_cast(f.getType().getInputs().size())); + entryBlock.addArguments(f.getType().getInputs()); auto argList = entryBlock.getArguments(); for (int argNum = 0; argNum < blockArgSize; ++argNum) { argList[0]->replaceAllUsesWith(argList[blockArgSize]); entryBlock.eraseArgument(0); } - assert(succeeded(f->verify())); + assert(succeeded(f.verify())); } LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f->getName() << "'\n"); + << "Run shape inference on : '" << f.getName() << "'\n"); auto *toyDialect = getContext().getRegisteredDialect("toy"); if (!toyDialect) { @@ -211,7 +211,7 @@ public: // Populate the worklist with the operations that need shape inference: // these are the Toy operations that return a generic array. llvm::SmallPtrSet opWorklist; - f->walk([&](mlir::Operation *op) { + f.walk([&](mlir::Operation *op) { if (op->getDialect() == toyDialect) { if (op->getNumResults() == 1 && op->getResult(0)->getType().cast().isGeneric()) @@ -292,9 +292,9 @@ public: // restart after the callee is processed. if (auto callOp = llvm::dyn_cast(op)) { auto calleeName = callOp.getCalleeName(); - auto *callee = getModule().getNamedFunction(calleeName); + auto callee = getModule().getNamedFunction(calleeName); if (!callee) { - f->emitError("Shape inference failed, call to unknown '") + f.emitError("Shape inference failed, call to unknown '") << calleeName << "'"; signalPassFailure(); return mlir::failure(); @@ -302,7 +302,7 @@ public: auto mangledName = mangle(calleeName, op->getOpOperands()); LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName << "', mangled: '" << mangledName << "'\n"); - auto *mangledCallee = getModule().getNamedFunction(mangledName); + auto mangledCallee = getModule().getNamedFunction(mangledName); if (!mangledCallee) { // Can't find the target, this is where we queue the request for the // callee and stop the inference for the current function now. @@ -327,7 +327,7 @@ public: // Done with inference on this function, removing it from the worklist. funcWorklist.pop_back(); // Mark the function as non-generic now that inference has succeeded - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { @@ -337,31 +337,31 @@ public: << " operations couldn't be inferred\n"; for (auto *ope : opWorklist) errorMsg << " - " << *ope << "\n"; - f->emitError(errorMsg.str()); + f.emitError(errorMsg.str()); signalPassFailure(); return mlir::failure(); } // Finally, update the return type of the function based on the argument to // the return operation. - for (auto &block : f->getBlocks()) { + for (auto &block : f.getBlocks()) { auto ret = llvm::cast(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && - f->getType().getResult(0) == ret.getOperand()->getType()) + f.getType().getResult(0) == ret.getOperand()->getType()) // type match, we're done break; SmallVector retTy; if (ret.getNumOperands()) retTy.push_back(ret.getOperand()->getType()); std::vector argumentsType; - for (auto arg : f->getArguments()) + for (auto arg : f.getArguments()) argumentsType.push_back(arg->getType()); auto newType = mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f->setType(newType); - assert(succeeded(f->verify())); + f.setType(newType); + assert(succeeded(f.verify())); break; } return mlir::success(); diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 8b2a392..60a8b5a 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -136,14 +136,14 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { // Get or create the declaration of the printf function in the module. - Function *printfFunc = getPrintf(*op->getFunction()->getModule()); + Function printfFunc = getPrintf(*op->getFunction().getModule()); auto print = cast(op); auto loc = print.getLoc(); // We will operate on a MemRef abstraction, we use a type.cast to get one // if our operand is still a Toy array. Value *operand = memRefTypeCast(rewriter, operands[0]); - Type retTy = printfFunc->getType().getResult(0); + Type retTy = printfFunc.getType().getResult(0); // Create our loop nest now using namespace edsc; @@ -205,8 +205,8 @@ private: /// Return the prototype declaration for printf in the module, create it if /// necessary. - Function *getPrintf(Module &module) const { - auto *printfFunc = module.getNamedFunction("printf"); + Function getPrintf(Module &module) const { + auto printfFunc = module.getNamedFunction("printf"); if (printfFunc) return printfFunc; @@ -218,10 +218,10 @@ private: auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect); auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo(); auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); - printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy); + printfFunc = Function::create(builder.getUnknownLoc(), "printf", printfTy); // It should be variadic, but we don't support it fully just yet. - printfFunc->setAttr("std.varargs", builder.getBoolAttr(true)); - module.getFunctions().push_back(printfFunc); + printfFunc.setAttr("std.varargs", builder.getBoolAttr(true)); + module.push_back(printfFunc); return printfFunc; } }; @@ -369,7 +369,7 @@ struct LateLoweringPass : public ModulePass { // affine dialect: they already include conversion to the LLVM dialect. // First patch calls type to return memref instead of ToyArray - for (auto &function : getModule()) { + for (auto function : getModule()) { function.walk([&](Operation *op) { auto callOp = dyn_cast(op); if (!callOp) @@ -384,7 +384,7 @@ struct LateLoweringPass : public ModulePass { }); } - for (auto &function : getModule()) { + for (auto function : getModule()) { function.walk([&](Operation *op) { // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). if (auto allocOp = dyn_cast(op)) { @@ -403,8 +403,8 @@ struct LateLoweringPass : public ModulePass { } // Lower Linalg to affine - for (auto &function : getModule()) - linalg::lowerToLoops(&function); + for (auto function : getModule()) + linalg::lowerToLoops(function); getModule().dump(); diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index f7e6fad..9ebfeb4 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -76,7 +76,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -130,40 +130,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -173,16 +173,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function->getBody()); + builder = llvm::make_unique(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emited. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index cad2ded..0abcb4b 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -113,7 +113,7 @@ public: // function to process, the mangled name for this specialization, and the // types of the arguments on which to specialize. struct FunctionToSpecialize { - mlir::Function *function; + mlir::Function function; std::string mangledName; SmallVector argumentsType; }; @@ -121,7 +121,7 @@ public: void runOnModule() override { auto &module = getModule(); mlir::ModuleManager moduleManager(&module); - auto *main = moduleManager.getNamedFunction("main"); + auto main = moduleManager.getNamedFunction("main"); if (!main) { emitError(mlir::UnknownLoc::get(module.getContext()), "Shape inference failed: can't find a main function\n"); @@ -140,7 +140,7 @@ public: // Delete any generic function left // FIXME: we may want this as a separate pass. - for (mlir::Function &function : llvm::make_early_inc_range(module)) { + for (mlir::Function function : llvm::make_early_inc_range(module)) { if (auto genericAttr = function.getAttrOfType("toy.generic")) { if (genericAttr.getValue()) @@ -155,7 +155,7 @@ public: specialize(SmallVectorImpl &funcWorklist, mlir::ModuleManager &moduleManager) { FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function *f = functionToSpecialize.function; + mlir::Function f = functionToSpecialize.function; // Check if cloning for specialization is needed (usually anything but main) // We will create a new function with the concrete types for the parameters @@ -171,36 +171,36 @@ public: auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, {ToyArrayType::get(&getContext())}, &getContext()); - auto *newFunction = new mlir::Function( - f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); + auto newFunction = mlir::Function::create( + f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs()); moduleManager.insert(newFunction); // Clone the function body mlir::BlockAndValueMapping mapper; - f->cloneInto(newFunction, mapper); + f.cloneInto(newFunction, mapper); LLVM_DEBUG({ llvm::dbgs() << "====== Cloned : \n"; - f->dump(); + f.dump(); llvm::dbgs() << "====== Into : \n"; - newFunction->dump(); + newFunction.dump(); }); f = newFunction; - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // Remap the entry-block arguments // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f->getBlocks().front(); + auto &entryBlock = f.getBlocks().front(); int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == static_cast(f->getType().getInputs().size())); - entryBlock.addArguments(f->getType().getInputs()); + assert(blockArgSize == static_cast(f.getType().getInputs().size())); + entryBlock.addArguments(f.getType().getInputs()); auto argList = entryBlock.getArguments(); for (int argNum = 0; argNum < blockArgSize; ++argNum) { argList[0]->replaceAllUsesWith(argList[blockArgSize]); entryBlock.eraseArgument(0); } - assert(succeeded(f->verify())); + assert(succeeded(f.verify())); } LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f->getName() << "'\n"); + << "Run shape inference on : '" << f.getName() << "'\n"); auto *toyDialect = getContext().getRegisteredDialect("toy"); if (!toyDialect) { @@ -212,7 +212,7 @@ public: // Populate the worklist with the operations that need shape inference: // these are the Toy operations that return a generic array. llvm::SmallPtrSet opWorklist; - f->walk([&](mlir::Operation *op) { + f.walk([&](mlir::Operation *op) { if (op->getDialect() == toyDialect) { if (op->getNumResults() == 1 && op->getResult(0)->getType().cast().isGeneric()) @@ -295,16 +295,16 @@ public: // restart after the callee is processed. if (auto callOp = llvm::dyn_cast(op)) { auto calleeName = callOp.getCalleeName(); - auto *callee = moduleManager.getNamedFunction(calleeName); + auto callee = moduleManager.getNamedFunction(calleeName); if (!callee) { signalPassFailure(); - return f->emitError("Shape inference failed, call to unknown '") + return f.emitError("Shape inference failed, call to unknown '") << calleeName << "'"; } auto mangledName = mangle(calleeName, op->getOpOperands()); LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName << "', mangled: '" << mangledName << "'\n"); - auto *mangledCallee = moduleManager.getNamedFunction(mangledName); + auto mangledCallee = moduleManager.getNamedFunction(mangledName); if (!mangledCallee) { // Can't find the target, this is where we queue the request for the // callee and stop the inference for the current function now. @@ -315,7 +315,7 @@ public: // Found a specialized callee! Let's turn this into a normal call // operation. SmallVector operands(op->getOperands()); - mlir::OpBuilder builder(f->getBody()); + mlir::OpBuilder builder(f.getBody()); builder.setInsertionPoint(op); auto newCall = builder.create(op->getLoc(), mangledCallee, operands); @@ -330,12 +330,12 @@ public: // Done with inference on this function, removing it from the worklist. funcWorklist.pop_back(); // Mark the function as non-generic now that inference has succeeded - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { signalPassFailure(); - auto diag = f->emitError("Shape inference failed, ") + auto diag = f.emitError("Shape inference failed, ") << opWorklist.size() << " operations couldn't be inferred\n"; for (auto *ope : opWorklist) diag << " - " << *ope << "\n"; @@ -344,24 +344,24 @@ public: // Finally, update the return type of the function based on the argument to // the return operation. - for (auto &block : f->getBlocks()) { + for (auto &block : f.getBlocks()) { auto ret = llvm::cast(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && - f->getType().getResult(0) == ret.getOperand()->getType()) + f.getType().getResult(0) == ret.getOperand()->getType()) // type match, we're done break; SmallVector retTy; if (ret.getNumOperands()) retTy.push_back(ret.getOperand()->getType()); std::vector argumentsType; - for (auto arg : f->getArguments()) + for (auto arg : f.getArguments()) argumentsType.push_back(arg->getType()); auto newType = mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f->setType(newType); - assert(succeeded(f->verify())); + f.setType(newType); + assert(succeeded(f.verify())); break; } return mlir::success(); diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index e69756e..8d7b2d5 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -34,7 +34,7 @@ template class DominanceInfoBase { using base = llvm::DominatorTreeBase; public: - DominanceInfoBase(Function *function) { recalculate(function); } + DominanceInfoBase(Function function) { recalculate(function); } DominanceInfoBase(Operation *op) { recalculate(op); } DominanceInfoBase(DominanceInfoBase &&) = default; DominanceInfoBase &operator=(DominanceInfoBase &&) = default; @@ -43,7 +43,7 @@ public: DominanceInfoBase &operator=(const DominanceInfoBase &) = delete; /// Recalculate the dominance info. - void recalculate(Function *function); + void recalculate(Function function); void recalculate(Operation *op); /// Get the root dominance node of the given region. diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 3ab24f8..b89011a 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -104,8 +104,8 @@ struct NestedPattern { NestedPattern &operator=(const NestedPattern &) = default; /// Returns all the top-level matches in `func`. - void match(Function *func, SmallVectorImpl *matches) { - func->walk([&](Operation *op) { matchOne(op, matches); }); + void match(Function func, SmallVectorImpl *matches) { + func.walk([&](Operation *op) { matchOne(op, matches); }); } /// Returns all the top-level matches in `op`. diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h index a2d982d..3d20eaf 100644 --- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h +++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h @@ -44,7 +44,7 @@ struct StaticFloatMemRef { /// each of the arguments, initialize the storage with `initialValue`, and /// return a list of type-erased descriptor pointers. llvm::Expected> -allocateMemRefArguments(Function *func, float initialValue = 0.0); +allocateMemRefArguments(Function func, float initialValue = 0.0); /// Free a list of type-erased descriptors to statically-shaped memrefs with /// element type f32. diff --git a/mlir/include/mlir/GPU/GPUDialect.h b/mlir/include/mlir/GPU/GPUDialect.h index 8f682ce..c0326de 100644 --- a/mlir/include/mlir/GPU/GPUDialect.h +++ b/mlir/include/mlir/GPU/GPUDialect.h @@ -44,7 +44,7 @@ public: /// Returns whether the given function is a kernel function, i.e., has the /// 'gpu.kernel' attribute. - static bool isKernel(Function *function); + static bool isKernel(Function function); }; /// Utility class for the GPU dialect to represent triples of `Value`s @@ -122,12 +122,12 @@ public: using Op::Op; static void build(Builder *builder, OperationState *result, - Function *kernelFunc, Value *gridSizeX, Value *gridSizeY, + Function kernelFunc, Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, ArrayRef kernelOperands); static void build(Builder *builder, OperationState *result, - Function *kernelFunc, KernelDim3 gridSize, + Function kernelFunc, KernelDim3 gridSize, KernelDim3 blockSize, ArrayRef kernelOperands); /// The kernel function specified by the operation's `kernel` attribute. diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 5b9bfca..b46e160 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -313,9 +313,8 @@ class FunctionAttr detail::StringAttributeStorage> { public: using Base::Base; - using ValueType = Function *; + using ValueType = StringRef; - static FunctionAttr get(Function *value); static FunctionAttr get(StringRef value, MLIRContext *ctx); /// Returns the name of the held function reference. diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index f4ecb4e..feae5c9 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -101,7 +101,7 @@ public: /// Returns the function that this block is part of, even if the block is /// nested under an operation region. - Function *getFunction(); + Function getFunction(); /// Insert this block (which must not already be in a function) right before /// the specified block. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 6ce5c22..e5c8c03 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -112,7 +112,7 @@ public: AffineMapAttr getAffineMapAttr(AffineMap map); IntegerSetAttr getIntegerSetAttr(IntegerSet set); TypeAttr getTypeAttr(Type type); - FunctionAttr getFunctionAttr(Function *value); + FunctionAttr getFunctionAttr(Function value); FunctionAttr getFunctionAttr(StringRef value); ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef values); diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 56d0661..4e82689 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -145,17 +145,13 @@ public: /// Verify an attribute from this dialect on the given function. Returns /// failure if the verification failed, success otherwise. - virtual LogicalResult verifyFunctionAttribute(Function *, NamedAttribute) { - return success(); - } + virtual LogicalResult verifyFunctionAttribute(Function, NamedAttribute); /// Verify an attribute from this dialect on the argument at 'argIndex' for /// the given function. Returns failure if the verification failed, success /// otherwise. - virtual LogicalResult - verifyFunctionArgAttribute(Function *, unsigned argIndex, NamedAttribute) { - return success(); - } + virtual LogicalResult verifyFunctionArgAttribute(Function, unsigned argIndex, + NamedAttribute); /// Verify an attribute from this dialect on the given operation. Returns /// failure if the verification failed, success otherwise. diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 8f3b3b0..e11a45b 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -29,29 +29,79 @@ namespace mlir { class BlockAndValueMapping; class FunctionType; +class Function; class MLIRContext; class Module; -/// This is the base class for all of the MLIR function types. -class Function : public llvm::ilist_node_with_parent { +namespace detail { +/// This class represents all of the internal state of a Function. This allows +/// for the Function class to be value typed. +class FunctionStorage + : public llvm::ilist_node_with_parent { + FunctionStorage(Location location, StringRef name, FunctionType type, + ArrayRef attrs = {}); + FunctionStorage(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs); + /// The name of the function. + Identifier name; + + /// The module this function is embedded into. + Module *module = nullptr; + + /// The source location the function was defined or derived from. + Location location; + + /// The type of the function. + FunctionType type; + + /// This holds general named attributes for the function. + NamedAttributeList attrs; + + /// The attributes lists for each of the function arguments. + std::vector argAttrs; + + /// The body of the function. + Region body; + + friend struct llvm::ilist_traits; + friend Function; +}; +} // namespace detail + +/// This class represents an MLIR function, or the common unit of computation. +/// The region of a function is not allowed to implicitly capture global values, +/// and all external references must use Function arguments or attributes. +class Function { public: - Function(Location location, StringRef name, FunctionType type, - ArrayRef attrs = {}); - Function(Location location, StringRef name, FunctionType type, - ArrayRef attrs, - ArrayRef argAttrs); + Function(detail::FunctionStorage *impl = nullptr) : impl(impl) {} + + static Function create(Location location, StringRef name, FunctionType type, + ArrayRef attrs = {}) { + return new detail::FunctionStorage(location, name, type, attrs); + } + static Function create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs) { + return new detail::FunctionStorage(location, name, type, attrs, argAttrs); + } + + /// Allow converting a Function to bool for null checks. + operator bool() const { return impl; } + bool operator==(Function other) const { return impl == other.impl; } + bool operator!=(Function other) const { return !(*this == other); } /// The source location the function was defined or derived from. - Location getLoc() { return location; } + Location getLoc() { return impl->location; } /// Set the source location this function was defined or derived from. - void setLoc(Location loc) { location = loc; } + void setLoc(Location loc) { impl->location = loc; } /// Return the name of this function, without the @. - Identifier getName() { return name; } + Identifier getName() { return impl->name; } /// Return the type of this function. - FunctionType getType() { return type; } + FunctionType getType() { return impl->type; } /// Change the type of this function in place. This is an extremely dangerous /// operation and it is up to the caller to ensure that this is legal for this @@ -61,12 +111,12 @@ public: /// parameters we drop the extra attributes, if there are more parameters /// they won't have any attributes. void setType(FunctionType newType) { - type = newType; - argAttrs.resize(type.getNumInputs()); + impl->type = newType; + impl->argAttrs.resize(newType.getNumInputs()); } MLIRContext *getContext(); - Module *getModule() { return module; } + Module *getModule() { return impl->module; } /// Add an entry block to an empty function, and set up the block arguments /// to match the signature of the function. @@ -82,28 +132,28 @@ public: // Body Handling //===--------------------------------------------------------------------===// - Region &getBody() { return body; } - void eraseBody() { body.getBlocks().clear(); } + Region &getBody() { return impl->body; } + void eraseBody() { getBody().getBlocks().clear(); } /// This is the list of blocks in the function. using RegionType = Region::RegionType; - RegionType &getBlocks() { return body.getBlocks(); } + RegionType &getBlocks() { return getBody().getBlocks(); } // Iteration over the block in the function. using iterator = RegionType::iterator; using reverse_iterator = RegionType::reverse_iterator; - iterator begin() { return body.begin(); } - iterator end() { return body.end(); } - reverse_iterator rbegin() { return body.rbegin(); } - reverse_iterator rend() { return body.rend(); } + iterator begin() { return getBody().begin(); } + iterator end() { return getBody().end(); } + reverse_iterator rbegin() { return getBody().rbegin(); } + reverse_iterator rend() { return getBody().rend(); } - bool empty() { return body.empty(); } - void push_back(Block *block) { body.push_back(block); } - void push_front(Block *block) { body.push_front(block); } + bool empty() { return getBody().empty(); } + void push_back(Block *block) { getBody().push_back(block); } + void push_front(Block *block) { getBody().push_front(block); } - Block &back() { return body.back(); } - Block &front() { return body.front(); } + Block &back() { return getBody().back(); } + Block &front() { return getBody().front(); } //===--------------------------------------------------------------------===// // Operation Walkers @@ -150,53 +200,55 @@ public: /// the lifetime of an function. /// Return all of the attributes on this function. - ArrayRef getAttrs() { return attrs.getAttrs(); } + ArrayRef getAttrs() { return impl->attrs.getAttrs(); } /// Return the internal attribute list on this function. - NamedAttributeList &getAttrList() { return attrs; } + NamedAttributeList &getAttrList() { return impl->attrs; } /// Return all of the attributes for the argument at 'index'. ArrayRef getArgAttrs(unsigned index) { assert(index < getNumArguments() && "invalid argument number"); - return argAttrs[index].getAttrs(); + return impl->argAttrs[index].getAttrs(); } /// Set the attributes held by this function. void setAttrs(ArrayRef attributes) { - attrs.setAttrs(attributes); + impl->attrs.setAttrs(attributes); } /// Set the attributes held by the argument at 'index'. void setArgAttrs(unsigned index, ArrayRef attributes) { assert(index < getNumArguments() && "invalid argument number"); - argAttrs[index].setAttrs(attributes); + impl->argAttrs[index].setAttrs(attributes); } void setArgAttrs(unsigned index, NamedAttributeList attributes) { assert(index < getNumArguments() && "invalid argument number"); - argAttrs[index] = attributes; + impl->argAttrs[index] = attributes; } void setAllArgAttrs(ArrayRef attributes) { assert(attributes.size() == getNumArguments()); for (unsigned i = 0, e = attributes.size(); i != e; ++i) - argAttrs[i] = attributes[i]; + impl->argAttrs[i] = attributes[i]; } /// Return all argument attributes of this function. - MutableArrayRef getAllArgAttrs() { return argAttrs; } + MutableArrayRef getAllArgAttrs() { + return impl->argAttrs; + } /// Return the specified attribute if present, null otherwise. - Attribute getAttr(Identifier name) { return attrs.get(name); } - Attribute getAttr(StringRef name) { return attrs.get(name); } + Attribute getAttr(Identifier name) { return impl->attrs.get(name); } + Attribute getAttr(StringRef name) { return impl->attrs.get(name); } /// Return the specified attribute, if present, for the argument at 'index', /// null otherwise. Attribute getArgAttr(unsigned index, Identifier name) { assert(index < getNumArguments() && "invalid argument number"); - return argAttrs[index].get(name); + return impl->argAttrs[index].get(name); } Attribute getArgAttr(unsigned index, StringRef name) { assert(index < getNumArguments() && "invalid argument number"); - return argAttrs[index].get(name); + return impl->argAttrs[index].get(name); } template AttrClass getAttrOfType(Identifier name) { @@ -219,13 +271,15 @@ public: /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. - void setAttr(Identifier name, Attribute value) { attrs.set(name, value); } + void setAttr(Identifier name, Attribute value) { + impl->attrs.set(name, value); + } void setAttr(StringRef name, Attribute value) { setAttr(Identifier::get(name, getContext()), value); } void setArgAttr(unsigned index, Identifier name, Attribute value) { assert(index < getNumArguments() && "invalid argument number"); - argAttrs[index].set(name, value); + impl->argAttrs[index].set(name, value); } void setArgAttr(unsigned index, StringRef name, Attribute value) { setArgAttr(index, Identifier::get(name, getContext()), value); @@ -234,12 +288,12 @@ public: /// Remove the attribute with the specified name if it exists. The return /// value indicates whether the attribute was present or not. NamedAttributeList::RemoveResult removeAttr(Identifier name) { - return attrs.remove(name); + return impl->attrs.remove(name); } NamedAttributeList::RemoveResult removeArgAttr(unsigned index, Identifier name) { assert(index < getNumArguments() && "invalid argument number"); - return attrs.remove(name); + return impl->attrs.remove(name); } //===--------------------------------------------------------------------===// @@ -281,44 +335,37 @@ public: /// contains entries for function arguments, these arguments are not included /// in the new function. Replaces references to cloned sub-values with the /// corresponding value that is copied, and adds those mappings to the mapper. - Function *clone(BlockAndValueMapping &mapper); - Function *clone(); + Function clone(BlockAndValueMapping &mapper); + Function clone(); /// Clone the internal blocks and attributes from this function into dest. Any /// cloned blocks are appended to the back of dest. This function asserts that /// the attributes of the current function and dest are compatible. - void cloneInto(Function *dest, BlockAndValueMapping &mapper); + void cloneInto(Function dest, BlockAndValueMapping &mapper); + + /// Methods for supporting PointerLikeTypeTraits. + const void *getAsOpaquePointer() const { + return static_cast(impl); + } + static Function getFromOpaquePointer(const void *pointer) { + return reinterpret_cast( + const_cast(pointer)); + } private: /// Set the name of this function. - void setName(Identifier newName) { name = newName; } - - /// The name of the function. - Identifier name; - - /// The module this function is embedded into. - Module *module = nullptr; - - /// The source location the function was defined or derived from. - Location location; - - /// The type of the function. - FunctionType type; - - /// This holds general named attributes for the function. - NamedAttributeList attrs; + void setName(Identifier newName) { impl->name = newName; } - /// The attributes lists for each of the function arguments. - std::vector argAttrs; - - /// The body of the function. - Region body; - - void operator=(Function &) = delete; - friend struct llvm::ilist_traits; + /// A pointer to the impl storage instance for this function. This allows for + /// 'Function' to be treated as a value type. + detail::FunctionStorage *impl = nullptr; // Allow access to 'setName'. friend class SymbolTable; + + // Allow access to 'impl'. + friend class Module; + friend class Region; }; //===--------------------------------------------------------------------===// @@ -487,21 +534,52 @@ private: namespace llvm { template <> -struct ilist_traits<::mlir::Function> - : public ilist_alloc_traits<::mlir::Function> { - using Function = ::mlir::Function; - using function_iterator = simple_ilist::iterator; +struct ilist_traits<::mlir::detail::FunctionStorage> + : public ilist_alloc_traits<::mlir::detail::FunctionStorage> { + using FunctionStorage = ::mlir::detail::FunctionStorage; + using function_iterator = simple_ilist::iterator; - static void deleteNode(Function *function) { delete function; } + static void deleteNode(FunctionStorage *function) { delete function; } - void addNodeToList(Function *function); - void removeNodeFromList(Function *function); - void transferNodesFromList(ilist_traits &otherList, + void addNodeToList(FunctionStorage *function); + void removeNodeFromList(FunctionStorage *function); + void transferNodesFromList(ilist_traits &otherList, function_iterator first, function_iterator last); private: mlir::Module *getContainingModule(); }; -} // end namespace llvm + +// Functions hash just like pointers. +template <> struct DenseMapInfo { + static mlir::Function getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::Function::getFromOpaquePointer(pointer); + } + static mlir::Function getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::Function::getFromOpaquePointer(pointer); + } + static unsigned getHashValue(mlir::Function val) { + return hash_value(val.getAsOpaquePointer()); + } + static bool isEqual(mlir::Function LHS, mlir::Function RHS) { + return LHS == RHS; + } +}; + +/// Allow stealing the low bits of FunctionStorage. +template <> struct PointerLikeTypeTraits { +public: + static inline void *getAsVoidPointer(mlir::Function I) { + return const_cast(I.getAsOpaquePointer()); + } + static inline mlir::Function getFromVoidPointer(void *P) { + return mlir::Function::getFromOpaquePointer(P); + } + enum { NumLowBitsAvailable = 3 }; +}; + +} // namespace llvm #endif // MLIR_IR_FUNCTION_H diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index 8161a30..d8a4789 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -34,34 +34,54 @@ public: MLIRContext *getContext() { return context; } + /// An iterator class used to iterate over the held functions. + class iterator : public llvm::mapped_iterator< + llvm::iplist::iterator, + Function (*)(detail::FunctionStorage &)> { + static Function unwrap(detail::FunctionStorage &impl) { return &impl; } + + public: + using reference = Function; + + /// Initializes the operand type iterator to the specified operand iterator. + iterator(llvm::iplist::iterator it) + : llvm::mapped_iterator::iterator, + Function (*)(detail::FunctionStorage &)>( + it, &unwrap) {} + iterator(Function it) + : iterator(llvm::iplist::iterator(it.impl)) {} + }; + /// This is the list of functions in the module. - using FunctionListType = llvm::iplist; - FunctionListType &getFunctions() { return functions; } + llvm::iterator_range getFunctions() { return {begin(), end()}; } // Iteration over the functions in the module. - using iterator = FunctionListType::iterator; - using reverse_iterator = FunctionListType::reverse_iterator; - iterator begin() { return functions.begin(); } iterator end() { return functions.end(); } - reverse_iterator rbegin() { return functions.rbegin(); } - reverse_iterator rend() { return functions.rend(); } + Function front() { return &functions.front(); } + Function back() { return &functions.back(); } + + void push_back(Function fn) { functions.push_back(fn.impl); } + void insert(iterator insertPt, Function fn) { + functions.insert(insertPt.getCurrent(), fn.impl); + } // Interfaces for working with the symbol table. /// Look up a function with the specified name, returning null if no such /// name exists. Function names never include the @ on them. Note: This /// performs a linear scan of held symbols. - Function *getNamedFunction(StringRef name) { + Function getNamedFunction(StringRef name) { return getNamedFunction(Identifier::get(name, getContext())); } /// Look up a function with the specified name, returning null if no such /// name exists. Function names never include the @ on them. Note: This /// performs a linear scan of held symbols. - Function *getNamedFunction(Identifier name) { - auto it = llvm::find_if( - functions, [name](Function &fn) { return fn.getName() == name; }); + Function getNamedFunction(Identifier name) { + auto it = llvm::find_if(functions, [name](detail::FunctionStorage &fn) { + return Function(&fn).getName() == name; + }); return it == functions.end() ? nullptr : &*it; } @@ -74,11 +94,13 @@ public: void dump(); private: - friend struct llvm::ilist_traits; - friend class Function; + friend struct llvm::ilist_traits; + friend detail::FunctionStorage; + friend Function; /// getSublistAccess() - Returns pointer to member of function list - static FunctionListType Module::*getSublistAccess(Function *) { + static llvm::iplist Module::* + getSublistAccess(detail::FunctionStorage *) { return &Module::functions; } @@ -86,7 +108,7 @@ private: MLIRContext *context; /// This is the actual list of functions the module contains. - FunctionListType functions; + llvm::iplist functions; }; /// A class used to manage the symbols held by a module. This class handles @@ -98,24 +120,24 @@ public: /// Look up a symbol with the specified name, returning null if no such /// name exists. Names must never include the @ on them. - template Function *getNamedFunction(NameTy &&name) const { + template Function getNamedFunction(NameTy &&name) const { return symbolTable.lookup(name); } /// Insert a new symbol into the module, auto-renaming it as necessary. - void insert(Function *function) { + void insert(Function function) { symbolTable.insert(function); - module->getFunctions().push_back(function); + module->push_back(function); } - void insert(Module::iterator insertPt, Function *function) { + void insert(Module::iterator insertPt, Function function) { symbolTable.insert(function); - module->getFunctions().insert(insertPt, function); + module->insert(insertPt, function); } /// Remove the given symbol from the module symbol table and then erase it. - void erase(Function *function) { + void erase(Function function) { symbolTable.erase(function); - function->erase(); + function.erase(); } /// Return the internally held module. diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index e532399..f916f4b 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -128,7 +128,7 @@ public: /// Returns the function that this operation is part of. /// The function is determined by traversing the chain of parent operations. /// Returns nullptr if the operation is unlinked. - Function *getFunction(); + Function getFunction(); /// Replace any uses of 'from' with 'to' within this operation. void replaceUsesOfWith(Value *from, Value *to); diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index a1b81fc..9214376 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -420,7 +420,7 @@ private: /// patterns in a greedy work-list driven manner. Return true if no more /// patterns can be matched in the result function. /// -bool applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns); +bool applyPatternsGreedily(Function fn, OwningRewritePatternList &&patterns); /// Helper class to create a list of rewrite patterns given a list of their /// types and a list of attributes perfect-forwarded to each of the conversion diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h index 2189ad4..ad0692b 100644 --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -27,11 +27,16 @@ namespace mlir { class BlockAndValueMapping; +namespace detail { +class FunctionStorage; +} + /// This class contains a list of basic blocks and has a notion of the object it /// is part of - a Function or an Operation. class Region { public: - explicit Region(Function *container = nullptr); + Region() = default; + explicit Region(Function container); explicit Region(Operation *container); ~Region(); @@ -77,7 +82,7 @@ public: /// A Region is either a function body or a part of an operation. If it is /// a Function body, then return this function, otherwise return null. - Function *getContainingFunction(); + Function getContainingFunction(); /// Return true if this region is a proper ancestor of the `other` region. bool isProperAncestor(Region *other); @@ -118,7 +123,7 @@ private: RegionType blocks; /// This is the object we are part of. - llvm::PointerUnion container; + llvm::PointerUnion container; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index 3074958..a351f66 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -18,7 +18,7 @@ #ifndef MLIR_IR_SYMBOLTABLE_H #define MLIR_IR_SYMBOLTABLE_H -#include "mlir/IR/Identifier.h" +#include "mlir/IR/Function.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -35,18 +35,18 @@ public: /// Look up a symbol with the specified name, returning null if no such /// name exists. Names never include the @ on them. - Function *lookup(StringRef name) const; + Function lookup(StringRef name) const; /// Look up a symbol with the specified name, returning null if no such /// name exists. Names never include the @ on them. - Function *lookup(Identifier name) const; + Function lookup(Identifier name) const; /// Erase the given symbol from the table. - void erase(Function *symbol); + void erase(Function symbol); /// Insert a new symbol into the table, and rename it as necessary to avoid /// collisions. - void insert(Function *symbol); + void insert(Function symbol); /// Returns the context held by this symbol table. MLIRContext *getContext() const { return context; } @@ -55,7 +55,7 @@ private: MLIRContext *context; /// This is a mapping from a name to the function with that name. - llvm::DenseMap symbolTable; + llvm::DenseMap symbolTable; /// This is used when name conflicts are detected. unsigned uniquingCounter = 0; diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index e90505e..4604ed9 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -72,7 +72,7 @@ public: } /// Return the function that this Value is defined in. - Function *getFunction(); + Function getFunction(); /// If this value is the result of an operation, return the operation that /// defines it. @@ -128,7 +128,7 @@ public: } /// Return the function that this argument is defined in. - Function *getFunction(); + Function getFunction(); Block *getOwner() { return owner; } diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h index bd3286d..a28aa71 100644 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -153,7 +153,7 @@ public: /// Verify a function argument attribute registered to this dialect. /// Returns failure if the verification failed, success otherwise. - LogicalResult verifyFunctionArgAttribute(Function *func, unsigned argIdx, + LogicalResult verifyFunctionArgAttribute(Function func, unsigned argIdx, NamedAttribute argAttr) override; private: diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h index 3751a936..c44f88f 100644 --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -106,7 +106,7 @@ template class AnalysisMap { } public: - explicit AnalysisMap(IRUnitT *ir) : ir(ir) {} + explicit AnalysisMap(IRUnitT ir) : ir(ir) {} /// Get an analysis for the current IR unit, computing it if necessary. template AnalysisT &getAnalysis(PassInstrumentor *pi) { @@ -140,8 +140,8 @@ public: } /// Returns the IR unit that this analysis map represents. - IRUnitT *getIRUnit() { return ir; } - const IRUnitT *getIRUnit() const { return ir; } + IRUnitT getIRUnit() { return ir; } + const IRUnitT getIRUnit() const { return ir; } /// Clear any held analyses. void clear() { analyses.clear(); } @@ -158,7 +158,7 @@ public: } private: - IRUnitT *ir; + IRUnitT ir; ConceptMap analyses; }; @@ -231,14 +231,14 @@ public: /// Query for the analysis of a function. The analysis is computed if it does /// not exist. template - AnalysisT &getFunctionAnalysis(Function *function) { + AnalysisT &getFunctionAnalysis(Function function) { return slice(function).getAnalysis(); } /// Query for a cached analysis of a child function, or return null. template llvm::Optional> - getCachedFunctionAnalysis(Function *function) const { + getCachedFunctionAnalysis(Function function) const { auto it = functionAnalyses.find(function); if (it == functionAnalyses.end()) return llvm::None; @@ -258,7 +258,7 @@ public: } /// Create an analysis slice for the given child function. - FunctionAnalysisManager slice(Function *function); + FunctionAnalysisManager slice(Function function); /// Invalidate any non preserved analyses. void invalidate(const detail::PreservedAnalyses &pa); @@ -269,11 +269,11 @@ public: private: /// The cached analyses for functions within the current module. - llvm::DenseMap>> + llvm::DenseMap>> functionAnalyses; /// The analyses for the owning module. - detail::AnalysisMap moduleAnalyses; + detail::AnalysisMap moduleAnalyses; /// An optional instrumentation object. PassInstrumentor *passInstrumentor; diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 5fd6dfd..41d20cc 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -70,12 +70,12 @@ class ModulePassExecutor; /// interface for accessing and initializing necessary state for pass execution. template struct PassExecutionState { - PassExecutionState(IRUnitT *ir, AnalysisManagerT &analysisManager) + PassExecutionState(IRUnitT ir, AnalysisManagerT &analysisManager) : irAndPassFailed(ir, false), analysisManager(analysisManager) {} /// The current IR unit being transformed and a bool for if the pass signaled /// a failure. - llvm::PointerIntPair irAndPassFailed; + llvm::PointerIntPair irAndPassFailed; /// The analysis manager for the IR unit. AnalysisManagerT &analysisManager; @@ -107,9 +107,7 @@ protected: virtual FunctionPassBase *clone() const = 0; /// Return the current function being transformed. - Function &getFunction() { - return *getPassState().irAndPassFailed.getPointer(); - } + Function getFunction() { return getPassState().irAndPassFailed.getPointer(); } /// Return the MLIR context for the current function being transformed. MLIRContext &getContext() { return *getFunction().getContext(); } @@ -128,7 +126,7 @@ protected: private: /// Forwarding function to execute this pass. LLVM_NODISCARD - LogicalResult run(Function *fn, FunctionAnalysisManager &fam); + LogicalResult run(Function fn, FunctionAnalysisManager &fam); /// The current execution state for the pass. llvm::Optional passState; @@ -140,7 +138,8 @@ private: /// Pass to transform a module. Derived passes should not inherit from this /// class directly, and instead should use the CRTP ModulePass class. class ModulePassBase : public Pass { - using PassStateT = detail::PassExecutionState; + using PassStateT = + detail::PassExecutionState; public: static bool classof(const Pass *pass) { @@ -272,7 +271,7 @@ struct FunctionPass : public detail::PassModel { template struct ModulePass : public detail::PassModel { /// Returns the analysis for a child function. - template AnalysisT &getFunctionAnalysis(Function *f) { + template AnalysisT &getFunctionAnalysis(Function f) { return this->getAnalysisManager().template getFunctionAnalysis( f); } @@ -280,7 +279,7 @@ struct ModulePass : public detail::PassModel { /// Returns an existing analysis for a child function if it exists. template llvm::Optional> - getCachedFunctionAnalysis(Function *f) { + getCachedFunctionAnalysis(Function f) { return this->getAnalysisManager() .template getCachedFunctionAnalysis(f); } diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h index 0f42706..4035832 100644 --- a/mlir/include/mlir/Pass/PassInstrumentation.h +++ b/mlir/include/mlir/Pass/PassInstrumentation.h @@ -77,29 +77,29 @@ public: ~PassInstrumentor(); /// See PassInstrumentation::runBeforePass for details. - template void runBeforePass(Pass *pass, IRUnitT *ir) { + template void runBeforePass(Pass *pass, IRUnitT ir) { runBeforePass(pass, llvm::Any(ir)); } /// See PassInstrumentation::runAfterPass for details. - template void runAfterPass(Pass *pass, IRUnitT *ir) { + template void runAfterPass(Pass *pass, IRUnitT ir) { runAfterPass(pass, llvm::Any(ir)); } /// See PassInstrumentation::runAfterPassFailed for details. - template void runAfterPassFailed(Pass *pass, IRUnitT *ir) { + template void runAfterPassFailed(Pass *pass, IRUnitT ir) { runAfterPassFailed(pass, llvm::Any(ir)); } /// See PassInstrumentation::runBeforeAnalysis for details. template - void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) { + void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) { runBeforeAnalysis(name, id, llvm::Any(ir)); } /// See PassInstrumentation::runAfterAnalysis for details. template - void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) { + void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) { runAfterAnalysis(name, id, llvm::Any(ir)); } diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 1b14e2a..a7afe1f 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -214,11 +214,11 @@ def CallOp : Std_Op<"call"> { let results = (outs Variadic); let builders = [OpBuilder< - "Builder *builder, OperationState *result, Function *callee," + "Builder *builder, OperationState *result, Function callee," "ArrayRef operands = {}", [{ result->addOperands(operands); result->addAttribute("callee", builder->getFunctionAttr(callee)); - result->addTypes(callee->getType().getResults()); + result->addTypes(callee.getType().getResults()); }]>, OpBuilder< "Builder *builder, OperationState *result, StringRef callee," "ArrayRef results, ArrayRef operands = {}", [{ diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 00da0d5..c8ede78 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -345,7 +345,7 @@ LLVM_NODISCARD LogicalResult applyConversionPatterns( /// Convert the given functions with the provided conversion patterns. This /// function returns failure if a type conversion failed. LLVM_NODISCARD -LogicalResult applyConversionPatterns(ArrayRef fns, +LogicalResult applyConversionPatterns(MutableArrayRef fns, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns); @@ -354,7 +354,7 @@ LogicalResult applyConversionPatterns(ArrayRef fns, /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LLVM_NODISCARD -LogicalResult applyConversionPatterns(Function &fn, ConversionTarget &target, +LogicalResult applyConversionPatterns(Function fn, ConversionTarget &target, OwningRewritePatternList &&patterns); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/LowerAffine.h b/mlir/include/mlir/Transforms/LowerAffine.h index d77b35a..09aa7dc 100644 --- a/mlir/include/mlir/Transforms/LowerAffine.h +++ b/mlir/include/mlir/Transforms/LowerAffine.h @@ -37,7 +37,7 @@ Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, /// Convert from the Affine dialect to the Standard dialect, in particular /// convert structured affine control flow into CFG branch-based control flow. -LogicalResult lowerAffineConstructs(Function &function); +LogicalResult lowerAffineConstructs(Function function); /// Emit code that computes the lower bound of the given affine loop using /// standard arithmetic operations. diff --git a/mlir/include/mlir/Transforms/ViewFunctionGraph.h b/mlir/include/mlir/Transforms/ViewFunctionGraph.h index c1da5ef..5780df5 100644 --- a/mlir/include/mlir/Transforms/ViewFunctionGraph.h +++ b/mlir/include/mlir/Transforms/ViewFunctionGraph.h @@ -33,11 +33,11 @@ class FunctionPassBase; /// Displays the CFG in a window. This is for use from the debugger and /// depends on Graphviz to generate the graph. -void viewGraph(Function &function, const Twine &name, bool shortNames = false, +void viewGraph(Function function, const Twine &name, bool shortNames = false, const Twine &title = "", llvm::GraphProgram::Name program = llvm::GraphProgram::DOT); -llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function &function, +llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function function, bool shortNames = false, const Twine &title = ""); /// Creates a pass to print CFG graphs. diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 016ef43..d7650dc 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -303,7 +303,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { if (inserted) { reorderedDims.push_back(v); } - return getAffineDimExpr(iterPos->second, v->getFunction()->getContext()) + return getAffineDimExpr(iterPos->second, v->getFunction().getContext()) .cast(); } diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 954a01b..b4cdeb7 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -37,17 +37,16 @@ template class llvm::DomTreeNodeBase; /// Recalculate the dominance info. template -void DominanceInfoBase::recalculate(Function *function) { +void DominanceInfoBase::recalculate(Function function) { dominanceInfos.clear(); // Build the top level function dominance. auto functionDominance = llvm::make_unique(); - functionDominance->recalculate(function->getBody()); - dominanceInfos.try_emplace(&function->getBody(), - std::move(functionDominance)); + functionDominance->recalculate(function.getBody()); + dominanceInfos.try_emplace(&function.getBody(), std::move(functionDominance)); /// Build the dominance for each of the operation regions. - function->walk([&](Operation *op) { + function.walk([&](Operation *op) { for (auto ®ion : op->getRegions()) { // Don't compute dominance if the region is empty. if (region.empty()) diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 5177afc..75a2fc1 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -45,7 +45,7 @@ void PrintOpStatsPass::runOnModule() { opCount.clear(); // Compute the operation statistics for each function in the module. - for (auto &fn : getModule()) + for (auto fn : getModule()) fn.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); printSummary(); } diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index cbda6d4..473d253 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -43,7 +43,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() { // Walks the function and emits a note for all 'affine.for' ops detected as // parallel. void TestParallelismDetection::runOnFunction() { - Function &f = getFunction(); + Function f = getFunction(); OpBuilder b(f.getBody()); f.walk([&](AffineForOp forOp) { if (isLoopParallel(forOp)) diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 1330fe0..0d05251 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -53,7 +53,7 @@ public: : ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {} /// Verify the body of the given function. - LogicalResult verify(Function &fn); + LogicalResult verify(Function fn); /// Verify the given operation. LogicalResult verify(Operation &op); @@ -104,7 +104,7 @@ private: } // end anonymous namespace /// Verify the body of the given function. -LogicalResult OperationVerifier::verify(Function &fn) { +LogicalResult OperationVerifier::verify(Function fn) { // Verify the body first. if (failed(verifyRegion(fn.getBody()))) return failure(); @@ -113,7 +113,7 @@ LogicalResult OperationVerifier::verify(Function &fn) { // check. We do this as a second pass since malformed CFG's can cause // dominator analysis constructure to crash and we want the verifier to be // resilient to malformed code. - DominanceInfo theDomInfo(&fn); + DominanceInfo theDomInfo(fn); domInfo = &theDomInfo; if (failed(verifyDominance(fn.getBody()))) return failure(); @@ -313,7 +313,7 @@ LogicalResult Function::verify() { // Verify this attribute with the defining dialect. if (auto *dialect = opVerifier.getDialectForAttribute(attr)) - if (failed(dialect->verifyFunctionAttribute(this, attr))) + if (failed(dialect->verifyFunctionAttribute(*this, attr))) return failure(); } @@ -331,7 +331,7 @@ LogicalResult Function::verify() { // Verify this attribute with the defining dialect. if (auto *dialect = opVerifier.getDialectForAttribute(attr)) - if (failed(dialect->verifyFunctionArgAttribute(this, i, attr))) + if (failed(dialect->verifyFunctionArgAttribute(*this, i, attr))) return failure(); } } @@ -369,7 +369,7 @@ LogicalResult Operation::verify() { LogicalResult Module::verify() { // Check that all functions are uniquely named. llvm::StringMap nameToOrigLoc; - for (auto &fn : *this) { + for (auto fn : *this) { auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc()); if (!it.second) return fn.emitError() @@ -379,7 +379,7 @@ LogicalResult Module::verify() { } // Check that each function is correct. - for (auto &fn : *this) + for (auto fn : *this) if (failed(fn.verify())) return failure(); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 9d7aeeb..022d8c7 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -64,8 +64,8 @@ public: LLVMInitializeNVPTXTargetMC(); LLVMInitializeNVPTXAsmPrinter(); - for (auto &function : getModule()) { - if (!gpu::GPUDialect::isKernel(&function) || function.isExternal()) { + for (auto function : getModule()) { + if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) { continue; } if (failed(translateGpuKernelToCubinAnnotation(function))) @@ -142,7 +142,7 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) { std::unique_ptr module(builder.createModule()); // TODO(herhut): Also handle called functions. - module->getFunctions().push_back(function.clone()); + module->push_back(function.clone()); auto llvmModule = translateModuleToNVVMIR(*module); auto cubin = convertModuleToCubin(*llvmModule, function); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index bd96f39..f9d5899 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -118,7 +118,7 @@ private: void declareCudaFunctions(Location loc); Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); - Value *generateKernelNameConstant(Function *kernelFunction, Location &loc, + Value *generateKernelNameConstant(Function kernelFunction, Location &loc, OpBuilder &builder); void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); @@ -130,7 +130,7 @@ public: // Cache the used LLVM types. initializeCachedTypes(); - for (auto &func : getModule()) { + for (auto func : getModule()) { func.walk( [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); }); } @@ -155,66 +155,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { Module &module = getModule(); Builder builder(&module); if (!module.getNamedFunction(cuModuleLoadName)) { - module.getFunctions().push_back( - new Function(loc, cuModuleLoadName, - builder.getFunctionType( - { - getPointerPointerType(), /* CUmodule *module */ - getPointerType() /* void *cubin */ - }, - getCUResultType()))); + module.push_back( + Function::create(loc, cuModuleLoadName, + builder.getFunctionType( + { + getPointerPointerType(), /* CUmodule *module */ + getPointerType() /* void *cubin */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuModuleGetFunctionName)) { // The helper uses void* instead of CUDA's opaque CUmodule and // CUfunction. - module.getFunctions().push_back( - new Function(loc, cuModuleGetFunctionName, - builder.getFunctionType( - { - getPointerPointerType(), /* void **function */ - getPointerType(), /* void *module */ - getPointerType() /* char *name */ - }, - getCUResultType()))); + module.push_back( + Function::create(loc, cuModuleGetFunctionName, + builder.getFunctionType( + { + getPointerPointerType(), /* void **function */ + getPointerType(), /* void *module */ + getPointerType() /* char *name */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuLaunchKernelName)) { // Other than the CUDA api, the wrappers use uintptr_t to match the // LLVM type if MLIR's index type, which the GPU dialect uses. // Furthermore, they use void* instead of CUDA's opaque CUfunction and // CUstream. - module.getFunctions().push_back( - new Function(loc, cuLaunchKernelName, - builder.getFunctionType( - { - getPointerType(), /* void* f */ - getIntPtrType(), /* intptr_t gridXDim */ - getIntPtrType(), /* intptr_t gridyDim */ - getIntPtrType(), /* intptr_t gridZDim */ - getIntPtrType(), /* intptr_t blockXDim */ - getIntPtrType(), /* intptr_t blockYDim */ - getIntPtrType(), /* intptr_t blockZDim */ - getInt32Type(), /* unsigned int sharedMemBytes */ - getPointerType(), /* void *hstream */ - getPointerPointerType(), /* void **kernelParams */ - getPointerPointerType() /* void **extra */ - }, - getCUResultType()))); + module.push_back(Function::create( + loc, cuLaunchKernelName, + builder.getFunctionType( + { + getPointerType(), /* void* f */ + getIntPtrType(), /* intptr_t gridXDim */ + getIntPtrType(), /* intptr_t gridyDim */ + getIntPtrType(), /* intptr_t gridZDim */ + getIntPtrType(), /* intptr_t blockXDim */ + getIntPtrType(), /* intptr_t blockYDim */ + getIntPtrType(), /* intptr_t blockZDim */ + getInt32Type(), /* unsigned int sharedMemBytes */ + getPointerType(), /* void *hstream */ + getPointerPointerType(), /* void **kernelParams */ + getPointerPointerType() /* void **extra */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuGetStreamHelperName)) { // Helper function to get the current CUDA stream. Uses void* instead of // CUDAs opaque CUstream. - module.getFunctions().push_back(new Function( + module.push_back(Function::create( loc, cuGetStreamHelperName, builder.getFunctionType({}, getPointerType() /* void *stream */))); } if (!module.getNamedFunction(cuStreamSynchronizeName)) { - module.getFunctions().push_back( - new Function(loc, cuStreamSynchronizeName, - builder.getFunctionType( - { - getPointerType() /* CUstream stream */ - }, - getCUResultType()))); + module.push_back( + Function::create(loc, cuStreamSynchronizeName, + builder.getFunctionType( + { + getPointerType() /* CUstream stream */ + }, + getCUResultType()))); } } @@ -264,14 +264,14 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, // %0[n] = constant name[n] // %0[n+1] = 0 Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( - Function *kernelFunction, Location &loc, OpBuilder &builder) { + Function kernelFunction, Location &loc, OpBuilder &builder) { // TODO(herhut): Make this a constant once this is supported. auto kernelNameSize = builder.create( loc, getInt32Type(), - builder.getI32IntegerAttr(kernelFunction->getName().size() + 1)); + builder.getI32IntegerAttr(kernelFunction.getName().size() + 1)); auto kernelName = builder.create(loc, getPointerType(), kernelNameSize); - for (auto byte : llvm::enumerate(kernelFunction->getName())) { + for (auto byte : llvm::enumerate(kernelFunction.getName())) { auto index = builder.create( loc, getInt32Type(), builder.getI32IntegerAttr(byte.index())); auto gep = builder.create(loc, getPointerType(), kernelName, @@ -284,7 +284,7 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( // Add trailing zero to terminate string. auto index = builder.create( loc, getInt32Type(), - builder.getI32IntegerAttr(kernelFunction->getName().size())); + builder.getI32IntegerAttr(kernelFunction.getName().size())); auto gep = builder.create(loc, getPointerType(), kernelName, ArrayRef{index}); auto value = builder.create( @@ -326,9 +326,9 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( // TODO(herhut): This should rather be a static global once supported. auto kernelFunction = getModule().getNamedFunction(launchOp.kernel()); auto cubinGetter = - kernelFunction->getAttrOfType(kCubinGetterAnnotation); + kernelFunction.getAttrOfType(kCubinGetterAnnotation); if (!cubinGetter) { - kernelFunction->emitError("Missing ") + kernelFunction.emitError("Missing ") << kCubinGetterAnnotation << " attribute."; return signalPassFailure(); } @@ -337,7 +337,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( // Emit the load module call to load the module data. Error checking is done // in the called helper function. auto cuModule = allocatePointer(builder, loc); - Function *cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName); + Function cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName); builder.create(loc, ArrayRef{getCUResultType()}, builder.getFunctionAttr(cuModuleLoad), ArrayRef{cuModule, data.getResult(0)}); @@ -347,14 +347,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.create(loc, getPointerType(), cuModule); auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder); auto cuFunction = allocatePointer(builder, loc); - Function *cuModuleGetFunction = + Function cuModuleGetFunction = getModule().getNamedFunction(cuModuleGetFunctionName); builder.create( loc, ArrayRef{getCUResultType()}, builder.getFunctionAttr(cuModuleGetFunction), ArrayRef{cuFunction, cuModuleRef, kernelName}); // Grab the global stream needed for execution. - Function *cuGetStreamHelper = + Function cuGetStreamHelper = getModule().getNamedFunction(cuGetStreamHelperName); auto cuStream = builder.create( loc, ArrayRef{getPointerType()}, diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index c1d4af3..97790a5 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -53,15 +53,15 @@ constexpr const char *kMallocHelperName = "mcuMalloc"; class GpuGenerateCubinAccessorsPass : public ModulePass { private: - Function *getMallocHelper(Location loc, Builder &builder) { - Function *result = getModule().getNamedFunction(kMallocHelperName); + Function getMallocHelper(Location loc, Builder &builder) { + Function result = getModule().getNamedFunction(kMallocHelperName); if (!result) { - result = new Function( + result = Function::create( loc, kMallocHelperName, builder.getFunctionType( ArrayRef{LLVM::LLVMType::getInt32Ty(llvmDialect)}, LLVM::LLVMType::getInt8PtrTy(llvmDialect))); - getModule().getFunctions().push_back(result); + getModule().push_back(result); } return result; } @@ -70,18 +70,18 @@ private: // data from blob. As there are currently no global constants, this uses a // sequence of store operations. // TODO(herhut): Use global constants instead. - Function *generateCubinAccessor(Builder &builder, Function &orig, - StringAttr blob) { + Function generateCubinAccessor(Builder &builder, Function &orig, + StringAttr blob) { Location loc = orig.getLoc(); SmallString<128> nameBuffer(orig.getName()); nameBuffer.append(kCubinGetterSuffix); // Generate a function that returns void*. - Function *result = new Function( + Function result = Function::create( loc, mlir::Identifier::get(nameBuffer, &getContext()), builder.getFunctionType(ArrayRef{}, LLVM::LLVMType::getInt8PtrTy(llvmDialect))); // Insert a body block that just returns the constant. - OpBuilder ob(result->getBody()); + OpBuilder ob(result.getBody()); ob.createBlock(); auto sizeConstant = ob.create( loc, LLVM::LLVMType::getInt32Ty(llvmDialect), @@ -115,18 +115,18 @@ public: void runOnModule() override { llvmDialect = getModule().getContext()->getRegisteredDialect(); - Builder builder(getModule().getContext()); + auto &module = getModule(); + Builder builder(&getContext()); - auto &functions = getModule().getFunctions(); + auto functions = module.getFunctions(); for (auto it = functions.begin(); it != functions.end();) { // Move iterator to after the current function so that potential insertion // of the accessor is after the kernel with cubin iself. - Function &orig = *it++; + Function orig = *it++; StringAttr cubinBlob = orig.getAttrOfType(kCubinAnnotation); if (!cubinBlob) continue; - it = - functions.insert(it, generateCubinAccessor(builder, orig, cubinBlob)); + module.insert(it, generateCubinAccessor(builder, orig, cubinBlob)); } } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 8727078..e849f6f 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -441,13 +441,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern { createIndexConstant(rewriter, op->getLoc(), elementSize)}); // Insert the `malloc` declaration if it is not already present. - Function *mallocFunc = - op->getFunction()->getModule()->getNamedFunction("malloc"); + Function mallocFunc = + op->getFunction().getModule()->getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); - mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType); - op->getFunction()->getModule()->getFunctions().push_back(mallocFunc); + mallocFunc = + Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); + op->getFunction().getModule()->push_back(mallocFunc); } // Allocate the underlying buffer and store a pointer to it in the MemRef @@ -502,12 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. - Function *freeFunc = - op->getFunction()->getModule()->getNamedFunction("free"); + Function freeFunc = op->getFunction().getModule()->getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); - freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType); - op->getFunction()->getModule()->getFunctions().push_back(freeFunc); + freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); + op->getFunction().getModule()->push_back(freeFunc); } auto type = transformed.memref()->getType().cast(); @@ -937,7 +937,7 @@ static void ensureDistinctSuccessors(Block &bb) { } void mlir::LLVM::ensureDistinctSuccessors(Module *m) { - for (auto &f : *m) { + for (auto f : *m) { for (auto &bb : f.getBlocks()) { ::ensureDistinctSuccessors(bb); } diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index ff19821..dafc8e7 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -365,7 +365,7 @@ struct UniformRealMulEwPattern : public OpRewritePattern { //===----------------------------------------------------------------------===// void LowerUniformRealMathPass::runOnFunction() { - auto &fn = getFunction(); + auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); patterns.push_back(llvm::make_unique(context)); @@ -386,7 +386,7 @@ static PassRegistration lowerUniformRealMathPass( //===----------------------------------------------------------------------===// void LowerUniformCastsPass::runOnFunction() { - auto &fn = getFunction(); + auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); patterns.push_back(llvm::make_unique(context)); diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 9dcc6df..8469fa2 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -106,7 +106,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, void ConvertConstPass::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); auto *context = &getContext(); patterns.push_back(llvm::make_unique(context)); applyPatternsGreedily(func, std::move(patterns)); diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index ea8095b..0c93146 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -95,7 +95,7 @@ public: void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); auto *context = &getContext(); patterns.push_back( llvm::make_unique(context, &hadFailure)); diff --git a/mlir/lib/ExecutionEngine/MemRefUtils.cpp b/mlir/lib/ExecutionEngine/MemRefUtils.cpp index 5163603..f13b743d 100644 --- a/mlir/lib/ExecutionEngine/MemRefUtils.cpp +++ b/mlir/lib/ExecutionEngine/MemRefUtils.cpp @@ -67,10 +67,10 @@ allocMemRefDescriptor(Type type, bool allocateData = true, } llvm::Expected> -mlir::allocateMemRefArguments(Function *func, float initialValue) { +mlir::allocateMemRefArguments(Function func, float initialValue) { SmallVector args; - args.reserve(func->getNumArguments()); - for (const auto &arg : func->getArguments()) { + args.reserve(func.getNumArguments()); + for (const auto &arg : func.getArguments()) { auto descriptor = allocMemRefDescriptor(arg->getType(), /*allocateData=*/true, initialValue); @@ -79,10 +79,10 @@ mlir::allocateMemRefArguments(Function *func, float initialValue) { args.push_back(*descriptor); } - if (func->getType().getNumResults() > 1) + if (func.getType().getNumResults() > 1) return make_string_error("functions with more than 1 result not supported"); - for (Type resType : func->getType().getResults()) { + for (Type resType : func.getType().getResults()) { auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false); if (!descriptor) return descriptor.takeError(); diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index e39860b..5e8090b4 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -30,9 +30,9 @@ using namespace mlir::gpu; StringRef GPUDialect::getDialectName() { return "gpu"; } -bool GPUDialect::isKernel(Function *function) { +bool GPUDialect::isKernel(Function function) { UnitAttr isKernelAttr = - function->getAttrOfType(getKernelFuncAttrName()); + function.getAttrOfType(getKernelFuncAttrName()); return static_cast(isKernelAttr); } @@ -318,7 +318,7 @@ ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) { //===----------------------------------------------------------------------===// void LaunchFuncOp::build(Builder *builder, OperationState *result, - Function *kernelFunc, Value *gridSizeX, + Function kernelFunc, Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, ArrayRef kernelOperands) { @@ -331,7 +331,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result, } void LaunchFuncOp::build(Builder *builder, OperationState *result, - Function *kernelFunc, KernelDim3 gridSize, + Function kernelFunc, KernelDim3 gridSize, KernelDim3 blockSize, ArrayRef kernelOperands) { build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, @@ -366,23 +366,23 @@ LogicalResult LaunchFuncOp::verify() { return emitOpError("attribute 'kernel' must be a function"); } - auto *module = getOperation()->getFunction()->getModule(); - Function *kernelFunc = module->getNamedFunction(kernel()); + auto *module = getOperation()->getFunction().getModule(); + Function kernelFunc = module->getNamedFunction(kernel()); if (!kernelFunc) return emitError() << "kernel function '" << kernelAttr << "' is undefined"; - if (!kernelFunc->getAttrOfType( + if (!kernelFunc.getAttrOfType( GPUDialect::getKernelFuncAttrName())) { return emitError("kernel function is missing the '") << GPUDialect::getKernelFuncAttrName() << "' attribute"; } - unsigned numKernelFuncArgs = kernelFunc->getNumArguments(); + unsigned numKernelFuncArgs = kernelFunc.getNumArguments(); if (getNumKernelOperands() != numKernelFuncArgs) { return emitOpError("got ") << getNumKernelOperands() << " kernel operands but expected " << numKernelFuncArgs; } - auto functionType = kernelFunc->getType(); + auto functionType = kernelFunc.getType(); for (unsigned i = 0; i < numKernelFuncArgs; ++i) { if (getKernelOperand(i)->getType() != functionType.getInput(i)) { return emitOpError("type of function argument ") diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp index 46363f0..f93febc 100644 --- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp @@ -40,7 +40,7 @@ static void createForAllDimensions(OpBuilder &builder, Location loc, // Add operations generating block/thread ids and gird/block dimensions at the // beginning of `kernelFunc` and replace uses of the respective function args. -static void injectGpuIndexOperations(Location loc, Function &kernelFunc) { +static void injectGpuIndexOperations(Location loc, Function kernelFunc) { OpBuilder OpBuilder(kernelFunc.getBody()); SmallVector indexOps; createForAllDimensions(OpBuilder, loc, indexOps); @@ -58,20 +58,20 @@ static void injectGpuIndexOperations(Location loc, Function &kernelFunc) { // Outline the `gpu.launch` operation body into a kernel function. Replace // `gpu.return` operations by `std.return` in the generated functions. -static Function *outlineKernelFunc(gpu::LaunchOp launchOp) { +static Function outlineKernelFunc(gpu::LaunchOp launchOp) { Location loc = launchOp.getLoc(); SmallVector kernelOperandTypes(launchOp.getKernelOperandTypes()); FunctionType type = FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); std::string kernelFuncName = - Twine(launchOp.getOperation()->getFunction()->getName(), "_kernel").str(); - Function *outlinedFunc = new mlir::Function(loc, kernelFuncName, type); - outlinedFunc->getBody().takeBody(launchOp.getBody()); + Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str(); + Function outlinedFunc = Function::create(loc, kernelFuncName, type); + outlinedFunc.getBody().takeBody(launchOp.getBody()); Builder builder(launchOp.getContext()); - outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), - builder.getUnitAttr()); - injectGpuIndexOperations(loc, *outlinedFunc); - outlinedFunc->walk([](mlir::gpu::Return op) { + outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + injectGpuIndexOperations(loc, outlinedFunc); + outlinedFunc.walk([](mlir::gpu::Return op) { OpBuilder replacer(op); replacer.create(op.getLoc()); op.erase(); @@ -82,12 +82,12 @@ static Function *outlineKernelFunc(gpu::LaunchOp launchOp) { // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching // `kernelFunc`. static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, - Function &kernelFunc) { + Function kernelFunc) { OpBuilder builder(launchOp); SmallVector kernelOperandValues( launchOp.getKernelOperandValues()); builder.create( - launchOp.getLoc(), &kernelFunc, launchOp.getGridSizeOperandValues(), + launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(), launchOp.getBlockSizeOperandValues(), kernelOperandValues); launchOp.erase(); } @@ -98,11 +98,11 @@ class GpuKernelOutliningPass : public ModulePass { public: void runOnModule() override { ModuleManager moduleManager(&getModule()); - for (auto &func : getModule()) { + for (auto func : getModule()) { func.walk([&](mlir::gpu::LaunchOp op) { - Function *outlinedFunc = outlineKernelFunc(op); + Function outlinedFunc = outlineKernelFunc(op); moduleManager.insert(outlinedFunc); - convertToLaunchFuncOp(op, *outlinedFunc); + convertToLaunchFuncOp(op, outlinedFunc); }); } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 8e3d578..346d35a 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -306,7 +306,7 @@ void ModuleState::initialize(Module *module) { initializeSymbolAliases(); // Walk the module and visit each operation. - for (auto &fn : *module) { + for (auto fn : *module) { visitType(fn.getType()); for (auto attr : fn.getAttrs()) ModuleState::visitAttribute(attr.second); @@ -342,7 +342,7 @@ public: void printAttribute(Attribute attr, bool mayElideType = false); void printType(Type type); - void print(Function *fn); + void print(Function fn); void printLocation(LocationAttr loc); void printAffineMap(AffineMap map); @@ -460,8 +460,8 @@ void ModulePrinter::print(Module *module) { state.printTypeAliases(os); // Print the module. - for (auto &fn : *module) - print(&fn); + for (auto fn : *module) + print(fn); } /// Print a floating point value in a way that the parser will be able to @@ -1186,7 +1186,7 @@ namespace { // CFG and ML functions. class FunctionPrinter : public ModulePrinter, private OpAsmPrinter { public: - FunctionPrinter(Function *function, ModulePrinter &other); + FunctionPrinter(Function function, ModulePrinter &other); // Prints the function as a whole. void print(); @@ -1275,7 +1275,7 @@ protected: void printValueID(Value *value, bool printResultNo = true) const; private: - Function *function; + Function function; /// This is the value ID for each SSA value in the current function. If this /// returns ~0, then the valueID has an entry in valueNames. @@ -1305,10 +1305,10 @@ private: }; } // end anonymous namespace -FunctionPrinter::FunctionPrinter(Function *function, ModulePrinter &other) +FunctionPrinter::FunctionPrinter(Function function, ModulePrinter &other) : ModulePrinter(other), function(function) { - for (auto &block : *function) + for (auto &block : function) numberValuesInBlock(block); } @@ -1419,17 +1419,17 @@ void FunctionPrinter::print() { printFunctionSignature(); // Print out function attributes, if present. - auto attrs = function->getAttrs(); + auto attrs = function.getAttrs(); if (!attrs.empty()) { os << "\n attributes "; printOptionalAttrDict(attrs); } // Print the trailing location. - printTrailingLocation(function->getLoc()); + printTrailingLocation(function.getLoc()); - if (!function->empty()) { - printRegion(function->getBody(), /*printEntryBlockArgs=*/false, + if (!function.empty()) { + printRegion(function.getBody(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); os << "\n"; } @@ -1437,24 +1437,24 @@ void FunctionPrinter::print() { } void FunctionPrinter::printFunctionSignature() { - os << "func @" << function->getName() << '('; + os << "func @" << function.getName() << '('; - auto fnType = function->getType(); - bool isExternal = function->isExternal(); - for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) { + auto fnType = function.getType(); + bool isExternal = function.isExternal(); + for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) { if (i > 0) os << ", "; // If this is an external function, don't print argument labels. if (!isExternal) { - printOperand(function->getArgument(i)); + printOperand(function.getArgument(i)); os << ": "; } printType(fnType.getInput(i)); // Print the attributes for this argument. - printOptionalAttrDict(function->getArgAttrs(i)); + printOptionalAttrDict(function.getArgAttrs(i)); } os << ')'; @@ -1662,7 +1662,7 @@ void FunctionPrinter::printSuccessorAndUseList(Operation *term, } // Prints function with initialized module state. -void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); } +void ModulePrinter::print(Function fn) { FunctionPrinter(fn, *this).print(); } //===----------------------------------------------------------------------===// // print and dump methods @@ -1737,13 +1737,13 @@ void Value::print(raw_ostream &os) { void Value::dump() { print(llvm::errs()); } void Operation::print(raw_ostream &os) { - auto *function = getFunction(); + auto function = getFunction(); if (!function) { os << "<>\n"; return; } - ModuleState state(function->getContext()); + ModuleState state(function.getContext()); ModulePrinter modulePrinter(os, state); FunctionPrinter(function, modulePrinter).print(this); } @@ -1754,13 +1754,13 @@ void Operation::dump() { } void Block::print(raw_ostream &os) { - auto *function = getFunction(); + auto function = getFunction(); if (!function) { os << "<>\n"; return; } - ModuleState state(function->getContext()); + ModuleState state(function.getContext()); ModulePrinter modulePrinter(os, state); FunctionPrinter(function, modulePrinter).print(this); } @@ -1773,14 +1773,14 @@ void Block::printAsOperand(raw_ostream &os, bool printType) { os << "<>\n"; return; } - ModuleState state(getFunction()->getContext()); + ModuleState state(getFunction().getContext()); ModulePrinter modulePrinter(os, state); FunctionPrinter(getFunction(), modulePrinter).printBlockName(this); } void Function::print(raw_ostream &os) { ModuleState state(getContext()); - ModulePrinter(os, state).print(this); + ModulePrinter(os, state).print(*this); } void Function::dump() { print(llvm::errs()); } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 01f9a06..9cbba0f 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -249,11 +249,6 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional loc, // FunctionAttr //===----------------------------------------------------------------------===// -FunctionAttr FunctionAttr::get(Function *value) { - assert(value && "Cannot get FunctionAttr for a null function"); - return get(value->getName(), value->getContext()); -} - FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) { return Base::get(ctx, StandardAttributes::Function, value, NoneType::get(ctx)); diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index e7616f6..134f6e4 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -50,7 +50,7 @@ Operation *Block::getContainingOp() { return getParent() ? getParent()->getContainingOp() : nullptr; } -Function *Block::getFunction() { +Function Block::getFunction() { Block *block = this; while (auto *op = block->getContainingOp()) { block = op->getBlock(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 9b30205..89df642 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -177,8 +177,8 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); } -FunctionAttr Builder::getFunctionAttr(Function *value) { - return FunctionAttr::get(value); +FunctionAttr Builder::getFunctionAttr(Function value) { + return getFunctionAttr(value.getName()); } FunctionAttr Builder::getFunctionAttr(StringRef value) { return FunctionAttr::get(value, getContext()); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 4547452..e38b95f 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectHooks.h" +#include "mlir/IR/Function.h" #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/ManagedStatic.h" @@ -68,6 +69,20 @@ Dialect::Dialect(StringRef name, MLIRContext *context) Dialect::~Dialect() {} +/// Verify an attribute from this dialect on the given function. Returns +/// failure if the verification failed, success otherwise. +LogicalResult Dialect::verifyFunctionAttribute(Function, NamedAttribute) { + return success(); +} + +/// Verify an attribute from this dialect on the argument at 'argIndex' for +/// the given function. Returns failure if the verification failed, success +/// otherwise. +LogicalResult Dialect::verifyFunctionArgAttribute(Function, unsigned argIndex, + NamedAttribute) { + return success(); +} + /// Parse an attribute registered to this dialect. Attribute Dialect::parseAttribute(StringRef attrData, Location loc) const { emitError(loc) << "dialect '" << getNamespace() diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 7d17ed1..f8835f0 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -27,45 +27,50 @@ #include "llvm/ADT/Twine.h" using namespace mlir; +using namespace mlir::detail; -Function::Function(Location location, StringRef name, FunctionType type, - ArrayRef attrs) +FunctionStorage::FunctionStorage(Location location, StringRef name, + FunctionType type, + ArrayRef attrs) : name(Identifier::get(name, type.getContext())), location(location), type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {} -Function::Function(Location location, StringRef name, FunctionType type, - ArrayRef attrs, - ArrayRef argAttrs) +FunctionStorage::FunctionStorage(Location location, StringRef name, + FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs) : name(Identifier::get(name, type.getContext())), location(location), type(type), attrs(attrs), argAttrs(argAttrs), body(this) {} MLIRContext *Function::getContext() { return getType().getContext(); } -Module *llvm::ilist_traits::getContainingModule() { +Module *llvm::ilist_traits::getContainingModule() { size_t Offset( size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr)))); - iplist *Anchor(static_cast *>(this)); + iplist *Anchor(static_cast *>(this)); return reinterpret_cast(reinterpret_cast(Anchor) - Offset); } /// This is a trait method invoked when a Function is added to a Module. We /// keep the module pointer and module symbol table up to date. -void llvm::ilist_traits::addNodeToList(Function *function) { - assert(!function->getModule() && "already in a module!"); +void llvm::ilist_traits::addNodeToList( + FunctionStorage *function) { + assert(!function->module && "already in a module!"); function->module = getContainingModule(); } /// This is a trait method invoked when a Function is removed from a Module. /// We keep the module pointer up to date. -void llvm::ilist_traits::removeNodeFromList(Function *function) { +void llvm::ilist_traits::removeNodeFromList( + FunctionStorage *function) { assert(function->module && "not already in a module!"); function->module = nullptr; } /// This is a trait method invoked when an operation is moved from one block /// to another. We keep the block pointer up to date. -void llvm::ilist_traits::transferNodesFromList( - ilist_traits &otherList, function_iterator first, +void llvm::ilist_traits::transferNodesFromList( + ilist_traits &otherList, function_iterator first, function_iterator last) { // If we are transferring functions within the same module, the Module // pointer doesn't need to be updated. @@ -82,8 +87,10 @@ void llvm::ilist_traits::transferNodesFromList( /// Unlink this function from its Module and delete it. void Function::erase() { - assert(getModule() && "Function has no parent"); - getModule()->getFunctions().erase(this); + if (auto *module = getModule()) + getModule()->functions.erase(impl); + else + delete impl; } /// Emit an error about fatal conditions with this function, reporting up to @@ -111,10 +118,10 @@ InFlightDiagnostic Function::emitRemark(const Twine &message) { /// Clone the internal blocks from this function into dest and all attributes /// from this function to dest. -void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { +void Function::cloneInto(Function dest, BlockAndValueMapping &mapper) { // Add the attributes of this function to dest. llvm::MapVector newAttrs; - for (auto &attr : dest->getAttrs()) + for (auto &attr : dest.getAttrs()) newAttrs.insert(attr); for (auto &attr : getAttrs()) { auto insertPair = newAttrs.insert(attr); @@ -125,10 +132,10 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { assert((insertPair.second || insertPair.first->second == attr.second) && "the two functions have incompatible attributes"); } - dest->setAttrs(newAttrs.takeVector()); + dest.setAttrs(newAttrs.takeVector()); // Clone the body. - body.cloneInto(&dest->body, mapper); + impl->body.cloneInto(&dest.impl->body, mapper); } /// Create a deep copy of this function and all of its blocks, remapping @@ -136,8 +143,8 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { /// provided (leaving them alone if no entry is present). Replaces references /// to cloned sub-values with the corresponding value that is copied, and adds /// those mappings to the mapper. -Function *Function::clone(BlockAndValueMapping &mapper) { - FunctionType newType = type; +Function Function::clone(BlockAndValueMapping &mapper) { + FunctionType newType = impl->type; // If the function has a body, then the user might be deleting arguments to // the function by specifying them in the mapper. If so, we don't add the @@ -147,23 +154,23 @@ Function *Function::clone(BlockAndValueMapping &mapper) { SmallVector inputTypes; for (unsigned i = 0, e = getNumArguments(); i != e; ++i) if (!mapper.contains(getArgument(i))) - inputTypes.push_back(type.getInput(i)); - newType = FunctionType::get(inputTypes, type.getResults(), getContext()); + inputTypes.push_back(newType.getInput(i)); + newType = FunctionType::get(inputTypes, newType.getResults(), getContext()); } // Create the new function. - Function *newFunc = new Function(getLoc(), getName(), newType); + Function newFunc = Function::create(getLoc(), getName(), newType); /// Set the argument attributes for arguments that aren't being replaced. for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i) if (isExternalFn || !mapper.contains(getArgument(i))) - newFunc->setArgAttrs(destI++, getArgAttrs(i)); + newFunc.setArgAttrs(destI++, getArgAttrs(i)); /// Clone the current function into the new one and return it. cloneInto(newFunc, mapper); return newFunc; } -Function *Function::clone() { +Function Function::clone() { BlockAndValueMapping mapper; return clone(mapper); } @@ -178,7 +185,7 @@ void Function::addEntryBlock() { assert(empty() && "function already has an entry block"); auto *entry = new Block(); push_back(entry); - entry->addArguments(type.getInputs()); + entry->addArguments(impl->type.getInputs()); } void Function::walk(const std::function &callback) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 83171f1..f953cd2 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,7 +281,7 @@ Operation *Operation::getParentOp() { return block ? block->getContainingOp() : nullptr; } -Function *Operation::getFunction() { +Function Operation::getFunction() { return block ? block->getFunction() : nullptr; } @@ -861,12 +861,13 @@ static LogicalResult verifyBBArguments(Operation::operand_range operands, } static LogicalResult verifyTerminatorSuccessors(Operation *op) { + auto *parent = op->getContainingRegion(); + // Verify that the operands lines up with the BB arguments in the successor. - Function *fn = op->getFunction(); for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { auto *succ = op->getSuccessor(i); - if (succ->getFunction() != fn) - return op->emitError("reference to block defined in another function"); + if (succ->getParent() != parent) + return op->emitError("reference to block defined in another region"); if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op))) return failure(); } diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 992d911..74c71b7 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -21,7 +21,7 @@ #include "mlir/IR/Operation.h" using namespace mlir; -Region::Region(Function *container) : container(container) {} +Region::Region(Function container) : container(container.impl) {} Region::Region(Operation *container) : container(container) {} @@ -38,7 +38,7 @@ MLIRContext *Region::getContext() { assert(!container.isNull() && "region is not attached to a container"); if (auto *inst = getContainingOp()) return inst->getContext(); - return getContainingFunction()->getContext(); + return getContainingFunction().getContext(); } /// Return a location for this region. This is the location attached to the @@ -47,7 +47,7 @@ Location Region::getLoc() { assert(!container.isNull() && "region is not attached to a container"); if (auto *inst = getContainingOp()) return inst->getLoc(); - return getContainingFunction()->getLoc(); + return getContainingFunction().getLoc(); } Region *Region::getContainingRegion() { @@ -60,8 +60,8 @@ Operation *Region::getContainingOp() { return container.dyn_cast(); } -Function *Region::getContainingFunction() { - return container.dyn_cast(); +Function Region::getContainingFunction() { + return container.dyn_cast(); } bool Region::isProperAncestor(Region *other) { diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index a0819a7..dafbd48 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -22,8 +22,8 @@ using namespace mlir; /// Build a symbol table with the symbols within the given module. SymbolTable::SymbolTable(Module *module) : context(module->getContext()) { - for (auto &func : *module) { - auto inserted = symbolTable.insert({func.getName(), &func}); + for (auto func : *module) { + auto inserted = symbolTable.insert({func.getName(), func}); (void)inserted; assert(inserted.second && "expected module to contain uniquely named functions"); @@ -32,34 +32,34 @@ SymbolTable::SymbolTable(Module *module) : context(module->getContext()) { /// Look up a symbol with the specified name, returning null if no such name /// exists. Names never include the @ on them. -Function *SymbolTable::lookup(StringRef name) const { +Function SymbolTable::lookup(StringRef name) const { return lookup(Identifier::get(name, context)); } /// Look up a symbol with the specified name, returning null if no such name /// exists. Names never include the @ on them. -Function *SymbolTable::lookup(Identifier name) const { +Function SymbolTable::lookup(Identifier name) const { return symbolTable.lookup(name); } /// Erase the given symbol from the table. -void SymbolTable::erase(Function *symbol) { - auto it = symbolTable.find(symbol->getName()); +void SymbolTable::erase(Function symbol) { + auto it = symbolTable.find(symbol.getName()); if (it != symbolTable.end() && it->second == symbol) symbolTable.erase(it); } /// Insert a new symbol into the table, and rename it as necessary to avoid /// collisions. -void SymbolTable::insert(Function *symbol) { +void SymbolTable::insert(Function symbol) { // Add this symbol to the symbol table, uniquing the name if a conflict is // detected. - if (symbolTable.insert({symbol->getName(), symbol}).second) + if (symbolTable.insert({symbol.getName(), symbol}).second) return; // If a conflict was detected, then the function will not have been added to // the symbol table. Try suffixes until we get to a unique name that works. - SmallString<128> nameBuffer(symbol->getName()); + SmallString<128> nameBuffer(symbol.getName()); unsigned originalLength = nameBuffer.size(); // Iteratively try suffixes until we find one that isn't used. We use a @@ -68,6 +68,6 @@ void SymbolTable::insert(Function *symbol) { nameBuffer.resize(originalLength); nameBuffer += '_'; nameBuffer += std::to_string(uniquingCounter++); - symbol->setName(Identifier::get(nameBuffer, context)); - } while (!symbolTable.insert({symbol->getName(), symbol}).second); + symbol.setName(Identifier::get(nameBuffer, context)); + } while (!symbolTable.insert({symbol.getName(), symbol}).second); } diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 073c3b3..65a98f7 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -30,7 +30,7 @@ Operation *Value::getDefiningOp() { } /// Return the function that this Value is defined in. -Function *Value::getFunction() { +Function Value::getFunction() { switch (getKind()) { case Value::Kind::BlockArgument: return cast(this)->getFunction(); @@ -84,7 +84,7 @@ void IRObjectWithUseList::dropAllUses() { //===----------------------------------------------------------------------===// /// Return the function that this argument is defined in. -Function *BlockArgument::getFunction() { +Function BlockArgument::getFunction() { if (auto *owner = getOwner()) return owner->getFunction(); return nullptr; @@ -92,6 +92,6 @@ Function *BlockArgument::getFunction() { /// Returns if the current argument is a function argument. bool BlockArgument::isFunctionArgument() { - auto *containingFn = getFunction(); - return containingFn && &containingFn->front() == getOwner(); + auto containingFn = getFunction(); + return containingFn && &containingFn.front() == getOwner(); } diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 0d3a5ca..0dbf63a 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -816,12 +816,12 @@ void LLVMDialect::printType(Type type, raw_ostream &os) const { } /// Verify LLVMIR function argument attributes. -LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func, +LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function func, unsigned argIdx, NamedAttribute argAttr) { // Check that llvm.noalias is a boolean attribute. if (argAttr.first == "llvm.noalias" && !argAttr.second.isa()) - return func->emitError() + return func.emitError() << "llvm.noalias argument attribute of non boolean type"; return success(); } diff --git a/mlir/lib/Linalg/Transforms/Fusion.cpp b/mlir/lib/Linalg/Transforms/Fusion.cpp index 7ddb7b0c..5761cc6 100644 --- a/mlir/lib/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Linalg/Transforms/Fusion.cpp @@ -209,7 +209,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, return true; } -static void fuseLinalgOps(Function &f, ArrayRef tileSizes) { +static void fuseLinalgOps(Function f, ArrayRef tileSizes) { OperationFolder state; DenseSet eraseSet; diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index a8099aa..5fe4f07 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -170,12 +170,13 @@ public: LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. - auto *module = op->getFunction()->getModule(); - Function *mallocFunc = module->getNamedFunction("malloc"); + auto *module = op->getFunction().getModule(); + Function mallocFunc = module->getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); - mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType); - module->getFunctions().push_back(mallocFunc); + mallocFunc = + Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); + module->push_back(mallocFunc); } // Get MLIR types for injecting element pointer. @@ -230,12 +231,12 @@ public: auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. - auto *module = op->getFunction()->getModule(); - Function *freeFunc = module->getNamedFunction("free"); + auto *module = op->getFunction().getModule(); + Function freeFunc = module->getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(voidPtrTy, {}); - freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType); - module->getFunctions().push_back(freeFunc); + freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); + module->push_back(freeFunc); } // Get MLIR types for extracting element pointer. @@ -572,37 +573,37 @@ public: // Create a function definition which takes as argument pointers to the input // types and returns pointers to the output types. -static Function *getLLVMLibraryCallImplDefinition(Function *libFn) { - auto implFnName = (libFn->getName().str() + "_impl"); - auto module = libFn->getModule(); - if (auto *f = module->getNamedFunction(implFnName)) { +static Function getLLVMLibraryCallImplDefinition(Function libFn) { + auto implFnName = (libFn.getName().str() + "_impl"); + auto module = libFn.getModule(); + if (auto f = module->getNamedFunction(implFnName)) { return f; } SmallVector fnArgTypes; - for (auto t : libFn->getType().getInputs()) { + for (auto t : libFn.getType().getInputs()) { assert(t.isa() && "Expected LLVM Type for argument while generating library Call " "Implementation Definition"); fnArgTypes.push_back(t.cast().getPointerTo()); } - auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext()); + auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext()); // Insert the implementation function definition. - auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType); - module->getFunctions().push_back(implFnDefn); + auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType); + module->push_back(implFnDefn); return implFnDefn; } // Get function definition for the LinalgOp. If it doesn't exist, insert a // definition. template -static Function *getLLVMLibraryCallDeclaration(Operation *op, - LLVMTypeConverter &lowering, - PatternRewriter &rewriter) { +static Function getLLVMLibraryCallDeclaration(Operation *op, + LLVMTypeConverter &lowering, + PatternRewriter &rewriter) { assert(isa(op)); auto fnName = LinalgOp::getLibraryCallName(); - auto module = op->getFunction()->getModule(); - if (auto *f = module->getNamedFunction(fnName)) { + auto module = op->getFunction().getModule(); + if (auto f = module->getNamedFunction(fnName)) { return f; } @@ -618,29 +619,29 @@ static Function *getLLVMLibraryCallDeclaration(Operation *op, "Library call for linalg operation can be generated only for ops that " "have void return types"); auto libFnType = FunctionType::get(inputTypes, {}, op->getContext()); - auto libFn = new Function(op->getLoc(), fnName, libFnType); - module->getFunctions().push_back(libFn); + auto libFn = Function::create(op->getLoc(), fnName, libFnType); + module->push_back(libFn); // Return after creating the function definition. The body will be created // later. return libFn; } -static void getLLVMLibraryCallDefinition(Function *fn, +static void getLLVMLibraryCallDefinition(Function fn, LLVMTypeConverter &lowering) { // Generate the implementation function definition. auto implFn = getLLVMLibraryCallImplDefinition(fn); // Generate the function body. - fn->addEntryBlock(); + fn.addEntryBlock(); - OpBuilder builder(fn->getBody()); - edsc::ScopedContext scope(builder, fn->getLoc()); + OpBuilder builder(fn.getBody()); + edsc::ScopedContext scope(builder, fn.getLoc()); SmallVector implFnArgs; // Create a constant 1. auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()), - IntegerAttr::get(IndexType::get(fn->getContext()), 1)); - for (auto arg : fn->getArguments()) { + IntegerAttr::get(IndexType::get(fn.getContext()), 1)); + for (auto arg : fn.getArguments()) { // Allocate a stack for storing the argument value. The stack is passed to // the implementation function. auto alloca = @@ -665,17 +666,17 @@ public: return convertLinalgType(t, *this); } - void addLibraryFnDeclaration(Function *fn) { + void addLibraryFnDeclaration(Function fn) { libraryFnDeclarations.push_back(fn); } - ArrayRef getLibraryFnDeclarations() { + ArrayRef getLibraryFnDeclarations() { return libraryFnDeclarations; } private: /// List of library functions declarations needed during dialect conversion - SmallVector libraryFnDeclarations; + SmallVector libraryFnDeclarations; }; } // end anonymous namespace @@ -692,7 +693,7 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { // Only emit library call declaration. Fill in the body later. - auto *f = getLLVMLibraryCallDeclaration(op, lowering, rewriter); + auto f = getLLVMLibraryCallDeclaration(op, lowering, rewriter); static_cast(lowering).addLibraryFnDeclaration(f); auto fAttr = rewriter.getFunctionAttr(f); @@ -803,7 +804,7 @@ static void lowerLinalgForToCFG(Function &f) { void LowerLinalgToLLVMPass::runOnModule() { auto &module = getModule(); - for (auto &f : module.getFunctions()) { + for (auto f : module.getFunctions()) { lowerLinalgSubViewOps(f); lowerLinalgForToCFG(f); if (failed(lowerAffineConstructs(f))) diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index d31ba5b..2e616c3 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -104,9 +104,8 @@ struct LowerLinalgToLoopsPass : public FunctionPass { } // namespace void LowerLinalgToLoopsPass::runOnFunction() { - auto &f = getFunction(); OperationFolder state; - f.walk([&state](LinalgOp linalgOp) { + getFunction().walk([&state](LinalgOp linalgOp) { emitLinalgOpAsLoops(linalgOp, state); linalgOp.getOperation()->erase(); }); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index c63e1cf..2f752b2 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -259,7 +259,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef tileSizes, return tileLinalgOp(op, tileSizeValues, state); } -static void tileLinalgOps(Function &f, ArrayRef tileSizes) { +static void tileLinalgOps(Function f, ArrayRef tileSizes) { OperationFolder state; f.walk([tileSizes, &state](LinalgOp op) { auto opLoopsPair = tileLinalgOp(op, tileSizes, state); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 44f0596..4af2f09 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -254,7 +254,7 @@ public: /// trailing-location ::= location? /// template - ParseResult parseOptionalTrailingLocation(Owner *owner) { + ParseResult parseOptionalTrailingLocation(Owner &owner) { // If there is a 'loc' we parse a trailing location. if (!getToken().is(Token::kw_loc)) return success(); @@ -263,7 +263,7 @@ public: LocationAttr directLoc; if (parseLocation(directLoc)) return failure(); - owner->setLoc(directLoc); + owner.setLoc(directLoc); return success(); } @@ -2472,8 +2472,8 @@ namespace { /// operations. class OperationParser : public Parser { public: - OperationParser(ParserState &state, Function *function) - : Parser(state), function(function), opBuilder(function->getBody()) {} + OperationParser(ParserState &state, Function function) + : Parser(state), function(function), opBuilder(function.getBody()) {} ~OperationParser(); @@ -2588,7 +2588,7 @@ public: Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); private: - Function *function; + Function function; /// Returns the info for a block at the current scope for the given name. std::pair &getBlockInfoByName(StringRef name) { @@ -2690,7 +2690,7 @@ ParseResult OperationParser::popSSANameScope() { for (auto entry : forwardRefInCurrentScope) { errors.push_back({entry.second.getPointer(), entry.first}); // Add this block to the top-level region to allow for automatic cleanup. - function->push_back(entry.first); + function.push_back(entry.first); } llvm::array_pod_sort(errors.begin(), errors.end()); @@ -2984,7 +2984,7 @@ ParseResult OperationParser::parseOperation() { } // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(op)) + if (parseOptionalTrailingLocation(*op)) return failure(); return success(); @@ -4049,17 +4049,17 @@ ParseResult ModuleParser::parseFunc(Module *module) { } // Okay, the function signature was parsed correctly, create the function now. - auto *function = - new Function(getEncodedSourceLocation(loc), name, type, attrs); - module->getFunctions().push_back(function); + auto function = + Function::create(getEncodedSourceLocation(loc), name, type, attrs); + module->push_back(function); // Parse an optional trailing location. if (parseOptionalTrailingLocation(function)) return failure(); // Add the attributes to the function arguments. - for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) - function->setArgAttrs(i, argAttrs[i]); + for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) + function.setArgAttrs(i, argAttrs[i]); // External functions have no body. if (getToken().isNot(Token::l_brace)) @@ -4076,11 +4076,11 @@ ParseResult ModuleParser::parseFunc(Module *module) { // Parse the function body. auto parser = OperationParser(getState(), function); - if (parser.parseRegion(function->getBody(), entryArgs)) + if (parser.parseRegion(function.getBody(), entryArgs)) return failure(); // Verify that a valid function body was parsed. - if (function->empty()) + if (function.empty()) return emitError(braceLoc, "function must have a body"); return parser.finalize(braceLoc); diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 868d492..057f265 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -61,12 +61,12 @@ private: static void printIR(const llvm::Any &ir, bool printModuleScope, raw_ostream &out) { // Check for printing at module scope. - if (printModuleScope && llvm::any_isa(ir)) { - Function *function = llvm::any_cast(ir); + if (printModuleScope && llvm::any_isa(ir)) { + Function function = llvm::any_cast(ir); // Print the function name and a newline before the Module. - out << " (function: " << function->getName() << ")\n"; - function->getModule()->print(out); + out << " (function: " << function.getName() << ")\n"; + function.getModule()->print(out); return; } @@ -74,8 +74,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope, out << "\n"; // Print the given function. - if (llvm::any_isa(ir)) { - llvm::any_cast(ir)->print(out); + if (llvm::any_isa(ir)) { + llvm::any_cast(ir).print(out); return; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 2f605b6..27ec74c 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -46,8 +46,7 @@ static llvm::cl::opt void Pass::anchor() {} /// Forwarding function to execute this pass. -LogicalResult FunctionPassBase::run(Function *fn, - FunctionAnalysisManager &fam) { +LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) { // Initialize the pass state. passState.emplace(fn, fam); @@ -115,7 +114,7 @@ FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs) } /// Run all of the passes in this manager over the current function. -LogicalResult detail::FunctionPassExecutor::run(Function *function, +LogicalResult detail::FunctionPassExecutor::run(Function function, FunctionAnalysisManager &fam) { // Run each of the held passes. for (auto &pass : passes) @@ -141,7 +140,7 @@ LogicalResult detail::ModulePassExecutor::run(Module *module, /// Utility to run the given function and analysis manager on a provided /// function pass executor. static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, - Function *func, + Function func, FunctionAnalysisManager &fam) { // Run the function pipeline over the provided function. auto result = fpe.run(func, fam); @@ -158,14 +157,14 @@ static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, /// module. void ModuleToFunctionPassAdaptor::runOnModule() { ModuleAnalysisManager &mam = getAnalysisManager(); - for (auto &func : getModule()) { + for (auto func : getModule()) { // Skip external functions. if (func.isExternal()) continue; // Run the held function pipeline over the current function. - auto fam = mam.slice(&func); - if (failed(runFunctionPipeline(fpe, &func, fam))) + auto fam = mam.slice(func); + if (failed(runFunctionPipeline(fpe, func, fam))) return signalPassFailure(); // Clear out any computed function analyses. These analyses won't be used @@ -189,10 +188,10 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() { // Run a prepass over the module to collect the functions to execute a over. // This ensures that an analysis manager exists for each function, as well as // providing a queue of functions to execute over. - std::vector> funcAMPairs; - for (auto &func : getModule()) + std::vector> funcAMPairs; + for (auto func : getModule()) if (!func.isExternal()) - funcAMPairs.emplace_back(&func, mam.slice(&func)); + funcAMPairs.emplace_back(func, mam.slice(func)); // A parallel diagnostic handler that provides deterministic diagnostic // ordering. @@ -340,8 +339,8 @@ PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const { } /// Create an analysis slice for the given child function. -FunctionAnalysisManager ModuleAnalysisManager::slice(Function *func) { - assert(func->getModule() == moduleAnalyses.getIRUnit() && +FunctionAnalysisManager ModuleAnalysisManager::slice(Function func) { + assert(func.getModule() == moduleAnalyses.getIRUnit() && "function has a different parent module"); auto it = functionAnalyses.find(func); if (it == functionAnalyses.end()) { diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index 46addfb..d2563fb 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -48,7 +48,7 @@ public: FunctionPassExecutor(const FunctionPassExecutor &rhs); /// Run the executor on the given function. - LogicalResult run(Function *function, FunctionAnalysisManager &fam); + LogicalResult run(Function function, FunctionAnalysisManager &fam); /// Add a pass to the current executor. This takes ownership over the provided /// pass pointer. diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index 375a64d8..3f26bf0 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -71,7 +71,7 @@ void AddDefaultStatsPass::runOnFunction() { void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, const TargetConfiguration &config) { - auto &func = getFunction(); + auto func = getFunction(); // Insert stats for each argument. for (auto *arg : func.getArguments()) { diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index dec4ea9..169fec3 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -129,7 +129,7 @@ void InferQuantizedTypesPass::runOnModule() { void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext, const TargetConfiguration &config) { CAGSlice cag(solverContext); - for (auto &f : getModule()) { + for (auto f : getModule()) { f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); }); } config.finalizeAnchors(cag); diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index ed3b095..6b376db 100644 --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -58,7 +58,7 @@ public: void RemoveInstrumentationPass::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); auto *context = &getContext(); patterns.push_back( llvm::make_unique>(context)); diff --git a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp index 3add211..543b730 100644 --- a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp +++ b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp @@ -36,11 +36,11 @@ using namespace mlir; // block. The created block will be terminated by `std.return`. Block *createOneBlockFunction(Builder builder, Module *module) { auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{}); - auto *fn = new Function(builder.getUnknownLoc(), "spirv_module", fnType); - module->getFunctions().push_back(fn); + auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType); + module->push_back(fn); auto *block = new Block(); - fn->push_back(block); + fn.push_back(block); OperationState state(builder.getUnknownLoc(), ReturnOp::getOperationName()); ReturnOp::build(&builder, &state); diff --git a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp index ebdcaf7..33572d5 100644 --- a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp +++ b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp @@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) { // wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed // to take in the SPIR-V ModuleOp directly after module and function are // migrated to be general ops. - for (auto &fn : *module) { + for (auto fn : *module) { fn.walk([&](spirv::ModuleOp spirvModule) { if (done) { spirvModule.emitError("found more than one 'spv.module' op"); diff --git a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp index 1a8d79c..1ce2b69 100644 --- a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp +++ b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp @@ -42,7 +42,7 @@ class StdOpsToSPIRVConversionPass void StdOpsToSPIRVConversionPass::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); populateWithGenerated(func.getContext(), &patterns); applyPatternsGreedily(func, std::move(patterns)); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 6d5073f..9fc216e 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -440,14 +440,14 @@ static LogicalResult verify(CallOp op) { auto fnAttr = op.getAttrOfType("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' function attribute"); - auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction( + auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction( fnAttr.getValue()); if (!fn) return op.emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; // Verify that the operand and result types match the callee. - auto fnType = fn->getType(); + auto fnType = fn.getType(); if (fnType.getNumInputs() != op.getNumOperands()) return op.emitOpError("incorrect number of operands for callee"); @@ -1107,13 +1107,13 @@ static LogicalResult verify(ConstantOp &op) { return op.emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. - auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction( + auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction( fnAttr.getValue()); if (!fn) return op.emitOpError("reference to undefined function 'bar'"); // Check that the referenced function has the correct type. - if (fn->getType() != type) + if (fn.getType() != type) return op.emitOpError("reference to function with mismatched type"); return success(); @@ -1876,10 +1876,10 @@ static void print(OpAsmPrinter *p, ReturnOp op) { } static LogicalResult verify(ReturnOp op) { - auto *function = op.getOperation()->getFunction(); + auto function = op.getOperation()->getFunction(); // The operand number and types must match the function signature. - const auto &results = function->getType().getResults(); + const auto &results = function.getType().getResults(); if (op.getNumOperands() != results.size()) return op.emitOpError("has ") << op.getNumOperands() diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 74ade94..1e84092 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -69,7 +69,7 @@ std::unique_ptr mlir::translateModuleToNVVMIR(Module &m) { // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the // function as a kernel. - for (Function &func : m) { + for (Function func : m) { if (!func.getAttrOfType(gpu::GPUDialect::getKernelFuncAttrName())) continue; @@ -89,20 +89,21 @@ std::unique_ptr mlir::translateModuleToNVVMIR(Module &m) { return llvmModule; } -static TranslateFromMLIRRegistration registration( - "mlir-to-nvvmir", [](Module *module, llvm::StringRef outputFilename) { - if (!module) - return true; +static TranslateFromMLIRRegistration + registration("mlir-to-nvvmir", + [](Module *module, llvm::StringRef outputFilename) { + if (!module) + return true; - auto llvmModule = mlir::translateModuleToNVVMIR(*module); - if (!llvmModule) - return true; + auto llvmModule = mlir::translateModuleToNVVMIR(*module); + if (!llvmModule) + return true; - auto file = openOutputFile(outputFilename); - if (!file) - return true; + auto file = openOutputFile(outputFilename); + if (!file) + return true; - llvmModule->print(file->os(), nullptr); - file->keep(); - return false; - }); + llvmModule->print(file->os(), nullptr); + file->keep(); + return false; + }); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ef286cb..4a68ac7 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -375,7 +375,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) { bool ModuleTranslation::convertFunctions() { // Declare all functions first because there may be function calls that form a // call graph with cycles. - for (Function &function : mlirModule) { + for (Function function : mlirModule) { mlir::BoolAttr isVarArgsAttr = function.getAttrOfType("std.varargs"); bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue(); @@ -392,7 +392,7 @@ bool ModuleTranslation::convertFunctions() { } // Convert functions. - for (Function &function : mlirModule) { + for (Function function : mlirModule) { // Ignore external functions. if (function.isExternal()) continue; diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 8a2002c..394b3ef 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -40,7 +40,7 @@ struct Canonicalizer : public FunctionPass { void Canonicalizer::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); // TODO: Instead of adding all known patterns from the whole system lazily add // and cache the canonicalization patterns for ops we see in practice when diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index be60ada..84f00b9 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -849,7 +849,7 @@ struct FunctionConverter { /// error, success otherwise. If 'signatureConversion' is provided, the /// arguments of the entry block are updated accordingly. LogicalResult - convertFunction(Function *f, + convertFunction(Function f, TypeConverter::SignatureConversion *signatureConversion); /// Converts the given region starting from the entry block and following the @@ -957,22 +957,22 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, } LogicalResult FunctionConverter::convertFunction( - Function *f, TypeConverter::SignatureConversion *signatureConversion) { + Function f, TypeConverter::SignatureConversion *signatureConversion) { // If this is an external function, there is nothing else to do. - if (f->isExternal()) + if (f.isExternal()) return success(); - DialectConversionRewriter rewriter(f->getBody(), typeConverter); + DialectConversionRewriter rewriter(f.getBody(), typeConverter); // Update the signature of the entry block. if (signatureConversion) { rewriter.argConverter.convertSignature( - &f->getBody().front(), *signatureConversion, rewriter.mapping); + &f.getBody().front(), *signatureConversion, rewriter.mapping); } // Rewrite the function body. if (failed( - convertRegion(rewriter, f->getBody(), /*convertEntryTypes=*/false))) { + convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) { // Reset any of the generated rewrites. rewriter.discardRewrites(); return failure(); @@ -1124,24 +1124,6 @@ auto ConversionTarget::getOpAction(OperationName op) const // applyConversionPatterns //===----------------------------------------------------------------------===// -namespace { -/// This class represents a function to be converted. It allows for converting -/// the body of functions and the signature in two phases. -struct ConvertedFunction { - ConvertedFunction(Function *fn, FunctionType newType, - ArrayRef newFunctionArgAttrs) - : fn(fn), newType(newType), - newFunctionArgAttrs(newFunctionArgAttrs.begin(), - newFunctionArgAttrs.end()) {} - - /// The function to convert. - Function *fn; - /// The new type and argument attributes for the function. - FunctionType newType; - SmallVector newFunctionArgAttrs; -}; -} // end anonymous namespace - /// Convert the given module with the provided conversion patterns and type /// conversion object. If conversion fails for specific functions, those /// functions remains unmodified. @@ -1149,37 +1131,33 @@ LogicalResult mlir::applyConversionPatterns(Module &module, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { - std::vector allFunctions; - allFunctions.reserve(module.getFunctions().size()); - for (auto &func : module) - allFunctions.push_back(&func); + SmallVector allFunctions(module.getFunctions()); return applyConversionPatterns(allFunctions, target, converter, std::move(patterns)); } /// Convert the given functions with the provided conversion patterns. LogicalResult mlir::applyConversionPatterns( - ArrayRef fns, ConversionTarget &target, + MutableArrayRef fns, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { if (fns.empty()) return success(); // Build the function converter. - FunctionConverter funcConverter(fns.front()->getContext(), target, patterns, - &converter); + auto *ctx = fns.front().getContext(); + FunctionConverter funcConverter(ctx, target, patterns, &converter); // Try to convert each of the functions within the module. - auto *ctx = fns.front()->getContext(); - for (auto *func : fns) { + for (auto func : fns) { // Convert the function type using the type converter. auto conversion = - converter.convertSignature(func->getType(), func->getAllArgAttrs()); + converter.convertSignature(func.getType(), func.getAllArgAttrs()); if (!conversion) return failure(); // Update the function signature. - func->setType(conversion->getConvertedType(ctx)); - func->setAllArgAttrs(conversion->getConvertedArgAttrs()); + func.setType(conversion->getConvertedType(ctx)); + func.setAllArgAttrs(conversion->getConvertedArgAttrs()); // Convert the body of this function. if (failed(funcConverter.convertFunction(func, &*conversion))) @@ -1193,9 +1171,9 @@ LogicalResult mlir::applyConversionPatterns( /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LogicalResult -mlir::applyConversionPatterns(Function &fn, ConversionTarget &target, +mlir::applyConversionPatterns(Function fn, ConversionTarget &target, OwningRewritePatternList &&patterns) { // Convert the body of this function. FunctionConverter converter(fn.getContext(), target, patterns); - return converter.convertFunction(&fn, /*signatureConversion=*/nullptr); + return converter.convertFunction(fn, /*signatureConversion=*/nullptr); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 5a926ce..a3aa092 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -214,7 +214,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED emitRemarkForBlock(Block &block) { auto *op = block.getContainingOp(); - return op ? op->emitRemark() : block.getFunction()->emitRemark(); + return op ? op->emitRemark() : block.getFunction().emitRemark(); } /// Creates a buffer in the faster memory space for the specified region; @@ -246,8 +246,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, OpBuilder &b = region.isWrite() ? epilogue : prologue; // Builder to create constants at the top level. - auto *func = block->getFunction(); - OpBuilder top(func->getBody()); + auto func = block->getFunction(); + OpBuilder top(func.getBody()); auto loc = region.loc; auto *memref = region.memref; @@ -751,14 +751,14 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { if (auto *op = block->getContainingOp()) op->emitError(str); else - block->getFunction()->emitError(str); + block->getFunction().emitError(str); } return totalDmaBuffersSizeInBytes; } void DmaGeneration::runOnFunction() { - Function &f = getFunction(); + Function f = getFunction(); OpBuilder topBuilder(f.getBody()); zeroIndex = topBuilder.create(f.getLoc(), 0); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8d2e75b..77b944f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -257,7 +257,7 @@ public: // Initializes the dependence graph based on operations in 'f'. // Returns true on success, false otherwise. - bool init(Function &f); + bool init(Function f); // Returns the graph node for 'id'. Node *getNode(unsigned id) { @@ -637,7 +637,7 @@ public: // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. -bool MemRefDependenceGraph::init(Function &f) { +bool MemRefDependenceGraph::init(Function f) { DenseMap> memrefAccesses; // TODO: support multi-block functions. @@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Create builder to insert alloc op just before 'forOp'. OpBuilder b(forInst); // Builder to create constants at the top level. - OpBuilder top(forInst->getFunction()->getBody()); + OpBuilder top(forInst->getFunction().getBody()); // Create new memref type based on slice bounds. auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); @@ -1750,9 +1750,9 @@ public: }; // Search for siblings which load the same memref function argument. - auto *fn = dstNode->op->getFunction(); - for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { - for (auto *user : fn->getArgument(i)->getUsers()) { + auto fn = dstNode->op->getFunction(); + for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { + for (auto *user : fn.getArgument(i)->getUsers()) { if (auto loadOp = dyn_cast(user)) { // Gather loops surrounding 'use'. SmallVector loops; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index c1be6e8..2744e5c 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -261,7 +261,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(Function &f, +static void getTileableBands(Function f, std::vector> *bands) { // Get maximal perfect nest of 'affine.for' insts starting from root // (inclusive). diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 0595392..6f13f62 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -92,8 +92,8 @@ void LoopUnroll::runOnFunction() { // Store innermost loops as we walk. std::vector loops; - void walkPostOrder(Function *f) { - for (auto &b : *f) + void walkPostOrder(Function f) { + for (auto &b : f) walkPostOrder(b.begin(), b.end()); } @@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() { ? clUnrollNumRepetitions : 1; // If the call back is provided, we will recurse until no loops are found. - Function &func = getFunction(); + Function func = getFunction(); for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { InnermostLoopGatherer ilg; - ilg.walkPostOrder(&func); + ilg.walkPostOrder(func); auto &loops = ilg.loops; if (loops.empty()) break; diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 77a23b1..df30e27 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -726,7 +726,7 @@ public: } // end namespace -LogicalResult mlir::lowerAffineConstructs(Function &function) { +LogicalResult mlir::lowerAffineConstructs(Function function) { OwningRewritePatternList patterns; RewriteListBuildergetFunction()->print(dbgs())); + LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* @@ -667,7 +667,7 @@ static bool emitSlice(MaterializationState *state, /// because we currently disallow vectorization of defs that come from another /// scope. /// TODO(ntv): please document return value. -static bool materialize(Function *f, const SetVector &terminators, +static bool materialize(Function f, const SetVector &terminators, MaterializationState *state) { DenseSet seen; DominanceInfo domInfo(f); @@ -721,7 +721,7 @@ static bool materialize(Function *f, const SetVector &terminators, return true; } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG(f->print(dbgs())); + LLVM_DEBUG(f.print(dbgs())); } return false; } @@ -731,13 +731,13 @@ void MaterializeVectorsPass::runOnFunction() { NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. - Function *f = &getFunction(); - if (f->getBlocks().size() != 1) + Function f = getFunction(); + if (f.getBlocks().size() != 1) return; using matcher::Op; LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n"); - LLVM_DEBUG(f->print(dbgs())); + LLVM_DEBUG(f.print(dbgs())); MaterializationState state(hwVectorSize); // Get the hardware vector type. diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index c5676af..1208e2f 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -212,7 +212,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { void MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. - Function &f = getFunction(); + Function f = getFunction(); if (f.getBlocks().size() != 1) { markAllAnalysesPreserved(); return; diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index f97f549..c7c3621 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -29,7 +29,7 @@ struct StripDebugInfo : public FunctionPass { } // end anonymous namespace void StripDebugInfo::runOnFunction() { - Function &func = getFunction(); + Function func = getFunction(); auto unknownLoc = UnknownLoc::get(&getContext()); // Strip the debug info from the function and its operations. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 47ca378..e185f70 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -44,7 +44,7 @@ namespace { /// applies the locally optimal patterns in a roughly "bottom up" way. class GreedyPatternRewriteDriver : public PatternRewriter { public: - explicit GreedyPatternRewriteDriver(Function &fn, + explicit GreedyPatternRewriteDriver(Function fn, OwningRewritePatternList &&patterns) : PatternRewriter(fn.getBody()), matcher(std::move(patterns)) { worklist.reserve(64); @@ -213,7 +213,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { /// patterns in a greedy work-list driven manner. Return true if no more /// patterns can be matched in the result function. /// -bool mlir::applyPatternsGreedily(Function &fn, +bool mlir::applyPatternsGreedily(Function fn, OwningRewritePatternList &&patterns) { GreedyPatternRewriteDriver driver(fn, std::move(patterns)); bool converged = driver.simplifyFunction(maxPatternMatchIterations); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 728123f..4ddf93c 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { Operation *op = forOp.getOperation(); if (!iv->use_empty()) { if (forOp.hasConstantLowerBound()) { - OpBuilder topBuilder(op->getFunction()->getBody()); + OpBuilder topBuilder(op->getFunction().getBody()); auto constOp = topBuilder.create( forOp.getLoc(), forOp.getConstantLowerBound()); iv->replaceAllUsesWith(constOp); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 39a05d8..3fca26b 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1194,7 +1194,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. void Vectorize::runOnFunction() { - Function &f = getFunction(); + Function f = getFunction(); if (!fastestVaryingPattern.empty() && fastestVaryingPattern.size() != vectorSizes.size()) { f.emitRemark("Fastest varying pattern specified with different size than " @@ -1220,7 +1220,7 @@ void Vectorize::runOnFunction() { unsigned patternDepth = pat.getDepth(); SmallVector matches; - pat.match(&f, &matches); + pat.match(f, &matches); // Iterate over all the top-level matches and vectorize eagerly. // This automatically prunes intersecting matches. for (auto m : matches) { diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 1f2ab69..3c1a1b3 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -53,13 +53,13 @@ std::string DOTGraphTraits::getNodeLabel(Block *Block, Function *) { } // end namespace llvm -void mlir::viewGraph(Function &function, const llvm::Twine &name, +void mlir::viewGraph(Function function, const llvm::Twine &name, bool shortNames, const llvm::Twine &title, llvm::GraphProgram::Name program) { llvm::ViewGraph(&function, name, shortNames, title, program); } -llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function, +llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function function, bool shortNames, const llvm::Twine &title) { return llvm::WriteGraph(os, &function, shortNames, title); } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 834e7c9..a88312d 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -43,13 +43,12 @@ static MLIRContext &globalContext() { return context; } -static std::unique_ptr makeFunction(StringRef name, - ArrayRef results = {}, - ArrayRef args = {}) { +static Function makeFunction(StringRef name, ArrayRef results = {}, + ArrayRef args = {}) { auto &ctx = globalContext(); - auto function = llvm::make_unique( - UnknownLoc::get(&ctx), name, FunctionType::get(args, results, &ctx)); - function->addEntryBlock(); + auto function = Function::create(UnknownLoc::get(&ctx), name, + FunctionType::get(args, results, &ctx)); + function.addEntryBlock(); return function; } @@ -62,10 +61,10 @@ TEST_FUNC(builder_dynamic_for_func_args) { auto f = makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)), - ub(f->getArgument(1)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)), + ub(f.getArgument(1)); ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type)); ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type)); ValueHandle i7(constant_int(7, 32)); @@ -102,7 +101,8 @@ TEST_FUNC(builder_dynamic_for_func_args) { // CHECK-DAG: [[ri4:%[0-9]+]] = muli {{.*}}, {{.*}} : i32 // CHECK: {{.*}} = subi [[ri3]], [[ri4]] : i32 // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_dynamic_for) { @@ -113,10 +113,10 @@ TEST_FUNC(builder_dynamic_for) { auto f = makeFunction("builder_dynamic_for", {}, {indexType, indexType, indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)), - c(f->getArgument(2)), d(f->getArgument(3)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)), + c(f.getArgument(2)), d(f.getArgument(3)); LoopBuilder(&i, a - b, c + d, 2)(); // clang-format off @@ -125,7 +125,8 @@ TEST_FUNC(builder_dynamic_for) { // CHECK-DAG: [[r1:%[0-9]+]] = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3] // CHECK-NEXT: affine.for %i0 = (d0) -> (d0)([[r0]]) to (d0) -> (d0)([[r1]]) step 2 { // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_max_min_for) { @@ -136,10 +137,10 @@ TEST_FUNC(builder_max_min_for) { auto f = makeFunction("builder_max_min_for", {}, {indexType, indexType, indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)), - ub1(f->getArgument(2)), ub2(f->getArgument(3)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)), + ub1(f.getArgument(2)), ub2(f.getArgument(3)); LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)(); ret(); @@ -148,7 +149,8 @@ TEST_FUNC(builder_max_min_for) { // CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) { // CHECK: return // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_blocks) { @@ -157,14 +159,14 @@ TEST_FUNC(builder_blocks) { using namespace edsc::op; auto f = makeFunction("builder_blocks"); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle c1(ValueHandle::create(42, 32)), c2(ValueHandle::create(1234, 32)); ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()), arg4(c1.getType()), r(c1.getType()); - BlockHandle b1, b2, functionBlock(&f->front()); + BlockHandle b1, b2, functionBlock(&f.front()); BlockBuilder(&b1, {&arg1, &arg2})( // b2 has not yet been constructed, need to come back later. // This is a byproduct of non-structured control-flow. @@ -192,7 +194,8 @@ TEST_FUNC(builder_blocks) { // CHECK-NEXT: br ^bb1(%3, %4 : i32, i32) // CHECK-NEXT: } // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_blocks_eager) { @@ -201,8 +204,8 @@ TEST_FUNC(builder_blocks_eager) { using namespace edsc::op; auto f = makeFunction("builder_blocks_eager"); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle c1(ValueHandle::create(42, 32)), c2(ValueHandle::create(1234, 32)); ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()), @@ -235,7 +238,8 @@ TEST_FUNC(builder_blocks_eager) { // CHECK-NEXT: br ^bb1(%3, %4 : i32, i32) // CHECK-NEXT: } // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_cond_branch) { @@ -244,15 +248,15 @@ TEST_FUNC(builder_cond_branch) { auto f = makeFunction("builder_cond_branch", {}, {IntegerType::get(1, &globalContext())}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle funcArg(f->getArgument(0)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle funcArg(f.getArgument(0)); ValueHandle c32(ValueHandle::create(32, 32)), c64(ValueHandle::create(64, 64)), c42(ValueHandle::create(42, 32)); ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType()); - BlockHandle b1, b2, functionBlock(&f->front()); + BlockHandle b1, b2, functionBlock(&f.front()); BlockBuilder(&b1, {&arg1})([&] { ret(); }); BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); }); // Get back to entry block and add a conditional branch @@ -271,7 +275,8 @@ TEST_FUNC(builder_cond_branch) { // CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0 // CHECK-NEXT: return // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_cond_branch_eager) { @@ -281,9 +286,9 @@ TEST_FUNC(builder_cond_branch_eager) { auto f = makeFunction("builder_cond_branch_eager", {}, {IntegerType::get(1, &globalContext())}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle funcArg(f->getArgument(0)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle funcArg(f.getArgument(0)); ValueHandle c32(ValueHandle::create(32, 32)), c64(ValueHandle::create(64, 64)), c42(ValueHandle::create(42, 32)); @@ -309,7 +314,8 @@ TEST_FUNC(builder_cond_branch_eager) { // CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0 // CHECK-NEXT: return // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_helpers) { @@ -321,14 +327,14 @@ TEST_FUNC(builder_helpers) { auto f = makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle f7( ValueHandle::create(llvm::APFloat(7.0f), f32Type)); - MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), - vC(f->getArgument(2)); - IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2)); + MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), + vC(f.getArgument(2)); + IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2; int64_t step0, step1, step2; std::tie(lb0, ub0, step0) = vA.range(0); @@ -363,7 +369,8 @@ TEST_FUNC(builder_helpers) { // CHECK-DAG: [[e:%.*]] = addf [[d]], [[c]] : f32 // CHECK-NEXT: store [[e]], %arg2[%i0, %i1, %i3] : memref // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(custom_ops) { @@ -373,8 +380,8 @@ TEST_FUNC(custom_ops) { auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("custom_ops", {}, {indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); CustomOperation MY_CUSTOM_OP("my_custom_op"); CustomOperation MY_CUSTOM_OP_0("my_custom_op_0"); CustomOperation MY_CUSTOM_OP_2("my_custom_op_2"); @@ -382,7 +389,7 @@ TEST_FUNC(custom_ops) { // clang-format off ValueHandle vh(indexType), vh20(indexType), vh21(indexType); OperationHandle ih0, ih2; - IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1)); + IndexHandle m, n, M(f.getArgument(0)), N(f.getArgument(1)); IndexHandle ten(index_t(10)), twenty(index_t(20)); LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{ vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {}); @@ -402,7 +409,8 @@ TEST_FUNC(custom_ops) { // CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index) // CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(insertion_in_block) { @@ -412,8 +420,8 @@ TEST_FUNC(insertion_in_block) { auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("insertion_in_block", {}, {indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); BlockHandle b1; // clang-format off ValueHandle::create(0, 32); @@ -427,7 +435,8 @@ TEST_FUNC(insertion_in_block) { // CHECK: ^bb1: // no predecessors // CHECK: {{.*}} = constant 1 : i32 // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(select_op) { @@ -438,12 +447,12 @@ TEST_FUNC(select_op) { auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0); auto f = makeFunction("select_op", {}, {memrefType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle zero = constant_index(0), one = constant_index(1); - MemRefView vA(f->getArgument(0)); - IndexedValue A(f->getArgument(0)); + MemRefView vA(f.getArgument(0)); + IndexedValue A(f.getArgument(0)); IndexHandle i, j; LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{ // This test exercises IndexedValue::operator Value*. @@ -461,7 +470,8 @@ TEST_FUNC(select_op) { // CHECK-DAG: {{.*}} = load // CHECK-NEXT: {{.*}} = select // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } // Inject an EDSC-constructed computation to exercise imperfectly nested 2-d @@ -474,12 +484,11 @@ TEST_FUNC(tile_2d) { MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0); auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle zero = constant_index(0); - MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), - vC(f->getArgument(2)); - IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2)); + MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); + IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2)); // clang-format off @@ -531,7 +540,8 @@ TEST_FUNC(tile_2d) { // CHECK-NEXT: {{.*}}= addf {{.*}}, {{.*}} : f32 // CHECK-NEXT: store {{.*}}, {{.*}}[%i8, %i9, %i7] : memref // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } // Inject an EDSC-constructed computation to exercise 2-d vectorization. @@ -544,16 +554,15 @@ TEST_FUNC(vectorize_2d) { auto owningF = makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType}); - mlir::Function *f = owningF.release(); + mlir::Function f = owningF; mlir::Module module(&globalContext()); - module.getFunctions().push_back(f); + module.push_back(f); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle zero = constant_index(0); - MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), - vC(f->getArgument(2)); - IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2)); + MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); + IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); IndexHandle M(vA.ub(0)), N(vA.ub(1)), P(vA.ub(2)); // clang-format off @@ -580,9 +589,10 @@ TEST_FUNC(vectorize_2d) { pm.addPass(mlir::createCanonicalizerPass()); SmallVector vectorSizes{4, 4}; pm.addPass(mlir::createVectorizePass(vectorSizes)); - auto result = pm.run(f->getModule()); + auto result = pm.run(f.getModule()); if (succeeded(result)) - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } int main() { diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index 4767e33..7bfb556 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -97,12 +97,12 @@ struct VectorizerTestPass : public FunctionPass { } // end anonymous namespace void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); using matcher::Op; SmallVector shape(clTestVectorShapeRatio.begin(), clTestVectorShapeRatio.end()); auto subVectorType = - VectorType::get(shape, FloatType::getF32(f->getContext())); + VectorType::get(shape, FloatType::getF32(f.getContext())); // Only filter operations that operate on a strict super-vector and have one // return. This makes testing easier. auto filter = [&](Operation &op) { @@ -148,7 +148,7 @@ static NestedPattern patternTestSlicingOps() { } void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); SmallVector matches; patternTestSlicingOps().match(f, &matches); @@ -163,7 +163,7 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) { } void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { @@ -177,7 +177,7 @@ void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) { } void VectorizerTestPass::testSlicing(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); SmallVector matches; patternTestSlicingOps().match(f, &matches); @@ -195,7 +195,7 @@ static bool customOpWithAffineMapAttribute(Operation &op) { } void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); using matcher::Op; auto pattern = Op(customOpWithAffineMapAttribute); @@ -227,7 +227,7 @@ static bool singleResultAffineApplyOpWithoutUses(Operation &op) { void VectorizerTestPass::testNormalizeMaps() { using matcher::Op; - auto *f = &getFunction(); + auto f = getFunction(); // Save matched AffineApplyOp that all need to be erased in the end. auto pattern = Op(affineApplyOp); @@ -256,7 +256,7 @@ void VectorizerTestPass::runOnFunction() { NestedPatternContext mlContext; // Only support single block functions at this point. - Function &f = getFunction(); + Function f = getFunction(); if (f.getBlocks().size() != 1) return; diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp index 54a9c6c..1ac6c40 100644 --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp @@ -163,8 +163,8 @@ static LogicalResult convertAffineStandardToLLVMIR(Module *module) { static Error compileAndExecuteFunctionWithMemRefs( Module *module, StringRef entryPoint, std::function transformer) { - Function *mainFunction = module->getNamedFunction(entryPoint); - if (!mainFunction || mainFunction->getBlocks().empty()) { + Function mainFunction = module->getNamedFunction(entryPoint); + if (!mainFunction || mainFunction.getBlocks().empty()) { return make_string_error("entry point not found"); } @@ -172,9 +172,9 @@ static Error compileAndExecuteFunctionWithMemRefs( // pretty print the results, because the function itself will be rewritten // to use the LLVM dialect. SmallVector argTypes = - llvm::to_vector<8>(mainFunction->getType().getInputs()); + llvm::to_vector<8>(mainFunction.getType().getInputs()); SmallVector resTypes = - llvm::to_vector<8>(mainFunction->getType().getResults()); + llvm::to_vector<8>(mainFunction.getType().getResults()); float init = std::stof(initValue.getValue()); @@ -206,18 +206,18 @@ static Error compileAndExecuteFunctionWithMemRefs( static Error compileAndExecuteSingleFloatReturnFunction( Module *module, StringRef entryPoint, std::function transformer) { - Function *mainFunction = module->getNamedFunction(entryPoint); - if (!mainFunction || mainFunction->isExternal()) { + Function mainFunction = module->getNamedFunction(entryPoint); + if (!mainFunction || mainFunction.isExternal()) { return make_string_error("entry point not found"); } - if (!mainFunction->getType().getInputs().empty()) + if (!mainFunction.getType().getInputs().empty()) return make_string_error("function inputs not supported"); - if (mainFunction->getType().getResults().size() != 1) + if (mainFunction.getType().getResults().size() != 1) return make_string_error("only single f32 function result supported"); - auto t = mainFunction->getType().getResults()[0].dyn_cast(); + auto t = mainFunction.getType().getResults()[0].dyn_cast(); if (!t) return make_string_error("only single llvm.f32 function result supported"); auto *llvmTy = t.getUnderlyingType(); diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp index 38a059b..d2a8237 100644 --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -25,11 +25,11 @@ using namespace mlir::detail; namespace { /// Minimal class definitions for two analyses. struct MyAnalysis { - MyAnalysis(Function *) {} + MyAnalysis(Function) {} MyAnalysis(Module *) {} }; struct OtherAnalysis { - OtherAnalysis(Function *) {} + OtherAnalysis(Function) {} OtherAnalysis(Module *) {} }; @@ -59,10 +59,10 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) { // Create a function and a module. std::unique_ptr module(new Module(&context)); - Function *func1 = - new Function(builder.getUnknownLoc(), "foo", - builder.getFunctionType(llvm::None, llvm::None)); - module->getFunctions().push_back(func1); + Function func1 = + Function::create(builder.getUnknownLoc(), "foo", + builder.getFunctionType(llvm::None, llvm::None)); + module->push_back(func1); // Test fine grain invalidation of the function analysis manager. ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr); @@ -87,10 +87,10 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) { // Create a function and a module. std::unique_ptr module(new Module(&context)); - Function *func1 = - new Function(builder.getUnknownLoc(), "foo", - builder.getFunctionType(llvm::None, llvm::None)); - module->getFunctions().push_back(func1); + Function func1 = + Function::create(builder.getUnknownLoc(), "foo", + builder.getFunctionType(llvm::None, llvm::None)); + module->push_back(func1); // Test fine grain invalidation of a function analysis from within a module // analysis manager. -- 2.7.4