From b939c015a4ad1f1d07f93d322e7dbe2feb0a13bc Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Fri, 30 Jun 2023 13:05:01 -0700 Subject: [PATCH] [mlir][sparse] add affine parsing to new surface syntax for STEA (1) uses the previously introduce API to reuse AffineExpr parser without codedup (2) solves the look-ahead problem when parsing level spec Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D154254 --- .../Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp | 10 +- .../SparseTensor/IR/Detail/DimLvlMapParser.cpp | 142 ++++++++------------- .../SparseTensor/IR/Detail/DimLvlMapParser.h | 34 +++-- mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp | 12 ++ mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h | 48 +------ .../Dialect/SparseTensor/roundtrip_encoding.mlir | 67 +++++++++- 6 files changed, 159 insertions(+), 154 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp index b8eeb1a..d872587 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp @@ -241,7 +241,7 @@ void LvlSpec::print(llvm::raw_ostream &os, bool wantElision) const { if (!wantElision || !elideVar) os << var << " = "; os << expr; - os << ": \"" << toMLIRString(type) << "\""; + os << ": " << toMLIRString(type); } //===----------------------------------------------------------------------===// @@ -264,10 +264,10 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef dimSpecs, // Third, we set every `LvlSpec::elideVar` according to whether that // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr). VarSet usedVars(getRanks()); - for (const auto &dimSpec : dimSpecs) - // NOTE TO Wren: bypassed for empty - if (dimSpec.hasExpr() && !dimSpec.canElideExpr()) - usedVars.add(dimSpec.getExpr()); + // NOTE TO Wren: bypassed for now + // for (const auto &dimSpec : dimSpecs) + // if (!dimSpec.canElideExpr()) + // usedVars.add(dimSpec.getExpr()); for (auto &lvlSpec : this->lvlSpecs) lvlSpec.setElideVar(!usedVars.contains(lvlSpec.getBoundVar())); } diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp index 6ff72c1..361425d 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp @@ -1,8 +1,4 @@ //===- DimLvlMapParser.cpp - `DimLvlMap` parser implementation ------------===// -// These two lookup methods are probably small enough to benefit from -// being defined inline/in-class, expecially since doing so may allow the -// compiler to optimize the `std::optional` away. But we put the defns -// here until benchmarks prove the benefit of doing otherwise. // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -16,32 +12,12 @@ using namespace mlir; using namespace mlir::sparse_tensor; using namespace mlir::sparse_tensor::ir_detail; -//===----------------------------------------------------------------------===// -// TODO(wrengr): rephrase these to do the trick for gobbling up any trailing -// semicolon -// -// NOTE: There's no way for `FAILURE_IF_FAILED` to simultaneously support -// both `OptionalParseResult` and `InFlightDiagnostic` return types. -// We can get the compiler to accept the code if we returned "`{}`", -// however for `OptionalParseResult` that would become the nullopt result, -// whereas for `InFlightDiagnostic` it would become a result that can -// be implicitly converted to success. By using "`failure()`" we ensure -// that `OptionalParseResult` behaves as intended, however that means the -// macro cannot be used for `InFlightDiagnostic` since there's no implicit -// conversion. #define FAILURE_IF_FAILED(STMT) \ if (failed(STMT)) { \ return failure(); \ } -// Although `ERROR_IF` is phrased to return `InFlightDiagnostic`, that type -// can be implicitly converted to all four of `LogicalResult, `FailureOr`, -// `ParseResult`, and `OptionalParseResult`. (However, beware that the -// conversion to `OptionalParseResult` doesn't properly delegate to -// `InFlightDiagnostic::operator ParseResult`.) -// // NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope. -// NOTE_TO_SELF(wrengr): The LOC used to always be `parser.getNameLoc()` #define ERROR_IF(COND, MSG) \ if (COND) { \ return parser.emitError(loc, MSG); \ @@ -107,11 +83,8 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional, FailureOr DimLvlMapParser::parseVarUsage(VarKind vk) { VarInfo::ID varID; bool didCreate; - // We use the policy `May` because we want to allow parsing free/unbound - // variables. If we wanted to distinguish between parsing free-var uses - // vs bound-var uses, then the latter should use `MustNot`. - const auto res = - parseVar(vk, /*isOptional=*/false, CreationPolicy::May, varID, didCreate); + const auto res = parseVar(vk, /*isOptional=*/false, CreationPolicy::MustNot, + varID, didCreate); if (!res.has_value() || failed(*res)) return failure(); return varID; @@ -126,9 +99,19 @@ DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) { if (res.has_value()) { FAILURE_IF_FAILED(*res) return std::make_pair(env.bindVar(id), true); - } else { - return std::make_pair(env.bindUnusedVar(vk), false); } + return std::make_pair(env.bindUnusedVar(vk), false); +} + +FailureOr DimLvlMapParser::parseLvlVarBinding(bool directAffine) { + // Nothing to parse, create a new lvl var right away. + if (directAffine) + return env.bindUnusedVar(VarKind::Level).cast(); + // Parse a lvl var, always pulling from the existing pool. + const auto use = parseVarUsage(VarKind::Level); + FAILURE_IF_FAILED(use) + FAILURE_IF_FAILED(parser.parseEqual()) + return env.toVar(*use); } //===----------------------------------------------------------------------===// @@ -136,7 +119,10 @@ DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) { //===----------------------------------------------------------------------===// FailureOr DimLvlMapParser::parseDimLvlMap() { - FAILURE_IF_FAILED(parseOptionalSymbolIdList()) + FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Symbol, + OpAsmParser::Delimiter::OptionalSquare)) + FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Level, + OpAsmParser::Delimiter::OptionalBraces)) FAILURE_IF_FAILED(parseDimSpecList()) FAILURE_IF_FAILED(parser.parseArrow()) FAILURE_IF_FAILED(parseLvlSpecList()) @@ -148,19 +134,14 @@ FailureOr DimLvlMapParser::parseDimLvlMap() { return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs); } -using Delimiter = mlir::OpAsmParser::Delimiter; - -ParseResult DimLvlMapParser::parseOptionalSymbolIdList() { - const auto parseSymVarBinding = [&]() -> ParseResult { - return ParseResult(parseVarBinding(VarKind::Symbol, /*isOptional=*/false)); +ParseResult +DimLvlMapParser::parseOptionalIdList(VarKind vk, + OpAsmParser::Delimiter delimiter) { + const auto parseIdBinding = [&]() -> ParseResult { + return ParseResult(parseVarBinding(vk, /*isOptional=*/false)); }; - // If I've correctly unpacked how exactly `Parser::parseCommaSeparatedList` - // handles the "optional" delimiters vs the non-optional ones, then - // the following call to `AsmParser::parseCommaSeparatedList` should - // be equivalent to the whole `AffineParse::parseOptionalSymbolIdList` - // method (which uses `Parser` methods to handle the optionality instead). - return parser.parseCommaSeparatedList(Delimiter::OptionalSquare, - parseSymVarBinding, " in symbol list"); + return parser.parseCommaSeparatedList(delimiter, parseIdBinding, + " in id list"); } //===----------------------------------------------------------------------===// @@ -169,7 +150,8 @@ ParseResult DimLvlMapParser::parseOptionalSymbolIdList() { ParseResult DimLvlMapParser::parseDimSpecList() { return parser.parseCommaSeparatedList( - Delimiter::Paren, [&]() -> ParseResult { return parseDimSpec(); }, + OpAsmParser::Delimiter::Paren, + [&]() -> ParseResult { return parseDimSpec(); }, " in dimension-specifier list"); } @@ -178,22 +160,17 @@ ParseResult DimLvlMapParser::parseDimSpec() { FAILURE_IF_FAILED(res) const DimVar var = res->first.cast(); - DimExpr expr{AffineExpr()}; + // Parse an optional dimension expression. + AffineExpr affine; if (succeeded(parser.parseOptionalEqual())) { - // FIXME(wrengr): I don't think there's any way to implement this - // without replicating the bulk of `AffineParser::parseAffineExpr` - // TODO(wrengr): Also, need to make sure the parser uses - // `parseVarUsage(VarKind::Level)` so that every `AffineDimExpr` - // necessarily corresponds to a `LvlVar` (never a `DimVar`). - // - // FIXME: proof of concept, parse trivial level vars (viz d0 = l0). - auto use = parseVarUsage(VarKind::Level); - FAILURE_IF_FAILED(use) - AffineExpr a = getAffineDimExpr(var.getNum(), parser.getContext()); - DimExpr dexpr{a}; - expr = dexpr; + // Parse the dim affine expr, with only any lvl-vars in scope. + SmallVector, 4> dimsAndSymbols; + env.addVars(dimsAndSymbols, VarKind::Level, parser.getContext()); + FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine)) } + DimExpr expr{affine}; + // Parse an optional slice. SparseTensorDimSliceAttr slice; if (succeeded(parser.parseOptionalColon())) { const auto loc = parser.getCurrentLocation(); @@ -212,40 +189,29 @@ ParseResult DimLvlMapParser::parseDimSpec() { //===----------------------------------------------------------------------===// ParseResult DimLvlMapParser::parseLvlSpecList() { + // If no level variable is declared at this point, the following level + // specification consists of direct affine expressions only, as in: + // (d0, d1) -> (d0 : dense, d1 : compressed) + // Otherwise, we are looking for a leading lvl-var, as in: + // {l0, l1} ( d0 = l0, d1 = l1) -> ( l0 = d0 : dense, l1 = d1: compressed) + const bool directAffine = env.getRanks().getLvlRank() == 0; return parser.parseCommaSeparatedList( - Delimiter::Paren, [&]() -> ParseResult { return parseLvlSpec(); }, + mlir::OpAsmParser::Delimiter::Paren, + [&]() -> ParseResult { return parseLvlSpec(directAffine); }, " in level-specifier list"); } -ParseResult DimLvlMapParser::parseLvlSpec() { - // FIXME(wrengr): This implementation isn't actually going to work as-is, - // due to grammar ambiguity. That is, assuming the current token is indeed - // a variable, we don't yet know whether that variable is supposed to - // be a binding vs being a usage that's part of the following AffineExpr. - // We can only disambiguate that by peeking at the next token to see whether - // it's the equals symbol or not. - // - // FIXME: proof of concept, assume it is new (viz. l0 = d0). - const auto res = parseVarBinding(VarKind::Level, /*isOptional=*/true); - FAILURE_IF_FAILED(res) - if (res->second) { - FAILURE_IF_FAILED(parser.parseEqual()) - } - const LvlVar var = res->first.cast(); - - // FIXME(wrengr): I don't think there's any way to implement this - // without replicating the bulk of `AffineParser::parseAffineExpr` - // - // TODO(wrengr): Also, need to make sure the parser uses - // `parseVarUsage(VarKind::Dimension)` so that every `AffineDimExpr` - // necessarily corresponds to a `DimVar` (never a `LvlVar`). - // - // FIXME: proof of concept, parse trivial dim vars (viz l0 = d0). - auto use = parseVarUsage(VarKind::Dimension); - FAILURE_IF_FAILED(use) - AffineExpr a = - getAffineDimExpr(env.toVar(*use).getNum(), parser.getContext()); - LvlExpr expr{a}; +ParseResult DimLvlMapParser::parseLvlSpec(bool directAffine) { + auto res = parseLvlVarBinding(directAffine); + FAILURE_IF_FAILED(res); + LvlVar var = res->cast(); + + // Parse the lvl affine expr, with only the dim-vars in scope. + AffineExpr affine; + SmallVector, 4> dimsAndSymbols; + env.addVars(dimsAndSymbols, VarKind::Dimension, parser.getContext()); + FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine)) + LvlExpr expr{affine}; FAILURE_IF_FAILED(parser.parseColon()) diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h index a77879e..2cdf66b 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h @@ -16,14 +16,23 @@ namespace mlir { namespace sparse_tensor { namespace ir_detail { -//===----------------------------------------------------------------------===// -// NOTE(wrengr): The idea here was originally based on the -// "lib/AsmParser/AffineParser.cpp"-static class `AffineParser`. -// Unfortunately, we can't use that class directly since it's file-local. -// Even worse, both `mlir::detail::Parser` and `mlir::detail::ParserState` -// are also file-local classes. I've been attempting to convert things -// over to using `AsmParser` wherever possible, though it's not clear that -// that'll work... +/// +/// Parses the Sparse Tensor Encoding Attribute (STEA). +/// +/// General syntax is as follows, +/// +/// [s0, ...] // optional forward decl sym-vars +/// {l0, ...} // optional forward decl lvl-vars +/// ( +/// d0 = ..., // dim-var = dim-exp +/// ... +/// ) -> ( +/// l0 = ..., // lvl-var = lvl-exp +/// ... +/// ) +/// +/// with simplifications when variables are implicit. +/// class DimLvlMapParser final { public: explicit DimLvlMapParser(AsmParser &parser) : parser(parser) {} @@ -33,18 +42,17 @@ public: FailureOr parseDimLvlMap(); private: - // TODO(wrengr): rather than using `OptionalParseResult` and two - // out-parameters, should we define a type to encapsulate all that? OptionalParseResult parseVar(VarKind vk, bool isOptional, CreationPolicy creationPolicy, VarInfo::ID &id, bool &didCreate); FailureOr parseVarUsage(VarKind vk); FailureOr> parseVarBinding(VarKind vk, bool isOptional); + FailureOr parseLvlVarBinding(bool directAffine); - ParseResult parseOptionalSymbolIdList(); + ParseResult parseOptionalIdList(VarKind vk, OpAsmParser::Delimiter delimiter); ParseResult parseDimSpec(); ParseResult parseDimSpecList(); - ParseResult parseLvlSpec(); + ParseResult parseLvlSpec(bool directAffine); ParseResult parseLvlSpecList(); AsmParser &parser; @@ -54,8 +62,6 @@ private: SmallVector lvlSpecs; }; -//===----------------------------------------------------------------------===// - } // namespace ir_detail } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp index 5cc9c22..35022d7 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -289,4 +289,16 @@ InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const { return {}; } +void VarEnv::addVars( + SmallVectorImpl> &dimsAndSymbols, + VarKind vk, MLIRContext *context) const { + for (const auto &var : vars) { + if (var.getKind() == vk) { + assert(var.hasNum()); + dimsAndSymbols.push_back(std::make_pair( + var.getName(), getAffineDimExpr(*var.getNum(), context))); + } + } +} + //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h index e3b1038..0c9a2cf 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h @@ -162,11 +162,6 @@ public: }; static_assert(IsZeroCostAbstraction); -// TODO(wrengr): I'd like to give the ctors the types `DimVar(Dimension)` -// and `LvlVar(Level)`, instead of their current types using `Num`; -// however, that'd require importing "IR/SparseTensor.h" which nothing else -// in this file requires. Also beware the issues about implicit-conversion -// from `uint64_t` to `Num`. class DimVar final : public Var { public: static constexpr VarKind Kind = VarKind::Dimension; @@ -189,32 +184,6 @@ public: }; static_assert(IsZeroCostAbstraction); -// FIXME(wrengr): In order to get the `llvm::{isa,cast,dyn_cast}` -// free-functions to work (instead of using our hand-rolled methods), -// we'll need to define something like this: -// ``` -// namespace llvm { -// template struct CastInfo : OptionalValueCast {}; -// template <> struct ValueIsPresent { -// using UnwrappedType = Var; -// static inline bool isPresent(Var const&) { return true; } -// }; -// } // namespace llvm -// ``` -// The above will enable the type `llvm::dyn_cast(Var) -> std::optional`. -// -// FIXME(wrengr): The default `OptionalValueCast::doCast(Var const&)` -// implementation uses the expression "`U(var)`", which means that all the -// subclasses will need to define that upcasting-copy-ctor, and to ensure -// safety/correctness will need to mark that ctor as private/protected, -// which in turn means they'll need make the `CastInfo`/`OptionalValueCast` -// classes friends. -// -// We run into similar issues with our hand-rolled methods, the only -// difference is that the upcasting-copy-ctor would have type `U(Impl)` -// instead of `U(Var)` and that we'd need to make the `Var` class a friend -// rather than the `CastInfo`/`OptionalValueCast` classes. -// template constexpr bool Var::isa() const { if constexpr (std::is_same_v) @@ -257,8 +226,6 @@ class Ranks final { } public: - // NOTE_TO_SELF(wrengr): According to - // we should be able to do this just fine, even though `constexpr` constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank) : impl() { impl[to_index(VarKind::Symbol)] = symRank; @@ -305,16 +272,6 @@ public: }; //===----------------------------------------------------------------------===// -// TODO(wrengr): For good error messages we'll need to define something like: -// ```class LocatedVar final { llvm::SMLoc loc; VarInfo::ID id; };``` -// to be the actual thing occuring in our variant of AffineExpr. -// Though we may also want that struct to contain a pointer back to the -// `VarEnv` which contains the `VarInfo` for that `VarInfo::ID`. -// -// To go along with this, the `VarInfo` record should drop its own `SMLoc` -// field. - -//===----------------------------------------------------------------------===// /// A record of metadata for/about a variable, used by `VarEnv`. /// The principal goal of this record is to enable `VarEnv` to be used for /// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to @@ -456,6 +413,11 @@ public: InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const; Ranks getRanks() const { return Ranks(nextNum); } + + /// Adds all variables of given kind to the vector. + void + addVars(SmallVectorImpl> &dimsAndSymbols, + VarKind vk, MLIRContext *context) const; }; //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir index e9efbd0..5a7c5a6 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -129,19 +129,78 @@ func.func private @sparse_slice(tensor) // CHECK-SAME: tensor> func.func private @sparse_slice(tensor) -// ----- +/////////////////////////////////////////////////////////////////////////////// // Migration plan for new STEA surface syntax, // use the NEW_SYNTAX on selected examples // and then TODO: remove when fully migrated +/////////////////////////////////////////////////////////////////////////////// + +// ----- + +#CSR_implicit = #sparse_tensor.encoding<{ + NEW_SYNTAX = + (d0, d1) -> (d0 : dense, d1 : compressed) +}> + +// CHECK-LABEL: func private @foo( +// CHECK-SAME: tensor> +func.func private @foo(%arg0: tensor) { + return +} + +// ----- -#NewSurfaceSyntax = #sparse_tensor.encoding<{ +#CSR_explicit = #sparse_tensor.encoding<{ NEW_SYNTAX = - (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) + {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed) }> // CHECK-LABEL: func private @foo( // CHECK-SAME: tensor> -func.func private @foo(%arg0: tensor) { +func.func private @foo(%arg0: tensor) { + return +} + +// ----- + +#BCSR_implicit = #sparse_tensor.encoding<{ + NEW_SYNTAX = + ( i, j ) -> + ( i floordiv 2 : compressed, + j floordiv 3 : compressed, + i mod 2 : dense, + j mod 3 : dense + ) +}> + +// FIXME: should not have to use 4 dims ;-) +// +// CHECK-LABEL: func private @foo( +// CHECK-SAME: tensor> +func.func private @foo(%arg0: tensor) { + return +} + +// ----- + +#BCSR_explicit = #sparse_tensor.encoding<{ + NEW_SYNTAX = + {il, jl, ii, jj} + ( i = il * 2 + ii, + j = jl * 3 + jj + ) -> + ( il = i floordiv 2 : compressed, + jl = j floordiv 3 : compressed, + ii = i mod 2 : dense, + jj = j mod 3 : dense + ) +}> + +// FIXME: should not have to use 4 dims ;-) +// +// CHECK-LABEL: func private @foo( +// CHECK-SAME: tensor> +func.func private @foo(%arg0: tensor) { return } -- 2.7.4