Updated bbox_transform and nms unit test for caffe2 ops. (#16722)
authorPeizhao Zhang <stzpz@fb.com>
Fri, 15 Feb 2019 08:14:45 +0000 (00:14 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Feb 2019 08:21:55 +0000 (00:21 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16722

Updated bbox_transform and nms unit test for caffe2 ops.

Differential Revision: D13937416

fbshipit-source-id: 034743d29671c6e73d323a935e2d734ecc071bff

caffe2/operators/generate_proposals_op_util_boxes_test.cc
caffe2/operators/generate_proposals_op_util_nms_test.cc

index f9ff7e9..4e85da4 100644 (file)
@@ -21,18 +21,18 @@ TEST(UtilsBoxesTest, TestBboxTransformRandom) {
       -0.08533806, -0.60335309, 0.79052375;
 
   EMatXf result_gt(5, 4);
-  result_gt << 206.94953073, -30.71519157, 298.3876512, 245.44846569,
-      143.8712194, -83.34289038, 291.50227513, 122.05339902, 177.43029521,
-      198.66623633, 197.29527254, 229.70308414, 152.25190373, 145.43156421,
-      388.21547899, 275.59425266, 5.06242193, 11.04094661, 67.32890274,
-      270.68622005;
+  result_gt << 206.949539, -30.715202, 297.387665, 244.448486, 143.871216,
+      -83.342888, 290.502289, 121.053398, 177.430283, 198.666245, 196.295273,
+      228.703079, 152.251892, 145.431564, 387.215454, 274.594238, 5.062420,
+      11.040955, 66.328903, 269.686218;
 
   const float BBOX_XFORM_CLIP = log(1000.0 / 16.0);
   auto result = utils::bbox_transform(
       bbox.array(),
       deltas.array(),
       std::vector<float>{1.0, 1.0, 1.0, 1.0},
-      BBOX_XFORM_CLIP);
+      BBOX_XFORM_CLIP,
+      true);
   EXPECT_NEAR((result.matrix() - result_gt).norm(), 0.0, 1e-4);
 }
 
index 696ff83..b7da35b 100644 (file)
@@ -46,6 +46,57 @@ TEST(UtilsNMSTest, TestNMS) {
   }
 }
 
+TEST(UtilsNMSTest, TestNMS1) {
+  Eigen::ArrayXXf proposals(53, 4);
+  proposals << 350.9821, 161.8200, 369.9685, 205.2372, 250.5236, 154.2844,
+      274.1773, 204.9810, 471.4920, 160.4118, 496.0094, 213.4244, 352.0421,
+      164.5933, 366.4458, 205.9624, 166.0765, 169.7707, 183.0102, 232.6606,
+      252.3000, 183.1449, 269.6541, 210.6747, 469.7862, 162.0192, 482.1673,
+      187.0053, 168.4862, 174.2567, 181.7437, 232.9379, 470.3290, 162.3442,
+      496.4272, 214.6296, 251.0450, 155.5911, 272.2693, 203.3675, 252.0326,
+      154.7950, 273.7404, 195.3671, 351.7479, 161.9567, 370.6432, 204.3047,
+      496.3306, 161.7157, 515.0573, 210.7200, 471.0749, 162.6143, 485.3374,
+      207.3448, 250.9745, 160.7633, 264.1924, 206.8350, 470.4792, 169.0351,
+      487.1934, 220.2984, 474.4227, 161.9546, 513.1018, 215.5193, 251.9428,
+      184.1950, 262.6937, 207.6416, 252.6623, 175.0252, 269.8806, 213.7584,
+      260.9884, 157.0351, 288.3554, 206.6027, 251.3629, 164.5101, 263.2179,
+      202.4203, 471.8361, 190.8142, 485.6812, 220.8586, 248.6243, 156.9628,
+      264.3355, 199.2767, 495.1643, 158.0483, 512.6261, 184.4192, 376.8718,
+      168.0144, 387.3584, 201.3210, 122.9191, 160.7433, 172.5612, 231.3837,
+      350.3857, 175.8806, 366.2500, 205.4329, 115.2958, 162.7822, 161.9776,
+      229.6147, 168.4375, 177.4041, 180.8028, 232.4551, 169.7939, 184.4330,
+      181.4767, 232.1220, 347.7536, 175.9356, 355.8637, 197.5586, 495.5434,
+      164.6059, 516.4031, 207.7053, 172.1216, 194.6033, 183.1217, 235.2653,
+      264.2654, 181.5540, 288.4626, 214.0170, 111.7971, 183.7748, 137.3745,
+      225.9724, 253.4919, 186.3945, 280.8694, 210.0731, 165.5334, 169.7344,
+      185.9159, 232.8514, 348.3662, 184.5187, 354.9081, 201.4038, 164.6562,
+      162.5724, 186.3108, 233.5010, 113.2999, 186.8410, 135.8841, 219.7642,
+      117.0282, 179.8009, 142.5375, 221.0736, 462.1312, 161.1004, 495.3576,
+      217.2208, 462.5800, 159.9310, 501.2937, 224.1655, 503.5242, 170.0733,
+      518.3792, 209.0113, 250.3658, 195.5925, 260.6523, 212.4679, 108.8287,
+      163.6994, 146.3642, 229.7261, 256.7617, 187.3123, 288.8407, 211.2013,
+      161.2781, 167.4801, 186.3751, 232.7133, 115.3760, 177.5859, 163.3512,
+      236.9660, 248.9077, 188.0919, 264.8579, 207.9718, 108.1349, 160.7851,
+      143.6370, 229.6243, 465.0900, 156.7555, 490.3561, 213.5704, 107.5338,
+      173.4323, 141.0704, 235.2910;
+
+  Eigen::ArrayXXf scores(53, 1);
+  scores << 0.1919, 0.3293, 0.0860, 0.1600, 0.1885, 0.4297, 0.0974, 0.2711,
+      0.1483, 0.1173, 0.1034, 0.2915, 0.1993, 0.0677, 0.3217, 0.0966, 0.0526,
+      0.5675, 0.3130, 0.1592, 0.1353, 0.0634, 0.1557, 0.1512, 0.0699, 0.0545,
+      0.2692, 0.1143, 0.0572, 0.1990, 0.0558, 0.1500, 0.2214, 0.1878, 0.2501,
+      0.1343, 0.0809, 0.1266, 0.0743, 0.0896, 0.0781, 0.0983, 0.0557, 0.0623,
+      0.5808, 0.3090, 0.1050, 0.0524, 0.0513, 0.4501, 0.4167, 0.0623, 0.1749;
+
+  std::vector<int> output_gt{1,  6,  7,  8,  11, 12, 13, 14, 17,
+                             18, 19, 21, 23, 24, 25, 26, 30, 32,
+                             33, 34, 35, 37, 43, 44, 47, 50};
+
+  auto cur_out = utils::nms_cpu(proposals, scores, 0.5);
+  std::sort(cur_out.begin(), cur_out.end());
+  EXPECT_EQ(output_gt, cur_out);
+}
+
 TEST(UtilsNMSTest, TestSoftNMS) {
   Eigen::ArrayXXf input(5, 5);
   input.row(0) << 5.18349426e+02, 1.77783920e+02, 9.06085266e+02,