From 0ce25b12357b24d06cf08cc02719c144d567d5db Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Sat, 28 Jan 2023 19:18:19 -0500 Subject: [PATCH] [mlir] Require explicit casts when using TypedValue Currently `TypedValue` can be constructed directly from `Value`, hiding errors that could be caught at compile time. For example the following will compile, but crash/assert at runtime: ``` void foo(TypedValue); void bar(TypedValue v) { foo(v); } ``` This change removes the constructors and replaces them with explicit llvm casts. Depends on D142852 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D142855 --- mlir/docs/Tutorials/Toy/Ch-4.md | 2 +- mlir/examples/toy/Ch4/mlir/Dialect.cpp | 6 ++--- mlir/examples/toy/Ch5/mlir/Dialect.cpp | 6 ++--- mlir/examples/toy/Ch6/mlir/Dialect.cpp | 6 ++--- mlir/examples/toy/Ch7/mlir/Dialect.cpp | 10 ++++---- mlir/include/mlir/IR/OpDefinition.h | 2 +- mlir/include/mlir/IR/Value.h | 18 ++++----------- mlir/include/mlir/TableGen/Class.h | 6 +++++ mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 2 +- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 2 +- .../Transforms/SparseTensorRewriting.cpp | 2 +- .../Transforms/SparseTensorStorageLayout.h | 3 ++- mlir/test/mlir-tblgen/op-operand.td | 2 +- mlir/test/mlir-tblgen/op-result.td | 2 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 27 ++++++++++++++-------- mlir/unittests/IR/IRMapping.cpp | 2 +- 16 files changed, 54 insertions(+), 44 deletions(-) diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md index 95cdca0..df82141 100644 --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -375,7 +375,7 @@ inferred as the shape of the inputs. ```c++ /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } ``` At this point, each of the necessary Toy operations provide a mechanism by which diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index c87e107..17a42d6 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -237,7 +237,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } /// Infer the output shape of the AddOp, this is required by the shape inference /// interface. -void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // CastOp @@ -245,7 +245,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } /// Infer the output shape of the CastOp, this is required by the shape /// inference interface. -void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } /// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify @@ -349,7 +349,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // ReturnOp diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index 04ae314..77ceb63 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -237,7 +237,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } /// Infer the output shape of the AddOp, this is required by the shape inference /// interface. -void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // CastOp @@ -245,7 +245,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } /// Infer the output shape of the CastOp, this is required by the shape /// inference interface. -void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } /// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify @@ -349,7 +349,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // ReturnOp diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index 04ae314..77ceb63 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -237,7 +237,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } /// Infer the output shape of the AddOp, this is required by the shape inference /// interface. -void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // CastOp @@ -245,7 +245,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } /// Infer the output shape of the CastOp, this is required by the shape /// inference interface. -void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } /// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify @@ -349,7 +349,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // ReturnOp diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 3382cbc..188b94f 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -243,7 +243,9 @@ mlir::LogicalResult StructConstantOp::verify() { /// Infer the output shape of the ConstantOp, this is required by the shape /// inference interface. -void ConstantOp::inferShapes() { getResult().setType(getValue().getType()); } +void ConstantOp::inferShapes() { + getResult().setType(cast(getValue().getType())); +} //===----------------------------------------------------------------------===// // AddOp @@ -264,7 +266,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } /// Infer the output shape of the AddOp, this is required by the shape inference /// interface. -void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // CastOp @@ -272,7 +274,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } /// Infer the output shape of the CastOp, this is required by the shape /// inference interface. -void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } /// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify @@ -376,7 +378,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // ReturnOp diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 60c961f..4c0a849 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -632,7 +632,7 @@ public: : public TraitBase::Impl> { public: TypedValue getResult() { - return this->getOperation()->getResult(0); + return cast>(this->getOperation()->getResult(0)); } /// If the operation returns a single value, then the Op can be implicitly diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index c84ae97..e95bfcf 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -427,21 +427,13 @@ inline unsigned OpResultImpl::getResultNumber() const { /// TypedValue can be null/empty template struct TypedValue : Value { + using Value::Value; + + static bool classof(Value value) { return llvm::isa(value.getType()); } + /// Return the known Type Ty getType() { return Value::getType().template cast(); } - void setType(mlir::Type ty) { - assert(ty.template isa()); - Value::setType(ty); - } - - TypedValue(Value val) : Value(val) { - assert(!val || val.getType().template isa()); - } - TypedValue &operator=(const Value &other) { - assert(!other || other.getType().template isa()); - Value::operator=(other); - return *this; - } + void setType(Ty ty) { Value::setType(ty); } }; } // namespace detail diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h index 954ef5b..9c4efcd 100644 --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -152,6 +152,9 @@ public: /// Get the name of the method. StringRef getName() const { return methodName; } + /// Get the return type of the method + StringRef getReturnType() const { return returnType; } + /// Get the number of parameters. unsigned getNumParameters() const { return parameters.getNumParameters(); } @@ -344,6 +347,9 @@ public: /// Returns the name of this method. StringRef getName() const { return methodSignature.getName(); } + /// Returns the return type of this method + StringRef getReturnType() const { return methodSignature.getReturnType(); } + /// Returns if this method makes the `other` method redundant. bool makesRedundant(const Method &other) const { return methodSignature.makesRedundant(other.methodSignature); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index aa41c96..46f5292 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1884,7 +1884,7 @@ public: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { llvm::SmallVector indices; for (unsigned int i = 0; i < inputTy.getRank(); i++) { - auto index = + Value index = rewriter.create(nestedLoc, i).getResult(); if (i == axis) { auto one = rewriter.create(nestedLoc, 1); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 931a85b..d8070b3 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1033,7 +1033,7 @@ struct UnrollTransferWriteConversion auto vec = getDataVector(xferOp); auto xferVecType = xferOp.getVectorType(); int64_t dimSize = xferVecType.getShape()[0]; - auto source = xferOp.getSource(); // memref or tensor to be written to. + Value source = xferOp.getSource(); // memref or tensor to be written to. auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); // Generate fully unrolled loop of transfer ops. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 1e2a7b0..9cd2331 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1056,7 +1056,7 @@ struct NewRewriter : public OpRewritePattern { // %t = sparse_tensor.ConvertOp %tmp RankedTensorType cooTp = getUnorderedCOOFromTypeWithOrdering(dstTp, encDst.getDimOrdering()); - auto cooBuffer = + Value cooBuffer = rewriter.create(loc, cooTp, dynSizesArray).getResult(); Value c0 = constantIndex(rewriter, loc, 0); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h index d5554ea..98c7551 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -173,7 +173,8 @@ private: class SparseTensorSpecifier { public: - explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {} + explicit SparseTensorSpecifier(Value specifier) + : specifier(cast>(specifier)) {} // Undef value for dimension sizes, all zero value for memory sizes. static Value getInitValue(OpBuilder &builder, Location loc, diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 7a76f98..68a9def 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -43,7 +43,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> // CHECK-NEXT: return getODSOperands(0); // CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpD::getInput2 -// CHECK-NEXT: return *getODSOperands(1).begin(); +// CHECK-NEXT: return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); // CHECK-LABEL: OpD::build // CHECK-NEXT: odsState.addOperands(input1); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 06178f8..d49bffa 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -100,7 +100,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> // CHECK-NEXT: return getODSResults(0); // CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpI::getOutput2 -// CHECK-NEXT: return *getODSResults(1).begin(); +// CHECK-NEXT: return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSResults(1).begin()); // CHECK-LABEL: OpI::build // CHECK-NEXT: odsState.addTypes(output1); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 7ed29f9..bc3e259 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1337,10 +1337,12 @@ generateNamedOperandGetters(const Operator &op, Class &opClass, : generateTypeForGetter(operand), name); ERROR_IF_PRUNED(m, name, op); - m->body().indent() << formatv( - "auto operands = getODSOperands({0});\n" - "return operands.empty() ? {1}{{} : *operands.begin();", - i, rangeElementType); + m->body().indent() << formatv("auto operands = getODSOperands({0});\n" + "return operands.empty() ? {1}{{} : ", + i, m->getReturnType()); + if (!isGenericAdaptorBase) + m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType()); + m->body() << "(*operands.begin());"; } else if (operand.isVariadicOfVariadic()) { std::string segmentAttr = op.getGetterName( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); @@ -1366,7 +1368,10 @@ generateNamedOperandGetters(const Operator &op, Class &opClass, : generateTypeForGetter(operand), name); ERROR_IF_PRUNED(m, name, op); - m->body() << " return *getODSOperands(" << i << ").begin();"; + m->body().indent() << "return "; + if (!isGenericAdaptorBase) + m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType()); + m->body() << llvm::formatv("(*getODSOperands({0}).begin());", i); } } } @@ -1489,9 +1494,11 @@ void OpEmitter::genNamedResultGetters() { if (result.isOptional()) { m = opClass.addMethod(generateTypeForGetter(result), name); ERROR_IF_PRUNED(m, name, op); - m->body() - << " auto results = getODSResults(" << i << ");\n" - << " return results.empty() ? ::mlir::Value() : *results.begin();"; + m->body() << " auto results = getODSResults(" << i << ");\n" + << llvm::formatv(" return results.empty()" + " ? {0}()" + " : ::llvm::cast<{0}>(*results.begin());", + m->getReturnType()); } else if (result.isVariadic()) { m = opClass.addMethod("::mlir::Operation::result_range", name); ERROR_IF_PRUNED(m, name, op); @@ -1499,7 +1506,9 @@ void OpEmitter::genNamedResultGetters() { } else { m = opClass.addMethod(generateTypeForGetter(result), name); ERROR_IF_PRUNED(m, name, op); - m->body() << " return *getODSResults(" << i << ").begin();"; + m->body() << llvm::formatv( + " return ::llvm::cast<{0}>(*getODSResults({1}).begin());", + m->getReturnType(), i); } } } diff --git a/mlir/unittests/IR/IRMapping.cpp b/mlir/unittests/IR/IRMapping.cpp index bf00e5e..8362797 100644 --- a/mlir/unittests/IR/IRMapping.cpp +++ b/mlir/unittests/IR/IRMapping.cpp @@ -32,7 +32,7 @@ TEST(IRMapping, TypedValue) { IRMapping mapping; mapping.map(i64Val, f64Val); - TypedValue typedI64Val = i64Val; + auto typedI64Val = cast>(i64Val); EXPECT_EQ(mapping.lookup(typedI64Val), f64Val); } -- 2.7.4