From: Nicolas Vasilache Date: Fri, 6 Sep 2019 15:30:54 +0000 (-0700) Subject: Simplify Linalg ABI integration with external function calls. X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1b8eff8fcd80fa40792d24953373b54e6da1305f;p=platform%2Fupstream%2Fllvm.git Simplify Linalg ABI integration with external function calls. View descriptors are converted to *pointer to* LLVM struct to avoid ABI issues related to C struct packing. This creates unnecessary complexity and hampers unification with memrefs. Instead, this CL makes view descriptors convert to LLVM struct (as it was originally) and promotes all structs to pointers right before calling an external function. PiperOrigin-RevId: 267602693 --- diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 2f413de..30aaa0d53 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -103,9 +103,6 @@ private: // pointer as defined by the data layout of the module. LLVM::LLVMType getIndexType(); - // Wrap the given LLVM IR type into an LLVM IR dialect type. - Type wrap(llvm::Type *llvmType); - // Extract an LLVM IR dialect type. LLVM::LLVMType unwrap(Type type); }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index e29827e..569b139 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -72,7 +72,7 @@ using llvm_select = ValueBuilder; using mul = ValueBuilder; using ptrtoint = ValueBuilder; using sub = ValueBuilder; -using undef = ValueBuilder; +using llvm_undef = ValueBuilder; using urem = ValueBuilder; using llvm_alloca = ValueBuilder; using llvm_return = OperationBuilder; @@ -123,10 +123,10 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { if (t.isa()) return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); - // A linalg.view type converts to a *pointer to* a view descriptor. The view - // descriptor contains the pointer to the data buffer, followed by a 64-bit - // integer containing the distance between the beginning of the buffer and the - // first element to be accessed through the view, followed by two arrays, each + // A linalg.view type converts to a view descriptor. The view descriptor + // contains the pointer to the data buffer, followed by a 64-bit integer + // containing the distance between the beginning of the buffer and the first + // element to be accessed through the view, followed by two arrays, each // containing as many 64-bit integers as the rank of the View. The first array // represents the size, in number of original elements, of the view along the // given dimension. When taking the view, the size is the difference between @@ -146,12 +146,11 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { // int64_t offset; // int64_t sizes[Rank]; // int64_t strides[Rank]; - // } *; + // }; if (auto viewType = t.dyn_cast()) { auto ptrTy = getPtrToElementType(viewType, lowering); auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank()); - return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy) - .getPointerTo(); + return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy); } return Type(); @@ -179,7 +178,7 @@ namespace { /// Factor out the common information for all view conversions: /// 1. common types in (standard and LLVM dialects) /// 2. `pos` method -/// 3. op of the FuncOp alloca'ed value and descriptor. +/// 3. view descriptor construction `desc`. class BaseViewConversionHelper { public: BaseViewConversionHelper(Operation *op, ViewType viewType, @@ -189,28 +188,17 @@ public: int64Ty( lowering.convertType(rewriter.getIntegerType(64)).cast()), rewriter(rewriter) { - IndexType indexType = rewriter.getIndexType(); - viewDescriptorPtrTy = - convertLinalgType(viewType, lowering).cast(); - OpBuilder::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart( - &op->getParentOfType().getBlocks().front()); - - edsc::ScopedContext context(rewriter, op->getLoc()); - Value *one = constant(int64Ty, IntegerAttr::get(indexType, 1)); - // Alloca with proper alignment. - allocatedDesc = llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8); - // Load the alloca'ed descriptor. - desc = llvm_load(allocatedDesc); + viewDescriptorTy = convertLinalgType(viewType, lowering).cast(); + desc = rewriter.create(op->getLoc(), viewDescriptorTy); } ArrayAttr pos(ArrayRef values) const { return positionAttr(rewriter, values); }; - LLVMType elementTy, int64Ty, viewDescriptorPtrTy; + LLVMType elementTy, int64Ty, viewDescriptorTy; ConversionPatternRewriter &rewriter; - Value *allocatedDesc, *desc; + Value *desc; }; } // namespace @@ -283,7 +271,7 @@ public: data = gep(voidPtrTy, allocated, offset); } data = bitcast(elementPtrType, data); - Value *desc = undef(bufferDescriptorTy); + Value *desc = llvm_undef(bufferDescriptorTy); desc = insertvalue(bufferDescriptorTy, desc, allocated, positionAttr(rewriter, kBasePtrPosInBuffer)); desc = insertvalue(bufferDescriptorTy, desc, data, @@ -362,7 +350,7 @@ public: auto pos = positionAttr( rewriter, {kSizePosInView, static_cast(dimOp.getIndex())}); linalg::DimOpOperandAdaptor adaptor(operands); - Value *viewDescriptor = llvm_load(adaptor.view()); + Value *viewDescriptor = adaptor.view(); rewriter.replaceOp(op, {extractvalue(indexTy, viewDescriptor, pos)}); return matchSuccess(); } @@ -382,7 +370,7 @@ public: // current view indices. Use the base offset and strides stored in the view // descriptor to emit IR iteratively computing the actual offset, followed by // a getelementptr. This must be called under an edsc::ScopedContext. - Value *obtainDataPtr(Operation *op, Value *viewDescriptorPtr, + Value *obtainDataPtr(Operation *op, Value *viewDescriptor, ArrayRef indices, ConversionPatternRewriter &rewriter) const { auto loadOp = cast(op); @@ -394,7 +382,6 @@ public: // Linearize subscripts as: // base_offset + SUM_i index_i * stride_i. - Value *viewDescriptor = llvm_load(viewDescriptorPtr); Value *base = extractvalue(elementTy, viewDescriptor, pos(kPtrPosInView)); Value *offset = extractvalue(int64Ty, viewDescriptor, pos(kOffsetPosInView)); @@ -442,7 +429,7 @@ public: // Fill in an aggregate value of the descriptor. RangeOpOperandAdaptor adaptor(operands); - Value *desc = undef(rangeDescriptorTy); + Value *desc = llvm_undef(rangeDescriptorTy); desc = insertvalue(desc, adaptor.min(), positionAttr(rewriter, 0)); desc = insertvalue(desc, adaptor.max(), positionAttr(rewriter, 1)); desc = insertvalue(desc, adaptor.step(), positionAttr(rewriter, 2)); @@ -468,57 +455,42 @@ public: matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { SliceOpOperandAdaptor adaptor(operands); + Value *baseDesc = adaptor.view(); + auto sliceOp = cast(op); - auto viewDescriptorPtrTy = - convertLinalgType(sliceOp.getViewType(), lowering); - auto viewType = sliceOp.getBaseViewType(); - auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); + BaseViewConversionHelper helper(op, sliceOp.getViewType(), rewriter, + lowering); + LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty, + viewDescriptorTy = helper.viewDescriptorTy; + Value *desc = helper.desc; - // Helper function to create an integer array attribute out of a list of - // values. - auto pos = [&rewriter](ArrayRef values) { - return positionAttr(rewriter, values); - }; + auto viewType = sliceOp.getBaseViewType(); edsc::ScopedContext context(rewriter, op->getLoc()); - // Declare the view descriptor and insert data ptr *at the entry block of - // the function*, which is the preferred location for LLVM's analyses. - auto ip = rewriter.getInsertionPoint(); - auto ib = rewriter.getInsertionBlock(); - rewriter.setInsertionPointToStart( - &op->getParentOfType().getBlocks().front()); Value *zero = constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - Value *one = - constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - // Alloca with proper alignment. - Value *allocatedDesc = - llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8); - Value *desc = llvm_load(allocatedDesc); - rewriter.setInsertionPoint(ib, ip); - - Value *baseDesc = llvm_load(adaptor.view()); - - auto ptrPos = pos(kPtrPosInView); - auto elementTy = getPtrToElementType(sliceOp.getViewType(), lowering); + + auto ptrPos = helper.pos(kPtrPosInView); desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); // TODO(ntv): extract sizes and emit asserts. SmallVector strides(viewType.getRank()); for (int i = 0, e = viewType.getRank(); i < e; ++i) { - strides[i] = extractvalue(int64Ty, baseDesc, pos({kStridePosInView, i})); + strides[i] = + extractvalue(int64Ty, baseDesc, helper.pos({kStridePosInView, i})); } // Compute and insert base offset. - Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView)); + Value *baseOffset = + extractvalue(int64Ty, baseDesc, helper.pos(kOffsetPosInView)); for (int i = 0, e = viewType.getRank(); i < e; ++i) { Value *indexing = adaptor.indexings()[i]; Value *min = indexing; if (sliceOp.indexing(i)->getType().isa()) - min = extractvalue(int64Ty, indexing, pos(0)); + min = extractvalue(int64Ty, indexing, helper.pos(0)); baseOffset = add(baseOffset, mul(min, strides[i])); } - desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView)); + desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView)); // Compute and insert view sizes (max - min along the range) and strides. // Skip the non-range operands as they will be projected away from the view. @@ -528,11 +500,11 @@ public: if (indexing->getType().isa()) { int rank = en.index(); Value *rangeDescriptor = adaptor.indexings()[rank]; - Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); - Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); - Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); + Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0)); + Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1)); + Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2)); Value *baseSize = - extractvalue(int64Ty, baseDesc, pos({kSizePosInView, rank})); + extractvalue(int64Ty, baseDesc, helper.pos({kSizePosInView, rank})); // Bound upper by base view upper bound. max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, baseSize); @@ -541,15 +513,15 @@ public: size = llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); Value *stride = mul(strides[rank], step); - desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims})); - desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims})); + desc = + insertvalue(desc, size, helper.pos({kSizePosInView, numNewDims})); + desc = insertvalue(desc, stride, + helper.pos({kStridePosInView, numNewDims})); ++numNewDims; } } - // Store back in alloca'ed region. - llvm_store(desc, allocatedDesc); - rewriter.replaceOp(op, allocatedDesc); + rewriter.replaceOp(op, desc); return matchSuccess(); } }; @@ -588,16 +560,19 @@ public: ConversionPatternRewriter &rewriter) const override { // Initialize the common boilerplate and alloca at the top of the FuncOp. TransposeOpOperandAdaptor adaptor(operands); + Value *baseDesc = adaptor.view(); + auto tranposeOp = cast(op); + // No permutation, early exit. + if (tranposeOp.permutation().isIdentity()) + return rewriter.replaceOp(op, baseDesc), matchSuccess(); + BaseViewConversionHelper helper(op, tranposeOp.getViewType(), rewriter, lowering); LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; - Value *allocatedDesc = helper.allocatedDesc, *desc = helper.desc; + Value *desc = helper.desc; edsc::ScopedContext context(rewriter, op->getLoc()); - // Load the descriptor of the view constructed by the helper. - Value *baseDesc = llvm_load(adaptor.view()); - // Copy the base pointer from the old descriptor to the new one. ArrayAttr ptrPos = helper.pos(kPtrPosInView); desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); @@ -606,12 +581,6 @@ public: ArrayAttr offPos = helper.pos(kOffsetPosInView); desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos); - if (tranposeOp.permutation().isIdentity()) { - // No permutation, just store back in alloca'ed region. - llvm_store(desc, allocatedDesc); - return rewriter.replaceOp(op, allocatedDesc), matchSuccess(); - } - // Iterate over the dimensions and apply size/stride permutation. for (auto en : llvm::enumerate(tranposeOp.permutation().getResults())) { int sourcePos = en.index(); @@ -625,9 +594,7 @@ public: insertvalue(desc, stride, helper.pos({kStridePosInView, targetPos})); } - // Store back in alloca'ed region. - llvm_store(desc, allocatedDesc); - rewriter.replaceOp(op, allocatedDesc); + rewriter.replaceOp(op, desc); return matchSuccess(); } }; @@ -647,45 +614,32 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto viewOp = cast(op); ViewOpOperandAdaptor adaptor(operands); - auto viewDescriptorPtrTy = - convertLinalgType(viewOp.getViewType(), lowering); - auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering); - auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); - auto pos = [&rewriter](ArrayRef values) { - return positionAttr(rewriter, values); - }; + auto viewOp = cast(op); + BaseViewConversionHelper helper(op, viewOp.getViewType(), rewriter, + lowering); + LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty, + viewDescriptorTy = helper.viewDescriptorTy; + Value *desc = helper.desc; Value *bufferDescriptor = adaptor.buffer(); auto bufferTy = getPtrToElementType( viewOp.buffer()->getType().cast(), lowering); - // Declare the descriptor of the view. edsc::ScopedContext context(rewriter, op->getLoc()); - auto ip = rewriter.getInsertionPoint(); - auto ib = rewriter.getInsertionBlock(); - rewriter.setInsertionPointToStart( - &op->getParentOfType().getBlocks().front()); - Value *one = - constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - // Alloca for proper alignment. - Value *allocatedDesc = - llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8); - Value *desc = llvm_load(allocatedDesc); - rewriter.setInsertionPoint(ib, ip); // Copy the buffer pointer from the old descriptor to the new one. Value *bufferAsViewElementType = - bitcast(elementTy, - extractvalue(bufferTy, bufferDescriptor, pos(kPtrPosInBuffer))); - desc = insertvalue(desc, bufferAsViewElementType, pos(kPtrPosInView)); + bitcast(elementTy, extractvalue(bufferTy, bufferDescriptor, + helper.pos(kPtrPosInBuffer))); + desc = + insertvalue(desc, bufferAsViewElementType, helper.pos(kPtrPosInView)); // Zero base offset. auto indexTy = rewriter.getIndexType(); Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0)); - desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView)); + desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView)); // Compute and insert view sizes (max - min along the range). int numRanges = llvm::size(viewOp.ranges()); @@ -693,26 +647,72 @@ public: for (int i = numRanges - 1; i >= 0; --i) { // Update stride. Value *rangeDescriptor = operands[1 + i]; - Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); + Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2)); Value *stride = mul(runningStride, step); - desc = insertvalue(desc, stride, pos({kStridePosInView, i})); + desc = insertvalue(desc, stride, helper.pos({kStridePosInView, i})); // Update size. - Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); - Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); + Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0)); + Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1)); Value *size = sub(max, min); - desc = insertvalue(desc, size, pos({kSizePosInView, i})); + desc = insertvalue(desc, size, helper.pos({kSizePosInView, i})); // Update stride for the next dimension. if (i > 0) runningStride = mul(runningStride, max); } - // Store back in alloca'ed region. - llvm_store(desc, allocatedDesc); - rewriter.replaceOp(op, allocatedDesc); + rewriter.replaceOp(op, desc); return matchSuccess(); } }; +// Promote LLVM struct types to pointer to struct types to avoid ABI issues +// related to C struct packing. +static SmallVector +promoteStructTypes(Operation::operand_range operands, + LLVMTypeConverter &lowering) { + SmallVector res; + for (auto operand : operands) { + auto type = lowering.convertType(operand->getType()).cast(); + if (type.isStructTy()) + res.push_back(type.getPointerTo()); + else + res.push_back(type); + } + return res; +} + +// Promote LLVM struct to pointer to struct to avoid ABI issues related to +// C struct packing. +static SmallVector +promoteStructs(Location loc, ArrayRef operands, + ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering) { + auto *context = rewriter.getContext(); + auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + auto indexType = IndexType::get(context); + edsc::ScopedContext scope(rewriter, loc); + SmallVector promotedOperands; + promotedOperands.reserve(operands.size()); + for (auto *operand : operands) { + auto type = operand->getType().cast(); + if (!type.isStructTy()) { + promotedOperands.push_back(operand); + continue; + } + // Alloca with proper alignment. This is purely for solving ABI issues + // related to C struct packing across external library call boundaries. We + // do not expect optimizations of this alloca op and so we omit + // allocating at the entry block. + auto ptrType = type.cast().getPointerTo(); + Value *one = constant(int64Ty, IntegerAttr::get(indexType, 1)); + Value *allocated = llvm_alloca(ptrType, one, /*alignment=*/8); + // Store into the alloca'ed descriptor. + llvm_store(operand, allocated); + promotedOperands.push_back(allocated); + } + return promotedOperands; +} + // Get function definition for the LinalgOp. If it doesn't exist, insert a // definition. template @@ -731,9 +731,9 @@ getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering, } // Get the Function type consistent with LLVM Lowering. - SmallVector inputTypes; - for (auto operand : op->getOperands()) - inputTypes.push_back(lowering.convertType(operand->getType())); + // Structs are automatically promoted to pointer to struct in order to avoid + // ABI issues related to C struct packing that we don't want to handle here. + auto inputTypes = promoteStructTypes(op->getOperands(), lowering); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); @@ -778,8 +778,9 @@ public: auto fAttr = rewriter.getSymbolRefAttr(f); auto named = rewriter.getNamedAttr("callee", fAttr); - rewriter.replaceOpWithNewOp(op, operands, - ArrayRef{named}); + rewriter.replaceOpWithNewOp( + op, promoteStructs(op->getLoc(), operands, rewriter, lowering), + ArrayRef{named}); return matchSuccess(); } }; @@ -809,8 +810,9 @@ public: auto fAttr = rewriter.getSymbolRefAttr(f); auto named = rewriter.getNamedAttr("callee", fAttr); - rewriter.replaceOpWithNewOp(op, operands, - ArrayRef{named}); + rewriter.replaceOpWithNewOp( + op, promoteStructs(op->getLoc(), operands, rewriter, lowering), + ArrayRef{named}); return matchSuccess(); } }; diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir index b82d114..3d21051 100644 --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -48,10 +48,7 @@ func @view(%arg0: !linalg.buffer, %arg1: !linalg.range) { return } // CHECK-LABEL: func @view -// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: llvm.alloca {{.*}} x !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> -// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> -// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i8*, float*, i64 }"> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i8*, float*, i64 }"> // CHECK-NEXT: llvm.bitcast {{.*}} : !llvm<"float*"> to !llvm<"float*"> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64 @@ -64,7 +61,6 @@ func @view(%arg0: !linalg.buffer, %arg1: !linalg.range) { // CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> -// CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> // CHECK-NEXT: llvm.return func @view3d(%arg0: !linalg.buffer, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) { @@ -72,9 +68,6 @@ func @view3d(%arg0: !linalg.buffer, %arg1: !linalg.range, %arg2: !linalg. return } // CHECK-LABEL: func @view3d -// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: llvm.alloca {{.*}} x !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> -// CHECK-NEXT: llvm.load {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> @@ -90,13 +83,9 @@ func @slice(%arg0: !linalg.buffer, %arg1: !linalg.range) { return } // CHECK-LABEL: func @slice -// 1st load from view_op -// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> -// 2nd load from reloading the view descriptor pointer -// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> -// 3rd load from slice_op -// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> -// insert data ptr +// insert ptr for view op +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> +// insert data ptr for slice op // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> @@ -127,14 +116,15 @@ func @dot(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg linalg.dot(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view return } -// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">) { -// CHECK: llvm.call @linalg_dot_viewxf32_viewxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">) -> () +// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) { +// CHECK-COUNT-3: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store +// CHECK-NEXT: llvm.call @linalg_dot_viewxf32_viewxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">) -> () func @dim(%arg0: !linalg.view) { %0 = linalg.dim %arg0, 1 : !linalg.view return } -// CHECK-LABEL: func @dim(%{{.*}}: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">) { +// CHECK-LABEL: func @dim(%{{.*}}: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) { // CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> func @subview(%arg0: !linalg.view) { @@ -145,10 +135,9 @@ func @subview(%arg0: !linalg.view) { // CHECK-LABEL: func @subview // // Subview lowers to range + slice op -// CHECK: llvm.alloca %{{.*}} x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*"> // CHECK: llvm.mlir.undef : !llvm<"{ i64, i64, i64 }"> // CHECK: llvm.mlir.undef : !llvm<"{ i64, i64, i64 }"> -// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> // // Select occurs in slice op lowering // CHECK: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> @@ -177,11 +166,8 @@ func @view_with_range_and_index(%arg0: !linalg.view) { return } // CHECK-LABEL: func @view_with_range_and_index -// top of the function alloca + load. -// CHECK: llvm.alloca %{{.*}} x !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ double*, i64, [1 x i64], [1 x i64] }*"> -// CHECK: llvm.load %{{.*}} : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }*"> -// loop-body load from descriptor ptr. -// CHECK: llvm.load %{{.*}} : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }*"> +// loop-body. +// CHECK: llvm.mlir.undef : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> // CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> @@ -204,7 +190,7 @@ func @transpose(%arg0: !linalg.view) { return } // CHECK-LABEL: func @transpose -// CHECK: llvm.alloca {{.*}} x !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> +// CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> @@ -213,7 +199,6 @@ func @transpose(%arg0: !linalg.view) { // CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> func @copy_transpose(%arg0: !linalg.view, %arg1: !linalg.view) { linalg.copy(%arg0, %arg1) {inputPermutation = (i, j, k) -> (i, k, j), @@ -223,24 +208,23 @@ func @copy_transpose(%arg0: !linalg.view, %arg1: !linalg.view -// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> // Transpose output -// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> -// Call external copy -// CHECK: llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> () +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// Call external copy after promoting input and output structs to pointers +// CHECK-COUNT-2: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store +// CHECK: llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> ()