From: Hanhan Wang Date: Fri, 22 Jan 2021 06:08:51 +0000 (-0800) Subject: [mlir][Linalg] Introduce linalg.pad_tensor op. X-Git-Tag: llvmorg-13-init~517 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=16d4bbef30a9e625e04653047759d5636f9e58a5;p=platform%2Fupstream%2Fllvm.git [mlir][Linalg] Introduce linalg.pad_tensor op. `linalg.pad_tensor` is an operation that pads the `source` tensor with given `low` and `high` padding config. Example 1: ```mlir %pad_value = ... : f32 %1 = linalg.pad_tensor %0 low[1, 2] high[2, 3] { ^bb0(%arg0 : index, %arg1 : index): linalg.yield %pad_value : f32 } : tensor to tensor ``` Example 2: ```mlir %pad_value = ... : f32 %1 = linalg.pad_tensor %arg0 low[2, %arg1, 3, 3] high[3, 3, %arg1, 2] { ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index): linalg.yield %pad_value : f32 } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> ``` Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D93704 --- diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 0ce86e4..ae9f81d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -117,6 +117,101 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> { let hasCanonicalizer = 1; } +def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", + [AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "tensor pad operation"; + let description = [{ + `linalg.pad_tensor` is an operation that pads the `source` tensor + with given `low` and `high` padding config. + + The PadTensor operation supports the following arguments: + + * source: the "base" tensor on which to pad. + * low: A list contains the padding along the start of each + dimension, i.e `low`. + * high: A list contains the padding along the end of each + dimension, i.e. `high`. + + The result tensor dimensions are `low` + `dim` + `high` along that + dimension. The number of elements of `low` and `high` must match + the rank of the input tensor (which is also the rank of the output + tensor). They can be either a constant or a dynamic value. + + The region of the `pad_tensor` operation returns the value to use + for the padding. The arguments of the region represent the index + of the source being accessed. There should be as many arguments as + the rank of the `source` tensor. The value `yield`-ed by the + region is used as the value of the view at the given position. + + Example 1: + + ```mlir + %pad_value = ... : f32 + %0 = linalg.pad_tensor %0 low[1, 2] high[2, 3] { + ^bb0(%arg0 : index, %arg1 : index): + linalg.yield %pad_value : f32 + } : tensor to tensor + ``` + + Example 2: + + ```mlir + %pad_value = ... : f32 + %0 = linalg.pad_tensor %arg0 low[2, %arg1, 3, 3] high[3, 3, %arg1, 2] { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index): + linalg.yield %pad_value : f32 + } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> + ``` + + Example 3: + + ```mlir + %pad_value = ... : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad_value : f32 + } : tensor<2x3xf32> to tensor + ``` + }]; + + let arguments = (ins + AnyTensor:$source, + Variadic:$low, + Variadic:$high, + I64ArrayAttr:$static_low, + I64ArrayAttr:$static_high); + + let regions = (region AnyRegion:$region); + + let results = (outs AnyTensor:$result); + + let extraClassDeclaration = [{ + static StringRef getStaticLowAttrName() { + return "static_low"; + } + + static StringRef getStaticHighAttrName() { + return "static_high"; + } + + // Infer the shape of the result tensor given the static shapes + // and element type of the result tensor. + static RankedTensorType inferResultType(RankedTensorType sourceType, + ArrayRef staticLow, + ArrayRef staticHigh); + }]; + + let builders = [ + // Build a PadTensorOp with mixed static and dynamic entries. + OpBuilderDAG<(ins "Value":$source, "ArrayRef":$staticLow, + "ArrayRef":$staticHigh, "ValueRange":$low, "ValueRange":$high, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build a PadTensorOp with all dynamic entries. + OpBuilderDAG<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high, + CArg<"ArrayRef", "{}">:$attrs)> + ]; +} + def Linalg_RangeOp : Linalg_Op<"range", [NoSideEffect]>, Arguments<(ins Index:$min, Index:$max, Index:$step)>, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index fa98ed0c..b500eef 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -916,6 +916,151 @@ void InitTensorOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// +// PadTensorOp +//===----------------------------------------------------------------------===// + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +static SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + +static LogicalResult verify(PadTensorOp op) { + auto sourceType = op.source().getType().cast(); + auto resultType = op.result().getType().cast(); + auto expectedType = PadTensorOp::inferResultType( + sourceType, extractFromI64ArrayAttr(op.static_low()), + extractFromI64ArrayAttr(op.static_high())); + if (resultType != expectedType) { + return op.emitError("specified type ") + << resultType << " does not match the inferred type " + << expectedType; + } + + auto ®ion = op.region(); + if (!llvm::hasSingleElement(region)) + return op.emitOpError("expected region with 1 block"); + unsigned rank = resultType.getRank(); + Block &block = region.front(); + if (block.getNumArguments() != rank) + return op.emitError("expected the block to have ") << rank << " arguments"; + + // Note: the number and type of yield values are checked in the YieldOp. + for (auto en : llvm::enumerate(block.getArgumentTypes())) { + if (!en.value().isIndex()) + return op.emitOpError("expected block argument ") + << (en.index() + 1) << " to be an index"; + } + + return success(); +} + +RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType, + ArrayRef staticLow, + ArrayRef staticHigh) { + unsigned rank = sourceType.getRank(); + assert(staticLow.size() == rank && "unexpected staticLow size mismatch"); + assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch"); + + SmallVector resultShape; + for (auto i : llvm::seq(0, rank)) { + if (sourceType.isDynamicDim(i) || + staticLow[i] == ShapedType::kDynamicSize || + staticHigh[i] == ShapedType::kDynamicSize) { + resultShape.push_back(ShapedType::kDynamicSize); + } else { + int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; + resultShape.push_back(size); + } + } + + return RankedTensorType::get(resultShape, sourceType.getElementType()); +} + +static ParseResult parsePadTensorOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType baseInfo; + SmallVector operands; + SmallVector types; + if (parser.parseOperand(baseInfo)) + return failure(); + + IndexType indexType = parser.getBuilder().getIndexType(); + SmallVector lowPadding, highPadding; + if (parser.parseKeyword("low") || + parseListOfOperandsOrIntegers(parser, result, + PadTensorOp::getStaticLowAttrName(), + ShapedType::kDynamicSize, lowPadding)) + return failure(); + if (parser.parseKeyword("high") || + parseListOfOperandsOrIntegers(parser, result, + PadTensorOp::getStaticHighAttrName(), + ShapedType::kDynamicSize, highPadding)) + return failure(); + + SmallVector regionOperands; + std::unique_ptr region = std::make_unique(); + SmallVector operandTypes, regionTypes; + if (parser.parseRegion(*region, regionOperands, regionTypes)) + return failure(); + result.addRegion(std::move(region)); + + Type srcType, dstType; + if (parser.parseColonType(srcType) || parser.parseKeywordType("to", dstType)) + return failure(); + + if (parser.addTypeToList(dstType, result.types)) + return failure(); + + SmallVector segmentSizesFinal = {1}; // source tensor + segmentSizesFinal.append({static_cast(lowPadding.size()), + static_cast(highPadding.size())}); + result.addAttribute( + OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), + parser.getBuilder().getI32VectorAttr(segmentSizesFinal)); + return failure( + parser.parseOptionalAttrDict(result.attributes) || + parser.resolveOperand(baseInfo, srcType, result.operands) || + parser.resolveOperands(lowPadding, indexType, result.operands) || + parser.resolveOperands(highPadding, indexType, result.operands)); +} + +static void print(OpAsmPrinter &p, PadTensorOp op) { + p << op->getName().getStringRef() << ' '; + p << op.source(); + p << " low"; + printListOfOperandsOrIntegers(p, op.low(), op.static_low(), + ShapedType::isDynamic); + p << " high"; + printListOfOperandsOrIntegers(p, op.high(), op.static_high(), + ShapedType::isDynamic); + p.printRegion(op.region()); + p << " : " << op.source().getType() << " to " << op.getType(); +} + +void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, + ArrayRef staticLow, + ArrayRef staticHigh, ValueRange low, + ValueRange high, ArrayRef attrs) { + auto sourceType = source.getType().cast(); + auto resultType = inferResultType(sourceType, staticLow, staticHigh); + build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), + b.getI64ArrayAttr(staticHigh)); + result.addAttributes(attrs); +} + +void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, + ValueRange low, ValueRange high, + ArrayRef attrs) { + auto sourceType = source.getType().cast(); + unsigned rank = sourceType.getRank(); + SmallVector staticVector(ShapedType::kDynamicSize, rank); + build(b, result, source, staticVector, staticVector, low, high, attrs); +} + +//===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -1557,6 +1702,13 @@ static LogicalResult verify(linalg::YieldOp op) { if (auto linalgOp = dyn_cast(parentOp)) return verifyYield(op, cast(parentOp)); + if (auto padTensorOp = dyn_cast(parentOp)) { + return success( + op.getNumOperands() == 1 && + op.getOperand(0).getType() == + padTensorOp.getType().cast().getElementType()); + } + return op.emitOpError("expected parent op with LinalgOp interface"); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 4359eeb..a3ef242 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -617,3 +617,45 @@ func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref) -> mem memref into memref return %0 : memref } + +// ----- + +func @pad_result_type(%arg0: tensor, %arg1: index, %arg2: i32) -> tensor { + // expected-error @+1 {{specified type 'tensor' does not match the inferred type 'tensor}} + %0 = linalg.pad_tensor %arg0 low[1, %arg1, 2, 2] high[1, 2, %arg1, 3] { + ^bb0(%arg3: index, %arg4: index): // no predecessors + linalg.yield %arg2 : i32 + } : tensor to tensor + return %0 : tensor +} + +// ----- + +func @pad_number_of_block_args(%arg0: tensor, %arg1: i32) -> tensor { + // expected-error @+1 {{expected the block to have 2 arguments}} + %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { + ^bb0(%arg2: index, %arg3: index, %arg4: index): // no predecessors + linalg.yield %arg1 : i32 + } : tensor to tensor + return %0 : tensor +} + +// ----- + +func @pad_no_block(%arg0: tensor, %arg1: i32) -> tensor { + // expected-error @+1 {{expected region with 1 block}} + %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { + } : tensor to tensor + return %0 : tensor +} + +// ----- + +func @pad_block_args(%arg0: tensor, %arg1: i32) -> tensor { + // expected-error @+1 {{op expected block argument 1 to be an index}} + %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + linalg.yield %arg1 : i32 + } : tensor to tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index d0121b0..c4a3247 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -5,6 +5,58 @@ // Test that we can lower all the way to LLVM without crashing, don't check results here. // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 +func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, + %pad_value: f32) -> tensor<6x?x?x?xf32> { + %0 = linalg.pad_tensor %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + linalg.yield %pad_value : f32 + } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> + return %0 : tensor<6x?x?x?xf32> +} +// CHECK-LABEL: func @pad_dynamic +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[LOW:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[HIGH:[a-zA-Z0-9_]*]] +// CHECK: linalg.pad_tensor %[[ARG0]] +// CHECK-SAME: low[2, %[[LOW]], 3, 3] +// CHECK-SAME: high[3, 3, %[[HIGH]], 2] +// CHECK: : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> + +// ----- + +func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> { + %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { + ^bb0(%arg1 : index, %arg2 : index): + linalg.yield %pad_value : f32 + } : tensor<3x4xf32> to tensor<6x9xf32> + return %0 : tensor<6x9xf32> +} +// CHECK-LABEL: func @pad_static +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK: linalg.pad_tensor %[[ARG0]] low[1, 2] high[2, 3] +// CHECK: : tensor<3x4xf32> to tensor<6x9xf32> + +// ----- + +func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index, + %pad_value: f32) -> tensor { + %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad_value : f32 + } : tensor<2x3xf32> to tensor + return %0 : tensor +} +// CHECK-LABEL: func @pad_asymmetrical +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[UB0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[UB1:[a-zA-Z0-9_]*]] +// CHECK: linalg.pad_tensor %[[ARG0]] +// CHECK-SAME: low[0, 0] +// CHECK-SAME: high[%[[UB0]], %[[UB1]]] +// CHECK: : tensor<2x3xf32> to tensor + +// ----- + func @range(%arg0: index, %arg1: index, %arg2: index) { %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range return