Add a rank op to MLIR. Example:
authorRasmus Munk Larsen <rmlarsen@google.com>
Wed, 29 May 2019 16:22:30 +0000 (09:22 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:06:51 +0000 (20:06 -0700)
      %1 = rank %0 : index

--

PiperOrigin-RevId: 250505411

mlir/include/mlir/StandardOps/Ops.td
mlir/lib/StandardOps/Ops.cpp
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/Transforms/constant-fold.mlir

index a9f681a..e825e89 100644 (file)
@@ -573,6 +573,26 @@ def OrOp : IntArithmeticOp<"or", [Commutative]> {
   let hasFolder = 1;
 }
 
+def RankOp : Std_Op<"rank", [NoSideEffect]> {
+  let summary = "rank operation";
+  let description = [{
+    The "rank" operation takes a tensor operand and returns its rank.
+
+      %1 = rank %0 : index
+  }];
+
+  let arguments = (ins AnyTensor);
+  let results = (outs Index);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *tensor", [{
+      auto indexType = builder->getIndexType();
+      build(builder, result, indexType, tensor);
+    }]>];
+
+  let hasFolder = 1;
+}
+
 def RemFOp : FloatArithmeticOp<"remf"> {
   let summary = "floating point division remainder operation";
 }
index 316d36d..7c73d5b 100644 (file)
@@ -1788,6 +1788,43 @@ OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, RankOp op) {
+  *p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType();
+}
+
+static ParseResult parseRankOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType operandInfo;
+  Type type;
+  Type indexType = parser->getBuilder().getIndexType();
+
+  return failure(parser->parseOperand(operandInfo) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperand(operandInfo, type, result->operands) ||
+                 parser->addTypeToList(indexType, result->types));
+}
+
+static LogicalResult verify(RankOp op) {
+  auto type = op.getOperand()->getType();
+  if (!type.isa<TensorType>()) {
+    return op.emitOpError("requires an operand that is a tensor");
+  }
+  return success();
+}
+
+OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+  // Constant fold rank when the rank of the tensor is known.
+  auto type = getOperand()->getType();
+  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+    int64_t rank = tensorType.getRank();
+    return IntegerAttr::get(IndexType::get(getContext()), rank);
+  }
+  return IntegerAttr();
+}
+
+//===----------------------------------------------------------------------===//
 // RemISOp
 //===----------------------------------------------------------------------===//
 
index ac97ea5..1a45d9d 100644 (file)
@@ -276,6 +276,12 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index) {
   // CHECK: %{{[0-9]+}} = cmpf "oeq", %cst_8, %cst_8 : vector<4xf32>
   %70 = cmpf "oeq", %vcf32, %vcf32 : vector<4 x f32>
 
+  // CHECK: %{{[0-9]+}} = rank %arg0 : tensor<4x4x?xf32>
+  %71 = "std.rank"(%t) : (tensor<4x4x?xf32>) -> index
+
+  // CHECK: %{{[0-9]+}} = rank %arg0 : tensor<4x4x?xf32>
+  %72 = rank %t : tensor<4x4x?xf32>
+
   return
 }
 
index 2a47730..562c2ce 100644 (file)
@@ -24,6 +24,14 @@ func @dim3(tensor<1xf32>) {
 
 // -----
 
+func @rank(f32) {
+^bb(%0: f32):
+  "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be tensor of any type values}}
+  return
+}
+
+// -----
+
 func @constant() {
 ^bb:
   %x = "std.constant"(){value: "xyz"} : () -> i32 // expected-error {{requires a result type that aligns with the 'value' attribute}}
index 927d2d4..647c718 100644 (file)
@@ -420,3 +420,17 @@ func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
   // CHECK-NEXT: return
   return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
 }
+
+
+// CHECK-LABEL: func @fold_rank
+func @fold_rank() -> (index) {
+  %const_0 = constant dense<tensor<2x1x4xi32>, [[[1, -2, 1, 36]], [[0, 2, -1, 64]]]>
+
+  // Fold a rank into a constant
+  // CHECK-NEXT: {{.*}} = constant 3 : index
+  %rank_0 = rank %const_0 : tensor<2x1x4xi32>
+
+  // CHECK-NEXT: return
+  return %rank_0 : index
+}
+