/// operations equal. Returns the updated input1 and input2 for the original
/// input. The caller is expected to use these to rewrite the original operator
/// with the RESHAPE now in the graph.
-static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
- RankedTensorType outputType, Value input1,
- Value input2, Value &outInput1,
- Value &outInput2) {
+static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
+ Location loc,
+ RankedTensorType outputType,
+ Value input1, Value input2,
+ Value &outInput1, Value &outInput2) {
+ auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
+ auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
- int64_t input1Rank = input1.getType().cast<RankedTensorType>().getRank();
- int64_t input2Rank = input2.getType().cast<RankedTensorType>().getRank();
+ if (!input1Ty || !input2Ty)
+ return failure();
+
+ int64_t input1Rank = input1Ty.getRank();
+ int64_t input2Rank = input2Ty.getRank();
Value higherTensorValue, lowerTensorValue;
- // return if rank already match
+ // Cannot rewrite as its already correct.
if (input1Rank == input2Rank)
- return 1;
+ return failure();
if (input1Rank > input2Rank) {
higherTensorValue = input1;
lowerTensorValue = input1;
}
- ArrayRef<int64_t> outputRankShape = outputType.getShape();
ArrayRef<int64_t> higherRankShape =
higherTensorValue.getType().cast<RankedTensorType>().getShape();
(void)higherRankShape;
ArrayRef<int64_t> lowerRankShape =
lowerTensorValue.getType().cast<RankedTensorType>().getShape();
- // outputRank == higherRank == max(input1Rank, input2Rank)
- assert(higherRankShape.size() == outputRankShape.size());
-
SmallVector<int64_t, 4> reshapeOutputShape;
- computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape);
+ computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape);
auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+ // Verify the rank agrees with the output type if the output type is ranked.
+ if (outputType) {
+ if (outputType.getShape().size() != reshapeOutputShape.size() ||
+ outputType.getShape().size() != higherRankShape.size())
+ return failure();
+ }
+
auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
loc, reshapeOutputType, lowerTensorValue,
rewriter.getI64ArrayAttr(reshapeOutputShape));
outInput2 = higherTensorValue;
}
- return 0;
+ return success();
}
namespace {
Value input1 = tosaBinaryOp.input1();
Value input2 = tosaBinaryOp.input2();
Value output = tosaBinaryOp.getResult();
- auto outputType = output.getType().cast<RankedTensorType>();
+
+ auto outputType = output.getType().dyn_cast<RankedTensorType>();
Value outInput1, outInput2;
if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
- input1, input2, outInput1, outInput2))
+ input1, input2, outInput1, outInput2)
+ .failed())
return failure();
rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
Value input2 = tosaBinaryOp.input2();
int32_t shift = tosaBinaryOp.shift();
Value output = tosaBinaryOp.getResult();
- auto outputType = output.getType().cast<RankedTensorType>();
+ auto outputType = output.getType().dyn_cast<RankedTensorType>();
Value outInput1, outInput2;
if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
- input1, input2, outInput1, outInput2))
+ input1, input2, outInput1, outInput2)
+ .failed())
return failure();
rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
Value outInput1, outInput2;
if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
- input1, input2, outInput1, outInput2))
+ input1, input2, outInput1, outInput2)
+ .failed())
return failure();
rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(