Allow for the case where ShapedType is a MemRef in fixed point math kernel utils
authorGeoffrey Martin-Noble <gcmn@google.com>
Fri, 17 May 2019 21:32:25 +0000 (14:32 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:46:00 +0000 (13:46 -0700)
    MemRef may soon be a subclass of ShapedType.

--

PiperOrigin-RevId: 248788950

mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h

index 2fdac8e..26428cf 100644 (file)
@@ -166,7 +166,7 @@ struct QuantizedMultiplierSmallerThanOneExp {
   int exponent;
 };
 
-/// Casts an integer or floating point based type to a new element type.
+/// Casts an integer or floating point based shaped type to a new element type.
 inline Type castElementType(Type t, Type newElementType) {
   if (auto st = t.dyn_cast<ShapedType>()) {
     switch (st.getKind()) {
@@ -176,6 +176,9 @@ inline Type castElementType(Type t, Type newElementType) {
       return RankedTensorType::get(st.getShape(), newElementType);
     case StandardTypes::Kind::UnrankedTensor:
       return UnrankedTensorType::get(newElementType);
+    case StandardTypes::Kind::MemRef:
+      return MemRefType::get(st.getShape(), newElementType,
+                             st.cast<MemRefType>().getAffineMaps());
     }
   }
   assert(t.isIntOrFloat());
@@ -183,7 +186,7 @@ inline Type castElementType(Type t, Type newElementType) {
 }
 
 /// Creates an IntegerAttr with a type that matches the shape of 't' (which can
-/// be a primitive/vector/tensor).
+/// be a scalar primitive or a shaped type).
 inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
   if (auto st = t.dyn_cast<ShapedType>()) {
     assert(st.getElementType().isa<IntegerType>());
@@ -208,8 +211,8 @@ inline APFloat convertFloatToType(FloatType ft, APFloat value) {
   return value;
 }
 
-/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
-/// be a primitive/vector/tensor).
+/// Creates a FloatAttr with a type that matches the shape of 't' (which can be
+/// a scalar primitive or a shaped type).
 inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
   if (auto st = t.dyn_cast<ShapedType>()) {
     FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();