using namespace mlir::linalg;
/// Forward declarations.
-
-/// Generic entry point to create the block for the region of a LinalgOp.
-/// This is used by both named structured ops created by ods-gen and by manually
-/// defined C++ ops.
-/// This is used by both builders and parsers.
-/// This function creates the block in the region with arguments corresponding
-/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
-/// to be ShapedType.
-template <typename NamedStructuredOpType>
-static void fillStructuredOpRegion(
- OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
- TypeRange outputTypes, ValueRange captures = {},
- std::function<void(unsigned, unsigned)> errorHandler = [](unsigned,
- unsigned) {});
-
-/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
-static void
-createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
- TypeRange inputTypes, TypeRange outputTypes,
- ValueRange captures = {});
+static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
+ OperationState &result,
+ TypeRange inputTypes,
+ TypeRange outputTypes);
-/// Common parsing and printing used for both named structured ops created by
-/// ods-gen and by manually defined C++ ops. Does not handle regions.
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes);
-template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
- NamedStructuredOpType op);
-/// Specific parsing and printing for named structured ops created by ods-gen.
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
- TypeRange inputTypes, TypeRange outputTypes,
- ArrayRef<OpAsmParser::OperandType> captures = {});
-
+ TypeRange inputTypes, TypeRange outputTypes);
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes);
template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
- ArrayRef<OpAsmParser::OperandType> captures = {});
+static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
+ OperationState &result);
+
+template <typename NamedStructuredOpType>
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+ NamedStructuredOpType op);
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes);
}
//===----------------------------------------------------------------------===//
-// CopyOp
-//===----------------------------------------------------------------------===//
-void CopyOp::regionBuilder(Block &block, ValueRange captures) {
- using namespace edsc::intrinsics;
- assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
- (linalg_yield(block.getArgument(0)));
-}
-
-void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
- Value output, AffineMap inputPermutation,
- AffineMap outputPermutation,
- ArrayRef<NamedAttribute> namedAttrs) {
- result.addOperands({input, output});
- result.addAttributes(namedAttrs);
- if (inputPermutation)
- result.addAttribute("inputPermutation",
- AffineMapAttr::get(inputPermutation));
- if (outputPermutation)
- result.addAttribute("outputPermutation",
- AffineMapAttr::get(outputPermutation));
- result.addRegion();
- fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
- TypeRange{input.getType()},
- TypeRange{output.getType()});
-}
-
-ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
- Type outputType) {
- OpBuilder opBuilder(parser.getBuilder().getContext());
- fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
- TypeRange{outputType});
- return success();
-}
-
-/// CopyOp region is elided when printing.
-void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
-
-static LogicalResult verify(CopyOp op) {
- auto outputViewType = op.getOutputShapedType(0);
- auto inputViewType = op.getInputShapedType(0);
- if (inputViewType.getElementType() != outputViewType.getElementType())
- return op.emitOpError("expects views of the same type");
- if (inputViewType.getRank() != outputViewType.getRank())
- return op.emitOpError("expects views of the same rank");
- auto rank = op.getNumParallelLoops();
- auto inputPermutationMap = op.inputPermutation();
- if (inputPermutationMap) {
- if (inputPermutationMap->getNumInputs() != rank)
- return op.emitOpError("expects optional input_permutation map of rank ")
- << rank;
- if (!inputPermutationMap->isPermutation())
- return op.emitOpError(
- "expects optional input_permutation map to be a permutation");
- }
- auto outputPermutationMap = op.outputPermutation();
- if (outputPermutationMap) {
- if (outputPermutationMap->getNumInputs() != rank)
- return op.emitOpError("expects optional output_permutation map of rank ")
- << rank;
- if (!outputPermutationMap->isPermutation())
- return op.emitOpError(
- "expects optional output_permutation map to be a permutation");
- }
- if (rank == 0 && inputPermutationMap)
- return op.emitOpError("expected no input permutation when rank == 0");
- if (rank == 0 && outputPermutationMap)
- return op.emitOpError("expected no output permutation when rank == 0");
- return success();
-}
-
-void CopyOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- effects.emplace_back(MemoryEffects::Read::get(), input(),
- SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), output(),
- SideEffects::DefaultResource::get());
-}
-
-//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
-void FillOp::regionBuilder(Block &block, ValueRange captures) {
- using namespace edsc::intrinsics;
- assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
- (linalg_yield(captures));
-}
void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
Value value) {
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
value);
- fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
- TypeRange{output.getType()}, value);
-}
-
-ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
- OpAsmParser::OperandType valueRef) {
- OpBuilder opBuilder(parser.getBuilder().getContext());
- // Resolve `valueRef` into `value` at parse time so we can build the region
- // with captures.
- SmallVector<Value> value;
- parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
- fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
- TypeRange{outputType}, value);
- return success();
-}
-
-/// FillOp region is elided when printing.
-void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
-
-static LogicalResult verify(FillOp op) {
- auto viewType = op.getOutputShapedType(0);
- auto fillType = op.value().getType();
- if (viewType.getElementType() != fillType)
- return op.emitOpError("expects fill type to match view elemental type");
- if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
- return op.emitOpError(
- "expected fill op with no result value to use memref type");
- }
- return success();
-}
-
-void FillOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- if (output().getType().isa<MemRefType>())
- effects.emplace_back(MemoryEffects::Write::get(), output(),
- SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===//
result.addAttributes(attrs);
}
+
static LogicalResult verify(InitTensorOp op) {
RankedTensorType resultType = op.getType();
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
/////// Operations corresponding to library calls defined with Tablegen ////////
+void FillOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (output().getType().isa<MemRefType>())
+ effects.emplace_back(MemoryEffects::Write::get(), output(),
+ SideEffects::DefaultResource::get());
+}
+
+static LogicalResult verify(FillOp op) {
+ auto viewType = op.getOutputShapedType(0);
+ auto fillType = op.value().getType();
+ if (viewType.getElementType() != fillType)
+ return op.emitOpError("expects fill type to match view elemental type");
+ if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
+ return op.emitOpError(
+ "expected fill op with no result value to use memref type");
+ }
+ return success();
+}
+
+void CopyOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), input(),
+ SideEffects::DefaultResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), output(),
+ SideEffects::DefaultResource::get());
+}
+
+static LogicalResult verify(CopyOp op) {
+ auto outputViewType = op.getOutputShapedType(0);
+ auto inputViewType = op.getInputShapedType(0);
+ if (inputViewType.getElementType() != outputViewType.getElementType())
+ return op.emitOpError("expects views of the same type");
+ if (inputViewType.getRank() != outputViewType.getRank())
+ return op.emitOpError("expects views of the same rank");
+ auto rank = op.getNumParallelLoops();
+ auto inputPermutationMap = op.inputPermutation();
+ if (inputPermutationMap) {
+ if (inputPermutationMap->getNumInputs() != rank)
+ return op.emitOpError("expects optional input_permutation map of rank ")
+ << rank;
+ if (!inputPermutationMap->isPermutation())
+ return op.emitOpError(
+ "expects optional input_permutation map to be a permutation");
+ }
+ auto outputPermutationMap = op.outputPermutation();
+ if (outputPermutationMap) {
+ if (outputPermutationMap->getNumInputs() != rank)
+ return op.emitOpError("expects optional output_permutation map of rank ")
+ << rank;
+ if (!outputPermutationMap->isPermutation())
+ return op.emitOpError(
+ "expects optional output_permutation map to be a permutation");
+ }
+ if (rank == 0 && inputPermutationMap)
+ return op.emitOpError("expected no input permutation when rank == 0");
+ if (rank == 0 && outputPermutationMap)
+ return op.emitOpError("expected no output permutation when rank == 0");
+ return success();
+}
+
template <typename LinalgPoolingOp>
static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
ArrayRef<Attribute> attrs,
}
//===----------------------------------------------------------------------===//
-// Support for named Linalg ops defined in ods-gen.
+// Auto-generated Linalg named ops.
//===----------------------------------------------------------------------===//
-/// Generic entry point to create the block for the region of a LinalgOp.
-/// This is used by both named structured ops created by ods-gen and by manually
-/// defined C++ ops.
-/// This is used by both builders and parsers.
-/// This function creates the block in the region with arguments corresponding
-/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
-/// to be ShapedType.
template <typename NamedStructuredOpType>
-static void
-fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
- TypeRange inputTypes, TypeRange outputTypes,
- ValueRange captures,
- std::function<void(unsigned, unsigned)> errorHandler) {
- assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
- assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
-
+static void buildNamedStructuredOpRegionAndAttributesImpl(
+ OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
+ TypeRange outputTypes,
+ std::function<void(unsigned, unsigned)> errorHandler) {
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
SmallVector<Type, 8> argTypes;
// RAII.
OpBuilder::InsertionGuard guard(opBuilder);
- Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes);
+ Block *body = opBuilder.createBlock(®ion, {}, argTypes);
unsigned actual = body->getNumArguments();
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
if (expected != actual)
opBuilder.setInsertionPointToStart(body);
mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
- NamedStructuredOpType::regionBuilder(*body, captures);
+ NamedStructuredOpType::regionBuilder(*body);
// indexing_maps is an auto-generated method.
// iterator_types is an auto-generated method.
}
-/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
-void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
- OperationState &result,
- TypeRange inputTypes,
- TypeRange outputTypes,
- ValueRange captures) {
+void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
+ OperationState &result,
+ TypeRange inputTypes,
+ TypeRange outputTypes) {
Region ®ion = *result.addRegion();
- fillStructuredOpRegion<NamedStructuredOpType>(
- opBuilder, region, inputTypes, outputTypes, captures,
+ buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
+ opBuilder, region, inputTypes, outputTypes,
[&](unsigned expected, unsigned actual) {
+ llvm::errs() << "region expects " << expected << " args, got "
+ << actual;
assert(expected != actual && "incorrect number of arguments");
});
}
-/// Common parsing used for both named structured ops created by ods-gen and by
-/// manually defined C++ ops. Does not handle regions.
+template <typename NamedStructuredOpType>
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
+ TypeRange inputTypes, TypeRange outputTypes) {
+ ParseResult res = success();
+ OpBuilder opBuilder(parser.getBuilder().getContext());
+ buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
+ opBuilder, region, inputTypes, outputTypes,
+ [&](unsigned expected, unsigned actual) {
+ res = parser.emitError(parser.getCurrentLocation(),
+ llvm::formatv("region expects {0} args, got {1}",
+ expected, actual));
+ });
+ return res;
+}
+
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+ SmallVectorImpl<Type> &resultTypes) {
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(resultTypes))
+ return failure();
+ return success();
+}
+
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
}
template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
- NamedStructuredOpType op) {
- if (!op.inputs().empty())
- p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
- if (!op.outputs().empty())
- p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
-}
-
-//===----------------------------------------------------------------------===//
-// Specific parsing and printing for named structured ops created by ods-gen.
-//===----------------------------------------------------------------------===//
-
-template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
- TypeRange inputTypes, TypeRange outputTypes,
- ArrayRef<OpAsmParser::OperandType> captures) {
- ParseResult res = success();
- OpBuilder opBuilder(parser.getBuilder().getContext());
- // Resolve `captures` into `capturedValues` at parse time so we can build the
- // region with captures.
- SmallVector<Value> capturedValues;
- fillStructuredOpRegion<NamedStructuredOpType>(
- opBuilder, region, inputTypes, outputTypes, capturedValues,
- [&](unsigned expected, unsigned actual) {
- res = parser.emitError(
- parser.getCurrentLocation(),
- llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
- "region expects {0} args, got {1}",
- expected, actual));
- region.front().dump();
- });
- return res;
-}
-
-static ParseResult
-parseNamedStructuredOpResults(OpAsmParser &parser,
- SmallVectorImpl<Type> &resultTypes) {
- if (succeeded(parser.parseOptionalArrow()))
- if (parser.parseTypeList(resultTypes))
- return failure();
- return success();
-}
-
-template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
- ArrayRef<OpAsmParser::OperandType> captures) {
- // TODO: Enable when ods-gen supports captures.
- assert(captures.empty() && "unexpected captures for named structured ops");
+static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
+ OperationState &result) {
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
std::unique_ptr<Region> region = std::make_unique<Region>();
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
- parser, *region, inputTypes, outputTypes, captures))
+ parser, *region, inputTypes, outputTypes))
return failure();
result.addRegion(std::move(region));
}
template <typename NamedStructuredOpType>
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+ NamedStructuredOpType op) {
+ if (!op.inputs().empty())
+ p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
+ if (!op.outputs().empty())
+ p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
+}
+
+template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
p << op.getOperationName();
p.printOptionalAttrDict(op.getAttrs(),
return verifyGenericOp<NamedStructuredOpType>(op);
}
-//===----------------------------------------------------------------------===//
-// Canonicalizers and Folders.
-//===----------------------------------------------------------------------===//
-
namespace {
struct EraseDeadLinalgOp : public RewritePattern {
EraseDeadLinalgOp(PatternBenefit benefit = 1)