--- /dev/null
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
+
+#include <sstream>
+
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace profiling {
+namespace {
+
+using Detail = tensorflow::StatsCalculator::Detail;
+
+struct OperatorDetails {
+ std::string name;
+ std::vector<std::string> inputs;
+ std::vector<std::string> outputs;
+};
+
+std::string GetTensorName(const tflite::Interpreter& interpreter,
+ int tensor_index) {
+ const auto tensor = interpreter.tensor(tensor_index);
+ if (tensor == nullptr || tensor->name == nullptr) {
+ return "Unknown";
+ }
+ return tensor->name;
+}
+std::vector<std::string> GetTensorNames(const tflite::Interpreter& interpreter,
+ const TfLiteIntArray* tensor_indices) {
+ std::vector<std::string> tensors;
+ tensors.reserve(tensor_indices->size);
+ for (int i = 0; i < tensor_indices->size; i++) {
+ tensors.push_back(GetTensorName(interpreter, tensor_indices->data[i]));
+ }
+ return tensors;
+}
+
+std::string ToString(const std::vector<std::string>& str_vector) {
+ std::stringstream stream;
+ stream << "[";
+ bool first = true;
+ for (const auto& s : str_vector) {
+ if (!first) {
+ stream << ", ";
+ } else {
+ first = false;
+ }
+ stream << s;
+ }
+ stream << "]";
+ return stream.str();
+}
+
+OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
+ int node_index) {
+ auto node_reg = interpreter.node_and_registration(node_index);
+ auto inputs = node_reg->first.inputs;
+ auto outputs = node_reg->first.outputs;
+ int code = node_reg->second.builtin_code;
+ const char* op_name = nullptr;
+ if (code == tflite::BuiltinOperator_CUSTOM) {
+ const char* custom_name = node_reg->second.custom_name;
+ op_name = custom_name ? custom_name : "UnknownCustomOp";
+ } else {
+ op_name = tflite::EnumNamesBuiltinOperator()[code];
+ }
+ OperatorDetails details;
+ details.name = op_name;
+ details.inputs = GetTensorNames(interpreter, inputs);
+ details.outputs = GetTensorNames(interpreter, outputs);
+ return details;
+}
+
+} // namespace
+
+ProfileSummarizer::ProfileSummarizer()
+ : stats_calculator_(new ::tensorflow::StatsCalculator(
+ tensorflow::StatSummarizerOptions())) {}
+
+void ProfileSummarizer::ProcessProfiles(
+ const std::vector<const ProfileEvent*>& profile_stats,
+ const tflite::Interpreter& interpreter) {
+ std::vector<const ProfileEvent*> events;
+ std::copy_if(profile_stats.begin(), profile_stats.end(),
+ std::back_inserter(events), [](const ProfileEvent* e) {
+ return e->event_type ==
+ ProfileEvent::EventType::OPERATOR_INVOKE_EVENT &&
+ e->end_timestamp_us >= e->begin_timestamp_us;
+ });
+ // Sort with begin_time.
+ std::sort(events.begin(), events.end(),
+ [](const ProfileEvent* const& a, const ProfileEvent* const& b) {
+ return a->begin_timestamp_us < b->begin_timestamp_us;
+ });
+ if (events.empty()) {
+ return;
+ }
+
+ int64_t base_start_us = events[0]->begin_timestamp_us;
+ int node_num = 0;
+ int64_t curr_total_us = 0;
+ std::map<std::string, Detail> details;
+ int prev_op_idx = -1;
+ int seq_no = 1;
+ for (auto event : events) {
+ auto op_details = GetOperatorDetails(interpreter, event->event_metadata);
+ bool is_continued = (prev_op_idx == event->event_metadata);
+ seq_no = is_continued ? seq_no + 1 : 1;
+ auto node_name = ToString(op_details.outputs) + "#" + std::to_string(seq_no);
+ auto result = details.emplace(node_name, Detail());
+ Detail* detail = &(result.first->second);
+ detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us);
+ int64_t node_exec_time =
+ event->end_timestamp_us - event->begin_timestamp_us;
+ detail->rel_end_us.UpdateStat(node_exec_time);
+ curr_total_us += node_exec_time;
+ ++node_num;
+
+ if (result.second) {
+ detail->name = node_name;
+ detail->type = op_details.name;
+ detail->run_order = node_num;
+ detail->times_called = 0;
+ }
+ if (!is_continued) {
+ ++detail->times_called;
+ }
+ prev_op_idx = event->event_metadata;
+ }
+ stats_calculator_->UpdateDetails(details);
+ stats_calculator_->UpdateRunTotalUs(curr_total_us);
+}
+} // namespace profiling
+} // namespace tflite
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
+#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
+
+#include <vector>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/profiling/profiler.h"
+#include "tensorflow/core/util/stats_calculator.h"
+
+namespace tflite {
+namespace profiling {
+
+// Creates a summary of operator invocations in the interpreter.
+class ProfileSummarizer {
+ public:
+ ProfileSummarizer();
+ virtual ~ProfileSummarizer() {}
+
+ // Process profile events to update statistics for operator invocations.
+ void ProcessProfiles(const std::vector<const ProfileEvent*>& profile_stats,
+ const tflite::Interpreter& interpreter);
+
+ // Returns a string detailing the accumulated runtime stats in a tab-separated
+ // format which can be pasted into a spreadsheet for further analysis.
+ std::string GetOutputString() const {
+ return stats_calculator_->GetOutputString();
+ }
+
+ std::string GetShortSummary() const {
+ return stats_calculator_->GetShortSummary();
+ }
+
+ private:
+ std::unique_ptr<tensorflow::StatsCalculator> stats_calculator_;
+};
+
+} // namespace profiling
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_