class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
detail::MemRefTypeStorage> {
public:
+ /// This is a builder type that keeps local references to arguments. Arguments
+ /// that are passed into the builder must out-live the builder.
+ class Builder {
+ public:
+ // Build from another MemRefType.
+ explicit Builder(MemRefType other)
+ : shape(other.getShape()), elementType(other.getElementType()),
+ affineMaps(other.getAffineMaps()),
+ memorySpace(other.getMemorySpace()) {}
+
+ // Build from scratch.
+ Builder(ArrayRef<int64_t> shape, Type elementType)
+ : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
+ }
+
+ Builder &setShape(ArrayRef<int64_t> newShape) {
+ shape = newShape;
+ return *this;
+ }
+
+ Builder &setElementType(Type newElementType) {
+ elementType = newElementType;
+ return *this;
+ }
+
+ Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
+ affineMaps = newAffineMaps;
+ return *this;
+ }
+
+ Builder &setMemorySpace(unsigned newMemorySpace) {
+ memorySpace = newMemorySpace;
+ return *this;
+ }
+
+ operator MemRefType() {
+ return MemRefType::get(shape, elementType, affineMaps, memorySpace);
+ }
+
+ private:
+ ArrayRef<int64_t> shape;
+ Type elementType;
+ ArrayRef<AffineMap> affineMaps;
+ unsigned memorySpace;
+ };
+
using Base::Base;
/// Get or create a new MemRefType based on shape, element type, affine
auto memref = type.dyn_cast<MemRefType>();
if (memref &&
memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) {
- type = MemRefType::get(memref.getShape(), memref.getElementType(),
- memref.getAffineMaps());
+ type = MemRefType::Builder(memref).setMemorySpace(0);
}
return LLVMTypeConverter::convertType(type);
case StandardTypes::Kind::UnrankedTensor:
return UnrankedTensorType::get(newElementType);
case StandardTypes::Kind::MemRef:
- return MemRefType::get(st.getShape(), newElementType,
- st.cast<MemRefType>().getAffineMaps());
+ return MemRefType::Builder(st.cast<MemRefType>())
+ .setElementType(newElementType);
}
}
assert(t.isIntOrFloat());
// Early-exit: if `type` is contiguous, the result must be contiguous.
if (canonicalizeStridedLayout(type).getAffineMaps().empty())
- return MemRefType::get(newSizes, type.getElementType(), {});
+ return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
// Convert back to int64_t because we don't have enough information to create
// new strided layouts from AffineExpr only. This corresponds to a case where
auto layout =
makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
return canonicalizeStridedLayout(
- MemRefType::get(newSizes, type.getElementType(), {layout}));
+ MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
}
/// Helper functions assert Attribute of the proper type in attr and returns the
unsigned rank = memRefType.getRank();
// TODO(ntv): propagate static size and stride information when available.
SmallVector<int64_t, 4> sizes(rank, -1); // -1 encodes dynamic size.
- Type elementType = memRefType.getElementType();
- result.addTypes({MemRefType::get(
- sizes, elementType,
- {makeStridedLinearLayoutMap(strides, offset, b->getContext())},
- memRefType.getMemorySpace())});
+ result.addTypes({MemRefType::Builder(memRefType)
+ .setShape(sizes)
+ .setAffineMaps(makeStridedLinearLayoutMap(
+ strides, offset, b->getContext()))});
}
static void print(OpAsmPrinter &p, SliceOp op) {
auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext());
map = permutationMap ? map.compose(permutationMap) : map;
// Compute result type.
- auto resultType = MemRefType::get(sizes, memRefType.getElementType(), map,
- memRefType.getMemorySpace());
+ MemRefType resultType =
+ MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
build(b, result, resultType, view, attrs);
result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
}
// Create new memref type (which will have fewer dynamic dimensions).
- auto newMemRefType = MemRefType::get(
- newShapeConstants, memrefType.getElementType(),
- memrefType.getAffineMaps(), memrefType.getMemorySpace());
+ MemRefType newMemRefType =
+ MemRefType::Builder(memrefType).setShape(newShapeConstants);
assert(static_cast<int64_t>(newOperands.size()) ==
newMemRefType.getNumDynamicDims());
rewriter.getContext());
// Create new memref type with constant folded dims and/or offset/strides.
- auto newMemRefType =
- MemRefType::get(newShapeConstants, memrefType.getElementType(), {map},
- memrefType.getMemorySpace());
+ MemRefType newMemRefType = MemRefType::Builder(memrefType)
+ .setShape(newShapeConstants)
+ .setAffineMaps({map});
(void)dynamicOffsetOperandCount; // unused in opt mode
assert(static_cast<int64_t>(newOperands.size()) ==
dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
auto rank = memRefType.getRank();
int64_t offset;
SmallVector<int64_t, 4> strides;
- Type elementType = memRefType.getElementType();
auto res = getStridesAndOffset(memRefType, strides, offset);
assert(succeeded(res) && "SubViewOp expected strided memref type");
(void)res;
auto stridedLayout =
makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize);
- return MemRefType::get(sizes, elementType, stridedLayout,
- memRefType.getMemorySpace());
+ return MemRefType::Builder(memRefType)
+ .setShape(sizes)
+ .setAffineMaps(stridedLayout);
}
void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
assert(defOp);
staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
}
- MemRefType newMemRefType = MemRefType::get(
- staticShape, subViewType.getElementType(), subViewType.getAffineMaps(),
- subViewType.getMemorySpace());
+ MemRefType newMemRefType =
+ MemRefType::Builder(subViewType).setShape(staticShape);
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
AffineMap layoutMap = makeStridedLinearLayoutMap(
staticStrides, resultOffset, rewriter.getContext());
MemRefType newMemRefType =
- MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
- layoutMap, subViewType.getMemorySpace());
+ MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
AffineMap layoutMap = makeStridedLinearLayoutMap(
resultStrides, staticOffset, rewriter.getContext());
MemRefType newMemRefType =
- MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
- layoutMap, subViewType.getMemorySpace());
+ MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
subViewOp.sizes(), subViewOp.strides(), newMemRefType);
auto simplifiedLayoutExpr =
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
if (expr != simplifiedLayoutExpr)
- return MemRefType::get(t.getShape(), t.getElementType(),
- {AffineMap::get(m.getNumDims(), m.getNumSymbols(),
- {simplifiedLayoutExpr})});
-
- return MemRefType::get(t.getShape(), t.getElementType(), {});
+ return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
+ m.getNumDims(), m.getNumSymbols(), {simplifiedLayoutExpr})});
+ return MemRefType::Builder(t).setAffineMaps({});
}
/// Return true if the layout for `t` is compatible with strided semantics.
SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
newShape[0] = 2;
std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
- auto newMemRefType =
- MemRefType::get(newShape, oldMemRefType.getElementType(), {},
- oldMemRefType.getMemorySpace());
- return newMemRefType;
+ return MemRefType::Builder(oldMemRefType)
+ .setShape(newShape)
+ .setAffineMaps({});
};
auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
auto oldMemRef = allocOp.getResult();
SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
- auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(),
- b.getMultiDimIdentityMap(newRank));
+ MemRefType newMemRefType =
+ MemRefType::Builder(memrefType)
+ .setShape(newShape)
+ .setAffineMaps(b.getMultiDimIdentityMap(newRank));
auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
// Replace all uses of the old memref.