loss_layers.hpp: Doxygen-style documentation
authorJeff Donahue <jeff.donahue@gmail.com>
Thu, 28 Aug 2014 11:17:43 +0000 (04:17 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 3 Sep 2014 17:59:24 +0000 (10:59 -0700)
include/caffe/loss_layers.hpp

index df887b8..1a93e4d 100644 (file)
@@ -19,13 +19,21 @@ namespace caffe {
 
 const float kLOG_THRESHOLD = 1e-20;
 
-/* AccuracyLayer
-  Note: not an actual loss layer! Does not implement backwards step.
 Computes the accuracy of argmax(a) with respect to b.
-*/
+/**
+ * @brief Computes the classification accuracy for a one-of-many
*        classification task.
+ */
 template <typename Dtype>
 class AccuracyLayer : public Layer<Dtype> {
  public:
+  /**
+   * @param param provides AccuracyParameter accuracy_param,
+   *     with AccuracyLayer options:
+   *   - top_k (\b optional, default 1).
+   *     Sets the maximum rank @f$ k @f$ at which a prediction is considered
+   *     correct.  For example, if @f$ k = 5 @f$, a prediction is counted
+   *     correct if the correct label is among the top 5 predicted labels.
+   */
   explicit AccuracyLayer(const LayerParameter& param)
       : Layer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
@@ -39,8 +47,35 @@ class AccuracyLayer : public Layer<Dtype> {
   virtual inline int ExactNumTopBlobs() const { return 1; }
 
  protected:
+  /**
+   * @param bottom input Blob vector (length 2)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the predictions @f$ x @f$, a Blob with values in
+   *      @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of
+   *      the @f$ K = CHW @f$ classes. Each @f$ x_n @f$ is mapped to a predicted
+   *      label @f$ \hat{l}_n @f$ given by its maximal index:
+   *      @f$ \hat{l}_n = \arg\max\limits_k x_{nk} @f$
+   *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+   *      the labels @f$ l @f$, an integer-valued Blob with values
+   *      @f$ l_n \in [0, 1, 2, ..., K - 1] @f$
+   *      indicating the correct class label among the @f$ K @f$ classes
+   * @param top output Blob vector (length 1)
+   *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+   *      the computed accuracy: @f$
+   *        \frac{1}{N} \sum\limits_{n=1}^N \delta\{ \hat{l}_n = l_n \}
+   *      @f$, where @f$
+   *      \delta\{\mathrm{condition}\} = \left\{
+   *         \begin{array}{lr}
+   *            1 & \mbox{if condition} \\
+   *            0 & \mbox{otherwise}
+   *         \end{array} \right.
+   *      @f$
+   */
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+
+
+  /// @brief Not implemented -- AccuracyLayer cannot be used as a loss.
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
     for (int i = 0; i < propagate_down.size(); ++i) {
@@ -51,10 +86,14 @@ class AccuracyLayer : public Layer<Dtype> {
   int top_k_;
 };
 
-/* LossLayer
-  Takes two inputs of same num (a and b), and has no output.
-  The gradient is propagated to a.
-*/
+/**
+ * @brief An interface for Layer%s that take two Blob%s as input -- usually
+ *        (1) predictions and (2) ground-truth labels -- and output a
+ *        singleton Blob representing the loss.
+ *
+ * LossLayers are typically only capable of backpropagating to their first input
+ * -- the predictions.
+ */
 template <typename Dtype>
 class LossLayer : public Layer<Dtype> {
  public:
@@ -64,24 +103,50 @@ class LossLayer : public Layer<Dtype> {
       const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top);
 
   virtual inline int ExactNumBottomBlobs() const { return 2; }
-  // For convenience and backwards compatibility, LossLayer's automatically
-  // allocate a singleton top blob into which they output their singleton loss,
-  // (even if the user didn't specify one in the prototxt, etc.).
+
+  /**
+   * @brief For convenience and backwards compatibility, instruct the Net to
+   *        automatically allocate a single top Blob for LossLayers, into which
+   *        they output their singleton loss, (even if the user didn't specify
+   *        one in the prototxt, etc.).
+   */
   virtual inline bool AutoTopBlobs() const { return true; }
   virtual inline int ExactNumTopBlobs() const { return 1; }
-  // We usually cannot backpropagate to the labels; ignore force_backward for
-  // these inputs.
+  /**
+   * We usually cannot backpropagate to the labels; ignore force_backward for
+   * these inputs.
+   */
   virtual inline bool AllowForceBackward(const int bottom_index) const {
     return bottom_index != 1;
   }
 };
 
-/* EuclideanLossLayer
-  Compute the L_2 distance between the two inputs.
-
-  loss = (1/2 \sum_i (a_i - b_i)^2)
-  a' = 1/I (a - b)
-*/
+/**
+ * @brief Computes the Euclidean (L2) loss @f$
+ *          E = \frac{1}{2N} \sum\limits_{n=1}^N \left| \left| \hat{y}_n - y_n
+ *        \right| \right|_2^2 @f$ for real-valued regression tasks.
+ *
+ * @param bottom input Blob vector (length 2)
+ *   -# @f$ (N \times C \times H \times W) @f$
+ *      the predictions @f$ \hat{y} \in [-\infty, +\infty]@f$
+ *   -# @f$ (N \times C \times H \times W) @f$
+ *      the targets @f$ y \in [-\infty, +\infty]@f$
+ * @param top output Blob vector (length 1)
+ *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+ *      the computed Euclidean loss: @f$ E =
+ *          \frac{1}{2n} \sum\limits_{n=1}^N \left| \left| \hat{y}_n - y_n
+ *        \right| \right|_2^2 @f$
+ *
+ * This can be used for least-squares regression tasks.  An InnerProductLayer
+ * input to a EuclideanLossLayer exactly formulates a linear least squares
+ * regression problem. With non-zero weight decay the problem becomes one of
+ * ridge regression -- see src/caffe/test/test_sgd_solver.cpp for a concrete
+ * example wherein we check that the gradients computed for a Net with exactly
+ * this structure match hand-computed gradient formulas for ridge regression.
+ *
+ * (Note: Caffe, and SGD in general, is certainly \b not the best way to solve
+ * linear least squares problems! We use it only as an instructive example.)
+ */
 template <typename Dtype>
 class EuclideanLossLayer : public LossLayer<Dtype> {
  public:
@@ -93,17 +158,55 @@ class EuclideanLossLayer : public LossLayer<Dtype> {
   virtual inline LayerParameter_LayerType type() const {
     return LayerParameter_LayerType_EUCLIDEAN_LOSS;
   }
-  // Unlike most loss layers, in the EuclideanLossLayer we can backpropagate
-  // to both inputs.
+
+  /**
+   * Unlike most loss layers, in the EuclideanLossLayer we can backpropagate
+   * to both inputs -- override to return true and always allow force_backward.
+   */
   virtual inline bool AllowForceBackward(const int bottom_index) const {
     return true;
   }
 
  protected:
+  /// @copydoc EuclideanLossLayer
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+
+  /**
+   * @brief Computes the Euclidean error gradient w.r.t. the inputs.
+   *
+   * Unlike other children of LossLayer, EuclideanLossLayer \b can compute
+   * gradients with respect to the label inputs bottom[1] (but still only will
+   * if propagate_down[1] is set, due to being produced by learnable parameters
+   * or if force_backward is set). In fact, this layer is "commutative" -- the
+   * result is the same regardless of the order of the two bottoms.
+   *
+   * @param top output Blob vector (length 1), providing the error gradient with
+   *      respect to the outputs
+   *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+   *      This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$,
+   *      as @f$ \lambda @f$ is the coefficient of this layer's output
+   *      @f$\ell_i@f$ in the overall Net loss
+   *      @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence
+   *      @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$.
+   *      (*Assuming that this top Blob is not used as a bottom (input) by any
+   *      other layer of the Net.)
+   * @param propagate_down see Layer::Backward.
+   * @param bottom input Blob vector (length 2)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the predictions @f$\hat{y}@f$; Backward fills their diff with
+   *      gradients @f$
+   *        \frac{\partial E}{\partial \hat{y}} =
+   *            \frac{1}{n} \sum\limits_{n=1}^N (\hat{y}_n - y_n)
+   *      @f$ if propagate_down[0]
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the targets @f$y@f$; Backward fills their diff with gradients
+   *      @f$ \frac{\partial E}{\partial y} =
+   *          \frac{1}{n} \sum\limits_{n=1}^N (y_n - \hat{y}_n)
+   *      @f$ if propagate_down[1]
+   */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
@@ -112,8 +215,49 @@ class EuclideanLossLayer : public LossLayer<Dtype> {
   Blob<Dtype> diff_;
 };
 
-/* HingeLossLayer
-*/
+/**
+ * @brief Computes the hinge loss for a one-of-many classification task.
+ *
+ * @param bottom input Blob vector (length 2)
+ *   -# @f$ (N \times C \times H \times W) @f$
+ *      the predictions @f$ t @f$, a Blob with values in
+ *      @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of
+ *      the @f$ K = CHW @f$ classes. In an SVM, @f$ t @f$ is the result of
+ *      taking the inner product @f$ X^T W @f$ of the D-dimensional features
+ *      @f$ X \in \mathcal{R}^{D \times N} @f$ and the learned hyperplane
+ *      parameters @f$ W \in \mathcal{R}^{D \times K} @f$, so a Net with just
+ *      an InnerProductLayer (with num_output = D) providing predictions to a
+ *      HingeLossLayer and no other learnable parameters or losses is
+ *      equivalent to an SVM.
+ *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+ *      the labels @f$ l @f$, an integer-valued Blob with values
+ *      @f$ l_n \in [0, 1, 2, ..., K - 1] @f$
+ *      indicating the correct class label among the @f$ K @f$ classes
+ * @param top output Blob vector (length 1)
+ *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+ *      the computed hinge loss: @f$ E =
+ *        \frac{1}{N} \sum\limits_{n=1}^N \sum\limits_{k=1}^K
+ *        [\max(0, 1 - \delta\{l_n = k\} t_{nk})] ^ p
+ *      @f$, for the @f$ L^p @f$ norm
+ *      (defaults to @f$ p = 1 @f$, the L1 norm; L2 norm, as in L2-SVM,
+ *      is also available), and @f$
+ *      \delta\{\mathrm{condition}\} = \left\{
+ *         \begin{array}{lr}
+ *            1 & \mbox{if condition} \\
+ *           -1 & \mbox{otherwise}
+ *         \end{array} \right.
+ *      @f$
+ *
+ * In an SVM, @f$ t \in \mathcal{R}^{N \times K} @f$ is the result of taking
+ * the inner product @f$ X^T W @f$ of the features
+ * @f$ X \in \mathcal{R}^{D \times N} @f$
+ * and the learned hyperplane parameters
+ * @f$ W \in \mathcal{R}^{D \times K} @f$. So, a Net with just an
+ * InnerProductLayer (with num_output = @f$k@f$) providing predictions to a
+ * HingeLossLayer is equivalent to an SVM (assuming it has no other learned
+ * outside the InnerProductLayer and no other losses outside the
+ * HingeLossLayer).
+ */
 template <typename Dtype>
 class HingeLossLayer : public LossLayer<Dtype> {
  public:
@@ -125,14 +269,73 @@ class HingeLossLayer : public LossLayer<Dtype> {
   }
 
  protected:
+  /// @copydoc HingeLossLayer
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+
+  /**
+   * @brief Computes the hinge loss error gradient w.r.t. the predictions.
+   *
+   * Gradients cannot be computed with respect to the label inputs (bottom[1]),
+   * so this method ignores bottom[1] and requires !propagate_down[1], crashing
+   * if propagate_down[1] is set.
+   *
+   * @param top output Blob vector (length 1), providing the error gradient with
+   *      respect to the outputs
+   *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+   *      This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$,
+   *      as @f$ \lambda @f$ is the coefficient of this layer's output
+   *      @f$\ell_i@f$ in the overall Net loss
+   *      @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence
+   *      @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$.
+   *      (*Assuming that this top Blob is not used as a bottom (input) by any
+   *      other layer of the Net.)
+   * @param propagate_down see Layer::Backward.
+   *      propagate_down[1] must be false as we can't compute gradients with
+   *      respect to the labels.
+   * @param bottom input Blob vector (length 2)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the predictions @f$t@f$; Backward computes diff
+   *      @f$ \frac{\partial E}{\partial t} @f$
+   *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+   *      the labels -- ignored as we can't compute their error gradients
+   */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
-/* InfogainLossLayer
-*/
+/**
+ * @brief A generalization of MultinomialLogisticLossLayer that takes an
+ *        "information gain" (infogain) matrix specifying the "value" of all label
+ *        pairs.
+ *
+ * Equivalent to the MultinomialLogisticLossLayer if the infogain matrix is the
+ * identity.
+ *
+ * @param bottom input Blob vector (length 2-3)
+ *   -# @f$ (N \times C \times H \times W) @f$
+ *      the predictions @f$ \hat{p} @f$, a Blob with values in
+ *      @f$ [0, 1] @f$ indicating the predicted probability of each of the
+ *      @f$ K = CHW @f$ classes.  Each prediction vector @f$ \hat{p}_n @f$
+ *      should sum to 1 as in a probability distribution: @f$
+ *      \forall n \sum\limits_{k=1}^K \hat{p}_{nk} = 1 @f$.
+ *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+ *      the labels @f$ l @f$, an integer-valued Blob with values
+ *      @f$ l_n \in [0, 1, 2, ..., K - 1] @f$
+ *      indicating the correct class label among the @f$ K @f$ classes
+ *   -# @f$ (1 \times 1 \times K \times K) @f$
+ *      (\b optional) the infogain matrix @f$ H @f$.  This must be provided as
+ *      the third bottom blob input if not provided as the infogain_mat in the
+ *      InfogainLossParameter. If @f$ H = I @f$, this layer is equivalent to the
+ *      MultinomialLogisticLossLayer.
+ * @param top output Blob vector (length 1)
+ *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+ *      the computed infogain multinomial logistic loss: @f$ E =
+ *        \frac{-1}{N} \sum\limits_{n=1}^N H_{l_n} \log(\hat{p}_n) =
+ *        \frac{-1}{N} \sum\limits_{n=1}^N \sum\limits_{k=1}^{K} H_{l_n,k}
+ *        \log(\hat{p}_{n,k})
+ *      @f$, where @f$ H_{l_n} @f$ denotes row @f$l_n@f$ of @f$H@f$.
+ */
 template <typename Dtype>
 class InfogainLossLayer : public LossLayer<Dtype> {
  public:
@@ -141,7 +344,7 @@ class InfogainLossLayer : public LossLayer<Dtype> {
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
 
-  // InfogainLossLayer takes 2-3 bottom blobs; if there are 3 the third should
+  // InfogainLossLayer takes 2-3 bottom Blobs; if there are 3 the third should
   // be the infogain matrix.  (Otherwise the infogain matrix is loaded from a
   // file specified by LayerParameter.)
   virtual inline int ExactNumBottomBlobs() const { return -1; }
@@ -153,16 +356,77 @@ class InfogainLossLayer : public LossLayer<Dtype> {
   }
 
  protected:
+  /// @copydoc InfogainLossLayer
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+
+  /**
+   * @brief Computes the infogain loss error gradient w.r.t. the predictions.
+   *
+   * Gradients cannot be computed with respect to the label inputs (bottom[1]),
+   * so this method ignores bottom[1] and requires !propagate_down[1], crashing
+   * if propagate_down[1] is set. (The same applies to the infogain matrix, if
+   * provided as bottom[2] rather than in the layer_param.)
+   *
+   * @param top output Blob vector (length 1), providing the error gradient
+   *      with respect to the outputs
+   *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+   *      This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$,
+   *      as @f$ \lambda @f$ is the coefficient of this layer's output
+   *      @f$\ell_i@f$ in the overall Net loss
+   *      @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence
+   *      @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$.
+   *      (*Assuming that this top Blob is not used as a bottom (input) by any
+   *      other layer of the Net.)
+   * @param propagate_down see Layer::Backward.
+   *      propagate_down[1] must be false as we can't compute gradients with
+   *      respect to the labels (similarly for propagate_down[2] and the
+   *      infogain matrix, if provided as bottom[2])
+   * @param bottom input Blob vector (length 2-3)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the predictions @f$ \hat{p} @f$; Backward computes diff
+   *      @f$ \frac{\partial E}{\partial \hat{p}} @f$
+   *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+   *      the labels -- ignored as we can't compute their error gradients
+   *   -# @f$ (1 \times 1 \times K \times K) @f$
+   *      (\b optional) the information gain matrix -- ignored as its error
+   *      gradient computation is not implemented.
+   */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   Blob<Dtype> infogain_;
 };
 
-/* MultinomialLogisticLossLayer
-*/
+/**
+ * @brief Computes the multinomial logistic loss for a one-of-many
+ *        classification task, directly taking a predicted probability
+ *        distribution as input.
+ *
+ * When predictions are not already a probability distribution, you should
+ * instead use the SoftmaxWithLossLayer, which maps predictions to a
+ * distribution using the SoftmaxLayer, before computing the multinomial
+ * logistic loss. The SoftmaxWithLossLayer should be preferred over separate
+ * SoftmaxLayer + MultinomialLogisticLossLayer
+ * as its gradient computation is more numerically stable.
+ *
+ * @param bottom input Blob vector (length 2)
+ *   -# @f$ (N \times C \times H \times W) @f$
+ *      the predictions @f$ \hat{p} @f$, a Blob with values in
+ *      @f$ [0, 1] @f$ indicating the predicted probability of each of the
+ *      @f$ K = CHW @f$ classes.  Each prediction vector @f$ \hat{p}_n @f$
+ *      should sum to 1 as in a probability distribution: @f$
+ *      \forall n \sum\limits_{k=1}^K \hat{p}_{nk} = 1 @f$.
+ *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+ *      the labels @f$ l @f$, an integer-valued Blob with values
+ *      @f$ l_n \in [0, 1, 2, ..., K - 1] @f$
+ *      indicating the correct class label among the @f$ K @f$ classes
+ * @param top output Blob vector (length 1)
+ *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+ *      the computed multinomial logistic loss: @f$ E =
+ *        \frac{-1}{N} \sum\limits_{n=1}^N \log(\hat{p}_{n,l_n})
+ *      @f$
+ */
 template <typename Dtype>
 class MultinomialLogisticLossLayer : public LossLayer<Dtype> {
  public:
@@ -176,14 +440,71 @@ class MultinomialLogisticLossLayer : public LossLayer<Dtype> {
   }
 
  protected:
+  /// @copydoc MultinomialLogisticLossLayer
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+
+  /**
+   * @brief Computes the multinomial logistic loss error gradient w.r.t. the
+   *        predictions.
+   *
+   * Gradients cannot be computed with respect to the label inputs (bottom[1]),
+   * so this method ignores bottom[1] and requires !propagate_down[1], crashing
+   * if propagate_down[1] is set.
+   *
+   * @param top output Blob vector (length 1), providing the error gradient with
+   *      respect to the outputs
+   *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+   *      This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$,
+   *      as @f$ \lambda @f$ is the coefficient of this layer's output
+   *      @f$\ell_i@f$ in the overall Net loss
+   *      @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence
+   *      @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$.
+   *      (*Assuming that this top Blob is not used as a bottom (input) by any
+   *      other layer of the Net.)
+   * @param propagate_down see Layer::Backward.
+   *      propagate_down[1] must be false as we can't compute gradients with
+   *      respect to the labels.
+   * @param bottom input Blob vector (length 2)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the predictions @f$ \hat{p} @f$; Backward computes diff
+   *      @f$ \frac{\partial E}{\partial \hat{p}} @f$
+   *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+   *      the labels -- ignored as we can't compute their error gradients
+   */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
-/* SigmoidCrossEntropyLossLayer
-*/
+/**
+ * @brief Computes the cross-entropy (logistic) loss @f$
+ *          E = \frac{-1}{n} \sum\limits_{n=1}^N \left[
+ *                  p_n \log \hat{p}_n +
+ *                  (1 - p_n) \log(1 - \hat{p}_n)
+ *              \right]
+ *        @f$, often used for predicting targets interpreted as probabilities.
+ *
+ * This layer is implemented rather than separate
+ * SigmoidLayer + CrossEntropyLayer
+ * as its gradient computation is more numerically stable.
+ * At test time, this layer can be replaced simply by a SigmoidLayer.
+ *
+ * @param bottom input Blob vector (length 2)
+ *   -# @f$ (N \times C \times H \times W) @f$
+ *      the scores @f$ x \in [-\infty, +\infty]@f$,
+ *      which this layer maps to probability predictions
+ *      @f$ \hat{p}_n = \sigma(x_n) \in [0, 1] @f$
+ *      using the sigmoid function @f$ \sigma(.) @f$ (see SigmoidLayer).
+ *   -# @f$ (N \times C \times H \times W) @f$
+ *      the targets @f$ y \in [0, 1] @f$
+ * @param top output Blob vector (length 1)
+ *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+ *      the computed cross-entropy loss: @f$
+ *          E = \frac{-1}{n} \sum\limits_{n=1}^N \left[
+ *                  p_n \log \hat{p}_n + (1 - p_n) \log(1 - \hat{p}_n)
+ *              \right]
+ *      @f$
+ */
 template <typename Dtype>
 class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
  public:
@@ -199,34 +520,88 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
   }
 
  protected:
+  /// @copydoc SigmoidCrossEntropyLossLayer
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+
+  /**
+   * @brief Computes the sigmoid cross-entropy loss error gradient w.r.t. the
+   *        predictions.
+   *
+   * Gradients cannot be computed with respect to the target inputs (bottom[1]),
+   * so this method ignores bottom[1] and requires !propagate_down[1], crashing
+   * if propagate_down[1] is set.
+   *
+   * @param top output Blob vector (length 1), providing the error gradient with
+   *      respect to the outputs
+   *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+   *      This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$,
+   *      as @f$ \lambda @f$ is the coefficient of this layer's output
+   *      @f$\ell_i@f$ in the overall Net loss
+   *      @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence
+   *      @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$.
+   *      (*Assuming that this top Blob is not used as a bottom (input) by any
+   *      other layer of the Net.)
+   * @param propagate_down see Layer::Backward.
+   *      propagate_down[1] must be false as gradient computation with respect
+   *      to the targets is not implemented.
+   * @param bottom input Blob vector (length 2)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the predictions @f$x@f$; Backward computes diff
+   *      @f$ \frac{\partial E}{\partial x} =
+   *          \frac{1}{n} \sum\limits_{n=1}^N (\hat{p}_n - p_n)
+   *      @f$
+   *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+   *      the labels -- ignored as we can't compute their error gradients
+   */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
+  /// The internal SigmoidLayer used to map predictions to probabilities.
   shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
-  // sigmoid_output stores the output of the sigmoid layer.
+  /// sigmoid_output stores the output of the SigmoidLayer.
   shared_ptr<Blob<Dtype> > sigmoid_output_;
-  // Vector holders to call the underlying sigmoid layer forward and backward.
+  /// bottom vector holder to call the underlying SigmoidLayer::Forward
   vector<Blob<Dtype>*> sigmoid_bottom_vec_;
+  /// top vector holder to call the underlying SigmoidLayer::Forward
   vector<Blob<Dtype>*> sigmoid_top_vec_;
 };
 
 // Forward declare SoftmaxLayer for use in SoftmaxWithLossLayer.
 template <typename Dtype> class SoftmaxLayer;
 
-/* SoftmaxWithLossLayer
-  Implements softmax and computes the loss.
-
-  It is preferred over separate softmax + multinomiallogisticloss
-  layers due to more numerically stable gradients.
-
-  In test, this layer could be replaced by simple softmax layer.
-*/
+/**
+ * @brief Computes the multinomial logistic loss for a one-of-many
+ *        classification task, passing real-valued predictions through a
+ *        softmax to get a probability distribution over classes.
+ *
+ * This layer should be preferred over separate
+ * SoftmaxLayer + MultinomialLogisticLossLayer
+ * as its gradient computation is more numerically stable.
+ * At test time, this layer can be replaced simply by a SoftmaxLayer.
+ *
+ * @param bottom input Blob vector (length 2)
+ *   -# @f$ (N \times C \times H \times W) @f$
+ *      the predictions @f$ x @f$, a Blob with values in
+ *      @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of
+ *      the @f$ K = CHW @f$ classes. This layer maps these scores to a
+ *      probability distribution over classes using the softmax function
+ *      @f$ \hat{p}_{nk} = \exp(x_{nk}) /
+ *      \left[\sum_{k'} \exp(x_{nk'})\right] @f$ (see SoftmaxLayer).
+ *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+ *      the labels @f$ l @f$, an integer-valued Blob with values
+ *      @f$ l_n \in [0, 1, 2, ..., K - 1] @f$
+ *      indicating the correct class label among the @f$ K @f$ classes
+ * @param top output Blob vector (length 1)
+ *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+ *      the computed cross-entropy classification loss: @f$ E =
+ *        \frac{-1}{N} \sum\limits_{n=1}^N \log(\hat{p}_{n,l_n})
+ *      @f$, for softmax output class probabilites @f$ \hat{p} @f$
+ */
 template <typename Dtype>
 class SoftmaxWithLossLayer : public LossLayer<Dtype> {
  public:
@@ -247,20 +622,51 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> {
   virtual inline int MaxTopBlobs() const { return 2; }
 
  protected:
+  /// @copydoc SoftmaxWithLossLayer
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+
+  /**
+   * @brief Computes the softmax loss error gradient w.r.t. the predictions.
+   *
+   * Gradients cannot be computed with respect to the label inputs (bottom[1]),
+   * so this method ignores bottom[1] and requires !propagate_down[1], crashing
+   * if propagate_down[1] is set.
+   *
+   * @param top output Blob vector (length 1), providing the error gradient with
+   *      respect to the outputs
+   *   -# @f$ (1 \times 1 \times 1 \times 1) @f$
+   *      This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$,
+   *      as @f$ \lambda @f$ is the coefficient of this layer's output
+   *      @f$\ell_i@f$ in the overall Net loss
+   *      @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence
+   *      @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$.
+   *      (*Assuming that this top Blob is not used as a bottom (input) by any
+   *      other layer of the Net.)
+   * @param propagate_down see Layer::Backward.
+   *      propagate_down[1] must be false as we can't compute gradients with
+   *      respect to the labels.
+   * @param bottom input Blob vector (length 2)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the predictions @f$ x @f$; Backward computes diff
+   *      @f$ \frac{\partial E}{\partial x} @f$
+   *   -# @f$ (N \times 1 \times 1 \times 1) @f$
+   *      the labels -- ignored as we can't compute their error gradients
+   */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
+  /// The internal SoftmaxLayer used to map predictions to a distribution.
   shared_ptr<SoftmaxLayer<Dtype> > softmax_layer_;
-  // prob stores the output probability of the layer.
+  /// prob stores the output probability predictions from the SoftmaxLayer.
   Blob<Dtype> prob_;
-  // Vector holders to call the underlying softmax layer forward and backward.
+  /// bottom vector holder used in call to the underlying SoftmaxLayer::Forward
   vector<Blob<Dtype>*> softmax_bottom_vec_;
+  /// top vector holder used in call to the underlying SoftmaxLayer::Forward
   vector<Blob<Dtype>*> softmax_top_vec_;
 };