From 24f9293de8794963bd29c731745a71ef6a1aab9d Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Thu, 3 Nov 2022 20:30:12 +0000 Subject: [PATCH] [mlir][Tensor] Allow builders of `tensor.empty` to accept encoding attribute. The `RankedTensorType` can have an optional encoding attribute. Allowing the builders of `tensor.empty` to accept the encoding attribute (optionally), allows building empty tensors with the type having the encoding attribute. Reviewed By: nicolasvasilache, hanchung, springerm Differential Revision: https://reviews.llvm.org/D137297 --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 9 ++++++--- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 14 ++++++++------ mlir/test/Dialect/Tensor/ops.mlir | 9 +++++++++ 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 2cfdc6d..552d2db 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -191,14 +191,17 @@ def Tensor_EmptyOp : Tensor_Op<"empty", let builders = [ // Build with fully static sizes. - OpBuilder<(ins "ArrayRef":$staticShape, "Type":$elementType)>, + OpBuilder<(ins "ArrayRef":$staticShape, "Type":$elementType, + CArg<"Attribute", "{}">:$encoding)>, // Build with mixed static/dynamic sizes. OpBuilder<(ins "ArrayRef":$staticShape, "Type":$elementType, - "ValueRange":$dynamicSizes)>, + "ValueRange":$dynamicSizes, + CArg<"Attribute", "{}">:$encoding)>, // Build with mixed static/dynamic sizes. - OpBuilder<(ins "ArrayRef":$sizes, "Type":$elementType)> + OpBuilder<(ins "ArrayRef":$sizes, "Type":$elementType, + CArg<"Attribute", "{}">:$encoding)> ]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 445e78e..31d892f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -497,27 +497,29 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, //===----------------------------------------------------------------------===// void EmptyOp::build(OpBuilder &builder, OperationState &result, - ArrayRef staticShape, Type elementType) { + ArrayRef staticShape, Type elementType, + Attribute encoding) { assert(all_of(staticShape, [](int64_t sz) { return !ShapedType::isDynamic(sz); }) && "expected only static sizes"); - build(builder, result, staticShape, elementType, {}); + build(builder, result, staticShape, elementType, ValueRange{}, encoding); } void EmptyOp::build(OpBuilder &builder, OperationState &result, ArrayRef staticShape, Type elementType, - ValueRange dynamicSizes) { - auto tensorType = RankedTensorType::get(staticShape, elementType); + ValueRange dynamicSizes, Attribute encoding) { + auto tensorType = RankedTensorType::get(staticShape, elementType, encoding); build(builder, result, tensorType, dynamicSizes); } void EmptyOp::build(OpBuilder &builder, OperationState &result, - ArrayRef sizes, Type elementType) { + ArrayRef sizes, Type elementType, + Attribute encoding) { SmallVector staticShape; SmallVector dynamicSizes; dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape, ShapedType::kDynamicSize); - build(builder, result, staticShape, elementType, dynamicSizes); + build(builder, result, staticShape, elementType, dynamicSizes, encoding); } LogicalResult EmptyOp::verify() { diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index aadf6ab9..4afe128 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -21,6 +21,15 @@ func.func @empty(%sz: index) -> tensor<5x?x6xf32> { return %0 : tensor<5x?x6xf32> } +// CHECK-LABEL: func @empty_with_encoding( +// CHECK-SAME: %[[sz:.*]]: index +func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> { + // CHECK: tensor.empty(%[[sz]]) : tensor<5x?x6xf32, "foo"> + %0 = tensor.empty(%sz) : tensor<5x?x6xf32, "foo"> + return %0 : tensor<5x?x6xf32, "foo"> +} + + // CHECK-LABEL: func @extract( // CHECK-SAME: %[[TENSOR:.*]]: tensor, // CHECK-SAME: %[[INDEX:.*]]: index) { -- 2.7.4