From 13017246815be7d4c3eb85fccffd2aef94663e34 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 17 May 2019 14:32:25 -0700 Subject: [PATCH] Allow for the case where ShapedType is a MemRef in fixed point math kernel utils MemRef may soon be a subclass of ShapedType. -- PiperOrigin-RevId: 248788950 --- mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index 2fdac8e..26428cf 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -166,7 +166,7 @@ struct QuantizedMultiplierSmallerThanOneExp { int exponent; }; -/// Casts an integer or floating point based type to a new element type. +/// Casts an integer or floating point based shaped type to a new element type. inline Type castElementType(Type t, Type newElementType) { if (auto st = t.dyn_cast()) { switch (st.getKind()) { @@ -176,6 +176,9 @@ inline Type castElementType(Type t, Type newElementType) { return RankedTensorType::get(st.getShape(), newElementType); case StandardTypes::Kind::UnrankedTensor: return UnrankedTensorType::get(newElementType); + case StandardTypes::Kind::MemRef: + return MemRefType::get(st.getShape(), newElementType, + st.cast().getAffineMaps()); } } assert(t.isIntOrFloat()); @@ -183,7 +186,7 @@ inline Type castElementType(Type t, Type newElementType) { } /// Creates an IntegerAttr with a type that matches the shape of 't' (which can -/// be a primitive/vector/tensor). +/// be a scalar primitive or a shaped type). inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) { if (auto st = t.dyn_cast()) { assert(st.getElementType().isa()); @@ -208,8 +211,8 @@ inline APFloat convertFloatToType(FloatType ft, APFloat value) { return value; } -/// Creates an IntegerAttr with a type that matches the shape of 't' (which can -/// be a primitive/vector/tensor). +/// Creates a FloatAttr with a type that matches the shape of 't' (which can be +/// a scalar primitive or a shaped type). inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) { if (auto st = t.dyn_cast()) { FloatType floatElementType = st.getElementType().dyn_cast(); -- 2.7.4