/// memory.
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB);
+
+/// Return the result value of reducing two scalar/vector values with the
+/// corresponding arith operation.
+Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
+ Value v1, Value v2);
} // namespace vector
} // namespace mlir
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
-
-/// Return the result value of reducing two scalar/vector values with the
-/// corresponding arith operation.
-Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
- Value v1, Value v2);
} // namespace vector
/// Return the number of elements of basis, `0` if empty.
reductionOp.getVector(),
rewriter.getI64ArrayAttr(0));
- if (Value acc = reductionOp.getAcc()) {
- assert(reductionOp.getType().isa<FloatType>());
- switch (reductionOp.getKind()) {
- case CombiningKind::ADD:
- result = rewriter.create<arith::AddFOp>(loc, result, acc);
- break;
- case CombiningKind::MUL:
- result = rewriter.create<arith::MulFOp>(loc, result, acc);
- break;
- default:
- assert(false && "invalid op!");
- }
- }
+ if (Value acc = reductionOp.getAcc())
+ result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
+ result, acc);
rewriter.replaceOp(reductionOp, result);
return success();
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
}
+Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
+ CombiningKind kind, Value v1, Value v2) {
+ Type t1 = getElementTypeOrSelf(v1.getType());
+ Type t2 = getElementTypeOrSelf(v2.getType());
+ switch (kind) {
+ case CombiningKind::ADD:
+ if (t1.isIntOrIndex() && t2.isIntOrIndex())
+ return b.createOrFold<arith::AddIOp>(loc, v1, v2);
+ else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+ return b.createOrFold<arith::AddFOp>(loc, v1, v2);
+ llvm_unreachable("invalid value types for ADD reduction");
+ case CombiningKind::AND:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::AndIOp>(loc, v1, v2);
+ case CombiningKind::MAXF:
+ assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ "expected float values");
+ return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
+ case CombiningKind::MINF:
+ assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ "expected float values");
+ return b.createOrFold<arith::MinFOp>(loc, v1, v2);
+ case CombiningKind::MAXSI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
+ case CombiningKind::MINSI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
+ case CombiningKind::MAXUI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
+ case CombiningKind::MINUI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
+ case CombiningKind::MUL:
+ if (t1.isIntOrIndex() && t2.isIntOrIndex())
+ return b.createOrFold<arith::MulIOp>(loc, v1, v2);
+ else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+ return b.createOrFold<arith::MulFOp>(loc, v1, v2);
+ llvm_unreachable("invalid value types for MUL reduction");
+ case CombiningKind::OR:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::OrIOp>(loc, v1, v2);
+ case CombiningKind::XOR:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
+ };
+ llvm_unreachable("unknown CombiningKind");
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
llvm_unreachable("Expected MemRefType or TensorType");
}
-Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
- CombiningKind kind, Value v1, Value v2) {
- Type t1 = getElementTypeOrSelf(v1.getType());
- Type t2 = getElementTypeOrSelf(v2.getType());
- switch (kind) {
- case CombiningKind::ADD:
- if (t1.isIntOrIndex() && t2.isIntOrIndex())
- return b.createOrFold<arith::AddIOp>(loc, v1, v2);
- else if (t1.isa<FloatType>() && t2.isa<FloatType>())
- return b.createOrFold<arith::AddFOp>(loc, v1, v2);
- llvm_unreachable("invalid value types for ADD reduction");
- case CombiningKind::AND:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::AndIOp>(loc, v1, v2);
- case CombiningKind::MAXF:
- assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
- "expected float values");
- return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
- case CombiningKind::MINF:
- assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
- "expected float values");
- return b.createOrFold<arith::MinFOp>(loc, v1, v2);
- case CombiningKind::MAXSI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
- case CombiningKind::MINSI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
- case CombiningKind::MAXUI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
- case CombiningKind::MINUI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
- case CombiningKind::MUL:
- if (t1.isIntOrIndex() && t2.isIntOrIndex())
- return b.createOrFold<arith::MulIOp>(loc, v1, v2);
- else if (t1.isa<FloatType>() && t2.isa<FloatType>())
- return b.createOrFold<arith::MulFOp>(loc, v1, v2);
- llvm_unreachable("invalid value types for MUL reduction");
- case CombiningKind::OR:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::OrIOp>(loc, v1, v2);
- case CombiningKind::XOR:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
- };
- llvm_unreachable("unknown CombiningKind");
-}
-
/// Return the number of elements of basis, `0` if empty.
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
if (basis.empty())
// -----
+// CHECK-LABEL: func @reduce_one_element_vector_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+// CHECK: %[[S:.+]] = arith.maxf %[[A]], %[[B]] : f32
+// CHECK: return %[[S]]
+func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 {
+ %s = vector.reduction <maxf>, %a, %b : vector<1xf32> into f32
+ return %s : f32
+}
+
+// -----
+
// CHECK-LABEL: func @bitcast(
// CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> {
// CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>