from tensorflow.python.estimator import export as export_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
estimator = estimator_lib.Estimator(
keras_model_fn, model_dir=model_dir, config=config)
+ old_session = K._SESSION
# Pass the config into keras backend's default session.
sess = session.Session(config=estimator._session_config)
K.set_session(sess)
+ try:
+ keras_weights = keras_model.get_weights()
+ except errors.FailedPreconditionError as e:
+ if old_session is None:
+ raise e
+ logging.warning(
+ 'The Keras backend session has already been '
+ 'set. The _session_config passed to model_to_estimator is not used.')
+ K.set_session(old_session)
+ keras_weights = keras_model.get_weights()
- keras_weights = keras_model.get_weights()
if keras_model._is_graph_network:
# TODO(yifeif): move checkpoint initialization to scaffold.init_fn
_save_first_checkpoint(keras_model,
._config.gpu_options.per_process_gpu_memory_fraction,
gpu_options.per_process_gpu_memory_fraction)
+ def test_pretrained_weights(self):
+ keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ keras_model.train_on_batch(
+ np.random.random((10,) + _INPUT_SIZE), np.random.random((10,
+ _NUM_CLASS)))
+ weights = keras_model.get_weights()
+ keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
+ keras_model.set_weights(weights)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+ keras.estimator.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+
if __name__ == '__main__':
test.main()