[mlir][Linalg] Lower linalg.reshape to LLVM for the static case
authorNicolas Vasilache <ntv@google.com>
Mon, 6 Jan 2020 00:37:56 +0000 (19:37 -0500)
committerNicolas Vasilache <ntv@google.com>
Wed, 8 Jan 2020 18:07:41 +0000 (13:07 -0500)
Summary:
This diff adds lowering of the linalg.reshape op to LLVM.

A new descriptor is created with fields initialized as follows:
1. allocatedPTr, alignedPtr and offset are copied from the source descriptor
2. sizes are copied from the static destination shape
3. strides are copied from the static strides collected with `getStridesAndOffset`

Only the static case in which the target view conforms to strided memref
semantics is supported. Other cases are left for future work and will be added on
a per-need basis.

Reviewers: ftynse, mravishankar

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72316

mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/test/Dialect/Linalg/llvm.mlir

index 2dd36c9..86890b1 100644 (file)
@@ -122,8 +122,14 @@ public:
   void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
   Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
   void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
+  void setConstantSize(unsigned i, int64_t v) {
+    d.setConstantSize(rewriter(), loc(), i, v);
+  }
   Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
   void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
+  void setConstantStride(unsigned i, int64_t v) {
+    d.setConstantStride(rewriter(), loc(), i, v);
+  }
 
   operator Value() { return d; }
 
@@ -161,6 +167,48 @@ public:
   }
 };
 
+// ReshapeOp creates a new view descriptor of the proper rank.
+// For now, the only conversion supported is for target MemRef with static sizes
+// and strides.
+class ReshapeOpConversion : public LLVMOpLowering {
+public:
+  explicit ReshapeOpConversion(MLIRContext *context,
+                               LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto reshapeOp = cast<ReshapeOp>(op);
+    MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
+
+    if (!dstType.hasStaticShape())
+      return matchFailure();
+
+    int64_t offset;
+    SmallVector<int64_t, 4> strides;
+    auto res = getStridesAndOffset(dstType, strides, offset);
+    if (failed(res) || llvm::any_of(strides, [](int64_t val) {
+          return ShapedType::isDynamicStrideOrOffset(val);
+        }))
+      return matchFailure();
+
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    ReshapeOpOperandAdaptor adaptor(operands);
+    BaseViewConversionHelper baseDesc(adaptor.view());
+    BaseViewConversionHelper desc(lowering.convertType(dstType));
+    desc.setAllocatedPtr(baseDesc.allocatedPtr());
+    desc.setAlignedPtr(baseDesc.alignedPtr());
+    desc.setOffset(baseDesc.offset());
+    for (auto en : llvm::enumerate(dstType.getShape()))
+      desc.setConstantSize(en.index(), en.value());
+    for (auto en : llvm::enumerate(strides))
+      desc.setConstantStride(en.index(), en.value());
+    rewriter.replaceOp(op, {desc});
+    return matchSuccess();
+  }
+};
+
 /// Conversion pattern that transforms a linalg.slice op into:
 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
@@ -508,8 +556,8 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
 void mlir::populateLinalgToLLVMConversionPatterns(
     LinalgTypeConverter &converter, OwningRewritePatternList &patterns,
     MLIRContext *ctx) {
-  patterns.insert<RangeOpConversion, SliceOpConversion, TransposeOpConversion,
-                  YieldOpConversion>(ctx, converter);
+  patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
+                  TransposeOpConversion, YieldOpConversion>(ctx, converter);
 }
 
 namespace {
index 7054a3d..d70ee8c 100644 (file)
@@ -196,3 +196,63 @@ func @matmul_vec_indexed(%A: !matrix_type_A,
 // CHECK-LABEL: func @matmul_vec_indexed(
 //   CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 //   CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
+
+func @reshape_static(%arg0: memref<3x4x5xf32>) {
+  // Reshapes that expand and collapse back a contiguous tensor with some 1's.
+  %0 = linalg.reshape %arg0 [(i, j, k, l, m) -> (i, j),
+                             (i, j, k, l, m) -> (k),
+                             (i, j, k, l, m) -> (l, m)] :
+    memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+  %r0 = linalg.reshape %0 [(i, j, k, l, m) -> (i, j),
+                           (i, j, k, l, m) -> (k),
+                           (i, j, k, l, m) -> (l, m)] :
+    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+  return
+}
+// CHECK-LABEL: func @reshape_static(
+//       CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(60 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">