Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / res / TensorFlowPythonExamples / examples / while_2 / __init__.py
1 import tensorflow as tf
2
3 i = tf.constant(0, shape=[1, 0], dtype=tf.int32, name='i')
4 x = tf.compat.v1.placeholder(shape=[1, 1], dtype=tf.int32, name='Hole')
5
6 c = lambda i: tf.compat.v1.less(tf.compat.v1.size(i[0]), 10)
7 b = lambda i: tf.concat([i, x], axis=1)
8
9 # this loop changs i's shape from [1, 0] -> [1, 1] -> [1, 2] -> ... -> [1, 10]
10 r = tf.compat.v1.while_loop(
11     c, b, [i], name="While", shape_invariants=[tf.TensorShape([1, None])])
12
13 output = tf.compat.v1.identity(r, name="Output")
14
15 # by adding the following code, [[1 1 1 1 1 1 1 1 1 1]] and (1, 10) will be printed
16 #
17 # import numpy as np
18 # x_val = np.array([[1]])
19 # with tf.Session() as sess:
20 #   result = sess.run(r, feed_dict={x:x_val})
21 #   print(result)
22 #   print(result.shape)
23
24 # with TF 2.3, tf2tflite throws the following error
25 #
26 # Exception: venv/tf-2.3/lib/python3.6/site-packages/tensorflow/python/eager/lift_to_graph.py:339:0:
27 # error: body function result type tensor<1x1xi32> is incompatible with result type tensor<1x0xi32>
28 # at index 0
29 # ...
30 # note: see current operation: %1:2 = "tf.While"(%0, %arg0)
31 # {body = @_functionalize_body_00, cond = @_functionalize_cond_00, device = "", is_stateless = false, output_shapes = [], parallel_iterations = 10 : i64}
32 # : (tensor<1x0xi32>, tensor<1x1xi32>) -> (tensor<1x0xi32>, tensor<1x1xi32>)