// domain context.
static bool getIterationDomainContext(const Statement *stmt,
IterationDomainContext *ctx) {
- // Walk up tree storing parent statements in 'loops'.
// TODO(andydavis) Extend this to gather enclosing IfStmts and consider
// factoring it out into a utility function.
SmallVector<ForStmt *, 4> loops;
- auto *currStmt = stmt->getParentStmt();
- while (currStmt != nullptr) {
- if (isa<IfStmt>(currStmt))
- return false;
- assert(isa<ForStmt>(currStmt));
- auto *forStmt = dyn_cast<ForStmt>(currStmt);
- loops.push_back(forStmt);
- currStmt = currStmt->getParentStmt();
- }
+ getLoopIVs(*stmt, &loops);
+
// Iterate through 'loops' from outer-most loop to inner-most loop.
// Populate 'values'.
ctx->values.reserve(loops.size());
- for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
- auto *forStmt = loops[i];
- // TODO(andydavis) Compose affine maps into lower/upper bounds of 'forStmt'
- // and add de-duped symbols to ctx.symbols.
- if (!forStmt->hasConstantBounds())
- return false;
- ctx->values.push_back(forStmt);
- ctx->numDims++;
- }
+ ctx->numDims += loops.size();
+ ctx->values.insert(ctx->values.end(), loops.begin(), loops.end());
+
// Resize flat affine constraint system based on num dims symbols found.
unsigned numDims = ctx->getNumDims();
unsigned numSymbols = ctx->getNumSymbols();