From: Alexander Belyaev Date: Thu, 8 Oct 2020 09:07:36 +0000 (+0200) Subject: [mlir] Add basic support for dynamic tensor results in TensorToBuffers.cpp. X-Git-Tag: llvmorg-13-init~9834 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c1fd4305b68500c754a7ce6a86fe297c36e21d3b;p=platform%2Fupstream%2Fllvm.git [mlir] Add basic support for dynamic tensor results in TensorToBuffers.cpp. The simplest case is when the indexing maps are DimIds in every component. This covers cwise ops. Also: * Expose populateConvertLinalgOnTensorsToBuffersPatterns in Transforms.h * Expose emitLoopRanges in Transforms.h Differential Revision: https://reviews.llvm.org/D88781 --- diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 2e566c9..395db39 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -16,6 +16,9 @@ #include "llvm/ADT/SmallBitVector.h" namespace mlir { + +class BufferAssignmentTypeConverter; + namespace linalg { struct LinalgFusionOptions; @@ -45,6 +48,12 @@ void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns, ArrayRef tileSizes); +/// Populates the given list with patterns to convert Linalg operations on +/// tensors to buffers. +void populateConvertLinalgOnTensorsToBuffersPatterns( + MLIRContext *context, BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns); + /// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` /// The permutation is expressed as a list of integers that specify @@ -246,6 +255,16 @@ Optional promoteSubViews(OpBuilder &b, LinalgOp op, LinalgPromotionOptions options, OperationFolder *folder = nullptr); +/// Creates a number of ranges equal to the number of dimensions in the `map`. +/// The returned ranges correspond to the loop ranges, in the proper order, for +/// which new loops will be created. +/// The function supports only maps that are invertible and have results of type +/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr). +/// It expects a non-inverted, concatenated map and last values in +/// allViewSizes will be applied to the symbols in the map if it contains any. +SmallVector emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, + ValueRange viewSizes); + /// Emit a suitable vector form for a Linalg op with fully static shape. void vectorizeLinalgOp(OpBuilder &builder, Operation *op); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 9e96c8c..b95469d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -58,77 +58,6 @@ static SmallVector permuteIvs(ArrayRef ivs, : SmallVector(ivs.begin(), ivs.end()); } -/// Creates a number of ranges equal to the number of dimensions in the `map`. -/// The returned ranges correspond to the loop ranges, in the proper order, for -/// which new loops will be created. -/// The function supports only maps that are invertible and have results of type -/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr). -/// It expects a non-inverted, concatenated map and last values in -/// allViewSizes will be applied to the symbols in the map if it contains any. -static SmallVector emitLoopRanges(OpBuilder &b, Location loc, - AffineMap map, - ValueRange viewSizes) { - unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); - unsigned numSym = map.getNumSymbols(); - assert(viewSizes.size() == numRes + numSym && - "viewSizes must contain sizes of all views and values for symbols"); - SmallVector res(numDims); - for (unsigned idx = 0; idx < numRes; ++idx) { - auto result = map.getResult(idx); - if (auto d = result.dyn_cast()) { - if (res[d.getPosition()].offset) - continue; - res[d.getPosition()] = - Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)}; - } - - // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2), - // then the bounds are: - // (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1). - // where size(n) is applied to the symbol s. - // This is done statically now. - if (auto binOp = result.dyn_cast()) { - auto lhs = binOp.getLHS().dyn_cast(); - auto rhs = binOp.getRHS().dyn_cast(); - if (!lhs || !rhs || binOp.getKind() != AffineExprKind::Add || - lhs.getKind() != AffineExprKind::Add || - rhs.getKind() != mlir::AffineExprKind::Mul) - continue; - - auto m = lhs.getLHS().dyn_cast(); - auto n = lhs.getRHS().dyn_cast(); - auto fDiv = rhs.getLHS().dyn_cast(); - auto minusOne = rhs.getRHS().dyn_cast(); - if (!m || !n || !fDiv || !minusOne || - fDiv.getKind() != AffineExprKind::FloorDiv || - fDiv.getLHS().getKind() != AffineExprKind::SymbolId || - fDiv.getRHS().getKind() != AffineExprKind::Constant) - continue; - - auto s = fDiv.getLHS().dyn_cast(); - if (minusOne.getValue() != -1) - continue; - - int mPos = m.getPosition(); - AffineExpr one = getAffineConstantExpr(1, s.getContext()); - AffineExpr sizeOfM = getAffineSymbolExpr(numSym, s.getContext()); - // Construction of upper bound (size(m) + s floordiv 2 - s + 1). - AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s; - AffineMap fromMap = AffineMap::get(numDims, numSym + 1, fDiv); - AffineMap toMap = AffineMap::get(numDims, numSym + 1, upperOffsetExpr); - SmallVector values(viewSizes.begin(), - viewSizes.begin() + numDims); - values.insert(values.end(), viewSizes.begin() + numRes, viewSizes.end()); - values.push_back(viewSizes[mPos]); - // Construction of the lower bound (s floordiv 2). - Value from = applyMapToValues(b, loc, fromMap, values).front(); - Value to = applyMapToValues(b, loc, toMap, values).front(); - res[mPos] = Range{from, to, std_constant_index(1)}; - } - } - return res; -} - template static void inlineRegionAndEmitStore(OpType op, ArrayRef indexedValues, ArrayRef> indexing, @@ -708,6 +637,70 @@ static Optional linalgOpToLoopsImplSwitch(Operation *op, llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); } +SmallVector mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc, + AffineMap map, + ValueRange viewSizes) { + unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); + unsigned numSym = map.getNumSymbols(); + assert(viewSizes.size() == numRes + numSym && + "viewSizes must contain sizes of all views and values for symbols"); + SmallVector res(numDims); + for (unsigned idx = 0; idx < numRes; ++idx) { + auto result = map.getResult(idx); + if (auto d = result.dyn_cast()) { + if (res[d.getPosition()].offset) + continue; + res[d.getPosition()] = + Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)}; + } + + // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2), + // then the bounds are: + // (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1). + // where size(n) is applied to the symbol s. + // This is done statically now. + if (auto binOp = result.dyn_cast()) { + auto lhs = binOp.getLHS().dyn_cast(); + auto rhs = binOp.getRHS().dyn_cast(); + if (!lhs || !rhs || binOp.getKind() != AffineExprKind::Add || + lhs.getKind() != AffineExprKind::Add || + rhs.getKind() != mlir::AffineExprKind::Mul) + continue; + + auto m = lhs.getLHS().dyn_cast(); + auto n = lhs.getRHS().dyn_cast(); + auto fDiv = rhs.getLHS().dyn_cast(); + auto minusOne = rhs.getRHS().dyn_cast(); + if (!m || !n || !fDiv || !minusOne || + fDiv.getKind() != AffineExprKind::FloorDiv || + fDiv.getLHS().getKind() != AffineExprKind::SymbolId || + fDiv.getRHS().getKind() != AffineExprKind::Constant) + continue; + + auto s = fDiv.getLHS().dyn_cast(); + if (minusOne.getValue() != -1) + continue; + + int mPos = m.getPosition(); + AffineExpr one = getAffineConstantExpr(1, s.getContext()); + AffineExpr sizeOfM = getAffineSymbolExpr(numSym, s.getContext()); + // Construction of upper bound (size(m) + s floordiv 2 - s + 1). + AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s; + AffineMap fromMap = AffineMap::get(numDims, numSym + 1, fDiv); + AffineMap toMap = AffineMap::get(numDims, numSym + 1, upperOffsetExpr); + SmallVector values(viewSizes.begin(), + viewSizes.begin() + numDims); + values.insert(values.end(), viewSizes.begin() + numRes, viewSizes.end()); + values.push_back(viewSizes[mPos]); + // Construction of the lower bound (s floordiv 2). + Value from = applyMapToValues(b, loc, fromMap, values).front(); + Value to = applyMapToValues(b, loc, toMap, values).front(); + res[mPos] = Range{from, to, std_constant_index(1)}; + } + } + return res; +} + /// Emits a loop nest with the proper body for `op`. template Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp index b714a1f..3282358 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -14,14 +14,119 @@ #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/BufferPlacement.h" -using namespace mlir; - namespace { + +using namespace ::mlir; +using namespace ::mlir::linalg; + +SmallVector +computeLoopRanges(Location loc, linalg::GenericOp linalgOp, OpBuilder *b) { + auto indexingMaps = llvm::to_vector<4>( + linalgOp.indexing_maps().getAsValueRange()); + auto inputIndexingMaps = + llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs()); + + mlir::edsc::ScopedContext scope(*b, loc); + return emitLoopRanges(scope.getBuilderRef(), loc, + concatAffineMaps(inputIndexingMaps), + getShape(*b, linalgOp)); +} + +Value maybeConvertToIndex(Location loc, Value val, OpBuilder *b) { + if (val.getType().isIndex()) + return val; + return b->create(loc, val, b->getIndexType()); +} + +LogicalResult allocateBuffersForResults(Location loc, + linalg::GenericOp linalgOp, + linalg::GenericOpAdaptor &adaptor, + SmallVectorImpl *resultBuffers, + OpBuilder *b) { + // Lazily compute loopRanges. + SmallVector loopRanges; + + // Allocate a buffer for every tensor result. + for (auto en : llvm::enumerate(linalgOp.getResultTypes())) { + size_t resultIndex = en.index(); + Type resultType = en.value(); + + auto tensorType = resultType.dyn_cast(); + if (tensorType == nullptr) { + linalgOp.emitOpError() + << "tensor to buffer conversion expects ranked tensor results"; + return failure(); + } + auto tensorShape = tensorType.getShape(); + auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); + + // Allocate buffers for init tensors that are assumed to fold onto the first + // results. + // TODO: update this assumption because the reality is more complex + // under linalg on tensor based transformations. + bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors(); + if (foldedInitTensor) { + // Dealing with an init tensor requires distinguishing between 1-use + // and many-use cases which would create aliasing and WAR hazards. + Value initTensor = linalgOp.getInitTensor(resultIndex); + Value initBuffer = adaptor.init_tensors()[resultIndex]; + if (initTensor.hasOneUse()) { + resultBuffers->push_back(initBuffer); + continue; + } + SmallVector dynOperands; + for (auto dim : llvm::enumerate(tensorShape)) { + if (dim.value() == TensorType::kDynamicSize) { + dynOperands.push_back(b->create(loc, initTensor, dim.index())); + } + } + auto alloc = b->create(loc, memrefType, dynOperands); + b->create(loc, initBuffer, alloc); + resultBuffers->push_back(alloc); + continue; + } + + // Allocate buffers for statically-shaped results. + if (memrefType.hasStaticShape()) { + resultBuffers->push_back(b->create(loc, memrefType)); + continue; + } + + // Perform a naive shape inference for the dynamically-shaped results. + // Extract the required element out of the vector. + SmallVector dynOperands; + auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex); + for (auto shapeElement : llvm::enumerate(tensorType.getShape())) { + if (loopRanges.empty()) + loopRanges = computeLoopRanges(loc, linalgOp, b); + + if (shapeElement.value() != ShapedType::kDynamicSize) + continue; + + AffineExpr expr = resultIndexingMap.getResult(shapeElement.index()); + switch (expr.getKind()) { + case AffineExprKind::DimId: { + int64_t loopIndex = expr.cast().getPosition(); + Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b); + dynOperands.push_back(size); + break; + } + default: + return failure(); + } + } + resultBuffers->push_back(b->create(loc, memrefType, dynOperands)); + } + return success(); +} + /// A pattern to convert Generic Linalg operations which work on tensors to /// use buffers. A buffer is allocated using BufferAssignmentPlacer for /// each operation result. BufferPlacement pass should be later used to move @@ -34,10 +139,10 @@ public: linalg::GenericOp>::BufferAssignmentOpConversionPattern; LogicalResult - matchAndRewrite(linalg::GenericOp op, ArrayRef operands, + matchAndRewrite(linalg::GenericOp linalgOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - linalg::GenericOpAdaptor adaptor(operands, - op.getOperation()->getAttrDictionary()); + linalg::GenericOpAdaptor adaptor( + operands, linalgOp.getOperation()->getAttrDictionary()); // All inputs need to be turned into buffers first. Until then, bail out. if (llvm::any_of(adaptor.inputs(), @@ -50,93 +155,54 @@ public: [](Value in) { return !in.getType().isa(); })) return failure(); - Location loc = op.getLoc(); - SmallVector newOutputBuffers; - newOutputBuffers.reserve(op.getNumOutputs()); - newOutputBuffers.append(adaptor.output_buffers().begin(), - adaptor.output_buffers().end()); - - // Update all types to memref types. - // Assume the init tensors fold onto the first results. - // TODO: update this assumption because the reality is more complex under - // linalg on tensor based transformations. - for (auto en : llvm::enumerate(op.getResultTypes())) { - auto type = en.value().cast(); - if (!type.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "dynamic shapes not currently supported"); - auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); - bool foldedInitTensor = en.index() < op.getNumInitTensors(); - if (foldedInitTensor) { - // Dealing with an init tensor requires distinguishing between 1-use - // and many-use cases which would create aliasing and WAR hazards. - Value initTensor = op.getInitTensor(en.index()); - Value initBuffer = adaptor.init_tensors()[en.index()]; - if (initTensor.hasOneUse()) { - newOutputBuffers.push_back(initBuffer); - continue; - } - auto alloc = rewriter.create(loc, memrefType); - rewriter.create(loc, initBuffer, alloc); - newOutputBuffers.push_back(alloc); - } else { - auto alloc = rewriter.create(loc, memrefType); - newOutputBuffers.push_back(alloc); - } + Location loc = linalgOp.getLoc(); + SmallVector newOutputBuffers(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); + + if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, + &newOutputBuffers, &rewriter))) { + linalgOp.emitOpError() + << "Failed to allocate buffers for tensor results."; + return failure(); } // Generate a new linalg operation that works on buffers. - auto linalgOp = rewriter.create( + auto newLinalgOp = rewriter.create( loc, - /*resultTensorTypes=*/ArrayRef{}, + /*resultTensorTypes=*/llvm::None, /*inputs=*/adaptor.inputs(), /*outputBuffers=*/newOutputBuffers, - /*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(), - op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr()); + /*initTensors=*/llvm::None, linalgOp.indexing_maps(), + linalgOp.iterator_types(), linalgOp.docAttr(), + linalgOp.library_callAttr(), linalgOp.symbol_sourceAttr()); // Create a new block in the region of the new Generic Op. - Block &oldBlock = op.getRegion().front(); - Region &newRegion = linalgOp.region(); + Block *oldBlock = linalgOp.getBody(); + Region &newRegion = newLinalgOp.region(); Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), - oldBlock.getArgumentTypes()); - - // Add the result arguments that do not come from init_tensors to the new - // block. - // TODO: update this assumption because the reality is more complex under - // linalg on tensor based transformations. - for (Value v : - ValueRange(newOutputBuffers).drop_front(adaptor.init_tensors().size())) + oldBlock->getArgumentTypes()); + + // Add the result arguments to the new block. + for (Value v : newOutputBuffers) newBlock->addArgument(v.getType().cast().getElementType()); // Clone the body of the old block to the new block. BlockAndValueMapping mapping; - for (unsigned i = 0; i < oldBlock.getNumArguments(); i++) - mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i)); + mapping.map(oldBlock->getArguments(), newBlock->getArguments()); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(newBlock); - for (auto &op : oldBlock.getOperations()) { + for (auto &op : oldBlock->getOperations()) { Operation *clonedOp = rewriter.clone(op, mapping); mapping.map(op.getResults(), clonedOp->getResults()); } // Replace the results of the old op with the new output buffers. - rewriter.replaceOp(op, newOutputBuffers); + rewriter.replaceOp(linalgOp, newOutputBuffers); return success(); } }; -/// Populate the given list with patterns to convert Linalg operations on -/// tensors to buffers. -static void populateConvertLinalgOnTensorsToBuffersPattern( - MLIRContext *context, BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns) { - populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter, - patterns); - patterns->insert(context, converter); -} - /// Converts Linalg operations that work on tensor-type operands or results to /// work on buffers. struct ConvertLinalgOnTensorsToBuffers @@ -176,8 +242,11 @@ struct ConvertLinalgOnTensorsToBuffers BufferAssignmentTypeConverter::AppendToArgumentsList); OwningRewritePatternList patterns; - populateConvertLinalgOnTensorsToBuffersPattern(&context, &converter, - &patterns); + populateConvertLinalgOnTensorsToBuffersPatterns(&context, &converter, + &patterns); + populateWithBufferAssignmentOpConversionPatterns< + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(&context, &converter, + &patterns); if (failed(applyFullConversion(this->getOperation(), target, patterns))) this->signalPassFailure(); } @@ -188,3 +257,9 @@ std::unique_ptr> mlir::createConvertLinalgOnTensorsToBuffersPass() { return std::make_unique(); } + +void mlir::linalg::populateConvertLinalgOnTensorsToBuffersPatterns( + MLIRContext *context, BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { + patterns->insert(context, converter); +} diff --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir index 654a13f..4339b33 100644 --- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir +++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir @@ -2,11 +2,13 @@ #map0 = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func @multiple_results_generic_op -func @multiple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0, %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} - ins(%arg0 : tensor<4xf32>) { - ^bb0(%gen_arg1: f32): +// CHECK-LABEL: func @multiple_results +func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0, %1 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel"] + } ins(%arg0 : tensor<4xf32>) { + ^bb0(%gen_arg1: f32): %tmp1 = exp %gen_arg1 : f32 linalg.yield %tmp1, %tmp1 : f32, f32 } -> tensor<4xf32>, tensor<4xf32> @@ -34,15 +36,20 @@ func @multiple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tenso // CHECK-LABEL: func @chained_operations func @chained_operations(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} - ins(%arg0 : tensor<4xf32>) { - ^bb0(%gen_arg1: f32): + %0 = linalg.generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel"] + } ins(%arg0 : tensor<4xf32>) { + ^bb0(%gen_arg1: f32): %tmp1 = exp %gen_arg1 : f32 linalg.yield %tmp1 : f32 } -> tensor<4xf32> - %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} - ins(%0 : tensor<4xf32>) { - ^bb0(%gen_arg2: f32): + + %1 = linalg.generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel"] + } ins(%0 : tensor<4xf32>) { + ^bb0(%gen_arg2: f32): %tmp2 = exp %gen_arg2 : f32 linalg.yield %tmp2 : f32 } -> tensor<4xf32> @@ -73,6 +80,46 @@ func @no_linalg_op(%arg0: f32) -> (f32, f32) { %0 = mulf %arg0, %arg0 : f32 return %0, %0 : f32, f32 } -// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]]) -> ([[TYPE]], [[TYPE]]) -// CHECK: %[[RESULT:.*]] = mulf %[[NEW_ARG0]], %[[NEW_ARG0]] : [[TYPE]] -// CHECK: return %[[RESULT]], %[[RESULT]] : [[TYPE]], [[TYPE]] +// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]]) -> ([[TYPE]], [[TYPE]]) +// CHECK: %[[RESULT:.*]] = mulf %[[NEW_ARG0]], %[[NEW_ARG0]] : [[TYPE]] +// CHECK: return %[[RESULT]], %[[RESULT]] : [[TYPE]], [[TYPE]] + +// ----- + +#map_2d = affine_map<(d0, d1) -> (d0, d1)> +#map_2d_inv = affine_map<(d0, d1) -> (d1, d0)> + +func @dynamic_results(%arg0: tensor) + -> (tensor, tensor) { + %0, %1 = linalg.generic { + indexing_maps = [#map_2d, #map_2d, #map_2d_inv], + iterator_types = ["parallel", "parallel"] + } ins(%arg0 : tensor) { + ^bb0(%gen_arg1: f32): + %tmp1 = exp %gen_arg1 : f32 + linalg.yield %tmp1, %tmp1 : f32, f32 + } -> tensor, tensor + return %0, %1 : tensor, tensor +} + +// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #map1 = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: func @dynamic_results +// CHECK-SAME: (%[[INPUT:.*]]: [[TYPE:.*]], %[[OUT_1:.*]]: [[TYPE]], %[[OUT_2:.*]]: [[TYPE]]) { +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[DIM_0:.*]] = dim %[[INPUT]], %[[C0]] : [[TYPE]] +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[DIM_1:.*]] = dim %[[INPUT]], %[[C1]] : [[TYPE]] +// CHECK: %[[OUT_BUF_1:.*]] = alloc(%[[DIM_0]], %[[DIM_1]]) : [[TYPE]] +// CHECK: %[[OUT_BUF_2:.*]] = alloc(%[[DIM_1]], %[[DIM_0]]) : [[TYPE]] + +// CHECK: linalg.generic {indexing_maps = [#map0, #map0, #map1], {{.*}}} +// CHECK-SAME: ins(%[[INPUT]] : [[TYPE]]) +// CHECK-SAME: outs(%[[OUT_BUF_1]], %[[OUT_BUF_2]] : [[TYPE]], [[TYPE]]) { + +// CHECK: linalg.copy(%[[OUT_BUF_1]], %[[OUT_1]]) : [[TYPE]], [[TYPE]] +// CHECK: dealloc %[[OUT_BUF_1]] : [[TYPE]] +// CHECK: linalg.copy(%[[OUT_BUF_2]], %[[OUT_2]]) : [[TYPE]], [[TYPE]] +// CHECK: dealloc %[[OUT_BUF_2]] : [[TYPE]] +// CHECK: return