From 346830051105a849d7fc3ceb246e65acbc0264ae Mon Sep 17 00:00:00 2001 From: Ehsan Toosi Date: Mon, 4 May 2020 16:06:59 +0200 Subject: [PATCH] [MLIR] Update the FunctionAndBlockSignatureConverter and NonVoidToVoidReturnOpConverter of Buffer Assignment Making these two converters more generic. FunctionAndBlockSignatureConverter now moves only memref results (after type conversion) to the function argument and keeps other legal function results unchanged. NonVoidToVoidReturnOpConverter is renamed to NoBufferOperandsReturnOpConverter. It removes only the buffer operands from the operands of the converted ReturnOp and inserts CopyOps to copy each buffer to the target function argument. Differential Revision: https://reviews.llvm.org/D79329 --- mlir/include/mlir/Transforms/BufferPlacement.h | 66 ++++++++++++---------- .../Dialect/Linalg/Transforms/TensorsToBuffers.cpp | 28 +-------- mlir/lib/Transforms/BufferPlacement.cpp | 37 +++++++++--- ...tion.mlir => buffer-placement-preparation.mlir} | 40 ++++++++++++- mlir/test/lib/Transforms/TestBufferPlacement.cpp | 17 +++--- 5 files changed, 117 insertions(+), 71 deletions(-) rename mlir/test/Transforms/{buffer-placement-prepration.mlir => buffer-placement-preparation.mlir} (80%) diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h index 013a55f..030b875 100644 --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -76,12 +76,11 @@ protected: TypeConverter *converter; }; -/// This conversion adds an extra argument for each function result which makes -/// the converted function a void function. A type converter must be provided -/// for this conversion to convert a non-shaped type to memref. -/// BufferAssignmentTypeConverter is an helper TypeConverter for this -/// purpose. All the non-shaped type of the input function will be converted to -/// memref. +/// Converts the signature of the function using the type converter. +/// It adds an extra argument for each illegally-typed function +/// result to the function arguments. `BufferAssignmentTypeConverter` +/// is a helper `TypeConverter` for this purpose. All the non-shaped types +/// of the input function will be converted to memref. class FunctionAndBlockSignatureConverter : public BufferAssignmentOpConversionPattern { public: @@ -94,12 +93,12 @@ public: ConversionPatternRewriter &rewriter) const final; }; -/// This pattern converter transforms a non-void ReturnOpSourceTy into a void -/// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy -/// to copy the results to the output buffer. +/// Converts the source `ReturnOp` to target `ReturnOp`, removes all +/// the buffer operands from the operands list, and inserts `CopyOp`s +/// for all buffer operands instead. template -class NonVoidToVoidReturnOpConverter +class NoBufferOperandsReturnOpConverter : public BufferAssignmentOpConversionPattern { public: using BufferAssignmentOpConversionPattern< @@ -109,29 +108,38 @@ public: LogicalResult matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - unsigned numReturnValues = returnOp.getNumOperands(); Block &entryBlock = returnOp.getParentRegion()->front(); unsigned numFuncArgs = entryBlock.getNumArguments(); Location loc = returnOp.getLoc(); - // Find the corresponding output buffer for each operand. - assert(numReturnValues <= numFuncArgs && - "The number of operands of return operation is more than the " - "number of function argument."); - unsigned firstReturnParameter = numFuncArgs - numReturnValues; - for (auto operand : llvm::enumerate(operands)) { - unsigned returnArgNumber = firstReturnParameter + operand.index(); - BlockArgument dstBuffer = entryBlock.getArgument(returnArgNumber); - if (dstBuffer == operand.value()) - continue; - - // Insert the copy operation to copy before the return. - rewriter.setInsertionPoint(returnOp); - rewriter.create(loc, operand.value(), - entryBlock.getArgument(returnArgNumber)); - } - // Insert the new target return operation. - rewriter.replaceOpWithNewOp(returnOp); + // The target `ReturnOp` should not contain any memref operands. + SmallVector newOperands(operands.begin(), operands.end()); + llvm::erase_if(newOperands, [](Value operand) { + return operand.getType().isa(); + }); + + // Find the index of the first destination buffer. + unsigned numBufferOperands = operands.size() - newOperands.size(); + unsigned destArgNum = numFuncArgs - numBufferOperands; + + rewriter.setInsertionPoint(returnOp); + // Find the corresponding destination buffer for each memref operand. + for (Value operand : operands) + if (operand.getType().isa()) { + assert(destArgNum < numFuncArgs && + "The number of operands of return operation is more than the " + "number of function argument."); + + // For each memref type operand of the source `ReturnOp`, a new `CopyOp` + // is inserted that copies the buffer content from the operand to the + // target. + rewriter.create(loc, operand, + entryBlock.getArgument(destArgNum)); + ++destArgNum; + } + + // Insert the new target Return operation. + rewriter.replaceOpWithNewOp(returnOp, newOperands); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp index 9350101..9b5855d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -21,8 +21,8 @@ using namespace mlir; using ReturnOpConverter = - NonVoidToVoidReturnOpConverter; + NoBufferOperandsReturnOpConverter; namespace { /// A pattern to convert Generic Linalg operations which work on tensors to @@ -132,30 +132,6 @@ struct ConvertLinalgOnTensorsToBuffers Optional( isLegalOperation)); - // TODO: Considering the following dynamic legality checks, the current - // implementation of FunctionAndBlockSignatureConverter of Buffer Assignment - // will convert the function signature incorrectly. This converter moves - // all the return values of the function to the input argument list without - // considering the return value types and creates a void function. However, - // the NonVoidToVoidReturnOpConverter doesn't change the return operation if - // its operands are not tensors. The following example leaves the IR in a - // broken state. - // - // @function(%arg0: f32, %arg1: tensor<4xf32>) -> (f32, f32) { - // %0 = mulf %arg0, %arg0 : f32 - // return %0, %0 : f32, f32 - // } - // - // broken IR after conversion: - // - // func @function(%arg0: f32, %arg1: memref<4xf32>, f32, f32) { - // %0 = mulf %arg0, %arg0 : f32 - // return %0, %0 : f32, f32 - // } - // - // This issue must be fixed in FunctionAndBlockSignatureConverter and - // NonVoidToVoidReturnOpConverter. - // Mark Standard Return operations illegal as long as one operand is tensor. target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { return llvm::none_of(returnOp.getOperandTypes(), isIllegalType); diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp index 24c228e..cd0641c 100644 --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -43,7 +43,8 @@ // The current implementation does not support loops and the resulting code will // be invalid with respect to program semantics. The only thing that is // currently missing is a high-level loop analysis that allows us to move allocs -// and deallocs outside of the loop blocks. +// and deallocs outside of the loop blocks. Furthermore, it doesn't also accept +// functions which return buffers already. // //===----------------------------------------------------------------------===// @@ -429,19 +430,39 @@ LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite( "FunctionAndBlockSignatureConverter"); return failure(); } - // Converting shaped type arguments to memref type. auto funcType = funcOp.getType(); + TypeRange resultTypes = funcType.getResults(); + if (llvm::any_of(resultTypes, + [](Type type) { return type.isa(); })) + return funcOp.emitError("BufferAssignmentPlacer doesn't currently support " + "functions which return memref typed values"); + + // Convert function arguments using the provided TypeConverter. TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); for (auto argType : llvm::enumerate(funcType.getInputs())) conversion.addInputs(argType.index(), converter->convertType(argType.value())); - // Adding function results to the arguments of the converted function as - // memref type. The converted function will be a void function. - for (Type resType : funcType.getResults()) - conversion.addInputs(converter->convertType((resType))); + + // Adding a function argument for each function result which is going to be a + // memref type after type conversion. + SmallVector newResultTypes; + newResultTypes.reserve(funcOp.getNumResults()); + for (Type resType : resultTypes) { + Type convertedType = converter->convertType(resType); + + // If the result type is memref after the type conversion, a new argument + // should be appended to the function arguments list for this result. + // Otherwise, it remains unchanged as a function result. + if (convertedType.isa()) + conversion.addInputs(convertedType); + else + newResultTypes.push_back(convertedType); + } + + // Update the signature of the function. rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType( - rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); + funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + newResultTypes)); rewriter.applySignatureConversion(&funcOp.getBody(), conversion); }); return success(); diff --git a/mlir/test/Transforms/buffer-placement-prepration.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir similarity index 80% rename from mlir/test/Transforms/buffer-placement-prepration.mlir rename to mlir/test/Transforms/buffer-placement-preparation.mlir index 7621253..ef7a2e3 100644 --- a/mlir/test/Transforms/buffer-placement-prepration.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file %s | FileCheck %s -dump-input-on-failure +// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file -verify-diagnostics %s | FileCheck %s -dump-input-on-failure // CHECK-LABEL: func @func_signature_conversion func @func_signature_conversion(%arg0: tensor<4x8xf32>) { @@ -8,6 +8,44 @@ func @func_signature_conversion(%arg0: tensor<4x8xf32>) { // ----- +// expected-error @below {{BufferAssignmentPlacer doesn't currently support functions which return memref typed values}} +// expected-error @below {{failed to legalize operation 'func'}} +func @memref_in_function_results(%arg0: tensor<4x8xf32>) -> (tensor<4x8xf32>, memref<5xf32>) { + %0 = alloc() : memref<5xf32> + return %arg0, %0 : tensor<4x8xf32>, memref<5xf32> +} + +// ----- + +// CHECK-LABEL: func @no_signature_conversion_is_needed +func @no_signature_conversion_is_needed(%arg0: memref<4x8xf32>) { + return +} +// CHECK: ({{.*}}: memref<4x8xf32>) { + +// ----- + +// CHECK-LABEL: func @no_signature_conversion_is_needed +func @no_signature_conversion_is_needed(%arg0: i1, %arg1: f16) -> (i1, f16){ + return %arg0, %arg1 : i1, f16 +} +// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16) -> (i1, f16) +// CHECK: return %[[ARG0]], %[[ARG1]] + +// ----- + +// CHECK-LABEL: func @complex_signature_conversion +func @complex_signature_conversion(%arg0: tensor<4x8xf32>, %arg1: i1, %arg2: tensor<5x5xf64>,%arg3: f16) -> (i1, tensor<5x5xf64>, f16, tensor<4x8xf32>) { + return %arg1, %arg2, %arg3, %arg0 : i1, tensor<5x5xf64>, f16, tensor<4x8xf32> +} +// CHECK: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5x5xf64>, %[[ARG3:.*]]: f16, +// CHECK-SAME: %[[RESULT1:.*]]: memref<5x5xf64>, %[[RESULT2:.*]]: memref<4x8xf32>) -> (i1, f16) { +// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]]) +// CHECK-NEXT: linalg.copy(%[[ARG0]], %[[RESULT2]]) +// CHECK-NEXT: return %[[ARG1]], %[[ARG3]] + +// ----- + // CHECK-LABEL: func @non_void_to_void_return_op_converter func @non_void_to_void_return_op_converter(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { return %arg0 : tensor<4x8xf32> diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp index 03c6a2a..2d781e6 100644 --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -23,7 +23,7 @@ using namespace mlir; namespace { /// This pass tests the computeAllocPosition helper method and two provided /// operation converters, FunctionAndBlockSignatureConverter and -/// NonVoidToVoidReturnOpConverter. Furthermore, this pass converts linalg +/// NoBufferOperandsReturnOpConverter. Furthermore, this pass converts linalg /// operations on tensors to linalg operations on buffers to prepare them for /// the BufferPlacement pass that can be applied afterwards. struct TestBufferPlacementPreparationPass @@ -82,7 +82,6 @@ struct TestBufferPlacementPreparationPass auto type = result.getType().cast(); entryBlock.addArgument(type.getElementType()); } - rewriter.eraseOp(op); return success(); } @@ -95,7 +94,7 @@ struct TestBufferPlacementPreparationPass patterns->insert< FunctionAndBlockSignatureConverter, GenericOpConverter, - NonVoidToVoidReturnOpConverter< + NoBufferOperandsReturnOpConverter< ReturnOp, ReturnOp, linalg::CopyOp> >(context, placer, converter); // clang-format on @@ -105,8 +104,9 @@ struct TestBufferPlacementPreparationPass auto &context = getContext(); ConversionTarget target(context); BufferAssignmentTypeConverter converter; - // Make all linalg operations illegal as long as they work on tensors. target.addLegalDialect(); + + // Make all linalg operations illegal as long as they work on tensors. target.addDynamicallyLegalDialect( Optional( [&](Operation *op) { @@ -117,9 +117,12 @@ struct TestBufferPlacementPreparationPass llvm::none_of(op->getResultTypes(), isIllegalType); })); - // Mark return operations illegal as long as they return values. - target.addDynamicallyLegalOp( - [](mlir::ReturnOp returnOp) { return returnOp.getNumOperands() == 0; }); + // Mark std.ReturnOp illegal as long as an operand is tensor or buffer. + target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { + return llvm::none_of(returnOp.getOperandTypes(), [&](Type type) { + return type.isa() || !converter.isLegal(type); + }); + }); // Mark the function whose arguments are in tensor-type illegal. target.addDynamicallyLegalOp([&](FuncOp funcOp) { -- 2.7.4