[RUNTIME] Implement TVMDSOOp(TensorFlow custom op) for TVM runtime (#4459)
authortobe <tobeg3oogle@gmail.com>
Tue, 7 Apr 2020 23:59:32 +0000 (07:59 +0800)
committerGitHub <noreply@github.com>
Tue, 7 Apr 2020 23:59:32 +0000 (16:59 -0700)
* Add implementation of TVMDSOOp

* feat: Update cmake script to work with c++11 and in-repo build

* feat: Use libtvm as oplib dependency

* fix: Add missing link dependency to libtvm

* feat: Update tf tvmdso op by review comments

* fix: Update with pr comments

* fix: Fix lint

* feat: Add test script and fix gpu shape

* feat: Add test script and fix gpu shape

* fix: Conditional build tftvm op for gpu

* fix: Conditional build tftvm op for gpu

* fix: Fix pylint of tf_op module.py

* fix: Fix pylint of tf_op module.py

* feat: Conditional enable gpu test for tftvm op

* feat: Conditional enable gpu test for tftvm op

* feat: Add tf_tvmdsoop test script as an app test

* fix: Fix gpu/cpu enabled check on tvm in test script

* fix: Make tf tvmdso op test script runnable with pytest

* remove unused test script test_tfop_module.py

* fix: Remove pushd & popd in tfdsoop test script

* fix: Upgrade tftvmop use python3 to find TensorFlow

* fix: Upgrade tftvmop use python3 to find TensorFlow

* fix: Change target_link_options to target_link_libraries

* fix: Add tftvmop build script's c++ option

* fix: Add tvm library path to tf op test library path

* fix: Debug ci build for tftvm dso op

* fix: Fix cmake error and skip tfop test

* fix: Fix typo and indentation issues

* feat: Use TF list input op def

* fix: Fix style and unexpected changes

Co-authored-by: baoxinqi <baoxinqi@4paradigm.com>
Co-authored-by: Chen Dihao <chendihao@4paradigm.com>
Co-authored-by: wrongtest <wrongtest@4paradigm.com>
CMakeLists.txt
apps/tf_tvmdsoop/CMakeLists.txt [new file with mode: 0644]
apps/tf_tvmdsoop/prepare_and_test_tfop_module.sh [new file with mode: 0644]
apps/tf_tvmdsoop/tests/test_tfop_module.py [new file with mode: 0644]
cmake/config.cmake
cmake/modules/contrib/TF_TVMDSOOP.cmake [new file with mode: 0644]
python/tvm/contrib/tf_op/__init__.py [new file with mode: 0644]
python/tvm/contrib/tf_op/module.py [new file with mode: 0644]
src/contrib/tf_op/tvm_dso_op_kernels.cc [new file with mode: 0644]
src/contrib/tf_op/tvm_dso_ops.cc [new file with mode: 0644]
tests/scripts/task_python_integration.sh

index 7428802..6993f67 100644 (file)
@@ -41,6 +41,7 @@ tvm_option(USE_MSVC_MT "Build with MT" OFF)
 tvm_option(USE_MICRO "Build with Micro" OFF)
 tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF)
 tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
+tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF)
 
 # 3rdparty libraries
 tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include")
@@ -259,6 +260,7 @@ include(cmake/modules/contrib/Sort.cmake)
 include(cmake/modules/contrib/NNPack.cmake)
 include(cmake/modules/contrib/HybridDump.cmake)
 include(cmake/modules/contrib/TFLite.cmake)
+include(cmake/modules/contrib/TF_TVMDSOOP.cmake)
 
 if(NOT MSVC)
   include(CheckCXXCompilerFlag)
diff --git a/apps/tf_tvmdsoop/CMakeLists.txt b/apps/tf_tvmdsoop/CMakeLists.txt
new file mode 100644 (file)
index 0000000..cb601ef
--- /dev/null
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+cmake_minimum_required(VERSION 3.2)
+project(tf_tvmdsoop C CXX)
+
+set(TFTVM_COMPILE_FLAGS -std=c++11)
+set(BUILD_TVMDSOOP_ONLY ON)
+set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT})
+set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build)
+
+include_directories(${TVM_ROOT}/3rdparty/dlpack/include/)
+include_directories(${TVM_ROOT}/3rdparty/dmlc-core/include/)
+include_directories(${TVM_ROOT}/include)
+
+link_directories(${TVM_ROOT}/build)
+
+include(${TVM_ROOT}/cmake/util/FindCUDA.cmake)
+include(${TVM_ROOT}/cmake/modules/CUDA.cmake)
+
+include(${TVM_ROOT}/cmake/modules/contrib/TF_TVMDSOOP.cmake)
diff --git a/apps/tf_tvmdsoop/prepare_and_test_tfop_module.sh b/apps/tf_tvmdsoop/prepare_and_test_tfop_module.sh
new file mode 100644 (file)
index 0000000..2bde4f8
--- /dev/null
@@ -0,0 +1,35 @@
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+TVM_ROOT=$(cd $(dirname $0)/../..; pwd)
+echo "TVM_ROOT=${TVM_ROOT}"
+
+export PYTHONPATH=${TVM_ROOT}/python
+
+python3 -c "import tvm; print(tvm.runtime.enabled('gpu'))" | grep -e 1
+if [ "$?" -eq 0 ]; then 
+    echo "Build TF_TVMDSOOP with gpu support and execute tests"
+    CMAKE_OPTIONS="-DUSE_CUDA=ON -DPython3_EXECUTABLE=python3 -DTVM_ROOT=${TVM_ROOT}"
+    mkdir -p build
+    cd build; cmake .. ${CMAKE_OPTIONS} && make
+    cd ..
+
+    LD_LIBRARY_PATH=${TVM_ROOT}/build:./build:$LD_LIBRARY_PATH python3 -m pytest -v ./tests
+fi
+
diff --git a/apps/tf_tvmdsoop/tests/test_tfop_module.py b/apps/tf_tvmdsoop/tests/test_tfop_module.py
new file mode 100644 (file)
index 0000000..1672b58
--- /dev/null
@@ -0,0 +1,118 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test script for tf op module"""
+import tempfile
+import os
+import logging
+import tensorflow as tf
+import numpy as np
+import tvm
+from tvm import te
+from tvm.contrib import tf_op
+
+
+def test_use_tvmdso_op():
+    """main test function"""
+
+    def export_cpu_add_lib():
+        """create cpu add op lib"""
+        n = te.var("n")
+        ph_a = te.placeholder((n,), name='ph_a')
+        ph_b = te.placeholder((n,), name='ph_b')
+        ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
+        sched = te.create_schedule(ph_c.op)
+        fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "c", name="vector_add")
+        lib_path = tempfile.mktemp("tvm_add_dll.so")
+        fadd_dylib.export_library(lib_path)
+        return lib_path
+
+
+    def export_gpu_add_lib():
+        """create gpu add op lib"""
+        n = te.var("n")
+        ph_a = te.placeholder((n,), name='ph_a')
+        ph_b = te.placeholder((n,), name='ph_b')
+        ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
+        sched = te.create_schedule(ph_c.op)
+        b_axis, t_axis = sched[ph_c].split(ph_c.op.axis[0], factor=64)
+        sched[ph_c].bind(b_axis, te.thread_axis("blockIdx.x"))
+        sched[ph_c].bind(t_axis, te.thread_axis("threadIdx.x"))
+        fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "cuda", name="vector_add")
+        lib_path = tempfile.mktemp("tvm_add_cuda_dll.so")
+        fadd_dylib.export_library(lib_path)
+        return lib_path
+
+
+    def test_add(session, lib_path, tf_device):
+        """test add lib with TensorFlow wrapper"""
+        module = tf_op.OpModule(lib_path)
+
+        left = tf.placeholder("float32", shape=[4])
+        right = tf.placeholder("float32", shape=[4])
+
+        feed_dict = {left: [1.0, 2.0, 3.0, 4.0], right: [5.0, 6.0, 7.0, 8.0]}
+        expect = np.asarray([6.0, 8.0, 10.0, 12.0])
+
+        add1 = module.func("vector_add", output_shape=[4], output_dtype="float")
+        add2 = module.func("vector_add", output_shape=tf.shape(left), output_dtype="float")
+        add3 = module.func("vector_add", output_shape=[tf.shape(left)[0]], output_dtype="float")
+
+        with tf.device(tf_device):
+            output1 = session.run(add1(left, right), feed_dict)
+            np.testing.assert_equal(output1, expect)
+
+            output2 = session.run(add2(left, right), feed_dict)
+            np.testing.assert_equal(output2, expect)
+
+            output3 = session.run(add3(left, right), feed_dict)
+            np.testing.assert_equal(output3, expect)
+
+
+    def cpu_test(session):
+        """test function for cpu"""
+        cpu_lib = None
+        try:
+            cpu_lib = export_cpu_add_lib()
+            test_add(session, cpu_lib, "/cpu:0")
+        finally:
+            if cpu_lib is not None:
+                os.remove(cpu_lib)
+
+
+    def gpu_test(session):
+        """test function for gpu"""
+        gpu_lib = None
+        try:
+            gpu_lib = export_gpu_add_lib()
+            test_add(session, gpu_lib, "/gpu:0")
+        finally:
+            if gpu_lib is not None:
+                os.remove(gpu_lib)
+
+    with tf.Session() as session:
+        if tvm.runtime.enabled("cpu"):
+            logging.info("Test TensorFlow op on cpu kernel")
+            cpu_test(session)
+        if tvm.runtime.enabled("gpu"):
+            logging.info("Test TensorFlow op on gpu kernel")
+            gpu_test(session)
+
+
+if __name__ == "__main__":
+    test_use_tvmdso_op()
index 6451ea8..8c448b4 100644 (file)
@@ -204,3 +204,7 @@ set(USE_EXAMPLE_EXT_RUNTIME OFF)
 
 # Whether use Thrust
 set(USE_THRUST OFF)
+
+# Whether to build the TensorFlow TVMDSOOp module
+set(USE_TF_TVMDSOOP OFF)
+
diff --git a/cmake/modules/contrib/TF_TVMDSOOP.cmake b/cmake/modules/contrib/TF_TVMDSOOP.cmake
new file mode 100644 (file)
index 0000000..e92822a
--- /dev/null
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(NOT USE_TF_TVMDSOOP STREQUAL "OFF")
+  find_package(Python3 COMPONENTS Interpreter)
+  
+  execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()))"
+    OUTPUT_VARIABLE TF_COMPILE_FLAGS_STR
+    RESULT_VARIABLE TF_STATUS)
+  if (NOT ${TF_STATUS} EQUAL 0)
+    message(FATAL_ERROR "Fail to get TensorFlow compile flags")
+  endif()
+  
+  if(NOT USE_CUDA STREQUAL "OFF")
+    add_definitions(-DTF_TVMDSOOP_ENABLE_GPU)
+  endif()
+
+  execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()))"
+    OUTPUT_VARIABLE TF_LINK_FLAGS_STR
+    RESULT_VARIABLE TF_STATUS)
+  if (NOT ${TF_STATUS} EQUAL 0)
+    message(FATAL_ERROR "Fail to get TensorFlow link flags")
+  endif()
+
+  string(REGEX REPLACE "\n" " " TF_FLAGS "${TF_COMPILE_FLAGS} ${TF_LINK_FLAGS}")
+  separate_arguments(TF_COMPILE_FLAGS UNIX_COMMAND ${TF_COMPILE_FLAGS_STR})
+  separate_arguments(TF_LINK_FLAGS UNIX_COMMAND ${TF_LINK_FLAGS_STR})
+
+
+  set(OP_LIBRARY_NAME tvm_dso_op)
+  file(GLOB_RECURSE TFTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/tf_op/*.cc)
+  add_library(${OP_LIBRARY_NAME} SHARED ${TFTVM_SRCS})
+  set_target_properties(${OP_LIBRARY_NAME} PROPERTIES PREFIX "")
+  set(TFTVM_LINK_FLAGS  -ltvm -L${CMAKE_CURRENT_BINARY_DIR})
+  
+  if (NOT BUILD_TVMDSOOP_ONLY STREQUAL "ON")
+      add_dependencies(${OP_LIBRARY_NAME} tvm) 
+  endif()
+
+  target_compile_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_COMPILE_FLAGS} ${TF_COMPILE_FLAGS})
+  target_link_libraries(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_LINK_FLAGS} ${TF_LINK_FLAGS})
+
+endif()
+
diff --git a/python/tvm/contrib/tf_op/__init__.py b/python/tvm/contrib/tf_op/__init__.py
new file mode 100644 (file)
index 0000000..05d0ecc
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Module container of TensorFlow TVMDSO op"""
+from . import module
+
+OpModule = module.OpModule
diff --git a/python/tvm/contrib/tf_op/module.py b/python/tvm/contrib/tf_op/module.py
new file mode 100644 (file)
index 0000000..f13670e
--- /dev/null
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Module container of TensorFlow TVMDSO op"""
+import tensorflow as tf
+from tensorflow.python.framework import load_library
+
+
+class OpModule:
+    """Module container of TensorFlow TVMDSO op which wraps exported
+    TVM op implementation library to be called on TensorFlow side"""
+
+    def __init__(self, lib_path):
+        self.lib_path = lib_path
+
+    def func(self, name, output_dtype=None, output_shape=None):
+        """Get tvm op function wrapped as TensorFlow tensor to tensor function
+
+        Parameters
+        ----------
+        name: str
+            function name
+        output_dtype: str or TensorFlow datatype
+            Output datatype, default is float32
+        output_shape: List of integer/tf scalar tensor or tf shape tensor
+            Output shape, default the same with first input's shape
+
+        Returns
+        ----------
+        Func object that acts as TensorFlow tensor to tensor function.
+        """
+        return TensorFunc(self.lib_path, name, output_dtype, output_shape)
+
+    def __getitem__(self, func_name):
+        return self.func(func_name)
+
+
+class TensorFunc:
+    """Function object that acts as TensorFlow tensor to tensor function."""
+
+    def __init__(self, lib_path, func_name, output_dtype, output_shape):
+        self.lib_path = lib_path
+        self.func_name = func_name
+        self.output_dtype = output_dtype
+
+        # const(0) indicate invalid dynamic shape
+        self.dynamic_output_shape = tf.constant(0, tf.int64)
+        self.static_output_shape = None
+        self.has_static_output_shape = False  # extra flag is required
+
+        if self._is_static_shape(output_shape):
+            self.static_output_shape = output_shape
+            self.has_static_output_shape = True
+        elif output_shape is not None:
+            self.dynamic_output_shape = self._pack_shape_tensor(output_shape)
+
+        self.module = load_library.load_op_library('tvm_dso_op.so')
+        self.tvm_dso_op = self.module.tvm_dso_op
+
+    def apply(self, *params):
+        return self.tvm_dso_op(params,
+                               dynamic_output_shape=self.dynamic_output_shape,
+                               static_output_shape=self.static_output_shape,
+                               has_static_output_shape=self.has_static_output_shape,
+                               lib_path=self.lib_path,
+                               func_name=self.func_name,
+                               output_dtype=self.output_dtype)
+
+    def __call__(self, *params):
+        return self.apply(*params)
+
+    def _is_static_shape(self, shape):
+        if shape is None or not isinstance(shape, list):
+            return False
+        for dim_value in shape:
+            if not isinstance(dim_value, int):
+                return False
+            if dim_value < 0:
+                raise Exception("Negative dimension is illegal: %d" % dim_value)
+        return True
+
+    def _pack_shape_tensor(self, shape):
+        if isinstance(shape, tf.Tensor):
+            if shape.dtype == tf.int32:
+                shape = tf.cast(shape, tf.int64)
+        elif isinstance(shape, list):
+            shape_dims = []
+            for dim_value in shape:
+                if isinstance(dim_value, int):
+                    shape_dims.append(tf.constant(dim_value, tf.int64))
+                elif isinstance(dim_value, tf.Tensor) and dim_value.shape.rank == 0:
+                    if dim_value.dtype == tf.int32:
+                        dim_value = tf.cast(dim_value, tf.int64)
+                    shape_dims.append(dim_value)
+                else:
+                    raise TypeError("Input shape dimension is neither scalar tensor nor int")
+            shape = tf.stack(shape_dims)
+        else:
+            raise TypeError("Input shape is neither tensor nor list")
+        return shape
diff --git a/src/contrib/tf_op/tvm_dso_op_kernels.cc b/src/contrib/tf_op/tvm_dso_op_kernels.cc
new file mode 100644 (file)
index 0000000..d74d8fb
--- /dev/null
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+#include <cuda_runtime.h>
+#endif
+#include <dlpack/dlpack.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+typedef tensorflow::gtl::InlinedVector<tensorflow::int64, 4> ShapeContainer;
+
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+
+using tvm::runtime::TVMArgs;
+using tvm::runtime::TVMArgsSetter;
+using tvm::runtime::TVMRetValue;
+
+// Op utility trait for diffrent device type template
+template <typename DEVICE_TYPE>
+class TVMDSOOpTrait;
+
+// Buffer information used for actual computation.
+// Each buffer is associated with one TensorFlow tensor
+// whose underlying buffer is record into "origin_buf".
+// For input tensor, we copy data from origin_buf to buf
+// and for output tensor, copy data from buf to origin_buf
+class TensorAsBuf {
+ public:
+  tensorflow::Tensor inline_tensor;
+  tensorflow::Tensor* tensor;
+
+  size_t size;
+  size_t offset;
+
+  int device_type;
+
+  char* origin_buf;
+  char* buf;
+
+  void CopyToOrigin() {
+    if (buf == origin_buf) {
+      return;
+    }
+    if (device_type == kDLCPU) {
+      memcpy(origin_buf, buf + offset, size);
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+    } else if (device_type == kDLGPU) {
+      cudaMemcpy(origin_buf, buf + offset, size, cudaMemcpyDeviceToDevice);
+#endif
+    } else {
+      LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type
+                 << " is not implemented currently";
+    }
+  }
+
+  void CopyFromOrigin() {
+    if (buf == origin_buf) {
+      return;
+    }
+    if (device_type == kDLCPU) {
+      memcpy(buf + offset, origin_buf, size);
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+    } else if (device_type == kDLGPU) {
+      cudaMemcpy(buf + offset, origin_buf, size, cudaMemcpyDeviceToDevice);
+#endif
+    } else {
+      LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type
+                 << " is not implemented currently";
+    }
+  }
+};
+
+tensorflow::Status GetDLPackDtype(const tensorflow::Tensor& tf_tensor, DLDataType* res) {
+  auto dtype = tf_tensor.dtype();
+  if (dtype == tensorflow::DT_FLOAT) {
+    *res = {kDLFloat, 32, 1};
+  } else if (dtype == tensorflow::DT_INT64) {
+    *res = {kDLInt, 64, 1};
+  } else if (dtype == tensorflow::DT_INT32) {
+    *res = {kDLInt, 32, 1};
+  } else {
+    return tensorflow::Status(tensorflow::error::INTERNAL, "Fail to get dlpack datatype");
+  }
+  return tensorflow::Status::OK();
+}
+
+// Ensure buffer used for actual computation take 64byte alignment
+void EnsureAlignment(OpKernelContext* ctx, const tensorflow::Tensor& tensor, TensorAsBuf* out) {
+  char* buf = const_cast<char*>(tensor.tensor_data().data());
+  out->origin_buf = buf;
+  out->size = tensor.TotalBytes();
+
+  int alignment = 64;
+  char* aligned = reinterpret_cast<char*>(((uint64_t)buf + alignment - 1) & (~(alignment - 1)));
+  if (buf == aligned) {
+    out->tensor = const_cast<tensorflow::Tensor*>(&tensor);
+    out->buf = buf;
+    out->offset = 0;
+  } else {
+    tensorflow::TensorShape buf_shape;
+    tensorflow::int64 dims[1] = {(tensorflow::int64)(tensor.TotalBytes() + alignment)};
+    tensorflow::TensorShapeUtils::MakeShape(dims, 1, &buf_shape);
+
+    out->tensor = &out->inline_tensor;
+    ctx->allocate_temp(tensor.dtype(), buf_shape, out->tensor);
+
+    buf = const_cast<char*>(out->tensor->tensor_data().data());
+    char* buf_aligned = reinterpret_cast<char*>(((uint64_t)buf + alignment) & (~(alignment - 1)));
+    out->buf = buf;
+    out->offset = buf_aligned - buf;
+  }
+}
+
+// Create DLPack tensor from TensorFlow tensor
+tensorflow::Status MakeDLTensor(const TensorAsBuf& src, const DLContext& ctx, int64_t* tf_shape,
+                                DLTensor* out) {
+  DLDataType dlpack_type;
+  const tensorflow::Tensor& tensor = *src.tensor;
+
+  auto status = GetDLPackDtype(tensor, &dlpack_type);
+  if (!status.ok()) {
+    return status;
+  }
+  out->ctx = ctx;
+  out->ndim = tensor.shape().dims();
+  out->shape = tf_shape;
+  out->strides = nullptr;
+  out->byte_offset = 0;
+  out->dtype = dlpack_type;
+  out->data = src.buf + src.offset;
+  return tensorflow::Status::OK();
+}
+
+template <>
+class TVMDSOOpTrait<CPUDevice> {
+ public:
+  static const int device_type = kDLCPU;
+
+  static int device_id(OpKernelContext* context) { return 0; }
+
+  static void make_shape_from_tensor(const tensorflow::Tensor& shape_tensor,
+                                     tensorflow::TensorShape* output_shape) {
+    tensorflow::int64 num_dims = shape_tensor.NumElements();
+    const tensorflow::int64* dims = shape_tensor.flat<tensorflow::int64>().data();
+    tensorflow::TensorShapeUtils::MakeShape(dims, num_dims, output_shape);
+  }
+};
+
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+template <>
+class TVMDSOOpTrait<GPUDevice> {
+ public:
+  static const int device_type = kDLGPU;
+
+  static int device_id(OpKernelContext* context) {
+    auto device_base = context->device();
+    auto gpu_device_info = device_base->tensorflow_gpu_device_info();
+    return gpu_device_info->gpu_id;
+  }
+
+  static void make_shape_from_tensor(const tensorflow::Tensor& shape_tensor,
+                                     tensorflow::TensorShape* output_shape) {
+    tensorflow::int64 num_dims = shape_tensor.NumElements();
+    const tensorflow::int64* flat = shape_tensor.flat<tensorflow::int64>().data();
+    tensorflow::int64* dims = new tensorflow::int64[num_dims];
+    cudaMemcpy(dims, flat, sizeof(tensorflow::int64) * num_dims, cudaMemcpyDeviceToHost);
+    tensorflow::TensorShapeUtils::MakeShape(dims, num_dims, output_shape);
+    delete dims;
+  }
+};
+#endif
+
+template <typename DEVICE_TYPE>
+class TVMDSOOp : public OpKernel {
+ private:
+  tvm::runtime::PackedFunc tvm_func;
+  std::string lib_path;
+  std::string func_name;
+
+  tensorflow::DataType output_dtype;
+
+  bool has_static_output_shape;
+  std::vector<tensorflow::int64> static_output_shape;
+
+  void initAttributes(OpKernelConstruction* context) {
+    context->GetAttr("lib_path", &lib_path);
+    context->GetAttr("func_name", &func_name);
+    context->GetAttr("output_dtype", &output_dtype);
+
+    context->GetAttr("has_static_output_shape", &has_static_output_shape);
+    context->GetAttr("static_output_shape", &static_output_shape);
+  }
+
+ public:
+  explicit TVMDSOOp(OpKernelConstruction* context) : OpKernel(context) {
+    // Get attr
+    initAttributes(context);
+
+    // Load TVM function from dynamic library
+    tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile(lib_path);
+    tvm_func = mod_dylib.GetFunction(func_name);
+    CHECK(tvm_func != nullptr);
+  }
+
+  void Compute(tensorflow::OpKernelContext* context) override {
+    // the last input is output shape spec
+    const int num_inputs = context->num_inputs() - 1;
+    const int num_total_args = num_inputs + 1;
+    std::vector<DLTensor> args(num_total_args);
+    std::vector<TensorAsBuf> buf_info(num_inputs);
+    std::vector<ShapeContainer> shapes(num_inputs);
+
+    tensorflow::Status status;
+    int device_id = TVMDSOOpTrait<DEVICE_TYPE>::device_id(context);
+    int device_type = TVMDSOOpTrait<DEVICE_TYPE>::device_type;
+
+    DLContext dl_ctx = {DLDeviceType(device_type), device_id};
+
+    // Get output shape
+    tensorflow::TensorShape output_shape;
+    auto& output_shape_tensor = context->input(num_inputs);
+    if (has_static_output_shape) {
+      // use static output shape
+      const tensorflow::int64* dims = static_output_shape.data();
+      tensorflow::TensorShapeUtils::MakeShape(dims, static_output_shape.size(), &output_shape);
+    } else if (output_shape_tensor.dims() == 1) {
+      // use shape tensor values as output shape
+      TVMDSOOpTrait<DEVICE_TYPE>::make_shape_from_tensor(output_shape_tensor, &output_shape);
+    } else {
+      // use input tensor shape by default
+      output_shape = context->input(0).shape();
+    }
+
+    for (int i = 0; i < num_inputs; ++i) {
+      // Grab the input tensor
+      auto& input_tensor = context->input(i);
+
+      // Create shape container, should keep ref during execution
+      shapes[i] = input_tensor.shape().dim_sizes();
+      auto shape_ptr = reinterpret_cast<int64_t*>(shapes[i].data());
+
+      TensorAsBuf& input = buf_info[i];
+      input.device_type = device_type;
+
+      EnsureAlignment(context, input_tensor, &input);
+      input.CopyFromOrigin();
+
+      status = MakeDLTensor(input, dl_ctx, shape_ptr, &args[i]);
+      OP_REQUIRES_OK(context, status);
+    }
+
+    // Allocate output tensor
+    tensorflow::Tensor* output_tensor;
+    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));
+    // shape dimension buf should keel alive on stack
+    auto output_shape_dim_buf = output_tensor->shape().dim_sizes();
+    auto output_shape_ptr = reinterpret_cast<int64_t*>(output_shape_dim_buf.data());
+
+    TensorAsBuf output;
+    output.device_type = device_type;
+    EnsureAlignment(context, *output_tensor, &output);
+
+    status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[num_inputs]);
+    OP_REQUIRES_OK(context, status);
+
+    // Prepare PackedFunc arguments
+    std::vector<TVMValue> tvm_values(num_total_args);
+    std::vector<int> tvm_type_codes(num_total_args);
+    TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
+    for (int k = 0; k < num_total_args; ++k) {
+      setter(k, &args[k]);
+    }
+    TVMRetValue rv;
+    tvm_func.CallPacked(TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv);
+
+    output.CopyToOrigin();
+  }
+};
+
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_GPU), TVMDSOOp<GPUDevice>);
+#else
+REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
+#endif
diff --git a/src/contrib/tf_op/tvm_dso_ops.cc b/src/contrib/tf_op/tvm_dso_ops.cc
new file mode 100644 (file)
index 0000000..1183b2e
--- /dev/null
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "tensorflow/core/framework/op.h"
+
+REGISTER_OP("TvmDsoOp")
+    .Input("input_args: ListT")
+    .Attr("ListT: list({int8, int32, int64, float16, float32})")
+    .Input("dynamic_output_shape: int64")
+    .Output("output: output_dtype")
+    .Attr("lib_path: string")
+    .Attr("func_name: string")
+    .Attr("output_dtype: {int8, int32, int64, float16, float32} = DT_FLOAT")
+    .Attr("static_output_shape: list(int) >= 0 = []")
+    .Attr("has_static_output_shape: bool");
index d9cc0ec..f396aaa 100755 (executable)
@@ -53,6 +53,9 @@ cd ../..
 TVM_FFI=cython python3 -m pytest -v apps/dso_plugin_module
 TVM_FFI=ctypes python3 -m pytest -v apps/dso_plugin_module
 
+# Do not enable TensorFlow op
+# TVM_FFI=cython sh prepare_and_test_tfop_module.sh
+# TVM_FFI=ctypes sh prepare_and_test_tfop_module.sh
 
 TVM_FFI=ctypes python3 -m pytest -v tests/python/integration
 TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib