using namespace mlir::sparse_tensor;
using namespace mlir::sparse_tensor::ir_detail;
-#define FAILURE_IF_FAILED(STMT) \
- if (failed(STMT)) { \
+#define FAILURE_IF_FAILED(RES) \
+ if (failed(RES)) { \
+ return failure(); \
+ }
+
+/// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating
+/// its `RES` parameter.
+static inline bool didntSucceed(OptionalParseResult res) {
+ return !res.has_value() || failed(*res);
+}
+
+#define FAILURE_IF_NULLOPT_OR_FAILED(RES) \
+ if (didntSucceed(RES)) { \
return failure(); \
}
llvm_unreachable("unknown Policy");
}
-FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk) {
- VarInfo::ID varID;
+FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk,
+ bool requireKnown) {
+ VarInfo::ID id;
bool didCreate;
- const auto res =
- parseVar(vk, /*isOptional=*/false, Policy::MustNot, varID, didCreate);
- if (!res.has_value() || failed(*res))
- return failure();
- return varID;
+ const bool isOptional = false;
+ const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May;
+ const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
+ FAILURE_IF_NULLOPT_OR_FAILED(res)
+ assert(requireKnown ? !didCreate : true);
+ return id;
+}
+
+FailureOr<VarInfo::ID> DimLvlMapParser::parseVarBinding(VarKind vk,
+ bool requireKnown) {
+ const auto loc = parser.getCurrentLocation();
+ VarInfo::ID id;
+ bool didCreate;
+ const bool isOptional = false;
+ const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
+ const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
+ FAILURE_IF_NULLOPT_OR_FAILED(res)
+ assert(requireKnown ? !didCreate : didCreate);
+ bindVar(loc, id);
+ return id;
}
FailureOr<std::pair<Var, bool>>
-DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) {
+DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) {
+ const auto loc = parser.getCurrentLocation();
VarInfo::ID id;
bool didCreate;
- const auto res = parseVar(vk, isOptional, Policy::Must, id, didCreate);
+ const bool isOptional = true;
+ const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
+ const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
if (res.has_value()) {
FAILURE_IF_FAILED(*res)
- return std::make_pair(env.bindVar(id), true);
+ assert(didCreate);
+ return std::make_pair(bindVar(loc, id), true);
}
+ assert(!didCreate);
return std::make_pair(env.bindUnusedVar(vk), false);
}
-FailureOr<Var> DimLvlMapParser::parseLvlVarBinding(bool directAffine) {
- // Nothing to parse, create a new lvl var right away.
- if (directAffine)
- return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
- // Parse a lvl var, always pulling from the existing pool.
- const auto use = parseVarUsage(VarKind::Level);
- FAILURE_IF_FAILED(use)
- FAILURE_IF_FAILED(parser.parseEqual())
- return env.toVar(*use);
+Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) {
+ MLIRContext *context = parser.getContext();
+ const auto var = env.bindVar(id);
+ const auto &info = std::as_const(env).access(id);
+ const auto name = info.getName();
+ const auto num = *info.getNum();
+ switch (info.getKind()) {
+ case VarKind::Symbol: {
+ const auto affine = getAffineSymbolExpr(num, context);
+ dimsAndSymbols.emplace_back(name, affine);
+ lvlsAndSymbols.emplace_back(name, affine);
+ return var;
+ }
+ case VarKind::Dimension:
+ dimsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
+ return var;
+ case VarKind::Level:
+ lvlsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
+ return var;
+ }
+ llvm_unreachable("unknown VarKind");
}
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
- FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Symbol,
- OpAsmParser::Delimiter::OptionalSquare))
- FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Level,
- OpAsmParser::Delimiter::OptionalBraces))
+ FAILURE_IF_FAILED(parseSymbolBindingList())
+ FAILURE_IF_FAILED(parseLvlVarBindingList())
FAILURE_IF_FAILED(parseDimSpecList())
FAILURE_IF_FAILED(parser.parseArrow())
FAILURE_IF_FAILED(parseLvlSpecList())
return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
}
-ParseResult
-DimLvlMapParser::parseOptionalIdList(VarKind vk,
- OpAsmParser::Delimiter delimiter) {
- const auto parseIdBinding = [&]() -> ParseResult {
- return ParseResult(parseVarBinding(vk, /*isOptional=*/false));
- };
- return parser.parseCommaSeparatedList(delimiter, parseIdBinding,
- " in id list");
+ParseResult DimLvlMapParser::parseSymbolBindingList() {
+ return parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::OptionalSquare,
+ [this]() { return ParseResult(parseVarBinding(VarKind::Symbol)); },
+ " in symbol binding list");
+}
+
+// FIXME: The forward-declaration of level-vars is a stop-gap workaround
+// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
+// `DimLvlMapParser::parseDimSpec`. (In particular, note that all the
+// variables must be bound before entering `AsmParser::parseAffineExpr`,
+// since that method requires every variable to already have a fixed/known
+// `Var::Num`.)
+//
+// However, the forward-declaration list duplicates information which is
+// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
+// the names of the variables themselves, and the order in which the names
+// are bound). This redundancy causes bad UX, and also means we must be
+// sure to verify consistency between the two sources of information.
+//
+// Therefore, it would be best to remove the forward-declaration list from
+// the syntax. This can be achieved by implementing our own version of
+// `AffineParser::parseAffineExpr` which calls
+// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
+// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
+// some `Var::Num` immediately). This would also enable us to use the `VarEnv`
+// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
+// side, and thus would also enable us to avoid the O(n^2) behavior of copying
+// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
+// every time `AsmParser::parseAffineExpr` is called.
+ParseResult DimLvlMapParser::parseLvlVarBindingList() {
+ return parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::OptionalBraces,
+ [this]() { return ParseResult(parseVarBinding(VarKind::Level)); },
+ " in level declaration list");
}
//===----------------------------------------------------------------------===//
ParseResult DimLvlMapParser::parseDimSpecList() {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::Paren,
- [&]() -> ParseResult { return parseDimSpec(); },
+ [this]() -> ParseResult { return parseDimSpec(); },
" in dimension-specifier list");
}
ParseResult DimLvlMapParser::parseDimSpec() {
- const auto res = parseVarBinding(VarKind::Dimension, /*isOptional=*/false);
- FAILURE_IF_FAILED(res)
- const DimVar var = res->first.cast<DimVar>();
+ // Parse the requisite dim-var binding.
+ const auto varID = parseVarBinding(VarKind::Dimension);
+ FAILURE_IF_FAILED(varID)
+ const DimVar var = env.getVar(*varID).cast<DimVar>();
// Parse an optional dimension expression.
AffineExpr affine;
if (succeeded(parser.parseOptionalEqual())) {
// Parse the dim affine expr, with only any lvl-vars in scope.
- SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
- env.addVars(dimsAndSymbols, VarKind::Level, parser.getContext());
- FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
+ // FIXME(wrengr): This still has the O(n^2) behavior of copying
+ // our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
+ // field every time `parseDimSpec` is called.
+ FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
}
DimExpr expr{affine};
//===----------------------------------------------------------------------===//
ParseResult DimLvlMapParser::parseLvlSpecList() {
- // If no level variable is declared at this point, the following level
- // specification consists of direct affine expressions only, as in:
- // (d0, d1) -> (d0 : dense, d1 : compressed)
- // Otherwise, we are looking for a leading lvl-var, as in:
- // {l0, l1} ( d0 = l0, d1 = l1) -> ( l0 = d0 : dense, l1 = d1: compressed)
- const bool directAffine = env.getRanks().getLvlRank() == 0;
- return parser.parseCommaSeparatedList(
+ // This method currently only supports two syntaxes:
+ //
+ // (1) There are no forward-declarations, and no lvl-var bindings:
+ // (d0, d1) -> (d0 : dense, d1 : compressed)
+ // Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus
+ // `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that
+ // the level-rank is correct at the end of parsing.
+ //
+ // (2) There are forward-declarations, and every lvl-spec must have
+ // a lvl-var binding:
+ // {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+ // However, this introduces duplicate information since the order of
+ // the lvl-vars in `parseLvlVarBindingList` must agree with their order
+ // in the list of lvl-specs. Therefore, `parseLvlSpec` will not call
+ // `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so),
+ // and must also validate the consistency between the two lvl-var orders.
+ const auto declaredLvlRank = env.getRanks().getLvlRank();
+ const bool requireLvlVarBinding = declaredLvlRank != 0;
+ // Have `ERROR_IF` point to the start of the list.
+ const auto loc = parser.getCurrentLocation();
+ const auto res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::Paren,
- [&]() -> ParseResult { return parseLvlSpec(directAffine); },
+ [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
" in level-specifier list");
+ FAILURE_IF_FAILED(res)
+ const auto specLvlRank = lvlSpecs.size();
+ ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank,
+ "Level-rank mismatch between forward-declarations and specifiers. "
+ "Declared " +
+ Twine(declaredLvlRank) + " level-variables; but got " +
+ Twine(specLvlRank) + " level-specifiers.")
+ return success();
+}
+
+static inline Twine nth(Var::Num n) {
+ switch (n) {
+ case 1:
+ return "1st";
+ case 2:
+ return "2nd";
+ default:
+ return Twine(n) + "th";
+ }
}
-ParseResult DimLvlMapParser::parseLvlSpec(bool directAffine) {
- auto res = parseLvlVarBinding(directAffine);
- FAILURE_IF_FAILED(res);
- LvlVar var = res->cast<LvlVar>();
+// NOTE: This is factored out as a separate method only because `Var`
+// lacks a default-ctor, which makes this conditional difficult to inline
+// at the one call-site.
+FailureOr<LvlVar>
+DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
+ // Nothing to parse, just bind an unnamed variable.
+ if (!requireLvlVarBinding)
+ return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
+
+ const auto loc = parser.getCurrentLocation();
+ // NOTE: Calling `parseVarUsage` here is semantically inappropriate,
+ // since the thing we're parsing is supposed to be a variable *binding*
+ // rather than a variable *use*. However, the call to `VarEnv::bindVar`
+ // (and its corresponding call to `DimLvlMapParser::recordVarBinding`)
+ // already occured in `parseLvlVarBindingList`, and therefore we must
+ // use `parseVarUsage` here in order to operationally do the right thing.
+ const auto varID = parseVarUsage(VarKind::Level, /*requireKnown=*/true);
+ FAILURE_IF_FAILED(varID)
+ const auto &info = std::as_const(env).access(*varID);
+ const auto var = info.getVar().cast<LvlVar>();
+ const auto forwardNum = var.getNum();
+ const auto specNum = lvlSpecs.size();
+ ERROR_IF(forwardNum != specNum,
+ "Level-variable ordering mismatch. The variable '" + info.getName() +
+ "' was forward-declared as the " + nth(forwardNum) +
+ " level; but is bound by the " + nth(specNum) +
+ " specification.")
+ FAILURE_IF_FAILED(parser.parseEqual())
+ return var;
+}
+
+ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
+ // Parse the optional lvl-var binding. (Actually, `requireLvlVarBinding`
+ // specifies whether that "optional" is actually Must or MustNot.)
+ const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
+ FAILURE_IF_FAILED(varRes)
+ const LvlVar var = *varRes;
// Parse the lvl affine expr, with only the dim-vars in scope.
AffineExpr affine;
- SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
- env.addVars(dimsAndSymbols, VarKind::Dimension, parser.getContext());
+ // FIXME(wrengr): This still has the O(n^2) behavior of copying
+ // our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
+ // field every time `parseLvlSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
LvlExpr expr{affine};
FAILURE_IF_FAILED(parser.parseColon())
-
const auto type = lvlTypeParser.parseLvlType(parser);
FAILURE_IF_FAILED(type)
FailureOr<DimLvlMap> parseDimLvlMap();
private:
+ /// The core code for parsing `Var`. This method abstracts out a lot
+ /// of complex details to avoid code duplication; however, client code
+ /// should prefer using `parseVarUsage` and `parseVarBinding` rather than
+ /// calling this method directly.
OptionalParseResult parseVar(VarKind vk, bool isOptional,
Policy creationPolicy, VarInfo::ID &id,
bool &didCreate);
- FailureOr<VarInfo::ID> parseVarUsage(VarKind vk);
- FailureOr<std::pair<Var, bool>> parseVarBinding(VarKind vk, bool isOptional);
- FailureOr<Var> parseLvlVarBinding(bool directAffine);
- ParseResult parseOptionalIdList(VarKind vk, OpAsmParser::Delimiter delimiter);
+ /// Parse a variable occurence which is a *use* of that variable.
+ /// The `requireKnown` parameter specifies how to handle the case of
+ /// encountering a valid variable name which is currently unused: when
+ /// `requireKnown=true`, an error is raised; when `requireKnown=false`,
+ /// a new unbound variable will be created.
+ ///
+ /// NOTE: Just because a variable is *known* (i.e., the name has been
+ /// associated with an `VarInfo::ID`), does not mean that the variable
+ /// is actually *in scope*.
+ FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
+
+ /// Parse a variable occurence which is a *binding* of that variable.
+ /// The `requireKnown` parameter is for handling the binding of
+ /// forward-declared variables.
+ FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
+
+ /// Parse an optional variable binding. When the next token is
+ /// not a valid variable name, this will bind a new unnamed variable.
+ /// The returned `bool` indicates whether a variable name was parsed.
+ FailureOr<std::pair<Var, bool>>
+ parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
+
+ /// Binds the given variable: both updating the `VarEnv` itself, and
+ /// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
+ /// to `AsmParser::parseAffineExpr`). This method is already called by the
+ /// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
+ /// not need to be called elsewhere.
+ Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
+
+ ParseResult parseSymbolBindingList();
+ ParseResult parseLvlVarBindingList();
ParseResult parseDimSpec();
ParseResult parseDimSpecList();
- ParseResult parseLvlSpec(bool directAffine);
+ FailureOr<LvlVar> parseLvlVarBinding(bool requireLvlVarBinding);
+ ParseResult parseLvlSpec(bool requireLvlVarBinding);
ParseResult parseLvlSpecList();
AsmParser &parser;
LvlTypeParser lvlTypeParser;
VarEnv env;
+ // The parser maintains the `{dims,lvls}AndSymbols` lists to avoid
+ // the O(n^2) cost of repeatedly constructing them inside of the
+ // `parse{Dim,Lvl}Spec` methods.
+ SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
+ SmallVector<std::pair<StringRef, AffineExpr>, 4> lvlsAndSymbols;
SmallVector<DimSpec> dimSpecs;
SmallVector<LvlSpec> lvlSpecs;
};