unify cpp tests (#17947)
authorMichael Suo <suo@fb.com>
Wed, 13 Mar 2019 04:31:59 +0000 (21:31 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 13 Mar 2019 04:35:40 +0000 (21:35 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17947

Instead of having a gtest and a no-gtest file that you have to remember to register tests in, add a single registration point and use some macro magic to make it work for both gtest and non-gtest builds

Reviewed By: eellison

Differential Revision: D14431302

fbshipit-source-id: e1abac135992577a943eaa7abcc81a6ed31fa6e5

test/cpp/jit/CMakeLists.txt
test/cpp/jit/gtest.cpp [deleted file]
test/cpp/jit/no-gtest.cpp [deleted file]
test/cpp/jit/test.cpp [new file with mode: 0644]
test/cpp/jit/test_base.h
torch/CMakeLists.txt

index e1c5231..66860eb 100644 (file)
@@ -2,7 +2,7 @@ set(JIT_TEST_ROOT ${TORCH_ROOT}/test/cpp/jit)
 
 add_executable(test_jit
   ${TORCH_ROOT}/test/cpp/common/main.cpp
-  ${JIT_TEST_ROOT}/gtest.cpp)
+  ${JIT_TEST_ROOT}/test.cpp)
 
 target_link_libraries(test_jit PRIVATE torch gtest)
 target_compile_definitions(test_jit PRIVATE USE_GTEST)
diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp
deleted file mode 100644 (file)
index 1186dd9..0000000
+++ /dev/null
@@ -1,61 +0,0 @@
-#include <gtest/gtest.h>
-
-#include <test/cpp/jit/test_alias_analysis.h>
-#include <test/cpp/jit/test_class_parser.h>
-#include <test/cpp/jit/test_constant_pooling.h>
-#include <test/cpp/jit/test_irparser.h>
-#include <test/cpp/jit/test_misc.h>
-#include <test/cpp/jit/test_netdef_converter.h>
-
-using namespace torch;
-using namespace torch::jit;
-using namespace torch::jit::script;
-
-#define JIT_TEST(name)  \
-  TEST(JitTest, name) { \
-    test##name();       \
-  }
-
-JIT_TEST(ADFormulas)
-JIT_TEST(Attributes)
-JIT_TEST(Blocks)
-JIT_TEST(CodeTemplate)
-JIT_TEST(ControlFlow)
-JIT_TEST(CreateAutodiffSubgraphs)
-JIT_TEST(CustomOperators)
-JIT_TEST(Differentiate)
-JIT_TEST(DifferentiateWithRequiresGrad)
-JIT_TEST(DynamicDAG)
-JIT_TEST(EvalModeForLoadedModule)
-JIT_TEST(FromQualString)
-JIT_TEST(InternedStrings)
-JIT_TEST(IValue)
-JIT_TEST(Proto)
-JIT_TEST(RegisterFusionCachesKernel)
-JIT_TEST(SchemaParser)
-JIT_TEST(TopologicalIndex)
-JIT_TEST(TopologicalMove)
-JIT_TEST(SubgraphUtils)
-JIT_TEST(AliasAnalysis)
-JIT_TEST(WriteTracking)
-JIT_TEST(Wildcards)
-JIT_TEST(MemoryDAG)
-JIT_TEST(IRParser)
-JIT_TEST(ConstantPooling)
-
-JIT_TEST(NetDefConverter)
-
-JIT_TEST(THNNConv)
-JIT_TEST(ATenNativeBatchNorm)
-JIT_TEST(NoneSchemaMatch)
-JIT_TEST(ClassParser)
-
-#define JIT_TEST_CUDA(name)    \
-  TEST(JitTest, name##_CUDA) { \
-    test##name();              \
-  }
-
-JIT_TEST_CUDA(ArgumentSpec)
-JIT_TEST_CUDA(Fusion)
-JIT_TEST_CUDA(GraphExecutor)
-JIT_TEST_CUDA(Interp)
diff --git a/test/cpp/jit/no-gtest.cpp b/test/cpp/jit/no-gtest.cpp
deleted file mode 100644 (file)
index 845a38e..0000000
+++ /dev/null
@@ -1,54 +0,0 @@
-#include <test/cpp/jit/test_alias_analysis.h>
-#include <test/cpp/jit/test_constant_pooling.h>
-#include <test/cpp/jit/test_class_parser.h>
-#include <test/cpp/jit/test_irparser.h>
-#include <test/cpp/jit/test_misc.h>
-#include <test/cpp/jit/test_netdef_converter.h>
-
-#include <sstream>
-#include <string>
-
-using namespace torch::jit::script;
-namespace torch {
-namespace jit {
-void runJITCPPTests() {
-  testNoneSchemaMatch();
-  testAutogradProfiler();
-  testADFormulas();
-  testArgumentSpec();
-  testAttributes();
-  testBlocks();
-  testCodeTemplate();
-  testControlFlow();
-  testCreateAutodiffSubgraphs();
-  testCustomOperators();
-  testDifferentiate();
-  testDifferentiateWithRequiresGrad();
-  testDynamicDAG();
-  testEvalModeForLoadedModule();
-  testFromQualString();
-  testFusion();
-  testGraphExecutor();
-  testInternedStrings();
-  testInterp();
-  testIValue();
-  testProto();
-  testSchemaParser();
-  testTopologicalIndex();
-  testTopologicalMove();
-  testSubgraphUtils();
-  testTHNNConv();
-  testATenNativeBatchNorm();
-  testRegisterFusionCachesKernel();
-  testAliasAnalysis();
-  testWriteTracking();
-  testWildcards();
-  testMemoryDAG();
-  testNetDefConverter();
-  testIRParser();
-  testConstantPooling();
-  testClassParser();
-}
-
-} // namespace jit
-} // namespace torch
diff --git a/test/cpp/jit/test.cpp b/test/cpp/jit/test.cpp
new file mode 100644 (file)
index 0000000..46fafa4
--- /dev/null
@@ -0,0 +1,87 @@
+#if defined(USE_GTEST)
+#include <gtest/gtest.h>
+#endif
+
+// To add a new test file:
+// 1. Add a test_foo.h file in this directory
+// 2. include test_base.h
+// 3. Write your tests as pure functions starting with "test", like "testFoo"
+// 4. Include test_foo.h here and add it to the appropriate macro listing
+#include <test/cpp/jit/test_alias_analysis.h>
+#include <test/cpp/jit/test_class_parser.h>
+#include <test/cpp/jit/test_constant_pooling.h>
+#include <test/cpp/jit/test_irparser.h>
+#include <test/cpp/jit/test_misc.h>
+#include <test/cpp/jit/test_netdef_converter.h>
+
+using namespace torch::jit::script;
+namespace torch {
+namespace jit {
+#define TH_FORALL_TESTS(_)         \
+  _(ADFormulas)                    \
+  _(Attributes)                    \
+  _(Blocks)                        \
+  _(CodeTemplate)                  \
+  _(ControlFlow)                   \
+  _(CreateAutodiffSubgraphs)       \
+  _(CustomOperators)               \
+  _(Differentiate)                 \
+  _(DifferentiateWithRequiresGrad) \
+  _(DynamicDAG)                    \
+  _(FromQualString)                \
+  _(InternedStrings)               \
+  _(IValue)                        \
+  _(Proto)                         \
+  _(RegisterFusionCachesKernel)    \
+  _(SchemaParser)                  \
+  _(TopologicalIndex)              \
+  _(TopologicalMove)               \
+  _(SubgraphUtils)                 \
+  _(AliasAnalysis)                 \
+  _(WriteTracking)                 \
+  _(Wildcards)                     \
+  _(MemoryDAG)                     \
+  _(IRParser)                      \
+  _(ConstantPooling)               \
+  _(NetDefConverter)               \
+  _(THNNConv)                      \
+  _(ATenNativeBatchNorm)           \
+  _(NoneSchemaMatch)               \
+  _(ClassParser)
+
+#define TH_FORALL_TESTS_CUDA(_) \
+  _(ArgumentSpec)               \
+  _(Fusion)                     \
+  _(GraphExecutor)              \
+  _(Interp)
+
+#if defined(USE_GTEST)
+
+#define JIT_GTEST(name)  \
+  TEST(JitTest, name) { \
+    test##name();       \
+  }
+TH_FORALL_TESTS(JIT_GTEST)
+#undef JIT_TEST
+
+#define JIT_GTEST_CUDA(name)    \
+  TEST(JitTest, name##_CUDA) { \
+    test##name();              \
+  }
+TH_FORALL_TESTS_CUDA(JIT_GTEST_CUDA)
+#undef JIT_TEST_CUDA
+#endif
+
+#define JIT_TEST(name) test##name();
+void runJITCPPTests() {
+  TH_FORALL_TESTS(JIT_TEST)
+  TH_FORALL_TESTS_CUDA(JIT_TEST)
+
+  // This test is special since it requires prior setup in python.
+  // So it's included here but not in the pure cpp gtest suite
+  testEvalModeForLoadedModule();
+}
+#undef JIT_TEST
+
+} // namespace jit
+} // namespace torch
index b7cc81e..5289b76 100644 (file)
@@ -2,13 +2,6 @@
 
 // This file defines assertion macros that work in both gtest and non-gtest
 // builds, and has some common includes.
-//
-// To add a new test file:
-// 1. Add a test_foo.h file in this directory
-// 2. include test_base.h
-// 3. Write your tests as pure functions
-// 4. Include test_foo.h in gtest.cpp and no-gtest.cpp and register the tests
-//    there.
 #include "torch/csrc/jit/ir.h"
 #include "torch/csrc/jit/operator.h"
 
index 5bcd344..e53c271 100644 (file)
@@ -191,7 +191,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/fuser/executor.cpp
   ${TORCH_SRC_DIR}/csrc/jit/fuser/codegen.cpp
   ${TORCH_SRC_DIR}/csrc/jit/fuser/fallback.cpp
-  ${TORCH_ROOT}/test/cpp/jit/no-gtest.cpp
+  ${TORCH_ROOT}/test/cpp/jit/test.cpp
   )
 
 if (WIN32)