Update onnx coverage script for more accurate result (#15029)
authorzrphercule <zrphercule@gmail.com>
Tue, 11 Dec 2018 21:12:23 +0000 (13:12 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 11 Dec 2018 21:14:35 +0000 (13:14 -0800)
Summary:
The coverage of scalar-input test cases were not accurate. This patch fixed that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15029

Differential Revision: D13419764

Pulled By: zrphercule

fbshipit-source-id: a14a5cbef432bea8c9126156f5deb1125e1aeb47

test/onnx/export_onnx_tests_generator.py

index 574d482..39e5ec7 100644 (file)
@@ -33,6 +33,8 @@ def get_test_name(testcase):
 # Take a test case (a dict) as input, return the input for the module.
 def gen_input(testcase):
     if "input_size" in testcase:
+        if testcase["input_size"] == () and "desc" in testcase and testcase["desc"][-6:] == "scalar":
+            testcase["input_size"] = (1,)
         return Variable(torch.randn(*testcase["input_size"]))
     elif "input_fn" in testcase:
         input = testcase["input_fn"]()
@@ -91,6 +93,8 @@ def convert_tests(testcases, sets=1):
         test_name = get_test_name(t)
         module = gen_module(t)
         module_name = str(module).split("(")[0]
+        if (module_name != "LogSoftmax"):
+            continue
         if (module_name == "FunctionalModule"):
             FunctionalModule_nums += 1
         else: