fix mlgo regalloc test model generation for tflite
authoryundiqian <yundi@google.com>
Tue, 9 Aug 2022 19:10:08 +0000 (12:10 -0700)
committerMircea Trofin <mtrofin@google.com>
Tue, 9 Aug 2022 19:36:28 +0000 (12:36 -0700)
To move from TF C API to TFLite, we found that the argmax op in TFLite does not work for int64 inputs, so cast the int64 inputs to int32 inputs to make TFLite argmax op work

Differential Revision: https://reviews.llvm.org/D131462

llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py

index 476163d..11bc3f2 100644 (file)
@@ -46,7 +46,7 @@ def build_mock_model(path):
   module.var = tf.Variable(0, dtype=tf.int64)
 
   def action(*inputs):
-    result = tf.math.argmax(inputs[0]['mask'], axis=-1) + module.var
+    result = tf.math.argmax(tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
     return {POLICY_DECISION_LABEL: result}
   module.action = tf.function()(action)
   action = {