From 7438dcb71f4f3a38ef0add60c6f307736412d188 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 4 Jun 2019 19:31:30 -0700 Subject: [PATCH] ODG: Always deference operand/result when using named arg/result. Considered adding more placeholders to designate types in the replacement pattern, but convinced for now sticking to simpler approach. This should at least enable specifying constraints across operands/results/attributes and we can start getting rid of the special cases. PiperOrigin-RevId: 251564893 --- mlir/include/mlir/IR/OpBase.td | 4 ++-- mlir/test/TestDialect/TestOps.td | 3 +-- mlir/test/mlir-tblgen/types.mlir | 11 +++++------ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 10 ++++++---- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index ff3da38..dd4c276 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1023,9 +1023,9 @@ class TCopVTEtIs : And<[ // Predicate to verify that a named argument or result's element type matches a // given type. class ArgOrResultElementTypeIs : And<[ - SubstLeaves<"$_self", "$" # name # "->getType()", IsShapedTypePred>, + SubstLeaves<"$_self", "$" # name # ".getType()", IsShapedTypePred>, SubstLeaves<"$_self", "$" # name # - "->getType().cast().getElementType()", + ".getType().cast().getElementType()", type.predicate>]>; // Predicate to verify that the i'th operand and the j'th operand have the same diff --git a/mlir/test/TestDialect/TestOps.td b/mlir/test/TestDialect/TestOps.td index a721510..b0296e1 100644 --- a/mlir/test/TestDialect/TestOps.td +++ b/mlir/test/TestDialect/TestOps.td @@ -91,8 +91,7 @@ def ArgAndResHaveFixedElementTypesOp : Or<[And<[ArgOrResultElementTypeIs<"x", I32>, ArgOrResultElementTypeIs<"y", F32>, ArgOrResultElementTypeIs<"res", I16>]>, - // TODO(jpienaar): change back to attr. - ArgOrResultElementTypeIs<"x", I8>]>>]> { + ArgOrResultElementTypeIs<"attr", I8>]>>]> { let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y, AnyAttr:$attr); let results = (outs AnyVectorOrTensor:$res); } diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir index 0b853b2..e84adcf 100644 --- a/mlir/test/mlir-tblgen/types.mlir +++ b/mlir/test/mlir-tblgen/types.mlir @@ -89,12 +89,11 @@ func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) { // ----- -// TODO(jpienaar): re-enable post supporting attributes again. -// DISABLED_CHECK-LABEL: @fixed_element_types -//func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) { -// %0 = "test.arg_and_res_have_fixed_element_types"(%arg0, %arg1) {attr: splat, 1>}: (tensor<* x i32>, tensor<* x f32>) -> tensor<* x i32> -// return -//} +// CHECK-LABEL: @fixed_element_types +func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) { + %0 = "test.arg_and_res_have_fixed_element_types"(%arg0, %arg1) {attr: splat, 1>}: (tensor<* x i32>, tensor<* x f32>) -> tensor<* x i32> + return +} // ----- diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 4261faf..9d4ab12 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -985,7 +985,9 @@ void OpEmitter::genVerifier() { auto &body = method.body(); // Populate substitutions for attributes and named operands and results. - // TODO(jpienaar): Add attributes back. + for (const auto &namedAttr : op.getAttributes()) + verifyCtx.addSubst(namedAttr.name, + formatv("this->getAttr(\"{0}\")", namedAttr.name)); for (int i = 0, e = op.getNumOperands(); i < e; ++i) { auto &value = op.getOperand(i); // Skip from from first variadic operands for now. Else getOperand index @@ -993,8 +995,8 @@ void OpEmitter::genVerifier() { if (value.isVariadic()) break; if (!value.name.empty()) - verifyCtx.addSubst(value.name, - formatv("this->getOperation()->getOperand({0})", i)); + verifyCtx.addSubst( + value.name, formatv("(*this->getOperation()->getOperand({0}))", i)); } for (int i = 0, e = op.getNumResults(); i < e; ++i) { auto &value = op.getResult(i); @@ -1004,7 +1006,7 @@ void OpEmitter::genVerifier() { break; if (!value.name.empty()) verifyCtx.addSubst(value.name, - formatv("this->getOperation()->getResult({0})", i)); + formatv("(*this->getOperation()->getResult({0}))", i)); } // Verify the attributes have the correct type. -- 2.7.4