BaseMemRefType getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
- Attribute memorySpace = {});
+ unsigned memorySpace = 0);
/// Return a MemRef type with fully dynamic layout. If the given tensor type
/// is unranked, return an unranked MemRef type.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
- Attribute memorySpace = {});
+ unsigned memorySpace = 0);
/// Return a MemRef type with a static identity layout (i.e., no layout map). If
/// the given tensor type is unranked, return an unranked MemRef type.
-BaseMemRefType
-getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
- Attribute memorySpace = {});
+BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
+ unsigned memorySpace = 0);
} // namespace bufferization
} // namespace mlir
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
- Attribute memorySpace) {
+ unsigned memorySpace) {
+ auto memorySpaceAttr = IntegerAttr::get(
+ IntegerType::get(tensorType.getContext(), 64), memorySpace);
+
// Case 1: Unranked memref type.
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
assert(!layout && "UnrankedTensorType cannot have a layout map");
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
- memorySpace);
+ memorySpaceAttr);
}
// Case 2: Ranked memref type with specified layout.
if (layout) {
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
- memorySpace);
+ memorySpaceAttr);
}
// Case 3: Configured with "fully dynamic layout maps".
BaseMemRefType
bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
- Attribute memorySpace) {
+ unsigned memorySpace) {
// Case 1: Unranked memref type.
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
}
// Case 2: Ranked memref type.
+ auto memorySpaceAttr = IntegerAttr::get(
+ IntegerType::get(tensorType.getContext(), 64), memorySpace);
auto rankedTensorType = tensorType.cast<RankedTensorType>();
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
dynamicStrides, dynamicOffset, rankedTensorType.getContext());
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), stridedLayout,
- memorySpace);
+ memorySpaceAttr);
}
/// Return a MemRef type with a static identity layout (i.e., no layout map). If
/// the given tensor type is unranked, return an unranked MemRef type.
BaseMemRefType
bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
- Attribute memorySpace) {
+ unsigned memorySpace) {
// Case 1: Unranked memref type.
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
// Case 2: Ranked memref type.
auto rankedTensorType = tensorType.cast<RankedTensorType>();
+ auto memorySpaceAttr = IntegerAttr::get(
+ IntegerType::get(tensorType.getContext(), 64), memorySpace);
MemRefLayoutAttrInterface layout = {};
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
- memorySpace);
+ memorySpaceAttr);
}
// The result buffer still has the old (pre-cast) type.
Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options);
auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
- Attribute memorySpace = sourceMemRefType.getMemorySpace();
TensorType resultTensorType =
castOp.getResult().getType().cast<TensorType>();
MemRefLayoutAttrInterface layout;
// Compute the new memref type.
Type resultMemRefType =
- getMemRefType(resultTensorType, options, layout, memorySpace);
+ getMemRefType(resultTensorType, options, layout,
+ sourceMemRefType.getMemorySpaceAsInt());
// Replace the op with a memref.cast.
assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),