device inference for Adam (#13990)
authorYan Zhu <yzhu@fb.com>
Thu, 15 Nov 2018 01:19:14 +0000 (17:19 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 15 Nov 2018 01:21:08 +0000 (17:21 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13990

to make sure ITER blob lives on CPU.

Reviewed By: xianjiec

Differential Revision: D13056070

fbshipit-source-id: 148edbf745e50e886da3eb99d4e485d11c1924e2

caffe2/python/core_test.py
caffe2/sgd/adam_op.cc

index 2f6dedb..6c23d88 100644 (file)
@@ -707,6 +707,20 @@ class TestInferDevice(test_util.TestCase):
             outputs=["fc_1"]
         )
 
+    def test_infer_device_adam(self):
+        in_options = [self.cuda_option] * 6
+        in_options[5] = self.cpu_option
+        out_options = [self.cuda_option] * 4
+        self._test_op(
+            "Adam",
+            in_options,
+            out_options,
+            op_option=self.cuda_option,
+            inputs=["param", "moment_1", "moment_2", "grad", "lr", "iter"],
+            outputs=["output_param", "output_moment_1", "output_moment_2",
+                "output_grad"]
+        )
+
     def test_infer_device_cross_device(self):
         self._test_op("CopyGPUToCPU", self.cuda_option, self.cpu_option)
         self._test_op("CopyCPUToGPU", self.cpu_option, self.cuda_option)
index 623e93a..d12f676 100644 (file)
@@ -7,6 +7,15 @@ OPERATOR_SCHEMA(Adam)
     .NumInputs(6)
     .NumOutputs(3, 4)
     .AllowInplace({{0, 0}, {1, 1}, {2, 2}})
+    .DeviceInferenceFunction([](const OperatorDef& def) {
+      auto op_device =
+          def.has_device_option() ? def.device_option() : DeviceOption();
+      vector<DeviceOption> in_dev(def.input_size(), op_device);
+      vector<DeviceOption> out_dev(def.output_size(), op_device);
+      // ITER input lives on CPU
+      in_dev[5] = DeviceOption();
+      return std::make_pair(in_dev, out_dev);
+    })
     .SetDoc(R"DOC(
 
 Computes the Adam update (https://arxiv.org/abs/1412.6980) for an