[mlir] Harden error propagation in LLVM import
authorAlex Zinenko <zinenko@google.com>
Thu, 16 Jan 2020 13:32:33 +0000 (14:32 +0100)
committerAlex Zinenko <zinenko@google.com>
Mon, 27 Jan 2020 15:15:11 +0000 (16:15 +0100)
Summary:
LLVM importer to MLIR was implemented mostly as a prototype. As such, it did
not deal handle errors in a consistent way, reporting them out stderr in some
cases and continuing the execution in the error state until eventually
crashing. This is not desirable for a user-facing tool. Make sure errors are
returned from functions, consistently checked at call sites and propagated
further. Functions returning nullable IR values return nullptr to denote the
error state. Other functions return LogicalResult. LLVM importer in
mlir-translate should no longer crash on unsupported inputs.

The errors are reported without association with the source file (and therefore
cannot be checked using -verify-diagnostics). Attaching them to the actual
input file is left for future work.

Differential Revision: https://reviews.llvm.org/D72839

mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp

index e882161..0692b9b 100644 (file)
@@ -73,11 +73,12 @@ private:
   /// is unavailable.
   Location processDebugLoc(const llvm::DebugLoc &loc,
                            llvm::Instruction *inst = nullptr);
-  /// `br` branches to `target`. Return the block arguments to attach to the
-  /// generated branch op. These should be in the same order as the PHIs in
-  /// `target`.
-  SmallVector<Value, 4> processBranchArgs(llvm::BranchInst *br,
-                                          llvm::BasicBlock *target);
+  /// `br` branches to `target`. Append the block arguments to attach to the
+  /// generated branch op to `blockArguments`. These should be in the same order
+  /// as the PHIs in `target`.
+  LogicalResult processBranchArgs(llvm::BranchInst *br,
+                                  llvm::BasicBlock *target,
+                                  SmallVectorImpl<Value> &blockArguments);
   /// Returns the standard type equivalent to be used in attributes for the
   /// given LLVM IR dialect type.
   Type getStdTypeForAttr(LLVMType type);
@@ -151,17 +152,27 @@ LLVMType Importer::processType(llvm::Type *type) {
     return LLVMType::getDoubleTy(dialect);
   case llvm::Type::IntegerTyID:
     return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth());
-  case llvm::Type::PointerTyID:
-    return processType(type->getPointerElementType())
-        .getPointerTo(type->getPointerAddressSpace());
-  case llvm::Type::ArrayTyID:
-    return LLVMType::getArrayTy(processType(type->getArrayElementType()),
-                                type->getArrayNumElements());
+  case llvm::Type::PointerTyID: {
+    LLVMType elementType = processType(type->getPointerElementType());
+    if (!elementType)
+      return nullptr;
+    return elementType.getPointerTo(type->getPointerAddressSpace());
+  }
+  case llvm::Type::ArrayTyID: {
+    LLVMType elementType = processType(type->getArrayElementType());
+    if (!elementType)
+      return nullptr;
+    return LLVMType::getArrayTy(elementType, type->getArrayNumElements());
+  }
   case llvm::Type::VectorTyID: {
-    if (type->getVectorIsScalable())
+    if (type->getVectorIsScalable()) {
       emitError(unknownLoc) << "scalable vector types not supported";
-    return LLVMType::getVectorTy(processType(type->getVectorElementType()),
-                                 type->getVectorNumElements());
+      return nullptr;
+    }
+    LLVMType elementType = processType(type->getVectorElementType());
+    if (!elementType)
+      return nullptr;
+    return LLVMType::getVectorTy(elementType, type->getVectorNumElements());
   }
   case llvm::Type::VoidTyID:
     return LLVMType::getVoidTy(dialect);
@@ -171,18 +182,30 @@ LLVMType Importer::processType(llvm::Type *type) {
     return LLVMType::getX86_FP80Ty(dialect);
   case llvm::Type::StructTyID: {
     SmallVector<LLVMType, 4> elementTypes;
-    for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i)
-      elementTypes.push_back(processType(type->getStructElementType(i)));
+    elementTypes.reserve(type->getStructNumElements());
+    for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) {
+      LLVMType ty = processType(type->getStructElementType(i));
+      if (!ty)
+        return nullptr;
+      elementTypes.push_back(ty);
+    }
     return LLVMType::getStructTy(dialect, elementTypes,
                                  cast<llvm::StructType>(type)->isPacked());
   }
   case llvm::Type::FunctionTyID: {
     llvm::FunctionType *fty = cast<llvm::FunctionType>(type);
     SmallVector<LLVMType, 4> paramTypes;
-    for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i)
-      paramTypes.push_back(processType(fty->getParamType(i)));
-    return LLVMType::getFunctionTy(processType(fty->getReturnType()),
-                                   paramTypes, fty->isVarArg());
+    for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) {
+      LLVMType ty = processType(fty->getParamType(i));
+      if (!ty)
+        return nullptr;
+      paramTypes.push_back(ty);
+    }
+    LLVMType result = processType(fty->getReturnType());
+    if (!result)
+      return nullptr;
+
+    return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg());
   }
   default: {
     // FIXME: Diagnostic should be able to natively handle types that have
@@ -191,7 +214,7 @@ LLVMType Importer::processType(llvm::Type *type) {
     llvm::raw_string_ostream os(s);
     os << *type;
     emitError(unknownLoc) << "unhandled type: " << os.str();
-    return {};
+    return nullptr;
   }
   }
 }
@@ -217,10 +240,14 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
   // LLVM vectors can only contain scalars.
   if (type.isVectorTy()) {
     auto numElements = type.getUnderlyingType()->getVectorElementCount();
-    if (numElements.Scalable)
+    if (numElements.Scalable) {
       emitError(unknownLoc) << "scalable vectors not supported";
-    return VectorType::get(numElements.Min,
-                           getStdTypeForAttr(type.getVectorElementType()));
+      return nullptr;
+    }
+    Type elementType = getStdTypeForAttr(type.getVectorElementType());
+    if (!elementType)
+      return nullptr;
+    return VectorType::get(numElements.Min, elementType);
   }
 
   // LLVM arrays can contain other arrays or vectors.
@@ -239,20 +266,26 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
       LLVMType vectorType = type.getArrayElementType();
       auto numElements =
           vectorType.getUnderlyingType()->getVectorElementCount();
-      if (numElements.Scalable)
+      if (numElements.Scalable) {
         emitError(unknownLoc) << "scalable vectors not supported";
+        return nullptr;
+      }
       shape.push_back(numElements.Min);
 
-      LLVMType elementType = vectorType.getVectorElementType();
-      return VectorType::get(shape, getStdTypeForAttr(elementType));
+      Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
+      if (!elementType)
+        return nullptr;
+      return VectorType::get(shape, elementType);
     }
 
     // Otherwise use a tensor.
-    return RankedTensorType::get(shape,
-                                 getStdTypeForAttr(type.getArrayElementType()));
+    Type elementType = getStdTypeForAttr(type.getArrayElementType());
+    if (!elementType)
+      return nullptr;
+    return RankedTensorType::get(shape, elementType);
   }
 
-  llvm_unreachable("no equivalent standard type for typed attributes");
+  return nullptr;
 }
 
 // Get the given constant as an attribute. Not all constants can be represented
@@ -277,9 +310,11 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
   // Convert constant data to a dense elements attribute.
   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
     LLVMType type = processType(cd->getElementType());
+    if (!type)
+      return nullptr;
+
     auto attrType = getStdTypeForAttr(processType(cd->getType()))
                         .dyn_cast_or_null<ShapedType>();
-    assert(attrType);
     if (!attrType)
       return nullptr;
 
@@ -368,15 +403,19 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
   Attribute valueAttr;
   if (GV->hasInitializer())
     valueAttr = getConstantAsAttr(GV->getInitializer());
+  LLVMType type = processType(GV->getValueType());
+  if (!type)
+    return nullptr;
   GlobalOp op = b.create<GlobalOp>(
-      UnknownLoc::get(context), processType(GV->getValueType()),
-      GV->isConstant(), processLinkage(GV->getLinkage()), GV->getName(),
-      valueAttr);
+      UnknownLoc::get(context), type, GV->isConstant(),
+      processLinkage(GV->getLinkage()), GV->getName(), valueAttr);
   if (GV->hasInitializer() && !valueAttr) {
     Region &r = op.getInitializerRegion();
     currentEntryBlock = b.createBlock(&r);
     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
     Value v = processConstant(GV->getInitializer());
+    if (!v)
+      return nullptr;
     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
   }
   return globals[GV] = op;
@@ -386,13 +425,17 @@ Value Importer::processConstant(llvm::Constant *c) {
   if (Attribute attr = getConstantAsAttr(c)) {
     // These constants can be represented as attributes.
     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
-    return instMap[c] = b.create<ConstantOp>(unknownLoc,
-                                             processType(c->getType()), attr);
+    LLVMType type = processType(c->getType());
+    if (!type)
+      return nullptr;
+    return instMap[c] = b.create<ConstantOp>(unknownLoc, type, attr);
   }
   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
-    return instMap[c] =
-               b.create<NullOp>(unknownLoc, processType(cn->getType()));
+    LLVMType type = processType(cn->getType());
+    if (!type)
+      return nullptr;
+    return instMap[c] = b.create<NullOp>(unknownLoc, type);
   }
   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
     llvm::Instruction *i = ce->getAsInstruction();
@@ -420,13 +463,19 @@ Value Importer::processValue(llvm::Value *value) {
   // this instruction yet, create an unknown op and remap it later.
   if (isa<llvm::Instruction>(value)) {
     OperationState state(UnknownLoc::get(context), "unknown");
-    state.addTypes({processType(value->getType())});
+    LLVMType type = processType(value->getType());
+    if (!type)
+      return nullptr;
+    state.addTypes(type);
     unknownInstMap[value] = b.createOperation(state);
     return unknownInstMap[value]->getResult(0);
   }
 
   if (auto *GV = dyn_cast<llvm::GlobalVariable>(value)) {
-    return b.create<AddressOfOp>(UnknownLoc::get(context), processGlobal(GV),
+    auto global = processGlobal(GV);
+    if (!global)
+      return nullptr;
+    return b.create<AddressOfOp>(UnknownLoc::get(context), global,
                                  ArrayRef<NamedAttribute>());
   }
 
@@ -520,14 +569,17 @@ static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
 
 // `br` branches to `target`. Return the branch arguments to `br`, in the
 // same order of the PHIs in `target`.
-SmallVector<Value, 4> Importer::processBranchArgs(llvm::BranchInst *br,
-                                                  llvm::BasicBlock *target) {
-  SmallVector<Value, 4> v;
+LogicalResult
+Importer::processBranchArgs(llvm::BranchInst *br, llvm::BasicBlock *target,
+                            SmallVectorImpl<Value> &blockArguments) {
   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
     auto *PN = cast<llvm::PHINode>(&*inst);
-    v.push_back(processValue(PN->getIncomingValueForBlock(br->getParent())));
+    Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
+    if (!value)
+      return failure();
+    blockArguments.push_back(value);
   }
-  return v;
+  return success();
 }
 
 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
@@ -577,20 +629,32 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     OperationState state(loc, opcMap.lookup(inst->getOpcode()));
     SmallVector<Value, 4> ops;
     ops.reserve(inst->getNumOperands());
-    for (auto *op : inst->operand_values())
-      ops.push_back(processValue(op));
+    for (auto *op : inst->operand_values()) {
+      Value value = processValue(op);
+      if (!value)
+        return failure();
+      ops.push_back(value);
+    }
     state.addOperands(ops);
-    if (!inst->getType()->isVoidTy())
-      state.addTypes(ArrayRef<Type>({processType(inst->getType())}));
+    if (!inst->getType()->isVoidTy()) {
+      LLVMType type = processType(inst->getType());
+      if (!type)
+        return failure();
+      state.addTypes(type);
+    }
     Operation *op = b.createOperation(state);
     if (!inst->getType()->isVoidTy())
       v = op->getResult(0);
     return success();
   }
   case llvm::Instruction::ICmp: {
+    Value lhs = processValue(inst->getOperand(0));
+    Value rhs = processValue(inst->getOperand(1));
+    if (!lhs || !rhs)
+      return failure();
     v = b.create<ICmpOp>(
-        loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()),
-        processValue(inst->getOperand(0)), processValue(inst->getOperand(1)));
+        loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
+        rhs);
     return success();
   }
   case llvm::Instruction::Br: {
@@ -598,35 +662,57 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     OperationState state(loc,
                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
     SmallVector<Value, 4> ops;
-    if (brInst->isConditional())
-      ops.push_back(processValue(brInst->getCondition()));
+    if (brInst->isConditional()) {
+      Value condition = processValue(brInst->getCondition());
+      if (!condition)
+        return failure();
+      ops.push_back(condition);
+    }
     state.addOperands(ops);
     SmallVector<Block *, 4> succs;
-    for (auto *succ : llvm::reverse(brInst->successors()))
-      state.addSuccessor(blocks[succ], processBranchArgs(brInst, succ));
+    for (auto *succ : llvm::reverse(brInst->successors())) {
+      SmallVector<Value, 4> blockArguments;
+      if (failed(processBranchArgs(brInst, succ, blockArguments)))
+        return failure();
+      state.addSuccessor(blocks[succ], blockArguments);
+    }
     b.createOperation(state);
     return success();
   }
   case llvm::Instruction::PHI: {
-    v = b.getInsertionBlock()->addArgument(processType(inst->getType()));
+    LLVMType type = processType(inst->getType());
+    if (!type)
+      return failure();
+    v = b.getInsertionBlock()->addArgument(type);
     return success();
   }
   case llvm::Instruction::Call: {
     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
     SmallVector<Value, 4> ops;
     ops.reserve(inst->getNumOperands());
-    for (auto &op : ci->arg_operands())
-      ops.push_back(processValue(op.get()));
+    for (auto &op : ci->arg_operands()) {
+      Value arg = processValue(op.get());
+      if (!arg)
+        return failure();
+      ops.push_back(arg);
+    }
 
     SmallVector<Type, 2> tys;
-    if (!ci->getType()->isVoidTy())
-      tys.push_back(processType(inst->getType()));
+    if (!ci->getType()->isVoidTy()) {
+      LLVMType type = processType(inst->getType());
+      if (!type)
+        return failure();
+      tys.push_back(type);
+    }
     Operation *op;
     if (llvm::Function *callee = ci->getCalledFunction()) {
       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
                             ops);
     } else {
-      ops.insert(ops.begin(), processValue(ci->getCalledValue()));
+      Value calledValue = processValue(ci->getCalledValue());
+      if (!calledValue)
+        return failure();
+      ops.insert(ops.begin(), calledValue);
       op = b.create<CallOp>(loc, tys, ops, ArrayRef<NamedAttribute>());
     }
     if (!ci->getType()->isVoidTy())
@@ -637,10 +723,16 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     // FIXME: Support inbounds GEPs.
     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
     SmallVector<Value, 4> ops;
-    for (auto *op : gep->operand_values())
-      ops.push_back(processValue(op));
-    v = b.create<GEPOp>(loc, processType(inst->getType()), ops,
-                        ArrayRef<NamedAttribute>());
+    for (auto *op : gep->operand_values()) {
+      Value value = processValue(op);
+      if (!value)
+        return failure();
+      ops.push_back(value);
+    }
+    Type type = processType(inst->getType());
+    if (!type)
+      return failure();
+    v = b.create<GEPOp>(loc, type, ops, ArrayRef<NamedAttribute>());
     return success();
   }
   }
@@ -651,9 +743,13 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
   instMap.clear();
   unknownInstMap.clear();
 
+  LLVMType functionType = processType(f->getFunctionType());
+  if (!functionType)
+    return failure();
+
   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
   LLVMFuncOp fop = b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(),
-                                        processType(f->getFunctionType()));
+                                        functionType);
   if (f->isDeclaration())
     return success();
 
@@ -666,8 +762,9 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
   currentEntryBlock = blockList[0];
 
   // Add function arguments to the entry block.
-  for (auto &arg : f->args())
-    instMap[&arg] = blockList[0]->addArgument(processType(arg.getType()));
+  for (auto kv : llvm::enumerate(f->args()))
+    instMap[&kv.value()] = blockList[0]->addArgument(
+        functionType.getFunctionParamType(kv.index()));
 
   for (auto bbs : llvm::zip(*f, blockList)) {
     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))