Replace trivial backend calls with calls to underlying TensorFlow functions - Part 2
authorPavithra Vijay <psv@google.com>
Wed, 4 Apr 2018 23:53:51 +0000 (16:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 23:56:14 +0000 (16:56 -0700)
PiperOrigin-RevId: 191669725

16 files changed:
tensorflow/python/keras/_impl/keras/activations.py
tensorflow/python/keras/_impl/keras/constraints.py
tensorflow/python/keras/_impl/keras/engine/training_utils.py
tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
tensorflow/python/keras/_impl/keras/layers/core.py
tensorflow/python/keras/_impl/keras/layers/core_test.py
tensorflow/python/keras/_impl/keras/layers/embeddings.py
tensorflow/python/keras/_impl/keras/layers/merge.py
tensorflow/python/keras/_impl/keras/layers/noise.py
tensorflow/python/keras/_impl/keras/layers/recurrent.py
tensorflow/python/keras/_impl/keras/layers/wrappers.py
tensorflow/python/keras/_impl/keras/losses.py
tensorflow/python/keras/_impl/keras/metrics.py
tensorflow/python/keras/_impl/keras/metrics_test.py
tensorflow/python/keras/_impl/keras/optimizers.py
tensorflow/python/keras/_impl/keras/regularizers.py

index 74ec373..b518898 100644 (file)
@@ -24,6 +24,7 @@ from tensorflow.python.keras._impl.keras import backend as K
 from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
 from tensorflow.python.layers.base import Layer
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util.tf_export import tf_export
 
@@ -44,9 +45,9 @@ def softmax(x, axis=-1):
   """
   ndim = K.ndim(x)
   if ndim == 2:
-    return K.softmax(x)
+    return nn.softmax(x)
   elif ndim > 2:
-    e = K.exp(x - K.max(x, axis=axis, keepdims=True))
+    e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True))
     s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
     return e / s
   else:
@@ -80,12 +81,12 @@ def selu(x):
 
 @tf_export('keras.activations.softplus')
 def softplus(x):
-  return K.softplus(x)
+  return nn.softplus(x)
 
 
 @tf_export('keras.activations.softsign')
 def softsign(x):
-  return K.softsign(x)
+  return nn.softsign(x)
 
 
 @tf_export('keras.activations.relu')
@@ -95,12 +96,12 @@ def relu(x, alpha=0., max_value=None):
 
 @tf_export('keras.activations.tanh')
 def tanh(x):
-  return K.tanh(x)
+  return nn.tanh(x)
 
 
 @tf_export('keras.activations.sigmoid')
 def sigmoid(x):
-  return K.sigmoid(x)
+  return nn.sigmoid(x)
 
 
 @tf_export('keras.activations.hard_sigmoid')
index aac4d0f..abe95d8 100644 (file)
@@ -67,7 +67,7 @@ class MaxNorm(Constraint):
 
   def __call__(self, w):
     norms = K.sqrt(
-        math_ops.reduce_sum(K.square(w), axis=self.axis, keepdims=True))
+        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
     desired = K.clip(norms, 0, self.max_value)
     return w * (desired / (K.epsilon() + norms))
 
@@ -81,7 +81,7 @@ class NonNeg(Constraint):
   """
 
   def __call__(self, w):
-    return w * math_ops.cast(K.greater_equal(w, 0.), K.floatx())
+    return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx())
 
 
 @tf_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm')
@@ -108,7 +108,8 @@ class UnitNorm(Constraint):
   def __call__(self, w):
     return w / (
         K.epsilon() + K.sqrt(
-            math_ops.reduce_sum(K.square(w), axis=self.axis, keepdims=True)))
+            math_ops.reduce_sum(
+                math_ops.square(w), axis=self.axis, keepdims=True)))
 
   def get_config(self):
     return {'axis': self.axis}
@@ -152,7 +153,7 @@ class MinMaxNorm(Constraint):
 
   def __call__(self, w):
     norms = K.sqrt(
-        math_ops.reduce_sum(K.square(w), axis=self.axis, keepdims=True))
+        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
     desired = (
         self.rate * K.clip(norms, self.min_value, self.max_value) +
         (1 - self.rate) * norms)
index 58d2c78..a3fc8ef 100644 (file)
@@ -451,7 +451,8 @@ def weighted_masked_objective(fn):
       weight_ndim = K.ndim(weights)
       score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))
       score_array *= weights
-      score_array /= K.mean(math_ops.cast(K.not_equal(weights, 0), K.floatx()))
+      score_array /= K.mean(
+          math_ops.cast(math_ops.not_equal(weights, 0), K.floatx()))
     return K.mean(score_array)
 
   return weighted
index 45b0c6c..11ca89d 100644 (file)
@@ -147,7 +147,7 @@ class PReLU(Layer):
     if K.backend() == 'theano':
       neg = (
           K.pattern_broadcast(self.alpha, self.param_broadcast) *
-          (inputs - K.abs(inputs)) * 0.5)
+          (inputs - math_ops.abs(inputs)) * 0.5)
     else:
       neg = -self.alpha * K.relu(-inputs)
     return pos + neg
@@ -233,7 +233,8 @@ class ThresholdedReLU(Layer):
     self.theta = K.cast_to_floatx(theta)
 
   def call(self, inputs, mask=None):
-    return inputs * math_ops.cast(K.greater(inputs, self.theta), K.floatx())
+    return inputs * math_ops.cast(
+        math_ops.greater(inputs, self.theta), K.floatx())
 
   def get_config(self):
     config = {'theta': float(self.theta)}
index a709a07..c74fc1e 100644 (file)
@@ -77,11 +77,11 @@ class Masking(Layer):
     self.mask_value = mask_value
 
   def compute_mask(self, inputs, mask=None):
-    return K.any(K.not_equal(inputs, self.mask_value), axis=-1)
+    return K.any(math_ops.not_equal(inputs, self.mask_value), axis=-1)
 
   def call(self, inputs):
     boolean_mask = K.any(
-        K.not_equal(inputs, self.mask_value), axis=-1, keepdims=True)
+        math_ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True)
     return inputs * math_ops.cast(boolean_mask, inputs.dtype)
 
   def compute_output_shape(self, input_shape):
@@ -416,7 +416,8 @@ class Reshape(Layer):
     return tensor_shape.TensorShape(output_shape)
 
   def call(self, inputs):
-    return K.reshape(inputs, (array_ops.shape(inputs)[0],) + self.target_shape)
+    return array_ops.reshape(inputs,
+                             (array_ops.shape(inputs)[0],) + self.target_shape)
 
   def get_config(self):
     config = {'target_shape': self.target_shape}
@@ -469,7 +470,7 @@ class Permute(Layer):
     return tensor_shape.TensorShape(output_shape)
 
   def call(self, inputs):
-    return K.permute_dimensions(inputs, (0,) + self.dims)
+    return array_ops.transpose(inputs, perm=(0,) + self.dims)
 
   def get_config(self):
     config = {'dims': self.dims}
index 2ca816a..551d1b1 100644 (file)
@@ -23,6 +23,7 @@ import numpy as np
 from tensorflow.python.framework import test_util as tf_test_util
 from tensorflow.python.keras._impl import keras
 from tensorflow.python.keras._impl.keras import testing_utils
+from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
@@ -159,7 +160,7 @@ class CoreLayersTest(test.TestCase):
 
     # test with lambda
     ld = keras.layers.Lambda(
-        lambda x: keras.backend.concatenate([keras.backend.square(x), x]))
+        lambda x: keras.backend.concatenate([math_ops.square(x), x]))
     config = ld.get_config()
     ld = keras.layers.Lambda.from_config(config)
 
@@ -235,4 +236,3 @@ class CoreLayersTest(test.TestCase):
 
 if __name__ == '__main__':
   test.main()
-
index a0fd7a9..540e2d9 100644 (file)
@@ -128,7 +128,7 @@ class Embedding(Layer):
     if not self.mask_zero:
       return None
     else:
-      return K.not_equal(inputs, 0)
+      return math_ops.not_equal(inputs, 0)
 
   @shape_type_conversion
   def compute_output_shape(self, input_shape):
index 6290db2..7c87e6c 100644 (file)
@@ -24,6 +24,8 @@ from tensorflow.python.keras._impl.keras import backend as K
 from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
 from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -128,7 +130,7 @@ class _Merge(Layer):
         for x in inputs:
           x_ndim = K.ndim(x)
           for _ in range(max_ndim - x_ndim):
-            x = K.expand_dims(x, 1)
+            x = array_ops.expand_dims(x, axis=1)
           reshaped_inputs.append(x)
         return self._merge_function(reshaped_inputs)
       else:
@@ -140,17 +142,20 @@ class _Merge(Layer):
           if x_ndim is None:
             x_shape = array_ops.shape(x)
             batch_size = x_shape[0]
-            new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
-            x_transposed = K.reshape(x,
-                                     K.stack([batch_size,
-                                              K.prod(x_shape[1:])]))
-            x_transposed = K.permute_dimensions(x_transposed, (1, 0))
-            x_transposed = K.reshape(x_transposed, new_shape)
+            new_shape = K.concatenate(
+                [x_shape[1:],
+                 array_ops.expand_dims(batch_size, axis=-1)])
+            x_transposed = array_ops.reshape(
+                x,
+                array_ops.stack(
+                    [batch_size, math_ops.reduce_prod(x_shape[1:])], axis=0))
+            x_transposed = array_ops.transpose(x_transposed, perm=(1, 0))
+            x_transposed = array_ops.reshape(x_transposed, new_shape)
             reshaped_inputs.append(x_transposed)
             transposed = True
           elif x_ndim > 1:
             dims = list(range(1, x_ndim)) + [0]
-            reshaped_inputs.append(K.permute_dimensions(x, dims))
+            reshaped_inputs.append(array_ops.transpose(x, perm=dims))
             transposed = True
           else:
             # We don't transpose inputs if they are 1D vectors or scalars.
@@ -163,14 +168,15 @@ class _Merge(Layer):
             y_shape = array_ops.shape(y)
             y_ndim = array_ops.shape(y_shape)[0]
             batch_size = y_shape[y_ndim - 1]
-            new_shape = K.concatenate(
-                [K.expand_dims(batch_size), y_shape[:y_ndim - 1]])
-            y = K.reshape(y, (-1, batch_size))
-            y = K.permute_dimensions(y, (1, 0))
-            y = K.reshape(y, new_shape)
+            new_shape = K.concatenate([
+                array_ops.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1]
+            ])
+            y = array_ops.reshape(y, (-1, batch_size))
+            y = array_ops.transpose(y, perm=(1, 0))
+            y = array_ops.reshape(y, new_shape)
           elif y_ndim > 1:
             dims = [y_ndim - 1] + list(range(y_ndim - 1))
-            y = K.permute_dimensions(y, dims)
+            y = array_ops.transpose(y, perm=dims)
         return y
     else:
       return self._merge_function(inputs)
@@ -208,7 +214,7 @@ class _Merge(Layer):
                        'should have the same length.')
     if all([m is None for m in mask]):
       return None
-    masks = [K.expand_dims(m, 0) for m in mask if m is not None]
+    masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None]
     return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)
 
 
@@ -326,7 +332,7 @@ class Maximum(_Merge):
   def _merge_function(self, inputs):
     output = inputs[0]
     for i in range(1, len(inputs)):
-      output = K.maximum(output, inputs[i])
+      output = math_ops.maximum(output, inputs[i])
     return output
 
 
@@ -341,7 +347,7 @@ class Minimum(_Merge):
   def _merge_function(self, inputs):
     output = inputs[0]
     for i in range(1, len(inputs)):
-      output = K.minimum(output, inputs[i])
+      output = math_ops.minimum(output, inputs[i])
     return output
 
 
@@ -422,7 +428,7 @@ class Concatenate(_Merge):
         masks.append(array_ops.ones_like(input_i, dtype='bool'))
       elif K.ndim(mask_i) < K.ndim(input_i):
         # Mask is smaller than the input, expand it
-        masks.append(K.expand_dims(mask_i))
+        masks.append(array_ops.expand_dims(mask_i, axis=-1))
       else:
         masks.append(mask_i)
     concatenated = K.concatenate(masks, axis=self.axis)
@@ -512,8 +518,8 @@ class Dot(_Merge):
         else:
           axes.append(self.axes[i])
     if self.normalize:
-      x1 = K.l2_normalize(x1, axis=axes[0])
-      x2 = K.l2_normalize(x2, axis=axes[1])
+      x1 = nn.l2_normalize(x1, axis=axes[0])
+      x2 = nn.l2_normalize(x2, axis=axes[1])
     output = K.batch_dot(x1, x2, axes)
     return output
 
index 4366b65..72dc7a1 100644 (file)
@@ -166,7 +166,7 @@ class AlphaDropout(Layer):
         scale = 1.0507009873554804934193349852946
         alpha_p = -alpha * scale
 
-        kept_idx = K.greater_equal(
+        kept_idx = math_ops.greater_equal(
             K.random_uniform(noise_shape, seed=seed), rate)
         kept_idx = math_ops.cast(kept_idx, K.floatx())
 
index bd7c42e..7f9f77c 100644 (file)
@@ -510,7 +510,8 @@ class RNN(Layer):
     # shape of initial_state = (samples, timesteps, input_dim)
     initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2))
     # shape of initial_state = (samples,)
-    initial_state = K.expand_dims(initial_state)  # (samples, 1)
+    initial_state = array_ops.expand_dims(initial_state, axis=-1)
+    # shape of initial_state = (samples, 1)
     if hasattr(self.cell.state_size, '__len__'):
       return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size]
     else:
@@ -2357,7 +2358,8 @@ class Recurrent(Layer):
     # shape of initial_state = (samples, timesteps, input_dim)
     initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2))
     # shape of initial_state = (samples,)
-    initial_state = K.expand_dims(initial_state)  # (samples, 1)
+    initial_state = array_ops.expand_dims(initial_state, axis=-1)
+    # shape of initial_state = (samples, 1)
     initial_state = K.tile(initial_state, [1,
                                            self.units])  # (samples, output_dim)
     initial_state = [initial_state for _ in range(len(self.states))]
index 12f3361..c510e46 100644 (file)
@@ -214,7 +214,7 @@ class TimeDistributed(Wrapper):
       # Shape: (num_samples * timesteps, ...). And track the
       # transformation in self._input_map.
       input_uid = tf_layers_util.object_list_uid(inputs)
-      inputs = K.reshape(inputs, (-1,) + input_shape[2:])
+      inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:])
       self._input_map[input_uid] = inputs
       # (num_samples * timesteps, ...)
       y = self.layer.call(inputs, **kwargs)
@@ -222,7 +222,7 @@ class TimeDistributed(Wrapper):
         uses_learning_phase = y._uses_learning_phase
       # Shape: (num_samples, timesteps, ...)
       output_shape = self.compute_output_shape(input_shape).as_list()
-      y = K.reshape(y, (-1, input_length) + tuple(output_shape[2:]))
+      y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:]))
 
     # Apply activity regularizer if any:
     if (hasattr(self.layer, 'activity_regularizer') and
index 859bda0..1d634d3 100644 (file)
@@ -25,51 +25,54 @@ from tensorflow.python.keras._impl.keras import backend as K
 from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
 from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
 from tensorflow.python.util.tf_export import tf_export
 
 
 @tf_export('keras.metrics.mean_squared_error',
            'keras.losses.mean_squared_error')
 def mean_squared_error(y_true, y_pred):
-  return K.mean(K.square(y_pred - y_true), axis=-1)
+  return K.mean(math_ops.square(y_pred - y_true), axis=-1)
 
 
 @tf_export('keras.metrics.mean_absolute_error',
            'keras.losses.mean_absolute_error')
 def mean_absolute_error(y_true, y_pred):
-  return K.mean(K.abs(y_pred - y_true), axis=-1)
+  return K.mean(math_ops.abs(y_pred - y_true), axis=-1)
 
 
 @tf_export('keras.metrics.mean_absolute_percentage_error',
            'keras.losses.mean_absolute_percentage_error')
 def mean_absolute_percentage_error(y_true, y_pred):
-  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true), K.epsilon(), None))
+  diff = math_ops.abs(
+      (y_true - y_pred) / K.clip(math_ops.abs(y_true), K.epsilon(), None))
   return 100. * K.mean(diff, axis=-1)
 
 
 @tf_export('keras.metrics.mean_squared_logarithmic_error',
            'keras.losses.mean_squared_logarithmic_error')
 def mean_squared_logarithmic_error(y_true, y_pred):
-  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
-  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
-  return K.mean(K.square(first_log - second_log), axis=-1)
+  first_log = math_ops.log(K.clip(y_pred, K.epsilon(), None) + 1.)
+  second_log = math_ops.log(K.clip(y_true, K.epsilon(), None) + 1.)
+  return K.mean(math_ops.square(first_log - second_log), axis=-1)
 
 
 @tf_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge')
 def squared_hinge(y_true, y_pred):
-  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)
+  return K.mean(
+      math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1)
 
 
 @tf_export('keras.metrics.hinge', 'keras.losses.hinge')
 def hinge(y_true, y_pred):
-  return K.mean(K.maximum(1. - y_true * y_pred, 0.), axis=-1)
+  return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1)
 
 
 @tf_export('keras.losses.categorical_hinge')
 def categorical_hinge(y_true, y_pred):
   pos = math_ops.reduce_sum(y_true * y_pred, axis=-1)
-  neg = K.max((1. - y_true) * y_pred, axis=-1)
-  return K.maximum(0., neg - pos + 1.)
+  neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1)
+  return math_ops.maximum(0., neg - pos + 1.)
 
 
 @tf_export('keras.losses.logcosh')
@@ -90,7 +93,7 @@ def logcosh(y_true, y_pred):
   """
 
   def _logcosh(x):
-    return x + K.softplus(-2. * x) - K.log(2.)
+    return x + nn.softplus(-2. * x) - math_ops.log(2.)
 
   return K.mean(_logcosh(y_pred - y_true), axis=-1)
 
@@ -118,18 +121,18 @@ def binary_crossentropy(y_true, y_pred):
 def kullback_leibler_divergence(y_true, y_pred):
   y_true = K.clip(y_true, K.epsilon(), 1)
   y_pred = K.clip(y_pred, K.epsilon(), 1)
-  return math_ops.reduce_sum(y_true * K.log(y_true / y_pred), axis=-1)
+  return math_ops.reduce_sum(y_true * math_ops.log(y_true / y_pred), axis=-1)
 
 
 @tf_export('keras.metrics.poisson', 'keras.losses.poisson')
 def poisson(y_true, y_pred):
-  return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1)
+  return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1)
 
 
 @tf_export('keras.metrics.cosine_proximity', 'keras.losses.cosine_proximity')
 def cosine_proximity(y_true, y_pred):
-  y_true = K.l2_normalize(y_true, axis=-1)
-  y_pred = K.l2_normalize(y_pred, axis=-1)
+  y_true = nn.l2_normalize(y_true, axis=-1)
+  y_pred = nn.l2_normalize(y_pred, axis=-1)
   return -math_ops.reduce_sum(y_true * y_pred, axis=-1)
 
 
index 24192cf..747c3e6 100644 (file)
@@ -38,39 +38,45 @@ from tensorflow.python.keras._impl.keras.losses import squared_hinge
 from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
 from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
 from tensorflow.python.util.tf_export import tf_export
 
 
 @tf_export('keras.metrics.binary_accuracy')
 def binary_accuracy(y_true, y_pred):
-  return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
+  return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1)
 
 
 @tf_export('keras.metrics.categorical_accuracy')
 def categorical_accuracy(y_true, y_pred):
   return math_ops.cast(
-      K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx())
+      math_ops.equal(
+          math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)),
+      K.floatx())
 
 
 def sparse_categorical_accuracy(y_true, y_pred):
   return math_ops.cast(
-      K.equal(
-          K.max(y_true, axis=-1),
-          math_ops.cast(K.argmax(y_pred, axis=-1), K.floatx())), K.floatx())
+      math_ops.equal(
+          math_ops.reduce_max(y_true, axis=-1),
+          math_ops.cast(math_ops.argmax(y_pred, axis=-1), K.floatx())),
+      K.floatx())
 
 
 @tf_export('keras.metrics.top_k_categorical_accuracy')
 def top_k_categorical_accuracy(y_true, y_pred, k=5):
-  return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
+  return K.mean(
+      nn.in_top_k(y_pred, math_ops.argmax(y_true, axis=-1), k), axis=-1)
 
 
 @tf_export('keras.metrics.sparse_top_k_categorical_accuracy')
 def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
   return K.mean(
-      K.in_top_k(y_pred, math_ops.cast(K.max(y_true, axis=-1), 'int32'), k),
+      nn.in_top_k(y_pred,
+                  math_ops.cast(math_ops.reduce_max(y_true, axis=-1), 'int32'),
+                  k),
       axis=-1)
 
-
 # Aliases
 
 mse = MSE = mean_squared_error
index 2b73e0c..9deaab0 100644 (file)
@@ -107,9 +107,8 @@ class KerasMetricsTest(test.TestCase):
                 completion of the batch.
         """
         y_true = math_ops.cast(y_true, 'int32')
-        y_pred = math_ops.cast(keras.backend.round(y_pred), 'int32')
-        correct_preds = math_ops.cast(
-            keras.backend.equal(y_pred, y_true), 'int32')
+        y_pred = math_ops.cast(math_ops.round(y_pred), 'int32')
+        correct_preds = math_ops.cast(math_ops.equal(y_pred, y_true), 'int32')
         true_pos = math_ops.cast(
             math_ops.reduce_sum(correct_preds * y_true), 'int32')
         current_true_pos = self.true_positives * 1
index dc0e472..9f383de 100644 (file)
@@ -119,7 +119,8 @@ class Optimizer(object):
                        'Common ops without gradient: '
                        'K.argmax, K.round, K.eval.')
     if hasattr(self, 'clipnorm') and self.clipnorm > 0:
-      norm = K.sqrt(sum([math_ops.reduce_sum(K.square(g)) for g in grads]))
+      norm = K.sqrt(
+          sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads]))
       grads = [clip_norm(g, self.clipnorm, norm) for g in grads]
     if hasattr(self, 'clipvalue') and self.clipvalue > 0:
       grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads]
@@ -288,7 +289,7 @@ class RMSprop(Optimizer):
 
     for p, g, a in zip(params, grads, accumulators):
       # update accumulator
-      new_a = self.rho * a + (1. - self.rho) * K.square(g)
+      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
       self.updates.append(state_ops.assign(a, new_a))
       new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
 
@@ -349,7 +350,7 @@ class Adagrad(Optimizer):
                                                 K.dtype(self.decay))))
 
     for p, g, a in zip(params, grads, accumulators):
-      new_a = a + K.square(g)  # update accumulator
+      new_a = a + math_ops.square(g)  # update accumulator
       self.updates.append(state_ops.assign(a, new_a))
       new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
 
@@ -414,7 +415,7 @@ class Adadelta(Optimizer):
 
     for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
       # update accumulator
-      new_a = self.rho * a + (1. - self.rho) * K.square(g)
+      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
       self.updates.append(state_ops.assign(a, new_a))
 
       # use the new accumulator and the *old* delta_accumulator
@@ -428,7 +429,7 @@ class Adadelta(Optimizer):
       self.updates.append(state_ops.assign(p, new_p))
 
       # update delta_accumulator
-      new_d_a = self.rho * d_a + (1 - self.rho) * K.square(update)
+      new_d_a = self.rho * d_a + (1 - self.rho) * math_ops.square(update)
       self.updates.append(state_ops.assign(d_a, new_d_a))
     return self.updates
 
@@ -494,7 +495,8 @@ class Adam(Optimizer):
 
     t = math_ops.cast(self.iterations, K.floatx()) + 1
     lr_t = lr * (
-        K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)))
+        K.sqrt(1. - math_ops.pow(self.beta_2, t)) /
+        (1. - math_ops.pow(self.beta_1, t)))
 
     ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
     vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
@@ -506,9 +508,9 @@ class Adam(Optimizer):
 
     for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
       m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
-      v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
+      v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g)
       if self.amsgrad:
-        vhat_t = K.maximum(vhat, v_t)
+        vhat_t = math_ops.maximum(vhat, v_t)
         p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
         self.updates.append(state_ops.assign(vhat, vhat_t))
       else:
@@ -583,7 +585,7 @@ class Adamax(Optimizer):
                                                 K.dtype(self.decay))))
 
     t = math_ops.cast(self.iterations, K.floatx()) + 1
-    lr_t = lr / (1. - K.pow(self.beta_1, t))
+    lr_t = lr / (1. - math_ops.pow(self.beta_1, t))
 
     shapes = [K.int_shape(p) for p in params]
     # zero init of 1st moment
@@ -595,7 +597,7 @@ class Adamax(Optimizer):
     for p, g, m, u in zip(params, grads, ms, us):
 
       m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
-      u_t = K.maximum(self.beta_2 * u, K.abs(g))
+      u_t = math_ops.maximum(self.beta_2 * u, math_ops.abs(g))
       p_t = p - lr_t * m_t / (u_t + self.epsilon)
 
       self.updates.append(state_ops.assign(m, m_t))
@@ -666,10 +668,11 @@ class Nadam(Optimizer):
 
     # Due to the recommendations in [2], i.e. warming momentum schedule
     momentum_cache_t = self.beta_1 * (
-        1. - 0.5 * (K.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
+        1. - 0.5 *
+        (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
     momentum_cache_t_1 = self.beta_1 * (
         1. - 0.5 *
-        (K.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
+        (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
     m_schedule_new = self.m_schedule * momentum_cache_t
     m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
     self.updates.append((self.m_schedule, m_schedule_new))
@@ -685,8 +688,8 @@ class Nadam(Optimizer):
       g_prime = g / (1. - m_schedule_new)
       m_t = self.beta_1 * m + (1. - self.beta_1) * g
       m_t_prime = m_t / (1. - m_schedule_next)
-      v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g)
-      v_t_prime = v_t / (1. - K.pow(self.beta_2, t))
+      v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
+      v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
       m_t_bar = (
           1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
 
index fdb9d33..74c37d3 100644 (file)
@@ -56,9 +56,9 @@ class L1L2(Regularizer):
   def __call__(self, x):
     regularization = 0.
     if self.l1:
-      regularization += math_ops.reduce_sum(self.l1 * K.abs(x))
+      regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x))
     if self.l2:
-      regularization += math_ops.reduce_sum(self.l2 * K.square(x))
+      regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x))
     return regularization
 
   def get_config(self):