using namespace mlir::linalg;
/// Forward declarations.
+
+/// Generic entry point to create the block for the region of a LinalgOp.
+/// This is used by both named structured ops created by ods-gen and by manually
+/// defined C++ ops.
+/// This is used by both builders and parsers.
+/// This function creates the block in the region with arguments corresponding
+/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
+/// to be ShapedType.
+template <typename NamedStructuredOpType>
+static void fillStructuredOpRegion(
+ OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
+ TypeRange outputTypes, ValueRange captures = {},
+ std::function<void(unsigned, unsigned)> errorHandler = [](unsigned,
+ unsigned) {});
+
+/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
-static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
- OperationState &result,
- TypeRange inputTypes,
- TypeRange outputTypes);
+static void
+createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
+ TypeRange inputTypes, TypeRange outputTypes,
+ ValueRange captures = {});
+/// Common parsing and printing used for both named structured ops created by
+/// ods-gen and by manually defined C++ ops. Does not handle regions.
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes);
+template <typename NamedStructuredOpType>
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+ NamedStructuredOpType op);
+/// Specific parsing and printing for named structured ops created by ods-gen.
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
- TypeRange inputTypes, TypeRange outputTypes);
+ TypeRange inputTypes, TypeRange outputTypes,
+ ArrayRef<OpAsmParser::OperandType> captures = {});
+
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes);
template <typename NamedStructuredOpType>
-static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
- OperationState &result);
-
-template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
- NamedStructuredOpType op);
+static ParseResult
+parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
+ ArrayRef<OpAsmParser::OperandType> captures = {});
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes);
}
//===----------------------------------------------------------------------===//
+// CopyOp
+//===----------------------------------------------------------------------===//
+void CopyOp::regionBuilder(Block &block, ValueRange captures) {
+ using namespace edsc::intrinsics;
+ assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
+ (linalg_yield(block.getArgument(0)));
+}
+
+void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
+ Value output, AffineMap inputPermutation,
+ AffineMap outputPermutation,
+ ArrayRef<NamedAttribute> namedAttrs) {
+ result.addOperands({input, output});
+ result.addAttributes(namedAttrs);
+ if (inputPermutation)
+ result.addAttribute("inputPermutation",
+ AffineMapAttr::get(inputPermutation));
+ if (outputPermutation)
+ result.addAttribute("outputPermutation",
+ AffineMapAttr::get(outputPermutation));
+ result.addRegion();
+ fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
+ TypeRange{input.getType()},
+ TypeRange{output.getType()});
+}
+
+ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
+ Type outputType) {
+ OpBuilder opBuilder(parser.getBuilder().getContext());
+ fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
+ TypeRange{outputType});
+ return success();
+}
+
+/// CopyOp region is elided when printing.
+void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
+
+static LogicalResult verify(CopyOp op) {
+ auto outputViewType = op.getOutputShapedType(0);
+ auto inputViewType = op.getInputShapedType(0);
+ if (inputViewType.getElementType() != outputViewType.getElementType())
+ return op.emitOpError("expects views of the same type");
+ if (inputViewType.getRank() != outputViewType.getRank())
+ return op.emitOpError("expects views of the same rank");
+ auto rank = op.getNumParallelLoops();
+ auto inputPermutationMap = op.inputPermutation();
+ if (inputPermutationMap) {
+ if (inputPermutationMap->getNumInputs() != rank)
+ return op.emitOpError("expects optional input_permutation map of rank ")
+ << rank;
+ if (!inputPermutationMap->isPermutation())
+ return op.emitOpError(
+ "expects optional input_permutation map to be a permutation");
+ }
+ auto outputPermutationMap = op.outputPermutation();
+ if (outputPermutationMap) {
+ if (outputPermutationMap->getNumInputs() != rank)
+ return op.emitOpError("expects optional output_permutation map of rank ")
+ << rank;
+ if (!outputPermutationMap->isPermutation())
+ return op.emitOpError(
+ "expects optional output_permutation map to be a permutation");
+ }
+ if (rank == 0 && inputPermutationMap)
+ return op.emitOpError("expected no input permutation when rank == 0");
+ if (rank == 0 && outputPermutationMap)
+ return op.emitOpError("expected no output permutation when rank == 0");
+ return success();
+}
+
+void CopyOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), input(),
+ SideEffects::DefaultResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), output(),
+ SideEffects::DefaultResource::get());
+}
+
+//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
+void FillOp::regionBuilder(Block &block, ValueRange captures) {
+ using namespace edsc::intrinsics;
+ assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
+ (linalg_yield(captures));
+}
void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
Value value) {
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
value);
+ fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
+ TypeRange{output.getType()}, value);
+}
+
+ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
+ OpAsmParser::OperandType valueRef) {
+ OpBuilder opBuilder(parser.getBuilder().getContext());
+ // Resolve `valueRef` into `value` at parse time so we can build the region
+ // with captures.
+ SmallVector<Value> value;
+ parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
+ fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
+ TypeRange{outputType}, value);
+ return success();
+}
+
+/// FillOp region is elided when printing.
+void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
+
+static LogicalResult verify(FillOp op) {
+ auto viewType = op.getOutputShapedType(0);
+ auto fillType = op.value().getType();
+ if (viewType.getElementType() != fillType)
+ return op.emitOpError("expects fill type to match view elemental type");
+ if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
+ return op.emitOpError(
+ "expected fill op with no result value to use memref type");
+ }
+ return success();
+}
+
+void FillOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (output().getType().isa<MemRefType>())
+ effects.emplace_back(MemoryEffects::Write::get(), output(),
+ SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===//
// InitTensorOp
//===----------------------------------------------------------------------===//
-
static LogicalResult verify(InitTensorOp op) {
RankedTensorType resultType = op.getType();
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
/////// Operations corresponding to library calls defined with Tablegen ////////
-void FillOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- if (output().getType().isa<MemRefType>())
- effects.emplace_back(MemoryEffects::Write::get(), output(),
- SideEffects::DefaultResource::get());
-}
-
-static LogicalResult verify(FillOp op) {
- auto viewType = op.getOutputShapedType(0);
- auto fillType = op.value().getType();
- if (viewType.getElementType() != fillType)
- return op.emitOpError("expects fill type to match view elemental type");
- if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
- return op.emitOpError(
- "expected fill op with no result value to use memref type");
- }
- return success();
-}
-
-void CopyOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- effects.emplace_back(MemoryEffects::Read::get(), input(),
- SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), output(),
- SideEffects::DefaultResource::get());
-}
-
-static LogicalResult verify(CopyOp op) {
- auto outputViewType = op.getOutputShapedType(0);
- auto inputViewType = op.getInputShapedType(0);
- if (inputViewType.getElementType() != outputViewType.getElementType())
- return op.emitOpError("expects views of the same type");
- if (inputViewType.getRank() != outputViewType.getRank())
- return op.emitOpError("expects views of the same rank");
- auto rank = op.getNumParallelLoops();
- auto inputPermutationMap = op.inputPermutation();
- if (inputPermutationMap) {
- if (inputPermutationMap->getNumInputs() != rank)
- return op.emitOpError("expects optional input_permutation map of rank ")
- << rank;
- if (!inputPermutationMap->isPermutation())
- return op.emitOpError(
- "expects optional input_permutation map to be a permutation");
- }
- auto outputPermutationMap = op.outputPermutation();
- if (outputPermutationMap) {
- if (outputPermutationMap->getNumInputs() != rank)
- return op.emitOpError("expects optional output_permutation map of rank ")
- << rank;
- if (!outputPermutationMap->isPermutation())
- return op.emitOpError(
- "expects optional output_permutation map to be a permutation");
- }
- if (rank == 0 && inputPermutationMap)
- return op.emitOpError("expected no input permutation when rank == 0");
- if (rank == 0 && outputPermutationMap)
- return op.emitOpError("expected no output permutation when rank == 0");
- return success();
-}
-
template <typename LinalgPoolingOp>
static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
ArrayRef<Attribute> attrs,
}
//===----------------------------------------------------------------------===//
-// Auto-generated Linalg named ops.
+// Support for named Linalg ops defined in ods-gen.
//===----------------------------------------------------------------------===//
+/// Generic entry point to create the block for the region of a LinalgOp.
+/// This is used by both named structured ops created by ods-gen and by manually
+/// defined C++ ops.
+/// This is used by both builders and parsers.
+/// This function creates the block in the region with arguments corresponding
+/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
+/// to be ShapedType.
template <typename NamedStructuredOpType>
-static void buildNamedStructuredOpRegionAndAttributesImpl(
- OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
- TypeRange outputTypes,
- std::function<void(unsigned, unsigned)> errorHandler) {
+static void
+fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
+ TypeRange inputTypes, TypeRange outputTypes,
+ ValueRange captures,
+ std::function<void(unsigned, unsigned)> errorHandler) {
+ assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
+ assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
+
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
SmallVector<Type, 8> argTypes;
// RAII.
OpBuilder::InsertionGuard guard(opBuilder);
- Block *body = opBuilder.createBlock(®ion, {}, argTypes);
+ Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes);
unsigned actual = body->getNumArguments();
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
if (expected != actual)
opBuilder.setInsertionPointToStart(body);
mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
- NamedStructuredOpType::regionBuilder(*body);
+ NamedStructuredOpType::regionBuilder(*body, captures);
// indexing_maps is an auto-generated method.
// iterator_types is an auto-generated method.
}
+/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
-void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
- OperationState &result,
- TypeRange inputTypes,
- TypeRange outputTypes) {
+void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
+ OperationState &result,
+ TypeRange inputTypes,
+ TypeRange outputTypes,
+ ValueRange captures) {
Region ®ion = *result.addRegion();
- buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
- opBuilder, region, inputTypes, outputTypes,
+ fillStructuredOpRegion<NamedStructuredOpType>(
+ opBuilder, region, inputTypes, outputTypes, captures,
[&](unsigned expected, unsigned actual) {
- llvm::errs() << "region expects " << expected << " args, got "
- << actual;
assert(expected != actual && "incorrect number of arguments");
});
}
-template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
- TypeRange inputTypes, TypeRange outputTypes) {
- ParseResult res = success();
- OpBuilder opBuilder(parser.getBuilder().getContext());
- buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
- opBuilder, region, inputTypes, outputTypes,
- [&](unsigned expected, unsigned actual) {
- res = parser.emitError(parser.getCurrentLocation(),
- llvm::formatv("region expects {0} args, got {1}",
- expected, actual));
- });
- return res;
-}
-
-static ParseResult
-parseNamedStructuredOpResults(OpAsmParser &parser,
- SmallVectorImpl<Type> &resultTypes) {
- if (succeeded(parser.parseOptionalArrow()))
- if (parser.parseTypeList(resultTypes))
- return failure();
- return success();
-}
-
+/// Common parsing used for both named structured ops created by ods-gen and by
+/// manually defined C++ ops. Does not handle regions.
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
}
template <typename NamedStructuredOpType>
-static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
- OperationState &result) {
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+ NamedStructuredOpType op) {
+ if (!op.inputs().empty())
+ p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
+ if (!op.outputs().empty())
+ p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
+}
+
+//===----------------------------------------------------------------------===//
+// Specific parsing and printing for named structured ops created by ods-gen.
+//===----------------------------------------------------------------------===//
+
+template <typename NamedStructuredOpType>
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
+ TypeRange inputTypes, TypeRange outputTypes,
+ ArrayRef<OpAsmParser::OperandType> captures) {
+ ParseResult res = success();
+ OpBuilder opBuilder(parser.getBuilder().getContext());
+ // Resolve `captures` into `capturedValues` at parse time so we can build the
+ // region with captures.
+ SmallVector<Value> capturedValues;
+ fillStructuredOpRegion<NamedStructuredOpType>(
+ opBuilder, region, inputTypes, outputTypes, capturedValues,
+ [&](unsigned expected, unsigned actual) {
+ res = parser.emitError(
+ parser.getCurrentLocation(),
+ llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
+ "region expects {0} args, got {1}",
+ expected, actual));
+ region.front().dump();
+ });
+ return res;
+}
+
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+ SmallVectorImpl<Type> &resultTypes) {
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(resultTypes))
+ return failure();
+ return success();
+}
+
+template <typename NamedStructuredOpType>
+static ParseResult
+parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
+ ArrayRef<OpAsmParser::OperandType> captures) {
+ // TODO: Enable when ods-gen supports captures.
+ assert(captures.empty() && "unexpected captures for named structured ops");
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
std::unique_ptr<Region> region = std::make_unique<Region>();
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
- parser, *region, inputTypes, outputTypes))
+ parser, *region, inputTypes, outputTypes, captures))
return failure();
result.addRegion(std::move(region));
}
template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
- NamedStructuredOpType op) {
- if (!op.inputs().empty())
- p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
- if (!op.outputs().empty())
- p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
-}
-
-template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
p << op.getOperationName();
p.printOptionalAttrDict(op.getAttrs(),
return verifyGenericOp<NamedStructuredOpType>(op);
}
+//===----------------------------------------------------------------------===//
+// Canonicalizers and Folders.
+//===----------------------------------------------------------------------===//
+
namespace {
struct EraseDeadLinalgOp : public RewritePattern {
EraseDeadLinalgOp(PatternBenefit benefit = 1)