void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
+ void setConstantSize(unsigned i, int64_t v) {
+ d.setConstantSize(rewriter(), loc(), i, v);
+ }
Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
+ void setConstantStride(unsigned i, int64_t v) {
+ d.setConstantStride(rewriter(), loc(), i, v);
+ }
operator Value() { return d; }
}
};
+// ReshapeOp creates a new view descriptor of the proper rank.
+// For now, the only conversion supported is for target MemRef with static sizes
+// and strides.
+class ReshapeOpConversion : public LLVMOpLowering {
+public:
+ explicit ReshapeOpConversion(MLIRContext *context,
+ LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto reshapeOp = cast<ReshapeOp>(op);
+ MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
+
+ if (!dstType.hasStaticShape())
+ return matchFailure();
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto res = getStridesAndOffset(dstType, strides, offset);
+ if (failed(res) || llvm::any_of(strides, [](int64_t val) {
+ return ShapedType::isDynamicStrideOrOffset(val);
+ }))
+ return matchFailure();
+
+ edsc::ScopedContext context(rewriter, op->getLoc());
+ ReshapeOpOperandAdaptor adaptor(operands);
+ BaseViewConversionHelper baseDesc(adaptor.view());
+ BaseViewConversionHelper desc(lowering.convertType(dstType));
+ desc.setAllocatedPtr(baseDesc.allocatedPtr());
+ desc.setAlignedPtr(baseDesc.alignedPtr());
+ desc.setOffset(baseDesc.offset());
+ for (auto en : llvm::enumerate(dstType.getShape()))
+ desc.setConstantSize(en.index(), en.value());
+ for (auto en : llvm::enumerate(strides))
+ desc.setConstantStride(en.index(), en.value());
+ rewriter.replaceOp(op, {desc});
+ return matchSuccess();
+ }
+};
+
/// Conversion pattern that transforms a linalg.slice op into:
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
void mlir::populateLinalgToLLVMConversionPatterns(
LinalgTypeConverter &converter, OwningRewritePatternList &patterns,
MLIRContext *ctx) {
- patterns.insert<RangeOpConversion, SliceOpConversion, TransposeOpConversion,
- YieldOpConversion>(ctx, converter);
+ patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
+ TransposeOpConversion, YieldOpConversion>(ctx, converter);
}
namespace {
// CHECK-LABEL: func @matmul_vec_indexed(
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
+
+func @reshape_static(%arg0: memref<3x4x5xf32>) {
+ // Reshapes that expand and collapse back a contiguous tensor with some 1's.
+ %0 = linalg.reshape %arg0 [(i, j, k, l, m) -> (i, j),
+ (i, j, k, l, m) -> (k),
+ (i, j, k, l, m) -> (l, m)] :
+ memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+ %r0 = linalg.reshape %0 [(i, j, k, l, m) -> (i, j),
+ (i, j, k, l, m) -> (k),
+ (i, j, k, l, m) -> (l, m)] :
+ memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+ return
+}
+// CHECK-LABEL: func @reshape_static(
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">