# Take a test case (a dict) as input, return the input for the module.
def gen_input(testcase):
if "input_size" in testcase:
+ if testcase["input_size"] == () and "desc" in testcase and testcase["desc"][-6:] == "scalar":
+ testcase["input_size"] = (1,)
return Variable(torch.randn(*testcase["input_size"]))
elif "input_fn" in testcase:
input = testcase["input_fn"]()
test_name = get_test_name(t)
module = gen_module(t)
module_name = str(module).split("(")[0]
+ if (module_name != "LogSoftmax"):
+ continue
if (module_name == "FunctionalModule"):
FunctionalModule_nums += 1
else: