To move from TF C API to TFLite, we found that the argmax op in TFLite does not work for int64 inputs, so cast the int64 inputs to int32 inputs to make TFLite argmax op work
Differential Revision: https://reviews.llvm.org/D131462
module.var = tf.Variable(0, dtype=tf.int64)
def action(*inputs):
- result = tf.math.argmax(inputs[0]['mask'], axis=-1) + module.var
+ result = tf.math.argmax(tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
return {POLICY_DECISION_LABEL: result}
module.action = tf.function()(action)
action = {