[mlir] Add an additional check to vectorizeStaticLinalgOpPrecondition.
authorAdrian Kuegel <akuegel@google.com>
Wed, 22 Jun 2022 11:50:30 +0000 (13:50 +0200)
committerAdrian Kuegel <akuegel@google.com>
Thu, 23 Jun 2022 08:24:04 +0000 (10:24 +0200)
We need to make sure that the types used in the body are valid element types
for VectorType.

Differential Revision: https://reviews.llvm.org/D128336

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index efbf051..6c3d25d 100644 (file)
@@ -552,6 +552,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
+  // All types in the body should be a supported element type for VectorType.
+  for (Operation &innerOp : op->getRegion(0).front()) {
+    if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
+          return !VectorType::isValidElementType(type);
+        })) {
+      return failure();
+    }
+    if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
+          return !VectorType::isValidElementType(type);
+        })) {
+      return failure();
+    }
+  }
   if (isElementwise(op))
     return success();
   // TODO: isaConvolutionOpInterface that can also infer from generic features.
index 99617d5..dbd0957 100644 (file)
@@ -207,6 +207,23 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
 
 // -----
 
+// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types
+func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex<f32>>, %arg0 : complex<f32>) {
+  // CHECK-NOT: vector.broadcast
+  // CHECK-NOT: vector.transfer_write
+  linalg.generic {
+    indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
+    iterator_types = ["parallel", "parallel"]}
+   ins(%arg0 : complex<f32>)
+  outs(%A: memref<8x16xcomplex<f32>>) {
+    ^bb(%0: complex<f32>, %1: complex<f32>) :
+      linalg.yield %0 : complex<f32>
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @test_vectorize_fill
 func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
   //       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>