From: Benjamin Kramer Date: Fri, 18 Feb 2022 13:40:11 +0000 (+0100) Subject: [mlir][ODS] Infer return types if the operands are variadic but the results are not X-Git-Tag: upstream/15.0.7~15924 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3ce2ee28f042c2a00d09c228c76f2692778bd607;p=platform%2Fupstream%2Fllvm.git [mlir][ODS] Infer return types if the operands are variadic but the results are not Clean up code that worked around this limitation. Differential Revision: https://reviews.llvm.org/D120119 --- diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 19b0864..31a6952 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -192,8 +192,7 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect, }]; } -def Shape_ShapeEqOp : Shape_Op<"shape_eq", - [NoSideEffect, Commutative, InferTypeOpInterface]> { +def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative]> { let summary = "Returns whether the input shapes or extent tensors are equal"; let description = [{ Takes one or more shape or extent tensor operands and determines whether @@ -211,17 +210,6 @@ def Shape_ShapeEqOp : Shape_Op<"shape_eq", OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs), [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>, ]; - let extraClassDeclaration = [{ - // TODO: This should really be automatic. Figure out how to not need this defined. - static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, - ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { - inferredReturnTypes.push_back(::mlir::IntegerType::get(context, - /*width=*/1)); - return success(); - }; - }]; let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; let hasFolder = 1; @@ -262,8 +250,7 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> { let assemblyFormat = "$input attr-dict `:` type($input)"; } -def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", - [Commutative, InferTypeOpInterface]> { +def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> { let summary = "Determines if 2+ shapes can be successfully broadcasted"; let description = [{ Given multiple input shapes or extent tensors, return a predicate specifying @@ -289,17 +276,6 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs), [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>, ]; - let extraClassDeclaration = [{ - // TODO: This should really be automatic. Figure out how to not need this defined. - static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, - ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { - inferredReturnTypes.push_back(::mlir::IntegerType::get(context, - /*width=*/1)); - return success(); - }; - }]; let hasFolder = 1; let hasCanonicalizer = 1; @@ -850,12 +826,6 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]> let arguments = (ins Variadic:$inputs); let results = (outs Shape_WitnessType:$result); - // Only needed while tablegen is unable to generate this for ops with variadic - // arguments. - let builders = [ - OpBuilder<(ins "ValueRange":$inputs)>, - ]; - let assemblyFormat = "$inputs attr-dict"; let hasFolder = 1; @@ -917,8 +887,7 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } -def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", - [Commutative, InferTypeOpInterface]> { +def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> { let summary = "Determines if 2+ shapes can be successfully broadcasted"; let description = [{ Given input shapes or extent tensors, return a witness specifying if they @@ -944,23 +913,12 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>, ]; - let extraClassDeclaration = [{ - // TODO: This should really be automatic. Figure out how to not need this defined. - static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, - ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { - inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context)); - return success(); - }; - }]; - let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; } -def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> { +def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> { let summary = "Determines if all input shapes are equal"; let description = [{ Given 1 or more input shapes, determine if all shapes are the exact same. @@ -978,17 +936,6 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> { let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; - let extraClassDeclaration = [{ - // TODO: This should really be automatic. Figure out how to not need this defined. - static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, - ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { - inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context)); - return success(); - }; - }]; - let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index aec8dd5..2e7f069 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -803,8 +803,6 @@ def Vector_InsertMapOp : into vector<64x4x32xf32> ``` }]; - let builders = [OpBuilder<(ins "Value":$vector, "Value":$dest, - "ValueRange":$ids)>]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { return vector().getType().cast(); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 5c851f5..0f63331 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -663,11 +663,6 @@ LogicalResult AssumingAllOp::verify() { return success(); } -void AssumingAllOp::build(OpBuilder &b, OperationState &state, - ValueRange inputs) { - build(b, state, b.getType(), inputs); -} - //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ddfe0d8..f6547e4 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1900,11 +1900,6 @@ OpFoldResult vector::InsertOp::fold(ArrayRef operands) { // InsertMapOp //===----------------------------------------------------------------------===// -void InsertMapOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value dest, ValueRange ids) { - InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids); -} - LogicalResult InsertMapOp::verify() { if (getSourceVectorType().getRank() != getResultType().getRank()) return emitOpError("expected source and destination vectors of same rank"); diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index a71ae4d..2a0d49f 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -327,9 +327,8 @@ void Operator::populateTypeInferenceInfo( if (getNumResults() == 0) return; - // Skip for ops with variadic operands/results. - // TODO: This can be relaxed. - if (isVariadic()) + // Skip ops with variadic or optional results. + if (getNumVariableLengthResults() > 0) return; // Skip cases currently being custom generated.