[Frontend][ONNX] LSTM Support (#4825)
authorJosh Fromm <jwfromm@uw.edu>
Fri, 7 Feb 2020 02:09:10 +0000 (18:09 -0800)
committerGitHub <noreply@github.com>
Fri, 7 Feb 2020 02:09:10 +0000 (11:09 +0900)
* Initial version working and passing tests.

* WIP on supporting other activations.

* add support for multiple activation functions in lstm

* All tests working and code cleaned up.

* Undo import swap to avoid conflict with masahi.

* Added new tests and related bug fixes.

Co-authored-by: Matthew Brookhart <mbrookhart@octoml.ai>
python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index b275e85..fdf6bd7 100644 (file)
@@ -32,6 +32,55 @@ from .common import infer_type, infer_value, infer_value_simulated, get_name
 __all__ = ['from_onnx']
 
 
+class onnx_input():
+    """ Dual purpose list or dictionary access object."""
+
+    def __init__(self):
+        self.input_keys = []
+        self.input_dict = {}
+
+    def __getitem__(self, item):
+        if isinstance(item, int):
+            return self.input_dict[self.input_keys[item]]
+        if isinstance(item, str):
+            if item not in self.input_keys:
+                return None
+            return self.input_dict[item]
+        if isinstance(item, slice):
+            keys = self.input_keys[item]
+            return [self.input_dict[key] for key in keys]
+
+        raise ValueError("Only integer, string, and slice accesses allowed.")
+
+    def __setitem__(self, item, value):
+        if isinstance(item, int):
+            self.input_dict[self.input_keys[item]] = value
+        elif isinstance(item, str):
+            if item not in self.input_dict:
+                self.input_keys.append(item)
+            self.input_dict[item] = value
+        else:
+            raise ValueError("Only integer and string indexed writes allowed.")
+
+    def keys(self):
+        return self.input_keys
+
+    def __len__(self):
+        return len(self.input_keys)
+
+    def __iter__(self):
+        self.n = 0
+        return self
+
+    def __next__(self):
+        if self.n < len(self.input_keys):
+            output = self.input_dict[self.input_keys[self.n]]
+            self.n += 1
+            return output
+
+        raise StopIteration
+
+
 def get_numpy(tensor_proto):
     """Grab data in TensorProto and convert to numpy array."""
     try:
@@ -664,13 +713,24 @@ class Sum(OnnxOpConverter):
         return inputs[len(inputs) - 1]
 
 
+class Affine(OnnxOpConverter):
+    """ Operator converter for Affine transformation.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        alpha = _expr.const(attr.get('alpha', 1.0))
+        beta = _expr.const(attr.get('beta', 0.0))
+        return (alpha * inputs[0]) + beta
+
+
 class ThresholdedRelu(OnnxOpConverter):
     """ Operator converter for ThresholdedRelu.
     """
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        alpha = float(attr.get('alpha', 0.0))
+        alpha = float(attr.get('alpha', 1.0))
         alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
         mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
         return inputs[0] * mask
@@ -893,7 +953,7 @@ class Maximum(OnnxOpConverter):
     """
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if not isinstance(inputs, list) or len(inputs) < 2:
+        if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
             raise ValueError("Expect minimum 2 inputs")
         _max = inputs[0]
         for i in range(1, len(inputs)):
@@ -905,7 +965,7 @@ class Minimum(OnnxOpConverter):
     """
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if not isinstance(inputs, list) or len(inputs) < 2:
+        if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
             raise ValueError("Expect minimum 2 inputs")
         _min = inputs[0]
         for i in range(1, len(inputs)):
@@ -917,7 +977,7 @@ class Mean(OnnxOpConverter):
     """
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if not isinstance(inputs, list) or len(inputs) < 2:
+        if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
             raise ValueError("Expect minimum 2 inputs")
         # avoid overflow
         concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
@@ -1190,6 +1250,151 @@ class Expand(OnnxOpConverter):
         return _op.broadcast_to(inputs[0], shape=tuple(shape))
 
 
+class LSTM(OnnxOpConverter):
+    """ Operator converter for LSTM.
+    """
+
+    @classmethod
+    def _activation_helper(cls, activation, alpha, beta):
+        convert_map = _get_convert_map(1)
+        attrs = {}
+        if alpha is not None:
+            attrs['alpha'] = alpha
+        if beta is not None:
+            attrs['beta'] = beta
+        return lambda x: convert_map[activation.decode("utf-8")]([x], attrs, {})
+
+    @classmethod
+    def _activation_needs_alpha(cls, activation):
+        needs_alpha = [
+            "Affine",
+            "LeakyRelu",
+            "ThresholdedRelu",
+            "ScaledTanh",
+            "HardSigmoid",
+            "Elu",
+        ]
+        return activation.decode("utf-8") in needs_alpha
+
+    @classmethod
+    def _activation_needs_beta(cls, activation):
+        needs_beta = [
+            "Affine",
+            "ScaledTanh",
+            "HardSigmoid",
+        ]
+        return activation.decode("utf-8") in needs_beta
+
+    @classmethod
+    def _impl_v7(cls, inputs, attr, params):
+        # Unpack inputs, note that if optional and not provided then value will be None.
+        X = inputs[0]
+        W = inputs[1]
+        R = inputs[2]
+        B = inputs['B']
+        # Sequence length currently unused as it can be inferred from shapes.
+        #sequence_lens = inputs['sequence_lens']
+        h_0 = inputs['initial_h']
+        c_0 = inputs['initial_c']
+        P = inputs['P']
+
+        num_directions = infer_shape(W)[0]
+        W_dtype = infer_type(W).type_annotation.dtype
+
+        if num_directions != 1:
+            raise NotImplementedError("Bidirectional LSTMs not yet supported.")
+        # Remove num_directions axis from weights.
+        W = _op.squeeze(W, axis=[0])
+        R = _op.squeeze(R, axis=[0])
+        if B is not None:
+            B = _op.squeeze(B, axis=[0])
+
+        X_shape = infer_shape(X)
+        hidden_size = infer_shape(R)[-1]
+        batch_size = X_shape[1]
+
+        # Initialize state if not provided.
+        # Otherwise remove bidirectional axis.
+        if h_0 is None:
+            h_0 = _op.zeros((batch_size, hidden_size), W_dtype)
+        else:
+            h_0 = _op.squeeze(h_0, axis=[0])
+        if c_0 is None:
+            c_0 = _op.zeros((batch_size, hidden_size), W_dtype)
+        else:
+            c_0 = _op.squeeze(c_0, axis=[0])
+
+        if P is not None:
+            P = _op.squeeze(P, axis=[0])
+            p_i, p_o, p_f = _op.split(P, 3)
+        H_t = h_0
+        C_t = c_0
+        h_list = []
+
+        if 'activations' in attr:
+            activations = attr['activations']
+            if len(activations) != 3:
+                raise NotImplementedError("LSTM assumes 3 activation functions are provided")
+            alpha_loc = 0
+            alphas = attr.get('activation_alpha', [])
+            if isinstance(alphas, float):
+                alphas = [alphas]
+            beta_loc = 0
+            betas = attr.get('activation_beta', [])
+            if isinstance(betas, float):
+                betas = [betas]
+            acts = []
+            for i in range(3):
+                alpha = None
+                beta = None
+                activation = activations[i]
+                if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc:
+                    alpha = alphas[alpha_loc]
+                    alpha_loc += 1
+                if cls._activation_needs_beta(activation) and len(betas) > beta_loc:
+                    beta = betas[beta_loc]
+                    beta_loc += 1
+                acts.append(cls._activation_helper(activation, alpha, beta))
+            f_act, g_act, h_act = acts
+        else:
+            f_act = _op.sigmoid
+            g_act = _op.tanh
+            h_act = _op.tanh
+
+        X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
+        for step in X_steps:
+            step = _op.squeeze(step, axis=[0])
+            gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R)
+            if B is not None:
+                WB, RB = _op.split(B, 2)
+                gates += WB + RB
+            i, o, f, c = _op.split(gates, 4, axis=-1)
+            if P is not None:
+                i = f_act(i + p_i * C_t)
+                f = f_act(f + p_f * C_t)
+
+            else:
+                i = f_act(i)
+                f = f_act(f)
+            c = g_act(c)
+            C = f * C_t + i * c
+            if P is not None:
+                o = f_act(o + p_o * C)
+            else:
+                o = f_act(o)
+            H = o * h_act(C)
+            H_t = H
+            C_t = C
+            h_list.append(_op.expand_dims(H, axis=0))
+        # Concatenate outputs and add back in direction axis.
+        concatenated = _op.concatenate(h_list, 0)
+        output = _op.expand_dims(concatenated, axis=1)
+        H_t = _op.expand_dims(H_t, axis=0)
+        C_t = _op.expand_dims(C_t, axis=0)
+
+        return _expr.TupleWrapper(_expr.Tuple((output, H_t, C_t)), 3)
+
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1203,7 +1408,7 @@ def _get_convert_map(opset):
     return {
         # defs/experimental
         'Identity': Renamer('copy'),
-        # 'Affine'
+        'Affine': Affine.get_converter(opset),
         'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
         'ScaledTanh': ScaledTanh.get_converter(opset),
         'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
@@ -1281,6 +1486,8 @@ def _get_convert_map(opset):
         'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
         'Flatten': Flatten.get_converter(opset),
         'LRN': LRN.get_converter(opset),
+        # Recurrent Layers
+        'LSTM': LSTM.get_converter(opset),
 
         # defs/reduction
         'ReduceMax': ReduceMax.get_converter(opset),
@@ -1414,7 +1621,11 @@ class GraphProto(object):
         for node in graph.node:
             op_name = node.op_type
             attr = self._parse_attr(node.attribute)
-            inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
+            # Create and populate onnx input object.
+            inputs = onnx_input()
+            for i in node.input:
+                if i != '':
+                    inputs[i] = self._nodes[self._renames.get(i, i)]
             if op_name == "Constant":
                 t_proto = self._parse_attr(node.attribute)["value"]
                 self._num_param += 1
index f9e8aac..ef96c11 100644 (file)
@@ -56,6 +56,12 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
     # set inputs
     if isinstance(input_data, list):
         for i, e in enumerate(input_names):
+            # Its possible for some onnx inputs to not be needed in the tvm
+            # module, confirm its present before setting.
+            try:
+                m.get_input(input_names[i])
+            except:
+                continue
             m.set_input(input_names[i], tvm.nd.array(
                 input_data[i].astype(input_data[i].dtype)))
     else:
@@ -1962,6 +1968,175 @@ def test_pooling():
                        auto_pad='SAME_UPPER')
 
 
+def verify_lstm(seq_length,
+                batch_size,
+                input_size,
+                hidden_size,
+                use_bias=False,
+                activations=None,
+                alphas=None,
+                betas=None,
+                use_initial_state=False,
+                use_peep=False):
+    x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype('float32')
+    w_np = np.random.uniform(size=(1, 4 * hidden_size, input_size)).astype('float32')
+    r_np = np.random.uniform(size=(1, 4 * hidden_size, hidden_size)).astype('float32')
+    input_names = ["X", "W", "R"]
+    input_tensors = [
+        helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_np.shape)),
+        helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_np.shape)),
+        helper.make_tensor_value_info("R", TensorProto.FLOAT, list(r_np.shape))
+    ]
+    input_values = [x_np, w_np, r_np]
+
+    if use_bias:
+        b_np = np.random.uniform(size=(1, 8 * hidden_size)).astype('float32')
+        input_names.append("B")
+        input_tensors.append(
+            helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 8 * hidden_size]))
+        input_values.append(b_np)
+
+    if use_initial_state:
+        assert use_bias == True, "Initial states must have bias specified."
+        sequence_np = np.repeat(seq_length, batch_size).astype('int32')
+        input_names.append("sequence_lens")
+        input_tensors.append(helper.make_tensor_value_info("sequence_lens", TensorProto.INT32, [batch_size]))
+        input_values.append(sequence_np)
+
+        initial_h_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype('float32')
+        input_names.append("initial_h")
+        input_tensors.append(
+            helper.make_tensor_value_info("initial_h", TensorProto.FLOAT,
+                                          [1, batch_size, hidden_size]))
+        input_values.append(initial_h_np)
+
+        initial_c_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype('float32')
+        input_names.append("initial_c")
+        input_tensors.append(
+            helper.make_tensor_value_info("initial_c", TensorProto.FLOAT,
+                                          [1, batch_size, hidden_size]))
+        input_values.append(initial_c_np)
+
+    if use_peep:
+        assert use_initial_state == True, "Peepholes require initial state to be specified."
+        p_np = np.random.uniform(size=(1, 3 * hidden_size)).astype('float32')
+        input_names.append("P")
+        input_tensors.append(
+            helper.make_tensor_value_info("P", TensorProto.FLOAT, [1, 3 * hidden_size]))
+        input_values.append(p_np)
+
+    Y_shape = [seq_length, 1, batch_size, hidden_size]
+    Y_h_shape = [1, batch_size, hidden_size]
+    Y_c_shape = [1, batch_size, hidden_size]
+
+    if activations is None:
+        lstm_node = helper.make_node(
+            'LSTM', inputs=input_names, outputs=["Y", "Y_h", "Y_c"], hidden_size=hidden_size)
+    elif alphas is None:
+        lstm_node = helper.make_node(
+            'LSTM',
+            inputs=input_names,
+            outputs=["Y", "Y_h", "Y_c"],
+            hidden_size=hidden_size,
+            activations=activations)
+    else:
+        lstm_node = helper.make_node(
+            'LSTM',
+            inputs=input_names,
+            outputs=["Y", "Y_h", "Y_c"],
+            hidden_size=hidden_size,
+            activations=activations,
+            activation_alpha=alphas,
+            activation_beta=betas)
+
+    graph = helper.make_graph([lstm_node],
+                              "lstm_test",
+                              inputs=input_tensors,
+                              outputs=[
+                                  helper.make_tensor_value_info("Y", TensorProto.FLOAT,
+                                                                list(Y_shape)),
+                                  helper.make_tensor_value_info("Y_h", TensorProto.FLOAT,
+                                                                list(Y_h_shape)),
+                                  helper.make_tensor_value_info("Y_c", TensorProto.FLOAT,
+                                                                list(Y_c_shape))
+                              ])
+
+    model = helper.make_model(graph, producer_name='lstm_test')
+
+    for target, ctx in ctx_list():
+        onnx_out = get_onnxruntime_output(model, input_values, 'float32')
+        tvm_out = get_tvm_output(
+            model,
+            input_values,
+            target,
+            ctx, [Y_shape, Y_h_shape, Y_c_shape],
+            output_dtype=['float32', 'float32', 'float32'])
+        for o_out, t_out in zip(onnx_out, tvm_out):
+            tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3)
+
+
+def test_lstm():
+    # No bias.
+    verify_lstm(seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False)
+    # large batch.
+    verify_lstm(seq_length=4, batch_size=8, input_size=16, hidden_size=32, use_bias=True)
+    # Non power of two.
+    verify_lstm(seq_length=3, batch_size=3, input_size=16, hidden_size=40, use_bias=True)
+    # Long sequence.
+    verify_lstm(seq_length=8, batch_size=1, input_size=16, hidden_size=32, use_bias=True)
+    # Large hidden.
+    verify_lstm(seq_length=2, batch_size=1, input_size=16, hidden_size=128, use_bias=True)
+    # Large input.
+    verify_lstm(seq_length=2, batch_size=1, input_size=64, hidden_size=32, use_bias=True)
+
+    # Different activation testing.
+    # Default value hardsigmoid.
+    verify_lstm(
+        seq_length=2,
+        batch_size=1,
+        input_size=16,
+        hidden_size=32,
+        use_bias=False,
+        activations=['HardSigmoid', 'Tanh', 'Tanh'])
+    # Multiple parameterized activations.
+    verify_lstm(
+        seq_length=2,
+        batch_size=1,
+        input_size=16,
+        hidden_size=32,
+        use_bias=False,
+        activations=['HardSigmoid', 'LeakyRelu', 'Tanh'],
+        alphas=[2.0, 0.5],
+        betas=[.3])
+    # All parameterized with new Affine activation.
+    verify_lstm(
+        seq_length=2,
+        batch_size=1,
+        input_size=16,
+        hidden_size=32,
+        use_bias=False,
+        activations=['HardSigmoid', 'LeakyRelu', 'Affine'],
+        alphas=[2.0, 0.5, 0.8],
+        betas=[.3, 0.1])
+
+    # Testing with initial state and peepholes
+    verify_lstm(
+        seq_length=2,
+        batch_size=1,
+        input_size=16,
+        hidden_size=32,
+        use_bias=True,
+        use_initial_state=True)
+    verify_lstm(
+        seq_length=2,
+        batch_size=1,
+        input_size=16,
+        hidden_size=32,
+        use_bias=True,
+        use_initial_state=True,
+        use_peep=True)
+
+
 if __name__ == '__main__':
     test_flatten()
     test_reshape()
@@ -2020,3 +2195,4 @@ if __name__ == '__main__':
     test_convtranspose()
     test_unsqueeze_constant()
     test_pooling()
+    test_lstm()