From: Alex Zinenko Date: Thu, 16 Jan 2020 13:32:33 +0000 (+0100) Subject: [mlir] Harden error propagation in LLVM import X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=84c3f05c8e3e28fd58c458f842e721bbbaa837b2;p=platform%2Fupstream%2Fllvm.git [mlir] Harden error propagation in LLVM import Summary: LLVM importer to MLIR was implemented mostly as a prototype. As such, it did not deal handle errors in a consistent way, reporting them out stderr in some cases and continuing the execution in the error state until eventually crashing. This is not desirable for a user-facing tool. Make sure errors are returned from functions, consistently checked at call sites and propagated further. Functions returning nullable IR values return nullptr to denote the error state. Other functions return LogicalResult. LLVM importer in mlir-translate should no longer crash on unsupported inputs. The errors are reported without association with the source file (and therefore cannot be checked using -verify-diagnostics). Attaching them to the actual input file is left for future work. Differential Revision: https://reviews.llvm.org/D72839 --- diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index e882161..0692b9b 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -73,11 +73,12 @@ private: /// is unavailable. Location processDebugLoc(const llvm::DebugLoc &loc, llvm::Instruction *inst = nullptr); - /// `br` branches to `target`. Return the block arguments to attach to the - /// generated branch op. These should be in the same order as the PHIs in - /// `target`. - SmallVector processBranchArgs(llvm::BranchInst *br, - llvm::BasicBlock *target); + /// `br` branches to `target`. Append the block arguments to attach to the + /// generated branch op to `blockArguments`. These should be in the same order + /// as the PHIs in `target`. + LogicalResult processBranchArgs(llvm::BranchInst *br, + llvm::BasicBlock *target, + SmallVectorImpl &blockArguments); /// Returns the standard type equivalent to be used in attributes for the /// given LLVM IR dialect type. Type getStdTypeForAttr(LLVMType type); @@ -151,17 +152,27 @@ LLVMType Importer::processType(llvm::Type *type) { return LLVMType::getDoubleTy(dialect); case llvm::Type::IntegerTyID: return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth()); - case llvm::Type::PointerTyID: - return processType(type->getPointerElementType()) - .getPointerTo(type->getPointerAddressSpace()); - case llvm::Type::ArrayTyID: - return LLVMType::getArrayTy(processType(type->getArrayElementType()), - type->getArrayNumElements()); + case llvm::Type::PointerTyID: { + LLVMType elementType = processType(type->getPointerElementType()); + if (!elementType) + return nullptr; + return elementType.getPointerTo(type->getPointerAddressSpace()); + } + case llvm::Type::ArrayTyID: { + LLVMType elementType = processType(type->getArrayElementType()); + if (!elementType) + return nullptr; + return LLVMType::getArrayTy(elementType, type->getArrayNumElements()); + } case llvm::Type::VectorTyID: { - if (type->getVectorIsScalable()) + if (type->getVectorIsScalable()) { emitError(unknownLoc) << "scalable vector types not supported"; - return LLVMType::getVectorTy(processType(type->getVectorElementType()), - type->getVectorNumElements()); + return nullptr; + } + LLVMType elementType = processType(type->getVectorElementType()); + if (!elementType) + return nullptr; + return LLVMType::getVectorTy(elementType, type->getVectorNumElements()); } case llvm::Type::VoidTyID: return LLVMType::getVoidTy(dialect); @@ -171,18 +182,30 @@ LLVMType Importer::processType(llvm::Type *type) { return LLVMType::getX86_FP80Ty(dialect); case llvm::Type::StructTyID: { SmallVector elementTypes; - for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) - elementTypes.push_back(processType(type->getStructElementType(i))); + elementTypes.reserve(type->getStructNumElements()); + for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) { + LLVMType ty = processType(type->getStructElementType(i)); + if (!ty) + return nullptr; + elementTypes.push_back(ty); + } return LLVMType::getStructTy(dialect, elementTypes, cast(type)->isPacked()); } case llvm::Type::FunctionTyID: { llvm::FunctionType *fty = cast(type); SmallVector paramTypes; - for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) - paramTypes.push_back(processType(fty->getParamType(i))); - return LLVMType::getFunctionTy(processType(fty->getReturnType()), - paramTypes, fty->isVarArg()); + for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) { + LLVMType ty = processType(fty->getParamType(i)); + if (!ty) + return nullptr; + paramTypes.push_back(ty); + } + LLVMType result = processType(fty->getReturnType()); + if (!result) + return nullptr; + + return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg()); } default: { // FIXME: Diagnostic should be able to natively handle types that have @@ -191,7 +214,7 @@ LLVMType Importer::processType(llvm::Type *type) { llvm::raw_string_ostream os(s); os << *type; emitError(unknownLoc) << "unhandled type: " << os.str(); - return {}; + return nullptr; } } } @@ -217,10 +240,14 @@ Type Importer::getStdTypeForAttr(LLVMType type) { // LLVM vectors can only contain scalars. if (type.isVectorTy()) { auto numElements = type.getUnderlyingType()->getVectorElementCount(); - if (numElements.Scalable) + if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; - return VectorType::get(numElements.Min, - getStdTypeForAttr(type.getVectorElementType())); + return nullptr; + } + Type elementType = getStdTypeForAttr(type.getVectorElementType()); + if (!elementType) + return nullptr; + return VectorType::get(numElements.Min, elementType); } // LLVM arrays can contain other arrays or vectors. @@ -239,20 +266,26 @@ Type Importer::getStdTypeForAttr(LLVMType type) { LLVMType vectorType = type.getArrayElementType(); auto numElements = vectorType.getUnderlyingType()->getVectorElementCount(); - if (numElements.Scalable) + if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; + return nullptr; + } shape.push_back(numElements.Min); - LLVMType elementType = vectorType.getVectorElementType(); - return VectorType::get(shape, getStdTypeForAttr(elementType)); + Type elementType = getStdTypeForAttr(vectorType.getVectorElementType()); + if (!elementType) + return nullptr; + return VectorType::get(shape, elementType); } // Otherwise use a tensor. - return RankedTensorType::get(shape, - getStdTypeForAttr(type.getArrayElementType())); + Type elementType = getStdTypeForAttr(type.getArrayElementType()); + if (!elementType) + return nullptr; + return RankedTensorType::get(shape, elementType); } - llvm_unreachable("no equivalent standard type for typed attributes"); + return nullptr; } // Get the given constant as an attribute. Not all constants can be represented @@ -277,9 +310,11 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) { // Convert constant data to a dense elements attribute. if (auto *cd = dyn_cast(value)) { LLVMType type = processType(cd->getElementType()); + if (!type) + return nullptr; + auto attrType = getStdTypeForAttr(processType(cd->getType())) .dyn_cast_or_null(); - assert(attrType); if (!attrType) return nullptr; @@ -368,15 +403,19 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { Attribute valueAttr; if (GV->hasInitializer()) valueAttr = getConstantAsAttr(GV->getInitializer()); + LLVMType type = processType(GV->getValueType()); + if (!type) + return nullptr; GlobalOp op = b.create( - UnknownLoc::get(context), processType(GV->getValueType()), - GV->isConstant(), processLinkage(GV->getLinkage()), GV->getName(), - valueAttr); + UnknownLoc::get(context), type, GV->isConstant(), + processLinkage(GV->getLinkage()), GV->getName(), valueAttr); if (GV->hasInitializer() && !valueAttr) { Region &r = op.getInitializerRegion(); currentEntryBlock = b.createBlock(&r); b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); Value v = processConstant(GV->getInitializer()); + if (!v) + return nullptr; b.create(op.getLoc(), ArrayRef({v})); } return globals[GV] = op; @@ -386,13 +425,17 @@ Value Importer::processConstant(llvm::Constant *c) { if (Attribute attr = getConstantAsAttr(c)) { // These constants can be represented as attributes. OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); - return instMap[c] = b.create(unknownLoc, - processType(c->getType()), attr); + LLVMType type = processType(c->getType()); + if (!type) + return nullptr; + return instMap[c] = b.create(unknownLoc, type, attr); } if (auto *cn = dyn_cast(c)) { OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); - return instMap[c] = - b.create(unknownLoc, processType(cn->getType())); + LLVMType type = processType(cn->getType()); + if (!type) + return nullptr; + return instMap[c] = b.create(unknownLoc, type); } if (auto *ce = dyn_cast(c)) { llvm::Instruction *i = ce->getAsInstruction(); @@ -420,13 +463,19 @@ Value Importer::processValue(llvm::Value *value) { // this instruction yet, create an unknown op and remap it later. if (isa(value)) { OperationState state(UnknownLoc::get(context), "unknown"); - state.addTypes({processType(value->getType())}); + LLVMType type = processType(value->getType()); + if (!type) + return nullptr; + state.addTypes(type); unknownInstMap[value] = b.createOperation(state); return unknownInstMap[value]->getResult(0); } if (auto *GV = dyn_cast(value)) { - return b.create(UnknownLoc::get(context), processGlobal(GV), + auto global = processGlobal(GV); + if (!global) + return nullptr; + return b.create(UnknownLoc::get(context), global, ArrayRef()); } @@ -520,14 +569,17 @@ static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { // `br` branches to `target`. Return the branch arguments to `br`, in the // same order of the PHIs in `target`. -SmallVector Importer::processBranchArgs(llvm::BranchInst *br, - llvm::BasicBlock *target) { - SmallVector v; +LogicalResult +Importer::processBranchArgs(llvm::BranchInst *br, llvm::BasicBlock *target, + SmallVectorImpl &blockArguments) { for (auto inst = target->begin(); isa(inst); ++inst) { auto *PN = cast(&*inst); - v.push_back(processValue(PN->getIncomingValueForBlock(br->getParent()))); + Value value = processValue(PN->getIncomingValueForBlock(br->getParent())); + if (!value) + return failure(); + blockArguments.push_back(value); } - return v; + return success(); } LogicalResult Importer::processInstruction(llvm::Instruction *inst) { @@ -577,20 +629,32 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { OperationState state(loc, opcMap.lookup(inst->getOpcode())); SmallVector ops; ops.reserve(inst->getNumOperands()); - for (auto *op : inst->operand_values()) - ops.push_back(processValue(op)); + for (auto *op : inst->operand_values()) { + Value value = processValue(op); + if (!value) + return failure(); + ops.push_back(value); + } state.addOperands(ops); - if (!inst->getType()->isVoidTy()) - state.addTypes(ArrayRef({processType(inst->getType())})); + if (!inst->getType()->isVoidTy()) { + LLVMType type = processType(inst->getType()); + if (!type) + return failure(); + state.addTypes(type); + } Operation *op = b.createOperation(state); if (!inst->getType()->isVoidTy()) v = op->getResult(0); return success(); } case llvm::Instruction::ICmp: { + Value lhs = processValue(inst->getOperand(0)); + Value rhs = processValue(inst->getOperand(1)); + if (!lhs || !rhs) + return failure(); v = b.create( - loc, getICmpPredicate(cast(inst)->getPredicate()), - processValue(inst->getOperand(0)), processValue(inst->getOperand(1))); + loc, getICmpPredicate(cast(inst)->getPredicate()), lhs, + rhs); return success(); } case llvm::Instruction::Br: { @@ -598,35 +662,57 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { OperationState state(loc, brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); SmallVector ops; - if (brInst->isConditional()) - ops.push_back(processValue(brInst->getCondition())); + if (brInst->isConditional()) { + Value condition = processValue(brInst->getCondition()); + if (!condition) + return failure(); + ops.push_back(condition); + } state.addOperands(ops); SmallVector succs; - for (auto *succ : llvm::reverse(brInst->successors())) - state.addSuccessor(blocks[succ], processBranchArgs(brInst, succ)); + for (auto *succ : llvm::reverse(brInst->successors())) { + SmallVector blockArguments; + if (failed(processBranchArgs(brInst, succ, blockArguments))) + return failure(); + state.addSuccessor(blocks[succ], blockArguments); + } b.createOperation(state); return success(); } case llvm::Instruction::PHI: { - v = b.getInsertionBlock()->addArgument(processType(inst->getType())); + LLVMType type = processType(inst->getType()); + if (!type) + return failure(); + v = b.getInsertionBlock()->addArgument(type); return success(); } case llvm::Instruction::Call: { llvm::CallInst *ci = cast(inst); SmallVector ops; ops.reserve(inst->getNumOperands()); - for (auto &op : ci->arg_operands()) - ops.push_back(processValue(op.get())); + for (auto &op : ci->arg_operands()) { + Value arg = processValue(op.get()); + if (!arg) + return failure(); + ops.push_back(arg); + } SmallVector tys; - if (!ci->getType()->isVoidTy()) - tys.push_back(processType(inst->getType())); + if (!ci->getType()->isVoidTy()) { + LLVMType type = processType(inst->getType()); + if (!type) + return failure(); + tys.push_back(type); + } Operation *op; if (llvm::Function *callee = ci->getCalledFunction()) { op = b.create(loc, tys, b.getSymbolRefAttr(callee->getName()), ops); } else { - ops.insert(ops.begin(), processValue(ci->getCalledValue())); + Value calledValue = processValue(ci->getCalledValue()); + if (!calledValue) + return failure(); + ops.insert(ops.begin(), calledValue); op = b.create(loc, tys, ops, ArrayRef()); } if (!ci->getType()->isVoidTy()) @@ -637,10 +723,16 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { // FIXME: Support inbounds GEPs. llvm::GetElementPtrInst *gep = cast(inst); SmallVector ops; - for (auto *op : gep->operand_values()) - ops.push_back(processValue(op)); - v = b.create(loc, processType(inst->getType()), ops, - ArrayRef()); + for (auto *op : gep->operand_values()) { + Value value = processValue(op); + if (!value) + return failure(); + ops.push_back(value); + } + Type type = processType(inst->getType()); + if (!type) + return failure(); + v = b.create(loc, type, ops, ArrayRef()); return success(); } } @@ -651,9 +743,13 @@ LogicalResult Importer::processFunction(llvm::Function *f) { instMap.clear(); unknownInstMap.clear(); + LLVMType functionType = processType(f->getFunctionType()); + if (!functionType) + return failure(); + b.setInsertionPoint(module.getBody(), getFuncInsertPt()); LLVMFuncOp fop = b.create(UnknownLoc::get(context), f->getName(), - processType(f->getFunctionType())); + functionType); if (f->isDeclaration()) return success(); @@ -666,8 +762,9 @@ LogicalResult Importer::processFunction(llvm::Function *f) { currentEntryBlock = blockList[0]; // Add function arguments to the entry block. - for (auto &arg : f->args()) - instMap[&arg] = blockList[0]->addArgument(processType(arg.getType())); + for (auto kv : llvm::enumerate(f->args())) + instMap[&kv.value()] = blockList[0]->addArgument( + functionType.getFunctionParamType(kv.index())); for (auto bbs : llvm::zip(*f, blockList)) { if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))