Add NNC AOT Compiler executable (#63994)
authorPriya Ramani <priyaramani@fb.com>
Thu, 16 Sep 2021 02:12:47 +0000 (19:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 16 Sep 2021 02:18:24 +0000 (19:18 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63994

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D30582149

Pulled By: priyaramani

fbshipit-source-id: 3bbf085428824c3cb308e006c18bb0a57f50fef6

binaries/CMakeLists.txt
binaries/aot_model_compiler.cc [new file with mode: 0644]
tools/build_variables.bzl
torch/CMakeLists.txt
torch/csrc/jit/mobile/nnc/aot_compiler.cpp [new file with mode: 0644]
torch/csrc/jit/mobile/nnc/aot_compiler.h [new file with mode: 0644]

index 4dfe767..f048aba 100644 (file)
@@ -108,3 +108,6 @@ caffe2_binary_target("tutorial_blob.cc")
 
 caffe2_binary_target("dump_operator_names.cc")
 caffe2_binary_target("optimize_for_mobile.cc")
+
+caffe2_binary_target(aot_model_compiler "aot_model_compiler.cc")
+target_link_libraries(aot_model_compiler aot_compiler)
diff --git a/binaries/aot_model_compiler.cc b/binaries/aot_model_compiler.cc
new file mode 100644 (file)
index 0000000..d757af1
--- /dev/null
@@ -0,0 +1,170 @@
+#include <sstream>
+#include <string>
+
+#include <torch/csrc/jit/backends/backend.h>
+#include <torch/csrc/jit/backends/backend_detail.h>
+#include <torch/csrc/jit/backends/backend_preprocess.h>
+#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
+#include <torch/csrc/jit/passes/freeze_module.h>
+#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
+#include <torch/csrc/jit/serialization/export.h>
+#include <torch/csrc/jit/serialization/import.h>
+#include <torch/script.h>
+
+
+C10_DEFINE_string(model, "", "The torch script model to optimize.");
+C10_DEFINE_string(model_name, "", "The name of the model.");
+C10_DEFINE_string(model_version, "", "The version of the model.");
+C10_DEFINE_string(
+    input_dims,
+    "",
+    "For input float TensorCPUs, specify the dimension using comma "
+    "separated numbers. If multiple inputs needed, use semicolon "
+    "to separate the dimension of different tensors.");
+C10_DEFINE_string(
+    output_llvm,
+    "",
+    "Name of the output llvm assembly to be saved.");
+C10_DEFINE_string(output_model, "", "Name of the output model to be saved.");
+
+namespace {
+
+std::vector<std::string> split(
+    char separator,
+    const std::string& string,
+    bool ignore_empty = true) {
+  std::vector<std::string> pieces;
+  std::stringstream ss(string);
+  std::string item;
+  while (getline(ss, item, separator)) {
+    if (!ignore_empty || !item.empty()) {
+      pieces.push_back(std::move(item));
+    }
+  }
+  return pieces;
+}
+
+std::vector<std::vector<int64_t>> parseInputShapes() {
+  CAFFE_ENFORCE_GE(FLAGS_input_dims.size(), 0, "Input dims must be specified.");
+  std::vector<std::string> input_dims_list = split(';', FLAGS_input_dims);
+  std::vector<std::vector<int64_t>> inputs;
+  for (const auto& input_dims_item : input_dims_list) {
+    auto input_dims_str = split(',', input_dims_item);
+    std::vector<int64_t> input_dims;
+    input_dims.reserve(input_dims_str.size());
+    for (const auto& s : input_dims_str) {
+      input_dims.push_back(c10::stoi(s));
+    }
+    inputs.push_back(input_dims);
+  }
+  return inputs;
+}
+
+c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
+  c10::Dict<c10::IValue, c10::IValue> compile_spec(
+      c10::StringType::get(), c10::AnyType::get());
+  c10::Dict<c10::IValue, c10::IValue> method_spec(
+      c10::StringType::get(), c10::AnyType::get());
+  auto input_shapes = parseInputShapes();
+  TORCH_CHECK(
+      input_shapes.size() == 1,
+      "Wrong # of input shapes: ",
+      input_shapes.size());
+  method_spec.insert("sizes", input_shapes[0]); // TODO: support multiple inputs
+  compile_spec.insert("forward", method_spec);
+  return compile_spec;
+}
+
+std::vector<int64_t> getInputSizesForMethod(
+    const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
+    const std::string& method_name) {
+  return method_compile_spec.at(method_name)
+      .toGenericDict()
+      .at("sizes")
+      .toIntVector();
+}
+
+std::string getNncKernelId(const std::string& method_name) {
+  // TODO: calculate the version_token.
+  const std::string version_token = "VERTOKEN";
+  return FLAGS_model_name + ":" + FLAGS_model_version + ":" + method_name +
+      ":" + version_token;
+}
+
+void writeOutputLlvmAssembly(const std::string& asm_code) {
+  std::string output_llvm_file_name = FLAGS_output_llvm;
+  if (output_llvm_file_name.empty()) {
+    output_llvm_file_name =
+        FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
+  }
+
+  std::ofstream output(output_llvm_file_name);
+  output << asm_code;
+}
+
+c10::IValue preprocess(
+    const torch::jit::Module& mod,
+    const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
+    const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
+  const std::string& method_name = "forward";
+  auto method = mod.get_method(method_name);
+  auto graph = method.function().graph()->copy();
+  auto sizes = getInputSizesForMethod(method_compile_spec, method_name);
+
+  std::string llvm_asm_code;
+  auto func =
+      torch::jit::mobile::nnc::aotCompile(method_name, graph, sizes, &llvm_asm_code);
+  writeOutputLlvmAssembly(llvm_asm_code);
+
+  func->set_nnc_kernel_id(getNncKernelId(method_name));
+
+  torch::jit::mobile::nnc::CompilationUnit cu;
+  cu.register_function(std::move(func));
+  return cu.serialize();
+}
+
+static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess);
+
+} // namespace
+
+int main(int argc, char** argv) {
+  c10::SetUsageMessage(
+      "Run NNC AOT compiler for pytorch model. Example usage:\n"
+      "build/bin/aot_model_compiler"
+      " --model=<model file>"
+      " --model_name=<model name>"
+      " --model_version=<model version>"
+      " --input_dims='1,3,224,224'"
+      " [--output_llvm=<llvm assembly output file path>]"
+      " [--output_model=<output model file path>]");
+
+  if (!c10::ParseCommandLineFlags(&argc, &argv)) {
+    std::cerr << "Failed to parse command line flags!" << std::endl;
+    std::cout << c10::UsageMessage() << std::endl;
+    return 1;
+  }
+
+  CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage());
+
+  std::string output_model_name = FLAGS_output_model;
+  if (output_model_name.empty()) {
+    output_model_name =
+        FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.pt";
+  }
+
+  auto m = torch::jit::load(FLAGS_model);
+  m.eval();
+  auto frozen_m = torch::jit::freeze_module(m.clone());
+  auto graph = frozen_m.get_method("forward").graph();
+  torch::jit::OptimizeFrozenGraph(graph, true);
+
+  auto compile_spec = createCompileSpec();
+  auto any_dict_ty =
+      c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
+  auto compiled_module = torch::jit::detail::codegen_backend_module(
+      "nnc", frozen_m, compile_spec, any_dict_ty);
+  compiled_module._save_for_mobile(output_model_name);
+  std::cout << "The compiled model was saved to " << output_model_name
+            << std::endl;
+  return 0;
+}
index a139515..30ee081 100644 (file)
@@ -183,6 +183,7 @@ core_sources_full_mobile_no_backend_interface = [
     "torch/csrc/jit/ir/subgraph_matcher.cpp",
     "torch/csrc/jit/jit_log.cpp",
     "torch/csrc/jit/jit_opt_limit.cpp",
+    "torch/csrc/jit/mobile/nnc/aot_compiler.cpp",
     "torch/csrc/jit/mobile/nnc/backend.cpp",
     "torch/csrc/jit/mobile/nnc/context.cpp",
     "torch/csrc/jit/mobile/nnc/registry.cpp",
index 7c08685..4de3346 100644 (file)
@@ -423,3 +423,9 @@ if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
   # Pybind11 requires explicit linking of the torch_python library
   target_link_libraries(nnapi_backend torch torch_python)
 endif()
+
+if(BUILD_BINARY)
+  add_library(aot_compiler SHARED
+          ${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/aot_compiler.cpp
+          )
+endif()
diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp
new file mode 100644 (file)
index 0000000..0790fdf
--- /dev/null
@@ -0,0 +1,112 @@
+#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
+
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/jit_log.h>
+#include <torch/csrc/jit/passes/constant_propagation.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/peephole.h>
+#include <torch/csrc/jit/passes/remove_mutation.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
+#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
+#include <torch/csrc/jit/tensorexpr/graph_opt.h>
+#include <torch/csrc/jit/tensorexpr/ir.h>
+#include <torch/csrc/jit/tensorexpr/kernel.h>
+
+using namespace torch::jit;
+using namespace torch::jit::tensorexpr;
+
+namespace torch {
+namespace jit {
+namespace mobile {
+namespace nnc {
+
+std::vector<int64_t> getConstSizes(const BufPtr b) {
+  std::vector<int64_t> r;
+  for (const auto& dim : b->dims()) {
+    LongImmPtr imm_dim = to<LongImm>(dim);
+    // TODO: assert it's actually immediate
+    int64_t s = imm_dim->value();
+    r.push_back(s);
+  }
+  return r;
+}
+
+void getCompiledFunction(
+    std::shared_ptr<tensorexpr::TensorExprKernel> kernel,
+    Function* func) {
+  std::vector<at::Tensor> parameters;
+
+  auto const_descriptors = kernel->getConstantDescriptors();
+  for (const auto& cd : const_descriptors) {
+    auto sizes = getConstSizes(cd.buf);
+    at::Tensor const_tensor = at::from_blob(cd.ptr, sizes).clone();
+    parameters.push_back(const_tensor);
+  }
+  func->set_parameters(c10::impl::toList(c10::List<at::Tensor>(parameters)));
+
+  MemoryPlan plan;
+  plan.buffer_sizes_ = {}; // temp_sizes_;
+  // TODO: implement prealloc optimization and fill in temp_sizes
+  func->set_memory_plan(plan);
+
+  int64_t n_inputs = kernel->graph()->inputs().size();
+  int64_t n_outputs = kernel->graph()->outputs().size();
+  std::vector<OutputSpec> out_spec;
+  for (int64_t idx = n_inputs; idx < n_inputs + n_outputs; idx++) {
+    const auto& ba = kernel->getBufferArgs()[idx];
+    OutputSpec output;
+    output.sizes_ = getConstSizes(ba.buf());
+    // TODO: assert the output is a buffer and not a scalar
+    // TODO: use actual dtype
+    output.dtype_ = c10::ScalarType::Float;
+    out_spec.push_back(output);
+  }
+  func->set_output_specs(out_spec);
+}
+
+std::unique_ptr<Function> aotCompile(
+    const std::string& method_name,
+    std::shared_ptr<Graph>& g,
+    const std::vector<int64_t>& sizes,
+    std::string* compiled_assembly) {
+  auto g2 = g->copy();
+  GRAPH_DEBUG("Input sizes ", sizes);
+
+  RemoveTensorMutation(g);
+  EliminateDeadCode(g->block());
+  g = tensorexpr::removeUnusedSelfArgument(g);
+  GRAPH_DUMP("graph before shape propagation ", g);
+
+  std::vector<c10::optional<at::Tensor>> example_inputs = {at::rand(sizes)};
+  tensorexpr::annotateInputShapes(g, example_inputs);
+
+  PropagateShapesOnGraph(g);
+  PeepholeOptimize(g, false);
+  ConstantPropagation(g);
+  PropagateShapesOnGraph(g);
+  GRAPH_DUMP("graph after shape propagation ", g);
+
+  std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
+      std::make_shared<tensorexpr::TensorExprKernel>(g);
+  *compiled_assembly = kernel->getCodeText();
+
+  g = g2;
+
+  auto func = std::make_unique<Function>();
+  func->set_name(method_name);
+
+  InputSpec input;
+  input.sizes_ = sizes;
+  input.dtype_ = c10::ScalarType::Float;
+  func->set_input_specs({input});
+
+  getCompiledFunction(kernel, func.get());
+  return func;
+}
+
+} // namespace nnc
+} // namespace mobile
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.h b/torch/csrc/jit/mobile/nnc/aot_compiler.h
new file mode 100644 (file)
index 0000000..71f6d92
--- /dev/null
@@ -0,0 +1,23 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/mobile/nnc/context.h>
+
+namespace torch {
+namespace jit {
+namespace mobile {
+namespace nnc {
+
+// Performs Ahead Of Time compilation of a given method in a model
+// returning the compiled function and LLVM assembly code
+TORCH_API std::unique_ptr<Function> aotCompile(
+    const std::string& method_name,
+    std::shared_ptr<Graph>& subgraph,
+    const std::vector<int64_t>& sizes,
+    std::string* compiled_assembly);
+
+} // namespace nnc
+} // namespace mobile
+} // namespace jit
+} // namespace torch