BitVector is becoming widespread enough that we should add a proper using.
Differential Revision: https://reviews.llvm.org/D118290
/// tensor expression.
struct LatPoint {
LatPoint(unsigned n, unsigned e, unsigned b);
- LatPoint(const llvm::BitVector &b, unsigned e);
+ LatPoint(const BitVector &b, unsigned e);
/// Conjunction of tensor loop indices as bitvector. This represents
/// all indices involved in the tensor expression
- llvm::BitVector bits;
+ BitVector bits;
/// Simplified conjunction of tensor loop indices as bitvector. This
/// represents a simplified condition under which this tensor expression
/// must execute. Pre-computed during codegen to avoid repeated eval.
- llvm::BitVector simple;
+ BitVector simple;
/// Index of the tensor expresssion.
unsigned exp;
/// within the given set using just two basic rules:
/// (1) multiple dense conditions are reduced to single dense, and
/// (2) a *singleton* sparse/dense is reduced to sparse/random access.
- llvm::BitVector simplifyCond(unsigned s0, unsigned p0);
+ BitVector simplifyCond(unsigned s0, unsigned p0);
/// Returns true if Li > Lj.
bool latGT(unsigned i, unsigned j) const;
}
/// Returns true if any set bit corresponds to queried dim.
- bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
+ bool hasAnyDimOf(const BitVector &bits, Dim d) const;
/// Returns true if given tensor iterates *only* in the given tensor
/// expression. For the output tensor, this defines a "simply dynamic"
void dumpExp(unsigned e) const;
void dumpLat(unsigned p) const;
void dumpSet(unsigned s) const;
- void dumpBits(const llvm::BitVector &bits) const;
+ void dumpBits(const BitVector &bits) const;
#endif
/// Builds the iteration lattices in a bottom-up traversal given the remaining
void eraseArguments(ArrayRef<unsigned> argIndices);
/// Erases the arguments that have their corresponding bit set in
/// `eraseIndices` and removes them from the argument list.
- void eraseArguments(const llvm::BitVector &eraseIndices);
+ void eraseArguments(const BitVector &eraseIndices);
/// Erases arguments using the given predicate. If the predicate returns true,
/// that argument is erased.
void eraseArguments(function_ref<bool(BlockArgument)> shouldEraseFn);
TypeRange resultTypes);
/// Returns a new function type without the specified arguments and results.
- FunctionType getWithoutArgsAndResults(const llvm::BitVector &argIndices,
- const llvm::BitVector &resultIndices);
+ FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
+ const BitVector &resultIndices);
}];
}
unsigned originalNumResults, Type newType);
/// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(Operation *op, const llvm::BitVector &argIndices,
+void eraseFunctionArguments(Operation *op, const BitVector &argIndices,
Type newType);
/// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(Operation *op, const llvm::BitVector &resultIndices,
+void eraseFunctionResults(Operation *op, const BitVector &resultIndices,
Type newType);
/// Set a FunctionOpInterface operation's type signature.
/// Filters out any elements referenced by `indices`. If any types are removed,
/// `storage` is used to hold the new type list. Returns the new type list.
-TypeRange filterTypesOut(TypeRange types, const llvm::BitVector &indices,
+TypeRange filterTypesOut(TypeRange types, const BitVector &indices,
SmallVectorImpl<Type> &storage);
//===----------------------------------------------------------------------===//
/// Erase a single argument at `argIndex`.
void eraseArgument(unsigned argIndex) {
- llvm::BitVector argsToErase($_op.getNumArguments());
+ BitVector argsToErase($_op.getNumArguments());
argsToErase.set(argIndex);
eraseArguments(argsToErase);
}
/// Erases the arguments listed in `argIndices`.
- void eraseArguments(const llvm::BitVector &argIndices) {
+ void eraseArguments(const BitVector &argIndices) {
Type newType = $_op.getTypeWithoutArgs(argIndices);
function_interface_impl::eraseFunctionArguments(
this->getOperation(), argIndices, newType);
/// Erase a single result at `resultIndex`.
void eraseResult(unsigned resultIndex) {
- llvm::BitVector resultsToErase($_op.getNumResults());
+ BitVector resultsToErase($_op.getNumResults());
resultsToErase.set(resultIndex);
eraseResults(resultsToErase);
}
/// Erases the results listed in `resultIndices`.
- void eraseResults(const llvm::BitVector &resultIndices) {
+ void eraseResults(const BitVector &resultIndices) {
Type newType = $_op.getTypeWithoutResults(resultIndices);
function_interface_impl::eraseFunctionResults(
this->getOperation(), resultIndices, newType);
/// results. This is used to update the function's signature in the
/// `eraseArguments` and `eraseResults` methods.
Type getTypeWithoutArgsAndResults(
- const llvm::BitVector &argIndices, const llvm::BitVector &resultIndices) {
+ const BitVector &argIndices, const BitVector &resultIndices) {
SmallVector<Type> argStorage, resultStorage;
TypeRange newArgTypes = function_interface_impl::filterTypesOut(
$_op.getArgumentTypes(), argIndices, argStorage);
$_op.getResultTypes(), resultIndices, resultStorage);
return $_op.cloneTypeWith(newArgTypes, newResultTypes);
}
- Type getTypeWithoutArgs(const llvm::BitVector &argIndices) {
+ Type getTypeWithoutArgs(const BitVector &argIndices) {
SmallVector<Type> argStorage;
TypeRange newArgTypes = function_interface_impl::filterTypesOut(
$_op.getArgumentTypes(), argIndices, argStorage);
return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes());
}
- Type getTypeWithoutResults(const llvm::BitVector &resultIndices) {
+ Type getTypeWithoutResults(const BitVector &resultIndices) {
SmallVector<Type> resultStorage;
TypeRange newResultTypes = function_interface_impl::filterTypesOut(
$_op.getResultTypes(), resultIndices, resultStorage);
/// Erases the operands that have their corresponding bit set in
/// `eraseIndices` and removes them from the operand list.
- void eraseOperands(const llvm::BitVector &eraseIndices) {
+ void eraseOperands(const BitVector &eraseIndices) {
getOperandStorage().eraseOperands(eraseIndices);
}
/// Erase the operands held by the storage that have their corresponding bit
/// set in `eraseIndices`.
- void eraseOperands(const llvm::BitVector &eraseIndices);
+ void eraseOperands(const BitVector &eraseIndices);
/// Get the operation operands held by the storage.
MutableArrayRef<OpOperand> getOperands() { return {operandStorage, size()}; }
// Containers.
template <typename T> class ArrayRef;
+class BitVector;
namespace detail {
template <typename KeyT, typename ValueT> struct DenseMapPair;
} // namespace detail
//
// Containers.
using llvm::ArrayRef;
+using llvm::BitVector;
template <typename T, typename Enable = void>
using DenseMapInfo = llvm::DenseMapInfo<T, Enable>;
template <typename KeyT, typename ValueT,
// Collect information about the results will become appended arguments.
SmallVector<Type, 6> erasedResultTypes;
- llvm::BitVector erasedResultIndices(functionType.getNumResults());
+ BitVector erasedResultIndices(functionType.getNumResults());
for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
if (resultType.value().isa<BaseMemRefType>()) {
erasedResultIndices.set(resultType.index());
// case due to the output operand. For reductions, we need to check that after
// the fusion, each loop dimension has at least one input that defines it.
if ((consumer.getNumReductionLoops())) {
- llvm::BitVector coveredDims(consumer.getNumLoops(), false);
+ BitVector coveredDims(consumer.getNumLoops(), false);
auto addToCoveredDims = [&](AffineMap map) {
for (auto result : map.getResults())
SmallVectorImpl<int> &segments) {
// Check done[clause] to see if it has been parsed already
- llvm::BitVector done(ClauseType::COUNT, false);
+ BitVector done(ClauseType::COUNT, false);
// See pos[clause] to get position of clause in operand segments
SmallVector<int> pos(ClauseType::COUNT, -1);
/// maintain the universal index.
static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
linalg::GenericOp op, std::vector<unsigned> &topSort,
- unsigned at, llvm::BitVector &inits) {
+ unsigned at, BitVector &inits) {
bool needsUniv = false;
Location loc = op.getLoc();
unsigned idx = topSort[at];
static Operation *genFor(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
bool isOuter, bool isInner, unsigned idx,
- llvm::BitVector &indices) {
+ BitVector &indices) {
unsigned fb = indices.find_first();
unsigned tensor = merger.tensor(fb);
assert(idx == merger.index(fb));
static Operation *genWhile(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
unsigned idx, bool needsUniv,
- llvm::BitVector &indices) {
+ BitVector &indices) {
SmallVector<Type, 4> types;
SmallVector<Value, 4> operands;
// Construct the while-loop with a parameter for each index.
static Operation *genLoop(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
std::vector<unsigned> &topSort, unsigned at,
- bool needsUniv, llvm::BitVector &indices) {
+ bool needsUniv, BitVector &indices) {
unsigned idx = topSort[at];
if (indices.count() == 1) {
bool isOuter = at == 0;
static void genLocals(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
std::vector<unsigned> &topSort, unsigned at,
- bool needsUniv, llvm::BitVector &locals) {
+ bool needsUniv, BitVector &locals) {
Location loc = op.getLoc();
unsigned idx = topSort[at];
static void genWhileInduction(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
unsigned idx, bool needsUniv,
- llvm::BitVector &induction,
+ BitVector &induction,
scf::WhileOp whileOp) {
Location loc = op.getLoc();
// Finalize each else branch of all if statements.
/// Generates a single if-statement within a while-loop.
static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
- unsigned idx, llvm::BitVector &conditions) {
+ unsigned idx, BitVector &conditions) {
Location loc = op.getLoc();
SmallVector<Type, 4> types;
Value cond;
bits.set(b);
}
-LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
+LatPoint::LatPoint(const BitVector &b, unsigned e)
: bits(b), simple(), exp(e) {}
//===----------------------------------------------------------------------===//
unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
unsigned p = latPoints.size();
- llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
+ BitVector nb = BitVector(latPoints[p0].bits);
nb |= latPoints[p1].bits;
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
latPoints.push_back(LatPoint(nb, e));
return s;
}
-llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
+BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
// First determine if this lattice point is a *singleton*, i.e.,
// the last point in a lattice, no other is less than this one.
bool isSingleton = true;
}
}
// Now apply the two basic rules.
- llvm::BitVector simple = latPoints[p0].bits;
+ BitVector simple = latPoints[p0].bits;
bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
for (unsigned b = 0, be = simple.size(); b < be; b++) {
if (simple[b] && !isDim(b, kSparse)) {
}
bool Merger::latGT(unsigned i, unsigned j) const {
- const llvm::BitVector &bitsi = latPoints[i].bits;
- const llvm::BitVector &bitsj = latPoints[j].bits;
+ const BitVector &bitsi = latPoints[i].bits;
+ const BitVector &bitsj = latPoints[j].bits;
assert(bitsi.size() == bitsj.size());
if (bitsi.count() > bitsj.count()) {
for (unsigned b = 0, be = bitsj.size(); b < be; b++)
}
bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
- llvm::BitVector tmp = latPoints[j].bits;
+ BitVector tmp = latPoints[j].bits;
tmp ^= latPoints[i].bits;
return !hasAnyDimOf(tmp, kSparse);
}
-bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
+bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
for (unsigned b = 0, be = bits.size(); b < be; b++)
if (bits[b] && isDim(b, d))
return true;
llvm::dbgs() << "}\n";
}
-void Merger::dumpBits(const llvm::BitVector &bits) const {
+void Merger::dumpBits(const BitVector &bits) const {
for (unsigned b = 0, be = bits.size(); b < be; b++) {
if (bits[b]) {
unsigned t = tensor(b);
}
void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
- llvm::BitVector eraseIndices(getNumArguments());
+ BitVector eraseIndices(getNumArguments());
for (unsigned i : argIndices)
eraseIndices.set(i);
eraseArguments(eraseIndices);
}
-void Block::eraseArguments(const llvm::BitVector &eraseIndices) {
+void Block::eraseArguments(const BitVector &eraseIndices) {
eraseArguments(
[&](BlockArgument arg) { return eraseIndices.test(arg.getArgNumber()); });
}
/// Returns a new function type without the specified arguments and results.
FunctionType
-FunctionType::getWithoutArgsAndResults(const llvm::BitVector &argIndices,
- const llvm::BitVector &resultIndices) {
+FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
+ const BitVector &resultIndices) {
SmallVector<Type> argStorage, resultStorage;
TypeRange newArgTypes = function_interface_impl::filterTypesOut(
getInputs(), argIndices, argStorage);
}
void mlir::function_interface_impl::eraseFunctionArguments(
- Operation *op, const llvm::BitVector &argIndices, Type newType) {
+ Operation *op, const BitVector &argIndices, Type newType) {
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
}
void mlir::function_interface_impl::eraseFunctionResults(
- Operation *op, const llvm::BitVector &resultIndices, Type newType) {
+ Operation *op, const BitVector &resultIndices, Type newType) {
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
TypeRange
mlir::function_interface_impl::filterTypesOut(TypeRange types,
- const llvm::BitVector &indices,
+ const BitVector &indices,
SmallVectorImpl<Type> &storage) {
if (indices.none())
return types;
}
void detail::OperandStorage::eraseOperands(
- const llvm::BitVector &eraseIndices) {
+ const BitVector &eraseIndices) {
MutableArrayRef<OpOperand> operands = getOperands();
assert(eraseIndices.size() == operands.size());
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
- llvm::BitVector indicesToErase(func.getNumArguments());
+ BitVector indicesToErase(func.getNumArguments());
for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
if (func.getArgAttr(argIndex, "test.erase_this_arg"))
indicesToErase.set(argIndex);
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
- llvm::BitVector indicesToErase(func.getNumResults());
+ BitVector indicesToErase(func.getNumResults());
for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
if (func.getResultAttr(resultIndex, "test.erase_this_result"))
indicesToErase.set(resultIndex);
const AttrOrTypeDef &def;
/// Seen attribute or type parameters.
- llvm::BitVector seenParams;
+ BitVector seenParams;
};
} // namespace
// Get a string containing all of the cases that can't be represented with a
// keyword.
- llvm::BitVector nonKeywordCases(cases.size());
+ BitVector nonKeywordCases(cases.size());
bool hasStrCase = false;
for (auto &it : llvm::enumerate(cases)) {
hasStrCase = it.value().isStrCase();
/// but there is no required ordering within groups.
bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
const std::shared_ptr<Pattern> &pattern,
- const llvm::BitVector &bits) {
+ const BitVector &bits) {
for (unsigned i = p; i < p + n; ++i) {
if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
compareBits(s, i, bits))
/// Wrapper over latPointWithinRange for readability of tests.
void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
const std::shared_ptr<Pattern> &pattern,
- const llvm::BitVector &bits) {
+ const BitVector &bits) {
EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
}
/// Wrapper over expectLatPointWithinRange for a single lat point.
void expectLatPoint(unsigned s, unsigned p,
const std::shared_ptr<Pattern> &pattern,
- const llvm::BitVector &bits) {
+ const BitVector &bits) {
EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
}
/// Converts a vector of (loop, tensor) pairs to a bitvector with the
/// corresponding bits set.
- llvm::BitVector
+ BitVector
loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
- llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false);
+ BitVector testBits = BitVector(numTensors + 1, false);
for (auto l : loops) {
auto loop = std::get<0>(l);
auto tensor = std::get<1>(l);
}
/// Returns true if the bits of lattice point p in set s match the given bits.
- bool compareBits(unsigned s, unsigned p, const llvm::BitVector &bits) {
+ bool compareBits(unsigned s, unsigned p, const BitVector &bits) {
return merger.lat(merger.set(s)[p]).bits == bits;
}
// Create an operation with operands to erase.
Operation *user =
createOp(&context, {operand2, operand1, operand2, operand1});
- llvm::BitVector eraseIndices(user->getNumOperands());
+ BitVector eraseIndices(user->getNumOperands());
// Check erasing no operands.
user->eraseOperands(eraseIndices);