import os
import sys
import time
+import urllib
import six
import tensorflow as tf
return rgb, chars, length
+def maybe_download(filename, work_directory, source_url):
+ """Download the data from source url, unless it's already here.
+
+ Args:
+ filename: string, name of the file in the directory.
+ work_directory: string, path to working directory.
+ source_url: url to download from if file doesn't exist.
+
+ Returns:
+ Path to resulting file.
+ """
+ if not tf.gfile.Exists(work_directory):
+ tf.gfile.MakeDirs(work_directory)
+ filepath = os.path.join(work_directory, filename)
+ if not tf.gfile.Exists(filepath):
+ temp_file_name, _ = urllib.request.urlretrieve(source_url)
+ tf.gfile.Copy(temp_file_name, filepath)
+ with tf.gfile.GFile(filepath) as f:
+ size = f.size()
+ print("Successfully downloaded", filename, size, "bytes.")
+ return filepath
+
+
def load_dataset(data_dir, url, batch_size):
"""Loads the colors data at path into a PaddedDataset."""
# Downloads data at url into data_dir/basename(url). The dataset has a header
# row (color_name, r, g, b) followed by comma-separated lines.
- path = tf.contrib.learn.datasets.base.maybe_download(
- os.path.basename(url), data_dir, url)
+ path = maybe_download(os.path.basename(url), data_dir, url)
# This chain of commands loads our data by:
# 1. skipping the header; (.skip(1))