[mlir] fix a crash in linalg.generic parser
authorAlex Zinenko <zinenko@google.com>
Tue, 11 Apr 2023 13:23:13 +0000 (13:23 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 11 Apr 2023 15:35:58 +0000 (15:35 +0000)
Report an error when the `iterator_types` attribute is missing.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D148015

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir

index ee4d064..da582a5 100644 (file)
@@ -777,18 +777,23 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
   // The name is unimportant as we will overwrite result.attributes.
   // The core linalg traits must contain the information necessary to pass the
   // verifier.
+  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
   if (parser.parseAttribute(dictAttr, "_", result.attributes))
     return failure();
   result.attributes.assign(dictAttr.getValue().begin(),
                            dictAttr.getValue().end());
 
-  // Convert array of string into an array of IteratyType enums. This is needed,
-  // because tests still use the old format when 'iterator_types' attribute is
-  // represented as an array of strings.
+  // Convert array of string into an array of IteratorType enums. This is
+  // needed, because tests still use the old format when 'iterator_types'
+  // attribute is represented as an array of strings.
   // TODO: Remove this conversion once tests are fixed.
-  ArrayAttr iteratorTypes =
-      result.attributes.get(getIteratorTypesAttrName(result.name))
-          .cast<ArrayAttr>();
+  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
+      result.attributes.get(getIteratorTypesAttrName(result.name)));
+  if (!iteratorTypes) {
+    return parser.emitError(attributeLocation)
+           << "expected " << getIteratorTypesAttrName(result.name)
+           << " array attribute";
+  }
 
   SmallVector<Attribute> iteratorTypeAttrs;
 
index 03540be..af3dc66 100644 (file)
@@ -725,3 +725,11 @@ func.func @broadcast_size_1_extension_not_supported(
       dimensions = [1]
   func.return %bcast : tensor<4x?x16xf32>
 }
+
+// -----
+
+func.func @missing_iterator_types() {
+  // expected-error @below {{expected "iterator_types" array attribute}}
+  linalg.generic {} ins() outs()
+  return
+}