Use TestDialect to test traits instead of unittest.
authorJacques Pienaar <jpienaar@google.com>
Fri, 24 May 2019 23:17:52 +0000 (16:17 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:01:12 +0000 (20:01 -0700)
--

PiperOrigin-RevId: 249916947

mlir/test/IR/traits.mlir [new file with mode: 0644]
mlir/test/TestDialect/TestOps.td
mlir/unittests/IR/CMakeLists.txt
mlir/unittests/IR/OpDefinitionTest.cpp [deleted file]

diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
new file mode 100644 (file)
index 0000000..c9e1f86
--- /dev/null
@@ -0,0 +1,42 @@
+// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
+
+// CHECK: succeededSameOperandAndResultElementType
+func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
+  %0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %1 = "test.same_operand_and_result_type"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xf32>
+  %2 = "test.same_operand_and_result_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xf32>
+  %3 = "test.same_operand_and_result_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %4 = "test.same_operand_and_result_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xf32>
+  return
+}
+
+// -----
+
+func @failedSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
+  // expected-error@+1 {{requires the same element type for all operands and results}}
+  %0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi32>
+}
+
+// -----
+
+func @failedSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
+  // expected-error@+1 {{requires the same element type for all operands and results}}
+  %0 = "test.same_operand_and_result_type"(%t1, %t1i) : (tensor<1xf32>, tensor<1xi32>) -> tensor<1xf32>
+}
+
+// -----
+
+// CHECK: succeededSameOperandAndResultShape
+func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
+  %0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %1 = "test.same_operand_and_result_shape"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+  %2 = "test.same_operand_and_result_shape"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32>
+  return
+}
+
+// -----
+
+func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>) {
+  // expected-error@+1 {{requires the same shape for all operands and results}}
+  %0 = "test.same_operand_and_result_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+}
index 984b9f2..65a21b8 100644 (file)
@@ -50,6 +50,8 @@ def VUVFoldTwoResultOp : Pattern<(VUVTwoResultOp $input), [
 // Test Types
 //===----------------------------------------------------------------------===//
 
+def AnyVectorOrTensor: AnyTypeOf<[AnyVector, AnyTensor]>;
+
 def TupleOp : TEST_Op<"tuple_32_bit"> {
   let results = (outs TupleOf<[I32, F32]>);
 }
@@ -58,4 +60,21 @@ def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> {
   let results = (outs NestedTupleOf<[I32, F32]>);
 }
 
-#endif // TEST_OPS
\ No newline at end of file
+
+//===----------------------------------------------------------------------===//
+// Test Traits
+//===----------------------------------------------------------------------===//
+
+def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_type",
+    [SameOperandsAndResultElementType]> {
+  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
+  let results = (outs AnyVectorOrTensor:$res);
+}
+
+def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
+    [SameValueShape]> {
+  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
+  let results = (outs AnyVectorOrTensor:$res);
+}
+
+#endif // TEST_OPS
index 20e8722..5236ea7 100644 (file)
@@ -1,7 +1,6 @@
 add_mlir_unittest(MLIRIRTests
   DialectTest.cpp
   OperationSupportTest.cpp
-  OpDefinitionTest.cpp
 )
 target_link_libraries(MLIRIRTests
   PRIVATE
diff --git a/mlir/unittests/IR/OpDefinitionTest.cpp b/mlir/unittests/IR/OpDefinitionTest.cpp
deleted file mode 100644 (file)
index dcc83ac..0000000
+++ /dev/null
@@ -1,131 +0,0 @@
-//===- OpDefinitionTest.cpp - Op definition unit tests --------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/StandardTypes.h"
-#include "gmock/gmock.h"
-
-using namespace mlir;
-using namespace mlir::OpTrait::impl;
-
-namespace {
-
-#define FILE_LOC                                                               \
-  FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0,   \
-                      &context)
-
-// TODO: Replace with regular test once this trait is used by operation in core.
-// TODO(b/132891206): Replace with dialect test.
-TEST(OpDefinitionTest, SameOperandAndResultElementType) {
-  MLIRContext context;
-  Builder b(&context);
-  auto *operandtF32x10x10 = Operation::create(
-      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
-      /*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
-      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
-      /*resizableOperandList=*/false, &context);
-  auto *operandtF32x1 = Operation::create(
-      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
-      /*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
-      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
-      /*resizableOperandList=*/false, &context);
-  auto *operandvF32x1 = Operation::create(
-      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
-      /*resultTypes=*/{b.getVectorType({1}, b.getF32Type())},
-      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
-      /*resizableOperandList=*/false, &context);
-  auto *operandtI32x1 = Operation::create(
-      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
-      /*resultTypes=*/{b.getTensorType({1}, b.getIntegerType(32))},
-      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
-      /*resizableOperandList=*/false, &context);
-
-  // Verifies whether an op with x and y as inputs and resultType satisfies the
-  // SameOperandAndResultElementType trait.
-  auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
-    auto op = Operation::create(loc, OperationName("some_op", &context),
-                                /*operands=*/{x->getResult(0), y->getResult(0)},
-                                /*resultTypes=*/{resultType},
-                                /*attributes=*/llvm::None, /*successors=*/{},
-                                /*numRegions=*/0,
-                                /*resizableOperandList=*/false, &context);
-    return succeeded(verifySameOperandsAndResultElementType(op));
-  };
-
-  EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
-                    b.getTensorType({12}, b.getF32Type())));
-  EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
-                    b.getTensorType({5}, b.getF32Type())));
-  EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
-                     b.getTensorType({7}, b.getF32Type())));
-  EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
-                     b.getTensorType({12}, b.getIntegerType(32))));
-  EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
-                     b.getTensorType({9}, b.getIntegerType(32))));
-  EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
-                    b.getVectorType({9}, b.getF32Type())));
-  EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandvF32x1,
-                    b.getVectorType({9}, b.getF32Type())));
-  EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandvF32x1,
-                    b.getTensorType({5}, b.getF32Type())));
-  EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
-                     b.getTensorType({5}, b.getF32Type())));
-}
-
-TEST(OpDefinitionTest, SameOperandAndResultShape) {
-  MLIRContext context;
-  Builder b(&context);
-  auto *operandtF32x10x10 = Operation::create(
-      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
-      /*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
-      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
-      /*resizableOperandList=*/false, &context);
-  auto *operandtF32x1 = Operation::create(
-      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
-      /*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
-      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
-      /*resizableOperandList=*/false, &context);
-  auto *operandtF32xunranked = Operation::create(
-      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
-      /*resultTypes=*/{b.getTensorType(b.getF32Type())},
-      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
-      /*resizableOperandList=*/false, &context);
-
-  // SameOperandAndResultShape trait.
-  auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
-    auto op = Operation::create(loc, OperationName("some_op", &context),
-                                /*operands=*/{x->getResult(0), y->getResult(0)},
-                                /*resultTypes=*/{resultType},
-                                /*attributes=*/llvm::None, /*successors=*/{},
-                                /*numRegions=*/0,
-                                /*resizableOperandList=*/false, &context);
-    return succeeded(verifySameOperandsAndResultShape(op));
-  };
-
-  EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
-                    b.getTensorType({1}, b.getF32Type())));
-  EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
-                     b.getTensorType({12}, b.getF32Type())));
-  EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x10x10,
-                     b.getTensorType({1}, b.getF32Type())));
-  EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32xunranked,
-                    b.getTensorType({1}, b.getF32Type())));
-}
-
-#undef FILE_LOC
-} // end namespace