From: River Riddle Date: Sun, 6 Feb 2022 20:33:08 +0000 (-0800) Subject: [mlir][NFC] Remove deprecated/old build/fold/parser utilities from OpDefinition X-Git-Tag: upstream/15.0.7~17448 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=60cac0c0816193f3d910cd8bdaebac8e6694a6bd;p=platform%2Fupstream%2Fllvm.git [mlir][NFC] Remove deprecated/old build/fold/parser utilities from OpDefinition These have generally been replaced by better ODS functionality, and do not need to be explicitly provided anymore. Differential Revision: https://reviews.llvm.org/D119065 --- diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index e79fafa..cf20385 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2534,19 +2534,15 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> { class fir_ArithmeticOp traits = []> : fir_Op, - Results<(outs AnyType)> { - let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);"; - - let printer = "return printBinaryOp(this->getOperation(), p);"; + Results<(outs AnyType:$result)> { + let assemblyFormat = "operands attr-dict `:` type($result)"; } class fir_UnaryArithmeticOp traits = []> : fir_Op, - Results<(outs AnyType)> { - let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);"; - - let printer = "return printUnaryOp(this->getOperation(), p);"; + Results<(outs AnyType:$result)> { + let assemblyFormat = "operands attr-dict `:` type($result)"; } def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> { diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index bc91111..33d34b6 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -3211,26 +3211,6 @@ mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, return mlir::success(); } -/// Generic pretty-printer of a binary operation -static void printBinaryOp(Operation *op, OpAsmPrinter &p) { - assert(op->getNumOperands() == 2 && "binary op must have two operands"); - assert(op->getNumResults() == 1 && "binary op must have one result"); - - p << ' ' << op->getOperand(0) << ", " << op->getOperand(1); - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op->getResult(0).getType(); -} - -/// Generic pretty-printer of an unary operation -static void printUnaryOp(Operation *op, OpAsmPrinter &p) { - assert(op->getNumOperands() == 1 && "unary op must have one operand"); - assert(op->getNumResults() == 1 && "unary op must have one result"); - - p << ' ' << op->getOperand(0); - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op->getResult(0).getType(); -} - bool fir::isReferenceLike(mlir::Type type) { return type.isa() || type.isa() || type.isa(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 78b3af4..4a40132 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -419,8 +419,7 @@ class LLVM_CastOpgetOperation(), p); }]; + let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)"; } def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast", LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> { diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 79ad1ed..abb8231 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -383,7 +383,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [ }]; let hasFolder = 1; - let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index ec9116c..c600b31 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -34,6 +34,7 @@ class SPV_ArithmeticBinaryOp:$result ); + let assemblyFormat = "operands attr-dict `:` type($result)"; } class SPV_ArithmeticUnaryOp:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; // No additional verification needed in addition to the ODS-generated ones. let hasVerifier = 0; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td index e0f1603..b9b31f6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td @@ -21,7 +21,9 @@ class SPV_BitBinaryOp traits = []> : // All the operands type used in bit instructions are SPV_Integer. SPV_BinaryOp; + [NoSideEffect, SameOperandsAndResultType])> { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} class SPV_BitFieldExtractOp traits = []> : SPV_Op:$result ); - - let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; - let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; } // ----- @@ -85,9 +85,9 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> { SPV_ScalarOrVectorOrPtr:$result ); - let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; - let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; - + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td index 1532d3e..f0d5515 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td @@ -72,10 +72,6 @@ class SPV_GLSLBinaryOp:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; - let hasVerifier = 0; } @@ -83,7 +79,10 @@ class SPV_GLSLBinaryOp traits = []> : - SPV_GLSLBinaryOp; + SPV_GLSLBinaryOp { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} // Base class for GLSL ternary ops. class SPV_GLSLTernaryArithmeticOp:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; + let parser = [{ return parseOneResultSameOperandTypeOp(parser, result); }]; + let printer = [{ return printOneResultOp(getOperation(), p); }]; let hasVerifier = 0; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td index 92fb8a4..0b3c08a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -71,10 +71,6 @@ class SPV_OCLBinaryOp:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; - let hasVerifier = 0; } @@ -82,7 +78,10 @@ class SPV_OCLBinaryOp traits = []> : - SPV_OCLBinaryOp; + SPV_OCLBinaryOp { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} // ----- diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index ee745e3..f4a499a 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -14,6 +14,7 @@ #define SHAPE_OPS include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -331,7 +332,9 @@ def Shape_RankOp : Shape_Op<"rank", }]; } -def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { +def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [ + DeclareOpInterfaceMethods, NoSideEffect + ]> { let summary = "Creates a dimension tensor from a shape"; let description = [{ Converts a shape to a 1D integral tensor of extents. The number of elements @@ -624,7 +627,9 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", }]; } -def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { +def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [ + DeclareOpInterfaceMethods, NoSideEffect + ]> { let summary = "Casts between index types of the shape and standard dialect"; let description = [{ Converts a `shape.size` to a standard index. This operation and its diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index ac51c7b..e29b35f 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1897,26 +1897,9 @@ protected: }; //===----------------------------------------------------------------------===// -// Common Operation Folders/Parsers/Printers +// CastOpInterface utilities //===----------------------------------------------------------------------===// -// These functions are out-of-line implementations of the methods in UnaryOp and -// BinaryOp, which avoids them being template instantiated/duplicated. -namespace impl { -ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser, - OperationState &result); - -void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, - Value rhs); -ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, - OperationState &result); - -// Prints the given binary `op` in custom assembly form if both the two operands -// and the result have the same time. Otherwise, prints the generic assembly -// form. -void printOneResultOp(Operation *op, OpAsmPrinter &p); -} // namespace impl - // These functions are out-of-line implementations of the methods in // CastOpInterface, which avoids them being template instantiated/duplicated. namespace impl { @@ -1927,20 +1910,6 @@ LogicalResult foldCastInterfaceOp(Operation *op, /// Attempt to verify the given cast operation. LogicalResult verifyCastInterfaceOp( Operation *op, function_ref areCastCompatible); - -// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the -// need for them, but some older ODS code in `std` still depends on them). -void buildCastOp(OpBuilder &builder, OperationState &result, Value source, - Type destType); -ParseResult parseCastOp(OpAsmParser &parser, OperationState &result); -void printCastOp(Operation *op, OpAsmPrinter &p); -// TODO: These methods are deprecated in favor of CastOpInterface. Remove them -// when all uses have been updated. Also, consider adding functionality to -// CastOpInterface to be able to perform the ChainedTensorCast canonicalization -// generically. -Value foldCastOp(Operation *op); -LogicalResult verifyCastOp(Operation *op, - function_ref areCastCompatible); } // namespace impl } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index da672dc..33192e1 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -65,10 +65,6 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) { return NoneType::get(type.getContext()); } -LogicalResult memref::CastOp::verify() { - return impl::verifyCastOp(*this, areCastCompatible); -} - //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index cb476dc..1d06047 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -64,6 +64,54 @@ static constexpr const char kCompositeSpecConstituentsName[] = "constituents"; // Common utility functions //===----------------------------------------------------------------------===// +static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, + OperationState &result) { + SmallVector ops; + Type type; + // If the operand list is in-between parentheses, then we have a generic form. + // (see the fallback in `printOneResultOp`). + SMLoc loc = parser.getCurrentLocation(); + if (!parser.parseOptionalLParen()) { + if (parser.parseOperandList(ops) || parser.parseRParen() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseType(type)) + return failure(); + auto fnType = type.dyn_cast(); + if (!fnType) { + parser.emitError(loc, "expected function type"); + return failure(); + } + if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) + return failure(); + result.addTypes(fnType.getResults()); + return success(); + } + return failure(parser.parseOperandList(ops) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperands(ops, type, result.operands) || + parser.addTypeToList(type, result.types)); +} + +static void printOneResultOp(Operation *op, OpAsmPrinter &p) { + assert(op->getNumResults() == 1 && "op should have one result"); + + // If not all the operand and result types are the same, just use the + // generic assembly form to avoid omitting information in printing. + auto resultType = op->getResult(0).getType(); + if (llvm::any_of(op->getOperandTypes(), + [&](Type type) { return type != resultType; })) { + p.printGenericOp(op, /*printOpName=*/false); + return; + } + + p << ' '; + p.printOperands(op->getOperands()); + p.printOptionalAttrDict(op->getAttrs()); + // Now we can output only one type for all operands and the result. + p << " : " << resultType; +} + /// Returns true if the given op is a function-like op or nested in a /// function-like op without a module-like op in the middle. static bool isNestedInFunctionOpInterface(Operation *op) { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index a25f6dd..9241e2d 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1692,7 +1692,7 @@ OpFoldResult SizeToIndexOp::fold(ArrayRef operands) { // `IntegerAttr`s which makes constant folding simple. if (Attribute arg = operands[0]) return arg; - return impl::foldCastOp(*this); + return OpFoldResult(); } void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -1700,6 +1700,12 @@ void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(context); } +bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + return inputs[0].isa() && outputs[0].isa(); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// @@ -1750,7 +1756,7 @@ LogicalResult SplitAtOp::fold(ArrayRef operands, OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { if (!operands[0]) - return impl::foldCastOp(*this); + return OpFoldResult(); Builder builder(getContext()); auto shape = llvm::to_vector<6>( operands[0].cast().getValues()); @@ -1759,6 +1765,21 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { return DenseIntElementsAttr::get(type, shape); } +bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + if (auto inputTensor = inputs[0].dyn_cast()) { + if (!inputTensor.getElementType().isa() || + inputTensor.getRank() != 1 || !inputTensor.isDynamicDim(0)) + return false; + } else if (!inputs[0].isa()) { + return false; + } + + TensorType outputTensor = outputs[0].dyn_cast(); + return outputTensor && outputTensor.getElementType().isa(); +} + //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index e679337..7273897 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1125,69 +1125,7 @@ bool OpTrait::hasElementwiseMappableTraits(Operation *op) { } //===----------------------------------------------------------------------===// -// BinaryOp implementation -//===----------------------------------------------------------------------===// - -// These functions are out-of-line implementations of the methods in BinaryOp, -// which avoids them being template instantiated/duplicated. - -void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, - Value rhs) { - assert(lhs.getType() == rhs.getType()); - result.addOperands({lhs, rhs}); - result.types.push_back(lhs.getType()); -} - -ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, - OperationState &result) { - SmallVector ops; - Type type; - // If the operand list is in-between parentheses, then we have a generic form. - // (see the fallback in `printOneResultOp`). - SMLoc loc = parser.getCurrentLocation(); - if (!parser.parseOptionalLParen()) { - if (parser.parseOperandList(ops) || parser.parseRParen() || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(type)) - return failure(); - auto fnType = type.dyn_cast(); - if (!fnType) { - parser.emitError(loc, "expected function type"); - return failure(); - } - if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) - return failure(); - result.addTypes(fnType.getResults()); - return success(); - } - return failure(parser.parseOperandList(ops) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperands(ops, type, result.operands) || - parser.addTypeToList(type, result.types)); -} - -void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { - assert(op->getNumResults() == 1 && "op should have one result"); - - // If not all the operand and result types are the same, just use the - // generic assembly form to avoid omitting information in printing. - auto resultType = op->getResult(0).getType(); - if (llvm::any_of(op->getOperandTypes(), - [&](Type type) { return type != resultType; })) { - p.printGenericOp(op, /*printOpName=*/false); - return; - } - - p << ' '; - p.printOperands(op->getOperands()); - p.printOptionalAttrDict(op->getAttrs()); - // Now we can output only one type for all operands and the result. - p << " : " << resultType; -} - -//===----------------------------------------------------------------------===// -// CastOp implementation +// CastOpInterface //===----------------------------------------------------------------------===// /// Attempt to fold the given cast operation. @@ -1232,50 +1170,6 @@ LogicalResult impl::verifyCastInterfaceOp( return success(); } -void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, - Type destType) { - result.addOperands(source); - result.addTypes(destType); -} - -ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType srcInfo; - Type srcType, dstType; - return failure(parser.parseOperand(srcInfo) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.resolveOperand(srcInfo, srcType, result.operands) || - parser.parseKeywordType("to", dstType) || - parser.addTypeToList(dstType, result.types)); -} - -void impl::printCastOp(Operation *op, OpAsmPrinter &p) { - p << ' ' << op->getOperand(0); - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op->getOperand(0).getType() << " to " - << op->getResult(0).getType(); -} - -Value impl::foldCastOp(Operation *op) { - // Identity cast - if (op->getOperand(0).getType() == op->getResult(0).getType()) - return op->getOperand(0); - return nullptr; -} - -LogicalResult -impl::verifyCastOp(Operation *op, - function_ref areCastCompatible) { - auto opType = op->getOperand(0).getType(); - auto resType = op->getResult(0).getType(); - if (!areCastCompatible(opType, resType)) - return op->emitError("operand type ") - << opType << " and result type " << resType - << " are cast incompatible"; - - return success(); -} - //===----------------------------------------------------------------------===// // Misc. utils //===----------------------------------------------------------------------===//