[nnpack] Preallocate workspace buffer (#2369)
authorhlu1 <14827759+hlu1@users.noreply.github.com>
Sat, 5 Jan 2019 18:12:30 +0000 (10:12 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sat, 5 Jan 2019 18:12:30 +0000 (10:12 -0800)
src/contrib/nnpack/convolution.cc

index e600360c67f1b4cc36195c1f9dbde5b468131dfe..887129819bc2ec08abf12427555fe8b324e6c82e 100644 (file)
@@ -2,6 +2,7 @@
  *  Copyright (c) 2017 by Contributors
  * \file Use external nnpack library call.
  */
+#include <tvm/runtime/device_api.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/util.h>
 #include <dmlc/logging.h>
@@ -72,6 +73,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
         zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
       }
 
+      size_t workspace_size = 0;
+      nnp_status status = nnp_convolution_inference(
+          algo, nnp_convolution_transform_strategy_compute, input_channels,
+          output_channels, input_size, input_padding, kernel_size, stride_size,
+          nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
+          nnp_activation_identity, nullptr, entry->threadpool, nullptr);
+      CHECK_EQ(status, nnp_status_success);
+
+      // Division with rounding up, in case size is not multiple of sizeof(float)
+      const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);
+
+      TVMContext ctx = input->ctx;
+      TVMType type_hint = input->dtype;
+
+      DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
+      void* workspace_buffer =
+        cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
+      CHECK(workspace_buffer != nullptr);
+
       for (auto n = 0; n < input->shape[0]; ++n) {
         nnp_status status = nnp_convolution_inference(
             algo, nnp_convolution_transform_strategy_compute, input_channels,
@@ -85,10 +105,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
             static_cast<float *>(output->data) + n * output->shape[1] *
                                                     output->shape[2] *
                                                     output->shape[3],
-            NULL, NULL, nnp_activation_identity, NULL, entry->threadpool, NULL);
+            workspace_buffer, &workspace_size,
+            nnp_activation_identity, nullptr, entry->threadpool, nullptr);
 
         CHECK_EQ(status, nnp_status_success);
       }
+      cpu_api->FreeWorkspace(ctx, workspace_buffer);
     });
 
 TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform")
@@ -147,6 +169,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
         zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
       }
 
+      size_t workspace_size = 0;
+      nnp_status status = nnp_convolution_inference(
+          algo, nnp_convolution_transform_strategy_reuse, input_channels,
+          output_channels, input_size, input_padding, kernel_size, stride_size,
+          nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
+          nnp_activation_identity, nullptr, entry->threadpool, nullptr);
+      CHECK_EQ(status, nnp_status_success);
+
+      // Division with rounding up, in case size is not multiple of sizeof(float)
+      const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);
+
+      TVMContext ctx = input->ctx;
+      TVMType type_hint = input->dtype;
+
+      DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
+      void* workspace_buffer =
+        cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
+      CHECK(workspace_buffer != nullptr);
+
       for (auto n = 0; n < input->shape[0]; ++n) {
       nnp_status status = nnp_convolution_inference(
           algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels,
@@ -159,10 +200,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
           static_cast<float *>(output->data) + n * output->shape[1] *
                                output->shape[2] *
                                output->shape[3],
-          NULL, NULL,
-          nnp_activation_identity, NULL, entry->threadpool, NULL);
+          workspace_buffer, &workspace_size,
+          nnp_activation_identity, nullptr, entry->threadpool, nullptr);
       CHECK_EQ(status, nnp_status_success);
       }
+
+      cpu_api->FreeWorkspace(ctx, workspace_buffer);
     });
 
 TVM_REGISTER_GLOBAL(