Operator-level performance microbenchmarks (#18740)
authorMingzhe Li <mingzhe0908@fb.com>
Wed, 3 Apr 2019 00:03:23 +0000 (17:03 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 00:06:19 +0000 (17:06 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18740

Test utilities for writing Caffe2/PyTorch performance microbenchmarks. Brief description of the file structure

* benchmark_core.py : core utiltiites for running microbenchmark tests
* benchmark_caffe2.py : Caffe2 specific benchmark utilitites
* benchmark_pytorch.py: PyTorch specific benchmark utilities
* benchmark_runner.py : Main function. Currently it can run the microbenchmark tests in a stand-alone mode. The next step is to have this integrate with AI-PEP.

The utilities are located at https://github.com/pytorch/pytorch/tree/master/test to have access to both Caffe2/PyTorch Python's frontend.

Include two operator microbenchmarks; support both Caffe2/PyTorch:
* MatMul
* Add

Reference: PyTorch benchmarks : https://github.com/pytorch/benchmark/tree/master/timing/python. In this work, we start with two example binary operators MatMul and Add, but eventually we should to cover unary operators like in the PyTorch benchmark repo.

Reviewed By: zheng-xq

Differential Revision: D13887111

fbshipit-source-id: b7a56b95448c9ec3e674b0de0ffb96af4439bfce

benchmarks/operator_benchmark/__init__.py [new file with mode: 0644]
benchmarks/operator_benchmark/benchmark_caffe2.py [new file with mode: 0644]
benchmarks/operator_benchmark/benchmark_core.py [new file with mode: 0644]
benchmarks/operator_benchmark/benchmark_pytorch.py [new file with mode: 0644]
benchmarks/operator_benchmark/benchmark_runner.py [new file with mode: 0644]
benchmarks/operator_benchmark/benchmark_utils.py [new file with mode: 0644]
benchmarks/operator_benchmark/ops/__init__.py [new file with mode: 0644]
benchmarks/operator_benchmark/ops/add.py [new file with mode: 0644]
benchmarks/operator_benchmark/ops/matmul.py [new file with mode: 0644]
caffe2/python/pybind_state.cc
caffe2/python/workspace.py

diff --git a/benchmarks/operator_benchmark/__init__.py b/benchmarks/operator_benchmark/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/benchmarks/operator_benchmark/benchmark_caffe2.py b/benchmarks/operator_benchmark/benchmark_caffe2.py
new file mode 100644 (file)
index 0000000..cf341c4
--- /dev/null
@@ -0,0 +1,47 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.python import core, workspace
+from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
+
+"""Caffe2 performance microbenchmarks.
+
+This module contains Caffe2-specific functionalities for performance
+microbenchmarks.
+"""
+
+
+def Caffe2OperatorTestCase(test_name, op_type, input_shapes, op_args, run_mode):
+    """Benchmark Tester function for Caffe2 framework.
+    test_case is expected to be a Caffe2OperatorTestCase object. If not, the
+    function will return False.
+    It returns a function that contains the code to benchmarked
+    (operator execution).
+    """
+    idx = 0
+    input_blobs = []
+    for input in input_shapes:
+        blob_name = 'input_' + test_name + str(input_shapes) + str(op_args) + str(idx)
+        input_blobs.append(blob_name)
+        # TODO: figure out the data type from operator schema/
+        # or accept custom data type for more comprehensive coverage.
+        # Also, consider a more complex range/distribution of numerical inputs.
+        workspace.FeedBlob(blob_name, benchmark_utils.numpy_random_fp32(*input))
+        idx += 1
+
+    # TODO: consider reuse logic in Caffe2's Functional utility to get
+    # these benefits
+    # - Read operator schema to figure out if inplace enforcement is needed
+    # for the operator and name the output blob appropriately.
+    # - Also figure out the number of outputs from operator schema.
+    op = core.CreateOperator(
+        op_type, input_blobs, ['out'], **op_args
+    )
+
+    def benchmark_func(num_runs):
+        if not workspace.RunOperatorMultiple(op, num_runs):
+            raise RuntimeError('Unable to run operator test case ' % test_name)
+
+    benchmark_core.add_benchmark_tester("Caffe2", test_name, input_shapes, op_args, run_mode, benchmark_func)
diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py
new file mode 100644 (file)
index 0000000..2693f84
--- /dev/null
@@ -0,0 +1,187 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import functools
+import numpy as np
+import timeit
+import json
+
+from caffe2.benchmarks.operator_benchmark import benchmark_utils
+
+"""Performance microbenchmarks.
+
+This module contains core functionalities for performance microbenchmark tests.
+"""
+
+
+# List of run modes we support.
+# Each benchmark test case is associated with a run mode.
+# If the value of the test case's run mode is less than the value of the
+# benchmark binary's run mode, the test case will be executed, e.g. a short-mode
+# test case will be executed when the binary is on either long and short
+# modes; while a long-mode test case will only be executed when the binary is
+# on long-mode.
+RUN_MODES = {'short': 0, 'long': 1}
+BENCHMARK_TESTER = [{} for _ in range(len(RUN_MODES))]
+BENCHMARK_TEST_GROUP = {}
+
+
+def add_benchmark_tester(framework, op_name, input_shapes, op_args, run_mode, func):
+    func_name = "__".join([framework, op_name, benchmark_utils.shape_to_string(input_shapes)
+                          , str(op_args), run_mode])
+    run_mode = RUN_MODES[run_mode]
+    for mode in RUN_MODES.values():
+        # short mode runs with some of the input shapes for an op
+        # long mode runs with all the input shapes for an op
+        if (mode < run_mode):
+            continue
+        BENCHMARK_TESTER[mode][func_name] = func
+
+
+def benchmark_test_group(func):
+    """Decorator to register a benchmark test group.
+    A benchmark test group is a function that returns a list of benchmark test
+    case objects to be run.
+    """
+    BENCHMARK_TEST_GROUP[__name__ + "." + func.__name__] = func
+    return func
+
+
+HEADER_LINE = """
+# {}
+# PyTorch/Caffe2 Operator Micro-benchmarks
+# {}
+# Run_mode : {}
+"""
+
+
+class BenchmarkRunner(object):
+    """BenchmarkRunner is responsible for benchmarking all the registered
+    benchmark test groups.
+
+    Attributes:
+        run_mode (str): Must of one of 'short', 'long'. For long mode, the
+    benchmark runner takes a longer time to run since it repeats each benchmark
+    test case more times to reduce measured variance, and it also executes
+    longer running test cases that is marked as long mode.
+        operator (str): Only run benchmark test cases that contains
+    this filter string in the test case's id.
+    """
+    def __init__(self, args):
+        # Depend on the run mode, set the execution contrains based of number of
+        # runs per measure, and number of measures.
+        # TODO: consider time-bound constraints as well.
+        self.args = args
+        self.iters = 100
+        self.has_explicit_iteration_count = False
+        self.multiplier = 2
+        self.min_time = 0.8
+        self.max_iters = 1e6
+        for test_group in BENCHMARK_TEST_GROUP.items():
+            test_group_func = test_group[1]
+            test_group_func()
+        if self.args.iterations:
+            self.has_explicit_iteration_count = True
+            self.iters = self.args.iterations
+
+    def _print_header(self, run_mode):
+        DASH_LINE = '-' * 40
+        print(HEADER_LINE.format(DASH_LINE, DASH_LINE, self.args.run_mode, self.iters))
+        print("# List of Operators to run:")
+        if self.args.operator is None:
+            ops = set()
+            for tester in BENCHMARK_TESTER[run_mode].items():
+                full_test_id = tester[0]
+                framework, op_name, input_shapes, args, run_mode = full_test_id.split("__")
+                if op_name not in ops:
+                    print("# {}".format(op_name))
+                    ops.add(op_name)
+        else:
+            print("# {}".format(self.args.operator))
+        print("\n")
+
+    def _print_perf_result(self, full_test_id, input_shapes, args, reported_run_time):
+        if self.args.ai_pep_format:
+            # Output for AI-PEP
+            print("Caffe2Observer " + json.dumps(
+                {
+                    "type": "NET",
+                    "metric": full_test_id,
+                    "unit": "ms",
+                    "value": str(reported_run_time),
+                }
+            ))
+        else:
+            print("# Input Shape: {}\n"
+                  "Execution Time (us) : {:.3f} \n"
+                  .format(input_shapes, reported_run_time))
+
+    def _predict_num_iter_needed(self, i):
+        return (i * self.multiplier)
+
+    def _report_iteration_result(self, iters, run_time):
+        return (iters > self.max_iters or
+                run_time > 5 * self.min_time)
+
+    def run(self):
+        run_mode = RUN_MODES[self.args.run_mode]
+        self._print_header(run_mode)
+
+        if self.args.list_tests:
+            return
+
+        for tester in BENCHMARK_TESTER[run_mode].items():
+            full_test_id = tester[0]
+            benchmark_func = tester[1]
+            framework, op_name, input_shapes, args, run_mode = full_test_id.split("__")
+            # TODO: consider regex matching for test filtering.
+            # Currently, this is a sub-string matching.
+            if self.args.operator and (self.args.operator not in full_test_id):
+                continue
+            if self.args.framework and (self.args.framework not in full_test_id):
+                continue
+
+            # To reduce variance, fix a numpy randseed to the test case,
+            # so that the randomly generated input tensors remain the
+            # same for each test case.
+            # The random seed is limited to 32-bit because of numpy
+            # requirement.
+            np.random.seed(seed=hash(full_test_id) & ((1 << 32) - 1))
+
+            print("# Benchmarking {} {}".format(
+                framework,
+                op_name))
+            # Warmup
+            functools.partial(benchmark_func, self.args.warmup_iterations)
+
+            # Actual Execution
+            run_time = 0
+            iters = self.iters
+            while True:
+                # Use Python's timeit module to measure execution time.
+                # Each experiment consists of repeated execution of
+                # the benchmark_func a number of times (self.iters)
+                # because otherwise the duration is too short to get
+                # an accurate measure. The benchmark loop is pushed
+                # to C++ to minimize Python overhead.
+                # The experiment is also repeated a number of times
+                # (num_repeats) and we then take the minimum execution
+                # time as the final measurement result (this is also
+                # recommended by timeit's doc).
+                run_time = run_time + min(timeit.repeat(functools.partial(benchmark_func, iters),
+                                          repeat=1, number=1))
+                # Analyze time after each run to decide if the result is stable
+                results_are_significant = self.has_explicit_iteration_count or \
+                    self._report_iteration_result(iters, run_time)
+
+                if results_are_significant:
+                    break
+
+                # Re-estimate the hopefully-sufficient
+                # iteration count, and run the benchmark again...
+                iters = self._predict_num_iter_needed(iters)
+
+            reported_run_time = (1e6 * run_time / iters)
+            self._print_perf_result(full_test_id, input_shapes, args, reported_run_time)
diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py
new file mode 100644 (file)
index 0000000..5f30542
--- /dev/null
@@ -0,0 +1,29 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
+
+import torch
+
+"""PyTorch performance microbenchmarks.
+
+This module contains PyTorch-specific functionalities for performance
+microbenchmarks.
+"""
+
+
+def PyTorchOperatorTestCase(test_name, op_type, input_shapes, op_args, run_mode):
+    """Benchmark Tester function for Pytorch framework.
+    test_case is expected to be a PyTorchOperatorTestCase object. If not, the
+    function will return False.
+    It returns a function that contains the code to benchmarked
+    (operator execution).
+    """
+    inputs = [torch.from_numpy(benchmark_utils.numpy_random_fp32(*input)) for input in input_shapes]
+
+    def benchmark_func(num_runs):
+        op_type(*(inputs + [num_runs]))
+
+    benchmark_core.add_benchmark_tester("PyTorch", test_name, input_shapes, op_args, run_mode, benchmark_func)
diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py
new file mode 100644 (file)
index 0000000..5e06a5a
--- /dev/null
@@ -0,0 +1,90 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import sys
+import argparse
+
+from caffe2.python import workspace
+
+from caffe2.benchmarks.operator_benchmark import benchmark_core
+
+import caffe2.benchmarks.operator_benchmark.benchmark_caffe2
+import caffe2.benchmarks.operator_benchmark.benchmark_pytorch
+
+import caffe2.benchmarks.operator_benchmark.ops.add
+import caffe2.benchmarks.operator_benchmark.ops.matmul
+
+"""Performance microbenchmarks's main binary.
+
+This is the main function for running performance microbenchmark tests.
+It also registers existing benchmark tests via Python module imports.
+"""
+
+
+if __name__ == "__main__":
+    print("Python version " + str(sys.version_info[0]))
+
+    parser = argparse.ArgumentParser(
+        description="Run microbenchmarks.",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    parser.add_argument(
+        '--run_mode',
+        help='Run mode. '
+        'short: run all operators with few shapes'
+        'long: run all operators with all shapes',
+        choices=benchmark_core.RUN_MODES.keys(),
+        default='short')
+
+    # This option is used to filter test cases to run.
+    # Currently, the matching is sub-string but we can consider support regex.
+    # For example, if test_case_filter = 'matmul', in will match these test
+    # cases:
+    # matmul_benchmark.Caffe2OperatorTestCase.matmul_512_128_512_transa_transb
+    # matmul_benchmark.PyTorchOperatorTestCase.matmul_100_200_150
+    # ...
+    parser.add_argument(
+        '--operator',
+        help='Only run the test cases that contain the provided operator'
+        ' as a substring of their names',
+        default=None)
+
+    parser.add_argument(
+        '--list_tests',
+        help='List all test cases without running them',
+        action='store_true')
+
+    parser.add_argument(
+        "--iterations",
+        help="Repeat each operator for the number of iterations",
+        type=int
+    )
+
+    parser.add_argument(
+        "--warmup_iterations",
+        help="Number of iterations to ignore before measuring performance",
+        default=10,
+        type=int
+    )
+
+    parser.add_argument(
+        "--ai_pep_format",
+        help="Print result when running on AI-PEP",
+        default=False,
+        type=bool
+    )
+
+    parser.add_argument(
+        '--framework',
+        help='Run PyTorch or Caffe2 operators',
+        default=None)
+
+    args = parser.parse_args()
+
+    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
+    workspace.ClearGlobalNetObserver()
+
+    benchmark_core.BenchmarkRunner(args).run()
diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py
new file mode 100644 (file)
index 0000000..e0d5231
--- /dev/null
@@ -0,0 +1,35 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import itertools
+import random
+
+
+"""Performance microbenchmarks's utils.
+
+This module contains utilities for writing microbenchmark tests.
+"""
+
+
+def shape_to_string(shape):
+    return ', '.join([str(x) for x in shape])
+
+
+def numpy_random_fp32(*shape):
+    """Return a random numpy tensor of float32 type.
+    """
+    # TODO: consider more complex/custom dynamic ranges for
+    # comprehensive test coverage.
+    return np.random.rand(*shape).astype(np.float32)
+
+
+def cross_product(*inputs):
+    return (list(itertools.product(*inputs)))
+
+
+def get_n_rand_nums(min_val, max_val, n):
+    random.seed((1 << 32) - 1)
+    return random.sample(range(min_val, max_val), n)
diff --git a/benchmarks/operator_benchmark/ops/__init__.py b/benchmarks/operator_benchmark/ops/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/benchmarks/operator_benchmark/ops/add.py b/benchmarks/operator_benchmark/ops/add.py
new file mode 100644 (file)
index 0000000..23d208c
--- /dev/null
@@ -0,0 +1,68 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
+
+from caffe2.benchmarks.operator_benchmark.benchmark_caffe2 import Caffe2OperatorTestCase
+from caffe2.benchmarks.operator_benchmark.benchmark_pytorch import PyTorchOperatorTestCase
+
+import torch
+
+
+"""Microbenchmarks for element-wise Add operator. Supports both Caffe2/PyTorch."""
+
+# Input shapes that we test and the run mode for each shape.
+# Sum up two tensors with the same shape
+
+
+def generate_inputs():
+    ms = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=1)
+    ns = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+    ks = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+    mode = ['long']
+
+    test_cases = benchmark_utils.cross_product([ms], mode)
+
+    two_dims = benchmark_utils.cross_product(ms, ns)
+    two_dims = benchmark_utils.cross_product(two_dims, mode)
+    test_cases.extend(two_dims)
+
+    three_dims = benchmark_utils.cross_product(ms, ns, ks)
+    three_dims = benchmark_utils.cross_product(three_dims, mode)
+    test_cases.extend(three_dims)
+
+    # Representative inputs
+    test_cases.extend([([128], 'short'),
+                       ([64, 128], 'short'),
+                       ([32, 64, 128], 'short')])
+    return test_cases
+
+
+@torch.jit.script
+def torch_add(a, b, iterations):
+    # type: (Tensor, Tensor, int)
+    result = torch.jit.annotate(torch.Tensor, None)
+    for _ in range(iterations):
+        result = torch.add(a, b)
+    return result
+
+
+@benchmark_core.benchmark_test_group
+def add_test_cases():
+    test_cases = generate_inputs()
+    for test_case in test_cases:
+        X, run_mode = test_case
+        Caffe2OperatorTestCase(
+            test_name='add',
+            op_type='Add',
+            input_shapes=[X, X],
+            op_args={},
+            run_mode=run_mode)
+        PyTorchOperatorTestCase(
+            test_name='add',
+            op_type=torch_add,
+            input_shapes=[X, X],
+            op_args={},
+            run_mode=run_mode)
diff --git a/benchmarks/operator_benchmark/ops/matmul.py b/benchmarks/operator_benchmark/ops/matmul.py
new file mode 100644 (file)
index 0000000..214e2a5
--- /dev/null
@@ -0,0 +1,63 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
+
+from caffe2.benchmarks.operator_benchmark.benchmark_caffe2 import Caffe2OperatorTestCase
+from caffe2.benchmarks.operator_benchmark.benchmark_pytorch import PyTorchOperatorTestCase
+
+import torch
+
+
+"""Microbenchmarks for MatMul operator. Supports both Caffe2/PyTorch."""
+
+
+def generate_inputs():
+    # Random inputs
+    Ms = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+    Ns = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+    Ks = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+    transpose_a = [False, True]
+    transpose_b = [True, False]
+    mode = ['long']
+    test_cases = benchmark_utils.cross_product(Ms, Ns, Ks, transpose_a, transpose_b, mode)
+
+    # Representative inputs
+    test_cases.extend([(8, 16, 64, False, False, 'short'),
+                       (64, 64, 256, False, False, 'short'),
+                       (256, 256, 256, False, False, 'short')])
+    return test_cases
+
+
+@torch.jit.script
+def torch_matmul(a, b, iterations):
+    # type: (Tensor, Tensor, int)
+    result = torch.jit.annotate(torch.Tensor, None)
+    for _ in range(iterations):
+        result = torch.matmul(a, b)
+    return result
+
+
+@benchmark_core.benchmark_test_group
+def matmul_test_cases():
+    test_cases = generate_inputs()
+    for test_case in test_cases:
+        M, N, K, trans_a, trans_b, run_mode = test_case
+        input_shapes = [(N, M) if trans_a else (M, N), (K, N) if trans_b else (N, K)]
+        Caffe2OperatorTestCase(
+            test_name='matmul',
+            op_type='MatMul',
+            input_shapes=input_shapes,
+            op_args={'trans_a': trans_a, 'trans_b': trans_b},
+            run_mode=run_mode)
+        if not trans_a and not trans_b:
+            # PyTorch's matmul does not take transpose flags, so we only
+            # have a test case when there are no transpose flags.
+            PyTorchOperatorTestCase(
+                test_name='matmul',
+                op_type=torch_matmul,
+                input_shapes=input_shapes,
+                op_args={},
+                run_mode=run_mode)
index e4f3e6f..dc2b339 100644 (file)
@@ -1190,6 +1190,10 @@ void addGlobalMethods(py::module& m) {
         NetBase* net = gWorkspace->GetNet(net_name);
         net->DetachObserver(observer);
       });
+  m.def("clear_global_net_observer", []() {
+    py::gil_scoped_release g;
+    caffe2::ClearGlobalNetObservers();
+  });
   m.def("num_observers_on_net", [](const std::string& net_name) {
     CAFFE_ENFORCE(gWorkspace);
     CAFFE_ENFORCE(gWorkspace->GetNet(net_name), "Can't find net ", net_name);
@@ -1227,6 +1231,22 @@ void addGlobalMethods(py::module& m) {
     CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def));
     return true;
   });
+  // Run an operator multiple times.
+  // This is needed for microbenchmarking as we want the benchmark loop to be in
+  // C++ to minimize overhead.
+  m.def("run_operator_multiple", [](const py::bytes& op_def, int num_runs) {
+    CAFFE_ENFORCE(gWorkspace);
+    OperatorDef def;
+    CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast<std::string>(), &def));
+    py::gil_scoped_release g;
+    std::unique_ptr<OperatorBase> op(CreateOperator(def, gWorkspace));
+    for (int i = 0; i < num_runs; i++) {
+      if (!op->Run()) {
+        return false;
+      }
+    }
+    return true;
+  });
   m.def(
       "get_operator_cost",
       [](const py::bytes& op_def, const std::vector<string>& input_blobs) {
index 342bdfc..18fcd9b 100644 (file)
@@ -185,6 +185,10 @@ def RunOperatorOnce(operator):
     return C.run_operator_once(StringifyProto(operator))
 
 
+def RunOperatorMultiple(operator, num_runs):
+    return C.run_operator_multiple(StringifyProto(operator), num_runs)
+
+
 def RunOperatorsOnce(operators):
     for op in operators:
         success = RunOperatorOnce(op)
@@ -193,6 +197,10 @@ def RunOperatorsOnce(operators):
     return True
 
 
+def ClearGlobalNetObserver():
+    return C.clear_global_net_observer()
+
+
 def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
     try:
         return func(*args, **kwargs)