Internal Change.
authorMichael Case <mikecase@google.com>
Wed, 9 May 2018 19:05:18 +0000 (12:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 20:12:15 +0000 (13:12 -0700)
PiperOrigin-RevId: 196007623

tensorflow/python/estimator/canned/dnn.py

index e7fbf8e..1feac36 100644 (file)
@@ -126,7 +126,8 @@ def _dnn_model_fn(features,
                   activation_fn=nn.relu,
                   dropout=None,
                   input_layer_partitioner=None,
-                  config=None):
+                  config=None,
+                  tpu_estimator_spec=False):
   """Deep Neural Net model_fn.
 
   Args:
@@ -147,6 +148,8 @@ def _dnn_model_fn(features,
     input_layer_partitioner: Partitioner for input layer. Defaults
       to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
     config: `RunConfig` object to configure the runtime settings.
+    tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or
+      or `model_fn.EstimatorSpec` instance.
 
   Returns:
     An `EstimatorSpec` instance.
@@ -154,59 +157,6 @@ def _dnn_model_fn(features,
   Raises:
     ValueError: If features has the wrong type.
   """
-  tpu_estimator_spec = _tpu_dnn_model_fn(
-      features=features,
-      labels=labels,
-      mode=mode,
-      head=head,
-      hidden_units=hidden_units,
-      feature_columns=feature_columns,
-      optimizer=optimizer,
-      activation_fn=activation_fn,
-      dropout=dropout,
-      input_layer_partitioner=input_layer_partitioner,
-      config=config)
-  return tpu_estimator_spec.as_estimator_spec()
-
-
-def _tpu_dnn_model_fn(features,
-                      labels,
-                      mode,
-                      head,
-                      hidden_units,
-                      feature_columns,
-                      optimizer='Adagrad',
-                      activation_fn=nn.relu,
-                      dropout=None,
-                      input_layer_partitioner=None,
-                      config=None):
-  """Deep Neural Net model_fn for TPUEstimator.
-
-  Args:
-    features: dict of `Tensor`.
-    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
-      dtype `int32` or `int64` in the range `[0, n_classes)`.
-    mode: Defines whether this is training, evaluation or prediction.
-      See `ModeKeys`.
-    head: A `head_lib._Head` instance.
-    hidden_units: Iterable of integer number of hidden units per layer.
-    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
-    optimizer: String, `tf.Optimizer` object, or callable that creates the
-      optimizer to use for training. If not specified, will use the Adagrad
-      optimizer with a default learning rate of 0.05.
-    activation_fn: Activation function applied to each layer.
-    dropout: When not `None`, the probability we will drop out a given
-      coordinate.
-    input_layer_partitioner: Partitioner for input layer. Defaults
-      to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
-    config: `RunConfig` object to configure the runtime settings.
-
-  Returns:
-    A `model_fn.TPUEstimatorSpec` instance.
-
-  Raises:
-    ValueError: If features has the wrong type.
-  """
   if not isinstance(features, dict):
     raise ValueError('features should be a dictionary of `Tensor`s. '
                      'Given type: {}'.format(type(features)))
@@ -235,12 +185,20 @@ def _tpu_dnn_model_fn(features,
         input_layer_partitioner=input_layer_partitioner)
     logits = logit_fn(features=features, mode=mode)
 
-    return head._create_tpu_estimator_spec(  # pylint: disable=protected-access
-        features=features,
-        mode=mode,
-        labels=labels,
-        optimizer=optimizer,
-        logits=logits)
+    if tpu_estimator_spec:
+      return head._create_tpu_estimator_spec(  # pylint: disable=protected-access
+          features=features,
+          mode=mode,
+          labels=labels,
+          optimizer=optimizer,
+          logits=logits)
+    else:
+      return head.create_estimator_spec(
+          features=features,
+          mode=mode,
+          labels=labels,
+          optimizer=optimizer,
+          logits=logits)
 
 
 @tf_export('estimator.DNNClassifier')