From a5cee3e386bde28ce21ff2ead3fc420f018604ca Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Mon, 17 Jul 2023 16:14:16 -0700 Subject: [PATCH] [mlir][linalg] Add a padding case for `ComplexType` If the paddingAttr is an ArrayAttr with two values we know that the element type is a `ComplexType` and we should pad the value accordingly. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D154908 --- mlir/lib/Dialect/Linalg/Transforms/Padding.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index fe720aa..f87fbbe 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" @@ -125,8 +126,17 @@ static FailureOr padOperandToSmallestStaticBoundingBox( return rewriter.notifyMatchFailure(opToPad, "--no padding value specified"); } Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()]; - Value paddingValue = rewriter.create( - opToPad.getLoc(), cast(paddingAttr)); + + Value paddingValue; + if (auto complexTy = dyn_cast( + getElementTypeOrSelf(opOperand->get().getType()))) { + auto complexAttr = cast(paddingAttr); + paddingValue = rewriter.create(opToPad.getLoc(), + complexTy, complexAttr); + } else { + paddingValue = rewriter.create( + opToPad.getLoc(), cast(paddingAttr)); + } // Pad the operand to the bounding box defined by `paddedShape`. auto paddedTensorType = RankedTensorType::get( -- 2.7.4