[mlir][tosa] Make tosa MakeBroadcastable pass handle unreanked tensors.
authorRob Suderman <rob.suderman@gmail.com>
Fri, 23 Jul 2021 00:22:11 +0000 (17:22 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Fri, 23 Jul 2021 00:57:05 +0000 (17:57 -0700)
If this pass executes without shape inference its possible for unranked tensors
to appear in the IR. This pass should gracefully handle unranked tensors.

Differential Revision: https://reviews.llvm.org/D106617

mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

index e850e1f..98df911 100644 (file)
@@ -108,18 +108,24 @@ static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
 /// operations equal. Returns the updated input1 and input2 for the original
 /// input. The caller is expected to use these to rewrite the original operator
 /// with the RESHAPE now in the graph.
-static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
-                                RankedTensorType outputType, Value input1,
-                                Value input2, Value &outInput1,
-                                Value &outInput2) {
+static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
+                                          Location loc,
+                                          RankedTensorType outputType,
+                                          Value input1, Value input2,
+                                          Value &outInput1, Value &outInput2) {
+  auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
+  auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
 
-  int64_t input1Rank = input1.getType().cast<RankedTensorType>().getRank();
-  int64_t input2Rank = input2.getType().cast<RankedTensorType>().getRank();
+  if (!input1Ty || !input2Ty)
+    return failure();
+
+  int64_t input1Rank = input1Ty.getRank();
+  int64_t input2Rank = input2Ty.getRank();
 
   Value higherTensorValue, lowerTensorValue;
-  // return if rank already match
+  // Cannot rewrite as its already correct.
   if (input1Rank == input2Rank)
-    return 1;
+    return failure();
 
   if (input1Rank > input2Rank) {
     higherTensorValue = input1;
@@ -129,24 +135,27 @@ static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
     lowerTensorValue = input1;
   }
 
-  ArrayRef<int64_t> outputRankShape = outputType.getShape();
   ArrayRef<int64_t> higherRankShape =
       higherTensorValue.getType().cast<RankedTensorType>().getShape();
   (void)higherRankShape;
   ArrayRef<int64_t> lowerRankShape =
       lowerTensorValue.getType().cast<RankedTensorType>().getShape();
 
-  // outputRank == higherRank == max(input1Rank, input2Rank)
-  assert(higherRankShape.size() == outputRankShape.size());
-
   SmallVector<int64_t, 4> reshapeOutputShape;
 
-  computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape);
+  computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape);
 
   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
   auto reshapeOutputType = RankedTensorType::get(
       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
 
+  // Verify the rank agrees with the output type if the output type is ranked.
+  if (outputType) {
+    if (outputType.getShape().size() != reshapeOutputShape.size() ||
+        outputType.getShape().size() != higherRankShape.size())
+      return failure();
+  }
+
   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
       loc, reshapeOutputType, lowerTensorValue,
       rewriter.getI64ArrayAttr(reshapeOutputShape));
@@ -159,7 +168,7 @@ static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
     outInput2 = higherTensorValue;
   }
 
-  return 0;
+  return success();
 }
 
 namespace {
@@ -173,11 +182,13 @@ struct ConvertTosaOp : public OpRewritePattern<OpTy> {
     Value input1 = tosaBinaryOp.input1();
     Value input2 = tosaBinaryOp.input2();
     Value output = tosaBinaryOp.getResult();
-    auto outputType = output.getType().cast<RankedTensorType>();
+
+    auto outputType = output.getType().dyn_cast<RankedTensorType>();
 
     Value outInput1, outInput2;
     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
-                             input1, input2, outInput1, outInput2))
+                             input1, input2, outInput1, outInput2)
+            .failed())
       return failure();
 
     rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
@@ -200,11 +211,12 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
     Value input2 = tosaBinaryOp.input2();
     int32_t shift = tosaBinaryOp.shift();
     Value output = tosaBinaryOp.getResult();
-    auto outputType = output.getType().cast<RankedTensorType>();
+    auto outputType = output.getType().dyn_cast<RankedTensorType>();
 
     Value outInput1, outInput2;
     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
-                             input1, input2, outInput1, outInput2))
+                             input1, input2, outInput1, outInput2)
+            .failed())
       return failure();
 
     rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
@@ -233,7 +245,8 @@ struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
 
     Value outInput1, outInput2;
     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
-                             input1, input2, outInput1, outInput2))
+                             input1, input2, outInput1, outInput2)
+            .failed())
       return failure();
 
     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(