outputs=["fc_1"]
)
+ def test_infer_device_adam(self):
+ in_options = [self.cuda_option] * 6
+ in_options[5] = self.cpu_option
+ out_options = [self.cuda_option] * 4
+ self._test_op(
+ "Adam",
+ in_options,
+ out_options,
+ op_option=self.cuda_option,
+ inputs=["param", "moment_1", "moment_2", "grad", "lr", "iter"],
+ outputs=["output_param", "output_moment_1", "output_moment_2",
+ "output_grad"]
+ )
+
def test_infer_device_cross_device(self):
self._test_op("CopyGPUToCPU", self.cuda_option, self.cpu_option)
self._test_op("CopyCPUToGPU", self.cpu_option, self.cuda_option)
.NumInputs(6)
.NumOutputs(3, 4)
.AllowInplace({{0, 0}, {1, 1}, {2, 2}})
+ .DeviceInferenceFunction([](const OperatorDef& def) {
+ auto op_device =
+ def.has_device_option() ? def.device_option() : DeviceOption();
+ vector<DeviceOption> in_dev(def.input_size(), op_device);
+ vector<DeviceOption> out_dev(def.output_size(), op_device);
+ // ITER input lives on CPU
+ in_dev[5] = DeviceOption();
+ return std::make_pair(in_dev, out_dev);
+ })
.SetDoc(R"DOC(
Computes the Adam update (https://arxiv.org/abs/1412.6980) for an