dimension, i.e `low`.
* high: A list contains the padding along the end of each
dimension, i.e. `high`.
+ * packing: whether the padding operation is guaranteed to create a new
+ tensor suitable for packing, i.e. a copy.
The result tensor dimensions are `low` + `dim` + `high` along that
dimension. The number of elements of `low` and `high` must match
the rank of the `source` tensor. The value `yield`-ed by the
region is used as the value of the view at the given position.
+ If `packing` is indicated, the padding is guaranteed to produce a new
+ tensor, e.g., to use for packing or promotion to faster memory. Such
+ operations are not optimized away even when the source type has the same
+ static shape.
+
Example 1:
```mlir
linalg.yield %pad_value : f32
} : tensor<2x3xf32> to tensor<?x?xf32>
```
+
+ Example 4:
+
+ ```mlir
+ // Force a padded value to be always exist with `packing`.
+ %pad_value = ... : f32
+ %0 = linalg.pad_tensor %arg0 packing low[0, 0] high[0, 0] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad_value : f32
+ } : tensor<2x3xf32> to tensor<2x3xf32>
+ ```
}];
let arguments = (ins
Variadic<Index>:$low,
Variadic<Index>:$high,
I64ArrayAttr:$static_low,
- I64ArrayAttr:$static_high);
+ I64ArrayAttr:$static_high,
+ UnitAttr:$packing);
let regions = (region SizedRegion<1>:$region);
// TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands.
let assemblyFormat = [{
$source
+ (`packing` $packing^)?
`low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
`high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
$region attr-dict `:` type($source) `to` type($result)
// "high" padding (i.e. it adds trailing padding values until the desired
// size is met).
static linalg::PadTensorOp createPadHighOp(
- Type type, Value source, Value pad, Location loc, OpBuilder & builder);
+ Type type, Value source, Value pad, bool packing, Location loc,
+ OpBuilder & builder);
// Return a PadTensorOp that pads `source to `type` size with `pad` value.
// I.e., a block will be created and the `pad` value will be yielded
// directly. If the type passed is nullptr, it is inferred.
static linalg::PadTensorOp createPadScalarOp(
Type type, Value source, Value pad, ArrayRef<OpFoldResult> low,
- ArrayRef<OpFoldResult> high, Location loc, OpBuilder & builder);
+ ArrayRef<OpFoldResult> high, bool packing, Location loc,
+ OpBuilder & builder);
// Return the pad value if it is a constant. Return null value otherwise.
Value getConstantPaddingValue();
// Build a PadTensorOp with mixed static and dynamic entries.
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
"ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
+ CArg<"bool", "false">:$packing,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a PadTensorOp with all dynamic entries.
OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
+ CArg<"bool", "false">:$packing,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a PadTensorOp with mixed static and dynamic entries and custom
// result type. If the type passed is nullptr, it is inferred.
OpBuilder<(ins "Type":$resultType, "Value":$source,
"ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
+ CArg<"bool", "false">:$packing,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];
return linalg::PadTensorOp::createPadScalarOp(
RankedTensorType::get(paddedShape, inputETy), input, padValue,
- lowIndices, highIndices, loc, rewriter)
+ lowIndices, highIndices, /*packing=*/false, loc, rewriter)
.result();
}
Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
- padOp.getType(), input, constant, lowValues, highValues, loc, rewriter);
+ padOp.getType(), input, constant, lowValues, highValues,
+ /*packing=*/false, loc, rewriter);
rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<int64_t> staticLow,
ArrayRef<int64_t> staticHigh, ValueRange low,
- ValueRange high, ArrayRef<NamedAttribute> attrs) {
+ ValueRange high, bool packing,
+ ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
auto resultType = inferResultType(sourceType, staticLow, staticHigh);
build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
- b.getI64ArrayAttr(staticHigh));
+ b.getI64ArrayAttr(staticHigh), packing ? b.getUnitAttr() : UnitAttr());
result.addAttributes(attrs);
}
void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
- ValueRange low, ValueRange high,
+ ValueRange low, ValueRange high, bool packing,
ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
unsigned rank = sourceType.getRank();
SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
- build(b, result, source, staticVector, staticVector, low, high, attrs);
+ build(b, result, source, staticVector, staticVector, low, high, packing,
+ attrs);
}
void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ArrayRef<OpFoldResult> low,
- ArrayRef<OpFoldResult> high,
+ ArrayRef<OpFoldResult> high, bool packing,
ArrayRef<NamedAttribute> attrs) {
assert(resultType.isa<RankedTensorType>());
auto sourceType = source.getType().cast<RankedTensorType>();
PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
}
build(b, result, resultType, source, dynamicLow, dynamicHigh,
- b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
+ b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
+ packing ? b.getUnitAttr() : UnitAttr());
+ result.addAttributes(attrs);
}
PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
ArrayRef<OpFoldResult> low,
ArrayRef<OpFoldResult> high,
- Location loc, OpBuilder &builder) {
- auto padTensorOp =
- builder.create<linalg::PadTensorOp>(loc, type, source, low, high);
+ bool packing, Location loc,
+ OpBuilder &builder) {
+ auto padTensorOp = builder.create<linalg::PadTensorOp>(loc, type, source, low,
+ high, packing);
int rank = padTensorOp.getResultType().getRank();
SmallVector<Type, 4> blockArgTypes;
blockArgTypes.assign(rank, builder.getIndexType());
}
PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
- Location loc, OpBuilder &builder) {
+ bool packing, Location loc,
+ OpBuilder &builder) {
SmallVector<OpFoldResult, 4> low, high;
auto rankedTensorType = type.cast<RankedTensorType>();
assert(rankedTensorType.hasStaticShape());
high.push_back(highValue);
low.push_back(builder.createOrFold<ConstantIndexOp>(loc, 0));
}
- return PadTensorOp::createPadScalarOp(type, source, pad, low, high, loc,
- builder);
+ return PadTensorOp::createPadScalarOp(type, source, pad, low, high, packing,
+ loc, builder);
}
LogicalResult PadTensorOp::reifyResultShapes(
}
namespace {
-// Folds linalg.pad_tensor when padding is static zeros.
+// Folds linalg.pad_tensor when padding is static zeros and packing is not
+// requested.
struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
PatternRewriter &rewriter) const override {
if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
return failure();
+ if (padTensorOp.packing())
+ return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(
padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
return success();
auto newOp = rewriter.create<PadTensorOp>(
padTensorOp->getLoc(), newResultType, padTensorOp.source(),
padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
- padTensorOp.static_high());
+ padTensorOp.static_high(), padTensorOp.packing());
BlockAndValueMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
auto replacementOp = rewriter.create<PadTensorOp>(
padTensorOp.getLoc(), tensorCastOp.dest().getType(),
padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
- padTensorOp.static_low(), padTensorOp.static_high());
+ padTensorOp.static_low(), padTensorOp.static_high(),
+ padTensorOp.packing());
replacementOp.region().takeBody(padTensorOp.region());
rewriter.replaceOp(padTensorOp, replacementOp.result());
}
OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
- if (getResultType().hasStaticShape() && getResultType() == getSourceType())
+ if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
+ !packing())
return source();
return {};
}
return *this;
}
-/// Try to compute a static bounding box for `operand`
-/// Return success if either:
-/// 1. The operand is already statically shaped, `result` is left unchanged.
-/// 2. The operand is (partially) dynamic, `result` is the result of a freshly
-/// created PadTensorOp.
-/// Return failure if the operand cannot be padded to a static shape.
+/// Try to compute a static bounding box for `operand`. The padding happens
+/// even if the operand already has static shape. `result` is the result of a
+/// freshly created PadTensorOp. Return failure if the operand cannot be padded
+/// to a static shape.
static LogicalResult padOperandToSmallestStaticBoundingBox(
PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
const PaddingValueComputationFunction &paddingFunc, Value &result) {
- // Already static shape, no need to pad.
- if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
+ // Can't pad scalars.
+ if (opToPad.getShape(opOperand).empty())
return success();
auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
// Not a slice op, cannot construct a static bounding box.
auto staticTensorType = RankedTensorType::get(
staticSizes, getElementTypeOrSelf(opOperand->get()));
result = linalg::PadTensorOp::createPadHighOp(
- staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
+ staticTensorType, opOperand->get(), pad, /*packing=*/true,
+ opToPad->getLoc(), rewriter);
return success();
}
LinalgOp &paddedOp) {
Location loc = opToPad->getLoc();
- // If the op is fully static, it does not need padding.
// TODO: there are cases where we may still want to pad to larger sizes.
assert(opToPad.hasTensorSemantics() &&
"expected operation to have tensor semantics");
- if (!opToPad.hasDynamicShape())
- return success();
OpBuilder::InsertionGuard g(rewriter);
// Set IP after op because we also take the dims of the original output.
// -----
+// CHECK-LABEL: func @pad_tensor_packing_same_static_shape(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
+// CHECK: %[[PAD:.*]] = linalg.pad_tensor
+// CHECK: return %[[PAD]]
+func @pad_tensor_packing_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+ -> tensor<5x6xf32> {
+ %cst = constant 0.000000e+00 : f32
+ %0 = linalg.pad_tensor %arg0 packing low[%a, 0] high[0, %a] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %cst : f32
+ } : tensor<5x6xf32> to tensor<5x6xf32>
+ return %0 : tensor<5x6xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @pad_tensor_after_cast_different_shape(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32
// -----
+// CHECK-LABEL: func @pad_packing_static_zero(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>
+// CHECK: %[[PAD:.*]] = linalg.pad_tensor
+// CHECK: return %[[PAD]]
+func @pad_packing_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
+ %c0 = constant 0 : index
+ %0 = linalg.pad_tensor %arg0 packing low[0, %c0, 0] high[0, 0, %c0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index):
+ linalg.yield %pad_value : f32
+ } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+
+ return %0 : tensor<2x3x4xf32>
+}
+
+// -----
+
func private @some_use(%i : index, %j : index)
// CHECK-LABEL: func @init_canonicalize
// CHECK-NOT: linalg.matmul {{.*}} tensor<?x?xi8>
// Padding injects static information.
-// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK: : tensor<?x?xi8> to tensor<2x4xi8>
-// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK: : tensor<?x?xi8> to tensor<4x3xi8>
-// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK: : tensor<?x?xi32> to tensor<2x3xi32>
// CHECK: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>)
// CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
// Padding injects static information.
-// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}]
+// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}]
// CHECK: : tensor<?x?x?xf32> to tensor<2x3x4xf32>
// CHECK: %[[pD:.*]] = linalg.generic
// CHECK-SAME: ins(%[[VAL]] : f32) outs(%[[pC]] : tensor<2x3x4xf32>)
// CHECK-1DIM-TILE: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x8xi8> to tensor<?x8xi8>
// CHECK-1DIM-TILE: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8>
// CHECK-1DIM-TILE: %[[sTC:.*]] = tensor.extract_slice %[[TC1]][{{.*}}] : tensor<?x?xi32> to tensor<?x?xi32>
-// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK-1DIM-TILE: : tensor<?x8xi8> to tensor<2x8xi8>
-// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK-1DIM-TILE: : tensor<8x?xi8> to tensor<8x3xi8>
-// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK-1DIM-TILE: : tensor<?x?xi32> to tensor<2x3xi32>
// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
// CHECK-1DIM-TILE: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>
+
+// Check that the tile-and-pad transformation actually introduces the padding
+// as requested, even if original operation already operates on static
+// shapes.
+// CHECK-LABEL: @pad_to_same_static_size
+func @pad_to_same_static_size(%arg0: tensor<2x3x4xf32>, %arg1: f32) -> tensor<2x3x4xf32> {
+ // CHECK: %[[c0:.*]] = constant 0 : index
+ // CHECK-NOT: scf.for
+ // CHECK: linalg.pad_tensor %{{.*}} packing low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
+ // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32>
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> ()>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)> ],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ {__internal_linalg_transform__ = "tile"}
+ ins(%arg1 : f32) outs(%arg0 : tensor<2x3x4xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } -> tensor<2x3x4xf32>
+ return %0 : tensor<2x3x4xf32>
+}
+
+// CHECK-LABEL: @pad_static_divisible_size
+func @pad_static_divisible_size(%arg0: tensor<4x6x8xf32>, %arg1: f32) -> tensor<4x6x8xf32> {
+ // CHECK: %[[c0:.*]] = constant 0 : index
+ // CHECK-COUNT-3: scf.for
+ // CHECK: linalg.pad_tensor %{{.*}} packing low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
+ // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32>
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> ()>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)> ],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ {__internal_linalg_transform__ = "tile"}
+ ins(%arg1 : f32) outs(%arg0 : tensor<4x6x8xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } -> tensor<4x6x8xf32>
+ return %0 : tensor<4x6x8xf32>
+}