Fix stride computation bug when lowering linalg.view to llvm
authorNicolas Vasilache <ntv@google.com>
Mon, 20 May 2019 15:26:11 +0000 (08:26 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:49:10 +0000 (13:49 -0700)
--

PiperOrigin-RevId: 249053115

mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Linalg/Utils/Utils.cpp
mlir/test/Linalg/llvm.mlir

index 27e1385..f62dae9 100644 (file)
@@ -632,7 +632,6 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
                                      AffineMap::get(3, 0, {i, j}, {})};
   llvm_unreachable("Missing loopToOperandRangesMaps for op");
 }
-
 // Ideally this should all be Tablegen'd but there is no good story for op
 // expansion directly in MLIR for now.
 void mlir::linalg::emitScalarImplementation(
index a4da12a..ef762ff 100644 (file)
@@ -539,7 +539,7 @@ public:
     // Compute and insert view sizes (max - min along the range).
     int numIndexings = llvm::size(viewOp.getIndexings());
     Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
-    for (int i = 0; i < numIndexings; ++i) {
+    for (int i = numIndexings - 1; i >= 0; --i) {
       // Update stride.
       Value *rangeDescriptor = operands[1 + i];
       Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
@@ -550,9 +550,8 @@ public:
       Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
       Value *size = sub(max, min);
       desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
-      ++i;
       // Update stride for the next dimension.
-      if (i < numIndexings - 1)
+      if (i > 0)
         runningStride = mul(runningStride, max);
     }
 
index 75cddb1..41393a3 100644 (file)
@@ -138,7 +138,6 @@ Value *mlir::extractRangePart(Value *range, RangePart part) {
   }
   llvm_unreachable("need operations to extract range parts");
 }
-
 // Folding eagerly is necessary to abide by affine.for static step requirement.
 // We must propagate constants on the steps as aggressively as possible.
 // Returns nullptr if folding is not trivially feasible.
@@ -172,6 +171,7 @@ SmallVector<Value *, 4> mlir::applyMapToRangePart(FuncBuilder *b, Location loc,
                                                   RangePart part,
                                                   FunctionConstants &state) {
   SmallVector<Value *, 4> rangeParts(ranges.size());
+
   llvm::transform(ranges, rangeParts.begin(),
                   [&](Value *range) { return extractRangePart(range, part); });
 
index 1c641fb..d3ce0ec 100644 (file)
@@ -40,6 +40,21 @@ func @view(%arg0: !linalg.buffer<f32>, %arg1: !linalg.range) {
 //  CHECK-NEXT:   %11 = llvm.sub %10, %9 : !llvm.i64
 //  CHECK-NEXT:   %12 = llvm.insertvalue %11, %8[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 
+func @view3d(%arg0: !linalg.buffer<f32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
+  %0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.view<?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @view3d(%arg0: !llvm<"{ float*, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">, %arg2: !llvm<"{ i64, i64, i64 }">, %arg3: !llvm<"{ i64, i64, i64 }">) {
+//  CHECK-NEXT:   %5 = llvm.constant(1 : index) : !llvm.i64
+//  CHECK-NEXT:   %6 = llvm.extractvalue %arg3[2] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %7 = llvm.mul %5, %6 : !llvm.i64
+//  CHECK-NEXT:   %8 = llvm.insertvalue %7, %4[3, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   %10 = llvm.extractvalue %arg3[1] : !llvm<"{ i64, i64, i64 }">
+//       CHECK:   %13 = llvm.mul %5, %10 : !llvm.i64
+//  CHECK-NEXT:   %14 = llvm.extractvalue %arg2[2] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %15 = llvm.mul %13, %14 : !llvm.i64
+//  CHECK-NEXT:   %16 = llvm.insertvalue %15, %12[3, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+
 func @slice(%arg0: !linalg.buffer<f32>, %arg1: !linalg.range) {
   %0 = linalg.view %arg0[%arg1] : !linalg.view<?xf32>
   %1 = linalg.slice %0[%arg1] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>