--- /dev/null
+import flatbuffers
+import numpy as np
+import tensorflow as tf
+
+import re
+import os.path as path
+import argparse
+from collections import defaultdict
+
+from opinfo import OperatorInfoList
+from opinfo import OperatorInfo
+from opinfo import OperatorType
+from opinfo import Tensor
+from opinfo import Shape
+from opinfo import PadType
+from opinfo import PoolType
+
+# 'axis' for CAPPED_RELU is not an error, it just denotes a numeric parameter.
+OP_FORMATS = {
+ 'FULLY_CONNECTED': ('kernels',),
+ 'CONV_2D': ('kernels', 'padType', 'shapes'),
+ 'DEPTHWISE_CONV_2D': ('kernels', 'padType', 'shapes'),
+ 'POOL_2D': ('padType', 'poolType', 'shapes'),
+ 'CONCATENATION': ('axis',),
+ 'RESHAPE': ('shapes',),
+ 'RELU': (),
+ 'CAPPED_RELU': ('axis',),
+ 'BIAS_ADD': ('kernels',),
+ 'SOFTMAX': ('axis',)
+}
+
+
+class OpInfoSerializer:
+ def __init__(self):
+ self.builder = flatbuffers.Builder(1024*1024)
+ self.dataMap = {'int': self.builder.PrependInt32,
+ 'float': self.builder.PrependFloat32,
+ 'obj': self.builder.PrependUOffsetTRelative}
+
+ @staticmethod
+ def vector_start_method(mdl, fld):
+ method_name = mdl.__name__.split('.')[-1] + 'Start' + fld.capitalize() + 'Vector'
+ return getattr(mdl, method_name)
+
+ def make_vector(self, mdl, fld, data_type, data, data_size):
+ if data_type == 'float':
+ print('Serializing vector {}.{} of {} floats'.format(mdl.__name__, fld, data_size))
+
+ add_vector_item = self.dataMap[data_type]
+ self.vector_start_method(mdl, fld)(self.builder, data_size)
+ for item in reversed(data):
+ add_vector_item(item)
+ return self.builder.EndVector(data_size)
+
+ def make_tensor_vector(self, tensors, field_name):
+ if tensors is None:
+ return None
+ tensors = [self.make_tensor(tensor) for tensor in tensors]
+ return self.make_vector(OperatorInfo, field_name, 'obj', tensors, len(tensors))
+
+ def make_shape_vector(self, shapes):
+ if shapes is None:
+ return None
+ shapes = [self.make_shape(shape) for shape in shapes]
+ return self.make_vector(OperatorInfo, 'shapes', 'obj', shapes, len(shapes))
+
+ def make_op_info_vector(self, providers):
+ op_infos = [self.make_operator_info(p) for p in providers]
+ return self.make_vector(OperatorInfoList, 'infos', 'obj', op_infos, len(op_infos))
+
+ def make_shape(self, shape):
+ dims = self.make_vector(Shape, 'dims', 'int', shape, len(shape))
+
+ Shape.ShapeStart(self.builder)
+ Shape.ShapeAddDims(self.builder, dims)
+ return Shape.ShapeEnd(self.builder)
+
+ def make_tensor(self, tensor):
+ shape = self.make_shape(tensor.shape)
+
+ flat_tensor = tensor.flatten()
+ data = self.make_vector(Tensor, 'data', 'float', flat_tensor, tensor.size)
+
+ Tensor.TensorStart(self.builder)
+ Tensor.TensorAddShape(self.builder, shape)
+ Tensor.TensorAddData(self.builder, data)
+ return Tensor.TensorEnd(self.builder)
+
+ def add_optional(self, method, value):
+ if value is not None:
+ method(self.builder, value)
+
+ def add_optional_enum(self, method, cls, value):
+ if value is not None:
+ method(self.builder, getattr(cls, value))
+
+ def make_operator_info(self, provider):
+ inputs = self.make_tensor_vector(provider.get_inputs(), 'inputs')
+ kernels = self.make_tensor_vector(provider.get_kernels(), 'kernels')
+ results = self.make_tensor_vector(provider.get_results(), 'results')
+
+ op = getattr(OperatorType.OperatorType, provider.get_op())
+ shapes = self.make_shape_vector(provider.get_shapes())
+
+ OperatorInfo.OperatorInfoStart(self.builder)
+ OperatorInfo.OperatorInfoAddOp(self.builder, op)
+ OperatorInfo.OperatorInfoAddInputs(self.builder, inputs)
+ OperatorInfo.OperatorInfoAddResults(self.builder, results)
+
+ self.add_optional(OperatorInfo.OperatorInfoAddKernels, kernels)
+ self.add_optional_enum(OperatorInfo.OperatorInfoAddPadType, PadType.PadType, provider.get_pad_type())
+ self.add_optional_enum(OperatorInfo.OperatorInfoAddPoolType, PoolType.PoolType, provider.get_pool_type())
+ self.add_optional(OperatorInfo.OperatorInfoAddShapes, shapes)
+ self.add_optional(OperatorInfo.OperatorInfoAddAxis, provider.get_axis())
+ return OperatorInfo.OperatorInfoEnd(self.builder)
+
+ def serialize(self, providers):
+ op_infos = self.make_op_info_vector(providers)
+
+ OperatorInfoList.OperatorInfoListStart(self.builder)
+ OperatorInfoList.OperatorInfoListAddInfos(self.builder, op_infos)
+ info_list = OperatorInfoList.OperatorInfoListEnd(self.builder)
+
+ self.builder.Finish(info_list)
+ return self.builder.Output()
+
+ def save_to_file(self, providers, filename):
+ result_data = self.serialize(providers)
+ with open(filename, 'wb') as f:
+ f.write(result_data)
+
+
+class OpInfoProvider:
+ sess = tf.InteractiveSession()
+
+ def __init__(self, op_type):
+ self.kernel_gen_method = 'RANDOM'
+ self.input_gen_method = 'RANDOM'
+
+ self.op = op_type
+ self.input_shapes = None
+ self.kernel_shapes = None
+
+ self.padType = None
+ self.poolType = None
+ self.shapes = None
+ self.axis = None
+
+ self._inputs = None
+ self._kernels = None
+ self._results = None
+
+ def __repr__(self):
+ return '{}: input {}'.format(self.op, self.input_shapes)
+
+ @staticmethod
+ def gen_tensor(shape, method):
+ if method == 'RANDOM':
+ return np.random.rand(*shape).astype(np.float32) * 10 - 5
+ else:
+ raise Exception("So far only RANDOM tensor generation method is supported")
+
+ def get_op(self):
+ return self.op
+
+ def get_pad_type(self):
+ return self.padType
+
+ def get_pool_type(self):
+ return self.poolType + 'POOL' if self.poolType else self.poolType
+
+ def get_shapes(self):
+ if self.op in ('CONV_2D', 'DEPTHWISE_CONV_2D', 'POOL_2D'):
+ # Current NN interpreter implementation requires that strides and pooling kernels are 3d - [h, w, c]
+ return [shape + [1] for shape in self.shapes]
+ else:
+ return self.shapes
+
+ def get_axis(self):
+ return self.axis
+
+ def get_inputs(self):
+ if self._inputs is not None:
+ return self._inputs
+
+ self._inputs = [self.gen_tensor(shape, self.input_gen_method) for shape in self.input_shapes]
+ return self._inputs
+
+ def get_kernels(self):
+ if self.kernel_shapes is None:
+ return None
+
+ if self._kernels is not None:
+ return self._kernels
+
+ self._kernels = [self.gen_tensor(shape, self.kernel_gen_method) for shape in self.kernel_shapes]
+ return self._kernels
+
+ def get_results(self):
+ if self._results is not None:
+ return self._results
+
+ self.get_inputs()
+ self.get_kernels()
+
+ self._results = getattr(OpInfoProvider, 'get_{}_result'.format(self.op.lower()))(self)
+ return self._results
+
+ def get_fully_connected_result(self):
+ x = self._inputs[0]
+ kernel = self._kernels[0]
+ return [tf.matmul(x, kernel).eval()]
+
+ def get_conv_2d_result(self):
+ x = self._inputs[0]
+ kernel = self._kernels[0]
+ strides = [1] + self.shapes[0] + [1]
+
+ net = tf.nn.conv2d(tf.expand_dims(x, 0), kernel, strides, self.padType)
+ return [tf.squeeze(net, axis=0).eval()]
+
+ def get_depthwise_conv_2d_result(self):
+ x = self._inputs[0]
+ kernel = tf.expand_dims(self._kernels[0], -1)
+ strides = [1] + self.shapes[0] + [1]
+
+ net = tf.nn.depthwise_conv2d(tf.expand_dims(x, 0), kernel, strides, self.padType)
+ return [tf.squeeze(net, axis=0).eval()]
+
+ def get_pool_2d_result(self):
+ x = self._inputs[0]
+ net = tf.nn.pool(tf.expand_dims(x, 0), self.shapes[0], self.poolType, self.padType, strides=self.shapes[1])
+ return [tf.squeeze(net, axis=0).eval()]
+
+ def get_concatenation_result(self):
+ return [tf.concat(self._inputs, self.axis).eval()]
+
+ def get_reshape_result(self):
+ return [tf.reshape(self._inputs[0], self.shapes[0]).eval()]
+
+ def get_relu_result(self):
+ return [tf.nn.relu(self._inputs[0]).eval()]
+
+ def get_capped_relu_result(self):
+ return [tf.maximum(0.0, tf.minimum(self._inputs[0], self.axis)).eval()]
+
+ def get_bias_add_result(self):
+ return [tf.add(self._inputs[0], self._kernels[0]).eval()]
+
+ def get_softmax_result(self):
+ return [tf.nn.softmax(self._inputs[0], self.axis).eval()]
+
+
+class OpInfoParser:
+ # This regex just selects spaces that are not inside [].
+ _split_regex = re.compile(r'\s+(?!(?:[^\[\]]*\[[^\[\]]*\])*[^\[\]]*\])')
+
+ def __init__(self, op_type):
+ self.op = op_type
+
+ @staticmethod
+ def info_split(info_string):
+ return re.split(OpInfoParser._split_regex, info_string)
+
+ @staticmethod
+ def get_shape_list(shape_string):
+ shape_string = re.sub(r'\s*', '', shape_string)
+ shape_string = re.sub(r'\]\[', ' ', shape_string)
+ if shape_string.startswith('[['):
+ shape_string = shape_string[1:-1]
+ shape_list = shape_string.split(' ')
+ return [[int(dim.strip()) for dim in shape.strip('[]').split(',')] for shape in shape_list]
+
+ def create_provider(self, info_string):
+ op_format = OP_FORMATS[self.op]
+ info_items = self.info_split(info_string)
+
+ provider = OpInfoProvider(self.op)
+
+ provider.input_shapes = self.get_shape_list(info_items[0])
+ for i, op_attr in enumerate(op_format):
+ if op_attr == 'kernels':
+ provider.kernel_shapes = self.get_shape_list(info_items[i + 1])
+ elif op_attr == 'padType':
+ provider.padType = info_items[i + 1]
+ elif op_attr == 'poolType':
+ provider.poolType = info_items[i + 1]
+ elif op_attr == 'axis':
+ provider.axis = int(info_items[i + 1])
+ elif op_attr == 'shapes':
+ provider.shapes = self.get_shape_list(info_items[i + 1])
+ else:
+ raise Exception('Encountered unknown op attr type')
+
+ return provider
+
+
+def preprocess(filename):
+ with open(filename, 'r') as f:
+ lines = f.readlines()
+
+ return [stripped
+ for stripped in (line.strip() for line in lines)
+ if stripped != '' and not stripped.startswith('#')]
+
+
+def get_opwise_opinfo_lines(lines):
+ opwise_opinfo = defaultdict(lambda: [])
+ current_op = None
+ for line in lines:
+ if line in OP_FORMATS:
+ current_op = line
+ else:
+ opwise_opinfo[current_op].append(line)
+
+ if None in opwise_opinfo.keys():
+ raise Exception('Operator info description file doesn\'t start with an operator name.')
+
+ return opwise_opinfo
+
+
+def prepare_save_paths(given_path):
+ given_path = path.abspath(given_path)
+ dirname = path.dirname(given_path)
+ basename = path.basename(given_path)
+
+ if not path.exists(dirname):
+ raise Exception('Indicated path for saving result does not exist.')
+
+ if basename == '':
+ basename = 'result.fb'
+
+ return dirname, basename
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Generate Flatbuffers files containing NN operator'
+ 'info needed for testing.')
+ parser.add_argument('input', type=str)
+ parser.add_argument('-o', '--output', type=str, default='./result.fb',
+ help='output file path; if used with "-b" will use the name of the file as a postfix')
+ parser.add_argument('-b', '--bulk', action='store_true',
+ help='save result to a single file (default is to save a file for each operation)')
+
+ args = parser.parse_args()
+
+ dirname, basename = prepare_save_paths(args.output)
+
+ raw = preprocess(args.input)
+ opwise_raw = get_opwise_opinfo_lines(raw)
+
+ opwise_data_providers = dict()
+ for op_type in opwise_raw:
+ parser = OpInfoParser(op_type)
+ opwise_data_providers[op_type] = [parser.create_provider(info_string) for info_string in opwise_raw[op_type]]
+
+ if not args.bulk:
+ for op_type in opwise_data_providers:
+ serializer = OpInfoSerializer()
+ path_to_save = path.join(dirname, '{}_{}'.format(op_type.lower(), basename))
+ serializer.save_to_file(opwise_data_providers[op_type], path_to_save)
+ else:
+ all_providers = [provider for op_type in opwise_data_providers for provider in opwise_data_providers[op_type]]
+ serializer = OpInfoSerializer()
+ serializer.save_to_file(all_providers, path.join(dirname, basename))
+
+
+if __name__ == '__main__':
+ main()
--- /dev/null
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: opinfo
+
+import flatbuffers
+
+class OperatorInfo(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsOperatorInfo(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = OperatorInfo()
+ x.Init(buf, n + offset)
+ return x
+
+ # OperatorInfo
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # OperatorInfo
+ def Op(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
+ return 0
+
+ # OperatorInfo
+ def Inputs(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ x = self._tab.Vector(o)
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+ x = self._tab.Indirect(x)
+ from .Tensor import Tensor
+ obj = Tensor()
+ obj.Init(self._tab.Bytes, x)
+ return obj
+ return None
+
+ # OperatorInfo
+ def InputsLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # OperatorInfo
+ def Kernels(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ x = self._tab.Vector(o)
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+ x = self._tab.Indirect(x)
+ from .Tensor import Tensor
+ obj = Tensor()
+ obj.Init(self._tab.Bytes, x)
+ return obj
+ return None
+
+ # OperatorInfo
+ def KernelsLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # OperatorInfo
+ def Results(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ x = self._tab.Vector(o)
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+ x = self._tab.Indirect(x)
+ from .Tensor import Tensor
+ obj = Tensor()
+ obj.Init(self._tab.Bytes, x)
+ return obj
+ return None
+
+ # OperatorInfo
+ def ResultsLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # OperatorInfo
+ def PadType(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
+ return 0
+
+ # OperatorInfo
+ def PoolType(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
+ return 0
+
+ # OperatorInfo
+ def Shapes(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+ if o != 0:
+ x = self._tab.Vector(o)
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+ x = self._tab.Indirect(x)
+ from .Shape import Shape
+ obj = Shape()
+ obj.Init(self._tab.Bytes, x)
+ return obj
+ return None
+
+ # OperatorInfo
+ def ShapesLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # OperatorInfo
+ def Axis(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+def OperatorInfoStart(builder): builder.StartObject(8)
+def OperatorInfoAddOp(builder, op): builder.PrependInt8Slot(0, op, 0)
+def OperatorInfoAddInputs(builder, inputs): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0)
+def OperatorInfoStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def OperatorInfoAddKernels(builder, kernels): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(kernels), 0)
+def OperatorInfoStartKernelsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def OperatorInfoAddResults(builder, results): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(results), 0)
+def OperatorInfoStartResultsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def OperatorInfoAddPadType(builder, padType): builder.PrependInt8Slot(4, padType, 0)
+def OperatorInfoAddPoolType(builder, poolType): builder.PrependInt8Slot(5, poolType, 0)
+def OperatorInfoAddShapes(builder, shapes): builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(shapes), 0)
+def OperatorInfoStartShapesVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def OperatorInfoAddAxis(builder, axis): builder.PrependInt32Slot(7, axis, 0)
+def OperatorInfoEnd(builder): return builder.EndObject()