Refactor linalg.view lowering to LLVM - NFC
authorNicolas Vasilache <ntv@google.com>
Wed, 14 Aug 2019 14:01:04 +0000 (07:01 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 14 Aug 2019 14:01:41 +0000 (07:01 -0700)
This CL fuses the emission of size and stride information and makes it clearer which indexings are stepped over when querying the positions. This refactor was motivated by an index calculation bug in the stride computation.

PiperOrigin-RevId: 263341610

mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/test/Linalg/llvm.mlir

index ac71c32509057b56a3c6c91dd1c8a28957e2067e..908191ccd660c72737f17c5015ff0e2f652b915f 100644 (file)
@@ -388,6 +388,7 @@ public:
   PatternMatchResult
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    SliceOpOperandAdaptor adaptor(operands);
     auto sliceOp = cast<SliceOp>(op);
     auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
     auto viewType = sliceOp.getBaseViewType();
@@ -408,56 +409,45 @@ public:
     // Declare the view descriptor and insert data ptr.
     Value *desc = undef(viewDescriptorTy);
     desc = insertvalue(viewDescriptorTy, desc,
-                       getViewPtr(viewType, operands[0]), pos(0));
+                       getViewPtr(viewType, adaptor.view()), pos(0));
 
     // TODO(ntv): extract sizes and emit asserts.
     SmallVector<Value *, 4> strides(viewType.getRank());
-    for (int dim = 0, e = viewType.getRank(); dim < e; ++dim) {
-      strides[dim] = extractvalue(int64Ty, operands[0], pos({3, dim}));
+    for (int i = 0, e = viewType.getRank(); i < e; ++i) {
+      strides[i] = extractvalue(int64Ty, adaptor.view(), pos({3, i}));
     }
 
     // Compute and insert base offset.
-    Value *baseOffset = extractvalue(int64Ty, operands[0], pos(1));
-    for (int j = 0, e = viewType.getRank(); j < e; ++j) {
-      Value *indexing = operands[1 + j];
+    Value *baseOffset = extractvalue(int64Ty, adaptor.view(), pos(1));
+    for (int i = 0, e = viewType.getRank(); i < e; ++i) {
+      Value *indexing = adaptor.indexings()[i];
       Value *min =
-          sliceOp.indexing(j)->getType().isa<RangeType>()
+          sliceOp.indexing(i)->getType().isa<RangeType>()
               ? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
               : indexing;
-      Value *product = mul(min, strides[j]);
+      Value *product = mul(min, strides[i]);
       baseOffset = add(baseOffset, product);
     }
     desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
 
-    // Compute and insert view sizes (max - min along the range).  Skip the
-    // non-range operands as they will be projected away from the view.
-    int i = 0, j = 0;
-    for (Value *index : sliceOp.indexings()) {
-      if (!index->getType().isa<RangeType>()) {
-        ++j;
-        continue;
+    // Compute and insert view sizes (max - min along the range) and strides.
+    // Skip the non-range operands as they will be projected away from the view.
+    int numNewDims = 0;
+    for (auto en : llvm::enumerate(sliceOp.indexings())) {
+      Value *indexing = en.value();
+      if (indexing->getType().isa<RangeType>()) {
+        int i = en.index();
+        Value *rangeDescriptor = adaptor.indexings()[i];
+        Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+        Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+        Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+        Value *size = sub(max, min);
+        Value *stride = mul(strides[i], step);
+        desc = insertvalue(viewDescriptorTy, desc, size, pos({2, numNewDims}));
+        desc =
+            insertvalue(viewDescriptorTy, desc, stride, pos({3, numNewDims}));
+        ++numNewDims;
       }
-
-      Value *rangeDescriptor = operands[1 + j];
-      Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
-      Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
-      Value *size = sub(max, min);
-
-      desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
-      ++i;
-      ++j;
-    }
-
-    // Compute and insert view strides.  Step over the strides that correspond
-    // to non-range operands as they are projected away from the view.
-    i = 0;
-    for (int j = 0, e = strides.size(); j < e; ++j) {
-      if (!sliceOp.indexing(j)->getType().isa<RangeType>())
-        continue;
-      Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
-      Value *stride = mul(strides[j], step);
-      desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
-      ++i;
     }
 
     rewriter.replaceOp(op, desc);
index a56b631d6f86aa4b22684192636498c50620bf9c..9fa05af756141e2038bd45b2c5779c6fe8279856 100644 (file)
@@ -76,10 +76,10 @@ func @slice(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
 //  CHECK-NEXT:   %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 //  CHECK-NEXT:   %{{.*}} = llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
 //  CHECK-NEXT:   %{{.*}} = llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
-//  CHECK-NEXT:   %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
-//  CHECK-NEXT:   %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 //  CHECK-NEXT:   %{{.*}} = llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
 //  CHECK-NEXT:   %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+//  CHECK-NEXT:   %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 //  CHECK-NEXT:   %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 
 func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<f32>) {