self.print_level = 0
# Set tensor index list to print information
- # TODO:
- # Print tensors in list only
- # Print all tensors if argument used and not specified index number
+ self.print_all_tensor = True
if (args.tensor != None):
- if (len(args.tensor) == 0):
- self.print_all_tensor = True
- else:
+ if (len(args.tensor) != 0):
self.print_all_tensor = False
self.print_tensor_index = []
-
for tensor_index in args.tensor:
self.print_tensor_index.append(int(tensor_index))
# Set operator index list to print information
- # TODO:
- # Print operators in list only
- # Print all operators if argument used and not specified index number
+ self.print_all_operator = True
if (args.operator != None):
- if (len(args.operator) == 0):
- self.print_all_oeprator = True
- else:
- self.print_all_oeprator = False
+ if (len(args.operator) != 0):
+ self.print_all_operator = False
self.print_operator_index = []
-
for operator_index in args.operator:
self.print_operator_index.append(int(operator_index))
def PrintModel(self, model_name, op_parser):
- ModelPrinter(self.print_level, op_parser, model_name).PrintAll()
+ printer = ModelPrinter(self.print_level, op_parser, model_name)
+
+ if self.print_all_tensor == False:
+ printer.SetPrintSpecificTensors(self.print_tensor_index)
+
+ if self.print_all_operator == False:
+ printer.SetPrintSpecificOperators(self.print_operator_index)
+
+ printer.PrintInfo()
def main(self):
# Generate Model: top structure of tflite model file
# limitations under the License.
from operator_printer import OperatorPrinter
+from tensor_printer import TensorPrinter
class ModelPrinter(object):
self.verbose = verbose
self.op_parser = op_parser
self.model_name = model_name
+ self.print_all_tensor = True
+ self.print_tensor_index_list = None
+ self.print_all_operator = True
+ self.print_operator_index_list = None
+
+ def SetPrintSpecificTensors(self, tensor_indices):
+ if len(tensor_indices) != 0:
+ self.print_all_tensor = False
+ self.print_tensor_index_list = tensor_indices
+
+ def SetPrintSpecificOperators(self, operator_indices):
+ if len(operator_indices) != 0:
+ self.print_all_operator = False
+ self.print_operator_index_list = operator_indices
+
+ def PrintInfo(self):
+ if self.print_all_tensor == True and self.print_all_operator == True:
+ self.PrintModelInfo()
+ self.PrintAllOperatorsInList()
+ self.PrintAllTypesInfo()
+
+ if self.print_all_tensor == False:
+ print('')
+ self.PrintSpecificTensors()
- def PrintAll(self):
- self.PrintModelInfo()
- self.PrintAllOperatorsInList()
- self.PrintAllTypesInfo()
+ if self.print_all_operator == False:
+ print('')
+ self.PrintSpecificOperators()
def PrintModelInfo(self):
print("[" + self.model_name + "]\n")
print(summary_str)
print('')
+
+ def PrintSpecificTensors(self):
+ for tensor in self.op_parser.GetAllTensors():
+ if tensor.tensor_idx in self.print_tensor_index_list:
+ printer = TensorPrinter(self.verbose, tensor)
+ printer.PrintInfo()
+ print('')
+ print('')
+
+ def PrintSpecificOperators(self):
+ for operator in self.op_parser.operators_in_list:
+ if operator.operator_idx in self.print_operator_index_list:
+ printer = OperatorPrinter(self.verbose, operator)
+ printer.PrintInfo(self.op_parser.perf_predictor)
+ print('')
+
+ print('')
return_list.append(Tensor(tensor_idx, tf_tensor, tf_buffer))
return return_list
+ def GetAllTensors(self):
+ return_list = list()
+ for tensor_idx in range(self.tf_subgraph.TensorsLength()):
+ if (tensor_idx < 0):
+ return_list.append(Tensor(tensor_idx, 0, 0))
+ continue
+ tf_tensor = self.tf_subgraph.Tensors(tensor_idx)
+ buffer_idx = tf_tensor.Buffer()
+ tf_buffer = self.tf_model.Buffers(buffer_idx)
+ return_list.append(Tensor(tensor_idx, tf_tensor, tf_buffer))
+ return return_list
+
def AppendOperator(self, operator):
self.operators_in_list.append(operator)
def PrintInfo(self, perf_predictor=None):
if (self.verbose < 1):
- pass
+ return
op_str = "Operator {0}: {1}".format(self.operator.operator_idx,
self.operator.opcode_str)
op_str = op_str + "(instrs: {0}, cycls: {1})".format(instrs, cycles)
print(op_str)
-
print("\tFused Activation: " + self.operator.fused_activation)
+ self.PrintTensors()
+ def PrintTensors(self):
print("\tInput Tensors" + GetStrTensorIndex(self.operator.inputs))
for tensor in self.operator.inputs:
TensorPrinter(self.verbose, tensor).PrintInfo("\t\t")