///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
-#include "../C/apply_adam.h"
#include "primitive.hpp"
#include <vector>
/// @n float v[t] = beta2 * v[t-1] + (1 - beta2) * grad[t] * grad[t];
/// @n float result = result - lr[t] * m[t] / (sqrt(v[t]) + epsilon);
-struct apply_adam : public primitive_base<apply_adam, CLDNN_PRIMITIVE_DESC(apply_adam)> {
+struct apply_adam : public primitive_base<apply_adam> {
CLDNN_DECLARE_PRIMITIVE(apply_adam)
/// @brief Constructs apply Adam primitive.
epsilon(epsilon),
dependency_id(dependency_id) {}
- /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{apply_adam}
- apply_adam(const dto* dto)
- : primitive_base(dto),
- m(dto->m),
- v(dto->v),
- beta1_power(dto->beta1_power),
- beta2_power(dto->beta2_power),
- lr(dto->lr),
- beta1(dto->beta1),
- beta2(dto->beta2),
- epsilon(dto->epsilon),
- dependency_id(dto->dependency_id) {}
-
/// @brief Primitive id containing m data.
primitive_id m;
/// @brief Primitive id containing v data.
ret.push_back(dependency_id);
return ret;
}
-
- void update_dto(dto& dto) const override {
- dto.m = m.c_str();
- dto.v = v.c_str();
- dto.beta1_power = beta1_power.c_str();
- dto.beta2_power = beta2_power.c_str();
- dto.lr = lr;
- dto.beta1 = beta1;
- dto.beta2 = beta2;
- dto.epsilon = epsilon;
- dto.dependency_id = dependency_id.c_str();
- }
};
/// @}
/// @}