From 0312b25df0a872295f8db203fbebfb4a0d7f0f3e Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 31 Mar 2021 11:18:27 -0700 Subject: [PATCH] [mlir][tosa] Add tosa.table lowering to linalg.generic Table op lowering to linalg.generic for both i8 (behaves like a gather) and a pair of gathers with a quantized interpolation. Differential Revision: https://reviews.llvm.org/D99756 --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 177 ++++++++++++++++++--- .../Conversion/TosaToLinalg/tosa-to-linalg.mlir | 43 +++++ 2 files changed, 202 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 3fdbc6f..a6271f7 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1407,37 +1407,178 @@ public: } }; +// Lowerings the TableOp to a series of gathers and numerica operations. This +// includes interpolation between the high/low values. For the I8 varient, this +// simplifies to a single gather operation. +class TableConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TableOp op, + PatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + Value input = op.input(); + Value table = op.table(); + auto inputTy = input.getType().cast(); + auto tableTy = table.getType().cast(); + auto resultTy = op.getType().cast(); + + if (!inputTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "require input type to have static shape"); + + auto inputElementTy = inputTy.getElementType(); + auto tableElementTy = tableTy.getElementType(); + auto resultElementTy = resultTy.getElementType(); + + auto initTensor = + rewriter + .create(loc, ArrayRef{}, + resultTy.getShape(), resultElementTy) + .result(); + + SmallVector affineMaps = { + rewriter.getMultiDimIdentityMap(resultTy.getRank()), + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + auto genericOp = rewriter.create( + loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(resultTy.getRank())); + rewriter.replaceOp(op, genericOp.getResult(0)); + + { + OpBuilder::InsertionGuard regionGuard(rewriter); + Block *block = + rewriter.createBlock(&genericOp.region(), genericOp.region().end(), + TypeRange({inputElementTy, resultElementTy})); + + auto inputValue = block->getArgument(0); + rewriter.setInsertionPointToStart(block); + if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && + resultElementTy.isInteger(8)) { + Value index = rewriter.create(loc, rewriter.getIndexType(), + inputValue); + Value extract = + rewriter.create(loc, table, ValueRange{index}); + rewriter.create(loc, extract); + return success(); + } + + if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && + resultElementTy.isInteger(32)) { + Value extend = rewriter.create( + loc, rewriter.getI32Type(), inputValue); + + auto offset = + rewriter.create(loc, rewriter.getI32IntegerAttr(32768)); + auto seven = + rewriter.create(loc, rewriter.getI32IntegerAttr(7)); + auto one = + rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + auto b1111111 = + rewriter.create(loc, rewriter.getI32IntegerAttr(127)); + + // Compute the index and fractional part from the input value: + // value = value + 32768 + // index = value >> 7; + // fraction = 0x01111111 & value + auto extendAdd = rewriter.create(loc, extend, offset); + Value index = + rewriter.create(loc, extendAdd, seven); + Value fraction = rewriter.create(loc, extendAdd, b1111111); + + // Extract the base and next values from the table. + // base = (int32_t) table[index]; + // next = (int32_t) table[index + 1]; + Value indexPlusOne = rewriter.create(loc, index, one); + + index = + rewriter.create(loc, rewriter.getIndexType(), index); + indexPlusOne = rewriter.create( + loc, rewriter.getIndexType(), indexPlusOne); + + Value base = + rewriter.create(loc, table, ValueRange{index}); + Value next = rewriter.create( + loc, table, ValueRange{indexPlusOne}); + + base = rewriter.create(loc, rewriter.getI32Type(), base); + next = rewriter.create(loc, rewriter.getI32Type(), next); + + // Use the fractional part to interpolate between the input values: + // result = (base << 7) + (next - base) * fraction + Value baseScaled = rewriter.create(loc, base, seven); + Value diff = rewriter.create(loc, next, base); + Value diffScaled = rewriter.create(loc, diff, fraction); + Value result = rewriter.create(loc, baseScaled, diffScaled); + + rewriter.create(loc, result); + + return success(); + } + } + + return rewriter.notifyMatchFailure( + op, "unable to create body for tosa.table op"); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( RewritePatternSet *patterns) { patterns->add< - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, + // clang-format off + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, + PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, + PointwiseConverter, + PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, IdentityNConverter, - IdentityNConverter, ReduceConverter, - ReduceConverter, ReduceConverter, - ReduceConverter, ReduceConverter, - ReduceConverter, ArgMaxConverter, ConcatConverter, - PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter, - TileConverter, TransposeConverter, MatMulConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + IdentityNConverter, + IdentityNConverter, + ReduceConverter, + ReduceConverter, + ReduceConverter, + ReduceConverter, + ReduceConverter, + ReduceConverter, + ArgMaxConverter, + ConcatConverter, + PadConverter, + ReshapeConverter, + RescaleConverter, + ReverseConverter, + TableConverter, + TileConverter, + TransposeConverter, + MatMulConverter, FullyConnectedConverter>(patterns->getContext()); + // clang-format on } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 1bc4d6a..5d77c93 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -830,3 +830,46 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () { return } + +// ----- + +// CHECK-LABEL: @table8 +func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [6] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>) + // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8) + // CHECK: %[[CAST:.+]] = index_cast %[[ARG_IN]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[CAST]]] + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi8>, tensor<513xi8>) -> (tensor<6xi8>) + return +} + +// CHECK-LABEL: @table16 +func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [6] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi16>) outs(%[[INIT]] : tensor<6xi32>) + // CHECK: ^bb0(%arg2: i16, %arg3: i32) + // CHECK: %[[EXT_IN:.+]] = sexti %arg2 + // CHECK: %[[C32768:.+]] = constant 32768 + // CHECK: %[[C7:.+]] = constant 7 + // CHECK: %[[C1:.+]] = constant 1 + // CHECK: %[[C127:.+]] = constant 127 + // CHECK: %[[INADD:.+]] = addi %[[EXT_IN]], %[[C32768]] + // CHECK: %[[IDX:.+]] = shift_right_unsigned %[[INADD]], %[[C7]] + // CHECK: %[[FRACTION:.+]] = and %[[INADD]], %[[C127]] + // CHECK: %[[IDXPLUS1:.+]] = addi %[[IDX]], %[[C1]] + // CHECK: %[[IDX_CAST:.+]] = index_cast %[[IDX]] + // CHECK: %[[IDXPLUS1_CAST:.+]] = index_cast %[[IDXPLUS1]] + // CHECK: %[[BASE:.+]] = tensor.extract %arg1[%[[IDX_CAST]]] + // CHECK: %[[NEXT:.+]] = tensor.extract %arg1[%[[IDXPLUS1_CAST]]] + // CHECK: %[[BASE_EXT:.+]] = sexti %[[BASE]] + // CHECK: %[[NEXT_EXT:.+]] = sexti %[[NEXT]] + // CHECK: %[[BASE_MUL:.+]] = shift_left %[[BASE_EXT]], %[[C7]] + // CHECK: %[[DIFF:.+]] = subi %[[NEXT_EXT]], %[[BASE_EXT]] + // CHECK: %[[DIFF_MUL:.+]] = muli %[[DIFF]], %[[FRACTION]] + // CHECK: %[[RESULT:.+]] = addi %[[BASE_MUL]], %[[DIFF_MUL]] + // CHECK: linalg.yield %[[RESULT]] + %0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi16>, tensor<513xi16>) -> (tensor<6xi32>) + return +} -- 2.7.4