int exponent;
};
-/// Casts an integer or floating point based type to a new element type.
+/// Casts an integer or floating point based shaped type to a new element type.
inline Type castElementType(Type t, Type newElementType) {
if (auto st = t.dyn_cast<ShapedType>()) {
switch (st.getKind()) {
return RankedTensorType::get(st.getShape(), newElementType);
case StandardTypes::Kind::UnrankedTensor:
return UnrankedTensorType::get(newElementType);
+ case StandardTypes::Kind::MemRef:
+ return MemRefType::get(st.getShape(), newElementType,
+ st.cast<MemRefType>().getAffineMaps());
}
}
assert(t.isIntOrFloat());
}
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
-/// be a primitive/vector/tensor).
+/// be a scalar primitive or a shaped type).
inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
if (auto st = t.dyn_cast<ShapedType>()) {
assert(st.getElementType().isa<IntegerType>());
return value;
}
-/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
-/// be a primitive/vector/tensor).
+/// Creates a FloatAttr with a type that matches the shape of 't' (which can be
+/// a scalar primitive or a shaped type).
inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
if (auto st = t.dyn_cast<ShapedType>()) {
FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();