RNN operators should inherit step_net device_options (#16086)
authorNikita Shulga <nikita.shulga@oculus.com>
Fri, 18 Jan 2019 19:33:40 +0000 (11:33 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 18 Jan 2019 19:36:38 +0000 (11:36 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16086

[caffe2] RNN operators should inherit step_net device_options
According to NetDef documentaiton, if network has a specific device option it applies to all network operators that do not explicitly specifiy it.
But this does not seem to be the case for RecurrentNetwork operators

Reviewed By: orionr

Differential Revision: D13699552

fbshipit-source-id: 14529bc9504e3b02f763e3c2429be21e46f82b68

caffe2/operators/rnn/recurrent_network_executor.h

index 3300f78..2022cc8 100644 (file)
@@ -37,7 +37,16 @@ class RecurrentNetworkExecutorBase {
       : step_net_def_(step_net_def),
         recurrent_input_map_(recurrent_input_map),
         timestep_blob_(timestep_blob) {
+    const bool net_def_has_device_option = step_net_def_.has_device_option();
     for (int i = 0; i < step_net_def_.op_size(); i++) {
+      if (!step_net_def_.op(i).has_device_option() &&
+          net_def_has_device_option) {
+        // In the case that the operator def does not specify a device option
+        // but the net def has a default option, we copy the device option over
+        // to the operator def.
+        step_net_def_.mutable_op(i)->mutable_device_option()->CopyFrom(
+            step_net_def_.device_option());
+      }
       op_deps_.push_back(op_deps(i));
     }
   }