```
}];
- let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor],
+ let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
"any tensor or memref type">:$memrefOrTensor,
Index:$index);
let results = (outs Index:$result);
def RankOp : Std_Op<"rank", [NoSideEffect]> {
let summary = "rank operation";
let description = [{
- The `rank` operation takes a tensor operand and returns its rank.
+ The `rank` operation takes a memref/tensor operand and returns its rank.
Example:
```mlir
- %1 = rank %0 : tensor<*xf32>
+ %1 = rank %arg0 : tensor<*xf32>
+ %2 = rank %arg1 : memref<*xf32>
```
}];
- let arguments = (ins AnyTensor);
+ let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
+ "any tensor or memref type">:$memrefOrTensor);
let results = (outs Index);
let verifier = ?;
}]>];
let hasFolder = 1;
- let assemblyFormat = "operands attr-dict `:` type(operands)";
+ let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)";
}
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
- // Constant fold rank when the rank of the tensor is known.
+ // Constant fold rank when the rank of the operand is known.
auto type = getOperand().getType();
- if (auto tensorType = type.dyn_cast<RankedTensorType>())
- return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
+ if (auto shapedType = type.dyn_cast<ShapedType>())
+ if (shapedType.hasRank())
+ return IntegerAttr::get(IndexType::get(getContext()),
+ shapedType.getRank());
return IntegerAttr();
}
// -----
+// CHECK-LABEL: func @fold_rank_memref
+func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
+ // Fold a rank into a constant
+ // CHECK-NEXT: [[C2:%.+]] = constant 2 : index
+ %rank_0 = rank %arg0 : memref<?x?xf32>
+
+ // CHECK-NEXT: return [[C2]]
+ return %rank_0 : index
+}
+
+// -----
+
// CHECK-LABEL: func @nested_isolated_region
func @nested_isolated_region() {
// CHECK-NEXT: func @isolated_op