#pragma once
#include "api/CPP/batch_norm.hpp"
#include "primitive_inst.h"
+#include "mutable_data_inst.h"
namespace cldnn
{
program_node& input() const { return get_dependency(0); }
program_node& mean() const { return get_dependency(1); }
program_node& variance() const { return get_dependency(2); }
- program_node& inv_variance() const { return get_dependency(1); };
+ program_node& scale() const
+ {
+ if(get_dependencies().size() >= 5)
+ return get_dependency(3);
+ else
+ return get_dependency(1);
+ }
+ program_node& shift() const
+ {
+ if (get_dependencies().size() >= 5)
+ return get_dependency(4);
+ else
+ return get_dependency(2);
+ }
+ program_node& inv_variance() const
+ {
+ if (get_dependencies().size() == 2)
+ return get_dependency(1);
+ else if (get_dependencies().size() == 6)
+ return get_dependency(5);
+ else
+ return get_dependency(3);
+ };
bool variance_term() const { return !get_primitive()->variance.empty(); }
bool use_global_stats() const { return !get_primitive()->mean.empty() && !get_primitive()->variance.empty(); };
+ bool use_scale_shift() const { return !get_primitive()->scale.empty() && !get_primitive()->shift.empty(); };
bool forwad_pass() const { return !get_primitive()->inv_variance.empty(); };
+ bool calc_mean_var() const { return (use_global_stats() && mean().is_type<mutable_data>() && variance().is_type<mutable_data>()); };
};
memory_impl& mean_memory() const { return dep_memory(1); }
memory_impl& variance_memory() const { return dep_memory(2); }
- memory_impl& inv_variance_memory() const { return dep_memory(1); };
+ memory_impl& scale_memory() const
+ {
+ if (dependencies().size() >= 5)
+ return dep_memory(3);
+ else
+ return dep_memory(1);
+ }
+ memory_impl& shift_memory() const
+ {
+ if (dependencies().size() >= 5)
+ return dep_memory(4);
+ else
+ return dep_memory(2);
+ }
+ memory_impl& inv_variance_memory() const
+ {
+ if (dependencies().size() == 2)
+ return dep_memory(1);
+ else if (dependencies().size() == 6)
+ return dep_memory(5);
+ else
+ return dep_memory(3);
+ };
bool use_global_stats() const { return !argument.mean.empty() && !argument.variance.empty(); };
+ bool use_scale_shift() const { return !argument.scale.empty() && !argument.scale.empty(); };
bool forwad_pass() const { return !argument.inv_variance.empty(); };
+ bool calc_mean_var() const { return node.calc_mean_var(); };
};
using batch_norm_inst = typed_primitive_inst<batch_norm>;