From fa9ac61fff1a52e63a4adf14aad936a8de04b6b7 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Fri, 9 Dec 2022 11:31:53 +0100 Subject: [PATCH] [mlir][llvm] Modernize LLVM instruction and global import (NFC). Modernize the import of LLVMIR instructions and global variables. Use longer variable names, factor out code used to import call or invoke instructions, use the CPP builders for importing branch instructions, etc. The revision is a preparation for a follow up revision that moves the import code to implement improved error handling. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D139404 --- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 214 ++++++++++++++------------- 1 file changed, 115 insertions(+), 99 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index f151965..9866b3c 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -403,8 +403,8 @@ public: /// into LLVM dialect attributes of LLVMFuncOp \p funcOp. void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp); - /// Imports GV as a GlobalOp, creating it if it doesn't exist. - GlobalOp processGlobal(llvm::GlobalVariable *gv); + /// Imports `globalVar` as a GlobalOp, creating it if it doesn't exist. + GlobalOp processGlobal(llvm::GlobalVariable *globalVar); private: /// Clears the block and value mapping before processing a new region. @@ -432,6 +432,12 @@ private: LogicalResult convertBranchArgs(llvm::Instruction *branch, llvm::BasicBlock *target, SmallVectorImpl &blockArguments); + /// Appends the converted result type and operands of `callInst` to the + /// `types` and `operands` arrays. For indirect calls, the method additionally + /// inserts the called function at the beginning of the `operands` array. + void convertCallTypeAndOperands(llvm::CallBase *callInst, + SmallVectorImpl &types, + SmallVectorImpl &operands); /// Returns the builtin type equivalent to be used in attributes for the given /// LLVM IR dialect type. Type getStdTypeForAttr(Type type); @@ -648,9 +654,9 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) { return nullptr; } -GlobalOp Importer::processGlobal(llvm::GlobalVariable *gv) { - if (globals.count(gv)) - return globals[gv]; +GlobalOp Importer::processGlobal(llvm::GlobalVariable *globalVar) { + if (globals.count(globalVar)) + return globals[globalVar]; // Insert the global after the last one or at the start of the module. OpBuilder::InsertionGuard guard(builder); @@ -661,38 +667,40 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *gv) { } Attribute valueAttr; - if (gv->hasInitializer()) - valueAttr = getConstantAsAttr(gv->getInitializer()); - Type type = convertType(gv->getValueType()); + if (globalVar->hasInitializer()) + valueAttr = getConstantAsAttr(globalVar->getInitializer()); + Type type = convertType(globalVar->getValueType()); uint64_t alignment = 0; - llvm::MaybeAlign maybeAlign = gv->getAlign(); + llvm::MaybeAlign maybeAlign = globalVar->getAlign(); if (maybeAlign.has_value()) { llvm::Align align = maybeAlign.value(); alignment = align.value(); } - GlobalOp op = builder.create( - UnknownLoc::get(context), type, gv->isConstant(), - convertLinkageFromLLVM(gv->getLinkage()), gv->getName(), valueAttr, - alignment, /*addr_space=*/gv->getAddressSpace(), - /*dso_local=*/gv->isDSOLocal(), - /*thread_local=*/gv->isThreadLocal()); - globalInsertionOp = op; + GlobalOp globalOp = builder.create( + UnknownLoc::get(context), type, globalVar->isConstant(), + convertLinkageFromLLVM(globalVar->getLinkage()), globalVar->getName(), + valueAttr, alignment, /*addr_space=*/globalVar->getAddressSpace(), + /*dso_local=*/globalVar->isDSOLocal(), + /*thread_local=*/globalVar->isThreadLocal()); + globalInsertionOp = globalOp; - if (gv->hasInitializer() && !valueAttr) { + if (globalVar->hasInitializer() && !valueAttr) { clearBlockAndValueMapping(); - Block *block = builder.createBlock(&op.getInitializerRegion()); + Block *block = builder.createBlock(&globalOp.getInitializerRegion()); setConstantInsertionPointToStart(block); - Value value = convertConstantExpr(gv->getInitializer()); - builder.create(op.getLoc(), value); + Value value = convertConstantExpr(globalVar->getInitializer()); + builder.create(globalOp.getLoc(), value); } - if (gv->hasAtLeastLocalUnnamedAddr()) - op.setUnnamedAddr(convertUnnamedAddrFromLLVM(gv->getUnnamedAddr())); - if (gv->hasSection()) - op.setSection(gv->getSection()); + if (globalVar->hasAtLeastLocalUnnamedAddr()) { + globalOp.setUnnamedAddr( + convertUnnamedAddrFromLLVM(globalVar->getUnnamedAddr())); + } + if (globalVar->hasSection()) + globalOp.setSection(globalVar->getSection()); - return globals[gv] = op; + return globals[globalVar] = globalOp; } SetVector @@ -906,6 +914,20 @@ Importer::convertBranchArgs(llvm::Instruction *branch, llvm::BasicBlock *target, return success(); } +void Importer::convertCallTypeAndOperands(llvm::CallBase *callInst, + SmallVectorImpl &types, + SmallVectorImpl &operands) { + if (!callInst->getType()->isVoidTy()) + types.push_back(convertType(callInst->getType())); + + if (!callInst->getCalledFunction()) { + Value called = convertValue(callInst->getCalledOperand()); + operands.push_back(called); + } + SmallVector args(callInst->args()); + llvm::append_range(operands, convertValues(args)); +} + LogicalResult Importer::processInstruction(llvm::Instruction *inst) { // FIXME: Support uses of SubtargetData. // FIXME: Add support for inbounds GEPs. @@ -922,34 +944,31 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { if (succeeded(convertOperation(builder, inst))) return success(); - // Convert all special instructions that do not provide an MLIR builder. + // Convert all remaining instructions that do not provide an MLIR builder. Location loc = translateLoc(inst->getDebugLoc()); if (inst->getOpcode() == llvm::Instruction::Br) { auto *brInst = cast(inst); - OperationState state(loc, - brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); - if (brInst->isConditional()) { - Value condition = convertValue(brInst->getCondition()); - state.addOperands(condition); - } - std::array operandSegmentSizes = {1, 0, 0}; - for (int i : llvm::seq(0, brInst->getNumSuccessors())) { + SmallVector succBlocks; + SmallVector> succBlockArgs; + for (auto i : llvm::seq(0, brInst->getNumSuccessors())) { llvm::BasicBlock *succ = brInst->getSuccessor(i); - SmallVector blockArguments; - if (failed(convertBranchArgs(brInst, succ, blockArguments))) + SmallVector blockArgs; + if (failed(convertBranchArgs(brInst, succ, blockArgs))) return failure(); - state.addSuccessors(lookupBlock(succ)); - state.addOperands(blockArguments); - operandSegmentSizes[i + 1] = blockArguments.size(); + succBlocks.push_back(lookupBlock(succ)); + succBlockArgs.push_back(blockArgs); } if (brInst->isConditional()) { - state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr(operandSegmentSizes)); + Value condition = convertValue(brInst->getCondition()); + builder.create(loc, condition, succBlocks.front(), + succBlockArgs.front(), succBlocks.back(), + succBlockArgs.back()); + } else { + builder.create(loc, succBlockArgs.front(), + succBlocks.front()); } - - builder.create(state); return success(); } if (inst->getOpcode() == llvm::Instruction::Switch) { @@ -990,85 +1009,82 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { return success(); } if (inst->getOpcode() == llvm::Instruction::Call) { - llvm::CallInst *ci = cast(inst); - SmallVector args(ci->args()); - SmallVector ops = convertValues(args); - SmallVector tys; - if (!ci->getType()->isVoidTy()) { - Type type = convertType(inst->getType()); - tys.push_back(type); - } - Operation *op; - if (llvm::Function *callee = ci->getCalledFunction()) { - op = builder.create( - loc, tys, SymbolRefAttr::get(builder.getContext(), callee->getName()), - ops); + auto *callInst = cast(inst); + + SmallVector types; + SmallVector operands; + convertCallTypeAndOperands(callInst, types, operands); + + CallOp callOp; + if (llvm::Function *callee = callInst->getCalledFunction()) { + callOp = builder.create( + loc, types, SymbolRefAttr::get(context, callee->getName()), operands); } else { - Value calledValue = convertValue(ci->getCalledOperand()); - ops.insert(ops.begin(), calledValue); - op = builder.create(loc, tys, ops); + callOp = builder.create(loc, types, operands); } - if (!ci->getType()->isVoidTy()) - mapValue(inst, op->getResult(0)); + if (!callInst->getType()->isVoidTy()) + mapValue(inst, callOp.getResult()); return success(); } if (inst->getOpcode() == llvm::Instruction::LandingPad) { - llvm::LandingPadInst *lpi = cast(inst); - SmallVector ops; + auto *lpInst = cast(inst); - for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++) - ops.push_back(convertConstantExpr(lpi->getClause(i))); + SmallVector operands; + operands.reserve(lpInst->getNumClauses()); + for (auto i : llvm::seq(0, lpInst->getNumClauses())) { + Value operand = convertConstantExpr(lpInst->getClause(i)); + operands.push_back(operand); + } - Type ty = convertType(lpi->getType()); - Value res = builder.create(loc, ty, lpi->isCleanup(), ops); + Type type = convertType(lpInst->getType()); + Value res = + builder.create(loc, type, lpInst->isCleanup(), operands); mapValue(inst, res); return success(); } if (inst->getOpcode() == llvm::Instruction::Invoke) { - llvm::InvokeInst *ii = cast(inst); - - SmallVector tys; - if (!ii->getType()->isVoidTy()) - tys.push_back(convertType(inst->getType())); - - SmallVector args(ii->args()); - SmallVector ops = convertValues(args); - - SmallVector normalArgs, unwindArgs; - (void)convertBranchArgs(ii, ii->getNormalDest(), normalArgs); - (void)convertBranchArgs(ii, ii->getUnwindDest(), unwindArgs); - - Operation *op; - if (llvm::Function *callee = ii->getCalledFunction()) { - op = builder.create( - loc, tys, SymbolRefAttr::get(builder.getContext(), callee->getName()), - ops, lookupBlock(ii->getNormalDest()), normalArgs, - lookupBlock(ii->getUnwindDest()), unwindArgs); + auto *invokeInst = cast(inst); + + SmallVector types; + SmallVector operands; + convertCallTypeAndOperands(invokeInst, types, operands); + + SmallVector normalArgs, unwindArgs; + (void)convertBranchArgs(invokeInst, invokeInst->getNormalDest(), + normalArgs); + (void)convertBranchArgs(invokeInst, invokeInst->getUnwindDest(), + unwindArgs); + + InvokeOp invokeOp; + if (llvm::Function *callee = invokeInst->getCalledFunction()) { + invokeOp = builder.create( + loc, types, + SymbolRefAttr::get(builder.getContext(), callee->getName()), operands, + lookupBlock(invokeInst->getNormalDest()), normalArgs, + lookupBlock(invokeInst->getUnwindDest()), unwindArgs); } else { - ops.insert(ops.begin(), convertValue(ii->getCalledOperand())); - op = builder.create( - loc, tys, ops, lookupBlock(ii->getNormalDest()), normalArgs, - lookupBlock(ii->getUnwindDest()), unwindArgs); + invokeOp = builder.create( + loc, types, operands, lookupBlock(invokeInst->getNormalDest()), + normalArgs, lookupBlock(invokeInst->getUnwindDest()), unwindArgs); } - - if (!ii->getType()->isVoidTy()) - mapValue(inst, op->getResult(0)); + if (!invokeInst->getType()->isVoidTy()) + mapValue(inst, invokeOp.getResults().front()); return success(); } if (inst->getOpcode() == llvm::Instruction::GetElementPtr) { // FIXME: Support inbounds GEPs. - llvm::GetElementPtrInst *gep = cast(inst); - Value basePtr = convertValue(gep->getOperand(0)); - Type sourceElementType = convertType(gep->getSourceElementType()); + auto *gepInst = cast(inst); + Type sourceElementType = convertType(gepInst->getSourceElementType()); + Value basePtr = convertValue(gepInst->getOperand(0)); // Treat every indices as dynamic since GEPOp::build will refine those // indices into static attributes later. One small downside of this // approach is that many unused `llvm.mlir.constant` would be emitted // at first place. SmallVector indices; - for (llvm::Value *operand : llvm::drop_begin(gep->operand_values())) { - Value val = convertValue(operand); - indices.push_back(val); + for (llvm::Value *operand : llvm::drop_begin(gepInst->operand_values())) { + Value index = convertValue(operand); + indices.push_back(index); } Type type = convertType(inst->getType()); -- 2.7.4