Sort VM stats by time (#4601)
authorZhi <5145158+zhiics@users.noreply.github.com>
Tue, 31 Dec 2019 19:16:12 +0000 (11:16 -0800)
committerYao Wang <kevinthesunwy@gmail.com>
Tue, 31 Dec 2019 19:16:12 +0000 (11:16 -0800)
python/tvm/relay/backend/profiler_vm.py
src/runtime/vm/profiler/vm.cc
tests/python/unittest/test_runtime_vm_profiler.py

index 5ee2d66..fa0326e 100644 (file)
@@ -38,8 +38,21 @@ class VirtualMachineProfiler(vm.VirtualMachine):
         self._set_input = self.mod["set_input"]
         self._reset = self.mod["reset"]
 
-    def get_stat(self):
-        return self._get_stat()
+    def get_stat(self, sort_by_time=True):
+        """Get the statistics of executed ops.
+
+        Parameters
+        ----------
+        sort_by_time: Optional[Boolean]
+           Set to indicate the returned results are sorted by execution time in
+           the descending order. It is printed in the random order if this
+           field is not set.
+
+        Returns
+        -------
+            The execution statistics in string.
+        """
+        return self._get_stat(sort_by_time)
 
     def reset(self):
         self._reset()
index b004f67..3b7b7aa 100644 (file)
@@ -31,6 +31,7 @@
 #include <memory>
 #include <numeric>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "vm.h"
@@ -43,16 +44,32 @@ PackedFunc VirtualMachineDebug::GetFunction(
     const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
   if (name == "get_stat") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_EQ(args.size(), 1U);
+      std::vector<std::pair<Index, double>> op_acc_time;
+      for (auto kv : op_durations_) {
+        auto val = std::make_pair(
+            kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0));
+        op_acc_time.push_back(val);
+      }
+      bool sort_by_time = args[0];
+      if (sort_by_time) {
+        auto comp = [](const std::pair<Index, double>& lhs,
+                       const std::pair<Index, double>& rhs) {
+          return lhs.second > rhs.second;
+        };
+        std::sort(op_acc_time.begin(), op_acc_time.end(), comp);
+      }
       double total_duration = 0.0;
+      int64_t total_packed_funcs = 0;
       std::ostringstream os;
       os << std::setw(30) << std::left << "#OpName"
          << "\t" << std::setw(10) << std::left << "#InvokeCount"
          << "\t"
          << "#Duration(us): Sum/Mean/Min/Max" << std::endl;
 
-      for (auto kv : op_durations_) {
+      for (auto kv : op_acc_time) {
         auto vals = op_durations_[kv.first];
-        auto sum = std::accumulate(vals.begin(), vals.end(), 0.0);;
+        auto sum = kv.second;
         auto mean = sum / static_cast<double>(vals.size());
         auto min_value = *std::min_element(vals.begin(), vals.end());
         auto max_value = *std::max_element(vals.begin(), vals.end());
@@ -62,8 +79,10 @@ PackedFunc VirtualMachineDebug::GetFunction(
            <<  sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl;
 
         total_duration += sum;
+        total_packed_funcs += op_invokes_[kv.first];
       }
-      os << "Total Duration " << total_duration << " us" << std::endl;
+      os << "\nTotal Duration: " << total_duration << " us.\t"
+         << "Total Packed Functions: " << total_packed_funcs << std::endl;
       *rv = os.str();
     });
   } else if (name == "reset") {
index 6cfe6e8..b7bbe2f 100644 (file)
@@ -35,6 +35,7 @@ def test_basic():
     data = np.random.rand(1, 3, 224, 224).astype('float32')
     res = vm.invoke("main", [data])
     print("\n{}".format(vm.get_stat()))
+    print("\n{}".format(vm.get_stat(False)))
 
 if __name__ == "__main__":
     test_basic()