[mlir] Fix call op conversion in bare-ptr calling convention
authorDiego Caballero <diego.caballero@intel.com>
Fri, 2 Oct 2020 15:42:13 +0000 (08:42 -0700)
committerDiego Caballero <diego.caballero@intel.com>
Fri, 2 Oct 2020 15:48:21 +0000 (08:48 -0700)
We hit an llvm_unreachable related to unranked memrefs for call ops
with scalar types. Removing the llvm_unreachable since the conversion
should gracefully bail out in the presence of unranked memrefs. Adding
tests to verify that.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D88709

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir

index c77c0b5..37d0c94 100644 (file)
@@ -436,14 +436,10 @@ void LLVMTypeConverter::promoteBarePtrsToDescriptors(
     SmallVectorImpl<Value> &values) {
   assert(stdTypes.size() == values.size() &&
          "The number of types and values doesn't match");
-  for (unsigned i = 0, end = values.size(); i < end; ++i) {
-    Type stdTy = stdTypes[i];
-    if (auto memrefTy = stdTy.dyn_cast<MemRefType>())
+  for (unsigned i = 0, end = values.size(); i < end; ++i)
+    if (auto memrefTy = stdTypes[i].dyn_cast<MemRefType>())
       values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
                                                     memrefTy, values[i]);
-    else
-      llvm_unreachable("Unranked memrefs are not supported");
-  }
 }
 
 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
index 5dd36ba..b93446f 100644 (file)
@@ -416,3 +416,33 @@ func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> {
   // BAREPTR-NEXT:    llvm.return %[[res]] : !llvm.ptr<i8>
   return %res : memref<20xi8>
 }
+
+// -----
+
+// BAREPTR: llvm.func @goo(!llvm.float) -> !llvm.float
+func @goo(f32) -> f32
+
+// BAREPTR-LABEL: func @check_scalar_func_call
+// BAREPTR-SAME:    %[[in:.*]]: !llvm.float)
+func @check_scalar_func_call(%in : f32) {
+  // BAREPTR-NEXT:    %[[call:.*]] = llvm.call @goo(%[[in]]) : (!llvm.float) -> !llvm.float
+  %res = call @goo(%in) : (f32) -> (f32)
+  return
+}
+
+// -----
+
+// Unranked memrefs are currently not supported in the bare-ptr calling
+// convention. Check that the conversion to the LLVM-IR dialect doesn't happen
+// in the presence of unranked memrefs when using such a calling convention.
+
+// BAREPTR: func @hoo(memref<*xi8>) -> memref<*xi8>
+func @hoo(memref<*xi8>) -> memref<*xi8>
+
+// BAREPTR-LABEL: func @check_unranked_memref_func_call(%{{.*}}: memref<*xi8>) -> memref<*xi8>
+func @check_unranked_memref_func_call(%in: memref<*xi8>) -> memref<*xi8> {
+  // BAREPTR-NEXT: call @hoo(%{{.*}}) : (memref<*xi8>) -> memref<*xi8>
+  %res = call @hoo(%in) : (memref<*xi8>) -> memref<*xi8>
+  // BAREPTR-NEXT: return %{{.*}} : memref<*xi8>
+  return %res : memref<*xi8>
+}