static StringRef getOperationName() { return "affine.terminator"; }
};
+/// The "affine.load" op reads an element from a memref, where the index
+/// for each memref dimension is an affine expression of loop induction
+/// variables and symbols. The output of 'affine.load' is a new value with the
+/// same type as the elements of the memref. An affine expression of loop IVs
+/// and symbols must be specified for each dimension of the memref. The keyword
+/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
+//
+// Example 1:
+//
+// %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+//
+// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
+//
+// %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)]
+// : memref<100x100xf32>
+//
+class AffineLoadOp : public Op<AffineLoadOp, OpTrait::OneResult,
+ OpTrait::AtLeastNOperands<1>::Impl> {
+public:
+ using Op::Op;
+
+ /// Builds an affine load op with the specified map and operands.
+ static void build(Builder *builder, OperationState *result, AffineMap map,
+ ArrayRef<Value *> operands);
+
+ /// Get memref operand.
+ Value *getMemRef() { return getOperand(0); }
+ void setMemRef(Value *value) { setOperand(0, value); }
+ MemRefType getMemRefType() {
+ return getMemRef()->getType().cast<MemRefType>();
+ }
+
+ /// Get affine map operands.
+ operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
+
+ /// Returns the affine map used to index the memref for this operation.
+ AffineMap getAffineMap() {
+ return getAttrOfType<AffineMapAttr>("map").getValue();
+ }
+
+ static StringRef getOperationName() { return "affine.load"; }
+
+ // Hooks to customize behavior of this op.
+ static ParseResult parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
+ LogicalResult verify();
+};
+
+/// The "affine.store" op writes an element to a memref, where the index
+/// for each memref dimension is an affine expression of loop induction
+/// variables and symbols. The 'affine.store' op stores a new value which is the
+/// same type as the elements of the memref. An affine expression of loop IVs
+/// and symbols must be specified for each dimension of the memref. The keyword
+/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
+//
+// Example 1:
+//
+// affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+//
+// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
+//
+// affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)]
+// : memref<100x100xf32>
+//
+class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult,
+ OpTrait::AtLeastNOperands<1>::Impl> {
+public:
+ using Op::Op;
+
+ /// Builds an affine store operation with the specified map and operands.
+ static void build(Builder *builder, OperationState *result,
+ Value *valueToStore, AffineMap map,
+ ArrayRef<Value *> operands);
+
+ /// Get value to be stored by store operation.
+ Value *getValueToStore() { return getOperand(0); }
+
+ /// Get memref operand.
+ Value *getMemRef() { return getOperand(1); }
+ void setMemRef(Value *value) { setOperand(1, value); }
+
+ MemRefType getMemRefType() {
+ return getMemRef()->getType().cast<MemRefType>();
+ }
+
+ /// Get affine map operands.
+ operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
+
+ /// Returns the affine map used to index the memref for this operation.
+ AffineMap getAffineMap() {
+ return getAttrOfType<AffineMapAttr>("map").getValue();
+ }
+
+ static StringRef getOperationName() { return "affine.store"; }
+
+ // Hooks to customize behavior of this op.
+ static ParseResult parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
+ LogicalResult verify();
+};
+
/// Returns true if the given Value can be used as a dimension id.
bool isValidDim(Value *value);
virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
bool printBlockTerminators = true) = 0;
+ /// Prints an affine map of SSA ids, where SSA id names are used in place
+ /// of dims/symbols.
+ virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
+ ArrayRef<Value *> operands) = 0;
+
/// Print an optional arrow followed by a type list.
void printOptionalArrowTypeList(ArrayRef<Type> types) {
if (types.empty())
/// Parse a `)` token if present.
virtual ParseResult parseOptionalRParen() = 0;
+ /// Parse a `[` token.
+ virtual ParseResult parseLSquare() = 0;
+
+ /// Parse a `]` token.
+ virtual ParseResult parseRSquare() = 0;
+
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
return success();
}
+ /// Parses an affine map attribute where dims and symbols are SSA operands.
+ virtual ParseResult
+ parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
+ StringRef attrName,
+ SmallVectorImpl<NamedAttribute> &attrs) = 0;
+
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
- addOperations<AffineApplyOp, AffineForOp, AffineIfOp, AffineTerminatorOp>();
+ addOperations<AffineApplyOp, AffineForOp, AffineIfOp, AffineLoadOp,
+ AffineStoreOp, AffineTerminatorOp>();
}
/// A utility function to check if a value is defined at the top level of a
/// Returns the list of 'else' blocks.
Region &AffineIfOp::getElseBlocks() { return getOperation()->getRegion(1); }
+
+//===----------------------------------------------------------------------===//
+// AffineLoadOp
+//===----------------------------------------------------------------------===//
+
+void AffineLoadOp::build(Builder *builder, OperationState *result,
+ AffineMap map, ArrayRef<Value *> operands) {
+ // TODO(b/133776335) Check that map operands are loop IVs or symbols.
+ result->addOperands(operands);
+ result->addAttribute("map", builder->getAffineMapAttr(map));
+}
+
+ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) {
+ auto &builder = parser->getBuilder();
+ auto affineIntTy = builder.getIndexType();
+
+ MemRefType type;
+ OpAsmParser::OperandType memrefInfo;
+ AffineMapAttr mapAttr;
+ SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+ return failure(
+ parser->parseOperand(memrefInfo) || parser->parseLSquare() ||
+ parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, "map",
+ result->attributes) ||
+ parser->parseRSquare() || parser->parseColonType(type) ||
+ parser->resolveOperand(memrefInfo, type, result->operands) ||
+ parser->resolveOperands(mapOperands, affineIntTy, result->operands) ||
+ parser->addTypeToList(type.getElementType(), result->types));
+}
+
+void AffineLoadOp::print(OpAsmPrinter *p) {
+ *p << "affine.load " << *getMemRef() << '[';
+ AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>("map");
+ SmallVector<Value *, 2> operands(getIndices());
+ p->printAffineMapOfSSAIds(mapAttr, operands);
+ *p << "] : " << getMemRefType();
+}
+
+LogicalResult AffineLoadOp::verify() {
+ if (getType() != getMemRefType().getElementType())
+ return emitOpError("result type must match element type of memref");
+
+ AffineMap map = getAttrOfType<AffineMapAttr>("map").getValue();
+ if (map.getNumResults() != getMemRefType().getRank())
+ return emitOpError("affine.load affine map num results must equal memref "
+ "rank");
+
+ for (auto *idx : getIndices())
+ if (!idx->getType().isIndex())
+ return emitOpError("index to load must have 'index' type");
+ // TODO(b/133776335) Verify that map operands are loop IVs or symbols.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AffineStoreOp
+//===----------------------------------------------------------------------===//
+
+void AffineStoreOp::build(Builder *builder, OperationState *result,
+ Value *valueToStore, AffineMap map,
+ ArrayRef<Value *> operands) {
+ // TODO(b/133776335) Check that map operands are loop IVs or symbols.
+ result->addOperands(valueToStore);
+ result->addOperands(operands);
+ result->addAttribute("map", builder->getAffineMapAttr(map));
+}
+
+ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) {
+ auto affineIntTy = parser->getBuilder().getIndexType();
+
+ MemRefType type;
+ OpAsmParser::OperandType storeValueInfo;
+ OpAsmParser::OperandType memrefInfo;
+ AffineMapAttr mapAttr;
+ SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+ return failure(
+ parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+ parser->parseOperand(memrefInfo) || parser->parseLSquare() ||
+ parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, "map",
+ result->attributes) ||
+ parser->parseRSquare() || parser->parseColonType(type) ||
+ parser->resolveOperand(storeValueInfo, type.getElementType(),
+ result->operands) ||
+ parser->resolveOperand(memrefInfo, type, result->operands) ||
+ parser->resolveOperands(mapOperands, affineIntTy, result->operands));
+}
+
+void AffineStoreOp::print(OpAsmPrinter *p) {
+ *p << "affine.store " << *getValueToStore();
+ *p << ", " << *getMemRef() << '[';
+ AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>("map");
+ SmallVector<Value *, 2> operands(getIndices());
+ p->printAffineMapOfSSAIds(mapAttr, operands);
+ *p << "] : " << getMemRefType();
+}
+
+LogicalResult AffineStoreOp::verify() {
+ // First operand must have same type as memref element type.
+ if (getValueToStore()->getType() != getMemRefType().getElementType())
+ return emitOpError("first operand must have same type memref element type");
+
+ AffineMap map = getAttrOfType<AffineMapAttr>("map").getValue();
+ if (map.getNumResults() != getMemRefType().getRank())
+ return emitOpError("affine.store affine map num results must equal memref "
+ "rank");
+
+ for (auto *idx : getIndices())
+ if (!idx->getType().isIndex())
+ return emitOpError("index to load must have 'index' type");
+ // TODO(b/133776335) Verify that map operands are loop IVs or symbols.
+ return success();
+}
void printLocation(LocationAttr loc);
void printAffineMap(AffineMap map);
- void printAffineExpr(AffineExpr expr);
+ void printAffineExpr(AffineExpr expr, ArrayRef<StringRef> dimValueNames = {},
+ ArrayRef<StringRef> symbolValueNames = {});
void printAffineConstraint(AffineExpr expr, bool isEq);
void printIntegerSet(IntegerSet set);
Strong, // All other binary operators.
};
void printAffineExprInternal(AffineExpr expr,
- BindingStrength enclosingTightness);
+ BindingStrength enclosingTightness,
+ ArrayRef<StringRef> dimValueNames = {},
+ ArrayRef<StringRef> symbolValueNames = {});
};
} // end anonymous namespace
// Affine expressions and maps
//===----------------------------------------------------------------------===//
-void ModulePrinter::printAffineExpr(AffineExpr expr) {
- printAffineExprInternal(expr, BindingStrength::Weak);
+void ModulePrinter::printAffineExpr(AffineExpr expr,
+ ArrayRef<StringRef> dimValueNames,
+ ArrayRef<StringRef> symbolValueNames) {
+ printAffineExprInternal(expr, BindingStrength::Weak, dimValueNames,
+ symbolValueNames);
}
void ModulePrinter::printAffineExprInternal(
- AffineExpr expr, BindingStrength enclosingTightness) {
+ AffineExpr expr, BindingStrength enclosingTightness,
+ ArrayRef<StringRef> dimValueNames, ArrayRef<StringRef> symbolValueNames) {
const char *binopSpelling = nullptr;
switch (expr.getKind()) {
- case AffineExprKind::SymbolId:
- os << 's' << expr.cast<AffineSymbolExpr>().getPosition();
+ case AffineExprKind::SymbolId: {
+ unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
+ if (pos < symbolValueNames.size())
+ os << "symbol(%" << symbolValueNames[pos] << ')';
+ else
+ os << 's' << pos;
return;
- case AffineExprKind::DimId:
- os << 'd' << expr.cast<AffineDimExpr>().getPosition();
+ }
+ case AffineExprKind::DimId: {
+ unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+ if (pos < dimValueNames.size())
+ os << '%' << dimValueNames[pos];
+ else
+ os << 'd' << pos;
return;
+ }
case AffineExprKind::Constant:
os << expr.cast<AffineConstantExpr>().getValue();
return;
auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
if (rhsConst && rhsConst.getValue() == -1) {
os << "-";
- printAffineExprInternal(lhsExpr, BindingStrength::Strong);
+ printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
+ symbolValueNames);
return;
}
- printAffineExprInternal(lhsExpr, BindingStrength::Strong);
+ printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
+ symbolValueNames);
os << binopSpelling;
- printAffineExprInternal(rhsExpr, BindingStrength::Strong);
+ printAffineExprInternal(rhsExpr, BindingStrength::Strong, dimValueNames,
+ symbolValueNames);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
AffineExpr rrhsExpr = rhs.getRHS();
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
if (rrhs.getValue() == -1) {
- printAffineExprInternal(lhsExpr, BindingStrength::Weak);
+ printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
+ symbolValueNames);
os << " - ";
if (rhs.getLHS().getKind() == AffineExprKind::Add) {
- printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong);
+ printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
+ dimValueNames, symbolValueNames);
} else {
- printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak);
+ printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
+ dimValueNames, symbolValueNames);
}
if (enclosingTightness == BindingStrength::Strong)
}
if (rrhs.getValue() < -1) {
- printAffineExprInternal(lhsExpr, BindingStrength::Weak);
+ printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
+ symbolValueNames);
os << " - ";
- printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong);
+ printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
+ dimValueNames, symbolValueNames);
os << " * " << -rrhs.getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
// Pretty print addition to a negative number as a subtraction.
if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
if (rhsConst.getValue() < 0) {
- printAffineExprInternal(lhsExpr, BindingStrength::Weak);
+ printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
+ symbolValueNames);
os << " - " << -rhsConst.getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
}
}
- printAffineExprInternal(lhsExpr, BindingStrength::Weak);
+ printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
+ symbolValueNames);
os << " + ";
- printAffineExprInternal(rhsExpr, BindingStrength::Weak);
+ printAffineExprInternal(rhsExpr, BindingStrength::Weak, dimValueNames,
+ symbolValueNames);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
os.indent(currentIndent) << "}";
}
+ void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
+ ArrayRef<Value *> operands) {
+ AffineMap map = mapAttr.getValue();
+ unsigned numDims = map.getNumDims();
+ SmallVector<StringRef, 2> dimValueNames;
+ SmallVector<StringRef, 1> symbolValueNames;
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ if (i < numDims)
+ dimValueNames.push_back(valueNames[operands[i]]);
+ else
+ symbolValueNames.push_back(valueNames[operands[i]]);
+ }
+ interleaveComma(map.getResults(), [&](AffineExpr expr) {
+ printAffineExpr(expr, dimValueNames, symbolValueNames);
+ });
+ }
+
// Number of spaces used for indenting nested operations.
const static unsigned indentWidth = 2;
ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
IntegerSet &set);
+ /// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+ ParseResult
+ parseAffineMapOfSSAIds(AffineMap &map,
+ llvm::function_ref<ParseResult(bool)> parseElement);
+
private:
/// The Parser is subclassed and reinstantiated. Do not add additional
/// non-trivial state here, add it to the ParserState class.
/// bodies.
class AffineParser : public Parser {
public:
- explicit AffineParser(ParserState &state) : Parser(state) {}
+ AffineParser(ParserState &state, bool allowParsingSSAIds = false,
+ llvm::function_ref<ParseResult(bool)> parseElement = nullptr)
+ : Parser(state), allowParsingSSAIds(allowParsingSSAIds),
+ parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {}
AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols);
ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set);
IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols);
+ ParseResult parseAffineMapOfSSAIds(AffineMap &map);
+ void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds,
+ unsigned &numDims);
private:
// Binary affine op parsing.
AffineExpr parseNegateExpression(AffineExpr lhs);
AffineExpr parseIntegerExpr();
AffineExpr parseBareIdExpr();
+ AffineExpr parseSSAIdExpr(bool isSymbol);
+ AffineExpr parseSymbolSSAIdExpr();
AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
AffineExpr rhs, SMLoc opLoc);
AffineExpr parseAffineConstraint(bool *isEq);
private:
+ bool allowParsingSSAIds;
+ llvm::function_ref<ParseResult(bool)> parseElement;
+ unsigned numDimOperands;
+ unsigned numSymbolOperands;
SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
};
} // end anonymous namespace
return (emitError("use of undeclared identifier"), nullptr);
}
+/// Parse an SSA id which may appear in an affine expression.
+AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) {
+ if (!allowParsingSSAIds)
+ return (emitError("unexpected ssa identifier"), nullptr);
+ if (getToken().isNot(Token::percent_identifier))
+ return (emitError("expected ssa identifier"), nullptr);
+ auto name = getTokenSpelling();
+ // Check if we already parsed this SSA id.
+ for (auto entry : dimsAndSymbols) {
+ if (entry.first == name) {
+ consumeToken(Token::percent_identifier);
+ return entry.second;
+ }
+ }
+ // Parse the SSA id and add an AffineDim/SymbolExpr to represent it.
+ if (parseElement(isSymbol))
+ return (emitError("failed to parse ssa identifier"), nullptr);
+ auto idExpr = isSymbol
+ ? getAffineSymbolExpr(numSymbolOperands++, getContext())
+ : getAffineDimExpr(numDimOperands++, getContext());
+ dimsAndSymbols.push_back({name, idExpr});
+ return idExpr;
+}
+
+AffineExpr AffineParser::parseSymbolSSAIdExpr() {
+ if (parseToken(Token::kw_symbol, "expected symbol keyword") ||
+ parseToken(Token::l_paren, "expected '(' at start of SSA symbol"))
+ return nullptr;
+ AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true);
+ if (!symbolExpr)
+ return nullptr;
+ if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol"))
+ return nullptr;
+ return symbolExpr;
+}
+
/// Parse a positive integral constant appearing in an affine expression.
///
/// affine-expr ::= integer-literal
switch (getToken().getKind()) {
case Token::bare_identifier:
return parseBareIdExpr();
+ case Token::kw_symbol:
+ return parseSymbolSSAIdExpr();
+ case Token::percent_identifier:
+ return parseSSAIdExpr(/*isSymbol=*/false);
case Token::integer:
return parseIntegerExpr();
case Token::l_paren:
return failure();
}
+/// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) {
+ SmallVector<AffineExpr, 4> exprs;
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseAffineExpr();
+ exprs.push_back(elt);
+ return elt ? success() : failure();
+ };
+
+ // Parse a multi-dimensional affine expression (a comma-separated list of
+ // 1-d affine expressions); the list cannot be empty. Grammar:
+ // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+ if (parseCommaSeparatedList(parseElt))
+ return failure();
+ // Parsed a valid affine map.
+ map = builder.getAffineMap(numDimOperands,
+ dimsAndSymbols.size() - numDimOperands, exprs);
+ return success();
+}
+
/// Parse the range and sizes affine map definition inline.
///
/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set);
}
+/// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to
+/// parse SSA value uses encountered while parsing affine expressions.
+ParseResult Parser::parseAffineMapOfSSAIds(
+ AffineMap &map, llvm::function_ref<ParseResult(bool)> parseElement) {
+ return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
+ .parseAffineMapOfSSAIds(map);
+}
+
//===----------------------------------------------------------------------===//
// OperationParser
//===----------------------------------------------------------------------===//
return success(parser.consumeIf(Token::r_paren));
}
+ /// Parse a `[` token.
+ ParseResult parseLSquare() override {
+ return parser.parseToken(Token::l_square, "expected '['");
+ }
+
+ /// Parse a `]` token.
+ ParseResult parseRSquare() override {
+ return parser.parseToken(Token::r_square, "expected ']'");
+ }
+
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
return failure();
}
+ /// Parse an AffineMap of SSA ids.
+ ParseResult parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands,
+ Attribute &mapAttr, StringRef attrName,
+ SmallVectorImpl<NamedAttribute> &attrs) {
+ SmallVector<OperandType, 2> dimOperands;
+ SmallVector<OperandType, 1> symOperands;
+
+ auto parseElement = [&](bool isSymbol) -> ParseResult {
+ OperandType operand;
+ if (parseOperand(operand))
+ return failure();
+ if (isSymbol)
+ symOperands.push_back(operand);
+ else
+ dimOperands.push_back(operand);
+ return success();
+ };
+
+ AffineMap map;
+ if (parser.parseAffineMapOfSSAIds(map, parseElement))
+ return failure();
+ // Add AffineMap attribute.
+ mapAttr = parser.builder.getAffineMapAttr(map);
+ attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr));
+
+ // Add dim operands before symbol operands in 'operands'.
+ operands.assign(dimOperands.begin(), dimOperands.end());
+ operands.append(symOperands.begin(), symOperands.end());
+ return success();
+ }
+
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
TOK_KEYWORD(size)
TOK_KEYWORD(sparse)
TOK_KEYWORD(step)
+TOK_KEYWORD(symbol)
TOK_KEYWORD(tensor)
TOK_KEYWORD(to)
TOK_KEYWORD(true)
--- /dev/null
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1)
+
+// Test with just loop IVs.
+func @test0(%arg0 : index, %arg1 : index) {
+ %0 = alloc() : memref<100x100xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 10 {
+ %1 = affine.load %0[%i0, %i1] : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[%i0, %i1] : memref<100x100xf32>
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0 + 3, d1 + 7)
+
+// Test with loop IVs and constants.
+func @test1(%arg0 : index, %arg1 : index) {
+ %0 = alloc() : memref<100x100xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 10 {
+ %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+ affine.store %1, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+// CHECK: affine.store %1, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (d0 + d1, d2 + d3)
+
+// Test with loop IVs and function args without 'symbol' keyword (should
+// be parsed as dim identifiers).
+func @test2(%arg0 : index, %arg1 : index) {
+ %0 = alloc() : memref<100x100xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 10 {
+ %1 = affine.load %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32>
+ affine.store %1, %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32>
+// CHECK: affine.store %1, %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32>
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)
+
+// Test with loop IVs and function args with 'symbol' keyword (should
+// be parsed as symbol identifiers).
+func @test3(%arg0 : index, %arg1 : index) {
+ %0 = alloc() : memref<100x100xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 10 {
+ %1 = affine.load %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)]
+ : memref<100x100xf32>
+ affine.store %1, %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)]
+ : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)] : memref<100x100xf32>
+// CHECK: affine.store %1, %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)] : memref<100x100xf32>
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0, s1] -> ((d0 + s0) floordiv 3 + 11, (d1 + s1) mod 4 + 7)
+
+// Test with loop IVs, symbols and constants in nested affine expressions.
+func @test4(%arg0 : index, %arg1 : index) {
+ %0 = alloc() : memref<100x100xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 10 {
+ %1 = affine.load %0[(%i0 + symbol(%arg0)) floordiv 3 + 11,
+ (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32>
+ affine.store %1, %0[(%i0 + symbol(%arg0)) floordiv 3 + 11,
+ (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[(%i0 + symbol(%arg0)) floordiv 3 + 11, (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32>
+// CHECK: affine.store %1, %0[(%i0 + symbol(%arg0)) floordiv 3 + 11, (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32>
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2)
+
+// Test with swizzled loop IVs.
+func @test5(%arg0 : index, %arg1 : index) {
+ %0 = alloc() : memref<10x10x10xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 10 {
+ affine.for %i2 = 0 to 10 {
+ %1 = affine.load %0[%i2, %i0, %i1] : memref<10x10x10xf32>
+ affine.store %1, %0[%i2, %i0, %i1] : memref<10x10x10xf32>
+// CHECK: %1 = affine.load %0[%i2, %i0, %i1] : memref<10x10x10xf32>
+// CHECK: affine.store %1, %0[%i2, %i0, %i1] : memref<10x10x10xf32>
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3, d4) -> (d0 + d1, d2 + d3, d3 + d1 + d4)
+
+// Test with swizzled loop IVs, duplicate args, and function args used as dims.
+// Dim identifiers are assigned in parse order:
+// d0 = %i2, d1 = %arg0, d2 = %i0, d3 = %i1, d4 = %arg1
+func @test6(%arg0 : index, %arg1 : index) {
+ %0 = alloc() : memref<10x10x10xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 10 {
+ affine.for %i2 = 0 to 10 {
+ %1 = affine.load %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1]
+ : memref<10x10x10xf32>
+ affine.store %1, %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1]
+ : memref<10x10x10xf32>
+// CHECK: %1 = affine.load %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1] : memref<10x10x10xf32>
+// CHECK: affine.store %1, %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1] : memref<10x10x10xf32>
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2)[s0, s1] -> (d0 + s0, d1 + d2, d2 + s0 + s1)
+
+// Test with swizzled loop IVs, duplicate args, and function args used as syms.
+// Dim and symbol identifiers are assigned in parse order:
+// d0 = %i2, d1 = %i0, d2 = %i1
+// s0 = %arg0, s1 = %arg1
+func @test6(%arg0 : index, %arg1 : index) {
+ %0 = alloc() : memref<10x10x10xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 10 {
+ affine.for %i2 = 0 to 10 {
+ %1 = affine.load %0[%i2 + symbol(%arg0),
+ %i0 + %i1,
+ %i1 + symbol(%arg0) + symbol(%arg1)]
+ : memref<10x10x10xf32>
+ affine.store %1, %0[%i2 + symbol(%arg0),
+ %i0 + %i1,
+ %i1 + symbol(%arg0) + symbol(%arg1)]
+ : memref<10x10x10xf32>
+// CHECK: %1 = affine.load %0[%i2 + symbol(%arg0), %i0 + %i1, %i1 + symbol(%arg0) + symbol(%arg1)] : memref<10x10x10xf32>
+// CHECK: affine.store %1, %0[%i2 + symbol(%arg0), %i0 + %i1, %i1 + symbol(%arg0) + symbol(%arg1)] : memref<10x10x10xf32>
+ }
+ }
+ }
+ return
+}
\ No newline at end of file