[mlir][tosa] Fix tosa.avg_pool2d lowering to normalize correctly
authorRob Suderman <rob.suderman@gmail.com>
Mon, 17 May 2021 17:00:09 +0000 (10:00 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Mon, 17 May 2021 17:00:43 +0000 (10:00 -0700)
Initial version of pooling assumed normalization was accross all elements
equally. TOSA actually requires the noramalization is perform by how
many elements were summed (edges are not artifically dimmer). Updated
the lowering to reflect this change with corresponding tests.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D102540

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index 06b23cc..a68c8bc 100644 (file)
@@ -2263,7 +2263,7 @@ public:
     pad.resize(2, 0);
     getValuesFromIntArrayAttribute(op.pad(), pad);
     pad.resize(pad.size() + 2, 0);
-    input = applyPad(loc, input, pad, initialAttr, rewriter);
+    Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
 
     Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
 
@@ -2273,7 +2273,6 @@ public:
 
     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
-    int64_t kernelSize = kernel[0] * kernel[1];
 
     // Create the linalg op that performs pooling.
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
@@ -2290,7 +2289,7 @@ public:
           rewriter
               .create<std::remove_pointer_t<decltype(typePtr)>>(
                   loc, ArrayRef<Type>{resultTy},
-                  ValueRange{input, fakeWindowDims}, filledInitTensor,
+                  ValueRange{paddedInput, fakeWindowDims}, filledInitTensor,
                   dilationAttr, strideAttr)
               .getOperation());
     };
@@ -2324,14 +2323,76 @@ public:
     }
 
     if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
-      linalg::LinalgOp poolingOp =
-          createOp(static_cast<linalg::PoolingNHWCSumFOp *>(nullptr));
-      auto constAttr = DenseElementsAttr::get(
-          resultTy, static_cast<float>(1.0 / kernelSize));
-      auto constant = rewriter.create<ConstantOp>(loc, constAttr);
-      auto mul = rewriter.create<tosa::MulOp>(
-          loc, resultTy, poolingOp->getResult(0), constant, 0);
-      rewriter.replaceOp(op, mul.output());
+      Value poolingOp =
+          createOp(static_cast<linalg::PoolingNHWCSumFOp *>(nullptr))
+              ->getResult(0);
+      auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
+      auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
+      auto genericOp = rewriter.create<linalg::IndexedGenericOp>(
+          loc, ArrayRef<Type>({resultTy}), ValueRange{}, ValueRange{poolingOp},
+          ArrayRef<AffineMap>({affineMap}),
+          getNParallelLoopsAttrs(resultTy.getRank()),
+          [&](OpBuilder &b, Location loc, ValueRange indices, ValueRange args) {
+            auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
+            auto one = rewriter.create<ConstantIndexOp>(loc, 1);
+            auto iH = rewriter.create<ConstantIndexOp>(
+                loc, poolingOpTy.getDimSize(1) - 1);
+            auto iW = rewriter.create<ConstantIndexOp>(
+                loc, poolingOpTy.getDimSize(2) - 1);
+
+            // Compute the indices from either end.
+            auto y0 = indices[1];
+            auto x0 = indices[2];
+            auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
+            auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
+
+            // Determines what the portion of valid input is covered by the
+            // kernel.
+            auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
+              if (pad == 0)
+                return v;
+
+              auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
+              Value dx = rewriter.create<SubIOp>(loc, x, padVal);
+
+              Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
+                                                        dx, zero);
+              Value offset =
+                  rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
+              return rewriter.create<mlir::AddIOp>(loc, v, offset)
+                  ->getResult(0);
+            };
+
+            // Compute the vertical component of coverage.
+            auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
+            auto kH1 = padFn(kH0, y0, pad[2]);
+            auto kH2 = padFn(kH1, y1, pad[3]);
+            auto kHCmp =
+                rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
+            auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
+
+            // compute teh horizontal component of coverage.
+            auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
+            auto kW1 = padFn(kW0, x0, pad[4]);
+            auto kW2 = padFn(kW1, x1, pad[5]);
+            auto kWCmp =
+                rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
+            auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
+
+            // Compute the total number of elements and normalize.
+            Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
+            auto countI = rewriter.create<mlir::IndexCastOp>(
+                loc, rewriter.getI32Type(), count);
+            auto countF =
+                rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
+
+            auto div =
+                rewriter.create<DivFOp>(loc, args[0], countF)->getResult(0);
+
+            rewriter.create<linalg::YieldOp>(loc, div);
+          });
+
+      rewriter.replaceOp(op, genericOp.getResult(0));
       return success();
     }
 
index c2e6b07..b789072 100644 (file)
@@ -1087,17 +1087,59 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
 // -----
 
 // CHECK-LABEL: @avg_pool
-func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
-  // CHECK-DAG: [[CONST:%.+]] = constant 0
-  // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 3, 31, 62]
-  // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
-  // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [4, 4]
-  // CHECK: linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x3x31x62xf32>)
-  // CHECK: constant dense<6.250000e-02>
-  // CHECK: linalg.generic
-  // CHECK: mulf
-  %0 = "tosa.avg_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>)  -> (tensor<1x3x31x62xf32>)
-  return
+func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
+  // Initial piece computes the sum of the pooling region, with appropriate padding.
+  // CHECK: [[CONST:%.+]] = constant 0
+  // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] 
+  // CHECK: [[CONST:%.+]] = constant 0
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
+  // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
+  // CHECK: [[KERNEL:%.+]] = linalg.init_tensor [4, 4]
+  // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>)
+  // CHECK: [[GENERIC:%.+]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs([[POOL]] : tensor<1x5x33x62xf32>)
+  // CHECK:   [[ZERO:%.0]] = constant 0
+  // CHECK:   [[ONE:%.+]] = constant 1
+  // CHECK:   [[HEIGHT:%.+]] = constant 4
+  // CHECK:   [[WIDTH:%.+]] = constant 32
+
+  // The large block below computes what portion of the kernel is within non-padded input.
+  // CHECK:   [[NY:%.+]] = subi [[HEIGHT]], %arg2
+  // CHECK:   [[NX:%.+]] = subi [[WIDTH]], %arg3
+  // CHECK:   [[KH:%.+]] = constant 4
+  // CHECK:   [[PAD0:%.+]] = constant 1
+  // CHECK:   [[SUBP0:%.+]] = subi %arg2, [[PAD0]]
+  // CHECK:   [[P0CMP:%.+]] = cmpi slt, [[SUBP0]], [[ZERO]]
+  // CHECK:   [[SELP0:%.+]] = select [[P0CMP]], [[SUBP0]], [[ZERO]]
+  // CHECK:   [[ADDP0:%.+]] = addi [[KH]], [[SELP0]]
+  // CHECK:   [[PAD1:%.+]] = constant 1
+  // CHECK:   [[SUBP1:%.+]] = subi [[NY]], [[PAD1]]
+  // CHECK:   [[P1CMP:%.+]] = cmpi slt, [[SUBP1]], [[ZERO]]
+  // CHECK:   [[SELP1:%.+]] = select [[P1CMP]], [[SUBP1]], [[ZERO]]
+  // CHECK:   [[ADDP1:%.+]] = addi [[ADDP0]], [[SELP1]]
+  // CHECK:   [[YCMP:%.+]] = cmpi slt, [[ADDP1]], [[ONE]]
+  // CHECK:   [[YSEL:%.+]] = select [[YCMP]], [[ONE]], [[ADDP1]]
+  // CHECK:   [[KW:%.+]] = constant 4 : index
+  // CHECK:   [[PAD2:%.+]] = constant 1 : index
+  // CHECK:   [[SUBP2:%.+]] = subi %arg3, [[PAD2]]
+  // CHECK:   [[P2CMP:%.+]] = cmpi slt, [[SUBP2]], [[ZERO]]
+  // CHECK:   [[SELP2:%.+]] = select [[P2CMP]], [[SUBP2]], [[ZERO]]
+  // CHECK:   [[ADDP2:%.+]] = addi [[KW]], [[SELP2]]
+  // CHECK:   [[PAD3:%.+]] = constant 1 : index
+  // CHECK:   [[SUBP3:%.+]] = subi [[NX]], [[PAD3]]
+  // CHECK:   [[P3CMP:%.+]] = cmpi slt, [[SUBP3]], [[ZERO]]
+  // CHECK:   [[SELP3:%.+]] = select [[P3CMP]], [[SUBP3]], [[ZERO]]
+  // CHECK:   [[ADDP3:%.+]] = addi [[ADDP2]], [[SELP3]]
+  // CHECK:   [[XCMP:%.+]] = cmpi slt, [[ADDP3]], [[ONE]]
+  // CHECK:   [[XSEL:%.+]] = select [[XCMP]], [[ONE]], [[ADDP3]]
+
+  // Given the valid coverage of the pooling region, normalize the summation.
+  // CHECK:   [[C:%.+]] = muli [[YSEL]], [[XSEL]]
+  // CHECK:   [[CI:%.+]] = index_cast [[C]]
+  // CHECK:   [[CF:%.+]] = sitofp [[CI]]
+  // CHECK:   [[RESULT:%.+]] = divf %arg5, [[CF]]
+  // CHECK:   linalg.yield [[RESULT]]
+  %0 = "tosa.avg_pool2d"(%arg0) {pad = [1, 1, 1, 1], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>)  -> (tensor<1x5x33x62xf32>)
+  return %0 : tensor<1x5x33x62xf32>
 }
 
 // -----