2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include <gtest/gtest.h>
19 #include "api/CPP/memory.hpp"
20 #include <api/CPP/input_layout.hpp>
21 #include "api/CPP/apply_adam.hpp"
22 #include <api/CPP/topology.hpp>
23 #include <api/CPP/network.hpp>
24 #include <api/CPP/engine.hpp>
25 #include "test_utils/test_utils.h"
26 #include <api/CPP/reorder.hpp>
27 #include <api/CPP/data.hpp>
28 #include <api/CPP/activation.hpp>
29 #include <api/CPP/mutable_data.hpp>
31 using namespace cldnn;
32 using namespace tests;
34 TEST(apply_adam_gpu, basic_in2x2x3x2_bfyx) {
35 // Test creates topology with two apply adam primitives (t = [0, 1]) with the same output variable which is updated.
39 auto input_grad = memory::allocate(engine, { data_types::f32, format::bfyx,{ 1, 1, 1, 1 } });
40 auto var = memory::allocate(engine, { data_types::f32, format::bfyx,{ 1, 1, 1, 1 } });
41 auto m = memory::allocate(engine, { data_types::f32, format::bfyx,{ 1, 1, 1, 1 } });
42 auto v = memory::allocate(engine, { data_types::f32, format::bfyx,{ 1, 1, 1, 1 } });
43 auto beta1_power = memory::allocate(engine, { data_types::f32, format::bfyx,{ 1, 1, 1, 1 } });
44 auto beta2_power = memory::allocate(engine, { data_types::f32, format::bfyx,{ 1, 1, 1, 1 } });
46 float input_grad_f = 100.f;
52 float beta1_power_f = beta1;
53 float beta2_power_f = beta2;
55 float epsilon = 0.0001f;
58 topology.add(input_layout("input", input_grad.get_layout()));
59 topology.add(mutable_data("m", m));
60 topology.add(mutable_data("v", v));
61 topology.add(data("beta1_power_t1", beta1_power));
62 topology.add(data("beta2_power_t1", beta2_power));
63 topology.add(apply_adam("apply_adam", "input", "m", "v", "beta1_power_t1", "beta2_power_t1", lr, beta1, beta2, epsilon));
64 topology.add(activation("relu", "input", activation_linear, { 4.f, 0.f }));
65 topology.add(activation("beta1_power_t2", "beta1_power_t1", activation_linear, { beta1, 0.f }));
66 topology.add(activation("beta2_power_t2", "beta2_power_t1", activation_linear, { beta2, 0.f }));
67 topology.add(apply_adam("apply_adam2", "relu", "m", "v", "beta1_power_t2", "beta2_power_t2", lr, beta1, beta2, epsilon));
68 topology.add(mutable_data("var", { "apply_adam", "apply_adam2" }, var));
70 set_values(input_grad, {
74 set_values(m, { m_f });
75 set_values(v, { v_f });
76 set_values(beta1_power, { beta1_power_f });
77 set_values(beta2_power, { beta2_power_f });
78 set_values(var, { var_f });
81 bo.set_option(build_option::optimize_data(true));
82 network network(engine, topology, bo);
84 network.set_input_data("input", input_grad);
86 auto outputs = network.execute();
88 auto output = outputs.at("var").get_memory();
89 auto output_ptr = output.pointer<float>();
90 auto m_ptr = m.pointer<float>();
91 auto v_ptr = v.pointer<float>();
93 float lr_t1 = lr * sqrt(1 - beta2_power_f) / (1 - beta1_power_f);
94 float m_t1 = beta1 * m_f + (1 - beta1) * input_grad_f;
95 float v_t1 = beta2 * v_f + (1 - beta2) * input_grad_f * input_grad_f;
96 float result_t1 = var_f - lr_t1 * m_t1 / (sqrt(v_t1) + epsilon);
98 beta1_power_f *= beta1;
99 beta2_power_f *= beta2;
100 float input_grad2_f = input_grad_f * 4;
101 float lr_t2 = lr * sqrt(1 - beta2_power_f) / (1 - beta1_power_f);
102 float m_t2 = beta1 * m_t1 + (1 - beta1) * input_grad2_f;
103 float v_t2 = beta2 * v_t1 + (1 - beta2) * input_grad2_f * input_grad2_f;
104 float result_t2 = result_t1 - lr_t2 * m_t2 / (sqrt(v_t2) + epsilon);
106 EXPECT_NEAR(m_t2, m_ptr[0], 1e-03F);
107 EXPECT_NEAR(v_t2, v_ptr[0], 1e-03F);
108 EXPECT_NEAR(result_t2, output_ptr[0], 1e-03F);