[mlir][tensor] Fix insert_slice + tensor cast overflow
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 10 Dec 2021 21:27:20 +0000 (21:27 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 10 Dec 2021 21:41:26 +0000 (21:41 +0000)
InsertSliceOp may have subprefix semantics where missing trailing dimensions
are automatically inferred directly from the operand shape.
This revision fixes an overflow that occurs in such cases when the impl is based on the op rank.

Differential Revision: https://reviews.llvm.org/D115549

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir

index edddfb8..49cfec6 100644 (file)
@@ -1417,11 +1417,11 @@ struct InsertSliceOpSourceCastInserter final
       return failure();
     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
                                      srcType.getShape().end());
-    for (int64_t i = 0; i < srcType.getRank(); ++i) {
-      if (Optional<int64_t> constInt =
-              getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
-        newSrcShape[i] = *constInt;
-    }
+    // Offsets / sizes / strides can be a subprefix of the rank; take only the
+    // leading dimensions.
+    for (auto en : llvm::enumerate(insertSliceOp.getMixedSizes()))
+      if (Optional<int64_t> constInt = getConstantIntValue(en.value()))
+        newSrcShape[en.index()] = *constInt;
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());
index fc9abe4..50fda25 100644 (file)
@@ -536,6 +536,21 @@ func @insert_tensor_cast_on_insert_slice_src(
 
 // -----
 
+// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src_prefix(
+// CHECK-SAME:      %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
+//      CHECK:    %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x?xf32>
+//      CHECK:    %[[r:.*]] =  tensor.insert_slice %[[cast]] into %[[arg1]][0, 1] [64, 5] [1, 1] : tensor<64x5x?xf32> into tensor<?x?x?xf32>
+//      CHECK:    return %[[r]]
+func @insert_tensor_cast_on_insert_slice_src_prefix(
+    %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
+  %c64 = arith.constant 64: index
+  %r = tensor.insert_slice %arg0 into %arg1[0, 1] [%c64, 5] [1, 1]
+    : tensor<?x5x?xf32> into tensor<?x?x?xf32>
+  return %r : tensor<?x?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract_insert
 //  CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
 func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {