From 3812d956eaef834eb3794d311aef2097aac268e0 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 9 Jun 2019 07:00:09 -0700 Subject: [PATCH] [ODS] Support variadic operand/result verification This CL enables verification code generation for variadic operands and results. In verify(), we use fallback getter methods to access all the dynamic values belonging to one static variadic operand/result to reuse the value range calculation there. PiperOrigin-RevId: 252288219 --- mlir/include/mlir/IR/OpBase.td | 7 +-- mlir/include/mlir/LLVMIR/LLVMOps.td | 4 +- mlir/include/mlir/TableGen/Operator.h | 5 +- mlir/lib/StandardOps/Ops.cpp | 6 --- mlir/lib/TableGen/Operator.cpp | 4 +- mlir/test/IR/invalid-ops.mlir | 2 +- mlir/test/IR/operand.mlir | 35 ++++++++++++++ mlir/test/IR/result.mlir | 36 +++++++++++++++ mlir/test/TestDialect/TestOps.td | 25 ++++++++++ mlir/test/mlir-tblgen/op-operand.td | 18 -------- mlir/test/mlir-tblgen/op-result.td | 18 -------- mlir/test/mlir-tblgen/predicate.td | 6 ++- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 72 ++++++++++++++++------------- 13 files changed, 152 insertions(+), 86 deletions(-) create mode 100644 mlir/test/IR/operand.mlir create mode 100644 mlir/test/IR/result.mlir diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index ac8c652..0d32213 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -255,9 +255,7 @@ class TypeAlias : // class is used for supporting variadic operands/results. An op can declare no // more than one variadic operand/result, and that operand/result must be the // last one in the operand/result list. -class Variadic - // TODO(b/132908002): support variadic type conditions - : TypeConstraint, descr> { +class Variadic : TypeConstraint { Type baseType = type; } @@ -907,6 +905,9 @@ def Terminator : NativeOpTrait<"IsTerminator">; def FirstAttrDerivedResultType : GenInternalOpTrait<"FirstAttrDerivedResultType">; +// TODO(antiagainst): Turn the following into normal traits and generate +// verification for them. + // All variadic operands of the op have the same number of values. // A variadic operand contains an array of values whose array size is only // known at runtime. This trait requires all variadic operands of an op diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td index a207e94..e9f235a 100644 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -203,7 +203,9 @@ def LLVM_PtrToIntOp // Call-related operations. def LLVM_CallOp : LLVM_Op<"call">, Arguments<(ins OptionalAttr:$callee, - Variadic)>, + // TODO(b/133216756): fix test failure and + // change to LLVM_Type + Variadic)>, Results<(outs Variadic)>, LLVM_TwoBuilders { diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 6cc6bbc..4551790 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -69,11 +69,12 @@ public: std::string getQualCppClassName() const; using value_iterator = NamedTypeConstraint *; + using value_range = llvm::iterator_range; // Op result iterators. value_iterator result_begin(); value_iterator result_end(); - llvm::iterator_range getResults(); + value_range getResults(); // Returns the number of results this op produces. int getNumResults() const; @@ -110,7 +111,7 @@ public: // Op operand iterators. value_iterator operand_begin(); value_iterator operand_end(); - llvm::iterator_range getOperands(); + value_range getOperands(); int getNumOperands() const { return operands.size(); } NamedTypeConstraint &getOperand(int index) { return operands[index]; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 4b2940a..1ef3fcd 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1595,12 +1595,6 @@ static LogicalResult verify(ExtractElementOp op) { if (op.getType() != aggregateType.getElementType()) return op.emitOpError("result type must match element type of aggregate"); - // TODO(b/132908002) This should be covered by the op specification in - // tablegen, but for some reason it's not. - for (auto *idx : op.getIndices()) - if (!idx->getType().isIndex()) - return op.emitOpError("index to extract_element must have 'index' type"); - // Verify the # indices match if we have a ranked type. if (aggregateType.hasRank() && aggregateType.getRank() != op.getNumOperands() - 1) diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 3c269ba..a5dd9c25 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -95,7 +95,7 @@ auto tblgen::Operator::result_begin() -> value_iterator { auto tblgen::Operator::result_end() -> value_iterator { return results.end(); } -auto tblgen::Operator::getResults() -> llvm::iterator_range { +auto tblgen::Operator::getResults() -> value_range { return {result_begin(), result_end()}; } @@ -205,7 +205,7 @@ auto tblgen::Operator::operand_begin() -> value_iterator { auto tblgen::Operator::operand_end() -> value_iterator { return operands.end(); } -auto tblgen::Operator::getOperands() -> llvm::iterator_range { +auto tblgen::Operator::getOperands() -> value_range { return {operand_begin(), operand_end()}; } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 562c2ce..6baa104 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -639,7 +639,7 @@ func @extract_element_no_indices(%v : vector<3xf32>) { // ----- func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) { - // expected-error@+1 {{index to extract_element must have 'index' type}} + // expected-error@+1 {{operand #1 must be index}} %0 = "std.extract_element"(%v, %i) : (vector<3xf32>, i32) -> f32 return } diff --git a/mlir/test/IR/operand.mlir b/mlir/test/IR/operand.mlir new file mode 100644 index 0000000..0d7939f --- /dev/null +++ b/mlir/test/IR/operand.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test mixed normal and variadic operands +//===----------------------------------------------------------------------===// + +func @correct_variadic_operand(%arg0: tensor, %arg1: f32) { + // CHECK: mixed_normal_variadic_operand + "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg0, %arg0) : (tensor, tensor, tensor, tensor, tensor) -> () + return +} + +// ----- + +func @error_in_first_variadic_operand(%arg0: tensor, %arg1: f32) { + // expected-error @+1 {{operand #0 must be tensor of any type}} + "test.mixed_normal_variadic_operand"(%arg0, %arg1, %arg0, %arg0, %arg0) : (tensor, f32, tensor, tensor, tensor) -> () + return +} + +// ----- + +func @error_in_normal_operand(%arg0: tensor, %arg1: f32) { + // expected-error @+1 {{operand #1 must be tensor of any type}} + "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg1, %arg0, %arg0) : (tensor, tensor, f32, tensor, tensor) -> () + return +} + +// ----- + +func @error_in_second_variadic_operand(%arg0: tensor, %arg1: f32) { + // expected-error @+1 {{operand #2 must be tensor of any type}} + "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg1, %arg0) : (tensor, tensor, tensor, f32, tensor) -> () + return +} diff --git a/mlir/test/IR/result.mlir b/mlir/test/IR/result.mlir new file mode 100644 index 0000000..fc5d597 --- /dev/null +++ b/mlir/test/IR/result.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test mixed normal and variadic results +//===----------------------------------------------------------------------===// + +func @correct_variadic_result() -> tensor { + // CHECK: mixed_normal_variadic_result + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, tensor, tensor, tensor, tensor) + return %0#4 : tensor +} + +// ----- + +func @error_in_first_variadic_result() -> tensor { + // expected-error @+1 {{result #0 must be tensor of any type}} + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, f32, tensor, tensor, tensor) + return %0#4 : tensor +} + +// ----- + +func @error_in_normal_result() -> tensor { + // expected-error @+1 {{result #1 must be tensor of any type}} + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, tensor, f32, tensor, tensor) + return %0#4 : tensor +} + +// ----- + +func @error_in_second_variadic_result() -> tensor { + // expected-error @+1 {{result #2 must be tensor of any type}} + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, tensor, tensor, f32, tensor) + return %0#4 : tensor +} + diff --git a/mlir/test/TestDialect/TestOps.td b/mlir/test/TestDialect/TestOps.td index 845b08d..10c144f 100644 --- a/mlir/test/TestDialect/TestOps.td +++ b/mlir/test/TestDialect/TestOps.td @@ -60,6 +60,31 @@ def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> { let results = (outs NestedTupleOf<[I32, F32]>); } +//===----------------------------------------------------------------------===// +// Test Operands +//===----------------------------------------------------------------------===// + +def MixedNormalVariadicOperandOp : TEST_Op< + "mixed_normal_variadic_operand", [SameVariadicOperandSize]> { + let arguments = (ins + Variadic:$input1, + AnyTensor:$input2, + Variadic:$input3 + ); +} + +//===----------------------------------------------------------------------===// +// Test Results +//===----------------------------------------------------------------------===// + +def MixedNormalVariadicResults : TEST_Op< + "mixed_normal_variadic_result", [SameVariadicResultSize]> { + let results = (outs + Variadic:$output1, + AnyTensor:$output2, + Variadic:$output3 + ); +} //===----------------------------------------------------------------------===// // Test Attributes diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 6055081..ea567e4 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -26,10 +26,6 @@ def OpA : NS_Op<"one_normal_operand_op", []> { // CHECK: assert(operands.size() == 1u && "mismatched number of parameters"); // CHECK: tblgen_state->addOperands(operands); -// CHECK: LogicalResult OpA::verify() { -// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32)))) -// CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer"); - def OpB : NS_Op<"one_variadic_operand_op", []> { let arguments = (ins Variadic:$input); } @@ -52,20 +48,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> // CHECK-LABEL: ArrayRef OpDOperandAdaptor::input3 // CHECK-NEXT: return getODSOperands(2); -// TODO(b/134305899): Move to use TestDialect after fixing verification. - -// CHECK-LABEL: Operation::operand_range OpD::getODSOperands(unsigned index) -// CHECK-NEXT: bool isVariadic[] = {true, false, true}; -// CHECK-NEXT: int prevVariadicCount = 0; -// CHECK-NEXT: for (int i = 0; i < index; ++i) -// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount; - -// CHECK: int variadicSize = (getOperation()->getNumOperands() - 1) / 2; -// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount; -// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1; - -// CHECK: return {std::next(getOperation()->operand_begin(), offset), std::next(getOperation()->operand_begin(), offset + size)}; - // CHECK-LABEL: Operation::operand_range OpD::input1 // CHECK-NEXT: return getODSOperands(0); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index e0f14e4..83f804a 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -17,10 +17,6 @@ def OpA : NS_Op<"one_normal_result_op", []> { // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types"); // CHECK-NEXT: tblgen_state->addTypes(resultTypes); -// CHECK-LABEL: LogicalResult OpA::verify() -// CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32)))) -// CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer"); - def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> { let arguments = (ins I32:$x); let results = (outs I32:$y); @@ -90,20 +86,6 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> let results = (outs Variadic:$output1, AnyTensor:$output2, Variadic:$output3); } -// TODO(b/134305899): Move to use TestDialect after fixing verification. - -// CHECK-LABEL: Operation::result_range OpI::getODSResults(unsigned index) -// CHECK-NEXT: bool isVariadic[] = {true, false, true}; -// CHECK-NEXT: int prevVariadicCount = 0; -// CHECK-NEXT: for (int i = 0; i < index; ++i) -// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount; - -// CHECK: int variadicSize = (getOperation()->getNumResults() - 1) / 2; -// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount; -// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1; - -// CHECK: return {std::next(getOperation()->result_begin(), offset), std::next(getOperation()->result_begin(), offset + size)}; - // CHECK-LABEL: Operation::result_range OpI::output1 // CHECK-NEXT: return getODSResults(0); diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 454a01b..7cf5a8d 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -16,7 +16,8 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> { } // CHECK-LABEL: OpA::verify -// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32) || this->getOperation()->getOperand(0)->getType().isF32()))) +// CHECK: for (Value *v : getODSOperands(0)) { +// CHECK: if (!((v->getType().isInteger(32) || v->getType().isF32()))) def OpB : NS_Op<"op_for_And_PredOpTrait", [ PredOpTrait<"both first and second holds", @@ -103,4 +104,5 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> { } // CHECK-LABEL: OpK::verify -// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa())) && (((this->getOperation()->getOperand(0)->getType().cast().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast().getElementType().isInteger(32)))))) +// CHECK: for (Value *v : getODSOperands(0)) { +// CHECK: if (!(((v->getType().isa())) && (((v->getType().cast().getElementType().isF32())) || ((v->getType().cast().getElementType().isInteger(32)))))) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 7183f34..7718a0d 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -448,6 +448,12 @@ private: // Generates verify method for the operation. void genVerifier(); + // Generates verify statements for operands and results in the operation. + // The generated code will be attached to `body`. + void genOperandResultVerifier(OpMethodBody &body, + Operator::value_range values, + StringRef valueKind); + // Generates verify statements for regions in the operation. // The generated code will be attached to `body`. void genRegionVerifier(OpMethodBody &body); @@ -1022,39 +1028,8 @@ void OpEmitter::genVerifier() { body << " }\n"; } - // Emits verification code for an operand or result. - auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index, - bool isOperand) -> void { - // TODO: Handle variadic operand/result verification. - if (value.isVariadic()) - return; - - // TODO: Commonality between matchers could be extracted to have a more - // concise code. - if (value.hasPredicate()) { - auto description = value.constraint.getDescription(); - body << " if (!(" - << tgfmt( - value.constraint.getConditionTemplate(), - &verifyCtx.withSelf("this->getOperation()->get" + - Twine(isOperand ? "Operand" : "Result") + - "(" + Twine(index) + ")->getType()")) - << ")) {\n"; - body << " return emitOpError(\"" << (isOperand ? "operand" : "result") - << " #" << index - << (description.empty() ? " type precondition failed" - : " must be " + Twine(description)) - << "\");\n }\n"; - } - }; - - for (int i = 0, e = op.getNumOperands(); i < e; ++i) { - verifyValue(op.getOperand(i), i, /*isOperand=*/true); - } - - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - verifyValue(op.getResult(i), i, /*isOperand=*/false); - } + genOperandResultVerifier(body, op.getOperands(), "operand"); + genOperandResultVerifier(body, op.getResults(), "result"); for (auto &trait : op.getTraits()) { if (auto t = dyn_cast(&trait)) { @@ -1073,6 +1048,37 @@ void OpEmitter::genVerifier() { body << " return mlir::success();\n"; } +void OpEmitter::genOperandResultVerifier(OpMethodBody &body, + Operator::value_range values, + StringRef valueKind) { + FmtContext fctx; + unsigned i = 0; + for (auto &staticValue : values) { + if (!staticValue.hasPredicate()) + continue; + + // Emit a loop to check all the dynamic values in the pack. + body << formatv(" for (Value *v : getODS{0}{1}s({2})) {{\n", + // Capitalize the first letter to match the function name + valueKind.substr(0, 1).upper(), valueKind.substr(1), i); + + auto description = staticValue.constraint.getDescription(); + body << " (void)v;\n"; + body << " if (!(" + << tgfmt(staticValue.constraint.getConditionTemplate(), + &fctx.withSelf("v->getType()")) + << "))\n"; + body << " return emitOpError(\"" + // TODO(b/129706806): Use the name of the operand/result here + << valueKind << " #" << i + << (description.empty() ? " type precondition failed" + : " must be " + Twine(description)) + << "\");\n"; + body << " }\n"; + ++i; + } +} + void OpEmitter::genRegionVerifier(OpMethodBody &body) { unsigned numRegions = op.getNumRegions(); -- 2.7.4