using mul = ValueBuilder<mlir::LLVM::MulOp>;
using ptrtoint = ValueBuilder<mlir::LLVM::PtrToIntOp>;
using sub = ValueBuilder<mlir::LLVM::SubOp>;
-using undef = ValueBuilder<mlir::LLVM::UndefOp>;
+using llvm_undef = ValueBuilder<mlir::LLVM::UndefOp>;
using urem = ValueBuilder<mlir::LLVM::URemOp>;
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
if (t.isa<RangeType>())
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
- // A linalg.view type converts to a *pointer to* a view descriptor. The view
- // descriptor contains the pointer to the data buffer, followed by a 64-bit
- // integer containing the distance between the beginning of the buffer and the
- // first element to be accessed through the view, followed by two arrays, each
+ // A linalg.view type converts to a view descriptor. The view descriptor
+ // contains the pointer to the data buffer, followed by a 64-bit integer
+ // containing the distance between the beginning of the buffer and the first
+ // element to be accessed through the view, followed by two arrays, each
// containing as many 64-bit integers as the rank of the View. The first array
// represents the size, in number of original elements, of the view along the
// given dimension. When taking the view, the size is the difference between
// int64_t offset;
// int64_t sizes[Rank];
// int64_t strides[Rank];
- // } *;
+ // };
if (auto viewType = t.dyn_cast<ViewType>()) {
auto ptrTy = getPtrToElementType(viewType, lowering);
auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
- return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy)
- .getPointerTo();
+ return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy);
}
return Type();
/// Factor out the common information for all view conversions:
/// 1. common types in (standard and LLVM dialects)
/// 2. `pos` method
-/// 3. op of the FuncOp alloca'ed value and descriptor.
+/// 3. view descriptor construction `desc`.
class BaseViewConversionHelper {
public:
BaseViewConversionHelper(Operation *op, ViewType viewType,
int64Ty(
lowering.convertType(rewriter.getIntegerType(64)).cast<LLVMType>()),
rewriter(rewriter) {
- IndexType indexType = rewriter.getIndexType();
- viewDescriptorPtrTy =
- convertLinalgType(viewType, lowering).cast<LLVMType>();
- OpBuilder::InsertionGuard insertGuard(rewriter);
- rewriter.setInsertionPointToStart(
- &op->getParentOfType<FuncOp>().getBlocks().front());
-
- edsc::ScopedContext context(rewriter, op->getLoc());
- Value *one = constant(int64Ty, IntegerAttr::get(indexType, 1));
- // Alloca with proper alignment.
- allocatedDesc = llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8);
- // Load the alloca'ed descriptor.
- desc = llvm_load(allocatedDesc);
+ viewDescriptorTy = convertLinalgType(viewType, lowering).cast<LLVMType>();
+ desc = rewriter.create<LLVM::UndefOp>(op->getLoc(), viewDescriptorTy);
}
ArrayAttr pos(ArrayRef<int> values) const {
return positionAttr(rewriter, values);
};
- LLVMType elementTy, int64Ty, viewDescriptorPtrTy;
+ LLVMType elementTy, int64Ty, viewDescriptorTy;
ConversionPatternRewriter &rewriter;
- Value *allocatedDesc, *desc;
+ Value *desc;
};
} // namespace
data = gep(voidPtrTy, allocated, offset);
}
data = bitcast(elementPtrType, data);
- Value *desc = undef(bufferDescriptorTy);
+ Value *desc = llvm_undef(bufferDescriptorTy);
desc = insertvalue(bufferDescriptorTy, desc, allocated,
positionAttr(rewriter, kBasePtrPosInBuffer));
desc = insertvalue(bufferDescriptorTy, desc, data,
auto pos = positionAttr(
rewriter, {kSizePosInView, static_cast<int>(dimOp.getIndex())});
linalg::DimOpOperandAdaptor adaptor(operands);
- Value *viewDescriptor = llvm_load(adaptor.view());
+ Value *viewDescriptor = adaptor.view();
rewriter.replaceOp(op, {extractvalue(indexTy, viewDescriptor, pos)});
return matchSuccess();
}
// current view indices. Use the base offset and strides stored in the view
// descriptor to emit IR iteratively computing the actual offset, followed by
// a getelementptr. This must be called under an edsc::ScopedContext.
- Value *obtainDataPtr(Operation *op, Value *viewDescriptorPtr,
+ Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter) const {
auto loadOp = cast<Op>(op);
// Linearize subscripts as:
// base_offset + SUM_i index_i * stride_i.
- Value *viewDescriptor = llvm_load(viewDescriptorPtr);
Value *base = extractvalue(elementTy, viewDescriptor, pos(kPtrPosInView));
Value *offset =
extractvalue(int64Ty, viewDescriptor, pos(kOffsetPosInView));
// Fill in an aggregate value of the descriptor.
RangeOpOperandAdaptor adaptor(operands);
- Value *desc = undef(rangeDescriptorTy);
+ Value *desc = llvm_undef(rangeDescriptorTy);
desc = insertvalue(desc, adaptor.min(), positionAttr(rewriter, 0));
desc = insertvalue(desc, adaptor.max(), positionAttr(rewriter, 1));
desc = insertvalue(desc, adaptor.step(), positionAttr(rewriter, 2));
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
SliceOpOperandAdaptor adaptor(operands);
+ Value *baseDesc = adaptor.view();
+
auto sliceOp = cast<SliceOp>(op);
- auto viewDescriptorPtrTy =
- convertLinalgType(sliceOp.getViewType(), lowering);
- auto viewType = sliceOp.getBaseViewType();
- auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
+ BaseViewConversionHelper helper(op, sliceOp.getViewType(), rewriter,
+ lowering);
+ LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty,
+ viewDescriptorTy = helper.viewDescriptorTy;
+ Value *desc = helper.desc;
- // Helper function to create an integer array attribute out of a list of
- // values.
- auto pos = [&rewriter](ArrayRef<int> values) {
- return positionAttr(rewriter, values);
- };
+ auto viewType = sliceOp.getBaseViewType();
edsc::ScopedContext context(rewriter, op->getLoc());
- // Declare the view descriptor and insert data ptr *at the entry block of
- // the function*, which is the preferred location for LLVM's analyses.
- auto ip = rewriter.getInsertionPoint();
- auto ib = rewriter.getInsertionBlock();
- rewriter.setInsertionPointToStart(
- &op->getParentOfType<FuncOp>().getBlocks().front());
Value *zero =
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
- Value *one =
- constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
- // Alloca with proper alignment.
- Value *allocatedDesc =
- llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8);
- Value *desc = llvm_load(allocatedDesc);
- rewriter.setInsertionPoint(ib, ip);
-
- Value *baseDesc = llvm_load(adaptor.view());
-
- auto ptrPos = pos(kPtrPosInView);
- auto elementTy = getPtrToElementType(sliceOp.getViewType(), lowering);
+
+ auto ptrPos = helper.pos(kPtrPosInView);
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
// TODO(ntv): extract sizes and emit asserts.
SmallVector<Value *, 4> strides(viewType.getRank());
for (int i = 0, e = viewType.getRank(); i < e; ++i) {
- strides[i] = extractvalue(int64Ty, baseDesc, pos({kStridePosInView, i}));
+ strides[i] =
+ extractvalue(int64Ty, baseDesc, helper.pos({kStridePosInView, i}));
}
// Compute and insert base offset.
- Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView));
+ Value *baseOffset =
+ extractvalue(int64Ty, baseDesc, helper.pos(kOffsetPosInView));
for (int i = 0, e = viewType.getRank(); i < e; ++i) {
Value *indexing = adaptor.indexings()[i];
Value *min = indexing;
if (sliceOp.indexing(i)->getType().isa<RangeType>())
- min = extractvalue(int64Ty, indexing, pos(0));
+ min = extractvalue(int64Ty, indexing, helper.pos(0));
baseOffset = add(baseOffset, mul(min, strides[i]));
}
- desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
+ desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView));
// Compute and insert view sizes (max - min along the range) and strides.
// Skip the non-range operands as they will be projected away from the view.
if (indexing->getType().isa<RangeType>()) {
int rank = en.index();
Value *rangeDescriptor = adaptor.indexings()[rank];
- Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
- Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
- Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+ Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
+ Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
+ Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
Value *baseSize =
- extractvalue(int64Ty, baseDesc, pos({kSizePosInView, rank}));
+ extractvalue(int64Ty, baseDesc, helper.pos({kSizePosInView, rank}));
// Bound upper by base view upper bound.
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
baseSize);
size =
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
Value *stride = mul(strides[rank], step);
- desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims}));
- desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims}));
+ desc =
+ insertvalue(desc, size, helper.pos({kSizePosInView, numNewDims}));
+ desc = insertvalue(desc, stride,
+ helper.pos({kStridePosInView, numNewDims}));
++numNewDims;
}
}
- // Store back in alloca'ed region.
- llvm_store(desc, allocatedDesc);
- rewriter.replaceOp(op, allocatedDesc);
+ rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
ConversionPatternRewriter &rewriter) const override {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
TransposeOpOperandAdaptor adaptor(operands);
+ Value *baseDesc = adaptor.view();
+
auto tranposeOp = cast<TransposeOp>(op);
+ // No permutation, early exit.
+ if (tranposeOp.permutation().isIdentity())
+ return rewriter.replaceOp(op, baseDesc), matchSuccess();
+
BaseViewConversionHelper helper(op, tranposeOp.getViewType(), rewriter,
lowering);
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
- Value *allocatedDesc = helper.allocatedDesc, *desc = helper.desc;
+ Value *desc = helper.desc;
edsc::ScopedContext context(rewriter, op->getLoc());
- // Load the descriptor of the view constructed by the helper.
- Value *baseDesc = llvm_load(adaptor.view());
-
// Copy the base pointer from the old descriptor to the new one.
ArrayAttr ptrPos = helper.pos(kPtrPosInView);
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
ArrayAttr offPos = helper.pos(kOffsetPosInView);
desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos);
- if (tranposeOp.permutation().isIdentity()) {
- // No permutation, just store back in alloca'ed region.
- llvm_store(desc, allocatedDesc);
- return rewriter.replaceOp(op, allocatedDesc), matchSuccess();
- }
-
// Iterate over the dimensions and apply size/stride permutation.
for (auto en : llvm::enumerate(tranposeOp.permutation().getResults())) {
int sourcePos = en.index();
insertvalue(desc, stride, helper.pos({kStridePosInView, targetPos}));
}
- // Store back in alloca'ed region.
- llvm_store(desc, allocatedDesc);
- rewriter.replaceOp(op, allocatedDesc);
+ rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
- auto viewOp = cast<ViewOp>(op);
ViewOpOperandAdaptor adaptor(operands);
- auto viewDescriptorPtrTy =
- convertLinalgType(viewOp.getViewType(), lowering);
- auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
- auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
- auto pos = [&rewriter](ArrayRef<int> values) {
- return positionAttr(rewriter, values);
- };
+ auto viewOp = cast<ViewOp>(op);
+ BaseViewConversionHelper helper(op, viewOp.getViewType(), rewriter,
+ lowering);
+ LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty,
+ viewDescriptorTy = helper.viewDescriptorTy;
+ Value *desc = helper.desc;
Value *bufferDescriptor = adaptor.buffer();
auto bufferTy = getPtrToElementType(
viewOp.buffer()->getType().cast<BufferType>(), lowering);
- // Declare the descriptor of the view.
edsc::ScopedContext context(rewriter, op->getLoc());
- auto ip = rewriter.getInsertionPoint();
- auto ib = rewriter.getInsertionBlock();
- rewriter.setInsertionPointToStart(
- &op->getParentOfType<FuncOp>().getBlocks().front());
- Value *one =
- constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
- // Alloca for proper alignment.
- Value *allocatedDesc =
- llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8);
- Value *desc = llvm_load(allocatedDesc);
- rewriter.setInsertionPoint(ib, ip);
// Copy the buffer pointer from the old descriptor to the new one.
Value *bufferAsViewElementType =
- bitcast(elementTy,
- extractvalue(bufferTy, bufferDescriptor, pos(kPtrPosInBuffer)));
- desc = insertvalue(desc, bufferAsViewElementType, pos(kPtrPosInView));
+ bitcast(elementTy, extractvalue(bufferTy, bufferDescriptor,
+ helper.pos(kPtrPosInBuffer)));
+ desc =
+ insertvalue(desc, bufferAsViewElementType, helper.pos(kPtrPosInView));
// Zero base offset.
auto indexTy = rewriter.getIndexType();
Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
- desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
+ desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView));
// Compute and insert view sizes (max - min along the range).
int numRanges = llvm::size(viewOp.ranges());
for (int i = numRanges - 1; i >= 0; --i) {
// Update stride.
Value *rangeDescriptor = operands[1 + i];
- Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+ Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
Value *stride = mul(runningStride, step);
- desc = insertvalue(desc, stride, pos({kStridePosInView, i}));
+ desc = insertvalue(desc, stride, helper.pos({kStridePosInView, i}));
// Update size.
- Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
- Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+ Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
+ Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
Value *size = sub(max, min);
- desc = insertvalue(desc, size, pos({kSizePosInView, i}));
+ desc = insertvalue(desc, size, helper.pos({kSizePosInView, i}));
// Update stride for the next dimension.
if (i > 0)
runningStride = mul(runningStride, max);
}
- // Store back in alloca'ed region.
- llvm_store(desc, allocatedDesc);
- rewriter.replaceOp(op, allocatedDesc);
+ rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
+// Promote LLVM struct types to pointer to struct types to avoid ABI issues
+// related to C struct packing.
+static SmallVector<Type, 4>
+promoteStructTypes(Operation::operand_range operands,
+ LLVMTypeConverter &lowering) {
+ SmallVector<Type, 4> res;
+ for (auto operand : operands) {
+ auto type = lowering.convertType(operand->getType()).cast<LLVM::LLVMType>();
+ if (type.isStructTy())
+ res.push_back(type.getPointerTo());
+ else
+ res.push_back(type);
+ }
+ return res;
+}
+
+// Promote LLVM struct to pointer to struct to avoid ABI issues related to
+// C struct packing.
+static SmallVector<Value *, 4>
+promoteStructs(Location loc, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &lowering) {
+ auto *context = rewriter.getContext();
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ auto indexType = IndexType::get(context);
+ edsc::ScopedContext scope(rewriter, loc);
+ SmallVector<Value *, 4> promotedOperands;
+ promotedOperands.reserve(operands.size());
+ for (auto *operand : operands) {
+ auto type = operand->getType().cast<LLVM::LLVMType>();
+ if (!type.isStructTy()) {
+ promotedOperands.push_back(operand);
+ continue;
+ }
+ // Alloca with proper alignment. This is purely for solving ABI issues
+ // related to C struct packing across external library call boundaries. We
+ // do not expect optimizations of this alloca op and so we omit
+ // allocating at the entry block.
+ auto ptrType = type.cast<LLVM::LLVMType>().getPointerTo();
+ Value *one = constant(int64Ty, IntegerAttr::get(indexType, 1));
+ Value *allocated = llvm_alloca(ptrType, one, /*alignment=*/8);
+ // Store into the alloca'ed descriptor.
+ llvm_store(operand, allocated);
+ promotedOperands.push_back(allocated);
+ }
+ return promotedOperands;
+}
+
// Get function definition for the LinalgOp. If it doesn't exist, insert a
// definition.
template <typename LinalgOp>
}
// Get the Function type consistent with LLVM Lowering.
- SmallVector<Type, 4> inputTypes;
- for (auto operand : op->getOperands())
- inputTypes.push_back(lowering.convertType(operand->getType()));
+ // Structs are automatically promoted to pointer to struct in order to avoid
+ // ABI issues related to C struct packing that we don't want to handle here.
+ auto inputTypes = promoteStructTypes(op->getOperands(), lowering);
assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that "
"have void return types");
auto fAttr = rewriter.getSymbolRefAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
- ArrayRef<NamedAttribute>{named});
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, promoteStructs(op->getLoc(), operands, rewriter, lowering),
+ ArrayRef<NamedAttribute>{named});
return matchSuccess();
}
};
auto fAttr = rewriter.getSymbolRefAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
- ArrayRef<NamedAttribute>{named});
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, promoteStructs(op->getLoc(), operands, rewriter, lowering),
+ ArrayRef<NamedAttribute>{named});
return matchSuccess();
}
};
return
}
// CHECK-LABEL: func @view
-// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: llvm.alloca {{.*}} x !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
-// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
-// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i8*, float*, i64 }">
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i8*, float*, i64 }">
// CHECK-NEXT: llvm.bitcast {{.*}} : !llvm<"float*"> to !llvm<"float*">
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-// CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
// CHECK-NEXT: llvm.return
func @view3d(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
return
}
// CHECK-LABEL: func @view3d
-// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: llvm.alloca {{.*}} x !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
-// CHECK-NEXT: llvm.load {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
return
}
// CHECK-LABEL: func @slice
-// 1st load from view_op
-// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
-// 2nd load from reloading the view descriptor pointer
-// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
-// 3rd load from slice_op
-// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
-// insert data ptr
+// insert ptr for view op
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// insert data ptr for slice op
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
linalg.dot(%arg0, %arg1, %arg2) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
return
}
-// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">) {
-// CHECK: llvm.call @linalg_dot_viewxf32_viewxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">) -> ()
+// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
+// CHECK-COUNT-3: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
+// CHECK-NEXT: llvm.call @linalg_dot_viewxf32_viewxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">) -> ()
func @dim(%arg0: !linalg.view<?x?xf32>) {
%0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
return
}
-// CHECK-LABEL: func @dim(%{{.*}}: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">) {
+// CHECK-LABEL: func @dim(%{{.*}}: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
// CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
func @subview(%arg0: !linalg.view<?x?xf32>) {
// CHECK-LABEL: func @subview
//
// Subview lowers to range + slice op
-// CHECK: llvm.alloca %{{.*}} x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
// CHECK: llvm.mlir.undef : !llvm<"{ i64, i64, i64 }">
// CHECK: llvm.mlir.undef : !llvm<"{ i64, i64, i64 }">
-// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
//
// Select occurs in slice op lowering
// CHECK: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
return
}
// CHECK-LABEL: func @view_with_range_and_index
-// top of the function alloca + load.
-// CHECK: llvm.alloca %{{.*}} x !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ double*, i64, [1 x i64], [1 x i64] }*">
-// CHECK: llvm.load %{{.*}} : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }*">
-// loop-body load from descriptor ptr.
-// CHECK: llvm.load %{{.*}} : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }*">
+// loop-body.
+// CHECK: llvm.mlir.undef : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
return
}
// CHECK-LABEL: func @transpose
-// CHECK: llvm.alloca {{.*}} x !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
func @copy_transpose(%arg0: !linalg.view<?x?x?xf32>, %arg1: !linalg.view<?x?x?xf32>) {
linalg.copy(%arg0, %arg1) {inputPermutation = (i, j, k) -> (i, k, j),
}
// CHECK-LABEL: func @copy
// Tranpose input
-// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// Transpose output
-// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
-// Call external copy
-// CHECK: llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> ()
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// Call external copy after promoting input and output structs to pointers
+// CHECK-COUNT-2: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
+// CHECK: llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> ()