From 13dbad87f6890ca278dfec9c4c90668d1c01e8ef Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 16 May 2019 11:57:36 -0700 Subject: [PATCH] Add linalg.range_intersect conversion to LLVM. This CL adds lowering for linalg.range_intersect into LLVM by computing: * new_min <- max (range1.min, range2.min) * new_max <- min (range1.max, range2.max) * new_step <- range1.step * range2.step -- PiperOrigin-RevId: 248571810 --- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 84 ++++++++++++++++++----- mlir/test/Linalg/llvm.mlir | 22 ++++++ 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 2f5a108..f5d9e2b 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -60,6 +60,8 @@ using call = OperationBuilder; using gep = ValueBuilder; using llvm_load = ValueBuilder; using llvm_store = OperationBuilder; +using llvm_select = ValueBuilder; +using icmp = ValueBuilder; template static llvm::Type *getPtrToElementType(T containerType, @@ -145,8 +147,7 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) { // Create an array attribute containing integer attributes with values provided // in `position`. -static ArrayAttr makePositionAttr(FuncBuilder &builder, - ArrayRef position) { +static ArrayAttr positionAttr(FuncBuilder &builder, ArrayRef position) { SmallVector attrs; attrs.reserve(position.size()); for (auto p : position) @@ -203,9 +204,9 @@ public: allocated = bitcast(elementPtrType, allocated); Value *desc = undef(bufferDescriptorType); desc = insertvalue(bufferDescriptorType, desc, allocated, - makePositionAttr(rewriter, 0)); + positionAttr(rewriter, 0)); desc = insertvalue(bufferDescriptorType, desc, size, - makePositionAttr(rewriter, 1)); + positionAttr(rewriter, 1)); return {desc}; } }; @@ -239,9 +240,8 @@ public: // Emit MLIR for buffer_dealloc. edsc::ScopedContext context(rewriter, op->getLoc()); - Value *casted = - bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0], - makePositionAttr(rewriter, 0))); + Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0], + positionAttr(rewriter, 0))); call(ArrayRef(), rewriter.getFunctionAttr(freeFunc), casted); return {}; @@ -258,7 +258,7 @@ public: FuncBuilder &rewriter) const override { auto int64Ty = lowering.convertType(operands[0]->getType()); edsc::ScopedContext context(rewriter, op->getLoc()); - return {extractvalue(int64Ty, operands[0], makePositionAttr(rewriter, 1))}; + return {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))}; } }; @@ -275,7 +275,7 @@ public: edsc::ScopedContext context(rewriter, op->getLoc()); return {extractvalue( indexTy, operands[0], - makePositionAttr(rewriter, {2, static_cast(dimOp.getIndex())}))}; + positionAttr(rewriter, {2, static_cast(dimOp.getIndex())}))}; } }; @@ -299,7 +299,7 @@ public: getPtrToElementType(loadOp.getViewType(), lowering)); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef values) { - return makePositionAttr(rewriter, values); + return positionAttr(rewriter, values); }; // Linearize subscripts as: @@ -349,11 +349,57 @@ public: // Fill in an aggregate value of the descriptor. Value *desc = undef(rangeDescriptorTy); desc = insertvalue(rangeDescriptorTy, desc, operands[0], - makePositionAttr(rewriter, 0)); + positionAttr(rewriter, 0)); desc = insertvalue(rangeDescriptorTy, desc, operands[1], - makePositionAttr(rewriter, 1)); + positionAttr(rewriter, 1)); desc = insertvalue(rangeDescriptorTy, desc, operands[2], - makePositionAttr(rewriter, 2)); + positionAttr(rewriter, 2)); + + return {desc}; + } +}; + +// RangeIntersectOp creates a new range descriptor. +class RangeIntersectOpConversion : public LLVMOpLowering { +public: + explicit RangeIntersectOpConversion(MLIRContext *context, + LLVMLowering &lowering_) + : LLVMOpLowering(RangeIntersectOp::getOperationName(), context, + lowering_) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto rangeIntersectOp = cast(op); + auto rangeDescriptorTy = + convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering); + auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); + auto int1Ty = lowering.convertType(rewriter.getIntegerType(1)); + + edsc::ScopedContext context(rewriter, op->getLoc()); + auto min1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 0)); + auto min2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 0)); + auto max1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1)); + auto max2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 1)); + auto step1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 2)); + auto step2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 2)); + + // Fill in an aggregate value of the descriptor. + auto SLE = + rewriter.getI64IntegerAttr(static_cast(CmpIPredicate::SLE)); + auto SGE = + rewriter.getI64IntegerAttr(static_cast(CmpIPredicate::SGE)); + Value *desc = undef(rangeDescriptorTy); + desc = insertvalue( + rangeDescriptorTy, desc, + llvm_select(int64Ty, icmp(int1Ty, SGE, min1, min2), min1, min2), + positionAttr(rewriter, 0)); + desc = insertvalue( + rangeDescriptorTy, desc, + llvm_select(int64Ty, icmp(int1Ty, SLE, max1, max2), max1, max2), + positionAttr(rewriter, 1)); + // TODO(ntv): this assumes both steps are one for now. Enforce and extend. + desc = insertvalue(rangeDescriptorTy, desc, mul(step1, step2), + positionAttr(rewriter, 2)); return {desc}; } @@ -374,7 +420,7 @@ public: // Helper function to create an integer array attribute out of a list of // values. auto pos = [&rewriter](ArrayRef values) { - return makePositionAttr(rewriter, values); + return positionAttr(rewriter, values); }; // Helper function to obtain the ptr of the given `view`. auto getViewPtr = [pos, &rewriter, this](ViewType type, @@ -471,7 +517,7 @@ public: auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef values) { - return makePositionAttr(rewriter, values); + return positionAttr(rewriter, values); }; // First operand to `view` is the buffer descriptor. @@ -545,10 +591,10 @@ protected: return ConversionListBuilder< BufferAllocOpConversion, BufferDeallocOpConversion, BufferSizeOpConversion, DimOpConversion, DotOpConversion, - LoadOpConversion, RangeOpConversion, SliceOpConversion, - StoreOpConversion, ViewOpConversion>::build(&converterStorage, - llvmDialect->getContext(), - *this); + LoadOpConversion, RangeOpConversion, RangeIntersectOpConversion, + SliceOpConversion, StoreOpConversion, + ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(), + *this); } Type convertAdditionalType(Type t) override { diff --git a/mlir/test/Linalg/llvm.mlir b/mlir/test/Linalg/llvm.mlir index 143b851..1c641fb 100644 --- a/mlir/test/Linalg/llvm.mlir +++ b/mlir/test/Linalg/llvm.mlir @@ -80,3 +80,25 @@ func @dim(%arg0: !linalg.view) { } // CHECK-LABEL: func @dim(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) { // CHECK: %0 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + +func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.range { + %0 = linalg.range_intersect %arg0, %arg1 : !linalg.range + return %0 : !linalg.range +} +// CHECK-LABEL: func @range_intersect(%arg0: !llvm<"{ i64, i64, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">) -> !llvm<"{ i64, i64, i64 }"> { +// CHECK: %0 = llvm.extractvalue %arg0[0] : !llvm<"{ i64, i64, i64 }"> +// CHECK: %1 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }"> +// CHECK: %2 = llvm.extractvalue %arg0[1] : !llvm<"{ i64, i64, i64 }"> +// CHECK: %3 = llvm.extractvalue %arg1[1] : !llvm<"{ i64, i64, i64 }"> +// CHECK: %4 = llvm.extractvalue %arg0[2] : !llvm<"{ i64, i64, i64 }"> +// CHECK: %5 = llvm.extractvalue %arg1[2] : !llvm<"{ i64, i64, i64 }"> +// CHECK: %6 = llvm.undef : !llvm<"{ i64, i64, i64 }"> +// CHECK: %7 = llvm.icmp "sge" %0, %1 : !llvm.i64 +// CHECK: %8 = llvm.select %7, %0, %1 : !llvm.i1, !llvm.i64 +// CHECK: %9 = llvm.insertvalue %8, %6[0] : !llvm<"{ i64, i64, i64 }"> +// CHECK: %10 = llvm.icmp "sle" %2, %3 : !llvm.i64 +// CHECK: %11 = llvm.select %10, %2, %3 : !llvm.i1, !llvm.i64 +// CHECK: %12 = llvm.insertvalue %11, %9[1] : !llvm<"{ i64, i64, i64 }"> +// CHECK: %13 = llvm.mul %4, %5 : !llvm.i64 +// CHECK: %14 = llvm.insertvalue %13, %12[2] : !llvm<"{ i64, i64, i64 }"> +// CHECK: llvm.return %14 : !llvm<"{ i64, i64, i64 }"> -- 2.7.4