[mlir] support creating memref descriptors from static shape with non-zero offset
authorTobias Gysi <tobias.gysi@inf.ethz.ch>
Wed, 12 Feb 2020 21:36:21 +0000 (22:36 +0100)
committerAlex Zinenko <zinenko@google.com>
Wed, 12 Feb 2020 21:40:49 +0000 (22:40 +0100)
This patch adapts the method MemRefDescriptor::fromStaticShape to
support static non-zero offsets. The updated method uses the
getStridesAndOffset method to extract strides and offset. The patch also
adapts the test cases since sizes and strides are now set in forward
instead of reverse order.

Differential Revision: https://reviews.llvm.org/D74474

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir
mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir

index 8d97ff1..57ebe42 100644 (file)
@@ -430,7 +430,17 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
                                   LLVMTypeConverter &typeConverter,
                                   MemRefType type, Value memory) {
   assert(type.hasStaticShape() && "unexpected dynamic shape");
-  assert(type.getAffineMaps().empty() && "unexpected layout map");
+
+  // Extract all strides and offsets and verify they are static.
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  auto result = getStridesAndOffset(type, strides, offset);
+  (void)result;
+  assert(succeeded(result) && "unexpected failure in stride computation");
+  assert(offset != MemRefType::getDynamicStrideOrOffset() &&
+         "expected static offset");
+  assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) &&
+         "expected static strides");
 
   auto convertedType = typeConverter.convertType(type);
   assert(convertedType && "unexpected failure in memref type conversion");
@@ -438,16 +448,12 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
   auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
   descr.setAllocatedPtr(builder, loc, memory);
   descr.setAlignedPtr(builder, loc, memory);
-  descr.setConstantOffset(builder, loc, 0);
-
-  // Fill in sizes and strides, in reverse order to simplify stride
-  // calculation.
-  uint64_t runningStride = 1;
-  for (unsigned i = type.getRank(); i > 0; --i) {
-    unsigned dim = i - 1;
-    descr.setConstantSize(builder, loc, dim, type.getDimSize(dim));
-    descr.setConstantStride(builder, loc, dim, runningStride);
-    runningStride *= type.getDimSize(dim);
+  descr.setConstantOffset(builder, loc, offset);
+
+  // Fill in sizes and strides
+  for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
+    descr.setConstantSize(builder, loc, i, type.getDimSize(i));
+    descr.setConstantStride(builder, loc, i, strides[i]);
   }
   return descr;
 }
index 115c71d..c6d080f 100644 (file)
@@ -92,18 +92,18 @@ gpu.module @kernel {
     // CHECK: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1]
     // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
     // CHECK: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2]
-    // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
-    // CHECK: %[[descr5:.*]] = llvm.insertvalue %[[c6]], %[[descr4]][3, 2]
-    // CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-    // CHECK: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 2]
+    // CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
+    // CHECK: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0]
+    // CHECK: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64
+    // CHECK: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0]
     // CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
     // CHECK: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1]
     // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
     // CHECK: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1]
-    // CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
-    // CHECK: %[[descr9:.*]] = llvm.insertvalue %[[c4]], %[[descr8]][3, 0]
-    // CHECK: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64
-    // CHECK: %[[descr10:.*]] = llvm.insertvalue %[[c12]], %[[descr9]][4, 0]
+    // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
+    // CHECK: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2]
+    // CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+    // CHECK: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2]
 
     %c0 = constant 0 : index
     store %arg0, %arg1[%c0,%c0,%c0] : memref<4x2x6xf32, 3>
index cc8cfc3..b47d355 100644 (file)
@@ -24,20 +24,48 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
 // BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 // BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
   return %static : memref<32x18xf32>
 }
 
 // -----
 
+// CHECK-LABEL: func @check_static_return_with_offset
+// CHECK-COUNT-2: !llvm<"float*">
+// CHECK-COUNT-5: !llvm.i64
+// CHECK-SAME: -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-LABEL: func @check_static_return_with_offset
+// BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
+func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> {
+// CHECK:  llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+
+// BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(22 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  return %static : memref<32x18xf32, offset:7, strides:[22,1]>
+}
+
+// -----
+
 // CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
 // ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
 // BAREPTR-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
@@ -302,7 +330,7 @@ func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f
 // BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) {
 func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
 // CHECK:        llvm.mlir.constant(42 : index) : !llvm.i64
-// BAREPTR:      llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// BAREPTR:      llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
 // BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
   %0 = dim %static, 0 : memref<42x32x15x13x27xf32>
 // CHECK-NEXT:  llvm.mlir.constant(32 : index) : !llvm.i64