using TypePredicate = llvm::function_ref<bool(Type)>;
-// Returns vector width if the element type is matching the predicate (scalars
-// that do match the predicate have width equal to `1`).
-static Optional<int> vectorWidth(Type type, TypePredicate pred) {
- // If the type matches the predicate then its width is `1`.
+// Returns vector shape if the element type is matching the predicate (scalars
+// that do match the predicate have shape equal to `{1}`).
+static Optional<SmallVector<int64_t, 2>> vectorShape(Type type,
+ TypePredicate pred) {
+ // If the type matches the predicate then its shape is `{1}`.
if (pred(type))
- return 1;
+ return SmallVector<int64_t, 2>{1};
// Otherwise check if the type is a vector type.
auto vectorType = type.dyn_cast<VectorType>();
if (vectorType && pred(vectorType.getElementType())) {
- assert(vectorType.getRank() == 1 && "only 1d vectors are supported");
- return vectorType.getDimSize(0);
+ return llvm::to_vector<2>(vectorType.getShape());
}
return llvm::None;
}
-// Returns vector width of the type. If the type is a scalar returns `1`.
-static int vectorWidth(Type type) {
+// Returns vector shape of the type. If the type is a scalar returns `1`.
+static SmallVector<int64_t, 2> vectorShape(Type type) {
auto vectorType = type.dyn_cast<VectorType>();
- return vectorType ? vectorType.getDimSize(0) : 1;
+ return vectorType ? llvm::to_vector<2>(vectorType.getShape())
+ : SmallVector<int64_t, 2>{1};
}
// Returns vector element type. If the type is a scalar returns the argument.
// Broadcast scalar types and values into vector types and values.
//----------------------------------------------------------------------------//
-// Broadcasts scalar type into vector type (iff width is greater then 1).
-static Type broadcast(Type type, int width) {
+// Returns true if shape != {1}.
+static bool isNonScalarShape(ArrayRef<int64_t> shape) {
+ return shape.size() > 1 || shape[0] > 1;
+}
+
+// Broadcasts scalar type into vector type (iff shape is non-scalar).
+static Type broadcast(Type type, ArrayRef<int64_t> shape) {
assert(!type.isa<VectorType>() && "must be scalar type");
- return width > 1 ? VectorType::get({width}, type) : type;
+ return isNonScalarShape(shape) ? VectorType::get(shape, type) : type;
}
-// Broadcasts scalar value into vector (iff width is greater then 1).
-static Value broadcast(ImplicitLocOpBuilder &builder, Value value, int width) {
+// Broadcasts scalar value into vector (iff shape is non-scalar).
+static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
+ ArrayRef<int64_t> shape) {
assert(!value.getType().isa<VectorType>() && "must be scalar value");
- auto type = broadcast(value.getType(), width);
- return width > 1 ? builder.create<BroadcastOp>(type, value) : value;
+ auto type = broadcast(value.getType(), shape);
+ return isNonScalarShape(shape) ? builder.create<BroadcastOp>(type, value)
+ : value;
}
//----------------------------------------------------------------------------//
bool is_positive = false) {
assert(isF32(elementType(arg.getType())) && "argument must be f32 type");
- int width = vectorWidth(arg.getType());
+ auto shape = vectorShape(arg.getType());
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, width);
+ return broadcast(builder, value, shape);
};
auto i32 = builder.getIntegerType(32);
- auto i32Vec = broadcast(i32, width);
- auto f32Vec = broadcast(builder.getF32Type(), width);
+ auto i32Vec = broadcast(i32, shape);
+ auto f32Vec = broadcast(builder.getF32Type(), shape);
Value cst126f = f32Cst(builder, 126.0f);
Value cstHalf = f32Cst(builder, 0.5f);
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
- int width = vectorWidth(arg.getType());
+ auto shape = vectorShape(arg.getType());
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, width);
+ return broadcast(builder, value, shape);
};
- auto f32Vec = broadcast(builder.getF32Type(), width);
+ auto f32Vec = broadcast(builder.getF32Type(), shape);
// The exponent of f32 located at 23-bit.
auto exponetBitLocation = bcast(i32Cst(builder, 23));
// Set the exponent bias to zero.
LogicalResult
TanhApproximation::matchAndRewrite(math::TanhOp op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
// Clamp operand into [plusClamp, minusClamp] range.
LogicalResult
LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
bool base2) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
Value cstZero = bcast(f32Cst(builder, 0.0f));
LogicalResult
Log1pApproximation::matchAndRewrite(math::Log1pOp op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
// Approximate log(1+x) using the following, due to W. Kahan:
LogicalResult
ExpApproximation::matchAndRewrite(math::ExpOp op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
// TODO: Consider a common pattern rewriter with all methods below to
// write the approximations.
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
auto fmla = [&](Value a, Value b, Value c) {
return builder.create<math::FmaOp>(a, b, c);
Value expY = fmla(q1, y2, q0);
expY = fmla(q2, y4, expY);
- auto i32Vec = broadcast(builder.getI32Type(), *width);
+ auto i32Vec = broadcast(builder.getI32Type(), *shape);
// exp2(k)
Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
LogicalResult
ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
// expm1(x) = exp(x) - 1 = u - 1.
static_assert(
llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
"SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
auto mul = [&](Value a, Value b) -> Value {
return builder.create<arith::MulFOp>(a, b);
};
auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
- auto i32Vec = broadcast(builder.getI32Type(), *width);
+ auto i32Vec = broadcast(builder.getI32Type(), *shape);
auto fPToSingedInteger = [&](Value a) -> Value {
return builder.create<arith::FPToSIOp>(a, i32Vec);
};
LogicalResult
RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
+ auto shape = vectorShape(op.operand().getType(), isF32);
// Only support already-vectorized rsqrt's.
- if (!width.hasValue() || *width != 8)
+ if (!shape.hasValue() || (*shape)[0] != 8)
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));