!targetMemRefType.hasStaticShape())
return matchFailure();
- Value *sourceMemRef = operands[0];
auto llvmSourceDescriptorTy =
- sourceMemRef->getType().dyn_cast<LLVM::LLVMType>();
+ operands[0]->getType().dyn_cast<LLVM::LLVMType>();
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
return matchFailure();
+ MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return matchFailure();
- Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
- Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
-
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides =
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
// Create descriptor.
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
+ auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+ Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
- Value *allocated = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmSourceElementTy, sourceMemRef,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, allocated,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
- // Set ptr.
- Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmSourceElementTy, sourceMemRef,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ desc.setAllocatedPtr(rewriter, loc, allocated);
+ // Set aligned ptr.
+ Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, zero,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ desc.setOffset(rewriter, loc, zero);
+
// Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
int64_t index = indexedSize.index();
auto sizeAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, size,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+ desc.setSize(rewriter, loc, index, size);
auto strideAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, stride,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+ desc.setStride(rewriter, loc, index, stride);
}
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(op, {desc});
return matchSuccess();
}
};
}
// CHECK-LABEL: vector_type_cast
// CHECK: llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
-// CHECK: %[[allocated:.*]] = llvm.extractvalue {{.*}}[0 : index] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: %[[allocatedBit:.*]] = llvm.bitcast %[[allocated]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*">
-// CHECK: llvm.insertvalue %[[allocatedBit]], {{.*}}[0 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
-// CHECK: %[[aligned:.*]] = llvm.extractvalue {{.*}}[1 : index] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue %[[allocatedBit]], {{.*}}[0] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
+// CHECK: %[[aligned:.*]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: %[[alignedBit:.*]] = llvm.bitcast %[[aligned]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*">
-// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
+// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
// CHECK: llvm.mlir.constant(0 : index
-// CHECK: llvm.insertvalue {{.*}}[2 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
-
+// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">