tvm_option(USE_MICRO "Build with Micro" OFF)
tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF)
tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
+tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF)
# 3rdparty libraries
tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include")
include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)
+include(cmake/modules/contrib/TF_TVMDSOOP.cmake)
if(NOT MSVC)
include(CheckCXXCompilerFlag)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+cmake_minimum_required(VERSION 3.2)
+project(tf_tvmdsoop C CXX)
+
+set(TFTVM_COMPILE_FLAGS -std=c++11)
+set(BUILD_TVMDSOOP_ONLY ON)
+set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT})
+set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build)
+
+include_directories(${TVM_ROOT}/3rdparty/dlpack/include/)
+include_directories(${TVM_ROOT}/3rdparty/dmlc-core/include/)
+include_directories(${TVM_ROOT}/include)
+
+link_directories(${TVM_ROOT}/build)
+
+include(${TVM_ROOT}/cmake/util/FindCUDA.cmake)
+include(${TVM_ROOT}/cmake/modules/CUDA.cmake)
+
+include(${TVM_ROOT}/cmake/modules/contrib/TF_TVMDSOOP.cmake)
--- /dev/null
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+TVM_ROOT=$(cd $(dirname $0)/../..; pwd)
+echo "TVM_ROOT=${TVM_ROOT}"
+
+export PYTHONPATH=${TVM_ROOT}/python
+
+python3 -c "import tvm; print(tvm.runtime.enabled('gpu'))" | grep -e 1
+if [ "$?" -eq 0 ]; then
+ echo "Build TF_TVMDSOOP with gpu support and execute tests"
+ CMAKE_OPTIONS="-DUSE_CUDA=ON -DPython3_EXECUTABLE=python3 -DTVM_ROOT=${TVM_ROOT}"
+
+ mkdir -p build
+ cd build; cmake .. ${CMAKE_OPTIONS} && make
+ cd ..
+
+ LD_LIBRARY_PATH=${TVM_ROOT}/build:./build:$LD_LIBRARY_PATH python3 -m pytest -v ./tests
+fi
+
--- /dev/null
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test script for tf op module"""
+import tempfile
+import os
+import logging
+import tensorflow as tf
+import numpy as np
+import tvm
+from tvm import te
+from tvm.contrib import tf_op
+
+
+def test_use_tvmdso_op():
+ """main test function"""
+
+ def export_cpu_add_lib():
+ """create cpu add op lib"""
+ n = te.var("n")
+ ph_a = te.placeholder((n,), name='ph_a')
+ ph_b = te.placeholder((n,), name='ph_b')
+ ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
+ sched = te.create_schedule(ph_c.op)
+ fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "c", name="vector_add")
+ lib_path = tempfile.mktemp("tvm_add_dll.so")
+ fadd_dylib.export_library(lib_path)
+ return lib_path
+
+
+ def export_gpu_add_lib():
+ """create gpu add op lib"""
+ n = te.var("n")
+ ph_a = te.placeholder((n,), name='ph_a')
+ ph_b = te.placeholder((n,), name='ph_b')
+ ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
+ sched = te.create_schedule(ph_c.op)
+ b_axis, t_axis = sched[ph_c].split(ph_c.op.axis[0], factor=64)
+ sched[ph_c].bind(b_axis, te.thread_axis("blockIdx.x"))
+ sched[ph_c].bind(t_axis, te.thread_axis("threadIdx.x"))
+ fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "cuda", name="vector_add")
+ lib_path = tempfile.mktemp("tvm_add_cuda_dll.so")
+ fadd_dylib.export_library(lib_path)
+ return lib_path
+
+
+ def test_add(session, lib_path, tf_device):
+ """test add lib with TensorFlow wrapper"""
+ module = tf_op.OpModule(lib_path)
+
+ left = tf.placeholder("float32", shape=[4])
+ right = tf.placeholder("float32", shape=[4])
+
+ feed_dict = {left: [1.0, 2.0, 3.0, 4.0], right: [5.0, 6.0, 7.0, 8.0]}
+ expect = np.asarray([6.0, 8.0, 10.0, 12.0])
+
+ add1 = module.func("vector_add", output_shape=[4], output_dtype="float")
+ add2 = module.func("vector_add", output_shape=tf.shape(left), output_dtype="float")
+ add3 = module.func("vector_add", output_shape=[tf.shape(left)[0]], output_dtype="float")
+
+ with tf.device(tf_device):
+ output1 = session.run(add1(left, right), feed_dict)
+ np.testing.assert_equal(output1, expect)
+
+ output2 = session.run(add2(left, right), feed_dict)
+ np.testing.assert_equal(output2, expect)
+
+ output3 = session.run(add3(left, right), feed_dict)
+ np.testing.assert_equal(output3, expect)
+
+
+ def cpu_test(session):
+ """test function for cpu"""
+ cpu_lib = None
+ try:
+ cpu_lib = export_cpu_add_lib()
+ test_add(session, cpu_lib, "/cpu:0")
+ finally:
+ if cpu_lib is not None:
+ os.remove(cpu_lib)
+
+
+ def gpu_test(session):
+ """test function for gpu"""
+ gpu_lib = None
+ try:
+ gpu_lib = export_gpu_add_lib()
+ test_add(session, gpu_lib, "/gpu:0")
+ finally:
+ if gpu_lib is not None:
+ os.remove(gpu_lib)
+
+ with tf.Session() as session:
+ if tvm.runtime.enabled("cpu"):
+ logging.info("Test TensorFlow op on cpu kernel")
+ cpu_test(session)
+ if tvm.runtime.enabled("gpu"):
+ logging.info("Test TensorFlow op on gpu kernel")
+ gpu_test(session)
+
+
+if __name__ == "__main__":
+ test_use_tvmdso_op()
# Whether use Thrust
set(USE_THRUST OFF)
+
+# Whether to build the TensorFlow TVMDSOOp module
+set(USE_TF_TVMDSOOP OFF)
+
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(NOT USE_TF_TVMDSOOP STREQUAL "OFF")
+ find_package(Python3 COMPONENTS Interpreter)
+
+ execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()))"
+ OUTPUT_VARIABLE TF_COMPILE_FLAGS_STR
+ RESULT_VARIABLE TF_STATUS)
+ if (NOT ${TF_STATUS} EQUAL 0)
+ message(FATAL_ERROR "Fail to get TensorFlow compile flags")
+ endif()
+
+ if(NOT USE_CUDA STREQUAL "OFF")
+ add_definitions(-DTF_TVMDSOOP_ENABLE_GPU)
+ endif()
+
+ execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()))"
+ OUTPUT_VARIABLE TF_LINK_FLAGS_STR
+ RESULT_VARIABLE TF_STATUS)
+ if (NOT ${TF_STATUS} EQUAL 0)
+ message(FATAL_ERROR "Fail to get TensorFlow link flags")
+ endif()
+
+ string(REGEX REPLACE "\n" " " TF_FLAGS "${TF_COMPILE_FLAGS} ${TF_LINK_FLAGS}")
+ separate_arguments(TF_COMPILE_FLAGS UNIX_COMMAND ${TF_COMPILE_FLAGS_STR})
+ separate_arguments(TF_LINK_FLAGS UNIX_COMMAND ${TF_LINK_FLAGS_STR})
+
+
+ set(OP_LIBRARY_NAME tvm_dso_op)
+ file(GLOB_RECURSE TFTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/tf_op/*.cc)
+ add_library(${OP_LIBRARY_NAME} SHARED ${TFTVM_SRCS})
+ set_target_properties(${OP_LIBRARY_NAME} PROPERTIES PREFIX "")
+ set(TFTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})
+
+ if (NOT BUILD_TVMDSOOP_ONLY STREQUAL "ON")
+ add_dependencies(${OP_LIBRARY_NAME} tvm)
+ endif()
+
+ target_compile_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_COMPILE_FLAGS} ${TF_COMPILE_FLAGS})
+ target_link_libraries(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_LINK_FLAGS} ${TF_LINK_FLAGS})
+
+endif()
+
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Module container of TensorFlow TVMDSO op"""
+from . import module
+
+OpModule = module.OpModule
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Module container of TensorFlow TVMDSO op"""
+import tensorflow as tf
+from tensorflow.python.framework import load_library
+
+
+class OpModule:
+ """Module container of TensorFlow TVMDSO op which wraps exported
+ TVM op implementation library to be called on TensorFlow side"""
+
+ def __init__(self, lib_path):
+ self.lib_path = lib_path
+
+ def func(self, name, output_dtype=None, output_shape=None):
+ """Get tvm op function wrapped as TensorFlow tensor to tensor function
+
+ Parameters
+ ----------
+ name: str
+ function name
+ output_dtype: str or TensorFlow datatype
+ Output datatype, default is float32
+ output_shape: List of integer/tf scalar tensor or tf shape tensor
+ Output shape, default the same with first input's shape
+
+ Returns
+ ----------
+ Func object that acts as TensorFlow tensor to tensor function.
+ """
+ return TensorFunc(self.lib_path, name, output_dtype, output_shape)
+
+ def __getitem__(self, func_name):
+ return self.func(func_name)
+
+
+class TensorFunc:
+ """Function object that acts as TensorFlow tensor to tensor function."""
+
+ def __init__(self, lib_path, func_name, output_dtype, output_shape):
+ self.lib_path = lib_path
+ self.func_name = func_name
+ self.output_dtype = output_dtype
+
+ # const(0) indicate invalid dynamic shape
+ self.dynamic_output_shape = tf.constant(0, tf.int64)
+ self.static_output_shape = None
+ self.has_static_output_shape = False # extra flag is required
+
+ if self._is_static_shape(output_shape):
+ self.static_output_shape = output_shape
+ self.has_static_output_shape = True
+ elif output_shape is not None:
+ self.dynamic_output_shape = self._pack_shape_tensor(output_shape)
+
+ self.module = load_library.load_op_library('tvm_dso_op.so')
+ self.tvm_dso_op = self.module.tvm_dso_op
+
+ def apply(self, *params):
+ return self.tvm_dso_op(params,
+ dynamic_output_shape=self.dynamic_output_shape,
+ static_output_shape=self.static_output_shape,
+ has_static_output_shape=self.has_static_output_shape,
+ lib_path=self.lib_path,
+ func_name=self.func_name,
+ output_dtype=self.output_dtype)
+
+ def __call__(self, *params):
+ return self.apply(*params)
+
+ def _is_static_shape(self, shape):
+ if shape is None or not isinstance(shape, list):
+ return False
+ for dim_value in shape:
+ if not isinstance(dim_value, int):
+ return False
+ if dim_value < 0:
+ raise Exception("Negative dimension is illegal: %d" % dim_value)
+ return True
+
+ def _pack_shape_tensor(self, shape):
+ if isinstance(shape, tf.Tensor):
+ if shape.dtype == tf.int32:
+ shape = tf.cast(shape, tf.int64)
+ elif isinstance(shape, list):
+ shape_dims = []
+ for dim_value in shape:
+ if isinstance(dim_value, int):
+ shape_dims.append(tf.constant(dim_value, tf.int64))
+ elif isinstance(dim_value, tf.Tensor) and dim_value.shape.rank == 0:
+ if dim_value.dtype == tf.int32:
+ dim_value = tf.cast(dim_value, tf.int64)
+ shape_dims.append(dim_value)
+ else:
+ raise TypeError("Input shape dimension is neither scalar tensor nor int")
+ shape = tf.stack(shape_dims)
+ else:
+ raise TypeError("Input shape is neither tensor nor list")
+ return shape
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+#include <cuda_runtime.h>
+#endif
+#include <dlpack/dlpack.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+typedef tensorflow::gtl::InlinedVector<tensorflow::int64, 4> ShapeContainer;
+
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+
+using tvm::runtime::TVMArgs;
+using tvm::runtime::TVMArgsSetter;
+using tvm::runtime::TVMRetValue;
+
+// Op utility trait for diffrent device type template
+template <typename DEVICE_TYPE>
+class TVMDSOOpTrait;
+
+// Buffer information used for actual computation.
+// Each buffer is associated with one TensorFlow tensor
+// whose underlying buffer is record into "origin_buf".
+// For input tensor, we copy data from origin_buf to buf
+// and for output tensor, copy data from buf to origin_buf
+class TensorAsBuf {
+ public:
+ tensorflow::Tensor inline_tensor;
+ tensorflow::Tensor* tensor;
+
+ size_t size;
+ size_t offset;
+
+ int device_type;
+
+ char* origin_buf;
+ char* buf;
+
+ void CopyToOrigin() {
+ if (buf == origin_buf) {
+ return;
+ }
+ if (device_type == kDLCPU) {
+ memcpy(origin_buf, buf + offset, size);
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+ } else if (device_type == kDLGPU) {
+ cudaMemcpy(origin_buf, buf + offset, size, cudaMemcpyDeviceToDevice);
+#endif
+ } else {
+ LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type
+ << " is not implemented currently";
+ }
+ }
+
+ void CopyFromOrigin() {
+ if (buf == origin_buf) {
+ return;
+ }
+ if (device_type == kDLCPU) {
+ memcpy(buf + offset, origin_buf, size);
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+ } else if (device_type == kDLGPU) {
+ cudaMemcpy(buf + offset, origin_buf, size, cudaMemcpyDeviceToDevice);
+#endif
+ } else {
+ LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type
+ << " is not implemented currently";
+ }
+ }
+};
+
+tensorflow::Status GetDLPackDtype(const tensorflow::Tensor& tf_tensor, DLDataType* res) {
+ auto dtype = tf_tensor.dtype();
+ if (dtype == tensorflow::DT_FLOAT) {
+ *res = {kDLFloat, 32, 1};
+ } else if (dtype == tensorflow::DT_INT64) {
+ *res = {kDLInt, 64, 1};
+ } else if (dtype == tensorflow::DT_INT32) {
+ *res = {kDLInt, 32, 1};
+ } else {
+ return tensorflow::Status(tensorflow::error::INTERNAL, "Fail to get dlpack datatype");
+ }
+ return tensorflow::Status::OK();
+}
+
+// Ensure buffer used for actual computation take 64byte alignment
+void EnsureAlignment(OpKernelContext* ctx, const tensorflow::Tensor& tensor, TensorAsBuf* out) {
+ char* buf = const_cast<char*>(tensor.tensor_data().data());
+ out->origin_buf = buf;
+ out->size = tensor.TotalBytes();
+
+ int alignment = 64;
+ char* aligned = reinterpret_cast<char*>(((uint64_t)buf + alignment - 1) & (~(alignment - 1)));
+ if (buf == aligned) {
+ out->tensor = const_cast<tensorflow::Tensor*>(&tensor);
+ out->buf = buf;
+ out->offset = 0;
+ } else {
+ tensorflow::TensorShape buf_shape;
+ tensorflow::int64 dims[1] = {(tensorflow::int64)(tensor.TotalBytes() + alignment)};
+ tensorflow::TensorShapeUtils::MakeShape(dims, 1, &buf_shape);
+
+ out->tensor = &out->inline_tensor;
+ ctx->allocate_temp(tensor.dtype(), buf_shape, out->tensor);
+
+ buf = const_cast<char*>(out->tensor->tensor_data().data());
+ char* buf_aligned = reinterpret_cast<char*>(((uint64_t)buf + alignment) & (~(alignment - 1)));
+ out->buf = buf;
+ out->offset = buf_aligned - buf;
+ }
+}
+
+// Create DLPack tensor from TensorFlow tensor
+tensorflow::Status MakeDLTensor(const TensorAsBuf& src, const DLContext& ctx, int64_t* tf_shape,
+ DLTensor* out) {
+ DLDataType dlpack_type;
+ const tensorflow::Tensor& tensor = *src.tensor;
+
+ auto status = GetDLPackDtype(tensor, &dlpack_type);
+ if (!status.ok()) {
+ return status;
+ }
+ out->ctx = ctx;
+ out->ndim = tensor.shape().dims();
+ out->shape = tf_shape;
+ out->strides = nullptr;
+ out->byte_offset = 0;
+ out->dtype = dlpack_type;
+ out->data = src.buf + src.offset;
+ return tensorflow::Status::OK();
+}
+
+template <>
+class TVMDSOOpTrait<CPUDevice> {
+ public:
+ static const int device_type = kDLCPU;
+
+ static int device_id(OpKernelContext* context) { return 0; }
+
+ static void make_shape_from_tensor(const tensorflow::Tensor& shape_tensor,
+ tensorflow::TensorShape* output_shape) {
+ tensorflow::int64 num_dims = shape_tensor.NumElements();
+ const tensorflow::int64* dims = shape_tensor.flat<tensorflow::int64>().data();
+ tensorflow::TensorShapeUtils::MakeShape(dims, num_dims, output_shape);
+ }
+};
+
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+template <>
+class TVMDSOOpTrait<GPUDevice> {
+ public:
+ static const int device_type = kDLGPU;
+
+ static int device_id(OpKernelContext* context) {
+ auto device_base = context->device();
+ auto gpu_device_info = device_base->tensorflow_gpu_device_info();
+ return gpu_device_info->gpu_id;
+ }
+
+ static void make_shape_from_tensor(const tensorflow::Tensor& shape_tensor,
+ tensorflow::TensorShape* output_shape) {
+ tensorflow::int64 num_dims = shape_tensor.NumElements();
+ const tensorflow::int64* flat = shape_tensor.flat<tensorflow::int64>().data();
+ tensorflow::int64* dims = new tensorflow::int64[num_dims];
+ cudaMemcpy(dims, flat, sizeof(tensorflow::int64) * num_dims, cudaMemcpyDeviceToHost);
+ tensorflow::TensorShapeUtils::MakeShape(dims, num_dims, output_shape);
+ delete dims;
+ }
+};
+#endif
+
+template <typename DEVICE_TYPE>
+class TVMDSOOp : public OpKernel {
+ private:
+ tvm::runtime::PackedFunc tvm_func;
+ std::string lib_path;
+ std::string func_name;
+
+ tensorflow::DataType output_dtype;
+
+ bool has_static_output_shape;
+ std::vector<tensorflow::int64> static_output_shape;
+
+ void initAttributes(OpKernelConstruction* context) {
+ context->GetAttr("lib_path", &lib_path);
+ context->GetAttr("func_name", &func_name);
+ context->GetAttr("output_dtype", &output_dtype);
+
+ context->GetAttr("has_static_output_shape", &has_static_output_shape);
+ context->GetAttr("static_output_shape", &static_output_shape);
+ }
+
+ public:
+ explicit TVMDSOOp(OpKernelConstruction* context) : OpKernel(context) {
+ // Get attr
+ initAttributes(context);
+
+ // Load TVM function from dynamic library
+ tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile(lib_path);
+ tvm_func = mod_dylib.GetFunction(func_name);
+ CHECK(tvm_func != nullptr);
+ }
+
+ void Compute(tensorflow::OpKernelContext* context) override {
+ // the last input is output shape spec
+ const int num_inputs = context->num_inputs() - 1;
+ const int num_total_args = num_inputs + 1;
+ std::vector<DLTensor> args(num_total_args);
+ std::vector<TensorAsBuf> buf_info(num_inputs);
+ std::vector<ShapeContainer> shapes(num_inputs);
+
+ tensorflow::Status status;
+ int device_id = TVMDSOOpTrait<DEVICE_TYPE>::device_id(context);
+ int device_type = TVMDSOOpTrait<DEVICE_TYPE>::device_type;
+
+ DLContext dl_ctx = {DLDeviceType(device_type), device_id};
+
+ // Get output shape
+ tensorflow::TensorShape output_shape;
+ auto& output_shape_tensor = context->input(num_inputs);
+ if (has_static_output_shape) {
+ // use static output shape
+ const tensorflow::int64* dims = static_output_shape.data();
+ tensorflow::TensorShapeUtils::MakeShape(dims, static_output_shape.size(), &output_shape);
+ } else if (output_shape_tensor.dims() == 1) {
+ // use shape tensor values as output shape
+ TVMDSOOpTrait<DEVICE_TYPE>::make_shape_from_tensor(output_shape_tensor, &output_shape);
+ } else {
+ // use input tensor shape by default
+ output_shape = context->input(0).shape();
+ }
+
+ for (int i = 0; i < num_inputs; ++i) {
+ // Grab the input tensor
+ auto& input_tensor = context->input(i);
+
+ // Create shape container, should keep ref during execution
+ shapes[i] = input_tensor.shape().dim_sizes();
+ auto shape_ptr = reinterpret_cast<int64_t*>(shapes[i].data());
+
+ TensorAsBuf& input = buf_info[i];
+ input.device_type = device_type;
+
+ EnsureAlignment(context, input_tensor, &input);
+ input.CopyFromOrigin();
+
+ status = MakeDLTensor(input, dl_ctx, shape_ptr, &args[i]);
+ OP_REQUIRES_OK(context, status);
+ }
+
+ // Allocate output tensor
+ tensorflow::Tensor* output_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));
+ // shape dimension buf should keel alive on stack
+ auto output_shape_dim_buf = output_tensor->shape().dim_sizes();
+ auto output_shape_ptr = reinterpret_cast<int64_t*>(output_shape_dim_buf.data());
+
+ TensorAsBuf output;
+ output.device_type = device_type;
+ EnsureAlignment(context, *output_tensor, &output);
+
+ status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[num_inputs]);
+ OP_REQUIRES_OK(context, status);
+
+ // Prepare PackedFunc arguments
+ std::vector<TVMValue> tvm_values(num_total_args);
+ std::vector<int> tvm_type_codes(num_total_args);
+ TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
+ for (int k = 0; k < num_total_args; ++k) {
+ setter(k, &args[k]);
+ }
+ TVMRetValue rv;
+ tvm_func.CallPacked(TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv);
+
+ output.CopyToOrigin();
+ }
+};
+
+#ifdef TF_TVMDSOOP_ENABLE_GPU
+REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_GPU), TVMDSOOp<GPUDevice>);
+#else
+REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
+#endif
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "tensorflow/core/framework/op.h"
+
+REGISTER_OP("TvmDsoOp")
+ .Input("input_args: ListT")
+ .Attr("ListT: list({int8, int32, int64, float16, float32})")
+ .Input("dynamic_output_shape: int64")
+ .Output("output: output_dtype")
+ .Attr("lib_path: string")
+ .Attr("func_name: string")
+ .Attr("output_dtype: {int8, int32, int64, float16, float32} = DT_FLOAT")
+ .Attr("static_output_shape: list(int) >= 0 = []")
+ .Attr("has_static_output_shape: bool");
TVM_FFI=cython python3 -m pytest -v apps/dso_plugin_module
TVM_FFI=ctypes python3 -m pytest -v apps/dso_plugin_module
+# Do not enable TensorFlow op
+# TVM_FFI=cython sh prepare_and_test_tfop_module.sh
+# TVM_FFI=ctypes sh prepare_and_test_tfop_module.sh
TVM_FFI=ctypes python3 -m pytest -v tests/python/integration
TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib