/// Performs constant folding `calculate` with element-wise behavior on the one
/// attributes in `operands` and returns the result if possible.
-template <class AttrElementT,
- class ElementValueT = typename AttrElementT::ValueType,
- class CalculationT = function_ref<ElementValueT(ElementValueT)>>
-Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
- const CalculationT &&calculate) {
+template <
+ class AttrElementT, class ElementValueT = typename AttrElementT::ValueType,
+ class CalculationT = function_ref<Optional<ElementValueT>(ElementValueT)>>
+Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
+ const CalculationT &&calculate) {
assert(operands.size() == 1 && "unary op takes one operands");
if (!operands[0])
return {};
if (operands[0].isa<AttrElementT>()) {
auto op = operands[0].cast<AttrElementT>();
- return AttrElementT::get(op.getType(), calculate(op.getValue()));
+ auto res = calculate(op.getValue());
+ if (!res)
+ return {};
+ return AttrElementT::get(op.getType(), *res);
}
if (operands[0].isa<SplatElementsAttr>()) {
// Both operands are splats so we can avoid expanding the values out and
auto op = operands[0].cast<SplatElementsAttr>();
auto elementResult = calculate(op.getSplatValue<ElementValueT>());
- return DenseElementsAttr::get(op.getType(), elementResult);
+ if (!elementResult)
+ return {};
+ return DenseElementsAttr::get(op.getType(), *elementResult);
} else if (operands[0].isa<ElementsAttr>()) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
auto opIt = op.value_begin<ElementValueT>();
SmallVector<ElementValueT> elementResults;
elementResults.reserve(op.getNumElements());
- for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt)
- elementResults.push_back(calculate(*opIt));
+ for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
+ auto elementResult = calculate(*opIt);
+ if (!elementResult)
+ return {};
+ elementResults.push_back(*elementResult);
+ }
return DenseElementsAttr::get(op.getType(), elementResults);
}
return {};
}
+template <class AttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType,
+ class CalculationT = function_ref<ElementValueT(ElementValueT)>>
+Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
+ const CalculationT &&calculate) {
+ return constFoldUnaryOpConditional<AttrElementT>(
+ operands,
+ [&](ElementValueT a) -> Optional<ElementValueT> { return calculate(a); });
+}
+
template <
class AttrElementT, class TargetAttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
//===----------------------------------------------------------------------===//
OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
- auto constOperand = operands.front();
- if (!constOperand)
- return {};
-
- auto attr = constOperand.dyn_cast<FloatAttr>();
- if (!attr)
- return {};
+ return constFoldUnaryOpConditional<FloatAttr>(
+ operands, [](const APFloat &a) -> Optional<APFloat> {
+ if (a.isNegative())
+ return {};
- auto ft = getType().cast<FloatType>();
+ if (a.getSizeInBits(a.getSemantics()) == 64)
+ return APFloat(log2(a.convertToDouble()));
- APFloat apf = attr.getValue();
+ if (a.getSizeInBits(a.getSemantics()) == 32)
+ return APFloat(log2f(a.convertToFloat()));
- if (apf.isNegative())
- return {};
-
- if (ft.getWidth() == 64)
- return FloatAttr::get(getType(), log2(apf.convertToDouble()));
-
- if (ft.getWidth() == 32)
- return FloatAttr::get(getType(), log2f(apf.convertToFloat()));
-
- return {};
+ return {};
+ });
}
//===----------------------------------------------------------------------===//
return %r : f64
}
+// CHECK-LABEL: @log2_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00, 1.58496249, 2.000000e+00]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @log2_fold_vec() -> (vector<4xf32>) {
+ %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+ %0 = math.log2 %v1 : vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
// CHECK-LABEL: @powf_fold
// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32
// CHECK: return %[[cst]]