Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / roi_pooling.hpp
index 3007f8c..1b5afa6 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2017 Intel Corporation
+// Copyright (c) 2017-2019 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -39,43 +39,58 @@ struct roi_pooling : public primitive_base<roi_pooling, CLDNN_PRIMITIVE_DESC(roi
         const primitive_id& input_data,
         const primitive_id& input_rois,
         pooling_mode mode,
+        bool position_sensitive,
         int pooled_width,
         int pooled_height,
         float spatial_scale,
-        int group_sz = 0,
+        int output_dim = 0,
+        int spatial_bins_x = 1,
+        int spatial_bins_y = 1,
         const padding& output_padding = padding()
         )
         : primitive_base(id, {input_data, input_rois}, output_padding)
         , mode(mode)
+        , position_sensitive(position_sensitive)
         , pooled_width(pooled_width)
         , pooled_height(pooled_height)
         , spatial_scale(spatial_scale)
-        , group_sz(group_sz)
+        , output_dim(output_dim)
+        , spatial_bins_x(spatial_bins_x)
+        , spatial_bins_y(spatial_bins_y)
     {}
 
     roi_pooling(const dto* dto)
         : primitive_base(dto)
         , mode(static_cast<pooling_mode>(dto->mode))
+        , position_sensitive(dto->position_sensitive)
         , pooled_width(dto->pooled_width)
         , pooled_height(dto->pooled_height)
         , spatial_scale(dto->spatial_scale)
-        , group_sz(dto->group_sz)
+        , output_dim(dto->output_dim)
+        , spatial_bins_x(dto->spatial_bins_x)
+        , spatial_bins_y(dto->spatial_bins_y)
     {}
 
     pooling_mode mode;
+    bool position_sensitive;
     int pooled_width;
     int pooled_height;
     float spatial_scale;
-    int group_sz;
+    int output_dim;
+    int spatial_bins_x;
+    int spatial_bins_y;
 
 protected:
     void update_dto(dto& dto) const override
     {
         dto.mode = static_cast<int32_t>(mode);
+        dto.position_sensitive = position_sensitive;
         dto.pooled_width = pooled_width;
         dto.pooled_height = pooled_height;
         dto.spatial_scale = spatial_scale;
-        dto.group_sz = group_sz;
+        dto.output_dim = output_dim;
+        dto.spatial_bins_x = spatial_bins_x;
+        dto.spatial_bins_y = spatial_bins_y;
     }
 };