From: Matthias Springer Date: Tue, 4 Oct 2022 08:06:00 +0000 (+0900) Subject: [mlir][tensor][NFC] Rename linalg.init_tensor to tensor.empty X-Git-Tag: upstream/17.0.6~31687 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=81ca5aa452400843235e058bc9c83fe71eccd593;p=platform%2Fupstream%2Fllvm.git [mlir][tensor][NFC] Rename linalg.init_tensor to tensor.empty tensor.empty/linalg.init_tensor produces an uninititalized tensor that can be used as a destination operand for destination-style ops (ops that implement `DestinationStyleOpInterface`). This change makes it possible to implement `TilingInterface` for non-destination-style ops without depending on the Linalg dialect. RFC: https://discourse.llvm.org/t/rfc-add-tensor-from-shape-operation/65101 Differential Revision: https://reviews.llvm.org/D135129 --- diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md index 99136f1..b2868efa 100644 --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -318,8 +318,8 @@ extends to a multi-dimensional pointwise computation. As a result, we may use `fill` with arbitrary ranked output tensors: ```python -tensor_2d = linalg.InitTensorOp([4, 8], f32) -tensor_3d = linalg.InitTensorOp([4, 8, 16], f32) +tensor_2d = tensor.EmptyOp([4, 8], f32) +tensor_3d = tensor.EmptyOp([4, 8, 16], f32) fill(value, outs=[tensor_2d]) fill(value, outs=[tensor_3d]) ``` diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 16fe89a..aab039e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -24,110 +24,6 @@ include "mlir/Interfaces/ViewLikeInterface.td" class Linalg_Op traits = []> : Op; -def Linalg_InitTensorOp : Linalg_Op<"init_tensor", - [NoSideEffect, - DeclareOpInterfaceMethods]> { - let summary = "operation to define a tensor of particular shape"; - - let description = [{ - `linalg.init_tensor` is an operation that defines a tensor of a particular - shape. The shape could be dynamic or static. The contents of the tensor are - unspecified and the only purpose of the op result is to materialize the - specified shape in IR and make it available to other transformations. - - Note: This op can be lowered to a `bufferization.alloc_tensor`, at which - point it turns into an explicit buffer allocation. - }]; - - let arguments = - (ins Variadic:$sizes, I64ArrayAttr:$static_sizes); - - let results = (outs AnyTensor:$result); - - let assemblyFormat = [{ - custom($sizes, $static_sizes, - "ShapedType::kDynamicSize") - attr-dict `:` type($result) - }]; - - let extraClassDeclaration = [{ - static StringRef getStaticSizesAttrStrName() { - return "static_sizes"; - } - - RankedTensorType getType() { - return getResult().getType().cast(); } - - // Infer the shape of the result tensor given the static shapes - // and element type of the result tensor. - static Type inferResultType(ArrayRef staticSizes, Type elementType, - Attribute encoding = {}); - - // Return true if the size of the tensor is dynamic at `idx` - bool isDynamicSize(unsigned idx) { - APInt v = *(getStaticSizes().getAsValueRange().begin() + idx); - return ShapedType::isDynamic(v.getSExtValue()); - } - - // Assert that the size of the result tensor is static at `idx` - // and return the shape. - int64_t getStaticSize(unsigned idx) { - assert(!isDynamicSize(idx) && "expected static size"); - APInt v = *(getStaticSizes(). - template getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - - // Return the argument position that contains the dynamic size of - // the tensor at dimension `idx`. Asserts that the shape is - // dynamic at that `idx`. - unsigned getIndexOfDynamicSize(unsigned idx) { - assert(isDynamicSize(idx) && "expected dynamic size"); - return std::count_if( - getStaticSizes().getValue().begin(), - getStaticSizes().getValue().begin() + idx, - [&](Attribute attr) { - return ShapedType::isDynamic(attr.cast().getInt()); - }); - } - - // Return both static and dynamic sizes as a list of `OpFoldResult`. - SmallVector getMixedSizes(); - - // Return the Value of the dynamic size of the tensor at dimension - // `idx`. Asserts that the shape is dynamic at that `idx. - Value getDynamicSize(unsigned idx) { - return getOperand(getIndexOfDynamicSize(idx)); - } - }]; - - let builders = [ - OpBuilder<(ins "ValueRange":$shape, - "ArrayRef":$staticShape, "Type":$elementType), - [{ - build($_builder, $_state, - InitTensorOp::inferResultType(staticShape, elementType), - shape, $_builder.getI64ArrayAttr(staticShape)); - }]>, - OpBuilder<(ins "ValueRange":$shape, "Type":$elementType), - [{ - SmallVector staticShape( - shape.size(), ShapedType::kDynamicSize); - build($_builder, $_state, shape, staticShape, elementType); - }]>, - OpBuilder<(ins "ArrayRef":$staticShape, "Type":$elementType), - [{ - build($_builder, $_state, ValueRange{}, staticShape, elementType); - }]>, - OpBuilder<(ins "ArrayRef":$sizes, "Type":$elementType, - CArg<"ArrayRef", "{}">:$attrs)> - ]; - - let hasCanonicalizer = 1; - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; -} - def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, Arguments<(ins Variadic:$values)> { let summary = "Linalg yield operation"; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 59a30b9..6b40102 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -61,8 +61,8 @@ createConvertLinalgToParallelLoopsPass(); std::unique_ptr> createConvertLinalgToAffineLoopsPass(); -/// Create a pass that rewrites init_tensor to alloc_tensor. -std::unique_ptr createLinalgInitTensorToAllocTensorPass(); +/// Create a pass that rewrites tensor.empty to bufferization.alloc_tensor. +std::unique_ptr createEmptyTensorToAllocTensorPass(); /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index b6ed7a7..6af6a12 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -24,14 +24,14 @@ def ConvertElementwiseToLinalg : Pass<"convert-elementwise-to-linalg", ""> { let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"]; } -def LinalgInitTensorToAllocTensor : Pass<"linalg-init-tensor-to-alloc-tensor"> { - let summary = "Replace all init_tensor ops by alloc_tensor ops."; +def EmptyTensorToAllocTensor : Pass<"empty-tensor-to-alloc-tensor"> { + let summary = "Replace all empty ops by alloc_tensor ops."; let description = [{ - init_tensor ops return a tensor of unspecified contents who's only purpose + tensor.empty ops return a tensor of unspecified contents who's only purpose is to carry the tensor shape. This pass converts such ops to bufferization.alloc_tensor ops, which bufferize to buffer allocations. }]; - let constructor = "mlir::createLinalgInitTensorToAllocTensorPass()"; + let constructor = "mlir::createEmptyTensorToAllocTensorPass()"; } def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> { diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 33b5c0e..2493e96 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -506,7 +506,7 @@ def SplitReductionOp : Op into tensor<4x8xf32> - %1 = linalg.init_tensor [4] : tensor<4xf32> + %1 = tensor.empty() : tensor<4xf32> %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], @@ -557,11 +557,11 @@ def SplitReductionOp : Op (d0, d1, d2)> #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> - %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32> + %0 = tensor.empty() : tensor<16x32x64xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) -> tensor<16x32x64xf32> - %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1> + %2 = tensor.empty() : tensor<64x4xi1> %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h b/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h index 795dd00..e257eba 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h @@ -49,7 +49,7 @@ class GenericOp; /// /// ``` /// scf.for (%i) { -/// %packed_init = linalg.init_tensor range(%j) : tensor +/// %packed_init = tensor.empty range(%j) : tensor /// %packed = scf.for (%k) iter_args(%p : %packed_init) { /// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor /// %0 = tensor.pad %st0 low[0, 0] high[...] { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 992cf72..a88eb84 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1196,7 +1196,7 @@ rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, using OptimizeCopyFn = std::function; -/// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp and +/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and /// InsertSliceOp. For now, only constant padding values are supported. /// `OptimizeCopyFn` can be used to customize copying step optimization. struct GeneralizePadOpPattern : public OpRewritePattern { @@ -1407,7 +1407,7 @@ void populateSplitReductionPattern( /// ``` /// %cst = arith.constant 0.000000e+00 : f32 /// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32> -/// %1 = linalg.init_tensor [4] : tensor<4xf32> +/// %1 = tensor.empty [4] : tensor<4xf32> /// %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> /// %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, /// affine_map<(d0, d1) -> (d0)>], @@ -1464,11 +1464,11 @@ splitReduction(PatternRewriter &b, LinalgOp op, /// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> /// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> /// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> -/// %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32> +/// %0 = tensor.empty [16, 32, 64] : tensor<16x32x64xf32> /// %cst = arith.constant 0.000000e+00 : f32 /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) -> /// tensor<16x32x64xf32> -/// %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1> +/// %2 = tensor.empty [64, 4] : tensor<64x4xi1> /// /// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], /// iterator_types = ["parallel", "parallel", "parallel", "reduction"]} diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td index a000fa0..1c380bd 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td @@ -47,6 +47,7 @@ def Tensor_Dialect : Dialect { let hasConstantMaterializer = 1; let dependentDialects = [ + "AffineDialect", "arith::ArithDialect", "complex::ComplexDialect", ]; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 4842196..6e82ca4 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -137,6 +137,65 @@ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect, ShapedDimOpInterface]> { } //===----------------------------------------------------------------------===// +// EmptyOp +//===----------------------------------------------------------------------===// + +def Tensor_EmptyOp : Tensor_Op<"empty", + [NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "empty tensor operation"; + + let description = [{ + `tensor.empty` is an operation that defines a tensor of a particular shape. + The shape could be dynamic or static. The contents of the tensor are + unspecified and the only purpose of the op result is to materialize the + specified shape in IR and make it available to other transformations. + + `tensor.empty` is useful in transformations that expect destination style + ops. I.e., ops that implement `DestinationStyleOpInterface`. Ops that are + not in destination style can be made compatible with such transformations + with a `tensor.empty` destination. + + Note: This op can be lowered to a `bufferization.alloc_tensor`, at which + point it turns into an explicit buffer allocation. + }]; + + let arguments = (ins Variadic:$dynamicSizes); + + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = "`(`$dynamicSizes`)` attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + RankedTensorType getType() { + return getResult().getType().cast(); + } + + // Return both static and dynamic sizes as a list of `OpFoldResult`. + SmallVector getMixedSizes(); + + // Return the Value of the dynamic size of the tensor at dimension `idx`. + // Asserts that the shape is dynamic at that `idx`. + Value getDynamicSize(unsigned idx); + }]; + + let builders = [ + // Build with fully static sizes. + OpBuilder<(ins "ArrayRef":$staticShape, "Type":$elementType)>, + + // Build with mixed static/dynamic sizes. + OpBuilder<(ins "ArrayRef":$staticShape, "Type":$elementType, + "ValueRange":$dynamicSizes)>, + + // Build with mixed static/dynamic sizes. + OpBuilder<(ins "ArrayRef":$sizes, "Type":$elementType)> + ]; + + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h index 2ca5562..96b7f99 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h @@ -62,7 +62,7 @@ namespace tensor { /// We can construct %2 by generating the following, which only uses `%0`: /// /// ``` -/// %dest = linalg.init_tensor [%size0, %size1] : tensor +/// %dest = tensor.empty(%size0, %size1) : tensor /// %1 = tensor.dim %0, %c1 : tensor<3x?x?x11x?xf32> /// %2 = tensor.dim %0, %c2 : tensor<3x?x?x11x?xf32> /// %3 = tensor.dim %0, %c4 : tensor<3x?x?x11x?xf32> diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index f693b35..1c584d2 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -400,7 +400,7 @@ getLinearizedDimensions(ArrayRef reassociationIndices); /// ``` /// This class helps build the below IR to replace %2: /// ``` -/// %dest = linalg.init_tensor() : tensor<10x10xf32> +/// %dest = tensor.empty() : tensor<10x10xf32> /// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> { /// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv) /// %3:3 = arith.delinearize_index %iv into (3, 7, 11) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index b54dab8..a634e2b 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -556,7 +556,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); SmallVector opResultTypes; - SmallVector initTensors; + SmallVector emptyTensors; SmallVector dynDims; dynDims.resize(results.front().getType().cast().getRank()); @@ -573,13 +573,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, for (auto result : results) { auto resultTy = result.getType().template cast(); - initTensors.push_back(rewriter.create( - loc, filteredDims, resultTy.getShape(), resultTy.getElementType())); + emptyTensors.push_back(rewriter.create( + loc, resultTy.getShape(), resultTy.getElementType(), filteredDims)); opResultTypes.push_back(result.getType()); } auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range( - initTensors, [](Value v) { return getElementTypeOrSelf(v); })); + emptyTensors, [](Value v) { return getElementTypeOrSelf(v); })); SmallVector operands; SmallVector indexingMaps; @@ -623,7 +623,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, bool didEncounterError = false; auto linalgOp = rewriter.create( - loc, opResultTypes, operands, initTensors, indexingMaps, + loc, opResultTypes, operands, emptyTensors, indexingMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value opResult = createLinalgBodyCalculationForElementwiseOp( @@ -771,10 +771,11 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType()); // First fill the output buffer with the init value. - auto initTensor = rewriter - .create(loc, dynDims, reduceShape, - resultTy.getElementType()) - .getResult(); + auto emptyTensor = + rewriter + .create(loc, reduceShape, resultTy.getElementType(), + dynDims) + .getResult(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) @@ -784,7 +785,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, auto fillValue = rewriter.create(loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, - ValueRange{initTensor}) + ValueRange{emptyTensor}) .result(); SmallVector srcExprs; @@ -1104,8 +1105,8 @@ public: SmallVector filteredDims = condenseValues(dynDims); - auto initTensor = rewriter.create( - loc, filteredDims, resultTy.getShape(), resultTy.getElementType()); + auto emptyTensor = rewriter.create( + loc, resultTy.getShape(), resultTy.getElementType(), filteredDims); SmallVector affineMaps = { AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, @@ -1113,7 +1114,7 @@ public: rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.replaceOpWithNewOp( - op, resultTy, op.getInput1(), ValueRange{initTensor}, affineMaps, + op, resultTy, op.getInput1(), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, *args.begin()); @@ -1219,12 +1220,12 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Construct the indexing maps needed for linalg.generic ops. - Value initTensor = rewriter.create( - loc, ArrayRef({dynDims}), outputTy.getShape(), - outputTy.getElementType()); + Value emptyTensor = rewriter.create( + loc, outputTy.getShape(), outputTy.getElementType(), + ArrayRef({dynDims})); auto linalgOp = rewriter.create( - loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps, + loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { @@ -1341,14 +1342,14 @@ public: if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") return failure(); - auto initTensor = rewriter.create( - loc, dynamicDims, resultTy.getShape(), resultElementTy); + auto emptyTensor = rewriter.create( + loc, resultTy.getShape(), resultElementTy, dynamicDims); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto genericOp = rewriter.create( - loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps, + loc, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -1647,15 +1648,15 @@ struct ConcatConverter : public OpConversionPattern { } sizes[axis] = resultDimSize; - Value init = rewriter.create( - loc, dynDims, resultType.getShape(), resultType.getElementType()); + Value emptyTensor = rewriter.create( + loc, resultType.getShape(), resultType.getElementType(), dynDims); Value zeroVal = rewriter.createOrFold( loc, rewriter.getZeroAttr(resultType.getElementType())); - Value result = - rewriter - .create(loc, ValueRange{zeroVal}, ValueRange{init}) - .result(); + Value result = rewriter + .create(loc, ValueRange{zeroVal}, + ValueRange{emptyTensor}) + .result(); auto toOpFoldResult = [](Value v) -> OpFoldResult { auto op = v.getDefiningOp(); @@ -1700,16 +1701,16 @@ public: Value axisDimSize = rewriter.create(loc, input, axis); // First fill the output buffer with the init value. - auto initTensor = rewriter - .create( - loc, ArrayRef({dynDims}), - inputTy.getShape(), inputTy.getElementType()) - .getResult(); + auto emptyTensor = rewriter + .create(loc, inputTy.getShape(), + inputTy.getElementType(), + ArrayRef({dynDims})) + .getResult(); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.replaceOpWithNewOp( - op, resultTy, ArrayRef({}), ValueRange{initTensor}, affineMaps, + op, resultTy, ArrayRef({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { llvm::SmallVector indices; @@ -1771,8 +1772,8 @@ struct TileConverter : public OpConversionPattern { } } - auto initTensor = rewriter.create( - op.getLoc(), dynDims, genericShape, elementTy); + auto emptyTensor = rewriter.create( + op.getLoc(), genericShape, elementTy, dynDims); // We needs to map the input shape to the non-broadcasted dimensions. SmallVector dimExprs; @@ -1789,7 +1790,7 @@ struct TileConverter : public OpConversionPattern { auto genericOp = rewriter.create( loc, RankedTensorType::get(genericShape, elementTy), input, - ValueRange{initTensor}, affineMaps, + ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(genericShape.size()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(op.getLoc(), *args.begin()); @@ -1919,24 +1920,23 @@ public: } // First fill the output buffer for the index. - auto initTensorIdx = - rewriter - .create(loc, dynDims, resultTy.getShape(), - outElementTy) - .getResult(); + auto emptyTensorIdx = rewriter + .create(loc, resultTy.getShape(), + outElementTy, dynDims) + .getResult(); auto fillValueIdx = rewriter.create( loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = rewriter .create(loc, ValueRange{fillValueIdx}, - ValueRange{initTensorIdx}) + ValueRange{emptyTensorIdx}) .result(); // Second fill the output buffer for the running max. - auto initTensorMax = rewriter - .create( - loc, dynDims, resultTy.getShape(), inElementTy) - .getResult(); + auto emptyTensorMax = rewriter + .create(loc, resultTy.getShape(), + inElementTy, dynDims) + .getResult(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); @@ -1949,7 +1949,7 @@ public: auto filledTensorMax = rewriter .create(loc, ValueRange{fillValueMax}, - ValueRange{initTensorMax}) + ValueRange{emptyTensorMax}) .result(); // We need to reduce along the arg-max axis, with parallel operations along @@ -2031,10 +2031,10 @@ public: auto loc = op.getLoc(); - auto initTensor = + auto emptyTensor = rewriter - .create(loc, dynamicDims, resultTy.getShape(), - resultElementTy) + .create(loc, resultTy.getShape(), resultElementTy, + dynamicDims) .getResult(); SmallVector affineMaps = { @@ -2046,7 +2046,7 @@ public: auto genericOp = rewriter.create( loc, ArrayRef({resultTy}), ValueRange{indices}, - ValueRange{initTensor}, affineMaps, + ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; @@ -2091,18 +2091,17 @@ public: } } - auto initTensor = - rewriter - .create(loc, dynDims, resultTy.getShape(), - resultElementTy) - .getResult(); + auto emptyTensor = rewriter + .create(loc, resultTy.getShape(), + resultElementTy, dynDims) + .getResult(); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto genericOp = rewriter.create( - loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps, + loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); rewriter.replaceOp(op, genericOp.getResult(0)); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 513ade4..f08ac19 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -241,12 +241,12 @@ public: weightPermValue); Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); - Value initTensor = rewriter.create( - loc, filteredDims, resultTy.getShape(), resultETy); + Value emptyTensor = rewriter.create( + loc, resultTy.getShape(), resultETy, filteredDims); Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, - ValueRange{initTensor}) + ValueRange{emptyTensor}) .result(); // Extract the attributes for convolution. @@ -268,8 +268,8 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); - Value biasInitTensor = rewriter.create( - loc, filteredDims, resultTy.getShape(), resultETy); + Value biasEmptyTensor = rewriter.create( + loc, resultTy.getShape(), resultETy, filteredDims); if (isQuantized) { auto quantizationInfo = @@ -289,7 +289,7 @@ public: Value result = rewriter .create( - loc, resultTy, ValueRange({bias, conv}), biasInitTensor, + loc, resultTy, ValueRange({bias, conv}), biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { @@ -311,7 +311,7 @@ public: Value result = rewriter .create( - loc, resultTy, ValueRange({bias, conv}), biasInitTensor, + loc, resultTy, ValueRange({bias, conv}), biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { @@ -426,16 +426,16 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); - Value initTensor = rewriter.create( - loc, filteredDims, linalgConvTy.getShape(), resultETy); + Value emptyTensor = rewriter.create( + loc, linalgConvTy.getShape(), resultETy, filteredDims); Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, - ValueRange{initTensor}) + ValueRange{emptyTensor}) .result(); - Value biasInitTensor = rewriter.create( - loc, filteredDims, resultTy.getShape(), resultETy); + Value biasEmptyTensor = rewriter.create( + loc, resultTy.getShape(), resultETy, filteredDims); if (!isQuantized) { Value conv = rewriter .create( @@ -452,7 +452,7 @@ public: rewriter .create( loc, resultTy, ValueRange({bias, convReshape}), - biasInitTensor, indexingMaps, + biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { @@ -479,7 +479,7 @@ public: rewriter .create( loc, resultTy, ValueRange({bias, convReshape}), - biasInitTensor, indexingMaps, + biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { @@ -527,11 +527,11 @@ public: auto zeroAttr = rewriter.getZeroAttr(outputElementTy); Value zero = rewriter.create(loc, zeroAttr); - auto initTensor = rewriter.create( - loc, filteredDims, outputTy.getShape(), outputTy.getElementType()); + auto emptyTensor = rewriter.create( + loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, - ValueRange{initTensor}) + ValueRange{emptyTensor}) .result(); if (!op.getQuantizationInfo()) { rewriter.replaceOpWithNewOp( @@ -597,15 +597,15 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); - auto initTensor = rewriter.create( - loc, filteredDims, outputTy.getShape(), outputTy.getElementType()); + auto emptyTensor = rewriter.create( + loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); // When quantized, the input elemeny type is not the same as the output Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy); Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, - ValueRange{initTensor}) + ValueRange{emptyTensor}) .result(); SmallVector permutation{1, 0}; @@ -621,10 +621,10 @@ public: Value transposedWeight = rewriter.create( loc, newWeightTy, weight, permutationValue); - auto biasInitTensor = + auto biasEmptyTensor = rewriter - .create(loc, filteredDims, - outputTy.getShape(), outputETy) + .create(loc, outputTy.getShape(), outputETy, + filteredDims) ->getResults(); if (!op.getQuantizationInfo()) { @@ -637,7 +637,7 @@ public: Value result = rewriter .create( - loc, outputTy, ValueRange({bias, matmul}), biasInitTensor, + loc, outputTy, ValueRange({bias, matmul}), biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { @@ -665,7 +665,7 @@ public: Value result = rewriter .create( - loc, outputTy, ValueRange({bias, matmul}), biasInitTensor, + loc, outputTy, ValueRange({bias, matmul}), biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { @@ -732,21 +732,21 @@ public: Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. - Value initTensor = rewriter.create( - loc, dynamicDims, resultTy.getShape(), resultTy.getElementType()); + Value emptyTensor = rewriter.create( + loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims); - Value filledInitTensor = + Value filledEmptyTensor = rewriter .create(loc, ValueRange{initialValue}, - ValueRange{initTensor}) + ValueRange{emptyTensor}) .result(); Value fakeWindowDims = - rewriter.create(loc, kernel, resultETy); + rewriter.create(loc, kernel, resultETy); rewriter.replaceOpWithNewOp( op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, - filledInitTensor, strideAttr, dilationAttr); + filledEmptyTensor, strideAttr, dilationAttr); return success(); } }; @@ -794,24 +794,24 @@ public: Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. - Value poolInitTensor = rewriter.create( - loc, dynamicDims, accTy.getShape(), accETy); + Value poolEmptyTensor = rewriter.create( + loc, accTy.getShape(), accETy, dynamicDims); - Value filledInitTensor = + Value filledEmptyTensor = rewriter .create(loc, ValueRange{initialValue}, - ValueRange{poolInitTensor}) + ValueRange{poolEmptyTensor}) .result(); Value fakeWindowDims = - rewriter.create(loc, kernel, accETy); + rewriter.create(loc, kernel, accETy); // Sum across the pooled region. Value poolingOp = rewriter .create( loc, ArrayRef{accTy}, ValueRange{paddedInput, fakeWindowDims}, - filledInitTensor, strideAttr, dilationAttr) + filledEmptyTensor, strideAttr, dilationAttr) .getResult(0); // Normalize the summed value by the number of elements grouped in each @@ -819,12 +819,12 @@ public: auto poolingOpTy = poolingOp.getType().cast(); auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); - Value genericInitTensor = rewriter.create( - loc, dynamicDims, resultTy.getShape(), resultETy); + Value genericEmptyTensor = rewriter.create( + loc, resultTy.getShape(), resultETy, dynamicDims); auto genericOp = rewriter.create( loc, ArrayRef({resultTy}), ValueRange{poolingOp}, - ValueRange{genericInitTensor}, + ValueRange{genericEmptyTensor}, ArrayRef({affineMap, affineMap}), getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 7014864..6289369 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -545,11 +545,11 @@ struct FoldFillWithPad final : public OpRewritePattern { auto oldResultType = padOp.getResultType(); SmallVector staticShape(oldResultType.getRank(), ShapedType::kDynamicSize); - auto newInitOp = rewriter.create( - padOp.getLoc(), reifiedShape.front(), staticShape, - oldResultType.getElementType()); + auto emptyTensor = rewriter.create( + padOp.getLoc(), staticShape, oldResultType.getElementType(), + reifiedShape.front()); auto newFillOp = rewriter.create( - fillOp.getLoc(), ValueRange{padValue}, ValueRange{newInitOp}); + fillOp.getLoc(), ValueRange{padValue}, ValueRange{emptyTensor}); rewriter.replaceOpWithNewOp(padOp, oldResultType, newFillOp.result()); @@ -1387,285 +1387,6 @@ LogicalResult ReduceOp::verify() { } //===----------------------------------------------------------------------===// -// InitTensorOp -//===----------------------------------------------------------------------===// - -void InitTensorOp::build(OpBuilder &b, OperationState &result, - ArrayRef sizes, Type elementType, - ArrayRef attrs) { - SmallVector dynamicSizes; - SmallVector staticSizes; - dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); - auto resultType = RankedTensorType ::get(staticSizes, elementType); - build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); - result.addAttributes(attrs); -} - -LogicalResult InitTensorOp::verify() { - RankedTensorType resultType = getType(); - SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( - getStaticSizes().cast(), - [](Attribute a) -> int64_t { return a.cast().getInt(); })); - - if (failed(verifyListOfOperandsOrIntegers( - *this, "sizes", resultType.getRank(), getStaticSizes(), getSizes(), - ShapedType::isDynamic))) - return failure(); - - if (getStaticSizes().size() != static_cast(resultType.getRank())) - return emitError("expected ") << resultType.getRank() << " sizes values"; - - Type expectedType = InitTensorOp::inferResultType( - staticSizes, resultType.getElementType(), resultType.getEncoding()); - if (resultType != expectedType) { - return emitError("specified type ") - << resultType << " does not match the inferred type " - << expectedType; - } - return success(); -} - -Type InitTensorOp::inferResultType(ArrayRef staticSizes, - Type elementType, Attribute encoding) { - return RankedTensorType::get(staticSizes, elementType, encoding); -} - -SmallVector InitTensorOp::getMixedSizes() { - SmallVector mixedSizes; - mixedSizes.reserve(getType().getRank()); - unsigned dynamicValIndex = 0; - for (Attribute attr : getStaticSizes()) { - auto intAttr = attr.cast(); - if (!ShapedType::isDynamic(intAttr.getInt())) { - mixedSizes.push_back(intAttr); - continue; - } - mixedSizes.push_back(getSizes()[dynamicValIndex++]); - } - return mixedSizes; -} - -namespace { -/// Change the type of the result of a `linalg.init_tensor` by making the result -/// type statically sized along dimension that in the original operation where -/// defined as dynamic, but the size was defined using a `constant` op. For -/// example -/// -/// %c5 = arith.constant 5: index -/// %0 = linalg.init_tensor [%arg0, %c5] : tensor -/// -/// to -/// -/// %0 = linalg.init_tensor [%arg0, 5] : tensor -struct ReplaceStaticShapeDims : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InitTensorOp op, - PatternRewriter &rewriter) const override { - SmallVector dynamicSizes; - SmallVector staticSizes; - for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { - // If the size is already static, nothing to do. - if (!op.isDynamicSize(i)) { - staticSizes.push_back(op.getStaticSize(i)); - continue; - } - - // If the size is dynamic but defined using a `constant` op, get the - // constant value to find the static size to use. - unsigned operandNum = op.getIndexOfDynamicSize(i); - Value sizeOperand = op.getOperand(operandNum); - if (auto constantIndexOp = - sizeOperand.getDefiningOp()) { - staticSizes.push_back(constantIndexOp.value()); - continue; - } - - // Fallback case. Keep the size dynamic. - dynamicSizes.push_back(sizeOperand); - staticSizes.push_back(ShapedType::kDynamicSize); - } - RankedTensorType newType = - RankedTensorType::get(staticSizes, op.getType().getElementType()); - if (newType == op.getType()) - return failure(); - auto newOp = - rewriter.create(op.getLoc(), newType, dynamicSizes, - rewriter.getI64ArrayAttr(staticSizes)); - rewriter.replaceOpWithNewOp(op, op.getType(), newOp); - return success(); - } -}; -} // namespace - -namespace { -/// Since `init_tensor` operation creates a tensor needed only for its shape, a -/// slice of this is also needed only for its shape. The result can be -/// replaced by a new init_tensor operation of the same size as the extract -/// slice op. -struct FoldInitTensorWithExtractSliceOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, - PatternRewriter &rewriter) const override { - if (!sliceOp.getSource().getDefiningOp()) - return failure(); - // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved - // as well as its result type. - rewriter.replaceOpWithNewOp( - sliceOp, sliceOp.getSizes(), - sliceOp.getResult().getType().cast().getShape(), - sliceOp.getSourceType().getElementType()); - return success(); - } -}; - -template -struct FoldInitTensorWithTensorReshapeOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, - PatternRewriter &rewriter) const override { - if (!reshapeOp.getSrc().template getDefiningOp()) - return failure(); - Location loc = reshapeOp.getLoc(); - ReifiedRankedShapedTypeDims resultShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - cast(reshapeOp.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, - resultShapes)) || - !llvm::hasSingleElement(resultShapes)) - return failure(); - Value initTensor = rewriter.create( - loc, getAsOpFoldResult(resultShapes[0]), - reshapeOp.getResultType().getElementType()); - if (initTensor.getType() != reshapeOp.getResultType()) { - rewriter.replaceOpWithNewOp( - reshapeOp, reshapeOp.getResultType(), initTensor); - } else { - rewriter.replaceOp(reshapeOp, initTensor); - } - return success(); - } -}; - -struct FoldInitTensorWithDimOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::DimOp dimOp, - PatternRewriter &rewriter) const override { - Optional maybeConstantIndex = dimOp.getConstantIndex(); - auto initTensorOp = dimOp.getSource().getDefiningOp(); - if (!initTensorOp || !maybeConstantIndex) - return failure(); - if (!initTensorOp.isDynamicSize(*maybeConstantIndex)) - return failure(); - rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex)); - return success(); - } -}; - -/// Canonicalize -/// -/// ```mlir -/// %0 = linalg.init_tensor [%d0, %d1] : tensor -/// %1 = tensor.cast %0 : tensor to tensor<4x?xf32> -/// ``` -/// -/// into -/// -/// ```mlir -/// %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32> -/// ``` -/// -/// This assumes the input program is correct in terms of its shape. So it -/// is safe to assume that `%d0` is in fact 4. If that was not the case, the -/// input program is wrong to begin with, so its undefined behavior anyway (i.e. -/// this optimization can still triggering without violating program semantics). -struct FoldInitTensorWithTensorCastOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::CastOp castOp, - PatternRewriter &rewriter) const override { - if (!canFoldIntoProducerOp(castOp)) - return failure(); - auto producer = castOp.getSource().getDefiningOp(); - if (!producer) - return failure(); - - auto resultType = castOp->getResult(0).getType().cast(); - ArrayRef resultShape = resultType.getShape(); - SmallVector currMixedSizes = producer.getMixedSizes(); - SmallVector newMixedSizes; - newMixedSizes.reserve(currMixedSizes.size()); - assert(resultShape.size() == currMixedSizes.size() && - "mismatch in result shape and sizes of init_tensor op"); - for (auto it : llvm::zip(resultShape, currMixedSizes)) { - int64_t newDim = std::get<0>(it); - OpFoldResult currDim = std::get<1>(it); - // Case 1: The init tensor dim is static. Check that the tensor cast - // result dim matches. - if (auto attr = currDim.dyn_cast()) { - if (ShapedType::isDynamic(newDim) || - newDim != attr.cast().getInt()) { - // Something is off, the cast result shape cannot be more dynamic than - // the init tensor result shape (enforced by `canFoldIntoProducer`). - // Abort for now. - return rewriter.notifyMatchFailure( - producer, "mismatch in static value of shape of init " - "tensor result and cast result"); - } - newMixedSizes.push_back(attr); - continue; - } - - // Case 2 : The tensor cast shape is static, but init tensor result shape - // is dynamic. - if (!ShapedType::isDynamic(newDim)) { - newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); - continue; - } - - // Case 3 : The tensor cast shape is dynamic and init tensor result shape - // is dynamic. Use the dynamic value from the init tensor op. - newMixedSizes.push_back(currDim); - } - - rewriter.replaceOpWithNewOp(castOp, newMixedSizes, - resultType.getElementType()); - return success(); - } -}; - -} // namespace - -void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add, - FoldInitTensorWithTensorReshapeOp, - ReplaceStaticShapeDims>(context); -} - -LogicalResult InitTensorOp::reifyResultShapes( - OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - auto shapes = llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, getType().getRank()), [&](int64_t dim) -> Value { - if (isDynamicSize(dim)) - return getDynamicSize(dim); - return builder.create(getLoc(), - getStaticSize(dim)); - })); - reifiedReturnShapes.emplace_back(std::move(shapes)); - return success(); -} - -//===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp index 6bdb33b..cebc978 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -46,7 +46,7 @@ namespace { /// gets split into /// /// ```mlir -/// %init = linalg.init_tensor ... +/// %init = tensor.empty ... /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) /// outs(%init0, %init1, %init : ...) /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): @@ -65,7 +65,7 @@ namespace { /// After canonicalization this is expected to be /// /// ```mlir -/// %init = linalg.init_tensor ... +/// %init = tensor.empty ... /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...) /// outs(%init : ...) /// ^bb0(%b0: ... , %b1: ... , %b2: ...): @@ -186,10 +186,10 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, // Fall back path, use an `init_tensor` and identity indexing map. AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size()); - Value initTensor = rewriter.create( - loc, domain, scalarOpResult.getType()); - newInitValues.push_back(initTensor); - newResultTypes.push_back(initTensor.getType()); + Value emptyTensor = + rewriter.create(loc, domain, scalarOpResult.getType()); + newInitValues.push_back(emptyTensor); + newResultTypes.push_back(emptyTensor.getType()); peeledGenericOpIndexingMaps.push_back(indexingMap); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 03b7526..baef90c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -188,7 +188,7 @@ struct LinalgDetensorize /// For the following snippet: /// ... /// ^bb1(%6: tensor, %9: tensor): - /// %7 = linalg.init_tensor [] : tensor + /// %7 = tensor.empty() : tensor /// %8 = linalg.generic #attrs /// ins(%6, %6 : tensor, tensor) /// outs(%7 : tensor) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 9a5614a..361c85a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -540,8 +540,8 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RankReducedInsertSliceOp>( context); linalg::FillOp::getCanonicalizationPatterns(patterns, context); - linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context); tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); + tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index b6d1d21..45bc4a8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1005,7 +1005,7 @@ static bool isDimSequencePreserved(AffineMap indexingMap, // ```mlir // #map = affine_map<(d0, d1) -> (d0, d1)> // %1 = tensor.expand_shape %0 [[0, 1]] : tensor into tensor -// %2 = linalg.init_tensor [..] : tensor +// %2 = tensor.empty [..] : tensor // %3 = linalg.generic { // indexing_maps = [#map, #map], // iterator_types = ["parallel" ,"parallel"]} @@ -1016,7 +1016,7 @@ static bool isDimSequencePreserved(AffineMap indexingMap, // // ```mlir // #map = affine_map<(d0) -> (d0)> -// %2 = linalg.init_tensor [..] : tensor +// %2 = tensor.empty [..] : tensor // %3 = linalg.generic { // indexing_maps = [#map, #map], // iterator_types = ["parallel"]} @@ -1030,7 +1030,7 @@ static bool isDimSequencePreserved(AffineMap indexingMap, // #map0 = affine_map<(d0, d1) -> (d0, d1)> // #map1 = affine_map<(d0, d1) -> (d1, d0)> // %1 = tensor.expand_shape %0 [[0, 1]] : tensor into tensor -// %2 = linalg.init_tensor [..] : tensor<4x?xf32> +// %2 = tensor.empty [..] : tensor<4x?xf32> // %2 = linalg.generic { // indexing_maps = [#map0, #map1], // iterator_types = ["parallel" ,"parallel"]} @@ -1643,8 +1643,8 @@ public: //===---------------------------------------------------------------------===// namespace { -/// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if -/// the value of the `outs` operand is not used within the op. This is only +/// Forces `outs` operands of linalg operations to use `tensor.empty` if the +/// value of the `outs` operand is not used within the op. This is only /// implemented for `linalg.generic` operations for now, but should hold for all /// linalg structured ops. struct RemoveOutsDependency : public OpRewritePattern { @@ -1666,8 +1666,8 @@ struct RemoveOutsDependency : public OpRewritePattern { if (sparse_tensor::getSparseTensorEncoding(operandVal.getType())) continue; - // If outs is already an `init_tensor` operation, nothing to do. - auto definingOp = operandVal.getDefiningOp(); + // If outs is already an `empty` operation, nothing to do. + auto definingOp = operandVal.getDefiningOp(); if (definingOp) continue; modifiedOutput = true; @@ -1678,10 +1678,10 @@ struct RemoveOutsDependency : public OpRewritePattern { dynamicDims.push_back(rewriter.createOrFold( loc, operandVal, dim.index())); } - Value initTensor = rewriter.create( - loc, dynamicDims, operandType.getShape(), - operandType.getElementType()); - op->setOperand(opOperand->getOperandNumber(), initTensor); + Value emptyTensor = rewriter.create( + loc, operandType.getShape(), operandType.getElementType(), + dynamicDims); + op->setOperand(opOperand->getOperandNumber(), emptyTensor); } } if (!modifiedOutput) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index b9eb930..3740633 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -37,7 +37,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { /// 1. `v.getType() == t` /// 2. If an operand of `op` has type `t`, let `operand_first` be the first /// such operand. Then`v == operand_first`. -/// 3. Otherwise, v is a newly created `linalg::InitTensorOp` with: +/// 3. Otherwise, v is a newly created `tensor::EmptyOp` with: /// a. Static and dynamic dims extracted from the first operand of `op`. /// b. Elemental type equal to the elemental type of `t`. /// @@ -71,8 +71,8 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b); - res.push_back(b.create( - loc, dynamicShape, staticShape, rankedTensorType.getElementType())); + res.push_back(b.create( + loc, staticShape, rankedTensorType.getElementType(), dynamicShape)); } return res; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp index ad33ff3..866b411 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -76,14 +76,14 @@ struct FusePadOp : OpRewritePattern { // Create the tensor of same size as output of the pad op. RankedTensorType padResultType = padOp.getResultType(); auto resultSizes = getAsOpFoldResult(resultShape[0]); - auto initTensor = rewriter.create( + auto emptyTensor = rewriter.create( loc, resultSizes, padResultType.getElementType()); // Fill the tensor with the pad value. // TODO: There is an option to fill only the boundaries. For now just // filling the whole tensor. auto fillTensor = - rewriter.create(loc, padValue, initTensor.getResult()); + rewriter.create(loc, padValue, emptyTensor.getResult()); // Construct a slice of the fill result that is to be replaced with the // result of the generic op. The low pad values are the offsets, the size of diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 124d733..2ca15dc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -431,9 +431,9 @@ FailureOr mlir::linalg::hoistPaddingOnTensors( llvm::append_range(packedShape, transposedTensorType->getShape()); auto packedTensorType = RankedTensorType::get( packedShape, transposedTensorType->getElementType()); - Value packedTensor = b.create( - loc, dynamicTensorSizes, packedTensorType.getShape(), - packedTensorType.getElementType()); + Value packedTensor = b.create( + loc, packedTensorType.getShape(), packedTensorType.getElementType(), + dynamicTensorSizes); // Clone the operations involved in the backward slice, iteratively stepping // into the loops that we encounter. @@ -543,11 +543,10 @@ FailureOr mlir::linalg::hoistPaddingOnTensors( // Transpose the packed tensor back to the original storage order. if (!transposeVector.empty()) { - Value initTensor = - b.create(loc, ValueRange{}, paddedTensorType.getShape(), - paddedTensorType.getElementType()); + Value emptyTensor = b.create( + loc, paddedTensorType.getShape(), paddedTensorType.getElementType()); transposeOps.push_back( - makeTransposeOp(b, loc, newResult, initTensor, transposeVector)); + makeTransposeOp(b, loc, newResult, emptyTensor, transposeVector)); newResult = transposeOps.back()->getResult(0); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/InitTensorToAllocTensor.cpp b/mlir/lib/Dialect/Linalg/Transforms/InitTensorToAllocTensor.cpp index 518d737..c82bbce 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/InitTensorToAllocTensor.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InitTensorToAllocTensor.cpp @@ -9,52 +9,52 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { -#define GEN_PASS_DEF_LINALGINITTENSORTOALLOCTENSOR +#define GEN_PASS_DEF_EMPTYTENSORTOALLOCTENSOR #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::bufferization; -using namespace mlir::linalg; +using namespace mlir::tensor; namespace { -struct InitTensorLoweringPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct EmptyTensorLoweringPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InitTensorOp op, + LogicalResult matchAndRewrite(tensor::EmptyOp op, PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getSizes()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getDynamicSizes()); return success(); } }; -struct LinalgInitTensorToAllocTensor - : public impl::LinalgInitTensorToAllocTensorBase< - LinalgInitTensorToAllocTensor> { - LinalgInitTensorToAllocTensor() = default; +struct EmptyTensorToAllocTensor + : public impl::EmptyTensorToAllocTensorBase { + EmptyTensorToAllocTensor() = default; void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { registry - .insert(); + .insert(); } }; } // namespace -void LinalgInitTensorToAllocTensor::runOnOperation() { +void EmptyTensorToAllocTensor::runOnOperation() { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); - patterns.insert(op->getContext()); + patterns.insert(op->getContext()); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); } -std::unique_ptr mlir::createLinalgInitTensorToAllocTensorPass() { - return std::make_unique(); +std::unique_ptr mlir::createEmptyTensorToAllocTensorPass() { + return std::make_unique(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index 80bdc06..7a25713 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -196,20 +196,20 @@ FailureOr mlir::linalg::splitReduction( b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); } } - Value initOrAllocTensor; + Value emptyOrAllocTensor; if (useAlloc) { - initOrAllocTensor = b.create( + emptyOrAllocTensor = b.create( loc, RankedTensorType::get(newOutputShape, op.getRegionOutputArgs()[0].getType()), ValueRange{}); } else { - initOrAllocTensor = b.create( + emptyOrAllocTensor = b.create( loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); } Value constantOp = b.create(loc, identity); Value identityTensor = - b.create(op->getLoc(), constantOp, initOrAllocTensor) + b.create(op->getLoc(), constantOp, emptyOrAllocTensor) .getResult(0); newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, @@ -225,7 +225,7 @@ FailureOr mlir::linalg::splitReduction( // Create the new op matching the original op with an extra parallel // dimension. GenericOp genericOp = b.create( - loc, TypeRange({initOrAllocTensor.getType()}), newInputs, + loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs, ValueRange({identityTensor}), newMaps, newIteratorTypes); b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), genericOp.getRegion().begin()); @@ -259,9 +259,10 @@ FailureOr mlir::linalg::splitReduction( }); b.replaceOp(op, reduction.getResults()); - return SplitReductionResult{ - initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp(), - cast(genericOp.getOperation()), reduction}; + return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(), + identityTensor.getDefiningOp(), + cast(genericOp.getOperation()), + reduction}; } /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) @@ -357,7 +358,7 @@ FailureOr mlir::linalg::splitReductionByScaling( // TODO: generalize when multi-reduction support is available. SmallVector newOutputs; newOutputs.reserve(op.getNumOutputs()); - SmallVector initOrAllocTensorOps; + SmallVector emptyOrAllocTensorOps; SmallVector fillOps; fillOps.reserve(op.getNumOutputs()); for (auto it : llvm::zip(op.getOutputs(), neutralElements)) { @@ -367,19 +368,19 @@ FailureOr mlir::linalg::splitReductionByScaling( reductionDimSize / splitFactor, insertSplitDimension); SmallVector dims = tensor::createDynamicDimValues(b, loc, rankedTensor); - Value initOrAllocTensor; + Value emptyOrAllocTensor; if (useAlloc) { - initOrAllocTensor = + emptyOrAllocTensor = b.create(loc, newT, dims); } else { - initOrAllocTensor = b.create( - loc, dims, newT.getShape(), t.getElementType()); + emptyOrAllocTensor = b.create(loc, newT.getShape(), + t.getElementType(), dims); } Value constantOp = b.create(loc, std::get<1>(it)); fillOps.push_back( - b.create(op->getLoc(), constantOp, initOrAllocTensor)); + b.create(op->getLoc(), constantOp, emptyOrAllocTensor)); newOutputs.push_back(fillOps.back().getResult(0)); - initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp()); + emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp()); } // Step 2. Reindex / expand indexing maps. @@ -406,7 +407,7 @@ FailureOr mlir::linalg::splitReductionByScaling( auto newInputs = llvm::to_vector<4>(op.getInputs()); // Add a single shape-only tensor to carry the dimensions without resorting to // more complex inversions. - newInputs.push_back(b.create( + newInputs.push_back(b.create( loc, ArrayRef{reductionDimSize / splitFactor, splitFactor}, b.getIntegerType(1))); // Output tensors are already good to go. @@ -469,7 +470,7 @@ FailureOr mlir::linalg::splitReductionByScaling( // TODO: extend when multi-reduction support is available. assert(fillOps.size() == results.size() && results.size() == 1); b.replaceOp(op, results.front()->getResults()); - return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(), + return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(), cast(genericOp.getOperation()), results.front()}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index b389ceb..d377906 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -720,10 +720,9 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); + tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); - - InitTensorOp::getCanonicalizationPatterns(patterns, ctx); tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); ctx->getLoadedDialect()->getCanonicalizationPatterns(patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 575dfbbc..fafefae 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -630,7 +630,7 @@ static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { return SmallVector(nParallelLoops, getParallelIteratorTypeName()); } -/// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to +/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to /// initialize with pad_val) and GenericOp (to copy contents). LogicalResult PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, @@ -661,13 +661,13 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, Location loc = padOp.getLoc(); SmallVector indices(resultShapedType.getRank(), rewriter.create(loc, 0)); - Value initTensor = rewriter.create( + Value emptyTensor = rewriter.create( loc, resultShapedType.getShape(), resultShapedType.getElementType()); // Initialize tensor with the pad value Value tmpTensor = rewriter .create(loc, ValueRange{padValue}, - ValueRange{initTensor}) + ValueRange{emptyTensor}) .result(); // Copy original contents into new tensor @@ -725,8 +725,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, }; auto resultType = padOp.getResultType(); - // Compute size of InitTensorOp. Any combination of static/dynamic is - // supported. + // Compute size of EmptyOp. Any combination of static/dynamic is supported. SmallVector dynSizes; SmallVector staticSizes; for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { @@ -744,9 +743,9 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, } // Init tensor and fill it with padding. - Value init = rewriter.create( - padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); - Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); + Value emptyTensor = rewriter.create( + padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); + Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); // Try optimize the copy of source. if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a8397f3..cc98013 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -680,7 +680,7 @@ static SmallVector ofrToIndexValues(OpBuilder &builder, Location loc, return result; } -/// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp and +/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and /// InsertSliceOp. For now, only constant padding values are supported. /// If there is enough static type information, TransferReadOps and /// TransferWriteOps may be generated instead of InsertSliceOps. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 7cd9f7f..ed6edd1 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -338,7 +338,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, /// Returns true if tensor materializes uninitialized into the computation. static bool isMaterializing(Value val) { - return val.getDefiningOp() || + return val.getDefiningOp() || val.getDefiningOp(); } diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt index 4b3dee7..020b5c1 100644 --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRTensorDialect Core LINK_LIBS PUBLIC + MLIRAffineDialect MLIRArithDialect MLIRArithUtils MLIRCastInterfaces diff --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp index 801ebd7..c1d285c 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 4fc2782..7a3f908 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -417,6 +417,283 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// +// EmptyOp +//===----------------------------------------------------------------------===// + +void EmptyOp::build(OpBuilder &builder, OperationState &result, + ArrayRef staticShape, Type elementType) { + assert(all_of(staticShape, + [](int64_t sz) { return !ShapedType::isDynamic(sz); }) && + "expected only static sizes"); + build(builder, result, staticShape, elementType, {}); +} + +void EmptyOp::build(OpBuilder &builder, OperationState &result, + ArrayRef staticShape, Type elementType, + ValueRange dynamicSizes) { + auto tensorType = RankedTensorType::get(staticShape, elementType); + build(builder, result, tensorType, dynamicSizes); +} + +void EmptyOp::build(OpBuilder &builder, OperationState &result, + ArrayRef sizes, Type elementType) { + SmallVector staticShape; + SmallVector dynamicSizes; + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape, + ShapedType::kDynamicSize); + build(builder, result, staticShape, elementType, dynamicSizes); +} + +LogicalResult EmptyOp::verify() { + if (getType().getNumDynamicDims() != + static_cast(getDynamicSizes().size())) + return emitOpError("incorrect number of dynamic sizes, has ") + << getDynamicSizes().size() << ", expected " + << getType().getNumDynamicDims(); + return success(); +} + +LogicalResult +EmptyOp::reifyResultShapes(OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); + unsigned ctr = 0; + for (int64_t i = 0; i < getType().getRank(); ++i) { + if (getType().isDynamicDim(i)) { + reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++]; + } else { + reifiedReturnShapes[0][i] = + builder.create(getLoc(), i); + } + } + return success(); +} + +Value EmptyOp::getDynamicSize(unsigned idx) { + assert(getType().isDynamicDim(idx) && "expected dynamic dim"); + unsigned ctr = 0; + for (int64_t i = 0; i < static_cast(idx); ++i) + if (getType().isDynamicDim(i)) + ++ctr; + return getDynamicSizes()[ctr]; +} + +SmallVector EmptyOp::getMixedSizes() { + SmallVector result; + unsigned ctr = 0; + OpBuilder b(getContext()); + for (int64_t i = 0; i < getType().getRank(); ++i) { + if (getType().isDynamicDim(i)) { + result.push_back(getDynamicSizes()[ctr++]); + } else { + result.push_back(b.getIndexAttr(getType().getShape()[i])); + } + } + return result; +} + +namespace { +/// Change the type of the result of a `tensor.empty` by making the result +/// type statically sized along dimensions that in the original operation were +/// defined as dynamic, but the size was defined using a `constant` op. For +/// example +/// +/// %c5 = arith.constant 5: index +/// %0 = tensor.empty(%arg0, %c5) : tensor +/// +/// to +/// +/// %0 = tensor.empty(%arg0) : tensor +struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(EmptyOp op, + PatternRewriter &rewriter) const override { + SmallVector staticShape(op.getType().getShape().begin(), + op.getType().getShape().end()); + SmallVector dynamicSizes; + + // Compute new static and dynamic sizes. + unsigned ctr = 0; + bool changedType = false; + for (int64_t i = 0; i < op.getType().getRank(); ++i) { + if (op.getType().isDynamicDim(i)) { + Value dynamicSize = op.getDynamicSizes()[ctr++]; + Optional cst = getConstantIntValue(dynamicSize); + if (cst.has_value()) { + staticShape[i] = *cst; + changedType = true; + } else { + dynamicSizes.push_back(dynamicSize); + } + } + } + + // Stop here if no dynamic size was promoted to static. + if (!changedType) + return failure(); + + auto tensorType = RankedTensorType::get( + staticShape, op.getType().getElementType(), op.getType().getEncoding()); + auto newOp = + rewriter.create(op.getLoc(), tensorType, dynamicSizes); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + return success(); + } +}; + +/// `tensor.empty` does not define any tensor contents, so a slice of a +/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. +struct FoldEmptyTensorWithExtractSliceOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + if (!sliceOp.getSource().getDefiningOp()) + return failure(); + + // ExtractSliceOp may be rank-reducing; its dynamic sizes must be + // preserved as well as its result type. + auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), + sliceOp.getType().getElementType(), + sliceOp.getType().getEncoding()); + rewriter.replaceOpWithNewOp(sliceOp, tensorType, + sliceOp.getSizes()); + return success(); + } +}; + +template +struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + if (!reshapeOp.getSrc().template getDefiningOp()) + return failure(); + Location loc = reshapeOp.getLoc(); + ReifiedRankedShapedTypeDims resultShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + cast(reshapeOp.getOperation()); + if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, + resultShapes)) || + !llvm::hasSingleElement(resultShapes)) + return failure(); + // TODO: Do not drop tensor type encoding. + Value emptyTensor = + rewriter.create(loc, getAsOpFoldResult(resultShapes[0]), + reshapeOp.getResultType().getElementType()); + if (emptyTensor.getType() != reshapeOp.getResultType()) { + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), emptyTensor); + } else { + rewriter.replaceOp(reshapeOp, emptyTensor); + } + return success(); + } +}; + +struct FoldEmptyTensorWithDimOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp dimOp, + PatternRewriter &rewriter) const override { + Optional maybeConstantIndex = dimOp.getConstantIndex(); + auto emptyTensorOp = dimOp.getSource().getDefiningOp(); + if (!emptyTensorOp || !maybeConstantIndex) + return failure(); + if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex)) + return failure(); + rewriter.replaceOp(dimOp, + emptyTensorOp.getDynamicSize(*maybeConstantIndex)); + return success(); + } +}; + +/// Canonicalize +/// +/// ```mlir +/// %0 = tensor.empty(%d0, %d1) : tensor +/// %1 = tensor.cast %0 : tensor to tensor<4x?xf32> +/// ``` +/// +/// into +/// +/// ```mlir +/// %0 = tensor.empty(%d1) : tensor<4x?xf32> +/// ``` +/// +/// This assumes the input program is correct in terms of its shape. So it is +/// safe to assume that `%d0` is in fact 4. +struct FoldEmptyTensorWithCastOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CastOp castOp, + PatternRewriter &rewriter) const override { + if (!canFoldIntoProducerOp(castOp)) + return failure(); + auto producer = castOp.getSource().getDefiningOp(); + if (!producer) + return failure(); + + auto resultType = castOp->getResult(0).getType().cast(); + ArrayRef resultShape = resultType.getShape(); + SmallVector currMixedSizes = producer.getMixedSizes(); + SmallVector newMixedSizes; + newMixedSizes.reserve(currMixedSizes.size()); + assert(resultShape.size() == currMixedSizes.size() && + "mismatch in result shape and sizes of empty op"); + for (auto it : llvm::zip(resultShape, currMixedSizes)) { + int64_t newDim = std::get<0>(it); + OpFoldResult currDim = std::get<1>(it); + // Case 1: The empty tensor dim is static. Check that the tensor cast + // result dim matches. + if (auto attr = currDim.dyn_cast()) { + if (ShapedType::isDynamic(newDim) || + newDim != attr.cast().getInt()) { + // Something is off, the cast result shape cannot be more dynamic + // than the empty tensor result shape (enforced by + // `canFoldIntoProducer`). Abort for now. + return rewriter.notifyMatchFailure( + producer, "mismatch in static value of shape of empty tensor " + "result and cast result"); + } + newMixedSizes.push_back(attr); + continue; + } + + // Case 2 : The tensor cast shape is static, but empty tensor result + // shape is dynamic. + if (!ShapedType::isDynamic(newDim)) { + newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); + continue; + } + + // Case 3 : The tensor cast shape is dynamic and empty tensor result + // shape is dynamic. Use the dynamic value from the empty tensor op. + newMixedSizes.push_back(currDim); + } + + // TODO: Do not drop tensor encoding. + rewriter.replaceOpWithNewOp(castOp, newMixedSizes, + resultType.getElementType()); + return success(); + } +}; + +} // namespace + +void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add, + FoldEmptyTensorWithReshapeOp, + ReplaceEmptyTensorStaticShapeDims>(context); +} + +//===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// @@ -430,8 +707,8 @@ LogicalResult ExtractOp::verify() { } OpFoldResult ExtractOp::fold(ArrayRef operands) { - // If this is a splat elements attribute, simply return the value. All of the - // elements of a splat attribute are the same. + // If this is a splat elements attribute, simply return the value. All of + // the elements of a splat attribute are the same. if (Attribute tensor = operands.front()) if (auto splatTensor = tensor.dyn_cast()) return splatTensor.getSplatValue(); @@ -457,8 +734,8 @@ OpFoldResult ExtractOp::fold(ArrayRef operands) { stride *= tensorType.getDimSize(i); flatIndex += indices[i] * stride; } - // Prevent out of bounds accesses. This can happen in invalid code that will - // never execute. + // Prevent out of bounds accesses. This can happen in invalid code that + // will never execute. if (static_cast(fromElementsOp.getElements().size()) <= flatIndex || flatIndex < 0) return {}; @@ -515,7 +792,8 @@ namespace { // // to just %element. // -// Consider expanding this to a template and handle all tensor cast operations. +// Consider expanding this to a template and handle all tensor cast +// operations. struct ExtractElementFromIndexCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -671,8 +949,8 @@ LogicalResult GenerateOp::reifyResultShapes( } LogicalResult GenerateOp::verify() { - // Ensure that the tensor type has as many dynamic dimensions as are specified - // by the operands. + // Ensure that the tensor type has as many dynamic dimensions as are + // specified by the operands. RankedTensorType resultTy = getType().cast(); if (getNumOperands() != resultTy.getNumDynamicDims()) return emitError("must have as many index operands as dynamic extents " @@ -908,7 +1186,8 @@ SmallVector ExpandShapeOp::getReassociationExprs() { getReassociationIndices()); } -/// Compute the RankedTensorType obtained by applying `reassociation` to `type`. +/// Compute the RankedTensorType obtained by applying `reassociation` to +/// `type`. static RankedTensorType computeTensorReshapeCollapsedType(RankedTensorType type, ArrayRef reassociation) { @@ -1006,8 +1285,8 @@ struct FoldReshapeWithConstant : OpRewritePattern { } }; -/// Reshape of a FromElements can be replaced with a FromElements of the result -/// type +/// Reshape of a FromElements can be replaced with a FromElements of the +/// result type template struct FoldReshapeWithFromElements : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1097,8 +1376,8 @@ RankedTensorType ExtractSliceOp::inferResultType( ShapedType sourceShapedTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides) { // An extract_slice op may specify only a leading subset of offset/sizes/ - // strides in which case we complete with offset=0, sizes from memref type and - // strides=1. + // strides in which case we complete with offset=0, sizes from memref type + // and strides=1. assert(static_cast(staticSizes.size()) == sourceShapedTensorType.getRank() && "unexpected staticSizes not equal to rank of source"); @@ -1122,8 +1401,8 @@ RankedTensorType ExtractSliceOp::inferResultType( } /// If the rank is reduced (i.e. the desiredResultRank is smaller than the -/// number of sizes), drop as many size 1 as needed to produce an inferred type -/// with the desired rank. +/// number of sizes), drop as many size 1 as needed to produce an inferred +/// type with the desired rank. /// /// Note that there may be multiple ways to compute this rank-reduced type: /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors. @@ -1210,8 +1489,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); } -/// Build an ExtractSliceOp with mixed static and dynamic entries packed into a -/// Range vector. +/// Build an ExtractSliceOp with mixed static and dynamic entries packed into +/// a Range vector. void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef ranges, ArrayRef attrs) { @@ -1219,8 +1498,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); } -/// Build an ExtractSliceOp with dynamic entries and custom result type. If the -/// type passed is nullptr, it is inferred. +/// Build an ExtractSliceOp with dynamic entries and custom result type. If +/// the type passed is nullptr, it is inferred. void ExtractSliceOp::build(OpBuilder &b, OperationState &result, RankedTensorType resultType, Value source, ValueRange offsets, ValueRange sizes, @@ -1393,9 +1672,9 @@ static void sliceElements(IterTy values, ArrayRef counts, } } -/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded -/// operation might introduce more constant data; Users can control their -/// heuristics by the control function. +/// Fold arith.constant and tensor.extract_slice into arith.constant. The +/// folded operation might introduce more constant data; Users can control +/// their heuristics by the control function. class ConstantOpExtractSliceFolder final : public OpRewritePattern { public: @@ -1529,8 +1808,8 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, for (OpFoldResult ofr : op.getMixedOffsets()) if (getConstantIntValue(ofr) != static_cast(0)) return failure(); - // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip - // is appropriate. + // Rank-reducing noops only need to inspect the leading dimensions: + // llvm::zip is appropriate. auto shape = shapedType.getShape(); for (auto it : llvm::zip(op.getMixedSizes(), shape)) if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it)) @@ -1541,8 +1820,8 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, return success(); } -/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, -/// we can return the InsertSliceOp's source directly. +/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same +/// slice, we can return the InsertSliceOp's source directly. // TODO: This only checks the immediate producer; extend to go up the // insert/extract chain if the slices are disjoint. static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) { @@ -1635,7 +1914,8 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, ArrayAttr staticOffsets, ArrayAttr staticSizes, ArrayAttr staticStrides, ShapedType *expectedType = nullptr) { - // insert_slice is the inverse of extract_slice, use the same type inference. + // insert_slice is the inverse of extract_slice, use the same type + // inference. RankedTensorType expected = ExtractSliceOp::inferResultType( dstType, extractFromI64ArrayAttr(staticOffsets), extractFromI64ArrayAttr(staticSizes), @@ -1759,9 +2039,9 @@ public: Value toInsert = insertSliceOp.getSource(); if (sourceType != insertSliceOp.getSourceType()) { OpBuilder::InsertionGuard g(rewriter); - // The only difference between InsertSliceOp and ParallelInsertSliceOp is - // the the insertion point is just before the ParallelCombiningOp in the - // parallel case. + // The only difference between InsertSliceOp and ParallelInsertSliceOp + // is the the insertion point is just before the ParallelCombiningOp in + // the parallel case. if (std::is_same::value) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); toInsert = rewriter.create(insertSliceOp.getLoc(), @@ -1774,9 +2054,9 @@ public: } }; -/// Fold tensor_casts with insert_slice operations. If the source or destination -/// tensor is a tensor_cast that removes static type information, the cast is -/// folded into the insert_slice operation. E.g.: +/// Fold tensor_casts with insert_slice operations. If the source or +/// destination tensor is a tensor_cast that removes static type information, +/// the cast is folded into the insert_slice operation. E.g.: /// /// ```mlir /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor @@ -2175,7 +2455,8 @@ struct FoldTargetTensorCast : public OpRewritePattern { /// 5) the tensor::PadOps do not have common padding dimensions, /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and /// zero-offset for every dimension. -/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the +/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for +/// the /// padded source dimensions. /// /// Example: @@ -2276,11 +2557,11 @@ struct FoldOrthogonalPaddings : public OpRewritePattern { padOp, "cannot find zero-offset and zero-padding pair"); } - // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of - // the outer tensor::ExtractSliceOp for the dimensions padded by the outer - // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp - // does not match the size of the padded dimension. Otherwise, take the size - // of the inner tensor::ExtractSliceOp. + // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size + // of the outer tensor::ExtractSliceOp for the dimensions padded by the + // outer tensor::PadOp and fail if the size of the inner + // tensor::ExtractSliceOp does not match the size of the padded dimension. + // Otherwise, take the size of the inner tensor::ExtractSliceOp. SmallVector newSizes = innerSliceOp.getMixedSizes(); for (auto &en : enumerate(newSizes)) { if (!outerDims.test(en.index())) @@ -2306,8 +2587,8 @@ struct FoldOrthogonalPaddings : public OpRewritePattern { newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()]; } - // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the - // two paddings in one step. + // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs + // the two paddings in one step. auto newSliceOp = rewriter.create( padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes, innerSliceOp.getMixedStrides()); @@ -2397,8 +2678,8 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, result.addAttributes(attrs); } -/// Build an ParallelInsertSliceOp with mixed static and dynamic entries packed -/// into a Range vector. +/// Build an ParallelInsertSliceOp with mixed static and dynamic entries +/// packed into a Range vector. void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, ArrayRef ranges, @@ -2485,7 +2766,8 @@ OpFoldResult SplatOp::fold(ArrayRef operands) { if (!constOperand.isa_and_nonnull()) return {}; - // SplatElementsAttr::get treats single value for second arg as being a splat. + // SplatElementsAttr::get treats single value for second arg as being a + // splat. return SplatElementsAttr::get(getType(), {constOperand}); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 2c306fa..96e4651 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -31,9 +31,9 @@ struct PadOpTiling : public TilingInterface::ExternalModel { auto padOp = cast(op); SmallVector mixedSizes = getAsOpFoldResult(reifiedShapes[0]); - Value initTensor = b.create( + Value emptyTensor = b.create( op->getLoc(), mixedSizes, padOp.getResultType().getElementType()); - return {initTensor}; + return {emptyTensor}; } SmallVector getLoopIteratorTypes(Operation *op) const { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index fe28cb4..ecff102 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -222,7 +222,9 @@ declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TensorOps.td - SOURCES dialects/tensor.py + SOURCES + dialects/tensor.py + dialects/_tensor_ops_ext.py DIALECT_NAME tensor) declare_mlir_dialect_python_bindings( diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index e3fb460..eb9e969 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -20,44 +20,6 @@ def isa(cls: Type, ty: Type): return False -class InitTensorOp: - """Extends the linalg.init_tensor op.""" - - def __init__(self, - sizes: Union[Sequence[int], Sequence[Value]], - element_type: Type, - *, - loc=None, - ip=None): - """Constructs an `init_tensor` with either static or dynamic sizes.""" - context = get_default_loc_context(loc) - operands = [] - attributes = {} - # TODO: Refactor the InitTensorOp to take an element type attribute and - # then use normal result type inference, unifying the Python and C++ side - # with a standard mechanism (versus stashing that in builders). - if sizes and isinstance(sizes[0], Value): - # Dynamic sizes. - operands.extend(sizes) - static_size_ints = [-1] * len(sizes) - result_type = RankedTensorType.get(static_size_ints, element_type) - else: - # Static sizes. - result_type = RankedTensorType.get(sizes, element_type) - static_size_ints = sizes - - i64_type = IntegerType.get_signless(64) - attributes["static_sizes"] = ArrayAttr.get( - [IntegerAttr.get(i64_type, s) for s in static_size_ints], - context=context) - op = self.build_generic(results=[result_type], - operands=operands, - attributes=attributes, - loc=loc, - ip=ip) - OpView.__init__(self, op) - - class StructuredOpMixin: """All structured ops use the same mixin class.""" diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py new file mode 100644 index 0000000..0f1b266 --- /dev/null +++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py @@ -0,0 +1,42 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Any, Optional, Sequence, Union +from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values + + +class EmptyOp: + """Extends the tensor.empty op.""" + + def __init__(self, + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + loc=None, + ip=None): + """Constructs an `empty` with mixed static/dynamic sizes.""" + # TODO: Refactor the EmptyOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + dynamic_sizes = [] + static_sizes = [] + for s in sizes: + if isinstance(s, int): + static_sizes.append(s) + else: + static_sizes.append(-1) + dynamic_sizes.append(s) + result_type = RankedTensorType.get(static_sizes, element_type) + op = self.build_generic( + results=[result_type], + operands=dynamic_sizes, + attributes={}, + loc=loc, + ip=ip) + OpView.__init__(self, op) diff --git a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir index dcb5f15..f11abe4 100644 --- a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir +++ b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir @@ -6,7 +6,7 @@ // CHECK-LABEL: func @generalize_pad_tensor_static_shape( // CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.*]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32> +// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x32x32x1xf32> // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x32x32x1xf32>) -> tensor<1x32x32x1xf32> // CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32> // CHECK: return %[[PADDED]] : tensor<1x32x32x1xf32> @@ -31,7 +31,7 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> // CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index -// CHECK: %[[INIT:.*]] = linalg.init_tensor [4, %[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]] : tensor<4x?x?x?xf32> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : tensor<4x?x?x?xf32> // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32> // CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32> // CHECK: return %[[PADDED]] : tensor<4x?x?x?xf32> diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index d956e82..2ee28c4 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: @matmul func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) { // CHECK: [[C0:%.+]] = arith.constant 0 - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) @@ -16,7 +16,7 @@ func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor // CHECK-LABEL: @matmul_quantized func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>) { // CHECK: [[C0:%.+]] = arith.constant 0 - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : i32) outs([[INIT]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32> // CHECK: [[ONE:%.+]] = arith.constant 1 // CHECK: [[TWO:%.+]] = arith.constant 2 @@ -32,7 +32,7 @@ func.func @matmul_dyn_batch(%arg0: tensor, %arg1: tensor) // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[C0_0:.+]] = arith.constant 0 - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]], 5, 6] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) outs(%[[FILLED]] : tensor) -> tensor %0 = "tosa.matmul"(%arg0, %arg1) : (tensor, tensor) -> (tensor) @@ -46,7 +46,7 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x // CHECK: %[[C2:.+]] = arith.constant 2 // CHECK: %[[DIM:.+]] = tensor.dim %arg1, %[[C2]] // CHECK: %[[C0:.+]] = arith.constant 0 - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, %[[DIM]]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32> %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> (tensor<1x5x?xf32>) @@ -58,7 +58,7 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x // CHECK-LABEL: @matmul_dyn_independent_dim func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>) { // CHECK: %[[C0:.+]] = arith.constant 0 - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, 6] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>) @@ -72,12 +72,12 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x // CHECK-LABEL: @fully_connected func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) { - // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[INITT:%.+]] = tensor.empty() // CHECK: [[ZERO:%.+]] = arith.constant 0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]] // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]> // CHECK: [[TRANSPOSE:%.+]] = "tosa.transpose"(%arg1, [[PERM]]) - // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[INITB:%.+]] = tensor.empty() // CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32> // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): @@ -95,12 +95,12 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2 // CHECK-LABEL: @quantized_fully_connected func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) { - // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[INITT:%.+]] = tensor.empty() // CHECK: [[ZERO:%.+]] = arith.constant 0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]] // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]> // CHECK: [[TRANSPOSE:%.+]] = "tosa.transpose"(%arg1, [[PERM]]) - // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[INITB:%.+]] = tensor.empty() // CHECK: [[ONE:%.+]] = arith.constant 1 // CHECK: [[TWO:%.+]] = arith.constant 2 // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32> @@ -121,12 +121,12 @@ func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8 func.func @fully_connected_dyn(%arg0: tensor, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor) { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INITT:.+]] = linalg.init_tensor [%[[DIM]], 6] + // CHECK: %[[INITT:.+]] = tensor.empty(%[[DIM]]) // CHECK: %[[ZERO:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]] // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> // CHECK: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg1, %[[PERM]]) - // CHECK: %[[INITB:.+]] = linalg.init_tensor [%[[DIM]], 6] + // CHECK: %[[INITB:.+]] = tensor.empty(%[[DIM]]) // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor, tensor<3x6xf32>) outs(%[[FILL]] : tensor) -> tensor // CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor) outs(%[[INITB]] : tensor) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): @@ -142,9 +142,9 @@ func.func @fully_connected_dyn(%arg0: tensor, %arg1: tensor<6x3xf32>, % // CHECK-LABEL: @max_pool func.func @max_pool(%arg0: tensor<1x6x34x62xf32>) -> () { // CHECK-DAG: [[CONST:%.+]] = arith.constant -3.40282347E+38 - // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 32, 62] + // CHECK-DAG: [[INIT:%.+]] = tensor.empty() // CHECK-DAG: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[INIT]] - // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3] + // CHECK-DAG: [[KERNEL:%.+]] = tensor.empty() // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x32x62xf32>) %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) return @@ -156,9 +156,9 @@ func.func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () { // CHECK-DAG: [[PAD:%.+]] = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 0, 1, 0] // CHECK-DAG: tensor.yield [[CONST]] // CHECK-DAG: [[INITVAL:%.+]] = arith.constant -3.40282347E+38 : f32 - // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 33, 62] + // CHECK-DAG: [[INIT:%.+]] = tensor.empty() // CHECK-DAG: [[FILL:%.+]] = linalg.fill ins([[INITVAL]]{{.*}}outs([[INIT]] - // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3] + // CHECK-DAG: [[KERNEL:%.+]] = tensor.empty() // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x6x35x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x33x62xf32>) %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 1], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x33x62xf32>) return @@ -169,9 +169,9 @@ func.func @max_pool_dyn(%arg0: tensor) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[CONST:.+]] = arith.constant -3.40282347E+38 - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 32, 62] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CONST]]{{.*}}outs(%[[INIT]] - // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [3, 3] + // CHECK: %[[KERNEL:.+]] = tensor.empty() // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor<3x3xf32>) outs(%[[FILL]] : tensor) %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor) -> (tensor) return @@ -208,11 +208,11 @@ func.func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) { // CHECK: [[CONST:%.+]] = arith.constant 0 // CHECK: [[PAD:%.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: [[CONST:%.+]] = arith.constant 0 - // CHECK: [[POOLINIT:%.+]] = linalg.init_tensor [1, 5, 33, 62] + // CHECK: [[POOLINIT:%.+]] = tensor.empty() // CHECK: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[POOLINIT]] - // CHECK: [[KERNEL:%.+]] = linalg.init_tensor [4, 4] + // CHECK: [[KERNEL:%.+]] = tensor.empty() // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>) - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[POOL]] : tensor<1x5x33x62xf32>) outs([[INIT]] : tensor<1x5x33x62xf32>) // CHECK: [[ZERO:%.0]] = arith.constant 0 // CHECK: [[ONE:%.+]] = arith.constant 1 @@ -269,11 +269,11 @@ func.func @avg_pool_dyn(%arg0: tensor) -> (tensor) // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] - // CHECK: %[[POOLINIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 33, 62] + // CHECK: %[[POOLINIT:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[FILL:.+]] = linalg.fill - // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [4, 4] + // CHECK: %[[KERNEL:.+]] = tensor.empty() // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor, tensor<4x4xf32>) outs(%[[FILL]] : tensor) - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 33, 62] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor) outs(%[[INIT]] : tensor) %0 = "tosa.avg_pool2d"(%arg0) {pad = [1, 1, 1, 1], kernel = [4, 4], stride = [1, 1]} : (tensor) -> (tensor) return %0 : tensor @@ -346,10 +346,10 @@ func.func @avg_pool_i16(%arg0 : tensor<1x128x128x2xi16>) -> () { func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> // CHECK: %[[W:.+]] = "tosa.transpose"(%arg1, %[[PERM]]) - // CHECK: %[[M_IN:.+]] = linalg.init_tensor [1, 45, 40, 28] + // CHECK: %[[M_IN:.+]] = tensor.empty() // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill - // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, 45, 40, 28] + // CHECK: %[[B_IN:.+]] = tensor.empty() // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>) // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) // CHECK: arith.addf @@ -369,10 +369,10 @@ func.func @conv2d_dyn(%input: tensor, %weights: tensor<28x3x3x27 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> // CHECK: %[[W:.+]] = "tosa.transpose"(%arg1, %[[PERM]]) - // CHECK: %[[M_IN:.+]] = linalg.init_tensor [%[[BATCH]], 45, 40, 28] + // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill - // CHECK: %[[B_IN:.+]] = linalg.init_tensor [%[[BATCH]], 45, 40, 28] + // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor) // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor) outs(%[[B_IN]] : tensor) // CHECK: %[[ADD:.+]] = arith.addf @@ -429,10 +429,10 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x // Running convolution // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> // CHECK: %[[WEIGHT:.+]] = "tosa.transpose"(%arg1, %[[PERM]]) - // CHECK: %[[M_IN:.+]] = linalg.init_tensor [1, %[[H_OUT]], %[[W_OUT]], 28] + // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill - // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, %[[H_OUT]], %[[W_OUT]], 28] + // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>) // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>) // CHECK: %[[ADD:.+]] = arith.addf @@ -472,10 +472,10 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x // CHECK-LABEL: @depthwise_conv func.func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[CST0:%.+]] = arith.constant 0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] - // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] + // CHECK: [[OUT:%.+]] = tensor.empty() // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { @@ -496,10 +496,10 @@ func.func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf func.func @depthwise_conv_dyn(%arg0 : tensor, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 3, 11] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[CST0:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill - // CHECK: %[[OUT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 33] + // CHECK: %[[OUT:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<3x1x3x11xf32>) outs(%[[FILL]] : tensor) // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: %[[BIAS:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[COLLAPSED]] : tensor<33xf32>, tensor) outs(%[[OUT]] : tensor) { @@ -518,10 +518,10 @@ func.func @depthwise_conv_dyn(%arg0 : tensor, %arg1 : tensor<3x1x3x // CHECK-LABEL: @depthwise_conv_strides func.func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[CST0:%.+]] = arith.constant 0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] - // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] + // CHECK: [[OUT:%.+]] = tensor.empty() // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { @@ -544,10 +544,10 @@ func.func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3 // CHECK: [[PAD:%.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield [[PADV]] - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 12, 12, 4, 128] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[CST0:%.+]] = arith.constant 0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] - // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 12, 12, 512] + // CHECK: [[OUT:%.+]] = tensor.empty() // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x12x12x4x128xi32>) @@ -568,10 +568,10 @@ func.func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3 // CHECK-LABEL: @depthwise_conv_quant_dilations func.func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 4, 128] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[CST0:%.+]] = arith.constant 0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] - // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 10, 10, 512] + // CHECK: [[OUT:%.+]] = tensor.empty() // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x10x10x4x128xi32>) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 685f782..ae7c547 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: @test_abs func.func @test_abs(%arg0: tensor) -> tensor { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins(%arg0 : tensor) outs([[INIT]] : tensor) { // CHECK: ^bb0(%arg1: f32, %arg2: f32): // CHECK: [[ELEMENT:%.+]] = math.absf %arg1 @@ -23,7 +23,7 @@ func.func @test_abs(%arg0: tensor) -> tensor { // CHECK-LABEL: @test_abs func.func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32> // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg1: f32, %arg2: f32): // CHECK: [[ELEMENT:%.+]] = math.absf %arg1 @@ -41,7 +41,7 @@ func.func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: @test_abs func.func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32> + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32> // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) { // CHECK: ^bb0(%arg1: f32, %arg2: f32): // CHECK: [[ELEMENT:%.+]] = math.absf %arg1 @@ -59,7 +59,7 @@ func.func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { func.func @test_abs(%arg0: tensor) -> tensor { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) // CHECK: linalg.generic // CHECK: math.absf %0 = "tosa.abs"(%arg0) : (tensor) -> tensor @@ -74,7 +74,7 @@ func.func @test_abs(%arg0: tensor) -> tensor { func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { // CHECK: %[[C1:.+]] = arith.constant 1 // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[DIM]]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) // CHECK: linalg.generic // CHECK: math.absf %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> @@ -88,7 +88,7 @@ func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { // CHECK-LABEL: @test_broadcast func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32> // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg0 // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %arg1 : tensor, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): @@ -106,7 +106,7 @@ func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor< // CHECK-LABEL: @test_broadcast_swapped_args func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32> // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg1 // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[RESHAPE]] : tensor<2xf32>, tensor) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): @@ -125,7 +125,7 @@ func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32 // CHECK-LABEL: @test_multibroadcast func.func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32> + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32> // CHECK: [[RESHAPE1:%.+]] = tensor.collapse_shape %arg0 {{\[}}[0, 1]] // CHECK: [[RESHAPE2:%.+]] = tensor.collapse_shape %arg1 {{\[}}[0, 1]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) { @@ -630,7 +630,7 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor< // CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>) func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () { %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32> - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3, 1] + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3x1xi32> // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>) // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32) // CHECK: linalg.yield [[ARG1]] @@ -650,7 +650,7 @@ func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () { %0 = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32> // CHECK: %[[C1:.+]] = arith.constant 1 // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]], 4, 1, 3] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs([[OUT:%.+]] : tensor) // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32) // CHECK: linalg.yield [[ARG1]] @@ -672,7 +672,7 @@ func.func @test_transpose_dyn_multiple(%arg0: tensor) -> () { // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[C1:.+]] = arith.constant 1 // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM1]], %[[DIM0]]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor) outs([[OUT:%.+]] : tensor) // CHECK: ^bb0([[ARG1:%.+]]: f32, [[ARG2:%.+]]: f32) // CHECK: linalg.yield [[ARG1]] @@ -690,7 +690,7 @@ func.func @test_transpose_dyn_multiple(%arg0: tensor) -> () { // CHECK-LABEL: @reduce_float // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32> func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32> // CHECK: [[CST0:%.+]] = arith.constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>) @@ -700,7 +700,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () { // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32> // CHECK: [[CST0:%.+]] = arith.constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>) @@ -739,7 +739,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () { func.func @reduce_float_dyn(%arg0: tensor) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 4] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor // CHECK: %[[CST0:.+]] = arith.constant 0.0 // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT]] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor) outs(%[[FILL]] : tensor) @@ -760,7 +760,7 @@ func.func @reduce_float_dyn(%arg0: tensor) -> () { func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () { // CHECK: %[[C1:.+]] = arith.constant 1 // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C1]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, %[[DYN]]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<5x?xf32> // CHECK: %[[CST1:.+]] = arith.constant 1.0 // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT]] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<5x?x4xf32>) outs(%[[FILL]] : tensor<5x?xf32>) @@ -781,7 +781,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () { func.func @reduce_float_dyn_multiple(%arg0: tensor) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) // CHECK: %[[CMIN:.+]] = arith.constant -3.40282347E+38 // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor) outs(%[[FILL]] : tensor) @@ -802,7 +802,7 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor) -> () { // CHECK-LABEL: @reduce_int // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi32> func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[CST0:%.+]] = arith.constant 0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>) @@ -812,7 +812,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () { // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32> - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[CST0:%.+]] = arith.constant 0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>) @@ -852,7 +852,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () { // CHECK-LABEL: @reduce_bool // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi1> func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[CST0:%.+]] = arith.constant true // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<4xi1>) @@ -880,7 +880,7 @@ func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () { // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index // CHECK: [[IDX0:%.+]] = arith.constant 0 : index // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1] + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32> // CHECK: [[CST:%.+]] = arith.constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST]]{{.*}}outs([[INIT]] // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1] @@ -892,7 +892,7 @@ func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () { // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index // CHECK: [[IDX0:%.+]] = arith.constant 0 : index // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2] + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32> // CHECK: [[CST:%.+]] = arith.constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST]]{{.*}}outs([[INIT]] // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1] @@ -913,7 +913,7 @@ func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) - // CHECK: %[[SIZE:.+]] = tensor.dim %arg0, %[[IDX1]] // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX1_2]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [11, %[[DYN]]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32> // CHECK: %[[CST:.+]] = arith.constant 0.0 // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]]{{.*}}outs(%[[INIT]] // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [5, %[[SIZE]]] [1, 1] @@ -934,7 +934,7 @@ func.func @concat_axis_dyn(%arg0: tensor, %arg1: tensor) -> () // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX0_2]] // CHECK: %[[IDX1:.+]] = arith.constant 1 : index - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 3] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor // CHECK: %[[CST:.+]] = arith.constant 0.0 // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]]{{.*}}outs(%[[INIT]] // CHECK: %[[DYN1:.+]] = tensor.dim %arg0, %[[AXIS]] @@ -953,7 +953,7 @@ func.func @concat_axis_dyn(%arg0: tensor, %arg1: tensor) -> () func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () { // CHECK: [[C0:%.+]] = arith.constant 19689 // CHECK: [[C1:%.+]] = arith.constant 15 - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>) // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8): // CHECK: [[C17:%.+]] = arith.constant 17 @@ -974,7 +974,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () { // CHECK: [[C0:%.+]] = arith.constant 19689 // CHECK: [[C1:%.+]] = arith.constant 15 - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>) // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8): // CHECK: [[C17:%.+]] = arith.constant 17 @@ -1006,13 +1006,13 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () { func.func @rescale_i8_dyn_batch(%arg0 : tensor) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor) outs(%[[INIT]] : tensor) %0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor) -> (tensor) // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor) outs(%[[INIT]] : tensor) %1 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor) -> (tensor) @@ -1029,7 +1029,7 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () { // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]] // CHECK: %[[C2:.+]] = arith.constant 2 // CHECK: %[[DIM2:.+]] = tensor.dim %arg0, %[[C2]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, %[[DIM1]], %[[DIM2]], 32] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]]) // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>) %0 = "tosa.rescale"(%arg0) {double_round = true, input_zp = 0 : i32, multiplier = [1376784203 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [38 : i32]} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8> return @@ -1043,7 +1043,7 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () { func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () { // CHECK: [[C0:%.+]] = arith.constant 19689 // CHECK: [[C1:%.+]] = arith.constant 15 - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>) // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8): // CHECK: [[C17:%.+]] = arith.constant 17 @@ -1074,7 +1074,7 @@ func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () { func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) { // CHECK: [[MULTIPLIERS:%.+]] = arith.constant dense<[42, 43, 0]> // CHECK: [[SHIFTS:%.+]] = arith.constant dense<[14, 15, 0]> - // CHECK: [[INIT:%.+]] = linalg.init_tensor [3] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[MULTIPLIERS]], [[SHIFTS]] : tensor<3xi8>, tensor<3xi32>, tensor<3xi8>) outs([[INIT]] : tensor<3xi8>) // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8): // CHECK: [[C243:%.+]] = arith.constant 243 @@ -1126,7 +1126,7 @@ func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) func.func @reverse(%arg0: tensor<5x4xi32>) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>) // CHECK-DAG: %[[I0:.+]] = linalg.index 0 // CHECK-DAG: %[[I1:.+]] = linalg.index 1 @@ -1139,7 +1139,7 @@ func.func @reverse(%arg0: tensor<5x4xi32>) -> () { // CHECK: %[[C1:.+]] = arith.constant 1 // CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C1]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>) // CHECK-DAG: %[[I0:.+]] = linalg.index 0 // CHECK-DAG: %[[I1:.+]] = linalg.index 1 @@ -1162,7 +1162,7 @@ func.func @reverse_dyn(%arg0: tensor) -> () { // CHECK: %[[D0_1:.+]] = tensor.dim %arg0, %[[C0_1]] // CHECK: %[[C0_2:.+]] = arith.constant 0 // CHECK: %[[D0_2:.+]] = tensor.dim %arg0, %[[C0_2]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0_1]]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[D0_1]]) // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel"]} outs(%[[INIT]] : tensor) // CHECK-DAG: %[[I0:.+]] = linalg.index 0 // CHECK-DAG: %[[SUB1:.+]] = arith.constant 1 @@ -1181,19 +1181,19 @@ func.func @reverse_dyn(%arg0: tensor) -> () { // CHECK-LABEL: @tile func.func @tile(%arg0 : tensor<2x3xi8>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) // CHECK: linalg.yield %arg1 : i8 // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]] %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>) - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) // CHECK: linalg.yield %arg1 : i8 // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>) - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3] + // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) // CHECK: linalg.yield %arg1 : i8 // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] @@ -1211,7 +1211,7 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () { func.func @tile_dyn_input(%arg0 : tensor) -> () { // CHECK: %[[CST0:.+]] = arith.constant 0 // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]] : tensor - // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[DYN]], 1, 3] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor) outs(%[[INIT]] : tensor<2x?x1x3xi8>) // CHECK: linalg.yield %arg1 : i8 // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]] @@ -1230,7 +1230,7 @@ func.func @tile_dyn_input(%arg0 : tensor) -> () { func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () { // CHECK: %[[CST1:.+]] = arith.constant 1 // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]] : tensor<2x3xi8> - // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 2, %[[DYN]], 3] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>) // CHECK: linalg.yield %arg1 : i8 // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]] @@ -1338,10 +1338,10 @@ func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor) { // CHECK: #[[$MAP4:.*]] = affine_map<(d0) -> ()> func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () { - // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[IDX_INIT:%.+]] = tensor.empty() // CHECK: [[IDX_MIN:%.+]] = arith.constant 0 : i32 // CHECK: [[IDX_FILL:%.+]] = linalg.fill ins([[IDX_MIN]]{{.*}}outs([[IDX_INIT]] - // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[VAL_INIT:%.+]] = tensor.empty() // CHECK: [[VAL_MIN:%.+]] = arith.constant -2147483648 // CHECK: [[VAL_FILL:%.+]] = linalg.fill ins([[VAL_MIN]]{{.*}}outs([[VAL_INIT]] // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<2xi32>, tensor<2xi32>) @@ -1353,10 +1353,10 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () { // CHECK: linalg.yield [[SELECT_IDX]], [[SELECT_VAL]] %0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x2xi32>) -> (tensor<2xi32>) - // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [3] + // CHECK: [[IDX_INIT:%.+]] = tensor.empty() // CHECK: [[IDX_MIN:%.+]] = arith.constant 0 : i32 // CHECK: [[IDX_FILL:%.+]] = linalg.fill ins([[IDX_MIN]]{{.*}}outs([[IDX_INIT]] - // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [3] + // CHECK: [[VAL_INIT:%.+]] = tensor.empty() // CHECK: [[VAL_MIN:%.+]] = arith.constant -2147483648 // CHECK: [[VAL_FILL:%.+]] = linalg.fill ins([[VAL_MIN]]{{.*}}outs([[VAL_INIT]] // CHECK: linalg.generic {indexing_maps = [#map0, #map2, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>) @@ -1388,10 +1388,10 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () { func.func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () { // CHECK: %[[CST1:.+]] = arith.constant 1 // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]] - // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [%[[DYN]]] + // CHECK: %[[IDX_INIT:.+]] = tensor.empty(%[[DYN]]) // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32 // CHECK: %[[IDX_FILL:.+]] = linalg.fill ins(%[[IDX_MIN]]{{.*}}outs(%[[IDX_INIT]] - // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [%[[DYN]]] + // CHECK: %[[VAL_INIT:.+]] = tensor.empty(%[[DYN]]) // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648 // CHECK: %[[VAL_FILL:.+]] = linalg.fill ins(%[[VAL_MIN]]{{.*}}outs(%[[VAL_INIT]] // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor, tensor) @@ -1411,10 +1411,10 @@ func.func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> func.func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () { - // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [3] + // CHECK: %[[IDX_INIT:.+]] = tensor.empty() // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32 // CHECK: %[[IDX_FILL:.+]] = linalg.fill ins(%[[IDX_MIN]]{{.*}}outs(%[[IDX_INIT]] - // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [3] + // CHECK: %[[VAL_INIT:.+]] = tensor.empty() // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648 // CHECK: %[[VAL_FILL:.+]] = linalg.fill ins(%[[VAL_MIN]]{{.*}}outs(%[[VAL_INIT]] // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>) @@ -1432,7 +1432,7 @@ func.func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () { // CHECK-LABEL: @gather_float func.func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xf32>) // CHECK: ^bb0(%[[ARG0:.+]]: i32, %[[ARG1:.+]]: f32) // CHECK: %[[IDX0:.+]] = linalg.index 0 @@ -1448,7 +1448,7 @@ func.func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () func.func @gather_float_dyn(%arg0: tensor, %arg1: tensor) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 3, 2] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor) outs(%[[INIT]] : tensor) // CHECK: ^bb0(%[[ARG0:.+]]: i32, %[[ARG1:.+]]: f32) // CHECK: %[[IDX0:.+]] = linalg.index 0 @@ -1462,7 +1462,7 @@ func.func @gather_float_dyn(%arg0: tensor, %arg1: tensor) -> // CHECK-LABEL: @gather_int func.func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xi32>) // CHECK: ^bb0(%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) // CHECK: %[[IDX0:.+]] = linalg.index 0 @@ -1478,7 +1478,7 @@ func.func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () { // CHECK-LABEL: @table8 func.func @table8(%arg0: tensor<6xi8>, %arg1: tensor<512xi8>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [6] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>) // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8) // CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG_IN]] @@ -1494,7 +1494,7 @@ func.func @table8(%arg0: tensor<6xi8>, %arg1: tensor<512xi8>) -> () { // CHECK-LABEL: @table16 func.func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [6] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi16>) outs(%[[INIT]] : tensor<6xi32>) // CHECK: ^bb0(%arg2: i16, %arg3: i32) // CHECK: %[[EXT_IN:.+]] = arith.extsi %arg2 @@ -1527,7 +1527,7 @@ func.func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () { func.func @table8_dyn(%arg0: tensor, %arg1: tensor<512xi8>) -> () { // CHECK: %[[CST0:.+]] = arith.constant 0 // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor) outs(%[[INIT]] : tensor) // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8) // CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG_IN]] @@ -1543,7 +1543,7 @@ func.func @table8_dyn(%arg0: tensor, %arg1: tensor<512xi8>) -> () { // CHECK-LABEL: @table8_dyn_table func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [6] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>) // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8) // CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG_IN]] @@ -1559,7 +1559,7 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor) -> () { // CHECK-LABEL: @resize_nearest func.func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[IDX0:.+]] = linalg.index 0 // CHECK: %[[IDX1:.+]] = linalg.index 1 @@ -1628,7 +1628,7 @@ func.func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () { // CHECK-LABEL: @resize_bilinear func.func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[IDX0:.+]] = linalg.index 0 // CHECK: %[[IDX1:.+]] = linalg.index 1 @@ -1710,7 +1710,7 @@ func.func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () { // CHECK-LABEL: @resize_nearest_int func.func @resize_nearest_int(%input: tensor<1x2x2x1xi32>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[IDX0:.+]] = linalg.index 0 // CHECK: %[[IDX1:.+]] = linalg.index 1 @@ -1778,7 +1778,7 @@ func.func @resize_nearest_int(%input: tensor<1x2x2x1xi32>) -> () { // CHECK-LABEL: @resize_bilinear_int func.func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] + // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[IDX0:.+]] = linalg.index 0 @@ -1865,7 +1865,7 @@ func.func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () { func.func @resize_dyn(%input: tensor) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 4, 1] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[GENERIC:.+]] = linalg.generic %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor) -> (tensor) return diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir index 51999f5..c6454cf 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir @@ -47,7 +47,8 @@ func.func @buffer_forwarding_no_conflict( %f0 = arith.constant 0.0: f32 // alloc_tensor itself does not alloc but forwards to the insert_slice. - // InitTensorOp replaces the alloc_tensor with an inplace extract_slice. + // AllocTensorOpElimination replaces the alloc_tensor with an inplace + // extract_slice. // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] %a = bufferization.alloc_tensor(%sz) : tensor diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir index 5ff98e1..957dc12 100644 --- a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir +++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir @@ -135,7 +135,7 @@ func.func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32x %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 - %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> + %init = tensor.empty() : tensor<1x112x112x32xf32> %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> %conv = linalg.conv_2d_nhwc_hwcf @@ -149,7 +149,7 @@ func.func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32x } // CHECK: func @conv_slice -// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x112x112x32xf32> // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[0, 128, 128, 0] [1, 65, 65, 3] [1, 1, 1, 1] : tensor<1x225x225x3xf32> to tensor<1x65x65x3xf32> // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[0, 0, 0, 16] [3, 3, 3, 16] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x16xf32> // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[INIT]][0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32> @@ -162,7 +162,7 @@ func.func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32x // The slice is not supposed to be bubbled up when it is rank-reducing. func.func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> { %cst = arith.constant 1.000000e+00 : f32 - %init = linalg.init_tensor [1, %width] : tensor<1x?xf32> + %init = tensor.empty(%width) : tensor<1x?xf32> %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32> %slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor into tensor<1x1x1x?xf32> @@ -170,7 +170,7 @@ func.func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> { } // CHECK: func @rank_reducing_slice -// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[INIT:.+]] = tensor.empty // CHECK: %[[FILL:.+]] = linalg.fill ins // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]] // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index eb89907..29f27e6 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -41,18 +41,18 @@ func.func @basic(%arg0: tensor<4xf32>) -> tensor<4xf32> { #map0 = affine_map<(d0) -> (d0)> -// Same as above but with linalg.init_tensor op. +// Same as above but with tensor.empty op. // CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func @init_tensor( +// CHECK-LABEL: func @empty_tensor( // CHECK-SAME: %[[IN:.*]]: tensor, %[[SIZE:.*]]: index) // CHECK-DAG: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref // CHECK-DAG: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) {{.*}} : memref // CHECK: linalg.generic // CHECK-SAME: ins(%[[MEMREF]] : memref) // CHECK-SAME: outs(%[[OUT_BUF]] : memref) { -func.func @init_tensor(%in : tensor, %size: index) -> tensor { - %init = linalg.init_tensor [%size] : tensor +func.func @empty_tensor(%in : tensor, %size: index) -> tensor { + %init = tensor.empty(%size) : tensor %0 = linalg.generic { indexing_maps = [#map0, #map0], iterator_types = ["parallel"] @@ -208,7 +208,7 @@ func.func private @csum(%arg0: tensor<6xi64>) -> tensor<6xi64> // CHECK: return %[[call]] func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> { %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x3xi1> into tensor<6xi1> - %1 = linalg.init_tensor [6] : tensor<6xi64> + %1 = tensor.empty() : tensor<6xi64> %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%0 : tensor<6xi1>) outs(%1 : tensor<6xi64>) { ^bb0(%arg1: i1, %arg2: i64): %4 = arith.extui %arg1 : i1 to i64 diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir index 2201997..5ca63d9 100644 --- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir @@ -131,8 +131,8 @@ func.func @drop_dead_results(%arg0 : tensor) -> (tensor, t #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> ()> func.func @argmax_lowering(%arg0 : tensor) -> tensor { - %init0 = linalg.init_tensor [] : tensor - %init1 = linalg.init_tensor [] : tensor + %init0 = tensor.empty() : tensor + %init1 = tensor.empty() : tensor %0:2 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["reduction"]} @@ -153,8 +153,8 @@ func.func @argmax_lowering(%arg0 : tensor) -> tensor { } // CHECK: func @argmax_lowering( // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [] : tensor -// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [] : tensor +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor // CHECK: %[[GENERIC:.+]]:2 = linalg.generic // CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] : // CHECK: return %[[GENERIC]]#1 @@ -164,7 +164,7 @@ func.func @argmax_lowering(%arg0 : tensor) -> tensor { // Do not remove operand needed for loop dim. func.func @loop_dim_operand(%arg0 : tensor) -> tensor { %cst = arith.constant 0 : i32 - %init = linalg.init_tensor [] : tensor + %init = tensor.empty() : tensor %fill = linalg.fill ins(%cst : i32) outs(%init : tensor) -> tensor %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], @@ -188,8 +188,8 @@ func.func @loop_dim_operand(%arg0 : tensor) -> tensor { // Do not remove outs operand needed for loop bound computation. func.func @loop_dim_outs_operand(%arg0 : index) -> tensor { %cst = arith.constant 0 : i32 - %init1 = linalg.init_tensor [%arg0] : tensor - %init = linalg.init_tensor [] : tensor + %init1 = tensor.empty(%arg0) : tensor + %init = tensor.empty() : tensor %fill = linalg.fill ins(%cst : i32) outs(%init : tensor) -> tensor %0:2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], @@ -205,7 +205,7 @@ func.func @loop_dim_outs_operand(%arg0 : index) -> tensor { } // CHECK: func @loop_dim_outs_operand( // CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[ARG0]]] +// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG0]]) // CHECK: linalg.generic // CHECK-SAME: outs(%[[INIT]] diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 51a7bf6..43589f7 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -85,48 +85,6 @@ func.func @linalg_effects(%a : tensor, %b : memref, %c : tenso // ----- -func.func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) { - %c6 = arith.constant 6 : index - %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32> - return %0 : tensor<4x5x?xf32> -} -// CHECK: func @init_tensor_canonicalize -// CHECK: %[[T0:.+]] = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32> -// CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> -// CHECK: return %[[T1]] - -// ----- - -func.func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { - %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32> - %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] - : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> - return %1 : tensor<2x3x5x4x?x7xf32> -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> -// CHECK: func @init_tensor_reshape_expansion -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[D]], 7] -// CHECK-NEXT: return %[[INIT]] - -// ----- - -func.func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { - %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32> - %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] - : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> - return %1 : tensor<6x5x?xf32> -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> -// CHECK: func @init_tensor_reshape_collapse -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [6, 5, %[[D]]] -// CHECK-NEXT: return %[[INIT]] - -// ----- - #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func.func @remove_no_op(%arg0 : tensor, %arg1 : tensor) -> (tensor, tensor) { @@ -136,7 +94,7 @@ func.func @remove_no_op(%arg0 : tensor, %arg1 : tensor) %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor %2 = tensor.dim %arg0, %c2 : tensor - %3 = linalg.init_tensor [%0, %1, %2] : tensor + %3 = tensor.empty(%0, %1, %2) : tensor %4, %5 = linalg.generic { indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"] @@ -157,7 +115,7 @@ func.func @remove_no_op(%arg0 : tensor, %arg1 : tensor) #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func.func @remove_no_op_mismatched_types(%arg0 : tensor) -> tensor<1x2x3xf32> { - %out = linalg.init_tensor [1, 2, 3] : tensor<1x2x3xf32> + %out = tensor.empty() : tensor<1x2x3xf32> %g = linalg.generic { indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"] @@ -177,7 +135,7 @@ func.func @remove_no_op_mismatched_types(%arg0 : tensor) #map = affine_map<() -> ()> func.func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor { - %out = linalg.init_tensor [] : tensor + %out = tensor.empty() : tensor %g = linalg.generic { indexing_maps = [#map, #map], iterator_types = [] @@ -200,7 +158,7 @@ func.func @keep_not_noop(%arg0 : tensor) -> tensor { %cst = arith.constant 1.000000e+00 : f32 %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor cf.br ^bb1(%cst : f32) ^bb1(%arg1 : f32): @@ -226,7 +184,7 @@ func.func @keep_not_noop(%arg0 : tensor, %arg1 : tensor) %cst = arith.constant 1.000000e+00 : f32 %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor cf.br ^bb1(%cst : f32) ^bb1(%arg2 : f32): @@ -246,33 +204,6 @@ func.func @keep_not_noop(%arg0 : tensor, %arg1 : tensor) // ----- -func.func @fold_init_tensor_with_slice - (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> -{ - %0 = linalg.init_tensor[%arg0, 10, 40] : tensor - %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] - : tensor to tensor<5x?x20xf32> - return %1 : tensor<5x?x20xf32> -} -// CHECK: func @fold_init_tensor_with_slice -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK: %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20] -// CHECK: return %[[T0]] - -// ----- - -func.func @fold_init_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> { - %0 = linalg.init_tensor [%arg0, 12] : tensor - %1 = tensor.cast %0 : tensor to tensor<1x12xf32> - return %1 : tensor<1x12xf32> -} -// CHECK: func @fold_init_tensor_with_cast(%[[ARG0:.+]]: index) -// CHECK: %[[T0:.+]] = linalg.init_tensor [1, 12] : tensor<1x12xf32> -// CHECK: return %[[T0]] : tensor<1x12xf32> - -// ----- - #accesses = [ affine_map<(i, j) -> (i, j)> ] @@ -315,7 +246,7 @@ func.func @propogate_casts(%arg0 : tensor, %arg1 : f32, %arg2 : index, %c1 = arith.constant 1 : index %c21 = arith.constant 21 : index %c42 = arith.constant 42 : index - %0 = linalg.init_tensor [%c21, %c42] : tensor + %0 = tensor.empty(%c21, %c42) : tensor %1 = linalg.fill ins(%arg1 : f32) outs(%0 : tensor) -> tensor %2 = tensor.dim %arg0, %c0 : tensor %3 = tensor.dim %arg0, %c1 : tensor @@ -323,7 +254,7 @@ func.func @propogate_casts(%arg0 : tensor, %arg1 : f32, %arg2 : index, return %4 : tensor } // CHECK-LABEL: func @propogate_casts -// CHECK: %[[INIT:.+]] = linalg.init_tensor [21, 42] +// CHECK: %[[INIT:.+]] = tensor.empty // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %{{.+}} into %[[FILL]] // CHECK: %[[RESULT:.+]] = tensor.cast %[[INSERTED]] @@ -353,8 +284,8 @@ func.func @remove_deadargs_generic_basic(%arg0: tensor) -> (tensor %c0 = arith.constant 0 : index %cst = arith.constant 7.0 : f32 %0 = tensor.dim %arg0, %c0 : tensor - %1 = linalg.init_tensor [%0] : tensor - %2 = linalg.init_tensor [%0] : tensor + %1 = tensor.empty(%0) : tensor + %2 = tensor.empty(%0) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor, tensor) outs (%2:tensor) { ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): %4 = arith.addf %arg1, %cst : f32 @@ -378,9 +309,9 @@ func.func @remove_deadargs_generic_mixedaccess(%arg0: tensor) -> (tenso %cst2 = arith.constant 6.0 : f32 %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor - %3 = linalg.init_tensor [%1, %0] : tensor - %4 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor + %3 = tensor.empty(%1, %0) : tensor + %4 = tensor.empty(%0, %1) : tensor %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor, tensor) outs (%4:tensor) { ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): %6 = arith.divf %cst1, %cst2 : f32 @@ -393,10 +324,10 @@ func.func @remove_deadargs_generic_mixedaccess(%arg0: tensor) -> (tenso // CHECK-LABEL: func @fold_fill_reshape() func.func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = arith.constant 0.0 : f32 - // CHECK: %[[INIT:.+]] = linalg.init_tensor [6, 4] : tensor<6x4xf32> - %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32> + // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<6x4xf32> + %empty = tensor.empty() : tensor<1x2x3x4xf32> // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<6x4xf32>) -> tensor<6x4xf32> - %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> + %fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]] : tensor<1x2x3x4xf32> into tensor<6x4xf32> // CHECK: return %[[FILL]] : tensor<6x4xf32> @@ -418,43 +349,6 @@ func.func @fold_fill_reshape_dynamic(%arg0 : tensor) -> tensor } - -// ----- - -func.func private @some_use(%i : index, %j : index) - -// CHECK-LABEL: func @init_canonicalize -// CHECK-SAME: %[[I:.*]]: index -func.func @init_canonicalize(%i : index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - // CHECK-NOT: init_tensor - %0 = linalg.init_tensor [%i, 42] : tensor - - // CHECK-NOT: tensor.dim - %1 = tensor.dim %0, %c0: tensor - %2 = tensor.dim %0, %c1: tensor - - // CHECK: %[[c42:.*]] = arith.constant 42 : index - // CHECK: call @some_use(%[[I]], %[[c42]]) - call @some_use(%1, %2) : (index, index) -> () - - return -} - -// ----- - -// CHECK-LABEL: func @rank_reducing_init_extract -func.func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> { - // CHECK: linalg.init_tensor [2] : tensor<2xf32> - %a = linalg.init_tensor [%sz, 2] : tensor - - // CHECK-NOT: extract - %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> - return %r: tensor<2xf32> -} - // ----- // CHECK: func @fold_self_copy @@ -475,13 +369,13 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) { // CHECK-LABEL: func @fold_static_pad_fill // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.+]] = linalg.init_tensor [412, 276] : tensor<412x276xf32> +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32> // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]] // CHECK: return %[[FILL]] func.func @fold_static_pad_fill() -> tensor<412x276xf32> { %f0 = arith.constant 0.0 : f32 - %init = linalg.init_tensor [400, 273] : tensor<400x273xf32> - %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<400x273xf32>) -> tensor<400x273xf32> + %empty = tensor.empty() : tensor<400x273xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<400x273xf32>) -> tensor<400x273xf32> %pad = tensor.pad %fill low[4, 1] high[8, 2] { ^bb0(%arg1: index, %arg2: index): tensor.yield %f0 : f32 @@ -507,12 +401,12 @@ func.func @fold_static_pad_fill() -> tensor<412x276xf32> { // CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]] // CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]] // CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]] -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : tensor +// CHECK: %[[INIT:.+]] = tensor.empty(%[[S0]], %[[S1]], %[[S2]], %[[S3]]) : tensor // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]] // CHECK: return %[[FILL]] -func.func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor { +func.func @fold_dynamic_pad_fill(%empty: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor { %f0 = arith.constant 0.0 : f32 - %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x?x16x32xf32>) -> tensor<8x?x16x32xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x?x16x32xf32>) -> tensor<8x?x16x32xf32> %pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] { ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): tensor.yield %f0 : f32 @@ -526,8 +420,8 @@ func.func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %lo func.func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> { %f0 = arith.constant 0.0 : f32 %f1 = arith.constant 1.0 : f32 - %init = linalg.init_tensor [400, 273] : tensor<400x273xf32> - %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<400x273xf32>) -> tensor<400x273xf32> + %empty = tensor.empty() : tensor<400x273xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<400x273xf32>) -> tensor<400x273xf32> // CHECK: tensor.pad %pad = tensor.pad %fill low[4, 1] high[8, 2] { ^bb0(%arg1: index, %arg2: index): @@ -552,7 +446,7 @@ func.func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> - %3 = linalg.init_tensor [%0, %1, %2] : tensor + %3 = tensor.empty(%0, %1, %2) : tensor %4 = linalg.generic { indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"] @@ -582,7 +476,7 @@ func.func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> - %3 = linalg.init_tensor [%0, %1, %2] : tensor + %3 = tensor.empty(%0, %1, %2) : tensor %4 = tensor.cast %arg1 : tensor to tensor<2x?x?xf32> %5 = linalg.generic { indexing_maps = [#map, #map, #map], @@ -613,7 +507,7 @@ func.func @static_output_with_cast(%arg0 : tensor, %arg1: tensor %1 = tensor.dim %arg2, %c1 : tensor<2x3x4xf32> %2 = tensor.dim %arg2, %c2 : tensor<2x3x4xf32> - %3 = linalg.init_tensor [%0, %1, %2] : tensor + %3 = tensor.empty(%0, %1, %2) : tensor %4 = tensor.cast %3 : tensor to tensor<2x3x4xf32> %5 = tensor.cast %arg1 : tensor to tensor<2x?x?xf32> %6 = linalg.generic { @@ -647,7 +541,7 @@ func.func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> t %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> - %3 = linalg.init_tensor [%0, %1, %2] : tensor + %3 = tensor.empty(%0, %1, %2) : tensor %4 = tensor.cast %arg0 : tensor<2x3x4xf32> to tensor<2x?x?xf32> %5 = tensor.cast %arg1 : tensor<2x3x4xf32> to tensor<2x?x?xf32> %6 = linalg.generic { @@ -672,7 +566,7 @@ func.func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> t // CHECK-LABEL: func @cast_dest // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<1x?x?xf32>, func.func @cast_dest(%arg0: tensor, %arg1: tensor<1x?x?xf32>, %arg2: index, %arg3: index, %arg4: index) -> tensor { - %0 = linalg.init_tensor [%arg2, %arg3, %arg4] : tensor + %0 = tensor.empty(%arg2, %arg3, %arg4) : tensor %1 = tensor.cast %arg1 : tensor<1x?x?xf32> to tensor %2 = linalg.generic { indexing_maps = [#map, #map, #map], @@ -699,7 +593,7 @@ func.func @cast_dest(%arg0: tensor, %arg1: tensor<1x?x?xf32>, %arg2: // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.+]] = linalg.init_tensor [8, 384, 384] +// CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]] // CHECK: %[[OFFSET1:.+]] = affine.apply #[[$MAP]]()[%[[LOW1]]] // CHECK: %[[D0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor @@ -713,8 +607,8 @@ func.func @insert_pad_into_fill(%input: tensor, %low0: index, %low1: ^bb0(%arg3: index, %arg4: index, %arg5: index): tensor.yield %f0 : f32 } : tensor to tensor<8x128x128xf32> - %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> + %empty = tensor.empty() : tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> return %0: tensor<8x384x384xf32> } @@ -734,8 +628,8 @@ func.func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor< ^bb0(%arg3: index, %arg4: index, %arg5: index): tensor.yield %f0 : f32 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> - %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> + %empty = tensor.empty() : tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> @@ -753,8 +647,8 @@ func.func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: ^bb0(%arg3: index, %arg4: index, %arg5: index): tensor.yield %f0 : f32 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> - %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> + %empty = tensor.empty() : tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %1 = tensor.insert_slice %a into %0 [0, 0, 129] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> // Range overlap with %1 at dim#3 @@ -773,8 +667,8 @@ func.func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: ^bb0(%arg3: index, %arg4: index, %arg5: index): tensor.yield %f0 : f32 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> - %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> + %empty = tensor.empty() : tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %1 = tensor.insert_slice %a into %0 [0, 128, 255] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> // Range overlap with %0 at dim#3 @@ -793,8 +687,8 @@ func.func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor< ^bb0(%arg3: index, %arg4: index, %arg5: index): tensor.yield %f0 : f32 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> - %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> + %empty = tensor.empty() : tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> // Overlap btween %0 and %1 is fine but not with %2 is fine. // CHECK-COUNT-3: tensor.insert_slice %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> @@ -815,9 +709,9 @@ func.func @multi_insert_pad_into_fill_mismatch(%input: tensor<7x123x124xf32>, %a ^bb0(%arg3: index, %arg4: index, %arg5: index): tensor.yield %f0 : f32 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> - %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> + %empty = tensor.empty() : tensor<8x384x384xf32> // Different filling value than padding value. - %fill = linalg.fill ins(%f1 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> + %fill = linalg.fill ins(%f1 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> @@ -903,14 +797,14 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor) -> %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg0, %c1 : tensor %d2 = tensor.dim %arg0, %c2 : tensor - %init1 = linalg.init_tensor [%d1, %d2, %d0] : tensor - %init2 = linalg.init_tensor [%d2, %d1, %d0] : tensor + %empty1 = tensor.empty(%d1, %d2, %d0) : tensor + %empty2 = tensor.empty(%d2, %d1, %d0) : tensor %0:2 = linalg.generic { iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d2, d1, d0)>]} - ins(%arg0 : tensor) outs(%init1, %init2 : tensor, tensor) { + ins(%arg0 : tensor) outs(%empty1, %empty2 : tensor, tensor) { ^bb0(%b0 : f32, %b1 : f32, %b2 : f32) : linalg.yield %b0, %b0 : f32, f32 } -> (tensor, tensor) @@ -919,9 +813,9 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor) -> } // CHECK: func @fold_multi_use_generic_op_with_consumer // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<2x3x4xf32> // CHECK-DAG: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor to tensor<4x3x2xf32> -// CHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [3, 2, 4] : tensor<3x2x4xf32> +// CHECK-DAG: %[[INIT2:.+]] = tensor.empty() : tensor<3x2x4xf32> // CHECK: %[[GENERIC:.+]]:2 = linalg.generic // CHECK-SAME: ins(%[[CAST]] : // CHECK-SAME: outs(%[[INIT2]], %[[INIT1]] : diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir index 4f71a0c..41cc866 100644 --- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -75,7 +75,7 @@ func.func @select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor func.func @cmpf(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: %[[INIT:.*]] = linalg.init_tensor [] : tensor + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor // CHECK: linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] // CHECK-SAME: outs(%[[INIT]] @@ -98,7 +98,7 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>) // CHECK: %[[D2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<4x?x?x8x2x?xf32> // CHECK: %[[C5:.*]] = arith.constant 5 : index // CHECK: %[[D5:.*]] = tensor.dim %[[ARG0]], %[[C5]] : tensor<4x?x?x8x2x?xf32> - // CHECK: %[[INIT:.*]] = linalg.init_tensor [4, %[[D1]], %[[D2]], 8, 2, %[[D5]]] : tensor<4x?x?x8x2x?xi1> + // CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]], %[[D2]], %[[D5]]) : tensor<4x?x?x8x2x?xi1> // CHECK: linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] // CHECK-SAME: outs(%[[INIT]] diff --git a/mlir/test/Dialect/Linalg/decompose-ops.mlir b/mlir/test/Dialect/Linalg/decompose-ops.mlir index cd8bd98..3eed6d2 100644 --- a/mlir/test/Dialect/Linalg/decompose-ops.mlir +++ b/mlir/test/Dialect/Linalg/decompose-ops.mlir @@ -7,8 +7,8 @@ func.func @simple_op(%arg0 : tensor, %arg1 : tensor, %arg2 : ten %c1 = arith.constant 1 : index %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg0, %c1 : tensor - %init1 = linalg.init_tensor [%d1, %d0] : tensor - %init2 = linalg.init_tensor [%d0, %d1] : tensor + %init1 = tensor.empty(%d1, %d0) : tensor + %init2 = tensor.empty(%d0, %d1) : tensor %result:2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, @@ -35,8 +35,8 @@ func.func @simple_op(%arg0 : tensor, %arg1 : tensor, %arg2 : ten // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]] -// CHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty(%[[D1]], %[[D0]]) +// CHECK-DAG: %[[INIT2:.+]] = tensor.empty(%[[D0]], %[[D1]]) // CHECK-DAG: %[[GENERIC1:.+]]:3 = linalg.generic // CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP0]], #[[MAP3]]] // CHECK-SAME: ["parallel", "parallel"] @@ -81,8 +81,8 @@ func.func @simple_op(%arg0 : tensor, %arg1 : tensor, %arg2 : ten // CANONICALIZECHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CANONICALIZECHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] // CANONICALIZECHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CANONICALIZECHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]] -// CANONICALIZECHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CANONICALIZECHECK-DAG: %[[INIT1:.+]] = tensor.empty(%[[D1]], %[[D0]]) +// CANONICALIZECHECK-DAG: %[[INIT2:.+]] = tensor.empty(%[[D0]], %[[D1]]) // CANONICALIZECHECK-DAG: %[[GENERIC1:.+]] = linalg.generic // CANONICALIZECHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CANONICALIZECHECK-SAME: ["parallel", "parallel"] @@ -116,8 +116,8 @@ func.func @simple_op_permuted_outputs(%arg0 : tensor, %arg1 : tensor %d1 = tensor.dim %arg0, %c1 : tensor - %init1 = linalg.init_tensor [%d1, %d0] : tensor - %init2 = linalg.init_tensor [%d0, %d1] : tensor + %init1 = tensor.empty(%d1, %d0) : tensor + %init2 = tensor.empty(%d0, %d1) : tensor %result:3 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, @@ -144,8 +144,8 @@ func.func @simple_op_permuted_outputs(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, %arg1 : tensor (d0)> #map2 = affine_map<(d0, d1) -> (d1, d0)> func.func @multi_statement(%arg0 : tensor<10x20xf32>, %arg1 : tensor<10xi32>) -> tensor<20x10xf64> { - %init = linalg.init_tensor [20, 10] : tensor<20x10xf64> + %init = tensor.empty() : tensor<20x10xf64> %0 = linalg.generic { indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel"]} @@ -242,8 +242,8 @@ func.func @multi_statement(%arg0 : tensor<10x20xf32>, %arg1 : tensor<10xi32>) -> // CHECK: func @multi_statement( // CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf32> // CHECK-SAME: %[[ARG1:.+]]: tensor<10xi32>) -// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [20, 10] : tensor<20x10xf64> -// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf64> +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<20x10xf64> +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<10x20xf64> // CHECK: %[[GENERIC0:.+]]:2 = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] // CHECK-SAME: iterator_types = ["parallel", "parallel"] @@ -290,8 +290,8 @@ func.func @multi_statement(%arg0 : tensor<10x20xf32>, %arg1 : tensor<10xi32>) -> // CANONICALIZECHECK: func @multi_statement( // CANONICALIZECHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf32> // CANONICALIZECHECK-SAME: %[[ARG1:.+]]: tensor<10xi32>) -// CANONICALIZECHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [20, 10] : tensor<20x10xf64> -// CANONICALIZECHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf64> +// CANONICALIZECHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<20x10xf64> +// CANONICALIZECHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<10x20xf64> // CANONICALIZECHECK: %[[GENERIC0:.+]] = linalg.generic // CANONICALIZECHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CANONICALIZECHECK-SAME: iterator_types = ["parallel", "parallel"] diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir index 218d722..5450580 100644 --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -3,7 +3,7 @@ #map = affine_map<() -> ()> func.func @detensor_simple(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty() : tensor %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg1, %arg2 : tensor, tensor) outs(%0 : tensor) { @@ -22,7 +22,7 @@ func.func @detensor_simple(%arg1: tensor, %arg2: tensor) -> tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty() : tensor %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg1, %arg2 : tensor, tensor) outs(%0 : tensor) { @@ -31,7 +31,7 @@ func.func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tenso linalg.yield %2 : f32 } -> tensor - %3 = linalg.init_tensor [] : tensor + %3 = tensor.empty() : tensor %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg1, %1 : tensor, tensor) outs(%3 : tensor) { @@ -40,7 +40,7 @@ func.func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tenso linalg.yield %5 : f32 } -> tensor - %6 = linalg.init_tensor [] : tensor + %6 = tensor.empty() : tensor %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%1, %4 : tensor, tensor) outs(%6 : tensor) { @@ -62,7 +62,7 @@ func.func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tenso // CHECK: return %[[new_tensor_res]] func.func @detensor_multiple_ops(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty() : tensor %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg1, %arg2 : tensor, tensor) outs(%0 : tensor) { @@ -83,7 +83,7 @@ func.func @detensor_multiple_ops(%arg1: tensor, %arg2: tensor) -> tens // CHECK: return %[[new_tensor_res]] func.func @detensor_foreign_op(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty() : tensor %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg1, %arg2 : tensor, tensor) outs(%0 : tensor) { diff --git a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir index 66bab42..c0cf7ab 100644 --- a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir @@ -6,7 +6,7 @@ func.func @if_true_test(%arg0: i1, %arg1: i32) -> tensor attributes {} { %arg1_t = tensor.from_elements %arg1 : tensor %cst = arith.constant dense<10> : tensor - %2 = linalg.init_tensor [] : tensor + %2 = tensor.empty() : tensor %3 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%arg0_t : tensor) @@ -19,7 +19,7 @@ func.func @if_true_test(%arg0: i1, %arg1: i32) -> tensor attributes {} { %5 = arith.trunci %4 : i8 to i1 cf.cond_br %5, ^bb1, ^bb2(%arg1_t : tensor) ^bb1: - %6 = linalg.init_tensor [] : tensor + %6 = tensor.empty() : tensor %7 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%arg1_t, %cst : tensor, tensor) diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir index fec9273..1720d6f 100644 --- a/mlir/test/Dialect/Linalg/detensorize_if.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir @@ -15,7 +15,7 @@ func.func @main() -> (tensor) attributes {} { cf.br ^bb1(%0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 - %3 = linalg.init_tensor [] : tensor + %3 = tensor.empty() : tensor %4 = linalg.generic #attrs ins(%2, %1 : tensor, tensor) outs(%3 : tensor) { @@ -27,7 +27,7 @@ func.func @main() -> (tensor) attributes {} { cf.cond_br %5, ^bb2(%2 : tensor), ^bb3(%2 : tensor) ^bb2(%6: tensor): // pred: ^bb1 - %7 = linalg.init_tensor [] : tensor + %7 = tensor.empty() : tensor %8 = linalg.generic #attrs ins(%6, %6 : tensor, tensor) outs(%7 : tensor) { @@ -76,7 +76,7 @@ func.func @main() -> (tensor) attributes {} { cf.br ^bb1(%0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 - %3 = linalg.init_tensor [] : tensor + %3 = tensor.empty() : tensor %4 = linalg.generic #attrs ins(%2, %1 : tensor, tensor) outs(%3 : tensor) { @@ -88,7 +88,7 @@ func.func @main() -> (tensor) attributes {} { cf.cond_br %5, ^bb2(%2 : tensor), ^bb3(%2 : tensor) ^bb2(%6: tensor): // pred: ^bb1 - %7 = linalg.init_tensor [] : tensor + %7 = tensor.empty() : tensor %8 = linalg.generic #attrs ins(%6, %6 : tensor, tensor) outs(%7 : tensor) { @@ -139,7 +139,7 @@ func.func @main() -> (tensor) attributes {} { cf.br ^bb1(%0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 - %3 = linalg.init_tensor [] : tensor + %3 = tensor.empty() : tensor %4 = linalg.generic #attrs ins(%2, %1 : tensor, tensor) outs(%3 : tensor) { @@ -156,7 +156,7 @@ func.func @main() -> (tensor) attributes {} { ^bb2(%6: tensor): // pred: ^bb1 %12 = tensor.from_elements %c10 : tensor - %7 = linalg.init_tensor [] : tensor + %7 = tensor.empty() : tensor %8 = linalg.generic #attrs ins(%6, %12 : tensor, tensor) outs(%7 : tensor) { diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir index 7abbcba..fa65ae3 100644 --- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir @@ -12,7 +12,7 @@ func.func @main(%farg0 : tensor) -> (tensor) attributes {} { %c10 = arith.constant 10 : i32 %1 = tensor.from_elements %c10 : tensor - %3 = linalg.init_tensor [] : tensor + %3 = tensor.empty() : tensor %4 = linalg.generic #attrs ins(%farg0, %1 : tensor, tensor) outs(%3 : tensor) { @@ -34,7 +34,7 @@ func.func @main(%farg0 : tensor) -> (tensor) attributes {} { // DET-CF-LABEL: func @main(%{{.*}}: tensor) // DET-CF-NEXT: arith.constant dense<10> : tensor -// DET-CF-NEXT: linalg.init_tensor [] : tensor +// DET-CF-NEXT: tensor.empty() : tensor // DET-CF-NEXT: linalg.generic // DET-CF-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1) // DET-CF-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}} diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir index 6d34c9f..7b70053 100644 --- a/mlir/test/Dialect/Linalg/detensorize_while.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir @@ -12,7 +12,7 @@ func.func @main(%farg0: tensor, %farg1: tensor) -> tensor attribu cf.br ^bb1(%farg0 : tensor) ^bb1(%0: tensor): // 2 preds: ^bb0, ^bb2 - %1 = linalg.init_tensor [] : tensor + %1 = tensor.empty() : tensor %2 = linalg.generic #attrs ins(%0, %farg1 : tensor, tensor) outs(%1 : tensor) { @@ -24,7 +24,7 @@ func.func @main(%farg0: tensor, %farg1: tensor) -> tensor attribu cf.cond_br %3, ^bb2(%0 : tensor), ^bb3(%0 : tensor) ^bb2(%4: tensor): // pred: ^bb1 - %5 = linalg.init_tensor [] : tensor + %5 = tensor.empty() : tensor %6 = linalg.generic #attrs ins(%4, %4 : tensor, tensor) outs(%5 : tensor) { diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir index 87a28af..a0d3cff 100644 --- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir @@ -25,7 +25,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor) -> tensor attr cf.br ^bb1(%farg0 : tensor<10xi32>) ^bb1(%0: tensor<10xi32>): // 2 preds: ^bb0, ^bb2 - %1 = linalg.init_tensor [] : tensor + %1 = tensor.empty() : tensor %2 = linalg.generic #sum_reduction_attrs ins(%0: tensor<10xi32>) outs(%1: tensor) { @@ -34,7 +34,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor) -> tensor attr linalg.yield %b : i32 } -> tensor - %3 = linalg.init_tensor [] : tensor + %3 = tensor.empty() : tensor %4 = linalg.generic #attrs ins(%2, %farg1 : tensor, tensor) outs(%3 : tensor) { @@ -46,7 +46,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor) -> tensor attr cf.cond_br %5, ^bb2(%2 : tensor), ^bb3(%2 : tensor) ^bb2(%6: tensor): // pred: ^bb1 - %7 = linalg.init_tensor [10] : tensor<10xi32> + %7 = tensor.empty() : tensor<10xi32> %9 = linalg.generic #broadcast_attrs ins(%6: tensor) outs(%7: tensor<10xi32>) { @@ -66,7 +66,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor) -> tensor attr // DET-ALL-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor) // DET-ALL: cf.br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>) // DET-ALL: ^[[bb1]](%{{.*}}: tensor<10xi32>) -// DET-ALL: linalg.init_tensor [] : tensor +// DET-ALL: tensor.empty() : tensor // DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor) { // DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): // DET-ALL: %{{.*}} = arith.addi %{{.*}}, %{{.*}} @@ -77,7 +77,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor) -> tensor attr // DET-ALL: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) // DET-ALL: ^[[bb2]](%{{.*}}: i32) // DET-ALL: tensor.from_elements %{{.*}} : tensor -// DET-ALL: linalg.init_tensor [10] : tensor<10xi32> +// DET-ALL: tensor.empty() : tensor<10xi32> // DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor) outs(%{{.*}} : tensor<10xi32>) { // DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): // DET-ALL: linalg.yield %{{.*}} : i32 diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir index 2ff0059..59137f9 100644 --- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir @@ -17,7 +17,7 @@ func.func @main() -> () attributes {} { cf.br ^bb1(%reshaped0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 - %3 = linalg.init_tensor [] : tensor + %3 = tensor.empty() : tensor %4 = linalg.generic #attrs ins(%2, %reshaped1 : tensor, tensor) outs(%3 : tensor) { @@ -29,7 +29,7 @@ func.func @main() -> () attributes {} { cf.cond_br %5, ^bb2(%2 : tensor), ^bb3 ^bb2(%6: tensor): // pred: ^bb1 - %7 = linalg.init_tensor [] : tensor + %7 = tensor.empty() : tensor %8 = linalg.generic #attrs ins(%6, %6 : tensor, tensor) outs(%7 : tensor) { diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 24b2c73..12ecdda 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -246,7 +246,7 @@ func.func @broadcast_scalar(%arg0 : tensor<1x1xf32>, %shape : tensor) - #map1 = affine_map<(d0, d1, d2) -> (d2)> func.func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32> { - %1 = linalg.init_tensor [1, 2, 5] : tensor<1x2x5xf32> + %1 = tensor.empty() : tensor<1x2x5xf32> %2 = linalg.generic {i64, indexing_maps = [#map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<5xf32>) outs(%1 : tensor<1x2x5xf32>) { @@ -263,9 +263,9 @@ func.func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5x // ----- -func.func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32> { +func.func @fold_unit_dim_for_empty_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32> { %cst = arith.constant 0.0 : f32 - %init = linalg.init_tensor [1] : tensor<1xf32> + %init = tensor.empty() : tensor<1xf32> %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1xf32>) -> tensor<1xf32> %add = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], @@ -282,11 +282,11 @@ func.func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1 // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ()> -// CHECK: func @fold_unit_dim_for_init_tensor +// CHECK: func @fold_unit_dim_for_empty_tensor // CHECK: %[[INPUT_RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32> -// CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] @@ -330,7 +330,7 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> %cst = arith.constant 1.000000e+00 : f32 %c3 = arith.constant 3 : index %0 = tensor.dim %arg0, %c3 : tensor<1x?x1x?xf32> - %1 = linalg.init_tensor [1, %0] : tensor<1x?xf32> + %1 = tensor.empty(%0) : tensor<1x?xf32> %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x?xf32>) -> tensor<1x?xf32> %3 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, @@ -349,7 +349,7 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> // CHECK: func @unit_dim_for_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32> // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor +// CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] @@ -364,7 +364,7 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> { %cst = arith.constant 1.000000e+00 : f32 %c3 = arith.constant 3 : index - %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32> + %1 = tensor.empty() : tensor<1x1xf32> %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32> %3 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, @@ -382,7 +382,7 @@ func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1 // CHECK: func @unit_dim_for_both_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x1xf32> // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] -// CHECK: %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32> +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] @@ -398,7 +398,7 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor) -> tensor - %1 = linalg.init_tensor [%0, 1] : tensor + %1 = tensor.empty(%0) : tensor %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor %3 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, @@ -417,7 +417,7 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor) -> tensor // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor +// CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] @@ -809,7 +809,7 @@ func.func @input_stays_same(%arg0 : memref>, %arg1 #CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }> func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> tensor<8xf32> { - %0 = linalg.init_tensor [8] : tensor<8xf32> + %0 = tensor.empty() : tensor<8xf32> %1 = linalg.generic #matvec ins(%arg0, %arg1: tensor<8x8xf32, #CSR>, tensor<8xf32>) outs(%0: tensor<8xf32>) { @@ -822,7 +822,7 @@ func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> te } // CHECK-LABEL: func @sparse_case -// CHECK-NEXT: linalg.init_tensor +// CHECK-NEXT: tensor.empty // CHECK-NEXT: linalg.generic // ----- @@ -831,9 +831,9 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> { %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.init_tensor [4, 2] : tensor<4x2xf32> + %0 = tensor.empty() : tensor<4x2xf32> %res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) { - %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32> + %1 = tensor.empty() : tensor<1x1xf32> %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32> scf.foreach_thread.perform_concurrently { // CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}} diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index fd0bd88..43450c4 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -11,7 +11,7 @@ func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>, %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> - %init = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x3x4x5x6x7x8x9xi32> + %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> %generic = linalg.generic { indexing_maps = [#map0, #map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} @@ -31,7 +31,7 @@ func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>, // CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32> // CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32> -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] +// CHECK-DAG: %[[INIT:.+]] = tensor.empty() // CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}} // CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}} // CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} @@ -62,7 +62,7 @@ func.func @fuse_by_collapsing_indexing_op(%arg0 : tensor<2x12x5x336x9xi32>, %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> - %init = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x3x4x5x6x7x8x9xi32> + %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> %generic = linalg.generic { indexing_maps = [#map0, #map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} @@ -124,7 +124,7 @@ func.func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi3 %arg1 : tensor<7x8x2xi32>, %arg2 : tensor<6x3x4x5xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<9x56x2x60x6xi32> into tensor<9x7x8x2x3x4x5x6xi32> - %init = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x3x4x5x6x7x8x9xi32> + %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> %generic = linalg.generic { indexing_maps = [#map0, #map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} @@ -145,7 +145,7 @@ func.func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi3 // CHECK-SAME: %[[ARG0:.+]]: tensor<9x56x2x60x6xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32> // CHECK-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32> -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] +// CHECK-DAG: %[[INIT:.+]] = tensor.empty() // CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2]{{\]}} // CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}} // CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}} @@ -176,7 +176,7 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor, %d4 = tensor.dim %arg2, %c0 : tensor %d6 = tensor.dim %arg1, %c1 : tensor %d7 = tensor.dim %arg0, %c0 : tensor - %init = linalg.init_tensor [%d0, 3, %d2, 5, %d4, 7, %d6, %d7] : tensor + %init = tensor.empty(%d0, %d2, %d4, %d6, %d7) : tensor %generic = linalg.generic { indexing_maps = [#map0, #map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} @@ -254,7 +254,7 @@ func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>) - #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> func.func @no_fuse_unpreserved_folding(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x3x4x5xf32> { %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> - %init = linalg.init_tensor [2, 3, 4, 5] : tensor<2x3x4x5xf32> + %init = tensor.empty(): tensor<2x3x4x5xf32> %1 = linalg.generic { indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} @@ -281,7 +281,7 @@ func.func @no_fuse_unpreserved_folding(%arg0 : tensor<2x12x5xf32>, %arg1 : tenso #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)> func.func @no_fuse_unpreserved_folding_transpose(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2xf32>) -> tensor<2x4x3x5xf32> { %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> - %init = linalg.init_tensor [2, 4, 3, 5] : tensor<2x4x3x5xf32> + %init = tensor.empty() : tensor<2x4x3x5xf32> %1 = linalg.generic { indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} @@ -308,7 +308,7 @@ func.func @no_fuse_unpreserved_folding_transpose(%arg0 : tensor<2x12x5xf32>, %ar #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3)> func.func @no_fuse_mismatched_iterator_types(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x5xf32> { %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> - %init = linalg.init_tensor [2, 5] : tensor<2x5xf32> + %init = tensor.empty() : tensor<2x5xf32> %1 = linalg.generic { indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} @@ -337,7 +337,7 @@ func.func @no_fuse_mismatched_iterator_types(%arg0 : tensor<2x12x5xf32>, %arg1 : func.func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tensor<2x3x4x5xf32> { %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<6xf32> into tensor<2x3xf32> %1 = tensor.expand_shape %arg1 [[0, 1]] : tensor<20xf32> into tensor<4x5xf32> - %init = linalg.init_tensor [2, 3, 4, 5] : tensor<2x3x4x5xf32> + %init = tensor.empty() : tensor<2x3x4x5xf32> %2 = linalg.generic { indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} @@ -370,7 +370,7 @@ func.func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tens // CONTROL-SAME: %[[ARG0:.+]]: tensor<6xf32> // CONTROL-SAME: %[[ARG1:.+]]: tensor<20xf32> // CONTROL: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] -// CONTROL: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 5] +// CONTROL: %[[INIT:.+]] = tensor.empty() // CONTROL: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]{{\]}} // CONTROL: %[[GENERIC:.+]] = linalg.generic // CONTROL-SAME: ins(%[[EXPAND]], %[[ARG1]] : @@ -383,7 +383,7 @@ func.func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tens #map = affine_map<(d0) -> (d0)> func.func @zero_D_test(%arg0: tensor) -> tensor<1xf32> { %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> - %init = linalg.init_tensor [1] : tensor<1xf32> + %init = tensor.empty() : tensor<1xf32> %1 = linalg.generic { indexing_maps = [#map, #map], iterator_types = ["parallel"]} @@ -444,7 +444,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor) -> tensor into tensor %d0 = tensor.dim %0, %c0 : tensor %d1 = tensor.dim %0, %c2 : tensor - %init = linalg.init_tensor [%d1, 8, %d0, 4] : tensor + %init = tensor.empty(%d1, %d0) : tensor %1 = linalg.generic { indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} @@ -468,7 +468,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor) -> tensor) // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[INIT:.+]] = tensor.empty // CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}} // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] @@ -500,7 +500,7 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor) -> te %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor - %init = linalg.init_tensor [] : tensor + %init = tensor.empty() : tensor %1 = linalg.generic { indexing_maps = [#map0, #map1], iterator_types = ["reduction", "reduction", "reduction", "reduction"]} diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 3002d44..cf69f04 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -10,7 +10,7 @@ func.func @add_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { @@ -52,7 +52,7 @@ func.func @scalar_add_mul_fusion(%arg0: tensor, %arg1 : f32, %arg2 : f3 %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, f32) outs(%2 : tensor) { @@ -94,7 +94,7 @@ func.func @transpose_add_mul_fusion(%arg0: tensor, %arg1 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { @@ -128,7 +128,7 @@ func.func @add_transpose_mul_fusion(%arg0: tensor, %arg1 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { @@ -162,7 +162,7 @@ func.func @add_broadcast_mul_fusion(%arg0: tensor, %arg1 : tensor, %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor - %1 = linalg.init_tensor [%0] : tensor + %1 = tensor.empty(%0) : tensor %2 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%1 : tensor) { @@ -173,7 +173,7 @@ func.func @add_broadcast_mul_fusion(%arg0: tensor, %arg1 : tensor, // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = {{\[}}[[$MAP1]], [[$MAP1]], [[$MAP0]], [[$MAP0]] %3 = tensor.dim %arg2, %c1 : tensor - %4 = linalg.init_tensor [%0, %3] : tensor + %4 = tensor.empty(%0, %3) : tensor %5 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%2, %arg2 : tensor, tensor) outs(%4 : tensor){ @@ -192,7 +192,7 @@ func.func @add_broadcast_mul_fusion(%arg0: tensor, %arg1 : tensor, // CHECK-LABEL: @add_mul_scalar_fusion func.func @add_mul_scalar_fusion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty() : tensor %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) { @@ -226,7 +226,7 @@ func.func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x? %cst = arith.constant dense<42.0> : tensor<5xf32> %0 = tensor.dim %arg0, %c1 : tensor<5x?x?xf32> %1 = tensor.dim %arg0, %c2 : tensor<5x?x?xf32> - %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32> + %2 = tensor.empty(%0, %1) : tensor<5x?x?xf32> %3 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} @@ -258,7 +258,7 @@ func.func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>) %cst = arith.constant dense<42.0> : tensor %0 = tensor.dim %arg0, %c1 : tensor<5x?x?xf32> %1 = tensor.dim %arg0, %c2 : tensor<5x?x?xf32> - %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32> + %2 = tensor.empty(%0, %1) : tensor<5x?x?xf32> %3 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} @@ -286,7 +286,7 @@ func.func @producer_indexed_consumer_fusion(%arg0: tensor, %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor %3 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"] } @@ -337,7 +337,7 @@ func.func @indexed_producer_consumer_fusion(%arg0: tensor) -> tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor %3 = linalg.generic { indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"] } @@ -391,7 +391,7 @@ func.func @indexed_producer_indexed_consumer_fusion(%arg0: tensor) %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor %3 = linalg.generic { indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"] } @@ -453,7 +453,7 @@ func.func @one_dim_indexed_producer_consumer_fusion(%arg0 : tensor, %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %d0 = tensor.dim %arg0, %c0 : tensor - %0 = linalg.init_tensor [%d0] : tensor + %0 = tensor.empty(%d0) : tensor %1 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} @@ -466,7 +466,7 @@ func.func @one_dim_indexed_producer_consumer_fusion(%arg0 : tensor, } -> tensor %2 = tensor.dim %arg1, %c0 : tensor %3 = tensor.dim %arg1, %c1 : tensor - %4 = linalg.init_tensor [%2, %3] : tensor + %4 = tensor.empty(%2, %3) : tensor %5 = linalg.generic {indexing_maps = [#map2, #map3, #map2], iterator_types = ["parallel", "parallel"]} @@ -499,7 +499,7 @@ func.func @scalar_generic_fusion { %c0 = arith.constant 0 : index %cst = arith.constant dense<1.000000e+00> : tensor<10xf32> - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty() : tensor %1 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} @@ -509,7 +509,7 @@ func.func @scalar_generic_fusion %4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32> linalg.yield %4 : f32 } -> tensor - %2 = linalg.init_tensor [10] : tensor<10xf32> + %2 = tensor.empty() : tensor<10xf32> %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], @@ -538,7 +538,7 @@ func.func @scalar_generic_fusion func.func @constant_fusion(%arg0 : tensor<4xf32>) -> (tensor<4xf32>) { %cst = arith.constant dense<1.0> : tensor<4xf32> - %1 = linalg.init_tensor [4] : tensor<4xf32> + %1 = tensor.empty() : tensor<4xf32> %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], @@ -555,7 +555,7 @@ func.func @constant_fusion(%arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)> // CHECK: func @constant_fusion(%[[ARG0:.+]]: tensor<4xf32>) // CHECK-DAG: %[[CST:.+]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [4] : tensor<4xf32> +// CHECK-DAG: %[[T0:.+]] = tensor.empty() : tensor<4xf32> // CHECK: %[[T1:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] // CHECK-SAME: ins(%[[ARG0]] : tensor<4xf32>) @@ -574,7 +574,7 @@ func.func @constant_fusion(%arg0 : tensor<4xf32>) -> (tensor<4xf32>) { func.func @consumer_with_reduction(%arg0: tensor<1x10xf32>, %arg1: tensor<1x10xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> { - %init = linalg.init_tensor [1, 10] : tensor<1x10xf32> + %init = tensor.empty() : tensor<1x10xf32> %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} @@ -620,7 +620,7 @@ func.func @sigmoid_dynamic_dim(%0: tensor) -> tensor { %shape = shape.shape_of %0 : tensor -> tensor %extend = shape.to_extent_tensor %shape : tensor -> tensor<2xindex> %extracted = tensor.extract %extend[%c0] : tensor<2xindex> - %init0 = linalg.init_tensor [%extracted, 1] : tensor + %init0 = tensor.empty(%extracted) : tensor %1 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"] @@ -630,7 +630,7 @@ func.func @sigmoid_dynamic_dim(%0: tensor) -> tensor { linalg.yield %cp5 : f32 } -> tensor %d0 = tensor.dim %0, %c0 : tensor - %init1 = linalg.init_tensor [%d0, 1] : tensor + %init1 = tensor.empty(%d0) : tensor %2 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, @@ -692,7 +692,7 @@ func.func @no_fuse_constant_with_reduction() -> tensor<3xf32> // CHECK-SAME: ins(%[[CONST]] : tensor<3x2xf32>) // CHECK: return %[[RESULT]] %three = arith.constant dense<3.0> : tensor<3x2xf32> - %init = linalg.init_tensor [3] : tensor<3xf32> + %init = tensor.empty() : tensor<3xf32> %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], @@ -732,12 +732,12 @@ func.func @break_outs_dependency(%arg0 : tensor) -> tensor // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) // CHECK: %[[GENERIC1:.+]] = linalg.generic // CHECK-SAME: outs(%[[INIT]] : tensor) // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[GENERIC1]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[GENERIC1]], %[[C1]] -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: outs(%[[INIT]] : tensor) @@ -750,8 +750,8 @@ func.func @fuse_scalar_constant(%arg0 : tensor) -> (tensor, te %c1 = arith.constant 1 : index %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg0, %c1 : tensor - %0 = linalg.init_tensor[%d0, %d1] : tensor - %1 = linalg.init_tensor[%d0, %d1] : tensor + %0 = tensor.empty(%d0, %d1) : tensor + %1 = tensor.empty(%d0, %d1) : tensor %2:2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, @@ -925,7 +925,7 @@ func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tenso // CHECK: linalg.generic func.func @no_fusion_missing_reduction_shape(%arg0: tensor, %arg1: index) -> tensor { %cst = arith.constant 0xFF800000 : f32 - %4 = linalg.init_tensor [%arg1, %arg1] : tensor + %4 = tensor.empty(%arg1, %arg1) : tensor %5 = linalg.generic { indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"] @@ -933,7 +933,7 @@ func.func @no_fusion_missing_reduction_shape(%arg0: tensor, %arg1: index) - ^bb0(%arg2: f32, %arg3: f32): linalg.yield %arg2 : f32 } -> tensor - %6 = linalg.init_tensor [%arg1] : tensor + %6 = tensor.empty(%arg1) : tensor %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor) -> tensor %8 = linalg.generic { indexing_maps = [#map2, #map3], @@ -959,7 +959,7 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi %23 = arith.index_cast %22 : index to i64 linalg.yield %23 : i64 } -> tensor<5000xi64> - %1 = linalg.init_tensor [5000] : tensor<5000xi32> + %1 = tensor.empty() : tensor<5000xi32> %2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["parallel", "parallel"]} @@ -976,8 +976,8 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi // CHECK: func @fusion_different_axes( // CHECK-SAME: %[[ARG0:.+]]: tensor<5000xi64> // CHECK-SAME: %[[ARG1:.+]]: tensor<5000xi32> -// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [5000] : tensor<5000xi64> -// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [5000] : tensor<5000xi32> +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<5000xi64> +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<5000xi32> // CHECK: %[[RESULT:.+]]:2 = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] : @@ -1004,9 +1004,9 @@ func.func @fold_fill_generic_basic(%arg0: tensor) -> (tensor) { %c0 = arith.constant 0 : index %cst = arith.constant 7.0 : f32 %0 = tensor.dim %arg0, %c0 : tensor - %1 = linalg.init_tensor [%0] : tensor + %1 = tensor.empty(%0) : tensor %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %3 = linalg.init_tensor [%0] : tensor + %3 = tensor.empty(%0) : tensor %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): %5 = arith.addf %arg1, %arg2 : f32 @@ -1031,11 +1031,11 @@ func.func @fold_fill_generic_mixedaccess(%arg0: tensor) -> (tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor) -> tensor - %4 = linalg.init_tensor [%1, %0] : tensor + %4 = tensor.empty(%1, %0) : tensor %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor) -> tensor - %6 = linalg.init_tensor [%0, %1] : tensor + %6 = tensor.empty(%0, %1) : tensor %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor, tensor) outs (%6:tensor) { ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): %8 = arith.divf %arg1, %arg2 : f32 @@ -1049,8 +1049,8 @@ func.func @fold_fill_generic_mixedaccess(%arg0: tensor) -> (tensor ()> module { func.func @fuse_multi_result_producer(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { - %0 = linalg.init_tensor [] : tensor - %1 = linalg.init_tensor [] : tensor + %0 = tensor.empty() : tensor + %1 = tensor.empty() : tensor %2:2 = linalg.generic { indexing_maps = [#map, #map, #map, #map, #map], iterator_types = []} ins(%arg0, %arg1, %arg1 : tensor, tensor, tensor) outs(%0, %1 : tensor, tensor) { @@ -1073,7 +1073,7 @@ module { // CHECK-LABEL: func.func @fuse_multi_result_producer // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[INIT:.+]] = tensor.empty // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK-SAME: outs(%[[INIT]] : diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir index b051728..9373bde 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir @@ -18,7 +18,7 @@ func.func @test_fusion_limit( %c1 = arith.constant 1 : index %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg0, %c1 : tensor - %init = linalg.init_tensor [%d0, %d1] : tensor + %init = tensor.empty(%d0, %d1) : tensor %0 = linalg.generic #binary2Dpointwise ins(%arg0, %arg1 : tensor, tensor) outs(%init : tensor) { diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir index 33489cb..e55bac4 100644 --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -34,7 +34,7 @@ func.func @reshape(%A: tensor, %B: tensor<16xf32>, %init: tensor, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>) -// CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> +// CHECK: %[[I:.*]] = tensor.empty() : tensor<112x112x16xf32> // CHECK: %[[RI:.*]] = tensor.collapse_shape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32> // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} @@ -47,7 +47,7 @@ func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>, : tensor<12544x16xf32> into tensor<112x112x16xf32> %1 = tensor.expand_shape %B [[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> - %2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> + %2 = tensor.empty() : tensor<112x112x16xf32> %3 = linalg.generic {indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, @@ -75,7 +75,7 @@ func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>, func.func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> { %20 = tensor.expand_shape %A [[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> - %21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> + %21 = tensor.empty() : tensor<112x112x16xf32> %22 = linalg.generic {indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], @@ -98,7 +98,7 @@ func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>, %cst_8 = arith.constant 1.1920929E-7 : f32 %25 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<6x5xi32> into tensor<2x3x5xi32> - %26 = linalg.init_tensor [2, 3, 5] : tensor<2x3x5xf32> + %26 = tensor.empty() : tensor<2x3x5xf32> %28 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, diff --git a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir index 4c98037..c43c8a7 100644 --- a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir +++ b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @generalize_pad_tensor_static_shape( // CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.*]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32> +// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x32x32x1xf32> // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x32x32x1xf32>) -> tensor<1x32x32x1xf32> // CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32> // CHECK: return %[[PADDED]] : tensor<1x32x32x1xf32> @@ -28,7 +28,7 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t // CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[OFFSET]], %[[C2]] : index // CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> // CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index -// CHECK: %[[INIT:.*]] = linalg.init_tensor [4, %[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]] : tensor<4x?x?x?xf32> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : tensor<4x?x?x?xf32> // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32> // CHECK: %[[DIM1_1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32> // CHECK: %[[DIM3_1:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir index f3d469d..05b3b43 100644 --- a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir +++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir @@ -6,7 +6,7 @@ // CHECK: func @inline_zerod(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: tensor) func.func @inline_zerod(%arg0: tensor<4xf32>, %scalar: tensor) -> tensor<4xf32> { - %0 = linalg.init_tensor [4] : tensor<4xf32> + %0 = tensor.empty() : tensor<4xf32> // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>) %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2], @@ -31,7 +31,7 @@ func.func @inline_zerod(%arg0: tensor<4xf32>, %scalar: tensor) -> tensor<4x // CHECK: func @inline_oned(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: tensor<1xf32>) func.func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4xf32> { // CHECK: %[[ZERO:.*]] = arith.constant 0 : index - %0 = linalg.init_tensor [4] : tensor<4xf32> + %0 = tensor.empty() : tensor<4xf32> // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>) %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2], diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index d69a798..16c7729 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -323,36 +323,9 @@ func.func @matching_inits(%m: memref, %t: tensor) { // ----- -func.func @init_tensor_err(%arg0 : index, %arg1 : index) -{ - // expected-error @+1 {{specified type 'tensor<4x?x?x5xf32>' does not match the inferred type 'tensor<4x5x?x?xf32>'}} - %1 = linalg.init_tensor [4, 5, %arg0, %arg1] : tensor<4x?x?x5xf32> - return -} - -// ----- - -func.func @init_tensor_err(%arg0 : index) -{ - // expected-error @+1 {{expected 4 sizes values}} - %1 = linalg.init_tensor [4, 5, %arg0] : tensor<4x?x?x5xf32> - return -} - -// ----- - -func.func @init_tensor_err(%arg0 : index) -{ - // expected-error @+1 {{expected 2 dynamic sizes values}} - %1 = "linalg.init_tensor"(%arg0) {static_sizes = [4, -1, -1, 5]} : (index) -> tensor<4x?x?x5xf32> - return -} - -// ----- - func.func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32) { - %0 = linalg.init_tensor [%arg0, %arg1] : tensor + %0 = tensor.empty(%arg0, %arg1) : tensor // expected-error @+1 {{expected the number of results (0) to be equal to the number of output tensors (1)}} linalg.fill ins(%arg2 : f32) outs(%0 : tensor) } diff --git a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir index 132752b..b98086a 100644 --- a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir +++ b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir @@ -52,7 +52,7 @@ func.func @pad_tensor_detailed(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1 // CHECK: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> // CHECK: %[[CTE:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[TMP:.+]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32> +// CHECK: %[[TMP:.+]] = tensor.empty() : tensor<1x32x32x1xf32> // CHECK: %[[R1c:.+]] = linalg.fill // CHECK: %[[R2c:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP4]], #[[$MAP5]]] diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 5a7b7ff..69bdf89 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> { %zero = arith.constant 0.000000e+00 : f32 - %init = linalg.init_tensor [1, 10, 8, 8] : tensor<1x10x8x8xf32> + %init = tensor.empty() : tensor<1x10x8x8xf32> %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8x8xf32>) -> tensor<1x10x8x8xf32> // CHECK: depthwise_conv_1d_nwc_wcm %0 = linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} @@ -17,7 +17,7 @@ func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor // CHECK-LABEL: func @depthwise_conv_1d_nwc_wc func.func @depthwise_conv_1d_nwc_wc(%input: tensor<1x12x8xf32>, %filter: tensor<3x8xf32>) -> tensor<1x10x8xf32> { %zero = arith.constant 0.000000e+00 : f32 - %init = linalg.init_tensor [1, 10, 8] : tensor<1x10x8xf32> + %init = tensor.empty() : tensor<1x10x8xf32> %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8xf32>) -> tensor<1x10x8xf32> // CHECK: depthwise_conv_1d_nwc_wc %0 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} @@ -31,7 +31,7 @@ func.func @depthwise_conv_1d_nwc_wc(%input: tensor<1x12x8xf32>, %filter: tensor< // CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwcm_tensor func.func @depthwise_conv_2d_nhwc_hwcm_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> { %zero = arith.constant 0.000000e+00 : f32 - %init = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32> + %init = tensor.empty() : tensor<2x3x4x2x3xf32> %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32> // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm // CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -59,7 +59,7 @@ func.func @depthwise_conv_2d_nhwc_hwcm_memref(%input: memref<2x4x5x2xf32>, %filt // CHECK-LABEL: func @depthwise_conv_1d_nw_tensor func.func @depthwise_conv_1d_nw_tensor(%input: tensor<1x113x96xf32>, %filter: tensor<3x96xf32>) -> tensor<1x56x96xf32> { - %init = linalg.init_tensor [1, 56, 96] : tensor<1x56x96xf32> + %init = tensor.empty() : tensor<1x56x96xf32> // CHECK: %{{.+}} = linalg.depthwise_conv_1d_nw // CHECK-SAME: {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x113x96xf32>, tensor<3x96xf32>) @@ -72,7 +72,7 @@ func.func @depthwise_conv_1d_nw_tensor(%input: tensor<1x113x96xf32>, %filter: te // CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor func.func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> { - %init = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32> + %init = tensor.empty() : tensor<1x56x56x96xf32> // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc // CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) @@ -97,7 +97,7 @@ func.func @depthwise_conv_2d_nhwc_hwc_memref(%input: memref<1x113x113x96xf32>, % // CHECK-LABEL: func @depthwise_conv_2d_nchw_chw_tensor func.func @depthwise_conv_2d_nchw_chw_tensor(%input: tensor<1x96x113x113xf32>, %filter: tensor<96x3x3xf32>) -> tensor<1x96x56x56xf32> { - %init = linalg.init_tensor [1, 96, 56, 56] : tensor<1x96x56x56xf32> + %init = tensor.empty() : tensor<1x96x56x56xf32> // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nchw_chw // CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x96x113x113xf32>, tensor<96x3x3xf32>) @@ -122,7 +122,7 @@ func.func @depthwise_conv_2d_nchw_chw_memref(%input: memref<1x96x113x113xf32>, % func.func @depthwise_conv_2d_nhwc_hwcm_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> { %zero = arith.constant 0.000000e+00 : f32 - %init = linalg.init_tensor [2, 6, 7, 2, 3] : tensor<2x6x7x2x3xf32> + %init = tensor.empty() : tensor<2x6x7x2x3xf32> %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x6x7x2x3xf32>) -> tensor<2x6x7x2x3xf32> // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm // CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -186,7 +186,7 @@ func.func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref< // CHECK-LABEL: func @depthwise_conv_3d_ndhwc_dhwcm func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> { %zero = arith.constant 0.000000e+00 : f32 - %init = linalg.init_tensor [2, 3, 13, 4, 6, 6] : tensor<2x3x13x4x6x6xf32> + %init = tensor.empty() : tensor<2x3x13x4x6x6xf32> %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> // CHECK: depthwise_conv_3d_ndhwc_dhwcm %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>} @@ -200,7 +200,7 @@ func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<2x6x13x12x6xf32>, %filte // CHECK-LABEL: func @depthwise_conv_3d_ndhwc_dhwc func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6xf32>) -> tensor<2x3x13x4x6xf32> { %zero = arith.constant 0.000000e+00 : f32 - %init = linalg.init_tensor [2, 3, 13, 4, 6] : tensor<2x3x13x4x6xf32> + %init = tensor.empty() : tensor<2x3x13x4x6xf32> %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6xf32>) -> tensor<2x3x13x4x6xf32> // CHECK: depthwise_conv_3d_ndhwc_dhwc %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>} @@ -410,8 +410,8 @@ func.func @conv_3d_ndhwc_dhwcf(%input: memref, %filter: memref, tensor<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> func.func @pooling_nhwc_sum_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> { - %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> - %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> + %fake = tensor.empty() : tensor<3x3xf32> + %init = tensor.empty() : tensor<1x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> %res = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -444,8 +444,8 @@ func.func @pooling_nhwc_sum(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x1x4x4xf32>, tensor<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> func.func @pooling_nchw_sum_tensor(%input: tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> { - %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> - %init = linalg.init_tensor [1, 1, 2, 2] : tensor<1x1x2x2xf32> + %fake = tensor.empty() : tensor<3x3xf32> + %init = tensor.empty() : tensor<1x1x2x2xf32> %cst = arith.constant 0.000000e+00 : f32 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> %res = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -478,8 +478,8 @@ func.func @pooling_nchw_sum(%input: memref<1x1x4x4xf32>, %fake: memref<3x3xf32>, // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> { - %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> - %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> + %fake = tensor.empty() : tensor<3x3xf32> + %init = tensor.empty() : tensor<1x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -497,8 +497,8 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x // CHECK-SAME: outs(%{{.+}} : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> func.func @pooling_nchw_max_tensor(%input: tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> { - %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> - %init = linalg.init_tensor [1, 1, 2, 2] : tensor<1x1x2x2xf32> + %fake = tensor.empty() : tensor<3x3xf32> + %init = tensor.empty() : tensor<1x1x2x2xf32> %cst = arith.constant 0.000000e+00 : f32 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> %res = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -531,8 +531,8 @@ func.func @pooling_nhwc_max(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi8>, tensor<3x3xi8>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8> func.func @pooling_nhwc_i8_max_tensor(%input: tensor<1x4x4x1xi8>) -> tensor<1x2x2x1xi8> { - %fake = linalg.init_tensor [3, 3] : tensor<3x3xi8> - %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi8> + %fake = tensor.empty() : tensor<3x3xi8> + %init = tensor.empty() : tensor<1x2x2x1xi8> %cst = arith.constant 0 : i8 %fill = linalg.fill ins(%cst : i8) outs(%init : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -565,8 +565,8 @@ func.func @pooling_nhwc_i8_max(%input: memref<1x4x4x1xi8>, %fake: memref<3x3xi8> // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi16>, tensor<3x3xi16>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16> func.func @pooling_nhwc_i16_max_tensor(%input: tensor<1x4x4x1xi16>) -> tensor<1x2x2x1xi16> { - %fake = linalg.init_tensor [3, 3] : tensor<3x3xi16> - %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi16> + %fake = tensor.empty() : tensor<3x3xi16> + %init = tensor.empty() : tensor<1x2x2x1xi16> %cst = arith.constant 0 : i16 %fill = linalg.fill ins(%cst : i16) outs(%init : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -599,8 +599,8 @@ func.func @pooling_nhwc_i16_max(%input: memref<1x4x4x1xi16>, %fake: memref<3x3xi // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> func.func @pooling_nhwc_i32_max_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> { - %fake = linalg.init_tensor [3, 3] : tensor<3x3xi32> - %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi32> + %fake = tensor.empty() : tensor<3x3xi32> + %init = tensor.empty() : tensor<1x2x2x1xi32> %cst = arith.constant 0 : i32 %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -634,8 +634,8 @@ func.func @pooling_nhwc_i32_max(%input: memref<1x4x4x1xi32>, %fake: memref<3x3xi // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> { - %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> - %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> + %fake = tensor.empty() : tensor<3x3xf32> + %init = tensor.empty() : tensor<1x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> %res = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} @@ -668,8 +668,8 @@ func.func @pooling_nhwc_min(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> func.func @pooling_ndhwc_sum_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x1xf32> { - %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32> - %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32> + %fake = tensor.empty() : tensor<3x3x3xf32> + %init = tensor.empty() : tensor<1x2x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> %res = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} @@ -702,8 +702,8 @@ func.func @pooling_ndhwc_sum(%input: memref<1x4x4x4x1xf32>, %fake: memref<3x3x3x // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> func.func @pooling_ndhwc_max_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x1xf32> { - %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32> - %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32> + %fake = tensor.empty() : tensor<3x3x3xf32> + %init = tensor.empty() : tensor<1x2x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> %res = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} @@ -736,8 +736,8 @@ func.func @pooling_ndhwc_max(%input: memref<1x4x4x4x1xf32>, %fake: memref<3x3x3x // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> func.func @pooling_ndhwc_min_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x1xf32> { - %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32> - %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32> + %fake = tensor.empty() : tensor<3x3x3xf32> + %init = tensor.empty() : tensor<1x2x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> %res = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir index 18302e2..0491196 100644 --- a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -eliminate-alloc-tensors -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s //===----------------------------------------------------------------------===// -// InitTensorOp elimination +// AllocTensorOp elimination //===----------------------------------------------------------------------===// // CHECK-LABEL: func @buffer_forwarding_conflict diff --git a/mlir/test/Dialect/Linalg/pad_fusion.mlir b/mlir/test/Dialect/Linalg/pad_fusion.mlir index 5d814c3..59b10da 100644 --- a/mlir/test/Dialect/Linalg/pad_fusion.mlir +++ b/mlir/test/Dialect/Linalg/pad_fusion.mlir @@ -6,7 +6,7 @@ func.func @dynamic_pad_fusion(%arg0 : tensor, %arg1 : index, %arg2 : in %c1 = arith.constant 1 : index %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg0, %c1 : tensor - %init = linalg.init_tensor [%d0, %d1] : tensor + %init = tensor.empty(%d0, %d1) : tensor %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -37,7 +37,7 @@ func.func @dynamic_pad_fusion(%arg0 : tensor, %arg1 : index, %arg2 : in // CHECK-DAG: %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]] // CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] // CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]] -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[TARGET_D0]], %[[TARGET_D1]]] +// CHECK: %[[INIT:.+]] = tensor.empty(%[[TARGET_D0]], %[[TARGET_D1]]) // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ARG5]]{{.*}}outs(%[[INIT]] // CHECK-DAG: %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]] // CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] @@ -55,7 +55,7 @@ func.func @mixed_pad_fusion(%arg0 : tensor, %arg1 : index, %arg2 : ind %arg3 : f32) -> tensor<49x?xf32> { %c0 = arith.constant 0 : index %d0 = tensor.dim %arg0, %c0 : tensor - %init = linalg.init_tensor [42, %d0] : tensor<42x?xf32> + %init = tensor.empty(%d0) : tensor<42x?xf32> %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} @@ -81,7 +81,7 @@ func.func @mixed_pad_fusion(%arg0 : tensor, %arg1 : index, %arg2 : ind // CHECK-DAG: %[[SOURCE:.+]] = linalg.generic // CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] // CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]] -// CHECK: %[[INIT:.+]] = linalg.init_tensor [49, %[[TARGET_D1]]] +// CHECK: %[[INIT:.+]] = tensor.empty(%[[TARGET_D1]]) // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ARG3]]{{.*}}outs(%[[INIT]] // CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]] diff --git a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir index f11c884..ab94898 100644 --- a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir @@ -6,7 +6,7 @@ func.func @control_producer_reshape_fusion(%arg0 : tensor, %arg1 : te %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor into tensor %d0 = tensor.dim %0, %c0 : tensor %d1 = tensor.dim %0, %c1 : tensor - %init = linalg.init_tensor [%d0, %d1] : tensor + %init = tensor.empty(%d0, %d1) : tensor %1 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -40,7 +40,7 @@ func.func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : te %cst = arith.constant 0.0 : f32 %d0 = tensor.dim %arg0, %c1 : tensor<1x?x?xf32> %d1 = tensor.dim %arg1, %c2 : tensor<1x?x?xf32> - %init = linalg.init_tensor [%d0, %d1] : tensor + %init = tensor.empty(%d0, %d1) : tensor %fill = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index fc556b5..5c4be8a 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -142,7 +142,7 @@ func.func @reshape_as_consumer_permutation func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) -> tensor<8x33x4xf32> { %cst = arith.constant dense<2.000000e+00> : tensor<264x4xf32> - %0 = linalg.init_tensor [264, 4] : tensor<264x4xf32> + %0 = tensor.empty() : tensor<264x4xf32> %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} @@ -162,7 +162,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32> // CHECK-DAG: %[[CST:.+]] = arith.constant // CHECK-SAME: : tensor<8x33x4xf32> -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [264, 4] +// CHECK-DAG: %[[INIT:.+]] = tensor.empty() // CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1], [2] // CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32> @@ -281,7 +281,7 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor, func.func @reshape_as_consumer_permutation (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>) -> tensor<2x3x4x5x6x7xi32> { - %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32> + %shape = tensor.empty() : tensor<6x4x210xi32> %c = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, @@ -318,7 +318,7 @@ func.func @reshape_as_consumer_permutation // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [6, 4, 210] +// CHECK-DAG: %[[INIT:.+]] = tensor.empty() // CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3, 4], [5] // CHECK-DAG: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] @@ -455,7 +455,7 @@ func.func @no_fuse_dynamic_dims(%arg0: tensor) -> tensor { %c0 = arith.constant 0 : index %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor into tensor %1 = tensor.dim %0, %c0 : tensor - %2 = linalg.init_tensor [%1] : tensor + %2 = tensor.empty(%1) : tensor %3 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} @@ -477,7 +477,7 @@ func.func @no_fuse_dynamic_dims(%arg0: tensor) -> tensor { func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor) -> tensor<2xi64> { %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64> - %1 = linalg.init_tensor [2] : tensor<2xi64> + %1 = tensor.empty() : tensor<2xi64> %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir index 3667176..f931fe8 100644 --- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir @@ -1,42 +1,42 @@ // RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s -func.func @init_tensor_static_dim() -> (index, index) { +func.func @empty_tensor_static_dim() -> (index, index) { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %c6 = arith.constant 6 : index - %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32> + %0 = tensor.empty(%c6) : tensor<4x5x?xf32> %1 = tensor.dim %0, %c2 : tensor<4x5x?xf32> %2 = tensor.dim %0, %c0 : tensor<4x5x?xf32> return %1, %2 : index, index } -// CHECK: func @init_tensor_static_dim +// CHECK: func @empty_tensor_static_dim // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index // CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index // CHECK: return %[[C6]], %[[C4]] // ----- -func.func @init_tensor_dynamic_dim(%arg0 : index) -> (index) { +func.func @empty_tensor_dynamic_dim(%arg0 : index) -> (index) { %c2 = arith.constant 2 : index - %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32> + %0 = tensor.empty(%arg0) : tensor<4x5x?xf32> %1 = tensor.dim %0, %c2 : tensor<4x5x?xf32> return %1 : index } -// CHECK: func @init_tensor_dynamic_dim +// CHECK: func @empty_tensor_dynamic_dim // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index // CHECK: return %[[ARG0]] // ----- -func.func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) { +func.func @empty_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %0 = linalg.init_tensor [%arg0, %arg1] : tensor + %0 = tensor.empty(%arg0, %arg1) : tensor %1 = tensor.dim %0, %c0 : tensor %2 = tensor.dim %0, %c1 : tensor return %1, %2 : index, index } -// CHECK: func @init_tensor_dynamic_dim2 +// CHECK: func @empty_tensor_dynamic_dim2 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK: return %[[ARG0]], %[[ARG1]] @@ -87,7 +87,7 @@ func.func @remove_dim_result_uses_outs %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %d0 = tensor.dim %arg0, %c0 : tensor - %0 = linalg.init_tensor [%d0, %arg1] : tensor + %0 = tensor.empty(%d0, %arg1) : tensor %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -149,7 +149,7 @@ func.func @keep_result_dim_uses_sequence2 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %d0 = tensor.dim %arg0, %c0 : tensor - %0 = linalg.init_tensor [%d0, %arg1] : tensor + %0 = tensor.empty(%d0, %arg1) : tensor %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -173,7 +173,7 @@ func.func @keep_result_dim_uses_sequence2 #map = affine_map<(d0) -> (d0)> -func.func @init_tensor_dim_of_linalg_result(%arg_0 : tensor, +func.func @empty_tensor_dim_of_linalg_result(%arg_0 : tensor, %arg_1: tensor) -> (index, index) { %0, %1 = linalg.generic { indexing_maps = [#map, #map, #map], @@ -190,7 +190,7 @@ func.func @init_tensor_dim_of_linalg_result(%arg_0 : tensor, %num_elem_1 = tensor.dim %1, %c0 : tensor return %num_elem_0, %num_elem_1 : index, index } -// CHECK: func @init_tensor_dim_of_linalg_result( +// CHECK: func @empty_tensor_dim_of_linalg_result( // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor) // CHECK: %[[R0:.+]] = tensor.dim %[[ARG_0]] diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 3d2e9bf..3fb6c3d 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -202,9 +202,9 @@ func.func @generic_with_multiple_tensor_outputs( %arg0: tensor, %arg1: tensor, %arg2: i32) -> (tensor, tensor) { %c0 = arith.constant 0 : index - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty() : tensor %1 = linalg.fill ins(%arg2 : i32) outs(%0 : tensor) -> tensor - %2 = linalg.init_tensor [] : tensor + %2 = tensor.empty() : tensor %3 = linalg.fill ins(%arg2 : i32) outs(%2 : tensor) -> tensor %4:2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], @@ -324,23 +324,8 @@ func.func @named_ops(%a3: memref, %b3: memref, %c3: memref // ----- -#attr = {"foo"} -func.func @init_tensor(%arg0 : index, %arg1 : index) -{ - %0 = linalg.init_tensor [3, 42] : tensor<3x42xf32> - %1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32> - %2 = linalg.init_tensor [2, 2] : tensor<2x2xf32, #attr> - return -} -// CHECK-LABEL: func @init_tensor -// CHECK: linalg.init_tensor [3, 42] : tensor<3x42xf32> -// CHECK: linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32> -// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32, {foo}> - -// ----- - func.func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor { - %0 = linalg.init_tensor [%arg0, %arg1] : tensor + %0 = tensor.empty(%arg0, %arg1) : tensor %1 = linalg.fill ins(%arg2 : f32) outs(%0 : tensor) -> tensor return %1 : tensor } diff --git a/mlir/test/Dialect/Linalg/split_reduction.mlir b/mlir/test/Dialect/Linalg/split_reduction.mlir index 453dd40..25a4999 100644 --- a/mlir/test/Dialect/Linalg/split_reduction.mlir +++ b/mlir/test/Dialect/Linalg/split_reduction.mlir @@ -16,7 +16,7 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten // CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32> // CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32> -// CHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: , iterator_types = ["parallel", "parallel", "parallel", "reduction"]} @@ -41,7 +41,7 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten // INNERPARALLELCHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 // INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x64x4xf32> // INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<64x4x32xf32> -// INNERPARALLELCHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32> +// INNERPARALLELCHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32> // INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> // INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // INNERPARALLELCHECK-SAME: , iterator_types = ["parallel", "parallel", "reduction", "parallel"]} @@ -83,7 +83,7 @@ func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor, %out: ten //CHECK-LABEL: @generic_split_1d // CHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32> -// CHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32> +// CHECK: %[[INI:.*]] = tensor.empty() : tensor<4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> // CHECK: %[[G:.*]] = linalg.generic // CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], @@ -107,7 +107,7 @@ func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor, %out: ten //INNERPARALLELCHECK-LABEL: @generic_split_1d // INNERPARALLELCHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 // INNERPARALLELCHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<8x4xf32> -// INNERPARALLELCHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32> +// INNERPARALLELCHECK: %[[INI:.*]] = tensor.empty() : tensor<4xf32> // INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> // INNERPARALLELCHECK: %[[G:.*]] = linalg.generic // INNERPARALLELCHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], @@ -153,7 +153,7 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32> // CHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 // CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32> // CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> -// CHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32> +// CHECK: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} // CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { @@ -177,7 +177,7 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32> // INNERPARALLELCHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 // INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32> // INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32> -// INNERPARALLELCHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32> +// INNERPARALLELCHECK: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> // INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} // INNERPARALLELCHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir index a0559d7..6f21e1e 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -65,7 +65,7 @@ func.func @conv_tensors_static(%input: tensor<1x225x225x3xf32>, %filter: tensor< %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 - %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> + %init = tensor.empty() : tensor<1x112x112x32xf32> %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> %conv = linalg.conv_2d_nhwc_hwcf @@ -109,7 +109,7 @@ func.func @conv_tensors_static(%input: tensor<1x225x225x3xf32>, %filter: tensor< // CHECK: func @conv_tensors_static // CHECK-SAME: (%[[INPUT:.+]]: tensor<1x225x225x3xf32>, %[[FILTER:.+]]: tensor<3x3x3x32xf32>, %[[ELEM:.+]]: tensor<1x112x112x32xf32>) -// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x112x112x32xf32> // CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> // CHECK-NEXT: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG0:.+]] = %[[FILL]]) @@ -147,7 +147,7 @@ func.func @conv_tensors_dynamic(%input: tensor, %filter: tensor %oc = tensor.dim %elementwise, %c3 : tensor - %init = linalg.init_tensor [%n, %oh, %ow, %oc] : tensor + %init = tensor.empty(%n, %oh, %ow, %oc) : tensor %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor %conv = linalg.conv_2d_nhwc_hwcf @@ -216,7 +216,7 @@ func.func @conv_tensors_dynamic(%input: tensor, %filter: tensor // CHECK-DAG: %[[ELEM_OC:.+]] = tensor.dim %[[ELEM]], %[[C3]] : tensor -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[ELEM_N]], %[[ELEM_OH]], %[[ELEM_OW]], %[[ELEM_OC]]] : tensor +// CHECK: %[[INIT:.+]] = tensor.empty(%[[ELEM_N]], %[[ELEM_OH]], %[[ELEM_OW]], %[[ELEM_OC]]) : tensor // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK-DAG: %[[FILTER_H:.+]] = tensor.dim %[[FILTER]], %[[C0]] : tensor diff --git a/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir index 5a0b2c4..f8f102e 100644 --- a/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir @@ -47,7 +47,7 @@ // CHECK-PEEL-12: } func.func @matmul_static_tensor(%arg0: tensor<1500x1600xf32>, %arg1: tensor<1600x1700xf32>) -> tensor<1500x1700xf32> { - %out = linalg.init_tensor [1500, 1700] : tensor<1500x1700xf32> + %out = tensor.empty() : tensor<1500x1700xf32> %r = linalg.matmul {__internal_linalg_transform__ = "tile"} ins(%arg0, %arg1: tensor<1500x1600xf32>, tensor<1600x1700xf32>) outs(%out: tensor<1500x1700xf32>) -> tensor<1500x1700xf32> @@ -102,7 +102,7 @@ func.func @matmul_dynamic_tensor(%arg0: tensor, %arg1: tensor) %c1 = arith.constant 1 : index %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg1, %c1 : tensor - %out = linalg.init_tensor [%d0, %d1] : tensor + %out = tensor.empty(%d0, %d1) : tensor %r = linalg.matmul {__internal_linalg_transform__ = "tile"} ins(%arg0, %arg1: tensor, tensor) outs(%out: tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir index 9d06888..01f2195 100644 --- a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir +++ b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir @@ -14,7 +14,7 @@ func.func @fill_matmul_tensors( // CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y // CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x // CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor +// CHECK-DAG: %[[INIT:.+]] = tensor.empty // CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]] // CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]] // CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]] @@ -43,7 +43,7 @@ func.func @fill_matmul_tensors( %cst = arith.constant 0.0 : f32 %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg1, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty(%0, %1) : tensor %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor) -> tensor %4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"} ins(%arg0, %arg1: tensor, tensor) diff --git a/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir b/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir index e83cab3..9697adf 100644 --- a/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir +++ b/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir @@ -19,7 +19,7 @@ func.func @matmul_partly_dynamic_tensor(%arg0: tensor, %arg1: tensor - %out = linalg.init_tensor [%d0, 2000] : tensor + %out = tensor.empty(%d0) : tensor %r = linalg.matmul {__internal_linalg_transform__ = "tile"} ins(%arg0, %arg1: tensor, tensor) outs(%out: tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir index cfd68e5..736a0e9 100644 --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -37,7 +37,7 @@ func.func @generic_op_tensors( %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor %2 = tensor.dim %arg0, %c2 : tensor - %3 = linalg.init_tensor [%0, %1, %2] : tensor + %3 = tensor.empty(%0, %1, %2) : tensor %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>, @@ -55,7 +55,7 @@ func.func @generic_op_tensors( // CHECK-LABEL: func @generic_op_tensors // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[INIT:.+]] = tensor.empty // CHECK: %[[TD0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[TC0:.+]] = %[[INIT]]) -> (tensor) { // CHECK: %[[TD1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[TC1:.+]] = %[[TC0]]) -> (tensor) { // CHECK: %[[TD2:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[TC2:.+]] = %[[TC1]]) -> (tensor) { diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir index 7348289..924ed54 100644 --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -40,8 +40,8 @@ func.func @conv_2d_nchw_fchw(%input: tensor, %filter: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32> func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> { - // CHECK: %[[RES:.+]] = linalg.init_tensor - %init = linalg.init_tensor [1, 1, 56, 96] : tensor<1x1x56x96xf32> + // CHECK: %[[RES:.+]] = tensor.empty + %init = tensor.empty() : tensor<1x1x56x96xf32> // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]] diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index 77bd3b2..f1236a8 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -66,12 +66,12 @@ module { // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32> // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32> func.func @fuse_untileable_op(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { - %0 = linalg.init_tensor [%arg0] : tensor + %0 = tensor.empty(%arg0) : tensor %1 = affine.apply #map0()[%arg0] // CHECK: scf.foreach_thread {{.*}} { %2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<64xf32>) { - // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor + // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty %3 = affine.apply #map1(%arg3)[%arg0] %4 = affine.min #map2(%arg3)[%arg0] %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<64xf32> to tensor @@ -91,10 +91,10 @@ module { ^bb0(%arg0: !pdl.operation): transform.sequence %arg0 failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.init_tensor"]} in %arg1 + %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 - // linalg.init_tensor is not tileable. The op is cloned and fused. + // tensor.empty is not tileable. The op is cloned and fused. transform.structured.fuse_into_containing_op %0 into %1 } } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 39227e6..f26462b 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -59,9 +59,9 @@ transform.with_pdl_patterns { // CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>) func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> { %five = arith.constant 5.0 : f32 - %init = linalg.init_tensor [12, 25] : tensor<12x25xf32> + %init = tensor.empty() : tensor<12x25xf32> -// CHECK: %[[INIT:.+]] = linalg.init_tensor [12, 25] +// CHECK: %[[INIT:.+]] = tensor.empty() // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index // CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index // CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]]) diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir index 7d31fc4..4b615fc 100644 --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -49,7 +49,7 @@ transform.with_pdl_patterns { #map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> func.func @match_complex_attribute(%arg0: tensor<12x128x32xf32>) -> tensor<128x12x32xf32> { - %0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32> + %0 = tensor.empty() : tensor<128x12x32xf32> // expected-remark @below {{matched complex attr}} %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index a9ec862..2b8f492 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -182,7 +182,7 @@ transform.with_pdl_patterns { func.func @generic_interchanged_transpose(%arg0: tensor<12x128x32xf32>) -> tensor<128x12x32xf32> { // CHECK: %[[IN:.+]] = vector.transfer_read // CHECK: vector.transfer_write %[[IN]], {{.+}} permutation_map = #[[MAP]] - %0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32> + %0 = tensor.empty() : tensor<128x12x32xf32> %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<12x128x32xf32>) @@ -786,7 +786,7 @@ transform.with_pdl_patterns { // CHECK-NOT: tensor.pad // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> +// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<2x3x4xf32> // CHECK-DAG: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x3x4xf32> // CHECK: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]]{{.*}} : vector<2x3x4xf32>, tensor<2x3x4xf32> // CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, false, true]} : tensor<2x?x2xf32>, vector<2x3x2xf32> @@ -818,7 +818,7 @@ transform.with_pdl_patterns { // CHECK-NOT: tensor.pad // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 6, 4] : tensor<2x6x4xf32> +// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x6x4xf32> // CHECK: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x6x4xf32> // CHECK: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<2x6x4xf32>, tensor<2x6x4xf32> // CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : tensor<2x5x2xf32>, vector<2x5x2xf32> @@ -858,7 +858,7 @@ transform.with_pdl_patterns { // CHECK: %[[DIM3:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32> // CHECK: %[[V4:.*]] = arith.addi %[[DIM3]], %[[C3]] : index // CHECK: %[[V5:.*]] = arith.addi %[[V4]], %[[C2]] : index -// CHECK: %[[INIT:.*]] = linalg.init_tensor [6, %[[V1]], %[[V2]], %[[V5]]] : tensor<6x?x?x?xf32> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[V1]], %[[V2]], %[[V5]]) : tensor<6x?x?x?xf32> // CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%[[INIT]] : tensor<6x?x?x?xf32>) -> tensor<6x?x?x?xf32> // CHECK: %[[SRCDIM:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32> // CHECK: %[[RESULT:.*]] = tensor.insert_slice %[[SRC]] into %[[FILL]][2, %[[LOW]], 3, 3] [1, 2, 2, %[[SRCDIM]]] [1, 1, 1, 1] : tensor<1x2x2x?xf32> into tensor<6x?x?x?xf32> @@ -1192,11 +1192,11 @@ transform.with_pdl_patterns { // CHECK-LABEL: func @red_max_2d( func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32> - // CHECK: linalg.init_tensor [4] : tensor<4xf32> + // CHECK: tensor.empty() : tensor<4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant -3.40282e+38 : f32 - %init = linalg.init_tensor [4] : tensor<4xf32> + %init = tensor.empty() : tensor<4xf32> %fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], @@ -1225,12 +1225,12 @@ transform.with_pdl_patterns { // CHECK-LABEL: func @red_min_2d( func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32> - // CHECK: linalg.init_tensor [4] : tensor<4xf32> + // CHECK: tensor.empty() : tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %maxf32 = arith.constant 3.40282e+38 : f32 - %init = linalg.init_tensor [4] : tensor<4xf32> + %init = tensor.empty() : tensor<4xf32> %fill = linalg.fill ins(%maxf32 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], @@ -1258,12 +1258,12 @@ transform.with_pdl_patterns { // CHECK-LABEL: func @red_mul_2d( func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { - // CHECK: linalg.init_tensor [4] : tensor<4xf32> + // CHECK: tensor.empty() : tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant 1.0 : f32 - %init = linalg.init_tensor [4] : tensor<4xf32> + %init = tensor.empty() : tensor<4xf32> %fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], @@ -1291,12 +1291,12 @@ transform.with_pdl_patterns { // CHECK-LABEL: func @red_or_2d( func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { - // CHECK: linalg.init_tensor [4] : tensor<4xi1> + // CHECK: tensor.empty() : tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false - %init = linalg.init_tensor [4] : tensor<4xi1> + %init = tensor.empty() : tensor<4xi1> %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], @@ -1324,12 +1324,12 @@ transform.with_pdl_patterns { // CHECK-LABEL: func @red_and_2d( func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { - // CHECK: linalg.init_tensor [4] : tensor<4xi1> + // CHECK: tensor.empty() : tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant true - %init = linalg.init_tensor [4] : tensor<4xi1> + %init = tensor.empty() : tensor<4xi1> %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], @@ -1357,12 +1357,12 @@ transform.with_pdl_patterns { // CHECK-LABEL: func @red_xor_2d( func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { - // CHECK: linalg.init_tensor [4] : tensor<4xi1> + // CHECK: tensor.empty() : tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false - %init = linalg.init_tensor [4] : tensor<4xi1> + %init = tensor.empty() : tensor<4xi1> %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], @@ -1397,7 +1397,7 @@ func.func @explicit_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> // CHECK: subf {{.*}} : vector<4x4xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32> %c0 = arith.constant 0.0 : f32 - %init = linalg.init_tensor [4, 4] : tensor<4x4xf32> + %init = tensor.empty() : tensor<4x4xf32> %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<4x4xf32>) -> tensor<4x4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, 0)>, @@ -1436,7 +1436,7 @@ func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32> %c0 = arith.constant 0.0 : f32 - %init = linalg.init_tensor [4] : tensor<4xf32> + %init = tensor.empty() : tensor<4xf32> %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, 0)>, @@ -1478,8 +1478,8 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %f0 = arith.constant 0.000000e+00 : f32 - // CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor - %0 = linalg.init_tensor [] : tensor + // CHECK: %[[init:.*]] = tensor.empty() : tensor + %0 = tensor.empty() : tensor %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor) -> tensor // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] @@ -1524,7 +1524,7 @@ transform.with_pdl_patterns { // CHECK-LABEL: func @not_projected_permutation func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf32> { %c0 = arith.constant 0.0 : f32 - %init = linalg.init_tensor [6, 6, 3, 3] : tensor<6x6x3x3xf32> + %init = tensor.empty() : tensor<6x6x3x3xf32> %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<6x6x3x3xf32>) -> tensor<6x6x3x3xf32> // CHECK: linalg.generic %result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>, diff --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir index b44ffcc..bdba516 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir @@ -49,7 +49,7 @@ func.func @add_d(%arga: tensor<32xf32, #DV>, %argb: f32, %argx: tensor<32xf32>) // CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_4:.*]] = arith.constant 0 : index // CHECK: %[[VAL_5:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_INITTENSOR:.*]] = linalg.init_tensor [32] : tensor<32xf32> +// CHECK: %[[VAL_INITTENSOR:.*]] = tensor.empty() : tensor<32xf32> // CHECK: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>> to memref // CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_INITTENSOR]] : memref<32xf32> // CHECK: linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_7]] : memref<32xf32>) @@ -62,7 +62,7 @@ func.func @add_d(%arga: tensor<32xf32, #DV>, %argb: f32, %argx: tensor<32xf32>) // CHECK: return %[[VAL_11]] : tensor<32xf32> // CHECK: } func.func @add_d_init(%arga: tensor<32xf32, #DV>, %argb: f32) -> tensor<32xf32> { - %u = linalg.init_tensor [32] : tensor<32xf32> + %u = tensor.empty() : tensor<32xf32> %0 = linalg.generic #trait1 ins(%arga: tensor<32xf32, #DV>) outs(%u: tensor<32xf32>) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir index d626045..1a52180 100755 --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -27,7 +27,7 @@ // CHECK: } func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> { %cst = arith.constant 0.000000e+00 : f64 - %0 = linalg.init_tensor [1024, 1024] : tensor<1024x1024xf64> + %0 = tensor.empty() : tensor<1024x1024xf64> %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -46,7 +46,7 @@ func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> { // CHECK: } func.func @fold_yield_direct_zero() -> tensor<32xf64> { %cst = arith.constant 0.000000e+00 : f64 - %0 = linalg.init_tensor [32] : tensor<32xf64> + %0 = tensor.empty() : tensor<32xf64> %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%0 : tensor<32xf64>) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir index 62e0602..70f54ee 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir @@ -32,7 +32,7 @@ // CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK-DAG: %[[VAL_10a:.*]] = linalg.init_tensor [8] : tensor<8xi64> +// CHECK-DAG: %[[VAL_10a:.*]] = tensor.empty() : tensor<8xi64> // CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_10a]] : memref<8xi64> // CHECK-DAG: linalg.fill ins(%[[VAL_5]] : i64) outs(%[[VAL_10]] : memref<8xi64>) // CHECK-DAG: %[[VAL_11:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref @@ -50,7 +50,7 @@ // CHECK: return %[[VAL_20]] : tensor<8xi64> // CHECK: } func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> { - %init = linalg.init_tensor [8] : tensor<8xi64> + %init = tensor.empty() : tensor<8xi64> %r = linalg.generic #trait_1d ins(%arga: tensor<8xi64, #SparseVector>) outs(%init: tensor<8xi64>) { @@ -73,7 +73,7 @@ func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8 // CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK-DAG: %[[VAL_9a:.*]] = linalg.init_tensor [8] : tensor<8xi64> +// CHECK-DAG: %[[VAL_9a:.*]] = tensor.empty() : tensor<8xi64> // CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_9a]] : memref<8xi64> // CHECK-DAG: linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_9]] : memref<8xi64>) // CHECK-DAG: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref @@ -112,7 +112,7 @@ func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8 // CHECK: return %[[VAL_35]] : tensor<8xi64> // CHECK: } func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> { - %init = linalg.init_tensor [8] : tensor<8xi64> + %init = tensor.empty() : tensor<8xi64> %r = linalg.generic #trait_1d ins(%arga: tensor<8xi64, #SparseVector>) outs(%init: tensor<8xi64>) { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 6d73d87..28bcabc 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1523,3 +1523,108 @@ func.func @dont_fold_mismatched_parameters(%input: tensor<1x2x2x4xf32>) -> tenso %1 = tensor.insert_slice %0 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> return %1: tensor<1x2x2x4xf32> } + +// ----- + +func.func @empty_canonicalize() -> (tensor<4x5x?xf32>) { + %c6 = arith.constant 6 : index + %0 = tensor.empty(%c6) : tensor<4x5x?xf32> + return %0 : tensor<4x5x?xf32> +} +// CHECK: func @empty_canonicalize +// CHECK: %[[T0:.+]] = tensor.empty() : tensor<4x5x6xf32> +// CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> +// CHECK: return %[[T1]] + +// ----- + +func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { + %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> + %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] + : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + return %1 : tensor<2x3x5x4x?x7xf32> +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> +// CHECK: func @empty_reshape_expansion +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-NEXT: return %[[INIT]] + +// ----- + +func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { + %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32> + %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] + : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + return %1 : tensor<6x5x?xf32> +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> +// CHECK: func @empty_reshape_collapse +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-NEXT: return %[[INIT]] + +// ----- + +func.func @fold_empty_tensor_with_slice + (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> +{ + %0 = tensor.empty(%arg0) : tensor + %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] + : tensor to tensor<5x?x20xf32> + return %1 : tensor<5x?x20xf32> +} +// CHECK: func @fold_empty_tensor_with_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]]) +// CHECK: return %[[T0]] + +// ----- + +func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> { + %0 = tensor.empty(%arg0) : tensor + %1 = tensor.cast %0 : tensor to tensor<1x12xf32> + return %1 : tensor<1x12xf32> +} +// CHECK: func @fold_empty_tensor_with_cast(%[[ARG0:.+]]: index) +// CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32> +// CHECK: return %[[T0]] : tensor<1x12xf32> + +// ----- + +func.func private @some_use(%i : index, %j : index) + +// CHECK-LABEL: func @empty_tensor_canonicalize +// CHECK-SAME: %[[I:.*]]: index +func.func @empty_tensor_canonicalize(%i : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK-NOT: tensor.empty + %0 = tensor.empty(%i) : tensor + + // CHECK-NOT: tensor.dim + %1 = tensor.dim %0, %c0: tensor + %2 = tensor.dim %0, %c1: tensor + + // CHECK: %[[c42:.*]] = arith.constant 42 : index + // CHECK: call @some_use(%[[I]], %[[c42]]) + call @some_use(%1, %2) : (index, index) -> () + + return +} + +// ----- + +// CHECK-LABEL: func @rank_reducing_empty_tensor_extract +func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> { + // CHECK: tensor.empty() : tensor<2xf32> + %a = tensor.empty(%sz) : tensor + + // CHECK-NOT: extract + %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> + return %r: tensor<2xf32> +} diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir index 02e2502..9a02278 100644 --- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -14,7 +14,7 @@ func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf3 // CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index // CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index -// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// CHECK-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32> // CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) // CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]] // CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : @@ -28,7 +28,7 @@ func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf3 // FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index // FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index // FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index -// FOREACH-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// FOREACH-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32> // FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) shared_outs(%[[dest:.+]] = %[[init]]) // FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]] // FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : @@ -54,7 +54,7 @@ func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor< // CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index // CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index -// CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] : +// CHECK: %[[init:.+]] = tensor.empty() : tensor<10x5xf32> // CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) // CHECK: %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]]) // CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]] @@ -80,7 +80,7 @@ func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %si // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index -// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor +// CHECK: %[[init:.+]] = tensor.empty(%[[sz]]) : tensor // CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32> // CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32> // CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) @@ -109,7 +109,7 @@ func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0 // CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index // CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index -// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// CHECK: %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor // CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : // CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : // CHECK-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : @@ -133,7 +133,7 @@ func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0 // FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index // FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index // FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index -// FOREACH: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// FOREACH: %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor // FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : // FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : // FOREACH-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : @@ -170,7 +170,7 @@ func.func @no_sliced_linearized_dims(%input: tensor<30x11x100xf32>, %offt: index %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<30x11x100xf32> into tensor<330x100xf32> %slice = tensor.extract_slice %collapsed [0, %offt] [330, %size] [1, 1] : tensor<330x100xf32> to tensor<330x?xf32> // CHECK-NOT: scf.for - // CHECK: %[[init:.+]] = linalg.init_tensor [330, %[[arg2]]] + // CHECK: %[[init:.+]] = tensor.empty(%[[arg2]]) // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, %[[arg1]]] [30, 11, %[[arg2]]] [1, 1, 1] // CHECK: %[[c:.+]] = tensor.collapse_shape %[[e]] {{\[}}[0, 1], [2]] // CHECK: %[[res:.+]] = tensor.insert_slice %[[c]] into %[[init]] diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 6b0dfb8..d1d1c35 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -514,3 +514,11 @@ func.func @scatter_wrong_result_type( (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32> return } + +// ----- + +func.func @empty_wrong_number_of_operands(%sz : index) { + // expected-error@+1 {{incorrect number of dynamic sizes, has 1, expected 2}} + %out = tensor.empty(%sz) : tensor<2x?x?x5xf32> + return +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 436a039..f2f4645 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -13,6 +13,14 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor tensor<5x?x6xf32> { + // CHECK: tensor.empty(%[[sz]]) : tensor<5x?x6xf32> + %0 = tensor.empty(%sz) : tensor<5x?x6xf32> + return %0 : tensor<5x?x6xf32> +} + // CHECK-LABEL: func @extract( // CHECK-SAME: %[[TENSOR:.*]]: tensor, // CHECK-SAME: %[[INDEX:.*]]: index) { diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir index 9eb91e4..59688de 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns \ -// RUN: -linalg-init-tensor-to-alloc-tensor -linalg-bufferize -arith-bufferize \ +// RUN: -empty-tensor-to-alloc-tensor -linalg-bufferize -arith-bufferize \ // RUN: -bufferization-bufferize -tensor-bufferize -func-bufferize \ // RUN: -finalizing-bufferize -buffer-deallocation \ // RUN: -convert-linalg-to-loops -convert-scf-to-cf -convert-linalg-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py index 18dccff..5606035 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -39,6 +39,7 @@ from mlir.dialects import builtin from mlir.dialects import func from mlir.dialects import linalg from mlir.dialects import sparse_tensor +from mlir.dialects import tensor from mlir.dialects.linalg.opdsl import lang from . import mlir_pytaco_utils as utils @@ -899,9 +900,9 @@ class _StructOpInfo: if self.dst_format is None or self.dst_format.rank() == 0: # Initialize the dense tensor. ir_type = _mlir_type_from_taco_type(self.dst_dtype) - tensor = linalg.InitTensorOp(self.dst_dims, ir_type).result + empty = tensor.EmptyOp(self.dst_dims, ir_type).result zero = arith.ConstantOp(ir_type, 0.0) - return linalg.fill(zero, outs=[tensor]) + return linalg.fill(zero, outs=[empty]) # Initialize the sparse tensor. mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims, @@ -1194,12 +1195,12 @@ class Tensor: """ if array.dtype != np.float32 and array.dtype != np.float64: raise ValueError(f"Expected floating point value type: {array.dtype}.") - tensor = Tensor( + t = Tensor( array.shape, dtype=_nptype_to_taco_type(array.dtype.type), is_dense=True) - tensor._dense_storage = np.copy(array) - return tensor + t._dense_storage = np.copy(array) + return t @staticmethod def from_coo( @@ -1234,9 +1235,9 @@ class Tensor: # The size of each dimension is one more that such a maximum coordinate # value. shape = [c + 1 for c in max_coordinate] - tensor = Tensor(shape, fmt, dtype=dtype) - tensor._coords = coordinates - tensor._values = values + t = Tensor(shape, fmt, dtype=dtype) + t._coords = coordinates + t._values = values return tensor @@ -1261,10 +1262,10 @@ class Tensor: sparse_tensor, shape = utils.create_sparse_tensor(filename, fmt.format_pack.formats, _dtype_to_mlir_str(dtype)) - tensor = Tensor(shape.tolist(), fmt, dtype=dtype) - tensor._set_packed_sparse_tensor(sparse_tensor) + t = Tensor(shape.tolist(), fmt, dtype=dtype) + t._set_packed_sparse_tensor(sparse_tensor) - return tensor + return t def to_file(self, filename: str) -> None: """Output the tensor value to a file. diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index dd8631f..955ea9d 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -6,7 +6,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> %cst = arith.constant 0.0 : f32 %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg1, %c1 : tensor - %init = linalg.init_tensor [%d0, %d1] : tensor + %init = tensor.empty(%d0, %d1) : tensor %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor %gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"} ins(%arg0, %arg1 : tensor, tensor) @@ -16,7 +16,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> // CHECK: func.func @gemm_fill_fusion( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) -// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[INIT:.+]] = tensor.empty // CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = // CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) // CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = @@ -41,7 +41,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, %cst = arith.constant 0.0 : f32 %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg1, %c1 : tensor - %init = linalg.init_tensor [%d0, %d1] : tensor + %init = tensor.empty(%d0, %d1) : tensor %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor %gemm = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) @@ -61,7 +61,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) -// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[INIT:.+]] = tensor.empty // CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = // CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) // CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = @@ -90,12 +90,12 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor, %rhs0 : tensor, %r %cst = arith.constant 0.0 : f32 %d0 = tensor.dim %lhs0, %c0 : tensor %d1 = tensor.dim %rhs0, %c1 : tensor - %init0 = linalg.init_tensor [%d0, %d1] : tensor + %init0 = tensor.empty(%d0, %d1) : tensor %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor %gemm0 = linalg.matmul ins(%lhs0, %rhs0 : tensor, tensor) outs(%fill0 : tensor) -> tensor %d2 = tensor.dim %rhs1, %c1 : tensor - %init1 = linalg.init_tensor [%d0, %d2] : tensor + %init1 = tensor.empty(%d0, %d2) : tensor %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"} ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor @@ -109,9 +109,9 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor, %rhs0 : tensor, %r // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]] -// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]]) // CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]] -// CHECK: %[[INIT1:.+]] = linalg.init_tensor [%[[D0]], %[[D2]]] +// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D2]]) // CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] = // CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]]) // CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] @@ -140,12 +140,12 @@ func.func @gemm_transpose_fusion(%arg0 : tensor, %arg1 : tensor %d1 = tensor.dim %arg1, %c1 : tensor - %init0 = linalg.init_tensor [%d0, %d1] : tensor + %init0 = tensor.empty(%d0, %d1) : tensor %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor %gemm = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%fill : tensor) -> tensor - %init1 = linalg.init_tensor [%d1, %d0] : tensor + %init1 = tensor.empty(%d1, %d0) : tensor %transpose = linalg.generic { __internal_linalg_transform__ = "fusion", indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], @@ -163,8 +163,8 @@ func.func @gemm_transpose_fusion(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor %d1 = tensor.dim %arg1, %c1 : tensor %cst = arith.constant 0.0 : f32 - %0 = linalg.init_tensor [%d0, %d1] : tensor + %0 = tensor.empty(%d0, %d1) : tensor %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) @@ -211,7 +211,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor, %arg1 : tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) -// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[INIT:.+]] = tensor.empty // CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = // CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) // CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = @@ -243,7 +243,7 @@ func.func @matmul_plus_matmul(%arg0: tensor, %arg1: tensor, outs(%arg2 : tensor) -> tensor %3 = tensor.dim %2, %c0 : tensor %4 = tensor.dim %2, %c1 : tensor - %5 = linalg.init_tensor [%3, %4] : tensor + %5 = tensor.empty(%3, %4) : tensor %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, @@ -302,7 +302,7 @@ func.func @matmul_plus_transpose_matmul(%arg0: tensor, %arg1: tensor) -> tensor %3 = tensor.dim %2, %c0 : tensor %4 = tensor.dim %2, %c1 : tensor - %5 = linalg.init_tensor [%3, %4] : tensor + %5 = tensor.empty(%3, %4) : tensor %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir index 4beeede..7c8fa20 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -87,8 +87,8 @@ func.func @simple_matmul_memref(%arg0 : memref, %arg1 : memref #map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) { - %init0 = linalg.init_tensor [128, 300, 200] : tensor<128x300x200xf32> - %init1 = linalg.init_tensor [300, 128, 200] : tensor<300x128x200xf32> + %init0 = tensor.empty() : tensor<128x300x200xf32> + %init1 = tensor.empty() : tensor<300x128x200xf32> %0:2 = linalg.generic { indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} @@ -108,8 +108,8 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x // CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index // CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index // CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index -// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200] -// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [300, 128, 200] +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() // CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]] // CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]]) // CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]] diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index 53ce2c8..df9e62e 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -136,8 +136,8 @@ struct RewriteExtractSliceFromCollapseShapeBase // Create the destination tensor using the above values. Type elementType = op.getSourceType().getElementType(); SmallVector outputShape = getAsOpFoldResult(reifiedShapes[0]); - Value dest = rewriter.create( - op->getLoc(), outputShape, elementType); + Value dest = rewriter.create(op->getLoc(), outputShape, + elementType); // Calculate the parameters for the tile loop nest. FailureOr params = diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py index add7d6a..6dff754 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py @@ -4,6 +4,7 @@ from mlir.ir import * from mlir.dialects import builtin from mlir.dialects import func from mlir.dialects import linalg +from mlir.dialects import tensor from mlir.dialects.linalg.opdsl.lang import * @@ -50,7 +51,7 @@ with Context() as ctx, Location.unknown(): # CHECK-LABEL: func @test_matmul_mono # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32> # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32> - # CHECK: %[[INITC:.+]] = linalg.init_tensor [4, 8] : tensor<4x8xf32> + # CHECK: %[[INITC:.+]] = tensor.empty() : tensor<4x8xf32> # CHECK: linalg.generic # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]] # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] @@ -59,7 +60,7 @@ with Context() as ctx, Location.unknown(): @func.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)) def test_matmul_mono(lhs, rhs): - init_result = linalg.InitTensorOp([4, 8], f32) + init_result = tensor.EmptyOp([4, 8], f32) return matmul_mono(lhs, rhs, outs=[init_result.result]) # CHECK-LABEL: @test_i8i8i32_matmul diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index fbb0552..e14ec42 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -1,6 +1,6 @@ # RUN: %PYTHON %s | FileCheck %s -from mlir.dialects import arith, builtin, func, linalg +from mlir.dialects import arith, builtin, func, linalg, tensor from mlir.dialects.linalg.opdsl.lang import * from mlir.ir import * @@ -11,46 +11,6 @@ def run(f): return f -# CHECK-LABEL: TEST: testInitTensor -@run -def testInitTensor(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - # CHECK-LABEL: func @static_sizes - # CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32> - @func.FuncOp.from_py_func() - def static_sizes(): - return linalg.InitTensorOp([3, 4], f32) - - # CHECK-LABEL: func @dynamic_sizes - # CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor - @func.FuncOp.from_py_func(IndexType.get(), IndexType.get()) - def dynamic_sizes(d0, d1): - return linalg.InitTensorOp([d0, d1], f32) - - # CHECK-LABEL: func @zero_d - # CHECK: %0 = linalg.init_tensor [] : tensor - @func.FuncOp.from_py_func() - def zero_d(): - return linalg.InitTensorOp([], f32) - - print(module) - - -# CHECK-LABEL: TEST: testInitTensorStaticSizesAttribute -@run -def testInitTensorStaticSizesAttribute(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - op = linalg.InitTensorOp([3, 4], f32) - # CHECK: [3, 4] - print(op.attributes["static_sizes"]) - - # CHECK-LABEL: TEST: testFill @run def testFill(): @@ -92,7 +52,7 @@ def testNamedStructuredOpCustomForm(): @func.FuncOp.from_py_func( RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)) def named_form(lhs, rhs): - init_result = linalg.InitTensorOp([4, 8], f32) + init_result = tensor.EmptyOp([4, 8], f32) # Check for the named form with custom format # CHECK: linalg.elemwise_unary # CHECK-SAME: cast = #linalg.type_fn @@ -127,7 +87,7 @@ def testNamedStructuredOpGenericForm(): RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)) def named_form(lhs, rhs): - init_result = linalg.InitTensorOp([4, 8], f32) + init_result = tensor.EmptyOp([4, 8], f32) # CHECK: "linalg.matmul"(%{{.*}}) # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 @@ -153,7 +113,7 @@ def testNamedStructuredAsGenericOp(): RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)) def generic_form(lhs, rhs): - init_result = linalg.InitTensorOp([4, 8], f32) + init_result = tensor.EmptyOp([4, 8], f32) # CHECK: linalg.generic return linalg.matmul( lhs, rhs, outs=[init_result.result], emit_generic=True) @@ -178,8 +138,8 @@ def testOpResultFromOtherOp(): lhs = linalg.fill(one, outs=[arg0]) # CHECK: %[[RHS:.*]] = linalg.fill rhs = linalg.fill(one, outs=[arg1]) - # CHECK: %[[INIT:.*]] = linalg.init_tensor - init = linalg.InitTensorOp([4, 8], f32) + # CHECK: %[[INIT:.*]] = tensor.empty + init = tensor.EmptyOp([4, 8], f32) # CHECK: linalg.matmul # CHECK: ins(%[[LHS]], %[[RHS]] # CHECK: outs(%[[INIT]] diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py index 3a47085..cb05feb 100644 --- a/mlir/test/python/dialects/tensor.py +++ b/mlir/test/python/dialects/tensor.py @@ -37,3 +37,37 @@ def testDimOp(): return [d0.result, d1.result] print(module) + + +# CHECK-LABEL: TEST: testEmptyOp +@run +def testEmptyOp(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + # CHECK-LABEL: func @static_sizes + # CHECK: %0 = tensor.empty() : tensor<3x4xf32> + @func.FuncOp.from_py_func() + def static_sizes(): + return tensor.EmptyOp([3, 4], f32) + + # CHECK-LABEL: func @dynamic_sizes + # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor + @func.FuncOp.from_py_func(IndexType.get(), IndexType.get()) + def dynamic_sizes(d0, d1): + return tensor.EmptyOp([d0, d1], f32) + + # CHECK-LABEL: func @mixed_static_dynamic_sizes + # CHECK: %0 = tensor.empty(%arg0) : tensor + @func.FuncOp.from_py_func(IndexType.get()) + def mixed_static_dynamic_sizes(d0): + return tensor.EmptyOp([d0, 4], f32) + + # CHECK-LABEL: func @zero_d + # CHECK: %0 = tensor.empty() : tensor + @func.FuncOp.from_py_func() + def zero_d(): + return tensor.EmptyOp([], f32) + + print(module) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 13f5f5a..0d89158 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5034,6 +5034,7 @@ cc_library( hdrs = ["include/mlir/Dialect/Tensor/IR/Tensor.h"], includes = ["include"], deps = [ + ":AffineDialect", ":ArithDialect", ":ArithUtils", ":CastOpInterfaces",