From: Jacques Pienaar Date: Sat, 4 May 2019 18:14:40 +0000 (-0700) Subject: Add SameOperandsAndResultElementType trait. X-Git-Tag: llvmorg-11-init~1466^2~1825 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=dcab80115ff8379c269748f1b5565ea88aff176a;p=platform%2Fupstream%2Fllvm.git Add SameOperandsAndResultElementType trait. This trait only works for tensor and vector types at the moment, verifying that the element type of an op with only tensor and vector types match. Added a unit test for it as there is no op currently in core that uses this trait. -- PiperOrigin-RevId: 246661697 --- diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index a2b3ca3..8d3cdb3 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -781,6 +781,9 @@ def NoSideEffect : NativeOpTrait<"HasNoSideEffect">; def SameValueShape : NativeOpTrait<"SameOperandsAndResultShape">; // Op has the same operand and result type. def SameValueType : NativeOpTrait<"SameOperandsAndResultType">; +// Op has the same operand and result element type. +def SameOperandsAndResultElementType : + NativeOpTrait<"SameOperandsAndResultElementType">; // Op is a terminator. def Terminator : NativeOpTrait<"IsTerminator">; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 2532dc8..e551140 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -317,6 +317,7 @@ LogicalResult verifyOneResult(Operation *op); LogicalResult verifyNResults(Operation *op, unsigned numOperands); LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands); LogicalResult verifySameOperandsAndResultShape(Operation *op); +LogicalResult verifySameOperandsAndResultElementType(Operation *op); LogicalResult verifySameOperandsAndResultType(Operation *op); LogicalResult verifyResultsAreBoolLike(Operation *op); LogicalResult verifyResultsAreFloatLike(Operation *op); @@ -573,11 +574,23 @@ public: }; /// This class provides verification for ops that are known to have the same +/// operand and result element type. +/// +/// TODO: This only works for VectorOrTensorType at the moment. +template +class SameOperandsAndResultElementType + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultElementType(op); + } +}; + +/// This class provides verification for ops that are known to have the same /// operand and result type. /// -/// Note: this trait subsumes the SameOperandsAndResultShape trait. -/// Additionally, it requires all operands and results should also have -/// the same element type. +/// Note: this trait subsumes the SameOperandsAndResultShape and +/// SameOperandsAndResultElementType traits. template class SameOperandsAndResultType : public TraitBase { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 9a13cee..067e86a 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -798,6 +798,39 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { return success(); } +LogicalResult +OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { + if (op->getNumOperands() == 0 || op->getNumResults() == 0) + return failure(); + + auto type = op->getResult(0)->getType().dyn_cast(); + if (!type) + return op->emitOpError("requires vector or tensor type results"); + auto elementType = type.getElementType(); + + // Verify result element type matches first result's element type. + for (auto result : drop_begin(op->getResults(), 1)) { + auto resultType = result->getType().dyn_cast(); + if (!resultType) + return op->emitOpError("requires vector or tensor type results"); + if (resultType.getElementType() != elementType) + return op->emitOpError( + "requires the same element type for all operands and results"); + } + + // Verify operand's element type matches first result's element type. + for (auto operand : op->getOperands()) { + auto operandType = operand->getType().dyn_cast(); + if (!operandType) + return op->emitOpError("requires vector or tensor type operands"); + if (operandType.getElementType() != elementType) + return op->emitOpError( + "requires the same element type for all operands and results"); + } + + return success(); +} + LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return failure(); diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 5d72c1f..5337eb4 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRIRTests DialectTest.cpp OperationSupportTest.cpp + OpDefinitionTest.cpp SDBMTest.cpp ) target_link_libraries(MLIRIRTests diff --git a/mlir/unittests/IR/OpDefinitionTest.cpp b/mlir/unittests/IR/OpDefinitionTest.cpp new file mode 100644 index 0000000..4074009 --- /dev/null +++ b/mlir/unittests/IR/OpDefinitionTest.cpp @@ -0,0 +1,91 @@ +//===- OpDefinitionTest.cpp - Op definition unit tests --------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "gmock/gmock.h" + +using namespace mlir; +using namespace mlir::OpTrait::impl; + +namespace { + +// TODO: Replace with regular test once this trait is used by operation in core. +TEST(OpDefinitionTest, SameOperandAndResultElementType) { + MLIRContext context; +#define FILE_LOC \ + FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \ + &context) + + Builder b(&context); + auto *operandtF32x10x10 = Operation::create( + FILE_LOC, OperationName("some_const", &context), /*operands=*/{}, + /*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())}, + /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, + /*resizableOperandList=*/false, &context); + auto *operandtF32x1 = Operation::create( + FILE_LOC, OperationName("some_const", &context), /*operands=*/{}, + /*resultTypes=*/{b.getTensorType({1}, b.getF32Type())}, + /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, + /*resizableOperandList=*/false, &context); + auto *operandvF32x1 = Operation::create( + FILE_LOC, OperationName("some_const", &context), /*operands=*/{}, + /*resultTypes=*/{b.getVectorType({1}, b.getF32Type())}, + /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, + /*resizableOperandList=*/false, &context); + auto *operandtI32x1 = Operation::create( + FILE_LOC, OperationName("some_const", &context), /*operands=*/{}, + /*resultTypes=*/{b.getTensorType({1}, b.getIntegerType(32))}, + /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, + /*resizableOperandList=*/false, &context); + + // Verifies whether an op with x and y as inputs and resultType satisfies the + // SameOperandAndResultElementType trait. + auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) { + auto op = Operation::create(loc, OperationName("some_op", &context), + /*operands=*/{x->getResult(0), y->getResult(0)}, + /*resultTypes=*/{resultType}, + /*attributes=*/llvm::None, /*successors=*/{}, + /*numRegions=*/0, + /*resizableOperandList=*/false, &context); + return succeeded(verifySameOperandsAndResultElementType(op)); + }; + + EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1, + b.getTensorType({12}, b.getF32Type()))); + EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1, + b.getTensorType({5}, b.getF32Type()))); + EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1, + b.getTensorType({7}, b.getF32Type()))); + EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1, + b.getTensorType({12}, b.getIntegerType(32)))); + EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1, + b.getTensorType({9}, b.getIntegerType(32)))); + EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1, + b.getVectorType({9}, b.getF32Type()))); + EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandvF32x1, + b.getVectorType({9}, b.getF32Type()))); + EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandvF32x1, + b.getTensorType({5}, b.getF32Type()))); + EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1, + b.getTensorType({5}, b.getF32Type()))); + +#undef FILE_LOC +} + +} // end namespace