[mlir] Fix asan issue in Vectorization.cpp of Linalg.
authorAlexander Belyaev <pifon@google.com>
Thu, 27 Oct 2022 16:02:32 +0000 (18:02 +0200)
committerAlexander Belyaev <pifon@google.com>
Thu, 27 Oct 2022 16:11:08 +0000 (18:11 +0200)
Differential Revision: https://reviews.llvm.org/D136852

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

index a435c10..356ba00 100644 (file)
@@ -528,10 +528,10 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
   // 3. Turn all BBArgs into vector.transfer_read / load.
   Location loc = linalgOp.getLoc();
   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
-  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
-    BlockArgument bbarg = block->getArgument(opOperand.getOperandNumber());
-    if (linalgOp.isScalar(&opOperand)) {
-      bvm.map(bbarg, opOperand.get());
+  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
+    BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
+    if (linalgOp.isScalar(opOperand)) {
+      bvm.map(bbarg, opOperand->get());
       continue;
     }
     VectorType readType;
@@ -540,23 +540,23 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
     // if (linalgOp.getShape(&opOperand).empty()) {
     //   readType = VectorType::get({}, bbarg.getType());
     // } else {
-    if (opOperand.getOperandNumber() < linalgOp.getNumInputs()) {
+    if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
       map = inverseAndBroadcastProjectedPermutation(
-          linalgOp.getMatchingIndexingMap(&opOperand));
+          linalgOp.getMatchingIndexingMap(opOperand));
       readType = VectorType::get(commonVectorShape,
-                                 getElementTypeOrSelf(opOperand.get()));
+                                 getElementTypeOrSelf(opOperand->get()));
     } else {
       map = inversePermutation(
-          reindexIndexingMap(linalgOp.getMatchingIndexingMap(&opOperand)));
-      readType = VectorType::get(map.compose(linalgOp.getShape(&opOperand)),
-                                 getElementTypeOrSelf(opOperand.get()));
+          reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand)));
+      readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
+                                 getElementTypeOrSelf(opOperand->get()));
     }
     // }
 
-    auto shape = linalgOp.getShape(&opOperand);
+    auto shape = linalgOp.getShape(opOperand);
     SmallVector<Value> indices(shape.size(), zero);
     Value readValue = b.create<vector::TransferReadOp>(
-        loc, readType, opOperand.get(), indices, map);
+        loc, readType, opOperand->get(), indices, map);
     // Not all ops support 0-d vectors, extract the scalar for now.
     // TODO: remove this.
     if (readValue.getType().cast<VectorType>().getRank() == 0)
@@ -564,7 +564,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
 
     LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
     bvm.map(bbarg, readValue);
-    bvm.map(opOperand.get(), readValue);
+    bvm.map(opOperand->get(), readValue);
   }
 
   SmallVector<CustomVectorizationHook> hooks;