[mlir][Linalg] Fix constant detection in linalg.pad_tensor vectorization.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Sun, 14 Feb 2021 15:53:13 +0000 (15:53 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Sun, 14 Feb 2021 15:53:39 +0000 (15:53 +0000)
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index bfd2884..e350d89 100644 (file)
@@ -429,8 +429,9 @@ LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
     if (Attribute attr = ofr.dyn_cast<Attribute>())
       return attr.cast<IntegerAttr>().getInt() != 0;
     Value v = ofr.get<Value>();
-    if (auto constOp = v.getDefiningOp<ConstantIntOp>())
-      return constOp.getValue() != 0;
+    if (auto constOp = v.getDefiningOp<ConstantOp>())
+      if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>())
+        return intAttr.getValue().getSExtValue() != 0;
     return true;
   };
 
index bb532b2..13d2e18 100644 (file)
@@ -402,7 +402,8 @@ func @pad_static(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32>
   //      CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32>
   //      CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]]
   // CHECK-SAME:   {masked = [false, false, false]} : vector<2x3x4xf32>, tensor<2x3x4xf32>
-  %0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 0, 0] {
+  %c0 = constant 0 : index
+  %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] {
     ^bb0(%arg1: index, %arg2: index, %arg3: index):
       linalg.yield %pad_value : f32
     } : tensor<?x?x?xf32> to tensor<2x3x4xf32>