[mlir][Linalg] Fix crash in LinalgToStandard
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 20 Jan 2023 08:06:34 +0000 (00:06 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 20 Jan 2023 08:17:55 +0000 (00:17 -0800)
Properly handle `appendMangledType` failure instead of asserting.

Fixes #59986.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/standard.mlir

index e411de6..ebfb8ef 100644 (file)
@@ -1795,7 +1795,7 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
   return llvm::to_vector<4>(concatRanges);
 }
 
-static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
+static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
   if (auto memref = t.dyn_cast<MemRefType>()) {
     ss << "view";
     for (auto size : memref.getShape())
@@ -1804,16 +1804,19 @@ static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
       else
         ss << size << "x";
     appendMangledType(ss, memref.getElementType());
-  } else if (auto vec = t.dyn_cast<VectorType>()) {
+    return success();
+  }
+  if (auto vec = t.dyn_cast<VectorType>()) {
     ss << "vector";
     llvm::interleave(
         vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
     appendMangledType(ss, vec.getElementType());
+    return success();
   } else if (t.isSignlessIntOrIndexOrFloat()) {
     ss << t;
-  } else {
-    llvm_unreachable("Invalid type for linalg library name mangling");
+    return success();
   }
+  return failure();
 }
 
 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
@@ -1823,11 +1826,14 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
   std::replace(name.begin(), name.end(), '.', '_');
   llvm::raw_string_ostream ss(name);
   ss << "_";
-  auto types = op->getOperandTypes();
-  llvm::interleave(
-      types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
-      [&]() { ss << "_"; });
-  return ss.str();
+  for (Type t : op->getOperandTypes()) {
+    if (failed(appendMangledType(ss, t)))
+      return std::string();
+    ss << "_";
+  }
+  std::string res = ss.str();
+  res.pop_back();
+  return res;
 }
 
 //===----------------------------------------------------------------------===//
index fcb215e..f50016f 100644 (file)
@@ -71,3 +71,11 @@ func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>)  {
   } -> tensor<?xf32>
   return 
 }
+
+// -----
+
+func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error @below {{failed to legalize}}
+  %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}