From aad6f978987aa4ec507a06261860e0daea2c6b59 Mon Sep 17 00:00:00 2001 From: Zafar Takhirov Date: Wed, 17 Apr 2019 11:19:19 -0700 Subject: [PATCH] Decorator to make sure we can import `core` from caffe2 (#19273) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19273 Some of the CIs are failing if the protobuf is not installed. Protobuf is imported as part of the `caffe2.python.core`, and this adds a skip decorator to avoid running tests that depend on `caffe2.python.core` Reviewed By: jianyuh Differential Revision: D14936387 fbshipit-source-id: e508a1858727bbd52c951d3018e2328e14f126be --- test/common_utils.py | 20 ++++++++++++++++++++ test/test_quantized.py | 7 +++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 6fb0d00..ce23a2d 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -140,6 +140,26 @@ def skipIfNoLapack(fn): return wrapper +def skipIfNotRegistered(op_name, message): + """Wraps the decorator to hide the import of the `core`. + + Args: + op_name: Check if this op is registered in `core._REGISTERED_OPERATORS`. + message: mesasge to fail with. + + Usage: + @skipIfNotRegistered('MyOp', 'MyOp is not linked!') + This will check if 'MyOp' is in the caffe2.python.core + """ + try: + from caffe2.python import core + skipper = unittest.skipIf(op_name not in core._REGISTERED_OPERATORS, + message) + except ImportError: + skipper = unittest.skip("Cannot import `caffe2.python.core`") + return skipper + + def slowTest(fn): @wraps(fn) def wrapper(*args, **kwargs): diff --git a/test/test_quantized.py b/test/test_quantized.py index 8b486b9..2023d36 100644 --- a/test/test_quantized.py +++ b/test/test_quantized.py @@ -4,16 +4,15 @@ import torch import torch.jit import numpy as np import unittest -# from caffe2.python import core -from common_utils import TestCase, run_tests +from common_utils import TestCase, run_tests, skipIfNotRegistered def canonical(graph): return str(torch._C._jit_pass_canonicalize(graph)) -@unittest.skip("Skipping due to the protobuf dependency in the CI's") -# @unittest.skipIf("Relu_ENGINE_DNNLOWP" not in core._REGISTERED_OPERATORS, "fbgemm-based Caffe2 ops are not linked") +@skipIfNotRegistered("Relu_ENGINE_DNNLOWP", + "fbgemm-based Caffe2 ops are not linked") class TestQuantized(TestCase): def test_relu(self): a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5) -- 2.7.4