Add RoiAlign to Onnx frontend (#5454)
authorMatthew Brookhart <matthewbrookhart@gmail.com>
Tue, 28 Apr 2020 01:20:23 +0000 (18:20 -0700)
committerGitHub <noreply@github.com>
Tue, 28 Apr 2020 01:20:23 +0000 (10:20 +0900)
python/tvm/relay/frontend/onnx.py
src/relay/op/vision/rcnn_op.cc
tests/python/frontend/onnx/test_forward.py

index 7782400..2ef9450 100644 (file)
@@ -26,6 +26,7 @@ from .. import analysis
 from .. import expr as _expr
 from .. import function as _function
 from .. import op as _op
+from .. import vision as _vision
 from .common import AttrCvt, Renamer
 from .common import get_relay_op, new_var, infer_shape, infer_channels
 from .common import infer_type, infer_value, infer_value_simulated, get_name
@@ -1495,6 +1496,34 @@ class TopK(OnnxOpConverter):
 
         return _op.topk(inputs[0], k=K, axis=axis)
 
+
+class RoiAlign(OnnxOpConverter):
+    """Operator converter for TopK
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        if len(inputs) != 3:
+            raise ValueError("Expect 3 inputs only")
+        x = inputs[0]
+        rois = inputs[1]
+        batch_indices = inputs[2]
+        mode = attr.get("mode", "avg")
+        if mode != b'avg':
+            raise ValueError("RoiAlign in Relay only uses avg mode")
+        output_height = attr.get("output_height", 1)
+        output_width = attr.get("output_width", 1)
+
+        sampling_ratio = attr.get("sampling_ratio", 0)
+        spatial_scale = attr.get("spatial_scale", 1.0)
+
+        batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1)
+        batch_indices = _op.cast(
+            batch_indices, infer_type(rois).type_annotation.dtype)
+        rois = _op.concatenate([batch_indices, rois], 1)
+
+        return _vision.roi_align(x, rois, [output_height, output_width],
+                                 spatial_scale, sampling_ratio)
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1592,6 +1621,9 @@ def _get_convert_map(opset):
         # Recurrent Layers
         'LSTM': LSTM.get_converter(opset),
 
+        # defs/vision
+        'RoiAlign': RoiAlign.get_converter(opset),
+
         # defs/reduction
         'ReduceMax': ReduceMax.get_converter(opset),
         'ReduceMin': ReduceMin.get_converter(opset),
index 6b221a2..5661ebb 100644 (file)
@@ -36,6 +36,8 @@ bool ROIAlignRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   CHECK_EQ(types.size(), 3);
   const auto* data = types[0].as<TensorTypeNode>();
   const auto* rois = types[1].as<TensorTypeNode>();
+  CHECK(data);
+  CHECK(rois);
   const auto& dshape = data->shape;
   const auto& rshape = rois->shape;
   CHECK(roi_align_attrs);
index f33c5f9..1185a5c 100644 (file)
@@ -2432,6 +2432,68 @@ def test_topk():
         verify_topk([n, n, n], 5, 2)
     
 
+def test_roi_align():
+    def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0):
+        output_dims = [num_roi, input_dims[1], output_height, output_width]
+
+        node = helper.make_node('RoiAlign',
+                                inputs=['X', 'rois', 'batch_indicies'],
+                                outputs=['Y'],
+                                mode="avg",
+                                output_height=output_height,
+                                output_width=output_width,
+                                sampling_ratio=sampling_ratio,
+                                spatial_scale=spatial_scale,
+                                )
+
+        graph = helper.make_graph([node],
+                                  "roialign_test",
+                                  inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
+                                          helper.make_tensor_value_info(
+                                              "rois", TensorProto.FLOAT, [num_roi, 4]),
+                                          helper.make_tensor_value_info(
+                                              "batch_indicies", TensorProto.INT64, [num_roi, ]),
+                                          ],
+                                  outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)])
+
+        model = helper.make_model(graph, producer_name='roialign_test')
+
+        np_data = np.random.uniform(size=input_dims).astype("float32")
+        np_rois = np.random.uniform(size=[num_roi, 4]).astype(
+            'float32') * input_dims[2]
+        np_batch_indicies = np.random.randint(
+            low=0, high=input_dims[0], size=num_roi)
+
+        onnx_out = get_onnxruntime_output(
+            model, [np_data, np_rois, np_batch_indicies])
+        for target, ctx in [('llvm', tvm.cpu())]:
+            tvm_out = get_tvm_output(model, [np_data, np_rois, np_batch_indicies], target, ctx, output_dims,
+                                     output_dtype='float32')
+            tvm.testing.assert_allclose(
+                onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05)
+
+    verify_roi_align((1, 4, 16, 16), 32, 7, 7,
+                     sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((4, 4, 16, 32), 32, 7, 7,
+                     sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 8, 16, 16), 32, 7, 7,
+                     sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 4, 8, 8), 32, 7, 7,
+                     sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 4, 16, 16), 16, 5, 7,
+                     sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 4, 16, 12), 8, 7, 3,
+                     sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 4, 16, 16), 32, 7, 7,
+                     sampling_ratio=0, spatial_scale=0.5)
+    verify_roi_align((3, 4, 12, 16), 32, 7, 7,
+                     sampling_ratio=0, spatial_scale=1.5)
+    verify_roi_align((5, 4, 16, 14), 32, 7, 7,
+                     sampling_ratio=1, spatial_scale=1.0)
+    verify_roi_align((1, 4, 16, 16), 32, 7, 7,
+                     sampling_ratio=2, spatial_scale=1.0)
+
+
 if __name__ == '__main__':
     test_flatten()
     test_reshape()
@@ -2498,3 +2560,4 @@ if __name__ == '__main__':
     test_resize()
     test_nonzero()
     test_topk()
+    test_roialign()