Add SameOperandsAndResultElementType trait.
authorJacques Pienaar <jpienaar@google.com>
Sat, 4 May 2019 18:14:40 +0000 (11:14 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:28:59 +0000 (08:28 -0700)
    This trait only works for tensor and vector types at the moment, verifying that the element type of an op with only tensor and vector types match. Added a unit test for it as there is no op currently in core that uses this trait.

--

PiperOrigin-RevId: 246661697

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Operation.cpp
mlir/unittests/IR/CMakeLists.txt
mlir/unittests/IR/OpDefinitionTest.cpp [new file with mode: 0644]

index a2b3ca3..8d3cdb3 100644 (file)
@@ -781,6 +781,9 @@ def NoSideEffect     : NativeOpTrait<"HasNoSideEffect">;
 def SameValueShape   : NativeOpTrait<"SameOperandsAndResultShape">;
 // Op has the same operand and result type.
 def SameValueType    : NativeOpTrait<"SameOperandsAndResultType">;
+// Op has the same operand and result element type.
+def SameOperandsAndResultElementType :
+  NativeOpTrait<"SameOperandsAndResultElementType">;
 // Op is a terminator.
 def Terminator       : NativeOpTrait<"IsTerminator">;
 
index 2532dc8..e551140 100644 (file)
@@ -317,6 +317,7 @@ LogicalResult verifyOneResult(Operation *op);
 LogicalResult verifyNResults(Operation *op, unsigned numOperands);
 LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
 LogicalResult verifySameOperandsAndResultShape(Operation *op);
+LogicalResult verifySameOperandsAndResultElementType(Operation *op);
 LogicalResult verifySameOperandsAndResultType(Operation *op);
 LogicalResult verifyResultsAreBoolLike(Operation *op);
 LogicalResult verifyResultsAreFloatLike(Operation *op);
@@ -573,11 +574,23 @@ public:
 };
 
 /// This class provides verification for ops that are known to have the same
+/// operand and result element type.
+///
+/// TODO: This only works for VectorOrTensorType at the moment.
+template <typename ConcreteType>
+class SameOperandsAndResultElementType
+    : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifySameOperandsAndResultElementType(op);
+  }
+};
+
+/// This class provides verification for ops that are known to have the same
 /// operand and result type.
 ///
-/// Note: this trait subsumes the SameOperandsAndResultShape trait.
-/// Additionally, it requires all operands and results should also have
-/// the same element type.
+/// Note: this trait subsumes the SameOperandsAndResultShape and
+/// SameOperandsAndResultElementType traits.
 template <typename ConcreteType>
 class SameOperandsAndResultType
     : public TraitBase<ConcreteType, SameOperandsAndResultType> {
index 9a13cee..067e86a 100644 (file)
@@ -798,6 +798,39 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
   return success();
 }
 
+LogicalResult
+OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return failure();
+
+  auto type = op->getResult(0)->getType().dyn_cast<VectorOrTensorType>();
+  if (!type)
+    return op->emitOpError("requires vector or tensor type results");
+  auto elementType = type.getElementType();
+
+  // Verify result element type matches first result's element type.
+  for (auto result : drop_begin(op->getResults(), 1)) {
+    auto resultType = result->getType().dyn_cast<VectorOrTensorType>();
+    if (!resultType)
+      return op->emitOpError("requires vector or tensor type results");
+    if (resultType.getElementType() != elementType)
+      return op->emitOpError(
+          "requires the same element type for all operands and results");
+  }
+
+  // Verify operand's element type matches first result's element type.
+  for (auto operand : op->getOperands()) {
+    auto operandType = operand->getType().dyn_cast<VectorOrTensorType>();
+    if (!operandType)
+      return op->emitOpError("requires vector or tensor type operands");
+    if (operandType.getElementType() != elementType)
+      return op->emitOpError(
+          "requires the same element type for all operands and results");
+  }
+
+  return success();
+}
+
 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
   if (op->getNumOperands() == 0 || op->getNumResults() == 0)
     return failure();
index 5d72c1f..5337eb4 100644 (file)
@@ -1,6 +1,7 @@
 add_mlir_unittest(MLIRIRTests
   DialectTest.cpp
   OperationSupportTest.cpp
+  OpDefinitionTest.cpp
   SDBMTest.cpp
 )
 target_link_libraries(MLIRIRTests
diff --git a/mlir/unittests/IR/OpDefinitionTest.cpp b/mlir/unittests/IR/OpDefinitionTest.cpp
new file mode 100644 (file)
index 0000000..4074009
--- /dev/null
@@ -0,0 +1,91 @@
+//===- OpDefinitionTest.cpp - Op definition unit tests --------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "gmock/gmock.h"
+
+using namespace mlir;
+using namespace mlir::OpTrait::impl;
+
+namespace {
+
+// TODO: Replace with regular test once this trait is used by operation in core.
+TEST(OpDefinitionTest, SameOperandAndResultElementType) {
+  MLIRContext context;
+#define FILE_LOC                                                               \
+  FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0,   \
+                      &context)
+
+  Builder b(&context);
+  auto *operandtF32x10x10 = Operation::create(
+      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
+      /*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
+      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+      /*resizableOperandList=*/false, &context);
+  auto *operandtF32x1 = Operation::create(
+      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
+      /*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
+      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+      /*resizableOperandList=*/false, &context);
+  auto *operandvF32x1 = Operation::create(
+      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
+      /*resultTypes=*/{b.getVectorType({1}, b.getF32Type())},
+      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+      /*resizableOperandList=*/false, &context);
+  auto *operandtI32x1 = Operation::create(
+      FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
+      /*resultTypes=*/{b.getTensorType({1}, b.getIntegerType(32))},
+      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+      /*resizableOperandList=*/false, &context);
+
+  // Verifies whether an op with x and y as inputs and resultType satisfies the
+  // SameOperandAndResultElementType trait.
+  auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
+    auto op = Operation::create(loc, OperationName("some_op", &context),
+                                /*operands=*/{x->getResult(0), y->getResult(0)},
+                                /*resultTypes=*/{resultType},
+                                /*attributes=*/llvm::None, /*successors=*/{},
+                                /*numRegions=*/0,
+                                /*resizableOperandList=*/false, &context);
+    return succeeded(verifySameOperandsAndResultElementType(op));
+  };
+
+  EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
+                    b.getTensorType({12}, b.getF32Type())));
+  EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
+                    b.getTensorType({5}, b.getF32Type())));
+  EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
+                     b.getTensorType({7}, b.getF32Type())));
+  EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
+                     b.getTensorType({12}, b.getIntegerType(32))));
+  EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
+                     b.getTensorType({9}, b.getIntegerType(32))));
+  EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
+                    b.getVectorType({9}, b.getF32Type())));
+  EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandvF32x1,
+                    b.getVectorType({9}, b.getF32Type())));
+  EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandvF32x1,
+                    b.getTensorType({5}, b.getF32Type())));
+  EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
+                     b.getTensorType({5}, b.getF32Type())));
+
+#undef FILE_LOC
+}
+
+} // end namespace