Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / batch_norm.hpp
index 3933203..7962d71 100644 (file)
@@ -35,7 +35,7 @@ namespace cldnn
 /// 
 /// <b>Algorithm:</b>
 /// @n global stats can be computed as:
-/// @n out[i] = (in[i] - mean[b]) / sqrt(variance[b] + epsilon)
+/// @n out[i] = ( (in[i] - mean[b]) / sqrt(variance[b] + epsilon) ) * scale[b] + shift[b]
 
 struct batch_norm : public primitive_base<batch_norm, CLDNN_PRIMITIVE_DESC(batch_norm)>
 {
@@ -63,6 +63,34 @@ struct batch_norm : public primitive_base<batch_norm, CLDNN_PRIMITIVE_DESC(batch
     {
     }
 
+       /// @brief Constructs batch normalization primitive.
+       /// @param id This primitive id.
+       /// @param input Input primitive id.
+       /// @param mean Primitive id containing mean data.
+       /// @param variance Primitive id containing variance.
+       /// @brief scale Primitive id containing scale.
+       /// @brief shift Primitive id containing shift.
+       /// @param epsilon Epsilon.
+       batch_norm(
+               const primitive_id& id,
+               const primitive_id& input,
+               const primitive_id& mean,
+               const primitive_id& variance,
+               const primitive_id& scale,
+               const primitive_id& shift,
+               float epsilon,
+               const padding& output_padding = padding()
+       )
+               :primitive_base(id, { input }, output_padding)
+               , mean(mean)
+               , variance(variance)
+               , scale(scale)
+               , shift(shift)
+               , inv_variance("")
+               , epsilon(epsilon)
+       {
+       }
+
     /// @brief Constructs batch normalization primitive with mean and variance calculation (used for training).
     /// @param id This primitive id.
     /// @param input Input primitive id.
@@ -83,11 +111,69 @@ struct batch_norm : public primitive_base<batch_norm, CLDNN_PRIMITIVE_DESC(batch
     {
     }
 
+       /// @brief Constructs batch normalization primitive with mean and variance calculation (used for training).
+       /// @param id This primitive id.
+       /// @param input Input primitive id.
+       /// @brief scale Primitive id containing scale.
+       /// @brief shift Primitive id containing shift.
+       /// @param epsilon Epsilon.
+       /// @param inv_variance Primitive id containing inverted variance calculated in this primitive. For inference leave empty.
+       batch_norm(
+               const primitive_id& id,
+               const primitive_id& input,
+               float epsilon,
+               const primitive_id& scale,
+               const primitive_id& shift,
+               const primitive_id& inv_variance = "",
+               const padding& output_padding = padding()
+       )
+               :primitive_base(id, { input }, output_padding)
+               , mean("")
+               , variance("")
+               , scale(scale)
+               , shift(shift)
+               , inv_variance(inv_variance)
+               , epsilon(epsilon)
+       {
+       }
+
+       /// @brief Constructs batch normalization primitive with mean and variance calculation (used for training).
+       /// @param id This primitive id.
+       /// @param input Input primitive id.
+       /// @brief scale Primitive id containing scale.
+       /// @brief shift Primitive id containing shift.
+       /// @brief mean_out Primitive id containing mean output.
+       /// @brief variance_out Primitive id containing variance output.
+       /// @param epsilon Epsilon.
+       /// @param inv_variance Primitive id containing inverted variance calculated in this primitive. For inference leave empty.
+       batch_norm(
+               const primitive_id& id,
+               const primitive_id& input,
+               float epsilon,
+               const primitive_id& mean_out,
+               const primitive_id& variance_out,
+               const primitive_id& scale,
+               const primitive_id& shift,
+               const primitive_id& inv_variance = "",
+               const padding& output_padding = padding()
+       )
+               :primitive_base(id, { input }, output_padding)
+               , mean(mean_out)
+               , variance(variance_out)
+               , scale(scale)
+               , shift(shift)
+               , inv_variance(inv_variance)
+               , epsilon(epsilon)
+       {
+       }
+
     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{batch_norm}
     batch_norm(const dto* dto)
         :primitive_base(dto)
         , mean(dto->mean)
         , variance(dto->variance)
+               , scale(dto->scale)
+               , shift(dto->shift)
         , inv_variance(dto->inv_variance)
         , epsilon(dto->epsilon)
     {
@@ -97,20 +183,36 @@ struct batch_norm : public primitive_base<batch_norm, CLDNN_PRIMITIVE_DESC(batch
     primitive_id mean;
     /// @brief Primitive id containing variance.
     primitive_id variance;
+       /// @brief Primitive id containing scale.
+       primitive_id scale;
+       /// @brief Primitive id containing shift.
+       primitive_id shift;
     /// @brief Primitive id containing inverted variance used in future gradient computing.
     primitive_id inv_variance;
     /// @brief Epsilon.
     float epsilon;
 
 protected:
-    std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override 
-    { 
-        if (!mean.empty() && !variance.empty())
-            return{ mean, variance };
-        else if (!inv_variance.empty())
-            return{ inv_variance };
-        else
-            return{};
+       std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
+       {
+               std::vector<std::reference_wrapper<const primitive_id>> deps;
+
+               if (!mean.empty() && !variance.empty())
+               {
+                       deps.push_back(mean);
+                       deps.push_back(variance);
+               }
+
+               if (!scale.empty() && !shift.empty())
+               {
+                       deps.push_back(scale);
+                       deps.push_back(shift);
+               }
+
+               if (!inv_variance.empty())
+                       deps.push_back(inv_variance);
+
+               return deps;
     }
 
     void update_dto(dto& dto) const override
@@ -118,6 +220,8 @@ protected:
         dto.mean = mean.c_str();
         dto.variance = variance.c_str();
         dto.inv_variance = inv_variance.c_str();
+        dto.scale = scale.c_str();
+        dto.shift = shift.c_str();
         dto.epsilon = epsilon;
     }
 };