Remove use of deprecated API from RNN Colorbot example.
authorAkshay Agrawal <akshayka@google.com>
Thu, 22 Mar 2018 21:24:23 +0000 (14:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 21:27:00 +0000 (14:27 -0700)
PiperOrigin-RevId: 190125356

tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py

index 29f0232..88fffc9 100644 (file)
@@ -60,6 +60,7 @@ import functools
 import os
 import sys
 import time
+import urllib
 
 import six
 import tensorflow as tf
@@ -89,13 +90,35 @@ def parse(line):
   return rgb, chars, length
 
 
+def maybe_download(filename, work_directory, source_url):
+  """Download the data from source url, unless it's already here.
+
+  Args:
+    filename: string, name of the file in the directory.
+    work_directory: string, path to working directory.
+    source_url: url to download from if file doesn't exist.
+
+  Returns:
+    Path to resulting file.
+  """
+  if not tf.gfile.Exists(work_directory):
+    tf.gfile.MakeDirs(work_directory)
+  filepath = os.path.join(work_directory, filename)
+  if not tf.gfile.Exists(filepath):
+    temp_file_name, _ = urllib.request.urlretrieve(source_url)
+    tf.gfile.Copy(temp_file_name, filepath)
+    with tf.gfile.GFile(filepath) as f:
+      size = f.size()
+      print("Successfully downloaded", filename, size, "bytes.")
+  return filepath
+
+
 def load_dataset(data_dir, url, batch_size):
   """Loads the colors data at path into a PaddedDataset."""
 
   # Downloads data at url into data_dir/basename(url). The dataset has a header
   # row (color_name, r, g, b) followed by comma-separated lines.
-  path = tf.contrib.learn.datasets.base.maybe_download(
-      os.path.basename(url), data_dir, url)
+  path = maybe_download(os.path.basename(url), data_dir, url)
 
   # This chain of commands loads our data by:
   #   1. skipping the header; (.skip(1))