From 286e7bdd3ea4f1c7a90a2877e28f353dcd9a7493 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 22 Jul 2021 17:22:11 -0700 Subject: [PATCH] [mlir][tosa] Make tosa MakeBroadcastable pass handle unreanked tensors. If this pass executes without shape inference its possible for unranked tensors to appear in the IR. This pass should gracefully handle unranked tensors. Differential Revision: https://reviews.llvm.org/D106617 --- .../Tosa/Transforms/TosaMakeBroadcastable.cpp | 51 ++++++++++++++-------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp index e850e1f..98df911 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -108,18 +108,24 @@ static void computeReshapeOutput(ArrayRef higherRankShape, /// operations equal. Returns the updated input1 and input2 for the original /// input. The caller is expected to use these to rewrite the original operator /// with the RESHAPE now in the graph. -static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, - RankedTensorType outputType, Value input1, - Value input2, Value &outInput1, - Value &outInput2) { +static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, + Location loc, + RankedTensorType outputType, + Value input1, Value input2, + Value &outInput1, Value &outInput2) { + auto input1Ty = input1.getType().dyn_cast(); + auto input2Ty = input2.getType().dyn_cast(); - int64_t input1Rank = input1.getType().cast().getRank(); - int64_t input2Rank = input2.getType().cast().getRank(); + if (!input1Ty || !input2Ty) + return failure(); + + int64_t input1Rank = input1Ty.getRank(); + int64_t input2Rank = input2Ty.getRank(); Value higherTensorValue, lowerTensorValue; - // return if rank already match + // Cannot rewrite as its already correct. if (input1Rank == input2Rank) - return 1; + return failure(); if (input1Rank > input2Rank) { higherTensorValue = input1; @@ -129,24 +135,27 @@ static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, lowerTensorValue = input1; } - ArrayRef outputRankShape = outputType.getShape(); ArrayRef higherRankShape = higherTensorValue.getType().cast().getShape(); (void)higherRankShape; ArrayRef lowerRankShape = lowerTensorValue.getType().cast().getShape(); - // outputRank == higherRank == max(input1Rank, input2Rank) - assert(higherRankShape.size() == outputRankShape.size()); - SmallVector reshapeOutputShape; - computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape); + computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape); auto reshapeInputType = lowerTensorValue.getType().cast(); auto reshapeOutputType = RankedTensorType::get( ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); + // Verify the rank agrees with the output type if the output type is ranked. + if (outputType) { + if (outputType.getShape().size() != reshapeOutputShape.size() || + outputType.getShape().size() != higherRankShape.size()) + return failure(); + } + auto reshapeLower = rewriter.create( loc, reshapeOutputType, lowerTensorValue, rewriter.getI64ArrayAttr(reshapeOutputShape)); @@ -159,7 +168,7 @@ static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, outInput2 = higherTensorValue; } - return 0; + return success(); } namespace { @@ -173,11 +182,13 @@ struct ConvertTosaOp : public OpRewritePattern { Value input1 = tosaBinaryOp.input1(); Value input2 = tosaBinaryOp.input2(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().cast(); + + auto outputType = output.getType().dyn_cast(); Value outInput1, outInput2; if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, - input1, input2, outInput1, outInput2)) + input1, input2, outInput1, outInput2) + .failed()) return failure(); rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, outInput1, @@ -200,11 +211,12 @@ struct ConvertTosaOp : public OpRewritePattern { Value input2 = tosaBinaryOp.input2(); int32_t shift = tosaBinaryOp.shift(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().cast(); + auto outputType = output.getType().dyn_cast(); Value outInput1, outInput2; if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, - input1, input2, outInput1, outInput2)) + input1, input2, outInput1, outInput2) + .failed()) return failure(); rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, @@ -233,7 +245,8 @@ struct ConvertTosaOp Value outInput1, outInput2; if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, - input1, input2, outInput1, outInput2)) + input1, input2, outInput1, outInput2) + .failed()) return failure(); rewriter.replaceOpWithNewOp( -- 2.7.4