From: Pavithra Vijay Date: Wed, 4 Apr 2018 23:53:51 +0000 (-0700) Subject: Replace trivial backend calls with calls to underlying TensorFlow functions - Part 2 X-Git-Tag: tflite-v0.1.7~39^2^2~10 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=faeebb7daef6d1fdd0e4eb3a3e0afedcd2d3350d;p=platform%2Fupstream%2Ftensorflow.git Replace trivial backend calls with calls to underlying TensorFlow functions - Part 2 PiperOrigin-RevId: 191669725 --- diff --git a/tensorflow/python/keras/_impl/keras/activations.py b/tensorflow/python/keras/_impl/keras/activations.py index 74ec373..b518898 100644 --- a/tensorflow/python/keras/_impl/keras/activations.py +++ b/tensorflow/python/keras/_impl/keras/activations.py @@ -24,6 +24,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.layers.base import Layer from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -44,9 +45,9 @@ def softmax(x, axis=-1): """ ndim = K.ndim(x) if ndim == 2: - return K.softmax(x) + return nn.softmax(x) elif ndim > 2: - e = K.exp(x - K.max(x, axis=axis, keepdims=True)) + e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True)) s = math_ops.reduce_sum(e, axis=axis, keepdims=True) return e / s else: @@ -80,12 +81,12 @@ def selu(x): @tf_export('keras.activations.softplus') def softplus(x): - return K.softplus(x) + return nn.softplus(x) @tf_export('keras.activations.softsign') def softsign(x): - return K.softsign(x) + return nn.softsign(x) @tf_export('keras.activations.relu') @@ -95,12 +96,12 @@ def relu(x, alpha=0., max_value=None): @tf_export('keras.activations.tanh') def tanh(x): - return K.tanh(x) + return nn.tanh(x) @tf_export('keras.activations.sigmoid') def sigmoid(x): - return K.sigmoid(x) + return nn.sigmoid(x) @tf_export('keras.activations.hard_sigmoid') diff --git a/tensorflow/python/keras/_impl/keras/constraints.py b/tensorflow/python/keras/_impl/keras/constraints.py index aac4d0f..abe95d8 100644 --- a/tensorflow/python/keras/_impl/keras/constraints.py +++ b/tensorflow/python/keras/_impl/keras/constraints.py @@ -67,7 +67,7 @@ class MaxNorm(Constraint): def __call__(self, w): norms = K.sqrt( - math_ops.reduce_sum(K.square(w), axis=self.axis, keepdims=True)) + math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) desired = K.clip(norms, 0, self.max_value) return w * (desired / (K.epsilon() + norms)) @@ -81,7 +81,7 @@ class NonNeg(Constraint): """ def __call__(self, w): - return w * math_ops.cast(K.greater_equal(w, 0.), K.floatx()) + return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx()) @tf_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm') @@ -108,7 +108,8 @@ class UnitNorm(Constraint): def __call__(self, w): return w / ( K.epsilon() + K.sqrt( - math_ops.reduce_sum(K.square(w), axis=self.axis, keepdims=True))) + math_ops.reduce_sum( + math_ops.square(w), axis=self.axis, keepdims=True))) def get_config(self): return {'axis': self.axis} @@ -152,7 +153,7 @@ class MinMaxNorm(Constraint): def __call__(self, w): norms = K.sqrt( - math_ops.reduce_sum(K.square(w), axis=self.axis, keepdims=True)) + math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) desired = ( self.rate * K.clip(norms, self.min_value, self.max_value) + (1 - self.rate) * norms) diff --git a/tensorflow/python/keras/_impl/keras/engine/training_utils.py b/tensorflow/python/keras/_impl/keras/engine/training_utils.py index 58d2c78..a3fc8ef 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_utils.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_utils.py @@ -451,7 +451,8 @@ def weighted_masked_objective(fn): weight_ndim = K.ndim(weights) score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim))) score_array *= weights - score_array /= K.mean(math_ops.cast(K.not_equal(weights, 0), K.floatx())) + score_array /= K.mean( + math_ops.cast(math_ops.not_equal(weights, 0), K.floatx())) return K.mean(score_array) return weighted diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py index 45b0c6c..11ca89d 100644 --- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py @@ -147,7 +147,7 @@ class PReLU(Layer): if K.backend() == 'theano': neg = ( K.pattern_broadcast(self.alpha, self.param_broadcast) * - (inputs - K.abs(inputs)) * 0.5) + (inputs - math_ops.abs(inputs)) * 0.5) else: neg = -self.alpha * K.relu(-inputs) return pos + neg @@ -233,7 +233,8 @@ class ThresholdedReLU(Layer): self.theta = K.cast_to_floatx(theta) def call(self, inputs, mask=None): - return inputs * math_ops.cast(K.greater(inputs, self.theta), K.floatx()) + return inputs * math_ops.cast( + math_ops.greater(inputs, self.theta), K.floatx()) def get_config(self): config = {'theta': float(self.theta)} diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py index a709a07..c74fc1e 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core.py +++ b/tensorflow/python/keras/_impl/keras/layers/core.py @@ -77,11 +77,11 @@ class Masking(Layer): self.mask_value = mask_value def compute_mask(self, inputs, mask=None): - return K.any(K.not_equal(inputs, self.mask_value), axis=-1) + return K.any(math_ops.not_equal(inputs, self.mask_value), axis=-1) def call(self, inputs): boolean_mask = K.any( - K.not_equal(inputs, self.mask_value), axis=-1, keepdims=True) + math_ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True) return inputs * math_ops.cast(boolean_mask, inputs.dtype) def compute_output_shape(self, input_shape): @@ -416,7 +416,8 @@ class Reshape(Layer): return tensor_shape.TensorShape(output_shape) def call(self, inputs): - return K.reshape(inputs, (array_ops.shape(inputs)[0],) + self.target_shape) + return array_ops.reshape(inputs, + (array_ops.shape(inputs)[0],) + self.target_shape) def get_config(self): config = {'target_shape': self.target_shape} @@ -469,7 +470,7 @@ class Permute(Layer): return tensor_shape.TensorShape(output_shape) def call(self, inputs): - return K.permute_dimensions(inputs, (0,) + self.dims) + return array_ops.transpose(inputs, perm=(0,) + self.dims) def get_config(self): config = {'dims': self.dims} diff --git a/tensorflow/python/keras/_impl/keras/layers/core_test.py b/tensorflow/python/keras/_impl/keras/layers/core_test.py index 2ca816a..551d1b1 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/core_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -159,7 +160,7 @@ class CoreLayersTest(test.TestCase): # test with lambda ld = keras.layers.Lambda( - lambda x: keras.backend.concatenate([keras.backend.square(x), x])) + lambda x: keras.backend.concatenate([math_ops.square(x), x])) config = ld.get_config() ld = keras.layers.Lambda.from_config(config) @@ -235,4 +236,3 @@ class CoreLayersTest(test.TestCase): if __name__ == '__main__': test.main() - diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py index a0fd7a9..540e2d9 100644 --- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py +++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py @@ -128,7 +128,7 @@ class Embedding(Layer): if not self.mask_zero: return None else: - return K.not_equal(inputs, 0) + return math_ops.not_equal(inputs, 0) @shape_type_conversion def compute_output_shape(self, input_shape): diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py index 6290db2..7c87e6c 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge.py @@ -24,6 +24,8 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine.base_layer import Layer from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export @@ -128,7 +130,7 @@ class _Merge(Layer): for x in inputs: x_ndim = K.ndim(x) for _ in range(max_ndim - x_ndim): - x = K.expand_dims(x, 1) + x = array_ops.expand_dims(x, axis=1) reshaped_inputs.append(x) return self._merge_function(reshaped_inputs) else: @@ -140,17 +142,20 @@ class _Merge(Layer): if x_ndim is None: x_shape = array_ops.shape(x) batch_size = x_shape[0] - new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)]) - x_transposed = K.reshape(x, - K.stack([batch_size, - K.prod(x_shape[1:])])) - x_transposed = K.permute_dimensions(x_transposed, (1, 0)) - x_transposed = K.reshape(x_transposed, new_shape) + new_shape = K.concatenate( + [x_shape[1:], + array_ops.expand_dims(batch_size, axis=-1)]) + x_transposed = array_ops.reshape( + x, + array_ops.stack( + [batch_size, math_ops.reduce_prod(x_shape[1:])], axis=0)) + x_transposed = array_ops.transpose(x_transposed, perm=(1, 0)) + x_transposed = array_ops.reshape(x_transposed, new_shape) reshaped_inputs.append(x_transposed) transposed = True elif x_ndim > 1: dims = list(range(1, x_ndim)) + [0] - reshaped_inputs.append(K.permute_dimensions(x, dims)) + reshaped_inputs.append(array_ops.transpose(x, perm=dims)) transposed = True else: # We don't transpose inputs if they are 1D vectors or scalars. @@ -163,14 +168,15 @@ class _Merge(Layer): y_shape = array_ops.shape(y) y_ndim = array_ops.shape(y_shape)[0] batch_size = y_shape[y_ndim - 1] - new_shape = K.concatenate( - [K.expand_dims(batch_size), y_shape[:y_ndim - 1]]) - y = K.reshape(y, (-1, batch_size)) - y = K.permute_dimensions(y, (1, 0)) - y = K.reshape(y, new_shape) + new_shape = K.concatenate([ + array_ops.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1] + ]) + y = array_ops.reshape(y, (-1, batch_size)) + y = array_ops.transpose(y, perm=(1, 0)) + y = array_ops.reshape(y, new_shape) elif y_ndim > 1: dims = [y_ndim - 1] + list(range(y_ndim - 1)) - y = K.permute_dimensions(y, dims) + y = array_ops.transpose(y, perm=dims) return y else: return self._merge_function(inputs) @@ -208,7 +214,7 @@ class _Merge(Layer): 'should have the same length.') if all([m is None for m in mask]): return None - masks = [K.expand_dims(m, 0) for m in mask if m is not None] + masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None] return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False) @@ -326,7 +332,7 @@ class Maximum(_Merge): def _merge_function(self, inputs): output = inputs[0] for i in range(1, len(inputs)): - output = K.maximum(output, inputs[i]) + output = math_ops.maximum(output, inputs[i]) return output @@ -341,7 +347,7 @@ class Minimum(_Merge): def _merge_function(self, inputs): output = inputs[0] for i in range(1, len(inputs)): - output = K.minimum(output, inputs[i]) + output = math_ops.minimum(output, inputs[i]) return output @@ -422,7 +428,7 @@ class Concatenate(_Merge): masks.append(array_ops.ones_like(input_i, dtype='bool')) elif K.ndim(mask_i) < K.ndim(input_i): # Mask is smaller than the input, expand it - masks.append(K.expand_dims(mask_i)) + masks.append(array_ops.expand_dims(mask_i, axis=-1)) else: masks.append(mask_i) concatenated = K.concatenate(masks, axis=self.axis) @@ -512,8 +518,8 @@ class Dot(_Merge): else: axes.append(self.axes[i]) if self.normalize: - x1 = K.l2_normalize(x1, axis=axes[0]) - x2 = K.l2_normalize(x2, axis=axes[1]) + x1 = nn.l2_normalize(x1, axis=axes[0]) + x2 = nn.l2_normalize(x2, axis=axes[1]) output = K.batch_dot(x1, x2, axes) return output diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py index 4366b65..72dc7a1 100644 --- a/tensorflow/python/keras/_impl/keras/layers/noise.py +++ b/tensorflow/python/keras/_impl/keras/layers/noise.py @@ -166,7 +166,7 @@ class AlphaDropout(Layer): scale = 1.0507009873554804934193349852946 alpha_p = -alpha * scale - kept_idx = K.greater_equal( + kept_idx = math_ops.greater_equal( K.random_uniform(noise_shape, seed=seed), rate) kept_idx = math_ops.cast(kept_idx, K.floatx()) diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index bd7c42e..7f9f77c 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -510,7 +510,8 @@ class RNN(Layer): # shape of initial_state = (samples, timesteps, input_dim) initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2)) # shape of initial_state = (samples,) - initial_state = K.expand_dims(initial_state) # (samples, 1) + initial_state = array_ops.expand_dims(initial_state, axis=-1) + # shape of initial_state = (samples, 1) if hasattr(self.cell.state_size, '__len__'): return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size] else: @@ -2357,7 +2358,8 @@ class Recurrent(Layer): # shape of initial_state = (samples, timesteps, input_dim) initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2)) # shape of initial_state = (samples,) - initial_state = K.expand_dims(initial_state) # (samples, 1) + initial_state = array_ops.expand_dims(initial_state, axis=-1) + # shape of initial_state = (samples, 1) initial_state = K.tile(initial_state, [1, self.units]) # (samples, output_dim) initial_state = [initial_state for _ in range(len(self.states))] diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py index 12f3361..c510e46 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py @@ -214,7 +214,7 @@ class TimeDistributed(Wrapper): # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. input_uid = tf_layers_util.object_list_uid(inputs) - inputs = K.reshape(inputs, (-1,) + input_shape[2:]) + inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:]) self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) y = self.layer.call(inputs, **kwargs) @@ -222,7 +222,7 @@ class TimeDistributed(Wrapper): uses_learning_phase = y._uses_learning_phase # Shape: (num_samples, timesteps, ...) output_shape = self.compute_output_shape(input_shape).as_list() - y = K.reshape(y, (-1, input_length) + tuple(output_shape[2:])) + y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:])) # Apply activity regularizer if any: if (hasattr(self.layer, 'activity_regularizer') and diff --git a/tensorflow/python/keras/_impl/keras/losses.py b/tensorflow/python/keras/_impl/keras/losses.py index 859bda0..1d634d3 100644 --- a/tensorflow/python/keras/_impl/keras/losses.py +++ b/tensorflow/python/keras/_impl/keras/losses.py @@ -25,51 +25,54 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export @tf_export('keras.metrics.mean_squared_error', 'keras.losses.mean_squared_error') def mean_squared_error(y_true, y_pred): - return K.mean(K.square(y_pred - y_true), axis=-1) + return K.mean(math_ops.square(y_pred - y_true), axis=-1) @tf_export('keras.metrics.mean_absolute_error', 'keras.losses.mean_absolute_error') def mean_absolute_error(y_true, y_pred): - return K.mean(K.abs(y_pred - y_true), axis=-1) + return K.mean(math_ops.abs(y_pred - y_true), axis=-1) @tf_export('keras.metrics.mean_absolute_percentage_error', 'keras.losses.mean_absolute_percentage_error') def mean_absolute_percentage_error(y_true, y_pred): - diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true), K.epsilon(), None)) + diff = math_ops.abs( + (y_true - y_pred) / K.clip(math_ops.abs(y_true), K.epsilon(), None)) return 100. * K.mean(diff, axis=-1) @tf_export('keras.metrics.mean_squared_logarithmic_error', 'keras.losses.mean_squared_logarithmic_error') def mean_squared_logarithmic_error(y_true, y_pred): - first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.) - second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.) - return K.mean(K.square(first_log - second_log), axis=-1) + first_log = math_ops.log(K.clip(y_pred, K.epsilon(), None) + 1.) + second_log = math_ops.log(K.clip(y_true, K.epsilon(), None) + 1.) + return K.mean(math_ops.square(first_log - second_log), axis=-1) @tf_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge') def squared_hinge(y_true, y_pred): - return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1) + return K.mean( + math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1) @tf_export('keras.metrics.hinge', 'keras.losses.hinge') def hinge(y_true, y_pred): - return K.mean(K.maximum(1. - y_true * y_pred, 0.), axis=-1) + return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1) @tf_export('keras.losses.categorical_hinge') def categorical_hinge(y_true, y_pred): pos = math_ops.reduce_sum(y_true * y_pred, axis=-1) - neg = K.max((1. - y_true) * y_pred, axis=-1) - return K.maximum(0., neg - pos + 1.) + neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1) + return math_ops.maximum(0., neg - pos + 1.) @tf_export('keras.losses.logcosh') @@ -90,7 +93,7 @@ def logcosh(y_true, y_pred): """ def _logcosh(x): - return x + K.softplus(-2. * x) - K.log(2.) + return x + nn.softplus(-2. * x) - math_ops.log(2.) return K.mean(_logcosh(y_pred - y_true), axis=-1) @@ -118,18 +121,18 @@ def binary_crossentropy(y_true, y_pred): def kullback_leibler_divergence(y_true, y_pred): y_true = K.clip(y_true, K.epsilon(), 1) y_pred = K.clip(y_pred, K.epsilon(), 1) - return math_ops.reduce_sum(y_true * K.log(y_true / y_pred), axis=-1) + return math_ops.reduce_sum(y_true * math_ops.log(y_true / y_pred), axis=-1) @tf_export('keras.metrics.poisson', 'keras.losses.poisson') def poisson(y_true, y_pred): - return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1) + return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1) @tf_export('keras.metrics.cosine_proximity', 'keras.losses.cosine_proximity') def cosine_proximity(y_true, y_pred): - y_true = K.l2_normalize(y_true, axis=-1) - y_pred = K.l2_normalize(y_pred, axis=-1) + y_true = nn.l2_normalize(y_true, axis=-1) + y_pred = nn.l2_normalize(y_pred, axis=-1) return -math_ops.reduce_sum(y_true * y_pred, axis=-1) diff --git a/tensorflow/python/keras/_impl/keras/metrics.py b/tensorflow/python/keras/_impl/keras/metrics.py index 24192cf..747c3e6 100644 --- a/tensorflow/python/keras/_impl/keras/metrics.py +++ b/tensorflow/python/keras/_impl/keras/metrics.py @@ -38,39 +38,45 @@ from tensorflow.python.keras._impl.keras.losses import squared_hinge from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export @tf_export('keras.metrics.binary_accuracy') def binary_accuracy(y_true, y_pred): - return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1) + return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1) @tf_export('keras.metrics.categorical_accuracy') def categorical_accuracy(y_true, y_pred): return math_ops.cast( - K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx()) + math_ops.equal( + math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)), + K.floatx()) def sparse_categorical_accuracy(y_true, y_pred): return math_ops.cast( - K.equal( - K.max(y_true, axis=-1), - math_ops.cast(K.argmax(y_pred, axis=-1), K.floatx())), K.floatx()) + math_ops.equal( + math_ops.reduce_max(y_true, axis=-1), + math_ops.cast(math_ops.argmax(y_pred, axis=-1), K.floatx())), + K.floatx()) @tf_export('keras.metrics.top_k_categorical_accuracy') def top_k_categorical_accuracy(y_true, y_pred, k=5): - return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1) + return K.mean( + nn.in_top_k(y_pred, math_ops.argmax(y_true, axis=-1), k), axis=-1) @tf_export('keras.metrics.sparse_top_k_categorical_accuracy') def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): return K.mean( - K.in_top_k(y_pred, math_ops.cast(K.max(y_true, axis=-1), 'int32'), k), + nn.in_top_k(y_pred, + math_ops.cast(math_ops.reduce_max(y_true, axis=-1), 'int32'), + k), axis=-1) - # Aliases mse = MSE = mean_squared_error diff --git a/tensorflow/python/keras/_impl/keras/metrics_test.py b/tensorflow/python/keras/_impl/keras/metrics_test.py index 2b73e0c..9deaab0 100644 --- a/tensorflow/python/keras/_impl/keras/metrics_test.py +++ b/tensorflow/python/keras/_impl/keras/metrics_test.py @@ -107,9 +107,8 @@ class KerasMetricsTest(test.TestCase): completion of the batch. """ y_true = math_ops.cast(y_true, 'int32') - y_pred = math_ops.cast(keras.backend.round(y_pred), 'int32') - correct_preds = math_ops.cast( - keras.backend.equal(y_pred, y_true), 'int32') + y_pred = math_ops.cast(math_ops.round(y_pred), 'int32') + correct_preds = math_ops.cast(math_ops.equal(y_pred, y_true), 'int32') true_pos = math_ops.cast( math_ops.reduce_sum(correct_preds * y_true), 'int32') current_true_pos = self.true_positives * 1 diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py index dc0e472..9f383de 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/_impl/keras/optimizers.py @@ -119,7 +119,8 @@ class Optimizer(object): 'Common ops without gradient: ' 'K.argmax, K.round, K.eval.') if hasattr(self, 'clipnorm') and self.clipnorm > 0: - norm = K.sqrt(sum([math_ops.reduce_sum(K.square(g)) for g in grads])) + norm = K.sqrt( + sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads])) grads = [clip_norm(g, self.clipnorm, norm) for g in grads] if hasattr(self, 'clipvalue') and self.clipvalue > 0: grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] @@ -288,7 +289,7 @@ class RMSprop(Optimizer): for p, g, a in zip(params, grads, accumulators): # update accumulator - new_a = self.rho * a + (1. - self.rho) * K.square(g) + new_a = self.rho * a + (1. - self.rho) * math_ops.square(g) self.updates.append(state_ops.assign(a, new_a)) new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) @@ -349,7 +350,7 @@ class Adagrad(Optimizer): K.dtype(self.decay)))) for p, g, a in zip(params, grads, accumulators): - new_a = a + K.square(g) # update accumulator + new_a = a + math_ops.square(g) # update accumulator self.updates.append(state_ops.assign(a, new_a)) new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) @@ -414,7 +415,7 @@ class Adadelta(Optimizer): for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators): # update accumulator - new_a = self.rho * a + (1. - self.rho) * K.square(g) + new_a = self.rho * a + (1. - self.rho) * math_ops.square(g) self.updates.append(state_ops.assign(a, new_a)) # use the new accumulator and the *old* delta_accumulator @@ -428,7 +429,7 @@ class Adadelta(Optimizer): self.updates.append(state_ops.assign(p, new_p)) # update delta_accumulator - new_d_a = self.rho * d_a + (1 - self.rho) * K.square(update) + new_d_a = self.rho * d_a + (1 - self.rho) * math_ops.square(update) self.updates.append(state_ops.assign(d_a, new_d_a)) return self.updates @@ -494,7 +495,8 @@ class Adam(Optimizer): t = math_ops.cast(self.iterations, K.floatx()) + 1 lr_t = lr * ( - K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) + K.sqrt(1. - math_ops.pow(self.beta_2, t)) / + (1. - math_ops.pow(self.beta_1, t))) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] @@ -506,9 +508,9 @@ class Adam(Optimizer): for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g - v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) + v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g) if self.amsgrad: - vhat_t = K.maximum(vhat, v_t) + vhat_t = math_ops.maximum(vhat, v_t) p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon) self.updates.append(state_ops.assign(vhat, vhat_t)) else: @@ -583,7 +585,7 @@ class Adamax(Optimizer): K.dtype(self.decay)))) t = math_ops.cast(self.iterations, K.floatx()) + 1 - lr_t = lr / (1. - K.pow(self.beta_1, t)) + lr_t = lr / (1. - math_ops.pow(self.beta_1, t)) shapes = [K.int_shape(p) for p in params] # zero init of 1st moment @@ -595,7 +597,7 @@ class Adamax(Optimizer): for p, g, m, u in zip(params, grads, ms, us): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g - u_t = K.maximum(self.beta_2 * u, K.abs(g)) + u_t = math_ops.maximum(self.beta_2 * u, math_ops.abs(g)) p_t = p - lr_t * m_t / (u_t + self.epsilon) self.updates.append(state_ops.assign(m, m_t)) @@ -666,10 +668,11 @@ class Nadam(Optimizer): # Due to the recommendations in [2], i.e. warming momentum schedule momentum_cache_t = self.beta_1 * ( - 1. - 0.5 * (K.pow(K.cast_to_floatx(0.96), t * self.schedule_decay))) + 1. - 0.5 * + (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay))) momentum_cache_t_1 = self.beta_1 * ( 1. - 0.5 * - (K.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay))) + (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay))) m_schedule_new = self.m_schedule * momentum_cache_t m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1 self.updates.append((self.m_schedule, m_schedule_new)) @@ -685,8 +688,8 @@ class Nadam(Optimizer): g_prime = g / (1. - m_schedule_new) m_t = self.beta_1 * m + (1. - self.beta_1) * g m_t_prime = m_t / (1. - m_schedule_next) - v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g) - v_t_prime = v_t / (1. - K.pow(self.beta_2, t)) + v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g) + v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t)) m_t_bar = ( 1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime diff --git a/tensorflow/python/keras/_impl/keras/regularizers.py b/tensorflow/python/keras/_impl/keras/regularizers.py index fdb9d33..74c37d3 100644 --- a/tensorflow/python/keras/_impl/keras/regularizers.py +++ b/tensorflow/python/keras/_impl/keras/regularizers.py @@ -56,9 +56,9 @@ class L1L2(Regularizer): def __call__(self, x): regularization = 0. if self.l1: - regularization += math_ops.reduce_sum(self.l1 * K.abs(x)) + regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x)) if self.l2: - regularization += math_ops.reduce_sum(self.l2 * K.square(x)) + regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x)) return regularization def get_config(self):