[Frontend][Relay] Keras softmax and prelu fix (#6278) (#6278)
authorDongming Yang <50566938+domin1985@users.noreply.github.com>
Tue, 25 Aug 2020 00:08:41 +0000 (08:08 +0800)
committerGitHub <noreply@github.com>
Tue, 25 Aug 2020 00:08:41 +0000 (09:08 +0900)
* prelu and softmax with NHWC layout consideration

* fix lint

* fix lint

Co-authored-by: Dongming Yang <dongming.yang@streamcomputing.com>
python/tvm/relay/frontend/keras.py
tests/python/frontend/keras/test_forward.py

index 32de471..b469ed0 100644 (file)
@@ -63,7 +63,7 @@ def _convert_recurrent_activation(inexpr, keras_layer):
     return _convert_activation(inexpr, act_type, None)
 
 
-def _convert_activation(inexpr, keras_layer, _):
+def _convert_activation(inexpr, keras_layer, etab):
     if isinstance(keras_layer, str):
         act_type = keras_layer
     else:
@@ -80,7 +80,8 @@ def _convert_activation(inexpr, keras_layer, _):
         beta = _expr.const(beta, dtype='float32')
         return _op.add(_op.multiply(inexpr, alpha), beta)
     if act_type == 'softmax':
-        return _op.nn.softmax(inexpr, axis=1)
+        axis = 1 if etab.data_layout == 'NCHW' else -1
+        return _op.nn.softmax(inexpr, axis)
     if act_type == 'sigmoid':
         return _op.sigmoid(inexpr)
     if act_type == 'tanh':
@@ -123,10 +124,11 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
         if isinstance(axis, list):
             raise tvm.error.OpAttributeUnImplemented(
                 'Softmax with axes {} is not supported.'.format(axis))
-        if axis == -1:
-            axis = 1
-        else:
-            axis = axis + 1 if axis < dims - 1 else 1
+        if etab.data_layout == 'NCHW':
+            if axis == -1:
+                axis = 1
+            else:
+                axis = axis + 1 if axis < dims - 1 else 1
         return _op.nn.softmax(inexpr, axis=axis)
     if act_type == 'ReLU':
         threshold = _expr.const(keras_layer.threshold, dtype='float32')
@@ -149,8 +151,11 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
         assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU."
         _check_data_format(keras_layer)
         size = len(keras_layer.alpha.shape)
-        alpha = etab.new_const(keras_layer.get_weights()[0] \
-                               .transpose(np.roll(range(size), 1)))
+        if etab.data_layout == 'NCHW':
+            alpha = etab.new_const(keras_layer.get_weights()[0]
+                                   .transpose(np.roll(range(size), 1)))
+        else:
+            alpha = etab.new_const(keras_layer.get_weights()[0])
         return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr)
     if act_type == 'ThresholdedReLU':
         theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1.
index 8ddae96..f940255 100644 (file)
@@ -182,6 +182,7 @@ class TestKeras:
             x = act_func(data)
             keras_model = keras.models.Model(data, x)
             verify_keras_frontend(keras_model)
+            verify_keras_frontend(keras_model, need_transpose=False, layout='NHWC')
 
 
     def test_forward_dense(self, keras):