#include "core/modelIR/operations/ReshapeOp.h"
#include "core/modelIR/operations/ScaleOp.h"
#include "core/modelIR/operations/SigmoidOp.h"
+#include "core/modelIR/operations/SliceOp.h"
#include "core/modelIR/operations/SoftmaxOp.h"
#include "core/modelIR/operations/TanhOp.h"
#include "core/modelIR/operations/TransposeOp.h"
}
}
+mir::IODescriptor CaffeOpCreator::createAdd(mir::IODescriptor arg1, mir::IODescriptor arg2) {
+ std::vector<IODescriptor> inputs{arg1, arg2};
+ auto op = createOp<ops::ElementwiseOp>("", inputs, ops::ElementwiseOp::OpType::add);
+ return op->getOutput(0);
+}
+
+mir::IODescriptor CaffeOpCreator::createMul(mir::IODescriptor arg1, mir::IODescriptor arg2) {
+ std::vector<IODescriptor> inputs{arg1, arg2};
+ auto op = createOp<ops::ElementwiseOp>("", inputs, ops::ElementwiseOp::OpType::mul);
+ return op->getOutput(0);
+}
+
+/// @brief Split arg into @p num_parts equal parts along @p axis axis.
+std::vector<mir::IODescriptor>
+CaffeOpCreator::createSplit(mir::IODescriptor arg, int32_t num_parts, int32_t axis) {
+ const auto& arg_shape = arg.op->getOutputShape(arg.index);
+
+ assert(axis >= 0 && axis < arg_shape.rank());
+ int32_t part_size = arg_shape.dim(axis) / num_parts;
+ assert(part_size * num_parts == arg_shape.dim(axis));
+
+ Shape starts(arg_shape.rank());
+ Shape sizes(arg_shape);
+ sizes.dim(axis) = part_size;
+
+ std::vector<mir::IODescriptor> outputs(num_parts);
+ for (int32_t i = 0; i < num_parts; ++i) {
+ outputs[i] = createOp<ops::SliceOp>("", arg, starts, sizes)->getOutput(0);
+ starts.dim(axis) += part_size;
+ }
+
+ return outputs;
+}
+
+/// @brief Helper function for creating FullyConnected operation with non-square input.
+IODescriptor
+CaffeOpCreator::createFullyConnected(const mir::IODescriptor& input,
+ const mir::TensorVariant& weights,
+ int32_t axis) {
+ const auto& input_shape = input.op->getOutputShape(input.index);
+ const auto& weights_shape = weights.getShape();
+
+ assert(axis >= 0 && axis < input_shape.rank());
+ assert(weights_shape.rank() == 2);
+
+ // Result shape is: input.shape[0:axis] + weights.shape[1].
+ Shape result_shape = input_shape;
+ result_shape.resize(axis + 1);
+ result_shape.dim(axis) = weights_shape.dim(1);
+
+ // Flatten input to 2-D shape.
+ int32_t outer_size = 1;
+ for (int32_t i = 0; i < axis; ++i)
+ outer_size *= input_shape.dim(i);
+ int32_t inner_size = 1;
+ for (int32_t i = axis; i < input_shape.rank(); ++i)
+ inner_size *= input_shape.dim(i);
+
+ auto flatten = createOp<ops::ReshapeOp>("", input, Shape{outer_size, inner_size})->getOutput(0);
+ auto fc = createOp<ops::FullyConnectedOp>("", flatten, weights)->getOutput(0);
+ return createOp<ops::ReshapeOp>("", fc, result_shape)->getOutput(0);
+}
+
TensorVariant CaffeOpCreator::convertBlob(const BlobProto& blob) {
size_t element_size;
const char* src_data;
}
}
-void CaffeOpCreator::checkInnerProduct(const InnerProductParameter& opts,
- std::set<std::string>& problemsOpSet) {
- if (opts.axis() != 1)
- problemsOpSet.insert("InnerProduct: unsupported axis");
-}
-
-/**
- * @brief Converts Caffe InnerProduct layer to Model IR FullyConnected operation.
- * @todo InnerProduct layer take NCHW input and flattens the CHW part. We insert the
- * Model IR Reshape operation here to account for that, but its result may not be
- * equivalent to how Caffe flattens inputs. Need to check how Caffe does this and
- * implement it correctly.
- * @todo Support axis and transpose parameters as needed.
- */
std::vector<mir::IODescriptor>
CaffeOpCreator::convertInnerProduct(const LayerParameter& layer,
const std::vector<mir::IODescriptor>& inputs) {
if (!params.transpose())
weights = transposeTensor<1, 0>(weights);
- auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
- // Transform input into 2-D tensor by flattening axes before/after params.axis().
- assert(params.axis() == 1);
- Shape shape{input_shape.dim(0), input_shape.numElements() / input_shape.dim(0)};
- auto reshape = createOp<ops::ReshapeOp>(layer.name() + ".reshape", inputs[0], shape);
- auto result = createOp<ops::FullyConnectedOp>(layer.name(), reshape->getOutput(0), weights);
+ auto result = createFullyConnected(inputs[0], weights, params.axis());
// Add the bias, if any.
if (params.bias_term()) {
- auto bias_weights = convertBlob(layer.blobs(1));
- result = createOp<ops::BiasAddOp>(layer.name() + ".bias", result->getOutput(0), bias_weights);
+ const auto& bias_weights = convertBlob(layer.blobs(1));
+ result = createOp<ops::BiasAddOp>(layer.name() + ".bias", result, bias_weights)->getOutput(0);
}
- return {result->getOutput(0)};
+ return {result};
}
std::vector<mir::IODescriptor>
return outputs;
}
+void CaffeOpCreator::checkLSTM(const caffe::LayerParameter& layer,
+ std::set<std::string>& problems_op_set) {
+ const auto& params = layer.recurrent_param();
+ if (params.expose_hidden())
+ problems_op_set.insert("LSTM: parameter 'expose_hidden' has unsupported value: " +
+ std::to_string(params.expose_hidden()));
+}
+
+static TensorVariant createZeroedTensor(const mir::Shape& shape) {
+ // For now it is hardcoded float32.
+ auto elem_type = mir::DTYPE::FLOAT32;
+ auto elem_size = sizeof(float);
+ auto num_elems = static_cast<std::size_t>(shape.numElements());
+ std::shared_ptr<char> data(new char[num_elems * elem_size], std::default_delete<char[]>());
+ std::memset(data.get(), 0, num_elems * elem_size);
+ return TensorVariant(shape, data, elem_type, elem_size);
+}
+
+/* See the following links for details on implementation:
+ * https://github.com/BVLC/caffe/blob/master/src/caffe/layers/recurrent_layer.cpp
+ * https://github.com/BVLC/caffe/blob/master/src/caffe/layers/lstm_layer.cpp
+ * https://github.com/BVLC/caffe/blob/master/src/caffe/layers/lstm_unit_layer.cpp
+ *
+ * Inputs:
+ * x -- The time-varying input. Shape: [T, N, d0, d1, ..., dn].
+ * cont -- The sequence continuation indicators. Shape: [T, N].
+ * x_static -- The static (non-time-varying) input. Shape: [N, ...].
+ * This parameter is optional and not currently supported.
+ *
+ * Additional inputs when parameter "expose_hidden" is true (not currently supported):
+ * h_0 -- The initial value of the hidden state. Shape: [1, N, D].
+ * c_0 -- The initial value of the cell state. Shape: [1, N, D].
+ *
+ * Learned parameters:
+ * xw -- x weights for input, output, forget and cell gates concatenated.
+ * Shape: [4 * D, d0 * d1 * ... * dn].
+ * xb -- x biases for input, output, forget and cell gates concatenated. Shape: [4 * D].
+ * hw -- h weights for input, output, forget and cell gates concatenated. Shape: [4 * D, D].
+ *
+ * Outputs:
+ * h -- The time-varying output. Shape: [T, N, D].
+ *
+ * Additional outputs when parameter "expose_hidden" is true (not currently supported):
+ * h_T -- The value of the hidden state at the last timestep. Shape: [1, N, D].
+ * c_T -- The value of the cell state at the last timestep. Shape: [1, N, D].
+ *
+ * Here:
+ * T - the number of timesteps,
+ * N - the number of independent streams.
+ * D - the number of hidden parameters.
+ *
+ * Formulas:
+ * c_cont = c[t-1] * cont[t]
+ * h_cont = h[t-1] * cont[t]
+ * i[t] = Sigmoid(x[t] . xw_i + xb_i + h_cont . hw_i)
+ * f[t] = Sigmoid(x[t] . xw_f + xb_f + h_cont . hw_f)
+ * o[t] = Sigmoid(x[t] . xw_o + xb_o + h_cont . hw_o)
+ * g[t] = Tanh(x[t] . xw_g + xb_g + h_cont . hw_g)
+ * c[t] = c_cont * f[t] + i[t] * g[t]
+ * h[t] = o[t] * Tanh(c[t])
+ *
+ * Here:
+ * t -- the timestep (ranges from 1 to T),
+ * * -- the inner product,
+ * . -- the Hadamard product (elementwise product).
+ *
+ * In this implementation the inner products for all gates are performed as single inner product for
+ * efficiency.
+ */
+std::vector<mir::IODescriptor>
+CaffeOpCreator::convertLSTM(const caffe::LayerParameter& layer,
+ const std::vector<mir::IODescriptor>& inputs) {
+ const auto& params = layer.recurrent_param();
+
+ // Inputs to the layer.
+ auto x = inputs[0];
+ auto cont = inputs[1];
+ assert(inputs.size() == 2);
+
+ const auto& x_shape = x.op->getOutputShape(x.index);
+ const int32_t seq_length = x_shape.dim(0);
+ const int32_t batch_size = x_shape.dim(1);
+ const int32_t hidden_size = params.num_output();
+
+ // Learned parameters of the layer. Tensors are transposed to match the ModelIR.
+ const auto& xw = transposeTensor<1, 0>(convertBlob(layer.blobs(0)));
+ const auto& xb = convertBlob(layer.blobs(1));
+ const auto& hw = transposeTensor<1, 0>(convertBlob(layer.blobs(2)));
+
+ // Add a dummy dimension so that element-wise operations perform properly.
+ cont = createOp<ops::ReshapeOp>("", cont, Shape{seq_length, batch_size, 1})->getOutput(0);
+
+ // Initialize cell and hidden states with zeros.
+ auto zero_tensor = createZeroedTensor(Shape{1, batch_size, hidden_size});
+ auto c_t = createOp<ops::ConstantOp>("", zero_tensor)->getOutput(0);
+ auto h_t = createOp<ops::ConstantOp>("", zero_tensor)->getOutput(0);
+
+ auto x_xw = createFullyConnected(x, xw, 2);
+ auto x_xw_b = createOp<ops::BiasAddOp>("", x_xw, xb)->getOutput(0);
+
+ // Split input and continuation tensors into seq_length slices.
+ std::vector<mir::IODescriptor> x_xw_b_slices = createSplit(x_xw_b, seq_length, 0);
+ std::vector<mir::IODescriptor> cont_slices = createSplit(cont, seq_length, 0);
+ std::vector<mir::IODescriptor> h_slices(seq_length);
+
+ for (int32_t t = 0; t < seq_length; t++) {
+ auto c_cont_t = createMul(c_t, cont_slices[t]);
+ auto h_cont_t = createMul(h_t, cont_slices[t]);
+
+ auto x_xw_b_t = x_xw_b_slices[t];
+ auto h_hw_t = createFullyConnected(h_cont_t, hw, 2);
+ auto activation_inputs_concat = createAdd(x_xw_b_t, h_hw_t);
+ std::vector<mir::IODescriptor> activation_inputs = createSplit(activation_inputs_concat, 4, 2);
+
+ auto i_t = createOp<ops::SigmoidOp>("", activation_inputs[0])->getOutput(0);
+ auto f_t = createOp<ops::SigmoidOp>("", activation_inputs[1])->getOutput(0);
+ auto o_t = createOp<ops::SigmoidOp>("", activation_inputs[2])->getOutput(0);
+ auto g_t = createOp<ops::TanhOp>("", activation_inputs[3])->getOutput(0);
+
+ c_t = createAdd(createMul(c_cont_t, f_t), createMul(i_t, g_t));
+ h_t = createMul(createOp<ops::TanhOp>("", c_t)->getOutput(0), o_t);
+
+ h_slices[t] = h_t;
+ }
+
+ return {createOp<ops::ConcatOp>("", h_slices, 0)->getOutput(0)};
+}
+
} // namespace nnc