From 861c55e1504b44602105db7c33972340ae341adc Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 29 May 2019 09:22:30 -0700 Subject: [PATCH] Add a rank op to MLIR. Example: %1 = rank %0 : index -- PiperOrigin-RevId: 250505411 --- mlir/include/mlir/StandardOps/Ops.td | 20 ++++++++++++++++++ mlir/lib/StandardOps/Ops.cpp | 37 +++++++++++++++++++++++++++++++++ mlir/test/IR/core-ops.mlir | 6 ++++++ mlir/test/IR/invalid-ops.mlir | 8 +++++++ mlir/test/Transforms/constant-fold.mlir | 14 +++++++++++++ 5 files changed, 85 insertions(+) diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index a9f681a..e825e89 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -573,6 +573,26 @@ def OrOp : IntArithmeticOp<"or", [Commutative]> { let hasFolder = 1; } +def RankOp : Std_Op<"rank", [NoSideEffect]> { + let summary = "rank operation"; + let description = [{ + The "rank" operation takes a tensor operand and returns its rank. + + %1 = rank %0 : index + }]; + + let arguments = (ins AnyTensor); + let results = (outs Index); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *tensor", [{ + auto indexType = builder->getIndexType(); + build(builder, result, indexType, tensor); + }]>]; + + let hasFolder = 1; +} + def RemFOp : FloatArithmeticOp<"remf"> { let summary = "floating point division remainder operation"; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 316d36d..7c73d5b 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1788,6 +1788,43 @@ OpFoldResult MulIOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter *p, RankOp op) { + *p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType(); +} + +static ParseResult parseRankOp(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType operandInfo; + Type type; + Type indexType = parser->getBuilder().getIndexType(); + + return failure(parser->parseOperand(operandInfo) || + parser->parseColonType(type) || + parser->resolveOperand(operandInfo, type, result->operands) || + parser->addTypeToList(indexType, result->types)); +} + +static LogicalResult verify(RankOp op) { + auto type = op.getOperand()->getType(); + if (!type.isa()) { + return op.emitOpError("requires an operand that is a tensor"); + } + return success(); +} + +OpFoldResult RankOp::fold(ArrayRef operands) { + // Constant fold rank when the rank of the tensor is known. + auto type = getOperand()->getType(); + if (auto tensorType = type.dyn_cast()) { + int64_t rank = tensorType.getRank(); + return IntegerAttr::get(IndexType::get(getContext()), rank); + } + return IntegerAttr(); +} + +//===----------------------------------------------------------------------===// // RemISOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index ac97ea5..1a45d9d 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -276,6 +276,12 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index) { // CHECK: %{{[0-9]+}} = cmpf "oeq", %cst_8, %cst_8 : vector<4xf32> %70 = cmpf "oeq", %vcf32, %vcf32 : vector<4 x f32> + // CHECK: %{{[0-9]+}} = rank %arg0 : tensor<4x4x?xf32> + %71 = "std.rank"(%t) : (tensor<4x4x?xf32>) -> index + + // CHECK: %{{[0-9]+}} = rank %arg0 : tensor<4x4x?xf32> + %72 = rank %t : tensor<4x4x?xf32> + return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 2a477300..562c2ce 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -24,6 +24,14 @@ func @dim3(tensor<1xf32>) { // ----- +func @rank(f32) { +^bb(%0: f32): + "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be tensor of any type values}} + return +} + +// ----- + func @constant() { ^bb: %x = "std.constant"(){value: "xyz"} : () -> i32 // expected-error {{requires a result type that aligns with the 'value' attribute}} diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index 927d2d4..647c718 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -420,3 +420,17 @@ func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) { // CHECK-NEXT: return return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32 } + + +// CHECK-LABEL: func @fold_rank +func @fold_rank() -> (index) { + %const_0 = constant dense, [[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> + + // Fold a rank into a constant + // CHECK-NEXT: {{.*}} = constant 3 : index + %rank_0 = rank %const_0 : tensor<2x1x4xi32> + + // CHECK-NEXT: return + return %rank_0 : index +} + -- 2.7.4