[nnc] Support for LSTM Caffe layer (#2721)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Wed, 9 Jan 2019 16:01:01 +0000 (19:01 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Wed, 9 Jan 2019 16:01:01 +0000 (19:01 +0300)
* Add support for LSTM Caffe layer.
* Rename "SoftmaxLoss" Caffe layer to "SoftmaxWithLoss" as it should be.
* Add support for non-default axis for InnerProduct Caffe layer.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
contrib/nnc/passes/caffe_frontend/caffe_importer.cpp
contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp
contrib/nnc/passes/caffe_frontend/caffe_op_creator.h
contrib/nnc/passes/caffe_frontend/caffe_op_types.h
contrib/nnc/unittests/caffe_frontend/unsupported_caffe_model.cpp

index 149c90c..62f64ce 100644 (file)
@@ -138,6 +138,9 @@ void CaffeImporter::createMIRNodesFromLayer(const LayerParameter& layer) {
     case CaffeOpType::sigmoid:
       outputs = _opCreator->convertSigmoid(layer, inputs);
       break;
+    case CaffeOpType::LSTM:
+      outputs = _opCreator->convertLSTM(layer, inputs);
+      break;
     default:
       assert(false && "All unsupported types should have been found before this pass.");
   }
@@ -169,15 +172,13 @@ void CaffeImporter::collectUnsupportedOp(const LayerParameter& lp) {
     case CaffeOpType::embed:
     case CaffeOpType::sigmoid:
     case CaffeOpType::tanh:
+    case CaffeOpType::innerProduct:
       // No checks
       break;
     case CaffeOpType::deconvolution:
     case CaffeOpType::convolution:
       _opCreator->checkConvolution(lp.convolution_param(), _problemsOpSet);
       break;
-    case CaffeOpType::innerProduct:
-      _opCreator->checkInnerProduct(lp.inner_product_param(), _problemsOpSet);
-      break;
     case CaffeOpType::pooling:
       _opCreator->checkPooling(lp.pooling_param(), _problemsOpSet);
       break;
@@ -190,6 +191,9 @@ void CaffeImporter::collectUnsupportedOp(const LayerParameter& lp) {
     case CaffeOpType::batchNorm:
       _opCreator->checkBatchNorm(lp, _problemsOpSet);
       break;
+    case CaffeOpType::LSTM:
+      _opCreator->checkLSTM(lp, _problemsOpSet);
+      break;
     default:
       _problemsOpSet.insert(lp.type() + ": unsupported layer");
       break;
@@ -275,7 +279,7 @@ const std::map<std::string, CaffeOpType> CaffeImporter::_operatorTypes = {
         {"Sigmoid",                 CaffeOpType::sigmoid},
         {"Silence",                 CaffeOpType::silence},
         {"Softmax",                 CaffeOpType::softmax},
-        {"SoftmaxLoss",             CaffeOpType::softmaxLoss},
+        {"SoftmaxWithLoss",         CaffeOpType::softmaxWithLoss},
         {"SPP",                     CaffeOpType::SPP},
         {"Split",                   CaffeOpType::split},
         {"Slice",                   CaffeOpType::slice},
index 74b1570..0effb9f 100644 (file)
@@ -32,6 +32,7 @@
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/SigmoidOp.h"
+#include "core/modelIR/operations/SliceOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/TransposeOp.h"
@@ -76,6 +77,69 @@ mir::IODescriptor CaffeOpCreator::convertMIRToCaffe(const mir::IODescriptor& arg
   }
 }
 
+mir::IODescriptor CaffeOpCreator::createAdd(mir::IODescriptor arg1, mir::IODescriptor arg2) {
+  std::vector<IODescriptor> inputs{arg1, arg2};
+  auto op = createOp<ops::ElementwiseOp>("", inputs, ops::ElementwiseOp::OpType::add);
+  return op->getOutput(0);
+}
+
+mir::IODescriptor CaffeOpCreator::createMul(mir::IODescriptor arg1, mir::IODescriptor arg2) {
+  std::vector<IODescriptor> inputs{arg1, arg2};
+  auto op = createOp<ops::ElementwiseOp>("", inputs, ops::ElementwiseOp::OpType::mul);
+  return op->getOutput(0);
+}
+
+/// @brief Split arg into @p num_parts equal parts along @p axis axis.
+std::vector<mir::IODescriptor>
+CaffeOpCreator::createSplit(mir::IODescriptor arg, int32_t num_parts, int32_t axis) {
+  const auto& arg_shape = arg.op->getOutputShape(arg.index);
+
+  assert(axis >= 0 && axis < arg_shape.rank());
+  int32_t part_size = arg_shape.dim(axis) / num_parts;
+  assert(part_size * num_parts == arg_shape.dim(axis));
+
+  Shape starts(arg_shape.rank());
+  Shape sizes(arg_shape);
+  sizes.dim(axis) = part_size;
+
+  std::vector<mir::IODescriptor> outputs(num_parts);
+  for (int32_t i = 0; i < num_parts; ++i) {
+    outputs[i] = createOp<ops::SliceOp>("", arg, starts, sizes)->getOutput(0);
+    starts.dim(axis) += part_size;
+  }
+
+  return outputs;
+}
+
+/// @brief Helper function for creating FullyConnected operation with non-square input.
+IODescriptor
+CaffeOpCreator::createFullyConnected(const mir::IODescriptor& input,
+                                     const mir::TensorVariant& weights,
+                                     int32_t axis) {
+  const auto& input_shape = input.op->getOutputShape(input.index);
+  const auto& weights_shape = weights.getShape();
+
+  assert(axis >= 0 && axis < input_shape.rank());
+  assert(weights_shape.rank() == 2);
+
+  // Result shape is: input.shape[0:axis] + weights.shape[1].
+  Shape result_shape = input_shape;
+  result_shape.resize(axis + 1);
+  result_shape.dim(axis) = weights_shape.dim(1);
+
+  // Flatten input to 2-D shape.
+  int32_t outer_size = 1;
+  for (int32_t i = 0; i < axis; ++i)
+    outer_size *= input_shape.dim(i);
+  int32_t inner_size = 1;
+  for (int32_t i = axis; i < input_shape.rank(); ++i)
+    inner_size *= input_shape.dim(i);
+
+  auto flatten = createOp<ops::ReshapeOp>("", input, Shape{outer_size, inner_size})->getOutput(0);
+  auto fc = createOp<ops::FullyConnectedOp>("", flatten, weights)->getOutput(0);
+  return createOp<ops::ReshapeOp>("", fc, result_shape)->getOutput(0);
+}
+
 TensorVariant CaffeOpCreator::convertBlob(const BlobProto& blob) {
   size_t element_size;
   const char* src_data;
@@ -254,20 +318,6 @@ CaffeOpCreator::convertDeconvolution(const caffe::LayerParameter& layer,
   }
 }
 
-void CaffeOpCreator::checkInnerProduct(const InnerProductParameter& opts,
-                                       std::set<std::string>& problemsOpSet) {
-  if (opts.axis() != 1)
-    problemsOpSet.insert("InnerProduct: unsupported axis");
-}
-
-/**
- * @brief Converts Caffe InnerProduct layer to Model IR FullyConnected operation.
- * @todo InnerProduct layer take NCHW input and flattens the CHW part. We insert the
- * Model IR Reshape operation here to account for that, but its result may not be
- * equivalent to how Caffe flattens inputs. Need to check how Caffe does this and
- * implement it correctly.
- * @todo Support axis and transpose parameters as needed.
- */
 std::vector<mir::IODescriptor>
 CaffeOpCreator::convertInnerProduct(const LayerParameter& layer,
                                     const std::vector<mir::IODescriptor>& inputs) {
@@ -277,20 +327,15 @@ CaffeOpCreator::convertInnerProduct(const LayerParameter& layer,
   if (!params.transpose())
     weights = transposeTensor<1, 0>(weights);
 
-  auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
-  // Transform input into 2-D tensor by flattening axes before/after params.axis().
-  assert(params.axis() == 1);
-  Shape shape{input_shape.dim(0), input_shape.numElements() / input_shape.dim(0)};
-  auto reshape = createOp<ops::ReshapeOp>(layer.name() + ".reshape", inputs[0], shape);
-  auto result = createOp<ops::FullyConnectedOp>(layer.name(), reshape->getOutput(0), weights);
+  auto result = createFullyConnected(inputs[0], weights, params.axis());
 
   // Add the bias, if any.
   if (params.bias_term()) {
-    auto bias_weights = convertBlob(layer.blobs(1));
-    result = createOp<ops::BiasAddOp>(layer.name() + ".bias", result->getOutput(0), bias_weights);
+    const auto& bias_weights = convertBlob(layer.blobs(1));
+    result = createOp<ops::BiasAddOp>(layer.name() + ".bias", result, bias_weights)->getOutput(0);
   }
 
-  return {result->getOutput(0)};
+  return {result};
 }
 
 std::vector<mir::IODescriptor>
@@ -611,4 +656,132 @@ CaffeOpCreator::convertSplit(const caffe::LayerParameter& layer,
   return outputs;
 }
 
+void CaffeOpCreator::checkLSTM(const caffe::LayerParameter& layer,
+                               std::set<std::string>& problems_op_set) {
+  const auto& params = layer.recurrent_param();
+  if (params.expose_hidden())
+    problems_op_set.insert("LSTM: parameter 'expose_hidden' has unsupported value: " +
+                           std::to_string(params.expose_hidden()));
+}
+
+static TensorVariant createZeroedTensor(const mir::Shape& shape) {
+  // For now it is hardcoded float32.
+  auto elem_type = mir::DTYPE::FLOAT32;
+  auto elem_size = sizeof(float);
+  auto num_elems = static_cast<std::size_t>(shape.numElements());
+  std::shared_ptr<char> data(new char[num_elems * elem_size], std::default_delete<char[]>());
+  std::memset(data.get(), 0, num_elems * elem_size);
+  return TensorVariant(shape, data, elem_type, elem_size);
+}
+
+/* See the following links for details on implementation:
+ * https://github.com/BVLC/caffe/blob/master/src/caffe/layers/recurrent_layer.cpp
+ * https://github.com/BVLC/caffe/blob/master/src/caffe/layers/lstm_layer.cpp
+ * https://github.com/BVLC/caffe/blob/master/src/caffe/layers/lstm_unit_layer.cpp
+ *
+ * Inputs:
+ *   x        -- The time-varying input. Shape: [T, N, d0, d1, ..., dn].
+ *   cont     -- The sequence continuation indicators. Shape: [T, N].
+ *   x_static -- The static (non-time-varying) input. Shape: [N, ...].
+ *               This parameter is optional and not currently supported.
+ *
+ * Additional inputs when parameter "expose_hidden" is true (not currently supported):
+ *   h_0  -- The initial value of the hidden state. Shape: [1, N, D].
+ *   c_0  -- The initial value of the cell state. Shape: [1, N, D].
+ *
+ * Learned parameters:
+ *   xw -- x weights for input, output, forget and cell gates concatenated.
+ *         Shape: [4 * D, d0 * d1 * ... * dn].
+ *   xb -- x biases for input, output, forget and cell gates concatenated. Shape: [4 * D].
+ *   hw -- h weights for input, output, forget and cell gates concatenated. Shape: [4 * D, D].
+ *
+ * Outputs:
+ *   h   -- The time-varying output. Shape: [T, N, D].
+ *
+ * Additional outputs when parameter "expose_hidden" is true (not currently supported):
+ *   h_T -- The value of the hidden state at the last timestep. Shape: [1, N, D].
+ *   c_T -- The value of the cell state at the last timestep. Shape: [1, N, D].
+ *
+ * Here:
+ *   T - the number of timesteps,
+ *   N - the number of independent streams.
+ *   D - the number of hidden parameters.
+ *
+ * Formulas:
+ *   c_cont = c[t-1] * cont[t]
+ *   h_cont = h[t-1] * cont[t]
+ *   i[t] = Sigmoid(x[t] . xw_i + xb_i + h_cont . hw_i)
+ *   f[t] = Sigmoid(x[t] . xw_f + xb_f + h_cont . hw_f)
+ *   o[t] = Sigmoid(x[t] . xw_o + xb_o + h_cont . hw_o)
+ *   g[t] =    Tanh(x[t] . xw_g + xb_g + h_cont . hw_g)
+ *   c[t] = c_cont * f[t] + i[t] * g[t]
+ *   h[t] = o[t] * Tanh(c[t])
+ *
+ * Here:
+ *   t -- the timestep (ranges from 1 to T),
+ *   * -- the inner product,
+ *   . -- the Hadamard product (elementwise product).
+ *
+ * In this implementation the inner products for all gates are performed as single inner product for
+ * efficiency.
+ */
+std::vector<mir::IODescriptor>
+CaffeOpCreator::convertLSTM(const caffe::LayerParameter& layer,
+                            const std::vector<mir::IODescriptor>& inputs) {
+  const auto& params = layer.recurrent_param();
+
+  // Inputs to the layer.
+  auto x = inputs[0];
+  auto cont = inputs[1];
+  assert(inputs.size() == 2);
+
+  const auto& x_shape = x.op->getOutputShape(x.index);
+  const int32_t seq_length = x_shape.dim(0);
+  const int32_t batch_size = x_shape.dim(1);
+  const int32_t hidden_size = params.num_output();
+
+  // Learned parameters of the layer. Tensors are transposed to match the ModelIR.
+  const auto& xw = transposeTensor<1, 0>(convertBlob(layer.blobs(0)));
+  const auto& xb = convertBlob(layer.blobs(1));
+  const auto& hw = transposeTensor<1, 0>(convertBlob(layer.blobs(2)));
+
+  // Add a dummy dimension so that element-wise operations perform properly.
+  cont = createOp<ops::ReshapeOp>("", cont, Shape{seq_length, batch_size, 1})->getOutput(0);
+
+  // Initialize cell and hidden states with zeros.
+  auto zero_tensor = createZeroedTensor(Shape{1, batch_size, hidden_size});
+  auto c_t = createOp<ops::ConstantOp>("", zero_tensor)->getOutput(0);
+  auto h_t = createOp<ops::ConstantOp>("", zero_tensor)->getOutput(0);
+
+  auto x_xw = createFullyConnected(x, xw, 2);
+  auto x_xw_b = createOp<ops::BiasAddOp>("", x_xw, xb)->getOutput(0);
+
+  // Split input and continuation tensors into seq_length slices.
+  std::vector<mir::IODescriptor> x_xw_b_slices = createSplit(x_xw_b, seq_length, 0);
+  std::vector<mir::IODescriptor> cont_slices = createSplit(cont, seq_length, 0);
+  std::vector<mir::IODescriptor> h_slices(seq_length);
+
+  for (int32_t t = 0; t < seq_length; t++) {
+    auto c_cont_t = createMul(c_t, cont_slices[t]);
+    auto h_cont_t = createMul(h_t, cont_slices[t]);
+
+    auto x_xw_b_t = x_xw_b_slices[t];
+    auto h_hw_t = createFullyConnected(h_cont_t, hw, 2);
+    auto activation_inputs_concat = createAdd(x_xw_b_t, h_hw_t);
+    std::vector<mir::IODescriptor> activation_inputs = createSplit(activation_inputs_concat, 4, 2);
+
+    auto i_t = createOp<ops::SigmoidOp>("", activation_inputs[0])->getOutput(0);
+    auto f_t = createOp<ops::SigmoidOp>("", activation_inputs[1])->getOutput(0);
+    auto o_t = createOp<ops::SigmoidOp>("", activation_inputs[2])->getOutput(0);
+    auto g_t = createOp<ops::TanhOp>("", activation_inputs[3])->getOutput(0);
+
+    c_t = createAdd(createMul(c_cont_t, f_t), createMul(i_t, g_t));
+    h_t = createMul(createOp<ops::TanhOp>("", c_t)->getOutput(0), o_t);
+
+    h_slices[t] = h_t;
+  }
+
+  return {createOp<ops::ConcatOp>("", h_slices, 0)->getOutput(0)};
+}
+
 } // namespace nnc
index 25e8265..4f02a74 100644 (file)
@@ -108,9 +108,11 @@ public:
   convertSplit(const caffe::LayerParameter& layer,
                const std::vector<mir::IODescriptor>& inputs);
 
-  void checkConvolution(const caffe::ConvolutionParameter& layer, std::set<std::string>&);
+  std::vector<mir::IODescriptor>
+  convertLSTM(const caffe::LayerParameter& layer,
+              const std::vector<mir::IODescriptor>& inputs);
 
-  void checkInnerProduct(const caffe::InnerProductParameter& opts, std::set<std::string>&);
+  void checkConvolution(const caffe::ConvolutionParameter& layer, std::set<std::string>&);
 
   void checkPooling(const caffe::PoolingParameter& opts, std::set<std::string>&);
 
@@ -120,6 +122,8 @@ public:
 
   void checkBatchNorm(const caffe::LayerParameter& layer, std::set<std::string>&);
 
+  void checkLSTM(const caffe::LayerParameter& layer, std::set<std::string>&);
+
 private:
   mir::Graph* _graph = nullptr;
 
@@ -127,6 +131,18 @@ private:
 
   mir::IODescriptor convertMIRToCaffe(const mir::IODescriptor& arg);
 
+  mir::IODescriptor createAdd(mir::IODescriptor arg1, mir::IODescriptor arg2);
+
+  mir::IODescriptor createMul(mir::IODescriptor arg1, mir::IODescriptor arg2);
+
+  std::vector<mir::IODescriptor>
+  createSplit(mir::IODescriptor arg, int32_t num_parts, int32_t axis);
+
+  mir::IODescriptor
+  createFullyConnected(const mir::IODescriptor& input,
+                       const mir::TensorVariant& weights,
+                       int32_t axis);
+
   TensorVariant convertBlob(const caffe::BlobProto& blob);
 
   template<typename OpType, typename... Types>
index 887112b..5247503 100644 (file)
@@ -73,7 +73,7 @@ enum class CaffeOpType {
   silence,
   slice,
   softmax,
-  softmaxLoss,
+  softmaxWithLoss,
   split,
   SPP,
   tanh,
index c3ee986..0227b90 100644 (file)
@@ -6,7 +6,7 @@
 
 const char *ErrorMsg = "Detected problems:\n"
                        "DummyData: unsupported layer\n"
-                       "LSTM: unsupported layer\n"
+                       "LSTM: parameter 'expose_hidden' has unsupported value: 1\n"
                        "UnexcitingLayerType: unknown layer\n";
 
 // When adding support for new layers, change the model, not the test