LogicalResult verifyNResults(Operation *op, unsigned numOperands);
LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
LogicalResult verifySameOperandsAndResultShape(Operation *op);
+LogicalResult verifySameOperandsAndResultElementType(Operation *op);
LogicalResult verifySameOperandsAndResultType(Operation *op);
LogicalResult verifyResultsAreBoolLike(Operation *op);
LogicalResult verifyResultsAreFloatLike(Operation *op);
};
/// This class provides verification for ops that are known to have the same
+/// operand and result element type.
+///
+/// TODO: This only works for VectorOrTensorType at the moment.
+template <typename ConcreteType>
+class SameOperandsAndResultElementType
+ : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return impl::verifySameOperandsAndResultElementType(op);
+ }
+};
+
+/// This class provides verification for ops that are known to have the same
/// operand and result type.
///
-/// Note: this trait subsumes the SameOperandsAndResultShape trait.
-/// Additionally, it requires all operands and results should also have
-/// the same element type.
+/// Note: this trait subsumes the SameOperandsAndResultShape and
+/// SameOperandsAndResultElementType traits.
template <typename ConcreteType>
class SameOperandsAndResultType
: public TraitBase<ConcreteType, SameOperandsAndResultType> {
return success();
}
+LogicalResult
+OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
+ if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+ return failure();
+
+ auto type = op->getResult(0)->getType().dyn_cast<VectorOrTensorType>();
+ if (!type)
+ return op->emitOpError("requires vector or tensor type results");
+ auto elementType = type.getElementType();
+
+ // Verify result element type matches first result's element type.
+ for (auto result : drop_begin(op->getResults(), 1)) {
+ auto resultType = result->getType().dyn_cast<VectorOrTensorType>();
+ if (!resultType)
+ return op->emitOpError("requires vector or tensor type results");
+ if (resultType.getElementType() != elementType)
+ return op->emitOpError(
+ "requires the same element type for all operands and results");
+ }
+
+ // Verify operand's element type matches first result's element type.
+ for (auto operand : op->getOperands()) {
+ auto operandType = operand->getType().dyn_cast<VectorOrTensorType>();
+ if (!operandType)
+ return op->emitOpError("requires vector or tensor type operands");
+ if (operandType.getElementType() != elementType)
+ return op->emitOpError(
+ "requires the same element type for all operands and results");
+ }
+
+ return success();
+}
+
LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return failure();
--- /dev/null
+//===- OpDefinitionTest.cpp - Op definition unit tests --------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "gmock/gmock.h"
+
+using namespace mlir;
+using namespace mlir::OpTrait::impl;
+
+namespace {
+
+// TODO: Replace with regular test once this trait is used by operation in core.
+TEST(OpDefinitionTest, SameOperandAndResultElementType) {
+ MLIRContext context;
+#define FILE_LOC \
+ FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \
+ &context)
+
+ Builder b(&context);
+ auto *operandtF32x10x10 = Operation::create(
+ FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
+ /*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
+ /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+ /*resizableOperandList=*/false, &context);
+ auto *operandtF32x1 = Operation::create(
+ FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
+ /*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
+ /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+ /*resizableOperandList=*/false, &context);
+ auto *operandvF32x1 = Operation::create(
+ FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
+ /*resultTypes=*/{b.getVectorType({1}, b.getF32Type())},
+ /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+ /*resizableOperandList=*/false, &context);
+ auto *operandtI32x1 = Operation::create(
+ FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
+ /*resultTypes=*/{b.getTensorType({1}, b.getIntegerType(32))},
+ /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+ /*resizableOperandList=*/false, &context);
+
+ // Verifies whether an op with x and y as inputs and resultType satisfies the
+ // SameOperandAndResultElementType trait.
+ auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
+ auto op = Operation::create(loc, OperationName("some_op", &context),
+ /*operands=*/{x->getResult(0), y->getResult(0)},
+ /*resultTypes=*/{resultType},
+ /*attributes=*/llvm::None, /*successors=*/{},
+ /*numRegions=*/0,
+ /*resizableOperandList=*/false, &context);
+ return succeeded(verifySameOperandsAndResultElementType(op));
+ };
+
+ EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
+ b.getTensorType({12}, b.getF32Type())));
+ EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
+ b.getTensorType({5}, b.getF32Type())));
+ EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
+ b.getTensorType({7}, b.getF32Type())));
+ EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
+ b.getTensorType({12}, b.getIntegerType(32))));
+ EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
+ b.getTensorType({9}, b.getIntegerType(32))));
+ EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
+ b.getVectorType({9}, b.getF32Type())));
+ EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandvF32x1,
+ b.getVectorType({9}, b.getF32Type())));
+ EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandvF32x1,
+ b.getTensorType({5}, b.getF32Type())));
+ EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
+ b.getTensorType({5}, b.getF32Type())));
+
+#undef FILE_LOC
+}
+
+} // end namespace