Remove zipped argument, and use an implicit dispatch mechanism
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 7 Apr 2018 00:13:13 +0000 (17:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 7 Apr 2018 00:15:57 +0000 (17:15 -0700)
PiperOrigin-RevId: 191962157

tensorflow/contrib/lite/build_def.bzl
tensorflow/contrib/lite/testing/generate_examples.py

index 2813d1c..b8f6b7f 100644 (file)
@@ -200,8 +200,7 @@ def gen_zipped_test_files(name, files):
     native.genrule(
         name = name + "_" + f + ".files",
         cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco
-               + " --zip_to_output " + f +
-               " $(@D) zipped"),
+               + " --zip_to_output " + f + " $(@D)"),
         outs = [out_file],
         tools = [
             ":generate_examples",
index 8045052..f919517 100644 (file)
 
 Usage:
 
-generate_examples <output directory> zipped
+generate_examples <output directory>
 
 bazel run //tensorflow/contrib/lite/testing:generate_examples
-    third_party/tensorflow/contrib/lite/testing/generated_examples zipped
 """
 from __future__ import absolute_import
 from __future__ import division
@@ -52,8 +51,6 @@ from tensorflow.python.ops import rnn
 parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
 parser.add_argument("output_path",
                     help="Directory where the outputs will be go.")
-# TODO(ahentz): remove this flag
-parser.add_argument("type", help="zipped")
 parser.add_argument("--zip_to_output",
                     type=str,
                     help="Particular zip to output.",
@@ -543,6 +540,18 @@ def make_pool_tests(pool_op_in):
   return f
 
 
+def make_l2_pool_tests(zip_path):
+  make_pool_tests(make_l2_pool)(zip_path)
+
+
+def make_avg_pool_tests(zip_path):
+  make_pool_tests(tf.nn.avg_pool)(zip_path)
+
+
+def make_max_pool_tests(zip_path):
+  make_pool_tests(tf.nn.max_pool)(zip_path)
+
+
 def make_relu_tests(zip_path):
   """Make a set of tests to do relu."""
 
@@ -902,6 +911,22 @@ def make_binary_op_tests_func(binary_operator):
   return lambda zip_path: make_binary_op_tests(zip_path, binary_operator)
 
 
+def make_add_tests(zip_path):
+  make_binary_op_tests(zip_path, tf.add)
+
+
+def make_div_tests(zip_path):
+  make_binary_op_tests(zip_path, tf.div)
+
+
+def make_sub_tests(zip_path):
+  make_binary_op_tests(zip_path, tf.subtract)
+
+
+def make_mul_tests(zip_path):
+  make_binary_op_tests(zip_path, tf.multiply)
+
+
 def make_gather_tests(zip_path):
   """Make a set of tests to do gather."""
 
@@ -1169,7 +1194,7 @@ def make_split_tests(zip_path):
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
 
 
-def make_concatenation_tests(zip_path):
+def make_concat_tests(zip_path):
   """Make a set of tests to do concatenation."""
 
   test_parameters = [{
@@ -1966,69 +1991,26 @@ def main(unused_args):
       if not os.path.isdir(x):
         raise RuntimeError("Failed to create dir %r" % x)
 
-  if FLAGS.type == "zipped":
-    opstest_path = os.path.join(FLAGS.output_path)
-    mkdir_if_not_exist(opstest_path)
-    def _path(filename):
-      return os.path.join(opstest_path, filename)
-
-    dispatch = {
-        "control_dep.zip": make_control_dep_tests,
-        "add.zip": make_binary_op_tests_func(tf.add),
-        "space_to_batch_nd.zip": make_space_to_batch_nd_tests,
-        "div.zip": make_binary_op_tests_func(tf.div),
-        "sub.zip": make_binary_op_tests_func(tf.subtract),
-        "batch_to_space_nd.zip": make_batch_to_space_nd_tests,
-        "conv.zip": make_conv_tests,
-        "constant.zip": make_constant_tests,
-        "depthwiseconv.zip": make_depthwiseconv_tests,
-        "concat.zip": make_concatenation_tests,
-        "fully_connected.zip": make_fully_connected_tests,
-        "global_batch_norm.zip": make_global_batch_norm_tests,
-        "gather.zip": make_gather_tests,
-        "fused_batch_norm.zip": make_fused_batch_norm_tests,
-        "l2norm.zip": make_l2norm_tests,
-        "local_response_norm.zip": make_local_response_norm_tests,
-        "mul.zip": make_binary_op_tests_func(tf.multiply),
-        "relu.zip": make_relu_tests,
-        "relu1.zip": make_relu1_tests,
-        "relu6.zip": make_relu6_tests,
-        "prelu.zip": make_prelu_tests,
-        "l2_pool.zip": make_pool_tests(make_l2_pool),
-        "avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
-        "max_pool.zip": make_pool_tests(tf.nn.max_pool),
-        "pad.zip": make_pad_tests,
-        "reshape.zip": make_reshape_tests,
-        "resize_bilinear.zip": make_resize_bilinear_tests,
-        "sigmoid.zip": make_sigmoid_tests,
-        "softmax.zip": make_softmax_tests,
-        "space_to_depth.zip": make_space_to_depth_tests,
-        "topk.zip": make_topk_tests,
-        "split.zip": make_split_tests,
-        "transpose.zip": make_transpose_tests,
-        "mean.zip": make_mean_tests,
-        "squeeze.zip": make_squeeze_tests,
-        "strided_slice.zip": make_strided_slice_tests,
-        "exp.zip": make_exp_tests,
-        "log_softmax.zip": make_log_softmax_tests,
-        "lstm.zip": make_lstm_tests,
-        "maximum.zip": make_maximum_tests,
-    }
-    out = FLAGS.zip_to_output
-    bin_path = FLAGS.toco
-    if out in dispatch:
-      dispatch[out](_path(out))
-    else:
-      raise RuntimeError("Invalid zip to output %r" % out)
+  opstest_path = os.path.join(FLAGS.output_path)
+  mkdir_if_not_exist(opstest_path)
 
-  else:
-    raise RuntimeError("Invalid argument for type of generation.")
+  out = FLAGS.zip_to_output
+  bin_path = FLAGS.toco
+  test_function = ("make_%s_tests" % out.replace(".zip", ""))
+  if test_function not in globals():
+    raise RuntimeError("Can't find a test function to create %r. Tried %r" %
+                       (out, test_function))
+
+  # TODO(ahentz): accessing globals() is not very elegant. We should either
+  # break this file into multiple tests or use decorator-based registration to
+  # avoid using globals().
+  globals()[test_function](os.path.join(opstest_path, out))
 
 
 if __name__ == "__main__":
   FLAGS, unparsed = parser.parse_known_args()
 
   if unparsed:
-    print("Usage: %s <path out> zipped <zip file to generate>")
+    print("Usage: %s <path out> <zip file to generate>")
   else:
     tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)