add_executable(test_jit
${TORCH_ROOT}/test/cpp/common/main.cpp
- ${JIT_TEST_ROOT}/gtest.cpp)
+ ${JIT_TEST_ROOT}/test.cpp)
target_link_libraries(test_jit PRIVATE torch gtest)
target_compile_definitions(test_jit PRIVATE USE_GTEST)
+++ /dev/null
-#include <gtest/gtest.h>
-
-#include <test/cpp/jit/test_alias_analysis.h>
-#include <test/cpp/jit/test_class_parser.h>
-#include <test/cpp/jit/test_constant_pooling.h>
-#include <test/cpp/jit/test_irparser.h>
-#include <test/cpp/jit/test_misc.h>
-#include <test/cpp/jit/test_netdef_converter.h>
-
-using namespace torch;
-using namespace torch::jit;
-using namespace torch::jit::script;
-
-#define JIT_TEST(name) \
- TEST(JitTest, name) { \
- test##name(); \
- }
-
-JIT_TEST(ADFormulas)
-JIT_TEST(Attributes)
-JIT_TEST(Blocks)
-JIT_TEST(CodeTemplate)
-JIT_TEST(ControlFlow)
-JIT_TEST(CreateAutodiffSubgraphs)
-JIT_TEST(CustomOperators)
-JIT_TEST(Differentiate)
-JIT_TEST(DifferentiateWithRequiresGrad)
-JIT_TEST(DynamicDAG)
-JIT_TEST(EvalModeForLoadedModule)
-JIT_TEST(FromQualString)
-JIT_TEST(InternedStrings)
-JIT_TEST(IValue)
-JIT_TEST(Proto)
-JIT_TEST(RegisterFusionCachesKernel)
-JIT_TEST(SchemaParser)
-JIT_TEST(TopologicalIndex)
-JIT_TEST(TopologicalMove)
-JIT_TEST(SubgraphUtils)
-JIT_TEST(AliasAnalysis)
-JIT_TEST(WriteTracking)
-JIT_TEST(Wildcards)
-JIT_TEST(MemoryDAG)
-JIT_TEST(IRParser)
-JIT_TEST(ConstantPooling)
-
-JIT_TEST(NetDefConverter)
-
-JIT_TEST(THNNConv)
-JIT_TEST(ATenNativeBatchNorm)
-JIT_TEST(NoneSchemaMatch)
-JIT_TEST(ClassParser)
-
-#define JIT_TEST_CUDA(name) \
- TEST(JitTest, name##_CUDA) { \
- test##name(); \
- }
-
-JIT_TEST_CUDA(ArgumentSpec)
-JIT_TEST_CUDA(Fusion)
-JIT_TEST_CUDA(GraphExecutor)
-JIT_TEST_CUDA(Interp)
+++ /dev/null
-#include <test/cpp/jit/test_alias_analysis.h>
-#include <test/cpp/jit/test_constant_pooling.h>
-#include <test/cpp/jit/test_class_parser.h>
-#include <test/cpp/jit/test_irparser.h>
-#include <test/cpp/jit/test_misc.h>
-#include <test/cpp/jit/test_netdef_converter.h>
-
-#include <sstream>
-#include <string>
-
-using namespace torch::jit::script;
-namespace torch {
-namespace jit {
-void runJITCPPTests() {
- testNoneSchemaMatch();
- testAutogradProfiler();
- testADFormulas();
- testArgumentSpec();
- testAttributes();
- testBlocks();
- testCodeTemplate();
- testControlFlow();
- testCreateAutodiffSubgraphs();
- testCustomOperators();
- testDifferentiate();
- testDifferentiateWithRequiresGrad();
- testDynamicDAG();
- testEvalModeForLoadedModule();
- testFromQualString();
- testFusion();
- testGraphExecutor();
- testInternedStrings();
- testInterp();
- testIValue();
- testProto();
- testSchemaParser();
- testTopologicalIndex();
- testTopologicalMove();
- testSubgraphUtils();
- testTHNNConv();
- testATenNativeBatchNorm();
- testRegisterFusionCachesKernel();
- testAliasAnalysis();
- testWriteTracking();
- testWildcards();
- testMemoryDAG();
- testNetDefConverter();
- testIRParser();
- testConstantPooling();
- testClassParser();
-}
-
-} // namespace jit
-} // namespace torch
--- /dev/null
+#if defined(USE_GTEST)
+#include <gtest/gtest.h>
+#endif
+
+// To add a new test file:
+// 1. Add a test_foo.h file in this directory
+// 2. include test_base.h
+// 3. Write your tests as pure functions starting with "test", like "testFoo"
+// 4. Include test_foo.h here and add it to the appropriate macro listing
+#include <test/cpp/jit/test_alias_analysis.h>
+#include <test/cpp/jit/test_class_parser.h>
+#include <test/cpp/jit/test_constant_pooling.h>
+#include <test/cpp/jit/test_irparser.h>
+#include <test/cpp/jit/test_misc.h>
+#include <test/cpp/jit/test_netdef_converter.h>
+
+using namespace torch::jit::script;
+namespace torch {
+namespace jit {
+#define TH_FORALL_TESTS(_) \
+ _(ADFormulas) \
+ _(Attributes) \
+ _(Blocks) \
+ _(CodeTemplate) \
+ _(ControlFlow) \
+ _(CreateAutodiffSubgraphs) \
+ _(CustomOperators) \
+ _(Differentiate) \
+ _(DifferentiateWithRequiresGrad) \
+ _(DynamicDAG) \
+ _(FromQualString) \
+ _(InternedStrings) \
+ _(IValue) \
+ _(Proto) \
+ _(RegisterFusionCachesKernel) \
+ _(SchemaParser) \
+ _(TopologicalIndex) \
+ _(TopologicalMove) \
+ _(SubgraphUtils) \
+ _(AliasAnalysis) \
+ _(WriteTracking) \
+ _(Wildcards) \
+ _(MemoryDAG) \
+ _(IRParser) \
+ _(ConstantPooling) \
+ _(NetDefConverter) \
+ _(THNNConv) \
+ _(ATenNativeBatchNorm) \
+ _(NoneSchemaMatch) \
+ _(ClassParser)
+
+#define TH_FORALL_TESTS_CUDA(_) \
+ _(ArgumentSpec) \
+ _(Fusion) \
+ _(GraphExecutor) \
+ _(Interp)
+
+#if defined(USE_GTEST)
+
+#define JIT_GTEST(name) \
+ TEST(JitTest, name) { \
+ test##name(); \
+ }
+TH_FORALL_TESTS(JIT_GTEST)
+#undef JIT_TEST
+
+#define JIT_GTEST_CUDA(name) \
+ TEST(JitTest, name##_CUDA) { \
+ test##name(); \
+ }
+TH_FORALL_TESTS_CUDA(JIT_GTEST_CUDA)
+#undef JIT_TEST_CUDA
+#endif
+
+#define JIT_TEST(name) test##name();
+void runJITCPPTests() {
+ TH_FORALL_TESTS(JIT_TEST)
+ TH_FORALL_TESTS_CUDA(JIT_TEST)
+
+ // This test is special since it requires prior setup in python.
+ // So it's included here but not in the pure cpp gtest suite
+ testEvalModeForLoadedModule();
+}
+#undef JIT_TEST
+
+} // namespace jit
+} // namespace torch
// This file defines assertion macros that work in both gtest and non-gtest
// builds, and has some common includes.
-//
-// To add a new test file:
-// 1. Add a test_foo.h file in this directory
-// 2. include test_base.h
-// 3. Write your tests as pure functions
-// 4. Include test_foo.h in gtest.cpp and no-gtest.cpp and register the tests
-// there.
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/operator.h"
${TORCH_SRC_DIR}/csrc/jit/fuser/executor.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/codegen.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/fallback.cpp
- ${TORCH_ROOT}/test/cpp/jit/no-gtest.cpp
+ ${TORCH_ROOT}/test/cpp/jit/test.cpp
)
if (WIN32)