[mlir][Tensor] Allow builders of `tensor.empty` to accept encoding attribute.
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 3 Nov 2022 20:30:12 +0000 (20:30 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Thu, 3 Nov 2022 20:30:12 +0000 (20:30 +0000)
The `RankedTensorType` can have an optional encoding
attribute. Allowing the builders of `tensor.empty` to accept the
encoding attribute (optionally), allows building empty tensors with
the type having the encoding attribute.

Reviewed By: nicolasvasilache, hanchung, springerm

Differential Revision: https://reviews.llvm.org/D137297

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/ops.mlir

index 2cfdc6d..552d2db 100644 (file)
@@ -191,14 +191,17 @@ def Tensor_EmptyOp : Tensor_Op<"empty",
 
   let builders = [
     // Build with fully static sizes.
-    OpBuilder<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType)>,
+    OpBuilder<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType,
+                   CArg<"Attribute", "{}">:$encoding)>,
 
     // Build with mixed static/dynamic sizes.
     OpBuilder<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType,
-                   "ValueRange":$dynamicSizes)>,
+                   "ValueRange":$dynamicSizes,
+                   CArg<"Attribute", "{}">:$encoding)>,
 
     // Build with mixed static/dynamic sizes.
-    OpBuilder<(ins "ArrayRef<OpFoldResult>":$sizes, "Type":$elementType)>
+    OpBuilder<(ins "ArrayRef<OpFoldResult>":$sizes, "Type":$elementType,
+                   CArg<"Attribute", "{}">:$encoding)>
   ];
 
   let hasCanonicalizer = 1;
index 445e78e..31d892f 100644 (file)
@@ -497,27 +497,29 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 void EmptyOp::build(OpBuilder &builder, OperationState &result,
-                    ArrayRef<int64_t> staticShape, Type elementType) {
+                    ArrayRef<int64_t> staticShape, Type elementType,
+                    Attribute encoding) {
   assert(all_of(staticShape,
                 [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
          "expected only static sizes");
-  build(builder, result, staticShape, elementType, {});
+  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
 }
 
 void EmptyOp::build(OpBuilder &builder, OperationState &result,
                     ArrayRef<int64_t> staticShape, Type elementType,
-                    ValueRange dynamicSizes) {
-  auto tensorType = RankedTensorType::get(staticShape, elementType);
+                    ValueRange dynamicSizes, Attribute encoding) {
+  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
   build(builder, result, tensorType, dynamicSizes);
 }
 
 void EmptyOp::build(OpBuilder &builder, OperationState &result,
-                    ArrayRef<OpFoldResult> sizes, Type elementType) {
+                    ArrayRef<OpFoldResult> sizes, Type elementType,
+                    Attribute encoding) {
   SmallVector<int64_t> staticShape;
   SmallVector<Value> dynamicSizes;
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape,
                              ShapedType::kDynamicSize);
-  build(builder, result, staticShape, elementType, dynamicSizes);
+  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
 }
 
 LogicalResult EmptyOp::verify() {
index aadf6ab..4afe128 100644 (file)
@@ -21,6 +21,15 @@ func.func @empty(%sz: index) -> tensor<5x?x6xf32> {
   return %0 : tensor<5x?x6xf32>
 }
 
+// CHECK-LABEL: func @empty_with_encoding(
+//  CHECK-SAME:             %[[sz:.*]]: index
+func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> {
+  // CHECK: tensor.empty(%[[sz]]) : tensor<5x?x6xf32, "foo">
+  %0 = tensor.empty(%sz) : tensor<5x?x6xf32, "foo">
+  return %0 : tensor<5x?x6xf32, "foo">
+}
+
+
 // CHECK-LABEL:   func @extract(
 // CHECK-SAME:                  %[[TENSOR:.*]]: tensor<?x?x?xf32>,
 // CHECK-SAME:                  %[[INDEX:.*]]: index) {