[Profiler] Change FLOP/s to Total FLOPs (#62779)
authorLucas Kabela <lucaskabela@fb.com>
Mon, 16 Aug 2021 20:34:56 +0000 (13:34 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 20:43:32 +0000 (13:43 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62779

Change from floating point operations per second to total floating point operations.  This requires removing the division  by executing time from the Kineto computed FLOPs and updating necessary documentation

Test Plan:
Running the following script:

```
import torch
from torch.profiler import profile
import torchvision.models as models

model = models.resnet18().eval()
inputs = torch.randn(5, 3, 224, 224)
with torch.no_grad():
    with profile(record_shapes=True, with_flops=True) as prof:
        model(inputs)
print(prof.key_averages().table(sort_by="cpu_time_total"))
```

Before diff results in:

{F636640118}

And after diff should be about `(27.78 * 10^9) FLOP/s * .652838 seconds =18135839640 FLOP = 18.136 GFLOP`.  Running the script again yields this answer:

{F636655686}

------------------------------------

Reviewed By: gdankel

Differential Revision: D29972997

fbshipit-source-id: 0f8d9f264b7d9f8f6bb3f10ab7c2c9794291e28b

test/test_profiler.py
torch/autograd/profiler.py
torch/autograd/profiler_util.py
torch/profiler/profiler.py

index 28d9671..25695a8 100644 (file)
@@ -487,8 +487,7 @@ class TestProfiler(TestCase):
         with _profile(record_shapes=True, with_flops=True, use_kineto=kineto_available()) as prof:
             model(inputs)
         profiler_output = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10)
-        self.assertIn("FLOPS", profiler_output)
-
+        self.assertIn("Total MFLOPs", profiler_output)
         if not (kineto_available() and torch.cuda.is_available()):
             return
 
@@ -501,7 +500,7 @@ class TestProfiler(TestCase):
             model(inputs)
         profiler_output = kineto_profiler.key_averages().table(
             sort_by="self_cuda_time_total", row_limit=-1)
-        self.assertIn("FLOPS", profiler_output)
+        self.assertIn("Total MFLOPs", profiler_output)
 
     def test_kineto_profiler_api(self):
         called_num = [0]
index ab95fdb..c38ad99 100644 (file)
@@ -63,8 +63,8 @@ class profile(object):
             collection.
 
         with_flops (bool, optional): If with_flops is set, the profiler will estimate
-            the FLOPS (floating pointer operations per second) value using the operator's input shape
-            and total time. This allows one to estimate the hardware performance. Currently,
+            the FLOPs (floating point operations) value using the operator's input shape.
+            This allows one to estimate the hardware performance. Currently,
             this option only works for the matrix multiplication and 2D convolution operators.
 
         profile_memory (bool, optional): track tensor memory allocation/deallocation.
index 1385e44..6062c09 100644 (file)
@@ -398,7 +398,7 @@ class FunctionEvent(FormattedTimesMixin):
         self.device_type: DeviceType = device_type
         self.device_index: int = device_index
         self.is_legacy: bool = is_legacy
-        self.flops: Optional[float] = flops
+        self.flops: Optional[int] = flops
 
     def append_kernel(self, name, device, duration):
         assert self.device_type == DeviceType.CPU
@@ -541,7 +541,7 @@ class FunctionEventAvg(FormattedTimesMixin):
         self.cpu_parent: Optional[FunctionEvent] = None
         self.device_type: DeviceType = DeviceType.CPU
         self.is_legacy: bool = False
-        self.flops: float = 0.0
+        self.flops: int = 0
 
     def add(self, other):
         if self.key is None:
@@ -752,28 +752,18 @@ def _build_table(
 
     def auto_scale_flops(flops):
         flop_headers = [
-            'FLOPS',
-            'KFLOPS',
-            'MFLOPS',
-            'GFLOPS',
-            'TFLOPS',
-            'PFLOPS',
+            'FLOPs',
+            'KFLOPs',
+            'MFLOPs',
+            'GFLOPs',
+            'TFLOPs',
+            'PFLOPs',
         ]
         assert flops > 0
         log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
         assert log_flops >= 0 and log_flops < len(flop_headers)
         return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])
 
-    def flops_rate(evt):
-        US_IN_SECOND = 1000.0 * 1000.0
-        if evt.flops > 0:
-            if evt.cuda_time_total != 0:
-                return float(evt.flops) / evt.cuda_time_total * US_IN_SECOND
-            else:
-                return float(evt.flops) / evt.cpu_time_total * US_IN_SECOND
-        else:
-            return -1
-
     add_column(name_column_width)
     for _ in headers[1:]:
         add_column(DEFAULT_COLUMN_WIDTH)
@@ -790,12 +780,11 @@ def _build_table(
         # Auto-scaling of flops header
         raw_flops = []
         for evt in events:
-            rate = flops_rate(evt)
-            if rate > 0:
-                raw_flops.append(rate)
+            if evt.flops > 0:
+                raw_flops.append(evt.flops)
         if len(raw_flops) != 0:
             (flops_scale, flops_header) = auto_scale_flops(min(raw_flops))
-            headers.append(flops_header)
+            headers.append('Total {}'.format(flops_header))
             add_column(flops_column_width)
         else:
             with_flops = False  # can't find any valid flops
@@ -895,11 +884,10 @@ def _build_table(
         if has_input_shapes:
             row_values.append(str(evt.input_shapes)[:shapes_column_width])
         if with_flops:
-            rate = flops_rate(evt)
-            if rate <= 0.0:
+            if evt.flops <= 0:
                 row_values.append("--")
             else:
-                row_values.append('{0:8.3f}'.format(rate * flops_scale))
+                row_values.append('{0:8.3f}'.format(evt.flops * flops_scale))
         if has_stack:
             src_field = ""
             if len(evt.stack) > 0:
index 20bdaa3..34ffb03 100644 (file)
@@ -114,7 +114,7 @@ class profile(object):
         record_shapes (bool): save information about operator's input shapes.
         profile_memory (bool): track tensor memory allocation/deallocation.
         with_stack (bool): record source information (file and line number) for the ops.
-        with_flops (bool): use formula to estimate the FLOPS of specific operators
+        with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
             (matrix multiplication and 2D convolution).
         with_modules (bool): record module hierarchy (including function names)
             corresponding to the callstack of the op. e.g. If module A's forward call's