From 33481c9997e3f0afb6c10b1a33a4d48014cf6100 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Thu, 11 Feb 2021 22:17:59 -0800 Subject: [PATCH] [mlgo] Fetch models from path / URL Allow custom location for pre-trained models used when AOT-compiling policies. Differential Revision: https://reviews.llvm.org/D96796 --- llvm/cmake/modules/TensorFlowCompile.cmake | 20 +++++++++++++------- llvm/lib/Analysis/CMakeLists.txt | 5 ++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/llvm/cmake/modules/TensorFlowCompile.cmake b/llvm/cmake/modules/TensorFlowCompile.cmake index a8ba56e..cbb450e 100644 --- a/llvm/cmake/modules/TensorFlowCompile.cmake +++ b/llvm/cmake/modules/TensorFlowCompile.cmake @@ -1,3 +1,14 @@ +# Ensure the ${model} is available at ${final_path}. +# +function(tfgetmodel model final_path) + if (IS_ABSOLUTE ${model}) + set(${final_path} ${model} PARENT_SCOPE) + else() + set(${final_path} + ${CMAKE_CURRENT_SOURCE_DIR}/${model} PARENT_SCOPE) + endif() +endfunction() + # Run the tensorflow compiler (saved_model_cli) on the saved model in the # ${model} directory, looking for the ${tag_set} tag set, and the SignatureDef # ${signature_def_key}. @@ -5,13 +16,8 @@ # ${CMAKE_CURRENT_BINARY_DIR}. The generated header will define a C++ class # called ${cpp_class} - which may be a namespace-qualified class name. function(tfcompile model tag_set signature_def_key fname cpp_class) - if (IS_ABSOLUTE ${model}) - set(LLVM_ML_MODELS_ABSOLUTE ${model}) - else() - set(LLVM_ML_MODELS_ABSOLUTE - ${CMAKE_CURRENT_SOURCE_DIR}/${model}) - endif() - + tfgetmodel(${model} LLVM_ML_MODELS_ABSOLUTE) + message("Using model at " ${LLVM_ML_MODELS_ABSOLUTE}) set(prefix ${CMAKE_CURRENT_BINARY_DIR}/${fname}) set(obj_file ${prefix}.o) set(hdr_file ${prefix}.h) diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt index f31cf34..c866d6b 100644 --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -1,7 +1,10 @@ if (DEFINED LLVM_HAVE_TF_AOT OR DEFINED LLVM_HAVE_TF_API) if (DEFINED LLVM_HAVE_TF_AOT) + set(LLVM_INLINER_MODEL_PATH "models/inliner" + CACHE STRING + "ML-driven inliner policy location (path to saved model)") include(TensorFlowCompile) - tfcompile(models/inliner serve action InlinerSizeModel llvm::InlinerSizeModel) + tfcompile(${LLVM_INLINER_MODEL_PATH} serve action InlinerSizeModel llvm::InlinerSizeModel) list(APPEND GeneratedMLSources $ ${GENERATED_OBJS} -- 2.7.4