==============================================================================*/
#include <tvm/runtime/c_runtime_api.h>
+
#include <cstddef>
#include <cstdint>
#endif
}
-void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst,
- size_t size) {
+void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, size_t size) {
float a_f, b_f;
BFloat16ToFloat(a, &a_f, 1);
BFloat16ToFloat(b, &b_f, 1);
#ifndef VTA_DE10_NANO_KERNEL_MODULE_CMA_H_
#define VTA_DE10_NANO_KERNEL_MODULE_CMA_H_
-
/* Should be defined in settings.mk file */
#ifndef CMA_IOCTL_MAGIC
-#define CMA_IOCTL_MAGIC 0xf2
+#define CMA_IOCTL_MAGIC 0xf2
#endif
+#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 1, 4)
+#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 2, 4)
+#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4)
+#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 4, 4)
+#define CMA_GET_SIZE _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 5, 4)
-#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 1, 4)
-#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 2, 4)
-#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4)
-#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 4, 4)
-#define CMA_GET_SIZE _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 5, 4)
-
-#define CMA_IOCTL_MAXNR 5
-
+#define CMA_IOCTL_MAXNR 5
#endif // VTA_DE10_NANO_KERNEL_MODULE_CMA_H_
* \brief Application layer implementation for contigous memory allocation.
*/
+#include <errno.h>
+#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
-#include <fcntl.h>
-#include <unistd.h>
-#include <errno.h>
#include <string.h>
-#include <sys/types.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
+#include <sys/types.h>
+#include <unistd.h>
#include "cma_api.h"
#ifndef CMA_IOCTL_MAGIC
-#define CMA_IOCTL_MAGIC 0xf2
+#define CMA_IOCTL_MAGIC 0xf2
#endif
-#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 1, 4)
-#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 2, 4)
-#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4)
-#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 4, 4)
-#define CMA_GET_SIZE _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 5, 4)
+#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 1, 4)
+#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 2, 4)
+#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4)
+#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 4, 4)
+#define CMA_GET_SIZE _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 5, 4)
-#define CMA_IOCTL_MAXNR 5
+#define CMA_IOCTL_MAXNR 5
#ifndef CMA_DEBUG
- #define CMA_DEBUG 0
+#define CMA_DEBUG 0
#endif
#ifndef DRIVER_NODE_NAME
- #define DRIVER_NODE_NAME "cma"
+#define DRIVER_NODE_NAME "cma"
#endif
#if CMA_DEBUG == 1
- #define __DEBUG(fmt, args...) printf("CMA_API_DEBUG: " fmt, ##args)
+#define __DEBUG(fmt, args...) printf("CMA_API_DEBUG: " fmt, ##args)
#else
- #define __DEBUG(fmt, args...)
+#define __DEBUG(fmt, args...)
#endif
-#define ROUND_UP(N, S) ((((N) + (S) - 1) / (S)) * (S))
-
+#define ROUND_UP(N, S) ((((N) + (S)-1) / (S)) * (S))
/* Private functions */
-void *cma_alloc(size_t size, unsigned ioctl_cmd);
+void* cma_alloc(size_t size, unsigned ioctl_cmd);
/* Global file descriptor */
int cma_fd = 0;
return 0;
}
-void *cma_alloc_cached(size_t size) {
- return cma_alloc(size, CMA_ALLOC_CACHED);
-}
+void* cma_alloc_cached(size_t size) { return cma_alloc(size, CMA_ALLOC_CACHED); }
-void *cma_alloc_noncached(size_t size) {
- return cma_alloc(size, CMA_ALLOC_NONCACHED);
-}
+void* cma_alloc_noncached(size_t size) { return cma_alloc(size, CMA_ALLOC_NONCACHED); }
-int cma_free(void *mem) {
+int cma_free(void* mem) {
__DEBUG("Releasing contigous memory from 0x%x\n", (unsigned)mem);
unsigned data, v_addr;
/* save user space pointer value */
- data = (unsigned)mem;
+ data = (unsigned)mem;
v_addr = (unsigned)mem;
- if ( ioctl(cma_fd, CMA_GET_SIZE, &data) == -1 ) {
+ if (ioctl(cma_fd, CMA_GET_SIZE, &data) == -1) {
__DEBUG("cma_free - ioctl command unsuccsessful - 0\n");
return -1;
}
munmap(mem, data);
/* free cma entry */
- if ( ioctl(cma_fd, CMA_FREE, &v_addr) == -1 ) {
+ if (ioctl(cma_fd, CMA_FREE, &v_addr) == -1) {
__DEBUG("cma_free - ioctl command unsuccsessful - 1\n");
return -1;
}
return 0;
}
-unsigned cma_get_phy_addr(void *mem) {
+unsigned cma_get_phy_addr(void* mem) {
unsigned data;
__DEBUG("Getting physical address from 0x%x\n", (unsigned)mem);
data = (unsigned)mem;
/* get physical address */
- if ( ioctl(cma_fd, CMA_GET_PHY_ADDR, &data) == -1 ) {
+ if (ioctl(cma_fd, CMA_GET_PHY_ADDR, &data) == -1) {
__DEBUG("cma_free - ioctl command unsuccsessful\n");
return 0;
}
return data;
}
-
-void *cma_alloc(size_t size, unsigned ioctl_cmd) {
+void* cma_alloc(size_t size, unsigned ioctl_cmd) {
unsigned data;
- void *mem;
+ void* mem;
__DEBUG("Allocating 0x%x bytes of contigous memory\n", size);
/* Page align size */
/* ioctl cmd to allocate contigous memory */
data = (unsigned)size;
- if ( ioctl(cma_fd, ioctl_cmd, &data) == -1 ) {
+ if (ioctl(cma_fd, ioctl_cmd, &data) == -1) {
__DEBUG("cma_alloc - ioctl command unsuccsessful\n");
return NULL;
}
int n = 32;
uint32_t y;
- y = x >>16; if (y) { n = n -16; x = y; }
- y = x >> 8; if (y) { n = n - 8; x = y; }
- y = x >> 4; if (y) { n = n - 4; x = y; }
- y = x >> 2; if (y) { n = n - 2; x = y; }
- y = x >> 1; if (y) return n - 2;
+ y = x >> 16;
+ if (y) {
+ n = n - 16;
+ x = y;
+ }
+ y = x >> 8;
+ if (y) {
+ n = n - 8;
+ x = y;
+ }
+ y = x >> 4;
+ if (y) {
+ n = n - 4;
+ x = y;
+ }
+ y = x >> 2;
+ if (y) {
+ n = n - 2;
+ x = y;
+ }
+ y = x >> 1;
+ if (y) return n - 2;
return n - x;
}
-template <typename SRC_T, typename SRC_REP_T, int SRC_SIG_BITS,
- typename DST_T, typename DST_REP_T, int DST_SIG_BITS>
+template <typename SRC_T, typename SRC_REP_T, int SRC_SIG_BITS, typename DST_T, typename DST_REP_T,
+ int DST_SIG_BITS>
static inline DST_T __truncXfYf2__(SRC_T a) {
// Various constants whose values follow from the type parameters.
// Any reasonable optimizer will fold and propagate all of these.
const DST_REP_T dstNaNCode = dstQNaN - 1;
// Break a into a sign and representation of the absolute value
- union SrcExchangeType { SRC_T f; SRC_REP_T i; };
+ union SrcExchangeType {
+ SRC_T f;
+ SRC_REP_T i;
+ };
SrcExchangeType src_rep;
src_rep.f = a;
const SRC_REP_T aRep = src_rep.i;
const SRC_REP_T roundBits = aAbs & roundMask;
// Round to nearest
- if (roundBits > halfway)
- absResult++;
- // Ties to even
+ if (roundBits > halfway) absResult++;
+ // Ties to even
else if (roundBits == halfway)
absResult += absResult & 1;
- }
- else if (aAbs > srcInfinity) {
+ } else if (aAbs > srcInfinity) {
// a is NaN.
// Conjure the result by beginning with infinity, setting the qNaN
// bit and inserting the (truncated) trailing NaN field.
absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS;
absResult |= dstQNaN;
absResult |= ((aAbs & srcNaNCode) >> (SRC_SIG_BITS - DST_SIG_BITS)) & dstNaNCode;
- }
- else if (aAbs >= overflow) {
+ } else if (aAbs >= overflow) {
// a overflows to infinity.
absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS;
- }
- else {
+ } else {
// a underflows on conversion to the destination type or is an exact
// zero. The result may be a denormal or zero. Extract the exponent
// to get the shift amount for the denormalization.
absResult = denormalizedSignificand >> (SRC_SIG_BITS - DST_SIG_BITS);
const SRC_REP_T roundBits = denormalizedSignificand & roundMask;
// Round to nearest
- if (roundBits > halfway)
- absResult++;
- // Ties to even
+ if (roundBits > halfway) absResult++;
+ // Ties to even
else if (roundBits == halfway)
absResult += absResult & 1;
}
// Apply the signbit to (DST_T)abs(a).
const DST_REP_T result = absResult | sign >> (srcBits - dstBits);
- union DstExchangeType { DST_T f; DST_REP_T i; };
+ union DstExchangeType {
+ DST_T f;
+ DST_REP_T i;
+ };
DstExchangeType dst_rep;
dst_rep.i = result;
return dst_rep.f;
}
-template<typename SRC_T, typename SRC_REP_T, int SRC_SIG_BITS,
- typename DST_T, typename DST_REP_T, int DST_SIG_BITS>
+template <typename SRC_T, typename SRC_REP_T, int SRC_SIG_BITS, typename DST_T, typename DST_REP_T,
+ int DST_SIG_BITS>
static inline DST_T __extendXfYf2__(SRC_T a) {
// Various constants whose values follow from the type parameters.
// Any reasonable optimizer will fold and propagate all of these.
const SRC_REP_T srcQNaN = SRC_REP_T(1) << (SRC_SIG_BITS - 1);
const SRC_REP_T srcNaNCode = srcQNaN - 1;
- const int dstBits = sizeof(DST_T)*8;
+ const int dstBits = sizeof(DST_T) * 8;
const int dstExpBits = dstBits - DST_SIG_BITS - 1;
const int dstInfExp = (1 << dstExpBits) - 1;
const int dstExpBias = dstInfExp >> 1;
const DST_REP_T dstMinNormal = DST_REP_T(1) << DST_SIG_BITS;
// Break a into a sign and representation of the absolute value
- union SrcExchangeType { SRC_T f; SRC_REP_T i; };
+ union SrcExchangeType {
+ SRC_T f;
+ SRC_REP_T i;
+ };
SrcExchangeType src_rep;
src_rep.f = a;
const SRC_REP_T aRep = src_rep.i;
absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS;
absResult |= (DST_REP_T)(aAbs & srcQNaN) << (DST_SIG_BITS - SRC_SIG_BITS);
absResult |= (DST_REP_T)(aAbs & srcNaNCode) << (DST_SIG_BITS - SRC_SIG_BITS);
- }
- else if (aAbs) {
+ } else if (aAbs) {
// a is denormal.
// renormalize the significand and clear the leading bit, then insert
// the correct adjusted exponent in the destination type.
absResult ^= dstMinNormal;
const int resultExponent = dstExpBias - srcExpBias - scale + 1;
absResult |= (DST_REP_T)resultExponent << DST_SIG_BITS;
- }
- else {
+ } else {
// a is zero.
absResult = 0;
}
// Apply the signbit to (DST_T)abs(a).
const DST_REP_T result = absResult | (DST_REP_T)sign << (dstBits - srcBits);
- union DstExchangeType { DST_T f; DST_REP_T i; };
+ union DstExchangeType {
+ DST_T f;
+ DST_REP_T i;
+ };
DstExchangeType dst_rep;
dst_rep.i = result;
return dst_rep.f;
* \brief Pack all tvm runtime source files
*/
#include <sys/stat.h>
+
#include <fstream>
/* Enable custom logging - this will cause TVM to pass every log message
#include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc"
-#include "../src/runtime/workspace_pool.cc"
+#include "../src/runtime/dso_library.cc"
+#include "../src/runtime/file_util.cc"
+#include "../src/runtime/graph/graph_runtime.cc"
#include "../src/runtime/library_module.cc"
-#include "../src/runtime/system_library.cc"
#include "../src/runtime/module.cc"
+#include "../src/runtime/ndarray.cc"
+#include "../src/runtime/object.cc"
#include "../src/runtime/registry.cc"
-#include "../src/runtime/file_util.cc"
-#include "../src/runtime/dso_library.cc"
-#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc"
-#include "../src/runtime/rpc/rpc_server_env.cc"
#include "../src/runtime/rpc/rpc_module.cc"
+#include "../src/runtime/rpc/rpc_server_env.cc"
+#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_socket_impl.cc"
+#include "../src/runtime/system_library.cc"
#include "../src/runtime/thread_pool.cc"
#include "../src/runtime/threading_backend.cc"
-#include "../src/runtime/graph/graph_runtime.cc"
-#include "../src/runtime/ndarray.cc"
-#include "../src/runtime/object.cc"
+#include "../src/runtime/workspace_pool.cc"
#ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc"
#include "../src/runtime/contrib/sort/sort.cc"
#endif
-
#include <android/log.h>
void dmlc::CustomLogMessage::Log(const std::string& msg) {
* \brief Pack all tvm runtime source files
*/
#include <sys/stat.h>
+
#include <fstream>
#include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc"
-#include "../src/runtime/workspace_pool.cc"
+#include "../src/runtime/dso_library.cc"
+#include "../src/runtime/file_util.cc"
+#include "../src/runtime/graph/graph_runtime.cc"
#include "../src/runtime/library_module.cc"
-#include "../src/runtime/system_library.cc"
#include "../src/runtime/module.cc"
+#include "../src/runtime/ndarray.cc"
+#include "../src/runtime/object.cc"
#include "../src/runtime/registry.cc"
-#include "../src/runtime/file_util.cc"
-#include "../src/runtime/dso_library.cc"
+#include "../src/runtime/system_library.cc"
#include "../src/runtime/thread_pool.cc"
-#include "../src/runtime/object.cc"
#include "../src/runtime/threading_backend.cc"
-#include "../src/runtime/ndarray.cc"
-
-#include "../src/runtime/graph/graph_runtime.cc"
+#include "../src/runtime/workspace_pool.cc"
#ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc"
* \brief Pack all tvm runtime source files
*/
#include <sys/stat.h>
+
#include <fstream>
/* Enable custom logging - this will cause TVM to pass every log message
#include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc"
-#include "../src/runtime/workspace_pool.cc"
+#include "../src/runtime/dso_library.cc"
+#include "../src/runtime/file_util.cc"
+#include "../src/runtime/graph/graph_runtime.cc"
#include "../src/runtime/library_module.cc"
-#include "../src/runtime/system_library.cc"
#include "../src/runtime/module.cc"
+#include "../src/runtime/ndarray.cc"
+#include "../src/runtime/object.cc"
#include "../src/runtime/registry.cc"
-#include "../src/runtime/file_util.cc"
-#include "../src/runtime/dso_library.cc"
-#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc"
-#include "../src/runtime/rpc/rpc_server_env.cc"
#include "../src/runtime/rpc/rpc_module.cc"
+#include "../src/runtime/rpc/rpc_server_env.cc"
+#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_socket_impl.cc"
+#include "../src/runtime/system_library.cc"
#include "../src/runtime/thread_pool.cc"
#include "../src/runtime/threading_backend.cc"
-#include "../src/runtime/graph/graph_runtime.cc"
-#include "../src/runtime/ndarray.cc"
-#include "../src/runtime/object.cc"
+#include "../src/runtime/workspace_pool.cc"
#ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc"
#include "../src/runtime/contrib/sort/sort.cc"
#endif
-
#include <android/log.h>
void dmlc::CustomLogMessage::Log(const std::string& msg) {
* under the License.
*/
-#include <memory>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
+#include <memory>
+
#define TVM_BUNDLE_FUNCTION __attribute__((visibility("default")))
extern "C" {
-TVM_BUNDLE_FUNCTION void *tvm_runtime_create(const char * build_graph_json,
- const char * build_params_bin,
+TVM_BUNDLE_FUNCTION void* tvm_runtime_create(const char* build_graph_json,
+ const char* build_params_bin,
const uint64_t build_params_bin_len) {
const int build_graph_json_len = strlen(build_graph_json);
- const std::string json_data(&build_graph_json[0],
- &build_graph_json[0] + build_graph_json_len);
- tvm::runtime::Module mod_syslib =
- (*tvm::runtime::Registry::Get("runtime.SystemLib"))();
+ const std::string json_data(&build_graph_json[0], &build_graph_json[0] + build_graph_json_len);
+ tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("runtime.SystemLib"))();
int device_type = kDLCPU;
int device_id = 0;
- tvm::runtime::Module mod =
- (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))(
- json_data, mod_syslib, device_type, device_id);
+ tvm::runtime::Module mod = (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))(
+ json_data, mod_syslib, device_type, device_id);
TVMByteArray params;
- params.data = reinterpret_cast<const char *>(&build_params_bin[0]);
+ params.data = reinterpret_cast<const char*>(&build_params_bin[0]);
params.size = build_params_bin_len;
mod.GetFunction("load_params")(params);
return new tvm::runtime::Module(mod);
}
-TVM_BUNDLE_FUNCTION void tvm_runtime_destroy(void *handle) {
- delete reinterpret_cast<tvm::runtime::Module *>(handle);
+TVM_BUNDLE_FUNCTION void tvm_runtime_destroy(void* handle) {
+ delete reinterpret_cast<tvm::runtime::Module*>(handle);
}
-TVM_BUNDLE_FUNCTION void tvm_runtime_set_input(void *handle, const char *name,
- void *tensor) {
- reinterpret_cast<tvm::runtime::Module *>(handle)->GetFunction("set_input")(
- name, reinterpret_cast<DLTensor *>(tensor));
+TVM_BUNDLE_FUNCTION void tvm_runtime_set_input(void* handle, const char* name, void* tensor) {
+ reinterpret_cast<tvm::runtime::Module*>(handle)->GetFunction("set_input")(
+ name, reinterpret_cast<DLTensor*>(tensor));
}
-TVM_BUNDLE_FUNCTION void tvm_runtime_run(void *handle) {
- reinterpret_cast<tvm::runtime::Module *>(handle)->GetFunction("run")();
+TVM_BUNDLE_FUNCTION void tvm_runtime_run(void* handle) {
+ reinterpret_cast<tvm::runtime::Module*>(handle)->GetFunction("run")();
}
-TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void *handle, int index,
- void *tensor) {
- reinterpret_cast<tvm::runtime::Module *>(handle)->GetFunction("get_output")(
- index, reinterpret_cast<DLTensor *>(tensor));
+TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void* handle, int index, void* tensor) {
+ reinterpret_cast<tvm::runtime::Module*>(handle)->GetFunction("get_output")(
+ index, reinterpret_cast<DLTensor*>(tensor));
}
}
#include <tvm/runtime/c_runtime_api.h>
-TVM_DLL void * tvm_runtime_create(const char * json_data,
- const char * params_data,
- const uint64_t params_size);
+TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data,
+ const uint64_t params_size);
-TVM_DLL void tvm_runtime_destroy(void * runtime);
+TVM_DLL void tvm_runtime_destroy(void* runtime);
-TVM_DLL void tvm_runtime_set_input(void * runtime,
- const char * name,
- DLTensor * tensor);
+TVM_DLL void tvm_runtime_set_input(void* runtime, const char* name, DLTensor* tensor);
-TVM_DLL void tvm_runtime_run(void * runtime);
+TVM_DLL void tvm_runtime_run(void* runtime);
-TVM_DLL void tvm_runtime_get_output(void * runtime,
- int32_t index,
- DLTensor * tensor);
+TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tensor);
#endif /* TVM_APPS_BUNDLE_DEPLOY_BUNDLE_H_ */
* under the License.
*/
+#include <assert.h>
+#include <dlfcn.h> //dlopen
+#include <sys/time.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <assert.h>
-#include <dlfcn.h> //dlopen
#include <iostream>
#include <random>
#include <vector>
-#include <sys/time.h>
#include "build/graph.json.c"
#include "build/params.bin.c"
-template <typename F> auto getFunc(void *bundle, const char *name) {
+template <typename F>
+auto getFunc(void* bundle, const char* name) {
dlerror();
- auto *f =
- reinterpret_cast<typename std::add_pointer<F>::type>(dlsym(bundle, name));
+ auto* f = reinterpret_cast<typename std::add_pointer<F>::type>(dlsym(bundle, name));
assert(!dlerror());
return f;
}
-int main(int argc, char **argv) {
+int main(int argc, char** argv) {
assert(argc == 3 && "Usage: demo <bundle.so> <cat.bin>");
- auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL);
+ auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL);
assert(bundle);
- char * json_data = reinterpret_cast<char*>(build_graph_json);
- char * params_data = reinterpret_cast<char*>(build_params_bin);
+ char* json_data = reinterpret_cast<char*>(build_graph_json);
+ char* params_data = reinterpret_cast<char*>(build_params_bin);
uint64_t params_size = build_params_bin_len;
struct timeval t0, t1, t2, t3, t4, t5;
gettimeofday(&t0, 0);
- auto *handle = getFunc<void *(char*, char*, int)>(bundle, "tvm_runtime_create")(
+ auto* handle = getFunc<void*(char*, char*, int)>(bundle, "tvm_runtime_create")(
json_data, params_data, params_size);
gettimeofday(&t1, 0);
float input_storage[1 * 3 * 224 * 224];
- FILE * fp = fopen(argv[2], "rb");
+ FILE* fp = fopen(argv[2], "rb");
fread(input_storage, 3 * 224 * 224, 4, fp);
fclose(fp);
input.strides = nullptr;
input.byte_offset = 0;
- getFunc<void(void *, const char *, void *)>(bundle, "tvm_runtime_set_input")(
- handle, "data", &input);
+ getFunc<void(void*, const char*, void*)>(bundle, "tvm_runtime_set_input")(handle, "data", &input);
gettimeofday(&t2, 0);
- auto *ftvm_runtime_run =
- (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run");
+ auto* ftvm_runtime_run = (auto (*)(void*)->void)dlsym(bundle, "tvm_runtime_run");
assert(!dlerror());
ftvm_runtime_run(handle);
gettimeofday(&t3, 0);
output.strides = nullptr;
output.byte_offset = 0;
- getFunc<void(void *, int, void *)>(bundle, "tvm_runtime_get_output")(
- handle, 0, &output);
+ getFunc<void(void*, int, void*)>(bundle, "tvm_runtime_get_output")(handle, 0, &output);
gettimeofday(&t4, 0);
float max_iter = -std::numeric_limits<float>::max();
}
}
- getFunc<void(void *)>(bundle, "tvm_runtime_destroy")(handle);
+ getFunc<void(void*)>(bundle, "tvm_runtime_destroy")(handle);
gettimeofday(&t5, 0);
- printf("The maximum position in output vector is: %d, with max-value %f.\n",
- max_index, max_iter);
- printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), "
- "%.2f ms (get_output), %.2f ms (destroy)\n",
- (t1.tv_sec-t0.tv_sec)*1000.0f + (t1.tv_usec-t0.tv_usec)/1000.f,
- (t2.tv_sec-t1.tv_sec)*1000.0f + (t2.tv_usec-t1.tv_usec)/1000.f,
- (t3.tv_sec-t2.tv_sec)*1000.0f + (t3.tv_usec-t2.tv_usec)/1000.f,
- (t4.tv_sec-t3.tv_sec)*1000.0f + (t4.tv_usec-t3.tv_usec)/1000.f,
- (t5.tv_sec-t4.tv_sec)*1000.0f + (t5.tv_usec-t4.tv_usec)/1000.f);
+ printf("The maximum position in output vector is: %d, with max-value %f.\n", max_index, max_iter);
+ printf(
+ "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), "
+ "%.2f ms (get_output), %.2f ms (destroy)\n",
+ (t1.tv_sec - t0.tv_sec) * 1000.0f + (t1.tv_usec - t0.tv_usec) / 1000.f,
+ (t2.tv_sec - t1.tv_sec) * 1000.0f + (t2.tv_usec - t1.tv_usec) / 1000.f,
+ (t3.tv_sec - t2.tv_sec) * 1000.0f + (t3.tv_usec - t2.tv_usec) / 1000.f,
+ (t4.tv_sec - t3.tv_sec) * 1000.0f + (t4.tv_usec - t3.tv_usec) / 1000.f,
+ (t5.tv_sec - t4.tv_sec) * 1000.0f + (t5.tv_usec - t4.tv_usec) / 1000.f);
dlclose(bundle);
-
+
return 0;
}
#include <dlpack/dlpack.h>
#include <tvm/runtime/module.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
#include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/cpu_device_api.cc"
-#include "../../src/runtime/workspace_pool.cc"
+#include "../../src/runtime/file_util.cc"
+#include "../../src/runtime/graph/graph_runtime.cc"
#include "../../src/runtime/library_module.cc"
#include "../../src/runtime/module.cc"
-#include "../../src/runtime/registry.cc"
-#include "../../src/runtime/file_util.cc"
-#include "../../src/runtime/threading_backend.cc"
-#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"
+#include "../../src/runtime/registry.cc"
#include "../../src/runtime/system_library.cc"
-#include "../../src/runtime/graph/graph_runtime.cc"
+#include "../../src/runtime/thread_pool.cc"
+#include "../../src/runtime/threading_backend.cc"
+#include "../../src/runtime/workspace_pool.cc"
* under the License.
*/
+#include <assert.h>
+#include <dlfcn.h> //dlopen
+#include <sys/stat.h>
+#include <sys/time.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <assert.h>
-#include <dlfcn.h> //dlopen
#include <iostream>
#include <random>
#include <vector>
-#include <sys/time.h>
-#include <sys/stat.h>
-template <typename F> auto getFunc(void *bundle, const char *name) {
+template <typename F>
+auto getFunc(void* bundle, const char* name) {
dlerror();
- auto *f =
- reinterpret_cast<typename std::add_pointer<F>::type>(dlsym(bundle, name));
+ auto* f = reinterpret_cast<typename std::add_pointer<F>::type>(dlsym(bundle, name));
assert(!dlerror());
return f;
}
-int main(int argc, char **argv) {
+int main(int argc, char** argv) {
assert(argc == 6 && "Usage: test <bundle.so> <data.bin> <output.bin> <graph.json> <params.bin>");
- auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL);
+ auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL);
assert(bundle);
struct stat st;
- char * json_data;
- char * params_data;
+ char* json_data;
+ char* params_data;
uint64_t params_size;
- FILE * fp = fopen(argv[4], "rb");
+ FILE* fp = fopen(argv[4], "rb");
stat(argv[4], &st);
json_data = (char*)malloc(st.st_size);
fread(json_data, st.st_size, 1, fp);
struct timeval t0, t1, t2, t3, t4, t5;
gettimeofday(&t0, 0);
- auto *handle = getFunc<void *(char*, char*, int)>(bundle, "tvm_runtime_create")(
+ auto* handle = getFunc<void*(char*, char*, int)>(bundle, "tvm_runtime_create")(
json_data, params_data, params_size);
gettimeofday(&t1, 0);
input.strides = nullptr;
input.byte_offset = 0;
- getFunc<void(void *, const char *, void *)>(bundle, "tvm_runtime_set_input")(
- handle, "x", &input);
+ getFunc<void(void*, const char*, void*)>(bundle, "tvm_runtime_set_input")(handle, "x", &input);
gettimeofday(&t2, 0);
- auto *ftvm_runtime_run =
- (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run");
+ auto* ftvm_runtime_run = (auto (*)(void*)->void)dlsym(bundle, "tvm_runtime_run");
assert(!dlerror());
ftvm_runtime_run(handle);
gettimeofday(&t3, 0);
output.strides = nullptr;
output.byte_offset = 0;
- getFunc<void(void *, int, void *)>(bundle, "tvm_runtime_get_output")(
- handle, 0, &output);
+ getFunc<void(void*, int, void*)>(bundle, "tvm_runtime_get_output")(handle, 0, &output);
gettimeofday(&t4, 0);
for (auto i = 0; i < 10 * 5; ++i) {
}
}
- getFunc<void(void *)>(bundle, "tvm_runtime_destroy")(handle);
+ getFunc<void(void*)>(bundle, "tvm_runtime_destroy")(handle);
gettimeofday(&t5, 0);
- printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), "
- "%.2f ms (get_output), %.2f ms (destroy)\n",
- (t1.tv_sec-t0.tv_sec)*1000.0f + (t1.tv_usec-t0.tv_usec)/1000.f,
- (t2.tv_sec-t1.tv_sec)*1000.0f + (t2.tv_usec-t1.tv_usec)/1000.f,
- (t3.tv_sec-t2.tv_sec)*1000.0f + (t3.tv_usec-t2.tv_usec)/1000.f,
- (t4.tv_sec-t3.tv_sec)*1000.0f + (t4.tv_usec-t3.tv_usec)/1000.f,
- (t5.tv_sec-t4.tv_sec)*1000.0f + (t5.tv_usec-t4.tv_usec)/1000.f);
+ printf(
+ "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), "
+ "%.2f ms (get_output), %.2f ms (destroy)\n",
+ (t1.tv_sec - t0.tv_sec) * 1000.0f + (t1.tv_usec - t0.tv_usec) / 1000.f,
+ (t2.tv_sec - t1.tv_sec) * 1000.0f + (t2.tv_usec - t1.tv_usec) / 1000.f,
+ (t3.tv_sec - t2.tv_sec) * 1000.0f + (t3.tv_usec - t2.tv_usec) / 1000.f,
+ (t4.tv_sec - t3.tv_sec) * 1000.0f + (t4.tv_usec - t3.tv_usec) / 1000.f,
+ (t5.tv_sec - t4.tv_sec) * 1000.0f + (t5.tv_usec - t4.tv_usec) / 1000.f);
free(json_data);
free(params_data);
dlclose(bundle);
-
+
return 0;
}
* \file rpc_server.cc
* \brief RPC Server for TVM.
*/
-#include <cstdlib>
#include <csignal>
#include <cstdio>
+#include <cstdlib>
#if defined(__linux__) || defined(__ANDROID__)
#include <unistd.h>
#endif
#include <dmlc/logging.h>
-#include <iostream>
+
#include <cstring>
-#include <vector>
+#include <iostream>
#include <sstream>
+#include <vector>
-#include "../../src/support/util.h"
#include "../../src/support/socket.h"
+#include "../../src/support/util.h"
#include "rpc_server.h"
#if defined(_WIN32)
using namespace tvm::runtime;
using namespace tvm::support;
-static const string kUsage = \
-"Command line usage\n" \
-" server - Start the server\n" \
-"--host - The hostname of the server, Default=0.0.0.0\n" \
-"--port - The port of the RPC, Default=9090\n" \
-"--port-end - The end search port of the RPC, Default=9199\n" \
-"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \
-"--key - The key used to identify the device type in tracker. Default=\"\"\n" \
-"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \
-"--silent - Whether to run in silent mode. Default=False\n" \
-"\n" \
-" Example\n" \
-" ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 "
-" --tracker=127.0.0.1:9190 --key=rasp" \
-"\n";
+static const string kUsage =
+ "Command line usage\n"
+ " server - Start the server\n"
+ "--host - The hostname of the server, Default=0.0.0.0\n"
+ "--port - The port of the RPC, Default=9090\n"
+ "--port-end - The end search port of the RPC, Default=9199\n"
+ "--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n"
+ "--key - The key used to identify the device type in tracker. Default=\"\"\n"
+ "--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n"
+ "--silent - Whether to run in silent mode. Default=False\n"
+ "\n"
+ " Example\n"
+ " ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 "
+ " --tracker=127.0.0.1:9190 --key=rasp"
+ "\n";
/*!
* \brief RpcServerArgs.
LOG(INFO) << "tracker = " << args.tracker;
LOG(INFO) << "key = " << args.key;
LOG(INFO) << "custom_addr = " << args.custom_addr;
- LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False"));
+ LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False"));
}
#if defined(__linux__) || defined(__ANDROID__)
* \param tracker The tracker input.
* \return result of operation.
*/
-bool ValidateTracker(string &tracker) {
+bool ValidateTracker(string& tracker) {
vector<string> list = Split(tracker, ':');
if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) {
return false;
* \param argv arg values
* \param args the output structure which holds the parsed values
*/
-void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
+void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) {
const string silent = GetCmdOption(argc, argv, "--silent", true);
if (!silent.empty()) {
args.silent = true;
}
#if defined(WIN32)
const string mmap_path = GetCmdOption(argc, argv, "--child_proc=");
- if(!mmap_path.empty()) {
+ if (!mmap_path.empty()) {
args.mmap_path = mmap_path;
dmlc::InitLogging("--minloglevel=0");
}
#endif
-
}
/*!
* \param argv arg values
* \return result of operation.
*/
-int RpcServer(int argc, char * argv[]) {
+int RpcServer(int argc, char* argv[]) {
RpcServerArgs args;
/* parse the command line args */
#endif
#if defined(WIN32)
- if(!args.mmap_path.empty()) {
+ if (!args.mmap_path.empty()) {
int ret = 0;
try {
- ChildProcSocketHandler(args.mmap_path);
+ ChildProcSocketHandler(args.mmap_path);
} catch (const std::exception&) {
- ret = -1;
+ ret = -1;
}
return ret;
}
#endif
- RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
- args.key, args.custom_addr, args.silent);
+ RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr,
+ args.silent);
return 0;
}
* \param argv arg values
* \return result of operation.
*/
-int main(int argc, char * argv[]) {
+int main(int argc, char* argv[]) {
if (argc <= 1) {
LOG(INFO) << kUsage;
return 0;
* \file rpc_env.cc
* \brief Server environment of the RPC.
*/
-#include <cerrno>
#include <tvm/runtime/registry.h>
+
+#include <cerrno>
#ifndef _WIN32
#include <dirent.h>
#include <sys/stat.h>
#include <Windows.h>
#include <direct.h>
namespace {
- int mkdir(const char* path, int /* ignored */) { return _mkdir(path); }
-}
+int mkdir(const char* path, int /* ignored */) { return _mkdir(path); }
+} // namespace
#endif
#include <cstring>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
-#include <string>
-#include "../../src/support/util.h"
#include "../../src/runtime/file_util.h"
+#include "../../src/support/util.h"
#include "rpc_env.h"
namespace {
- std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) {
- std::string untar_cmd;
- untar_cmd.reserve(512);
+std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) {
+ std::string untar_cmd;
+ untar_cmd.reserve(512);
#if defined(__linux__) || defined(__ANDROID__)
- untar_cmd += "tar -C ";
- untar_cmd += output_dir;
- untar_cmd += " -zxf ";
- untar_cmd += tar_file;
+ untar_cmd += "tar -C ";
+ untar_cmd += output_dir;
+ untar_cmd += " -zxf ";
+ untar_cmd += tar_file;
#elif defined(_WIN32)
- untar_cmd += "python -m tarfile -e ";
- untar_cmd += tar_file;
- untar_cmd += " ";
- untar_cmd += output_dir;
+ untar_cmd += "python -m tarfile -e ";
+ untar_cmd += tar_file;
+ untar_cmd += " ";
+ untar_cmd += output_dir;
#endif
- return untar_cmd;
- }
+ return untar_cmd;
+}
-}// Anonymous namespace
+} // Anonymous namespace
namespace tvm {
namespace runtime {
RPCEnv::RPCEnv() {
#ifndef _WIN32
char cwd[PATH_MAX];
- if (char *rc = getcwd(cwd, sizeof(cwd))) {
+ if (char* rc = getcwd(cwd, sizeof(cwd))) {
base_ = std::string(cwd) + "/rpc";
} else {
base_ = "./rpc";
* \param options The compiler options
* \param cc The compiler
*/
-void LinuxShared(const std::string output,
- const std::vector<std::string> &files,
- std::string options = "",
- std::string cc = "g++") {
- std::string cmd = cc;
- cmd += " -shared -fPIC ";
- cmd += " -o " + output;
- for (auto f = files.begin(); f != files.end(); ++f) {
- cmd += " " + *f;
- }
- cmd += " " + options;
- std::string err_msg;
- auto executed_status = support::Execute(cmd, &err_msg);
- if (executed_status) {
- LOG(FATAL) << err_msg;
- }
+void LinuxShared(const std::string output, const std::vector<std::string>& files,
+ std::string options = "", std::string cc = "g++") {
+ std::string cmd = cc;
+ cmd += " -shared -fPIC ";
+ cmd += " -o " + output;
+ for (auto f = files.begin(); f != files.end(); ++f) {
+ cmd += " " + *f;
+ }
+ cmd += " " + options;
+ std::string err_msg;
+ auto executed_status = support::Execute(cmd, &err_msg);
+ if (executed_status) {
+ LOG(FATAL) << err_msg;
+ }
}
#endif
* \param options The compiler options
* \param cc The compiler
*/
-void WindowsShared(const std::string& output,
- const std::vector<std::string>& files,
- const std::string& options = "",
- const std::string& cc = "clang") {
+void WindowsShared(const std::string& output, const std::vector<std::string>& files,
+ const std::string& options = "", const std::string& cc = "clang") {
std::string cmd = cc;
cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared ";
cmd += " -o " + output;
* \param fmt The format of file
* \return Module The loaded module
*/
-Module Load(std::string *fileIn, const std::string& fmt) {
+Module Load(std::string* fileIn, const std::string& fmt) {
const std::string& file = *fileIn;
if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll")) {
return Module::LoadFromFile(file, fmt);
#define TVM_APPS_CPP_RPC_ENV_H_
#include <tvm/runtime/registry.h>
+
#include <string>
namespace tvm {
* \param file The format of file
* \return Module The loaded module
*/
-Module Load(std::string *path, const std::string& fmt = "");
+Module Load(std::string* path, const std::string& fmt = "");
/*!
* \brief CleanDir Removes the files from the directory
* \param dirname THe name of the directory
*/
-void CleanDir(const std::string &dirname);
+void CleanDir(const std::string& dirname);
/*!
* \brief RPCEnv The RPC Environment parameters for c++ rpc server
#include <set>
#include <string>
-#include "../../src/support/socket.h"
#include "../../src/runtime/rpc/rpc_endpoint.h"
#include "../../src/runtime/rpc/rpc_socket_impl.h"
+#include "../../src/support/socket.h"
#include "rpc_env.h"
#include "rpc_server.h"
#include "rpc_tracker_client.h"
while (end < len && !isspace(str[end])) end++;
iss->seekg(end);
- return str.substr(start, end-start);
+ return str.substr(start, end - start);
}
#endif
/*!
* \brief Constructor.
*/
- RPCServer(std::string host, int port, int port_end, std::string tracker_addr,
- std::string key, std::string custom_addr) :
- host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end),
- tracker_addr_(std::move(tracker_addr)), key_(std::move(key)),
- custom_addr_(std::move(custom_addr))
- {
-
- }
+ RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key,
+ std::string custom_addr)
+ : host_(std::move(host)),
+ port_(port),
+ my_port_(0),
+ port_end_(port_end),
+ tracker_addr_(std::move(tracker_addr)),
+ key_(std::move(key)),
+ custom_addr_(std::move(custom_addr)) {}
/*!
* \brief Destructor.
// Free the resources
tracker_sock_.Close();
listen_sock_.Close();
- } catch(...) {
-
+ } catch (...) {
}
}
try {
SpawnRPCChild(conn.sockfd, seconds(timeout));
} catch (const std::exception&) {
-
}
auto dur = high_resolution_clock::now() - start_time;
* \param opts Parsed options for socket
* \param ping_period Timeout for select call waiting
*/
- void AcceptConnection(TrackerClient* tracker,
- support::TCPSocket* conn_sock,
- support::SockAddr* addr,
- std::string* opts,
- int ping_period = 2) {
+ void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock,
+ support::SockAddr* addr, std::string* opts, int ping_period = 2) {
std::set<std::string> old_keyset;
std::string matchkey;
support::TCPSocket conn = listen_sock_.Accept(addr);
int code = kRPCMagic;
- CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code));
+ CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code));
if (code != kRPCMagic) {
conn.Close();
LOG(FATAL) << "Client connected is not TVM RPC server";
#if defined(WIN32)
/*!
-* \brief ServerLoopFromChild The Server loop process.
-* \param socket The socket information
-*/
+ * \brief ServerLoopFromChild The Server loop process.
+ * \param socket The socket information
+ */
void ServerLoopFromChild(SOCKET socket) {
// Server loop
tvm::support::TCPSocket sock(socket);
* \param host The hostname of the server, Default=0.0.0.0
* \param port The port of the RPC, Default=9090
* \param port_end The end search port of the RPC, Default=9199
- * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
- * \param key The key used to identify the device type in tracker. Default=""
- * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
- * \param silent Whether run in silent mode. Default=True
+ * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190
+ * Default="" \param key The key used to identify the device type in tracker. Default="" \param
+ * custom_addr Custom IP Address to Report to RPC Tracker. Default="" \param silent Whether run in
+ * silent mode. Default=True
*/
void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr,
std::string key, std::string custom_addr, bool silent) {
dmlc::InitLogging("--minloglevel=2");
}
// Start the rpc server
- RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr));
+ RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key),
+ std::move(custom_addr));
rpc.Start();
}
-TVM_REGISTER_GLOBAL("rpc.ServerCreate")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
- });
+TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
+ RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
+});
} // namespace runtime
} // namespace tvm
#define TVM_APPS_CPP_RPC_SERVER_H_
#include <string>
+
#include "tvm/runtime/c_runtime_api.h"
namespace tvm {
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param silent Whether run in silent mode. Default=True
*/
-void RPCServerCreate(std::string host = "",
- int port = 9090,
- int port_end = 9099,
- std::string tracker_addr = "",
- std::string key = "",
- std::string custom_addr = "",
- bool silent = true);
+void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099,
+ std::string tracker_addr = "", std::string key = "",
+ std::string custom_addr = "", bool silent = true);
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_SERVER_H_
#ifndef TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
#define TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
-#include <set>
-#include <iostream>
#include <chrono>
+#include <iostream>
#include <random>
-#include <vector>
+#include <set>
#include <string>
+#include <vector>
#include "../../src/runtime/rpc/rpc_endpoint.h"
#include "../../src/support/socket.h"
public:
/*!
* \brief Constructor.
- */
- TrackerClient(const std::string& tracker_addr,
- const std::string& key,
+ */
+ TrackerClient(const std::string& tracker_addr, const std::string& key,
const std::string& custom_addr)
- : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr),
- gen_(std::random_device{}()), dis_(0.0, 1.0) {
- }
+ : tracker_addr_(tracker_addr),
+ key_(key),
+ custom_addr_(custom_addr),
+ gen_(std::random_device{}()),
+ dis_(0.0, 1.0) {}
/*!
* \brief Destructor.
- */
+ */
~TrackerClient() {
// Free the resources
Close();
}
/*!
* \brief IsValid Check tracker is valid.
- */
- bool IsValid() {
- return (!tracker_addr_.empty() && !tracker_sock_.IsClosed());
- }
+ */
+ bool IsValid() { return (!tracker_addr_.empty() && !tracker_sock_.IsClosed()); }
/*!
* \brief TryConnect Connect to tracker if the tracker address is valid.
- */
+ */
void TryConnect() {
if (!tracker_addr_.empty() && (tracker_sock_.IsClosed())) {
tracker_sock_ = ConnectWithRetry();
CHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker";
std::ostringstream ss;
- ss << "[" << static_cast<int>(TrackerCode::kUpdateInfo)
- << ", {\"key\": \"server:"<< key_ << "\"}]";
+ ss << "[" << static_cast<int>(TrackerCode::kUpdateInfo) << ", {\"key\": \"server:" << key_
+ << "\"}]";
tracker_sock_.SendBytes(ss.str());
// Receive status and validate
}
/*!
* \brief Close Clean up tracker resources.
- */
+ */
void Close() {
// close tracker resource
if (!tracker_sock_.IsClosed()) {
tracker_sock_.Close();
}
}
- /*!
- * \brief ReportResourceAndGetKey Report resource to tracker.
- * \param port listening port.
- * \param matchkey Random match key output.
- */
- void ReportResourceAndGetKey(int port,
- std::string *matchkey) {
+ /*!
+ * \brief ReportResourceAndGetKey Report resource to tracker.
+ * \param port listening port.
+ * \param matchkey Random match key output.
+ */
+ void ReportResourceAndGetKey(int port, std::string* matchkey) {
if (!tracker_sock_.IsClosed()) {
*matchkey = RandomKey(key_ + ":", old_keyset_);
if (custom_addr_.empty()) {
}
std::ostringstream ss;
- ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", ["
- << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
+ ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port
+ << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
tracker_sock_.SendBytes(ss.str());
std::string remote_status = tracker_sock_.RecvBytes();
CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
} else {
- *matchkey = key_;
+ *matchkey = key_;
}
}
* \param port listening port.
* \param ping_period Select wait time.
* \param matchkey Random match key output.
- */
- void WaitConnectionAndUpdateKey(support::TCPSocket listen_sock,
- int port,
- int ping_period,
- std::string *matchkey) {
+ */
+ void WaitConnectionAndUpdateKey(support::TCPSocket listen_sock, int port, int ping_period,
+ std::string* matchkey) {
int unmatch_period_count = 0;
int unmatch_timeout = 4;
while (true) {
// if match key not in pending key set
// it means the key is acquired by a client but not used.
if (pending_keys.find(*matchkey) == std::string::npos) {
- unmatch_period_count += 1;
+ unmatch_period_count += 1;
} else {
- unmatch_period_count = 0;
+ unmatch_period_count = 0;
}
// regenerate match key if key is acquired but not used for a while
if (unmatch_period_count * ping_period > unmatch_timeout + ping_period) {
*matchkey = RandomKey(key_ + ":", old_keyset_);
std::ostringstream ss;
- ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", ["
- << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
+ ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port
+ << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
tracker_sock_.SendBytes(ss.str());
std::string remote_status = tracker_sock_.RecvBytes();
}
auto period = (std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now() - tbegin)).count();
+ std::chrono::system_clock::now() - tbegin))
+ .count();
CHECK(period < timeout) << "Failed to connect to server" << addr.AsString();
- LOG(WARNING) << "Cannot connect to tracker " << addr.AsString()
- << " retry in " << retry_period << " seconds.";
+ LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() << " retry in "
+ << retry_period << " seconds.";
std::this_thread::sleep_for(std::chrono::seconds(retry_period));
}
}
/*!
- * \brief Random Generate a random number between 0 and 1.
- * \return random float value.
- */
- float Random() {
- return dis_(gen_);
- }
+ * \brief Random Generate a random number between 0 and 1.
+ * \return random float value.
+ */
+ float Random() { return dis_(gen_); }
/*!
* \brief Generate a random key.
* \param prefix The string prefix.
* \return cmap The conflict map set.
*/
- std::string RandomKey(const std::string& prefix, const std::set <std::string> &cmap) {
+ std::string RandomKey(const std::string& prefix, const std::set<std::string>& cmap) {
if (!cmap.empty()) {
while (true) {
std::string key = prefix + std::to_string(Random());
std::string key_;
std::string custom_addr_;
support::TCPSocket tracker_sock_;
- std::set <std::string> old_keyset_;
+ std::set<std::string> old_keyset_;
std::mt19937 gen_;
std::uniform_real_distribution<float> dis_;
-
};
} // namespace runtime
} // namespace tvm
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
+#include "win32_process.h"
+
+#include <conio.h>
+#include <dmlc/logging.h>
#include <winsock2.h>
#include <ws2tcpip.h>
+
#include <cstdio>
#include <memory>
-#include <conio.h>
-#include <string>
#include <stdexcept>
-#include <dmlc/logging.h>
-#include "win32_process.h"
+#include <string>
+
#include "rpc_server.h"
using namespace std::chrono;
*/
SOCKET GetSocket(const std::string& mmap_path) {
WSAPROTOCOL_INFO protocol_info;
-
+
const std::string parent_event_name = mmap_path + kParent;
const std::string child_event_name = mmap_path + kChild;
// Open the events
UniqueHandle parent_file_mapping_event;
- if ((parent_file_mapping_event = MakeUniqueHandle(OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) {
+ if ((parent_file_mapping_event = MakeUniqueHandle(
+ OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "OpenEvent() failed: " << GetLastError();
}
UniqueHandle child_file_mapping_event;
- if ((child_file_mapping_event = MakeUniqueHandle(OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) {
+ if ((child_file_mapping_event = MakeUniqueHandle(
+ OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "OpenEvent() failed: " << GetLastError();
}
-
+
// Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read
- if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) {
- LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError();
+ if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) !=
+ WAIT_OBJECT_0) {
+ LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError();
}
- const UniqueHandle file_map = MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE,
- false,
- mmap_path.c_str()));
+ const UniqueHandle file_map =
+ MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, false, mmap_path.c_str()));
if (!file_map) {
- LOG(INFO) << "CreateFileMapping() failed: " << GetLastError();
+ LOG(INFO) << "CreateFileMapping() failed: " << GetLastError();
}
- void* map_view = MapViewOfFile(file_map.get(),
- FILE_MAP_READ | FILE_MAP_WRITE,
- 0, 0, 0);
+ void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0);
SOCKET sock_duplicated = INVALID_SOCKET;
UnmapViewOfFile(map_view);
// Creates the duplicate socket, that was created in the parent
- sock_duplicated = WSASocket(FROM_PROTOCOL_INFO,
- FROM_PROTOCOL_INFO,
- FROM_PROTOCOL_INFO,
- &protocol_info,
- 0,
- 0);
+ sock_duplicated =
+ WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, &protocol_info, 0, 0);
// Let the parent know we are finished dupicating the socket
SetEvent(child_file_mapping_event.get());
return sock_duplicated;
}
-}// Anonymous namespace
+} // Anonymous namespace
namespace tvm {
namespace runtime {
*/
void SpawnRPCChild(SOCKET fd, seconds timeout) {
STARTUPINFOA startup_info;
-
+
memset(&startup_info, 0, sizeof(startup_info));
startup_info.cb = sizeof(startup_info);
// Create an event to let the child know the socket info was set to the mmap file
UniqueHandle parent_file_mapping_event;
- if ((parent_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) {
+ if ((parent_file_mapping_event = MakeUniqueHandle(
+ CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "CreateEvent for parent file mapping failed";
}
UniqueHandle child_file_mapping_event;
// An event to let the parent know the socket info was read from the mmap file
- if ((child_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) {
+ if ((child_file_mapping_event = MakeUniqueHandle(
+ CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "CreateEvent for child file mapping failed";
}
strcpy(command_line_ptr.get(), child_command_line.c_str());
PROCESS_INFORMATION child_process_info;
- if (CreateProcessA(nullptr,
- command_line_ptr.get(),
- nullptr,
- nullptr,
- false,
- CREATE_NO_WINDOW,
- nullptr,
- nullptr,
- &startup_info,
- &child_process_info)) {
+ if (CreateProcessA(nullptr, command_line_ptr.get(), nullptr, nullptr, false, CREATE_NO_WINDOW,
+ nullptr, nullptr, &startup_info, &child_process_info)) {
// Child process and thread handles must be closed, so wrapped in RAII
auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess);
auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread);
WSAPROTOCOL_INFO protocol_info;
// Get info needed to duplicate the socket
- if (WSADuplicateSocket(fd,
- child_process_info.dwProcessId,
- &protocol_info) == SOCKET_ERROR) {
+ if (WSADuplicateSocket(fd, child_process_info.dwProcessId, &protocol_info) == SOCKET_ERROR) {
LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError();
}
// Create a mmap file to store the info needed for duplicating the SOCKET in the child proc
- UniqueHandle file_map = MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE,
- nullptr,
- PAGE_READWRITE,
- 0,
- sizeof(WSAPROTOCOL_INFO),
- file_map_path.c_str()));
+ UniqueHandle file_map =
+ MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, 0,
+ sizeof(WSAPROTOCOL_INFO), file_map_path.c_str()));
if (!file_map) {
LOG(INFO) << "CreateFileMapping() failed: " << GetLastError();
}
// Let child proc know the mmap file is ready to be read
SetEvent(parent_file_mapping_event.get());
-
+
// Wait for the child to finish reading mmap file
- if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) {
+ if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) !=
+ WAIT_OBJECT_0) {
TerminateProcess(child_process_handle.get(), 0);
- LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child process.";
+ LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child "
+ "process.";
}
} else {
TerminateProcess(child_process_handle.get(), 0);
}
}
- const DWORD process_timeout = timeout.count()
- ? uint32_t(duration_cast<milliseconds>(timeout).count())
- : INFINITE;
+ const DWORD process_timeout =
+ timeout.count() ? uint32_t(duration_cast<milliseconds>(timeout).count()) : INFINITE;
// Wait for child process to exit, or hit configured timeout
if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) {
}
}
/*!
- * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket
- * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent
+ * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client
+ * socket \param mmap_path The memory mapped file path that will contain the information to
+ * duplicate the client socket from the parent
*/
void ChildProcSocketHandler(const std::string& mmap_path) {
SOCKET socket;
// Set high thread priority to avoid the thread scheduler from
// interfering with any measurements in the RPC server.
SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL);
-
+
if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) {
tvm::runtime::ServerLoopFromChild(socket);
- }
- else {
+ } else {
LOG(FATAL) << "GetSocket() failed";
}
-
}
} // namespace runtime
} // namespace tvm
\ No newline at end of file
* under the License.
*/
- /*!
- * \file win32_process.h
- * \brief Win32 process code to mimic a POSIX fork()
- */
+/*!
+ * \file win32_process.h
+ * \brief Win32 process code to mimic a POSIX fork()
+ */
#ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
#define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
#include <chrono>
*/
void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout);
/*!
- * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket
- * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent
+ * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client
+ * socket \param mmap_path The memory mapped file path that will contain the information to
+ * duplicate the client socket from the parent
*/
void ChildProcSocketHandler(const std::string& mmap_path);
} // namespace runtime
* \brief Example code that can be compiled and loaded by TVM runtime.
* \file plugin_module.cc
*/
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
namespace tvm_dso_plugin {
class MyModuleNode : public ModuleNode {
public:
- explicit MyModuleNode(int value)
- : value_(value) {}
+ explicit MyModuleNode(int value) : value_(value) {}
- virtual const char* type_key() const final {
- return "MyModule";
- }
+ virtual const char* type_key() const final { return "MyModule"; }
- virtual PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ virtual PackedFunc GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) final {
if (name == "add") {
- return TypedPackedFunc<int(int)>([sptr_to_self, this](int value) {
- return value_ + value;
- });
+ return TypedPackedFunc<int(int)>([sptr_to_self, this](int value) { return value_ + value; });
} else if (name == "mul") {
- return TypedPackedFunc<int(int)>([sptr_to_self, this](int value) {
- return value_ * value;
- });
+ return TypedPackedFunc<int(int)>([sptr_to_self, this](int value) { return value_ * value; });
} else {
LOG(FATAL) << "unknown function " << name;
return PackedFunc();
*rv = Module(make_object<MyModuleNode>(value));
}
-int SubOne_(int x) {
- return x - 1;
-}
+int SubOne_(int x) { return x - 1; }
// USE TVM_DLL_EXPORT_TYPED_PACKED_FUNC to export a
// typed function as packed function.
TVM_DLL_EXPORT_TYPED_FUNC(SubOne, SubOne_);
// TVM_DLL_EXPORT_TYPED_PACKED_FUNC also works for lambda.
-TVM_DLL_EXPORT_TYPED_FUNC(AddOne, [](int x) -> int {
- return x + 1;
-});
+TVM_DLL_EXPORT_TYPED_FUNC(AddOne, [](int x) -> int { return x + 1; });
// Use TVM_EXPORT_PACKED_FUNC to export a function with
TVM_DLL_EXPORT_PACKED_FUNC(CreateMyModule, tvm_dso_plugin::CreateMyModule_);
* under the License.
*/
-
/*!
* \brief Example package that uses TVM.
* \file tvm_ext.cc
*/
-#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/runtime/module.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
-#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
using namespace tvm;
public:
class SubContainer : public NDArray::Container {
public:
- SubContainer(int additional_info) :
- additional_info_(additional_info) {
+ SubContainer(int additional_info) : additional_info_(additional_info) {
type_index_ = SubContainer::RuntimeTypeIndex();
}
int additional_info_{0};
data_ = GetObjectPtr<Object>(ptr);
}
- NDSubClass AddWith(const NDSubClass &other) const {
- SubContainer *a = static_cast<SubContainer*>(get_mutable());
- SubContainer *b = static_cast<SubContainer*>(other.get_mutable());
+ NDSubClass AddWith(const NDSubClass& other) const {
+ SubContainer* a = static_cast<SubContainer*>(get_mutable());
+ SubContainer* b = static_cast<SubContainer*>(other.get_mutable());
CHECK(a != nullptr && b != nullptr);
return NDSubClass(a->additional_info_ + b->additional_info_);
}
int get_additional_info() const {
- SubContainer *self = static_cast<SubContainer*>(get_mutable());
+ SubContainer* self = static_cast<SubContainer*>(get_mutable());
CHECK(self != nullptr);
return self->additional_info_;
}
namespace tvm_ext {
-TVM_REGISTER_GLOBAL("tvm_ext.ivec_create")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- auto n = tvm::runtime::make_object<IntVectorObj>();
- for (int i = 0; i < args.size(); ++i) {
- n->vec.push_back(args[i].operator int());
- }
- *rv = IntVector(n);
- });
-
-TVM_REGISTER_GLOBAL("tvm_ext.ivec_get")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- IntVector p = args[0];
- *rv = p->vec[args[1].operator int()];
- });
-
-
-TVM_REGISTER_GLOBAL("tvm_ext.bind_add")
-.set_body([](TVMArgs args_, TVMRetValue *rv_) {
- PackedFunc pf = args_[0];
- int b = args_[1];
- *rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue *rv) {
- *rv = pf(b, args[0]);
- });
- });
-
-TVM_REGISTER_GLOBAL("tvm_ext.sym_add")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- Var a = args[0];
- Var b = args[1];
- *rv = a + b;
- });
-
-TVM_REGISTER_GLOBAL("device_api.ext_dev")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
- });
-
-TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("tvm_ext.ivec_create").set_body([](TVMArgs args, TVMRetValue* rv) {
+ auto n = tvm::runtime::make_object<IntVectorObj>();
+ for (int i = 0; i < args.size(); ++i) {
+ n->vec.push_back(args[i].operator int());
+ }
+ *rv = IntVector(n);
+});
+
+TVM_REGISTER_GLOBAL("tvm_ext.ivec_get").set_body([](TVMArgs args, TVMRetValue* rv) {
+ IntVector p = args[0];
+ *rv = p->vec[args[1].operator int()];
+});
+
+TVM_REGISTER_GLOBAL("tvm_ext.bind_add").set_body([](TVMArgs args_, TVMRetValue* rv_) {
+ PackedFunc pf = args_[0];
+ int b = args_[1];
+ *rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue* rv) { *rv = pf(b, args[0]); });
+});
+
+TVM_REGISTER_GLOBAL("tvm_ext.sym_add").set_body([](TVMArgs args, TVMRetValue* rv) {
+ Var a = args[0];
+ Var b = args[1];
+ *rv = a + b;
+});
+
+TVM_REGISTER_GLOBAL("device_api.ext_dev").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
+});
+
+TVM_REGISTER_GLOBAL("tvm_ext.nd_create").set_body([](TVMArgs args, TVMRetValue* rv) {
int additional_info = args[0];
*rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kTVMNDArrayHandle);
-
});
-TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two").set_body([](TVMArgs args, TVMRetValue* rv) {
NDSubClass a = args[0];
NDSubClass b = args[1];
*rv = a.AddWith(b);
});
-TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info").set_body([](TVMArgs args, TVMRetValue* rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
});
} // namespace tvm_ext
// External function exposed to runtime.
-extern "C" float TVMTestAddOne(float y) {
- return y + 1;
-}
+extern "C" float TVMTestAddOne(float y) { return y + 1; }
// This callback approach allows extension allows tvm to extract
// This way can be helpful when we want to use a header only
// minimum version of TVM Runtime.
extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) {
- const PackedFunc& fregister =
- *static_cast<PackedFunc*>(pregister);
- auto mul = [](TVMArgs args, TVMRetValue *rv) {
+ const PackedFunc& fregister = *static_cast<PackedFunc*>(pregister);
+ auto mul = [](TVMArgs args, TVMRetValue* rv) {
int x = args[0];
int y = args[1];
*rv = x * y;
* \brief Example code on load and run TVM module.s
* \file cpp_deploy.cc
*/
-#include <cstdio>
#include <dlpack/dlpack.h>
#include <tvm/runtime/module.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstdio>
void Verify(tvm::runtime::Module mod, std::string fname) {
// Get the function from the module.
int device_type = kDLCPU;
int device_id = 0;
int64_t shape[1] = {10};
- TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
- device_type, device_id, &x);
- TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
- device_type, device_id, &y);
+ TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &x);
+ TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &y);
for (int i = 0; i < shape[0]; ++i) {
static_cast<float*>(x->data)[i] = i;
}
int main(void) {
// Normally we can directly
- tvm::runtime::Module mod_dylib =
- tvm::runtime::Module::LoadFromFile("lib/test_addone_dll.so");
+ tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile("lib/test_addone_dll.so");
LOG(INFO) << "Verify dynamic loading from test_addone_dll.so";
Verify(mod_dylib, "addone");
// For libraries that are directly packed as system lib and linked together with the app
*/
#include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/cpu_device_api.cc"
-#include "../../src/runtime/workspace_pool.cc"
+#include "../../src/runtime/file_util.cc"
#include "../../src/runtime/library_module.cc"
#include "../../src/runtime/module.cc"
-#include "../../src/runtime/registry.cc"
-#include "../../src/runtime/file_util.cc"
-#include "../../src/runtime/threading_backend.cc"
-#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"
+#include "../../src/runtime/registry.cc"
+#include "../../src/runtime/thread_pool.cc"
+#include "../../src/runtime/threading_backend.cc"
+#include "../../src/runtime/workspace_pool.cc"
// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
@interface AppDelegate : UIResponder <UIApplicationDelegate>
-@property (strong, nonatomic) UIWindow *window;
-
+@property(strong, nonatomic) UIWindow* window;
@end
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define DMLC_LOG_CUSTOMIZE 1
#define TVM_METAL_RUNTIME 1
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
#include <functional>
namespace tvm {
* \param remote_key The remote key
* \return The event handler.
*/
-FEventHandler CreateServerEventHandler(NSOutputStream *outputStream,
- std::string name,
+FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string name,
std::string remote_key);
} // namespace runtime
// Runtime API
#include "../../../src/runtime/c_runtime_api.cc"
#include "../../../src/runtime/cpu_device_api.cc"
-#include "../../../src/runtime/workspace_pool.cc"
-#include "../../../src/runtime/thread_pool.cc"
-#include "../../../src/runtime/threading_backend.cc"
+#include "../../../src/runtime/dso_library.cc"
+#include "../../../src/runtime/file_util.cc"
#include "../../../src/runtime/library_module.cc"
-#include "../../../src/runtime/system_library.cc"
#include "../../../src/runtime/module.cc"
-#include "../../../src/runtime/registry.cc"
-#include "../../../src/runtime/file_util.cc"
-#include "../../../src/runtime/dso_library.cc"
#include "../../../src/runtime/ndarray.cc"
#include "../../../src/runtime/object.cc"
+#include "../../../src/runtime/registry.cc"
+#include "../../../src/runtime/system_library.cc"
+#include "../../../src/runtime/thread_pool.cc"
+#include "../../../src/runtime/threading_backend.cc"
+#include "../../../src/runtime/workspace_pool.cc"
// RPC server
-#include "../../../src/runtime/rpc/rpc_session.cc"
+#include "../../../src/runtime/rpc/rpc_module.cc"
#include "../../../src/runtime/rpc/rpc_server_env.cc"
+#include "../../../src/runtime/rpc/rpc_session.cc"
#include "../../../src/runtime/rpc/rpc_socket_impl.cc"
-#include "../../../src/runtime/rpc/rpc_module.cc"
// Graph runtime
#include "../../../src/runtime/graph/graph_runtime.cc"
// Metal
-#include "../../../src/runtime/metal/metal_module.mm"
#include "../../../src/runtime/metal/metal_device_api.mm"
+#include "../../../src/runtime/metal/metal_module.mm"
// CoreML
#include "../../../src/runtime/contrib/coreml/coreml_runtime.mm"
namespace dmlc {
// Override logging mechanism
-void CustomLogMessage::Log(const std::string& msg) {
- NSLog(@"%s", msg.c_str());
-}
+void CustomLogMessage::Log(const std::string& msg) { NSLog(@"%s", msg.c_str()); }
} // namespace dmlc
namespace tvm {
class NSStreamChannel final : public RPCChannel {
public:
- explicit NSStreamChannel(NSOutputStream* stream)
- : stream_(stream) {}
+ explicit NSStreamChannel(NSOutputStream* stream) : stream_(stream) {}
size_t Send(const void* data, size_t size) final {
- ssize_t nbytes = [stream_ write:reinterpret_cast<const uint8_t*>(data)
- maxLength:size];
+ ssize_t nbytes = [stream_ write:reinterpret_cast<const uint8_t*>(data) maxLength:size];
if (nbytes < 0) {
- NSLog(@"%@",[stream_ streamError].localizedDescription);
+ NSLog(@"%@", [stream_ streamError].localizedDescription);
throw dmlc::Error("Stream error");
}
return nbytes;
NSOutputStream* stream_;
};
-FEventHandler CreateServerEventHandler(
- NSOutputStream *outputStream, std::string name, std::string remote_key) {
+FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string name,
+ std::string remote_key) {
std::unique_ptr<NSStreamChannel> ch(new NSStreamChannel(outputStream));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name, remote_key);
return [sess](const std::string& in_bytes, int flag) {
}
}
// Get Path.
- std::string GetPath(const std::string& file_name) {
- return base_ + file_name;
- }
+ std::string GetPath(const std::string& file_name) { return base_ + file_name; }
private:
std::string base_;
// only load dylib from frameworks.
NSBundle* bundle = [NSBundle mainBundle];
NSString* base = [bundle privateFrameworksPath];
- NSString* path = [base stringByAppendingPathComponent: @"tvm/rpc_config.txt"];
+ NSString* path = [base stringByAppendingPathComponent:@"tvm/rpc_config.txt"];
std::string name = [path UTF8String];
std::ifstream fs(name, std::ios::in);
std::string url, key;
int port;
- CHECK(fs >> url >> port >> key)
- << "Invalid RPC config file " << name;
- RPCConnect(url, port, "server:" + key)
- ->ServerLoop();
+ CHECK(fs >> url >> port >> key) << "Invalid RPC config file " << name;
+ RPCConnect(url, port, "server:" + key)->ServerLoop();
}
-TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- static RPCEnv env;
- *rv = env.GetPath(args[0]);
- });
-
-TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- std::string name = args[0];
- std::string fmt = GetFileFormat(name, "");
- NSString* base;
- if (fmt == "dylib") {
- // only load dylib from frameworks.
- NSBundle* bundle = [NSBundle mainBundle];
- base = [[bundle privateFrameworksPath]
- stringByAppendingPathComponent: @"tvm"];
- } else {
- // Load other modules in tempdir.
- base = NSTemporaryDirectory();
- }
- NSString* path = [base stringByAppendingPathComponent:
- [NSString stringWithUTF8String:name.c_str()]];
- name = [path UTF8String];
- *rv = Module::LoadFromFile(name, fmt);
- LOG(INFO) << "Load module from " << name << " ...";
- });
+TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) {
+ static RPCEnv env;
+ *rv = env.GetPath(args[0]);
+});
+
+TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) {
+ std::string name = args[0];
+ std::string fmt = GetFileFormat(name, "");
+ NSString* base;
+ if (fmt == "dylib") {
+ // only load dylib from frameworks.
+ NSBundle* bundle = [NSBundle mainBundle];
+ base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"];
+ } else {
+ // Load other modules in tempdir.
+ base = NSTemporaryDirectory();
+ }
+ NSString* path =
+ [base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]];
+ name = [path UTF8String];
+ *rv = Module::LoadFromFile(name, fmt);
+ LOG(INFO) << "Load module from " << name << " ...";
+});
} // namespace runtime
} // namespace tvm
@implementation TVMRuntime
-+(void) launchSyncServer {
++ (void)launchSyncServer {
tvm::runtime::LaunchSyncServer();
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#import <UIKit/UIKit.h>
#include "TVMRuntime.h"
-@interface ViewController : UIViewController<NSStreamDelegate>
-{
+@interface ViewController : UIViewController <NSStreamDelegate> {
// input socket stream
- NSInputStream *inputStream_;
+ NSInputStream* inputStream_;
// output socket stream
- NSOutputStream *outputStream_;
+ NSOutputStream* outputStream_;
// temporal receive buffer.
std::string recvBuffer_;
// Whether connection is initialized.
tvm::runtime::FEventHandler handler_;
}
-@property (weak, nonatomic) IBOutlet UITextField *proxyURL;
-@property (weak, nonatomic) IBOutlet UITextField *proxyPort;
-@property (weak, nonatomic) IBOutlet UITextField *proxyKey;
-@property (weak, nonatomic) IBOutlet UILabel *statusLabel;
-@property (weak, nonatomic) IBOutlet UITextView *infoText;
+@property(weak, nonatomic) IBOutlet UITextField* proxyURL;
+@property(weak, nonatomic) IBOutlet UITextField* proxyPort;
+@property(weak, nonatomic) IBOutlet UITextField* proxyKey;
+@property(weak, nonatomic) IBOutlet UILabel* statusLabel;
+@property(weak, nonatomic) IBOutlet UITextView* infoText;
- (IBAction)connect:(id)sender;
- (IBAction)disconnect:(id)sender;
* \file ViewController.mm
*/
-#include <string>
#import "ViewController.h"
+#include <string>
@implementation ViewController
-- (void)stream:(NSStream *)strm handleEvent:(NSStreamEvent)event {
+- (void)stream:(NSStream*)strm handleEvent:(NSStreamEvent)event {
std::string buffer;
switch (event) {
case NSStreamEventOpenCompleted: {
break;
}
case NSStreamEventErrorOccurred: {
- NSLog(@"%@",[strm streamError].localizedDescription);
+ NSLog(@"%@", [strm streamError].localizedDescription);
break;
}
case NSStreamEventEndEncountered: {
constexpr int kRPCMagic = 0xff271;
if (!initialized_) {
int code;
- size_t nbytes = [inputStream_ read:reinterpret_cast<uint8_t*>(&code)
- maxLength:sizeof(code)];
+ size_t nbytes = [inputStream_ read:reinterpret_cast<uint8_t*>(&code) maxLength:sizeof(code)];
if (nbytes != sizeof(code)) {
self.infoText.text = @"Fail to receive remote confirmation code.";
[self close];
- (void)onWriteAvailable {
if (initSendPtr_ < initBytes_.length()) {
initSendPtr_ += [outputStream_ write:reinterpret_cast<uint8_t*>(&initBytes_[initSendPtr_])
- maxLength:(initBytes_.length() - initSendPtr_)];
+ maxLength:(initBytes_.length() - initSendPtr_)];
}
if (initialized_) {
try {
// Initialize the network.
CFReadStreamRef readStream;
CFWriteStreamRef writeStream;
- CFStreamCreatePairWithSocketToHost(
- NULL,
- (__bridge CFStringRef) self.proxyURL.text,
- [self.proxyPort.text intValue],
- &readStream, &writeStream);
- inputStream_ = (__bridge_transfer NSInputStream *)readStream;
- outputStream_ = (__bridge_transfer NSOutputStream *)writeStream;
+ CFStreamCreatePairWithSocketToHost(NULL, (__bridge CFStringRef)self.proxyURL.text,
+ [self.proxyPort.text intValue], &readStream, &writeStream);
+ inputStream_ = (__bridge_transfer NSInputStream*)readStream;
+ outputStream_ = (__bridge_transfer NSOutputStream*)writeStream;
[inputStream_ setDelegate:self];
[outputStream_ setDelegate:self];
[inputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode];
@implementation tvmrpcLauncher
- (void)setUp {
- [super setUp];
+ [super setUp];
}
- (void)tearDown {
- [super tearDown];
+ [super tearDown];
}
- (void)testRPC {
[TVMRuntime launchSyncServer];
}
-
@end
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_USE_MIOPEN 1
#define __HIP_PLATFORM_HCC__ 1
-#include "../../src/runtime/rocm/rocm_device_api.cc"
-#include "../../src/runtime/rocm/rocm_module.cc"
#include "../../src/contrib/miopen/conv_forward.cc"
#include "../../src/contrib/miopen/miopen_utils.cc"
+#include "../../src/runtime/rocm/rocm_device_api.cc"
+#include "../../src/runtime/rocm/rocm_module.cc"
// Standard includes
#include <stddef.h>
+#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
-#include <stdint.h>
// golang string compatible definition
-typedef struct { char *p; int n; } _gostring_;
+typedef struct {
+ char* p;
+ int n;
+} _gostring_;
#include <string>
#ifdef __cplusplus
#endif
// TVM runtime C interface
-#include <tvm/runtime/c_runtime_api.h>
#include <dlpack/dlpack.h>
+#include <tvm/runtime/c_runtime_api.h>
/*!
* \brief Convert native char array to _gostring_ structure.
* \return _gostring_ object corresponding to native char array.
* Caller is responsible to free the memory block allocated here.
*/
-static _gostring_ _native_to_gostring(const char *p, size_t l) {
+static _gostring_ _native_to_gostring(const char* p, size_t l) {
_gostring_ ret;
ret.p = reinterpret_cast<char*>(malloc(l));
if (NULL == ret.p) {
* \param off is the offset in the string object.
* \param v is the uint64_t value which need to embed into given string.
*/
-static void putuint64(std::string *s, size_t off, uint64_t v) {
- for (int i = 0; i < 8; i++) {
- (*s)[off + i] = (v >> (i * 8)) & 0xff;
- }
+static void putuint64(std::string* s, size_t off, uint64_t v) {
+ for (int i = 0; i < 8; i++) {
+ (*s)[off + i] = (v >> (i * 8)) & 0xff;
+ }
}
// TVM runtime C interface wrappers
* \return char pointer to TVM-VERSION
*/
const char* _TVM_VERSION(void) {
- const char *version = TVM_VERSION;
+ const char* version = TVM_VERSION;
return version;
}
*/
int _TVMFuncListGlobalNames(_gostring_* names) {
int names_size;
- char **names_array;
+ char** names_array;
int result;
- result = TVMFuncListGlobalNames(&names_size, (char const ***)&names_array);
+ result = TVMFuncListGlobalNames(&names_size, (char const***)&names_array);
if (result) {
return result;
}
size_t tot = 8;
- for (int ii = 0; ii < names_size ; ++ii) {
+ for (int ii = 0; ii < names_size; ++ii) {
tot += 8 + strlen(names_array[ii]);
}
str.resize(tot);
putuint64(&str, 0, names_size);
size_t off = 8;
- for (int64_t ii = 0; ii < names_size ; ++ii) {
+ for (int64_t ii = 0; ii < names_size; ++ii) {
putuint64(&str, off, strlen(names_array[ii]));
off += 8;
str.replace(off, strlen(names_array[ii]), names_array[ii]);
* \param array index in native array.
*/
void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) {
- TVMValue *from_p = reinterpret_cast<TVMValue*>(from_ptr);
- TVMValue *to_p = reinterpret_cast<TVMValue*>(to_ptr);
- memcpy(to_p+ind, from_p, sizeof(TVMValue));
+ TVMValue* from_p = reinterpret_cast<TVMValue*>(from_ptr);
+ TVMValue* to_p = reinterpret_cast<TVMValue*>(to_ptr);
+ memcpy(to_p + ind, from_p, sizeof(TVMValue));
}
/*!
* \param array index in native array.
*/
void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) {
- TVMValue *from_p = reinterpret_cast<TVMValue*>(from_ptr);
- TVMValue *to_p = reinterpret_cast<TVMValue*>(to_ptr);
- memcpy(to_p, from_p+ind, sizeof(TVMValue));
+ TVMValue* from_p = reinterpret_cast<TVMValue*>(from_ptr);
+ TVMValue* to_p = reinterpret_cast<TVMValue*>(to_ptr);
+ memcpy(to_p, from_p + ind, sizeof(TVMValue));
}
extern int goTVMCallback(void*, void*, int, void*, void*);
*
* \returns the error status as TVM_DLL
*/
-int _TVMCallback(TVMValue* args,
- int* type_codes,
- int num_args,
- TVMRetValueHandle ret,
+int _TVMCallback(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret,
void* resource_handle) {
- return goTVMCallback(args, type_codes, num_args, ret, resource_handle);
+ return goTVMCallback(args, type_codes, num_args, ret, resource_handle);
}
/*!
* _TVMPackedCFuncFinalizer is finalizer for packed function system.
*
*/
-void _TVMPackedCFuncFinalizer(void* resource_handle) {
- return;
-}
+void _TVMPackedCFuncFinalizer(void* resource_handle) { return; }
/*!
* /brief _ConvertFunction creates a packed function for with given resource handle.
*
* /return is an int indicating the return status.
*/
-int _ConvertFunction(void* fptr, TVMFunctionHandle *fhandle) {
- int ret = TVMFuncCreateFromCFunc(_TVMCallback,
- fptr,
- _TVMPackedCFuncFinalizer,
- fhandle);
+int _ConvertFunction(void* fptr, TVMFunctionHandle* fhandle) {
+ int ret = TVMFuncCreateFromCFunc(_TVMCallback, fptr, _TVMPackedCFuncFinalizer, fhandle);
return ret;
}
extern "C" {
#endif
+#include <dlpack/dlpack.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <dlpack/dlpack.h>
// Some type definitions for golang "C"
typedef void* native_voidp;
*/
#include "src/runtime/c_runtime_api.cc"
#include "src/runtime/cpu_device_api.cc"
-#include "src/runtime/workspace_pool.cc"
+#include "src/runtime/file_util.cc"
#include "src/runtime/library_module.cc"
#include "src/runtime/module.cc"
-#include "src/runtime/registry.cc"
-#include "src/runtime/file_util.cc"
-#include "src/runtime/threading_backend.cc"
-#include "src/runtime/thread_pool.cc"
#include "src/runtime/ndarray.cc"
#include "src/runtime/object.cc"
+#include "src/runtime/registry.cc"
+#include "src/runtime/thread_pool.cc"
+#include "src/runtime/threading_backend.cc"
+#include "src/runtime/workspace_pool.cc"
// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
#ifndef TVM_ARITH_ANALYZER_H_
#define TVM_ARITH_ANALYZER_H_
-#include <tvm/support/with.h>
-#include <tvm/ir/expr.h>
#include <tvm/arith/int_set.h>
+#include <tvm/ir/expr.h>
+#include <tvm/support/with.h>
-#include <vector>
-#include <unordered_map>
-#include <memory>
#include <limits>
+#include <memory>
+#include <unordered_map>
+#include <vector>
namespace tvm {
/*! \brief namespace of arithmetic analysis. */
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
- TVM_DLL void Update(const Var& var,
- const ConstIntBound& info,
- bool override = false);
+ TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool override = false);
/*!
* \brief Bind variable to a range.
*
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
- TVM_DLL void Update(const Var& var,
- const ModularSet& info,
- bool override = false);
+ TVM_DLL void Update(const Var& var, const ModularSet& info, bool override = false);
private:
friend class Analyzer;
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
- TVM_DLL void Update(const Var& var,
- const PrimExpr& new_expr,
- bool override = false);
+ TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false);
std::function<void()> EnterConstraint(const PrimExpr& constraint);
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
- TVM_DLL void Update(const Var& var,
- const PrimExpr& new_expr,
- bool override = false);
+ TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false);
private:
friend class Analyzer;
#ifndef TVM_ARITH_BOUND_H_
#define TVM_ARITH_BOUND_H_
-#include <tvm/node/container.h>
-#include <tvm/ir/expr.h>
#include <tvm/arith/int_set.h>
+#include <tvm/ir/expr.h>
+#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
}
namespace arith {
-using tir::Var;
-using tir::VarNode;
using tir::Domain;
using tir::Stmt;
+using tir::Var;
+using tir::VarNode;
/*!
* \brief Deduce the bound of the target variable in a expression,
* The deduce bound must implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
-IntSet DeduceBound(PrimExpr v, PrimExpr cond,
- const Map<Var, IntSet>& hint_map,
+IntSet DeduceBound(PrimExpr v, PrimExpr cond, const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
/*!
* \brief Same as DeduceBound with unordered_map signature.
* \param consider_stores If stores are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
-Domain DomainTouched(const Stmt& body,
- const tir::Buffer& buffer,
- bool consider_loads,
+Domain DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads,
bool consider_stores);
} // namespace arith
#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
+
#include <unordered_map>
namespace tvm {
namespace arith {
+using tir::IterVar;
using tir::Var;
using tir::VarNode;
-using tir::IterVar;
//-----------------------------------------------
// Integer set data structure.
/*!
* \brief Sign type of an integer expression.
*/
-enum SignType {
- kPositive,
- kNegative,
- kZero,
- kUnknown
-};
+enum SignType { kPositive, kNegative, kZero, kUnknown };
/*!
* \brief Base class of all Integer set containers.
* \brief access the internal node container
* \return the pointer to the internal node container
*/
- const IntSetNode* operator->() const {
- return static_cast<const IntSetNode*>(get());
- }
+ const IntSetNode* operator->() const { return static_cast<const IntSetNode*>(get()); }
/*!
* \brief Find a range that covers the region.
* \param max_range The range to be covered.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
-IntSet EvalSet(PrimExpr e,
- const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
+IntSet EvalSet(PrimExpr e, const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
-IntSet EvalSet(Range r,
- const Map<IterVar, IntSet>& dom_map);
+IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
-IntSet EvalSet(IntSet s,
- const std::unordered_map<const VarNode*, IntSet>& dom_map);
+IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
-IntSet EvalSet(Range r,
- const std::unordered_map<const VarNode*, IntSet>& dom_map);
+IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*!
* \param dom_map The domain of each variable.
* \return the map from the expression to its possible value.
*/
-ExprIntSetMap EvalSetForEachSubExpr(
- PrimExpr e,
- const std::unordered_map<const VarNode*, IntSet>& dom_map);
+ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Create an union set of all sets
#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
+
#include <unordered_map>
#include <vector>
namespace tvm {
namespace arith {
+using tir::IterVar;
using tir::Var;
using tir::VarNode;
-using tir::IterVar;
/*!
* \brief Represent integer constrains including (integer) variables, their ranges and
}
bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const {
- return
- equal(variables, other->variables) &&
- equal(ranges, other->ranges) &&
- equal(relations, other->relations);
+ return equal(variables, other->variables) && equal(ranges, other->ranges) &&
+ equal(relations, other->relations);
}
void SHashReduce(SHashReducer hash_reduce) const {
* \param relations The linear relations between the variables
* (either equations or inequalities)
*/
- TVM_DLL IntConstraints(Array<Var> variables,
- Map<Var, Range> ranges,
- Array<PrimExpr> relations);
+ TVM_DLL IntConstraints(Array<Var> variables, Map<Var, Range> ranges, Array<PrimExpr> relations);
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode);
};
}
bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const {
- return
- equal(src, other->src) &&
- equal(dst, other->dst) &&
- equal(src_to_dst, other->src_to_dst) &&
- equal(dst_to_src, other->dst_to_src);
+ return equal(src, other->src) && equal(dst, other->dst) &&
+ equal(src_to_dst, other->src_to_dst) && equal(dst_to_src, other->dst_to_src);
}
void SHashReduce(SHashReducer hash_reduce) const {
* \param dst_to_src mapping from variables in the \p dst to the variables in the \p src,
* e.g., {m -> a, n -> -b}
*/
- TVM_DLL IntConstraintsTransform(IntConstraints src,
- IntConstraints dst,
- Map<Var, PrimExpr> src_to_dst,
- Map<Var, PrimExpr> dst_to_src);
+ TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst,
+ Map<Var, PrimExpr> src_to_dst, Map<Var, PrimExpr> dst_to_src);
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
};
* NOTE: Although in standard Smith Normal Form the diagonal elements satisfy
* s_i | s_{i+1} (| means divides), the implement here does not guarantee it.
* TODO(yzhliu): From sergei-grechanik:
- * computing the proper Smith normal form may improve stability of automatic differentiation
- * (generating the same gradient code for slightly different but equivalent input code
- * U_{mxm} and V_{nxn} are invertible matrices.
- * This function modifies \p S to be S_{mxn}, \p V to be V_{nxn},
- * \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x.
- * \param S the original A_{mxn}, it will be modified to S_{mxn}
- * \param V an identity matrix, it will be modified to V_{nxn}
- * \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1}
- * \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1}
+ * computing the proper Smith normal form may improve stability of automatic
+ * differentiation (generating the same gradient code for slightly different but equivalent input
+ * code U_{mxm} and V_{nxn} are invertible matrices. This function modifies \p S to be S_{mxn}, \p V
+ * to be V_{nxn}, \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. \param S the original
+ * A_{mxn}, it will be modified to S_{mxn} \param V an identity matrix, it will be modified to
+ * V_{nxn} \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} \param y the y
+ * in A x = y. it will be modified to U_{mxm} y_{mx1}
*/
-void SmithNormalFormDiag(std::vector<std::vector<int64_t>> *S,
- std::vector<std::vector<int64_t>> *V,
- std::vector<PrimExpr>* x,
- std::vector<PrimExpr> *y);
+void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::vector<int64_t>>* V,
+ std::vector<PrimExpr>* x, std::vector<PrimExpr>* y);
/*!
* \brief Solve linear equations.
* as well as inequalities inferred from the \p system_to_solve.
* You can get the mapping from the original variables to the solution via ret->src_to_dst.
*/
-IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve);
+IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve);
} // namespace arith
} // namespace tvm
#ifndef TVM_ARITH_PATTERN_H_
#define TVM_ARITH_PATTERN_H_
-#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
+#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
namespace tvm {
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
-Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
- const Array<tir::Var>& vars);
+Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<tir::Var>& vars);
/*!
* \brief Detect if expression corresponds to clip bound of the vars
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
-Array<PrimExpr> DetectClipBound(const PrimExpr& e,
- const Array<tir::Var>& vars);
+Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<tir::Var>& vars);
} // namespace arith
} // namespace tvm
#ifndef TVM_DRIVER_DRIVER_API_H_
#define TVM_DRIVER_DRIVER_API_H_
+#include <tvm/ir/module.h>
#include <tvm/runtime/packed_func.h>
-#include <tvm/target/target.h>
#include <tvm/support/with.h>
-#include <tvm/ir/module.h>
+#include <tvm/target/target.h>
#include <tvm/te/schedule_pass.h>
#include <string>
-#include <vector>
-#include <utility>
#include <unordered_map>
#include <unordered_set>
+#include <utility>
+#include <vector>
namespace tvm {
/*!
-* \brief Build an IRModule given a schedule, args and binds
-* \param sch The schedule to lower.
-* \param args The arguments to the function.
-* \param name The name of the lowered function.
-* \param binds Buffer assignments.
-* \param config The build configuration.
-* \return The result module.
-*/
-TVM_DLL IRModule lower(
- te::Schedule sch,
- const Array<te::Tensor>& args,
- const std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- const BuildConfig& config);
+ * \brief Build an IRModule given a schedule, args and binds
+ * \param sch The schedule to lower.
+ * \param args The arguments to the function.
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \param config The build configuration.
+ * \return The result module.
+ */
+TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
+ const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+ const BuildConfig& config);
/*!
-* \brief Build a device and host module for a specific target from an IRModule.
-* \param funcs The functions to be built.
-* \param target The target device to build for.
-* \param target_host The target for building host code. To use the default, pass Target()
-* \param config The build configuration.
-* \return The built module.
-*/
-TVM_DLL runtime::Module build(const IRModule& funcs,
- const Target& target,
- const Target& target_host,
- const BuildConfig& config);
+ * \brief Build a device and host module for a specific target from an IRModule.
+ * \param funcs The functions to be built.
+ * \param target The target device to build for.
+ * \param target_host The target for building host code. To use the default, pass Target()
+ * \param config The build configuration.
+ * \return The built module.
+ */
+TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target,
+ const Target& target_host, const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from a map
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
-TVM_DLL runtime::Module build(const Map<Target, IRModule>& input,
- const Target& target_host,
+TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host,
const BuildConfig& config);
/*!
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
-TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input,
- const Target& target_host,
+TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input, const Target& target_host,
const BuildConfig& config);
} // namespace tvm
#ifndef TVM_IR_ADT_H_
#define TVM_IR_ADT_H_
-#include <tvm/runtime/object.h>
-#include <tvm/node/node.h>
-#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
+#include <tvm/node/container.h>
+#include <tvm/node/node.h>
+#include <tvm/runtime/object.h>
+
#include <string>
namespace tvm {
bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const {
// Use namehint for now to be consistent with the legacy relay impl
// TODO(tvm-team) revisit, need to check the type var.
- return
- equal(name_hint, other->name_hint) &&
- equal(inputs, other->inputs);
+ return equal(name_hint, other->name_hint) && equal(inputs, other->inputs);
}
void SHashReduce(SHashReducer hash_reduce) const {
* \param inputs The input types.
* \param belong_to The data type var the constructor will construct.
*/
- TVM_DLL Constructor(std::string name_hint,
- Array<Type> inputs,
- GlobalTypeVar belong_to);
+ TVM_DLL Constructor(std::string name_hint, Array<Type> inputs, GlobalTypeVar belong_to);
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
};
}
bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const {
- return
- equal.DefEqual(header, other->header) &&
- equal.DefEqual(type_vars, other->type_vars) &&
- equal(constructors, other->constructors);
+ return equal.DefEqual(header, other->header) && equal.DefEqual(type_vars, other->type_vars) &&
+ equal(constructors, other->constructors);
}
void SHashReduce(SHashReducer hash_reduce) const {
* \param type_vars type variables.
* \param constructors constructors field.
*/
- TVM_DLL TypeData(GlobalTypeVar header,
- Array<TypeVar> type_vars,
- Array<Constructor> constructors);
+ TVM_DLL TypeData(GlobalTypeVar header, Array<TypeVar> type_vars, Array<Constructor> constructors);
TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
};
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/packed_func.h>
-#include <unordered_map>
-#include <vector>
#include <functional>
-#include <type_traits>
#include <string>
+#include <type_traits>
+#include <unordered_map>
#include <utility>
+#include <vector>
namespace tvm {
/*!
* \param ClassName The name of the class.
* \param TypeKey The type key to be used by the TVM node system.
*/
-#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
- static constexpr const char* _type_key = TypeKey; \
- TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
- template<typename FVisit> \
+#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
+ static constexpr const char* _type_key = TypeKey; \
+ TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
+ template <typename FVisit> \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)
-
/*!
* \brief Declare an attribute field.
* \param FieldName The field name.
*/
-#define TVM_ATTR_FIELD(FieldName) \
- __fvisit__(#FieldName, &FieldName)
-
+#define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName)
/*!
* \brief Create a NodeRef type that represents null.
* \tparam TNodeRef the type to be created.
* \return A instance that will represent None.
*/
-template<typename TObjectRef>
+template <typename TObjectRef>
inline TObjectRef NullValue() {
- static_assert(TObjectRef::_type_is_nullable,
- "Can only get NullValue for nullable types");
+ static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
return TObjectRef(ObjectPtr<Object>(nullptr));
}
-template<>
+template <>
inline DataType NullValue<DataType>() {
return DataType(DataType::kHandle, 0, 0);
}
* \brief constructor
* \param msg error message
*/
- explicit AttrError(const std::string &msg)
- : dmlc::Error(msg) {}
+ explicit AttrError(const std::string& msg) : dmlc::Error(msg) {}
};
/*!
* \param args The postional arguments in the form
* [key0, value0, key1, value1, ..., key_n, value_n]
*/
- template<typename... Args>
- inline void InitBySeq(Args&& ...args);
+ template <typename... Args>
+ inline void InitBySeq(Args&&... args);
/*!
* \brief Print readible docstring to ostream, add newline.
* \param os the stream to print the docstring to.
*/
- inline void PrintDocString(std::ostream &os) const; // NOLINT(*)
+ inline void PrintDocString(std::ostream& os) const; // NOLINT(*)
/*!
* \brief Visit attributes that do not equal the default value.
*
return equal(dict, other->dict);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(dict);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); }
// implementations
void VisitAttrs(AttrVisitor* v) final;
*/
TVM_DLL explicit DictAttrs(Map<std::string, ObjectRef> dict);
-
TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};
struct AttrNopEntry {
using TSelf = AttrNopEntry;
- TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
- return *this;
- }
- template<typename T>
+ TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
+ template <typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
return *this;
}
- template<typename T>
+ template <typename T>
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
return *this;
}
- template<typename T>
+ template <typename T>
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
return *this;
}
// Wrapper for normal visitor.
class AttrNormalVisitor {
public:
- explicit AttrNormalVisitor(AttrVisitor* visitor)
- : visitor_(visitor) {
- }
- template<typename T>
+ explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
+ template <typename T>
AttrNopEntry operator()(const char* key, T* value) {
visitor_->Visit(key, value);
return AttrNopEntry();
bool result_{true};
// constructor
AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
- : lhs_(lhs), rhs_(rhs), equal_(equal) {
- }
- template<typename T>
+ : lhs_(lhs), rhs_(rhs), equal_(equal) {}
+ template <typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
- const T* rhs_value =
- reinterpret_cast<const T*>(
- reinterpret_cast<const char*>(rhs_) +
- (reinterpret_cast<const char*>(lhs_value) -
- reinterpret_cast<const char*>(lhs_)));
+ const T* rhs_value = reinterpret_cast<const T*>(
+ reinterpret_cast<const char*>(rhs_) +
+ (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_)));
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
class AttrsSHashVisitor {
public:
- explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
- : hash_reducer_(hash_reducer) {}
+ explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {}
- template<typename T>
+ template <typename T>
AttrNopEntry operator()(const char* key, T* value) {
hash_reducer_(*value);
return AttrNopEntry();
};
// helper entry that does initialization, set default.
-template<typename T>
+template <typename T>
struct AttrInitEntry {
// The attributes
using TSelf = AttrInitEntry<T>;
~AttrInitEntry() DMLC_THROW_EXCEPTION {
if (value_missing_) {
std::ostringstream os;
- os << type_key_ << ": Cannot find required field \'" << key_
- << "\' during initialization";
+ os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization";
throw AttrError(os.str());
}
}
// override fields.
// This function sets the lower bound of the attribute
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
- if (this->value_missing_) return *this;
+ if (this->value_missing_) return *this;
const T& val = *value_;
if (begin > val) {
std::ostringstream os;
os << type_key_ << "." << key_ << ": "
- << "value " << val
- << " is smaller than the lower bound " << begin;
+ << "value " << val << " is smaller than the lower bound " << begin;
throw AttrError(os.str());
}
return *this;
}
// This function sets the upper bound of the attribute
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
- if (this->value_missing_) return *this;
+ if (this->value_missing_) return *this;
const T& val = *value_;
if (val > end) {
std::ostringstream os;
os << type_key_ << "." << key_ << ": "
- << "value " << val
- << " is bigger than the upper bound " << end;
+ << "value " << val << " is bigger than the upper bound " << end;
throw AttrError(os.str());
}
return *this;
value_missing_ = false;
return *this;
}
- TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
- return *this;
- }
+ TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
};
// Template function to allow smart conversion
// from Expr types into the constants.
-template<typename T>
+template <typename T>
inline void SetValue(T* ptr, const TVMArgValue& val) {
*ptr = val.operator T();
}
-template<typename T>
+template <typename T>
inline void SetIntValue(T* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLInt) {
*ptr = static_cast<T>(val.value().v_int64);
}
}
-template<>
+template <>
inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kTVMStr) {
*ptr = val.operator std::string();
}
}
-template<>
+template <>
inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
*ptr = val.operator double();
}
}
}
-template<>
+template <>
inline void SetValue<int>(int* ptr, const TVMArgValue& val) {
SetIntValue(ptr, val);
}
-template<>
+template <>
inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) {
SetIntValue(ptr, val);
}
-template<>
+template <>
inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) {
SetIntValue(ptr, val);
}
-template<>
+template <>
inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) {
SetIntValue(ptr, val);
}
// Visitor for value initialization
-template<typename FFind>
+template <typename FFind>
class AttrInitVisitor {
public:
// Counter of number of matched attributes during visit.
// This is used to decide if there is additional unmatched attributes.
size_t hit_count_{0};
// constructor
- AttrInitVisitor(const char* type_key, FFind ffind)
- : type_key_(type_key), ffind_(ffind) {
- }
+ AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
- template<typename T>
+ template <typename T>
AttrInitEntry<T> operator()(const char* key, T* value) {
TVMArgValue val;
AttrInitEntry<T> opt;
FFind ffind_;
};
-template<typename FFind>
-inline AttrInitVisitor<FFind> CreateInitVisitor(
- const char* type_key,
- FFind ffind) {
+template <typename FFind>
+inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) {
return AttrInitVisitor<FFind>(type_key, ffind);
}
* \brief Helper struct to get the type name known to tvm.
* \tparam T the type we are interested in.
*/
-template<typename T>
+template <typename T>
struct TypeName {
static constexpr const char* value = T::ContainerType::_type_key;
};
-template<>
+template <>
struct TypeName<int> {
static constexpr const char* value = "int";
};
-template<>
+template <>
struct TypeName<int64_t> {
static constexpr const char* value = "int64";
};
-template<>
+template <>
struct TypeName<uint64_t> {
static constexpr const char* value = "uint64_t";
};
-template<>
+template <>
struct TypeName<DataType> {
static constexpr const char* value = "DataType";
};
-template<>
+template <>
struct TypeName<std::string> {
static constexpr const char* value = "str";
};
-template<>
+template <>
struct TypeName<bool> {
static constexpr const char* value = "bool";
};
-template<>
+template <>
struct TypeName<void*> {
static constexpr const char* value = "handle";
};
-template<>
+template <>
struct TypeName<double> {
static constexpr const char* value = "double";
};
public:
using TSelf = AttrDocEntry;
- explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info)
- : info_(info) {
- }
+ explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) {}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
info_->description = str;
return *this;
}
- template<typename T>
+ template <typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
std::ostringstream os;
os << info_->type_info << ", default=" << value;
info_->type_info = os.str();
return *this;
}
- template<typename T>
+ template <typename T>
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
return *this;
}
- template<typename T>
+ template <typename T>
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
return *this;
}
class AttrDocVisitor {
public:
- template<typename T>
+ template <typename T>
AttrDocEntry operator()(const char* key, T* v) {
- ObjectPtr<AttrFieldInfoNode> info
- = make_object<AttrFieldInfoNode>();
+ ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>();
info->name = key;
info->type_info = TypeName<T>::value;
fields_.push_back(AttrFieldInfo(info));
std::string key_;
bool exist_{false};
- template<typename T>
+ template <typename T>
AttrNopEntry operator()(const char* key, T* v) {
if (exist_) return AttrNopEntry();
if (key == key_) exist_ = true;
}
};
-template<typename T>
+template <typename T>
struct AttrTriggerNonDefaultEntry {
using TSelf = AttrTriggerNonDefaultEntry<T>;
// constructor
- AttrTriggerNonDefaultEntry(
- AttrVisitor* visitor, const char* key, T* data)
+ AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data)
: visitor_(visitor), key_(key), data_(data) {}
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
visitor_->Visit(key_, data_);
}
}
- TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
- return *this;
- }
+ TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
TSelf& set_default(const T& value) {
if (tvm::StructuralEqual()(value, *data_)) {
trigger_ = false;
}
return *this;
}
- TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
- return *this;
- }
- TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
- return *this;
- }
+ TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; }
+ TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; }
private:
AttrVisitor* visitor_;
- const char * key_;
- T *data_;
+ const char* key_;
+ T* data_;
bool trigger_{true};
};
class AttrNonDefaultVisitor {
public:
- explicit AttrNonDefaultVisitor(AttrVisitor* visitor)
- : visitor_(visitor) {
- }
- template<typename T>
- AttrTriggerNonDefaultEntry<T>
- operator()(const char* key, T* value) {
+ explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
+ template <typename T>
+ AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) {
return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
}
*
* \tparam DerivedType The final attribute type.
*/
-template<typename DerivedType>
+template <typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
void VisitAttrs(AttrVisitor* v) {
CHECK_EQ(args.type_codes[i], kTVMStr);
kwargs[args[i].operator std::string()] = args[i + 1];
}
- auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) {
+ auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) {
auto it = kwargs.find(key);
if (it != kwargs.end()) {
*val = it->second;
self()->__VisitAttrs__(visitor);
if (!visitor.exist_) {
std::ostringstream os;
- os << DerivedType::_type_key
- << ": does not have field \'" << visitor.key_
+ os << DerivedType::_type_key << ": does not have field \'" << visitor.key_
<< "\', Possible fields:\n";
os << "----------------\n";
this->PrintDocString(os);
private:
DerivedType* self() const {
- return const_cast<DerivedType*>(
- static_cast<const DerivedType*>(this));
+ return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
}
};
-
-template<typename... Args>
-inline void BaseAttrsNode::InitBySeq(Args&& ...args) {
- runtime::PackedFunc pf([this](const TVMArgs& args, TVMRetValue *rv) {
- this->InitByPackedArgs(args);
- });
+template <typename... Args>
+inline void BaseAttrsNode::InitBySeq(Args&&... args) {
+ runtime::PackedFunc pf(
+ [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); });
pf(std::forward<Args>(args)...);
}
-inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
+inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*)
Array<AttrFieldInfo> entry = this->ListFieldInfo();
for (AttrFieldInfo info : entry) {
os << info->name << " : " << info->type_info << '\n';
/*! \brief constructor */
EnvFuncNode() {}
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("name", &name);
- }
+ void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
// name uniquely identifies the env function.
EnvFunc() {}
explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The internal global function pointer */
- const EnvFuncNode* operator->() const {
- return static_cast<const EnvFuncNode*>(get());
- }
+ const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }
/*!
* \brief Invoke the function.
* \param args The arguments
* \returns The return value.
*/
- template<typename... Args>
+ template <typename... Args>
runtime::TVMRetValue operator()(Args&&... args) const {
const EnvFuncNode* n = operator->();
CHECK(n != nullptr);
/*!
* \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc<R(Args..)>"
*/
-template<typename FType>
+template <typename FType>
class TypedEnvFunc;
/*!
* \tparam Args The argument signature of the function.
* \sa EnvFunc
*/
-template<typename R, typename... Args>
+template <typename R, typename... Args>
class TypedEnvFunc<R(Args...)> : public ObjectRef {
public:
/*! \brief short hand for this function type */
return *this;
}
/*! \return The internal global function pointer */
- const EnvFuncNode* operator->() const {
- return static_cast<const EnvFuncNode*>(get());
- }
+ const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }
/*!
* \brief Invoke the function.
* \param args The arguments
R operator()(Args... args) const {
const EnvFuncNode* n = operator->();
CHECK(n != nullptr);
- return runtime::detail::typed_packed_call_dispatcher<R>
- ::run(n->func, std::forward<Args>(args)...);
+ return runtime::detail::typed_packed_call_dispatcher<R>::run(n->func,
+ std::forward<Args>(args)...);
}
/*! \brief specify container node */
using ContainerType = EnvFuncNode;
#ifndef TVM_IR_ERROR_H_
#define TVM_IR_ERROR_H_
-#include <tvm/ir/span.h>
#include <tvm/ir/module.h>
+#include <tvm/ir/span.h>
-#include <string>
-#include <vector>
#include <sstream>
+#include <string>
#include <unordered_map>
+#include <vector>
namespace tvm {
/*!
*/
struct ErrorBuilder {
public:
- template<typename T>
+ template <typename T>
ErrorBuilder& operator<<(const T& val) { // NOLINT(*)
stream_ << val;
return *this;
* \brief construct error from error builder.
* \param err The error builder
*/
- Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*)
+ Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*)
/*!
* \brief copy constructor.
* \param other The other ereor.
*/
- Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*)
+ Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*)
/*!
* \brief default constructor. */
Error() : dmlc::Error(""), span(nullptr) {}
*/
void RenderErrors(const IRModule& module, bool use_color = true);
- inline bool AnyErrors() {
- return errors_.size() != 0;
- }
+ inline bool AnyErrors() { return errors_.size() != 0; }
private:
std::vector<Error> errors_;
#ifndef TVM_IR_EXPR_H_
#define TVM_IR_EXPR_H_
-#include <tvm/runtime/object.h>
-#include <tvm/node/node.h>
-#include <tvm/node/container.h>
#include <tvm/ir/span.h>
#include <tvm/ir/type.h>
-#include <string>
+#include <tvm/node/container.h>
+#include <tvm/node/node.h>
+#include <tvm/runtime/object.h>
+
#include <algorithm>
#include <limits>
+#include <string>
#include <type_traits>
namespace tvm {
TVM_DLL PrimExpr(float value); // NOLINT(*)
/*! \return the data type of this expression. */
- DataType dtype() const {
- return static_cast<const PrimExprNode*>(get())->dtype;
- }
+ DataType dtype() const { return static_cast<const PrimExprNode*>(get())->dtype; }
TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
- template<typename TTypeNode>
+ template <typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "RelayExpr";
bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
// name matters for global var.
- return
- equal(name_hint, other->name_hint) &&
- equal.FreeVarEqualImpl(this, other);
+ return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
*/
class Bool : public IntImm {
public:
- explicit Bool(bool value)
- : IntImm(DataType::Bool(), value) {
- }
- Bool operator!() const {
- return Bool((*this)->value == 0);
- }
- operator bool() const {
- return (*this)->value != 0;
- }
+ explicit Bool(bool value) : IntImm(DataType::Bool(), value) {}
+ Bool operator!() const { return Bool((*this)->value == 0); }
+ operator bool() const { return (*this)->value != 0; }
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode);
};
// Overload operators to make sure we have the most fine grained types.
-inline Bool operator||(const Bool& a, bool b) {
- return Bool(a.operator bool() || b);
-}
-inline Bool operator||(bool a, const Bool& b) {
- return Bool(a || b.operator bool());
-}
+inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); }
+inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); }
inline Bool operator||(const Bool& a, const Bool& b) {
return Bool(a.operator bool() || b.operator bool());
}
-inline Bool operator&&(const Bool& a, bool b) {
- return Bool(a.operator bool() && b);
-}
-inline Bool operator&&(bool a, const Bool& b) {
- return Bool(a && b.operator bool());
-}
+inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); }
+inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); }
inline Bool operator&&(const Bool& a, const Bool& b) {
return Bool(a.operator bool() && b.operator bool());
}
* \tparam Enum The enum type.
* \param value The enum value.
*/
- template<typename Enum,
- typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
+ template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
explicit Integer(Enum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value,
"declare enum to be enum int to use visitor");
* \brief convert to int64_t
*/
operator int64_t() const {
- CHECK(data_ != nullptr)
- << " Trying to reference a null Integer";
+ CHECK(data_ != nullptr) << " Trying to reference a null Integer";
return (*this)->value;
}
// comparators
if (data_ == nullptr) return Bool(false);
return Bool((*this)->value == other);
}
- Bool operator!=(int other) const {
- return !(*this == other);
- }
- template<typename Enum,
- typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
+ Bool operator!=(int other) const { return !(*this == other); }
+ template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator==(Enum other) const {
return *this == static_cast<int>(other);
}
- template<typename Enum,
- typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
+ template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator!=(Enum other) const {
return *this != static_cast<int>(other);
}
// implementataions
inline const Type& RelayExprNode::checked_type() const {
- CHECK(checked_type_.defined())
- << "internal error: the type checker has "
- << "not populated the checked_type "
- << "field for "
- << GetRef<RelayExpr>(this);
+ CHECK(checked_type_.defined()) << "internal error: the type checker has "
+ << "not populated the checked_type "
+ << "field for " << GetRef<RelayExpr>(this);
return this->checked_type_;
}
-template<typename TTypeNode>
+template <typename TTypeNode>
inline const TTypeNode* RelayExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
- CHECK(node != nullptr)
- << "Expected type to be " << TTypeNode::_type_key
- << ", but get " << checked_type_->GetTypeKey();
+ CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get "
+ << checked_type_->GetTypeKey();
return node;
}
namespace tvm {
namespace runtime {
-template<>
+template <>
struct PackedFuncValueConverter<PrimExpr> {
// common rule for both RetValue and ArgValue.
static PrimExpr From(const TVMPODValue_& val) {
#ifndef TVM_IR_FUNCTION_H_
#define TVM_IR_FUNCTION_H_
-#include <tvm/ir/expr.h>
#include <tvm/ir/attrs.h>
+#include <tvm/ir/expr.h>
#include <tvm/runtime/container.h>
-#include <type_traits>
-#include <string>
+#include <string>
+#include <type_traits>
namespace tvm {
*
* \endcode
*/
- template<typename TObjectRef>
+ template <typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
}
}
// variant that uses TObjectRef to enable implicit conversion to default value.
- template<typename TObjectRef>
- Optional<TObjectRef> GetAttr(
- const std::string& attr_key, TObjectRef default_value) const {
+ template <typename TObjectRef>
+ Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
/*!
*
* \endcode
*/
-template<typename TFunc,
- typename = typename std::enable_if<
- std::is_base_of<BaseFunc, TFunc>::value>::type>
-inline TFunc WithAttr(TFunc func,
- const std::string& attr_key,
- ObjectRef attr_value) {
+template <typename TFunc,
+ typename = typename std::enable_if<std::is_base_of<BaseFunc, TFunc>::value>::type>
+inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = func.CopyOnWrite();
#ifndef TVM_IR_MODULE_H_
#define TVM_IR_MODULE_H_
-#include <tvm/ir/type.h>
+#include <tvm/ir/adt.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
-#include <tvm/ir/adt.h>
+#include <tvm/ir/type.h>
#include <tvm/node/container.h>
+
#include <string>
-#include <vector>
#include <unordered_map>
#include <unordered_set>
+#include <vector>
namespace tvm {
class IRModule;
*
* It does not do type checking as AddTypeDef does.
*/
- TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var,
- const TypeData& type,
+ TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
bool update = false);
/*!
*
* \returns The constructed module
*/
- static IRModule Empty() {
- return IRModule(Map<GlobalVar, BaseFunc>());
- }
+ static IRModule Empty() { return IRModule(Map<GlobalVar, BaseFunc>()); }
/*!
* \brief Construct a module from a standalone expression.
*
*
* \returns A module with expr set as the main function.
*/
- TVM_DLL static IRModule FromExpr(
- const RelayExpr& expr,
- const Map<GlobalVar, BaseFunc>& global_funcs = {},
- const Map<GlobalTypeVar, TypeData>& type_definitions = {});
+ TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
+ const Map<GlobalVar, BaseFunc>& global_funcs = {},
+ const Map<GlobalTypeVar, TypeData>& type_definitions = {});
/*!
* \brief Parse text format source file into an IRModule.
* \sa PrettyPrint.
* \return The text representation.
*/
-TVM_DLL String AsText(const ObjectRef& node,
- bool show_meta_data = true,
+TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true,
runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
} // namespace tvm
#endif // TVM_IR_MODULE_H_
#include <dmlc/registry.h>
#include <tvm/ir/attrs.h>
-#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h>
+#include <tvm/runtime/registry.h>
#include <string>
#include <utility>
* \param description Description of the argument.
* \return reference to self.
*/
- inline OpRegistry& add_argument(const std::string& name,
- const std::string& type,
+ inline OpRegistry& add_argument(const std::string& name, const std::string& type,
const std::string& description);
/*!
* \brief Attach the type function corresponding to the return type.
*/
inline OpRegistry& add_type_rel(
const std::string& rel_name,
- runtime::TypedPackedFunc<bool(const Array<Type>&,
- int,
- const Attrs&,
- const TypeReporter&)> type_rel_func);
+ runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
+ type_rel_func);
/*!
* \brief Set the the attrs type key and index to be AttrsType.
* \tparam AttrsType the attribute type to b set.
* \return reference to self.
*/
- template<typename AttrsType>
+ template <typename AttrsType>
inline OpRegistry& set_attrs_type();
/*!
* \brief Set the num_inputs
// return internal pointer to op.
inline OpNode* get();
// update the attribute OpMap
- TVM_DLL void UpdateAttr(const std::string& key,
- runtime::TVMRetValue value,
- int plevel);
+ TVM_DLL void UpdateAttr(const std::string& key, runtime::TVMRetValue value, int plevel);
};
/*!
#define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__)
// internal macros to make
-#define TVM_OP_REGISTER_VAR_DEF \
- static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op
+#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op
/*!
* \def TVM_REGISTER_OP
*
* \endcode
*/
-#define TVM_REGISTER_OP(OpName) \
- TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
- ::tvm::OpRegistry::Registry() \
- ->__REGISTER_OR_GET__(OpName) \
- .set_name()
+#define TVM_REGISTER_OP(OpName) \
+ TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
+ ::tvm::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name()
// implementations
-inline const OpNode* Op::operator->() const {
- return static_cast<const OpNode*>(get());
-}
+inline const OpNode* Op::operator->() const { return static_cast<const OpNode*>(get()); }
template <typename ValueType>
inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
return OpMap<ValueType>(Op::GetGenericAttr(key));
}
-inline bool Op::HasAttr(const std::string& key) {
- return Op::HasGenericAttr(key);
-}
+inline bool Op::HasAttr(const std::string& key) { return Op::HasGenericAttr(key); }
-inline OpNode* OpRegistry::get() {
- return const_cast<OpNode*>(op_.operator->());
-}
+inline OpNode* OpRegistry::get() { return const_cast<OpNode*>(op_.operator->()); }
-inline OpRegistry& OpRegistry::describe(
- const std::string& descr) { // NOLINT(*)
+inline OpRegistry& OpRegistry::describe(const std::string& descr) { // NOLINT(*)
get()->description = descr;
return *this;
}
-inline OpRegistry& OpRegistry::add_argument(const std::string& name,
- const std::string& type,
+inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::string& type,
const std::string& description) {
auto n = make_object<AttrFieldInfoNode>();
n->name = name;
inline OpRegistry& OpRegistry::add_type_rel(
const std::string& rel_name,
- runtime::TypedPackedFunc<bool(const Array<Type>&,
- int,
- const Attrs&,
- const TypeReporter&)> type_rel_func) {
+ runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
+ type_rel_func) {
auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
TypeRelationFn env_type_rel_func;
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func;
} else {
- runtime::Registry::Register(func_name)
- .set_body(type_rel_func.packed());
+ runtime::Registry::Register(func_name).set_body(type_rel_func.packed());
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func;
}
// A common example is sum(x, axis), where the choice of axis
// can affect the type of the function.
TypeConstraint type_rel =
- TypeRelation(env_type_rel_func,
- ty_call_args,
- arg_types.size(),
- Attrs());
+ TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs());
- auto func_type =
- FuncType(arg_types, out_param, type_params, {type_rel});
+ auto func_type = FuncType(arg_types, out_param, type_params, {type_rel});
get()->op_type = func_type;
return *this;
}
-template<typename AttrsType>
+template <typename AttrsType>
inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*)
get()->attrs_type_key = AttrsType::_type_key;
get()->attrs_type_index = AttrsType::RuntimeTypeIndex();
}
}
-inline const runtime::TVMRetValue&
-GenericOpMap::operator[](const Op& op) const {
+inline const runtime::TVMRetValue& GenericOpMap::operator[](const Op& op) const {
CHECK(op.defined());
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second != 0)
- << "Attribute " << attr_name_ << " has not been registered for Operator "
- << op->name;
+ << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name;
return data_[idx].first;
}
}
template <typename ValueType>
-inline ValueType OpMap<ValueType>::get(const Op& op,
- ValueType def_value) const {
+inline ValueType OpMap<ValueType>::get(const Op& op, ValueType def_value) const {
return map_.get<ValueType>(op, def_value);
}
template <typename ValueType>
-inline ValueType OpMap<ValueType>::get(const RelayExpr& expr,
- ValueType def_value) const {
+inline ValueType OpMap<ValueType>::get(const RelayExpr& expr, ValueType def_value) const {
return map_.get<ValueType>(expr, def_value);
}
#ifndef TVM_IR_SPAN_H_
#define TVM_IR_SPAN_H_
-#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
+#include <tvm/runtime/object.h>
+
#include <string>
namespace tvm {
}
bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const {
- return
- equal(source, other->source) &&
- equal(lineno, other->lineno) &&
- equal(col_offset, other->col_offset);
+ return equal(source, other->source) && equal(lineno, other->lineno) &&
+ equal(col_offset, other->col_offset);
}
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};
-
class Span : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
#ifndef TVM_IR_TENSOR_TYPE_H_
#define TVM_IR_TENSOR_TYPE_H_
-#include <tvm/ir/type.h>
#include <tvm/ir/expr.h>
+#include <tvm/ir/type.h>
namespace tvm {
/*!
}
bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {
- return
- equal(shape, other->shape) &&
- equal(dtype, other->dtype);
+ return equal(shape, other->shape) && equal(dtype, other->dtype);
}
void SHashReduce(SHashReducer hash_reduce) const {
#ifndef TVM_IR_TRANSFORM_H_
#define TVM_IR_TRANSFORM_H_
-#include <tvm/support/with.h>
-#include <tvm/runtime/container.h>
-#include <tvm/node/container.h>
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
+#include <tvm/node/container.h>
+#include <tvm/runtime/container.h>
+#include <tvm/support/with.h>
+
#include <string>
#include <utility>
*
*/
using TraceFunc =
- runtime::TypedPackedFunc<void(const IRModule& ir_module,
- const PassInfo& ctx,
- bool is_before)>;
+ runtime::TypedPackedFunc<void(const IRModule& ir_module, const PassInfo& ctx, bool is_before)>;
/*!
* \brief PassContextNode contains the information that a pass can rely on,
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};
-
/*!
* \brief PassContext that is used to configure the pass behavior.
*
* \param name Name of the pass.
* \param required The passes that are required to perform the current pass.
*/
- TVM_DLL PassInfo(int opt_level,
- std::string name,
- Array<runtime::String> required);
+ TVM_DLL PassInfo(int opt_level, std::string name, Array<runtime::String> required);
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
*
* \return The transformed module.
*/
- virtual IRModule operator()(IRModule mod,
- const PassContext& pass_ctx) const = 0;
+ virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
void VisitAttrs(AttrVisitor* v) {}
*
* \return The transformed module.
*/
- IRModule operator()(IRModule mod,
- const PassContext& pass_ctx) const {
+ IRModule operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(std::move(mod), pass_ctx);
*
* \return The created module pass.
*/
-TVM_DLL Pass CreateModulePass(
- const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
- int opt_level,
- const std::string& name,
- const Array<runtime::String>& required);
-
+TVM_DLL Pass
+CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
+ int opt_level, const std::string& name, const Array<runtime::String>& required);
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
#ifndef TVM_IR_TYPE_H_
#define TVM_IR_TYPE_H_
-#include <tvm/runtime/object.h>
-#include <tvm/runtime/data_type.h>
-#include <tvm/node/node.h>
-#include <tvm/node/container.h>
#include <tvm/ir/span.h>
+#include <tvm/node/container.h>
+#include <tvm/node/node.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/object.h>
+
#include <string>
namespace tvm {
*/
runtime::DataType dtype;
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &dtype);
- }
+ void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); }
bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(dtype);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); }
static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
-
/*
* \brief Managed reference to PrimTypeNode.
* \sa PrimTypeNode
TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
};
-
/*!
* \brief Low-level raw pointer type.
*
*/
Type element_type;
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("element_type", &element_type);
- }
+ void VisitAttrs(AttrVisitor* v) { v->Visit("element_type", &element_type); }
bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
return equal(element_type, other->element_type);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(element_type);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(element_type); }
static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode);
};
-
/*! \brief Possible kinds of TypeVars. */
enum TypeKind : int {
kType = 0,
}
bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const {
- return
- equal(kind, other->kind) &&
- equal.FreeVarEqualImpl(this, other);
+ return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const {
// name matters for now in global type var.
- return
- equal(name_hint, other->name_hint) &&
- equal.FreeVarEqualImpl(this, other);
+ return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
return equal(fields, other->fields);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(fields);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); }
static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
/*!
* \return a type that represents void.
*/
-inline Type VoidType() {
- return TupleType::Empty();
-}
+inline Type VoidType() { return TupleType::Empty(); }
/*!
* \brief Check whether the tyep represents void.
bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const {
// type params first as they defines type vars.
- return
- equal.DefEqual(type_params, other->type_params) &&
- equal(arg_types, other->arg_types) &&
- equal(ret_type, other->ret_type) &&
- equal(type_constraints, other->type_constraints);
+ return equal.DefEqual(type_params, other->type_params) && equal(arg_types, other->arg_types) &&
+ equal(ret_type, other->ret_type) && equal(type_constraints, other->type_constraints);
}
void SHashReduce(SHashReducer hash_reduce) const {
* \param type_constraints The type constraints.
* \sa FuncTypeNode for more docs about these fields.
*/
- TVM_DLL FuncType(Array<Type> arg_types,
- Type ret_type,
- Array<TypeVar> type_params,
+ TVM_DLL FuncType(Array<Type> arg_types, Type ret_type, Array<TypeVar> type_params,
Array<TypeConstraint> type_constraints);
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
}
bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
- return
- equal(kind, other->kind) &&
- equal.FreeVarEqualImpl(this, other);
+ return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(kind);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(kind); }
static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
};
-
/*!
* \brief Reference Type High-level Relay IR.
*
return equal(value, other->value);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(value);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
#define TVM_IR_TYPE_FUNCTOR_H_
#include <tvm/node/functor.h>
-#include <tvm/relay/expr.h>
#include <tvm/relay/adt.h>
+#include <tvm/relay/expr.h>
+
#include <string>
-#include <vector>
#include <utility>
+#include <vector>
namespace tvm {
class TypeFunctor;
// functions to be overriden.
-#define TYPE_FUNCTOR_DEFAULT \
+#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }
-
-#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \
- vtable.template set_dispatch<OP>( \
- [](const ObjectRef& n, TSelf* self, Args... args) { \
- return self->VisitType_(static_cast<const OP*>(n.get()), \
- std::forward<Args>(args)...); \
- });
+#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
+ return self->VisitType_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+ });
template <typename R, typename... Args>
class TypeFunctor<R(const Type& n, Args...)> {
* \param args Additional arguments.
* \return The result of the call
*/
- R operator()(const Type& n, Args... args) {
- return VisitType(n, std::forward<Args>(args)...);
- }
+ R operator()(const Type& n, Args... args) { return VisitType(n, std::forward<Args>(args)...); }
/*!
* \brief The functor call.
* \param n The expression node.
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
- virtual R VisitType_(const TensorTypeNode* op,
- Args... args) TYPE_FUNCTOR_DEFAULT;
+ virtual R VisitType_(const TensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
/*!
* \brief A type visitor that recursively visit types.
*/
-class TVM_DLL TypeVisitor :
- public TypeFunctor<void(const Type& n)> {
+class TVM_DLL TypeVisitor : public TypeFunctor<void(const Type& n)> {
public:
void VisitType_(const TypeVarNode* op) override;
void VisitType_(const IncompleteTypeNode* op) override;
/*!
* \brief TypeMutator that mutates expressions.
*/
-class TVM_DLL TypeMutator :
- public TypeFunctor<Type(const Type& n)> {
+class TVM_DLL TypeMutator : public TypeFunctor<Type(const Type& n)> {
public:
Type VisitType(const Type& t) override;
Type VisitType_(const TypeVarNode* op) override;
#ifndef TVM_IR_TYPE_RELATION_H_
#define TVM_IR_TYPE_RELATION_H_
-#include <tvm/ir/type.h>
-#include <tvm/ir/module.h>
-#include <tvm/ir/env_func.h>
#include <tvm/ir/attrs.h>
+#include <tvm/ir/env_func.h>
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
namespace tvm {
}
bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const {
- return
- equal(func, other->func) &&
- equal(args, other->args);
+ return equal(func, other->func) && equal(args, other->args);
}
void SHashReduce(SHashReducer hash_reduce) const {
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
- TVM_DLL virtual bool Assert(const PrimExpr& cond)= 0;
+ TVM_DLL virtual bool Assert(const PrimExpr& cond) = 0;
/*!
* \brief assert shape expression equals each other.
* \param lhs The left operand.
class TypeReporter : public ObjectRef {
public:
TypeReporter() {}
- explicit TypeReporter(ObjectPtr<Object> n) : ObjectRef(n) {
- }
+ explicit TypeReporter(ObjectPtr<Object> n) : ObjectRef(n) {}
TypeReporterNode* operator->() const {
- return const_cast<TypeReporterNode*>(
- static_cast<const TypeReporterNode*>(get()));
+ return const_cast<TypeReporterNode*>(static_cast<const TypeReporterNode*>(get()));
}
using ContainerType = TypeReporterNode;
};
* \return false if This relation cannot be resolved.
* true if this relation has been resolved.
*/
-using TypeRelationFn =
- TypedEnvFunc<bool(const Array<Type>& args,
- int num_inputs,
- const Attrs& attrs,
- const TypeReporter& reporter)>;
+using TypeRelationFn = TypedEnvFunc<bool(const Array<Type>& args, int num_inputs,
+ const Attrs& attrs, const TypeReporter& reporter)>;
/*!
* \brief User defined type relation, it is an input-output relation on types.
}
bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const {
- return
- equal(func, other->func) &&
- equal(args, other->args) &&
- equal(num_inputs, other->num_inputs) &&
- equal(attrs, other->attrs);
+ return equal(func, other->func) && equal(args, other->args) &&
+ equal(num_inputs, other->num_inputs) && equal(attrs, other->attrs);
}
void SHashReduce(SHashReducer hash_reduce) const {
* \param attrs Attributes to the relation function.
* \sa TypeRelationNode for more docs about these fields.
*/
- TVM_DLL TypeRelation(TypeRelationFn func,
- Array<Type> args,
- int num_inputs,
- Attrs attrs);
+ TVM_DLL TypeRelation(TypeRelationFn func, Array<Type> args, int num_inputs, Attrs attrs);
TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
};
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_
-#include <tvm/runtime/object.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
+#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/container.h>
-#include <type_traits>
-#include <vector>
#include <initializer_list>
+#include <string>
+#include <type_traits>
#include <unordered_map>
#include <utility>
-#include <string>
+#include <vector>
namespace tvm {
-using runtime::String;
-using runtime::StringObj;
+using runtime::make_object;
using runtime::Object;
+using runtime::ObjectEqual;
+using runtime::ObjectHash;
using runtime::ObjectPtr;
using runtime::ObjectRef;
-using runtime::make_object;
-using runtime::ObjectHash;
-using runtime::ObjectEqual;
+using runtime::String;
+using runtime::StringObj;
/*! \brief array node content in array */
class ArrayNode : public Object {
class MapNode : public Object {
public:
/*! \brief The corresponding conatiner type */
- using ContainerType = std::unordered_map<
- ObjectRef,
- ObjectRef,
- ObjectHash, ObjectEqual>;
+ using ContainerType = std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual>;
/*! \brief the data content */
ContainerType data;
TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object);
};
-
/*! \brief specialized map node with string as key */
class StrMapNode : public Object {
public:
* \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type.
*/
-template<typename Converter,
- typename TIter>
+template <typename Converter, typename TIter>
class IterAdapter {
public:
using difference_type = typename std::iterator_traits<TIter>::difference_type;
using value_type = typename Converter::ResultType;
using pointer = typename Converter::ResultType*;
- using reference = typename Converter::ResultType&; // NOLINT(*)
+ using reference = typename Converter::ResultType&; // NOLINT(*)
using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
explicit IterAdapter(TIter iter) : iter_(iter) {}
++iter_;
return *this;
}
- inline IterAdapter operator+(difference_type offset) const {
- return IterAdapter(iter_ + offset);
- }
+ inline IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); }
- template<typename T = IterAdapter>
+ template <typename T = IterAdapter>
typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
- typename T::difference_type>::type
- inline operator-(const IterAdapter& rhs) const {
+ typename T::difference_type>::type inline
+ operator-(const IterAdapter& rhs) const {
return iter_ - rhs.iter_;
}
- inline bool operator==(IterAdapter other) const {
- return iter_ == other.iter_;
- }
- inline bool operator!=(IterAdapter other) const {
- return !(*this == other);
- }
- inline const value_type operator*() const {
- return Converter::convert(*iter_);
- }
+ inline bool operator==(IterAdapter other) const { return iter_ == other.iter_; }
+ inline bool operator!=(IterAdapter other) const { return !(*this == other); }
+ inline const value_type operator*() const { return Converter::convert(*iter_); }
private:
TIter iter_;
* operator[] only provide const acces, use Set to mutate the content.
* \tparam T The content NodeRef type.
*/
-template<typename T,
- typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type >
+template <typename T,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
class Array : public ObjectRef {
public:
/*!
* \brief default constructor
*/
- Array() {
- data_ = make_object<ArrayNode>();
- }
+ Array() { data_ = make_object<ArrayNode>(); }
/*!
* \brief move constructor
* \param other source
*/
- Array(Array<T> && other) : ObjectRef() { // NOLINT(*)
+ Array(Array<T>&& other) : ObjectRef() { // NOLINT(*)
data_ = std::move(other.data_);
}
/*!
* \brief copy constructor
* \param other source
*/
- Array(const Array<T> &other) : ObjectRef() { // NOLINT(*)
+ Array(const Array<T>& other) : ObjectRef() { // NOLINT(*)
data_ = std::move(other.data_);
}
/*!
* \param end end of iterator
* \tparam IterType The type of iterator
*/
- template<typename IterType>
+ template <typename IterType>
Array(IterType begin, IterType end) {
assign(begin, end);
}
* \brief constructor from initializer list
* \param init The initalizer list
*/
- Array(std::initializer_list<T> init) { // NOLINT(*)
+ Array(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
- Array(const std::vector<T>& init) { // NOLINT(*)
+ Array(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \param other The source of assignment
* \return reference to self.
*/
- Array<T>& operator=(Array<T> && other) {
+ Array<T>& operator=(Array<T>&& other) {
data_ = std::move(other.data_);
return *this;
}
* \param other The source of assignment
* \return reference to self.
*/
- Array<T>& operator=(const Array<T> & other) {
+ Array<T>& operator=(const Array<T>& other) {
data_ = other.data_;
return *this;
}
* \param end end of iterator
* \tparam IterType The type of iterator
*/
- template<typename IterType>
+ template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_object<ArrayNode>();
for (IterType it = begin; it != end; ++it) {
* \return the i-th element.
*/
inline const T operator[](size_t i) const {
- return DowncastNoCheck<T>(
- static_cast<const ArrayNode*>(data_.get())->data[i]);
+ return DowncastNoCheck<T>(static_cast<const ArrayNode*>(data_.get())->data[i]);
}
/*! \return The size of the array */
inline size_t size() const {
* \return Handle to the internal node container(which ganrantees to be unique)
*/
inline ArrayNode* CopyOnWrite() {
- if (data_.get() == nullptr || !data_.unique()) {
+ if (data_.get() == nullptr || !data_.unique()) {
ObjectPtr<ArrayNode> n = make_object<ArrayNode>();
n->data = static_cast<ArrayNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
n->data[i] = value;
}
/*! \return whether array is empty */
- inline bool empty() const {
- return size() == 0;
- }
+ inline bool empty() const { return size() == 0; }
/*!
* \brief Helper function to apply fmutate to mutate an array.
* \param fmutate The transformation function T -> T.
* \tparam F the type of the mutation function.
* \note This function performs copy on write optimization.
*/
- template<typename F>
+ template <typename F>
inline void MutateByApply(F fmutate) {
ArrayNode* ptr = static_cast<ArrayNode*>(data_.get());
if (ptr == nullptr) return;
struct ValueConverter {
using ResultType = T;
- static inline T convert(const ObjectRef& n) {
- return DowncastNoCheck<T>(n);
- }
+ static inline T convert(const ObjectRef& n) { return DowncastNoCheck<T>(n); }
};
- using iterator = IterAdapter<ValueConverter,
- std::vector<ObjectRef>::const_iterator>;
+ using iterator = IterAdapter<ValueConverter, std::vector<ObjectRef>::const_iterator>;
- using reverse_iterator = IterAdapter<
- ValueConverter,
- std::vector<ObjectRef>::const_reverse_iterator>;
+ using reverse_iterator =
+ IterAdapter<ValueConverter, std::vector<ObjectRef>::const_reverse_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
* \tparam K The key NodeRef type.
* \tparam V The value NodeRef type.
*/
-template<typename K,
- typename V,
- typename = typename std::enable_if<
- std::is_base_of<ObjectRef, K>::value ||
- std::is_base_of<std::string, K>::value >::type,
- typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
+template <typename K, typename V,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value ||
+ std::is_base_of<std::string, K>::value>::type,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
class Map : public ObjectRef {
public:
/*!
* \brief default constructor
*/
- Map() {
- data_ = make_object<MapNode>();
- }
+ Map() { data_ = make_object<MapNode>(); }
/*!
* \brief move constructor
* \param other source
*/
- Map(Map<K, V> && other) { // NOLINT(*)
+ Map(Map<K, V>&& other) { // NOLINT(*)
data_ = std::move(other.data_);
}
/*!
* \brief copy constructor
* \param other source
*/
- Map(const Map<K, V> &other) : ObjectRef(other.data_) { // NOLINT(*)
+ Map(const Map<K, V>& other) : ObjectRef(other.data_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
* \param end end of iterator
* \tparam IterType The type of iterator
*/
- template<typename IterType>
+ template <typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
* \brief constructor from initializer list
* \param init The initalizer list
*/
- Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
+ Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
- template<typename Hash, typename Equal>
- Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
+ template <typename Hash, typename Equal>
+ Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \param other The source of assignment
* \return reference to self.
*/
- Map<K, V>& operator=(Map<K, V> && other) {
+ Map<K, V>& operator=(Map<K, V>&& other) {
data_ = std::move(other.data_);
return *this;
}
* \param other The source of assignment
* \return reference to self.
*/
- Map<K, V>& operator=(const Map<K, V> & other) {
+ Map<K, V>& operator=(const Map<K, V>& other) {
data_ = other.data_;
return *this;
}
* \param end end of iterator
* \tparam IterType The type of iterator
*/
- template<typename IterType>
+ template <typename IterType>
void assign(IterType begin, IterType end) {
ObjectPtr<MapNode> n = make_object<MapNode>();
for (IterType i = begin; i != end; ++i) {
* \return the corresonding element.
*/
inline const V operator[](const K& key) const {
- return DowncastNoCheck<V>(
- static_cast<const MapNode*>(data_.get())->data.at(key));
+ return DowncastNoCheck<V>(static_cast<const MapNode*>(data_.get())->data.at(key));
}
/*!
* \brief Read element from map.
* \return the corresonding element.
*/
inline const V at(const K& key) const {
- return DowncastNoCheck<V>(
- static_cast<const MapNode*>(data_.get())->data.at(key));
+ return DowncastNoCheck<V>(static_cast<const MapNode*>(data_.get())->data.at(key));
}
/*! \return The size of the array */
inline size_t size() const {
* \return Handle to the internal node container(which ganrantees to be unique)
*/
inline MapNode* CopyOnWrite() {
- if (data_.get() == nullptr || !data_.unique()) {
+ if (data_.get() == nullptr || !data_.unique()) {
ObjectPtr<MapNode> n = make_object<MapNode>();
n->data = static_cast<const MapNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
}
/*! \return whether array is empty */
- inline bool empty() const {
- return size() == 0;
- }
+ inline bool empty() const { return size() == 0; }
/*! \brief specify container node */
using ContainerType = MapNode;
struct ValueConverter {
using ResultType = std::pair<K, V>;
- static inline ResultType convert(const std::pair<
- ObjectRef,
- ObjectRef>& n) {
- return std::make_pair(DowncastNoCheck<K>(n.first),
- DowncastNoCheck<V>(n.second));
+ static inline ResultType convert(const std::pair<ObjectRef, ObjectRef>& n) {
+ return std::make_pair(DowncastNoCheck<K>(n.first), DowncastNoCheck<V>(n.second));
}
};
- using iterator = IterAdapter<
- ValueConverter, MapNode::ContainerType::const_iterator>;
+ using iterator = IterAdapter<ValueConverter, MapNode::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
}
/*! \return begin iterator */
inline iterator find(const K& key) const {
- return iterator(
- static_cast<const MapNode*>(data_.get())->data.find(key));
+ return iterator(static_cast<const MapNode*>(data_.get())->data.find(key));
}
};
// specialize of string map
-template<typename V, typename T1, typename T2>
+template <typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public ObjectRef {
public:
// for code reuse
- Map() {
- data_ = make_object<StrMapNode>();
- }
- Map(Map<std::string, V> && other) { // NOLINT(*)
+ Map() { data_ = make_object<StrMapNode>(); }
+ Map(Map<std::string, V>&& other) { // NOLINT(*)
data_ = std::move(other.data_);
}
- Map(const Map<std::string, V> &other) : ObjectRef(other.data_) { // NOLINT(*)
+ Map(const Map<std::string, V>& other) : ObjectRef(other.data_) { // NOLINT(*)
}
explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
- template<typename IterType>
+ template <typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
- Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
+ Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
- template<typename Hash, typename Equal>
- Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
+ template <typename Hash, typename Equal>
+ Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
- Map<std::string, V>& operator=(Map<std::string, V> && other) {
+ Map<std::string, V>& operator=(Map<std::string, V>&& other) {
data_ = std::move(other.data_);
return *this;
}
- Map<std::string, V>& operator=(const Map<std::string, V> & other) {
+ Map<std::string, V>& operator=(const Map<std::string, V>& other) {
data_ = other.data_;
return *this;
}
- template<typename IterType>
+ template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_object<StrMapNode>();
for (IterType i = begin; i != end; ++i) {
data_ = std::move(n);
}
inline const V operator[](const std::string& key) const {
- return DowncastNoCheck<V>(
- static_cast<const StrMapNode*>(data_.get())->data.at(key));
+ return DowncastNoCheck<V>(static_cast<const StrMapNode*>(data_.get())->data.at(key));
}
inline const V at(const std::string& key) const {
- return DowncastNoCheck<V>(
- static_cast<const StrMapNode*>(data_.get())->data.at(key));
+ return DowncastNoCheck<V>(static_cast<const StrMapNode*>(data_.get())->data.at(key));
}
inline size_t size() const {
if (data_.get() == nullptr) return 0;
return static_cast<const StrMapNode*>(data_.get())->data.count(key);
}
inline StrMapNode* CopyOnWrite() {
- if (data_.get() == nullptr || !data_.unique()) {
+ if (data_.get() == nullptr || !data_.unique()) {
ObjectPtr<StrMapNode> n = make_object<StrMapNode>();
n->data = static_cast<const StrMapNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
StrMapNode* n = this->CopyOnWrite();
n->data[key] = value;
}
- inline bool empty() const {
- return size() == 0;
- }
+ inline bool empty() const { return size() == 0; }
using ContainerType = StrMapNode;
struct ValueConverter {
using ResultType = std::pair<std::string, V>;
- static inline ResultType convert(const std::pair<
- std::string,
- ObjectRef>& n) {
+ static inline ResultType convert(const std::pair<std::string, ObjectRef>& n) {
return std::make_pair(n.first, DowncastNoCheck<V>(n.second));
}
};
- using iterator = IterAdapter<
- ValueConverter, StrMapNode::ContainerType::const_iterator>;
+ using iterator = IterAdapter<ValueConverter, StrMapNode::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
namespace tvm {
namespace runtime {
// Additional overloads for PackedFunc checking.
-template<typename T>
+template <typename T>
struct ObjectTypeChecker<Array<T> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
}
return true;
}
- static std::string TypeName() {
- return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
- }
+ static std::string TypeName() { return "List[" + ObjectTypeChecker<T>::TypeName() + "]"; }
};
-template<typename V>
+template <typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
}
return true;
}
- static std::string TypeName() {
- return "Map[str, " +
- ObjectTypeChecker<V>::TypeName()+ ']';
- }
+ static std::string TypeName() { return "Map[str, " + ObjectTypeChecker<V>::TypeName() + ']'; }
};
-template<typename K, typename V>
+template <typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
return true;
}
static std::string TypeName() {
- return "Map[" +
- ObjectTypeChecker<K>::TypeName() +
- ", " +
- ObjectTypeChecker<V>::TypeName()+ ']';
+ return "Map[" + ObjectTypeChecker<K>::TypeName() + ", " + ObjectTypeChecker<V>::TypeName() +
+ ']';
}
};
} // namespace runtime
#include <dmlc/logging.h>
#include <tvm/runtime/object.h>
-#include <vector>
#include <type_traits>
#include <utility>
+#include <vector>
namespace tvm {
* \tparam FType function signiture
* This type if only defined for FType with function signature
*/
-template<typename FType>
+template <typename FType>
class NodeFunctor;
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
class NodeFunctor<R(const ObjectRef& n, Args...)> {
private:
/*! \brief internal function pointer type */
- typedef R (*FPointer)(const ObjectRef&n, Args...);
+ typedef R (*FPointer)(const ObjectRef& n, Args...);
/*! \brief refer to itself. */
- using TSelf = NodeFunctor<R (const ObjectRef& n, Args...)>;
+ using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
/*! \brief internal function table */
std::vector<FPointer> func_;
* \return The result.
*/
R operator()(const ObjectRef& n, Args... args) const {
- CHECK(can_dispatch(n))
- << "NodeFunctor calls un-registered function on type "
- << n->GetTypeKey();
+ CHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type "
+ << n->GetTypeKey();
return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
}
/*!
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
- template<typename TNode>
+ template <typename TNode>
TSelf& set_dispatch(FPointer f) { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex();
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
- CHECK(func_[tindex] == nullptr)
- << "Dispatch for " << TNode::_type_key
- << " is already set";
+ CHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set";
func_[tindex] = f;
return *this;
}
/*!
- * \brief unset the dispacher for type TNode
- *
- * \tparam TNode the type of Node to be dispatched.
- * \return reference to self.
- */
- template<typename TNode>
+ * \brief unset the dispacher for type TNode
+ *
+ * \tparam TNode the type of Node to be dispatched.
+ * \return reference to self.
+ */
+ template <typename TNode>
TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex();
- CHECK_LT(tindex, func_.size())
- << "clear_dispatch: index out of range";
+ CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
func_[tindex] = nullptr;
return *this;
}
};
-
-#define TVM_REG_FUNC_VAR_DEF(ClsName) \
- static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
+#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName
/*!
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
* \param ClsName The name of the class
* \param FField The static function that returns a singleton of NodeFunctor.
*/
-#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
- TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = \
- ClsName::FField()
+#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
+ TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField()
} // namespace tvm
#endif // TVM_NODE_FUNCTOR_H_
#ifndef TVM_NODE_NODE_H_
#define TVM_NODE_NODE_H_
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/container.h>
-#include <tvm/runtime/object.h>
-#include <tvm/runtime/memory.h>
+#include <tvm/node/container.h>
#include <tvm/node/reflection.h>
#include <tvm/node/repr_printer.h>
-#include <tvm/node/container.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/object.h>
#include <string>
-#include <vector>
-#include <utility>
#include <type_traits>
+#include <utility>
+#include <vector>
namespace tvm {
-using runtime::TypeIndex;
+using runtime::Downcast;
+using runtime::GetRef;
+using runtime::make_object;
using runtime::Object;
+using runtime::ObjectEqual;
+using runtime::ObjectHash;
using runtime::ObjectPtr;
using runtime::ObjectRef;
-using runtime::GetRef;
-using runtime::Downcast;
-using runtime::ObjectHash;
-using runtime::ObjectEqual;
-using runtime::make_object;
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
+using runtime::TypeIndex;
} // namespace tvm
#endif // TVM_NODE_NODE_H_
#ifndef TVM_NODE_REFLECTION_H_
#define TVM_NODE_REFLECTION_H_
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/object.h>
+#include <tvm/runtime/data_type.h>
#include <tvm/runtime/memory.h>
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/ndarray.h>
-#include <tvm/runtime/data_type.h>
-#include <tvm/node/structural_equal.h>
-#include <tvm/node/structural_hash.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
-#include <vector>
#include <string>
#include <type_traits>
+#include <vector>
namespace tvm {
*/
class AttrVisitor {
public:
-//! \cond Doxygen_Suppress
+ //! \cond Doxygen_Suppress
TVM_DLL virtual ~AttrVisitor() = default;
TVM_DLL virtual void Visit(const char* key, double* value) = 0;
TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
- template<typename ENum,
- typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+ template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
-//! \endcond
+ //! \endcond
};
/*!
TVM_DLL static ReflectionVTable* Global();
class Registry;
- template<typename T, typename TraitName>
+ template <typename T, typename TraitName>
inline Registry Register();
private:
std::vector<FVisitAttrs> fvisit_attrs_;
/*! \brief Structural equal function. */
std::vector<FSEqualReduce> fsequal_reduce_;
- /*! \brief Structural hash function. */
+ /*! \brief Structural hash function. */
std::vector<FSHashReduce> fshash_reduce_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
class ReflectionVTable::Registry {
public:
Registry(ReflectionVTable* parent, uint32_t type_index)
- : parent_(parent), type_index_(type_index) { }
+ : parent_(parent), type_index_(type_index) {}
/*!
* \brief Set fcreate function.
* \param f The creator function.
uint32_t type_index_;
};
-
-#define TVM_REFLECTION_REG_VAR_DEF \
- static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry \
- __make_reflectiion
+#define TVM_REFLECTION_REG_VAR_DEF \
+ static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflectiion
/*!
* \brief Directly register reflection VTable.
* \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE.
* And can be used to register the related reflection functions for runtime objects.
*/
-#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
- TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
- ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>() \
+#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
+ TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
+ ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>()
/*!
* \brief Register a node type to object registry and reflection registry.
* \param TypeName The name of the type.
* \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well.
*/
-#define TVM_REGISTER_NODE_TYPE(TypeName) \
- TVM_REGISTER_OBJECT_TYPE(TypeName); \
+#define TVM_REGISTER_NODE_TYPE(TypeName) \
+ TVM_REGISTER_OBJECT_TYPE(TypeName); \
TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
- .set_creator([](const std::string&) -> ObjectPtr<Object> { \
- return ::tvm::runtime::make_object<TypeName>(); \
- })
-
+ .set_creator([](const std::string&) -> ObjectPtr<Object> { \
+ return ::tvm::runtime::make_object<TypeName>(); \
+ })
// Implementation details
namespace detail {
-template<typename T,
- bool = T::_type_has_method_visit_attrs>
+template <typename T, bool = T::_type_has_method_visit_attrs>
struct ImplVisitAttrs {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
};
-template<typename T>
+template <typename T>
struct ImplVisitAttrs<T, true> {
- static void VisitAttrs(T* self, AttrVisitor* v) {
- self->VisitAttrs(v);
- }
+ static void VisitAttrs(T* self, AttrVisitor* v) { self->VisitAttrs(v); }
};
-template<typename T,
- bool = T::_type_has_method_sequal_reduce>
+template <typename T, bool = T::_type_has_method_sequal_reduce>
struct ImplSEqualReduce {
static constexpr const std::nullptr_t SEqualReduce = nullptr;
};
-template<typename T>
+template <typename T>
struct ImplSEqualReduce<T, true> {
static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) {
return self->SEqualReduce(other, equal);
}
};
-template<typename T,
- bool = T::_type_has_method_shash_reduce>
+template <typename T, bool = T::_type_has_method_shash_reduce>
struct ImplSHashReduce {
static constexpr const std::nullptr_t SHashReduce = nullptr;
};
-template<typename T>
+template <typename T>
struct ImplSHashReduce<T, true> {
static void SHashReduce(const T* self, SHashReducer hash_reduce) {
self->SHashReduce(hash_reduce);
}
};
-template<typename T>
-struct ReflectionTrait :
- public ImplVisitAttrs<T>,
- public ImplSEqualReduce<T>,
- public ImplSHashReduce<T> {
-};
+template <typename T>
+struct ReflectionTrait : public ImplVisitAttrs<T>,
+ public ImplSEqualReduce<T>,
+ public ImplSHashReduce<T> {};
-template<typename T, typename TraitName,
- bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
+template <typename T, typename TraitName,
+ bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
struct SelectVisitAttrs {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
};
-template<typename T, typename TraitName>
+template <typename T, typename TraitName>
struct SelectVisitAttrs<T, TraitName, false> {
static void VisitAttrs(Object* self, AttrVisitor* v) {
TraitName::VisitAttrs(static_cast<T*>(self), v);
}
};
-template<typename T, typename TraitName,
- bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
+template <typename T, typename TraitName,
+ bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
struct SelectSEqualReduce {
static constexpr const std::nullptr_t SEqualReduce = nullptr;
};
-template<typename T, typename TraitName>
+template <typename T, typename TraitName>
struct SelectSEqualReduce<T, TraitName, false> {
- static bool SEqualReduce(const Object* self,
- const Object* other,
- SEqualReducer equal) {
- return TraitName::SEqualReduce(static_cast<const T*>(self),
- static_cast<const T*>(other),
+ static bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) {
+ return TraitName::SEqualReduce(static_cast<const T*>(self), static_cast<const T*>(other),
equal);
}
};
-template<typename T, typename TraitName,
- bool = std::is_null_pointer<decltype(TraitName::SHashReduce)>::value>
+template <typename T, typename TraitName,
+ bool = std::is_null_pointer<decltype(TraitName::SHashReduce)>::value>
struct SelectSHashReduce {
static constexpr const std::nullptr_t SHashReduce = nullptr;
};
-template<typename T, typename TraitName>
+template <typename T, typename TraitName>
struct SelectSHashReduce<T, TraitName, false> {
- static void SHashReduce(const Object* self,
- SHashReducer hash_reduce) {
- return TraitName::SHashReduce(static_cast<const T*>(self),
- hash_reduce);
+ static void SHashReduce(const Object* self, SHashReducer hash_reduce) {
+ return TraitName::SHashReduce(static_cast<const T*>(self), hash_reduce);
}
};
} // namespace detail
-template<typename T, typename TraitName>
-inline ReflectionVTable::Registry
-ReflectionVTable::Register() {
+template <typename T, typename TraitName>
+inline ReflectionVTable::Registry ReflectionVTable::Register() {
uint32_t tindex = T::RuntimeTypeIndex();
if (tindex >= fvisit_attrs_.size()) {
fvisit_attrs_.resize(tindex + 1, nullptr);
fshash_reduce_.resize(tindex + 1, nullptr);
}
// functor that implemnts the redirection.
- fvisit_attrs_[tindex] =
- ::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs;
+ fvisit_attrs_[tindex] = ::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs;
- fsequal_reduce_[tindex] =
- ::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce;
+ fsequal_reduce_[tindex] = ::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce;
- fshash_reduce_[tindex] =
- ::tvm::detail::SelectSHashReduce<T, TraitName>::SHashReduce;
+ fshash_reduce_[tindex] = ::tvm::detail::SelectSHashReduce<T, TraitName>::SHashReduce;
return Registry(this, tindex);
}
-inline void ReflectionVTable::
-VisitAttrs(Object* self, AttrVisitor* visitor) const {
+inline void ReflectionVTable::VisitAttrs(Object* self, AttrVisitor* visitor) const {
uint32_t tindex = self->type_index();
if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << self->GetTypeKey()
fvisit_attrs_[tindex](self, visitor);
}
-inline bool ReflectionVTable::GetReprBytes(const Object* self,
- std::string* repr_bytes) const {
+inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr_bytes) const {
uint32_t tindex = self->type_index();
if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) {
if (repr_bytes != nullptr) {
#define TVM_NODE_REPR_PRINTER_H_
#include <tvm/node/functor.h>
+
#include <iostream>
namespace tvm {
#ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
#define TVM_NODE_STRUCTURAL_EQUAL_H_
-#include <tvm/runtime/data_type.h>
-#include <tvm/node/functor.h>
#include <tvm/node/container.h>
+#include <tvm/node/functor.h>
+#include <tvm/runtime/data_type.h>
+
#include <string>
namespace tvm {
return diff > -atol && diff < atol;
}
- bool operator()(const int64_t& lhs, const int64_t& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const int& lhs, const int& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const bool& lhs, const bool& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const std::string& lhs, const std::string& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const DataType& lhs, const DataType& rhs) const {
- return lhs == rhs;
- }
- template<typename ENum,
- typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+ bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
+ bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; }
+ bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
+ bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; }
+ bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; }
+ bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; }
+ template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
return lhs == rhs;
}
* \note This function may save the equality condition of (lhs == rhs) in an internal
* stack and try to resolve later.
*/
- virtual bool SEqualReduce(const ObjectRef& lhs,
- const ObjectRef& rhs,
- bool map_free_vars) = 0;
+ virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0;
/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
* \param rhs The right operand.
* \return the immediate check result.
*/
- template<typename T>
+ template <typename T>
bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
// quick specialization for Array to reduce amount of recursion
// depth as array comparison is pretty common.
}
/*! \return Get the internal handler. */
- Handler* operator->() const {
- return handler_;
- }
+ Handler* operator->() const { return handler_; }
private:
/*! \brief Internal class pointer. */
#ifndef TVM_NODE_STRUCTURAL_HASH_H_
#define TVM_NODE_STRUCTURAL_HASH_H_
-#include <tvm/runtime/data_type.h>
-#include <tvm/node/functor.h>
#include <tvm/node/container.h>
-#include <string>
+#include <tvm/node/functor.h>
+#include <tvm/runtime/data_type.h>
+
#include <functional>
+#include <string>
namespace tvm {
*/
class BaseValueHash {
public:
- size_t operator()(const double& key) const {
- return std::hash<double>()(key);
- }
+ size_t operator()(const double& key) const { return std::hash<double>()(key); }
- size_t operator()(const int64_t& key) const {
- return std::hash<int64_t>()(key);
- }
+ size_t operator()(const int64_t& key) const { return std::hash<int64_t>()(key); }
- size_t operator()(const uint64_t& key) const {
- return std::hash<uint64_t>()(key);
- }
+ size_t operator()(const uint64_t& key) const { return std::hash<uint64_t>()(key); }
- size_t operator()(const int& key) const {
- return std::hash<int>()(key);
- }
+ size_t operator()(const int& key) const { return std::hash<int>()(key); }
- size_t operator()(const bool& key) const {
- return std::hash<bool>()(key);
- }
+ size_t operator()(const bool& key) const { return std::hash<bool>()(key); }
- size_t operator()(const std::string& key) const {
- return std::hash<std::string>()(key);
- }
+ size_t operator()(const std::string& key) const { return std::hash<std::string>()(key); }
size_t operator()(const runtime::DataType& key) const {
- return std::hash<int32_t>()(
- static_cast<int32_t>(key.code()) |
- (static_cast<int32_t>(key.bits()) << 8) |
- (static_cast<int32_t>(key.lanes()) << 16));
+ return std::hash<int32_t>()(static_cast<int32_t>(key.code()) |
+ (static_cast<int32_t>(key.bits()) << 8) |
+ (static_cast<int32_t>(key.lanes()) << 16));
}
- template<typename ENum,
- typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+ template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& key) const {
return std::hash<size_t>()(static_cast<size_t>(key));
}
* \brief Push hash of key to the current sequence of hash values.
* \param key The key to be hashed.
*/
- template<typename T,
- typename = typename std::enable_if<
- !std::is_base_of<ObjectRef, T>::value>::type>
+ template <typename T,
+ typename = typename std::enable_if<!std::is_base_of<ObjectRef, T>::value>::type>
void operator()(const T& key) const {
// handle normal values.
handler_->SHashReduceHashedValue(BaseValueHash()(key));
* \brief Push hash of key to the current sequence of hash values.
* \param key The key to be hashed.
*/
- void operator()(const ObjectRef& key) const {
- return handler_->SHashReduce(key, map_free_vars_);
- }
+ void operator()(const ObjectRef& key) const { return handler_->SHashReduce(key, map_free_vars_); }
/*!
* \brief Push hash of key to the current sequence of hash values.
* \param key The key to be hashed.
* \note This function indicate key could contain var defintions.
*/
- void DefHash(const ObjectRef& key) const {
- return handler_->SHashReduce(key, true);
- }
+ void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); }
/*!
* \brief Implementation for hash for a free var.
* \param var The variable.
}
/*! \return Get the internal handler. */
- Handler* operator->() const {
- return handler_;
- }
+ Handler* operator->() const { return handler_; }
private:
/*! \brief Internal class pointer. */
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_
-#include <tvm/ir/attrs.h>
#include <tvm/ir/adt.h>
+#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
-#include <string>
+
#include <functional>
+#include <string>
#include <utility>
namespace tvm {
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
public:
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("span", &span);
- }
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }
- bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const {
- return true;
- }
+ bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; }
- void SHashReduce(SHashReducer hash_reduce) const {
- }
+ void SHashReduce(SHashReducer hash_reduce) const {}
static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
return equal.DefEqual(var, other->var);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce.DefHash(var);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); }
static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
}
bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const {
- return
- equal(constructor, other->constructor) &&
- equal(patterns, other->patterns);
+ return equal(constructor, other->constructor) && equal(patterns, other->patterns);
}
void SHashReduce(SHashReducer hash_reduce) const {
return equal(patterns, other->patterns);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(patterns);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); }
static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
- return
- equal(data, other->data) &&
- equal(clauses, other->clauses) &&
- equal(complete, other->complete);
+ return equal(data, other->data) && equal(clauses, other->clauses) &&
+ equal(complete, other->complete);
}
void SHashReduce(SHashReducer hash_reduce) const {
#ifndef TVM_RELAY_ANALYSIS_H_
#define TVM_RELAY_ANALYSIS_H_
+#include <tvm/ir/module.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
-#include <tvm/ir/module.h>
#include <tvm/relay/type.h>
+
#include <string>
#include <unordered_map>
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
* although x is not shadowed.
*
- * \param expr the expression to check.
+ * \param expr the expression to check.
*
- * \return true iff all Var in expr is bound at most once.
+ * \return true iff all Var in expr is bound at most once.
*/
TVM_DLL bool WellFormed(const Expr& expr);
*
* \return The reference count mapping.
*/
-TVM_DLL std::unordered_map<const Object*, size_t>
-GetExprRefCount(const Expr& body);
+TVM_DLL std::unordered_map<const Object*, size_t> GetExprRefCount(const Expr& body);
} // namespace relay
} // namespace tvm
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
+
#include <string>
namespace tvm {
DataType dtype;
TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") {
- TVM_ATTR_FIELD(axis).set_default(-1)
- .describe("Axis along which to sort the input tensor."
- "If not given, the flattened array is used.");
- TVM_ATTR_FIELD(is_ascend).set_default(true)
- .describe("Whether to sort in ascending or descending order."
- "By default, sort in ascending order");
- TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
- .describe("DType of the output indices.");
+ TVM_ATTR_FIELD(axis).set_default(-1).describe(
+ "Axis along which to sort the input tensor."
+ "If not given, the flattened array is used.");
+ TVM_ATTR_FIELD(is_ascend).set_default(true).describe(
+ "Whether to sort in ascending or descending order."
+ "By default, sort in ascending order");
+ TVM_ATTR_FIELD(dtype)
+ .set_default(NullValue<DataType>())
+ .describe("DType of the output indices.");
}
};
DataType dtype;
TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
- TVM_ATTR_FIELD(k).set_default(1)
- .describe("Number of top elements to select");
- TVM_ATTR_FIELD(axis).set_default(-1)
- .describe("Axis along which to sort the input tensor.");
- TVM_ATTR_FIELD(ret_type).set_default("both")
- .describe("The return type [both, values, indices]."
- "both - return both top k data and indices."
- "values - return top k data only."
- "indices - return top k indices only.");
- TVM_ATTR_FIELD(is_ascend).set_default(false)
- .describe("Whether to sort in ascending or descending order."
- "By default, sort in descending order");
- TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
- .describe("Data type of the output indices.");
+ TVM_ATTR_FIELD(k).set_default(1).describe("Number of top elements to select");
+ TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor.");
+ TVM_ATTR_FIELD(ret_type).set_default("both").describe(
+ "The return type [both, values, indices]."
+ "both - return both top k data and indices."
+ "values - return top k data only."
+ "indices - return top k indices only.");
+ TVM_ATTR_FIELD(is_ascend).set_default(false).describe(
+ "Whether to sort in ascending or descending order."
+ "By default, sort in descending order");
+ TVM_ATTR_FIELD(dtype)
+ .set_default(NullValue<DataType>())
+ .describe("Data type of the output indices.");
}
};
#define TVM_RELAY_ATTRS_ANNOTATION_H_
#include <tvm/ir/attrs.h>
+
#include <string>
namespace tvm {
TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
- .describe(
- "The virutal device/context type that an expression is annotated with.")
- .set_default(0);
+ .describe("The virutal device/context type that an expression is annotated with.")
+ .set_default(0);
}
};
DataType dtype;
TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") {
- TVM_ATTR_FIELD(dtype)
- .describe(
- "The data type denoted to be cast.");
+ TVM_ATTR_FIELD(dtype).describe("The data type denoted to be cast.");
}
};
std::string compiler;
TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") {
- TVM_ATTR_FIELD(compiler)
- .describe("A 3rd party compiler used for code generation.");
+ TVM_ATTR_FIELD(compiler).describe("A 3rd party compiler used for code generation.");
}
};
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
+
#include <string>
namespace tvm {
bool unipolar;
TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") {
- TVM_ATTR_FIELD(units)
- .describe("Number of hidden units of the dense transformation.");
- TVM_ATTR_FIELD(data_bits)
- .set_default(1)
- .describe("Number of bits to pack for incoming tensor.");
+ TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
+ TVM_ATTR_FIELD(data_bits).set_default(1).describe(
+ "Number of bits to pack for incoming tensor.");
TVM_ATTR_FIELD(weight_bits)
- .set_default(1)
- .describe("Number of bits to pack for weight tensor.");
+ .set_default(1)
+ .describe("Number of bits to pack for weight tensor.");
TVM_ATTR_FIELD(pack_dtype)
- .set_default(NullValue<DataType>())
- .describe("Datatype to pack bits into before computation.");
- TVM_ATTR_FIELD(out_dtype)
- .set_default(NullValue<DataType>())
- .describe("Output data type.");
- TVM_ATTR_FIELD(unipolar)
- .set_default(true)
- .describe("Whether to use unipolar or bipolar quantization for inputs.");
+ .set_default(NullValue<DataType>())
+ .describe("Datatype to pack bits into before computation.");
+ TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
+ TVM_ATTR_FIELD(unipolar).set_default(true).describe(
+ "Whether to use unipolar or bipolar quantization for inputs.");
}
};
#define TVM_RELAY_ATTRS_DEBUG_H_
#include <tvm/ir/attrs.h>
+#include <tvm/ir/env_func.h>
+
#include <string>
namespace tvm {
EnvFunc debug_func;
TVM_DECLARE_ATTRS(DebugAttrs, "relay.attrs.DebugAttrs") {
- TVM_ATTR_FIELD(debug_func)
- .describe("The function to use when debugging.");
+ TVM_ATTR_FIELD(debug_func).describe("The function to use when debugging.");
}
};
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_
#include <tvm/ir/attrs.h>
+
#include <string>
namespace tvm {
TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type)
- .describe(
- "The virtual device/context type where the op copies data from.")
- .set_default(0);
+ .describe("The virtual device/context type where the op copies data from.")
+ .set_default(0);
TVM_ATTR_FIELD(dst_dev_type)
- .describe(
- "The virtual device/context type where the op copies data to.")
- .set_default(0);
+ .describe("The virtual device/context type where the op copies data to.")
+ .set_default(0);
}
};
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
+
#include <string>
namespace tvm {
DataType out_dtype;
TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") {
- TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >())
- .describe("Output Size.");
- TVM_ATTR_FIELD(layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Resize is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(method).set_default("bilinear")
- .describe("Specify the mode to use for scaling."
- "nearest_neighbor - Nearest Neighbor"
- "bilinear - Bilinear Interpolation"
- "bicubic - Bicubic Interpolation");
- TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel")
- .describe("Describes how to transform the coordinate in the resized tensor"
- "to the coordinate in the original tensor."
- "Refer to the ONNX Resize operator specification for details"
- "Available options are half_pixel, align_corners and asymmetric");
- TVM_ATTR_FIELD(out_dtype)
- .set_default(NullValue<DataType>())
- .describe("Output data type.");
+ TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >()).describe("Output Size.");
+ TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Resize is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(method)
+ .set_default("bilinear")
+ .describe(
+ "Specify the mode to use for scaling."
+ "nearest_neighbor - Nearest Neighbor"
+ "bilinear - Bilinear Interpolation"
+ "bicubic - Bicubic Interpolation");
+ TVM_ATTR_FIELD(coordinate_transformation_mode)
+ .set_default("half_pixel")
+ .describe(
+ "Describes how to transform the coordinate in the resized tensor"
+ "to the coordinate in the original tensor."
+ "Refer to the ONNX Resize operator specification for details"
+ "Available options are half_pixel, align_corners and asymmetric");
+ TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
}
};
DataType out_dtype;
TVM_DECLARE_ATTRS(CropAndResizeAttrs, "relay.attrs.CropAndResizeAttrs") {
- TVM_ATTR_FIELD(crop_size).set_default(NullValue<Array<IndexExpr> >())
- .describe("Target Size.");
- TVM_ATTR_FIELD(layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Resize is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(method).set_default("bilinear")
- .describe("Specify the mode to use for scaling."
- "nearest_neighbor - Nearest Neighbor"
- "bilinear - Bilinear Interpolation");
- TVM_ATTR_FIELD(extrapolation_value).set_default(0.0)
+ TVM_ATTR_FIELD(crop_size).set_default(NullValue<Array<IndexExpr> >()).describe("Target Size.");
+ TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Resize is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(method)
+ .set_default("bilinear")
+ .describe(
+ "Specify the mode to use for scaling."
+ "nearest_neighbor - Nearest Neighbor"
+ "bilinear - Bilinear Interpolation");
+ TVM_ATTR_FIELD(extrapolation_value)
+ .set_default(0.0)
.describe("Specify value for extrapolation.");
- TVM_ATTR_FIELD(out_dtype)
- .set_default(NullValue<DataType>())
- .describe("Output data type.");
+ TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
}
};
DataType out_dtype;
TVM_DECLARE_ATTRS(Dilation2DAttrs, "relay.attrs.Dilation2DAttrs") {
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the sliding window. [stride_height, stride_width].");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "two int : bottom, right will use same padding as top, left"
- "four int : padding width in the order of (top, left, bottom, right)");
- TVM_ATTR_FIELD(dilations).set_default(Array<IndexExpr>({1, 1}))
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (top, left, bottom, right)");
+ TVM_ATTR_FIELD(dilations)
+ .set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use. [dilation_height, dilation_width]");
- TVM_ATTR_FIELD(data_layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Convolution is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(kernel_layout).set_default("IHW")
- .describe("Dimension ordering of weight. Can be 'IHW', 'HWI', etc."
- "'I', 'H', 'W' stands for input_channel, height, and width"
- "dimensions respectively.");
+ TVM_ATTR_FIELD(data_layout)
+ .set_default("NCHW")
+ .describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Convolution is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .set_default("IHW")
+ .describe(
+ "Dimension ordering of weight. Can be 'IHW', 'HWI', etc."
+ "'I', 'H', 'W' stands for input_channel, height, and width"
+ "dimensions respectively.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
#include <tvm/ir/attrs.h>
#include <tvm/relay/expr.h>
+
#include <string>
#include <vector>
TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") {
TVM_ATTR_FIELD(dtype)
- .describe(
- "The dtype of the tensor to allocate.")
- .set_default(DataType::Float(32, 1));
- TVM_ATTR_FIELD(device_id)
- .describe(
- "The device id on which to allocate memory.");
- TVM_ATTR_FIELD(device_type)
- .describe(
- "The device type on which to allocate memory.");
+ .describe("The dtype of the tensor to allocate.")
+ .set_default(DataType::Float(32, 1));
+ TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory.");
+ TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory.");
}
};
TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") {
TVM_ATTR_FIELD(dtype)
- .describe(
- "The dtype of the tensor to allocate.")
- .set_default(DataType::Float(32, 1));
- TVM_ATTR_FIELD(const_shape)
- .describe(
- "The shape of constant used to aid in type inference.");
+ .describe("The dtype of the tensor to allocate.")
+ .set_default(DataType::Float(32, 1));
+ TVM_ATTR_FIELD(const_shape).describe("The shape of constant used to aid in type inference.");
TVM_ATTR_FIELD(assert_shape)
- .describe(
- "The shape to cast the return type of the allocation to, "\
- "used to specify the shape obtained via further analysis.");
+ .describe(
+ "The shape to cast the return type of the allocation to, "
+ "used to specify the shape obtained via further analysis.");
}
};
Array<Integer> is_input;
TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") {
- TVM_ATTR_FIELD(is_input)
- .describe(
- "A bool indicating whether the shape function should"\
- "expect shape or input in each position.");
+ TVM_ATTR_FIELD(is_input).describe(
+ "A bool indicating whether the shape function should"
+ "expect shape or input in each position.");
}
};
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
+
#include <string>
namespace tvm {
int axis;
TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") {
- TVM_ATTR_FIELD(axis)
- .describe("The axis to add the bias")
- .set_default(1);
+ TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1);
}
};
-
/*! \brief Attributes used in 1D convolution operators */
struct Conv1DAttrs : public tvm::AttrsNode<Conv1DAttrs> {
Array<IndexExpr> strides;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv1DAttrs, "relay.attrs.Conv1DAttrs") {
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, }))
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({
+ 1,
+ }))
.describe("Specifies the stride of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "on both sides for padding number of points");
- TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, }))
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "on both sides for padding number of points");
+ TVM_ATTR_FIELD(dilation)
+ .set_default(Array<IndexExpr>({
+ 1,
+ }))
.describe("Specifies the dilation rate to use for dilated convolution.");
- TVM_ATTR_FIELD(groups).set_default(1)
- .describe("Currently unused but may be added in the future.");
+ TVM_ATTR_FIELD(groups).set_default(1).describe(
+ "Currently unused but may be added in the future.");
TVM_ATTR_FIELD(channels)
- .describe("The number of output channels in the convolution."
- " If it is not set, inferred by shape of the weight.")
+ .describe(
+ "The number of output channels in the convolution."
+ " If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
- TVM_ATTR_FIELD(data_layout).set_default("NCW")
- .describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
- "'N', 'C', 'W' stands for batch, channel, and width"
- "dimensions respectively. Convolution is applied on the 'W'"
- "dimension.");
- TVM_ATTR_FIELD(kernel_layout).set_default("OIW")
- .describe("Dimension ordering of weight. Can be 'OIW', or 'WIO', etc."
- "'O', 'I', 'W' stands for num_filter, input_channel, and width"
- "dimensions respectively.");
+ TVM_ATTR_FIELD(data_layout)
+ .set_default("NCW")
+ .describe(
+ "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, and width"
+ "dimensions respectively. Convolution is applied on the 'W'"
+ "dimension.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .set_default("OIW")
+ .describe(
+ "Dimension ordering of weight. Can be 'OIW', or 'WIO', etc."
+ "'O', 'I', 'W' stands for num_filter, input_channel, and width"
+ "dimensions respectively.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
}
};
-
/*! \brief Attributes used in convolution operators */
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "two int : bottom, right will use same padding as top, left"
- "four int : padding width in the order of (top, left, bottom, right)");
- TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (top, left, bottom, right)");
+ TVM_ATTR_FIELD(dilation)
+ .set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
- TVM_ATTR_FIELD(groups).set_default(1)
- .describe("Controls the connections between inputs and outputs."
- "At groups=1, all inputs are convolved to all outputs."
- "At groups=2, the operation becomes equivalent to having two convolution"
- "layers side by side, each seeing half the input channels, and producing"
- "half the output channels, and both subsequently concatenated.");
+ TVM_ATTR_FIELD(groups).set_default(1).describe(
+ "Controls the connections between inputs and outputs."
+ "At groups=1, all inputs are convolved to all outputs."
+ "At groups=2, the operation becomes equivalent to having two convolution"
+ "layers side by side, each seeing half the input channels, and producing"
+ "half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
- .describe("The number of output channels in the convolution."
- " If it is not set, inferred by shape of the weight.")
+ .describe(
+ "The number of output channels in the convolution."
+ " If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
- TVM_ATTR_FIELD(data_layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Convolution is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
- .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
- "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
- "dimensions respectively.");
- TVM_ATTR_FIELD(out_layout).set_default("")
- .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Default to be same as input layout.");
+ TVM_ATTR_FIELD(data_layout)
+ .set_default("NCHW")
+ .describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Convolution is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .set_default("OIHW")
+ .describe(
+ "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
+ "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .set_default("")
+ .describe(
+ "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
};
/*! \brief Attributes used in winograd weight transformation operators */
-struct ConvWinogradWeightTransformAttrs :
- public tvm::AttrsNode<ConvWinogradWeightTransformAttrs> {
+struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode<ConvWinogradWeightTransformAttrs> {
int tile_size;
TVM_DECLARE_ATTRS(ConvWinogradWeightTransformAttrs,
- "relay.attrs.ConvWinogradWeightTransformAttrs") {
- TVM_ATTR_FIELD(tile_size)
- .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
+ "relay.attrs.ConvWinogradWeightTransformAttrs") {
+ TVM_ATTR_FIELD(tile_size).describe(
+ "Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
}
};
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") {
- TVM_ATTR_FIELD(tile_size)
- .describe("The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
+ TVM_ATTR_FIELD(tile_size).describe(
+ "The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "two int : bottom, right will use same padding as top, left"
- "four int : padding width in the order of (top, left, bottom, right)");
- TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (top, left, bottom, right)");
+ TVM_ATTR_FIELD(dilation)
+ .set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
- TVM_ATTR_FIELD(groups).set_default(1)
- .describe("Controls the connections between inputs and outputs."
- "At groups=1, all inputs are convolved to all outputs."
- "At groups=2, the operation becomes equivalent to having two convolution"
- "layers side by side, each seeing half the input channels, and producing"
- "half the output channels, and both subsequently concatenated.");
+ TVM_ATTR_FIELD(groups).set_default(1).describe(
+ "Controls the connections between inputs and outputs."
+ "At groups=1, all inputs are convolved to all outputs."
+ "At groups=2, the operation becomes equivalent to having two convolution"
+ "layers side by side, each seeing half the input channels, and producing"
+ "half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
- .describe("The number of output channels in the convolution."
- " If it is not set, inferred by shape of the weight.")
+ .describe(
+ "The number of output channels in the convolution."
+ " If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
- TVM_ATTR_FIELD(data_layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Convolution is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
- .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
- "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
- "dimensions respectively.");
- TVM_ATTR_FIELD(out_layout).set_default("")
- .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Default to be same as input layout.");
+ TVM_ATTR_FIELD(data_layout)
+ .set_default("NCHW")
+ .describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Convolution is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .set_default("OIHW")
+ .describe(
+ "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
+ "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .set_default("")
+ .describe(
+ "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") {
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "three int : back, bottom, right will use same padding as front, top, left"
- "six int : padding width in the order of (front, top, left, back, bottom,"
- "right)");
- TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "three int : back, bottom, right will use same padding as front, top, left"
+ "six int : padding width in the order of (front, top, left, back, bottom,"
+ "right)");
+ TVM_ATTR_FIELD(dilation)
+ .set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
- TVM_ATTR_FIELD(groups).set_default(1)
- .describe("Controls the connections between inputs and outputs."
- "At groups=1, all inputs are convolved to all outputs."
- "At groups=2, the operation becomes equivalent to having two convolution"
- "layers side by side, each seeing half the input channels, and producing"
- "half the output channels, and both subsequently concatenated.");
+ TVM_ATTR_FIELD(groups).set_default(1).describe(
+ "Controls the connections between inputs and outputs."
+ "At groups=1, all inputs are convolved to all outputs."
+ "At groups=2, the operation becomes equivalent to having two convolution"
+ "layers side by side, each seeing half the input channels, and producing"
+ "half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
- .describe("The number of output channels in the convolution."
- " If it is not set, inferred by shape of the weight.")
+ .describe(
+ "The number of output channels in the convolution."
+ " If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
- TVM_ATTR_FIELD(data_layout).set_default("NCDHW")
- .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
- "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
- "dimensions respectively. Convolution is applied on the 'D', 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW")
- .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
- "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height,"
- "and width dimensions respectively.");
- TVM_ATTR_FIELD(out_layout).set_default("")
- .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
- "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
- "dimensions respectively. Default to be same as input layout.");
+ TVM_ATTR_FIELD(data_layout)
+ .set_default("NCDHW")
+ .describe(
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+ "dimensions respectively. Convolution is applied on the 'D', 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .set_default("OIDHW")
+ .describe(
+ "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
+ "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height,"
+ "and width dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .set_default("")
+ .describe(
+ "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+ "dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") {
- TVM_ATTR_FIELD(tile_size)
- .describe("The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)");
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
+ TVM_ATTR_FIELD(tile_size).describe(
+ "The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)");
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "three int : back, bottom, right will use same padding as front, top, left"
- "six int : padding width in the order of (front, top, left, back, bottom,"
- "right)");
- TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "three int : back, bottom, right will use same padding as front, top, left"
+ "six int : padding width in the order of (front, top, left, back, bottom,"
+ "right)");
+ TVM_ATTR_FIELD(dilation)
+ .set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
- TVM_ATTR_FIELD(groups).set_default(1)
- .describe("Controls the connections between inputs and outputs."
- "At groups=1, all inputs are convolved to all outputs."
- "At groups=2, the operation becomes equivalent to having two convolution"
- "layers side by side, each seeing half the input channels, and producing"
- "half the output channels, and both subsequently concatenated.");
+ TVM_ATTR_FIELD(groups).set_default(1).describe(
+ "Controls the connections between inputs and outputs."
+ "At groups=1, all inputs are convolved to all outputs."
+ "At groups=2, the operation becomes equivalent to having two convolution"
+ "layers side by side, each seeing half the input channels, and producing"
+ "half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
- .describe("The number of output channels in the convolution."
- " If it is not set, inferred by shape of the weight.")
+ .describe(
+ "The number of output channels in the convolution."
+ " If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
- TVM_ATTR_FIELD(data_layout).set_default("NCDHW")
- .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
- "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
- "dimensions respectively. Convolution is applied on the 'D', 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW")
- .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
- "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height,"
- "and width dimensions respectively.");
- TVM_ATTR_FIELD(out_layout).set_default("")
- .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
- "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
- "dimensions respectively. Default to be same as input layout.");
+ TVM_ATTR_FIELD(data_layout)
+ .set_default("NCDHW")
+ .describe(
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+ "dimensions respectively. Convolution is applied on the 'D', 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .set_default("OIDHW")
+ .describe(
+ "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
+ "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height,"
+ "and width dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .set_default("")
+ .describe(
+ "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+ "dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
}
};
-
/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;
TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") {
- TVM_ATTR_FIELD(axis).set_default(-1)
- .describe("The axis to sum over when computing softmax.");
+ TVM_ATTR_FIELD(axis).set_default(-1).describe("The axis to sum over when computing softmax.");
}
};
TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") {
TVM_ATTR_FIELD(channels)
- .set_default(NullValue<IndexExpr>())
- .describe("The dimensionality of the output space"
- "i.e. the number of output channels in the convolution.");
+ .set_default(NullValue<IndexExpr>())
+ .describe(
+ "The dimensionality of the output space"
+ "i.e. the number of output channels in the convolution.");
TVM_ATTR_FIELD(kernel_size)
- .describe("The dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
- .describe("The strides of the convolution.");
- TVM_ATTR_FIELD(output_padding).set_default(Array<IndexExpr>({0, 0}))
- .describe("Zero-padding added to one side of the output."
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "two int : bottom, right will use same padding as top, left"
- "four int : padding width in the order of (top, left, bottom, right)");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "two int : bottom, right will use same padding as top, left"
- "four int : padding width in the order of (top, left, bottom, right)");
- TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
- .describe("Specifies the dilation rate to use for dilated convolution.");
- TVM_ATTR_FIELD(groups).set_default(1)
- .describe("Controls the connections between inputs and outputs."
- "At groups=1, all inputs are convolved to all outputs."
- "At groups=2, the operation becomes equivalent to having two convolution"
- "layers side by side, each seeing half the input channels, and producing"
- "half the output channels, and both subsequently concatenated.");
- TVM_ATTR_FIELD(data_layout).set_default("NCHW")
- .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Convolution is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
- .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
- "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
- "dimensions respectively.");
- TVM_ATTR_FIELD(out_layout).set_default("")
- .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Default to be same as input layout.");
+ .describe("The dimensions of the convolution window.")
+ .set_default(NullValue<Array<IndexExpr> >());
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1}))
+ .describe("The strides of the convolution.");
+ TVM_ATTR_FIELD(output_padding)
+ .set_default(Array<IndexExpr>({0, 0}))
+ .describe(
+ "Zero-padding added to one side of the output."
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (top, left, bottom, right)");
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (top, left, bottom, right)");
+ TVM_ATTR_FIELD(dilation)
+ .set_default(Array<IndexExpr>({1, 1}))
+ .describe("Specifies the dilation rate to use for dilated convolution.");
+ TVM_ATTR_FIELD(groups).set_default(1).describe(
+ "Controls the connections between inputs and outputs."
+ "At groups=1, all inputs are convolved to all outputs."
+ "At groups=2, the operation becomes equivalent to having two convolution"
+ "layers side by side, each seeing half the input channels, and producing"
+ "half the output channels, and both subsequently concatenated.");
+ TVM_ATTR_FIELD(data_layout)
+ .set_default("NCHW")
+ .describe(
+ "Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Convolution is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .set_default("OIHW")
+ .describe(
+ "Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
+ "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .set_default("")
+ .describe(
+ "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
Array<IndexExpr> strides;
TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") {
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
- .describe("Dilation stride on each dimension, 1 means no dilation.");
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1}))
+ .describe("Dilation stride on each dimension, 1 means no dilation.");
}
};
TVM_DECLARE_ATTRS(Conv1DTransposeAttrs, "relay.attrs.Conv1DTransposeAttrs") {
TVM_ATTR_FIELD(channels)
- .set_default(NullValue<IndexExpr>())
- .describe("The dimensionality of the output space"
- "i.e. the number of output channels in the convolution.");
+ .set_default(NullValue<IndexExpr>())
+ .describe(
+ "The dimensionality of the output space"
+ "i.e. the number of output channels in the convolution.");
TVM_ATTR_FIELD(kernel_size)
- .describe("The dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1}))
- .describe("The strides of the convolution.");
- TVM_ATTR_FIELD(output_padding).set_default(Array<IndexExpr>({0}))
- .describe("Zero-padding added to one side of the output.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0}))
- .describe("Symmetric or asymmetric padding."
- "Single value: the input is implicitly zero-padded on both sides."
- "Two values: padding[0] is used for left input padding, "
- "padding[1] is used for right input padding,");
- TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1}))
- .describe("Specifies the dilation rate to use for dilated convolution.");
- TVM_ATTR_FIELD(groups).set_default(1)
- .describe("Controls the connections between inputs and outputs."
- "At groups=1, all inputs are convolved to all outputs."
- "At groups=2, the operation becomes equivalent to having two convolution"
- "layers side by side, each seeing half the input channels, and producing"
- "half the output channels, and both subsequently concatenated.");
- TVM_ATTR_FIELD(data_layout).set_default("NCW")
- .describe("Dimension ordering of data. Can be 'NCW', 'NWC', etc."
- "'N', 'C', 'W' stands for batch, channel, and width"
- "dimensions respectively. Convolution is applied on the"
- "'W' dimension.");
- TVM_ATTR_FIELD(kernel_layout).set_default("OIW")
- .describe("Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc."
- "'O', 'I', 'W' stands for num_filter, input_channel, and width"
- "dimensions respectively.");
- TVM_ATTR_FIELD(out_layout).set_default("")
- .describe("Dimension ordering of output. Can be 'NCW', 'NWC', etc."
- "'N', 'C', 'W' stands for batch, channel, and width"
- "dimensions respectively. Default to be same as input layout.");
+ .describe("The dimensions of the convolution window.")
+ .set_default(NullValue<Array<IndexExpr> >());
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1}))
+ .describe("The strides of the convolution.");
+ TVM_ATTR_FIELD(output_padding)
+ .set_default(Array<IndexExpr>({0}))
+ .describe("Zero-padding added to one side of the output.");
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0}))
+ .describe(
+ "Symmetric or asymmetric padding."
+ "Single value: the input is implicitly zero-padded on both sides."
+ "Two values: padding[0] is used for left input padding, "
+ "padding[1] is used for right input padding,");
+ TVM_ATTR_FIELD(dilation)
+ .set_default(Array<IndexExpr>({1}))
+ .describe("Specifies the dilation rate to use for dilated convolution.");
+ TVM_ATTR_FIELD(groups).set_default(1).describe(
+ "Controls the connections between inputs and outputs."
+ "At groups=1, all inputs are convolved to all outputs."
+ "At groups=2, the operation becomes equivalent to having two convolution"
+ "layers side by side, each seeing half the input channels, and producing"
+ "half the output channels, and both subsequently concatenated.");
+ TVM_ATTR_FIELD(data_layout)
+ .set_default("NCW")
+ .describe(
+ "Dimension ordering of data. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, and width"
+ "dimensions respectively. Convolution is applied on the"
+ "'W' dimension.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .set_default("OIW")
+ .describe(
+ "Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc."
+ "'O', 'I', 'W' stands for num_filter, input_channel, and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .set_default("")
+ .describe(
+ "Dimension ordering of output. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, and width"
+ "dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
bool ceil_mode;
TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") {
- TVM_ATTR_FIELD(pool_size)
- .describe("Size of the pooling windows.");
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
- .describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "two int : bottom, right will use same padding as top, left"
- "four int : padding width in the order of (top, left, bottom, right)");
- TVM_ATTR_FIELD(layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Pooling is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(ceil_mode).set_default(false)
- .describe("When true, will use ceil instead of floor to compute the output shape.");
+ TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1}))
+ .describe("Specifies the strides of the convolution.");
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (top, left, bottom, right)");
+ TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Pooling is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
+ "When true, will use ceil instead of floor to compute the output shape.");
}
};
bool count_include_pad;
TVM_DECLARE_ATTRS(AvgPool2DAttrs, "relay.attrs.AvgPool2DAttrs") {
- TVM_ATTR_FIELD(pool_size)
- .describe("Size of the pooling windows.");
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
- .describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "two int : bottom, right will use same padding as top, left"
- "four int : padding width in the order of (top, left, bottom, right)");
- TVM_ATTR_FIELD(layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Pooling is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(ceil_mode).set_default(false)
- .describe("When true, will use ceil instead of floor to compute the output shape.");
- TVM_ATTR_FIELD(count_include_pad).set_default(false)
- .describe("When true, will include padding to compute the average");
+ TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1}))
+ .describe("Specifies the strides of the convolution.");
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (top, left, bottom, right)");
+ TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Pooling is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
+ "When true, will use ceil instead of floor to compute the output shape.");
+ TVM_ATTR_FIELD(count_include_pad)
+ .set_default(false)
+ .describe("When true, will include padding to compute the average");
}
};
std::string layout;
TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") {
- TVM_ATTR_FIELD(layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Pooling is applied on the 'H' and"
- "'W' dimensions.");
+ TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Pooling is applied on the 'H' and"
+ "'W' dimensions.");
}
};
std::string layout;
TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") {
- TVM_ATTR_FIELD(output_size).set_default(Array<IndexExpr>({}))
- .describe("Output height and width.");
- TVM_ATTR_FIELD(layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Pooling is applied on the 'H' and"
- "'W' dimensions.");
+ TVM_ATTR_FIELD(output_size)
+ .set_default(Array<IndexExpr>({}))
+ .describe("Output height and width.");
+ TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Pooling is applied on the 'H' and"
+ "'W' dimensions.");
}
};
std::string layout;
TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") {
- TVM_ATTR_FIELD(output_size).set_default(Array<IndexExpr>({}))
- .describe("Output depth, height and width.");
- TVM_ATTR_FIELD(layout).set_default("NCDHW")
- .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
- "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
- "dimensions respectively. Pooling is applied on 'D', 'H' and"
- "'W' dimensions.");
+ TVM_ATTR_FIELD(output_size)
+ .set_default(Array<IndexExpr>({}))
+ .describe("Output depth, height and width.");
+ TVM_ATTR_FIELD(layout).set_default("NCDHW").describe(
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+ "dimensions respectively. Pooling is applied on 'D', 'H' and"
+ "'W' dimensions.");
}
};
-
/*! \brief Attributes for 1D max pool operator */
struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
Array<IndexExpr> pool_size;
bool ceil_mode;
TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") {
- TVM_ATTR_FIELD(pool_size)
- .describe("Size of the pooling windows.");
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1}))
- .describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "three int : back, bottom, right will use same padding as front, top, left"
- "six int : padding width in the order of (front, top, left, back, bottom, right)");
- TVM_ATTR_FIELD(layout).set_default("NCW")
- .describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
- "'N', 'C', 'W' stands for batch, channel, and width"
- "dimensions respectively. Pooling is applied on the 'W' dimensions.");
- TVM_ATTR_FIELD(ceil_mode).set_default(false)
- .describe("When true, will use ceil instead of floor to compute the output shape.");
+ TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1}))
+ .describe("Specifies the strides of the convolution.");
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "three int : back, bottom, right will use same padding as front, top, left"
+ "six int : padding width in the order of (front, top, left, back, bottom, right)");
+ TVM_ATTR_FIELD(layout).set_default("NCW").describe(
+ "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, and width"
+ "dimensions respectively. Pooling is applied on the 'W' dimensions.");
+ TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
+ "When true, will use ceil instead of floor to compute the output shape.");
}
};
bool count_include_pad;
TVM_DECLARE_ATTRS(AvgPool1DAttrs, "relay.attrs.AvgPool1DAttrs") {
- TVM_ATTR_FIELD(pool_size)
- .describe("Size of the pooling windows.");
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1}))
- .describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "three int : back, bottom, right will use same padding as front, top, left"
- "six int : padding width in the order of (front, top, left, back, bottom, right)");
- TVM_ATTR_FIELD(layout).set_default("NCW")
- .describe("Dimension ordering of input data. Can be 'NCW', 'NHC', etc."
- "'N', 'C', 'W' stands for batch, channel, and width"
- "dimensions respectively. Pooling is applied on the 'W' dimension.");
- TVM_ATTR_FIELD(ceil_mode).set_default(false)
- .describe("When true, will use ceil instead of floor to compute the output shape.");
- TVM_ATTR_FIELD(count_include_pad).set_default(false)
- .describe("When true, will include padding to compute the average");
+ TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1}))
+ .describe("Specifies the strides of the convolution.");
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "three int : back, bottom, right will use same padding as front, top, left"
+ "six int : padding width in the order of (front, top, left, back, bottom, right)");
+ TVM_ATTR_FIELD(layout).set_default("NCW").describe(
+ "Dimension ordering of input data. Can be 'NCW', 'NHC', etc."
+ "'N', 'C', 'W' stands for batch, channel, and width"
+ "dimensions respectively. Pooling is applied on the 'W' dimension.");
+ TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
+ "When true, will use ceil instead of floor to compute the output shape.");
+ TVM_ATTR_FIELD(count_include_pad)
+ .set_default(false)
+ .describe("When true, will include padding to compute the average");
}
};
-
/*! \brief Attributes for 3D max pool operator */
struct MaxPool3DAttrs : public tvm::AttrsNode<MaxPool3DAttrs> {
Array<IndexExpr> pool_size;
bool ceil_mode;
TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") {
- TVM_ATTR_FIELD(pool_size)
- .describe("Size of the pooling windows.");
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
- .describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "three int : back, bottom, right will use same padding as front, top, left"
- "six int : padding width in the order of (front, top, left, back, bottom, right)");
- TVM_ATTR_FIELD(layout).set_default("NCDHW")
- .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
- "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
- "dimensions respectively. Pooling is applied on the 'D', 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(ceil_mode).set_default(false)
- .describe("When true, will use ceil instead of floor to compute the output shape.");
+ TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1, 1}))
+ .describe("Specifies the strides of the convolution.");
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "three int : back, bottom, right will use same padding as front, top, left"
+ "six int : padding width in the order of (front, top, left, back, bottom, right)");
+ TVM_ATTR_FIELD(layout).set_default("NCDHW").describe(
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+ "dimensions respectively. Pooling is applied on the 'D', 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
+ "When true, will use ceil instead of floor to compute the output shape.");
}
};
bool count_include_pad;
TVM_DECLARE_ATTRS(AvgPool3DAttrs, "relay.attrs.AvgPool3DAttrs") {
- TVM_ATTR_FIELD(pool_size)
- .describe("Size of the pooling windows.");
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
- .describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "three int : back, bottom, right will use same padding as front, top, left"
- "six int : padding width in the order of (front, top, left, back, bottom, right)");
- TVM_ATTR_FIELD(layout).set_default("NCDHW")
- .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
- "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
- "dimensions respectively. Pooling is applied on the 'D', 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(ceil_mode).set_default(false)
- .describe("When true, will use ceil instead of floor to compute the output shape.");
- TVM_ATTR_FIELD(count_include_pad).set_default(false)
- .describe("When true, will include padding to compute the average");
+ TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1, 1}))
+ .describe("Specifies the strides of the convolution.");
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "three int : back, bottom, right will use same padding as front, top, left"
+ "six int : padding width in the order of (front, top, left, back, bottom, right)");
+ TVM_ATTR_FIELD(layout).set_default("NCDHW").describe(
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+ "dimensions respectively. Pooling is applied on the 'D', 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
+ "When true, will use ceil instead of floor to compute the output shape.");
+ TVM_ATTR_FIELD(count_include_pad)
+ .set_default(false)
+ .describe("When true, will include padding to compute the average");
}
};
-
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
DataType out_dtype;
TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
- TVM_ATTR_FIELD(units)
- .describe("Number of hidden units of the dense transformation.");
+ TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
bool align_corners;
TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
- TVM_ATTR_FIELD(scale_h)
- .describe("The upsampling factor for height");
- TVM_ATTR_FIELD(scale_w)
- .describe("The upsampling factor for width");
- TVM_ATTR_FIELD(layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Upsampling is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(method).set_default("nearest_neighbor")
- .describe("Specify the mode to use for scaling."
- "nearest_neighbor - Nearest Neighbor"
- "bilinear - Bilinear Interpolation"
- "bicubic - Bicubic Interpolation");
- TVM_ATTR_FIELD(align_corners).set_default(false)
+ TVM_ATTR_FIELD(scale_h).describe("The upsampling factor for height");
+ TVM_ATTR_FIELD(scale_w).describe("The upsampling factor for width");
+ TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Upsampling is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(method)
+ .set_default("nearest_neighbor")
+ .describe(
+ "Specify the mode to use for scaling."
+ "nearest_neighbor - Nearest Neighbor"
+ "bilinear - Bilinear Interpolation"
+ "bicubic - Bicubic Interpolation");
+ TVM_ATTR_FIELD(align_corners)
+ .set_default(false)
.describe("Should be true to preserve the values at the corner pixels");
}
};
std::string coordinate_transformation_mode;
TVM_DECLARE_ATTRS(UpSampling3DAttrs, "relay.attrs.UpSampling3DAttrs") {
- TVM_ATTR_FIELD(scale_d)
- .describe("The upsampling factor for depth");
- TVM_ATTR_FIELD(scale_h)
- .describe("The upsampling factor for height");
- TVM_ATTR_FIELD(scale_w)
- .describe("The upsampling factor for width");
- TVM_ATTR_FIELD(layout).set_default("NCDHW")
- .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
- "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
- "dimensions respectively. Upsampling is applied on the 'D', 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(method).set_default("nearest_neighbor")
- .describe("Specify the mode to use for scaling."
- "nearest_neighbor - Nearest Neighbor"
- "trilinear - Trilinear Interpolation");
- TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel")
- .describe("Describes how to transform the coordinate in the resized tensor"
- "to the coordinate in the original tensor."
- "Refer to the ONNX Resize operator specification for details"
- "Available options are half_pixel, align_corners and asymmetric");
+ TVM_ATTR_FIELD(scale_d).describe("The upsampling factor for depth");
+ TVM_ATTR_FIELD(scale_h).describe("The upsampling factor for height");
+ TVM_ATTR_FIELD(scale_w).describe("The upsampling factor for width");
+ TVM_ATTR_FIELD(layout).set_default("NCDHW").describe(
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+ "dimensions respectively. Upsampling is applied on the 'D', 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(method)
+ .set_default("nearest_neighbor")
+ .describe(
+ "Specify the mode to use for scaling."
+ "nearest_neighbor - Nearest Neighbor"
+ "trilinear - Trilinear Interpolation");
+ TVM_ATTR_FIELD(coordinate_transformation_mode)
+ .set_default("half_pixel")
+ .describe(
+ "Describes how to transform the coordinate in the resized tensor"
+ "to the coordinate in the original tensor."
+ "Refer to the ONNX Resize operator specification for details"
+ "Available options are half_pixel, align_corners and asymmetric");
}
};
std::string pad_mode;
TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
- TVM_ATTR_FIELD(pad_value).set_default(0.0)
- .describe("The value used for padding when mode is 'constant'.");
- TVM_ATTR_FIELD(pad_width)
- .describe("Number of values padded to the edges of each axis, "
- "in the format of ((before_1, after_1), ..., (before_N, after_N))");
- TVM_ATTR_FIELD(pad_mode).set_default("constant")
- .describe("Padding type to use. \"constant\" pads with constant_value, "
- "\"edge\" pads using the edge values of the input array, "
- "\"reflect\" pads by reflecting values with respect to the edges.");
+ TVM_ATTR_FIELD(pad_value).set_default(0.0).describe(
+ "The value used for padding when mode is 'constant'.");
+ TVM_ATTR_FIELD(pad_width).describe(
+ "Number of values padded to the edges of each axis, "
+ "in the format of ((before_1, after_1), ..., (before_N, after_N))");
+ TVM_ATTR_FIELD(pad_mode)
+ .set_default("constant")
+ .describe(
+ "Padding type to use. \"constant\" pads with constant_value, "
+ "\"edge\" pads using the edge values of the input array, "
+ "\"reflect\" pads by reflecting values with respect to the edges.");
}
};
Array<Array<IndexExpr> > pad_width;
TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") {
- TVM_ATTR_FIELD(mode).set_default("SYMMETRIC")
- .describe("Specifies how mirroring should be performed.");
- TVM_ATTR_FIELD(pad_width)
- .describe("Number of values padded to the edges of each axis, "
- "in the format of ((before_1, after_1), ..., (before_N, after_N))");
+ TVM_ATTR_FIELD(mode)
+ .set_default("SYMMETRIC")
+ .describe("Specifies how mirroring should be performed.");
+ TVM_ATTR_FIELD(pad_width).describe(
+ "Number of values padded to the edges of each axis, "
+ "in the format of ((before_1, after_1), ..., (before_N, after_N))");
}
};
double alpha;
TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") {
- TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25)
- .describe("Slope coefficient for the negative half axis.");
+ TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25).describe(
+ "Slope coefficient for the negative half axis.");
}
};
-
/*! \brief Attributes for prelu operator */
struct PReluAttrs : public tvm::AttrsNode<PReluAttrs> {
int axis;
TVM_DECLARE_ATTRS(PReluAttrs, "relay.attrs.PReluAttrs") {
- TVM_ATTR_FIELD(axis).set_default(1)
- .describe("Specify which shape axis the channel is specified.");
+ TVM_ATTR_FIELD(axis).set_default(1).describe(
+ "Specify which shape axis the channel is specified.");
}
};
-
/*! \brief Attributes used in dropout operator */
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
double rate;
TVM_DECLARE_ATTRS(DropoutAttrs, "relay.attrs.DropoutAttrs") {
TVM_ATTR_FIELD(rate)
- .describe("Fraction of the input that gets dropped out during training time")
- .set_default(0.5);
+ .describe("Fraction of the input that gets dropped out during training time")
+ .set_default(0.5);
}
}; // struct DropoutAttrs
bool scale;
TVM_DECLARE_ATTRS(BatchNormAttrs, "relay.attrs.BatchNormAttrs") {
- TVM_ATTR_FIELD(axis)
- .describe("Specify which shape axis denotes the channel.")
- .set_default(1);
+ TVM_ATTR_FIELD(axis).describe("Specify which shape axis denotes the channel.").set_default(1);
TVM_ATTR_FIELD(epsilon)
- .describe("Small float added to variance to avoid dividing by zero")
- .set_default(1e-5);
+ .describe("Small float added to variance to avoid dividing by zero")
+ .set_default(1e-5);
TVM_ATTR_FIELD(center)
- .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored")
- .set_default(true);
+ .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored")
+ .set_default(true);
TVM_ATTR_FIELD(scale)
- .describe("If True, multiply by gamma. If False, gamma is not used. "
- "When the next layer is piecewise linear (also, e.g., nn.relu), "
- "this can be disabled since the scaling will be done by the next layer.")
- .set_default(true);
+ .describe(
+ "If True, multiply by gamma. If False, gamma is not used. "
+ "When the next layer is piecewise linear (also, e.g., nn.relu), "
+ "this can be disabled since the scaling will be done by the next layer.")
+ .set_default(true);
}
}; // struct BatchNormAttrs
-
/*! \brief Attributes used in instance_norm operator */
struct InstanceNormAttrs : public tvm::AttrsNode<InstanceNormAttrs> {
int axis;
bool scale;
TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") {
- TVM_ATTR_FIELD(axis)
- .describe("Specify which shape axis denotes the channel.")
- .set_default(1);
+ TVM_ATTR_FIELD(axis).describe("Specify which shape axis denotes the channel.").set_default(1);
TVM_ATTR_FIELD(epsilon)
- .describe("Small float added to variance to avoid dividing by zero")
- .set_default(1e-5);
- TVM_ATTR_FIELD(center).set_default(true)
- .describe("If true, add offset of beta to normalized tensor; "
- "otherwise, beta is ignored.");
- TVM_ATTR_FIELD(scale).set_default(true)
- .describe("If true, multiply by gamma; otherwise, gamma is ignored.");
+ .describe("Small float added to variance to avoid dividing by zero")
+ .set_default(1e-5);
+ TVM_ATTR_FIELD(center).set_default(true).describe(
+ "If true, add offset of beta to normalized tensor; "
+ "otherwise, beta is ignored.");
+ TVM_ATTR_FIELD(scale).set_default(true).describe(
+ "If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct InstanceNormAttrs
-
/*! \brief Attributes used in layer_norm operator */
struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
int axis;
bool scale;
TVM_DECLARE_ATTRS(LayerNormAttrs, "relay.attrs.LayerNormAttrs") {
- TVM_ATTR_FIELD(axis).set_default(-1)
- .describe("Specify which shape axis denotes the channel.");
- TVM_ATTR_FIELD(epsilon).set_default(1e-5)
- .describe("Small float added to variance to avoid dividing by zero");
- TVM_ATTR_FIELD(center).set_default(true)
- .describe("If true, add offset of beta to normalized tensor; "
- "otherwise, beta is ignored.");
- TVM_ATTR_FIELD(scale).set_default(true)
- .describe("If true, multiply by gamma; otherwise, gamma is ignored.");
+ TVM_ATTR_FIELD(axis).set_default(-1).describe("Specify which shape axis denotes the channel.");
+ TVM_ATTR_FIELD(epsilon).set_default(1e-5).describe(
+ "Small float added to variance to avoid dividing by zero");
+ TVM_ATTR_FIELD(center).set_default(true).describe(
+ "If true, add offset of beta to normalized tensor; "
+ "otherwise, beta is ignored.");
+ TVM_ATTR_FIELD(scale).set_default(true).describe(
+ "If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct LayerNormAttrs
-
/*! \brief Attributes used in group_norm operator */
struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
int num_groups;
bool scale;
TVM_DECLARE_ATTRS(GroupNormAttrs, "relay.attrs.GroupNormAttrs") {
- TVM_ATTR_FIELD(num_groups).set_default(0)
- .describe("Specify number of groups to separate the channels into.");
- TVM_ATTR_FIELD(axis).set_default(1)
- .describe("Specify which shape axis denotes the channel.");
- TVM_ATTR_FIELD(epsilon).set_default(1e-5)
- .describe("Small float added to variance to avoid dividing by zero");
- TVM_ATTR_FIELD(center).set_default(true)
- .describe("If true, add offset of beta to normalized tensor; "
- "otherwise, beta is ignored.");
- TVM_ATTR_FIELD(scale).set_default(true)
- .describe("If true, multiply by gamma; otherwise, gamma is ignored.");
+ TVM_ATTR_FIELD(num_groups)
+ .set_default(0)
+ .describe("Specify number of groups to separate the channels into.");
+ TVM_ATTR_FIELD(axis).set_default(1).describe("Specify which shape axis denotes the channel.");
+ TVM_ATTR_FIELD(epsilon).set_default(1e-5).describe(
+ "Small float added to variance to avoid dividing by zero");
+ TVM_ATTR_FIELD(center).set_default(true).describe(
+ "If true, add offset of beta to normalized tensor; "
+ "otherwise, beta is ignored.");
+ TVM_ATTR_FIELD(scale).set_default(true).describe(
+ "If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct GroupNormAttrs
-
/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
int size;
double beta;
TVM_DECLARE_ATTRS(LRNAttrs, "relay.attrs.LRNAttrs") {
- TVM_ATTR_FIELD(size).set_default(5)
- .describe("The size of the local region to be considered for normalization.");
- TVM_ATTR_FIELD(axis).set_default(1)
- .describe("Axis of input data layout channel.");
- TVM_ATTR_FIELD(bias).set_default(2)
- .describe("The offset parameter to avoid division by 0.");
- TVM_ATTR_FIELD(alpha).set_default(0.0001)
- .describe("The scaling parameter.");
- TVM_ATTR_FIELD(beta).set_default(0.75)
- .describe("The exponent parameter.");
+ TVM_ATTR_FIELD(size).set_default(5).describe(
+ "The size of the local region to be considered for normalization.");
+ TVM_ATTR_FIELD(axis).set_default(1).describe("Axis of input data layout channel.");
+ TVM_ATTR_FIELD(bias).set_default(2).describe("The offset parameter to avoid division by 0.");
+ TVM_ATTR_FIELD(alpha).set_default(0.0001).describe("The scaling parameter.");
+ TVM_ATTR_FIELD(beta).set_default(0.75).describe("The exponent parameter.");
}
};
-
/*! \brief Attributes for L2Normalize operator */
struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
double eps;
Array<Integer> axis;
TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") {
- TVM_ATTR_FIELD(eps)
- .describe("A lower bound value for the norm, to avoid division by 0.");
- TVM_ATTR_FIELD(axis)
- .describe("Axis over the normalization applied.");
+ TVM_ATTR_FIELD(eps).describe("A lower bound value for the norm, to avoid division by 0.");
+ TVM_ATTR_FIELD(axis).describe("Axis over the normalization applied.");
}
};
-
/*! \brief Attributes for DeformableConv2D operator */
struct DeformableConv2DAttrs : public tvm::AttrsNode<DeformableConv2DAttrs> {
Array<IndexExpr> strides;
DataType out_dtype;
TVM_DECLARE_ATTRS(DeformableConv2DAttrs, "relay.attrs.DeformableConv2DAttrs") {
- TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
+ TVM_ATTR_FIELD(strides)
+ .set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
- TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
- .describe("If padding is non-zero, then the input is implicitly zero-padded"
- "Padding support both symmetric and asymmetric as"
- "one int : same padding used on all sides"
- "two int : bottom, right will use same padding as top, left"
- "four int : padding width in the order of (top, left, bottom, right)");
- TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
+ TVM_ATTR_FIELD(padding)
+ .set_default(Array<IndexExpr>({0, 0}))
+ .describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (top, left, bottom, right)");
+ TVM_ATTR_FIELD(dilation)
+ .set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
- TVM_ATTR_FIELD(deformable_groups).set_default(1)
- .describe("Controls the connections between inputs and offsets."
- "Input channels are partitioned into multiple deformable groups. Offsets"
- "are shared across input channels in the same deformable group.");
- TVM_ATTR_FIELD(groups).set_default(1)
- .describe("Controls the connections between inputs and outputs."
- "At groups=1, all inputs are convolved to all outputs."
- "At groups=2, the operation becomes equivalent to having two convolution"
- "layers side by side, each seeing half the input channels, and producing"
- "half the output channels, and both subsequently concatenated.");
+ TVM_ATTR_FIELD(deformable_groups)
+ .set_default(1)
+ .describe(
+ "Controls the connections between inputs and offsets."
+ "Input channels are partitioned into multiple deformable groups. Offsets"
+ "are shared across input channels in the same deformable group.");
+ TVM_ATTR_FIELD(groups).set_default(1).describe(
+ "Controls the connections between inputs and outputs."
+ "At groups=1, all inputs are convolved to all outputs."
+ "At groups=2, the operation becomes equivalent to having two convolution"
+ "layers side by side, each seeing half the input channels, and producing"
+ "half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
- .describe("The number of output channels in the convolution."
- " If it is not set, inferred by shape of the weight.")
+ .describe(
+ "The number of output channels in the convolution."
+ " If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
- TVM_ATTR_FIELD(data_layout).set_default("NCHW")
- .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Convolution is applied on the 'H' and"
- "'W' dimensions.");
- TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
- .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
- "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
- "dimensions respectively.");
- TVM_ATTR_FIELD(out_layout).set_default("")
- .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
- "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
- "dimensions respectively. Default to be same as input layout.");
+ TVM_ATTR_FIELD(data_layout)
+ .set_default("NCHW")
+ .describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Convolution is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .set_default("OIHW")
+ .describe(
+ "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
+ "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .set_default("")
+ .describe(
+ "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
#define TVM_RELAY_ATTRS_REDUCE_H_
#include <tvm/ir/attrs.h>
+
#include <string>
namespace tvm {
bool exclude;
TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
- TVM_ATTR_FIELD(axis).set_default(NullValue<Array<Integer>>())
+ TVM_ATTR_FIELD(axis)
+ .set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");
- TVM_ATTR_FIELD(keepdims).set_default(false)
- .describe("If this is set to `True`, the reduced axes are left "
- "in the result as dimension with size one.");
- TVM_ATTR_FIELD(exclude).set_default(false)
- .describe("Whether to perform reduction on axis that are NOT in axis instead.");
+ TVM_ATTR_FIELD(keepdims).set_default(false).describe(
+ "If this is set to `True`, the reduced axes are left "
+ "in the result as dimension with size one.");
+ TVM_ATTR_FIELD(exclude).set_default(false).describe(
+ "Whether to perform reduction on axis that are NOT in axis instead.");
}
};
} // namespace relay
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>
+
#include <string>
namespace tvm {
DataType dtype;
TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
- TVM_ATTR_FIELD(dtype)
- .describe("Target data type");
+ TVM_ATTR_FIELD(dtype).describe("Target data type");
}
}; // struct CastAttrs.
int num_newaxis;
TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs") {
- TVM_ATTR_FIELD(axis)
- .describe("The axis at which the input array is expanded."
- "Should lie in range `[-data.ndim - 1, data.ndim]`."
- "If `axis < 0`, it is the first axis inserted;"
- "If `axis >= 0`, it is the last axis inserted in Python's negative indexing.");
+ TVM_ATTR_FIELD(axis).describe(
+ "The axis at which the input array is expanded."
+ "Should lie in range `[-data.ndim - 1, data.ndim]`."
+ "If `axis < 0`, it is the first axis inserted;"
+ "If `axis >= 0`, it is the last axis inserted in Python's negative indexing.");
TVM_ATTR_FIELD(num_newaxis)
.describe("Number of axises to be inserted. Should be >= 0.")
.set_lower_bound(0)
int axis;
TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") {
TVM_ATTR_FIELD(axis)
- .describe("The axis at which the input arrays are concatenated."
- "Should lie in range `[-ndim, ndim)`.")
+ .describe(
+ "The axis at which the input arrays are concatenated."
+ "Should lie in range `[-ndim, ndim)`.")
.set_default(0);
}
}; // struct ConcatenateAttrs
struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
Array<Integer> axes;
TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
- TVM_ATTR_FIELD(axes)
- .describe("The target axes order, reverse order if not specified.");
+ TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified.");
}
}; // struct TransposeAttrs
Array<Integer> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
- TVM_ATTR_FIELD(newshape)
- .describe("The new shape. Should be compatible with the original shape.");
+ TVM_ATTR_FIELD(newshape).describe(
+ "The new shape. Should be compatible with the original shape.");
TVM_ATTR_FIELD(reverse)
.describe("Infer the special values from right to left if true")
.set_default(false);
std::string mode;
TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
- TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
+ TVM_ATTR_FIELD(axis)
+ .set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
- TVM_ATTR_FIELD(mode).set_default("clip")
- .describe("Specify how out-of-bound indices will behave."
- "clip - clip to the range (default)"
- "wrap - wrap around the indices"
- "fast - no clip or wrap around (user must make sure indices are in-bound)");
+ TVM_ATTR_FIELD(mode).set_default("clip").describe(
+ "Specify how out-of-bound indices will behave."
+ "clip - clip to the range (default)"
+ "wrap - wrap around the indices"
+ "fast - no clip or wrap around (user must make sure indices are in-bound)");
}
};
DataType dtype;
TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") {
- TVM_ATTR_FIELD(shape)
- .describe("Target shape.");
- TVM_ATTR_FIELD(dtype)
- .describe("Target data type.")
- .set_default(NullValue<DataType>());
+ TVM_ATTR_FIELD(shape).describe("Target shape.");
+ TVM_ATTR_FIELD(dtype).describe("Target data type.").set_default(NullValue<DataType>());
}
}; // struct InitOpAttrs
DataType dtype;
TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
- TVM_ATTR_FIELD(start)
- .describe("Start of interval. The interval includes this value.");
- TVM_ATTR_FIELD(stop)
- .describe("Stop of interval. The interval does not include this value.");
- TVM_ATTR_FIELD(step)
- .describe("Spacing between values.");
- TVM_ATTR_FIELD(dtype)
- .describe("Target data type.");
+ TVM_ATTR_FIELD(start).describe("Start of interval. The interval includes this value.");
+ TVM_ATTR_FIELD(stop).describe("Stop of interval. The interval does not include this value.");
+ TVM_ATTR_FIELD(step).describe("Spacing between values.");
+ TVM_ATTR_FIELD(dtype).describe("Target data type.");
}
}; // struct ArangeAttrs
struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
Integer axis;
TVM_DECLARE_ATTRS(StackAttrs, "relay.attrs.StackAttrs") {
- TVM_ATTR_FIELD(axis).set_default(0)
- .describe("The axis in the result array along which the input arrays are stacked.");
+ TVM_ATTR_FIELD(axis).set_default(0).describe(
+ "The axis in the result array along which the input arrays are stacked.");
}
}; // struct StackAttrs
Integer repeats;
Integer axis;
TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") {
- TVM_ATTR_FIELD(repeats)
- .describe("The number of repetitions for each element.");
- TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
+ TVM_ATTR_FIELD(repeats).describe("The number of repetitions for each element.");
+ TVM_ATTR_FIELD(axis)
+ .set_default(NullValue<Integer>())
.describe(" The axis along which to repeat values.");
}
}; // struct RepeatAttrs
struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
Array<Integer> reps;
TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") {
- TVM_ATTR_FIELD(reps)
- .describe("The number of times for repeating the tensor a."
- "Each dim sizeof reps must be a positive integer.");
+ TVM_ATTR_FIELD(reps).describe(
+ "The number of times for repeating the tensor a."
+ "Each dim sizeof reps must be a positive integer.");
}
}; // struct TileAttrs
struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
Integer axis;
TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") {
- TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
+ TVM_ATTR_FIELD(axis)
+ .set_default(NullValue<Integer>())
.describe("The axis along which to reverse elements.");
}
}; // struct ReverseAttrs
TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axis)
- .describe("The axis to squeeze in the input tensor."
- "If `axis = None`, all axis of dimension 1 get squeezed;"
- "Else, the dimension in axes get squeezed."
- "It is an error if an axis does not has dimension 1.")
+ .describe(
+ "The axis to squeeze in the input tensor."
+ "If `axis = None`, all axis of dimension 1 get squeezed;"
+ "Else, the dimension in axes get squeezed."
+ "It is an error if an axis does not has dimension 1.")
.set_default(NullValue<Array<Integer> >());
}
}; // struct SqueezeAttrs
TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
TVM_ATTR_FIELD(indices_or_sections)
- .describe("Indices or sections to split into. Accepts an int or a tuple"
- "If indices_or_sections is an integer, the input will be divided equally"
- "along given axis. If such a split is not possible, an error is raised."
- "If indices_or_sections is a tuple of sorted integers,"
- "the entries indicate where along axis the array is split.");
- TVM_ATTR_FIELD(axis).set_default(0)
- .describe("the axis to be splitted.");
+ .describe(
+ "Indices or sections to split into. Accepts an int or a tuple"
+ "If indices_or_sections is an integer, the input will be divided equally"
+ "along given axis. If such a split is not possible, an error is raised."
+ "If indices_or_sections is a tuple of sorted integers,"
+ "the entries indicate where along axis the array is split.");
+ TVM_ATTR_FIELD(axis).set_default(0).describe("the axis to be splitted.");
}
};
Array<Integer> strides;
TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
- TVM_ATTR_FIELD(begin)
- .describe("Indices for begin of slice, begin index is also inclusive");
- TVM_ATTR_FIELD(end)
- .describe("Indices for end of slice, end index is exclusive");
- TVM_ATTR_FIELD(strides).set_default(Array<Integer>({}))
- .describe("Stride values of the slice");
+ TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
+ TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive");
+ TVM_ATTR_FIELD(strides).set_default(Array<Integer>({})).describe("Stride values of the slice");
}
};
Array<Integer> axes;
TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") {
- TVM_ATTR_FIELD(axes)
- .describe("List of axes on which input data will be sliced according to the "
- "corresponding size of the second input. By default will slice "
- "on all axes. Negative axes mean counting in reverse.");
+ TVM_ATTR_FIELD(axes).describe(
+ "List of axes on which input data will be sliced according to the "
+ "corresponding size of the second input. By default will slice "
+ "on all axes. Negative axes mean counting in reverse.");
}
};
double a_max;
TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
- TVM_ATTR_FIELD(a_min)
- .describe("The minimum clip value.");
- TVM_ATTR_FIELD(a_max)
- .describe("The maximum clip value.");
+ TVM_ATTR_FIELD(a_min).describe("The minimum clip value.");
+ TVM_ATTR_FIELD(a_max).describe("The maximum clip value.");
}
};
std::string dst_layout;
TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") {
- TVM_ATTR_FIELD(src_layout)
- .describe("The source layout of the tensor. (e.g. NCHW)");
- TVM_ATTR_FIELD(dst_layout)
- .describe("The destination layout of the tensor. (e.g. NCHW16c)");
+ TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. NCHW)");
+ TVM_ATTR_FIELD(dst_layout).describe("The destination layout of the tensor. (e.g. NCHW16c)");
}
};
DataType dtype;
TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") {
- TVM_ATTR_FIELD(dtype)
- .describe("Target data type")
- .set_default(NullValue<DataType>());
+ TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue<DataType>());
}
};
int axis;
TVM_DECLARE_ATTRS(SequenceMaskAttrs, "relay.attrs.SequenceMaskAttrs") {
- TVM_ATTR_FIELD(mask_value).set_default(0)
- .describe("The masking value.");
- TVM_ATTR_FIELD(axis).set_default(0)
- .describe("The axis of the length dimension. Can only be 0 or 1.");
+ TVM_ATTR_FIELD(mask_value).set_default(0).describe("The masking value.");
+ TVM_ATTR_FIELD(axis).set_default(0).describe(
+ "The axis of the length dimension. Can only be 0 or 1.");
}
}; // struct SequenceMaskAttrs.
DataType dtype;
TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs") {
- TVM_ATTR_FIELD(dtype)
- .describe("Target data type")
- .set_default(NullValue<DataType>());
+ TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue<DataType>());
}
};
DataType dtype;
TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") {
- TVM_ATTR_FIELD(depth).set_default(1)
- .describe("Depth of the one hot dimension.");
- TVM_ATTR_FIELD(axis).set_default(-1)
- .describe("Axis to fill.");
- TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
- .describe("Output data type.");
+ TVM_ATTR_FIELD(depth).set_default(1).describe("Depth of the one hot dimension.");
+ TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill.");
+ TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>()).describe("Output data type.");
}
}; // struct OneHotAttrs
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
+
#include <string>
namespace tvm {
TVM_DECLARE_ATTRS(MultiBoxPriorAttrs, "relay.attrs.MultiBoxPriorAttrs") {
TVM_ATTR_FIELD(sizes)
- .set_default(Array<IndexExpr>({static_cast<float>(1.0)}))
- .describe("List of sizes of generated MultiBoxPriores.");
+ .set_default(Array<IndexExpr>({static_cast<float>(1.0)}))
+ .describe("List of sizes of generated MultiBoxPriores.");
TVM_ATTR_FIELD(ratios)
- .set_default(Array<IndexExpr>({static_cast<float>(1.0)}))
- .describe("List of aspect ratios of generated MultiBoxPriores.");
+ .set_default(Array<IndexExpr>({static_cast<float>(1.0)}))
+ .describe("List of aspect ratios of generated MultiBoxPriores.");
TVM_ATTR_FIELD(steps)
- .set_default(Array<IndexExpr>({static_cast<float>(-1.0),
- static_cast<float>(-1.0)}))
- .describe("Priorbox step across y and x, -1 for auto calculation.");
+ .set_default(Array<IndexExpr>({static_cast<float>(-1.0), static_cast<float>(-1.0)}))
+ .describe("Priorbox step across y and x, -1 for auto calculation.");
TVM_ATTR_FIELD(offsets)
- .set_default(Array<IndexExpr>({static_cast<float>(0.5),
- static_cast<float>(0.5)}))
- .describe("Priorbox center offsets, y and x respectively.");
- TVM_ATTR_FIELD(clip).set_default(false)
- .describe("Whether to clip out-of-boundary boxes.");
+ .set_default(Array<IndexExpr>({static_cast<float>(0.5), static_cast<float>(0.5)}))
+ .describe("Priorbox center offsets, y and x respectively.");
+ TVM_ATTR_FIELD(clip).set_default(false).describe("Whether to clip out-of-boundary boxes.");
}
};
-struct MultiBoxTransformLocAttrs
- : public tvm::AttrsNode<MultiBoxTransformLocAttrs> {
+struct MultiBoxTransformLocAttrs : public tvm::AttrsNode<MultiBoxTransformLocAttrs> {
bool clip;
double threshold;
Array<IndexExpr> variances;
- TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs,
- "relay.attrs.MultiBoxTransformLocAttrs") {
- TVM_ATTR_FIELD(clip).set_default(true)
- .describe("Clip out-of-boundary boxes.");
- TVM_ATTR_FIELD(threshold).set_default(0.01)
- .describe("Threshold to be a positive prediction.");
+ TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, "relay.attrs.MultiBoxTransformLocAttrs") {
+ TVM_ATTR_FIELD(clip).set_default(true).describe("Clip out-of-boundary boxes.");
+ TVM_ATTR_FIELD(threshold).set_default(0.01).describe("Threshold to be a positive prediction.");
TVM_ATTR_FIELD(variances)
- .set_default(Array<IndexExpr>({0.1f, 0.1f , 0.2f, 0.2f}))
- .describe("Variances to be decoded from box regression output.");
+ .set_default(Array<IndexExpr>({0.1f, 0.1f, 0.2f, 0.2f}))
+ .describe("Variances to be decoded from box regression output.");
}
};
int score_index;
TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") {
- TVM_ATTR_FIELD(score_threshold).set_default(0.0)
- .describe("Lower limit of score for valid bounding boxes.");
- TVM_ATTR_FIELD(id_index).set_default(0)
- .describe("Axis index of id.");
- TVM_ATTR_FIELD(score_index).set_default(1)
- .describe("Index of the scores/confidence of boxes.");
+ TVM_ATTR_FIELD(score_threshold)
+ .set_default(0.0)
+ .describe("Lower limit of score for valid bounding boxes.");
+ TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id.");
+ TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the scores/confidence of boxes.");
}
};
bool invalid_to_bottom;
TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") {
- TVM_ATTR_FIELD(max_output_size).set_default(-1)
- .describe("Max number of output valid boxes for each instance."
- "By default all valid boxes are returned.");
- TVM_ATTR_FIELD(iou_threshold).set_default(0.5)
- .describe("Non-maximum suppression threshold.");
- TVM_ATTR_FIELD(force_suppress).set_default(false)
- .describe("Suppress all detections regardless of class_id.");
- TVM_ATTR_FIELD(top_k).set_default(-1)
- .describe("Keep maximum top k detections before nms, -1 for no limit.");
- TVM_ATTR_FIELD(coord_start).set_default(2)
- .describe("Start index of the consecutive 4 coordinates.");
- TVM_ATTR_FIELD(score_index).set_default(1)
- .describe("Index of the scores/confidence of boxes.");
- TVM_ATTR_FIELD(id_index).set_default(0)
- .describe("Axis index of id.");
- TVM_ATTR_FIELD(return_indices).set_default(true)
- .describe("Whether to return box indices in input data.");
- TVM_ATTR_FIELD(invalid_to_bottom).set_default(false)
- .describe("Whether to move all invalid bounding boxes to the bottom.");
+ TVM_ATTR_FIELD(max_output_size)
+ .set_default(-1)
+ .describe(
+ "Max number of output valid boxes for each instance."
+ "By default all valid boxes are returned.");
+ TVM_ATTR_FIELD(iou_threshold).set_default(0.5).describe("Non-maximum suppression threshold.");
+ TVM_ATTR_FIELD(force_suppress)
+ .set_default(false)
+ .describe("Suppress all detections regardless of class_id.");
+ TVM_ATTR_FIELD(top_k).set_default(-1).describe(
+ "Keep maximum top k detections before nms, -1 for no limit.");
+ TVM_ATTR_FIELD(coord_start)
+ .set_default(2)
+ .describe("Start index of the consecutive 4 coordinates.");
+ TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the scores/confidence of boxes.");
+ TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id.");
+ TVM_ATTR_FIELD(return_indices)
+ .set_default(true)
+ .describe("Whether to return box indices in input data.");
+ TVM_ATTR_FIELD(invalid_to_bottom)
+ .set_default(false)
+ .describe("Whether to move all invalid bounding boxes to the bottom.");
}
};
Integer stride;
TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") {
- TVM_ATTR_FIELD(stride)
- .set_default(1)
- .describe("Stride value for yolo reorg");
+ TVM_ATTR_FIELD(stride).set_default(1).describe("Stride value for yolo reorg");
}
};
.describe(
"The size of the receptive field each unit in the convolution layer of the rpn,"
"for example the product of all stride's prior to this layer.");
- TVM_ATTR_FIELD(threshold)
- .set_default(0.7)
- .describe(
- "IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)");
+ TVM_ATTR_FIELD(threshold).set_default(0.7).describe(
+ "IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)");
TVM_ATTR_FIELD(rpn_pre_nms_top_n)
.set_default(6000)
.describe("Number of top scoring boxes to apply NMS. -1 to use all boxes");
#ifndef TVM_RELAY_BASE_H_
#define TVM_RELAY_BASE_H_
-
#include <tvm/ir/span.h>
-#include <tvm/tir/expr.h>
#include <tvm/node/node.h>
+#include <tvm/tir/expr.h>
+
#include <string>
#include <vector>
*/
namespace relay {
-#define RELAY_DEBUG(...) \
-{ auto fdebug = runtime::Registry::Get("relay.debug"); \
- CHECK(fdebug) << "Could not find Relay Python debugger function."; \
- (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
-}
+#define RELAY_DEBUG(...) \
+ { \
+ auto fdebug = runtime::Registry::Get("relay.debug"); \
+ CHECK(fdebug) << "Could not find Relay Python debugger function."; \
+ (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
+ }
-#define RELAY_DEBUG_INTERP(...) \
-{ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \
- CHECK(fdebug) << "Could not find Relay Python debugger function."; \
- (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
-}
+#define RELAY_DEBUG_INTERP(...) \
+ { \
+ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \
+ CHECK(fdebug) << "Could not find Relay Python debugger function."; \
+ (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
+ }
/*!
* \brief Symbolic expression for tensor shape.
*/
std::string name_hint;
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("name_hint", &name_hint);
- }
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); }
static constexpr const char* _type_key = "relay.Id";
TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object);
#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
-#include <tvm/ir/op.h>
#include <tvm/ir/module.h>
-#include <string>
+#include <tvm/ir/op.h>
+
#include <functional>
+#include <string>
+
#include "./base.h"
#include "./type.h"
TensorType tensor_type() const;
/*! \return Whether it is scalar(rank-0 tensor) */
- bool is_scalar() const {
- return data->ndim == 0;
- }
+ bool is_scalar() const { return data->ndim == 0; }
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
return equal(data, other->data);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(data);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
Type type_annotation;
/*! \return The name hint of the variable */
- const std::string& name_hint() const {
- return vid->name_hint;
- }
+ const std::string& name_hint() const { return vid->name_hint; }
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid);
}
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
- return
- equal(type_annotation, other->type_annotation) &&
- equal.FreeVarEqualImpl(this, other);
+ return equal(type_annotation, other->type_annotation) && equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.FreeVarHashImpl(this);
}
- TVM_DLL static Var make(std::string name_hint,
- Type type_annotation);
+ TVM_DLL static Var make(std::string name_hint, Type type_annotation);
- TVM_DLL static Var make(Id vid,
- Type type_annotation);
+ TVM_DLL static Var make(Id vid, Type type_annotation);
static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
* \param name_hint The name hint of a variable.
* \param type_annotation The type annotation of a variable.
*/
- TVM_DLL Var(std::string name_hint, Type type_annotation) :
- Var(Id(name_hint), type_annotation) {}
+ TVM_DLL Var(std::string name_hint, Type type_annotation) : Var(Id(name_hint), type_annotation) {}
/*!
* \brief The constructor
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
// skip type_args check for primitive ops.
equal->MarkGraphNode();
- return
- equal(op, other->op) &&
- equal(args, other->args) &&
- equal(attrs, other->attrs) &&
- (IsPrimitiveOp(op) || equal(type_args, other->type_args));
+ return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) &&
+ (IsPrimitiveOp(op) || equal(type_args, other->type_args));
}
void SHashReduce(SHashReducer hash_reduce) const {
* \param attrs The attributes of the call node.
* \param type_args The type arguments passed to a polymorphic function.
*/
- TVM_DLL Call(Expr op,
- Array<Expr> args,
- Attrs attrs = Attrs(),
+ TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
Array<Type> type_args = Array<Type>());
TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
- return
- equal.DefEqual(var, other->var) &&
- equal(value, other->value) &&
- equal(body, other->body);
+ return equal.DefEqual(var, other->var) && equal(value, other->value) &&
+ equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
- return
- equal(cond, other->cond) &&
- equal(true_branch, other->true_branch) &&
- equal(false_branch, other->false_branch);
+ return equal(cond, other->cond) && equal(true_branch, other->true_branch) &&
+ equal(false_branch, other->false_branch);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
- return
- equal(tuple, other->tuple) &&
- equal(index, other->index);
+ return equal(tuple, other->tuple) && equal(index, other->index);
}
void SHashReduce(SHashReducer hash_reduce) const {
bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
- return
- equal(ref, other->ref) &&
- equal(value, other->value);
+ return equal(ref, other->ref) && equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_
#define TVM_RELAY_EXPR_FUNCTOR_H_
-#include <tvm/node/functor.h>
#include <tvm/ir/error.h>
+#include <tvm/node/functor.h>
+#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
-#include <tvm/relay/adt.h>
#include <tvm/relay/op.h>
#include <string>
-#include <utility>
#include <unordered_map>
+#include <utility>
namespace tvm {
namespace relay {
class ExprFunctor;
// functions to be overriden.
-#define EXPR_FUNCTOR_DEFAULT \
+#define EXPR_FUNCTOR_DEFAULT \
{ return VisitExprDefault_(op, std::forward<Args>(args)...); }
-#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
- vtable.template set_dispatch<OP>( \
- [](const ObjectRef& n, TSelf* self, Args... args) { \
- return self->VisitExpr_(static_cast<const OP*>(n.get()), \
- std::forward<Args>(args)...); \
- });
+#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
+ return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+ });
template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
* \param args Additional arguments.
* \return The result of the call
*/
- R operator()(const Expr& n, Args... args) {
- return VisitExpr(n, std::forward<Args>(args)...);
- }
+ R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); }
/*!
* \brief The functor call.
* \param n The expression node.
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
- virtual R VisitExpr_(const ConstantNode* op,
- Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const TupleNode* op,
- Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const VarNode* op,
- Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const GlobalVarNode* op,
- Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const FunctionNode* op,
- Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const IfNode* op,
- Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const OpNode* op,
- Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
* ExprVisitor treats Expr as dataflow graph,
* and only visit each Expr node once.
*/
-class ExprVisitor
- : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
+class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
public:
void VisitExpr(const Expr& expr) override;
void VisitExpr_(const VarNode* op) override;
* The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
*/
-class ExprMutator
- : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
+class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public:
/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
- Expr Mutate(const Expr& expr) {
- return this->VisitExpr(expr);
- }
+ Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); }
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override;
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
*
- * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive behavior.
+ * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive
+ * behavior.
*/
class MixedModeMutator : public ::tvm::relay::ExprMutator {
public:
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
/*!
- * \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be
- * able to rewrite the op only with data about the original node `pre` and the same node with
+ * \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will
+ * be able to rewrite the op only with data about the original node `pre` and the same node with
* modified inputs `post` and should not recurse.
*
* \param pre The expression node before rewriting.
* \param post The expression with rewritten inputs.
*/
- virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;}
+ virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post; }
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }
* \param post The expression node with rewritten inputs.
* \return The result of the call
*/
- Expr operator()(const Expr& pre, const Expr& post) {
- return Rewrite(pre, post);
- }
+ Expr operator()(const Expr& pre, const Expr& post) { return Rewrite(pre, post); }
/*!
* \brief The functor call.
* \param pre The expression node before rewriting.
#ifndef TVM_RELAY_FEATURE_H_
#define TVM_RELAY_FEATURE_H_
+#include <tvm/ir/module.h>
#include <tvm/node/container.h>
#include <tvm/relay/expr.h>
-#include <tvm/ir/module.h>
#include <bitset>
public:
FeatureSet(const FeatureSet&) = default;
/*! \brief A singleton set containing a single Feature. */
- explicit FeatureSet(Feature ft) {
- bs_.set(static_cast<size_t>(ft));
- }
+ explicit FeatureSet(Feature ft) { bs_.set(static_cast<size_t>(ft)); }
explicit FeatureSet(const tvm::Array<tvm::Integer>& ft) {
for (Integer i : ft) {
(*this) += Feature(static_cast<int>(i));
FeatureSet fs;
return fs;
}
- template<typename T>
+ template <typename T>
FeatureSet& operator+=(const T& rhs) {
bs_ |= FeatureSet(rhs).bs_;
return *this;
}
/*! \brief Set union. */
- template<typename T>
+ template <typename T>
FeatureSet operator+(const T& rhs) const {
FeatureSet fs(*this);
fs += rhs;
return fs;
}
- template<typename T>
+ template <typename T>
FeatureSet& operator-=(const T& rhs) {
bs_ &= ~(FeatureSet(rhs)).bs_;
return *this;
}
/*! \brief Set difference. */
- template<typename T>
+ template <typename T>
FeatureSet operator-(const T& rhs) const {
FeatureSet fs(*this);
fs -= rhs;
*
* \return true only if this is a subset of rhs.
*/
- bool is_subset_of(const FeatureSet& rhs) const {
- return ((*this) - rhs).bs_.none();
- }
+ bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); }
private:
std::bitset<feature_count> bs_;
FeatureSet() = default;
- explicit FeatureSet(const std::bitset<feature_count>& bs) : bs_(bs) { }
+ explicit FeatureSet(const std::bitset<feature_count>& bs) : bs_(bs) {}
};
/*!
#include <tvm/ir/function.h>
#include <tvm/relay/expr.h>
-#include <string>
+#include <string>
namespace tvm {
namespace relay {
bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
// Important to make def equal first.
equal->MarkGraphNode();
- return
- equal.DefEqual(params, other->params) &&
- equal.DefEqual(type_params, other->type_params) &&
- equal(ret_type, other->ret_type) &&
- equal(attrs, other->attrs) &&
- equal(body, other->body);
+ return equal.DefEqual(params, other->params) &&
+ equal.DefEqual(type_params, other->type_params) && equal(ret_type, other->ret_type) &&
+ equal(attrs, other->attrs) && equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
-
/*!
* \brief Managed reference to FunctionNode.
* \sa FunctionNode
* \param ty_params The type parameters.
* \param attrs Additional function attributes.
*/
- TVM_DLL Function(tvm::Array<Var> params,
- Expr body,
- Type ret_type,
- tvm::Array<TypeVar> ty_params,
+ TVM_DLL Function(tvm::Array<Var> params, Expr body, Type ret_type, tvm::Array<TypeVar> ty_params,
tvm::DictAttrs attrs = NullValue<DictAttrs>());
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
-#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
+#include <tvm/runtime/object.h>
#include <tvm/runtime/vm.h>
#include <tvm/target/target.h>
-
namespace tvm {
namespace relay {
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
-runtime::TypedPackedFunc<ObjectRef(Expr)>
-CreateInterpreter(IRModule mod, DLContext context, Target target);
+runtime::TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, DLContext context,
+ Target target);
/*! \brief The container type of Closures used by the interpreter. */
class InterpreterClosureObj : public runtime::vm::ClosureObj {
class InterpreterClosure : public runtime::vm::Closure {
public:
TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func);
- TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure,
- InterpreterClosureObj);
+ TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, InterpreterClosureObj);
};
/*! \brief The container type of RecClosure. */
RefValueObj() {}
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("value", &value);
- }
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); }
static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object);
class ConstructorValue : public ObjectRef {
public:
- TVM_DLL ConstructorValue(int32_t tag,
- tvm::Array<ObjectRef> fields,
- Constructor construtor = {});
+ TVM_DLL ConstructorValue(int32_t tag, tvm::Array<ObjectRef> fields, Constructor construtor = {});
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};
#define TVM_RELAY_OP_H_
#include <tvm/ir/op.h>
-#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
+#include <tvm/relay/type.h>
namespace tvm {
namespace relay {
using Op = tvm::Op;
using OpNode = tvm::OpNode;
-#define RELAY_REGISTER_OP(OpName) \
- TVM_REGISTER_OP(OpName)
+#define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName)
} // namespace relay
} // namespace tvm
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
#define TVM_RELAY_OP_ATTR_TYPES_H_
-#include <tvm/te/tensor.h>
-#include <tvm/te/schedule.h>
-#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
-#include <tvm/target/target.h>
+#include <tvm/relay/type.h>
#include <tvm/target/generic_func.h>
+#include <tvm/target/target.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/tensor.h>
#include <tvm/tir/data_layout.h>
+
#include <string>
namespace tvm {
namespace relay {
+using tir::BijectiveLayoutNode;
using tir::Layout;
using tir::LayoutAxis;
-using tir::BijectiveLayoutNode;
/*! \brief operator pattern used in graph fusion */
enum OpPatternKind {
& these are always placeholders.
* \return The output compute description of the operator.
*/
-using FTVMCompute = runtime::TypedPackedFunc<
- Array<te::Tensor>(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type)>;
+using FTVMCompute = runtime::TypedPackedFunc<Array<te::Tensor>(
+ const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type)>;
/*!
* \brief Build the computation schedule for
* \param target The build target.
* \return schedule The computation schedule.
*/
-using FTVMSchedule = runtime::TypedPackedFunc<
- te::Schedule(const Attrs& attrs,
- const Array<te::Tensor>& outs,
- const Target& target)>;
+using FTVMSchedule = runtime::TypedPackedFunc<te::Schedule(
+ const Attrs& attrs, const Array<te::Tensor>& outs, const Target& target)>;
/*!
* \brief Generate the strategy of operators. This function is a generic
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
-using FTVMAlterOpLayout = runtime::TypedPackedFunc<
- Expr(const Attrs& attrs,
- const Array<Expr>& args,
- const Array<te::Tensor>& tinfos,
- const Type& out_type)>;
+using FTVMAlterOpLayout =
+ runtime::TypedPackedFunc<Expr(const Attrs& attrs, const Array<Expr>& args,
+ const Array<te::Tensor>& tinfos, const Type& out_type)>;
/*!
* \brief Convert the layout of operators or replace the
* \param desired_layout The desired layout.
* \return new_expr The modified expression.
*/
-using FTVMConvertOpLayout = runtime::TypedPackedFunc<
- Expr(const Attrs& attrs,
- const Array<Expr>& args,
- const Array<te::Tensor>& tinfos,
- const std::string& desired_layout)>;
+using FTVMConvertOpLayout = runtime::TypedPackedFunc<Expr(
+ const Attrs& attrs, const Array<Expr>& args, const Array<te::Tensor>& tinfos,
+ const std::string& desired_layout)>;
/*!
* \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass.
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
-using FTVMLegalize = runtime::TypedPackedFunc<
- Expr(const Attrs& attrs,
- const Array<Expr>& args,
- const Array<tvm::relay::Type>& arg_types)>;
+using FTVMLegalize = runtime::TypedPackedFunc<Expr(const Attrs& attrs, const Array<Expr>& args,
+ const Array<tvm::relay::Type>& arg_types)>;
/*!
* \brief Annotates an expression to indicate if an op should be compiled using
* \return true if this op should be registered to invoke a specific compiler
* for codegen, otherwise, false.
*/
-using FTVMAnnotateTarget = runtime::TypedPackedFunc<
- bool(const Attrs& attrs, // NOLINT(*)
- const Array<Expr>& args)>;
+using FTVMAnnotateTarget = runtime::TypedPackedFunc<bool(const Attrs& attrs, // NOLINT(*)
+ const Array<Expr>& args)>;
/*!
* \brief Forward rewriting rule for a specific op.
* \note When we register the function, we can register
* a different signature with ctx to be a specific node type.
*/
-using FForwardRewrite = runtime::TypedPackedFunc<
- Expr(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx)>;
+using FForwardRewrite = runtime::TypedPackedFunc<Expr(
+ const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx)>;
/*!
* \brief Gradient for a specific op.
* \param output_grad the gradient of the Expr.
* \return the gradient for each parameters.
*/
-using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
- const Expr& output_grad)>;
+using FPrimalGradient =
+ runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call, const Expr& output_grad)>;
/*!
* \brief The codegeneration strategy for dynamic dimensions.
/*! \brief A runtime representation of shape. */
using Shape = Array<IndexExpr>;
-using FShapeFunc = runtime::TypedPackedFunc<
- Array<te::Tensor>(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Array<IndexExpr>& out_ndims)>;
+using FShapeFunc = runtime::TypedPackedFunc<Array<te::Tensor>(
+ const Attrs& attrs, const Array<te::Tensor>& inputs, const Array<IndexExpr>& out_ndims)>;
} // namespace relay
} // namespace tvm
#ifndef TVM_RELAY_OP_STRATEGY_H_
#define TVM_RELAY_OP_STRATEGY_H_
-#include <tvm/te/tensor.h>
-#include <tvm/te/schedule.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/target/target.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/tensor.h>
+
#include <string>
namespace tvm {
* \param out_type The output type information.
* \return The output compute description of the operator.
*/
- TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+ TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type);
/*!
* \brief Build the computation schedule.
* \param target The build target.
* \return The computation schedule.
*/
- TVM_DLL te::Schedule Schedule(const Attrs& attrs,
- const Array<te::Tensor>& outs,
+ TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array<te::Tensor>& outs,
const Target& target);
TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
- TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
- std::string name, int plevel);
+ TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name,
+ int plevel);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
};
/*! \brief List of operator specializations. */
Array<OpSpecialization> specializations;
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("specializations", &specializations);
- }
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); }
static constexpr const char* _type_key = "relay.OpStrategy";
TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
- TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
- std::string name, int plevel);
+ TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name,
+ int plevel);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
};
#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
#define TVM_RELAY_PATTERN_FUNCTOR_H_
-#include <tvm/node/functor.h>
#include <tvm/ir/error.h>
+#include <tvm/node/functor.h>
#include <string>
-#include <utility>
#include <unordered_map>
+#include <utility>
+#include "./adt.h"
#include "./expr.h"
#include "./op.h"
-#include "./adt.h"
namespace tvm {
namespace relay {
class PatternFunctor;
// functions to be overriden.
-#define PATTERN_FUNCTOR_DEFAULT \
+#define PATTERN_FUNCTOR_DEFAULT \
{ return VisitPatternDefault_(op, std::forward<Args>(args)...); }
-#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \
- vtable.template set_dispatch<OP>( \
- [](const ObjectRef& n, TSelf* self, Args... args) { \
- return self->VisitPattern_(static_cast<const OP*>(n.get()), \
- std::forward<Args>(args)...); \
- });
+#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
+ return self->VisitPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+ });
template <typename R, typename... Args>
class PatternFunctor<R(const Pattern& n, Args...)> {
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
- virtual R VisitPattern_(const PatternWildcardNode* op,
- Args... args) PATTERN_FUNCTOR_DEFAULT;
- virtual R VisitPattern_(const PatternVarNode* op,
- Args... args) PATTERN_FUNCTOR_DEFAULT;
- virtual R VisitPattern_(const PatternConstructorNode* op,
- Args... args) PATTERN_FUNCTOR_DEFAULT;
- virtual R VisitPattern_(const PatternTupleNode* op,
- Args... args) PATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitPattern_(const PatternWildcardNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitPattern_(const PatternVarNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitPattern_(const PatternConstructorNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw;
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
-class PatternMutator
- : public ::tvm::relay::PatternFunctor<Pattern(const Pattern&)> {
+class PatternMutator : public ::tvm::relay::PatternFunctor<Pattern(const Pattern&)> {
public:
Pattern Mutate(const Pattern& pat);
Pattern VisitPattern_(const PatternWildcardNode* op) override;
virtual Var VisitVar(const Var& v);
/*! \brief Used to visit the vars inside of patterns. */
virtual Constructor VisitConstructor(const Constructor& c);
+
private:
std::unordered_map<Var, Var, ObjectHash, ObjectEqual> var_map_;
};
#define TVM_RELAY_QNN_ATTRS_H_
#include <tvm/ir/attrs.h>
+
#include <string>
namespace tvm {
TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(axis)
- .describe("The output channel axis for channel wise quantization. Default value is -1,"
- "which corresponds to the last axis.")
- .set_default(-1);
- TVM_ATTR_FIELD(rounding).set_default("UPWARD")
- .describe("Defines the rounding direction when the value is midway between"
- "two representable values. There are two supported modes - UPWARD"
- "or TONEAREST. Both modes behave exactly same except at the"
- "midpoints between the two representable values. At the midpoint,"
- "UPWARD rounds towards positive infinity (for example -1.5 will be"
- "rounded to -1). TONEAREST is the standard rounding where the"
- "value is rounded away from zero at midpoints (for example, -1.5"
- "rounds to -2). More context can be found at following gblic manual"
- "https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
+ .describe(
+ "The output channel axis for channel wise quantization. Default value is -1,"
+ "which corresponds to the last axis.")
+ .set_default(-1);
+ TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe(
+ "Defines the rounding direction when the value is midway between"
+ "two representable values. There are two supported modes - UPWARD"
+ "or TONEAREST. Both modes behave exactly same except at the"
+ "midpoints between the two representable values. At the midpoint,"
+ "UPWARD rounds towards positive infinity (for example -1.5 will be"
+ "rounded to -1). TONEAREST is the standard rounding where the"
+ "value is rounded away from zero at midpoints (for example, -1.5"
+ "rounds to -2). More context can be found at following gblic manual"
+ "https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
int axis;
TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
- TVM_ATTR_FIELD(out_dtype)
- .describe("Output data type, can be one of [int8 or uint8].");
+ TVM_ATTR_FIELD(out_dtype).describe("Output data type, can be one of [int8 or uint8].");
TVM_ATTR_FIELD(axis)
- .describe("The output channel axis for channel wise quantization. Default value is -1,"
- "which corresponds to the last axis.")
- .set_default(-1);
+ .describe(
+ "The output channel axis for channel wise quantization. Default value is -1,"
+ "which corresponds to the last axis.")
+ .set_default(-1);
}
};
#ifndef TVM_RELAY_QNN_TRANSFORM_H_
#define TVM_RELAY_QNN_TRANSFORM_H_
-#include <tvm/runtime/c_runtime_api.h>
#include <tvm/relay/transform.h>
+#include <tvm/runtime/c_runtime_api.h>
namespace tvm {
namespace relay {
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
-#include <tvm/runtime/container.h>
-#include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h>
+#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
-#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
#include <string>
*
* \return The created function pass.
*/
-TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
- Function(Function, IRModule, PassContext)>& pass_func,
- int opt_level,
- const std::string& name,
- const tvm::Array<runtime::String>& required);
+TVM_DLL Pass CreateFunctionPass(
+ const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
+ int opt_level, const std::string& name, const tvm::Array<runtime::String>& required);
/*! \brief Remove expressions which does not effect the program result.
*
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
/*!
-* \brief Convert all expressions of TensorType into GradCell,
-* an algebraic data type defined in gradient.rly.
-*
-* This will delay or decrease memory usage. All calls to
-* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory,
-* rather only instantiate if needed. It also defines + and * operation
-* between GradCell types which can increase performance when using
-* zero-filled or one-filled tensors, which is the case in reverse mode ad.
-*
-* \return the pass
-*/
+ * \brief Convert all expressions of TensorType into GradCell,
+ * an algebraic data type defined in gradient.rly.
+ *
+ * This will delay or decrease memory usage. All calls to
+ * ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory,
+ * rather only instantiate if needed. It also defines + and * operation
+ * between GradCell types which can increase performance when using
+ * zero-filled or one-filled tensors, which is the case in reverse mode ad.
+ *
+ * \return the pass
+ */
TVM_DLL Pass LazyGradientInit();
/*!
* \return A type checked Function with its checked_type field populated.
* \note this function mutates mod and is not thread-safe.
*/
-TVM_DLL Function InferType(const Function& f,
- const IRModule& mod,
- const GlobalVar& var);
+TVM_DLL Function InferType(const Function& f, const IRModule& mod, const GlobalVar& var);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
-TVM_DLL Expr ForwardRewrite(const Expr& expr,
- const std::string& rewrite_map_attr_name,
+TVM_DLL Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_attr_name,
std::function<ObjectRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
*
* \return The rewritten expression.
*/
-TVM_DLL Expr ForwardRewrite(const Expr& expr,
- const FForwardRewrite& rewrite_func,
+TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
std::function<ObjectRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
#ifndef TVM_RELAY_TYPE_H_
#define TVM_RELAY_TYPE_H_
-#include <tvm/ir/type.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/env_func.h>
#include <tvm/ir/tensor_type.h>
+#include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h>
-#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
-#include <tvm/ir/env_func.h>
#include <tvm/tir/expr.h>
+
#include <string>
#include "base.h"
-
namespace tvm {
namespace relay {
*
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
*/
-typedef int (*TVMBackendPackedCFunc)(TVMValue* args,
- int* type_codes,
- int num_args,
- TVMValue* out_ret_value,
- int* out_ret_tcode);
+typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args,
+ TVMValue* out_ret_value, int* out_ret_tcode);
/*!
* \brief Backend function for modules to get function
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
-TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
- const char* func_name,
- TVMFunctionHandle *out);
+TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out);
/*!
* \brief Backend function to register system-wide library symbol.
*
* certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success
*/
-TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
- int device_id,
- uint64_t nbytes,
- int dtype_code_hint,
- int dtype_bits_hint);
+TVM_DLL void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes,
+ int dtype_code_hint, int dtype_bits_hint);
/*!
* \brief Backend function to free temporal workspace.
*
* \sa TVMBackendAllocWorkspace
*/
-TVM_DLL int TVMBackendFreeWorkspace(int device_type,
- int device_id,
- void* ptr);
+TVM_DLL int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr);
/*!
* \brief Environment for TVM parallel task.
* \param penv The parallel environment backs the execution.
* \param cdata The supporting closure data.
*/
-typedef int (*FTVMParallelLambda)(
- int task_id, TVMParallelGroupEnv* penv, void* cdata);
+typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata);
/*!
* \brief Backend function for running parallel jobs.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
-TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
- void* cdata,
- int num_task);
+TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task);
/*!
* \brief BSP barrrier between parallel threads
*/
TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
-
/*!
* \brief Simple static initialization function.
* Run f once and set handle to be not null.
* \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens
*/
-TVM_DLL int TVMBackendRunOnce(void** handle,
- int (*f)(void*),
- void *cdata,
- int nbytes);
+TVM_DLL int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes);
#ifdef __cplusplus
} // TVM_EXTERN_C
// TVM version
#define TVM_VERSION "0.7.dev1"
-
// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
#ifdef __cplusplus
extern "C" {
#endif
-#include <stdint.h>
#include <stddef.h>
+#include <stdint.h>
/*! \brief type of array index. */
typedef int64_t tvm_index_t;
* this function is threadsafe and can be called by different thread
* \return error info
*/
-TVM_DLL const char *TVMGetLastError(void);
+TVM_DLL const char* TVMGetLastError(void);
/*!
* \brief Load module from file.
* \param file_name The file name to load the module from.
* \note The resulting module do not contain import relation.
* It can be reconstructed by TVMModImport.
*/
-TVM_DLL int TVMModLoadFromFile(const char* file_name,
- const char* format,
- TVMModuleHandle* out);
+TVM_DLL int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out);
/*!
* \brief Add dep to mod's dependency.
* \param dep The dependent module to be imported.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMModImport(TVMModuleHandle mod,
- TVMModuleHandle dep);
+TVM_DLL int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep);
/*!
* \brief Get function from the module.
* \param out The result function, can be NULL if it is not available.
* \return 0 when no error is thrown, -1 when failure happens
*/
-TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
- const char* func_name,
- int query_imports,
- TVMFunctionHandle *out);
+TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports,
+ TVMFunctionHandle* out);
/*!
* \brief Free the Module
* The front-end need to call free function (e.g. TVMFuncFree)
* to free these handles.
*/
-TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
- TVMValue* arg_values,
- int* type_codes,
- int num_args,
- TVMValue* ret_val,
- int* ret_type_code);
+TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args,
+ TVMValue* ret_val, int* ret_type_code);
/*!
* \brief Set the return value of TVMPackedCFunc.
* \param type_code The type of the value to be returned.
* \param num_ret Number of return values, for now only 1 is supported.
*/
-TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
- TVMValue* value,
- int* type_code,
- int num_ret);
+TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret);
/*!
* \brief Inplace translate callback argument value to return value.
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
* \sa TVMCFuncSetReturn
*/
-typedef int (*TVMPackedCFunc)(
- TVMValue* args,
- int* type_codes,
- int num_args,
- TVMRetValueHandle ret,
- void* resource_handle);
+typedef int (*TVMPackedCFunc)(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret,
+ void* resource_handle);
/*!
* \brief C callback to free the resource handle in C packed function.
* \param out the result function handle.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
- void* resource_handle,
- TVMPackedCFuncFinalizer fin,
- TVMFunctionHandle *out);
+TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle,
+ TVMPackedCFuncFinalizer fin, TVMFunctionHandle* out);
/*!
* \brief Register the function to runtime's global table.
* \param f The function to be registered.
* \param override Whether allow override already registered function.
*/
-TVM_DLL int TVMFuncRegisterGlobal(
- const char* name, TVMFunctionHandle f, int override);
+TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override);
/*!
* \brief Get a global function.
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMFuncListGlobalNames(int* out_size,
- const char*** out_array);
+TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array);
// Array related apis for quick proptyping
/*!
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
- int ndim,
- int dtype_code,
- int dtype_bits,
- int dtype_lanes,
- int device_type,
- int device_id,
- TVMArrayHandle* out);
+TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits,
+ int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out);
/*!
* \brief Free the TVM Array.
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle,
- void* data,
- size_t nbytes);
+TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes);
/*!
* \brief Copy array data to CPU byte array.
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle,
- void* data,
- size_t nbytes);
+TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes);
/*!
* \brief Copy the array, both from and to must be valid during the copy.
* \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
- TVMArrayHandle to,
- TVMStreamHandle stream);
+TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream);
/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from,
- TVMArrayHandle* out);
+TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from,
- DLManagedTensor** out);
+TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out);
/*!
* \brief Delete (free) a DLManagedTensor's data.
* \param dst The destination stream to synchronize.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMStreamStreamSynchronize(int device_type,
- int device_id,
- TVMStreamHandle src,
+TVM_DLL int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src,
TVMStreamHandle dst);
/*!
* \param out_data The allocated device pointer.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx,
- size_t nbytes,
- size_t alignment,
- DLDataType type_hint,
- void** out_data);
+TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment,
+ DLDataType type_hint, void** out_data);
/*!
* \brief Free a data space on device.
* \param stream Optional stream object.
* \return 0 when success, -1 when failure happens.
*/
-TVM_DLL int TVMDeviceCopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t num_bytes,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* to,
+ size_t to_offset, size_t num_bytes, TVMContext ctx_from,
+ TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream);
#ifdef __cplusplus
// string_view:
// https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations
// https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros
-#if defined(__cpp_lib_experimental_string_view) && \
- __cpp_lib_experimental_string_view >= 201411
+#if defined(__cpp_lib_experimental_string_view) && __cpp_lib_experimental_string_view >= 201411
#define TVM_USE_CXX14_STRING_VIEW_HASH 1
#else
#define TVM_USE_CXX14_STRING_VIEW_HASH 0
* \brief Destroy the Inplace Array Base object
*/
~InplaceArrayBase() {
- if (!(std::is_standard_layout<ElemType>::value &&
- std::is_trivial<ElemType>::value)) {
+ if (!(std::is_standard_layout<ElemType>::value && std::is_trivial<ElemType>::value)) {
size_t size = Self()->GetSize();
for (size_t i = 0; i < size; ++i) {
ElemType* fp = reinterpret_cast<ElemType*>(AddressOf(i));
* \return Raw pointer to the element.
*/
void* AddressOf(size_t idx) const {
- static_assert(alignof(ArrayType) % alignof(ElemType) == 0 &&
- sizeof(ArrayType) % alignof(ElemType) == 0,
- "The size and alignment of ArrayType should respect "
- "ElemType's alignment.");
+ static_assert(
+ alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0,
+ "The size and alignment of ArrayType should respect "
+ "ElemType's alignment.");
size_t kDataStart = sizeof(ArrayType);
ArrayType* self = Self();
* \param fields The fields of the ADT object.
* \return The constructed ADT object reference.
*/
- ADT(int32_t tag, std::vector<ObjectRef> fields)
- : ADT(tag, fields.begin(), fields.end()){};
+ ADT(int32_t tag, std::vector<ObjectRef> fields) : ADT(tag, fields.begin(), fields.end()){};
/*!
* \brief construct an ADT object reference.
* \param init The initializer list of fields.
* \return The constructed ADT object reference.
*/
- ADT(int32_t tag, std::initializer_list<ObjectRef> init)
- : ADT(tag, init.begin(), init.end()){};
+ ADT(int32_t tag, std::initializer_list<ObjectRef> init) : ADT(tag, init.begin(), init.end()){};
/*!
* \brief Access element at index.
* \param idx The array index
* \return const ObjectRef
*/
- const ObjectRef& operator[](size_t idx) const {
- return operator->()->operator[](idx);
- }
+ const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); }
/*!
* \brief Return the ADT tag.
*
* \return the comparison result
*/
- bool operator==(const std::string& other) const {
- return this->compare(other) == 0;
- }
+ bool operator==(const std::string& other) const { return this->compare(other) == 0; }
/*!
* \brief Compare is not equal to other std::string
// This function falls back to string copy with c++11 compiler and is
// recommended to be compiled with c++14
#if TVM_USE_CXX17_STRING_VIEW_HASH
- return std::hash<std::string_view>()(
- std::string_view(data, size));
+ return std::hash<std::string_view>()(std::string_view(data, size));
#elif TVM_USE_CXX14_STRING_VIEW_HASH
- return std::hash<std::experimental::string_view>()(
- std::experimental::string_view(data, size));
+ return std::hash<std::experimental::string_view>()(std::experimental::string_view(data, size));
#else
return std::hash<std::string>()(std::string(data, size));
#endif
* \return int zero if both char sequences compare equal. negative if this
* appear before other, positive otherwise.
*/
- static int memncmp(const char* lhs, const char* rhs, size_t lhs_count,
- size_t rhs_count);
+ static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count);
};
/*! \brief An object representing string moved from std::string. */
return Downcast<String>(*this);
}
-inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count,
- size_t rhs_count) {
+inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
if (lhs == rhs && lhs_count == rhs_count) return 0;
for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
}
}
-template<>
+template <>
struct PackedFuncValueConverter<::tvm::runtime::String> {
static String From(const TVMArgValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
};
/*! \brief Helper to represent nullptr for optional. */
-struct NullOptType {
-};
+struct NullOptType {};
/*!
* \brief Optional container that to represent to a Nullable variant of T.
*
* \endcode
*/
-template<typename T>
+template <typename T>
class Optional : public ObjectRef {
public:
using ContainerType = typename T::ContainerType;
- static_assert(std::is_base_of<ObjectRef, T>::value,
- "Optional is only defined for ObjectRef.");
+ static_assert(std::is_base_of<ObjectRef, T>::value, "Optional is only defined for ObjectRef.");
// default constructors.
Optional() = default;
Optional(const Optional<T>&) = default;
return *this;
}
// normal value handling.
- Optional(T other) // NOLINT(*)
- : ObjectRef(std::move(other)) {
- }
+ Optional(T other) // NOLINT(*)
+ : ObjectRef(std::move(other)) {}
Optional<T>& operator=(T other) {
ObjectRef::operator=(std::move(other));
return *this;
* \return The contained value if the Optional is not null
* otherwise return the default_value.
*/
- T value_or(T default_value) const {
- return data_ != nullptr ? T(data_) : default_value;
- }
+ T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; }
/*! \return Whether the container is not nullptr.*/
- explicit operator bool() const {
- return *this != nullptr;
- }
+ explicit operator bool() const { return *this != nullptr; }
// operator overloadings
- bool operator==(std::nullptr_t) const {
- return data_ == nullptr;
- }
- bool operator!=(std::nullptr_t) const {
- return data_ != nullptr;
- }
+ bool operator==(std::nullptr_t) const { return data_ == nullptr; }
+ bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
auto operator==(const Optional<T>& other) const {
// support case where sub-class returns a symbolic ref type.
using RetType = decltype(value() == other.value());
if (*this != nullptr) return value() == other;
return RetType(false);
}
- auto operator!=(const T& other) const {
- return !(*this == other);
- }
- template<typename U>
+ auto operator!=(const T& other) const { return !(*this == other); }
+ template <typename U>
auto operator==(const U& other) const {
using RetType = decltype(value() == other);
if (*this == nullptr) return RetType(false);
return value() == other;
}
- template<typename U>
+ template <typename U>
auto operator!=(const U& other) const {
using RetType = decltype(value() != other);
if (*this == nullptr) return RetType(true);
static constexpr bool _type_is_nullable = true;
};
-template<typename T>
+template <typename T>
struct PackedFuncValueConverter<Optional<T>> {
static Optional<T> From(const TVMArgValue& val) {
if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
} // namespace runtime
// expose the functions to the root namespace.
-using runtime::String;
using runtime::Optional;
+using runtime::String;
constexpr runtime::NullOptType NullOpt{};
} // namespace tvm
* \param size The size of memory
* \return The virtual address
*/
-void * vmalloc(size_t size);
+void* vmalloc(size_t size);
/*!
* \brief Reallocate memory from manager
* \param size The size of memory
* \return The virtual address
*/
-void * vrealloc(void * ptr, size_t size);
+void* vrealloc(void* ptr, size_t size);
/*!
* \brief Free the memory.
* \param ptr The pointer to the memory to deallocate
* \return The virtual address
*/
-void vfree(void * ptr);
+void vfree(void* ptr);
#endif // TVM_RUNTIME_CRT_MEMORY_H_
#ifndef TVM_RUNTIME_DATA_TYPE_H_
#define TVM_RUNTIME_DATA_TYPE_H_
-#include <tvm/runtime/c_runtime_api.h>
#include <dmlc/logging.h>
-#include <type_traits>
+#include <tvm/runtime/c_runtime_api.h>
+
#include <string>
+#include <type_traits>
namespace tvm {
namespace runtime {
* \brief Constructor
* \param dtype The DLDataType
*/
- explicit DataType(DLDataType dtype)
- : data_(dtype) {}
+ explicit DataType(DLDataType dtype) : data_(dtype) {}
/*!
* \brief Constructor
* \param code The type code.
data_.lanes = static_cast<uint16_t>(lanes);
}
/*! \return The type code. */
- int code() const {
- return static_cast<int>(data_.code);
- }
+ int code() const { return static_cast<int>(data_.code); }
/*! \return number of bits in the data. */
- int bits() const {
- return static_cast<int>(data_.bits);
- }
+ int bits() const { return static_cast<int>(data_.bits); }
/*! \return number of bytes to store each scalar. */
- int bytes() const {
- return (bits() + 7) / 8;
- }
+ int bytes() const { return (bits() + 7) / 8; }
/*! \return number of lanes in the data. */
- int lanes() const {
- return static_cast<int>(data_.lanes);
- }
+ int lanes() const { return static_cast<int>(data_.lanes); }
/*! \return whether type is a scalar type. */
- bool is_scalar() const {
- return lanes() == 1;
- }
+ bool is_scalar() const { return lanes() == 1; }
/*! \return whether type is a scalar type. */
- bool is_bool() const {
- return code() == DataType::kUInt && bits() == 1;
- }
+ bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
/*! \return whether type is a float type. */
- bool is_float() const {
- return code() == DataType::kFloat;
- }
+ bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float16 type. */
- bool is_float16() const {
- return is_float() && bits() == 16;
- }
+ bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is an int type. */
- bool is_int() const {
- return code() == DataType::kInt;
- }
+ bool is_int() const { return code() == DataType::kInt; }
/*! \return whether type is an uint type. */
- bool is_uint() const {
- return code() == DataType::kUInt;
- }
+ bool is_uint() const { return code() == DataType::kUInt; }
/*! \return whether type is a handle type. */
- bool is_handle() const {
- return code() == DataType::kHandle && !is_void();
- }
+ bool is_handle() const { return code() == DataType::kHandle && !is_void(); }
/*! \return whether type is a vector type. */
- bool is_vector() const {
- return lanes() > 1;
- }
+ bool is_vector() const { return lanes() > 1; }
/*! \return whether type is a bool vector type. */
- bool is_vector_bool() const {
- return is_vector() && bits() == 1;
- }
+ bool is_vector_bool() const { return is_vector() && bits() == 1; }
/*! \return whether type is a Void type. */
- bool is_void() const {
- return code() == DataType::kHandle && bits() == 0 && lanes() == 0;
- }
+ bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; }
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
* \return the result type.
*/
- DataType with_lanes(int lanes) const {
- return DataType(data_.code, data_.bits, lanes);
- }
+ DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); }
/*!
* \brief Create a new data type by change bits to a specified value.
* \param bits The target number of bits.
* \return the result type.
*/
- DataType with_bits(int bits) const {
- return DataType(data_.code, bits, data_.lanes);
- }
+ DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); }
/*!
* \brief Get the scalar version of the type.
* \return the result type.
*/
- DataType element_of() const {
- return with_lanes(1);
- }
+ DataType element_of() const { return with_lanes(1); }
/*!
* \brief Equal comparator.
* \param other The data type to compre against.
* \return The comparison resilt.
*/
bool operator==(const DataType& other) const {
- return
- data_.code == other.data_.code &&
- data_.bits == other.data_.bits &&
- data_.lanes == other.data_.lanes;
+ return data_.code == other.data_.code && data_.bits == other.data_.bits &&
+ data_.lanes == other.data_.lanes;
}
/*!
* \brief NotEqual comparator.
* \param other The data type to compre against.
* \return The comparison resilt.
*/
- bool operator!=(const DataType& other) const {
- return !operator==(other);
- }
+ bool operator!=(const DataType& other) const { return !operator==(other); }
/*!
* \brief Converter to DLDataType
* \return the result.
*/
- operator DLDataType () const {
- return data_;
- }
+ operator DLDataType() const { return data_; }
/*!
* \brief Construct an int type.
* \param lanes The number of lanes.
* \return The constructed data type.
*/
- static DataType Int(int bits, int lanes = 1) {
- return DataType(kDLInt, bits, lanes);
- }
+ static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); }
/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
- static DataType UInt(int bits, int lanes = 1) {
- return DataType(kDLUInt, bits, lanes);
- }
+ static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); }
/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
- static DataType Float(int bits, int lanes = 1) {
- return DataType(kDLFloat, bits, lanes);
- }
+ static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
- static DataType Bool(int lanes = 1) {
- return DataType::UInt(1, lanes);
- }
+ static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); }
/*!
* \brief Construct a handle type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
- static DataType Handle(int bits = 64, int lanes = 1) {
- return DataType(kHandle, bits, lanes);
- }
+ static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); }
/*!
* \brief Construct a Void type.
* \return The constructed data type.
*/
- static DataType Void() {
- return DataType(kHandle, 0, 0);
- }
+ static DataType Void() { return DataType(kHandle, 0, 0); }
/*!
* \brief Get the corresponding type of TVMShapeIndex.
* \return The type of TVM shape index.
inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
- if (dtype == DataType::Bool() ||
- dtype == DataType::Int(4) ||
- dtype == DataType::UInt(4) ||
+ if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
dtype == DataType::Int(1)) {
return 1;
}
- CHECK_EQ(data_bits % 8, 0U)
- << "Need to load/store by multiple of bytes";
+ CHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
return data_bits / 8;
}
// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
- case kDLInt: return "int";
- case kDLUInt: return "uint";
- case kDLFloat: return "float";
- case kTVMStr: return "str";
- case kTVMBytes: return "bytes";
- case kTVMOpaqueHandle: return "handle";
- case kTVMNullptr: return "NULL";
- case kTVMDLTensorHandle: return "ArrayHandle";
- case kTVMDataType: return "DLDataType";
- case kTVMContext: return "TVMContext";
- case kTVMPackedFuncHandle: return "FunctionHandle";
- case kTVMModuleHandle: return "ModuleHandle";
- case kTVMNDArrayHandle: return "NDArrayContainer";
- case kTVMObjectHandle: return "Object";
- case kTVMObjectRValueRefArg: return "ObjectRValueRefArg";
- default: LOG(FATAL) << "unknown type_code="
- << static_cast<int>(type_code); return "";
+ case kDLInt:
+ return "int";
+ case kDLUInt:
+ return "uint";
+ case kDLFloat:
+ return "float";
+ case kTVMStr:
+ return "str";
+ case kTVMBytes:
+ return "bytes";
+ case kTVMOpaqueHandle:
+ return "handle";
+ case kTVMNullptr:
+ return "NULL";
+ case kTVMDLTensorHandle:
+ return "ArrayHandle";
+ case kTVMDataType:
+ return "DLDataType";
+ case kTVMContext:
+ return "TVMContext";
+ case kTVMPackedFuncHandle:
+ return "FunctionHandle";
+ case kTVMModuleHandle:
+ return "ModuleHandle";
+ case kTVMNDArrayHandle:
+ return "NDArrayContainer";
+ case kTVMObjectHandle:
+ return "Object";
+ case kTVMObjectRValueRefArg:
+ return "ObjectRValueRefArg";
+ default:
+ LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
+ return "";
}
}
inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
- os << "bool"; return os;
+ os << "bool";
+ return os;
}
if (DataType(t).is_void()) {
return os << "void";
return os;
}
-inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
+inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}
t = DataType::Void();
return t;
}
- t.bits = 32; t.lanes = 1;
+ t.bits = 32;
+ t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
- t.code = kDLInt; scan = s.c_str() + 3;
+ t.code = kDLInt;
+ scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
- t.code = kDLUInt; scan = s.c_str() + 4;
+ t.code = kDLUInt;
+ scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
- t.code = kDLFloat; scan = s.c_str() + 5;
+ t.code = kDLFloat;
+ scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
+
#include <string>
namespace tvm {
* as OpenGL, as nbytes & alignment are sufficient for most backends.
* \return The allocated device pointer.
*/
- virtual void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
+ virtual void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) = 0;
/*!
* \brief Free a data space on device.
* can be useful for cross device endian converison.
* \param stream Optional stream object.
*/
- virtual void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t num_bytes,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
- TVMStreamHandle stream) = 0;
- /*!
+ virtual void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
+ size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to,
+ DLDataType type_hint, TVMStreamHandle stream) = 0;
+ /*!
* \brief Create a new stream of execution.
*
* \param ctx The context of allocation.
* \param event_src The source stream to synchronize.
* \param event_dst The destination stream to synchronize.
*/
- virtual void SyncStreamFromTo(TVMContext ctx,
- TVMStreamHandle event_src,
+ virtual void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src,
TVMStreamHandle event_dst);
- /*!
+ /*!
* \brief Allocate temporal workspace for backend execution.
*
* \note We have the following assumption about backend temporal
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends.
*/
- virtual void* AllocWorkspace(TVMContext ctx,
- size_t nbytes,
- DLDataType type_hint = {});
+ virtual void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
*/
inline const char* DeviceName(int type) {
switch (type) {
- case kDLCPU: return "cpu";
- case kDLGPU: return "gpu";
- case kDLCPUPinned: return "cpu_pinned";
- case kDLOpenCL: return "opencl";
- case kDLSDAccel: return "sdaccel";
- case kDLAOCL: return "aocl";
- case kDLVulkan: return "vulkan";
- case kDLMetal: return "metal";
- case kDLVPI: return "vpi";
- case kDLROCM: return "rocm";
- case kOpenGL: return "opengl";
- case kDLExtDev: return "ext_dev";
- case kDLWebGPU: return "webgpu";
- case kDLMicroDev: return "micro_dev";
- case kDLHexagon: return "hexagon";
- default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
+ case kDLCPU:
+ return "cpu";
+ case kDLGPU:
+ return "gpu";
+ case kDLCPUPinned:
+ return "cpu_pinned";
+ case kDLOpenCL:
+ return "opencl";
+ case kDLSDAccel:
+ return "sdaccel";
+ case kDLAOCL:
+ return "aocl";
+ case kDLVulkan:
+ return "vulkan";
+ case kDLMetal:
+ return "metal";
+ case kDLVPI:
+ return "vpi";
+ case kDLROCM:
+ return "rocm";
+ case kOpenGL:
+ return "opengl";
+ case kDLExtDev:
+ return "ext_dev";
+ case kDLWebGPU:
+ return "webgpu";
+ case kDLMicroDev:
+ return "micro_dev";
+ case kDLHexagon:
+ return "hexagon";
+ default:
+ LOG(FATAL) << "unknown type =" << type;
+ return "Unknown";
}
}
#define TVM_RUNTIME_MEMORY_H_
#include <tvm/runtime/object.h>
+
#include <cstdlib>
-#include <utility>
#include <type_traits>
+#include <utility>
namespace tvm {
namespace runtime {
* \tparam T the node type.
* \return The ObjectPtr to the allocated object.
*/
-template<typename T, typename... Args>
+template <typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args);
// Detail implementations after this
*
* \tparam Derived The derived class.
*/
-template<typename Derived>
+template <typename Derived>
class ObjAllocatorBase {
public:
/*!
* \tparam Args The constructor signature.
* \param args The arguments.
*/
- template<typename T, typename... Args>
+ template <typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args) {
using Handler = typename Derived::template Handler<T>;
- static_assert(std::is_base_of<Object, T>::value,
- "make can only be used to create Object");
- T* ptr = Handler::New(static_cast<Derived*>(this),
- std::forward<Args>(args)...);
+ static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object");
+ T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...);
ptr->type_index_ = T::RuntimeTypeIndex();
ptr->deleter_ = Handler::Deleter();
return ObjectPtr<T>(ptr);
* \param num_elems The number of array elements.
* \param args The arguments.
*/
- template<typename ArrayType, typename ElemType, typename... Args>
+ template <typename ArrayType, typename ElemType, typename... Args>
inline ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... args) {
using Handler = typename Derived::template ArrayHandler<ArrayType, ElemType>;
static_assert(std::is_base_of<Object, ArrayType>::value,
"make_inplace_array can only be used to create Object");
- ArrayType* ptr = Handler::New(static_cast<Derived*>(this),
- num_elems,
- std::forward<Args>(args)...);
+ ArrayType* ptr =
+ Handler::New(static_cast<Derived*>(this), num_elems, std::forward<Args>(args)...);
ptr->type_index_ = ArrayType::RuntimeTypeIndex();
ptr->deleter_ = Handler::Deleter();
return ObjectPtr<ArrayType>(ptr);
};
// Simple allocator that uses new/delete.
-class SimpleObjAllocator :
- public ObjAllocatorBase<SimpleObjAllocator> {
+class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> {
public:
- template<typename T>
+ template <typename T>
class Handler {
public:
using StorageType = typename std::aligned_storage<sizeof(T), alignof(T)>::type;
- template<typename... Args>
+ template <typename... Args>
static T* New(SimpleObjAllocator*, Args&&... args) {
// NOTE: the first argument is not needed for SimpleObjAllocator
// It is reserved for special allocators that needs to recycle
return reinterpret_cast<T*>(data);
}
- static Object::FDeleter Deleter() {
- return Deleter_;
- }
+ static Object::FDeleter Deleter() { return Deleter_; }
private:
static void Deleter_(Object* objptr) {
};
// Array handler that uses new/delete.
- template<typename ArrayType, typename ElemType>
+ template <typename ArrayType, typename ElemType>
class ArrayHandler {
public:
using StorageType = typename std::aligned_storage<sizeof(ArrayType), alignof(ArrayType)>::type;
// for now only support elements that aligns with array header.
static_assert(alignof(ArrayType) % alignof(ElemType) == 0 &&
- sizeof(ArrayType) % alignof(ElemType) == 0,
+ sizeof(ArrayType) % alignof(ElemType) == 0,
"element alignment constraint");
- template<typename... Args>
+ template <typename... Args>
static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) {
// NOTE: the first argument is not needed for ArrayObjAllocator
// It is reserved for special allocators that needs to recycle
return reinterpret_cast<ArrayType*>(data);
}
- static Object::FDeleter Deleter() {
- return Deleter_;
- }
+ static Object::FDeleter Deleter() { return Deleter_; }
private:
static void Deleter_(Object* objptr) {
// call a virtual destructor(which may not be available and is not required).
tptr->ArrayType::~ArrayType();
StorageType* p = reinterpret_cast<StorageType*>(tptr);
- delete []p;
+ delete[] p;
}
};
};
-template<typename T, typename... Args>
+template <typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args) {
return SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...);
}
-template<typename ArrayType, typename ElemType, typename... Args>
+template <typename ArrayType, typename ElemType, typename... Args>
inline ObjectPtr<ArrayType> make_inplace_array_object(size_t num_elems, Args&&... args) {
- return SimpleObjAllocator().make_inplace_array<ArrayType, ElemType>(
- num_elems, std::forward<Args>(args)...);
+ return SimpleObjAllocator().make_inplace_array<ArrayType, ElemType>(num_elems,
+ std::forward<Args>(args)...);
}
} // namespace runtime
#define TVM_RUNTIME_MODULE_H_
#include <dmlc/io.h>
-
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
+#include <tvm/runtime/object.h>
#include <memory>
-#include <vector>
#include <string>
#include <unordered_map>
+#include <vector>
namespace tvm {
namespace runtime {
public:
Module() {}
// constructor from container.
- explicit Module(ObjectPtr<Object> n)
- : ObjectRef(n) {}
+ explicit Module(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Get packed function from current module by name.
*
* \note This function won't load the import relationship.
* Re-create import relationship by calling Import.
*/
- TVM_DLL static Module LoadFromFile(const std::string& file_name,
- const std::string& format = "");
+ TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = "");
// refer to the corresponding container.
using ContainerType = ModuleNode;
friend class ModuleNode;
* If the function need resource from the module(e.g. late linking),
* it should capture sptr_to_self.
*/
- virtual PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) = 0;
+ virtual PackedFunc GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) = 0;
/*!
* \brief Save the module to file.
* \param file_name The file to be saved to.
* \param format The format of the file.
*/
- virtual void SaveToFile(const std::string& file_name,
- const std::string& format);
+ virtual void SaveToFile(const std::string& file_name, const std::string& format);
/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
*/
const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */
- const std::vector<Module>& imports() const {
- return imports_;
- }
+ const std::vector<Module>& imports() const { return imports_; }
// integration with the existing components.
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
private:
/*! \brief Cache used by GetImport */
- std::unordered_map<std::string,
- std::shared_ptr<PackedFunc> > import_cache_;
+ std::unordered_map<std::string, std::shared_ptr<PackedFunc> > import_cache_;
};
/*!
// implementations of inline functions.
-inline void Module::Import(Module other) {
- return (*this)->Import(other);
-}
+inline void Module::Import(Module other) { return (*this)->Import(other); }
-inline ModuleNode* Module::operator->() {
- return static_cast<ModuleNode*>(get_mutable());
-}
+inline ModuleNode* Module::operator->() { return static_cast<ModuleNode*>(get_mutable()); }
inline const ModuleNode* Module::operator->() const {
return static_cast<const ModuleNode*>(get());
} // namespace tvm
#include <tvm/runtime/packed_func.h> // NOLINT(*)
-#endif // TVM_RUNTIME_MODULE_H_
+#endif // TVM_RUNTIME_MODULE_H_
#include <tvm/runtime/serializer.h>
#include <atomic>
-#include <vector>
#include <utility>
+#include <vector>
namespace tvm {
namespace runtime {
* \brief constructor.
* \param data ObjectPtr to the data container.
*/
- explicit NDArray(ObjectPtr<Object> data)
- : ObjectRef(data) {}
+ explicit NDArray(ObjectPtr<Object> data) : ObjectRef(data) {}
/*! \brief reset the content of NDArray to be nullptr */
inline void reset();
inline void CopyFrom(const DLTensor* other);
inline void CopyFrom(const NDArray& other);
/*!
- * \brief Copy data content from a byte buffer.
- * \param data The source bytes to be copied from.
- * \param nbytes The size of the buffer in bytes
- * Must be equal to the size of the NDArray.
- * \note The copy may happen asynchronously if it involves a GPU context.
- * TVMSynchronize is necessary.
- */
+ * \brief Copy data content from a byte buffer.
+ * \param data The source bytes to be copied from.
+ * \param nbytes The size of the buffer in bytes
+ * Must be equal to the size of the NDArray.
+ * \note The copy may happen asynchronously if it involves a GPU context.
+ * TVMSynchronize is necessary.
+ */
TVM_DLL void CopyFromBytes(const void* data, size_t nbytes);
/*!
* \brief Copy data content into another array.
* \param dtype The data type of the new array.
* \note The memory size of new array must be smaller than the current one.
*/
- TVM_DLL NDArray CreateView(
- std::vector<int64_t> shape, DLDataType dtype);
+ TVM_DLL NDArray CreateView(std::vector<int64_t> shape, DLDataType dtype);
/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
* \param ctx The context of the Array.
* \return The created Array
*/
- TVM_DLL static NDArray Empty(std::vector<int64_t> shape,
- DLDataType dtype,
- DLContext ctx);
+ TVM_DLL static NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx);
/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
* \param to The target array.
* \param stream The stream used in copy.
*/
- TVM_DLL static void CopyFromTo(
- const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
+ TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to,
+ TVMStreamHandle stream = nullptr);
TVM_DLL std::vector<int64_t> Shape() const;
// internal namespace
* \brief Object container class that backs NDArray.
* \note do not use this function directly, use NDArray.
*/
-class NDArray::Container :
- public Object,
- public NDArray::ContainerBase {
+class NDArray::Container : public Object, public NDArray::ContainerBase {
public:
/*! \brief default constructor */
Container() {
dl_tensor.byte_offset = 0;
}
- Container(void* data,
- std::vector<int64_t> shape,
- DLDataType dtype,
- DLContext ctx) {
+ Container(void* data, std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
// Initialize the type index.
type_index_ = Container::RuntimeTypeIndex();
dl_tensor.data = data;
* \brief Set the deleter field.
* \param deleter The deleter.
*/
- void SetDeleter(FDeleter deleter) {
- deleter_ = deleter;
- }
+ void SetDeleter(FDeleter deleter) { deleter_ = deleter; }
// Expose DecRef and IncRef as public function
// NOTE: they are only for developer purposes only.
inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
- NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
- dptr->dtype, ctx);
+ NDArray ret =
+ Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, ctx);
this->CopyTo(ret);
return ret;
}
-inline int NDArray::use_count() const {
- return data_.use_count();
-}
+inline int NDArray::use_count() const { return data_.use_count(); }
-inline const DLTensor* NDArray::operator->() const {
- return &(get_mutable()->dl_tensor);
-}
+inline const DLTensor* NDArray::operator->() const { return &(get_mutable()->dl_tensor); }
inline NDArray::Container* NDArray::get_mutable() const {
return static_cast<NDArray::Container*>(data_.get());
}
inline ObjectPtr<Object> NDArray::FFIDataFromHandle(TVMArrayHandle handle) {
- return GetObjectPtr<Object>(static_cast<NDArray::Container*>(
- reinterpret_cast<NDArray::ContainerBase*>(handle)));
+ return GetObjectPtr<Object>(
+ static_cast<NDArray::Container*>(reinterpret_cast<NDArray::ContainerBase*>(handle)));
}
inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) {
// NOTE: it is necessary to cast to container then to base
// so that the FFI handle uses the ContainerBase address.
- return reinterpret_cast<TVMArrayHandle>(
- static_cast<NDArray::ContainerBase*>(
- static_cast<NDArray::Container*>(
- const_cast<Object*>(nd.get()))));
+ return reinterpret_cast<TVMArrayHandle>(static_cast<NDArray::ContainerBase*>(
+ static_cast<NDArray::Container*>(const_cast<Object*>(nd.get()))));
}
inline void NDArray::FFIDecRef(TVMArrayHandle handle) {
- static_cast<NDArray::Container*>(
- reinterpret_cast<NDArray::ContainerBase*>(handle))->DecRef();
+ static_cast<NDArray::Container*>(reinterpret_cast<NDArray::ContainerBase*>(handle))->DecRef();
}
inline Object* TVMArrayHandleToObjectHandle(TVMArrayHandle handle) {
- return static_cast<NDArray::Container*>(
- reinterpret_cast<NDArray::ContainerBase*>(handle));
+ return static_cast<NDArray::Container*>(reinterpret_cast<NDArray::ContainerBase*>(handle));
}
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
-inline bool SaveDLTensor(dmlc::Stream* strm,
- const DLTensor* tensor) {
+inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
- if (DMLC_IO_NO_ENDIAN_SWAP &&
- tensor->ctx.device_type == kDLCPU &&
- tensor->strides == nullptr &&
+ if (DMLC_IO_NO_ENDIAN_SWAP && tensor->ctx.device_type == kDLCPU && tensor->strides == nullptr &&
tensor->byte_offset == 0) {
// quick path
strm->Write(tensor->data, data_byte_size);
} else {
std::vector<uint8_t> bytes(data_byte_size);
- CHECK_EQ(TVMArrayCopyToBytes(
- const_cast<DLTensor*>(tensor), dmlc::BeginPtr(bytes), data_byte_size), 0)
+ CHECK_EQ(
+ TVMArrayCopyToBytes(const_cast<DLTensor*>(tensor), dmlc::BeginPtr(bytes), data_byte_size),
+ 0)
<< TVMGetLastError();
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
return true;
}
-inline void NDArray::Save(dmlc::Stream* strm) const {
- SaveDLTensor(strm, operator->());
-}
+inline void NDArray::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->()); }
inline bool NDArray::Load(dmlc::Stream* strm) {
uint64_t header, reserved;
- CHECK(strm->Read(&header))
- << "Invalid DLTensor file format";
- CHECK(strm->Read(&reserved))
- << "Invalid DLTensor file format";
- CHECK(header == kTVMNDArrayMagic)
- << "Invalid DLTensor file format";
+ CHECK(strm->Read(&header)) << "Invalid DLTensor file format";
+ CHECK(strm->Read(&reserved)) << "Invalid DLTensor file format";
+ CHECK(header == kTVMNDArrayMagic) << "Invalid DLTensor file format";
DLContext ctx;
int ndim;
DLDataType dtype;
- CHECK(strm->Read(&ctx))
- << "Invalid DLTensor file format";
- CHECK(strm->Read(&ndim))
- << "Invalid DLTensor file format";
- CHECK(strm->Read(&dtype))
- << "Invalid DLTensor file format";
- CHECK_EQ(ctx.device_type, kDLCPU)
- << "Invalid DLTensor context: can only save as CPU tensor";
+ CHECK(strm->Read(&ctx)) << "Invalid DLTensor file format";
+ CHECK(strm->Read(&ndim)) << "Invalid DLTensor file format";
+ CHECK(strm->Read(&dtype)) << "Invalid DLTensor file format";
+ CHECK_EQ(ctx.device_type, kDLCPU) << "Invalid DLTensor context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
- CHECK(strm->ReadArray(&shape[0], ndim))
- << "Invalid DLTensor file format";
+ CHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
- CHECK(strm->Read(&data_byte_size))
- << "Invalid DLTensor file format";
- CHECK(data_byte_size == num_elems * elem_bytes)
- << "Invalid DLTensor file format";
- CHECK(strm->Read(ret->data, data_byte_size))
- << "Invalid DLTensor file format";
+ CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format";
+ CHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format";
+ CHECK(strm->Read(ret->data, data_byte_size)) << "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <type_traits>
+
#include <string>
+#include <type_traits>
#include <utility>
/*!
* Recommendation: set to estimate number of children needed.
* - _type_child_slots_can_overflow:
* Whether we can add additional child classes even if the number of child classes
- * exceeds the _type_child_slots. A fallback mechanism to check global type table will be used.
- * Recommendation: set to false for optimal runtime speed if we know exact number of children.
+ * exceeds the _type_child_slots. A fallback mechanism to check global type table will be
+ * used. Recommendation: set to false for optimal runtime speed if we know exact number of children.
*
* Two macros are used to declare helper functions in the object:
* - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed.
*/
typedef void (*FDeleter)(Object* self);
/*! \return The internal runtime type index of the object. */
- uint32_t type_index() const {
- return type_index_;
- }
+ uint32_t type_index() const { return type_index_; }
/*!
* \return the type key of the object.
* \note this operation is expensive, can be used for error reporting.
*/
- std::string GetTypeKey() const {
- return TypeIndex2Key(type_index_);
- }
+ std::string GetTypeKey() const { return TypeIndex2Key(type_index_); }
/*!
* \return A hash value of the return of GetTypeKey.
*/
- size_t GetTypeKeyHash() const {
- return TypeIndex2KeyHash(type_index_);
- }
+ size_t GetTypeKeyHash() const { return TypeIndex2KeyHash(type_index_); }
/*!
* Check if the object is an instance of TargetType.
* \tparam TargetType The target type to be checked.
* \return Whether the target type is true.
*/
- template<typename TargetType>
+ template <typename TargetType>
inline bool IsInstance() const;
/*!
static constexpr const char* _type_key = "runtime.Object";
- static uint32_t _GetOrAllocRuntimeTypeIndex() {
- return TypeIndex::kRoot;
- }
- static uint32_t RuntimeTypeIndex() {
- return TypeIndex::kRoot;
- }
+ static uint32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kRoot; }
+ static uint32_t RuntimeTypeIndex() { return TypeIndex::kRoot; }
// Default object type properties for sub-classes
static constexpr bool _type_final = false;
// The type index of Object is TypeIndex::kRoot
static constexpr uint32_t _type_index = TypeIndex::kDynamic;
-
// Default constructor and copy constructor
Object() {}
// Override the copy and assign constructors to do nothing.
}
Object(Object&& other) { // NOLINT(*)
}
- Object& operator=(const Object& other) { //NOLINT(*)
+ Object& operator=(const Object& other) { // NOLINT(*)
return *this;
}
- Object& operator=(Object&& other) { //NOLINT(*)
+ Object& operator=(Object&& other) { // NOLINT(*)
return *this;
}
FDeleter deleter_ = nullptr;
// Invariant checks.
static_assert(sizeof(int32_t) == sizeof(RefCounterType) &&
- alignof(int32_t) == sizeof(RefCounterType),
+ alignof(int32_t) == sizeof(RefCounterType),
"RefCounter ABI check.");
/*!
* \param type_child_slots_can_overflow Whether to allow child to overflow the slots.
* \return The allocated type index.
*/
- TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(
- const std::string& key,
- uint32_t static_tindex,
- uint32_t parent_tindex,
- uint32_t type_child_slots,
- bool type_child_slots_can_overflow);
+ TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex,
+ uint32_t parent_tindex,
+ uint32_t type_child_slots,
+ bool type_child_slots_can_overflow);
// reference counter related operations
/*! \brief developer function, increases reference counter. */
*/
TVM_DLL bool DerivedFrom(uint32_t parent_tindex) const;
// friend classes
- template<typename>
+ template <typename>
friend class ObjAllocatorBase;
- template<typename>
+ template <typename>
friend class ObjectPtr;
friend class TVMRetValue;
friend class ObjectInternal;
other.data_ = nullptr;
}
/*! \brief destructor */
- ~ObjectPtr() {
- this->reset();
- }
+ ~ObjectPtr() { this->reset(); }
/*!
* \brief Swap this array with another Object
* \param other The other Object
/*!
* \return Get the content of the pointer
*/
- T* get() const {
- return static_cast<T*>(data_);
- }
+ T* get() const { return static_cast<T*>(data_); }
/*!
* \return The pointer
*/
- T* operator->() const {
- return get();
- }
+ T* operator->() const { return get(); }
/*!
* \return The reference
*/
}
}
/*! \return The use count of the ptr, for debug purposes */
- int use_count() const {
- return data_ != nullptr ? data_->use_count() : 0;
- }
+ int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
/*! \return whether the reference is unique */
- bool unique() const {
- return data_ != nullptr && data_->use_count() == 1;
- }
+ bool unique() const { return data_ != nullptr && data_->use_count() == 1; }
/*! \return Whether two ObjectPtr do not equal each other */
- bool operator==(const ObjectPtr<T>& other) const {
- return data_ == other.data_;
- }
+ bool operator==(const ObjectPtr<T>& other) const { return data_ == other.data_; }
/*! \return Whether two ObjectPtr equals each other */
- bool operator!=(const ObjectPtr<T>& other) const {
- return data_ != other.data_;
- }
+ bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; }
/*! \return Whether the pointer is nullptr */
- bool operator==(std::nullptr_t null) const {
- return data_ == nullptr;
- }
+ bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
/*! \return Whether the pointer is not nullptr */
- bool operator!=(std::nullptr_t null) const {
- return data_ != nullptr;
- }
+ bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
private:
/*! \brief internal pointer field */
friend class Object;
friend class ObjectRef;
friend struct ObjectHash;
- template<typename>
+ template <typename>
friend class ObjectPtr;
- template<typename>
+ template <typename>
friend class ObjAllocatorBase;
friend class TVMPODValue_;
friend class TVMArgsSetter;
* \param other Another object ref.
* \return the compare result.
*/
- bool same_as(const ObjectRef& other) const {
- return data_ == other.data_;
- }
+ bool same_as(const ObjectRef& other) const { return data_ == other.data_; }
/*!
* \brief Comparator
* \param other Another object ref.
* \return the compare result.
*/
- bool operator==(const ObjectRef& other) const {
- return data_ == other.data_;
- }
+ bool operator==(const ObjectRef& other) const { return data_ == other.data_; }
/*!
* \brief Comparator
* \param other Another object ref.
* \return the compare result.
*/
- bool operator!=(const ObjectRef& other) const {
- return data_ != other.data_;
- }
+ bool operator!=(const ObjectRef& other) const { return data_ != other.data_; }
/*!
* \brief Comparator
* \param other Another object ref by address.
* \return the compare result.
*/
- bool operator<(const ObjectRef& other) const {
- return data_.get() < other.data_.get();
- }
+ bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); }
/*!
* \return whether the object is defined(not null).
*/
- bool defined() const {
- return data_ != nullptr;
- }
+ bool defined() const { return data_ != nullptr; }
/*! \return the internal object pointer */
- const Object* get() const {
- return data_.get();
- }
+ const Object* get() const { return data_.get(); }
/*! \return the internal object pointer */
- const Object* operator->() const {
- return get();
- }
+ const Object* operator->() const { return get(); }
/*! \return whether the reference is unique */
- bool unique() const {
- return data_.unique();
- }
+ bool unique() const { return data_.unique(); }
/*! \return The use count of the ptr, for debug purposes */
- int use_count() const {
- return data_.use_count();
- }
+ int use_count() const { return data_.use_count(); }
/*!
* \brief Try to downcast the internal Object to a
* raw pointer of a corresponding type.
/*! \brief Internal pointer that backs the reference. */
ObjectPtr<Object> data_;
/*! \return return a mutable internal ptr, can be used by sub-classes. */
- Object* get_mutable() const {
- return data_.get();
- }
+ Object* get_mutable() const { return data_.get(); }
/*!
* \brief Internal helper function downcast a ref without check.
* \note Only used for internal dev purposes.
* \tparam T The target reference type.
* \return The casted result.
*/
- template<typename T>
+ template <typename T>
static T DowncastNoCheck(ObjectRef ref) {
return T(std::move(ref.data_));
}
* after we successfully moved the field.
* \param ref The reference data.
*/
- static void FFIClearAfterMove(ObjectRef* ref) {
- ref->data_.data_ = nullptr;
- }
+ static void FFIClearAfterMove(ObjectRef* ref) { ref->data_.data_ = nullptr; }
/*!
* \brief Internal helper function get data_ as ObjectPtr of ObjectType.
* \note only used for internal dev purpose.
* \tparam ObjectType The corresponding object type.
* \return the corresponding type.
*/
- template<typename ObjectType>
+ template <typename ObjectType>
static ObjectPtr<ObjectType> GetDataPtr(const ObjectRef& ref) {
return ObjectPtr<ObjectType>(ref.data_.data_);
}
/*! \brief ObjectRef hash functor */
struct ObjectHash {
- size_t operator()(const ObjectRef& a) const {
- return operator()(a.data_);
- }
+ size_t operator()(const ObjectRef& a) const { return operator()(a.data_); }
- template<typename T>
+ template <typename T>
size_t operator()(const ObjectPtr<T>& a) const {
return std::hash<Object*>()(a.get());
}
};
-
/*! \brief ObjectRef equal functor */
struct ObjectEqual {
- bool operator()(const ObjectRef& a, const ObjectRef& b) const {
- return a.same_as(b);
- }
+ bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); }
- template<typename T>
+ template <typename T>
size_t operator()(const ObjectPtr<T>& a, const ObjectPtr<T>& b) const {
return a == b;
}
};
-
/*!
* \brief helper macro to declare a base object type that can be inheritated.
* \param TypeName The name of the current type.
* \param ParentType The name of the ParentType
*/
-#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
- static_assert(!ParentType::_type_final, "ParentObj maked as final"); \
- static uint32_t RuntimeTypeIndex() { \
- static_assert(TypeName::_type_child_slots == 0 || \
- ParentType::_type_child_slots == 0 || \
- TypeName::_type_child_slots < ParentType::_type_child_slots, \
- "Need to set _type_child_slots when parent specifies it."); \
- if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
- return TypeName::_type_index; \
- } \
- return _GetOrAllocRuntimeTypeIndex(); \
- } \
- static uint32_t _GetOrAllocRuntimeTypeIndex() { \
- static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \
- TypeName::_type_key, \
- TypeName::_type_index, \
- ParentType::_GetOrAllocRuntimeTypeIndex(), \
- TypeName::_type_child_slots, \
- TypeName::_type_child_slots_can_overflow); \
- return tidx; \
- } \
-
+#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
+ static_assert(!ParentType::_type_final, "ParentObj maked as final"); \
+ static uint32_t RuntimeTypeIndex() { \
+ static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
+ TypeName::_type_child_slots < ParentType::_type_child_slots, \
+ "Need to set _type_child_slots when parent specifies it."); \
+ if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
+ return TypeName::_type_index; \
+ } \
+ return _GetOrAllocRuntimeTypeIndex(); \
+ } \
+ static uint32_t _GetOrAllocRuntimeTypeIndex() { \
+ static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \
+ TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
+ TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \
+ return tidx; \
+ }
/*!
* \brief helper macro to declare type information in a final class.
- * \param TypeName The name of the current type.
- * \param ParentType The name of the ParentType
- */
-#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
- static const constexpr bool _type_final = true; \
- static const constexpr int _type_child_slots = 0; \
- TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
-
+ * \param TypeName The name of the current type.
+ * \param ParentType The name of the ParentType
+ */
+#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
+ static const constexpr bool _type_final = true; \
+ static const constexpr int _type_child_slots = 0; \
+ TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
-#define TVM_OBJECT_REG_VAR_DEF \
- static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid
+#define TVM_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid
/*!
* \brief Helper macro to register the object type to runtime.
*
* Use this macro in the cc file for each terminal class.
*/
-#define TVM_REGISTER_OBJECT_TYPE(TypeName) \
- TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \
- TypeName::_GetOrAllocRuntimeTypeIndex()
-
+#define TVM_REGISTER_OBJECT_TYPE(TypeName) \
+ TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex()
/*
* \brief Define the default copy/move constructor and assign opeator
* \param TypeName The class typename.
*/
-#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
- TypeName(const TypeName& other) = default; \
- TypeName(TypeName&& other) = default; \
- TypeName& operator=(const TypeName& other) = default; \
- TypeName& operator=(TypeName&& other) = default; \
+#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
+ TypeName(const TypeName& other) = default; \
+ TypeName(TypeName&& other) = default; \
+ TypeName& operator=(const TypeName& other) = default; \
+ TypeName& operator=(TypeName&& other) = default;
/*
* \brief Define object reference methods.
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
-#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
- TypeName() = default; \
- explicit TypeName( \
- ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
- : ParentType(n) {} \
- TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
- const ObjectName* operator->() const { \
- return static_cast<const ObjectName*>(data_.get()); \
- } \
+#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
+ TypeName() = default; \
+ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
+ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
+ const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
using ContainerType = ObjectName;
/*
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
-#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
- explicit TypeName( \
- ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
- : ParentType(n) {} \
- TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
- const ObjectName* operator->() const { \
- return static_cast<const ObjectName*>(data_.get()); \
- } \
- static constexpr bool _type_is_nullable = false; \
+#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
+ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
+ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
+ const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
+ static constexpr bool _type_is_nullable = false; \
using ContainerType = ObjectName;
/*
* \note We recommend making objects immutable when possible.
* This macro is only reserved for objects that stores runtime states.
*/
-#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
- TypeName() = default; \
- TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
- explicit TypeName( \
- ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
- : ParentType(n) {} \
- ObjectName* operator->() const { \
- return static_cast<ObjectName*>(data_.get()); \
- } \
+#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
+ TypeName() = default; \
+ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
+ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
+ ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
using ContainerType = ObjectName;
/*!
*
* \endcode
*/
-#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
- ObjectName* CopyOnWrite() { \
- CHECK(data_ != nullptr); \
- if (!data_.unique()) { \
- auto n = make_object<ObjectName>(*(operator->())); \
- ObjectPtr<Object>(std::move(n)).swap(data_); \
- } \
- return static_cast<ObjectName*>(data_.get()); \
- }
+#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
+ ObjectName* CopyOnWrite() { \
+ CHECK(data_ != nullptr); \
+ if (!data_.unique()) { \
+ auto n = make_object<ObjectName>(*(operator->())); \
+ ObjectPtr<Object>(std::move(n)).swap(data_); \
+ } \
+ return static_cast<ObjectName*>(data_.get()); \
+ }
// Implementations details below
// Object reference counting.
#if TVM_OBJECT_ATOMIC_REF_COUNTER
-inline void Object::IncRef() {
- ref_counter_.fetch_add(1, std::memory_order_relaxed);
-}
+inline void Object::IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); }
inline void Object::DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
}
}
-inline int Object::use_count() const {
- return ref_counter_.load(std::memory_order_relaxed);
-}
+inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); }
#else
-inline void Object::IncRef() {
- ++ref_counter_;
-}
+inline void Object::IncRef() { ++ref_counter_; }
inline void Object::DecRef() {
if (--ref_counter_ == 0) {
}
}
-inline int Object::use_count() const {
- return ref_counter_;
-}
+inline int Object::use_count() const { return ref_counter_; }
#endif // TVM_OBJECT_ATOMIC_REF_COUNTER
-template<typename TargetType>
+template <typename TargetType>
inline bool Object::IsInstance() const {
const Object* self = this;
// NOTE: the following code can be optimized by
}
}
-
template <typename ObjectType>
inline const ObjectType* ObjectRef::as() const {
- if (data_ != nullptr &&
- data_->IsInstance<ObjectType>()) {
+ if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
} else {
return nullptr;
inline SubRef Downcast(BaseRef ref) {
if (ref.defined()) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
- << "Downcast from " << ref->GetTypeKey() << " to "
- << SubRef::ContainerType::_type_key << " failed.";
+ << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key
+ << " failed.";
} else {
- CHECK(SubRef::_type_is_nullable)
- << "Downcast from nullptr to not nullable reference of "
- << SubRef::ContainerType::_type_key;
+ CHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of "
+ << SubRef::ContainerType::_type_key;
}
return SubRef(std::move(ref.data_));
}
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/data_type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
-#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
+
#include <functional>
-#include <tuple>
-#include <vector>
-#include <string>
#include <limits>
#include <memory>
-#include <utility>
+#include <string>
+#include <tuple>
#include <type_traits>
-
+#include <utility>
+#include <vector>
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
* }
* \endcode
*/
- using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
+ using FType = std::function<void(TVMArgs args, TVMRetValue* rv)>;
/*! \brief default constructor */
PackedFunc() {}
/*! \brief constructor from null */
* }
* \endcode
*/
- template<typename... Args>
- inline TVMRetValue operator()(Args&& ...args) const;
+ template <typename... Args>
+ inline TVMRetValue operator()(Args&&... args) const;
/*!
* \brief Call the function in packed format.
* \param args The arguments
/*! \return the internal body function */
inline FType body() const;
/*! \return Whether the packed function is nullptr */
- bool operator==(std::nullptr_t null) const {
- return body_ == nullptr;
- }
+ bool operator==(std::nullptr_t null) const { return body_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
- bool operator!=(std::nullptr_t null) const {
- return body_ != nullptr;
- }
+ bool operator!=(std::nullptr_t null) const { return body_ != nullptr; }
private:
/*! \brief internal container of packed function */
/*!
* \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>"
*/
-template<typename FType>
+template <typename FType>
class TypedPackedFunc;
/*!
* \tparam R The return value of the function.
* \tparam Args The argument signature of the function.
*/
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
class TypedPackedFunc<R(Args...)> {
public:
/*! \brief short hand for this function type */
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
*/
- template<typename FLambda,
- typename = typename std::enable_if<
- std::is_convertible<FLambda,
- std::function<R(Args...)>
- >::value>::type>
+ template <typename FLambda, typename = typename std::enable_if<
+ std::is_convertible<FLambda,
+ std::function<R(Args...)> >::value>::type>
TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
}
* \tparam FLambda the type of the lambda function.
* \returns reference to self.
*/
- template<typename FLambda,
- typename = typename std::enable_if<
- std::is_convertible<FLambda,
- std::function<R(Args...)>
- >::value>::type>
+ template <typename FLambda, typename = typename std::enable_if<
+ std::is_convertible<FLambda,
+ std::function<R(Args...)> >::value>::type>
TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
return *this;
* \param args The arguments
* \returns The return value.
*/
- TVM_ALWAYS_INLINE R operator()(Args ...args) const;
+ TVM_ALWAYS_INLINE R operator()(Args... args) const;
/*!
* \brief convert to PackedFunc
* \return the internal PackedFunc
*/
- operator PackedFunc() const {
- return packed();
- }
+ operator PackedFunc() const { return packed(); }
/*!
* \return reference the internal PackedFunc
*/
- const PackedFunc& packed() const {
- return packed_;
- }
+ const PackedFunc& packed() const { return packed_; }
/*! \return Whether the packed function is nullptr */
- bool operator==(std::nullptr_t null) const {
- return packed_ == nullptr;
- }
+ bool operator==(std::nullptr_t null) const { return packed_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
- bool operator!=(std::nullptr_t null) const {
- return packed_ != nullptr;
- }
+ bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; }
private:
friend class TVMRetValue;
* \tparam FLambda The lambda function type.
* \note We capture the lambda when possible for maximum efficiency.
*/
- template<typename FLambda>
+ template <typename FLambda>
inline void AssignTypedLambda(FLambda flambda);
};
* \param type_codes The argument type codes
* \param num_args number of arguments.
*/
- TVMArgs(const TVMValue* values,
- const int* type_codes,
- int num_args)
- : values(values),
- type_codes(type_codes),
- num_args(num_args) { }
+ TVMArgs(const TVMValue* values, const int* type_codes, int num_args)
+ : values(values), type_codes(type_codes), num_args(num_args) {}
/*! \return size of the arguments */
inline int size() const;
/*!
};
// macro to check type code.
-#define TVM_CHECK_TYPE_CODE(CODE, T) \
- CHECK_EQ(CODE, T) << " expected " \
- << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
+#define TVM_CHECK_TYPE_CODE(CODE, T) \
+ CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE)
/*!
* \brief Type traits for runtime type check during FFI conversion.
* \tparam T the type to be checked.
*/
-template<typename T>
+template <typename T>
struct ObjectTypeChecker {
static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
return value_.v_handle;
}
operator DLTensor*() const {
- if (type_code_ == kTVMDLTensorHandle ||
- type_code_ == kTVMNDArrayHandle) {
+ if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) {
return static_cast<DLTensor*>(value_.v_handle);
} else {
if (type_code_ == kTVMNullptr) return nullptr;
LOG(FATAL) << "Expect "
- << "DLTensor* or NDArray but get "
- << TypeCode2Str(type_code_);
+ << "DLTensor* or NDArray but get " << TypeCode2Str(type_code_);
return nullptr;
}
}
operator NDArray() const {
if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
- return NDArray(NDArray::FFIDataFromHandle(
- static_cast<TVMArrayHandle>(value_.v_handle)));
+ return NDArray(NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle)));
}
operator Module() const {
if (type_code_ == kTVMNullptr) {
return Module(ObjectPtr<Object>(nullptr));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
- return Module(
- ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
+ return Module(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
- int type_code() const {
- return type_code_;
- }
+ int type_code() const { return type_code_; }
/*!
* \brief return handle as specific pointer type.
* \tparam T the data type.
* \return The pointer type.
*/
- template<typename T>
+ template <typename T>
T* ptr() const {
return static_cast<T*>(value_.v_handle);
}
// ObjectRef handling
- template<typename TObjectRef,
- typename = typename std::enable_if<
- std::is_base_of<ObjectRef, TObjectRef>::value>::type>
+ template <typename TObjectRef,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
inline bool IsObjectRef() const;
- template<typename TObjectRef>
+ template <typename TObjectRef>
inline TObjectRef AsObjectRef() const;
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
TVMPODValue_() : type_code_(kTVMNullptr) {}
- TVMPODValue_(TVMValue value, int type_code)
- : value_(value), type_code_(type_code) {}
+ TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {}
/*! \brief The value */
TVMValue value_;
* \param value of the function
* \param type_code The type code.
*/
- TVMArgValue(TVMValue value, int type_code)
- : TVMPODValue_(value, type_code) {
- }
+ TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {}
// reuse converter from parent
using TVMPODValue_::operator double;
using TVMPODValue_::operator int64_t;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator Module;
- using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef;
+ using TVMPODValue_::IsObjectRef;
// conversion operator.
operator std::string() const {
// None type
if (type_code_ == kTVMNullptr) {
DLDataType t;
- t.code = kTVMOpaqueHandle; t.bits = 0; t.lanes = 0;
+ t.code = kTVMOpaqueHandle;
+ t.bits = 0;
+ t.lanes = 0;
return t;
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
return value_.v_type;
}
- operator DataType() const {
- return DataType(operator DLDataType());
- }
+ operator DataType() const { return DataType(operator DLDataType()); }
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
return *ptr<PackedFunc>();
}
- template<typename FType>
+ template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
- const TVMValue& value() const {
- return value_;
- }
+ const TVMValue& value() const { return value_; }
- template<typename T,
- typename = typename std::enable_if<
- std::is_class<T>::value>::type>
+ template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
inline operator T() const;
};
*/
class TVMMovableArgValue_ : public TVMArgValue {
public:
- TVMMovableArgValue_(TVMValue value, int type_code)
- : TVMArgValue(value, type_code) {
- }
+ TVMMovableArgValue_(TVMValue value, int type_code) : TVMArgValue(value, type_code) {}
// reuse converter from parent
using TVMArgValue::operator double;
using TVMArgValue::operator int64_t;
* Try to move out an argument if possible,
* fall back to normal argument conversion rule otherwise.
*/
- template<typename T,
- typename = typename std::enable_if<
- std::is_base_of<ObjectRef, T>::value>::type>
+ template <typename T,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
inline operator T() const;
};
* \brief move constructor from anoter return value.
* \param other The other return value.
*/
- TVMRetValue(TVMRetValue&& other)
- : TVMPODValue_(other.value_, other.type_code_) {
+ TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) {
other.value_.v_handle = nullptr;
other.type_code_ = kTVMNullptr;
}
/*! \brief destructor */
- ~TVMRetValue() {
- this->Clear();
- }
+ ~TVMRetValue() { this->Clear(); }
// reuse converter from parent
using TVMPODValue_::operator double;
using TVMPODValue_::operator int64_t;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Module;
- using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef;
+ using TVMPODValue_::IsObjectRef;
- TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
- this->Assign(other);
- }
+ TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); }
// conversion operators
operator std::string() const {
if (type_code_ == kTVMDataType) {
TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
return value_.v_type;
}
- operator DataType() const {
- return DataType(operator DLDataType());
- }
+ operator DataType() const { return DataType(operator DLDataType()); }
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
return *ptr<PackedFunc>();
}
- template<typename FType>
+ template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
value_.v_type = t;
return *this;
}
- TVMRetValue& operator=(const DataType& other) {
- return operator=(other.operator DLDataType());
- }
+ TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); }
TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
}
return *this;
}
- template<typename FType>
+ template <typename FType>
TVMRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed());
}
* \param ret_value The return value.
* \param ret_type_code The return type code.
*/
- void MoveToCHost(TVMValue* ret_value,
- int* ret_type_code) {
+ void MoveToCHost(TVMValue* ret_value, int* ret_type_code) {
// cannot move str; need specially handle.
CHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes);
*ret_value = value_;
* \param type_code The type code.
* \return The created TVMRetValue.
*/
- static TVMRetValue MoveFromCHost(TVMValue value,
- int type_code) {
+ static TVMRetValue MoveFromCHost(TVMValue value, int type_code) {
// Can move POD and everything under the object system.
- CHECK(type_code <= kTVMPackedFuncHandle ||
- type_code == kTVMNDArrayHandle);
+ CHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle);
TVMRetValue ret;
ret.value_ = value;
ret.type_code_ = type_code;
}
/*! \return The value field, if the data is POD */
const TVMValue& value() const {
- CHECK(type_code_ != kTVMObjectHandle &&
- type_code_ != kTVMPackedFuncHandle &&
- type_code_ != kTVMModuleHandle &&
- type_code_ != kTVMStr) << "TVMRetValue.value can only be used for POD data";
+ CHECK(type_code_ != kTVMObjectHandle && type_code_ != kTVMPackedFuncHandle &&
+ type_code_ != kTVMModuleHandle && type_code_ != kTVMStr)
+ << "TVMRetValue.value can only be used for POD data";
return value_;
}
// ObjectRef handling
- template<typename TObjectRef,
- typename = typename std::enable_if<
- std::is_base_of<ObjectRef, TObjectRef>::value>::type>
+ template <typename TObjectRef,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
inline TVMRetValue& operator=(TObjectRef other);
- template<typename T,
- typename = typename std::enable_if<
- std::is_class<T>::value>::type>
+ template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
inline operator T() const;
private:
- template<typename T>
+ template <typename T>
void Assign(const T& other) {
switch (other.type_code()) {
case kTVMStr: {
}
case kTVMObjectHandle: {
// Avoid operator ObjectRef as we already know it is not NDArray/Module
- SwitchToObject(
- kTVMObjectHandle, GetObjectPtr<Object>(
- static_cast<Object*>(other.value_.v_handle)));
+ SwitchToObject(kTVMObjectHandle,
+ GetObjectPtr<Object>(static_cast<Object*>(other.value_.v_handle)));
break;
}
case kTVMObjectRValueRefArg: {
type_code_ = type_code;
}
}
- template<typename T>
+ template <typename T>
void SwitchToClass(int type_code, T v) {
if (type_code_ != type_code) {
this->Clear();
void Clear() {
if (type_code_ == kTVMNullptr) return;
switch (type_code_) {
- case kTVMStr: case kTVMBytes: delete ptr<std::string>(); break;
- case kTVMPackedFuncHandle: delete ptr<PackedFunc>(); break;
+ case kTVMStr:
+ case kTVMBytes:
+ delete ptr<std::string>();
+ break;
+ case kTVMPackedFuncHandle:
+ delete ptr<PackedFunc>();
+ break;
case kTVMNDArrayHandle: {
NDArray::FFIDecRef(static_cast<TVMArrayHandle>(value_.v_handle));
break;
*
* \tparam TObjectRef the specific ObjectRefType.
*/
-template<typename TObjectRef>
+template <typename TObjectRef>
struct PackedFuncValueConverter {
/*!
* \brief Convert a TObjectRef from an argument value.
* \param val The argument value.
* \return the converted result.
*/
- static TObjectRef From(const TVMArgValue& val) {
- return val.AsObjectRef<TObjectRef>();
- }
+ static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef<TObjectRef>(); }
/*!
* \brief Convert a TObjectRef from a return value.
* \param val The argument value.
* \return the converted result.
*/
- static TObjectRef From(const TVMRetValue& val) {
- return val.AsObjectRef<TObjectRef>();
- }
+ static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef<TObjectRef>(); }
};
/*!
*
* \endcode
*/
-#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \
- extern "C" { \
- TVM_DLL int ExportName(TVMValue* args, \
- int* type_code, \
- int num_args, \
- TVMValue* out_value, \
- int* out_type_code); \
- int ExportName(TVMValue* args, \
- int* type_code, \
- int num_args, \
- TVMValue* out_value, \
- int* out_type_code) { \
- try { \
- ::tvm::runtime::TVMRetValue rv; \
- Function(::tvm::runtime::TVMArgs( \
- args, type_code, num_args), &rv); \
- rv.MoveToCHost(out_value, out_type_code); \
- return 0; \
- } catch (const ::std::runtime_error& _except_) { \
- TVMAPISetLastError(_except_.what()); \
- return -1; \
- } \
- } \
+#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \
+ extern "C" { \
+ TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
+ int* out_type_code); \
+ int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
+ int* out_type_code) { \
+ try { \
+ ::tvm::runtime::TVMRetValue rv; \
+ Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
+ rv.MoveToCHost(out_value, out_type_code); \
+ return 0; \
+ } catch (const ::std::runtime_error& _except_) { \
+ TVMAPISetLastError(_except_.what()); \
+ return -1; \
+ } \
+ } \
}
/*!
*
* \endcode
*/
-#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
- extern "C" { \
- TVM_DLL int ExportName(TVMValue* args, \
- int* type_code, \
- int num_args, \
- TVMValue* out_value, \
- int* out_type_code) { \
- try { \
- auto f = Function; \
- using FType = ::tvm::runtime::detail:: \
- function_signature<decltype(f)>::FType; \
- ::tvm::runtime::TVMRetValue rv; \
- ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \
- f, \
- ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
- rv.MoveToCHost(out_value, out_type_code); \
- return 0; \
- } catch (const ::std::runtime_error& _except_) { \
- TVMAPISetLastError(_except_.what()); \
- return -1; \
- } \
- } \
+#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
+ extern "C" { \
+ TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
+ int* out_type_code) { \
+ try { \
+ auto f = Function; \
+ using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \
+ ::tvm::runtime::TVMRetValue rv; \
+ ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \
+ f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
+ rv.MoveToCHost(out_value, out_type_code); \
+ return 0; \
+ } catch (const ::std::runtime_error& _except_) { \
+ TVMAPISetLastError(_except_.what()); \
+ return -1; \
+ } \
+ } \
}
-
inline TVMArgValue TVMArgs::operator[](int i) const {
- CHECK_LT(i, num_args)
- << "not enough argument passed, "
- << num_args << " passed"
- << " but request arg[" << i << "].";
+ CHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed"
+ << " but request arg[" << i << "].";
return TVMArgValue(values[i], type_codes[i]);
}
-inline int TVMArgs::size() const {
- return num_args;
-}
+inline int TVMArgs::size() const { return num_args; }
-inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
- body_(args, rv);
-}
+inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); }
-inline PackedFunc::FType PackedFunc::body() const {
- return body_;
-}
+inline PackedFunc::FType PackedFunc::body() const { return body_; }
// internal namespace
namespace detail {
-template<bool stop, std::size_t I, typename F>
+template <bool stop, std::size_t I, typename F>
struct for_each_dispatcher {
- template<typename T, typename ...Args>
+ template <typename T, typename... Args>
static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
f(I, std::forward<T>(value));
- for_each_dispatcher<sizeof...(Args) == 0, (I+1), F>
- ::run(f, std::forward<Args>(args)...);
+ for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...);
}
};
-template<std::size_t I, typename F>
-struct for_each_dispatcher<true, I, F> {
+template <std::size_t I, typename F>
+struct for_each_dispatcher<true, I, F> {
static void run(const F& f) {} // NOLINT(*)
};
-template<typename F, typename ...Args>
+template <typename F, typename... Args>
inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
- for_each_dispatcher<sizeof...(Args) == 0, 0, F>
- ::run(f, std::forward<Args>(args)...);
+ for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
}
-template<typename T>
+template <typename T>
struct func_signature_helper {
using FType = void;
};
-template<typename T, typename R, typename ...Args>
+template <typename T, typename R, typename... Args>
struct func_signature_helper<R (T::*)(Args...)> {
using FType = R(Args...);
- static_assert(!std::is_reference<R>::value,
- "TypedPackedFunc return reference");
+ static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
};
-template<typename T, typename R, typename ...Args>
+template <typename T, typename R, typename... Args>
struct func_signature_helper<R (T::*)(Args...) const> {
using FType = R(Args...);
- static_assert(!std::is_reference<R>::value,
- "TypedPackedFunc return reference");
+ static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
};
/*!
* \brief template class to get function signature of a function or functor.
* \tparam T The funtion/functor type.
*/
-template<typename T>
+template <typename T>
struct function_signature {
using FType = typename func_signature_helper<decltype(&T::operator())>::FType;
};
// handle case of function.
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
struct function_signature<R(Args...)> {
using FType = R(Args...);
- static_assert(!std::is_reference<R>::value,
- "TypedPackedFunc return reference");
+ static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
};
// handle case of function ptr.
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
struct function_signature<R (*)(Args...)> {
using FType = R(Args...);
- static_assert(!std::is_reference<R>::value,
- "TypedPackedFunc return reference");
+ static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
};
} // namespace detail
/* \brief argument settter to PackedFunc */
class TVMArgsSetter {
public:
- TVMArgsSetter(TVMValue* values, int* type_codes)
- : values_(values), type_codes_(type_codes) {}
+ TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {}
// setters for POD types
- template<typename T,
- typename = typename std::enable_if<
- std::is_integral<T>::value>::type>
+ template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
TVM_ALWAYS_INLINE void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kDLInt;
}
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
- CHECK_LE(value,
- static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
+ CHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
type_codes_[i] = kDLInt;
}
TVM_ALWAYS_INLINE void operator()(size_t i, double value) const {
type_codes_[i] = kTVMNullptr;
}
}
- template<typename FType>
+ template <typename FType>
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
operator()(i, value.packed());
}
}
}
// ObjectRef handling
- template<typename TObjectRef,
- typename = typename std::enable_if<
- std::is_base_of<ObjectRef, TObjectRef>::value>
- ::type>
+ template <typename TObjectRef,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const {
this->SetObject(i, value);
}
- template<typename TObjectRef,
- typename = typename std::enable_if<
- std::is_base_of<ObjectRef,
- typename std::remove_reference<TObjectRef>::type>::value>
- ::type>
+ template <typename TObjectRef,
+ typename = typename std::enable_if<std::is_base_of<
+ ObjectRef, typename std::remove_reference<TObjectRef>::type>::value>::type>
TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const {
this->SetObject(i, std::forward<TObjectRef>(value));
}
private:
- template<typename TObjectRef>
+ template <typename TObjectRef>
inline void SetObject(size_t i, TObjectRef&& value) const;
/*! \brief The values fields */
TVMValue* values_;
int* type_codes_;
};
-template<typename... Args>
-inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
+template <typename... Args>
+inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
- detail::for_each(TVMArgsSetter(values, type_codes),
- std::forward<Args>(args)...);
+ detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
namespace detail {
-template<typename R, int nleft, int index, typename F>
+template <typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
- template<typename ...Args>
- TVM_ALWAYS_INLINE static void run(const F& f,
- const TVMArgs& args_pack,
- TVMRetValue* rv,
+ template <typename... Args>
+ TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv,
Args&&... unpacked_args) {
// construct a movable argument value
// which allows potential move of argument to the input of F.
- unpack_call_dispatcher<R, nleft - 1, index + 1, F>
- ::run(f, args_pack, rv,
- std::forward<Args>(unpacked_args)...,
- TVMMovableArgValue_(args_pack.values[index],
- args_pack.type_codes[index]));
+ unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
+ f, args_pack, rv, std::forward<Args>(unpacked_args)...,
+ TVMMovableArgValue_(args_pack.values[index], args_pack.type_codes[index]));
}
};
-template<typename R, int index, typename F>
+template <typename R, int index, typename F>
struct unpack_call_dispatcher<R, 0, index, F> {
- template<typename ...Args>
- TVM_ALWAYS_INLINE static void run(const F& f,
- const TVMArgs& args_pack,
- TVMRetValue* rv,
+ template <typename... Args>
+ TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv,
Args&&... unpacked_args) {
using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
if (std::is_same<RetType, R>::value) {
}
};
-template<int index, typename F>
+template <int index, typename F>
struct unpack_call_dispatcher<void, 0, index, F> {
- template<typename ...Args>
- TVM_ALWAYS_INLINE static void run(const F& f,
- const TVMArgs& args_pack,
- TVMRetValue* rv,
+ template <typename... Args>
+ TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv,
Args&&... unpacked_args) {
f(std::forward<Args>(unpacked_args)...);
}
};
-template<typename R, int nargs, typename F>
-TVM_ALWAYS_INLINE void unpack_call(
- const F& f, const TVMArgs& args, TVMRetValue* rv) {
- CHECK_EQ(nargs, args.size())
- << "Expect " << nargs << " arguments but get " << args.size();
+template <typename R, int nargs, typename F>
+TVM_ALWAYS_INLINE void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
+ CHECK_EQ(nargs, args.size()) << "Expect " << nargs << " arguments but get " << args.size();
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}
-template<typename FType>
-struct unpack_call_by_signature {
-};
+template <typename FType>
+struct unpack_call_by_signature {};
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
struct unpack_call_by_signature<R(Args...)> {
- template<typename F>
- TVM_ALWAYS_INLINE static void run(
- const F& f,
- const TVMArgs& args,
- TVMRetValue* rv) {
+ template <typename F>
+ TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) {
unpack_call<R, sizeof...(Args)>(f, args, rv);
}
};
-template<typename R, typename ...Args>
-TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&& ...args) {
+template <typename R, typename... Args>
+TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) {
return R(pf(std::forward<Args>(args)...));
}
-template<typename R>
+template <typename R>
struct typed_packed_call_dispatcher {
- template<typename ...Args>
- TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&& ...args) {
+ template <typename... Args>
+ TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) {
return pf(std::forward<Args>(args)...);
}
};
-template<>
+template <>
struct typed_packed_call_dispatcher<void> {
- template<typename ...Args>
- TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&& ...args) {
+ template <typename... Args>
+ TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) {
pf(std::forward<Args>(args)...);
}
};
} // namespace detail
-template<typename R, typename ...Args>
-TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed)
- : packed_(packed) {}
+template <typename R, typename... Args>
+TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) : packed_(packed) {}
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMRetValue& value)
: packed_(value.operator PackedFunc()) {}
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value)
: packed_(value.operator PackedFunc()) {}
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(TVMMovableArgValue_&& value)
: packed_(value.operator PackedFunc()) {}
-template<typename R, typename ...Args>
-template<typename FType>
+template <typename R, typename... Args>
+template <typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
- detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
- });
+ detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
+ });
}
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
TVM_ALWAYS_INLINE R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
- return detail::typed_packed_call_dispatcher<R>
- ::run(packed_, std::forward<Args>(args)...);
+ return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...);
}
// ObjectRef related conversion handling
// kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle
//
// We use type traits to eliminate un-necessary checks.
-template<typename T>
+template <typename T>
inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
using ContainerType = typename std::remove_reference<T>::type::ContainerType;
if (value.defined()) {
}
}
-template<typename TObjectRef, typename>
+template <typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
using ContainerType = typename TObjectRef::ContainerType;
// NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
return type_code_ == kTVMNDArrayHandle &&
- TVMArrayHandleToObjectHandle(
- static_cast<TVMArrayHandle>(value_.v_handle))->IsInstance<ContainerType>();
+ TVMArrayHandleToObjectHandle(static_cast<TVMArrayHandle>(value_.v_handle))
+ ->IsInstance<ContainerType>();
}
if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
return type_code_ == kTVMModuleHandle &&
- static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
+ static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
// NOTE: we don't pass NDArray and runtime::Module as RValue ref.
if (type_code_ == kTVMObjectRValueRefArg) {
- return ObjectTypeChecker<TObjectRef>::Check(
- *static_cast<Object**>(value_.v_handle));
- }
- return
- (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
- type_code_ == kTVMNDArrayHandle) ||
- (std::is_base_of<ContainerType, Module::ContainerType>::value &&
- type_code_ == kTVMModuleHandle) ||
- (type_code_ == kTVMObjectHandle &&
- ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
+ return ObjectTypeChecker<TObjectRef>::Check(*static_cast<Object**>(value_.v_handle));
+ }
+ return (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
+ type_code_ == kTVMNDArrayHandle) ||
+ (std::is_base_of<ContainerType, Module::ContainerType>::value &&
+ type_code_ == kTVMModuleHandle) ||
+ (type_code_ == kTVMObjectHandle &&
+ ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
}
-template<typename TObjectRef>
+template <typename TObjectRef>
inline TObjectRef TVMPODValue_::AsObjectRef() const {
- static_assert(
- std::is_base_of<ObjectRef, TObjectRef>::value,
- "Conversion only works for ObjectRef");
+ static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
+ "Conversion only works for ObjectRef");
using ContainerType = typename TObjectRef::ContainerType;
if (type_code_ == kTVMNullptr) {
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
// Casting to a sub-class of NDArray
TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
- ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
- static_cast<TVMArrayHandle>(value_.v_handle));
+ ObjectPtr<Object> data =
+ NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle));
CHECK(data->IsInstance<ContainerType>())
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
return TObjectRef(data);
// normal object type check.
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
- << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
- << " but get " << ptr->GetTypeKey();
+ << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << " but get "
+ << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (type_code_ == kTVMObjectRValueRefArg) {
Object* ptr = *static_cast<Object**>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
- << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
- << " but get " << ptr->GetTypeKey();
+ << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << " but get "
+ << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
type_code_ == kTVMNDArrayHandle) {
// Casting to a base class that NDArray can sub-class
- ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
- static_cast<TVMArrayHandle>(value_.v_handle));
+ ObjectPtr<Object> data =
+ NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle));
return TObjectRef(data);
} else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) {
}
}
-template<typename TObjectRef, typename>
+template <typename TObjectRef, typename>
inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
using ContainerType = typename TObjectRef::ContainerType;
const Object* ptr = other.get();
return *this;
}
-
-template<typename T, typename>
+template <typename T, typename>
inline TVMArgValue::operator T() const {
return PackedFuncValueConverter<T>::From(*this);
}
-template<typename T, typename>
+template <typename T, typename>
inline TVMMovableArgValue_::operator T() const {
if (type_code_ == kTVMObjectRValueRefArg) {
auto** ref = static_cast<Object**>(value_.v_handle);
return PackedFuncValueConverter<T>::From(*this);
}
-template<typename T, typename>
+template <typename T, typename>
inline TVMRetValue::operator T() const {
return PackedFuncValueConverter<T>::From(*this);
}
#define TVM_RUNTIME_REGISTRY_H_
#include <tvm/runtime/packed_func.h>
+
#include <string>
-#include <vector>
#include <utility>
+#include <vector>
namespace tvm {
namespace runtime {
}
/*!
* \brief set the body of the function to the given function.
- * Note that this will ignore default arg values and always require all arguments to be provided.
+ * Note that this will ignore default arg values and always require all arguments to be
+ * provided.
*
* \code
*
* \param f The function to forward to.
* \tparam FLambda The signature of the function.
*/
- template<typename FLambda>
+ template <typename FLambda>
Registry& set_body_typed(FLambda f) {
using FType = typename detail::function_signature<FLambda>::FType;
return set_body(TypedPackedFunc<FType>(std::move(f)).packed());
}
/*!
* \brief set the body of the function to be the passed method pointer.
- * Note that this will ignore default arg values and always require all arguments to be provided.
+ * Note that this will ignore default arg values and always require all arguments to be
+ * provided.
*
* \code
*
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
- template<typename T, typename R, typename ...Args>
+ template <typename T, typename R, typename... Args>
Registry& set_body_method(R (T::*f)(Args...)) {
- auto fwrap =[f](T target, Args... params) -> R {
+ auto fwrap = [f](T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
};
/*!
* \brief set the body of the function to be the passed method pointer.
- * Note that this will ignore default arg values and always require all arguments to be provided.
+ * Note that this will ignore default arg values and always require all arguments to be
+ * provided.
*
* \code
*
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
- template<typename T, typename R, typename ...Args>
+ template <typename T, typename R, typename... Args>
Registry& set_body_method(R (T::*f)(Args...) const) {
auto fwrap = [f](const T target, Args... params) -> R {
// call method pointer
/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a ObjectRef subclass.
- * Note that this will ignore default arg values and always require all arguments to be provided.
+ * Note that this will ignore default arg values and always require all arguments to be
+ * provided.
*
* \code
*
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
- template<typename TObjectRef, typename TNode, typename R, typename ...Args,
- typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
+ template <typename TObjectRef, typename TNode, typename R, typename... Args,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) {
auto fwrap = [f](TObjectRef ref, Args... params) {
TNode* target = ref.operator->();
/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a ObjectRef subclass.
- * Note that this will ignore default arg values and always require all arguments to be provided.
+ * Note that this will ignore default arg values and always require all arguments to be
+ * provided.
*
* \code
*
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
- template<typename TObjectRef, typename TNode, typename R, typename ...Args,
- typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
+ template <typename TObjectRef, typename TNode, typename R, typename... Args,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) {
auto fwrap = [f](TObjectRef ref, Args... params) {
const TNode* target = ref.operator->();
friend struct Manager;
};
-#define TVM_FUNC_REG_VAR_DEF \
- static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
+#define TVM_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_##TVM
/*!
* \brief Register a function globally.
* });
* \endcode
*/
-#define TVM_REGISTER_GLOBAL(OpName) \
- TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
- ::tvm::runtime::Registry::Register(OpName)
+#define TVM_REGISTER_GLOBAL(OpName) \
+ TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName)
} // namespace runtime
} // namespace tvm
namespace dmlc {
namespace serializer {
-template<>
+template <>
struct Handler<DLDataType> {
- inline static void Write(Stream *strm, const DLDataType& dtype) {
+ inline static void Write(Stream* strm, const DLDataType& dtype) {
Handler<uint8_t>::Write(strm, dtype.code);
Handler<uint8_t>::Write(strm, dtype.bits);
Handler<uint16_t>::Write(strm, dtype.lanes);
}
- inline static bool Read(Stream *strm, DLDataType* dtype) {
+ inline static bool Read(Stream* strm, DLDataType* dtype) {
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
}
};
-template<>
+template <>
struct Handler<DLContext> {
- inline static void Write(Stream *strm, const DLContext& ctx) {
+ inline static void Write(Stream* strm, const DLContext& ctx) {
int32_t device_type = static_cast<int32_t>(ctx.device_type);
Handler<int32_t>::Write(strm, device_type);
Handler<int32_t>::Write(strm, ctx.device_id);
}
- inline static bool Read(Stream *strm, DLContext* ctx) {
+ inline static bool Read(Stream* strm, DLContext* ctx) {
int32_t device_type = 0;
if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
ctx->device_type = static_cast<DLDeviceType>(device_type);
public:
class Impl;
- /*!
- * \brief Creates a collection of threads which run a provided function.
- *
- * \param num_workers The total number of worker threads in this group.
- Includes main thread if `exclude_worker0 = true`
- * \param worker_callback A callback which is run in its own thread.
- Receives the worker_id as an argument.
- * \param exclude_worker0 Whether to use the main thread as a worker.
- * If `true`, worker0 will not be launched in a new thread and
- * `worker_callback` will only be called for values >= 1. This
- * allows use of the main thread as a worker.
- */
- ThreadGroup(int num_workers,
- std::function<void(int)> worker_callback,
+ /*!
+ * \brief Creates a collection of threads which run a provided function.
+ *
+ * \param num_workers The total number of worker threads in this group.
+ Includes main thread if `exclude_worker0 = true`
+ * \param worker_callback A callback which is run in its own thread.
+ Receives the worker_id as an argument.
+ * \param exclude_worker0 Whether to use the main thread as a worker.
+ * If `true`, worker0 will not be launched in a new thread and
+ * `worker_callback` will only be called for values >= 1. This
+ * allows use of the main thread as a worker.
+ */
+ ThreadGroup(int num_workers, std::function<void(int)> worker_callback,
bool exclude_worker0 = false);
~ThreadGroup();
- /*!
- * \brief Blocks until all non-main threads in the pool finish.
- */
+ /*!
+ * \brief Blocks until all non-main threads in the pool finish.
+ */
void Join();
enum AffinityMode : int {
*/
int MaxConcurrency();
-
} // namespace threading
} // namespace runtime
} // namespace tvm
#ifndef TVM_RUNTIME_VM_H_
#define TVM_RUNTIME_VM_H_
-#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
+#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
+
#include <memory>
#include <string>
#include <unordered_map>
* \param dst The destination register.
* \return The allocate tensor instruction.
*/
- static Instruction AllocTensor(RegName storage,
- const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
+ static Instruction AllocTensor(RegName storage, const std::vector<int64_t>& shape,
+ DLDataType dtype, RegName dst);
/*!
* \brief Construct an allocate tensor instruction with register.
* \param storage The storage to allocate out of.
* \param dst The destination register.
* \return The allocate tensor instruction.
*/
- static Instruction AllocTensorReg(RegName storage,
- RegName shape_register, DLDataType dtype, RegName dst);
+ static Instruction AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype,
+ RegName dst);
/*!
* \brief Construct an allocate datatype instruction.
* \param tag The datatype tag.
* \param dst The destination to place the storage.
* \return The alloc storage instruction.
*/
- static Instruction AllocStorage(RegName size, RegName alignment,
- DLDataType dtype_hint, RegName dst);
+ static Instruction AllocStorage(RegName size, RegName alignment, DLDataType dtype_hint,
+ RegName dst);
Instruction();
Instruction(const Instruction& instr);
Index register_file_size;
VMFunction(const std::string& name, std::vector<std::string> params,
- const std::vector<Instruction>& instructions,
- Index register_file_size)
+ const std::vector<Instruction>& instructions, Index register_file_size)
: name(name),
params(params),
instructions(instructions),
*
* \return PackedFunc or nullptr when it is not available.
*/
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final;
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
/*!
* \brief Serialize the executable into global section, constant section, and
virtual ~Executable() {}
- const char* type_key() const final {
- return "VMExecutable";
- }
+ const char* type_key() const final { return "VMExecutable"; }
/*! \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices. */
* If the function needs resource from the module(e.g. late linking),
* it should capture sptr_to_self.
*/
- virtual PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self);
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
virtual ~VirtualMachine() {}
- const char* type_key() const final {
- return "VirtualMachine";
- }
+ const char* type_key() const final { return "VirtualMachine"; }
VirtualMachine() : frames_(), func_index_(0), code_(nullptr), pc_(0), exec_(nullptr) {}
*
* \note The return value will be stored in the last output_size slots of args.
*/
- virtual void InvokePacked(Index packed_index,
- const PackedFunc& func,
- Index arg_count,
- Index output_size,
- const std::vector<ObjectRef>& args);
+ virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
+ Index output_size, const std::vector<ObjectRef>& args);
/*!
* \brief Initialize the virtual machine for a set of contexts.
* a = ...
* b = ...
* // if quit_on_assertion is true, if a==b, continue, otherwise quit.
- * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default behaviour)
- * COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting"
+ * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default
+ * behaviour) COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting"
* ...
* for (int i = 0; i < N; i++) {
* a = ...
// Not supposed to be used by users directly.
#define COND_CHECK_OP(quit_on_assert, x, y, what, op) \
- if (!quit_on_assert) { \
- if (!((x) op (y))) \
- what; \
- } \
- else /* NOLINT(*) */ \
+ if (!quit_on_assert) { \
+ if (!((x)op(y))) what; \
+ } else /* NOLINT(*) */ \
CHECK_##op(x, y)
#define COND_CHECK_EQ_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, ==)
#define COND_CHECK_GE_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, >=)
#define COND_CHECK_3(quit_on_assert, x, what) \
- if (!quit_on_assert) { \
- if (!(x)) \
- what; \
- } \
- else /* NOLINT(*) */ \
+ if (!quit_on_assert) { \
+ if (!(x)) what; \
+ } else /* NOLINT(*) */ \
CHECK(x)
#define COND_LOG_3(quit_on_assert, x, what) \
- if (!quit_on_assert) { \
- what; \
- } \
- else /* NOLINT(*) */ \
+ if (!quit_on_assert) { \
+ what; \
+ } else /* NOLINT(*) */ \
LOG(x)
#define COND_CHECK_EQ_3(quit_on_assert, x, y) COND_CHECK_EQ_4(quit_on_assert, x, y, return false)
#define COND_CHECK_2(quit_on_assert, x) COND_CHECK_3(quit_on_assert, x, return false)
#define COND_LOG_2(quit_on_assert, x) COND_LOG_3(quit_on_assert, x, return false)
-#endif // TVM_SUPPORT_LOGGING_H_
+#endif // TVM_SUPPORT_LOGGING_H_
#define TVM_SUPPORT_WITH_H_
#include <dmlc/logging.h>
+
#include <utility>
namespace tvm {
*
* \tparam ContextType Type of the context object.
*/
-template<typename ContextType>
+template <typename ContextType>
class With {
public:
/*!
* \brief constructor.
* Enter the scope of the context.
*/
- template<typename ...Args>
- explicit With(Args&& ...args)
- : ctx_(std::forward<Args>(args)...) {
+ template <typename... Args>
+ explicit With(Args&&... args) : ctx_(std::forward<Args>(args)...) {
ctx_.EnterWithScope();
}
/*! \brief destructor, leaves the scope of the context. */
- ~With() DMLC_THROW_EXCEPTION {
- ctx_.ExitWithScope();
- }
+ ~With() DMLC_THROW_EXCEPTION { ctx_.ExitWithScope(); }
private:
/*! \brief internal context type. */
#ifndef TVM_TARGET_CODEGEN_H_
#define TVM_TARGET_CODEGEN_H_
-#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
-#include <tvm/tir/expr.h>
+#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
+#include <tvm/tir/expr.h>
#include <string>
-
namespace tvm {
/*! \brief namespace for target translation and codegen. */
namespace codegen {
* \param target_triple LLVM target triple
* \return runtime::Module The generated LLVM module.
*/
-runtime::Module PackImportsToLLVM(const runtime::Module& m,
- bool system_lib,
+runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib,
const std::string& target_triple);
} // namespace codegen
} // namespace tvm
#ifndef TVM_TARGET_GENERIC_FUNC_H_
#define TVM_TARGET_GENERIC_FUNC_H_
-#include <tvm/support/with.h>
#include <tvm/runtime/packed_func.h>
+#include <tvm/support/with.h>
#include <tvm/target/target.h>
-#include <vector>
#include <string>
-#include <utility>
#include <unordered_map>
+#include <utility>
+#include <vector>
namespace tvm {
* false, an error will be logged if the call would override a previously registered function.
* \return reference to self.
*/
- TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value,
- bool allow_override = false);
+ TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false);
/*!
* \brief Register a specialized function
* \param tags The tags for this specialization
* \return reference to self.
*/
TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
- const runtime::PackedFunc value,
- bool allow_override = false);
+ const runtime::PackedFunc value, bool allow_override = false);
/*!
* \brief Call generic function by directly passing in unpacked format.
* \param args Arguments to be passed.
* }
* \endcode
*/
- template<typename... Args>
- inline runtime::TVMRetValue operator()(Args&& ...args) const;
+ template <typename... Args>
+ inline runtime::TVMRetValue operator()(Args&&... args) const;
/*!
* \brief Invoke the relevant function for the current target context, set by set_target_context.
* Arguments are passed in packed format.
* \param args The arguments to pass to the function.
* \param ret The return value
*/
- TVM_DLL void CallPacked(runtime::TVMArgs args,
- runtime::TVMRetValue* ret) const;
+ TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;
/*!
* \brief Find or register the GenericFunc instance corresponding to the give name
friend struct Manager;
};
-template<typename... Args>
-inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
+template <typename... Args>
+inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
- std::forward<Args>(args)...);
+ std::forward<Args>(args)...);
runtime::TVMRetValue rv;
CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
return static_cast<GenericFuncNode*>(get_mutable());
}
-#define TVM_GENERIC_FUNC_REG_VAR_DEF \
- static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM
+#define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM
/*!
* \def TVM_REGISTER_GENERIC_FUNC
*
* \param name The name of the function
*/
-#define TVM_REGISTER_GENERIC_FUNC(name) \
- TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = \
- ::tvm::GenericFunc::Get(#name)
+#define TVM_REGISTER_GENERIC_FUNC(name) \
+ TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name)
} // namespace tvm
#endif // TVM_TARGET_GENERIC_FUNC_H_
#ifndef TVM_TARGET_TARGET_H_
#define TVM_TARGET_TARGET_H_
-#include <tvm/support/with.h>
-#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
+#include <tvm/node/container.h>
+#include <tvm/support/with.h>
#include <string>
-#include <vector>
#include <unordered_set>
#include <utility>
+#include <vector>
namespace tvm {
/*!
Target() {}
explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
- * \brief Create a Target given a string
- * \param target_str the string to parse
- */
+ * \brief Create a Target given a string
+ * \param target_str the string to parse
+ */
TVM_DLL static Target Create(const std::string& target_str);
/*!
* \brief Get the current target context from thread local storage.
*/
TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
- const TargetNode* operator->() const {
- return static_cast<const TargetNode*>(get());
- }
+ const TargetNode* operator->() const { return static_cast<const TargetNode*>(get()); }
using ContainerType = TargetNode;
class Internal;
+
private:
// enable with syntax.
friend class Internal;
namespace target {
/*! \return A target for LLVM */
-TVM_DLL Target llvm(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target llvm(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for CUDA */
-TVM_DLL Target cuda(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target cuda(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for ROCm */
-TVM_DLL Target rocm(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target rocm(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for OpenCL */
-TVM_DLL Target opencl(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target opencl(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for Metal */
-TVM_DLL Target metal(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target metal(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for rasp */
-TVM_DLL Target rasp(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target rasp(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for Mali */
-TVM_DLL Target mali(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target mali(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for Intel Graphics */
-TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target intel_graphics(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for stackvm */
-TVM_DLL Target stackvm(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target stackvm(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for external device */
-TVM_DLL Target ext_dev(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target ext_dev(const std::vector<std::string>& options = std::vector<std::string>());
/*! \return A target for hexagon */
-TVM_DLL Target hexagon(const std::vector<std::string>& options =
- std::vector<std::string>());
+TVM_DLL Target hexagon(const std::vector<std::string>& options = std::vector<std::string>());
} // namespace target
/*!
public:
BuildConfig() {}
explicit BuildConfig(ObjectPtr<Object> n) : ObjectRef(n) {}
- const BuildConfigNode* operator->() const {
- return static_cast<const BuildConfigNode*>(get());
- }
- BuildConfigNode* operator->() {
- return static_cast<BuildConfigNode*>(get_mutable());
- }
+ const BuildConfigNode* operator->() const { return static_cast<const BuildConfigNode*>(get()); }
+ BuildConfigNode* operator->() { return static_cast<BuildConfigNode*>(get_mutable()); }
/*!
* \brief Construct a BuildConfig containing a empty build config node.
* \return The new BuildConfig
#define TVM_TARGET_TARGET_INFO_H_
#include <tvm/ir/expr.h>
+
#include <string>
namespace tvm {
#include <tvm/runtime/object.h>
#include <tvm/tir/expr.h>
+
#include "tensor.h"
namespace tvm {
*
* Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor
* dot product. \p input must be an immediate dependency of \p output (must be called from within
- * the body of \p output). That is, the function will compute one summand of the adjoint for \p input
- * given the adjoint for \p output (which is called \p head here).
+ * the body of \p output). That is, the function will compute one summand of the adjoint for \p
+ * input given the adjoint for \p output (which is called \p head here).
*
* \param output The tensor to differentiate.
* \param input The input tensor, which \p output should directly use.
* \return The tensor of shape `prefix + input.shape`
* representing the partial adjoint of \p input wrt one of its consumers (output)
*/
-Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head);
+Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Tensor& head);
/*!
* \brief Perform reverse mode automatic differentiation.
* wrt all tensors the output depends on.
* \param head The adjoint of the output, in other words, some tensor, by which the Jacobians
* will be multiplied (using tensordot axes=`output.shape`).
- * Its shape must be of the form `prefix + output.shape`. If the null pointer is provided,
- * the identity tensor of shape `output.shape + output.shape` will be used.
- * \return An array of adjoints corresponding to \p inputs.
+ * Its shape must be of the form `prefix + output.shape`. If the null pointer is
+ * provided, the identity tensor of shape `output.shape + output.shape` will be used. \return An
+ * array of adjoints corresponding to \p inputs.
*/
-TVM_DLL Array<Tensor> Gradient(
- const Tensor& output,
- const Array<Tensor>& inputs,
- const Tensor& head = Tensor());
+TVM_DLL Array<Tensor> Gradient(const Tensor& output, const Array<Tensor>& inputs,
+ const Tensor& head = Tensor());
} // namespace te
} // namespace tvm
#define TVM_TE_OPERATION_H_
#include <tvm/arith/analyzer.h>
-#include <tvm/te/tensor.h>
#include <tvm/te/schedule.h>
-
+#include <tvm/te/tensor.h>
+#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
-#include <tvm/tir/buffer.h>
#include <string>
-#include <vector>
#include <unordered_map>
+#include <vector>
namespace tvm {
/*! \brief Tensor expression language DSL. */
*/
struct TensorDom {
// constructor
- explicit TensorDom(int ndim)
- : data(ndim) {}
+ explicit TensorDom(int ndim) : data(ndim) {}
/*! \brief The domain data */
std::vector<std::vector<IntSet> > data;
};
/*! \brief additional attributes of the operation*/
Map<std::string, ObjectRef> attrs;
/*! \return name of the operation */
- const std::string& func_name() const final {
- return name;
- }
+ const std::string& func_name() const final { return name; }
/*!
* \return The list of iteration variable at root
* \note root_iter_vars decides the shape of the outputs.
* \param rmap The replacement map.
* \return self if nothing is replaced, otherwise return replaced op.
*/
- virtual Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
+ virtual Operation ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
/*!
* \brief Propagate the bounds to inputs
* \param self The reference to self.
* The function is only asked to fill the bounds for Tensors that
* is already in the out_dom_map
*/
- virtual void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
+ virtual void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*!
* \brief Gather the bound from output tensor.
* Set the range of each root_iter_vars in the op to out_dom_map
* \param tensor_dom Domain map of Tensor->access set of each dimension.
* \param out_dom_map The output domain map of each IterVar to be setted.
*/
- virtual void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
+ virtual void GatherBound(const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
/*!
* \brief Build the Realize statement that realizes
* the op's output tensors.
* \param body The body that is going to get
* \return A realization statement that wraps body.
*/
- virtual Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const = 0;
+ virtual Stmt BuildRealize(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const = 0;
/*!
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return A statement that add production and wraps consumer.
*/
- virtual Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const = 0;
+ virtual Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const = 0;
static constexpr const char* _type_key = "Operation";
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
+ Operation ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
+ Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
- static Operation make(std::string name,
- Array<PrimExpr> shape,
- DataType dtype);
+ static Operation make(std::string name, Array<PrimExpr> shape, DataType dtype);
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
// override functions
Array<IterVar> root_iter_vars() const final;
Array<PrimExpr> output_shape(size_t idx) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
+ void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
virtual size_t num_schedulable_dims() const = 0;
static constexpr const char* _type_key = "BaseComputeOp";
TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
};
-
/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
int num_outputs() const final;
DataType output_dtype(size_t i) const final;
Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
+ Operation ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
- static Operation make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<IterVar> axis,
- Array<PrimExpr> body);
+ static Operation make(std::string name, std::string tag, Map<std::string, ObjectRef> attrs,
+ Array<IterVar> axis, Array<PrimExpr> body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
int num_outputs() const final;
DataType output_dtype(size_t i) const final;
Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
+ Operation ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("input_regions", &input_regions);
v->Visit("scalar_inputs", &scalar_inputs);
}
- static Operation make(std::string name,
- std::string tag,
- Array<IterVar> axis,
- Array<IterVar> reduce_axis,
- int schedulable_ndim,
- TensorIntrin intrin,
- Array<Tensor> tensors,
- Array<Region> regions,
+ static Operation make(std::string name, std::string tag, Array<IterVar> axis,
+ Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
+ Array<Tensor> tensors, Array<Region> regions,
Array<PrimExpr> scalar_inputs);
static constexpr const char* _type_key = "TensorComputeOp";
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
+ Operation ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
+ Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_);
}
- static Operation make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- IterVar axis,
- Array<Tensor> init,
- Array<Tensor> update,
- Array<Tensor> state_placeholder,
- Array<Tensor> input);
+ static Operation make(std::string name, std::string tag, Map<std::string, ObjectRef> attrs,
+ IterVar axis, Array<Tensor> init, Array<Tensor> update,
+ Array<Tensor> state_placeholder, Array<Tensor> input);
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
+ Operation ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
+ Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
- TVM_DLL static Operation make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<Tensor> inputs,
- Array<Buffer> input_placeholders,
- Array<Buffer> output_placeholders,
- Stmt body);
+ TVM_DLL static Operation make(std::string name, std::string tag,
+ Map<std::string, ObjectRef> attrs, Array<Tensor> inputs,
+ Array<Buffer> input_placeholders, Array<Buffer> output_placeholders,
+ Stmt body);
static constexpr const char* _type_key = "ExternOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
+ Operation ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
+ Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("axis", &axis);
v->Visit("body", &body);
}
- TVM_DLL static Operation make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<Tensor> inputs,
- Array<Tensor> outputs,
- Stmt body);
+ TVM_DLL static Operation make(std::string name, std::string tag,
+ Map<std::string, ObjectRef> attrs, Array<Tensor> inputs,
+ Array<Tensor> outputs, Stmt body);
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
/*! \brief The compute function to specify the input source of a Tensor */
-using FCompute = std::function<PrimExpr (const Array<Var>& i)>;
+using FCompute = std::function<PrimExpr(const Array<Var>& i)>;
/*! \brief The compute function to specify the inputs source of Tensors */
-using FBatchCompute = std::function<Array<PrimExpr> (const Array<Var>& i)>;
+using FBatchCompute = std::function<Array<PrimExpr>(const Array<Var>& i)>;
/*!
* \brief create a place holder tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
-TVM_DLL Tensor placeholder(Array<PrimExpr> shape,
- DataType dtype = DataType::Float(32),
+TVM_DLL Tensor placeholder(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
std::string name = "placeholder");
/*!
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
-TVM_DLL Tensor compute(Array<PrimExpr> shape,
- FCompute fcompute,
- std::string name = "tensor",
- std::string tag = "",
- Map<std::string, ObjectRef> attrs = {});
+TVM_DLL Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name = "tensor",
+ std::string tag = "", Map<std::string, ObjectRef> attrs = {});
/*!
* \brief Construct a new tensor by computing over shape,
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
-TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape,
- FBatchCompute fcompute,
- std::string name = "tensor",
- std::string tag = "",
+TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute,
+ std::string name = "tensor", std::string tag = "",
Map<std::string, ObjectRef> attrs = {});
/*!
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
-TVM_DLL Array<Tensor> scan(Array<Tensor> init,
- Array<Tensor> update,
- Array<Tensor> state_placeholder,
- Array<Tensor> inputs = Array<Tensor>(),
- std::string name = "scan",
- std::string tag = "",
+TVM_DLL Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update,
+ Array<Tensor> state_placeholder, Array<Tensor> inputs = Array<Tensor>(),
+ std::string name = "scan", std::string tag = "",
Map<std::string, ObjectRef> attrs = {});
// same as compute, specialized for different fcompute function
-inline Tensor compute(Array<PrimExpr> shape,
- std::function<PrimExpr(Var)> f,
- std::string name = "tensor",
- std::string tag = "",
+inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var)> f,
+ std::string name = "tensor", std::string tag = "",
Map<std::string, ObjectRef> attrs = {}) {
- FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
+ FCompute fc = [f](const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name, tag, attrs);
}
-inline Tensor compute(Array<PrimExpr> shape,
- std::function<PrimExpr(Var, Var)> f,
- std::string name = "tensor",
- std::string tag = "",
+inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var)> f,
+ std::string name = "tensor", std::string tag = "",
Map<std::string, ObjectRef> attrs = {}) {
- FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
+ FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name, tag, attrs);
}
-inline Tensor compute(Array<PrimExpr> shape,
- std::function<PrimExpr(Var, Var, Var)> f,
- std::string name = "tensor",
- std::string tag = "",
+inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var)> f,
+ std::string name = "tensor", std::string tag = "",
Map<std::string, ObjectRef> attrs = {}) {
- FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
- return compute(shape, fc, name, tag, attrs);
+ FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
+ return compute(shape, fc, name, tag, attrs);
}
-inline Tensor compute(Array<PrimExpr> shape,
- std::function<PrimExpr(Var, Var, Var, Var)> f,
- std::string name = "tensor",
- std::string tag = "",
+inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var, Var)> f,
+ std::string name = "tensor", std::string tag = "",
Map<std::string, ObjectRef> attrs = {}) {
- FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
+ FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return compute(shape, fc, name, tag, attrs);
}
#ifndef TVM_TE_SCHEDULE_H_
#define TVM_TE_SCHEDULE_H_
-#include <tvm/tir/expr.h>
+#include <tvm/support/with.h>
#include <tvm/te/tensor.h>
#include <tvm/te/tensor_intrin.h>
-#include <tvm/support/with.h>
+#include <tvm/tir/expr.h>
#include <string>
#include <unordered_map>
* \param scope The iteration point to carry the schedule.
* \return reference to self.
*/
- TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
+ TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
/*!
* \brief Compute the function inline.
* \return reference to self.
*/
- TVM_DLL Stage& compute_inline(); // NOLINT(*)
+ TVM_DLL Stage& compute_inline(); // NOLINT(*)
/*!
* \brief Compute the function at group root.
* \return reference to self.
* \param p_inner The result inner domain.
* \return reference to self.
*/
- TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
+ TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer,
+ IterVar* p_inner); // NOLINT(*)
/*!
* \brief Split the iteration with given number of parts.
*
* \param p_inner The result inner domain.
* \return reference to self.
*/
- TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
+ TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
+ IterVar* p_inner); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param outer The outer domain to be fused.
* \param order The order of iteration variable.
* \return reference to self.
*/
- TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
+ TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
/*!
* \brief Perform tiling on two dimensions
* The final loop order from outmost to inner most are
* \param p_y_inner Inner axis of y dimension
* \return reference to self.
*/
- TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
- PrimExpr x_factor, PrimExpr y_factor,
- IterVar* p_x_outer, IterVar* p_y_outer,
- IterVar* p_x_inner, IterVar* p_y_inner);
+ TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
+ PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer,
+ IterVar* p_x_inner, IterVar* p_y_inner);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
- TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
+ TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Replace computation of the current stage by tensor intrinsic f.
* \param var The axis marks beginning of tensorization.
* \param f The Tensor compute intrinsics.
* \return reference to self.
*/
- TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
+ TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be unrolled.
* \return reference to self.
*/
- TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
+ TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
/*!
* \brief Parallelize iteration.
* \param var The axis to be parallelized.
* \return reference to self.
*/
- TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
+ TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
/*!
* \brief Annotate the iteration with pragma
*
*
* \return reference to self.
*/
- TVM_DLL Stage& pragma(IterVar var,
- const std::string& pragma_type,
- const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*)
+ TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type,
+ const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*)
/*!
* \brief Fetch data in advance.
* \param domain the tensor to be prefetched
* \param offset the number of iterations be to fetched in advance
* \return reference to self
*/
- TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*)
+ TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset); // NOLINT(*)
/*!
* \brief Set alignment requirement for specific dimension.
*
* \param offset The required offset factor.
* \return reference to self
*/
- TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
+ TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); // NOLINT(*)
/*!
* \brief Compute current stage with double buffering.
* \return reference to self.
*/
- TVM_DLL Stage& double_buffer(); // NOLINT(*)
+ TVM_DLL Stage& double_buffer(); // NOLINT(*)
/*!
* \brief Schedule for OpenGL fragment shader.
* \return reference to self.
*/
- Stage& opengl(); // NOLINT(*)
+ Stage& opengl(); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
* \param tensor The tensor
* \return The stage corresponding to the tensor's op
*/
- TVM_DLL Stage operator[](const Tensor& tensor) {
- return this->operator[](tensor->op);
- }
+ TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); }
/*!
* \brief Create a new stage group for all intermediate
* operations between inputs and outputs.
* \param include_inputs Whether include inputs if they are reachable from outputs.
* \return The new grouped stage.
*/
- TVM_DLL Stage create_group(const Array<Tensor>& outputs,
- const Array<Tensor>& inputs,
- bool include_inputs = false);
+ TVM_DLL Stage create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
+ bool include_inputs = false);
/*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
- TVM_DLL Tensor cache_read(const Tensor& tensor,
- const std::string& scope,
- const Array<Operation>& readers);
+ TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope,
+ const Array<Operation>& readers);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* \param factor_axis The position where the new axis is placed.
* \return The created factored tensors.
*/
- TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
- const IterVar& axis,
- int factor_axis = 0);
+ TVM_DLL Array<Tensor> rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* \param tensor The candidate tensor.
* \return true if the schedule has the tensor. Otherwise, false.
*/
- TVM_DLL bool Contain(const Tensor& tensor) const {
- return Contain(tensor->op);
- }
+ TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }
/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
-inline Schedule create_schedule(Array<Operation> ops) {
- return ScheduleNode::make(ops);
-}
+inline Schedule create_schedule(Array<Operation> ops) { return ScheduleNode::make(ops); }
/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Object {
v->Visit("nparts", &nparts);
}
- static IterVarRelation make(IterVar parent,
- IterVar outer,
- IterVar inner,
- PrimExpr factor,
+ static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor,
PrimExpr nparts);
static constexpr const char* _type_key = "Split";
v->Visit("fused", &fused);
}
- static IterVarRelation make(
- IterVar outer, IterVar inner, IterVar fused);
+ static IterVarRelation make(IterVar outer, IterVar inner, IterVar fused);
static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
};
-
/*!
* \brief Singleton iterator [0, 1)
*/
/*! \brief The singleton iterator */
IterVar iter;
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("iter", &iter);
- }
+ void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }
static IterVarRelation make(IterVar iter);
*/
Array<PrimExpr> clauses;
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("clauses", &clauses);
- }
+ void VisitAttrs(AttrVisitor* v) { v->Visit("clauses", &clauses); }
static constexpr const char* _type_key = "SpecializedCondition";
TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
};
// implementations
-inline const StageNode* Stage::operator->() const {
- return static_cast<const StageNode*>(get());
-}
-inline StageNode* Stage::operator->() {
- return static_cast<StageNode*>(get_mutable());
-}
+inline const StageNode* Stage::operator->() const { return static_cast<const StageNode*>(get()); }
+inline StageNode* Stage::operator->() { return static_cast<StageNode*>(get_mutable()); }
inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(get());
}
-inline ScheduleNode* Schedule::operator->() {
- return static_cast<ScheduleNode*>(get_mutable());
-}
+inline ScheduleNode* Schedule::operator->() { return static_cast<ScheduleNode*>(get_mutable()); }
inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(get());
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
-Stmt SchedulePostProcRewriteForTensorCore(
- Stmt stmt,
- Schedule schedule,
- Map<Tensor, Buffer> extern_buffer);
+Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
+ Map<Tensor, Buffer> extern_buffer);
/*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create
* \param body The body of the function.
* \param bindings potential Tensor to Buffer bindings for the Tensors in the body.
*/
-PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
- Stmt body,
+PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body,
Optional<Map<Tensor, Buffer>> bindings);
} // namespace te
#ifndef TVM_TE_TENSOR_H_
#define TVM_TE_TENSOR_H_
-#include <tvm/node/container.h>
#include <tvm/arith/bound.h>
+#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
-#include <vector>
-#include <utility>
#include <type_traits>
+#include <utility>
+#include <vector>
namespace tvm {
namespace te {
* \param args The indices
* \return the result expression representing tensor read.
*/
- template<typename... Args>
- inline PrimExpr operator()(Args&& ...args) const {
+ template <typename... Args>
+ inline PrimExpr operator()(Args&&... args) const {
Array<PrimExpr> indices{std::forward<Args>(args)...};
return operator()(indices);
}
* This is only valid when all the coordinates are fully specified.
* \return the corresponding expression of this slice.
*/
- inline operator PrimExpr() const {
- return tensor_(indices_);
- }
+ inline operator PrimExpr() const { return tensor_(indices_); }
private:
const Tensor& tensor_;
* \param i the index of the coordinate
* \return the subsequent slice.
*/
- inline Slice operator[](PrimExpr i) const {
- return Slice(*this, {i});
- }
+ inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); }
/*! \brief specify container node */
using ContainerType = TensorNode;
};
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
- TVM_DLL static Tensor make(Array<PrimExpr> shape,
- DataType dtype,
- Operation op,
- int value_index);
+ TVM_DLL static Tensor make(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
};
-
// Implementations of inline functions
inline const TensorNode* Tensor::operator->() const {
return static_cast<const TensorNode*>(get());
}
-inline size_t Tensor::ndim() const {
- return (*this)->shape.size();
-}
+inline size_t Tensor::ndim() const { return (*this)->shape.size(); }
inline bool Tensor::operator==(const Tensor& other) const {
if (get() == other.get()) return true;
if (get() == nullptr || other.get() == nullptr) return false;
if ((*this)->op.defined() || other->op.defined()) {
- return (*this)->op == other->op &&
- (*this)->value_index == other->value_index;
+ return (*this)->op == other->op && (*this)->value_index == other->value_index;
} else {
return false;
}
}
-inline bool Tensor::operator!=(const Tensor& other) const {
- return !(*this == other);
-}
+inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); }
// macro to turn every operation of slice to expression
-#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
- inline PrimExpr operator Op (const Tensor::Slice& a) { \
- return Op a.operator PrimExpr() ; \
- } \
+#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
+ inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); }
-#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
- template<typename T> \
- inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \
- return a.operator PrimExpr() Op b; \
- } \
- template<typename T> \
- inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \
- return a Op b.operator PrimExpr(); \
- } \
- inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
- return a.operator PrimExpr() Op b.operator PrimExpr(); \
+#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
+ template <typename T> \
+ inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \
+ return a.operator PrimExpr() Op b; \
+ } \
+ template <typename T> \
+ inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \
+ return a Op b.operator PrimExpr(); \
+ } \
+ inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \
+ return a.operator PrimExpr() Op b.operator PrimExpr(); \
}
DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
namespace std {
template <>
-struct hash<::tvm::te::Operation> : public ::tvm::ObjectHash {
-};
+struct hash<::tvm::te::Operation> : public ::tvm::ObjectHash {};
template <>
struct hash<::tvm::te::Tensor> {
::tvm::ObjectHash hasher;
if (k.defined() && k->op.defined()) {
return hasher(k->op);
- } else{
+ } else {
return hasher(k);
}
}
v->Visit("reduce_update", &reduce_update);
}
- TVM_DLL static TensorIntrin make(std::string name,
- Operation op,
- Array<Tensor> inputs,
- Array<Buffer> buffers,
- Array<Var> scalar_params,
- Stmt body,
- Stmt reduce_init,
- Stmt reduce_update);
+ TVM_DLL static TensorIntrin make(std::string name, Operation op, Array<Tensor> inputs,
+ Array<Buffer> buffers, Array<Var> scalar_params, Stmt body,
+ Stmt reduce_init, Stmt reduce_update);
static constexpr const char* _type_key = "TensorIntrin";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
/*! \brief regions of input tensors */
Array<Region> regions;
-
/*!
* \brief IterVar on each reduction axis, if the
* intrin will use the reduce axis
v->Visit("reduce_axis", &reduce_axis);
v->Visit("scalar_inputs", &scalar_inputs);
}
- static TensorIntrinCall make(TensorIntrin intrin,
- Array<Tensor> tensors,
- Array<Region> regions,
- Array<IterVar> reduce_axis,
- Array<PrimExpr> scalar_inputs);
+ static TensorIntrinCall make(TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions,
+ Array<IterVar> reduce_axis, Array<PrimExpr> scalar_inputs);
static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
+
#include <string>
namespace tvm {
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
*/
-TVM_DLL bool ExprUseVar(const PrimExpr& expr,
- std::function<bool(const VarNode*)> vset_contains);
+TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
/*!
* \brief Whether e expression used var.
* \return Whether e uses v.
*/
inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
- return ExprUseVar(expr, [&](const VarNode* node) {
- return var.get() == node;
- });
+ return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; });
}
-
/*!
* \brief Verifies whether the IR stmt or Expr is in SSA form.
* That is: each Var is defined and assigned once(in Let/For)
* \return valid Whether it is a valid GPU code
*
*/
-TVM_DLL bool VerifyGPUCode(const PrimFunc& func,
- Map<std::string, PrimExpr> constraints);
+TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<std::string, PrimExpr> constraints);
// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
#ifndef TVM_TIR_BUFFER_H_
#define TVM_TIR_BUFFER_H_
-#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
+#include <tvm/node/container.h>
#include <tvm/tir/var.h>
-#include <string>
+#include <string>
namespace tvm {
namespace tir {
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/
- TVM_DLL PrimExpr access_ptr(int access_mask,
- DataType ptr_type = DataType::Handle(),
+ TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
int content_lanes = 1,
PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
/*!
bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
// Use DefEqual as buffer can define variables
// in its semantics, skip name as name is not important.
- return
- equal.DefEqual(data, other->data) &&
- equal(dtype, other->dtype) &&
- equal.DefEqual(shape, other->shape) &&
- equal.DefEqual(strides, other->strides) &&
- equal.DefEqual(elem_offset, other->elem_offset) &&
- equal(scope, other->scope) &&
- equal(data_alignment, other->data_alignment) &&
- equal(buffer_type, other->buffer_type);
+ return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) &&
+ equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) &&
+ equal.DefEqual(elem_offset, other->elem_offset) && equal(scope, other->scope) &&
+ equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type);
}
void SHashReduce(SHashReducer hash_reduce) const {
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
- TVM_DLL static Buffer make(Var ptr,
- DataType dtype,
- Array<PrimExpr> shape,
- Array<PrimExpr> strides,
- PrimExpr elem_offset,
- std::string name,
- std::string scope,
- int data_alignment,
- int offset_factor,
+ TVM_DLL static Buffer make(Var ptr, DataType dtype, Array<PrimExpr> shape,
+ Array<PrimExpr> strides, PrimExpr elem_offset, std::string name,
+ std::string scope, int data_alignment, int offset_factor,
BufferType buffer_type);
static constexpr const char* _type_key = "Buffer";
* \return The created buffer.
* \sa BufferNode::make for complete constructor.
*/
-TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape,
- DataType dtype = DataType::Float(32),
+TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
std::string name = "buffer");
} // namespace tir
} // namespace tvm
#ifndef TVM_TIR_DATA_LAYOUT_H_
#define TVM_TIR_DATA_LAYOUT_H_
-
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
-#include <string>
+#include <algorithm>
#include <sstream>
-#include <vector>
+#include <string>
#include <utility>
-#include <algorithm>
-
+#include <vector>
namespace tvm {
namespace tir {
}
// return the primal axis. If it is already primal, return itself.
- const LayoutAxis& ToPrimal() const {
- return IsPrimal() ? *this : ToDual();
- }
+ const LayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); }
// return the subordinate axis. If it is already subordinate, return itself.
- const LayoutAxis& ToSubordinate() const {
- return IsPrimal() ? ToDual() : *this;
- }
+ const LayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; }
- inline bool operator==(const LayoutAxis& rhs) const {
- return name_ == rhs.name_;
- }
+ inline bool operator==(const LayoutAxis& rhs) const { return name_ == rhs.name_; }
friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) {
os << l.name();
explicit Layout(const Array<tir::IterVar>& axes);
/*! \brief construct from a string */
- Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
+ Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
/*!
* \brief construct from a string.
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
- Layout(const std::string& name); // NOLINT(*)
+ Layout(const std::string& name); // NOLINT(*)
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
- const LayoutNode* operator->() const {
- return static_cast<const LayoutNode*>(get());
- }
+ const LayoutNode* operator->() const { return static_cast<const LayoutNode*>(get()); }
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
- LayoutNode* operator->() {
- return static_cast<LayoutNode*>(get_mutable());
- }
+ LayoutNode* operator->() { return static_cast<LayoutNode*>(get_mutable()); }
/*!
* \brief Return an undefined layout.
* \param factor size of the sub-dimension.
* \return A newly constructed Layout object.
*/
- Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const;
-
+ Layout Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const;
/*! \return number of dimensions */
inline size_t ndim() const {
* \param rhs Another layout.
* \return whether the two layouts are equal.
*/
- inline bool Equals(const Layout &rhs) const {
- return name() == rhs.name();
- }
+ inline bool Equals(const Layout& rhs) const { return name() == rhs.name(); }
/*!
* \brief allow output string of layout to ostream
#ifndef TVM_TIR_EXPR_H_
#define TVM_TIR_EXPR_H_
-#include <tvm/node/node.h>
+#include <tvm/ir/expr.h>
#include <tvm/node/container.h>
#include <tvm/node/functor.h>
+#include <tvm/node/node.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/data_type.h>
-#include <tvm/ir/expr.h>
-#include <tvm/tir/var.h>
#include <tvm/tir/buffer.h>
+#include <tvm/tir/var.h>
-#include <string>
#include <algorithm>
-#include <unordered_map>
#include <iostream>
#include <limits>
+#include <string>
+#include <unordered_map>
#include <utility>
namespace tvm {
return equal(value, other->value);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(value);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
TVM_DLL PrimExpr static make(std::string value);
* \brief Base template to implement binary ops.
* \tparam T The type of the child class.
*/
-template<typename T>
+template <typename T>
class BinaryOpNode : public PrimExprNode {
public:
/*! \brief The left operand. */
}
bool SEqualReduce(const T* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(a, other->a) &&
- equal(b, other->b);
+ return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
}
void SHashReduce(SHashReducer hash_reduce) const {
* \brief Base template to implement comparison ops.
* \tparam T The type of the child class.
*/
-template<typename T>
+template <typename T>
class CmpOpNode : public PrimExprNode {
public:
/*! \brief The left operand. */
}
bool SEqualReduce(const T* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(a, other->a) &&
- equal(b, other->b);
+ return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(a, other->a) &&
- equal(b, other->b);
+ return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(a, other->a) &&
- equal(b, other->b);
+ return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(condition, other->condition) &&
- equal(true_value, other->true_value) &&
- equal(false_value, other->false_value);
+ return equal(dtype, other->dtype) && equal(condition, other->condition) &&
+ equal(true_value, other->true_value) && equal(false_value, other->false_value);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(buffer, other->buffer) &&
- equal(indices, other->indices);
+ return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
+ equal(indices, other->indices);
}
void SHashReduce(SHashReducer hash_reduce) const {
class BufferLoad : public PrimExpr {
public:
- TVM_DLL explicit BufferLoad(Buffer buffer,
- Array<PrimExpr> indices);
+ TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices);
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
};
}
bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(buffer_var, other->buffer_var) &&
- equal(index, other->index) &&
- equal(predicate, other->predicate);
+ return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) &&
+ equal(index, other->index) && equal(predicate, other->predicate);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(base, other->base) &&
- equal(stride, other->stride) &&
- equal(lanes, other->lanes);
+ return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) &&
+ equal(lanes, other->lanes);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(value, other->value) &&
- equal(lanes, other->lanes);
+ return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal.DefEqual(var, other->var) &&
- equal(value, other->value) &&
- equal(body, other->body);
+ return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) &&
+ equal(value, other->value) && equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
return this == other;
}
- void SHashReduce(SHashReducer hash_reduce) const {
- }
+ void SHashReduce(SHashReducer hash_reduce) const {}
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
}
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(name, other->name) &&
- equal(args, other->args) &&
- equal(call_type, other->call_type) &&
- equal(func, other->func) &&
- equal(value_index, other->value_index);
+ return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) &&
+ equal(call_type, other->call_type) && equal(func, other->func) &&
+ equal(value_index, other->value_index);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(value_index);
}
- TVM_DLL static PrimExpr make(DataType dtype,
- std::string name,
- Array<PrimExpr> args,
- CallType call_type,
- FunctionRef func = FunctionRef(),
+ TVM_DLL static PrimExpr make(DataType dtype, std::string name, Array<PrimExpr> args,
+ CallType call_type, FunctionRef func = FunctionRef(),
int value_index = 0);
/*! \return Whether call node is pure. */
bool is_pure() const {
- return (call_type == PureExtern ||
- call_type == PureIntrinsic ||
- call_type == Halide);
+ return (call_type == PureExtern || call_type == PureIntrinsic || call_type == Halide);
}
/*!
* \param intrin_name The name of the intrinsic.
*/
bool is_intrinsic(const char* intrin_name) const {
- return
- ((call_type == Intrinsic ||
- call_type == PureIntrinsic) &&
- name == intrin_name);
+ return ((call_type == Intrinsic || call_type == PureIntrinsic) && name == intrin_name);
}
/*! \return Whether call node can be vectorized. */
}
bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
- return
- equal(dtype, other->dtype) &&
- equal(vectors, other->vectors) &&
- equal(indices, other->indices);
+ return equal(dtype, other->dtype) && equal(vectors, other->vectors) &&
+ equal(indices, other->indices);
}
void SHashReduce(SHashReducer hash_reduce) const {
/*! \brief Function call operator to combine a and b */
Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */
- TVM_DLL static CommReducer make(Array<Var> lhs,
- Array<Var> rhs,
- Array<PrimExpr> result,
+ TVM_DLL static CommReducer make(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
Array<PrimExpr> identity_element);
void VisitAttrs(AttrVisitor* v) {
}
bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
- return
- equal.DefEqual(lhs, other->lhs) &&
- equal.DefEqual(rhs, other->rhs) &&
- equal(result, other->result) &&
- equal(identity_element, other->identity_element);
+ return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) &&
+ equal(result, other->result) && equal(identity_element, other->identity_element);
}
void SHashReduce(SHashReducer hash_reduce) const {
inline const CommReducerNode* CommReducer::get() const {
return static_cast<const CommReducerNode*>(data_.get());
}
-inline const CommReducerNode* CommReducer::operator->() const {
- return get();
-}
+inline const CommReducerNode* CommReducer::operator->() const { return get(); }
/*! \brief Reduction operator operator */
class ReduceNode : public PrimExprNode {
int value_index;
/*! \brief construct expr from op and rdom */
- TVM_DLL static PrimExpr make(CommReducer combiner,
- Array<PrimExpr> src,
- Array<IterVar> rdom,
- PrimExpr condition,
- int value_index);
+ TVM_DLL static PrimExpr make(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom,
+ PrimExpr condition, int value_index);
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
// check axis first so IterVars can define the necessary variables.
- return
- equal(dtype, other->dtype) &&
- equal(axis, other->axis) &&
- equal(combiner, other->combiner) &&
- equal(source, other->source) &&
- equal(condition, other->condition) &&
- equal(value_index, other->value_index);
+ return equal(dtype, other->dtype) && equal(axis, other->axis) &&
+ equal(combiner, other->combiner) && equal(source, other->source) &&
+ equal(condition, other->condition) && equal(value_index, other->value_index);
}
void SHashReduce(SHashReducer hash_reduce) const {
public:
void VisitAttrs(AttrVisitor* v) {}
- bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
- return true;
- }
+ bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { return true; }
- void SHashReduce(SHashReducer hash_reduce) const {
- }
+ void SHashReduce(SHashReducer hash_reduce) const {}
/*! \brief Convert to var. */
- Var ToVar() const {
- return Var("any_dim", DataType::Int(32));
- }
+ Var ToVar() const { return Var("any_dim", DataType::Int(32)); }
TVM_DLL static PrimExpr make();
TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
};
-
/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
* \tparam K the key of the Map.
* \tparam V the value of the Map.
*/
-template<typename K, typename V>
+template <typename K, typename V>
inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
std::unordered_map<K, V> ret;
for (auto kv : dmap) {
* return 0;
* }
*/
-constexpr const char *tvm_call_trace_packed = "tvm_call_trace_packed";
+constexpr const char* tvm_call_trace_packed = "tvm_call_trace_packed";
/*!
* \brief See pesudo code
* Mark the content as thread local context, can get optimized
* TVMRetValue(value_stack + end, tcode_stack + end));
* }
*/
-constexpr const char *tvm_call_trace_packed_lowered =
- "tvm_call_trace_packed_lowered";
+constexpr const char* tvm_call_trace_packed_lowered = "tvm_call_trace_packed_lowered";
/*!
* \brief See pseudo code
*
kTVMValueContent,
kTVMValueKindBound_
};
-} // namespace intrinsic
+} // namespace intrinsic
} // namespace tir
} // namespace tvm
namespace runtime {
// Additional implementattion overloads for PackedFunc.
-template<>
+template <>
struct PackedFuncValueConverter<tvm::Integer> {
// common rule for RetValue and ArgValue
static tvm::Integer From(const TVMPODValue_& val) {
namespace std {
template <>
-struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectHash {
-};
-}
+struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectHash {};
+} // namespace std
#endif // TVM_TIR_EXPR_H_
* \tparam FType function signiture
* This type if only defined for FType with function signiture R(const Expr&, Args...)
*/
-template<typename FType>
+template <typename FType>
class ExprFunctor;
// functions to be overriden.
-#define EXPR_FUNCTOR_DEFAULT { \
- return VisitExprDefault_(op, std::forward<Args>(args)...); \
- }
+#define EXPR_FUNCTOR_DEFAULT \
+ { return VisitExprDefault_(op, std::forward<Args>(args)...); }
-#define IR_EXPR_FUNCTOR_DISPATCH(OP) \
- vtable.template set_dispatch<OP>( \
- [](const ObjectRef& n, TSelf* self, Args... args) { \
- return self->VisitExpr_(static_cast<const OP*>(n.get()), \
- std::forward<Args>(args)...); \
- }); \
+#define IR_EXPR_FUNCTOR_DISPATCH(OP) \
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
+ return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+ });
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
class ExprFunctor<R(const PrimExpr& n, Args...)> {
private:
using TSelf = ExprFunctor<R(const PrimExpr& n, Args...)>;
virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExprDefault_(const Object* op, Args ...) {
+ virtual R VisitExprDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
}
/*!
* \brief ExprVisitor
*/
-class TVM_DLL ExprVisitor :
- public ExprFunctor<void(const PrimExpr&)> {
+class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
public:
using ExprFunctor::operator();
/*!
* \brief ExprMutator that mutates expressions.
*/
-class TVM_DLL ExprMutator :
- protected ExprFunctor<PrimExpr(const PrimExpr&)> {
+class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
public:
using ExprFunctor::operator();
#define TVM_TIR_FUNCTION_H_
#include <tvm/ir/function.h>
-#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
-#include <string>
+#include <string>
namespace tvm {
namespace tir {
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
// visit params and buffer_map first as they contains defs.
- return
- equal.DefEqual(params, other->params) &&
- equal(buffer_map, other->buffer_map) &&
- equal(ret_type, other->ret_type) &&
- equal(body, other->body) &&
- equal(attrs, other->attrs);
+ return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
+ equal(ret_type, other->ret_type) && equal(body, other->body) &&
+ equal(attrs, other->attrs);
}
void SHashReduce(SHashReducer hash_reduce) const {
* \param buffer_map The buffer map for parameter buffer unpacking.
* \param attrs Additional function attributes.
*/
- TVM_DLL PrimFunc(Array<tir::Var> params,
- Stmt body,
- Type ret_type = VoidType(),
+ TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = NullValue<Map<tir::Var, Buffer>>(),
DictAttrs attrs = NullValue<DictAttrs>());
#include <tvm/tir/stmt.h>
#include <algorithm>
-#include <type_traits>
#include <limits>
-
+#include <type_traits>
namespace tvm {
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
return tir::CallNode::make(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \
- } \
+ }
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(exp2);
TVM_DECLARE_INTRIN_UNARY(asinh);
TVM_DECLARE_INTRIN_UNARY(atanh);
-
namespace tir {
/*!
* \brief Make a const value with certain data type.
* \return the result expression.
* \tparam ValueType The constant value type
*/
-template<typename ValueType,
- typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
+template <typename ValueType,
+ typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
inline PrimExpr make_const(DataType t, ValueType value);
/*!
* \brief Make a const zero expr.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
-inline PrimExpr const_true(int lanes = 1) {
- return make_const(DataType::UInt(1, lanes), 1);
-}
+inline PrimExpr const_true(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 1); }
/*!
* \brief Make a constant false expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
-inline PrimExpr const_false(int lanes = 1) {
- return make_const(DataType::UInt(1, lanes), 0);
-}
+inline PrimExpr const_false(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 0); }
/*!
* \brief Get x as constant int expression.
* \param x The expression
* \note This only return true for integer types.
* \return whether x is constant 1
*/
-inline bool is_one(const PrimExpr& x) {
- return is_const_int(x, 1);
-}
+inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); }
/*!
* \brief Check whether x is a constant integer 0
* \return whether x is constant 0
* \note This only return true for integer types.
*/
-inline bool is_zero(const PrimExpr& x) {
- return is_const_int(x, 0);
-}
+inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
/*!
* \brief Check whether x is a constant.
return false;
}
-template<typename ValueType>
+template <typename ValueType>
inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
if (t.is_int()) return IntImm(t, static_cast<int64_t>(value));
if (t.is_uint()) {
return PrimExpr();
}
-template<typename ValueType, typename>
+template <typename ValueType, typename>
inline PrimExpr make_const(DataType t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
- return tir::BroadcastNode::make(
- MakeConstScalar(t.element_of(), value), t.lanes());
+ return tir::BroadcastNode::make(MakeConstScalar(t.element_of(), value), t.lanes());
}
}
} // namespace tir
// additional const expression overloading
-#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
- inline PrimExpr Name(PrimExpr& a, PrimExpr b) {\
- a = OpFunc(a, b); \
- return a; \
+#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
+ inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \
+ a = OpFunc(a, b); \
+ return a; \
}
-#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
- inline PrimExpr Name(const PrimExpr& a, float b) { \
- return Name(a, PrimExpr(b)); \
- } \
- inline PrimExpr Name(float a, const PrimExpr& b) { \
- return Name(PrimExpr(a), b); \
- } \
- inline PrimExpr Name(int a, const PrimExpr& b) { \
- return Name(tir::make_const(b.dtype(), a), b); \
- } \
- inline PrimExpr Name(const PrimExpr& a, int b) { \
- return Name(a, tir::make_const(a.dtype(), b)); \
- } \
- inline PrimExpr Name(const PrimExpr& a, double b) { \
- return Name(a, tir::make_const(DataType::Float(64), b)); \
+#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
+ inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
+ inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
+ inline PrimExpr Name(int a, const PrimExpr& b) { \
+ return Name(tir::make_const(b.dtype(), a), b); \
+ } \
+ inline PrimExpr Name(const PrimExpr& a, int b) { \
+ return Name(a, tir::make_const(a.dtype(), b)); \
+ } \
+ inline PrimExpr Name(const PrimExpr& a, double b) { \
+ return Name(a, tir::make_const(DataType::Float(64), b)); \
}
-#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
- inline PrimExpr Name(const PrimExpr& a, bool b) { \
- return Name(a, PrimExpr(b)); \
- } \
- inline PrimExpr Name(bool a, const PrimExpr& b) { \
- return Name(PrimExpr(a), b); \
- }
+#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
+ inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \
+ inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); }
-#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
- inline PrimExpr Name(const PrimExpr& a, int b) { \
- return Name(a, tir::make_const(a.dtype(), b)); \
- } \
- inline PrimExpr Name(int a, const PrimExpr& b) { \
- return Name(tir::make_const(b.dtype(), a), b); \
- }
+#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
+ inline PrimExpr Name(const PrimExpr& a, int b) { \
+ return Name(a, tir::make_const(a.dtype(), b)); \
+ } \
+ inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); }
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
-TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
-TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
+TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
+TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
-template<typename TA>
+template <typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
// to use the specific division function.
// The second template argument is necessary to make sure the
// code compiles lazily by the compiler during invocation.
-template<typename TB>
+template <typename TB>
inline PrimExpr operator/(const PrimExpr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
-template<typename TB>
+template <typename TB>
inline PrimExpr operator/=(const PrimExpr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
-template<typename TB>
+template <typename TB>
inline PrimExpr operator%(const PrimExpr& a, const TB& b) {
DivAmbiguityError(a);
return a;
#include <tvm/tir/expr.h>
-#include <type_traits>
#include <string>
-#include <vector>
+#include <type_traits>
#include <utility>
+#include <vector>
namespace tvm {
namespace tir {
}
bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
- return
- equal.DefEqual(var, other->var) &&
- equal(value, other->value) &&
- equal(body, other->body);
+ return equal.DefEqual(var, other->var) && equal(value, other->value) &&
+ equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
- return
- equal(node, other->node) &&
- equal(attr_key, other->attr_key) &&
- equal(value, other->value) &&
- equal(body, other->body);
+ return equal(node, other->node) && equal(attr_key, other->attr_key) &&
+ equal(value, other->value) && equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(body);
}
- TVM_DLL static Stmt make(ObjectRef node,
- std::string type_key,
- PrimExpr value,
- Stmt body);
+ TVM_DLL static Stmt make(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);
static constexpr const char* _type_key = "AttrStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
}
bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
- return
- equal(condition, other->condition) &&
- equal(message, other->message) &&
- equal(body, other->body);
+ return equal(condition, other->condition) && equal(message, other->message) &&
+ equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
- return
- equal(buffer_var, other->buffer_var) &&
- equal(value, other->value) &&
- equal(index, other->index) &&
- equal(predicate, other->predicate);
+ return equal(buffer_var, other->buffer_var) && equal(value, other->value) &&
+ equal(index, other->index) && equal(predicate, other->predicate);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(predicate);
}
- TVM_DLL static Stmt make(Var buffer_var,
- PrimExpr value,
- PrimExpr index,
- PrimExpr predicate);
+ TVM_DLL static Stmt make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate);
static constexpr const char* _type_key = "Store";
TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
}
bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
- return
- equal(buffer, other->buffer) &&
- equal(value, other->value) &&
- equal(indices, other->indices);
+ return equal(buffer, other->buffer) && equal(value, other->value) &&
+ equal(indices, other->indices);
}
void SHashReduce(SHashReducer hash_reduce) const {
*/
class BufferStore : public Stmt {
public:
- TVM_DLL explicit BufferStore(Buffer buffer,
- PrimExpr value,
- Array<PrimExpr> indices);
+ TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
};
}
bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
- return
- equal(buffer, other->buffer) &&
- equal(bounds, other->bounds) &&
- equal(condition, other->condition) &&
- equal(body, other->body);
+ return equal(buffer, other->buffer) && equal(bounds, other->bounds) &&
+ equal(condition, other->condition) && equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
BufferRealizeNode() = default;
- BufferRealizeNode(Buffer buffer,
- Array<Range> bounds,
- PrimExpr condition,
- Stmt body)
- : buffer(buffer), bounds(bounds),
- condition(condition), body(body) {}
+ BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body)
+ : buffer(buffer), bounds(bounds), condition(condition), body(body) {}
static constexpr const char* _type_key = "BufferRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
*/
class BufferRealize : public Stmt {
public:
- TVM_DLL explicit BufferRealize(Buffer buffer,
- Array<Range> bounds,
- PrimExpr condition,
- Stmt body);
+ TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode);
};
}
bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const {
- return
- equal(func, other->func) &&
- equal(value_index, other->value_index) &&
- equal(value, other->value) &&
- equal(args, other->args);
+ return equal(func, other->func) && equal(value_index, other->value_index) &&
+ equal(value, other->value) && equal(args, other->args);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(args);
}
- TVM_DLL static Stmt make(FunctionRef func,
- int value_index,
- PrimExpr value,
- Array<PrimExpr> args);
+ TVM_DLL static Stmt make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args);
static constexpr const char* _type_key = "Provide";
TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode);
}
bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
- return
- equal.DefEqual(buffer_var, other->buffer_var) &&
- equal(dtype, other->dtype) &&
- equal(extents, other->extents) &&
- equal(condition, other->condition) &&
- equal(body, other->body);
+ return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
+ equal(extents, other->extents) && equal(condition, other->condition) &&
+ equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(body);
}
- TVM_DLL static Stmt make(Var buffer_var,
- DataType dtype,
- Array<PrimExpr> extents,
- PrimExpr condition,
- Stmt body);
+ TVM_DLL static Stmt make(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
+ PrimExpr condition, Stmt body);
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \return The result.
*/
- int32_t constant_allocation_size() const {
- return constant_allocation_size(extents);
- }
+ int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \param extents The extents of the buffer.
* \return The result.
*/
- TVM_DLL static int32_t constant_allocation_size(
- const Array<PrimExpr>& extents);
+ TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
static constexpr const char* _type_key = "Allocate";
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
/*! \brief The buffer variable. */
Var buffer_var;
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("buffer_var", &buffer_var);
- }
+ void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); }
bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const {
- return
- equal(buffer_var, other->buffer_var);
+ return equal(buffer_var, other->buffer_var);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(buffer_var);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }
TVM_DLL static Stmt make(Var buffer_var);
v->Visit("body", &body);
}
- TVM_DLL static Stmt make(FunctionRef func,
- int value_index,
- DataType dtype,
- Region bounds,
- PrimExpr condition,
- Stmt body);
+ TVM_DLL static Stmt make(FunctionRef func, int value_index, DataType dtype, Region bounds,
+ PrimExpr condition, Stmt body);
bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const {
- return
- equal(func, other->func) &&
- equal(value_index, other->value_index) &&
- equal(dtype, other->dtype) &&
- equal(bounds, other->bounds) &&
- equal(condition, other->condition) &&
- equal(body, other->body);
+ return equal(func, other->func) && equal(value_index, other->value_index) &&
+ equal(dtype, other->dtype) && equal(bounds, other->bounds) &&
+ equal(condition, other->condition) && equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
Array<Stmt> seq;
/*! \return get the size of the sequence */
- size_t size() const {
- return seq.size();
- }
+ size_t size() const { return seq.size(); }
/*!
* \brief Get the index-th element in the sequence.
*/
- Stmt operator[](size_t index) const {
- return seq[index];
- }
+ Stmt operator[](size_t index) const { return seq[index]; }
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("seq", &seq);
- }
+ void VisitAttrs(AttrVisitor* v) { v->Visit("seq", &seq); }
bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
return equal(seq, other->seq);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(seq);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
static constexpr const char* _type_key = "SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
TVM_DLL explicit SeqStmt(Array<Stmt> seq);
/*! \return get the size of the sequence */
- size_t size() const {
- return operator->()->size();
- }
+ size_t size() const { return operator->()->size(); }
/*!
* \brief Get the index-th element in the sequence.
*/
- Stmt operator[](size_t index) const {
- return (*(operator->()))[index];
- }
+ Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
/*!
* \brief Construct a sequence statement by flattening
* all the arrays and sequences in the arguments
* \tparam Args arguments
* \return The constructed statement
*/
- template<typename ...Args>
+ template <typename... Args>
static Stmt Flatten(Args&&... seq_args) {
Array<Stmt> seq;
- runtime::detail::for_each(
- Flattener(&seq), std::forward<Args>(seq_args)...);
+ runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
if (seq.size() == 1) return seq[0];
return SeqStmt(seq);
}
/*! \brief Helper class to flatten sequence of arguments into Array. */
class Flattener {
public:
- explicit Flattener(Array<Stmt>* seq)
- : seq_(seq) {}
+ explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}
void operator()(size_t i, const Stmt& stmt) const {
if (!stmt.defined()) return;
}
}
- template<typename T>
+ template <typename T>
void operator()(size_t i, const T& seq) const {
for (auto v : seq) {
this->operator()(0, v);
}
bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
- return
- equal(condition, other->condition) &&
- equal(then_case, other->then_case) &&
- equal(else_case, other->else_case);
+ return equal(condition, other->condition) && equal(then_case, other->then_case) &&
+ equal(else_case, other->else_case);
}
void SHashReduce(SHashReducer hash_reduce) const {
/*! \brief The expression to be evaluated. */
PrimExpr value;
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("value", &value);
- }
+ void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); }
bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(value);
- }
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
TVM_DLL static Stmt make(PrimExpr v);
// Kevice api of for loop
// kept for backward compatibility
// consider refactor and remove later.
-enum class DeviceAPI: int {
- None = 0
-};
+enum class DeviceAPI : int { None = 0 };
/*!
* \brief A for loop, with poissible type annotations.
/*! \brief The body of the for loop. */
Stmt body;
- TVM_DLL static Stmt make(Var loop_var,
- PrimExpr min,
- PrimExpr extent,
- ForType for_type,
- DeviceAPI device_api,
- Stmt body);
+ TVM_DLL static Stmt make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
+ DeviceAPI device_api, Stmt body);
void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var);
}
bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
- return
- equal.DefEqual(loop_var, other->loop_var) &&
- equal(min, other->min) &&
- equal(extent, other->extent) &&
- equal(for_type, other->for_type) &&
- equal(device_api, other->device_api) &&
- equal(body, other->body);
+ return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
+ equal(extent, other->extent) && equal(for_type, other->for_type) &&
+ equal(device_api, other->device_api) && equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(body);
}
-
static constexpr const char* _type_key = "For";
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};
}
bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
- return
- equal(buffer, other->buffer) &&
- equal(bounds, other->bounds);
+ return equal(buffer, other->buffer) && equal(bounds, other->bounds);
}
void SHashReduce(SHashReducer hash_reduce) const {
}
PrefetchNode() = default;
- PrefetchNode(Buffer buffer, Array<Range> bounds)
- : buffer(buffer), bounds(bounds) {}
+ PrefetchNode(Buffer buffer, Array<Range> bounds) : buffer(buffer), bounds(bounds) {}
static constexpr const char* _type_key = "Prefetch";
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
* \return Expr a expression with dtype.
*/
inline PrimExpr TypeAnnotation(DataType dtype) {
- return tir::CallNode::make(dtype,
- "type_annotation", {},
- tir::CallNode::PureIntrinsic);
+ return tir::CallNode::make(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic);
}
// overload printing of for type.
#ifndef TVM_TIR_STMT_FUNCTOR_H_
#define TVM_TIR_STMT_FUNCTOR_H_
-#include <tvm/node/functor.h>
#include <tvm/node/container.h>
+#include <tvm/node/functor.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt.h>
#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
-#include <utility>
#include <unordered_map>
+#include <utility>
namespace tvm {
namespace tir {
* \tparam FType The function signature.
* \sa ExprFunctor
*/
-template<typename FType>
+template <typename FType>
class StmtFunctor;
-#define STMT_FUNCTOR_DEFAULT { \
- return VisitStmtDefault_(op, std::forward<Args>(args)...); \
- }
-
-#define IR_STMT_FUNCTOR_DISPATCH(OP) \
- vtable.template set_dispatch<OP>( \
- [](const ObjectRef& n, TSelf* self, Args... args) { \
- return self->VisitStmt_(static_cast<const OP*>(n.get()), \
- std::forward<Args>(args)...); \
- }); \
+#define STMT_FUNCTOR_DEFAULT \
+ { return VisitStmtDefault_(op, std::forward<Args>(args)...); }
+#define IR_STMT_FUNCTOR_DISPATCH(OP) \
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
+ return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+ });
-template<typename R, typename ...Args>
+template <typename R, typename... Args>
class StmtFunctor<R(const Stmt& n, Args... args)> {
private:
using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
* \param args Additional arguments.
* \return The result of the call
*/
- R operator()(const Stmt& n, Args... args) {
- return VisitStmt(n, std::forward<Args>(args)...);
- }
+ R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward<Args>(args)...); }
/*!
* \brief The functor call.
* \param n The stmt node.
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmtDefault_(const Object* op, Args ...) {
+ virtual R VisitStmtDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
}
/*!
* \brief StmtVisitor.
*/
-class TVM_DLL StmtVisitor :
- protected StmtFunctor<void(const Stmt&)> {
+class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
public:
using StmtFunctor::operator();
/*!
* \brief StmtMutator that mutates the statements.
*/
-class TVM_DLL StmtMutator :
- protected StmtFunctor<Stmt(const Stmt&)> {
+class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
public:
/*!
* \brief Mutate stmt.
*
* \return The result object pointer.
*/
- template<typename TNode>
+ template <typename TNode>
ObjectPtr<TNode> CopyOnWrite(const TNode* node) {
if (allow_copy_on_write_) {
// return the old node.
* or have a class sub-class both StmtMutator and ExprMutator
* and redirect Mutate to ExprMutator::Mutate(Expr)
*/
- virtual PrimExpr VisitExpr(const PrimExpr& e) {
- return e;
- }
+ virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; }
// statement visitor
Stmt VisitStmt_(const AttrStmtNode* op) override;
Stmt VisitStmt_(const IfThenElseNode* op) override;
* \param fmutate The mutate function, can be nullptr, which defaults to Visit.
* \return The mutated result.
*/
- Stmt VisitSeqStmt_(const SeqStmtNode* op,
- bool flatten_before_visit,
+ Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
std::function<Stmt(const Stmt&)> fmutate = nullptr);
// internal helper.
class Internal;
/*!
* \brief Visitor that recursively visit stmts and exprs on them.
*/
-class StmtExprVisitor :
- public StmtVisitor,
- public ExprVisitor {
+class StmtExprVisitor : public StmtVisitor, public ExprVisitor {
public:
using StmtVisitor::operator();
using ExprVisitor::operator();
protected:
- using StmtVisitor::VisitStmt;
using ExprVisitor::VisitExpr;
+ using StmtVisitor::VisitStmt;
- void VisitExpr(const PrimExpr& e) override {
- return ExprVisitor::VisitExpr(e);
- }
+ void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); }
};
/*!
* \brief Mutator that recursively mutates stmts and exprs on them.
*/
-class StmtExprMutator :
- public StmtMutator,
- public ExprMutator {
+class StmtExprMutator : public StmtMutator, public ExprMutator {
public:
using StmtMutator::operator();
using ExprMutator::operator();
protected:
- using StmtMutator::VisitExpr;
using ExprMutator::VisitExpr;
+ using StmtMutator::VisitExpr;
- PrimExpr VisitExpr(const PrimExpr& e) override {
- return ExprMutator::VisitExpr(e);
- }
+ PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); }
};
/*!
* If it is not null, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
-TVM_DLL Stmt IRTransform(Stmt stmt,
- const runtime::PackedFunc& preorder,
+TVM_DLL Stmt IRTransform(Stmt stmt, const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
Optional<Array<String>> only_enable = NullOpt);
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The converted form.
*/
-TVM_DLL Stmt Substitute(Stmt stmt,
- std::function<Optional<PrimExpr>(const Var& var)> vmap);
+TVM_DLL Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var& var)> vmap);
/*!
* \brief Substitute the var specified by vmap.
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The result.
*/
-TVM_DLL PrimExpr Substitute(PrimExpr expr,
- std::function<Optional<PrimExpr>(const Var& var)> vmap);
+TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var& var)> vmap);
/*!
* \brief Sugar for substitute via a given map.
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
-template<typename T>
+template <typename T>
inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var);
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
-template<typename T>
-inline T Substitute(T input,
- const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
+template <typename T>
+inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var.get());
if (it != value_map.end()) return (*it).second;
namespace transform {
using tvm::transform::Pass;
-using tvm::transform::PassNode;
-using tvm::transform::PassInfo;
-using tvm::transform::PassInfoNode;
using tvm::transform::PassContext;
using tvm::transform::PassContextNode;
+using tvm::transform::PassInfo;
+using tvm::transform::PassInfoNode;
+using tvm::transform::PassNode;
using tvm::transform::Sequential;
/*
*
* \return The created function pass.
*/
-TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
- PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
- int opt_level,
- const std::string& name,
- const tvm::Array<runtime::String>& required);
-
+TVM_DLL Pass CreatePrimFuncPass(
+ const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
+ int opt_level, const std::string& name, const tvm::Array<runtime::String>& required);
/*!
* \brief Inject prefetch instructions into stmt.
*
* \return The Pass
*/
-TVM_DLL Pass StorageFlatten(int cache_line_size,
- bool create_bound_attribute = false);
+TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false);
/*!
* \brief Inject copy intrinsics with optional pad.
* Expr pad_value)
* \return The pass.
*/
-TVM_DLL Pass InjectCopyIntrin(std::string pragma_key,
- runtime::PackedFunc fintrin);
+TVM_DLL Pass InjectCopyIntrin(std::string pragma_key, runtime::PackedFunc fintrin);
/*!
* \brief Detect and insert sync points to co-processor.
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return The pass.
*/
-TVM_DLL Pass UnrollLoop(int auto_max_step,
- int auto_max_depth,
- int auto_max_extent,
+TVM_DLL Pass UnrollLoop(int auto_max_step, int auto_max_depth, int auto_max_extent,
bool explicit_unroll);
/*!
TVM_DLL Pass RewriteUnsafeSelect();
/*!
-* \brief Run arithmetic simplifications on the statements and expressions.
-*
-* \return The pass.
-*/
+ * \brief Run arithmetic simplifications on the statements and expressions.
+ *
+ * \return The pass.
+ */
TVM_DLL Pass Simplify();
/*!
-* \brief Instruments bound checkers.
-*
-* \return The pass.
-*/
+ * \brief Instruments bound checkers.
+ *
+ * \return The pass.
+ */
TVM_DLL Pass InstrumentBoundCheckers();
/*!
*/
TVM_DLL Pass ThreadSync(std::string storage_scope);
-
/*!
* \brief Lower cross thread alleduce.
*
*/
TVM_DLL Pass CombineContextCall();
-
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
#ifndef TVM_TIR_VAR_H_
#define TVM_TIR_VAR_H_
+#include <tvm/ir/expr.h>
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
-#include <tvm/ir/expr.h>
+
#include <string>
namespace tvm {
* \param name_hint variable name
* \param dtype data type
*/
- TVM_DLL explicit Var(std::string name_hint = "v",
- DataType dtype = DataType::Int(32));
+ TVM_DLL explicit Var(std::string name_hint = "v", DataType dtype = DataType::Int(32));
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
- const VarNode* operator->() const {
- return get();
- }
+ const VarNode* operator->() const { return get(); }
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
- const VarNode* get() const {
- return static_cast<const VarNode*>(data_.get());
- }
+ const VarNode* get() const { return static_cast<const VarNode*>(data_.get()); }
/*! \brief type indicate the container type */
using ContainerType = VarNode;
};
* \param name_hint variable name
* \param t data type
*/
- TVM_DLL explicit SizeVar(std::string name_hint = "s",
- DataType t = DataType::Int(32));
+ TVM_DLL explicit SizeVar(std::string name_hint = "s", DataType t = DataType::Int(32));
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
- const SizeVarNode* operator->() const {
- return get();
- }
+ const SizeVarNode* operator->() const { return get(); }
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
- const SizeVarNode* get() const {
- return static_cast<const SizeVarNode*>(data_.get());
- }
+ const SizeVarNode* get() const { return static_cast<const SizeVarNode*>(data_.get()); }
/*! \brief type indicate the container type */
using ContainerType = SizeVarNode;
};
-
/*! \brief container class of iteration variable. */
class IterVarNode;
}
bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
- return
- equal(dom, other->dom) &&
- equal.DefEqual(var, other->var) &&
- equal(iter_type, other->iter_type) &&
- equal(thread_tag, other->thread_tag);
+ return equal(dom, other->dom) && equal.DefEqual(var, other->var) &&
+ equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(thread_tag);
}
- TVM_DLL static IterVar make(Range dom, Var var,
- IterVarType iter_type,
+ TVM_DLL static IterVar make(Range dom, Var var, IterVarType iter_type,
std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar";
return static_cast<const IterVarNode*>(data_.get());
}
-inline IterVar::operator PrimExpr() const {
- return (*this)->var;
-}
+inline IterVar::operator PrimExpr() const { return (*this)->var; }
inline const char* IterVarType2String(IterVarType t) {
switch (t) {
- case kDataPar: return "DataPar";
- case kThreadIndex: return "ThreadIndex";
- case kCommReduce: return "CommReduce";
- case kOrdered: return "Ordered";
- case kOpaque: return "Opaque";
- case kUnrolled: return "Unrolled";
- case kVectorized: return "Vectorized";
- case kParallelized: return "Parallelized";
- case kTensorized: return "Tensorized";
+ case kDataPar:
+ return "DataPar";
+ case kThreadIndex:
+ return "ThreadIndex";
+ case kCommReduce:
+ return "CommReduce";
+ case kOrdered:
+ return "Ordered";
+ case kOpaque:
+ return "Opaque";
+ case kUnrolled:
+ return "Unrolled";
+ case kVectorized:
+ return "Vectorized";
+ case kParallelized:
+ return "Parallelized";
+ case kTensorized:
+ return "Tensorized";
}
return "Unknown";
}
#define TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
// Helper functions for RefXXX getter & setter
-jlong getLongField(JNIEnv *env, jobject obj) {
+jlong getLongField(JNIEnv* env, jobject obj) {
jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
jlong ret = env->GetLongField(obj, refFid);
return ret;
}
-jint getIntField(JNIEnv *env, jobject obj) {
+jint getIntField(JNIEnv* env, jobject obj) {
jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
jint ret = env->GetIntField(obj, refFid);
return ret;
}
-void setIntField(JNIEnv *env, jobject obj, jint value) {
+void setIntField(JNIEnv* env, jobject obj, jint value) {
jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
env->SetIntField(obj, refFid, value);
env->DeleteLocalRef(refClass);
}
-void setLongField(JNIEnv *env, jobject obj, jlong value) {
+void setLongField(JNIEnv* env, jobject obj, jlong value) {
jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
env->SetLongField(obj, refFid, value);
env->DeleteLocalRef(refClass);
}
-void setStringField(JNIEnv *env, jobject obj, const char *value) {
+void setStringField(JNIEnv* env, jobject obj, const char* value) {
jclass refClass = env->FindClass("org/apache/tvm/Base$RefString");
jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;");
env->SetObjectField(obj, refFid, env->NewStringUTF(value));
}
// Helper functions for TVMValue
-jlong getTVMValueLongField(JNIEnv *env, jobject obj,
- const char *clsname = "org/apache/tvm/TVMValueLong") {
+jlong getTVMValueLongField(JNIEnv* env, jobject obj,
+ const char* clsname = "org/apache/tvm/TVMValueLong") {
jclass cls = env->FindClass(clsname);
jfieldID fid = env->GetFieldID(cls, "value", "J");
jlong ret = env->GetLongField(obj, fid);
return ret;
}
-jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) {
+jdouble getTVMValueDoubleField(JNIEnv* env, jobject obj) {
jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble");
jfieldID fid = env->GetFieldID(cls, "value", "D");
jdouble ret = env->GetDoubleField(obj, fid);
return ret;
}
-jstring getTVMValueStringField(JNIEnv *env, jobject obj) {
+jstring getTVMValueStringField(JNIEnv* env, jobject obj) {
jclass cls = env->FindClass("org/apache/tvm/TVMValueString");
jfieldID fid = env->GetFieldID(cls, "value", "Ljava/lang/String;");
jstring ret = static_cast<jstring>(env->GetObjectField(obj, fid));
return ret;
}
-jobject newTVMValueHandle(JNIEnv *env, jlong value) {
+jobject newTVMValueHandle(JNIEnv* env, jlong value) {
jclass cls = env->FindClass("org/apache/tvm/TVMValueHandle");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
return object;
}
-jobject newTVMValueLong(JNIEnv *env, jlong value) {
+jobject newTVMValueLong(JNIEnv* env, jlong value) {
jclass cls = env->FindClass("org/apache/tvm/TVMValueLong");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
return object;
}
-jobject newTVMValueDouble(JNIEnv *env, jdouble value) {
+jobject newTVMValueDouble(JNIEnv* env, jdouble value) {
jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(D)V");
jobject object = env->NewObject(cls, constructor, value);
return object;
}
-jobject newTVMValueString(JNIEnv *env, const char *value) {
+jobject newTVMValueString(JNIEnv* env, const char* value) {
jstring jvalue = env->NewStringUTF(value);
jclass cls = env->FindClass("org/apache/tvm/TVMValueString");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(Ljava/lang/String;)V");
return object;
}
-jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) {
+jobject newTVMValueBytes(JNIEnv* env, const TVMByteArray* arr) {
jbyteArray jarr = env->NewByteArray(arr->size);
env->SetByteArrayRegion(jarr, 0, arr->size,
- reinterpret_cast<jbyte *>(const_cast<char *>(arr->data)));
+ reinterpret_cast<jbyte*>(const_cast<char*>(arr->data)));
jclass cls = env->FindClass("org/apache/tvm/TVMValueBytes");
jmethodID constructor = env->GetMethodID(cls, "<init>", "([B)V");
jobject object = env->NewObject(cls, constructor, jarr);
return object;
}
-jobject newModule(JNIEnv *env, jlong value) {
+jobject newModule(JNIEnv* env, jlong value) {
jclass cls = env->FindClass("org/apache/tvm/Module");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
return object;
}
-jobject newFunction(JNIEnv *env, jlong value) {
+jobject newFunction(JNIEnv* env, jlong value) {
jclass cls = env->FindClass("org/apache/tvm/Function");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
return object;
}
-jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) {
+jobject newNDArray(JNIEnv* env, jlong handle, jboolean isview) {
jclass cls = env->FindClass("org/apache/tvm/NDArrayBase");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(JZ)V");
jobject object = env->NewObject(cls, constructor, handle, isview);
return object;
}
-jobject newObject(JNIEnv *env, const char *clsname) {
+jobject newObject(JNIEnv* env, const char* clsname) {
jclass cls = env->FindClass(clsname);
jmethodID constructor = env->GetMethodID(cls, "<init>", "()V");
jobject object = env->NewObject(cls, constructor);
return object;
}
-void fromJavaDType(JNIEnv *env, jobject jdtype, DLDataType *dtype) {
+void fromJavaDType(JNIEnv* env, jobject jdtype, DLDataType* dtype) {
jclass tvmTypeClass = env->FindClass("org/apache/tvm/DLDataType");
dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I")));
dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I")));
env->DeleteLocalRef(tvmTypeClass);
}
-void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
+void fromJavaContext(JNIEnv* env, jobject jctx, TVMContext* ctx) {
jclass tvmContextClass = env->FindClass("org/apache/tvm/TVMContext");
- ctx->device_type = static_cast<DLDeviceType>(env->GetIntField(jctx,
- env->GetFieldID(tvmContextClass, "deviceType", "I")));
- ctx->device_id = static_cast<int>(env->GetIntField(jctx,
- env->GetFieldID(tvmContextClass, "deviceId", "I")));
+ ctx->device_type = static_cast<DLDeviceType>(
+ env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceType", "I")));
+ ctx->device_id =
+ static_cast<int>(env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceId", "I")));
env->DeleteLocalRef(tvmContextClass);
}
-jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
+jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) {
switch (tcode) {
case kDLUInt:
case kDLInt:
case kTVMStr:
return newTVMValueString(env, value.v_str);
case kTVMBytes:
- return newTVMValueBytes(env, reinterpret_cast<TVMByteArray *>(value.v_handle));
+ return newTVMValueBytes(env, reinterpret_cast<TVMByteArray*>(value.v_handle));
case kTVMNullptr:
return newObject(env, "org/apache/tvm/TVMValueNull");
default:
#include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h>
#endif
-#include <iostream>
#include <cstring>
-#include <vector>
+#include <iostream>
#include <thread>
+#include <vector>
#include "jni_helper_func.h"
-JavaVM *_jvm;
-void *_tvmHandle = nullptr;
+JavaVM* _jvm;
+void* _tvmHandle = nullptr;
struct TVMFuncArgsThreadLocalEntry {
std::vector<TVMValue> tvmFuncArgValues;
std::vector<int> tvmFuncArgTypes;
// for later release
- std::vector<std::pair<jstring, const char *> > tvmFuncArgPushedStrs;
- std::vector<std::pair<jbyteArray, TVMByteArray *> > tvmFuncArgPushedBytes;
+ std::vector<std::pair<jstring, const char*> > tvmFuncArgPushedStrs;
+ std::vector<std::pair<jbyteArray, TVMByteArray*> > tvmFuncArgPushedBytes;
};
typedef dmlc::ThreadLocalStore<TVMFuncArgsThreadLocalEntry> TVMFuncArgsThreadLocalStore;
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit
- (JNIEnv *env, jobject obj, jstring jtvmLibFile) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit(JNIEnv* env, jobject obj,
+ jstring jtvmLibFile) {
if (_tvmHandle == NULL && !env->IsSameObject(jtvmLibFile, NULL)) {
- const char *tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0);
+ const char* tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0);
_tvmHandle = dlopen(tvmLibFile, RTLD_LAZY | RTLD_GLOBAL);
env->ReleaseStringUTFChars(jtvmLibFile, tvmLibFile);
if (!_tvmHandle) {
return env->GetJavaVM(&_jvm);
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv *env, jobject obj) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv* env, jobject obj) {
if (_tvmHandle) {
dlclose(_tvmHandle);
}
return 0;
}
-JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv * env, jobject obj) {
+JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv* env, jobject obj) {
return env->NewStringUTF(TVMGetLastError());
}
// Function
-JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong(
- JNIEnv *env, jobject obj, jlong arg) {
+JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong(JNIEnv* env, jobject obj,
+ jlong arg) {
TVMValue value;
value.v_int64 = static_cast<int64_t>(arg);
- TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
+ TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kDLInt);
}
-JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble(
- JNIEnv *env, jobject obj, jdouble arg) {
+JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble(JNIEnv* env, jobject obj,
+ jdouble arg) {
TVMValue value;
value.v_float64 = static_cast<double>(arg);
- TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
+ TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kDLFloat);
}
-JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString(
- JNIEnv *env, jobject obj, jstring arg) {
+JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString(JNIEnv* env, jobject obj,
+ jstring arg) {
TVMValue value;
jstring garg = reinterpret_cast<jstring>(env->NewGlobalRef(arg));
value.v_str = env->GetStringUTFChars(garg, 0);
- TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
+ TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kTVMStr);
// release string args later
e->tvmFuncArgPushedStrs.push_back(std::make_pair(garg, value.v_str));
}
-JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(
- JNIEnv *env, jobject obj, jlong arg, jint argType) {
+JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv* env, jobject obj,
+ jlong arg, jint argType) {
TVMValue value;
- value.v_handle = reinterpret_cast<void *>(arg);
- TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
+ value.v_handle = reinterpret_cast<void*>(arg);
+ TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(static_cast<int>(argType));
}
-JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(
- JNIEnv *env, jobject obj, jbyteArray arg) {
+JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj,
+ jbyteArray arg) {
jbyteArray garg = reinterpret_cast<jbyteArray>(env->NewGlobalRef(arg));
- jbyte *data = env->GetByteArrayElements(garg, 0);
+ jbyte* data = env->GetByteArrayElements(garg, 0);
- TVMByteArray *byteArray = new TVMByteArray();
+ TVMByteArray* byteArray = new TVMByteArray();
byteArray->size = static_cast<size_t>(env->GetArrayLength(garg));
- byteArray->data = reinterpret_cast<const char *>(data);
+ byteArray->data = reinterpret_cast<const char*>(data);
TVMValue value;
- value.v_handle = reinterpret_cast<void *>(byteArray);
+ value.v_handle = reinterpret_cast<void*>(byteArray);
- TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
+ TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kTVMBytes);
// release (garg, data), byteArray later
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames(
- JNIEnv *env, jobject obj, jobject jfuncNames) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames(JNIEnv* env, jobject obj,
+ jobject jfuncNames) {
int outSize;
- const char **outArray;
+ const char** outArray;
int ret = TVMFuncListGlobalNames(&outSize, &outArray);
if (ret) {
return ret;
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree(
- JNIEnv *env, jobject obj, jlong jhandle) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree(JNIEnv* env, jobject obj,
+ jlong jhandle) {
return TVMFuncFree(reinterpret_cast<TVMFunctionHandle>(jhandle));
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal(
- JNIEnv *env, jobject obj, jstring jname, jobject jhandle) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal(JNIEnv* env, jobject obj,
+ jstring jname,
+ jobject jhandle) {
TVMFunctionHandle handle;
- const char *name = env->GetStringUTFChars(jname, 0);
+ const char* name = env->GetStringUTFChars(jname, 0);
int ret = TVMFuncGetGlobal(name, &handle);
env->ReleaseStringUTFChars(jname, name);
setLongField(env, jhandle, reinterpret_cast<jlong>(handle));
return ret;
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall(
- JNIEnv *env, jobject obj, jlong jhandle, jobject jretVal) {
- TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall(JNIEnv* env, jobject obj,
+ jlong jhandle, jobject jretVal) {
+ TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get();
int numArgs = e->tvmFuncArgValues.size();
TVMValue retVal;
e->tvmFuncArgTypes.clear();
e->tvmFuncArgValues.clear();
- int ret = TVMFuncCall(reinterpret_cast<TVMFunctionHandle>(jhandle),
- &argValues[0], &argTypes[0], numArgs, &retVal, &retTypeCode);
+ int ret = TVMFuncCall(reinterpret_cast<TVMFunctionHandle>(jhandle), &argValues[0], &argTypes[0],
+ numArgs, &retVal, &retTypeCode);
if (ret != 0) {
return ret;
env->DeleteGlobalRef(iter->first);
}
for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) {
- env->ReleaseByteArrayElements(iter->first,
- reinterpret_cast<jbyte *>(const_cast<char *>(iter->second->data)), 0);
+ env->ReleaseByteArrayElements(
+ iter->first, reinterpret_cast<jbyte*>(const_cast<char*>(iter->second->data)), 0);
env->DeleteGlobalRef(iter->first);
delete iter->second;
}
// return TVMValue object to Java
jclass refTVMValueCls = env->FindClass("org/apache/tvm/Base$RefTVMValue");
- jfieldID refTVMValueFid
- = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;");
+ jfieldID refTVMValueFid = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;");
env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, retVal, retTypeCode));
}
// Callback function
-extern "C" int funcInvokeCallback(TVMValue *args,
- int *typeCodes, int numArgs, TVMRetValueHandle ret, void *resourceHandle) {
- JNIEnv *env;
- int jniStatus = _jvm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6);
+extern "C" int funcInvokeCallback(TVMValue* args, int* typeCodes, int numArgs,
+ TVMRetValueHandle ret, void* resourceHandle) {
+ JNIEnv* env;
+ int jniStatus = _jvm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
if (jniStatus == JNI_EDETACHED) {
- #ifdef TVM4J_ANDROID
+#ifdef TVM4J_ANDROID
_jvm->AttachCurrentThread(&env, nullptr);
- #else
- _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), nullptr);
- #endif
+#else
+ _jvm->AttachCurrentThread(reinterpret_cast<void**>(&env), nullptr);
+#endif
} else {
CHECK(jniStatus == JNI_OK);
}
for (int i = 0; i < numArgs; ++i) {
TVMValue arg = args[i];
int tcode = typeCodes[i];
- if (tcode == kTVMObjectHandle ||
- tcode == kTVMPackedFuncHandle ||
- tcode == kTVMObjectRValueRefArg ||
- tcode == kTVMModuleHandle) {
+ if (tcode == kTVMObjectHandle || tcode == kTVMPackedFuncHandle ||
+ tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle) {
TVMCbArgToReturn(&arg, &tcode);
}
jobject jarg = tvmRetValueToJava(env, arg, tcode);
}
jclass clsFunc = env->FindClass("org/apache/tvm/Function");
- jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID(clsFunc, "invokeRegisteredCbFunc",
+ jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID(
+ clsFunc, "invokeRegisteredCbFunc",
"(Lorg/apache/tvm/Function$Callback;[Lorg/apache/tvm/TVMValue;)Ljava/lang/Object;");
- jmethodID pushArgToStack = env->GetStaticMethodID(clsFunc, "pushArgToStack",
- "(Ljava/lang/Object;)V");
+ jmethodID pushArgToStack =
+ env->GetStaticMethodID(clsFunc, "pushArgToStack", "(Ljava/lang/Object;)V");
jobject jretValue = env->CallStaticObjectMethod(clsFunc, invokeRegisteredCbFunc,
- reinterpret_cast<jobject>(resourceHandle), jargs);
+ reinterpret_cast<jobject>(resourceHandle), jargs);
- TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
+ TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get();
const size_t prevNumStrArg = e->tvmFuncArgPushedStrs.size();
const size_t prevNumBytesArg = e->tvmFuncArgPushedBytes.size();
// release allocated strings.
if (e->tvmFuncArgPushedStrs.size() > prevNumStrArg) {
- const auto &pairArg = e->tvmFuncArgPushedStrs.back();
+ const auto& pairArg = e->tvmFuncArgPushedStrs.back();
env->ReleaseStringUTFChars(pairArg.first, pairArg.second);
env->DeleteGlobalRef(pairArg.first);
e->tvmFuncArgPushedStrs.pop_back();
}
// release allocated bytes.
if (e->tvmFuncArgPushedBytes.size() > prevNumBytesArg) {
- const auto &pairArg = e->tvmFuncArgPushedBytes.back();
- env->ReleaseByteArrayElements(pairArg.first,
- reinterpret_cast<jbyte *>(const_cast<char *>(pairArg.second->data)), 0);
+ const auto& pairArg = e->tvmFuncArgPushedBytes.back();
+ env->ReleaseByteArrayElements(
+ pairArg.first, reinterpret_cast<jbyte*>(const_cast<char*>(pairArg.second->data)), 0);
env->DeleteGlobalRef(pairArg.first);
delete pairArg.second;
e->tvmFuncArgPushedBytes.pop_back();
}
// Free callback function
-extern "C" void funcFreeCallback(void *resourceHandle) {
- JNIEnv *env;
- int jniStatus = _jvm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6);
+extern "C" void funcFreeCallback(void* resourceHandle) {
+ JNIEnv* env;
+ int jniStatus = _jvm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
if (jniStatus == JNI_EDETACHED) {
- #ifdef TVM4J_ANDROID
+#ifdef TVM4J_ANDROID
_jvm->AttachCurrentThread(&env, nullptr);
- #else
- _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), nullptr);
- #endif
+#else
+ _jvm->AttachCurrentThread(reinterpret_cast<void**>(&env), nullptr);
+#endif
} else {
CHECK(jniStatus == JNI_OK);
}
env->DeleteGlobalRef(reinterpret_cast<jobject>(resourceHandle));
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc(
- JNIEnv *env, jobject obj, jobject jfunction, jobject jretHandle) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc(JNIEnv* env, jobject obj,
+ jobject jfunction,
+ jobject jretHandle) {
TVMFunctionHandle out;
- int ret = TVMFuncCreateFromCFunc(reinterpret_cast<TVMPackedCFunc>(&funcInvokeCallback),
- reinterpret_cast<void *>(env->NewGlobalRef(jfunction)),
- reinterpret_cast<TVMPackedCFuncFinalizer>(&funcFreeCallback),
- &out);
+ int ret =
+ TVMFuncCreateFromCFunc(reinterpret_cast<TVMPackedCFunc>(&funcInvokeCallback),
+ reinterpret_cast<void*>(env->NewGlobalRef(jfunction)),
+ reinterpret_cast<TVMPackedCFuncFinalizer>(&funcFreeCallback), &out);
setLongField(env, jretHandle, reinterpret_cast<jlong>(out));
return ret;
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal(
- JNIEnv *env, jobject obj, jstring jname, jlong jhandle, jint joverride) {
- const char *name = env->GetStringUTFChars(jname, 0);
- int ret = TVMFuncRegisterGlobal(
- name, reinterpret_cast<TVMFunctionHandle>(jhandle), reinterpret_cast<int>(joverride));
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal(JNIEnv* env, jobject obj,
+ jstring jname,
+ jlong jhandle,
+ jint joverride) {
+ const char* name = env->GetStringUTFChars(jname, 0);
+ int ret = TVMFuncRegisterGlobal(name, reinterpret_cast<TVMFunctionHandle>(jhandle),
+ reinterpret_cast<int>(joverride));
env->ReleaseStringUTFChars(jname, name);
return ret;
}
// Module
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree(
- JNIEnv *env, jobject obj, jlong jhandle) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree(JNIEnv* env, jobject obj,
+ jlong jhandle) {
return TVMModFree(reinterpret_cast<TVMModuleHandle>(jhandle));
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport(
- JNIEnv *env, jobject obj, jlong jmod, jlong jdep) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport(JNIEnv* env, jobject obj,
+ jlong jmod, jlong jdep) {
return TVMModImport(reinterpret_cast<TVMModuleHandle>(jmod),
reinterpret_cast<TVMModuleHandle>(jdep));
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction(
- JNIEnv *env, jobject obj, jlong jhandle, jstring jname, jint jimport, jobject jret) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction(JNIEnv* env, jobject obj,
+ jlong jhandle, jstring jname,
+ jint jimport, jobject jret) {
TVMFunctionHandle retFunc;
- const char *name = env->GetStringUTFChars(jname, 0);
- int ret = TVMModGetFunction(reinterpret_cast<TVMFunctionHandle>(jhandle),
- name,
- reinterpret_cast<int>(jimport),
- &retFunc);
+ const char* name = env->GetStringUTFChars(jname, 0);
+ int ret = TVMModGetFunction(reinterpret_cast<TVMFunctionHandle>(jhandle), name,
+ reinterpret_cast<int>(jimport), &retFunc);
env->ReleaseStringUTFChars(jname, name);
setLongField(env, jret, reinterpret_cast<jlong>(retFunc));
}
// NDArray
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree(
- JNIEnv *env, jobject obj, jlong jhandle) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree(JNIEnv* env, jobject obj,
+ jlong jhandle) {
return TVMArrayFree(reinterpret_cast<TVMArrayHandle>(jhandle));
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc(
- JNIEnv *env, jobject obj, jlongArray jshape, jint jdtypeCode,
- jint jdtypeBits, jint jdtypeLanes, jint jdeviceType, jint jdeviceId, jobject jret) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc(JNIEnv* env, jobject obj,
+ jlongArray jshape, jint jdtypeCode,
+ jint jdtypeBits, jint jdtypeLanes,
+ jint jdeviceType, jint jdeviceId,
+ jobject jret) {
int ndim = static_cast<int>(env->GetArrayLength(jshape));
TVMArrayHandle out;
- jlong *shapeArray = env->GetLongArrayElements(jshape, NULL);
- int ret = TVMArrayAlloc(
- reinterpret_cast<const tvm_index_t*>(shapeArray),
- ndim,
- static_cast<int>(jdtypeCode),
- static_cast<int>(jdtypeBits),
- static_cast<int>(jdtypeLanes),
- static_cast<int>(jdeviceType),
- static_cast<int>(jdeviceId),
- &out);
+ jlong* shapeArray = env->GetLongArrayElements(jshape, NULL);
+ int ret = TVMArrayAlloc(reinterpret_cast<const tvm_index_t*>(shapeArray), ndim,
+ static_cast<int>(jdtypeCode), static_cast<int>(jdtypeBits),
+ static_cast<int>(jdtypeLanes), static_cast<int>(jdeviceType),
+ static_cast<int>(jdeviceId), &out);
env->ReleaseLongArrayElements(jshape, shapeArray, 0);
setLongField(env, jret, reinterpret_cast<jlong>(out));
return ret;
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(
- JNIEnv *env, jobject obj, jlong jhandle, jobject jshape) {
- DLTensor *array = reinterpret_cast<DLTensor *>(jhandle);
- int64_t *shape = array->shape;
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(JNIEnv* env, jobject obj,
+ jlong jhandle, jobject jshape) {
+ DLTensor* array = reinterpret_cast<DLTensor*>(jhandle);
+ int64_t* shape = array->shape;
int ndim = array->ndim;
// fill shape buffer
return 0;
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo(
- JNIEnv *env, jobject obj, jlong jfrom, jlong jto) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo(JNIEnv* env, jobject obj,
+ jlong jfrom, jlong jto) {
return TVMArrayCopyFromTo(reinterpret_cast<TVMArrayHandle>(jfrom),
reinterpret_cast<TVMArrayHandle>(jto), NULL);
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray(
- JNIEnv *env, jobject obj, jbyteArray jarr, jlong jfrom, jlong jto) {
- jbyte *data = env->GetByteArrayElements(jarr, NULL);
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray(JNIEnv* env, jobject obj,
+ jbyteArray jarr,
+ jlong jfrom, jlong jto) {
+ jbyte* data = env->GetByteArrayElements(jarr, NULL);
- DLTensor *from = reinterpret_cast<DLTensor *>(jfrom);
- from->data = static_cast<void *>(data);
+ DLTensor* from = reinterpret_cast<DLTensor*>(jfrom);
+ from->data = static_cast<void*>(data);
int ret = TVMArrayCopyFromTo(static_cast<TVMArrayHandle>(from),
reinterpret_cast<TVMArrayHandle>(jto), NULL);
return ret;
}
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray(
- JNIEnv *env, jobject obj, jlong jfrom, jbyteArray jarr) {
- DLTensor *from = reinterpret_cast<DLTensor *>(jfrom);
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray(JNIEnv* env, jobject obj,
+ jlong jfrom,
+ jbyteArray jarr) {
+ DLTensor* from = reinterpret_cast<DLTensor*>(jfrom);
int size = static_cast<int>(env->GetArrayLength(jarr));
- jbyte *pdata = env->GetByteArrayElements(jarr, NULL);
+ jbyte* pdata = env->GetByteArrayElements(jarr, NULL);
int ret = 0;
- if (memcpy(static_cast<void *>(pdata), from->data, size) == NULL) {
+ if (memcpy(static_cast<void*>(pdata), from->data, size) == NULL) {
ret = 1;
}
env->ReleaseByteArrayElements(jarr, pdata, 0); // copy back to java array automatically
}
// Context
-JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(
- JNIEnv *env, jint deviceType, jint deviceId) {
+JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, jint deviceType,
+ jint deviceId) {
return TVMSynchronize(static_cast<int>(deviceType), static_cast<int>(deviceId), NULL);
}
#ifndef NNVM_BASE_H_
#define NNVM_BASE_H_
+#include <dmlc/any.h>
+#include <dmlc/array_view.h>
#include <dmlc/base.h>
#include <dmlc/common.h>
-#include <dmlc/any.h>
-#include <dmlc/memory.h>
#include <dmlc/logging.h>
+#include <dmlc/memory.h>
#include <dmlc/registry.h>
-#include <dmlc/array_view.h>
namespace nnvm {
kFloat16 = 2,
kUint8 = 3,
kInt32 = 4,
- kInt8 = 5,
+ kInt8 = 5,
kInt64 = 6,
// kBool = 7,
// 7 is reserved for kBool, in order to keep consistency with MXNet TypeFlag defined in
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
typedef unsigned int nn_uint;
/*! \brief handle to a function that takes param and creates symbol */
-typedef void *OpHandle;
+typedef void* OpHandle;
/*! \brief handle to a symbol that can be bind as operator */
-typedef void *SymbolHandle;
+typedef void* SymbolHandle;
/*! \brief handle to Graph */
-typedef void *GraphHandle;
+typedef void* GraphHandle;
#ifdef __cplusplus
extern "C" {
* this function is threadsafe and can be called by different thread
* \return error info
*/
-NNVM_DLL const char *NNGetLastError(void);
+NNVM_DLL const char* NNGetLastError(void);
/*!
* \brief list all the available operator names, include entries.
* \param out_array the output operator name array.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNListAllOpNames(nn_uint *out_size,
- const char*** out_array);
+NNVM_DLL int NNListAllOpNames(nn_uint* out_size, const char*** out_array);
/*!
* \brief Get operator handle given name.
* \param op_name The name of the operator.
* \param op_out The returnning op handle.
*/
-NNVM_DLL int NNGetOpHandle(const char* op_name,
- OpHandle* op_out);
+NNVM_DLL int NNGetOpHandle(const char* op_name, OpHandle* op_out);
/*!
* \brief list all the available operators.
* \param out_array the output AtomicSymbolCreator array
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNListUniqueOps(nn_uint *out_size,
- OpHandle **out_array);
+NNVM_DLL int NNListUniqueOps(nn_uint* out_size, OpHandle** out_array);
/*!
* \brief Get the detailed information about atomic symbol.
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNGetOpInfo(OpHandle op,
- const char **real_name,
- const char **description,
- nn_uint *num_doc_args,
- const char ***arg_names,
- const char ***arg_type_infos,
- const char ***arg_descriptions,
- const char **return_type);
+NNVM_DLL int NNGetOpInfo(OpHandle op, const char** real_name, const char** description,
+ nn_uint* num_doc_args, const char*** arg_names,
+ const char*** arg_type_infos, const char*** arg_descriptions,
+ const char** return_type);
/*!
* \brief Create an AtomicSymbol functor.
* \param op The operator handle
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op,
- nn_uint num_param,
- const char **keys,
- const char **vals,
- SymbolHandle *out);
+NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, nn_uint num_param, const char** keys,
+ const char** vals, SymbolHandle* out);
/*!
* \brief Create a Variable Symbol.
* \param name name of the variable
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out);
+NNVM_DLL int NNSymbolCreateVariable(const char* name, SymbolHandle* out);
/*!
* \brief Create a Symbol by grouping list of symbols together
* \param num_symbols number of symbols to be grouped
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols,
- SymbolHandle *symbols,
- SymbolHandle *out);
+NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, SymbolHandle* symbols, SymbolHandle* out);
/*!
* \brief Add src_dep to the handle as control dep.
* \param handle The symbol to add dependency edges on.
* \param src_dep the source handles.
*/
-NNVM_DLL int NNAddControlDeps(SymbolHandle handle,
- SymbolHandle src_dep);
+NNVM_DLL int NNAddControlDeps(SymbolHandle handle, SymbolHandle src_dep);
/*!
* \brief Free the symbol handle.
* \param symbol the symbol
* \param out used to hold the result of copy
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
+NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle* out);
/*!
* \brief Print the content of symbol, used for debug.
* \param symbol the symbol
* \param out_str pointer to hold the output string of the printing.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str);
+NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char** out_str);
/*!
* \brief Get string attribute from symbol
* \param symbol the source symbol
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol,
- const char* key,
- const char** out,
- int *success);
+NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success);
/*!
* \brief Set string attribute from symbol.
- * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph.
+ * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic
+ * graph.
*
* Safe recommendaton: use immutable graph
* - Only allow set attributes during creation of new symbol as optional parameter
* \param values The value to be set
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol,
- nn_uint num_param,
- const char** keys,
+NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, const char** keys,
const char** values);
/*!
* \brief Get all attributes from symbol, including all descendents.
* \param out 2*out_size strings representing key value pairs.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol,
- int recursive_option,
- nn_uint *out_size,
+NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, int recursive_option, nn_uint* out_size,
const char*** out);
/*!
* \param out_sym_array the output array.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol,
- int option,
- nn_uint *out_size,
+NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, int option, nn_uint* out_size,
SymbolHandle** out_sym_array);
/*!
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol,
- int option,
- nn_uint *out_size,
- const char ***out_str_array);
+NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, int option, nn_uint* out_size,
+ const char*** out_str_array);
/*!
* \brief List returns names in the symbol.
* \param symbol the symbol
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol,
- nn_uint *out_size,
- const char ***out_str_array);
-
+NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, nn_uint* out_size,
+ const char*** out_str_array);
/*!
* \brief Supply number of outputs of the symbol.
* \param output_count number of outputs
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol,
- nn_uint *output_count);
+NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, nn_uint* output_count);
/*!
* \brief Get a symbol that contains all the internals.
* \param out The output symbol whose outputs are all the internals.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol,
- SymbolHandle *out);
+NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, SymbolHandle* out);
/*!
* \brief Get a symbol that contains only direct children.
* \param symbol The symbol
* \param out The output symbol whose outputs are the direct children.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol,
- SymbolHandle *out);
+NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, SymbolHandle* out);
/*!
* \brief Get index-th outputs of the symbol.
* \param symbol The symbol
* \param out The output symbol whose outputs are the index-th symbol.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol,
- nn_uint index,
- SymbolHandle *out);
+NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, nn_uint index, SymbolHandle* out);
/*!
* \brief Compose the symbol on other symbols.
* \param args arguments to sym
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNSymbolCompose(SymbolHandle sym,
- const char* name,
- nn_uint num_args,
- const char** keys,
- SymbolHandle* args);
+NNVM_DLL int NNSymbolCompose(SymbolHandle sym, const char* name, nn_uint num_args,
+ const char** keys, SymbolHandle* args);
// Graph IR API
/*!
* \param graph The graph handle created.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph);
+NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle* graph);
/*!
* \brief free the graph handle
* \param handle The handle to be freed.
* \param symbol The corresponding symbol
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
+NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle* symbol);
/*!
* \brief Get Set a attribute in json format.
* Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
- const char* key,
- const char* json_value);
+NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value);
/*!
* \brief Get a serialized attrirbute from graph.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle,
- const char* key,
- const char** json_out,
- int *success);
+NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, const char* key, const char** json_out,
+ int* success);
/*!
* \brief Set a attribute whose type is std::vector<NodeEntry> in c++
* \param list The symbol whose outputs represents the list of NodeEntry to be passed.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
- const char* key,
- SymbolHandle list);
+NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, const char* key, SymbolHandle list);
/*!
* \brief Apply passes on the src graph.
* \param src The source graph handle.
* \param dst The result graph.
* \return 0 when success, -1 when failure happens
*/
-NNVM_DLL int NNGraphApplyPasses(GraphHandle src,
- nn_uint num_pass,
- const char** pass_names,
- GraphHandle *dst);
+NNVM_DLL int NNGraphApplyPasses(GraphHandle src, nn_uint num_pass, const char** pass_names,
+ GraphHandle* dst);
#ifdef __cplusplus
} /* end extern "C" */
#ifndef NNVM_GRAPH_H_
#define NNVM_GRAPH_H_
-#include <vector>
-#include <string>
-#include <utility>
#include <algorithm>
#include <memory>
+#include <string>
#include <unordered_map>
#include <unordered_set>
+#include <utility>
+#include <vector>
+
#include "base.h"
#include "node.h"
#include "symbolic.h"
* \return the reference to corresponding attribute
* \tparam T the type of the attribute.
*/
- template<typename T>
+ template <typename T>
inline const T& GetAttr(const std::string& attr_name) const;
/*!
* \brief Check whether has a specific attribute.
* \return a new copy of the corresponding attribute.
* \tparam T the type of the attribute.
*/
- template<typename T>
+ template <typename T>
inline T MoveCopyAttr(const std::string& attr_name);
/*!
* \brief get a indexed graph of current graph, if not exist, create it on demand
std::weak_ptr<nnvm::Node> weak_ref;
};
/*! \return number of nodes in the graph */
- inline size_t num_nodes() const {
- return nodes_.size();
- }
+ inline size_t num_nodes() const { return nodes_.size(); }
/*! \return total number of NodeEntry in the graph */
- inline size_t num_node_entries() const {
- return entry_rptr_.back();
- }
+ inline size_t num_node_entries() const { return entry_rptr_.back(); }
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
- inline uint32_t entry_id(const NodeEntry& e) const {
- return entry_rptr_[e.node_id] + e.index;
- }
+ inline uint32_t entry_id(const NodeEntry& e) const { return entry_rptr_[e.node_id] + e.index; }
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param node The Node to query for index.
* \return the node index.
*/
- inline uint32_t node_id(const nnvm::Node* node) const {
- return node2index_.at(node);
- }
+ inline uint32_t node_id(const nnvm::Node* node) const { return node2index_.at(node); }
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
- inline const Node& operator[](uint32_t node_id) const {
- return nodes_[node_id];
- }
+ inline const Node& operator[](uint32_t node_id) const { return nodes_[node_id]; }
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
- inline const Node& operator[](const nnvm::Node* node) const {
- return nodes_[node_id(node)];
- }
+ inline const Node& operator[](const nnvm::Node* node) const { return nodes_[node_id(node)]; }
/*! \return list of argument nodes */
- inline const std::vector<uint32_t>& input_nodes() const {
- return input_nodes_;
- }
+ inline const std::vector<uint32_t>& input_nodes() const { return input_nodes_; }
/*! \return list of mutable nodes */
inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return mutable_input_nodes_;
}
/*! \return list of output entries */
- inline const std::vector<NodeEntry>& outputs() const {
- return outputs_;
- }
+ inline const std::vector<NodeEntry>& outputs() const { return outputs_; }
/*! \return whether a node is existed in the indexed graph */
- inline bool exist(const nnvm::Node* node) const {
- return node2index_.count(node);
- }
+ inline bool exist(const nnvm::Node* node) const { return node2index_.count(node); }
// disalllow copy assign
IndexedGraph(const IndexedGraph&) = delete;
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)>
* \tparam FVisit The function type to perform the visit.
*/
-template<typename FVisit>
+template <typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);
// inline function implementations
-template<typename T>
+template <typename T>
inline const T& Graph::GetAttr(const std::string& attr_name) const {
auto it = attrs.find(attr_name);
- CHECK(it != attrs.end())
- << "Cannot find attribute " << attr_name << " in the graph";
+ CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph";
return nnvm::unsafe_get<T>(*it->second);
}
return it != attrs.end();
}
-template<typename T>
+template <typename T>
inline T Graph::MoveCopyAttr(const std::string& attr_name) {
auto it = attrs.find(attr_name);
- CHECK(it != attrs.end())
- << "Cannot find attribute " << attr_name << " in the graph";
+ CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph";
std::shared_ptr<any> sptr = it->second;
attrs.erase(it);
if (sptr.unique()) {
}
}
-template <typename GNode, typename HashType,
- typename FVisit, typename HashFunc,
- typename InDegree, typename GetInput>
-void PostOrderDFSVisit(const std::vector<GNode>& heads,
- FVisit fvisit,
- HashFunc hash,
- InDegree indegree,
- GetInput getinput) {
+template <typename GNode, typename HashType, typename FVisit, typename HashFunc, typename InDegree,
+ typename GetInput>
+void PostOrderDFSVisit(const std::vector<GNode>& heads, FVisit fvisit, HashFunc hash,
+ InDegree indegree, GetInput getinput) {
std::vector<std::pair<GNode, uint32_t> > stack;
std::unordered_set<HashType> visited;
for (auto& head : heads) {
}
}
-template<typename FVisit>
-inline void DFSVisit(const std::vector<NodeEntry>& heads,
- FVisit fvisit) {
+template <typename FVisit>
+inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit) {
typedef const ObjectPtr* GNode;
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
- [](const NodeEntry& e)->GNode {
- return &e.node;
- });
+ [](const NodeEntry& e) -> GNode { return &e.node; });
PostOrderDFSVisit<GNode, Node*>(
- head_nodes,
- [fvisit](GNode n) {
- fvisit(*n);
- }, // FVisit
- [](GNode n)->Node* {
- return n->get();
- }, // HashFunc
- [](GNode n)->uint32_t { // InDegree
+ head_nodes, [fvisit](GNode n) { fvisit(*n); }, // FVisit
+ [](GNode n) -> Node* { return n->get(); }, // HashFunc
+ [](GNode n) -> uint32_t { // InDegree
if (!(*n)) return 0;
return (*n)->inputs.size() + (*n)->control_deps.size();
- },
- [](GNode n, uint32_t index)->GNode { // GetInput
+ },
+ [](GNode n, uint32_t index) -> GNode { // GetInput
if (index < (*n)->inputs.size()) {
return &(*n)->inputs.at(index).node;
} else {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#ifndef NNVM_GRAPH_ATTR_TYPES_H_
#define NNVM_GRAPH_ATTR_TYPES_H_
-#include <vector>
#include <string>
#include <unordered_map>
-#include "tuple.h"
+#include <vector>
+
#include "layout.h"
+#include "tuple.h"
namespace nnvm {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define NNVM_LAYOUT_H_
#include <dmlc/parameter.h>
-#include <string>
+
+#include <algorithm>
#include <sstream>
-#include <vector>
+#include <string>
#include <utility>
-#include <algorithm>
+#include <vector>
namespace nnvm {
using LayoutDim = char;
/*! \brief default constructor */
- Layout() : name_("__undef__") {} // NOLINT(*)
+ Layout() : name_("__undef__") {} // NOLINT(*)
/*!
* \brief construct from a string.
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
- inline Layout(const std::string& layout) { // NOLINT(*)
+ inline Layout(const std::string& layout) { // NOLINT(*)
parse(layout);
}
/*!
* \brief copy constructor from another layout
* \param s the source layout
*/
- inline Layout(const Layout& s) { // NOLINT(*)
+ inline Layout(const Layout& s) { // NOLINT(*)
this->parse(s.name_);
}
/*!
* \brief move constructor from Layout
* \param src the source layout
*/
- inline Layout(Layout&& src) { // NOLINT(*)
+ inline Layout(Layout&& src) { // NOLINT(*)
this->swap(src);
}
/*!
* \return reference of self
*/
inline Layout& operator=(Layout&& src) {
- Layout(std::move(src)).swap(*this); // NOLINT(*)
+ Layout(std::move(src)).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \return whether two layout equals
* \param s the layout to compare against
*/
- inline bool operator==(const Layout& s) const {
- return name_ == s.name_;
- }
+ inline bool operator==(const Layout& s) const { return name_ == s.name_; }
/*!
* \return whether two layout not equal
* \param s the layout to compare against
*/
- inline bool operator!=(const Layout& s) const {
- return !(*this == s);
- }
+ inline bool operator!=(const Layout& s) const { return !(*this == s); }
/*!
* \brief Append the current layout by another.
* \param dim input dimension
* \return Whether a given dimension is a super-dimension.
*/
- static inline bool is_superdim(LayoutDim dim) {
- return dim >= 'A' && dim <= 'Z';
- }
+ static inline bool is_superdim(LayoutDim dim) { return dim >= 'A' && dim <= 'Z'; }
/*!
* \brief Check whether a given dimension is a sub-dimension.
* \param dim input dimension
* \return Whether a given dimension is a sub-dimension.
*/
- static inline bool is_subdim(LayoutDim dim) {
- return dim >= 'a' && dim <= 'z';
- }
+ static inline bool is_subdim(LayoutDim dim) { return dim >= 'a' && dim <= 'z'; }
/*!
* \brief Convert a given dimension to super-dimension.
* \param dst the target layout
* \return Whether can be converted to dst layout.
*/
- inline bool convertible(const Layout &dst) const {
+ inline bool convertible(const Layout& dst) const {
if (!this->defined() || !dst.defined()) return false;
for (size_t i = 0; i < kUniqueDim; ++i) {
if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
* \return A newly constructed Layout object.
*/
inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const {
- CHECK(target_pos <= this->ndim()) << "Invalid split position "
- << target_pos << " for layout " << name_;
+ CHECK(target_pos <= this->ndim())
+ << "Invalid split position " << target_pos << " for layout " << name_;
CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim;
CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_;
- CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim
- << " has already been split in "
- << name_;
+ CHECK(!this->contains(to_subdim(dim)))
+ << "Dimension " << dim << " has already been split in " << name_;
CHECK(size > 0) << "Invalid split size " << size;
std::ostringstream new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
using reverse_iterator = std::vector<LayoutDim>::const_reverse_iterator;
/*! \return begin iterator */
- inline iterator begin() const {
- return layout_simplified_.begin();
- }
+ inline iterator begin() const { return layout_simplified_.begin(); }
/*! \return end iterator */
- inline iterator end() const {
- return layout_simplified_.end();
- }
+ inline iterator end() const { return layout_simplified_.end(); }
/*! \return rbegin iterator */
- inline reverse_iterator rbegin() const {
- return layout_simplified_.rbegin();
- }
+ inline reverse_iterator rbegin() const { return layout_simplified_.rbegin(); }
/*! \return rend iterator */
- inline reverse_iterator rend() const {
- return layout_simplified_.rend();
- }
+ inline reverse_iterator rend() const { return layout_simplified_.rend(); }
/*! \return number of dimensions */
- inline size_t ndim() const {
- return layout_simplified_.size();
- }
+ inline size_t ndim() const { return layout_simplified_.size(); }
/*!
* \brief The description of the \p i-th dimension.
* \return the description of the dimension.
*/
inline std::string at(size_t i) const {
- CHECK_LT(i, this->ndim()) << "position " << i
- << " exceeds ndim=" << this->ndim();
+ CHECK_LT(i, this->ndim()) << "position " << i << " exceeds ndim=" << this->ndim();
std::ostringstream repr;
if (is_subdim(layout_simplified_[i])) {
auto factor = subsizeof(layout_simplified_[i]);
* \return the index or -1 if not found.
*/
inline int32_t indexof(LayoutDim dim) const {
- if (!this->defined()) return -1;
- else if (is_superdim(dim)) return superdim_pos_[dim - 'A'];
- else if (is_subdim(dim)) return subdim_pos_[dim - 'a'];
+ if (!this->defined())
+ return -1;
+ else if (is_superdim(dim))
+ return superdim_pos_[dim - 'A'];
+ else if (is_subdim(dim))
+ return subdim_pos_[dim - 'a'];
return -1;
}
*/
inline bool contains(LayoutDim dim) const {
if (is_superdim(dim)) {
- return superdim_pos_[dim-'A'] >= 0;
+ return superdim_pos_[dim - 'A'] >= 0;
} else if (is_subdim(dim)) {
- return subdim_pos_[dim-'a'] >= 0;
+ return subdim_pos_[dim - 'a'] >= 0;
}
return false;
}
- inline LayoutDim operator[](size_t i) const {
- return layout_simplified_[i];
- }
+ inline LayoutDim operator[](size_t i) const { return layout_simplified_[i]; }
/*! \return whether the layout is defined */
- inline bool defined() const {
- return name_ != "__undef__";
- }
+ inline bool defined() const { return name_ != "__undef__"; }
/*! \return the string description of the layout */
- inline const std::string& name() const {
- return name_;
- }
+ inline const std::string& name() const { return name_; }
/*!
* \brief Write layout in JSON format.
* \param writer JSONWriter
*/
- inline void Save(dmlc::JSONWriter* writer) const {
- writer->Write(name_);
- }
+ inline void Save(dmlc::JSONWriter* writer) const { writer->Write(name_); }
/*!
* \brief Load layout from JSON.
const LayoutDim c = layout.at(i);
if (is_superdim(c)) {
int pos = c - 'A';
- CHECK_EQ(factor, 0) << "Invalid layout " << layout
- << ": invalid factor size " << factor
+ CHECK_EQ(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor
<< " before dimension " << c;
- CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout
- << ": duplicate dimension " << c;
+ CHECK_EQ(superdim_pos_[pos], -1)
+ << "Invalid layout " << layout << ": duplicate dimension " << c;
superdim_pos_[pos] = curr++;
layout_simplified_.push_back(c);
} else if (is_subdim(c)) {
int pos = c - 'a';
- CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
- << factor << " for dimension " << c;
- CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout
- << ": duplicate dimension " << c;
- CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout
- << ": duplicate dimension " << c;
+ CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor
+ << " for dimension " << c;
+ CHECK_EQ(subdim_pos_[pos], -1)
+ << "Invalid layout " << layout << ": duplicate dimension " << c;
+ CHECK_EQ(subdim_size_[pos], -1)
+ << "Invalid layout " << layout << ": duplicate dimension " << c;
subdim_pos_[pos] = curr++;
subdim_size_[pos] = factor;
layout_simplified_.push_back(c);
}
CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
for (LayoutDim dim : layout_simplified_) {
- CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0)
- << "Invalid layout " << layout << ": missing axis "
- << static_cast<char>(dim - 'a' + 'A');
+ CHECK(is_superdim(dim) || superdim_pos_[dim - 'a'] >= 0)
+ << "Invalid layout " << layout << ": missing axis " << static_cast<char>(dim - 'a' + 'A');
}
}
};
#include <memory>
#include <string>
-#include <vector>
-#include <utility>
#include <unordered_map>
+#include <utility>
+#include <vector>
+
#include "base.h"
-#include "op.h"
#include "c_api.h"
+#include "op.h"
namespace nnvm {
/*! \brief an entry that represents output data from a node */
struct NodeEntry {
- NodeEntry(ObjectPtr node, uint32_t index, uint32_t version):
- node(std::move(node)),
- index(index),
- version(version)
- {}
-
- explicit NodeEntry(ObjectPtr node):
- node(std::move(node)),
- index(),
- version()
- {}
+ NodeEntry(ObjectPtr node, uint32_t index, uint32_t version)
+ : node(std::move(node)), index(index), version(version) {}
+
+ explicit NodeEntry(ObjectPtr node) : node(std::move(node)), index(), version() {}
/**
* MXNet assumes that a node with a null ptr doesn't have a gradient attached. Don't change this
* constructor.
*/
- NodeEntry():
- node(nullptr),
- index(),
- version()
- {}
+ NodeEntry() : node(nullptr), index(), version() {}
/*! \brief the source node of this data */
ObjectPtr node;
* \brief version of input Variable.
* This field can only be nonzero when this->node is a Variable node.
* version is increased by one each time a Variable get composed to a mutation Op.
- * This information can be helpful to decide order of operations when sequence of mutation happens.
+ * This information can be helpful to decide order of operations when sequence of mutation
+ * happens.
*/
uint32_t version;
};
*/
struct NodeEntryHash {
size_t operator()(const NodeEntry& e) const {
- return std::hash<Node*>()(e.node.get()) ^
- (std::hash<size_t>()(e.index) << 1 >> 1) ^
- (std::hash<size_t>()(e.version) << 1);
+ return std::hash<Node*>()(e.node.get()) ^ (std::hash<size_t>()(e.index) << 1 >> 1) ^
+ (std::hash<size_t>()(e.version) << 1);
}
};
*/
struct NodeEntryEqual {
size_t operator()(const NodeEntry& a, const NodeEntry& b) const {
- return (a.node.get() == b.node.get()) &&
- (a.index == b.index) &&
- (a.version == b.version);
+ return (a.node.get() == b.node.get()) && (a.index == b.index) && (a.version == b.version);
}
};
/*! use NodeEntry as key in unordered_map */
-template<typename ValueType>
+template <typename ValueType>
using NodeEntryMap = std::unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>;
/*!
* \brief The operator this node uses.
* For place holder variable, op == nullptr.
*/
- const Op *op{nullptr};
+ const Op* op{nullptr};
/*! \brief name of the node */
std::string name;
/*! \brief The dictionary representation of attributes */
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
*/
- template<class ...Args>
+ template <class... Args>
static ObjectPtr Create(Args&&... args) {
return std::make_shared<Node>(std::forward<Args>(args)...);
}
* \param attrs The attributes
* \return The created node entry.
*/
-inline NodeEntry MakeNode(
- const char* op_name,
- std::string node_name,
- std::vector<NodeEntry> inputs,
- std::unordered_map<std::string, std::string> attrs =
- std::unordered_map<std::string, std::string>()) {
+inline NodeEntry MakeNode(const char* op_name, std::string node_name, std::vector<NodeEntry> inputs,
+ std::unordered_map<std::string, std::string> attrs =
+ std::unordered_map<std::string, std::string>()) {
ObjectPtr p = Node::Create();
p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name);
}
// implementation of functions.
-inline const Op* Node::op() const {
- return this->attrs.op;
-}
+inline const Op* Node::op() const { return this->attrs.op; }
-inline bool Node::is_variable() const {
- return this->op() == nullptr;
-}
+inline bool Node::is_variable() const { return this->op() == nullptr; }
inline uint32_t Node::num_outputs() const {
if (is_variable()) return 1;
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define NNVM_OP_H_
#include <dmlc/parameter.h>
+
+#include <functional>
+#include <limits>
#include <string>
-#include <vector>
-#include <utility>
#include <typeinfo>
-#include <limits>
-#include <functional>
+#include <utility>
+#include <vector>
+
#include "base.h"
#include "c_api.h"
// forward declarations
class Node;
struct NodeAttrs;
-template<typename ValueType>
+template <typename ValueType>
class OpMap;
class OpGroup;
class OpRegistryEntry;
* \param description Description of the argument.
* \return reference to self.
*/
- inline Op& add_argument(const std::string &name,
- const std::string &type,
- const std::string &description);
+ inline Op& add_argument(const std::string& name, const std::string& type,
+ const std::string& description);
/*!
* \brief Append list if arguments to the end.
* \param args Additional list of arguments.
* \return reference to self.
*/
- inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);
+ inline Op& add_arguments(const std::vector<ParamFieldInfo>& args);
/*!
* \brief Set the num_inputs
* \param n The number of inputs to be set.
* \param fn The function to be set.
* \return reference to self.
*/
- inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
+ inline Op& set_num_inputs(std::function<uint32_t(const NodeAttrs& attr)> fn); // NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
* \param fn The function to be set.
* \return reference to self.
*/
- inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
+ inline Op& set_num_outputs(std::function<uint32_t(const NodeAttrs& attr)> fn); // NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
* \return reference to self.
*/
- inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
+ inline Op& set_attr_parser(std::function<void(NodeAttrs* attrs)> fn); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
*
* \tparam ValueType The type of the value to be set.
*/
- template<typename ValueType>
+ template <typename ValueType>
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
- const ValueType& value,
- int plevel = 10);
+ const ValueType& value, int plevel = 10);
/*!
* \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias)
* \return An OpMap of specified attr_name.
* \tparam ValueType The type of the attribute.
*/
- template<typename ValueType>
+ template <typename ValueType>
static const OpMap<ValueType>& GetAttr(const std::string& attr_name);
private:
- template<typename ValueType>
+ template <typename ValueType>
friend class OpMap;
friend class OpGroup;
friend class dmlc::Registry<Op>;
// get const reference to certain attribute
static const any* GetAttrMap(const std::string& key);
// update the attribute OpMap
- static void UpdateAttrMap(const std::string& key,
- std::function<void(any*)> updater);
+ static void UpdateAttrMap(const std::string& key, std::function<void(any*)> updater);
// add a trigger based on tag matching on certain tag attribute
// This will apply trigger on all the op such that
// include the corresponding group.
// The trigger will also be applied to all future registrations
// that calls include
- static void AddGroupTrigger(const std::string& group_name,
- std::function<void(Op*)> trigger);
+ static void AddGroupTrigger(const std::string& group_name, std::function<void(Op*)> trigger);
};
/*!
* and returns ValueType
* \tparam ValueType The type of the value stored in map.
*/
-template<typename ValueType>
+template <typename ValueType>
class OpMap {
public:
/*!
// internal attribute name
std::string attr_name_;
// internal data
- std::vector<std::pair<ValueType, int> > data_;
+ std::vector<std::pair<ValueType, int>> data_;
OpMap() = default;
};
*
* \tparam ValueType The type of the value to be set.
*/
- template<typename ValueType>
+ template <typename ValueType>
inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*)
- const ValueType& value,
- int plevel = 1);
+ const ValueType& value, int plevel = 1);
};
// internal macros to make
-#define NNVM_REGISTER_VAR_DEF(OpName) \
- static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
+#define NNVM_REGISTER_VAR_DEF(OpName) \
+ static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op& __make_##NnvmOp##_##OpName
-#define NNVM_REGISTER_GVAR_DEF(TagName) \
- static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
+#define NNVM_REGISTER_GVAR_DEF(TagName) \
+ static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_##NnvmOpGroup##_##TagName
/*!
* \def NNVM_REGISTER_OP
*
* \endcode
*/
-#define NNVM_REGISTER_OP(OpName) \
- DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
+#define NNVM_REGISTER_OP(OpName) \
+ DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
/*!
*
* \endcode
*/
-#define NNVM_REGISTER_OP_GROUP(GroupName) \
- DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
- ::nnvm::OpGroup {#GroupName}
+#define NNVM_REGISTER_OP_GROUP(GroupName) \
+ DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = ::nnvm::OpGroup { #GroupName }
// implementations of template functions after this.
// member function of Op
-template<typename ValueType>
+template <typename ValueType>
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
const any* ref = GetAttrMap(key);
if (ref == nullptr) {
// update the attribute map of the key by creating new empty OpMap
UpdateAttrMap(key, [key](any* pmap) {
- // use callback so it is in lockscope
- if (pmap->empty()) {
- OpMap<ValueType> pm;
- pm.attr_name_ = key;
- *pmap = std::move(pm);
- }
- });
+ // use callback so it is in lockscope
+ if (pmap->empty()) {
+ OpMap<ValueType> pm;
+ pm.attr_name_ = key;
+ *pmap = std::move(pm);
+ }
+ });
ref = GetAttrMap(key);
}
- return nnvm::get<OpMap<ValueType> >(*ref);
+ return nnvm::get<OpMap<ValueType>>(*ref);
}
-template<typename ValueType>
+template <typename ValueType>
inline Op& Op::set_attr( // NOLINT(*)
- const std::string& attr_name,
- const ValueType& value,
- int plevel) {
- CHECK_GT(plevel, 0)
- << "plevel in set_attr must be greater than 0";
+ const std::string& attr_name, const ValueType& value, int plevel) {
+ CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
// update the attribute map of the key by creating new empty if needed.
- UpdateAttrMap(attr_name,
- [this, attr_name, value, plevel](any* pmap) {
- // the callback is in lockscope so is threadsafe.
- if (pmap->empty()) {
- OpMap<ValueType> pm;
- pm.attr_name_ = attr_name;
- *pmap = std::move(pm);
- }
- CHECK(pmap->type() == typeid(OpMap<ValueType>))
- << "Attribute " << attr_name
- << " of operator " << this->name
- << " is registered as inconsistent types"
- << " previously " << pmap->type().name()
- << " current " << typeid(OpMap<ValueType>).name();
- std::vector<std::pair<ValueType, int> >& vec =
- nnvm::get<OpMap<ValueType> >(*pmap).data_;
- // resize the value type.
- if (vec.size() <= index_) {
- vec.resize(index_ + 1,
- std::make_pair(ValueType(), 0));
- }
- std::pair<ValueType, int>& p = vec[index_];
- CHECK(p.second != plevel)
- << "Attribute " << attr_name
- << " of operator " << this->name
- << " is already registered with same plevel=" << plevel;
- if (p.second < plevel) {
- vec[index_] = std::make_pair(value, plevel);
- }
- });
+ UpdateAttrMap(attr_name, [this, attr_name, value, plevel](any* pmap) {
+ // the callback is in lockscope so is threadsafe.
+ if (pmap->empty()) {
+ OpMap<ValueType> pm;
+ pm.attr_name_ = attr_name;
+ *pmap = std::move(pm);
+ }
+ CHECK(pmap->type() == typeid(OpMap<ValueType>))
+ << "Attribute " << attr_name << " of operator " << this->name
+ << " is registered as inconsistent types"
+ << " previously " << pmap->type().name() << " current " << typeid(OpMap<ValueType>).name();
+ std::vector<std::pair<ValueType, int>>& vec = nnvm::get<OpMap<ValueType>>(*pmap).data_;
+ // resize the value type.
+ if (vec.size() <= index_) {
+ vec.resize(index_ + 1, std::make_pair(ValueType(), 0));
+ }
+ std::pair<ValueType, int>& p = vec[index_];
+ CHECK(p.second != plevel) << "Attribute " << attr_name << " of operator " << this->name
+ << " is already registered with same plevel=" << plevel;
+ if (p.second < plevel) {
+ vec[index_] = std::make_pair(value, plevel);
+ }
+ });
return *this;
}
-
inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
this->description = descr;
return *this;
}
-inline Op& Op::add_argument(const std::string &name,
- const std::string &type,
- const std::string &description) {
+inline Op& Op::add_argument(const std::string& name, const std::string& type,
+ const std::string& description) {
arguments.push_back({name, type, type, description});
return *this;
}
-inline Op& Op::add_arguments(const std::vector<ParamFieldInfo> &args) {
+inline Op& Op::add_arguments(const std::vector<ParamFieldInfo>& args) {
this->arguments.insert(arguments.end(), args.begin(), args.end());
return *this;
}
return *this;
}
-inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
+inline Op& Op::set_num_inputs(std::function<uint32_t(const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_inputs = fn;
return *this;
}
return *this;
}
-inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
+inline Op& Op::set_num_outputs(std::function<uint32_t(const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_outputs = fn;
return *this;
}
-inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // NOLINT(*)
+inline Op& Op::set_attr_parser(std::function<void(NodeAttrs* attrs)> fn) { // NOLINT(*)
this->attr_parser = fn;
return *this;
}
// member functions of OpMap
-template<typename ValueType>
+template <typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const {
if (contains(op)) {
return 1;
}
}
-template<typename ValueType>
+template <typename ValueType>
inline bool OpMap<ValueType>::contains(const Op* op) const {
if (op == nullptr) {
return false;
return idx < data_.size() ? (data_[idx].second != 0) : false;
}
-template<typename ValueType>
+template <typename ValueType>
inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
CHECK(op != nullptr);
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second)
- << "Attribute " << attr_name_
- << " has not been registered for Operator " << op->name;
+ << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name;
return data_[idx].first;
}
-template<typename ValueType>
+template <typename ValueType>
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
if (op == nullptr) return def_value;
const uint32_t idx = op->index_;
}
}
-template<typename ValueType>
-inline OpGroup& OpGroup::set_attr(const std::string& attr_name,
- const ValueType& value,
+template <typename ValueType>
+inline OpGroup& OpGroup::set_attr(const std::string& attr_name, const ValueType& value,
int plevel) {
auto trigger = [attr_name, value, plevel](Op* op) {
op->set_attr<ValueType>(attr_name, value, plevel);
#ifndef NNVM_OP_ATTR_TYPES_H_
#define NNVM_OP_ATTR_TYPES_H_
-#include <vector>
-#include <string>
-#include <utility>
#include <functional>
+#include <string>
#include <unordered_map>
+#include <utility>
+#include <vector>
+
#include "base.h"
+#include "layout.h"
#include "node.h"
#include "tuple.h"
-#include "layout.h"
namespace nnvm {
*
* FListInputNames enables automatic variable creation for missing arguments.
*/
-using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
+using FListInputNames = std::function<std::vector<std::string>(const NodeAttrs& attrs)>;
/*!
* \brief Return number of visible outputs by the user.
* but the additional outputs can be used to pass information from
* forward to gradient pass.
*/
-using FNumVisibleOutputs = std::function<uint32_t (const NodeAttrs& attrs)>;
+using FNumVisibleOutputs = std::function<uint32_t(const NodeAttrs& attrs)>;
/*!
* \brief Return list of output arguments names of each operator.
*
* FListOutputNames customized naming for operator outputs.
*/
-using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
+using FListOutputNames = std::function<std::vector<std::string>(const NodeAttrs& attrs)>;
/*!
* \brief Check whether operator will mutate k-th input.
* \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly.
*/
-using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>;
+using FMutateInputs = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
/*!
* \brief Inference function of certain type.
* \tparam AttrType The type of the attribute to be infered.
* \return whether all attributes are inferred.
*/
-template<typename AttrType>
-using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
- std::vector<AttrType> *in_attrs,
- std::vector<AttrType> *out_attrs)>;
+template <typename AttrType>
+using FInferNodeEntryAttr = std::function<bool(
+ const NodeAttrs& attrs, std::vector<AttrType>* in_attrs, std::vector<AttrType>* out_attrs)>;
/*!
* \brief Get attribute dictionary from node.
* \return The attribute dict.
* \note Register under "FUpdateAttrDict"
*/
-using FGetAttrDict = std::function<
- std::unordered_map<std::string, std::string>
- (const NodeAttrs& attrs)>;
+using FGetAttrDict =
+ std::function<std::unordered_map<std::string, std::string>(const NodeAttrs& attrs)>;
/*!
* \brief Shape inference function.
*
* \note Register under "FInplaceOption", by default no inplace can happen.
*/
-using FInplaceOption = std::function<
- std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
+using FInplaceOption = std::function<std::vector<std::pair<int, int> >(const NodeAttrs& attrs)>;
/*!
* \brief Get if the inplace option is an identity
*
* \note Register under "FInplaceIdentity", by default no identities.
*/
-using FInplaceIdentity = std::function<std::vector<bool> (const NodeAttrs& attrs)>;
+using FInplaceIdentity = std::function<std::vector<bool>(const NodeAttrs& attrs)>;
/*!
* \brief Get list of inputs in the op whose content are actually not used by the operator
*
* \note Register under "FIgnoreInputs".
*/
-using FIgnoreInputs = std::function<
- std::vector<uint32_t> (const NodeAttrs& attrs)>;
+using FIgnoreInputs = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
/*!
* \brief Get the gradient node of the op node
*
* \note Register under "FGradient"
*/
-using FGradient = std::function<std::vector<NodeEntry>(
- const ObjectPtr& nodeptr,
- const std::vector<NodeEntry>& out_grads)>;
+using FGradient = std::function<std::vector<NodeEntry>(const ObjectPtr& nodeptr,
+ const std::vector<NodeEntry>& out_grads)>;
/*!
* \brief Set the attributes of input variable.
* \param var the input variable
* \param index index of var in all inputs
*/
-using FSetInputVarAttrOnCompose = std::function<void(
- const NodeAttrs& attrs,
- ObjectPtr var,
- const int index)>;
+using FSetInputVarAttrOnCompose =
+ std::function<void(const NodeAttrs& attrs, ObjectPtr var, const int index)>;
/*!
* \brief Infer & correct function of node layout. See \p Layout for layout convention
* \param olayouts Inferred output layouts.
* \return success flag.
*/
-using FCorrectLayout = std::function<bool(
- const NodeAttrs& attrs,
- std::vector<Layout> *ilayouts,
- const std::vector<Layout> *last_ilayouts,
- std::vector<Layout> *olayouts)>;
+using FCorrectLayout =
+ std::function<bool(const NodeAttrs& attrs, std::vector<Layout>* ilayouts,
+ const std::vector<Layout>* last_ilayouts, std::vector<Layout>* olayouts)>;
/*!
* \brief Get a list of inputs that represent graphs instead of data.
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#ifndef NNVM_PASS_H_
#define NNVM_PASS_H_
-#include <vector>
#include <functional>
+#include <vector>
+
#include "base.h"
#include "graph.h"
* \param src The graph to be transformed.
* \return The generated graph.
*/
-typedef std::function<Graph (Graph src)> PassFunction;
+typedef std::function<Graph(Graph src)> PassFunction;
/*!
* \brief Apply a series of pass transformations on the input graph.
* \param passes A list of pass names to be applied.
* \return The transformed graph
*/
-Graph ApplyPasses(Graph src,
- const std::vector<std::string>& passes);
+Graph ApplyPasses(Graph src, const std::vector<std::string>& passes);
/*!
* \brief Apply one pass to the graph.
* \param pass The name of pass to be applied.
* \return The transformed graph.
*/
-inline Graph ApplyPass(Graph src, const std::string& pass) {
- return ApplyPasses(src, {pass});
-}
-
+inline Graph ApplyPass(Graph src, const std::string& pass) { return ApplyPasses(src, {pass}); }
/*!
* \brief Registry entry for pass functions.
*/
-struct PassFunctionReg
- : public dmlc::FunctionRegEntryBase<PassFunctionReg,
- PassFunction> {
+struct PassFunctionReg : public dmlc::FunctionRegEntryBase<PassFunctionReg, PassFunction> {
/*!
* \brief Whether the pass will change graph structure
* If this is false, the pass will only change attributes.
* });
* \endcode
*/
-#define NNVM_REGISTER_PASS(name) \
+#define NNVM_REGISTER_PASS(name) \
DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name)
} // namespace nnvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#ifndef NNVM_PASS_FUNCTIONS_H_
#define NNVM_PASS_FUNCTIONS_H_
-#include <string>
#include <memory>
-#include <vector>
+#include <string>
#include <utility>
+#include <vector>
+
#include "base.h"
-#include "pass.h"
#include "graph_attr_types.h"
+#include "pass.h"
namespace nnvm {
namespace pass {
return ret.GetAttr<std::string>("json");
}
-
/*!
* \brief Print graph ir
* \param graph The graph to be printed
* \param src The input graph.
* \return A graph with proper control flow dependencies added.
*/
-inline Graph OrderMutation(Graph src) {
- return ApplyPass(std::move(src), "OrderMutation");
-}
+inline Graph OrderMutation(Graph src) { return ApplyPass(std::move(src), "OrderMutation"); }
/*!
* \brief Infer shapes in the graph given the information.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/
-inline Graph InferShape(Graph graph,
- ShapeVector shape_inputs,
- std::string shape_attr_key = "") {
+inline Graph InferShape(Graph graph, ShapeVector shape_inputs, std::string shape_attr_key = "") {
if (shape_inputs.size() != 0) {
graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs));
}
* \return A graph with new attribute "dtype" containing inferred type of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/
-inline Graph InferType(Graph graph,
- DTypeVector dtype_inputs,
- std::string dtype_attr_key = "") {
+inline Graph InferType(Graph graph, DTypeVector dtype_inputs, std::string dtype_attr_key = "") {
if (dtype_inputs.size() != 0) {
graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs));
}
* \param device_copy_op The name of copy op to be inserted when cross device copy happened.
* \return A graph with new attribute "device", cotaining device information of each node.
*/
-inline Graph PlaceDevice(Graph graph,
- std::string device_group_attr_key,
- DeviceAssignMap device_assign_map,
- std::string device_copy_op) {
+inline Graph PlaceDevice(Graph graph, std::string device_group_attr_key,
+ DeviceAssignMap device_assign_map, std::string device_copy_op) {
graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key));
graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map));
graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op));
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun Aggregation function applied to aggregate the inputs.
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
- * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like.
- * \param zero_ops Optional, list of operators that outputs a single zero array. The first one
- * must be zeros_like.
- * \param copy_op_str Optional, name of the copy operation required to handle duplicates
- * on the edge of the graph
- * \return A new graph, whose outputs correspond to inputs of xs.
+ * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same
+ * as like. \param zero_ops Optional, list of operators that outputs a single zero array. The first
+ * one must be zeros_like. \param copy_op_str Optional, name of the copy operation required to
+ * handle duplicates on the edge of the graph \return A new graph, whose outputs correspond to
+ * inputs of xs.
*/
inline Graph Gradient(
- Graph graph,
- std::vector<NodeEntry> ys,
- std::vector<NodeEntry> xs,
+ Graph graph, std::vector<NodeEntry> ys, std::vector<NodeEntry> xs,
std::vector<NodeEntry> ys_out_grad,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr,
- std::function<NodeEntry(const NodeEntry& src, const NodeEntry &like)>
- attr_hint_fun = nullptr,
+ std::function<NodeEntry(const NodeEntry& src, const NodeEntry& like)> attr_hint_fun = nullptr,
std::vector<const Op*> zero_ops = std::vector<const Op*>(),
std::string copy_op_str = std::string()) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
}
if (copy_op_str != std::string()) {
- graph.attrs["copy_op"] = std::make_shared<any>(std::move(copy_op_str));
+ graph.attrs["copy_op"] = std::make_shared<any>(std::move(copy_op_str));
}
return ApplyPass(std::move(graph), "Gradient");
#define NNVM_SYMBOLIC_H_
#include <string>
-#include <vector>
#include <tuple>
-#include <utility>
#include <unordered_map>
+#include <utility>
+#include <vector>
#include "base.h"
#include "node.h"
* \brief Print the symbol info to output stream.
* \param os The output stream to print to.
*/
- void Print(std::ostream &os) const; // NOLINT(*)
+ void Print(std::ostream& os) const; // NOLINT(*)
/*!
* \brief Get the index-th element from the returned tuple.
* \param index Index of multi output.
* \return The symbol corresponds to the indexed element.
*/
- Symbol operator[] (size_t index) const;
+ Symbol operator[](size_t index) const;
/*!
* \brief List the input variable nodes.
*
* \param name Name of returned symbol.
* \return A new Symbol which is the composition of current symbol with its arguments.
*/
- Symbol operator () (const array_view<const Symbol*>& args,
- const std::unordered_map<std::string, const Symbol*>& kwargs,
- const std::string& name) const;
+ Symbol operator()(const array_view<const Symbol*>& args,
+ const std::unordered_map<std::string, const Symbol*>& kwargs,
+ const std::string& name) const;
/*!
* \brief Add control flow dependencies to the operators in symbols.
*
*
* \return The created attribute in format <operator_name, key, value>.
*/
- std::vector<std::tuple<std::string, std::string, std::string> >
- ListAttrsRecursive() const;
+ std::vector<std::tuple<std::string, std::string, std::string> > ListAttrsRecursive() const;
/*!
* \brief Create symbolic functor(AtomicSymbol) by given operator and attributes.
* \param op The operator.
* \param attrs The additional attributes.
* \return Symbol that can be used to call compose further.
*/
- static Symbol CreateFunctor(const Op* op,
- std::unordered_map<std::string, std::string> attrs);
+ static Symbol CreateFunctor(const Op* op, std::unordered_map<std::string, std::string> attrs);
/*!
* \brief Create symbolic functor(AtomicSymbol) by given node attributes.
* \param attrs pre-initialized Node attributes.
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#ifndef NNVM_TUPLE_H_
#define NNVM_TUPLE_H_
-#include <vector>
-#include <type_traits>
#include <algorithm>
-#include <utility>
#include <iostream>
#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
#include "base.h"
namespace nnvm {
* \tparam ValueType The type of data stored inside tuple.
* \sa TShape
*/
-template<typename ValueType>
+template <typename ValueType>
class Tuple {
public:
/*! \brief default constructor */
Tuple() = default;
/*! \brief destructor */
- inline ~Tuple() {
- delete [] data_heap_;
- }
+ inline ~Tuple() { delete[] data_heap_; }
/*!
* \brief copy constructor from another tuple
* \param s the source tuple
*/
- inline Tuple(const Tuple<ValueType>& s) {
- this->assign(s.begin(), s.end());
- }
+ inline Tuple(const Tuple<ValueType>& s) { this->assign(s.begin(), s.end()); }
/*!
* \brief constructor from initializer list
* \param init the initializer_list
*/
- inline Tuple(std::initializer_list<ValueType> init) {
- this->assign(init.begin(), init.end());
- }
+ inline Tuple(std::initializer_list<ValueType> init) { this->assign(init.begin(), init.end()); }
/*!
* \brief constructor from vector
* \param init the vector
* \param src the source shape
*/
- inline Tuple(Tuple<ValueType>&& src) { // NOLINT(runtime/explicit)
+ inline Tuple(Tuple<ValueType>&& src) { // NOLINT(runtime/explicit)
this->swap(src);
}
/*!
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
- template<typename RandomAccessIterator>
- inline Tuple(RandomAccessIterator begin,
- RandomAccessIterator end) {
+ template <typename RandomAccessIterator>
+ inline Tuple(RandomAccessIterator begin, RandomAccessIterator end) {
this->assign(begin, end);
}
/*!
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
- template<typename RandomAccessIterator>
- inline void assign(RandomAccessIterator begin,
- RandomAccessIterator end) {
+ template <typename RandomAccessIterator>
+ inline void assign(RandomAccessIterator begin, RandomAccessIterator end) {
this->SetDim(end - begin);
std::copy(begin, end, this->begin());
}
* \param init the source initializer list
* \return reference of self
*/
- inline Tuple<ValueType> &operator=(std::initializer_list<ValueType> init) {
+ inline Tuple<ValueType>& operator=(std::initializer_list<ValueType> init) {
this->assign(init.begin(), init.end());
return *this;
}
* \return whether two tuple equals
* \param s the tuple to compare against
*/
- inline bool operator==(const Tuple<ValueType> &s) const {
+ inline bool operator==(const Tuple<ValueType>& s) const {
if (ndim_ != s.ndim_) return false;
return std::equal(begin(), end(), s.begin());
}
* \return whether two tuple not equal
* \param s the tuple to compare against
*/
- inline bool operator!=(const Tuple<ValueType> &s) const {
- return !(*this == s);
- }
+ inline bool operator!=(const Tuple<ValueType>& s) const { return !(*this == s); }
/*! \return the begin data pointer to content of the tuple */
- inline const ValueType *begin() const {
- return ndim_ <= kStackCache ? data_stack_ : data_heap_;
- }
+ inline const ValueType* begin() const { return ndim_ <= kStackCache ? data_stack_ : data_heap_; }
/*! \return the begin data pointer to content of the tuple */
- inline ValueType *begin() {
- return ndim_ <= kStackCache ? data_stack_ : data_heap_;
- }
+ inline ValueType* begin() { return ndim_ <= kStackCache ? data_stack_ : data_heap_; }
/*! \return the data pointer to end of the tuple */
inline const ValueType* end() const {
- return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
+ return ndim_ <= kStackCache ? (data_stack_ + ndim_) : (data_heap_ + ndim_);
}
/*! \return the data pointer to end the tuple */
inline ValueType* end() {
- return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
+ return ndim_ <= kStackCache ? (data_stack_ + ndim_) : (data_heap_ + ndim_);
}
/*! \return number of dimension of the tuple */
- inline uint32_t ndim() const {
- return ndim_;
- }
+ inline uint32_t ndim() const { return ndim_; }
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
- inline ValueType& operator[](size_t i) {
- return begin()[i];
- }
+ inline ValueType& operator[](size_t i) { return begin()[i]; }
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
- inline const ValueType& operator[](size_t i) const {
- return begin()[i];
- }
+ inline const ValueType& operator[](size_t i) const { return begin()[i]; }
/*!
* \brief Save Tuple to JSON.
* \param writer JSONWriter
* \param t the tuple
* \return the ostream
*/
- friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
+ friend std::ostream& operator<<(std::ostream& os, const Tuple<ValueType>& t) {
os << '[';
const ValueType* begin = t.begin();
const ValueType* end = t.end();
* \param t The tuple
* \return the istream
*/
- friend std::istream &operator>>(std::istream &is, Tuple<ValueType> &t) {
+ friend std::istream& operator>>(std::istream& is, Tuple<ValueType>& t) {
// get (
while (true) {
char ch = is.peek();
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
- }
+ }
}
// Handle empty tuple
while (isspace(is.peek())) {
while (true) {
ch = is.peek();
if (isspace(ch)) {
- is.get(); continue;
+ is.get();
+ continue;
}
if (ch == ')' || ch == ']') {
- is.get(); break;
+ is.get();
+ break;
}
break;
}
* \tparam DType data type that save to
* \tparam TStream any stream type that have write
*/
- template<typename DType = ValueType, typename TStream>
- inline void Save(TStream *strm) const;
+ template <typename DType = ValueType, typename TStream>
+ inline void Save(TStream* strm) const;
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \tparam TStream any stream type that have write
* \return whether the load is successful
*/
- template<typename DType = ValueType, typename TStream>
- inline bool Load(TStream *strm);
+ template <typename DType = ValueType, typename TStream>
+ inline bool Load(TStream* strm);
protected:
// stack cache size
ValueType* data_heap_{nullptr};
// internal function to change the dimension
inline void SetDim(uint32_t ndim) {
- if (ndim > kStackCache &&
- ndim > num_heap_allocated_) {
- delete [] data_heap_;
+ if (ndim > kStackCache && ndim > num_heap_allocated_) {
+ delete[] data_heap_;
data_heap_ = new ValueType[ndim];
num_heap_allocated_ = ndim;
}
* \brief copy constructor of TShape
* \param s source shape.
*/
- inline TShape(const Tuple<dim_t>& s) { // NOLINT(*)
+ inline TShape(const Tuple<dim_t>& s) { // NOLINT(*)
this->assign(s.begin(), s.end());
}
/*!
* \brief constructor from initializer list
* \param init the initializer_list
*/
- inline TShape(std::initializer_list<dim_t> init) {
- this->assign(init.begin(), init.end());
- }
+ inline TShape(std::initializer_list<dim_t> init) { this->assign(init.begin(), init.end()); }
/*!
* \brief move constructor.
* \param s source shape.
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
- template<typename RandomAccessIterator>
- inline TShape(RandomAccessIterator begin,
- RandomAccessIterator end) {
+ template <typename RandomAccessIterator>
+ inline TShape(RandomAccessIterator begin, RandomAccessIterator end) {
this->assign(begin, end);
}
/*!
* \return self.
*/
inline TShape& operator=(Tuple<dim_t>&& src) { // NOLINT(*)
- TShape(std::move(src)).swap(*this); // NOLINT(*)
+ TShape(std::move(src)).swap(*this); // NOLINT(*)
return *this;
}
/*! \return total number of elements in the shape */
inline size_t Size() const {
dim_t size = 1;
- const dim_t* start = begin(), *fin = end();
+ const dim_t *start = begin(), *fin = end();
for (const dim_t* it = start; it != fin; ++it) {
size *= *it;
}
*/
inline size_t ProdShape(int dimstart, int dimend) const {
dim_t num = 1;
- const dim_t *d = this->data();
+ const dim_t* d = this->data();
for (int i = dimstart; i < dimend; ++i) {
num *= d[i];
}
return num;
}
/*! \return the begin data pointer to content of the tuple */
- inline const dim_t *data() const {
- return begin();
- }
+ inline const dim_t* data() const { return begin(); }
/*! \return the begin data pointer to content of the tuple */
- inline dim_t *data() {
- return begin();
- }
+ inline dim_t* data() { return begin(); }
#ifdef MSHADOW_XINLINE
- template<int dim>
- inline TShape(const mshadow::Shape<dim> &s) {// NOLINT(*)
+ template <int dim>
+ inline TShape(const mshadow::Shape<dim>& s) { // NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}
- template<int dim>
- inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*)
+ template <int dim>
+ inline TShape(mshadow::Shape<dim>&& s) { // NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}
/*!
* \tparam dim shape dimension
* \return reference of self
*/
- template<int dim>
- inline TShape &operator=(const mshadow::Shape<dim> &shape) {
+ template <int dim>
+ inline TShape& operator=(const mshadow::Shape<dim>& shape) {
this->assign(shape.shape_, shape.shape_ + dim);
return *this;
}
* \return the shape requested
* \tparam dim dimension of the tensor
*/
- template<int dim>
+ template <int dim>
inline mshadow::Shape<dim> get() const {
CHECK_EQ(dim, static_cast<int>(ndim()))
<< "dimension do not match target dimension " << dim << " vs " << ndim();
- const dim_t *d = this->data();
+ const dim_t* d = this->data();
mshadow::Shape<dim> s;
for (int i = 0; i < dim; ++i) {
s[i] = d[i];
inline mshadow::Shape<2> FlatTo2D(void) const {
mshadow::Shape<2> s;
if (ndim() == 0) return mshadow::Shape2(0, 0);
- const dim_t *d = this->data();
+ const dim_t* d = this->data();
s.shape_[1] = d[ndim() - 1];
dim_t ymax = 1;
for (size_t i = 1; i < ndim(); ++i) {
CHECK(axis_end >= axis_begin);
mshadow::Shape<3> s;
if (ndim() == 0) return mshadow::Shape3(0, 0, 0);
- const dim_t *d = this->data();
+ const dim_t* d = this->data();
s.shape_[0] = 1;
s.shape_[1] = 1;
s.shape_[2] = 1;
* \param axis The axis specified.
* \return the flat 3d shape
*/
- inline mshadow::Shape<3> FlatTo3D(size_t axis) const {
- return FlatTo3D(axis, axis);
- }
- inline bool operator==(const TShape &s) const {
+ inline mshadow::Shape<3> FlatTo3D(size_t axis) const { return FlatTo3D(axis, axis); }
+ inline bool operator==(const TShape& s) const {
if (ndim() != s.ndim()) return false;
return std::equal(begin(), end(), s.begin());
}
- inline bool operator!=(const TShape &s) const {
- return !(*this == s);
- }
+ inline bool operator!=(const TShape& s) const { return !(*this == s); }
/*!
* \return whether two shape equals
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
- template<int dim>
- inline bool operator==(const mshadow::Shape<dim> &s) const {
+ template <int dim>
+ inline bool operator==(const mshadow::Shape<dim>& s) const {
if (ndim_ != dim) return false;
- const dim_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
+ const dim_t* d = dim <= kStackCache ? data_stack_ : data_heap_;
for (size_t i = 0; i < dim; ++i) {
if (d[i] != s.shape_[i]) return false;
}
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
- template<int dim>
- inline bool operator!=(const mshadow::Shape<dim> &s) const {
+ template <int dim>
+ inline bool operator!=(const mshadow::Shape<dim>& s) const {
return !(*this == s);
}
#endif
};
/*! \brief helper function to cast type of container elements */
-template<typename SrcIter, typename DstIter>
-inline DstIter ShapeTypeCast(const SrcIter begin,
- const SrcIter end,
- DstIter dst_begin) {
+template <typename SrcIter, typename DstIter>
+inline DstIter ShapeTypeCast(const SrcIter begin, const SrcIter end, DstIter dst_begin) {
typedef typename std::iterator_traits<SrcIter>::value_type SrcDType;
typedef typename std::iterator_traits<DstIter>::value_type DstDType;
auto cast = [](const SrcDType& dim) { return static_cast<DstDType>(dim); };
}
/*! \brief helper function to transform a container to TShape with type cast */
-template<typename SrcIter>
+template <typename SrcIter>
inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) {
size_t ndim = std::distance(begin, end);
TShape res(ndim);
}
/*! \tparam ValueType The type of data stored inside tuple. */
-template<typename ValueType>
-template<typename DType, typename TStream>
-inline void Tuple<ValueType>::Save(TStream *strm) const {
+template <typename ValueType>
+template <typename DType, typename TStream>
+inline void Tuple<ValueType>::Save(TStream* strm) const {
strm->Write(&ndim_, sizeof(ndim_));
if (typeid(DType) == typeid(ValueType)) {
strm->Write(begin(), sizeof(ValueType) * ndim_);
}
/*! \tparam ValueType The type of data stored inside tuple. */
-template<typename ValueType>
-template<typename DType, typename TStream>
-inline bool Tuple<ValueType>::Load(TStream *strm) {
+template <typename ValueType>
+template <typename DType, typename TStream>
+inline bool Tuple<ValueType>::Load(TStream* strm) {
if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false;
this->SetDim(ndim_);
size_t nread = sizeof(DType) * ndim_;
namespace std {
/*! \brief hash function for Tuple. */
-template<typename T>
+template <typename T>
struct hash<nnvm::Tuple<T> > {
/*! \brief hash a Tuple into unsigned int */
size_t operator()(const nnvm::Tuple<T>& val) const {
};
/*! \brief hash function for TShape. */
-template<>
+template <>
struct hash<nnvm::TShape> {
/*! \brief hash a TShape into unsigned int */
size_t operator()(const nnvm::TShape& val) const {
DMLC_DECLARE_TYPE_NAME(optional<nnvm::TShape>, "Shape or None");
// avoid low version of MSVC
#if !defined(_MSC_VER)
-template<typename T>
+template <typename T>
struct type_name_helper<nnvm::Tuple<T> > {
- static inline std::string value() {
- return "tuple of <" + type_name<T>() + ">";
- }
+ static inline std::string value() { return "tuple of <" + type_name<T>() + ">"; }
};
#endif
} // namespace dmlc
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <dmlc/thread_local.h>
#include <nnvm/c_api.h>
#include <nnvm/symbolic.h>
-#include <vector>
+
#include <string>
-#include <utility>
#include <unordered_map>
+#include <utility>
+#include <vector>
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
-#define API_END() } catch(dmlc::Error &_except_) { return NNAPIHandleException(_except_); } return 0; // NOLINT(*)
+#define API_END() \
+ } \
+ catch (dmlc::Error & _except_) { \
+ return NNAPIHandleException(_except_); \
+ } \
+ return 0; // NOLINT(*)
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
-#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return NNAPIHandleException(_except_); } return 0; // NOLINT(*)
-
+#define API_END_HANDLE_ERROR(Finalize) \
+ } \
+ catch (dmlc::Error & _except_) { \
+ Finalize; \
+ return NNAPIHandleException(_except_); \
+ } \
+ return 0; // NOLINT(*)
/*! \brief entry to to easily hold returning information */
struct NNAPIThreadLocalEntry {
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
- std::vector<const char *> ret_vec_charp;
+ std::vector<const char*> ret_vec_charp;
/*! \brief result holder for returning handles */
- std::vector<void *> ret_handles;
+ std::vector<void*> ret_handles;
/*! \brief argument holder to hold symbol */
std::unordered_map<std::string, const nnvm::Symbol*> kwarg_symbol;
};
* \param e the exception
* \return the return value of API after exception is handled
*/
-inline int NNAPIHandleException(const dmlc::Error &e) {
+inline int NNAPIHandleException(const dmlc::Error& e) {
NNAPISetLastError(e.what());
return -1;
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \brief C error handling
*/
#include <dmlc/thread_local.h>
+
#include "c_api_common.h"
struct ErrorEntry {
typedef dmlc::ThreadLocalStore<ErrorEntry> NNAPIErrorStore;
-const char *NNGetLastError() {
- return NNAPIErrorStore::Get()->last_error.c_str();
-}
+const char* NNGetLastError() { return NNAPIErrorStore::Get()->last_error.c_str(); }
-void NNAPISetLastError(const char* msg) {
- NNAPIErrorStore::Get()->last_error = msg;
-}
+void NNAPISetLastError(const char* msg) { NNAPIErrorStore::Get()->last_error = msg; }
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \file c_api_graph.cc
* \brief C API related to Graph IR.
*/
+#include <dmlc/json.h>
#include <nnvm/c_api.h>
-#include <nnvm/op.h>
-#include <nnvm/symbolic.h>
#include <nnvm/graph.h>
+#include <nnvm/op.h>
#include <nnvm/pass.h>
-#include <dmlc/json.h>
+#include <nnvm/symbolic.h>
+
#include "c_api_common.h"
using namespace nnvm;
-int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) {
+int NNGraphCreate(SymbolHandle symbol, GraphHandle* graph) {
Graph* g = new Graph();
API_BEGIN();
g->outputs = static_cast<Symbol*>(symbol)->outputs;
API_END();
}
-int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
+int NNGraphGetSymbol(GraphHandle graph, SymbolHandle* symbol) {
Symbol* s = new Symbol();
API_BEGIN();
s->outputs = static_cast<Graph*>(graph)->outputs;
API_END_HANDLE_ERROR(delete s);
}
-int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
- const char* key,
- SymbolHandle list) {
+int NNGraphSetNodeEntryListAttr_(GraphHandle handle, const char* key, SymbolHandle list) {
API_BEGIN();
Symbol* s = static_cast<Symbol*>(list);
Graph* g = static_cast<Graph*>(handle);
- g->attrs[std::string(key)]
- = std::make_shared<any>(s->outputs);
+ g->attrs[std::string(key)] = std::make_shared<any>(s->outputs);
API_END();
}
-int NNGraphSetJSONAttr(GraphHandle handle,
- const char* key,
- const char* json_value) {
+int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value) {
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
std::string temp(json_value);
API_END();
}
-int NNGraphGetJSONAttr(GraphHandle handle,
- const char* key,
- const char** json_out,
- int *success) {
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+int NNGraphGetJSONAttr(GraphHandle handle, const char* key, const char** json_out, int* success) {
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
std::string skey(key);
API_END();
}
-int NNGraphApplyPasses(GraphHandle src,
- nn_uint num_pass,
- const char** pass_names,
- GraphHandle *dst) {
+int NNGraphApplyPasses(GraphHandle src, nn_uint num_pass, const char** pass_names,
+ GraphHandle* dst) {
Graph* g = new Graph();
API_BEGIN();
std::vector<std::string> vpass;
#include <nnvm/c_api.h>
#include <nnvm/op.h>
#include <nnvm/symbolic.h>
+
#include "c_api_common.h"
using namespace nnvm;
-int NNListAllOpNames(nn_uint *out_size,
- const char*** out_array) {
+int NNListAllOpNames(nn_uint* out_size, const char*** out_array) {
API_BEGIN();
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
ret->ret_vec_str = dmlc::Registry<Op>::ListAllNames();
ret->ret_vec_charp.resize(0);
ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
API_END();
}
-int NNGetOpHandle(const char* op_name,
- OpHandle* op_out) {
+int NNGetOpHandle(const char* op_name, OpHandle* op_out) {
API_BEGIN();
*op_out = (OpHandle)Op::Get(op_name); // NOLINT(*)
API_END();
}
-int NNListUniqueOps(nn_uint *out_size,
- OpHandle **out_array) {
+int NNListUniqueOps(nn_uint* out_size, OpHandle** out_array) {
API_BEGIN();
- auto &vec = dmlc::Registry<Op>::List();
+ auto& vec = dmlc::Registry<Op>::List();
*out_size = static_cast<nn_uint>(vec.size());
*out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}
-int NNAddControlDeps(SymbolHandle handle,
- SymbolHandle src_dep) {
+int NNAddControlDeps(SymbolHandle handle, SymbolHandle src_dep) {
API_BEGIN();
- static_cast<Symbol*>(handle)->AddControlDeps(
- *static_cast<Symbol*>(src_dep));
+ static_cast<Symbol*>(handle)->AddControlDeps(*static_cast<Symbol*>(src_dep));
API_END();
}
-int NNGetOpInfo(OpHandle handle,
- const char **name,
- const char **description,
- nn_uint *num_doc_args,
- const char ***arg_names,
- const char ***arg_type_infos,
- const char ***arg_descriptions,
- const char **return_type) {
- const Op *op = static_cast<const Op *>(handle);
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+int NNGetOpInfo(OpHandle handle, const char** name, const char** description, nn_uint* num_doc_args,
+ const char*** arg_names, const char*** arg_type_infos,
+ const char*** arg_descriptions, const char** return_type) {
+ const Op* op = static_cast<const Op*>(handle);
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
*name = op->name.c_str();
API_END();
}
-int NNSymbolCreateAtomicSymbol(OpHandle creator,
- nn_uint num_param,
- const char **keys,
- const char **vals,
- SymbolHandle *out) {
- Symbol *s = new Symbol();
+int NNSymbolCreateAtomicSymbol(OpHandle creator, nn_uint num_param, const char** keys,
+ const char** vals, SymbolHandle* out) {
+ Symbol* s = new Symbol();
API_BEGIN();
const Op* op = static_cast<const Op*>(creator);
std::unordered_map<std::string, std::string> kwargs;
API_END_HANDLE_ERROR(delete s;);
}
-int NNSymbolCreateVariable(const char *name, SymbolHandle *out) {
- Symbol *s = new Symbol();
+int NNSymbolCreateVariable(const char* name, SymbolHandle* out) {
+ Symbol* s = new Symbol();
API_BEGIN();
*s = Symbol::CreateVariable(name);
*out = s;
API_END_HANDLE_ERROR(delete s);
}
-int NNSymbolCreateGroup(nn_uint num_symbols,
- SymbolHandle *symbols,
- SymbolHandle *out) {
- Symbol *s = new Symbol();
- Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*)
+int NNSymbolCreateGroup(nn_uint num_symbols, SymbolHandle* symbols, SymbolHandle* out) {
+ Symbol* s = new Symbol();
+ Symbol** sym_arr = (Symbol**)symbols; // NOLINT(*)
API_BEGIN();
std::vector<Symbol> syms;
for (nn_uint i = 0; i < num_symbols; ++i) {
API_END_HANDLE_ERROR(delete s);
}
-int NNSymbolGetOutput(SymbolHandle symbol,
- nn_uint index,
- SymbolHandle *out) {
- Symbol *s = new Symbol();
+int NNSymbolGetOutput(SymbolHandle symbol, nn_uint index, SymbolHandle* out) {
+ Symbol* s = new Symbol();
API_BEGIN();
*s = (*static_cast<Symbol*>(symbol))[index];
*out = s;
API_END_HANDLE_ERROR(delete s);
}
-int NNSymbolGetInternals(SymbolHandle symbol,
- SymbolHandle *out) {
- Symbol *s = new Symbol();
+int NNSymbolGetInternals(SymbolHandle symbol, SymbolHandle* out) {
+ Symbol* s = new Symbol();
API_BEGIN();
*s = static_cast<Symbol*>(symbol)->GetInternals();
*out = s;
API_END_HANDLE_ERROR(delete s);
}
-int NNSymbolGetChildren(SymbolHandle symbol,
- SymbolHandle *out) {
- Symbol *s = new Symbol();
+int NNSymbolGetChildren(SymbolHandle symbol, SymbolHandle* out) {
+ Symbol* s = new Symbol();
API_BEGIN();
*s = static_cast<Symbol*>(symbol)->GetChildren();
*out = s;
API_END();
}
-int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) {
- Symbol *s = new Symbol();
+int NNSymbolCopy(SymbolHandle symbol, SymbolHandle* out) {
+ Symbol* s = new Symbol();
API_BEGIN();
*s = static_cast<const Symbol*>(symbol)->Copy();
*out = s;
API_END_HANDLE_ERROR(delete s);
}
-int NNSymbolPrint(SymbolHandle symbol, const char **out_str) {
- Symbol *s = static_cast<Symbol*>(symbol);
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+int NNSymbolPrint(SymbolHandle symbol, const char** out_str) {
+ Symbol* s = static_cast<Symbol*>(symbol);
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::ostringstream os;
s->Print(os);
API_END();
}
-int NNSymbolGetAttr(SymbolHandle symbol,
- const char* key,
- const char** out,
- int* success) {
- Symbol *s = static_cast<Symbol*>(symbol);
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+int NNSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success) {
+ Symbol* s = static_cast<Symbol*>(symbol);
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
if (s->GetAttr(key, &(ret->ret_str))) {
*out = (ret->ret_str).c_str();
API_END();
}
-int NNSymbolSetAttrs(SymbolHandle symbol,
- nn_uint num_param,
- const char** keys,
- const char** vals) {
- Symbol *s = static_cast<Symbol*>(symbol);
+int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, const char** keys, const char** vals) {
+ Symbol* s = static_cast<Symbol*>(symbol);
API_BEGIN();
std::vector<std::pair<std::string, std::string> > kwargs;
for (nn_uint i = 0; i < num_param; ++i) {
- kwargs.emplace_back(
- std::make_pair(std::string(keys[i]), std::string(vals[i])));
+ kwargs.emplace_back(std::make_pair(std::string(keys[i]), std::string(vals[i])));
}
s->SetAttrs(kwargs);
API_END();
}
-int NNSymbolListAttrs(SymbolHandle symbol,
- int option,
- nn_uint *out_size,
- const char*** out) {
- Symbol *s = static_cast<Symbol*>(symbol);
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+int NNSymbolListAttrs(SymbolHandle symbol, int option, nn_uint* out_size, const char*** out) {
+ Symbol* s = static_cast<Symbol*>(symbol);
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::unordered_map<std::string, std::string> attr =
s->ListAttrs(static_cast<Symbol::ListAttrOption>(option)); // NOLINT(*)
API_END();
}
-int NNSymbolListInputVariables(SymbolHandle symbol,
- int option,
- nn_uint *out_size,
+int NNSymbolListInputVariables(SymbolHandle symbol, int option, nn_uint* out_size,
SymbolHandle** out_sym_array) {
- Symbol *s = static_cast<Symbol*>(symbol);
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+ Symbol* s = static_cast<Symbol*>(symbol);
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::vector<ObjectPtr> vs = s->ListInputs(Symbol::ListInputOption(option));
ret->ret_handles.resize(0);
API_END();
}
-int NNSymbolListInputNames(SymbolHandle symbol,
- int option,
- nn_uint *out_size,
- const char ***out_str_array) {
- Symbol *s = static_cast<Symbol*>(symbol);
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+int NNSymbolListInputNames(SymbolHandle symbol, int option, nn_uint* out_size,
+ const char*** out_str_array) {
+ Symbol* s = static_cast<Symbol*>(symbol);
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
- ret->ret_vec_str =
- s->ListInputNames(Symbol::ListInputOption(option));
+ ret->ret_vec_str = s->ListInputNames(Symbol::ListInputOption(option));
ret->ret_vec_charp.resize(0);
ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
API_END();
}
-int NNSymbolListOutputNames(SymbolHandle symbol,
- nn_uint *out_size,
- const char ***out_str_array) {
- Symbol *s = static_cast<Symbol*>(symbol);
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+int NNSymbolListOutputNames(SymbolHandle symbol, nn_uint* out_size, const char*** out_str_array) {
+ Symbol* s = static_cast<Symbol*>(symbol);
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str = s->ListOutputNames();
ret->ret_vec_charp.resize(0);
API_END();
}
-int NNSymbolGetNumOutputs(SymbolHandle symbol,
- nn_uint *output_count) {
- Symbol *s = static_cast<Symbol*>(symbol);
+int NNSymbolGetNumOutputs(SymbolHandle symbol, nn_uint* output_count) {
+ Symbol* s = static_cast<Symbol*>(symbol);
API_BEGIN();
*output_count = static_cast<nn_uint>(s->outputs.size());
API_END();
}
-int NNSymbolCompose(SymbolHandle sym,
- const char *name,
- nn_uint num_args,
- const char** keys,
+int NNSymbolCompose(SymbolHandle sym, const char* name, nn_uint num_args, const char** keys,
SymbolHandle* args) {
API_BEGIN();
- NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
+ NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get();
std::string& s_name = ret->ret_str;
- std::unordered_map<std::string, const Symbol*>& kwargs
- = ret->kwarg_symbol;
+ std::unordered_map<std::string, const Symbol*>& kwargs = ret->kwarg_symbol;
kwargs.clear();
if (name != nullptr) {
s_name = name;
Symbol* s = static_cast<Symbol*>(sym);
if (keys == nullptr && num_args != 0) {
kwargs.clear();
- array_view<const Symbol*> parg(
- (Symbol**)args, (Symbol**)args + num_args); // NOLINT(*)
+ array_view<const Symbol*> parg((Symbol**)args, (Symbol**)args + num_args); // NOLINT(*)
s->Compose(parg, kwargs, s_name);
} else {
for (nn_uint i = 0; i < num_args; ++i) {
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
+
#include <limits>
namespace nnvm {
// e.g. the main graph is level 0
// subgraphs of the main graph is level 1
// subgraphs of the subgraphs of the main graph is level 2
-static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subgraphs) {
+static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>>& subgraphs) {
std::vector<const std::vector<nnvm::NodeEntry>*> curr_level;
std::vector<const std::vector<nnvm::NodeEntry>*> next_level;
std::unordered_map<nnvm::Node*, uint32_t> node2level;
- for (auto &subgraph : subgraphs)
- next_level.push_back(&subgraph->outputs);
+ for (auto& subgraph : subgraphs) next_level.push_back(&subgraph->outputs);
for (uint32_t level = 0; !next_level.empty(); ++level) {
curr_level.swap(next_level);
next_level.clear();
- for (const std::vector<NodeEntry> *graph_ptr : curr_level) {
- const std::vector<NodeEntry> &graph = *graph_ptr;
+ for (const std::vector<NodeEntry>* graph_ptr : curr_level) {
+ const std::vector<NodeEntry>& graph = *graph_ptr;
DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) {
- nnvm::Node *node = n.get();
+ nnvm::Node* node = n.get();
// if the node is visited, but on a different level, then check failed
// if check failed here or before, we stop doing anything, but raise an error
CHECK(!node2level.count(node) || node2level[node] == level)
- << "A subgraph should not depend on the outputs of nodes on higher levels";
+ << "A subgraph should not depend on the outputs of nodes on higher levels";
// otherwise, this node belongs to the current level
node2level[node] = level;
// subgraphs of current node belongs to next level
}
// implement constructor from graph
-IndexedGraph::IndexedGraph(const Graph &g) {
+IndexedGraph::IndexedGraph(const Graph& g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
std::vector<std::shared_ptr<Symbol>> subgraphs;
- DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
- (const ObjectPtr& n) {
- const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
- if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
- CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
- uint32_t nid = static_cast<uint32_t>(nodes_.size());
- CHECK(n);
- for (const auto &subgraph : n->attrs.subgraphs)
- subgraphs.push_back(subgraph);
- // nodes_
- IndexedGraph::Node new_node;
- new_node.source = n.get();
- new_node.weak_ref = n;
- nodes_.emplace_back(std::move(new_node));
- // arg_nodes_
- if (n->is_variable()) {
- input_nodes_.push_back(nid);
- }
- // node2index_
- node2index_[n.get()] = nid;
- // entry rptr
- entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs());
- // input entries
- for (const auto& e : n->inputs) {
- auto it = node2index_.find(e.node.get());
- if (it == node2index_.end() || it->first != e.node.get()) continue;
- input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version});
- }
- inputs_rptr.push_back(input_entries_.size());
- // control deps
- for (const auto& nptr : n->control_deps) {
- if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue;
- auto it = node2index_.find(nptr.get());
- CHECK(it != node2index_.end()) << "control dep not found in graph";
- control_deps_.push_back(it->second);
- }
- control_rptr.push_back(control_deps_.size());
+ DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs](const ObjectPtr& n) {
+ const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
+ if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
+ CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
+ uint32_t nid = static_cast<uint32_t>(nodes_.size());
+ CHECK(n);
+ for (const auto& subgraph : n->attrs.subgraphs) subgraphs.push_back(subgraph);
+ // nodes_
+ IndexedGraph::Node new_node;
+ new_node.source = n.get();
+ new_node.weak_ref = n;
+ nodes_.emplace_back(std::move(new_node));
+ // arg_nodes_
+ if (n->is_variable()) {
+ input_nodes_.push_back(nid);
+ }
+ // node2index_
+ node2index_[n.get()] = nid;
+ // entry rptr
+ entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs());
+ // input entries
+ for (const auto& e : n->inputs) {
+ auto it = node2index_.find(e.node.get());
+ if (it == node2index_.end() || it->first != e.node.get()) continue;
+ input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version});
+ }
+ inputs_rptr.push_back(input_entries_.size());
+ // control deps
+ for (const auto& nptr : n->control_deps) {
+ if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue;
+ auto it = node2index_.find(nptr.get());
+ CHECK(it != node2index_.end()) << "control dep not found in graph";
+ control_deps_.push_back(it->second);
+ }
+ control_rptr.push_back(control_deps_.size());
});
- if (!subgraphs.empty())
- SubgraphSanityCheck(subgraphs);
+ if (!subgraphs.empty()) SubgraphSanityCheck(subgraphs);
for (const auto& e : g.outputs) {
- outputs_.emplace_back(NodeEntry{
- node2index_.at(e.node.get()), e.index, e.version});
+ outputs_.emplace_back(NodeEntry{node2index_.at(e.node.get()), e.index, e.version});
}
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
// input_entries_ and control_rptr must not change after this step.
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
- nodes_[nid].inputs = array_view<NodeEntry>(
- iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
- if (nodes_[nid].source->op() != nullptr &&
- fmutate_inputs.count(nodes_[nid].source->op())) {
+ nodes_[nid].inputs =
+ array_view<NodeEntry>(iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
+ if (nodes_[nid].source->op() != nullptr && fmutate_inputs.count(nodes_[nid].source->op())) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
}
}
const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
- nodes_[nid].control_deps = array_view<uint32_t>(
- cptr + control_rptr[nid], cptr + control_rptr[nid + 1]);
+ nodes_[nid].control_deps =
+ array_view<uint32_t>(cptr + control_rptr[nid], cptr + control_rptr[nid + 1]);
}
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <nnvm/base.h>
#include <nnvm/op.h>
-#include <memory>
#include <atomic>
+#include <memory>
#include <mutex>
#include <unordered_set>
// storage of additional attribute table.
std::unordered_map<std::string, std::unique_ptr<any> > attr;
// storage of existing triggers
- std::unordered_map<std::string, std::vector<std::function<void(Op*)> > > tmap;
+ std::unordered_map<std::string, std::vector<std::function<void(Op*)> > > tmap;
// group of each operator.
std::vector<std::unordered_set<std::string> > op_group;
// get singleton of the
// find operator by name
const Op* Op::Get(const std::string& name) {
const Op* op = dmlc::Registry<Op>::Find(name);
- CHECK(op != nullptr)
- << "Operator " << name << " is not registered";
+ CHECK(op != nullptr) << "Operator " << name << " is not registered";
return op;
}
// Get attribute map by key
const any* Op::GetAttrMap(const std::string& key) {
- auto& dict = OpManager::Global()->attr;
+ auto& dict = OpManager::Global()->attr;
auto it = dict.find(key);
if (it != dict.end()) {
return it->second.get();
}
// update attribute map
-void Op::UpdateAttrMap(const std::string& key,
- std::function<void(any*)> updater) {
+void Op::UpdateAttrMap(const std::string& key, std::function<void(any*)> updater) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::recursive_mutex>(mgr->mutex);
std::unique_ptr<any>& value = mgr->attr[key];
if (updater != nullptr) updater(value.get());
}
-void Op::AddGroupTrigger(const std::string& group_name,
- std::function<void(Op*)> trigger) {
+void Op::AddGroupTrigger(const std::string& group_name, std::function<void(Op*)> trigger) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::recursive_mutex>(mgr->mutex);
auto& tvec = mgr->tmap[group_name];
tvec.push_back(trigger);
auto& op_group = mgr->op_group;
for (const Op* op : dmlc::Registry<Op>::List()) {
- if (op->index_ < op_group.size() &&
- op_group[op->index_].count(group_name) != 0) {
+ if (op->index_ < op_group.size() && op_group[op->index_].count(group_name) != 0) {
trigger((Op*)op); // NOLINT(*)
}
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \brief Support for pass registry.
*/
#include <nnvm/pass.h>
+
#include <algorithm>
namespace dmlc {
namespace nnvm {
-const PassFunctionReg* FindPassDep(const std::string&attr_name) {
+const PassFunctionReg* FindPassDep(const std::string& attr_name) {
for (auto* r : dmlc::Registry<PassFunctionReg>::List()) {
for (auto& s : r->graph_attr_targets) {
if (s == attr_name) return r;
return nullptr;
}
-Graph ApplyPasses(Graph g,
- const std::vector<std::string>& pass) {
+Graph ApplyPasses(Graph g, const std::vector<std::string>& pass) {
std::vector<const PassFunctionReg*> fpass;
for (auto& name : pass) {
auto* reg = dmlc::Registry<PassFunctionReg>::Find(name);
- CHECK(reg != nullptr)
- << "Cannot find pass " << name << " in the registry";
+ CHECK(reg != nullptr) << "Cannot find pass " << name << " in the registry";
fpass.push_back(reg);
}
if (pass_dep != nullptr) {
msg = " The attribute is provided by pass " + pass_dep->name;
}
- LOG(FATAL) << "Graph attr dependency " << dep
- << " is required by pass " << r->name
- << " but is not available "
- << msg;
+ LOG(FATAL) << "Graph attr dependency " << dep << " is required by pass " << r->name
+ << " but is not available " << msg;
}
}
g = r->body(std::move(g));
* \brief Symbolic graph composition API.
*/
#include <nnvm/graph.h>
-#include <nnvm/symbolic.h>
#include <nnvm/op_attr_types.h>
+#include <nnvm/symbolic.h>
namespace nnvm {
namespace symbol_constants {
-const char *kNamespaceSeparator = "$";
+const char* kNamespaceSeparator = "$";
} // namespace symbol_constants
// auxililary version attribute in variable.
// If the node's op mutates a certain input variable,
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
-inline void UpdateNodeVersion(Node *n) {
+inline void UpdateNodeVersion(Node* n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
for (NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
if (fmutate_inputs.count(n->op()) != 0) {
for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) {
NodeEntry& e = n->inputs[i];
- CHECK(e.node->is_variable())
- << "Mutation target can only be Variable";
+ CHECK(e.node->is_variable()) << "Mutation target can only be Variable";
// increase the version of the variable.
e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
}
-inline std::string DefaultVarName(const std::string &op_name,
- const std::string &arg_name) {
+inline std::string DefaultVarName(const std::string& op_name, const std::string& arg_name) {
if (op_name.length() == 0) {
return arg_name;
} else {
}
}
-inline void KeywordArgumentMismatch(const char *source,
- const std::vector<std::string>& user_args,
+inline void KeywordArgumentMismatch(const char* source, const std::vector<std::string>& user_args,
const array_view<std::string>& args) {
std::unordered_set<std::string> keys(args.begin(), args.end());
std::ostringstream head, msg;
for (const auto& key : user_args) {
if (keys.count(key) == 0) {
- LOG(FATAL) << source
- << "Keyword argument name " << key << " not found."
- << msg.str();
+ LOG(FATAL) << source << "Keyword argument name " << key << " not found." << msg.str();
}
}
}
-template<typename T>
-inline std::vector<std::string> GetKeys(
- const std::unordered_map<std::string, T>& kwargs) {
+template <typename T>
+inline std::vector<std::string> GetKeys(const std::unordered_map<std::string, T>& kwargs) {
std::vector<std::string> keys(kwargs.size());
std::transform(kwargs.begin(), kwargs.end(), keys.begin(),
[](decltype(*kwargs.begin())& kv) { return kv.first; });
std::unordered_map<Node*, ObjectPtr> old_new;
// use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const ObjectPtr& node) {
- ObjectPtr np = Node::Create();
- np->attrs = node->attrs;
- old_new[node.get()] = std::move(np);
- });
+ ObjectPtr np = Node::Create();
+ np->attrs = node->attrs;
+ old_new[node.get()] = std::move(np);
+ });
// connect nodes of new graph
- for (const auto &kv : old_new) {
+ for (const auto& kv : old_new) {
for (const NodeEntry& e : kv.first->inputs) {
- Node *ptr = e.node.get();
+ Node* ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
}
for (const ObjectPtr& p : kv.first->control_deps) {
}
// set the head
Symbol ret;
- for (const NodeEntry &e : outputs) {
+ for (const NodeEntry& e : outputs) {
ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version});
}
return ret;
}
-void Symbol::Print(std::ostream &os) const {
- if (outputs.size() == 1 &&
- outputs[0].node->inputs.size() == 0 &&
+void Symbol::Print(std::ostream& os) const {
+ if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0 &&
outputs[0].node->control_deps.size() == 0) {
if (outputs[0].node->is_variable()) {
os << "Variable:" << outputs[0].node->attrs.name << '\n';
} else {
- os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n';
+ os << "AtomicFunctor "
+ << " Op:" << outputs[0].node->op()->name << '\n';
}
} else {
// use DFSVisit to copy all the nodes
os << "Symbol Outputs:\n";
for (size_t i = 0; i < outputs.size(); ++i) {
- os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
- << '(' << outputs[i].index << ")\n";
+ os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name << '(' << outputs[i].index
+ << ")\n";
}
DFSVisit(this->outputs, [&os](const ObjectPtr& node) {
- if (node->is_variable()) {
- os << "Variable:" << node->attrs.name << '\n';
- } else {
- os << "--------------------\n";
- os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n'
- << "Inputs:\n";
- for (size_t i = 0; i < node->inputs.size(); ++i) {
- const NodeEntry& e = node->inputs[i];
- os << "\targ[" << i << "]=" << e.node->attrs.name
- << '(' << e.index << ")";
- if (e.node->is_variable()) {
- os << " version=" << e.version << '\n';
- } else {
- os << '\n';
- }
+ if (node->is_variable()) {
+ os << "Variable:" << node->attrs.name << '\n';
+ } else {
+ os << "--------------------\n";
+ os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n' << "Inputs:\n";
+ for (size_t i = 0; i < node->inputs.size(); ++i) {
+ const NodeEntry& e = node->inputs[i];
+ os << "\targ[" << i << "]=" << e.node->attrs.name << '(' << e.index << ")";
+ if (e.node->is_variable()) {
+ os << " version=" << e.version << '\n';
+ } else {
+ os << '\n';
}
- if (!node->attrs.dict.empty()) {
- os << "Attrs:\n";
- // make an ordered copy because unordered_map doesn't guarantee order.
- std::map<std::string, std::string> sorted_dict(
- node->attrs.dict.begin(), node->attrs.dict.end());
- for (auto &kv : sorted_dict) {
- os << '\t' << kv.first << '=' << kv.second << '\n';
- }
+ }
+ if (!node->attrs.dict.empty()) {
+ os << "Attrs:\n";
+ // make an ordered copy because unordered_map doesn't guarantee order.
+ std::map<std::string, std::string> sorted_dict(node->attrs.dict.begin(),
+ node->attrs.dict.end());
+ for (auto& kv : sorted_dict) {
+ os << '\t' << kv.first << '=' << kv.second << '\n';
}
- if (node->control_deps.size() != 0) {
- os << "Control deps:\n";
- for (size_t i = 0; i < node->control_deps.size(); ++i) {
- os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n';
- }
+ }
+ if (node->control_deps.size() != 0) {
+ os << "Control deps:\n";
+ for (size_t i = 0; i < node->control_deps.size(); ++i) {
+ os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n';
}
}
- });
+ }
+ });
}
}
-Symbol Symbol::operator[] (size_t index) const {
+Symbol Symbol::operator[](size_t index) const {
size_t nreturn = outputs.size();
CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index";
if (nreturn == 1) {
std::vector<ObjectPtr> ret;
if (option == kAll) {
ret.reserve(this->outputs.size());
- DFSVisit(this->outputs, [&ret](const ObjectPtr &node) {
- if (node->is_variable()) {
- ret.push_back(node);
- }
- });
+ DFSVisit(this->outputs, [&ret](const ObjectPtr& node) {
+ if (node->is_variable()) {
+ ret.push_back(node);
+ }
+ });
} else {
std::unordered_set<Node*> mutable_set;
std::vector<ObjectPtr> vlist;
vlist.reserve(this->outputs.size());
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
- DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr &node) {
- if (node->is_variable()) {
- vlist.push_back(node);
- } else if (fmutate_inputs.count(node->op())) {
- for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){
- mutable_set.insert(node->inputs[i].node.get());
- }
+ DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr& node) {
+ if (node->is_variable()) {
+ vlist.push_back(node);
+ } else if (fmutate_inputs.count(node->op())) {
+ for (uint32_t i : fmutate_inputs[node->op()](node->attrs)) {
+ mutable_set.insert(node->inputs[i].node.get());
}
- });
+ }
+ });
ret.reserve(vlist.size());
for (const ObjectPtr& node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) ||
std::vector<std::string> ret;
ret.reserve(outputs.size());
- for (auto &head : outputs) {
+ for (auto& head : outputs) {
if (head.node->is_variable()) {
ret.push_back(head.node->attrs.name);
} else {
Node* n = outputs[0].node.get();
FInputGraph fng = fgraph.get(n->op(), nullptr);
std::vector<uint32_t> garg_idx;
- if (fng != nullptr)
- garg_idx = fng(n->attrs);
+ if (fng != nullptr) garg_idx = fng(n->attrs);
// The names of the arguments that contain graphs.
FListInputNames name_fn = flist_inputs.get(n->op(), nullptr);
std::vector<std::string> garg_names(garg_idx.size());
for (size_t i = 0; i < garg_idx.size(); i++) {
size_t idx = garg_idx[i];
- if (idx < arg_names.size())
- garg_names[i] = arg_names[idx];
+ if (idx < arg_names.size()) garg_names[i] = arg_names[idx];
}
// parameter check.
// If the argument isn't a graph, it should have only one output.
if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
CHECK_EQ(args[i]->outputs.size(), 1U)
- << "Argument " << i << " is a tuple, single value is required";
+ << "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
- if (garg_names.empty()
- || std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
+ if (garg_names.empty() ||
+ std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
CHECK_EQ(kv.second->outputs.size(), 1U)
- << "Keyword Argument " << kv.first << " is a tuple, single value is required";
+ << "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
if (!name.empty()) outputs[0].node->attrs.name = name;
// Atomic functor composition.
if (IsAtomic(outputs)) {
uint32_t n_req = n->num_inputs();
- std::vector<const Symbol *> arg_vec(args.begin(), args.end());
+ std::vector<const Symbol*> arg_vec(args.begin(), args.end());
std::unordered_map<std::string, const Symbol*> kwarg_map(kwargs.begin(), kwargs.end());
// If one of the input arguments is a graph, we need to remove it from the
// list.
if (fng != nullptr) {
std::vector<uint32_t> idxes = fng(n->attrs);
for (auto idx : idxes) {
- const Symbol *sym;
+ const Symbol* sym;
if (idx < arg_vec.size()) {
sym = arg_vec[idx];
} else {
sym = it->second;
kwarg_map.erase(it);
}
- if (n_req != kVarg)
- n_req--;
+ if (n_req != kVarg) n_req--;
n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
}
// Because idxes does not contain duplicates, the loop below functions well.
if (n_req != kVarg) {
n->inputs.resize(n_req);
CHECK_LE(arg_vec.size(), n_req)
- << "Incorrect number of arguments, requires " << n_req
- << ", provided " << arg_vec.size();
+ << "Incorrect number of arguments, requires " << n_req << ", provided " << arg_vec.size();
for (size_t i = 0; i < arg_vec.size(); ++i) {
n->inputs[i] = arg_vec[i]->outputs[0];
}
n->inputs[i] = it->second->outputs[0];
++nmatched;
} else {
- n->inputs[i] = NodeEntry{
- CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0};
+ n->inputs[i] = NodeEntry{CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0};
// copy attribute of parent over automatically created variables
n->inputs[i].node->attrs.dict = n->attrs.dict;
}
}
} else {
// general composition
- CHECK_EQ(args.size(), 0U)
- << "General composition only support kwargs for now";
+ CHECK_EQ(args.size(), 0U) << "General composition only support kwargs for now";
size_t nmatched = 0;
size_t arg_counter = 0;
- std::unordered_map<Node *, const NodeEntry*> replace_map;
+ std::unordered_map<Node*, const NodeEntry*> replace_map;
// replace map stores the existing replacement plan for arguments node
- auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map]
- (const ObjectPtr &node) {
+ auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs,
+ &replace_map](const ObjectPtr& node) {
if (node->is_variable()) {
if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
++arg_counter;
} else {
- // match kwargs
+ // match kwargs
auto kit = kwargs.find(node->attrs.name);
if (kit != kwargs.end()) {
replace_map[node.get()] = &(kit->second->outputs[0]);
if (nmatched == kwargs.size() && arg_counter <= args.size()) {
std::vector<Node*> update_nodes;
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
- auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
- (const ObjectPtr &node) {
+ auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes](const ObjectPtr& node) {
// visit all the childs, find possible replacement
bool repl = false;
for (size_t i = 0; i < node->inputs.size(); ++i) {
- NodeEntry *e = &(node->inputs[i]);
+ NodeEntry* e = &(node->inputs[i]);
if (e->node->is_variable()) {
auto iter = replace_map.find(e->node.get());
if (iter != replace_map.end()) {
}
}
-Symbol Symbol::operator () (const array_view<const Symbol*>& args,
- const std::unordered_map<std::string, const Symbol*>& kwargs,
- const std::string& name) const {
+Symbol Symbol::operator()(const array_view<const Symbol*>& args,
+ const std::unordered_map<std::string, const Symbol*>& kwargs,
+ const std::string& name) const {
Symbol s = this->Copy();
s.Compose(args, kwargs, name);
return s;
}
void Symbol::AddControlDeps(const Symbol& src) {
- CHECK_EQ(outputs.size(), 1U)
- << "AddControlDeps only works for nongrouped symbol";
+ CHECK_EQ(outputs.size(), 1U) << "AddControlDeps only works for nongrouped symbol";
Node* n = outputs[0].node.get();
for (const NodeEntry& sp : src.outputs) {
n->control_deps.push_back(sp.node);
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret;
DFSVisit(this->outputs, [&ret](const ObjectPtr& node) {
- Node* n = node.get();
- if (n->is_variable()) {
- // grab version from variable.
- VariableParam& param = nnvm::get<VariableParam>(n->attrs.parsed);
- ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
- } else {
- uint32_t nout = n->num_outputs();
- if (fnum_vis_output.count(n->op())) {
- nout = fnum_vis_output[n->op()](n->attrs);
- }
- for (uint32_t i = 0; i < nout; ++i) {
- ret.outputs.emplace_back(NodeEntry{node, i, 0});
- }
+ Node* n = node.get();
+ if (n->is_variable()) {
+ // grab version from variable.
+ VariableParam& param = nnvm::get<VariableParam>(n->attrs.parsed);
+ ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
+ } else {
+ uint32_t nout = n->num_outputs();
+ if (fnum_vis_output.count(n->op())) {
+ nout = fnum_vis_output[n->op()](n->attrs);
}
- });
+ for (uint32_t i = 0; i < nout; ++i) {
+ ret.outputs.emplace_back(NodeEntry{node, i, 0});
+ }
+ }
+ });
return ret;
}
void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs) {
Node* node = outputs[0].node.get();
for (const NodeEntry& e : outputs) {
- CHECK(node == e.node.get())
- << "Symbol.SetAttrs only works for non-grouped symbol";
+ CHECK(node == e.node.get()) << "Symbol.SetAttrs only works for non-grouped symbol";
}
for (const auto& kv : attrs) {
if (kv.first == "name") {
if (option == kRecursive) {
std::unordered_map<std::string, std::string> ret;
DFSVisit(this->outputs, [&ret](const ObjectPtr& n) {
- for (const auto& it : n->attrs.dict) {
- ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
- }
- });
+ for (const auto& it : n->attrs.dict) {
+ ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
+ }
+ });
return ret;
} else {
return outputs[0].node->attrs.dict;
}
}
-std::vector<std::tuple<std::string, std::string, std::string> >
- Symbol::ListAttrsRecursive() const {
+std::vector<std::tuple<std::string, std::string, std::string> > Symbol::ListAttrsRecursive() const {
std::vector<std::tuple<std::string, std::string, std::string> > ret;
DFSVisit(this->outputs, [&ret](const ObjectPtr& n) {
- for (const auto& it : n->attrs.dict) {
- ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second));
- }
- });
+ for (const auto& it : n->attrs.dict) {
+ ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second));
+ }
+ });
return ret;
}
-Symbol Symbol::CreateFunctor(const Op* op,
- std::unordered_map<std::string, std::string> attrs) {
+Symbol Symbol::CreateFunctor(const Op* op, std::unordered_map<std::string, std::string> attrs) {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol s;
ObjectPtr n = Node::Create();
return s;
}
-Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
+Symbol Symbol::CreateGroup(const std::vector<Symbol>& symbols) {
Symbol ret;
- for (const auto &s : symbols) {
+ for (const auto& s : symbols) {
ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end());
}
return ret;
* \brief Infer and correct layout.
*/
#include <nnvm/graph.h>
-#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
-#include <nnvm/pass.h>
#include <nnvm/layout.h>
+#include <nnvm/op_attr_types.h>
+#include <nnvm/pass.h>
namespace nnvm {
namespace pass {
-nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src,
- const Layout& dst) {
+nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src, const Layout& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__");
static int count = 0;
nnvm::ObjectPtr n = nnvm::Node::Create();
* insert layout transform nodes automatically.
*/
nnvm::Graph CorrectLayout(nnvm::Graph src) {
- static auto& op_correct_layout =
- nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");
+ static auto& op_correct_layout = nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");
const IndexedGraph& idx = src.indexed_graph();
std::vector<nnvm::ObjectPtr> mirror_vec(idx.num_nodes(), nullptr);
*new_node = *(inode.source);
if (new_node->is_variable()) {
// Variable node. No operator. Only one output entry.
- auto input_iter = std::find(
- idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
+ auto input_iter = std::find(idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
int64_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
if (src.HasAttr("layout_inputs")) {
- new_layouts[new_node.get()] =
- {src.GetAttr<std::vector<Layout> >("layout_inputs")[input_id]};
+ new_layouts[new_node.get()] = {
+ src.GetAttr<std::vector<Layout> >("layout_inputs")[input_id]};
} else {
new_layouts[new_node.get()] = {Layout::Undef()};
}
}
if (op_correct_layout.count(new_node->op())) {
- const auto &flayout = op_correct_layout[new_node->op()];
+ const auto& flayout = op_correct_layout[new_node->op()];
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
- << "Layout infer fail";
+ << "Layout infer fail";
CHECK_EQ(request_ilayouts.size(), num_inputs);
CHECK_EQ(produce_olayouts.size(), num_outputs);
}
// register pass
NNVM_REGISTER_PASS(CorrectLayout)
-.describe("Return a layout-transformed graph of src.")
-.set_body(CorrectLayout)
-.provide_graph_attr("layout")
-.set_change_graph(true);
+ .describe("Return a layout-transformed graph of src.")
+ .set_body(CorrectLayout)
+ .provide_graph_attr("layout")
+ .set_change_graph(true);
DMLC_JSON_ENABLE_ANY(LayoutVector, list_layout);
* \brief Passes that takes gradient of the graph
* This code code was modified based on mxnet codebase by Min Lin
*/
-#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
+#include <nnvm/pass.h>
+
#include <algorithm>
#include <functional>
}
}
-bool CheckGradAllZero(const std::vector<NodeEntry>& grads,
- const std::vector<const Op*>& zero_ops) {
+bool CheckGradAllZero(const std::vector<NodeEntry>& grads, const std::vector<const Op*>& zero_ops) {
if (!grads.size() || !zero_ops.size()) return false;
for (const auto& g : grads) {
bool found = false;
Graph Gradient(Graph src) {
using nnvm::FGradient;
- using MirrorFun = std::function<int (const Node& node)>;
- using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>;
+ using MirrorFun = std::function<int(const Node& node)>;
+ using AttrHintFun = std::function<NodeEntry(const NodeEntry& src, const NodeEntry& like)>;
- CHECK_NE(src.attrs.count("grad_ys"), 0U)
- << "Gradient require grad_ys to be presented.";
+ CHECK_NE(src.attrs.count("grad_ys"), 0U) << "Gradient require grad_ys to be presented.";
CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U)
<< "Gradient require grad_ys_out_grad to be presented.";
- CHECK_NE(src.attrs.count("grad_xs"), 0U)
- << "Gradient require grad_xs to be presented.";
- const std::vector<NodeEntry>& ys =
- src.GetAttr<std::vector<NodeEntry> >("grad_ys");
+ CHECK_NE(src.attrs.count("grad_xs"), 0U) << "Gradient require grad_xs to be presented.";
+ const std::vector<NodeEntry>& ys = src.GetAttr<std::vector<NodeEntry> >("grad_ys");
const std::vector<NodeEntry>& ys_out_grad =
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
- const std::vector<NodeEntry>& xs =
- src.GetAttr<std::vector<NodeEntry> >("grad_xs");
- using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&& inputs)>;
+ const std::vector<NodeEntry>& xs = src.GetAttr<std::vector<NodeEntry> >("grad_xs");
+ using AggFun = std::function<NodeEntry(std::vector<NodeEntry> && inputs)>;
AggFun agg_fun = DefaultAggregateGradient;
if (src.attrs.count("grad_aggregate_fun") != 0) {
agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun");
if (src.attrs.count("zero_ops") != 0) {
zero_ops = src.GetAttr<std::vector<const Op*> >("zero_ops");
}
- const Op* copy_op = (src.attrs.count("copy_op") != 0) ?
- Op::Get(src.GetAttr<std::string>("copy_op")) :
- nullptr;
+ const Op* copy_op =
+ (src.attrs.count("copy_op") != 0) ? Op::Get(src.GetAttr<std::string>("copy_op")) : nullptr;
// topo sort
std::vector<ObjectPtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const ObjectPtr& node) {
- if (output_grads.count(node.get()) == 0) {
- output_grads[node.get()].resize(node->num_outputs());
- }
- topo_order.push_back(node);
- });
+ if (output_grads.count(node.get()) == 0) {
+ output_grads[node.get()].resize(node->num_outputs());
+ }
+ topo_order.push_back(node);
+ });
CHECK_EQ(ys.size(), ys_out_grad.size());
for (size_t i = 0; i < ys.size(); ++i) {
NodeEntry ograd = ys_out_grad[i];
- output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
+ output_grads[ys[i].node.get()][ys[i].index].grads = {ograd};
}
// Check that all xs are reachable from ys
for (size_t i = 0; i < xs.size(); ++i) {
CHECK(output_grads.find(xs[i].node.get()) != output_grads.end())
- << "Cannot differentiate with respect to the " << i+1 << "-th variable "
+ << "Cannot differentiate with respect to the " << i + 1 << "-th variable "
<< "because it is unreachable from the outputs.";
}
LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable "
<< "because it didn't register FGradient attribute.";
}
- for (const auto& nodeEntry : input_grads)
- CHECK(nodeEntry.node);
+ for (const auto& nodeEntry : input_grads) CHECK(nodeEntry.node);
auto git = input_grads.begin();
CHECK((*rit)->inputs.size() <= input_grads.size());
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
copy_node->attrs.name = os.str();
copy_node->inputs.emplace_back(entry.sum);
if (copy_node->attrs.op->attr_parser != nullptr) {
- copy_node->attrs.op->attr_parser(&(copy_node->attrs));
+ copy_node->attrs.op->attr_parser(&(copy_node->attrs));
}
unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter));
}
} else {
- ret.outputs[counter] = entry.sum;
+ ret.outputs[counter] = entry.sum;
}
++counter;
}
// register pass
NNVM_REGISTER_PASS(Gradient)
-.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
-.set_body(Gradient)
-.set_change_graph(true)
-.depend_graph_attr("grad_ys")
-.depend_graph_attr("grad_xs")
-.depend_graph_attr("grad_ys_out_grad");
+ .describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
+ .set_body(Gradient)
+ .set_change_graph(true)
+ .depend_graph_attr("grad_ys")
+ .depend_graph_attr("grad_xs")
+ .depend_graph_attr("grad_ys_out_grad");
} // namespace
} // namespace pass
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \brief This header contains graph algorithms on StaticGraph.
* It is used compute informations such as whether two
* operations can run in parallel, and helps allocation.
-*/
+ */
#ifndef NNVM_PASS_GRAPH_ALGORITHM_H_
#define NNVM_PASS_GRAPH_ALGORITHM_H_
#include <nnvm/graph.h>
+
#include <vector>
namespace nnvm {
* \param path the output path of nodes.
* \return the total reward of best path.
*/
-inline uint32_t FindBestPath(
- const IndexedGraph& graph,
- const std::vector<uint32_t>& node_reward,
- std::vector<uint32_t>* path) {
+inline uint32_t FindBestPath(const IndexedGraph& graph, const std::vector<uint32_t>& node_reward,
+ std::vector<uint32_t>* path) {
const uint32_t num_nodes = static_cast<uint32_t>(graph.num_nodes());
CHECK_EQ(num_nodes, node_reward.size());
path->clear();
uint32_t reward = 0;
for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) {
- path->push_back(nid); reward += node_reward[nid];
+ path->push_back(nid);
+ reward += node_reward[nid];
}
CHECK_EQ(reward, best_solution);
return best_solution;
* \param color the color index of each of the node.
* \return the total number of colors.
*/
-inline uint32_t ColorNodeGroup(
- const IndexedGraph &graph,
- std::vector<uint32_t> node_importance,
- uint32_t max_ncolor,
- std::vector<uint32_t> *color) {
+inline uint32_t ColorNodeGroup(const IndexedGraph& graph, std::vector<uint32_t> node_importance,
+ uint32_t max_ncolor, std::vector<uint32_t>* color) {
CHECK_NE(max_ncolor, 0U);
CHECK_EQ(graph.num_nodes(), node_importance.size());
* \file infer_shape.cc
* \brief Inference the shapes given existin information.
*/
-#include <nnvm/pass.h>
-#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
+#include <nnvm/op_attr_types.h>
+#include <nnvm/pass.h>
namespace nnvm {
namespace pass {
namespace {
-template<typename AttrType, typename IsNone, typename FDefault>
-Graph InferAttr(Graph &&ret,
- const AttrType empty_val,
- const char* infer_name,
- const char* input_name,
- const char* attr_key_name,
- const char* attr_name,
- const char* unknown_name,
- IsNone fis_none,
- FDefault fdefault) {
+template <typename AttrType, typename IsNone, typename FDefault>
+Graph InferAttr(Graph&& ret, const AttrType empty_val, const char* infer_name,
+ const char* input_name, const char* attr_key_name, const char* attr_name,
+ const char* unknown_name, IsNone fis_none, FDefault fdefault) {
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
- static auto& finfer_shape =
- Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
- static auto& is_backward =
- Op::GetAttr<TIsBackward>("TIsBackward");
+ static auto& finfer_shape = Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name);
+ static auto& is_backward = Op::GetAttr<TIsBackward>("TIsBackward");
// gradient function, used to get node correspondence.
- static auto& fgrad =
- Op::GetAttr<FGradient>("FGradient");
+ static auto& fgrad = Op::GetAttr<FGradient>("FGradient");
// reshape shape vector
AttrVector rshape;
if (ret.attrs.count(attr_name) != 0) {
// get the shape hints
std::string shape_hints_key = std::string(attr_name) + "_hints";
if (ret.attrs.count(shape_hints_key)) {
- NodeEntryMap<AttrType> shape_hints =
- ret.GetAttr<NodeEntryMap<AttrType>>(shape_hints_key);
+ NodeEntryMap<AttrType> shape_hints = ret.GetAttr<NodeEntryMap<AttrType>>(shape_hints_key);
for (const auto& kv : shape_hints) {
NodeEntry e = kv.first;
if (idx.exist(e.node.get())) {
}
} else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
CHECK_GE(inode.control_deps.size(), 1U)
- << "BackwardOp need to have control_deps to its forward op";
+ << "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
ObjectPtr fwd_ptr = inode.source->control_deps[0];
CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
}
// out grad entries
CHECK(igrad_node != nullptr)
- << "Cannot find matching backward op for " << inode.source->attrs.name;
+ << "Cannot find matching backward op for " << inode.source->attrs.name;
for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
const NodeEntry& e = igrad_node->inputs[i];
if (e.node == nullptr) {
throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
}
} else {
- CHECK(!last_iter)
- << "Attribute " << infer_name
- << " is not registered by op " << inode.source->op()->name
- << " we are not able to complete the inference because of this";
+ CHECK(!last_iter) << "Attribute " << infer_name << " is not registered by op "
+ << inode.source->op()->name
+ << " we are not able to complete the inference because of this";
}
}
// Save to the result map.
}
NNVM_REGISTER_PASS(InferShape)
-.describe("Infer the shape of each node entries.")
-.set_body([](Graph ret) {
- return InferAttr<TShape>(
- std::move(ret), TShape(),
- "FInferShape", "shape_inputs", "shape_attr_key",
- "shape", "shape_num_unknown_nodes",
- [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
- nullptr);
- })
-.set_change_graph(false)
-.provide_graph_attr("shape");
+ .describe("Infer the shape of each node entries.")
+ .set_body([](Graph ret) {
+ return InferAttr<TShape>(
+ std::move(ret), TShape(), "FInferShape", "shape_inputs", "shape_attr_key", "shape",
+ "shape_num_unknown_nodes", [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
+ nullptr);
+ })
+ .set_change_graph(false)
+ .provide_graph_attr("shape");
// inference function for same type
-inline bool SameType(const NodeAttrs& attrs,
- std::vector<int> *iattr,
- std::vector<int> *oattr) {
+inline bool SameType(const NodeAttrs& attrs, std::vector<int>* iattr, std::vector<int>* oattr) {
int def_v = -1;
for (int v : *oattr) {
if (v != -1) {
- def_v = v; break;
+ def_v = v;
+ break;
}
}
if (def_v == -1) {
for (int v : *iattr) {
if (v != -1) {
- def_v = v; break;
+ def_v = v;
+ break;
}
}
}
}
NNVM_REGISTER_PASS(InferType)
-.describe("Infer the dtype of each node entries.")
-.set_body([](Graph ret) {
- return InferAttr<int>(
- std::move(ret), -1,
- "FInferType", "dtype_inputs", "dtype_attr_key",
- "dtype", "dtype_num_unknown_nodes",
- [](const int t) { return t == -1; },
- SameType);
- })
-.set_change_graph(false)
-.provide_graph_attr("dtype");
+ .describe("Infer the dtype of each node entries.")
+ .set_body([](Graph ret) {
+ return InferAttr<int>(
+ std::move(ret), -1, "FInferType", "dtype_inputs", "dtype_attr_key", "dtype",
+ "dtype_num_unknown_nodes", [](const int t) { return t == -1; }, SameType);
+ })
+ .set_change_graph(false)
+ .provide_graph_attr("dtype");
DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
DMLC_JSON_ENABLE_ANY(DTypeVector, list_int);
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
*/
-#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
+#include <nnvm/pass.h>
namespace nnvm {
namespace pass {
namespace {
-template<typename T>
-inline T get_with_default(const std::unordered_map<Node*, T> &map,
- Node* key,
- const T& def) {
+template <typename T>
+inline T get_with_default(const std::unordered_map<Node*, T>& map, Node* key, const T& def) {
auto it = map.find(key);
if (it != map.end()) return it->second;
return def;
Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const ObjectPtr& n) {
- for (const NodeEntry& e : n->inputs) {
- if (e.node->is_variable()) {
- if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
- version_hist[e.node.get()] = std::vector<NodeEntry>{};
- }
+ for (const NodeEntry& e : n->inputs) {
+ if (e.node->is_variable()) {
+ if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
+ version_hist[e.node.get()] = std::vector<NodeEntry>{};
}
}
- });
+ }
+ });
// no mutation happens, everything if fine.
if (version_hist.size() == 0) return src;
// start preparing for remapping the nodes.
std::unordered_map<Node*, ObjectPtr> old_new;
- auto prepare = [&version_hist, &old_new] (const ObjectPtr& n) {
+ auto prepare = [&version_hist, &old_new](const ObjectPtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op())) {
};
DFSVisit(src.outputs, prepare);
// comparator of history entry
- auto comparator = [](const NodeEntry& a, const NodeEntry &b) {
+ auto comparator = [](const NodeEntry& a, const NodeEntry& b) {
if (a.version < b.version) return true;
if (a.version > b.version) return false;
return a.index > b.index;
};
- for (auto &kv : version_hist) {
+ for (auto& kv : version_hist) {
std::sort(kv.second.begin(), kv.second.end(), comparator);
}
// copy the nodes, as well as add control deps
- for (auto &kv : old_new) {
+ for (auto& kv : old_new) {
// copy the nodes
for (const NodeEntry& e : kv.first->inputs) {
auto it = old_new.find(e.node.get());
}
}
for (const ObjectPtr& p : kv.first->control_deps) {
- kv.second->control_deps.emplace_back(
- get_with_default(old_new, p.get(), p));
+ kv.second->control_deps.emplace_back(get_with_default(old_new, p.get(), p));
}
// add control deps
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
const NodeEntry& e = kv.first->inputs[i];
if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) {
std::vector<NodeEntry>& vec = version_hist.at(e.node.get());
- auto it = std::lower_bound(vec.begin(), vec.end(),
- NodeEntry{nullptr, 1, e.version},
- comparator);
+ auto it =
+ std::lower_bound(vec.begin(), vec.end(), NodeEntry{nullptr, 1, e.version}, comparator);
if (IsMutate(mutate_inputs, i)) {
int read_dep = 0;
while (it != vec.begin()) {
if (it->index != 0) break;
++read_dep;
// depend on previous read
- kv.second->control_deps.push_back(
- get_with_default(old_new, it->node.get(), it->node));
+ kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node));
}
if (read_dep == 0 && it->index != 0) {
// depend on last write
- kv.second->control_deps.push_back(
- get_with_default(old_new, it->node.get(), it->node));
+ kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node));
}
} else {
// depend on last write
if (it->index != 0) {
- kv.second->control_deps.push_back(
- get_with_default(old_new, it->node.get(), it->node));
+ kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node));
}
}
}
}
}
Graph ret;
- for (const NodeEntry &e : src.outputs) {
- ret.outputs.emplace_back(NodeEntry{
- get_with_default(old_new, e.node.get(), e.node), e.index, e.version});
+ for (const NodeEntry& e : src.outputs) {
+ ret.outputs.emplace_back(
+ NodeEntry{get_with_default(old_new, e.node.get(), e.node), e.index, e.version});
}
return ret;
}
NNVM_REGISTER_PASS(OrderMutation)
-.describe("Return a new graph that adds control dependencies, "\
- "to order the mutation and reads if mutation exists.")
-.set_body(OrderMutation)
-.set_change_graph(true);
+ .describe(
+ "Return a new graph that adds control dependencies, "
+ "to order the mutation and reads if mutation exists.")
+ .set_body(OrderMutation)
+ .set_change_graph(true);
} // namespace
} // namespace pass
* \brief Inference the device of each operator given known information.
* Insert a copy node automatically when there is a cross device.
*/
-#include <nnvm/pass.h>
-#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
+#include <nnvm/op_attr_types.h>
+#include <nnvm/pass.h>
namespace nnvm {
namespace pass {
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph();
- static auto& is_backward =
- Op::GetAttr<TIsBackward>("TIsBackward");
+ static auto& is_backward = Op::GetAttr<TIsBackward>("TIsBackward");
DeviceVector device;
// copy on write semanatics
if (src.attrs.count("device") != 0) {
<< "The device assignment not found for group " << device_group;
device[nid] = dit->second;
} else {
- if (!inode.source->is_variable() &&
- is_backward.get(inode.source->op(), false)) {
+ if (!inode.source->is_variable() && is_backward.get(inode.source->op(), false)) {
if (device[inode.control_deps[0]] != -1) {
device[nid] = device[inode.control_deps[0]];
}
} else {
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] != -1) {
- device[nid] = device[e.node_id]; break;
+ device[nid] = device[e.node_id];
+ break;
}
}
}
auto e = inode.inputs[index];
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
LOG(FATAL) << " mutable state cannot go across device"
- << " op=" << inode.source->op()->name
- << " input_state_index=" << index;
+ << " op=" << inode.source->op()->name << " input_state_index=" << index;
}
}
}
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
- need_mutate = true; break;
+ need_mutate = true;
+ break;
}
}
if (!need_mutate) {
for (const uint32_t cid : inode.control_deps) {
- if (new_node_map[cid] != nullptr) {
- need_mutate = true; break;
+ if (new_node_map[cid] != nullptr) {
+ need_mutate = true;
+ break;
}
}
}
auto copy_key = std::make_tuple(e.node_id, e.index, dev_id);
auto it = copy_map.find(copy_key);
if (it != copy_map.end() && it->first == copy_key) {
- new_node->inputs.emplace_back(
- NodeEntry{it->second, 0, 0});
+ new_node->inputs.emplace_back(NodeEntry{it->second, 0, 0});
} else {
ObjectPtr copy_node = Node::Create();
std::ostringstream os;
- os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
+ os << inode.source->inputs[i].node->attrs.name << "_" << e.index << "_copy";
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str();
if (new_node_map[e.node_id] != nullptr) {
- copy_node->inputs.emplace_back(
- NodeEntry{new_node_map[e.node_id], e.index, 0});
+ copy_node->inputs.emplace_back(NodeEntry{new_node_map[e.node_id], e.index, 0});
} else {
copy_node->inputs.push_back(inode.source->inputs[i]);
}
}
copy_map[copy_key] = copy_node;
new_device_map[copy_node.get()] = dev_id;
- new_node->inputs.emplace_back(
- NodeEntry{std::move(copy_node), 0, 0});
+ new_node->inputs.emplace_back(NodeEntry{std::move(copy_node), 0, 0});
}
} else {
if (new_node_map[e.node_id] != nullptr) {
- new_node->inputs.emplace_back(
- NodeEntry{new_node_map[e.node_id], e.index, 0});
+ new_node->inputs.emplace_back(NodeEntry{new_node_map[e.node_id], e.index, 0});
} else {
new_node->inputs.push_back(inode.source->inputs[i]);
}
}
NNVM_REGISTER_PASS(PlaceDevice)
-.describe("Infer the device type of each operator."\
- "Insert a copy node when there is cross device copy")
-.set_body(PlaceDevice)
-.set_change_graph(true)
-.provide_graph_attr("device")
-.depend_graph_attr("device_group_attr_key")
-.depend_graph_attr("device_assign_map")
-.depend_graph_attr("device_copy_op");
+ .describe(
+ "Infer the device type of each operator."
+ "Insert a copy node when there is cross device copy")
+ .set_body(PlaceDevice)
+ .set_change_graph(true)
+ .provide_graph_attr("device")
+ .depend_graph_attr("device_group_attr_key")
+ .depend_graph_attr("device_assign_map")
+ .depend_graph_attr("device_copy_op");
DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int);
* \brief Assign memory tag to each of the data entries.
*/
#include <nnvm/graph.h>
-#include <nnvm/pass.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/op_attr_types.h>
+#include <nnvm/pass.h>
+
#include <memory>
+
#include "graph_algorithm.h"
namespace nnvm {
auto end = free_.upper_bound(size * match_range_);
// search for memory blocks larger than requested
for (auto it = mid; it != end; ++it) {
- StorageEntry *e = it->second;
+ StorageEntry* e = it->second;
if (e->device_id != dev_id) continue;
- if (node_color_.size() != 0 &&
- node_color_[e->released_by_node] != node_color_[node_id]) continue;
+ if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id])
+ continue;
// Use exect matching strategy
e->max_bytes = std::max(size, e->max_bytes);
// find a exact match, erase from map and return
// then search for memory blocks smaller than requested space
for (auto it = mid; it != begin;) {
--it;
- StorageEntry *e = it->second;
+ StorageEntry* e = it->second;
if (e->device_id != dev_id) continue;
- if (node_color_.size() != 0 &&
- node_color_[e->released_by_node] != node_color_[node_id]) continue;
+ if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id])
+ continue;
// Use exect matching strategy
e->max_bytes = std::max(size, e->max_bytes);
// erase from map and return
void Release(StorageID id, uint32_t node_id) {
CHECK_NE(id, kBadStorageID);
if (id == kExternalStorageID || id == kDynamicStorageID) return;
- StorageEntry *e = data_[id].get();
+ StorageEntry* e = data_[id].get();
e->released_by_node = node_id;
free_.insert({e->max_bytes, e});
}
// totoal number of bytes allocated
size_t TotalAllocBytes() const {
size_t total = 0;
- for (auto &p : data_) {
+ for (auto& p : data_) {
total += p->max_bytes;
}
return total;
if ((*idx_)[nid].source->is_variable()) continue;
importance[nid] = 1;
}
- num_match_color_ = pass::ColorNodeGroup(
- *idx_, importance, num_match_color_, &node_color_);
+ num_match_color_ = pass::ColorNodeGroup(*idx_, importance, num_match_color_, &node_color_);
}
}
* Internal method to perform the memory allocation for a graph
* */
size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
- const std::pair<uint32_t, uint32_t>& node_range,
- StorageVector* storage_ptr,
+ const std::pair<uint32_t, uint32_t>& node_range, StorageVector* storage_ptr,
std::vector<int>* storage_inplace_index_ptr,
- const std::vector<uint32_t>& entry_ref_count,
- GraphAllocator* allocator) {
+ const std::vector<uint32_t>& entry_ref_count, GraphAllocator* allocator) {
static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption");
static auto& finplace_identity = Op::GetAttr<FInplaceIdentity>("FInplaceIdentity");
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
// Get reference
- auto &storage = *storage_ptr;
- auto &storage_inplace_index = *storage_inplace_index_ptr;
+ auto& storage = *storage_ptr;
+ auto& storage_inplace_index = *storage_inplace_index_ptr;
// Get attributes from the graph
const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape");
auto sid_out = storage[eid_out];
auto sid_in = storage[eid_in];
bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 &&
- fignore_inputs[inode.source->op()](
- inode.source->attrs).size() == inode.source->num_inputs());
+ fignore_inputs[inode.source->op()](inode.source->attrs).size() ==
+ inode.source->num_inputs());
// Identity should only be true if shape.Size() and types match
bool real_identity = identity[ipair] &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
dtype_vec[eid_out] == dtype_vec[eid_in];
- if (taken[kv.first] == false &&
- sid_out == GraphAllocator::kBadStorageID &&
- sid_in >= 0 &&
+ if (taken[kv.first] == false && sid_out == GraphAllocator::kBadStorageID && sid_in >= 0 &&
((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || real_identity) &&
- entry_ref_count[eid_out] > 0 &&
- shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
- (dtype_vec[eid_out] == dtype_vec[eid_in] ||
+ entry_ref_count[eid_out] > 0 && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
+ (dtype_vec[eid_out] == dtype_vec[eid_in] ||
GetDTypeSize(dtype_vec[eid_out]) == GetDTypeSize(dtype_vec[eid_in]))) {
// inplace optimization
taken[kv.first] = true;
uint32_t eid = idx.entry_id(nid, index);
// only request memory for kBadStorageID
if (storage[eid] == GraphAllocator::kBadStorageID) {
- auto &eshape = shape_vec[eid];
+ auto& eshape = shape_vec[eid];
size_t esize = 0;
if (eshape.ndim() != 0) esize = eshape.Size();
eids.insert(std::make_pair(esize, eid));
}
}
for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) {
- uint32_t eid = rit->second;
- auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid);
- if (sid >= 0) {
- storage_ref_count[sid] = entry_ref_count[eid];
- }
- storage[eid] = sid;
+ uint32_t eid = rit->second;
+ auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid);
+ if (sid >= 0) {
+ storage_ref_count[sid] = entry_ref_count[eid];
+ }
+ storage[eid] = sid;
}
// check if certain inputs is ignored.
std::vector<uint32_t> ignore_inputs;
return num_not_allocated;
}
-
// function to plan memory
Graph PlanMemory(Graph ret) {
// setup ref counter
size_t min_allocated_bytes = -1;
size_t max_match_range = dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16);
size_t min_match_range =
- dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range;
+ dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range;
for (size_t match_range = min_match_range; match_range <= max_match_range; match_range *= 2) {
// Make a copy of related fields
StorageVector storage_vec(storage);
GraphAllocator allocator(&idx, match_range);
// number of entries that are not statically allocated.
- size_t storage_num_not_allocated =
- AllocMemory(ret, idx, node_range, &storage_vec, &storage_inplace_index,
- ref_count, &allocator);
+ size_t storage_num_not_allocated = AllocMemory(ret, idx, node_range, &storage_vec,
+ &storage_inplace_index, ref_count, &allocator);
size_t storage_allocated_bytes = allocator.TotalAllocBytes();
// Choose the plan which leads to minimal memory usage
}
NNVM_REGISTER_PASS(PlanMemory)
-.describe("Plan the memory allocation of each node entries.")
-.set_body(PlanMemory)
-.set_change_graph(false)
-.depend_graph_attr("dtype")
-.depend_graph_attr("shape")
-.provide_graph_attr("storage_id")
-.provide_graph_attr("storage_inplace_index");
+ .describe("Plan the memory allocation of each node entries.")
+ .set_body(PlanMemory)
+ .set_change_graph(false)
+ .depend_graph_attr("dtype")
+ .depend_graph_attr("shape")
+ .provide_graph_attr("storage_id")
+ .provide_graph_attr("storage_inplace_index");
} // namespace
} // namespace pass
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <nnvm/tuple.h>
+
#include <iostream>
namespace nnvm {
using AttrPrinter = std::function<void(uint32_t index, std::ostream& os)>; // NOLINT(*)
-template<typename T>
+template <typename T>
AttrPrinter GetVectorPrinter_(const T& vec) {
return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*)
os << vec[index];
};
}
-AttrPrinter GetVectorPrinter(const Graph& graph,
- const std::string& key) {
+AttrPrinter GetVectorPrinter(const Graph& graph, const std::string& key) {
auto it = graph.attrs.find(key);
- CHECK(it != graph.attrs.end())
- << "Cannot find " << key << " in graph attr";
+ CHECK(it != graph.attrs.end()) << "Cannot find " << key << " in graph attr";
const any& value = *(it->second);
if (value.type() == typeid(std::vector<TShape>)) {
- return GetVectorPrinter_(
- nnvm::get<std::vector<TShape> >(value));
+ return GetVectorPrinter_(nnvm::get<std::vector<TShape> >(value));
} else if (value.type() == typeid(std::vector<int>)) {
- return GetVectorPrinter_(
- nnvm::get<std::vector<int> >(value));
+ return GetVectorPrinter_(nnvm::get<std::vector<int> >(value));
} else if (value.type() == typeid(std::vector<std::string>)) {
- return GetVectorPrinter_(
- nnvm::get<std::vector<std::string> >(value));
+ return GetVectorPrinter_(nnvm::get<std::vector<std::string> >(value));
} else {
LOG(FATAL) << "Cannot handle type " << value.type().name();
return nullptr;
}
}
-
// print the graph ir in readable format
-void PrintGraphIR_(Graph src,
- const std::vector<std::string>& join_entry_attrs,
+void PrintGraphIR_(Graph src, const std::vector<std::string>& join_entry_attrs,
const std::vector<std::string>& join_node_attrs,
- std::ostream& os) { // NOLINT(*)
+ std::ostream& os) { // NOLINT(*)
const IndexedGraph& idx = src.indexed_graph();
std::vector<std::function<void(uint32_t, std::ostream&)> > trigger; // NOLINT(*)
for (const std::string& key : join_entry_attrs) {
AttrPrinter fp = GetVectorPrinter(src, key);
- auto fprint = [&idx, key, fp](
- uint32_t nid, std::ostream& os) { // NOLINT(*)
+ auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*)
const IndexedGraph::Node& inode = idx[nid];
os << ", " << key << "=";
if (inode.source->num_outputs() != 1) {
}
for (const std::string& key : join_node_attrs) {
AttrPrinter fp = GetVectorPrinter(src, key);
- auto fprint = [&idx, key, fp](
- uint32_t nid, std::ostream& os) { // NOLINT(*)
+ auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*)
os << ", " << key << "=";
fp(idx.entry_id(nid, 0), os);
};
if (idx.input_nodes().size() < 4) {
for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
uint32_t nid = idx.input_nodes()[i];
- if (i != 0) {
+ if (i != 0) {
os << ", ";
}
os << '%' << idx[nid].source->attrs.name;
} else {
for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
uint32_t nid = idx.input_nodes()[i];
- if (i != 0) {
+ if (i != 0) {
os << ",\n ";
}
os << '%' << idx[nid].source->attrs.name;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
- os << " " << "%" << nid << " = "
- << inode.source->op()->name << "(";
+ os << " "
+ << "%" << nid << " = " << inode.source->op()->name << "(";
bool first = true;
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (first) {
std::ostringstream os;
std::vector<std::string> join_entry_attrs, join_node_attrs;
if (src.attrs.count("join_entry_attrs") != 0) {
- join_entry_attrs = src.MoveCopyAttr<std::vector<std::string> >(
- "join_entry_attrs");
+ join_entry_attrs = src.MoveCopyAttr<std::vector<std::string> >("join_entry_attrs");
}
if (src.attrs.count("join_node_attrs") != 0) {
- join_node_attrs = src.MoveCopyAttr<std::vector<std::string> >(
- "join_node_attrs");
+ join_node_attrs = src.MoveCopyAttr<std::vector<std::string> >("join_node_attrs");
}
PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os);
Graph ret;
// register pass
NNVM_REGISTER_PASS(PrintGraphIR)
-.describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]")
-.set_body(PrintGraphIRPass);
+ .describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]")
+ .set_body(PrintGraphIRPass);
} // namespace pass
} // namespace nnvm
* \file saveload_json.cc
* \brief Save and load graph to/from JSON file.
*/
+#include <dmlc/json.h>
#include <nnvm/pass.h>
#include <nnvm/pass_functions.h>
-#include <dmlc/json.h>
+
#include <algorithm>
namespace dmlc {
namespace json {
// overload handler for shared ptr
-template<>
-struct Handler<std::shared_ptr<any> > {
- inline static void Write(JSONWriter *writer, const std::shared_ptr<any> &data) {
+template <>
+struct Handler<std::shared_ptr<any>> {
+ inline static void Write(JSONWriter* writer, const std::shared_ptr<any>& data) {
writer->Write(*data);
}
- inline static void Read(JSONReader *reader, std::shared_ptr<any> *data) {
+ inline static void Read(JSONReader* reader, std::shared_ptr<any>* data) {
any v;
reader->Read(&v);
*data = std::make_shared<any>(std::move(v));
uint32_t index;
uint32_t version;
Entry() = default;
- Entry(uint32_t node_id, uint32_t index, uint32_t version):
- node_id(node_id), index(index), version(version) {
- }
- void Save(dmlc::JSONWriter *writer) const {
+ Entry(uint32_t node_id, uint32_t index, uint32_t version)
+ : node_id(node_id), index(index), version(version) {}
+ void Save(dmlc::JSONWriter* writer) const {
writer->BeginArray(false);
writer->WriteArrayItem(node_id);
writer->WriteArrayItem(index);
writer->WriteArrayItem(version);
writer->EndArray();
}
- void Load(dmlc::JSONReader *reader) {
+ void Load(dmlc::JSONReader* reader) {
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&node_id);
std::vector<JSONGraph> subgraphs;
// function to save JSON node.
- void Save(dmlc::JSONWriter *writer) const {
+ void Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
if (node->op() != nullptr) {
writer->WriteObjectKeyValue("op", node->op()->name);
writer->WriteObjectKeyValue("name", node->attrs.name);
if (node->attrs.dict.size() != 0) {
// write attributes in order;
- std::map<std::string, std::string> dict(
- node->attrs.dict.begin(), node->attrs.dict.end());
+ std::map<std::string, std::string> dict(node->attrs.dict.begin(), node->attrs.dict.end());
writer->WriteObjectKeyValue("attrs", dict);
}
writer->WriteObjectKeyValue("inputs", inputs);
writer->EndObject();
}
- void Load(dmlc::JSONReader *reader) {
+ void Load(dmlc::JSONReader* reader) {
node = Node::Create();
control_deps.clear();
dmlc::JSONObjectReadHelper helper;
if (op_type_str != "null") {
try {
node->attrs.op = Op::Get(op_type_str);
- } catch (const dmlc::Error &err) {
+ } catch (const dmlc::Error& err) {
std::ostringstream os;
- os << "Failed loading Op " << node->attrs.name
- << " of type " << op_type_str << ": " << err.what();
+ os << "Failed loading Op " << node->attrs.name << " of type " << op_type_str << ": "
+ << err.what();
throw dmlc::Error(os.str());
}
} else {
std::vector<uint32_t> arg_nodes;
std::vector<uint32_t> node_row_ptr;
std::vector<JSONNode::Entry> heads;
- std::unordered_map<std::string, std::shared_ptr<any> > attrs;
+ std::unordered_map<std::string, std::shared_ptr<any>> attrs;
- void Save(dmlc::JSONWriter *writer) const {
+ void Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("nodes", nodes);
writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
writer->EndObject();
}
- void Load(dmlc::JSONReader *reader) {
+ void Load(dmlc::JSONReader* reader) {
attrs.clear();
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("nodes", &nodes);
}
};
-void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
+void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph* jgraph) {
std::unordered_map<Node*, uint32_t> node2index;
jgraph->node_row_ptr.push_back(0);
DFSVisit(src->outputs, [&node2index, jgraph](const ObjectPtr& n) {
jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version);
}
// recursively construct subgraphs
- for (JSONNode &jnode : jgraph->nodes) {
+ for (JSONNode& jnode : jgraph->nodes) {
// construct jnode's subgraphs
- const std::vector<std::shared_ptr<Symbol>> &subgraphs = jnode.node->attrs.subgraphs;
- std::vector<JSONGraph> &jsubgraphs = jnode.subgraphs;
+ const std::vector<std::shared_ptr<Symbol>>& subgraphs = jnode.node->attrs.subgraphs;
+ std::vector<JSONGraph>& jsubgraphs = jnode.subgraphs;
jsubgraphs.resize(subgraphs.size());
for (uint32_t i = 0; i < subgraphs.size(); ++i) {
Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]);
}
}
-std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) {
- for (const JSONNode &n : jgraph.nodes) {
+std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph& jgraph, bool no_parse) {
+ for (const JSONNode& n : jgraph.nodes) {
n.node->inputs.reserve(n.inputs.size());
- for (const JSONNode::Entry &e : n.inputs) {
+ for (const JSONNode::Entry& e : n.inputs) {
CHECK(e.node_id < jgraph.nodes.size());
n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
CHECK(nid < jgraph.nodes.size());
n.node->control_deps.push_back(jgraph.nodes[nid].node);
}
- for (const JSONGraph &subgraph : n.subgraphs) {
+ for (const JSONGraph& subgraph : n.subgraphs) {
// The "no_parse" option here, is to be compatible with
// commit cfd3075e85807dcd8f9534c37e053583dee87524
// (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524),
n.node->op()->attr_parser(&(n.node->attrs));
} else if (!no_parse && n.node->is_variable()) {
n.node->attrs.parsed =
- Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed;
+ Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed;
}
}
// consistency check
}
std::shared_ptr<Symbol> symbol = std::make_shared<Symbol>();
symbol->outputs.reserve(jgraph.heads.size());
- for (const JSONNode::Entry &e : jgraph.heads) {
+ for (const JSONNode::Entry& e : jgraph.heads) {
CHECK(e.node_id < jgraph.nodes.size());
symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
// Load a graph from JSON file.
Graph LoadJSON(Graph src) {
- CHECK_NE(src.attrs.count("json"), 0U)
- << "Load JSON require json to be presented.";
- const std::string &json_str =
- nnvm::get<std::string>(*src.attrs.at("json"));
+ CHECK_NE(src.attrs.count("json"), 0U) << "Load JSON require json to be presented.";
+ const std::string& json_str = nnvm::get<std::string>(*src.attrs.at("json"));
bool no_parse = false;
if (src.attrs.count("load_json_no_parse")) {
no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
// register pass
NNVM_REGISTER_PASS(LoadJSON)
-.describe("Return a new Graph, loaded from src.attrs[\"json\"]")
-.set_body(LoadJSON)
-.set_change_graph(true)
-.depend_graph_attr("json");
+ .describe("Return a new Graph, loaded from src.attrs[\"json\"]")
+ .set_body(LoadJSON)
+ .set_change_graph(true)
+ .depend_graph_attr("json");
NNVM_REGISTER_PASS(SaveJSON)
-.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]")
-.set_body(SaveJSON)
-.set_change_graph(true)
-.provide_graph_attr("json");
-
+ .describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]")
+ .set_body(SaveJSON)
+ .set_change_graph(true)
+ .provide_graph_attr("json");
DMLC_JSON_ENABLE_ANY(std::string, str);
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <nnvm/op.h>
-#include <utility>
-NNVM_REGISTER_OP(add)
-.describe("add two data together")
-.set_num_inputs(2)
-.set_attr("inplace_pair", std::make_pair(0, 0));
+#include <utility>
NNVM_REGISTER_OP(add)
-.set_attr<std::string>("nick_name", "plus");
+ .describe("add two data together")
+ .set_num_inputs(2)
+ .set_attr("inplace_pair", std::make_pair(0, 0));
+NNVM_REGISTER_OP(add).set_attr<std::string>("nick_name", "plus");
TEST(Op, GetAttr) {
using namespace nnvm;
CHECK_EQ(nick[add], "plus");
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <nnvm/tuple.h>
TEST(Tuple, Basic) {
- using nnvm::Tuple;
using nnvm::TShape;
+ using nnvm::Tuple;
Tuple<int> x{1, 2, 3};
Tuple<int> y{1, 2, 3, 5, 6};
x = std::move(y);
CHECK((s == TShape{1, 2, 3}));
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
/*!
* \file tvm/arith/analyzer.cc
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
namespace tvm {
modular_set(this),
rewrite_simplify(this),
canonical_simplify(this),
- int_set(this) {
-}
+ int_set(this) {}
void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
PrimExpr new_expr = expr;
return res;
}
-TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- using runtime::PackedFunc;
- using runtime::TypedPackedFunc;
- auto self = std::make_shared<Analyzer>();
- auto f = [self](std::string name) -> PackedFunc {
- if (name == "const_int_bound") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->const_int_bound(args[0]);
- });
- } else if (name == "modular_set") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->modular_set(args[0]);
- });
- } else if (name == "const_int_bound_update") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- self->const_int_bound.Update(args[0], args[1], args[2]);
- });
- } else if (name == "Simplify") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->Simplify(args[0]);
- });
- } else if (name == "rewrite_simplify") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->rewrite_simplify(args[0]);
- });
- } else if (name == "canonical_simplify") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->canonical_simplify(args[0]);
- });
- } else if (name == "int_set") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->int_set(args[0], args[1]);
- });
- } else if (name == "bind") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- if (args[1].IsObjectRef<Range>()) {
- self->Bind(args[0], args[1].operator Range());
- } else {
- self->Bind(args[0], args[1].operator PrimExpr());
- }
- });
- } else if (name == "enter_constraint_context") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- // can't use make_shared due to noexcept(false) decl in destructor,
- // see https://stackoverflow.com/a/43907314
- auto ctx = std::shared_ptr<With<ConstraintContext> >(
- new With<ConstraintContext>(self.get(), args[0]));
- auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
- ctx.reset();
- };
- *ret = PackedFunc(fexit);
- });
- }
- return PackedFunc();
- };
- *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
+TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) {
+ using runtime::PackedFunc;
+ using runtime::TypedPackedFunc;
+ auto self = std::make_shared<Analyzer>();
+ auto f = [self](std::string name) -> PackedFunc {
+ if (name == "const_int_bound") {
+ return PackedFunc(
+ [self](TVMArgs args, TVMRetValue* ret) { *ret = self->const_int_bound(args[0]); });
+ } else if (name == "modular_set") {
+ return PackedFunc(
+ [self](TVMArgs args, TVMRetValue* ret) { *ret = self->modular_set(args[0]); });
+ } else if (name == "const_int_bound_update") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
+ self->const_int_bound.Update(args[0], args[1], args[2]);
+ });
+ } else if (name == "Simplify") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { *ret = self->Simplify(args[0]); });
+ } else if (name == "rewrite_simplify") {
+ return PackedFunc(
+ [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); });
+ } else if (name == "canonical_simplify") {
+ return PackedFunc(
+ [self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); });
+ } else if (name == "int_set") {
+ return PackedFunc(
+ [self](TVMArgs args, TVMRetValue* ret) { *ret = self->int_set(args[0], args[1]); });
+ } else if (name == "bind") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
+ if (args[1].IsObjectRef<Range>()) {
+ self->Bind(args[0], args[1].operator Range());
+ } else {
+ self->Bind(args[0], args[1].operator PrimExpr());
+ }
+ });
+ } else if (name == "enter_constraint_context") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
+ // can't use make_shared due to noexcept(false) decl in destructor,
+ // see https://stackoverflow.com/a/43907314
+ auto ctx = std::shared_ptr<With<ConstraintContext> >(
+ new With<ConstraintContext>(self.get(), args[0]));
+ auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); };
+ *ret = PackedFunc(fexit);
+ });
+ }
+ return PackedFunc();
+ };
+ *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
});
} // namespace arith
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
-#include <tvm/arith/analyzer.h>
-#include <unordered_set>
#include <unordered_map>
+#include <unordered_set>
+
#include "interval_set.h"
namespace tvm {
// a visitor to find the path to the target variable
// from a expression.
-class VariablePathFinder: public ExprVisitor {
+class VariablePathFinder : public ExprVisitor {
public:
explicit VariablePathFinder(PrimExpr target) : target_(target) {}
return v.path_;
}
-enum CompareOp {kGreater, kLess, kEqual};
+enum CompareOp { kGreater, kLess, kEqual };
// a visitor to deduce the bound of a variable from a expression
-class BoundDeducer: public ExprVisitor {
+class BoundDeducer : public ExprVisitor {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
BoundDeducer(PrimExpr target, PrimExpr expr,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map)
- : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
+ : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
void Deduce();
result_ += op->b;
} else {
result_ -= op->a;
- result_ = - result_;
+ result_ = -result_;
comp_op = ReverseOp(comp_op);
}
this->VisitExpr(left ? op->a : op->b);
// always use relax bound
bool divided = analyzer_.CanProve(floormod(result_, operand) == 0);
- result_ = floordiv(result_, operand); // rounding down here
+ result_ = floordiv(result_, operand); // rounding down here
if (!divided) {
if (comp_op == kGreater) {
Analyzer analyzer_;
};
-class BoundDeduceInputChecker: public ExprVisitor {
+class BoundDeduceInputChecker : public ExprVisitor {
public:
bool Check(BoundDeducer* deducer) {
deducer_ = deducer;
CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
switch (comp_op) {
- case kEqual: return kEqual; // IntSet can not represent range for `NE
- case kGreater: return kLess;
- case kLess: return kGreater;
+ case kEqual:
+ return kEqual; // IntSet can not represent range for `NE
+ case kGreater:
+ return kLess;
+ case kLess:
+ return kGreater;
default:
LOG(FATAL) << "Not a valid compare op";
return kGreater; // return some default value
// Both LHS and RHS of the EQ should behave as constants e.g. i == j,
// can not be resolved when either `i` or `j` or both are variables with
// some Range OR `i` and `j` both should be a single point in IntSet
- if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max())
- || !analyzer_.CanProve(a.min() == a.max()))) {
+ if (comp_op == kEqual &&
+ (!analyzer_.CanProve(b.min() == b.max()) || !analyzer_.CanProve(a.min() == a.max()))) {
success_ = false;
return;
}
- expr_ = (comp_op == kGreater) ? a.min() : a.max();
+ expr_ = (comp_op == kGreater) ? a.min() : a.max();
result_ = (comp_op == kGreater) ? b.max() : b.min();
}
IntSet DeduceBound(PrimExpr v, PrimExpr e,
- const std::unordered_map<const VarNode*, IntSet>& hint_map,
- const std::unordered_map<const VarNode*, IntSet>& relax_map) {
+ const std::unordered_map<const VarNode*, IntSet>& hint_map,
+ const std::unordered_map<const VarNode*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success_) return IntSet::nothing();
// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
-IntSet DeduceBound(PrimExpr v, PrimExpr e,
- const Map<Var, IntSet>& hint_map,
+IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map) {
std::unordered_map<const VarNode*, IntSet> hmap;
for (auto kv : hint_map) {
return DeduceBound(v, e, hmap, rmap);
}
-
TVM_REGISTER_GLOBAL("arith.DeduceBound")
-.set_body_typed([](
- PrimExpr v, PrimExpr cond,
- const Map<Var, IntSet> hint_map,
- const Map<Var, IntSet> relax_map
-) {
- return DeduceBound(v, cond, hint_map, relax_map);
-});
-
+ .set_body_typed([](PrimExpr v, PrimExpr cond, const Map<Var, IntSet> hint_map,
+ const Map<Var, IntSet> relax_map) {
+ return DeduceBound(v, cond, hint_map, relax_map);
+ });
} // namespace arith
} // namespace tvm
* \brief Canonical form based simplification.
*/
#include <tvm/arith/analyzer.h>
-#include <tvm/tir/op.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/op.h>
#include "const_fold.h"
#include "pattern_match.h"
class SumExpr;
class SplitExpr;
-
/*!
* \brief Base class of all temporary expression introduced
* for canonicalization.
virtual PrimExpr Normalize() const = 0;
// overrides
- void VisitAttrs(tvm::AttrVisitor* v) {
- }
+ void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "arith.CanonicalExpr";
static constexpr const uint32_t _type_child_slots = 2;
DivMode div_mode{kTruncDiv};
/*! \brief verify that this is a valid entry. */
- void Verify() const {
- CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0);
- }
+ void Verify() const { CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); }
PrimExpr NormalizeWithScale(int64_t sscale) const {
PrimExpr res = this->index;
return res;
}
- PrimExpr Normalize() const final {
- return NormalizeWithScale(1);
- }
+ PrimExpr Normalize() const final { return NormalizeWithScale(1); }
- void MulToSelf(int64_t scale) {
- this->scale *= scale;
- }
+ void MulToSelf(int64_t scale) { this->scale *= scale; }
inline bool IndexEqual(const SplitExpr& other) const;
inline bool DivModeCompatibleTo(DivMode mode) const;
/*! \brief Base value in the summation. */
int64_t base{0};
/*! \brief The expression equals zero. */
- bool IsZero() const {
- return base == 0 && args.size() == 0;
- }
+ bool IsZero() const { return base == 0 && args.size() == 0; }
/*!
* \brief Return the normal Expr that is equivalent to self.
* \return The normal expression.
if (this->args.size() == 0) {
return make_const(this->dtype, this->base);
}
- return Normalize_(this->dtype,
- SimplifySplitExprs(args),
- base);
+ return Normalize_(this->dtype, SimplifySplitExprs(args), base);
}
/*!
* \brief Whether self is divisible by scale.
* \brief add constant value to self.
* \param value to be added.
*/
- void AddToSelf(int64_t value) {
- this->base += value;
- }
+ void AddToSelf(int64_t value) { this->base += value; }
/*!
* \brief self += other * scale;
* \param other The expression to be added.
if (args[start]->IndexEqual(other)) break;
}
for (size_t j = start; j < args.size(); ++j) {
- if (!args[j]->IndexEqual(other) ||
- other->lower_factor > args[j]->lower_factor) {
+ if (!args[j]->IndexEqual(other) || other->lower_factor > args[j]->lower_factor) {
other.CopyOnWrite()->scale *= scale;
this->args.insert(this->args.begin() + j, other);
return;
* \param args The original list of arguments.
* \return simplified version.
*/
- static std::vector<SplitExpr>
- SimplifySplitExprs(std::vector<SplitExpr> args) {
+ static std::vector<SplitExpr> SimplifySplitExprs(std::vector<SplitExpr> args) {
// NOTE: This algorithm relies on the factor that args are divided into segments
// and each segment is sorted in descending order of lower_factor.
for (size_t i = 0; i < args.size(); ++i) {
SplitExpr& rhs = args[j];
if (!lhs->IndexEqual(rhs)) break;
if (lhs->upper_factor < rhs->lower_factor) break;
- if (lhs->upper_factor == rhs->upper_factor &&
- lhs->lower_factor == rhs->lower_factor &&
+ if (lhs->upper_factor == rhs->upper_factor && lhs->lower_factor == rhs->lower_factor &&
lhs->DivModeCompatibleTo(rhs->div_mode)) {
// folding same co-efficient.
rhs.CopyOnWrite()->scale += lhs->scale;
lhs.CopyOnWrite()->scale = 0;
- } else if (lhs->lower_factor == rhs->upper_factor &&
- rhs->scale != 0 &&
+ } else if (lhs->lower_factor == rhs->upper_factor && rhs->scale != 0 &&
lhs->scale % rhs->scale == 0 &&
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor &&
lhs->DivModeCompatibleTo(rhs->div_mode)) {
std::stable_sort(args.begin(), args.end(), fcompare);
return args;
}
- static PrimExpr Normalize_(DataType dtype,
- const std::vector<SplitExpr>& args,
- int64_t base) {
+ static PrimExpr Normalize_(DataType dtype, const std::vector<SplitExpr>& args, int64_t base) {
// Positive scales first
PrimExpr res = make_const(dtype, 0);
for (size_t i = 0; i < args.size(); ++i) {
public:
using Rewriter = RewriteSimplifier::Impl;
- explicit Impl(Analyzer* parent)
- : Rewriter(parent) {}
-
+ explicit Impl(Analyzer* parent) : Rewriter(parent) {}
PrimExpr CanonicalSimplify(PrimExpr expr) {
expr = operator()(expr);
}
// Normal mutation without normalization.
- PrimExpr CanonicalMutate(PrimExpr expr) {
- return Rewriter::VisitExpr(expr);
- }
+ PrimExpr CanonicalMutate(PrimExpr expr) { return Rewriter::VisitExpr(expr); }
using Rewriter::VisitExpr_;
PrimExpr VisitExpr_(const AddNode* op) final;
* \param out_divisible The result divisible component.
* \param out_non_divisible The non-divisible component.
*/
- void SeparateDivisibleParts(const SumExprNode* psum,
- int64_t coeff,
- SumExpr* out_divisible,
+ void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible,
SumExpr* out_non_divisible);
/*!
* \brief Normalize expr to normal expr.
PrimExpr SimplifyReduceCombiner(const ReduceNode* op);
};
-PrimExpr CanonicalSimplifier::Impl::
-VisitExpr_(const AddNode* op) {
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
return std::move(ret);
}
-PrimExpr CanonicalSimplifier::Impl::
-VisitExpr_(const SubNode* op) {
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
return std::move(ret);
}
-
-PrimExpr CanonicalSimplifier::Impl::
-VisitExpr_(const MulNode* op) {
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
}
}
-void CanonicalSimplifier::Impl::
-SeparateDivisibleParts(const SumExprNode* psum,
- int64_t coeff,
- SumExpr* out_divisible,
- SumExpr* out_non_divisible) {
+void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff,
+ SumExpr* out_divisible,
+ SumExpr* out_non_divisible) {
auto divisible = make_object<SumExprNode>();
auto non_divisible = make_object<SumExprNode>();
divisible->dtype = psum->dtype;
*out_non_divisible = SumExpr(non_divisible);
}
-SplitExpr CanonicalSimplifier::Impl::
-SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
+SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
CHECK_GT(cval, 0);
lhs = ConvertDivMode(lhs, div_mode);
return lhs;
}
-PrimExpr CanonicalSimplifier::Impl::
-VisitExpr_(const DivNode* op) {
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (TryCompare(temp, cval) != kLT) {
- lhs.CopyOnWrite()->AddToSelf(
- SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1);
+ lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1);
}
}
return std::move(lhs);
}
}
-PrimExpr CanonicalSimplifier::Impl::
-VisitExpr_(const FloorDivNode* op) {
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) {
- lhs.CopyOnWrite()->AddToSelf(
- SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1);
+ lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1);
}
}
return std::move(lhs);
}
}
-SplitExpr CanonicalSimplifier::Impl::
-SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
+SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
CHECK_GT(cval, 0);
lhs = ConvertDivMode(lhs, div_mode);
// (x / c1) % c2 => (x % (c1 * c2)) / c2
int64_t new_upper_factor = lhs->lower_factor * scaled_cval;
// try to see if we can reduce the existing upper modular.
- if (lhs->upper_factor == SplitExprNode::kPosInf ||
- lhs->upper_factor % new_upper_factor == 0) {
+ if (lhs->upper_factor == SplitExprNode::kPosInf || lhs->upper_factor % new_upper_factor == 0) {
// we gained a new upper factor that is smaller
// than the original one
// Perhaps there are more chances in simplifying the index
// Do a recursive call to simplify the mod with the new factor.
- if (new_upper_factor < lhs->upper_factor &&
- lhs->upper_factor != SplitExprNode::kPosInf) {
- auto updated = ToSplitExpr(this->VisitExpr(ModImpl(
- lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode)));
+ if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) {
+ auto updated = ToSplitExpr(this->VisitExpr(
+ ModImpl(lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode)));
updated.CopyOnWrite()->scale = lhs->scale;
// re-apply the lower_factor
if (lhs->lower_factor != 1) {
return lhs;
}
-PrimExpr CanonicalSimplifier::Impl::
-VisitExpr_(const ModNode* op) {
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0
auto cbound = analyzer_->const_int_bound(Normalize(a));
int64_t new_base = psum->base % cval;
- if (cbound->min_value >= 0 &&
- cbound->min_value - psum->base + new_base >= 0) {
+ if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) {
SumExpr sum_expr = Downcast<SumExpr>(a);
sum_expr.CopyOnWrite()->base = new_base;
return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv);
}
}
-PrimExpr CanonicalSimplifier::Impl::
-VisitExpr_(const FloorModNode* op) {
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
return floormod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
- if (TryCompare(temp, cval) == kLT &&
- analyzer_->CanProveGreaterEqual(temp, 0)) {
+ if (TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0)) {
return temp;
} else {
// contonue to use logic below.
}
// Simplify reduce expression.
-PrimExpr CanonicalSimplifier::Impl::
-SimplifyReduceCombiner(const ReduceNode* op) {
+PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) {
// First simplify the results
Array<PrimExpr> simplified_result;
for (const auto& res : op->combiner->result) {
// components which have side effects should also be preserved
for (size_t i = 0; i < used.size(); ++i) {
- if (HasSideEffect(op->source[i]) ||
- HasSideEffect(op->combiner->identity_element[i]) ||
+ if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) ||
HasSideEffect(op->combiner->result[i])) {
mark_used(i);
}
}
}
- CommReducer new_combiner =
- CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
- return ReduceNode::make(
- new_combiner, new_source, op->axis, op->condition, new_value_index);
+ CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
+ return ReduceNode::make(new_combiner, new_source, op->axis, op->condition, new_value_index);
}
-PrimExpr CanonicalSimplifier::Impl::
-VisitExpr_(const ReduceNode* op) {
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
// Recursively call simplification when necessary.
PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op);
op = ret.as<ReduceNode>();
// assumption we would have to perform a single iteration of the loop, i.e. use
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
// instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
- return this->VisitExpr(
- SelectNode::make(op->condition,
- op->source[op->value_index],
- op->combiner->identity_element[op->value_index]));
+ return this->VisitExpr(SelectNode::make(op->condition, op->source[op->value_index],
+ op->combiner->identity_element[op->value_index]));
}
// combiner simplification.
ret = SimplifyReduceCombiner(op);
return impl_->CanonicalSimplify(expr);
}
-void CanonicalSimplifier::Update(const Var& var,
- const PrimExpr& info,
- bool override) {
+void CanonicalSimplifier::Update(const Var& var, const PrimExpr& info, bool override) {
impl_->Update(var, info, override);
}
-CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent)
- : impl_(new Impl(parent)) {
-}
+CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {}
-CanonicalSimplifier::~CanonicalSimplifier() {
- delete impl_;
-}
+CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; }
} // namespace arith
} // namespace tvm
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
-#include <limits>
+
#include <algorithm>
+#include <limits>
namespace tvm {
namespace arith {
* \tparam Op the computation operator
* \return The result.
*/
-template<typename OP>
+template <typename OP>
inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) {
return OP::make(lhs, rhs);
}
* \tparam Op The computation operator
* \return The result.
*/
-template<typename Op>
-inline PrimExpr ComputeReduce(
- const Array<PrimExpr>& values, PrimExpr empty_value);
+template <typename Op>
+inline PrimExpr ComputeReduce(const Array<PrimExpr>& values, PrimExpr empty_value);
-template<>
+template <>
inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
return a + b;
}
-template<>
+template <>
inline PrimExpr Compute<tir::SubNode>(PrimExpr a, PrimExpr b) {
return a - b;
}
-template<>
+template <>
inline PrimExpr Compute<tir::MulNode>(PrimExpr a, PrimExpr b) {
return a * b;
}
-template<>
+template <>
inline PrimExpr Compute<tir::DivNode>(PrimExpr a, PrimExpr b) {
return truncdiv(a, b);
}
-template<>
+template <>
inline PrimExpr Compute<tir::ModNode>(PrimExpr a, PrimExpr b) {
return truncmod(a, b);
}
-template<>
+template <>
inline PrimExpr Compute<tir::MaxNode>(PrimExpr a, PrimExpr b) {
return max(a, b);
}
-template<>
+template <>
inline PrimExpr Compute<tir::MinNode>(PrimExpr a, PrimExpr b) {
return min(a, b);
}
-template<typename Op>
+template <typename Op>
inline PrimExpr ComputeReduce(const Array<PrimExpr>& values, PrimExpr empty_value) {
if (values.size() == 0U) {
CHECK(empty_value.defined());
} // namespace arith
} // namespace tvm
-#endif // TVM_ARITH_COMPUTE_EXPR_H_
+#endif // TVM_ARITH_COMPUTE_EXPR_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
+
#include <algorithm>
#include <cmath>
+
#include "int_operator.h"
namespace tvm {
* \note a and b Must already matched data types with each other.
* \return nullptr if constant fold fails, otherwise return folded result.
*/
-template<typename Op>
+template <typename Op>
inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) {
return PrimExpr();
}
* \note a and b Must already matched data types with each other.
* \return nullptr if constant fold fails, otherwise return folded result.
*/
-template<typename Op>
+template <typename Op>
inline PrimExpr TryConstFold(PrimExpr a);
/*!
* \return the checked result.
*/
inline bool IsIndexType(const DataType& type) {
- return type.is_int() && type.lanes() == 1 &&
- (type.bits() == 32 || type.bits() == 64);
+ return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64);
}
-
-#define TVM_ARITH_CONST_PROPAGATION(BODY) \
- using tir::FloatImmNode; \
- const IntImmNode* pa = a.as<IntImmNode>(); \
- const IntImmNode* pb = b.as<IntImmNode>(); \
- const FloatImmNode* fa = a.as<FloatImmNode>(); \
- const FloatImmNode* fb = b.as<FloatImmNode>(); \
+#define TVM_ARITH_CONST_PROPAGATION(BODY) \
+ using tir::FloatImmNode; \
+ const IntImmNode* pa = a.as<IntImmNode>(); \
+ const IntImmNode* pb = b.as<IntImmNode>(); \
+ const FloatImmNode* fa = a.as<FloatImmNode>(); \
+ const FloatImmNode* fb = b.as<FloatImmNode>(); \
BODY;
-
-#define TVM_INDEX_CONST_PROPAGATION(BODY) \
- const IntImmNode* pa = a.as<IntImmNode>(); \
- const IntImmNode* pb = b.as<IntImmNode>(); \
- const DataType& ta = a.dtype(); \
- const DataType& tb = b.dtype(); \
- if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \
- BODY; \
- } \
-
+#define TVM_INDEX_CONST_PROPAGATION(BODY) \
+ const IntImmNode* pa = a.as<IntImmNode>(); \
+ const IntImmNode* pb = b.as<IntImmNode>(); \
+ const DataType& ta = a.dtype(); \
+ const DataType& tb = b.dtype(); \
+ if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \
+ BODY; \
+ }
// specialization of constant folders.
-template<>
+template <>
inline PrimExpr TryConstFold<tir::AddNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm(rtype, pa->value + pb->value);
- if (pa && pa->value == 0) return b;
- if (pb && pb->value == 0) return a;
- if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
- if (fa && fa->value == 0) return b;
- if (fb && fb->value == 0) return a;
- });
+ const DataType& rtype = a.dtype();
+ if (pa && pb) return IntImm(rtype, pa->value + pb->value);
+ if (pa && pa->value == 0) return b;
+ if (pb && pb->value == 0) return a;
+ if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
+ if (fa && fa->value == 0) return b;
+ if (fb && fb->value == 0) return a;
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::SubNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm(rtype, pa->value - pb->value);
- if (pb && pb->value == 0) return a;
- if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
- if (fb && fb->value == 0) return a;
- });
+ const DataType& rtype = a.dtype();
+ if (pa && pb) return IntImm(rtype, pa->value - pb->value);
+ if (pb && pb->value == 0) return a;
+ if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
+ if (fb && fb->value == 0) return a;
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::MulNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm(rtype, pa->value * pb->value);
- if (pa) {
- if (pa->value == 1) return b;
- if (pa->value == 0) return a;
- }
- if (pb) {
- if (pb->value == 1) return a;
- if (pb->value == 0) return b;
- }
- if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
- if (fa) {
- if (fa->value == 1) return b;
- if (fa->value == 0) return a;
- }
- if (fb) {
- if (fb->value == 1) return a;
- if (fb->value == 0) return b;
- }
- });
+ const DataType& rtype = a.dtype();
+ if (pa && pb) return IntImm(rtype, pa->value * pb->value);
+ if (pa) {
+ if (pa->value == 1) return b;
+ if (pa->value == 0) return a;
+ }
+ if (pb) {
+ if (pb->value == 1) return a;
+ if (pb->value == 0) return b;
+ }
+ if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
+ if (fa) {
+ if (fa->value == 1) return b;
+ if (fa->value == 0) return a;
+ }
+ if (fb) {
+ if (fb->value == 1) return a;
+ if (fb->value == 0) return b;
+ }
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::DivNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) {
- // due to division and mod can have different modes
- // NOTE: this will assumes truc div.
- CHECK_NE(pb->value, 0) << "Divide by zero";
- return IntImm(rtype, pa->value / pb->value);
- }
- if (pa) {
- if (pa->value == 0) return a;
- }
- if (pb) {
- if (pb->value == 1) return a;
- CHECK_NE(pb->value, 0) << "Divide by zero";
- }
- if (fa && fb && fb->value != 0) {
- return FloatImm(rtype, fa->value / fb->value);
- }
- if (fa && fa->value == 0) return a;
- if (fb) {
- if (fb->value == 1) return a;
- CHECK_NE(fb->value, 0) << "Divide by zero";
- }
- });
+ const DataType& rtype = a.dtype();
+ if (pa && pb) {
+ // due to division and mod can have different modes
+ // NOTE: this will assumes truc div.
+ CHECK_NE(pb->value, 0) << "Divide by zero";
+ return IntImm(rtype, pa->value / pb->value);
+ }
+ if (pa) {
+ if (pa->value == 0) return a;
+ }
+ if (pb) {
+ if (pb->value == 1) return a;
+ CHECK_NE(pb->value, 0) << "Divide by zero";
+ }
+ if (fa && fb && fb->value != 0) {
+ return FloatImm(rtype, fa->value / fb->value);
+ }
+ if (fa && fa->value == 0) return a;
+ if (fb) {
+ if (fb->value == 1) return a;
+ CHECK_NE(fb->value, 0) << "Divide by zero";
+ }
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) {
- CHECK_NE(pb->value, 0) << "Divide by zero";
- return IntImm(rtype, pa->value % pb->value);
- }
- if (pa) {
- if (pa->value == 0) return a;
- }
- if (pb) {
- if (pb->value == 1) return tir::make_zero(rtype);
- CHECK_NE(pb->value, 0) << "Divide by zero";
- }
- });
+ const DataType& rtype = a.dtype();
+ if (pa && pb) {
+ CHECK_NE(pb->value, 0) << "Divide by zero";
+ return IntImm(rtype, pa->value % pb->value);
+ }
+ if (pa) {
+ if (pa->value == 0) return a;
+ }
+ if (pb) {
+ if (pb->value == 1) return tir::make_zero(rtype);
+ CHECK_NE(pb->value, 0) << "Divide by zero";
+ }
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::FloorDivNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) {
- CHECK_NE(pb->value, 0) << "Divide by zero";
- return IntImm(rtype, arith::floordiv(pa->value, pb->value));
- }
- if (pa) {
- if (pa->value == 0) return a;
- }
- if (pb) {
- if (pb->value == 1) return a;
- CHECK_NE(pb->value, 0) << "Divide by zero";
- }
- if (fa && fb && fb->value != 0) {
- return FloatImm(rtype, std::floor(fa->value / fb->value));
- }
- if (fa && fa->value == 0) return a;
- if (fb) {
- if (fb->value == 1) return a;
- CHECK_NE(fb->value, 0) << "Divide by zero";
- }
- });
+ const DataType& rtype = a.dtype();
+ if (pa && pb) {
+ CHECK_NE(pb->value, 0) << "Divide by zero";
+ return IntImm(rtype, arith::floordiv(pa->value, pb->value));
+ }
+ if (pa) {
+ if (pa->value == 0) return a;
+ }
+ if (pb) {
+ if (pb->value == 1) return a;
+ CHECK_NE(pb->value, 0) << "Divide by zero";
+ }
+ if (fa && fb && fb->value != 0) {
+ return FloatImm(rtype, std::floor(fa->value / fb->value));
+ }
+ if (fa && fa->value == 0) return a;
+ if (fb) {
+ if (fb->value == 1) return a;
+ CHECK_NE(fb->value, 0) << "Divide by zero";
+ }
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) {
- CHECK_NE(pb->value, 0) << "Divide by zero";
- return IntImm(rtype, floormod(pa->value, pb->value));
- }
- if (pa) {
- if (pa->value == 0) return a;
- }
- if (pb) {
- if (pb->value == 1) return tir::make_zero(rtype);
- CHECK_NE(pb->value, 0) << "Divide by zero";
- }
- });
+ const DataType& rtype = a.dtype();
+ if (pa && pb) {
+ CHECK_NE(pb->value, 0) << "Divide by zero";
+ return IntImm(rtype, floormod(pa->value, pb->value));
+ }
+ if (pa) {
+ if (pa->value == 0) return a;
+ }
+ if (pb) {
+ if (pb->value == 1) return tir::make_zero(rtype);
+ CHECK_NE(pb->value, 0) << "Divide by zero";
+ }
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::MinNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
- if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
- });
+ const DataType& rtype = a.dtype();
+ if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
+ if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
+ });
if (a.same_as(b)) return a;
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::MaxNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
- if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
- });
+ const DataType& rtype = a.dtype();
+ if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
+ if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
+ });
if (a.same_as(b)) return a;
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::GTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
- });
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::GENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
- });
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::LTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
- });
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::LENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
- });
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::EQNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
- });
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::NENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
- });
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
+ });
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::AndNode>(PrimExpr a, PrimExpr b) {
const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>();
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::OrNode>(PrimExpr a, PrimExpr b) {
const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>();
return PrimExpr();
}
-template<>
+template <>
inline PrimExpr TryConstFold<tir::NotNode>(PrimExpr a) {
const IntImmNode* pa = a.as<IntImmNode>();
if (pa) {
*
* \return positive infinity.
*/
-inline PrimExpr pos_inf() {
- return SymbolicLimits::pos_inf_;
-}
+inline PrimExpr pos_inf() { return SymbolicLimits::pos_inf_; }
/*!
* \brief Check if value is positive infinity.
*
* \return The check result.
*/
-inline bool is_pos_inf(const PrimExpr& value) {
- return value.same_as(SymbolicLimits::pos_inf_);
-}
+inline bool is_pos_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::pos_inf_); }
/*!
* \brief Opaque expression representing negative infinity.
*
* \return negative infinity.
*/
-inline PrimExpr neg_inf() {
- return SymbolicLimits::neg_inf_;
-}
+inline PrimExpr neg_inf() { return SymbolicLimits::neg_inf_; }
/*!
* \brief Check if value is negative infinity.
*
* \return The check result.
*/
-inline bool is_neg_inf(const PrimExpr& value) {
- return value.same_as(SymbolicLimits::neg_inf_);
-}
+inline bool is_neg_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::neg_inf_); }
} // namespace arith
} // namespace tvm
/*!
* \file tvm/arith/const_int_bound.cc
*/
-#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr_functor.h>
+
#include <algorithm>
+
#include "int_operator.h"
#include "pattern_match.h"
TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);
-ConstIntBound::ConstIntBound(
- int64_t min_value, int64_t max_value) {
+ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) {
auto node = make_object<ConstIntBoundNode>();
node->min_value = min_value;
node->max_value = max_value;
return ConstIntBound(min_value, max_value);
}
-TVM_REGISTER_GLOBAL("arith.ConstIntBound")
-.set_body_typed(MakeConstIntBound);
+TVM_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound);
inline void PrintBoundValue(std::ostream& os, int64_t val) {
if (val == ConstIntBound::kPosInf) {
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ConstIntBoundNode*>(node.get());
- p->stream << "ConstIntBound[";
- PrintBoundValue(p->stream, op->min_value);
- p->stream << ',';
- PrintBoundValue(p->stream, op->max_value);
- p->stream << ']';
- });
+ .set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ConstIntBoundNode*>(node.get());
+ p->stream << "ConstIntBound[";
+ PrintBoundValue(p->stream, op->min_value);
+ p->stream << ',';
+ PrintBoundValue(p->stream, op->max_value);
+ p->stream << ']';
+ });
// internal entry for const int bound
struct ConstIntBoundAnalyzer::Entry {
int64_t min_value;
int64_t max_value;
- bool is_const(int64_t value) const {
- return min_value == max_value && min_value == value;
- }
+ bool is_const(int64_t value) const { return min_value == max_value && min_value == value; }
bool operator==(const Entry& other) const {
return min_value == other.min_value && max_value == other.max_value;
}
};
-class ConstIntBoundAnalyzer::Impl :
- public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> {
+class ConstIntBoundAnalyzer::Impl
+ : public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> {
public:
/*! \brief additional bound info about expr \in bound */
struct BoundInfo {
Entry bound;
BoundInfo() {}
- BoundInfo(PrimExpr expr, Entry bound)
- : expr(expr), bound(bound) {
- }
+ BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {}
};
void Bind(const Var& var, const Range& range, bool override) {
Update(var, ret, override);
}
- void Update(const Var& var,
- const Entry& info,
- bool override) {
+ void Update(const Var& var, const Entry& info, bool override) {
if (!override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
- CHECK(it->second == info)
- << "Trying to update var \'" << var << "\'"
- << " with a different const bound: "
- << "original=" << ConstIntBound(it->second.min_value, it->second.max_value)
- << ", new=" << ConstIntBound(info.min_value, info.max_value);
+ CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
+ << " with a different const bound: "
+ << "original="
+ << ConstIntBound(it->second.min_value, it->second.max_value)
+ << ", new=" << ConstIntBound(info.min_value, info.max_value);
}
}
var_map_[var] = info;
}
- void Update(const Var& var,
- const ConstIntBound& info,
- bool override) {
+ void Update(const Var& var, const ConstIntBound& info, bool override) {
Update(var, MakeBound(info->min_value, info->max_value), override);
}
// Override visitor behaviors
Entry VisitExprDefault_(const Object* op) final {
- return Everything(
- static_cast<const PrimExprNode*>(op)->dtype);
+ return Everything(static_cast<const PrimExprNode*>(op)->dtype);
}
Entry VisitExpr(const PrimExpr& expr) final {
return Intersect(a, b);
}
- Entry VisitExpr_(const IntImmNode* op) final {
- return MakeBound(op->value, op->value);
- }
+ Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); }
Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a);
// 0 <= [a_min, a_max] < b_min
if (a.max_value < b.min_value) return a;
// other case, we can get close to 0
- return MakeBound(0,
- std::min(a.max_value, b_max_cap));
+ return MakeBound(0, std::min(a.max_value, b_max_cap));
} else {
return MakeBound(std::max(a.min_value, -b_max_cap),
std::min(std::max(a.max_value, (int64_t)0), b_max_cap));
* \tparam F the operator function type.
* \return The result.
*/
- template<typename F>
+ template <typename F>
static Entry BinaryOpBoundry(Entry a, Entry b, const F& op) {
Entry ret;
// The boundary point must be shihft of the original boundary.
return ConstIntBound(ret.min_value, ret.max_value);
}
-ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
- BoundMapType* bound) {
+ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapType* bound) {
impl_->bound_ = bound;
Entry ret = impl_->VisitExpr(expr);
impl_->bound_ = nullptr;
return ConstIntBound(ret.min_value, ret.max_value);
}
-void ConstIntBoundAnalyzer::Update(const Var& var,
- const ConstIntBound& info,
- bool override) {
+void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool override) {
impl_->Update(var, info, override);
}
return impl_->EnterConstraint(constraint);
}
-ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent)
- : impl_(new Impl()) {
-}
+ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {}
-ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() {
- delete impl_;
-}
+ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; }
} // namespace arith
} // namespace tvm
* \file detect_linear_equation.cc
* \brief Utility to detect patterns in the expression.
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
-#include <tvm/tir/expr.h>
#include <tvm/tir/analysis.h>
-#include <tvm/tir/op.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/arith/analyzer.h>
namespace tvm {
namespace arith {
PrimExpr max_value;
};
-class LinearEqDetector
- : public ExprFunctor<LinearEqEntry(const PrimExpr&, const PrimExpr &)> {
+class LinearEqDetector : public ExprFunctor<LinearEqEntry(const PrimExpr&, const PrimExpr&)> {
public:
- explicit LinearEqDetector(Var var)
- : var_(var) {}
+ explicit LinearEqDetector(Var var) : var_(var) {}
bool Detect(const PrimExpr& e, LinearEqEntry* ret) {
*ret = VisitExpr(e, e);
}
};
-Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
- const Array<Var>& vars) {
+Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<Var>& vars) {
PrimExpr base = e;
Array<PrimExpr> coeff;
}
std::unordered_set<const VarNode*> vset;
- auto vset_contains = [&](const VarNode* node) {
- return vset.count(node) != 0;
- };
+ auto vset_contains = [&](const VarNode* node) { return vset.count(node) != 0; };
for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get());
}
// Detect clip condition as min max value
-bool DetectClipBound(
- const PrimExpr& cond,
- std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
+bool DetectClipBound(const PrimExpr& cond,
+ std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
int flag = 0;
Var var;
auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) {
return false;
}
-
-template<typename OP>
+template <typename OP>
void SplitCommExpr(const PrimExpr& e, std::vector<PrimExpr>* ret) {
if (const OP* op = e.as<OP>()) {
SplitCommExpr<OP>(op->a, ret);
return ret;
}
-TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
-.set_body_typed(DetectLinearEquation);
+TVM_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation);
TVM_REGISTER_GLOBAL("arith.DetectClipBound")
-.set_body_typed([](const PrimExpr& e, const Array<Var>& vars) {
- return DetectClipBound(e, vars);
-});
+ .set_body_typed([](const PrimExpr& e, const Array<Var>& vars) {
+ return DetectClipBound(e, vars);
+ });
} // namespace arith
} // namespace tvm
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
+#include <tvm/runtime/registry.h>
+#include <tvm/te/tensor.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/te/tensor.h>
-#include <tvm/runtime/registry.h>
-#include <unordered_set>
#include <unordered_map>
+#include <unordered_set>
namespace tvm {
namespace arith {
// Find Read region of the tensor in the stmt.
class BufferTouchedDomain final : public StmtExprVisitor {
public:
- BufferTouchedDomain(const Buffer &buffer,
- bool consider_loads,
- bool consider_stores)
- : buffer_(buffer),
- consider_loads_(consider_loads),
- consider_stores_(consider_stores) {}
+ BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores)
+ : buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {}
Domain Find(const Stmt& stmt) {
operator()(stmt);
return ret;
}
- void VisitStmt_(const ForNode *op) final {
+ void VisitStmt_(const ForNode* op) final {
const VarNode* var = op->loop_var.get();
- dom_map_[var] = IntSet::range(
- Range::make_by_min_extent(op->min, op->extent));
+ dom_map_[var] = IntSet::range(Range::make_by_min_extent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}
void VisitStmt_(const LetStmtNode* op) final {
- dom_map_[op->var.get()] =
- arith::EvalSet(op->value, dom_map_);
+ dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_);
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(op->var.get());
}
}
}
- const Buffer &buffer_;
+ const Buffer& buffer_;
bool consider_loads_, consider_stores_;
std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const VarNode*, IntSet> dom_map_;
};
-Domain DomainTouched(const Stmt& stmt,
- const Buffer& buffer,
- bool consider_loads,
+Domain DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
bool consider_stores) {
return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt);
}
-TVM_REGISTER_GLOBAL("arith.DomainTouched")
-.set_body_typed(DomainTouched);
+TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched);
} // namespace arith
} // namespace tvm
* \brief The integer constraints data structures.
*/
#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
-#include <tvm/runtime/registry.h>
-#include <utility>
#include <algorithm>
#include <unordered_map>
+#include <utility>
namespace tvm {
namespace arith {
-IntConstraints::IntConstraints(Array<Var> variables,
- Map<Var, Range> ranges,
+IntConstraints::IntConstraints(Array<Var> variables, Map<Var, Range> ranges,
Array<PrimExpr> relations) {
ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>();
if (!variables.defined()) {
CHECK(relations.defined());
for (const auto& var : variables) {
CHECK(var.dtype().is_int() || var.dtype().is_uint())
- << "Variables in IntConstraints must be integers";
+ << "Variables in IntConstraints must be integers";
}
node->variables = std::move(variables);
node->ranges = std::move(ranges);
TVM_REGISTER_NODE_TYPE(IntConstraintsNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IntConstraintsNode*>(node.get());
- p->stream << "IntConstraints("
- << op->variables
- << ", " << op->ranges
- << ", " << op->relations
- << ")";
- });
-
+ .set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const IntConstraintsNode*>(node.get());
+ p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations
+ << ")";
+ });
-IntConstraintsTransform::IntConstraintsTransform(IntConstraints src,
- IntConstraints dst,
+IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src) {
ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>();
TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
- p->stream << "IntConstraintsTransform("
- << "\n\t" << op->src
- << "\n\t" << op->dst
- << "\n\t" << op->src_to_dst
- << "\n\t" << op->dst_to_src
- << "\n)";
- });
+ .set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
+ p->stream << "IntConstraintsTransform("
+ << "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst << "\n\t"
+ << op->dst_to_src << "\n)";
+ });
} // namespace arith
} // namespace tvm
* \return Whether overflow can happen.
* \tparam Op The integer operator.
*/
-template<typename Op>
-inline bool WillOverflow(int64_t x,
- int64_t y,
- int64_t min_value,
- int64_t max_value) {
+template <typename Op>
+inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
return false;
}
-template<>
-inline bool WillOverflow<tir::AddNode>(int64_t x,
- int64_t y,
- int64_t min_value,
- int64_t max_value) {
+template <>
+inline bool WillOverflow<tir::AddNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
if ((y > 0) && (x > max_value - y)) return true;
if ((y < 0) && (x < min_value - y)) return true;
return false;
}
-template<>
-inline bool WillOverflow<tir::SubNode>(int64_t x,
- int64_t y,
- int64_t min_value,
- int64_t max_value) {
+template <>
+inline bool WillOverflow<tir::SubNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
if ((y > 0) && (x < min_value + y)) return true;
if ((y < 0) && (x > max_value + y)) return true;
return false;
}
-template<>
-inline bool WillOverflow<tir::MulNode>(int64_t x,
- int64_t y,
- int64_t min_value,
- int64_t max_value) {
+template <>
+inline bool WillOverflow<tir::MulNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
if (y == 0) return false;
if (y > 0) {
- if (x < min_value / y) return true;
- if (x > max_value / y) return true;
+ if (x < min_value / y) return true;
+ if (x > max_value / y) return true;
} else {
if (y == -1 && x == std::numeric_limits<int64_t>::min()) return true;
- if (x > min_value / y) return true;
- if (x < max_value / y) return true;
+ if (x > min_value / y) return true;
+ if (x < max_value / y) return true;
}
return false;
}
-template<>
-inline bool WillOverflow<tir::ModNode>(int64_t x,
- int64_t y,
- int64_t min_value,
- int64_t max_value) {
+template <>
+inline bool WillOverflow<tir::ModNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
return y == 0;
}
* \param y The right operand.
* \return the result.
*/
-inline int64_t truncdiv(int64_t x, int64_t y) {
- return x / y;
-}
+inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; }
/*!
* \brief Compute the truncdiv remainder of two integers.
* \param y The right operand.
* \return the result.
*/
-inline int64_t truncmod(int64_t x, int64_t y) {
- return x % y;
-}
+inline int64_t truncmod(int64_t x, int64_t y) { return x % y; }
/*!
* \brief Peform floor division of two integers.
inline int64_t floordiv(int64_t x, int64_t y) {
int64_t rdiv = x / y;
int64_t rmod = x % y;
- bool is_floor_div =
- (y >= 0 && rmod >= 0) ||
- (y < 0 && rmod <= 0);
+ bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0);
return is_floor_div ? rdiv : (rdiv - 1);
}
-
/*!
* \brief Compute the floordiv remainder of two integers.
* \param x The left operand.
*/
inline int64_t floormod(int64_t x, int64_t y) {
int64_t rmod = x % y;
- bool is_floor_div =
- (y >= 0 && rmod >= 0) ||
- (y < 0 && rmod <= 0);
+ bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0);
return is_floor_div ? rmod : rmod + y;
}
* \brief The integer set functions
*/
#include <tvm/arith/int_set.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
-#include <tvm/runtime/registry.h>
-#include <utility>
#include <algorithm>
#include <unordered_map>
+#include <utility>
+
#include "interval_set.h"
#include "pattern_match.h"
namespace tvm {
namespace arith {
+using tir::is_one;
+using tir::is_zero;
using tir::make_const;
using tir::make_zero;
-using tir::is_zero;
-using tir::is_one;
PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
return IntervalSet(min_value, max_value);
}
-TVM_REGISTER_GLOBAL("arith.IntervalSet")
-.set_body_typed(MakeIntervalSet);
-
+TVM_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet);
IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
PrimExpr max_value = min(a->max_value, b->max_value);
}
// type traits
-template<typename OP>
+template <typename OP>
struct is_logical_op {
static const bool value = false;
};
-#define TVM_DECLARE_LOGICAL_OP(OP) \
- template<> \
- struct is_logical_op<tir::OP> { \
- static const bool value = true; \
+#define TVM_DECLARE_LOGICAL_OP(OP) \
+ template <> \
+ struct is_logical_op<tir::OP> { \
+ static const bool value = true; \
};
TVM_DECLARE_LOGICAL_OP(AndNode);
* \brief Combine two interval set under arithmetic operations.
* \note this can possibly relax the set.
*/
-template<typename Op>
-inline IntervalSet Combine(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+template <typename Op>
+inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
PrimExpr res = TryConstFold<Op>(a->min_value, b->min_value);
if (!res.defined()) res = Op::make(a->min_value, b->min_value);
return IntervalSet::SinglePoint(res);
}
if (is_logical_op<Op>::value) {
- return IntervalSet(make_const(a->min_value.dtype(), 0),
- make_const(a->min_value.dtype(), 1));
+ return IntervalSet(make_const(a->min_value.dtype(), 0), make_const(a->min_value.dtype(), 1));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
return IntervalSet::Everything();
}
-template<>
-inline IntervalSet Combine<tir::AddNode>(Analyzer* analyer,
- IntervalSet a,
- IntervalSet b) {
+template <>
+inline IntervalSet Combine<tir::AddNode>(Analyzer* analyer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value + b->min_value);
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
PrimExpr min_value =
- a->HasLowerBound() && b->HasLowerBound() ?
- a->min_value + b->min_value : neg_inf();
+ a->HasLowerBound() && b->HasLowerBound() ? a->min_value + b->min_value : neg_inf();
PrimExpr max_value =
- a->HasUpperBound() && b->HasUpperBound() ?
- a->max_value + b->max_value : pos_inf();
+ a->HasUpperBound() && b->HasUpperBound() ? a->max_value + b->max_value : pos_inf();
return IntervalSet(min_value, max_value);
}
-template<>
-inline IntervalSet Combine<tir::SubNode>(Analyzer* analyer,
- IntervalSet a,
- IntervalSet b) {
+template <>
+inline IntervalSet Combine<tir::SubNode>(Analyzer* analyer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value - b->min_value);
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
PrimExpr min_value =
- a->HasLowerBound() && b->HasUpperBound() ?
- a->min_value - b->max_value : neg_inf();
+ a->HasLowerBound() && b->HasUpperBound() ? a->min_value - b->max_value : neg_inf();
PrimExpr max_value =
- a->HasUpperBound() && b->HasLowerBound() ?
- a->max_value - b->min_value : pos_inf();
+ a->HasUpperBound() && b->HasLowerBound() ? a->max_value - b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
}
-
-template<>
-inline IntervalSet Combine<tir::MulNode>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+template <>
+inline IntervalSet Combine<tir::MulNode>(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value * b->min_value);
}
return IntervalSet::Everything();
}
-template<>
-inline IntervalSet Combine<tir::DivNode>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+template <>
+inline IntervalSet Combine<tir::DivNode>(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value / b->min_value);
}
return IntervalSet::Everything();
}
-template<>
-inline IntervalSet Combine<tir::ModNode>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+template <>
+inline IntervalSet Combine<tir::ModNode>(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
}
return IntervalSet::Everything();
}
-
-template<>
-inline IntervalSet Combine<tir::FloorDivNode>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+template <>
+inline IntervalSet Combine<tir::FloorDivNode>(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
}
return IntervalSet::Everything();
}
-template<>
-inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+template <>
+inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
}
return IntervalSet::Everything();
}
-template<>
-inline IntervalSet Combine<tir::MaxNode>(Analyzer* analzyer,
- IntervalSet a,
- IntervalSet b) {
+template <>
+inline IntervalSet Combine<tir::MaxNode>(Analyzer* analzyer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
- return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
+ return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
- return IntervalSet(max(a->min_value, b->min_value),
- max(a->max_value, b->max_value));
+ return IntervalSet(max(a->min_value, b->min_value), max(a->max_value, b->max_value));
}
-template<>
-inline IntervalSet Combine<tir::MinNode>(Analyzer* analzyer,
- IntervalSet a,
- IntervalSet b) {
+template <>
+inline IntervalSet Combine<tir::MinNode>(Analyzer* analzyer, IntervalSet a, IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
- return IntervalSet(min(a->min_value, b->min_value),
- min(a->max_value, b->max_value));
+ return IntervalSet(min(a->min_value, b->min_value), min(a->max_value, b->max_value));
}
// internal helper function to get an interval set
// Simplified version of int set evaluator that operates on IntervalSet
// We might use better set analysis in the future to replace the intervalset.
-class IntervalSetEvaluator :
- public ExprFunctor<IntervalSet(const PrimExpr&)> {
+class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
public:
- IntervalSetEvaluator(Analyzer* analyzer,
- const Map<Var, IntSet>& dom_map,
- bool eval_vec = false)
- : analyzer_(analyzer),
- dom_map_(dom_map),
- eval_vec_(eval_vec) {
- }
+ IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map, bool eval_vec = false)
+ : analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {}
- IntervalSet Eval(const PrimExpr& val) {
- return this->VisitExpr(val);
- }
+ IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); }
// evaluate and relax the set
IntervalSet Eval(IntervalSet val) {
// avoid recursive indefinite recursive expansion.
auto it = dom_map_.find(var);
if (it != dom_map_.end()) {
IntervalSet res = ToIntervalSet((*it).second);
- if (res->min_value.same_as(var) &&
- res->max_value.same_as(var)) {
+ if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
return res;
}
// recursively evaluate mapped result
}
}
+ IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const AddNode* op) final {
- return VisitBinaryExpr_(op);
- }
-
- IntervalSet VisitExpr_(const SubNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const MulNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const DivNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const ModNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const FloorDivNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const FloorModNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const MinNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const MaxNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const EQNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const NENode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const LTNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const LENode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const GTNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const GENode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const AndNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_(op); }
- IntervalSet VisitExpr_(const OrNode* op) final {
- return VisitBinaryExpr_(op);
- }
+ IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_(op); }
IntervalSet VisitExpr_(const RampNode* op) final {
CHECK(eval_vec_);
if (stride.Match(op->stride)) {
DataType t = op->base.dtype();
int64_t vstride = stride.Eval()->value;
- if (vstride> 0) {
- return Combine<AddNode>(
- analyzer_,
- base,
- IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)));
+ if (vstride > 0) {
+ return Combine<AddNode>(analyzer_, base,
+ IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)));
} else {
- return Combine<AddNode>(
- analyzer_,
- base,
- IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
+ return Combine<AddNode>(analyzer_, base,
+ IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
}
}
DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<PrimExpr>(op);
private:
// whether set is exactly single point that equals value.
- bool MatchPoint(const IntervalSet& set,
- const PrimExpr& value) const {
+ bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const {
return set->min_value.same_as(value) && set->max_value.same_as(value);
}
- template<typename T>
+ template <typename T>
inline IntervalSet VisitBinaryExpr_(const T* op) {
IntervalSet a = this->Eval(op->a);
IntervalSet b = this->Eval(op->b);
class IntSetAnalyzer::Impl {
public:
- explicit Impl(Analyzer* analyzer)
- : analyzer_(analyzer) {
- }
+ explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {}
IntSet Eval(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) const {
return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
Analyzer* analyzer_;
};
-IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent)
- : impl_(new Impl(parent)) {
-}
+IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
-IntSetAnalyzer::~IntSetAnalyzer() {
- delete impl_;
-}
+IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; }
-IntSet IntSetAnalyzer::operator()(const PrimExpr& expr,
- const Map<Var, IntSet>& dom_map) {
+IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) {
return impl_->Eval(expr, dom_map);
}
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
CHECK(s_int != nullptr);
if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
- return Range::make_by_min_extent(
- s_int->min_value, analyzer.Simplify(s_int->max_value + 1 - s_int->min_value));
+ return Range::make_by_min_extent(s_int->min_value,
+ analyzer.Simplify(s_int->max_value + 1 - s_int->min_value));
}
return max_range;
}
return s_int->min_value;
}
-IntSet IntSet::nothing() {
- return IntervalSet::Empty();
-}
+IntSet IntSet::nothing() { return IntervalSet::Empty(); }
-IntSet IntSet::everything() {
- return IntervalSet::Everything();
-}
+IntSet IntSet::everything() { return IntervalSet::Everything(); }
-IntSet IntSet::single_point(PrimExpr x) {
- return IntervalSet::SinglePoint(x);
-}
+IntSet IntSet::single_point(PrimExpr x) { return IntervalSet::SinglePoint(x); }
IntSet IntSet::interval(PrimExpr min, PrimExpr max) {
if (min.same_as(max)) {
if (!a_int) return false;
Analyzer ana;
return ProveEqual(&ana, a_int->min_value, b->min) &&
- ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1);
+ ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1);
}
IntSet Union(const Array<IntSet>& sets) {
for (size_t i = 1; i < sets.size(); ++i) {
x = Union(&ana, x, ToIntervalSet(sets[i]));
}
- return IntervalSet(ana.Simplify(x->min_value),
- ana.Simplify(x->max_value));
+ return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value));
}
IntSet Intersect(const Array<IntSet>& sets) {
for (size_t i = 1; i < sets.size(); ++i) {
x = Intersect(&ana, x, ToIntervalSet(sets[i]));
}
- return IntervalSet(ana.Simplify(x->min_value),
- ana.Simplify(x->max_value));
+ return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value));
}
Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
return dmap;
}
-Map<Var, IntSet> ConvertDomMap(
- const std::unordered_map<const VarNode*, IntSet>& dom_map) {
+Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map) {
Map<Var, IntSet> dmap;
for (auto kv : dom_map) {
dmap.Set(GetRef<Var>(kv.first), kv.second);
return dmap;
}
-IntSet EvalSet(PrimExpr e,
- const Map<Var, IntSet>& dom_map) {
+IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) {
Analyzer ana;
return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
}
return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
}
-IntSet EvalSet(PrimExpr e,
- const Map<IterVar, IntSet>& dom_map) {
+IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) {
return EvalSet(e, ConvertDomMap(dom_map));
}
-IntSet EvalSet(PrimExpr e,
- const std::unordered_map<const VarNode*, IntSet>& dom_map) {
+IntSet EvalSet(PrimExpr e, const std::unordered_map<const VarNode*, IntSet>& dom_map) {
return EvalSet(e, ConvertDomMap(dom_map));
}
-IntSet EvalSet(Range r,
- const Map<Var, IntSet>& dom_map) {
+IntSet EvalSet(Range r, const Map<Var, IntSet>& dom_map) {
Analyzer ana;
IntervalSetEvaluator m(&ana, dom_map);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
PrimExpr sum = r->min + r->extent - 1;
- auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum)));
+ auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum)));
return std::move(res);
}
-IntSet EvalSet(Range r,
- const std::unordered_map<const VarNode*, IntSet>& dom_map) {
+IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map) {
return EvalSet(r, ConvertDomMap(dom_map));
}
-IntSet EvalSet(IntSet s,
- const std::unordered_map<const VarNode*, IntSet>& dom_map) {
+IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_map) {
Analyzer ana;
auto dmap = ConvertDomMap(dom_map);
IntervalSetEvaluator m(&ana, dmap);
const IntervalSetNode* s_int = s.as<IntervalSetNode>();
- PrimExpr vmax = s_int->HasUpperBound() ?
- m.Eval(s_int->max_value).max() : s_int->max_value;
- PrimExpr vmin = s_int->HasLowerBound() ?
- m.Eval(s_int->min_value).min() : s_int->min_value;
+ PrimExpr vmax = s_int->HasUpperBound() ? m.Eval(s_int->max_value).max() : s_int->max_value;
+ PrimExpr vmin = s_int->HasLowerBound() ? m.Eval(s_int->min_value).min() : s_int->min_value;
return IntervalSet(vmin, vmax);
}
class SubExprIntervalSetEvaluator : public IntervalSetEvaluator {
public:
- explicit SubExprIntervalSetEvaluator(
- Analyzer* analyzer,
- const Map<Var, IntSet>& dom_map)
+ explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map)
: IntervalSetEvaluator(analyzer, dom_map) {}
IntervalSet VisitExpr(const PrimExpr& n) final {
ExprIntSetMap expr_map;
};
-ExprIntSetMap EvalSetForEachSubExpr(
- PrimExpr e,
- const std::unordered_map<const VarNode*, IntSet>& dom_map) {
+ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map) {
Analyzer ana;
auto dmap = ConvertDomMap(dom_map);
SubExprIntervalSetEvaluator m(&ana, dmap);
return m.expr_map;
}
-IntSet EvalSet(Range r,
- const Map<IterVar, IntSet>& dom_map) {
+IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map) {
return EvalSet(r, ConvertDomMap(dom_map));
}
TVM_REGISTER_NODE_TYPE(IntervalSetNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IntervalSetNode*>(node.get());
- p->stream << "IntervalSet"
- << "[" << op->min_value << ", "
- << op->max_value << ']';
- });
-
+ .set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const IntervalSetNode*>(node.get());
+ p->stream << "IntervalSet"
+ << "[" << op->min_value << ", " << op->max_value << ']';
+ });
-TVM_REGISTER_GLOBAL("arith.intset_single_point")
-.set_body_typed(IntSet::single_point);
+TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::single_point);
-TVM_REGISTER_GLOBAL("arith.intset_vector")
-.set_body_typed(IntSet::vector);
+TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::vector);
-TVM_REGISTER_GLOBAL("arith.intset_interval")
-.set_body_typed(IntSet::interval);
+TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::interval);
-TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
-.set_body_method(&IntSet::min);
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min);
-TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
-.set_body_method(&IntSet::max);
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max);
-TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
-.set_body_method(&IntSet::is_nothing);
+TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::is_nothing);
-TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
-.set_body_method(&IntSet::is_everything);
+TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::is_everything);
} // namespace arith
} // namespace tvm
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
+
#include <limits>
+
#include "const_fold.h"
namespace tvm {
}
/*! \return Whether the interval has upper bound. */
- bool HasUpperBound() const {
- return !is_pos_inf(max_value) && !IsEmpty();
- }
+ bool HasUpperBound() const { return !is_pos_inf(max_value) && !IsEmpty(); }
/*! \return Whether the interval has lower bound. */
- bool HasLowerBound() const {
- return !is_neg_inf(min_value) && !IsEmpty();
- }
+ bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); }
/*! \return Whether the interval is a single point. */
- bool IsSinglePoint() const {
- return min_value.same_as(max_value);
- }
+ bool IsSinglePoint() const { return min_value.same_as(max_value); }
/*! \return whether interval represent nothing */
bool IsEmpty() const {
// during computations, either extreme could occur.
return is_pos_inf(min_value) || is_neg_inf(max_value);
}
/*! \return whether interval represent everything */
- bool IsEverything() const {
- return is_neg_inf(min_value) && is_pos_inf(max_value);
- }
+ bool IsEverything() const { return is_neg_inf(min_value) && is_pos_inf(max_value); }
static constexpr const char* _type_key = "arith.IntervalSet";
TVM_DECLARE_FINAL_OBJECT_INFO(IntervalSetNode, IntSetNode);
* \param value The value to be represented.
* \return The result set.
*/
- static IntervalSet SinglePoint(PrimExpr value) {
- return IntervalSet(value, value);
- }
+ static IntervalSet SinglePoint(PrimExpr value) { return IntervalSet(value, value); }
/*!
* \brief Create an IntervalSet that represents everything.
* \param value The value to be represented.
* \return The result set.
*/
- static IntervalSet Everything() {
- return IntervalSet(neg_inf(), pos_inf());
- }
+ static IntervalSet Everything() { return IntervalSet(neg_inf(), pos_inf()); }
/*!
* \brief Create an empty eet.
* \return The result set.
*/
- static IntervalSet Empty() {
- return IntervalSet(pos_inf(), neg_inf());
- }
+ static IntervalSet Empty() { return IntervalSet(pos_inf(), neg_inf()); }
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntervalSetNode);
TVM_DEFINE_OBJECT_REF_METHODS(IntervalSet, IntSet, IntervalSetNode);
* \param b The second set.
* \return The result set.
*/
-TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b);
+TVM_DLL IntervalSet Intersect(Analyzer* analzyer, IntervalSet a, IntervalSet b);
} // namespace arith
} // namespace tvm
/*!
* \file tvm/arith/ir_mutator_with_analyzer.cc
*/
+#include "ir_mutator_with_analyzer.h"
+
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
-#include "ir_mutator_with_analyzer.h"
namespace tvm {
namespace arith {
using namespace tir;
-Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const ForNode* op) {
- analyzer_->Bind(op->loop_var,
- Range::make_by_min_extent(op->min, op->extent));
+Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
+ analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
return StmtExprMutator::VisitStmt_(op);
}
-Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const LetStmtNode* op) {
+Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (!tir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Stmt body = this->VisitStmt(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
}
}
-Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const IfThenElseNode* op) {
+Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
Stmt then_case, else_case;
{
then_case = this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
- With<ConstraintContext> ctx(analyzer_,
- analyzer_->rewrite_simplify(NotNode::make(condition)));
- else_case = this->VisitStmt(op->else_case);
+ With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(NotNode::make(condition)));
+ else_case = this->VisitStmt(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
return EvaluateNode::make(0);
}
- if (condition.same_as(op->condition) &&
- then_case.same_as(op->then_case) &&
+ if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
}
}
-Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == tir::attr::thread_extent ||
- op->attr_key == tir::attr::virtual_thread) {
+Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
+ if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
- analyzer_->Bind(iv->var,
- Range::make_by_min_extent(0, op->value));
+ analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
Stmt stmt = StmtExprMutator::VisitStmt_(op);
return stmt;
} else {
}
}
-Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const AssertStmtNode* op) {
+Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr message = this->VisitExpr(op->message);
With<ConstraintContext> ctx(analyzer_, condition);
Stmt body = this->VisitStmt(op->body);
- if (condition.same_as(op->condition) &&
- message.same_as(op->message) &&
- body.same_as(op->body)) {
+ if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
}
}
-PrimExpr IRMutatorWithAnalyzer::
-VisitExpr_(const CallNode* op) {
+PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) {
PrimExpr cond = this->VisitExpr(op->args[0]);
if (is_one(cond)) {
return true_value;
}
- if (cond.same_as(op->args[0]) &&
- true_value.same_as(op->args[1]) &&
+ if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
} else {
- return CallNode::make(op->dtype, op->name,
- {cond, true_value, false_value},
- op->call_type);
+ return CallNode::make(op->dtype, op->name, {cond, true_value, false_value}, op->call_type);
}
}
return StmtExprMutator::VisitExpr_(op);
}
-PrimExpr IRMutatorWithAnalyzer::
-VisitExpr_(const LetNode* op) {
+PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (!tir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
PrimExpr body = this->VisitExpr(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
}
}
-PrimExpr IRMutatorWithAnalyzer::
-VisitExpr_(const SelectNode* op) {
+PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) {
PrimExpr cond = this->VisitExpr(op->condition);
PrimExpr true_value, false_value;
{
true_value = VisitExpr(op->true_value);
}
{
- With<ConstraintContext> constraint(analyzer_,
- analyzer_->rewrite_simplify(NotNode::make(cond)));
+ With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(NotNode::make(cond)));
false_value = VisitExpr(op->false_value);
}
if (is_zero(cond)) {
return true_value;
}
// normal path
- if (cond.same_as(op->condition) &&
- true_value.same_as(op->true_value) &&
+ if (cond.same_as(op->condition) && true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
} else {
}
}
-PrimExpr IRMutatorWithAnalyzer::
-VisitExpr_(const ReduceNode* op) {
+PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_->Bind(iv->var, iv->dom);
#ifndef TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_
#define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_
-#include <tvm/tir/stmt_functor.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/tir/stmt_functor.h>
+
#include <utility>
namespace tvm {
*/
class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
public:
- explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
- : analyzer_(analyzer) {}
+ explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {}
- using StmtExprMutator::VisitStmt_;
using StmtExprMutator::VisitExpr_;
+ using StmtExprMutator::VisitStmt_;
// override functions that need to populate the context information.
tir::Stmt VisitStmt_(const tir::ForNode* op) override;
class IRVisitorWithAnalyzer final : public StmtExprVisitor {
public:
- PrimExpr Simplify(const PrimExpr& expr) {
- return analyzer_.Simplify(expr);
- }
+ PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); }
void VisitStmt_(const ForNode* op) {
- analyzer_.Bind(op->loop_var,
- Range::make_by_min_extent(op->min, op->extent));
+ analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
return StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
- analyzer_.Bind(iv->var,
- Range::make_by_min_extent(0, op->value));
+ analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value));
StmtExprVisitor::VisitStmt_(op);
} else {
StmtExprVisitor::VisitStmt_(op);
* \file modular_set.cc
* \brief Modular set analysis
*/
-#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
-#include <tvm/tir/op.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
#include <limits>
-#include <utility>
#include <unordered_map>
+#include <utility>
+
#include "pattern_match.h"
namespace tvm {
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ModularSetNode*>(node.get());
- p->stream << "ModularSet("
- << "coeff=" << op->coeff << ", base="
- << op->base << ')';
- });
-
-ModularSet MakeModularSet(int64_t coeff, int64_t base) {
- return ModularSet(coeff, base);
-}
+ .set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ModularSetNode*>(node.get());
+ p->stream << "ModularSet("
+ << "coeff=" << op->coeff << ", base=" << op->base << ')';
+ });
+
+ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); }
-TVM_REGISTER_GLOBAL("arith.ModularSet")
-.set_body_typed(MakeModularSet);
+TVM_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet);
// internal entry for const int bound
struct ModularSetAnalyzer::Entry {
this->base = base;
}
- bool is_const() const {
- return coeff == 0;
- }
+ bool is_const() const { return coeff == 0; }
- bool operator==(const Entry& other) const {
- return coeff == other.coeff && base == other.base;
- }
+ bool operator==(const Entry& other) const { return coeff == other.coeff && base == other.base; }
bool operator==(const ModularSet& other) const {
- return other.defined() &&
- coeff == other->coeff && base == other->base;
+ return other.defined() && coeff == other->coeff && base == other->base;
}
};
-class ModularSetAnalyzer::Impl :
- public ExprFunctor<ModularSetAnalyzer::Entry(const PrimExpr&)> {
+class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(const PrimExpr&)> {
public:
- explicit Impl(Analyzer* parent)
- : parent_(parent) {}
+ explicit Impl(Analyzer* parent) : parent_(parent) {}
- void Update(const Var& var,
- const ModularSet& info,
- bool override) {
+ void Update(const Var& var, const ModularSet& info, bool override) {
if (!override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
- CHECK(it->second == info)
- << "Trying to update var \'" << var << "\'"
- << " with a different const bound: "
- << "original=" << ModularSet(it->second.coeff, it->second.base)
- << ", new=" << info;
+ CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
+ << " with a different const bound: "
+ << "original=" << ModularSet(it->second.coeff, it->second.base)
+ << ", new=" << info;
}
}
var_map_[var] = Entry(info->coeff, info->base);
}
// Override visitor behaviors
- Entry VisitExprDefault_(const Object* op) final {
- return Everything();
- }
+ Entry VisitExprDefault_(const Object* op) final { return Everything(); }
- Entry VisitExpr_(const CastNode* op) final {
- return VisitExpr(op->value);
- }
+ Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
- Entry VisitExpr_(const IntImmNode* op) final {
- return Entry(0, op->value);
- }
+ Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); }
Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a);
return Entry(coeff, a.base * b.base);
}
- Entry DivByConst(const PrimExpr& lhs,
- int64_t val,
- bool round_down) {
+ Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) {
Entry a = VisitExpr(lhs);
CHECK_NE(val, 0);
if (a.coeff % val == 0) {
}
// positive division have a clear rounding mode.
// Only handle case where we clearly know we need to round down.
- if (a.base > 0 && val > 0 &&
- (round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
+ if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
return Entry(a.coeff / val, a.base / val);
}
}
}
var_map_[var] = Intersect(old, entry);
// reover function.
- return [this, old, var]() {
- var_map_[var] = old;
- };
+ return [this, old, var]() { var_map_[var] = old; };
}
/*!
* \brief Create union of two sets.
* \brief return everything dtype can represent.
* \return Bound that represent everything dtype can represent.
*/
- static Entry Everything() {
- return Entry(1, 0);
- }
+ static Entry Everything() { return Entry(1, 0); }
/*!
* \brief return an empty set
* \return Bound that represent everything dtype can represent.
*/
- static Entry Nothing() {
- return Entry(0, 1);
- }
+ static Entry Nothing() { return Entry(0, 1); }
};
ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) {
return ModularSet(ret.coeff, ret.base);
}
-void ModularSetAnalyzer::Update(const Var& var,
- const ModularSet& info,
- bool override) {
+void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool override) {
impl_->Update(var, info, override);
}
return impl_->EnterConstraint(constraint);
}
-ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent)
- : impl_(new Impl(parent)) {
-}
+ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
-ModularSetAnalyzer::~ModularSetAnalyzer() {
- delete impl_;
-}
+ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; }
} // namespace arith
} // namespace tvm
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
+
#include <tuple>
+
#include "const_fold.h"
namespace tvm {
*
* \tparam Derived The type of the derived class.
*/
-template<typename Derived>
+template <typename Derived>
class Pattern {
public:
/*!
*
* \return whether value matches the pattern.
*/
- template<typename NodeType>
+ template <typename NodeType>
bool Match(const NodeType& value) const {
derived().InitMatch_();
return derived().Match_(value);
}
/*! \return Derived instance of current class. */
- const Derived& derived() const {
- return *static_cast<const Derived*>(this);
- }
+ const Derived& derived() const { return *static_cast<const Derived*>(this); }
};
/*!
* \brief Default deep equality checker
* \tparam T the comparison point.
*/
-template<typename T>
+template <typename T>
class PEqualChecker {
public:
- bool operator()(const T& lhs, const T& rhs) const {
- return lhs == rhs;
- }
+ bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; }
};
-template<>
+template <>
class PEqualChecker<PrimExpr> {
public:
bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
}
};
-template<>
+template <>
class PEqualChecker<IntImm> {
public:
- bool operator()(const IntImm& lhs, const IntImm& rhs) const {
- return lhs->value == rhs->value;
- }
+ bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; }
};
-template<>
+template <>
class PEqualChecker<tir::Var> {
public:
- bool operator()(const tir::Var& lhs, const tir::Var& rhs) const {
- return lhs.same_as(rhs);
- }
+ bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); }
};
/*!
* \note PVar is not thread safe.
* Do not use the same PVar in multiple threads.
*/
-template<typename T>
-class PVar : public Pattern<PVar<T> > {
+template <typename T>
+class PVar : public Pattern<PVar<T>> {
public:
// Store PVars by reference in the expression.
using Nested = const PVar<T>&;
- void InitMatch_() const {
- filled_ = false;
- }
+ void InitMatch_() const { filled_ = false; }
bool Match_(const T& value) const {
if (!filled_) {
}
}
- template<typename NodeRefType,
- typename = typename std::enable_if<
- std::is_base_of<NodeRefType, T>::value>::type>
+ template <typename NodeRefType,
+ typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type>
bool Match_(const NodeRefType& value) const {
if (const auto* ptr = value.template as<typename T::ContainerType>()) {
return Match_(GetRef<T>(ptr));
*
* \tparam T the type of the hole.
*/
-template<typename T>
-class PConst : public Pattern<PConst<T> > {
+template <typename T>
+class PConst : public Pattern<PConst<T>> {
public:
PConst(T value) // NOLINT(*)
: value_(value) {}
void InitMatch_() const {}
- bool Match_(const T& value) const {
- return PEqualChecker<T>()(value_, value);
- }
+ bool Match_(const T& value) const { return PEqualChecker<T>()(value_, value); }
- T Eval() const {
- return value_;
- }
+ T Eval() const { return value_; }
private:
const T value_;
* \tparam TA The pattern type of the first operand.
* \tparam TB The pattern type of the second operand.
*/
-template<typename NodeType, typename TA, typename TB>
-class PBinaryExpr :
- public Pattern<PBinaryExpr<NodeType, TA, TB> > {
+template <typename NodeType, typename TA, typename TB>
+class PBinaryExpr : public Pattern<PBinaryExpr<NodeType, TA, TB>> {
public:
PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {}
typename TB::Nested b_;
};
-template<typename TA>
-class PConstWithTypeLike :
- public Pattern<PConstWithTypeLike<TA> > {
+template <typename TA>
+class PConstWithTypeLike : public Pattern<PConstWithTypeLike<TA>> {
public:
- PConstWithTypeLike(const TA& ref, int64_t value)
- : ref_(ref), value_(value) {}
+ PConstWithTypeLike(const TA& ref, int64_t value) : ref_(ref), value_(value) {}
void InitMatch_() const {}
}
}
- PrimExpr Eval() const {
- return tir::make_const(ref_.Eval().dtype(), value_);
- }
+ PrimExpr Eval() const { return tir::make_const(ref_.Eval().dtype(), value_); }
private:
typename TA::Nested ref_;
int64_t value_;
};
-
-#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \
- template<typename TA, typename TB> \
- inline PBinaryExpr<NodeName, TA, TB> \
- FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
- CheckStep; \
- return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
- } \
- template<typename TA> \
- inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
- FuncName(const Pattern<TA>& a, int64_t b) { \
- CheckStep; \
- return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
- } \
- template<typename TA> \
- inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
- FuncName(int64_t b, const Pattern<TA>& a) { \
- CheckStep; \
- return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
- }
-
-#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
- TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )
-
+#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \
+ template <typename TA, typename TB> \
+ inline PBinaryExpr<NodeName, TA, TB> FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
+ CheckStep; \
+ return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
+ } \
+ template <typename TA> \
+ inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA>> FuncName(const Pattern<TA>& a, \
+ int64_t b) { \
+ CheckStep; \
+ return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
+ } \
+ template <typename TA> \
+ inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> FuncName(int64_t b, \
+ const Pattern<TA>& a) { \
+ CheckStep; \
+ return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
+ }
+
+#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )
// raise ambiguity error for operator overload of / and %
TVM_PATTERN_BINARY_OP_EX(operator/, tir::DivNode, DivAmbiguityError(a));
* \brief Pattern not expression.
* \tparam TA The pattern type of the true operand.
*/
-template<typename TA>
-class PNotExpr : public Pattern<PNotExpr<TA> > {
+template <typename TA>
+class PNotExpr : public Pattern<PNotExpr<TA>> {
public:
- explicit PNotExpr(const TA& value)
- : value_(value) {}
+ explicit PNotExpr(const TA& value) : value_(value) {}
- void InitMatch_() const {
- value_.InitMatch_();
- }
+ void InitMatch_() const { value_.InitMatch_(); }
bool Match_(const ObjectRef& node) const {
if (const tir::NotNode* ptr = node.as<tir::NotNode>()) {
}
}
- PrimExpr Eval() const {
- return tir::NotNode::make(value_.Eval());
- }
+ PrimExpr Eval() const { return tir::NotNode::make(value_.Eval()); }
private:
typename TA::Nested value_;
};
-template<typename TA>
+template <typename TA>
inline PNotExpr<TA> operator!(const Pattern<TA>& value) {
return PNotExpr<TA>(value.derived());
}
* \tparam TA The pattern type of the true operand.
* \tparam TB The pattern type of the false operand.
*/
-template<typename TCond, typename TA, typename TB>
-class PSelectExpr :
- public Pattern<PSelectExpr<TCond, TA, TB> > {
+template <typename TCond, typename TA, typename TB>
+class PSelectExpr : public Pattern<PSelectExpr<TCond, TA, TB>> {
public:
- PSelectExpr(const TCond& condition,
- const TA& true_value,
- const TB& false_value)
- : condition_(condition),
- true_value_(true_value),
- false_value_(false_value) {}
+ PSelectExpr(const TCond& condition, const TA& true_value, const TB& false_value)
+ : condition_(condition), true_value_(true_value), false_value_(false_value) {}
void InitMatch_() const {
condition_.InitMatch_();
}
PrimExpr Eval() const {
- return tir::SelectNode::make(
- condition_.Eval(), true_value_.Eval(), false_value_.Eval());
+ return tir::SelectNode::make(condition_.Eval(), true_value_.Eval(), false_value_.Eval());
}
private:
* \tparam TA The pattern type of the true operand.
* \tparam TB The pattern type of the false operand.
*/
-template<typename TCond, typename TA, typename TB>
-inline PSelectExpr<TCond, TA, TB>
-select(const Pattern<TCond>& condition,
- const Pattern<TA>& true_value,
- const Pattern<TB>& false_value) {
- return PSelectExpr<TCond, TA, TB>(
- condition.derived(), true_value.derived(), false_value.derived());
+template <typename TCond, typename TA, typename TB>
+inline PSelectExpr<TCond, TA, TB> select(const Pattern<TCond>& condition,
+ const Pattern<TA>& true_value,
+ const Pattern<TB>& false_value) {
+ return PSelectExpr<TCond, TA, TB>(condition.derived(), true_value.derived(),
+ false_value.derived());
}
/*!
* \tparam DType The Pattern type of dtype.
* \tparam TA The pattern type of the first operand.
*/
-template<typename DType, typename TA>
-class PCastExpr :
- public Pattern<PCastExpr<DType, TA> > {
+template <typename DType, typename TA>
+class PCastExpr : public Pattern<PCastExpr<DType, TA>> {
public:
- PCastExpr(const DType& dtype, const TA& value)
- : dtype_(dtype), value_(value) {
- }
+ PCastExpr(const DType& dtype, const TA& value) : dtype_(dtype), value_(value) {}
void InitMatch_() const {
dtype_.InitMatch_();
}
}
- PrimExpr Eval() const {
- return tir::CastNode::make(dtype_.Eval(), value_.Eval());
- }
+ PrimExpr Eval() const { return tir::CastNode::make(dtype_.Eval(), value_.Eval()); }
private:
typename DType::Nested dtype_;
* \tparam DType The pattern type of type.
* \tparam TA The pattern type of value.
*/
-template<typename DType, typename TA>
-inline PCastExpr<DType, TA>
-cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
+template <typename DType, typename TA>
+inline PCastExpr<DType, TA> cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
return PCastExpr<DType, TA>(dtype.derived(), value.derived());
}
* \tparam TStride The pattern type of the stride.
* \tparam TLanes The pattern type of the lanes.
*/
-template<typename TBase, typename TStride, typename TLanes>
-class PRampExpr :
- public Pattern<PRampExpr<TBase, TStride, TLanes> > {
+template <typename TBase, typename TStride, typename TLanes>
+class PRampExpr : public Pattern<PRampExpr<TBase, TStride, TLanes>> {
public:
- PRampExpr(const TBase& base,
- const TStride& stride,
- const TLanes& lanes)
- : base_(base), stride_(stride), lanes_(lanes) {
- }
+ PRampExpr(const TBase& base, const TStride& stride, const TLanes& lanes)
+ : base_(base), stride_(stride), lanes_(lanes) {}
void InitMatch_() const {
base_.InitMatch_();
}
}
- PrimExpr Eval() const {
- return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
- }
+ PrimExpr Eval() const { return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); }
private:
typename TBase::Nested base_;
* \tparam TStride The pattern type of the stride.
* \tparam TLanes The pattern type of the lanes.
*/
-template<typename TBase, typename TStride, typename TLanes>
-inline PRampExpr<TBase, TStride, TLanes>
-ramp(const Pattern<TBase>& base,
- const Pattern<TStride>& stride,
- const Pattern<TLanes>& lanes) {
- return PRampExpr<TBase, TStride, TLanes>(
- base.derived(), stride.derived(), lanes.derived());
+template <typename TBase, typename TStride, typename TLanes>
+inline PRampExpr<TBase, TStride, TLanes> ramp(const Pattern<TBase>& base,
+ const Pattern<TStride>& stride,
+ const Pattern<TLanes>& lanes) {
+ return PRampExpr<TBase, TStride, TLanes>(base.derived(), stride.derived(), lanes.derived());
}
-template<typename TBase>
-inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>
-ramp(const Pattern<TBase>& base,
- int stride,
- int lanes) {
+template <typename TBase>
+inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>> ramp(const Pattern<TBase>& base,
+ int stride, int lanes) {
return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>(
- base.derived(),
- PConstWithTypeLike<TBase>(base.derived(), stride),
- PConst<int>(lanes));
+ base.derived(), PConstWithTypeLike<TBase>(base.derived(), stride), PConst<int>(lanes));
}
/*!
* \tparam TA The pattern type of the value.
* \tparam TLanes The pattern type of the lanes.
*/
-template<typename TA, typename TLanes>
-class PBroadcastExpr :
- public Pattern<PBroadcastExpr<TA, TLanes> > {
+template <typename TA, typename TLanes>
+class PBroadcastExpr : public Pattern<PBroadcastExpr<TA, TLanes>> {
public:
- PBroadcastExpr(const TA& value,
- const TLanes& lanes)
- : value_(value), lanes_(lanes) {
- }
+ PBroadcastExpr(const TA& value, const TLanes& lanes) : value_(value), lanes_(lanes) {}
void InitMatch_() const {
value_.InitMatch_();
}
}
- PrimExpr Eval() const {
- return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval());
- }
+ PrimExpr Eval() const { return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); }
private:
typename TA::Nested value_;
* \tparam TA The pattern type of the value.
* \tparam TLanes The pattern type of the lanes.
*/
-template<typename TA, typename TLanes>
-inline PBroadcastExpr<TA, TLanes>
-broadcast(const Pattern<TA>& value, const Pattern<TLanes>& lanes) {
+template <typename TA, typename TLanes>
+inline PBroadcastExpr<TA, TLanes> broadcast(const Pattern<TA>& value,
+ const Pattern<TLanes>& lanes) {
return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived());
}
// internal namespace
namespace detail {
// implementation details for CallExpr
-template<bool stop, std::size_t I, typename F>
+template <bool stop, std::size_t I, typename F>
struct tuple_for_each_dispatcher {
- template<typename TTuple>
- static void run(F& f, const TTuple& tuple) { // NOLINT(*)
+ template <typename TTuple>
+ static void run(F& f, const TTuple& tuple) { // NOLINT(*)
f(I, std::get<I>(tuple));
- tuple_for_each_dispatcher<
- (I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>
- ::run(f, tuple);
+ tuple_for_each_dispatcher<(I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>::run(f, tuple);
}
};
-template<std::size_t I, typename F>
+template <std::size_t I, typename F>
struct tuple_for_each_dispatcher<true, I, F> {
- template<typename TTuple>
- static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
+ template <typename TTuple>
+ static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
};
-template<typename F, typename TTuple>
+template <typename F, typename TTuple>
inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*)
- tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>
- ::run(f, tuple);
+ tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>::run(f, tuple);
}
struct PCallExprInitMatchFunctor {
- template<typename T>
+ template <typename T>
void operator()(size_t i, const T& pattern) const {
pattern.InitMatch_();
}
const tir::CallNode* call_;
bool matched_{true};
- explicit PCallExprMatchFunctor(const tir::CallNode* call)
- : call_(call) {}
+ explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {}
- template<typename T>
+ template <typename T>
void operator()(size_t i, const T& pattern) {
matched_ = matched_ && pattern.Match_(call_->args[i]);
}
struct PCallExprEvalArgsFunctor {
Array<PrimExpr> args_;
- template<typename T>
+ template <typename T>
void operator()(size_t i, const T& pattern) {
args_.push_back(pattern.Eval());
}
* \note Op functor contains the name of the function and
* the implementation of Eval.
*/
-template<typename Op, typename ...TArgs>
-class PCallExpr :
- public Pattern<PCallExpr<Op, TArgs...> > {
+template <typename Op, typename... TArgs>
+class PCallExpr : public Pattern<PCallExpr<Op, TArgs...>> {
public:
- explicit PCallExpr(const TArgs&... args)
- : args_(args...) {
- }
+ explicit PCallExpr(const TArgs&... args) : args_(args...) {}
void InitMatch_() const {
detail::PCallExprInitMatchFunctor finit;
};
// arithemetic intrinsics
-#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
- struct OpName { \
- static PrimExpr Eval(Array<PrimExpr> args) { \
- return tir::CallNode::make(args[0].dtype(), kName, args, \
- tir::CallNode::PureIntrinsic); \
- } \
- static constexpr const char* kName = IntrinStr; \
- }; \
- template<typename TA, typename TB> \
- inline PCallExpr<OpName, TA, TB> \
- FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
- return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \
+#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
+ struct OpName { \
+ static PrimExpr Eval(Array<PrimExpr> args) { \
+ return tir::CallNode::make(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \
+ } \
+ static constexpr const char* kName = IntrinStr; \
+ }; \
+ template <typename TA, typename TB> \
+ inline PCallExpr<OpName, TA, TB> FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
+ return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \
}
TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
// unary intrinsics
-#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
- struct OpName { \
- static PrimExpr Eval(Array<PrimExpr> args) { \
- return tir::CallNode::make(args[0].dtype(), kName, args, \
- tir::CallNode::PureIntrinsic); \
- } \
- static constexpr const char* kName = IntrinStr; \
- }; \
- template<typename TA> \
- inline PCallExpr<OpName, TA> \
- FuncName(const Pattern<TA>& a) { \
- return PCallExpr<OpName, TA>(a.derived()); \
+#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
+ struct OpName { \
+ static PrimExpr Eval(Array<PrimExpr> args) { \
+ return tir::CallNode::make(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \
+ } \
+ static constexpr const char* kName = IntrinStr; \
+ }; \
+ template <typename TA> \
+ inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) { \
+ return PCallExpr<OpName, TA>(a.derived()); \
}
TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
// if_then_else
struct PIfThenElseOp {
static PrimExpr Eval(Array<PrimExpr> args) {
- return tir::CallNode::make(
- args[1].dtype(), kName, args,
- tir::CallNode::PureIntrinsic);
+ return tir::CallNode::make(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic);
}
static constexpr const char* kName = "tvm_if_then_else";
};
* \tparam TA The pattern type of the true operand.
* \tparam TB The pattern type of the false operand.
*/
-template<typename TCond, typename TA, typename TB>
-inline PCallExpr<PIfThenElseOp, TCond, TA, TB>
-if_then_else(const Pattern<TCond>& cond,
- const Pattern<TA>& true_value,
- const Pattern<TB>& false_value) {
- return PCallExpr<PIfThenElseOp, TCond, TA, TB>(
- cond.derived(), true_value.derived(), false_value.derived());
+template <typename TCond, typename TA, typename TB>
+inline PCallExpr<PIfThenElseOp, TCond, TA, TB> if_then_else(const Pattern<TCond>& cond,
+ const Pattern<TA>& true_value,
+ const Pattern<TB>& false_value) {
+ return PCallExpr<PIfThenElseOp, TCond, TA, TB>(cond.derived(), true_value.derived(),
+ false_value.derived());
}
} // namespace arith
* \brief Rewrite-rule based simplification.
*/
// Acknowledgement: Most rewrite-rules are from Halide.
+#include "rewrite_simplify.h"
+
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
+
#include <algorithm>
+
#include "const_fold.h"
#include "pattern_match.h"
-#include "rewrite_simplify.h"
namespace tvm {
namespace arith {
using namespace tir;
// macro for doing simple rewrite
-#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
- if ((SrcExpr).Match(ret)) { \
- return (ResExpr).Eval(); \
+#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
+ if ((SrcExpr).Match(ret)) { \
+ return (ResExpr).Eval(); \
}
// macro for rewrite + recursively rewrite ResExpr
}
// macro rewrite only if CondExor is true after match.
-#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
- if ((SrcExpr).Match(ret) && (CondExpr)) { \
- return (ResExpr).Eval(); \
+#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
+ if ((SrcExpr).Match(ret) && (CondExpr)) { \
+ return (ResExpr).Eval(); \
}
// macro rewrite + recursive_rewrite only if CondExor is true after match.
-#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
- if ((SrcExpr).Match(ret) && (CondExpr)) { \
- return RecursiveRewrite((ResExpr).Eval()); \
+#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
+ if ((SrcExpr).Match(ret) && (CondExpr)) { \
+ return RecursiveRewrite((ResExpr).Eval()); \
}
// NOTE for developers:
//
// try to prove x equals val
-RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
-TryCompare(const PrimExpr& x, int64_t val) {
+RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x,
+ int64_t val) {
PrimExpr diff = this->VisitExpr(x);
if (const auto* ptr = diff.as<IntImmNode>()) {
if (ptr->value == val) {
return kUnknown;
}
-void RewriteSimplifier::Impl::
-Update(const Var& var, const PrimExpr& info, bool can_override) {
+void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) {
if (!can_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
- CHECK(ExprDeepEqual()(it->second, info))
- << "Trying to update var \'" << var << "\'"
- << " with a different value: "
- << "original=" << it->second
- << ", new=" << info;
+ CHECK(ExprDeepEqual()(it->second, info)) << "Trying to update var \'" << var << "\'"
+ << " with a different value: "
+ << "original=" << it->second << ", new=" << info;
}
}
var_map_[var] = info;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const AddNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<AddNode>();
PrimExpr const_res = TryConstFold<AddNode>(op->a, op->b);
PVar<int> lanes;
// Vector rules
if (op->dtype.lanes() != 1) {
- TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes),
- ramp(b1 + b2, s1 + s2, lanes));
- TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes),
- ramp(b1 + x, s1, lanes));
- TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes),
- ramp(x + b1, s1, lanes));
- TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes),
- broadcast(x + y, lanes));
+ TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes));
+ TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes));
}
if (IsIndexType(op->dtype)) {
TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y);
TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y);
- TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y),
- c1.Eval()->value == -c2.Eval()->value);
- TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2),
- c1.Eval()->value == -c2.Eval()->value);
- TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y),
- c1.Eval()->value == -c2.Eval()->value);
- TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2),
- c1.Eval()->value == -c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), c1.Eval()->value == -c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), c1.Eval()->value == -c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), c1.Eval()->value == -c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), c1.Eval()->value == -c2.Eval()->value);
// constant folding
// NOTE: canonicalization might better at this.
}
// condition rules.
- TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2),
- select(x, b1 + s1, b2 + s2));
+ TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), select(x, b1 + s1, b2 + s2));
// default value
return ret;
}
return frecover;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const SubNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<SubNode>();
PrimExpr const_res = TryConstFold<SubNode>(op->a, op->b);
PVar<int> lanes;
// Vector rules
if (op->dtype.lanes() != 1) {
- TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes),
- ramp(b1 - b2, s1 - s2, lanes));
- TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes),
- ramp(b1 - x, s1, lanes));
- TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes),
- ramp(x - b1, 0 - s1, lanes));
- TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes),
- broadcast(x - y, lanes));
+ TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes));
+ TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x, s1, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0 - s1, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes));
}
if (IsIndexType(op->dtype)) {
TVM_TRY_REWRITE((y + x) - (z + x), y - z);
TVM_TRY_REWRITE((y + x) - (x + z), y - z);
- TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x));
- TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x));
- TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y));
- TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y));
+ TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x));
+ TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x));
+ TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y));
+ TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y));
- TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x));
- TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x));
- TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y));
- TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y));
+ TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x));
+ TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x));
+ TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y));
+ TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y));
- TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z));
- TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z));
- TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y));
- TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y));
+ TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z));
+ TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z));
+ TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y));
+ TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y));
TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x));
TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x));
// DivMod rules
// trucdiv
// NOTE: c*(x/c) + x % c == x is true all division mode.
- TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1),
- c1.Eval()->value != 0);
- TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1),
- c1.Eval()->value != 0);
+ TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), c1.Eval()->value != 0);
+ TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1),
TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y,
c1.Eval()->value != 0);
- TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
// Proof in the case of floordiv, need positive condition.
// let x = a * c3 + r
// (x + c1) / c3 - x / c3 => (r + c1) / c3
// NOTE: the use of floormod(c2, c3) was intentional to simplify the const.
- TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3),
+ TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3),
truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
- c1.Eval()->value >= c2.Eval()->value &&
- c3.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x, c3),
- truncdiv(truncmod(x, c3) + c1, c3),
- CanProveGreaterEqual(x.Eval(), 0) &&
- c1.Eval()->value >= 0 &&
- c3.Eval()->value > 0);
+ c1.Eval()->value >= c2.Eval()->value && c3.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(
+ truncdiv(x + c1, c3) - truncdiv(x, c3), truncdiv(truncmod(x, c3) + c1, c3),
+ CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value >= 0 && c3.Eval()->value > 0);
// floordiv
- TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1),
- c1.Eval()->value != 0);
- TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1),
- c1.Eval()->value != 0);
+ TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), c1.Eval()->value != 0);
+ TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - floordiv(x + y, c1) * c1, floormod(x + y, c1) - y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c1 - x, y - floormod(x + y, c1),
TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y,
c1.Eval()->value != 0);
- TVM_TRY_REWRITE_IF(x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
- TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2,
- c1.Eval()->value != 0 &&
- c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(
+ floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2,
+ c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3),
floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
c3.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3),
- floordiv(floormod(x, c3) + c1, c3),
+ TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), floordiv(floormod(x, c3) + c1, c3),
c3.Eval()->value > 0);
// canonicalization rule
}
// condition rules.
- TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2),
- select(x, b1 - s1, b2 - s2));
- TVM_TRY_REWRITE(select(x, y, z) - z,
- select(x, y - z, ZeroWithTypeLike(z)));
- TVM_TRY_REWRITE(select(x, y, z) - y,
- select(x, ZeroWithTypeLike(y), z - y));
+ TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), select(x, b1 - s1, b2 - s2));
+ TVM_TRY_REWRITE(select(x, y, z) - z, select(x, y - z, ZeroWithTypeLike(z)));
+ TVM_TRY_REWRITE(select(x, y, z) - y, select(x, ZeroWithTypeLike(y), z - y));
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const MulNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MulNode>();
PrimExpr const_res = TryConstFold<MulNode>(op->a, op->b);
PVar<int> lanes;
// Vector rules
if (op->dtype.lanes() != 1) {
- TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes),
- broadcast(x * y, lanes));
- TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes),
- ramp(b1 * x, s1 * x, lanes));
- TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes),
- ramp(b1 * x, s1 * x, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes));
+ TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes));
}
if (IsIndexType(op->dtype)) {
// canonicalization
TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1);
TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1);
- TVM_TRY_RECURSIVE_REWRITE_IF(
- (x - y) * c1, (y - x) * (0 - c1),
- c1.Eval()->value < 0);
+ TVM_TRY_RECURSIVE_REWRITE_IF((x - y) * c1, (y - x) * (0 - c1), c1.Eval()->value < 0);
}
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const DivNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<DivNode>();
PrimExpr const_res = TryConstFold<DivNode>(op->a, op->b);
// Vector rules
if (op->dtype.lanes() != 1) {
// NOTE: use div as the pattern also works for float.
- TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)),
- broadcast(div(x, y), lanes));
+ TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes));
// ramp / bcast
if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) {
int64_t c1val = c1.Eval()->value;
c1.Eval()->value > 0 && c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3),
- c1.Eval()->value > 0 &&
- c2.Eval()->value >= 0 &&
- c3.Eval()->value > 0 &&
- CanProveGreaterEqual(x.Eval(), 0));
+ c1.Eval()->value > 0 && c2.Eval()->value >= 0 && c3.Eval()->value > 0 &&
+ CanProveGreaterEqual(x.Eval(), 0));
if (truncdiv(x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
TVM_TRY_REWRITE(truncdiv(c1 * x, x), c1);
// Rules involving 2-operands.
- TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2),
- x * truncdiv(c1, c2) + truncdiv(y, c2),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2),
- min(x * truncdiv(c1, c2), truncdiv(y, c2)),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2),
- max(x * truncdiv(c1, c2), truncdiv(y, c2)),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2),
- truncdiv(y, c2) + x * truncdiv(c1, c2),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2),
- min(truncdiv(y, c2), x * truncdiv(c1, c2)),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2),
- max(truncdiv(y, c2), x * truncdiv(c1, c2)),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
+ TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), x * truncdiv(c1, c2) + truncdiv(y, c2),
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
+ c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), min(x * truncdiv(c1, c2), truncdiv(y, c2)),
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
+ c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), max(x * truncdiv(c1, c2), truncdiv(y, c2)),
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
+ c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), truncdiv(y, c2) + x * truncdiv(c1, c2),
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
+ c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), min(truncdiv(y, c2), x * truncdiv(c1, c2)),
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
+ c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), max(truncdiv(y, c2), x * truncdiv(c1, c2)),
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
+ c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
// Rules involving 3-operands.
- TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y + z, c2),
- x * truncdiv(c1, c2) + truncdiv(y + z, c2),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual((y + z).Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv(x * c1 - y + z, c2),
- x * truncdiv(c1, c2) + truncdiv(z - y, c2),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual((z - y).Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y - z, c2),
- x * truncdiv(c1, c2) + truncdiv(y - z, c2),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual((y - z).Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv(y + x * c1 + z, c2),
- x * truncdiv(c1, c2) + truncdiv(y + z, c2),
- c1.Eval()->value > 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual((y + z).Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2),
- truncdiv(x, c2) + truncdiv(c1, c2),
- c1.Eval()->value > 0 &&
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0));
+ TVM_TRY_REWRITE_IF(
+ truncdiv(x * c1 + y + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2),
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(
+ truncdiv(x * c1 - y + z, c2), x * truncdiv(c1, c2) + truncdiv(z - y, c2),
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((z - y).Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(
+ truncdiv(x * c1 + y - z, c2), x * truncdiv(c1, c2) + truncdiv(y - z, c2),
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y - z).Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(
+ truncdiv(y + x * c1 + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2),
+ c1.Eval()->value > 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), truncdiv(x, c2) + truncdiv(c1, c2),
+ c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
+ c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(x + y, x), truncdiv(y, x) + 1,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(y + x, x), truncdiv(y, x) + 1,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
-
- TVM_TRY_REWRITE_IF(truncdiv((x + y) + z, x),
- truncdiv(y + z, x) + 1,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual((y + z).Eval(), 0));
- TVM_TRY_REWRITE_IF(truncdiv((y + x) + z, x),
- truncdiv(y + z, x) + 1,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual((y + z).Eval(), 0));
- TVM_TRY_REWRITE_IF(truncdiv(y + (z + x), x),
- truncdiv(y + z, x) + 1,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual((y + z).Eval(), 0));
- TVM_TRY_REWRITE_IF(truncdiv(y + (x + z), x),
- truncdiv(y + z, x) + 1,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual((y + z).Eval(), 0));
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(
+ truncdiv((x + y) + z, x), truncdiv(y + z, x) + 1,
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
+ TVM_TRY_REWRITE_IF(
+ truncdiv((y + x) + z, x), truncdiv(y + z, x) + 1,
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
+ TVM_TRY_REWRITE_IF(
+ truncdiv(y + (z + x), x), truncdiv(y + z, x) + 1,
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
+ TVM_TRY_REWRITE_IF(
+ truncdiv(y + (x + z), x), truncdiv(y + z, x) + 1,
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(x * y, y), x,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(y * x, y), x,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(x * z + y, z), x + truncdiv(y, z),
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0) &&
- CanProveGreaterEqual(z.Eval(), 0));
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
+ CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(z * x + y, z), x + truncdiv(y, z),
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0) &&
- CanProveGreaterEqual(z.Eval(), 0));
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
+ CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(y + x * z, z), truncdiv(y, z) + x,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0) &&
- CanProveGreaterEqual(z.Eval(), 0));
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
+ CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(y + z * x, z), truncdiv(y, z) + x,
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0) &&
- CanProveGreaterEqual(z.Eval(), 0));
+ CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
+ CanProveGreaterEqual(z.Eval(), 0));
}
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const ModNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<ModNode>();
PrimExpr const_res = TryConstFold<ModNode>(op->a, op->b);
if (ramp_min == ramp_max) {
return ramp(truncmod(bmod->base, c2), c1, lanes).Eval();
} else {
- return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes),
- broadcast(c2, lanes)).Eval();
+ return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval();
}
}
}
// We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands.
TVM_TRY_REWRITE_IF(truncmod(x * c1, c2), ZeroWithTypeLike(x),
- c2.Eval()->value != 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
+ c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(truncmod(x * c1 + y, c2), truncmod(y, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual((x * c1).Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual((x * c1).Eval(), 0) &&
+ CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(truncmod(x + c1, c2), truncmod(x, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value >= 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0));
+ c2.Eval()->value > 0 && c1.Eval()->value >= 0 &&
+ c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(truncmod(x + y * c1, c2), truncmod(x, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0 &&
- CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual((y * c1).Eval(), 0));
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual(x.Eval(), 0) &&
+ CanProveGreaterEqual((y * c1).Eval(), 0));
// canonicalization: x % c == x % (-c) for truncated division
// NOTE: trunc div required
TVM_TRY_RECURSIVE_REWRITE_IF(
- truncmod(x, c1),
- truncmod(x, PConst<PrimExpr>(make_const(op->dtype, -c1.Eval()->value))),
+ truncmod(x, c1), truncmod(x, PConst<PrimExpr>(make_const(op->dtype, -c1.Eval()->value))),
c1.Eval()->value < 0);
// try modular analysis
if (truncmod(x, c1).Match(ret)) {
ModularSet mod = analyzer_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value;
- if (mod->coeff % c1val == 0 &&
- c1val > 0 &&
- CanProveGreaterEqual(x.Eval(), 0)) {
+ if (mod->coeff % c1val == 0 && c1val > 0 && CanProveGreaterEqual(x.Eval(), 0)) {
return truncmod(mod->base, c1).Eval();
}
}
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const FloorDivNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>();
PrimExpr const_res = TryConstFold<FloorDivNode>(op->a, op->b);
TVM_TRY_REWRITE(floordiv(c1 * x, x), c1);
// Rules involving 2-operands.
- TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2),
- x * floordiv(c1, c2) + floordiv(y, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2),
- min(x * floordiv(c1, c2), floordiv(y, c2)),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2),
- max(x * floordiv(c1, c2), floordiv(y, c2)),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2),
- floordiv(y, c2) + x * floordiv(c1, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2),
- min(floordiv(y, c2), x * floordiv(c1, c2)),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2),
- max(floordiv(y, c2), x * floordiv(c1, c2)),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
+ TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
// Rules involving 3-operands.
- TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2),
- x * floordiv(c1, c2) + floordiv(y + z, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2),
- x * floordiv(c1, c2) + floordiv(z - y, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2),
- x * floordiv(c1, c2) + floordiv(y - z, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2),
- x * floordiv(c1, c2) + floordiv(y + z, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(x + c1, c2),
- floordiv(x, c2) + floordiv(c1, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
-
- TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1,
- CanProveGreaterEqual(x.Eval(), 0));
+ TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
- TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1,
- CanProveGreaterEqual(x.Eval(), 0));
+ TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), x * floordiv(c1, c2) + floordiv(z - y, c2),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2), x * floordiv(c1, c2) + floordiv(y - z, c2),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2),
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv((x + y) + z, x), floordiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(y + (x + z), x), floordiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0));
- TVM_TRY_REWRITE_IF(floordiv(x * y, y), x,
- CanProveGreaterEqual(y.Eval(), 0));
- TVM_TRY_REWRITE_IF(floordiv(y * x, y), x,
- CanProveGreaterEqual(y.Eval(), 0));
+ TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, CanProveGreaterEqual(y.Eval(), 0));
+ TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(x * z + y, z), x + floordiv(y, z),
CanProveGreaterEqual(z.Eval(), 0));
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const FloorModNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorModNode>();
PrimExpr const_res = TryConstFold<FloorModNode>(op->a, op->b);
if (IsIndexType(op->dtype)) {
// Be-aware of the division rules: we use floordiv/floormod here
TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x),
- c2.Eval()->value != 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
+ c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value % c2.Eval()->value == 0);
+ c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
// try modular analysis
if (floormod(x, c1).Match(ret)) {
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const MinNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MinNode>();
PrimExpr const_res = TryConstFold<MinNode>(op->a, op->b);
// vector rule
if (op->dtype.lanes() != 1) {
- TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)),
- broadcast(min(x, y), lanes));
+ TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes));
TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)),
min(x, broadcast(min(y, z), lanes)));
}
return (x + c2).Eval();
}
}
- if (min(x + c1, x).Match(ret) ||
- min(x, x + c1).Match(ret)) {
+ if (min(x + c1, x).Match(ret) || min(x, x + c1).Match(ret)) {
if (c1.Eval()->value < 0) {
return (x + c1).Eval();
} else {
// Divide up rounding: truc div
// NOTE: trucdiv(x, y) >= floordiv(x, y)
TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, x), x,
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value &&
- CanProveGreaterEqual(x.Eval(), 0));
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
+ CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x,
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value &&
- CanProveGreaterEqual(x.Eval(), 0));
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
+ CanProveGreaterEqual(x.Eval(), 0));
// Divide up rounding: floor div
TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, x), x,
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, max(x, c2)), max(x, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(x, floordiv(x + c1, c2) * c2), x,
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(max(x, c2), floordiv(x + c1, c2) * c2), max(x, c2),
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
- TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2,
- c2.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2,
- c2.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, c2.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, c2.Eval()->value > 0);
TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y));
TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y));
// canonicalization
TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1));
- TVM_TRY_RECURSIVE_REWRITE_IF(
- min(c1 - x, c2), c1 - max(x, c1 - c2),
- c2.Eval()->value != 0);
+ TVM_TRY_RECURSIVE_REWRITE_IF(min(c1 - x, c2), c1 - max(x, c1 - c2), c2.Eval()->value != 0);
}
// condition rules.
- TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)),
- select(x, min(y, s1), min(z, s2)));
+ TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), select(x, min(y, s1), min(z, s2)));
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const MaxNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MaxNode>();
PrimExpr const_res = TryConstFold<MaxNode>(op->a, op->b);
// vector rule
if (op->dtype.lanes() != 1) {
- TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)),
- broadcast(max(x, y), lanes));
+ TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes));
TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)),
max(x, broadcast(max(y, z), lanes)));
}
return (x + c2).Eval();
}
}
- if (max(x + c1, x).Match(ret) ||
- max(x, x + c1).Match(ret)) {
+ if (max(x + c1, x).Match(ret) || max(x, x + c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (x + c1).Eval();
} else {
// DivMod rules
// Divide up rounding: truc div
// NOTE: trucdiv(x, y) >= floordiv(x, y)
- TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x),
- truncdiv(x + c1, c2) * c2,
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
- TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2),
- truncdiv(x + c1, c2) * c2,
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), truncdiv(x + c1, c2) * c2,
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), truncdiv(x + c1, c2) * c2,
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
// Divide up rounding: floor div
TVM_TRY_REWRITE_IF(max(floordiv(x + c1, c2) * c2, x), floordiv(x + c1, c2) * c2,
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, floordiv(x + c1, c2) * c2), floordiv(x + c1, c2) * c2,
- c2.Eval()->value > 0 &&
- c1.Eval()->value + 1 == c2.Eval()->value);
+ c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
- TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x,
- c2.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x,
- c2.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x, c2.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x, c2.Eval()->value > 0);
TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y));
TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y));
// canonicalization
TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1));
- TVM_TRY_RECURSIVE_REWRITE_IF(
- max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0);
+ TVM_TRY_RECURSIVE_REWRITE_IF(max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0);
}
// condition rules.
- TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)),
- select(x, max(y, s1), max(z, s2)));
+ TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), select(x, max(y, s1), max(z, s2)));
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const EQNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<EQNode>();
PrimExpr const_res = TryConstFold<EQNode>(op->a, op->b);
// vector rule
if (op->dtype.lanes() != 1) {
- TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes),
- broadcast(x == y, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes));
}
if (IsIndexType(op->a.dtype())) {
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const NENode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) {
return this->VisitExpr(NotNode::make(op->a == op->b));
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const LENode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) {
return this->VisitExpr(NotNode::make(op->b < op->a));
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const GTNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) {
return this->VisitExpr(op->b < op->a);
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const GENode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) {
return this->VisitExpr(NotNode::make(op->a < op->b));
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const LTNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LTNode>();
PrimExpr const_res = TryConstFold<LTNode>(op->a, op->b);
// vector rule
if (op->dtype.lanes() != 1) {
- TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes),
- broadcast(x < y, lanes));
- TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes),
- broadcast(x < y, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes));
+ TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes));
}
if (IsIndexType(op->a.dtype())) {
return make_const(op->dtype, false);
}
+ // clang-format off
TVM_TRY_REWRITE(x + y < x + z, y < z);
TVM_TRY_REWRITE(x + y < z + x, y < z);
TVM_TRY_REWRITE(y + x < x + z, y < z);
TVM_TRY_REWRITE(c1 < x + c2, c1 - c2 < x);
TVM_TRY_REWRITE(c1 < c2 - x, x < c2 - c1);
- TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y,
- c1.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x,
- c1.Eval()->value < 0);
+ TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, c1.Eval()->value < 0);
// constant cancelation: only need to make use of one mod
// truc div
- TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1 - 1, c2) + 1,
- c1.Eval()->value > 0 &&
- c2.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(x * c2 < c1,
+ x < truncdiv(c1 - 1, c2) + 1, c1.Eval()->value > 0 && c2.Eval()->value > 0);
// NOTE: trunc div required
TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1, c2),
- c1.Eval()->value <= 0 &&
- c2.Eval()->value > 0);
+ c1.Eval()->value <= 0 && c2.Eval()->value > 0);
// NOTE: trunc div required (euclidean is ok too, floored is not)
- TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x,
- c1.Eval()->value > 0 &&
+ TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, c1.Eval()->value > 0 &&
c2.Eval()->value < 0);
// NOTE: trunc div required (floored is ok too, euclidean is not)
TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1, c2) < x,
- c1.Eval()->value <= 0 &&
- c2.Eval()->value < 0);
+ c1.Eval()->value <= 0 && c2.Eval()->value < 0);
// NOTE: trunc div required
TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1 + 1, c2) - 1 < x,
- c1.Eval()->value < 0 &&
- c2.Eval()->value > 0);
+ c1.Eval()->value < 0 && c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1, c2) < x,
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0);
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0);
// NOTE: trunc div required (floored is ok too, euclidean is not)
TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1 + 1, c2) + 1,
- c1.Eval()->value < 0 &&
- c2.Eval()->value < 0);
+ c1.Eval()->value < 0 && c2.Eval()->value < 0);
// NOTE: trunc div required (euclidean is ok too, floored is not)
TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1, c2),
- c1.Eval()->value >= 0 &&
- c2.Eval()->value < 0);
+ c1.Eval()->value >= 0 && c2.Eval()->value < 0);
// DivMod rules
// trucdiv
- TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * c2,
- c1.Eval()->value > 0 &&
- c2.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2,
+ x<c1 * c2, c1.Eval()->value> 0 && c2.Eval()->value > 0);
// NOTE: trunc div required
- TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * (c2 - 1) + 1,
- c1.Eval()->value > 0 &&
- c2.Eval()->value <= 0);
+ TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2,
+ x<c1*(c2 - 1) + 1, c1.Eval()->value> 0 && c2.Eval()->value <= 0);
TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), (c1 + 1) * c2 - 1 < x,
- c1.Eval()->value >= 0 &&
- c2.Eval()->value > 0);
+ c1.Eval()->value >= 0 && c2.Eval()->value > 0);
// NOTE: trunc div required
TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), c1 * c2 < x,
- c1.Eval()->value < 0 &&
- c2.Eval()->value > 0);
+ c1.Eval()->value < 0 && c2.Eval()->value > 0);
// invariance for any div mod: x - (x / c1) * c1 == x % c1
- TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1),
- c1.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, 0 < truncmod(x, c1) + y,
- c1.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, y < truncmod(x, c1),
- c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y,
+ 0 < truncmod(x, c1) + y, c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y,
+ y < truncmod(x, c1), c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x,
- c2 < truncmod(x + c2, c1),
- c1.Eval()->value > 0);
+ c2 < truncmod(x + c2, c1), c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x + y,
- c2 < truncmod(x + c2, c1) + y,
- c1.Eval()->value > 0);
+ c2 < truncmod(x + c2, c1) + y, c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x - y,
- y < truncmod(x + c2, c1) + (0 - c2),
- c1.Eval()->value > 0);
+ y < truncmod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0);
// floordiv
- TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2,
- c1.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x,
- c2.Eval()->value > 0);
-
- TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1),
- c1.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y, 0 < floormod(x, c1) + y,
- c1.Eval()->value > 0);
- TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y, y < floormod(x, c1),
- c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2, c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x, c2.Eval()->value > 0);
+
+ TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y,
+ 0 < floormod(x, c1) + y, c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y,
+ y < floormod(x, c1), c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x,
- c2 < floormod(x + c2, c1),
- c1.Eval()->value > 0);
+ c2 < floormod(x + c2, c1), c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x + y,
- c2 < floormod(x + c2, c1) + y,
- c1.Eval()->value > 0);
+ c2 < floormod(x + c2, c1) + y, c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x - y,
- y < floormod(x + c2, c1) + (0 - c2),
- c1.Eval()->value > 0);
+ y < floormod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0);
// canonicalization rule
TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z);
TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1);
TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1);
TVM_TRY_REWRITE(x - c1 < 0, x < c1);
+ // clang-format on
}
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const NotNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<NotNode>();
PrimExpr const_res = TryConstFold<NotNode>(op->a);
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const AndNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<AndNode>();
PrimExpr const_res = TryConstFold<AndNode>(op->a, op->b);
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
- TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes),
- broadcast(x && y, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes));
}
auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false));
TVM_TRY_REWRITE(x <= y && y < x, cfalse);
TVM_TRY_REWRITE(y < x && x <= y, cfalse);
- TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse,
- c2.Eval()->value + 1 >= c1.Eval()->value);
- TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse,
- c2.Eval()->value + 1 >= c1.Eval()->value);
-
- TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse,
- c2.Eval()->value >= c1.Eval()->value);
- TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse,
- c2.Eval()->value >= c1.Eval()->value);
- TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse,
- c2.Eval()->value >= c1.Eval()->value);
- TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse,
- c2.Eval()->value >= c1.Eval()->value);
-
- TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse,
- c2.Eval()->value > c1.Eval()->value);
- TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse,
- c2.Eval()->value > c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
+
+ TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value);
+
+ TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value);
TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const OrNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<OrNode>();
PrimExpr const_res = TryConstFold<OrNode>(op->a, op->b);
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
- TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes),
- broadcast(x || y, lanes));
+ TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes));
}
auto ctrue = PConst<PrimExpr>(make_const(op->dtype, true));
TVM_TRY_REWRITE(x <= y || y < x, ctrue);
TVM_TRY_REWRITE(y < x || x <= y, ctrue);
- TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue,
- c2.Eval()->value < c1.Eval()->value);
- TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue,
- c2.Eval()->value < c1.Eval()->value);
-
- TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue,
- c2.Eval()->value <= c1.Eval()->value);
- TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue,
- c2.Eval()->value <= c1.Eval()->value);
- TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue,
- c2.Eval()->value <= c1.Eval()->value);
- TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue,
- c2.Eval()->value <= c1.Eval()->value);
-
- TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue,
- c2.Eval()->value <= c1.Eval()->value + 1);
- TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue,
- c2.Eval()->value <= c1.Eval()->value + 1);
+ TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);
+
+ TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value);
+
+ TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
+ TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const SelectNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SelectNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<SelectNode>();
if (op == nullptr) return ret;
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const CallNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<CallNode>();
return ret;
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const VarNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) {
Var var = GetRef<Var>(op);
auto it = var_map_.find(var);
if (it != var_map_.end()) {
return GetRef<PrimExpr>(op);
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const CastNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<CastNode>();
return cast(op->dtype, op->value);
}
-PrimExpr RewriteSimplifier::Impl::
-VisitExpr_(const LetNode* op) {
+PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (!tir::HasSideEffect(value)) {
// it is fine to discard the let binding
return this->VisitExpr(op->body);
}
PrimExpr body = this->VisitExpr(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
return res;
}
-void RewriteSimplifier::Update(const Var& var,
- const PrimExpr& info,
- bool override) {
+void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool override) {
impl_->Update(var, info, override);
}
return impl_->EnterConstraint(constraint);
}
-RewriteSimplifier::RewriteSimplifier(Analyzer* parent)
- : impl_(new Impl(parent)) {
-}
+RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {}
-RewriteSimplifier::~RewriteSimplifier() {
- delete impl_;
-}
+RewriteSimplifier::~RewriteSimplifier() { delete impl_; }
} // namespace arith
} // namespace tvm
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
+
#include <unordered_map>
#include <vector>
+
#include "const_fold.h"
-#include "pattern_match.h"
#include "ir_mutator_with_analyzer.h"
+#include "pattern_match.h"
namespace tvm {
namespace arith {
public:
using IRMutatorWithAnalyzer::VisitExpr_;
- explicit Impl(Analyzer* parent)
- : IRMutatorWithAnalyzer(parent) {}
+ explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const PrimExpr& info, bool override_info);
PrimExpr VisitExpr_(const AddNode* op) override;
protected:
/*! \brief internal structure for comparison. */
- enum CompareResult {
- kUnknown,
- kEQ,
- kGT,
- kGE,
- kLT,
- kLE,
- kNE
- };
+ enum CompareResult { kUnknown, kEQ, kGT, kGE, kLT, kLE, kNE };
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
return res;
}
- template<typename TA>
+ template <typename TA>
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
}
- template<typename TA>
+ template <typename TA>
PConstWithTypeLike<TA> OneWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 1);
}
};
-
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_REWRITE_SIMPLIFY_H_
* \file tvm/arith/solve_linear_equation.cc
* \brief Solve linear equations.
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/tir/expr.h>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
-#include <tvm/arith/util.h>
#include <tvm/arith/pattern.h>
-
+#include <tvm/arith/util.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/runtime/data_type.h>
namespace tvm {
namespace arith {
using namespace tvm::runtime;
-void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
- std::vector<std::vector<int64_t> >* V,
- std::vector<PrimExpr>* x,
- std::vector<PrimExpr>* y) {
+void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::vector<int64_t>>* V,
+ std::vector<PrimExpr>* x, std::vector<PrimExpr>* y) {
if (S->empty() || V->empty()) return;
size_t m = S->size();
size_t n = (*S)[0].size(); // n is # of variables
for (size_t j = index; j < (*S)[i].size(); ++j) {
// Multiply index-th row by a and add the i-th row multiplied by b
// This will make the index-th diagonal element equal to the gcd
- int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j];
+ int64_t new_index_j = a * (*S)[index][j] + b * (*S)[i][j];
// This transformation performs zeroing of matrix[i][index]
- int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j];
+ int64_t new_i_j = n_g * (*S)[index][j] - m_g * (*S)[i][j];
(*S)[index][j] = new_index_j;
(*S)[i][j] = new_i_j;
}
PrimExpr eb = tir::make_const((*y)[i].dtype(), b);
PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g);
PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g);
- PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
- PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
+ PrimExpr new_index_rhs = ea * (*y)[index] + eb * (*y)[i];
+ PrimExpr new_i_rhs = e_n_g * (*y)[index] - e_m_g * (*y)[i];
(*y)[index] = new_index_rhs;
(*y)[i] = new_i_rhs;
}
int64_t n_g = (*S)[index][j] / g;
for (size_t i = index; i < m; ++i) {
- int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j];
- int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j];
+ int64_t new_i_index = a * (*S)[i][index] + b * (*S)[i][j];
+ int64_t new_i_j = n_g * (*S)[i][index] - m_g * (*S)[i][j];
(*S)[i][index] = new_i_index;
(*S)[i][j] = new_i_j;
}
// We do exactly the same transformations with V
for (size_t i = 0; i < n; ++i) {
- int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j];
- int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j];
+ int64_t new_i_index = a * (*V)[i][index] + b * (*V)[i][j];
+ int64_t new_i_j = n_g * (*V)[i][index] - m_g * (*V)[i][j];
(*V)[i][index] = new_i_index;
(*V)[i][j] = new_i_j;
}
PrimExpr eb = tir::make_const((*x)[index].dtype(), b);
PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g);
PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g);
- PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
- PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
+ PrimExpr new_index = e_m_g * (*x)[index] + e_n_g * (*x)[j];
+ PrimExpr new_j = eb * (*x)[index] - ea * (*x)[j];
(*x)[index] = new_index;
(*x)[j] = new_j;
}
}
}
-Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer,
- const Array<Var>& ori_vars,
+Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer, const Array<Var>& ori_vars,
const Map<Var, Range>& ori_ranges) {
// The resulting ranges
Map<Var, Range> new_ranges;
// pretty print matrix equation
void DebugPrint(const std::vector<std::vector<int64_t>>& S,
- const std::vector<std::vector<int64_t>>& V,
- const std::vector<PrimExpr>& V_inv_x,
+ const std::vector<std::vector<int64_t>>& V, const std::vector<PrimExpr>& V_inv_x,
const std::vector<PrimExpr>& rhs) {
std::cout << "S:\n";
for (size_t i = 0; i < S.size(); ++i) {
std::cout << "\n" << std::endl;
}
-IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) {
+IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve) {
// m: # of equations
// n: # of variables
// we first construct A_{mxn} x_{nx1} = y_{mx1}
// S_{mxn} = U_{mxm} A_{mxn} V_{nxn}
// => U^{-1} S V^{-1} x = y
// S V^{-1} x = U y
- std::vector<PrimExpr> Uy; // mx1
+ std::vector<PrimExpr> Uy; // mx1
std::vector<std::vector<int64_t>> S; // mxn
std::vector<std::vector<int64_t>> V; // nxn
- std::vector<PrimExpr> V_inv_x; // V^{-1} x, nx1
+ std::vector<PrimExpr> V_inv_x; // V^{-1} x, nx1
// Conditions we don't know what to do with
std::vector<PrimExpr> rest;
for (const PrimExpr& equation : system_to_solve->relations) {
if (const tir::EQNode* eq = equation.as<tir::EQNode>()) {
// a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n]
- Array<PrimExpr> coeffs = arith::DetectLinearEquation(
- analyzer_problem.Simplify(eq->a - eq->b),
- system_to_solve->variables);
+ Array<PrimExpr> coeffs = arith::DetectLinearEquation(analyzer_problem.Simplify(eq->a - eq->b),
+ system_to_solve->variables);
if (!coeffs.empty()) {
std::vector<int64_t> row;
for (size_t j = 0; j < coeffs.size() - 1; ++j) {
new_relation = analyzer_problem.Simplify(new_relation);
if (tir::is_const_int(new_relation, 0)) {
// unable to solve the system.
- return IntConstraintsTransform(
- system_to_solve,
- IntConstraints(
- /*variables=*/{},
- /*ranges=*/{},
- /*relations=*/{tir::make_zero(DataType::Bool())}),
- {}, {});
+ return IntConstraintsTransform(system_to_solve,
+ IntConstraints(
+ /*variables=*/{},
+ /*ranges=*/{},
+ /*relations=*/{tir::make_zero(DataType::Bool())}),
+ {}, {});
} else if (!tir::is_const_int(new_relation, 1)) {
new_relations.push_back(new_relation);
}
// S^{-1}_{nxm} Uy_{mxn}
if (S[j][j] >= 0) {
PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]);
- solution_for_V_inv_x.push_back(
- analyzer_problem.Simplify(floordiv(Uy[j], a)));
+ solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(Uy[j], a)));
} else {
// This is required because some simplifiers
// have problems with dividing by negative numbers
PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]);
- solution_for_V_inv_x.push_back(
- analyzer_problem.Simplify(floordiv(-Uy[j], a)));
+ solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(-Uy[j], a)));
}
}
}
for (size_t i = 0; i < num_vars; ++i) {
PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype());
for (size_t j = 0; j < num_vars; ++j) {
- e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
+ e = e + tir::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j];
}
e = analyzer_problem.Simplify(e);
old_to_new_map.Set(system_to_solve->variables[i], e);
}
// The resulting ranges
- Map<Var, Range> new_ranges = InferRange(
- new_to_old_map, system_to_solve->variables, system_to_solve->ranges);
+ Map<Var, Range> new_ranges =
+ InferRange(new_to_old_map, system_to_solve->variables, system_to_solve->ranges);
Analyzer analyzer_solution;
analyzer_solution.Bind(new_ranges);
const Range& old_range = p.second;
if (old_to_new_map.count(old_var)) {
PrimExpr express_by_new_vars = old_to_new_map[old_var];
- PrimExpr lower_cond = analyzer_solution.Simplify(
- old_range->min <= express_by_new_vars);
- PrimExpr upper_cond = analyzer_solution.Simplify(
- express_by_new_vars < old_range->min + old_range->extent);
+ PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars);
+ PrimExpr upper_cond =
+ analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent);
if (!tir::is_const_int(lower_cond, 1)) {
new_relations.push_back(lower_cond);
}
}
IntConstraints solution(new_vars, new_ranges, new_relations);
- IntConstraintsTransform transform(
- system_to_solve, solution, old_to_new_map, new_to_old_map);
+ IntConstraintsTransform transform(system_to_solve, solution, old_to_new_map, new_to_old_map);
return transform;
}
-TVM_REGISTER_GLOBAL("arith.SolveLinearEquations")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- if (args.size() == 1) {
- *ret = SolveLinearEquations(args[0]);
- } else if (args.size() == 3) {
- IntConstraints problem(args[0], args[1], args[2]);
- *ret = SolveLinearEquations(problem);
- } else {
- LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size();
- }
- });
+TVM_REGISTER_GLOBAL("arith.SolveLinearEquations").set_body([](TVMArgs args, TVMRetValue* ret) {
+ if (args.size() == 1) {
+ *ret = SolveLinearEquations(args[0]);
+ } else if (args.size() == 3) {
+ IntConstraints problem(args[0], args[1], args[2]);
+ *ret = SolveLinearEquations(problem);
+ } else {
+ LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size();
+ }
+});
} // namespace arith
} // namespace tvm
* \file util.cc
* \brief The utils for arithmetic analysis.
*/
-#include <tvm/arith/util.h>
#include <dmlc/logging.h>
+#include <tvm/arith/util.h>
namespace tvm {
namespace arith {
CHECK_EQ(a % old_r, 0);
CHECK_EQ(b % old_r, 0);
- CHECK(old_r == old_s*a + old_t*b);
+ CHECK(old_r == old_s * a + old_t * b);
return std::make_tuple(old_r, old_s, old_t);
}
// for loop
void FeatureVisitor::VisitStmt_(const ForNode* op) {
- const auto *extent = op->extent.as<IntImmNode>();
+ const auto* extent = op->extent.as<IntImmNode>();
int64_t loop_extent = -1;
- if (extent != nullptr)
- loop_extent = extent->value;
+ if (extent != nullptr) loop_extent = extent->value;
AnnotationType ann = kSerial;
switch (op->for_type) {
case ForType ::Parallel:
// parallel axis, virtual thread
void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
Var var = op->node.as<tir::IterVarNode>()->var;
- const auto *extent = op->value.as<IntImmNode>();
+ const auto* extent = op->value.as<IntImmNode>();
CHECK(extent);
std::string name = var.get()->name_hint;
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
+
#include <string>
namespace tvm {
* \brief Type of for loop, used as one-hot encoding in features
*/
enum AnnotationType {
- kBlockX, kBlockY, kBlockZ, kThreadX, kThreadY, kThreadZ,
- kUnrolled, kVectorized, kParallel, kSerial, kVirtualThread,
+ kBlockX,
+ kBlockY,
+ kBlockZ,
+ kThreadX,
+ kThreadY,
+ kThreadZ,
+ kUnrolled,
+ kVectorized,
+ kParallel,
+ kSerial,
+ kVirtualThread,
kNum,
};
void VisitExpr_(const LoadNode* op) final;
void VisitStmt_(const StoreNode* op) final;
- using StmtExprVisitor::VisitStmt_;
using StmtExprVisitor::VisitExpr_;
+ using StmtExprVisitor::VisitStmt_;
protected:
/*!
- * \brief Enter a for loop node
- * \param var The expression to be printed.
- * \param length The output stream
- * \param ann_type The type for the for loop
- * \return skip Whether skip this node
- */
+ * \brief Enter a for loop node
+ * \param var The expression to be printed.
+ * \param length The output stream
+ * \param ann_type The type for the for loop
+ * \return skip Whether skip this node
+ */
virtual bool EnterItervar_(tir::Var var, int64_t length, AnnotationType ann_type) = 0;
/*! \brief Exit a for loop subtree */
virtual void ExitItervar_() = 0;
#include "touch_extractor.h"
-#include <set>
#include <algorithm>
#include <cmath>
+#include <set>
#include <unordered_map>
namespace tvm {
int ParallelLevel(AnnotationType ann) {
switch (ann) {
- case kBlockX: case kBlockY: case kBlockZ:
+ case kBlockX:
+ case kBlockY:
+ case kBlockZ:
return 2;
- case kThreadX: case kThreadY: case kThreadZ: case kParallel:
+ case kThreadX:
+ case kThreadY:
+ case kThreadZ:
+ case kParallel:
return 1;
default:
return 0;
}
// get touch pattern from index expression
-class IndexParser: public ExprVisitor {
+class IndexParser : public ExprVisitor {
public:
void Parse(PrimExpr expr) {
pattern_map.clear();
itervar_map.erase(var);
}
- itervar_map.insert({var, ItervarFeature(var, length,
- static_cast<int>(itervar_stack_.size()),
- ann_type,
- topdown_product_,
- static_cast<int>(itervar_counter_++))});
+ itervar_map.insert(
+ {var, ItervarFeature(var, length, static_cast<int>(itervar_stack_.size()), ann_type,
+ topdown_product_, static_cast<int>(itervar_counter_++))});
}
return true;
CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
touch_pattern->second.count *= itervar_map[var].length;
}
- } else { // multiply reuse ratio
+ } else { // multiply reuse ratio
for (auto stack_var : itervar_stack_) {
auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
itervar_stack_.pop_back();
int64_t length = itervar_map[var].length;
- if (length != 0)
- topdown_product_ /= length;
+ if (length != 0) topdown_product_ /= length;
int64_t bottomup_product = -1;
for (auto kv : itervar_map[var].touch_feature) {
bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse);
}
}
-void TouchExtractor::ExitMem_() {
-}
+void TouchExtractor::ExitMem_() {}
/*!
* \brief Get axis-based feature for all axes
* \note If you want to flatten these features as the input of your model,
* You can use the faster one GetItervarFeatureFlatten below.
*/
-void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > > *ret_feature) {
+void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >* ret_feature) {
// extract
TouchExtractor touch_analyzer;
touch_analyzer.Analyze(stmt);
for (auto kv : touch_analyzer.itervar_map) {
vars.push_back(kv.first);
}
- std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool {
+ std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool {
return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
});
std::function<double(int64_t)> trans;
if (take_log) {
trans = [](int64_t x) {
- if (x < 0)
- return -std::log(-x+1) / std::log(2);
+ if (x < 0) return -std::log(-x + 1) / std::log(2);
x = x + 1;
return std::log(x) / std::log(2);
};
} else {
- trans = [](int64_t x) {
- return x;
- };
+ trans = [](int64_t x) { return x; };
}
// serialize for front end
for (auto var : vars) {
Array<Array<PrimExpr> > feature_row;
- ItervarFeature &fea = touch_analyzer.itervar_map[var];
+ ItervarFeature& fea = touch_analyzer.itervar_map[var];
feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_itervar_"), var});
- Array<PrimExpr> attr{tvm::tir::StringImmNode::make("_attr_"),
- FloatImm(DataType::Float(32), trans(fea.length)),
- IntImm(DataType::Int(32), fea.nest_level),
- FloatImm(DataType::Float(32), trans(fea.topdown_product)),
- FloatImm(DataType::Float(32), trans(fea.bottomup_product)),
+ Array<PrimExpr> attr{
+ tvm::tir::StringImmNode::make("_attr_"),
+ FloatImm(DataType::Float(32), trans(fea.length)),
+ IntImm(DataType::Int(32), fea.nest_level),
+ FloatImm(DataType::Float(32), trans(fea.topdown_product)),
+ FloatImm(DataType::Float(32), trans(fea.bottomup_product)),
};
// one hot annotation
for (int i = 0; i < kNum; i++) {
feature_row.push_back(attr);
// arithmetic
- feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_arith_"),
- FloatImm(DataType::Float(32), trans(fea.add_ct)),
- FloatImm(DataType::Float(32), trans(fea.mul_ct)),
- FloatImm(DataType::Float(32), trans(fea.div_ct)),
+ feature_row.push_back(Array<PrimExpr>{
+ tvm::tir::StringImmNode::make("_arith_"),
+ FloatImm(DataType::Float(32), trans(fea.add_ct)),
+ FloatImm(DataType::Float(32), trans(fea.mul_ct)),
+ FloatImm(DataType::Float(32), trans(fea.div_ct)),
});
// touch map
}
std::sort(bufs.begin(), bufs.end());
for (auto k : bufs) {
- TouchPattern &v = fea.touch_feature[k];
- feature_row.push_back(
- Array<PrimExpr>{tvm::tir::StringImmNode::make(k),
- FloatImm(DataType::Float(32), trans(v.stride)),
- FloatImm(DataType::Float(32), trans(v.mod)),
- FloatImm(DataType::Float(32), trans(v.count)),
- FloatImm(DataType::Float(32), trans(v.reuse)),
- FloatImm(DataType::Float(32), trans(v.thread_count)),
- FloatImm(DataType::Float(32), trans(v.thread_reuse)),
- });
+ TouchPattern& v = fea.touch_feature[k];
+ feature_row.push_back(Array<PrimExpr>{
+ tvm::tir::StringImmNode::make(k),
+ FloatImm(DataType::Float(32), trans(v.stride)),
+ FloatImm(DataType::Float(32), trans(v.mod)),
+ FloatImm(DataType::Float(32), trans(v.count)),
+ FloatImm(DataType::Float(32), trans(v.reuse)),
+ FloatImm(DataType::Float(32), trans(v.thread_count)),
+ FloatImm(DataType::Float(32), trans(v.thread_reuse)),
+ });
}
ret_feature->push_back(feature_row);
* \note See GetItervarFeature for more details about the return value.
* This is an optimized version of GetItervarFeature + Flatten. This runs much faster.
*/
-void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector<float> *ret_feature) {
+void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector<float>* ret_feature) {
// extract touch feature
TouchExtractor touch_analyzer;
touch_analyzer.Analyze(stmt);
for (auto kv : touch_analyzer.itervar_map) {
vars.push_back(kv.first);
}
- std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool {
+ std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool {
return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
});
std::function<float(int64_t)> trans;
if (take_log) {
trans = [](int64_t x) {
- if (x < 0)
- return -std::log(-x+1) / std::log(2);
+ if (x < 0) return -std::log(-x + 1) / std::log(2);
x = x + 1;
return std::log(x) / std::log(2);
};
} else {
- trans = [](int64_t x) {
- return x;
- };
+ trans = [](int64_t x) { return x; };
}
// serialize for front end
for (auto var : vars) {
- ItervarFeature &fea = touch_analyzer.itervar_map[var];
+ ItervarFeature& fea = touch_analyzer.itervar_map[var];
ret_feature->push_back(trans(fea.length));
ret_feature->push_back(fea.nest_level);
}
std::sort(bufs.begin(), bufs.end());
for (auto k : bufs) {
- TouchPattern &v = fea.touch_feature[k];
+ TouchPattern& v = fea.touch_feature[k];
ret_feature->push_back(trans(v.stride));
ret_feature->push_back(trans(v.mod));
ret_feature->push_back(trans(v.count));
}
/*!
- * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional vector.
- * \param stmt The statement to be extracted
- * \param sample_n The number of points used for sampling a curve (along one dimension)
- * \param ret_feature The buffer where the return value is stored
+ * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional
+ * vector. \param stmt The statement to be extracted \param sample_n The number of points used for
+ * sampling a curve (along one dimension) \param ret_feature The buffer where the return value is
+ * stored
*/
-void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector<float> *ret_feature) {
+void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector<float>* ret_feature) {
// extract touch feature
TouchExtractor touch_ext;
touch_ext.Analyze(stmt);
for (auto kv : touch_ext.itervar_map) {
vars.push_back(kv.first);
}
- std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool {
+ std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool {
return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order;
});
// find maximum depth of loop nest
for (auto var : vars) {
- ItervarFeature &fea = touch_ext.itervar_map[var];
+ ItervarFeature& fea = touch_ext.itervar_map[var];
max_depth = std::max(max_depth, fea.nest_level);
}
// mark inner most buffer
for (auto iter = vars.rbegin(); iter != vars.rend(); iter++) {
auto var = *iter;
- ItervarFeature &fea = touch_ext.itervar_map[var];
+ ItervarFeature& fea = touch_ext.itervar_map[var];
if (fea.nest_level == max_depth) {
for (auto kv : fea.touch_feature) {
// delete buffer no (e.g. 'A_0' -> 'A', 'A_1' -> 'A')
// delete memory scope (e.g. 'A.local' -> 'A', 'A.shared' -> 'A')
size_t pos = raw_name.find(".");
- if (pos < kv.first.size())
- raw_name = raw_name.substr(0, pos);
+ if (pos < kv.first.size()) raw_name = raw_name.substr(0, pos);
// If there are multiple innermost buffers that are derived from a same raw buffer
// We only record the last occurrence (note the `iter` is in reverse order)
// extract curves
for (auto var : vars) {
- ItervarFeature &fea = touch_ext.itervar_map[var];
+ ItervarFeature& fea = touch_ext.itervar_map[var];
for (auto kv : fea.touch_feature) {
if (innermost_buffers.find(kv.first) != innermost_buffers.end()) {
reuse_curve[kv.first].emplace_back(std::log(kv.second.reuse) / std::log(2));
}
// sample relation in the curve
- auto sample_curve = [&](const std::vector<double> &x, const std::vector<double> &y,
+ auto sample_curve = [&](const std::vector<double>& x, const std::vector<double>& y,
double weight) {
for (int i = 0; i < sample_n; i++) {
double xx = i * weight;
// serialize to frontend
for (auto k : innermost_buffers) {
- std::vector<double> &count = count_curve[k];
- std::vector<double> &reuse = reuse_curve[k];
- std::vector<double> &top_down = topdown_curve[k];
+ std::vector<double>& count = count_curve[k];
+ std::vector<double>& reuse = reuse_curve[k];
+ std::vector<double>& top_down = topdown_curve[k];
std::sort(count.begin(), count.end());
std::sort(reuse.begin(), reuse.end());
}
}
-
// register API for front end
TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeature")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Stmt stmt = args[0];
- bool take_log = args[1];
- Array<Array<Array<PrimExpr > > > ret_feature;
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ Stmt stmt = args[0];
+ bool take_log = args[1];
+ Array<Array<Array<PrimExpr> > > ret_feature;
- GetItervarFeature(stmt, take_log, &ret_feature);
-
- *ret = ret_feature;
-});
+ GetItervarFeature(stmt, take_log, &ret_feature);
+ *ret = ret_feature;
+ });
TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeatureFlatten")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Stmt stmt = args[0];
- bool take_log = args[1];
- std::vector<float> ret_feature;
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ Stmt stmt = args[0];
+ bool take_log = args[1];
+ std::vector<float> ret_feature;
- GetItervarFeatureFlatten(stmt, take_log, &ret_feature);
-
- TVMByteArray arr;
- arr.size = sizeof(float) * ret_feature.size();
- arr.data = reinterpret_cast<char *>(ret_feature.data());
- *ret = arr;
-});
+ GetItervarFeatureFlatten(stmt, take_log, &ret_feature);
+ TVMByteArray arr;
+ arr.size = sizeof(float) * ret_feature.size();
+ arr.data = reinterpret_cast<char*>(ret_feature.data());
+ *ret = arr;
+ });
TVM_REGISTER_GLOBAL("autotvm.feature.GetCurveSampleFeatureFlatten")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Stmt stmt = args[0];
- int sample_n = args[1];
- std::vector<float> ret_feature;
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ Stmt stmt = args[0];
+ int sample_n = args[1];
+ std::vector<float> ret_feature;
- GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature);
-
- TVMByteArray arr;
- arr.size = sizeof(float) * ret_feature.size();
- arr.data = reinterpret_cast<char *>(ret_feature.data());
- *ret = arr;
-});
+ GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature);
+ TVMByteArray arr;
+ arr.size = sizeof(float) * ret_feature.size();
+ arr.data = reinterpret_cast<char*>(ret_feature.data());
+ *ret = arr;
+ });
} // namespace autotvm
} // namespace tvm
#ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
-#include <tvm/runtime/registry.h>
-#include <stack>
-#include <vector>
+#include <deque>
#include <map>
+#include <stack>
#include <string>
-#include <deque>
#include <unordered_map>
+#include <vector>
+
#include "feature_visitor.h"
namespace tvm {
// all the feature of an iter var
struct ItervarFeature {
- ItervarFeature(Var var,
- int64_t extent,
- int nest,
- AnnotationType ann_type,
- int64_t topdown,
+ ItervarFeature(Var var, int64_t extent, int nest, AnnotationType ann_type, int64_t topdown,
int counter)
: length(extent), nest_level(nest), ann(ann_type), topdown_product(topdown), order(counter) {}
ItervarFeature() {}
// Axis Attributes
int64_t length;
int nest_level;
- AnnotationType ann; // one-hot axis type
- int64_t topdown_product; // accumulative product of axis length, in top-down order
- int64_t bottomup_product; // accumulative product of axis length, in bottom-up order
+ AnnotationType ann; // one-hot axis type
+ int64_t topdown_product; // accumulative product of axis length, in top-down order
+ int64_t bottomup_product; // accumulative product of axis length, in bottom-up order
// bottomup_product = reuse * count for any touched buffer
int order; // used for soring axis
// extract iter vars and their touch pattern from ir
class TouchExtractor : public FeatureVisitor {
public:
- void Analyze(const Stmt& stmt) {
- operator()(stmt);
- }
+ void Analyze(const Stmt& stmt) { operator()(stmt); }
// arithmetic stats
void VisitExpr_(const AddNode* op) final {
- if (op->dtype.is_float())
- itervar_map[itervar_stack_.back()].add_ct++;
+ if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const SubNode* op) final {
- if (op->dtype.is_float())
- itervar_map[itervar_stack_.back()].add_ct++;
+ if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const MulNode* op) final {
- if (op->dtype.is_float())
- itervar_map[itervar_stack_.back()].mul_ct++;
+ if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const DivNode* op) final {
- if (op->dtype.is_float())
- itervar_map[itervar_stack_.back()].div_ct++;
+ if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const ModNode* op) final {
- if (op->dtype.is_float())
- itervar_map[itervar_stack_.back()].div_ct++;
+ if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op);
}
/*!
* \file codegen_hybrid.cc
*/
+#include "codegen_hybrid.h"
+
#include <tvm/runtime/registry.h>
-#include <iomanip>
+
#include <cctype>
-#include "codegen_hybrid.h"
+#include <iomanip>
namespace tvm {
namespace contrib {
using namespace tir;
std::string dot_to_underscore(std::string s) {
- for (auto &ch : s)
+ for (auto& ch : s)
if (ch == '.') ch = '_';
return s;
}
return prefix;
}
-std::string CodeGenHybrid::Finish() {
- return stream.str();
-}
+std::string CodeGenHybrid::Finish() { return stream.str(); }
-void CodeGenHybrid::PrintType(DataType t, std::ostream &os) {
+void CodeGenHybrid::PrintType(DataType t, std::ostream& os) {
if (t.is_float()) {
os << "float";
CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
os << op->value;
}
-void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(" << std::setprecision(20) << op->value << ")";
}
-void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
os << "'" << op->value << "'";
}
-template<typename T>
-inline void PrintBinaryExpr(const T* op,
- const char* opstr,
+template <typename T>
+inline void PrintBinaryExpr(const T* op, const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
- CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented";
+ CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented";
if (isalpha(opstr[0])) {
os << opstr << '(';
p->PrintExpr(op->a, os);
}
}
-inline void PrintBinaryIntrinsitc(const CallNode* op,
- const char* opstr,
+inline void PrintBinaryIntrinsitc(const CallNode* op, const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
- CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented";
+ CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented";
CHECK_EQ(op->args.size(), 2U);
os << '(';
p->PrintExpr(op->args[0], os);
LOG(FATAL) << "Phase 0 has no Load(s)!";
}
-void CodeGenHybrid::VisitStmt_(const StoreNode* op) {
- LOG(FATAL) << "Phase 0 has no Store(s)!";
-}
+void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; }
void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Phase 0 has no Let(s)!";
LOG(FATAL) << "Ramp to be supported yet";
}
-void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported ";
}
CHECK(iter_var);
binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint);
PrintIndent();
- stream << "for " << binds_[iter_var->var.get()] << " in bind('"
- << iter_var->var->name_hint << "', ";
+ stream << "for " << binds_[iter_var->var.get()] << " in bind('" << iter_var->var->name_hint
+ << "', ";
PrintExpr(op->value, stream);
stream << "):\n";
indent_ += tab_;
std::string extent = PrintExpr(op->extent);
PrintIndent();
std::string vid = GetVarID(op->loop_var.get());
- stream << "for " << vid << " in " << "range(" << extent << "):\n";
+ stream << "for " << vid << " in "
+ << "range(" << extent << "):\n";
indent_ += tab_;
PrintStmt(op->body);
indent_ -= tab_;
}
-bool is_noop(const Stmt &stmt) {
- if (!stmt.defined())
- return true;
- if (auto eval = stmt.as<EvaluateNode>())
- return is_const(eval->value);
+bool is_noop(const Stmt& stmt) {
+ if (!stmt.defined()) return true;
+ if (auto eval = stmt.as<EvaluateNode>()) return is_const(eval->value);
return false;
}
void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
if (is_const(op->value)) return;
std::string str = PrintExpr(op->value);
- if (!str.empty())
- stream << str << "\n";
+ if (!str.empty()) stream << str << "\n";
}
-void CodeGenHybrid::PrintIndent() {
- stream << std::string(indent_, ' ');
-}
+void CodeGenHybrid::PrintIndent() { stream << std::string(indent_, ' '); }
-std::string CodeGenHybrid::GetVarID(const VarNode *v) {
- if (binds_.count(v))
- return binds_[v];
+std::string CodeGenHybrid::GetVarID(const VarNode* v) {
+ if (binds_.count(v)) return binds_[v];
auto key = std::make_pair(static_cast<const Object*>(v), 0);
if (id_map_.count(key)) {
return id_map_[key];
return id_map_[key] = GetUniqueName(v->name_hint);
}
-std::string CodeGenHybrid::GetTensorID(const FunctionRef &func, int value_index) {
+std::string CodeGenHybrid::GetTensorID(const FunctionRef& func, int value_index) {
auto key = std::make_pair(func.get(), value_index);
if (id_map_.count(key)) {
return id_map_[key];
GetUniqueName("max_num_threads");
}
-void CodeGenHybrid::DumpStmt(const Stmt &stmt,
- const Array<ObjectRef> &inputs,
- const Array<Tensor> &outputs,
- const std::string &name) {
+void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array<ObjectRef>& inputs,
+ const Array<Tensor>& outputs, const std::string& name) {
ReserveKeywords();
GetUniqueName(name);
indent_ += tab_;
for (size_t i = 0; i < outputs.size(); ++i) {
PrintIndent();
- stream << GetTensorID(outputs[i]->op, outputs[i]->value_index)
- << " = output_tensor((";
+ stream << GetTensorID(outputs[i]->op, outputs[i]->value_index) << " = output_tensor((";
for (size_t j = 0; j < outputs[i]->shape.size(); ++j) {
if (j) stream << ", ";
PrintExpr(outputs[i]->shape[j], stream);
}
- if (outputs[i]->shape.size() == 1)
- stream << ", ";
+ if (outputs[i]->shape.size() == 1) stream << ", ";
stream << "), '" << outputs[i]->dtype << "')\n";
}
PrintStmt(stmt);
stream << "\n";
}
-TVM_REGISTER_GLOBAL("hybrid._Dump")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- CodeGenHybrid codegen;
- if (args.size() == 4)
- codegen.DumpStmt(args[0], args[1], args[2], args[3]);
- else
- codegen.DumpStmt(args[0], args[1], args[2]);
- *rv = codegen.Finish();
- });
+TVM_REGISTER_GLOBAL("hybrid._Dump").set_body([](TVMArgs args, TVMRetValue* rv) {
+ CodeGenHybrid codegen;
+ if (args.size() == 4)
+ codegen.DumpStmt(args[0], args[1], args[2], args[3]);
+ else
+ codegen.DumpStmt(args[0], args[1], args[2]);
+ *rv = codegen.Finish();
+});
} // namespace contrib
} // namespace tvm
#ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
#define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
-#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
#include <tvm/te/schedule.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+
#include <map>
#include <string>
#include <unordered_map>
* **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3.
* For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``.
*/
-class CodeGenHybrid :
- public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
- public StmtFunctor<void(const Stmt&)> {
+class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
+ public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Dump the given function body to hybrid script.
* \param outputs Output tensors of this schedule.
* \param name The name of the function.
*/
- void DumpStmt(const Stmt &stmt, const Array<ObjectRef> &inputs, const Array<Tensor> &outputs,
- const std::string &name = "hybrid_func");
+ void DumpStmt(const Stmt& stmt, const Array<ObjectRef>& inputs, const Array<Tensor>& outputs,
+ const std::string& name = "hybrid_func");
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
* \brief Print the Stmt n to CodeGenHybrid->stream
* \param n The statement to be printed.
*/
- void PrintStmt(const Stmt &n) {
- this->VisitStmt(n);
- }
+ void PrintStmt(const Stmt& n) { this->VisitStmt(n); }
/*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
* \param os The output stream
*/
- void PrintExpr(const PrimExpr &n, std::ostream &os) {
- this->VisitExpr(n, os);
- }
+ void PrintExpr(const PrimExpr& n, std::ostream& os) { this->VisitExpr(n, os); }
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
- std::string PrintExpr(const PrimExpr &n) {
+ std::string PrintExpr(const PrimExpr& n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
}
// expression
- void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmtNode* op) override;
* \param t The type representation.
* \param os The stream to print the ctype into
*/
- virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
+ virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
private:
/*! \brief The current indent of the code dump. */
/*!
* \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/
- std::map<std::pair<const Object *, int>, std::string> id_map_;
+ std::map<std::pair<const Object*, int>, std::string> id_map_;
/*! \brief Variables (keys) binded to the threads (values). */
- std::map<const VarNode *, std::string> binds_;
+ std::map<const VarNode*, std::string> binds_;
/*!
* \brief Find an unallocated name for the given prefix.
* \param prefix The given prefix.
* \brief Get or allocate the ID for the given variable.
* \param v The given variable.
*/
- std::string GetVarID(const VarNode *v);
+ std::string GetVarID(const VarNode* v);
/*!
* \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name.
* \param value_index The value index of the given tensor.
*/
- std::string GetTensorID(const FunctionRef &func, int value_index);
+ std::string GetTensorID(const FunctionRef& func, int value_index);
/*! \brief the storage scope of allocation */
std::map<FunctionRef, std::string> alloc_storage_scope_;
};
*/
#include <dmlc/thread_local.h>
#include <tvm/driver/driver_api.h>
-#include <tvm/te/operation.h>
-
-#include <tvm/tir/transform.h>
-#include <tvm/tir/analysis.h>
-#include <tvm/target/codegen.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
+#include <tvm/target/codegen.h>
+#include <tvm/te/operation.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/transform.h>
#include <algorithm>
#include <mutex>
namespace tvm {
+using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
-using runtime::PackedFunc;
bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
}
}
-tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape,
- DataType dtype,
- std::string name,
- int data_alignment,
- int offset_factor,
- bool compact) {
+tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
+ int data_alignment, int offset_factor, bool compact) {
auto data = tir::Var(name, DataType::Handle());
bool has_any = false;
if (!compact) {
}
return tir::BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "",
- data_alignment, offset_factor, buffer_type);
+ data_alignment, offset_factor, buffer_type);
}
-void GetBinds(const Array<te::Tensor>& args,
- bool compact,
+void GetBinds(const Array<te::Tensor>& args, bool compact,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- Map<te::Tensor, tir::Buffer>* out_binds,
- Array<ObjectRef>* out_arg_list,
+ Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list,
const BuildConfig& config) {
*out_binds = binds;
- for (const auto &x : args) {
+ for (const auto& x : args) {
if (out_binds->find(x) == out_binds->end()) {
- auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name,
- config->data_alignment, config->offset_factor, compact);
+ auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, config->data_alignment,
+ config->offset_factor, compact);
out_binds->Set(x, buf);
out_arg_list->push_back(buf);
} else {
return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {});
}
-
-template<typename FCond>
+template <typename FCond>
transform::Pass Filter(FCond fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (fcond(f)) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-
-IRModule lower(te::Schedule sch,
- const Array<te::Tensor>& args,
- const std::string& name,
+IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config) {
Array<ObjectRef> out_arg_list;
GetBinds(args, compact, binds, &out_binds, &out_arg_list, config);
// build the function
- tir::PrimFunc f = te::SchedulePostProcToPrimFunc(
- out_arg_list, std::move(stmt), out_binds);
+ tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
// Phase 0
pass_list.push_back(tir::transform::InjectPrefetch());
- pass_list.push_back(
- tir::transform::StorageFlatten(64, config->instrument_bound_checkers));
+ pass_list.push_back(tir::transform::StorageFlatten(64, config->instrument_bound_checkers));
// Phase 1
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::InjectDoubleBuffer(config->double_buffer_split_loop));
pass_list.push_back(tir::transform::StorageRewrite());
pass_list.push_back(
- tir::transform::UnrollLoop(config->auto_unroll_max_step,
- config->auto_unroll_max_depth,
- config->auto_unroll_max_extent,
- config->unroll_explicit));
+ tir::transform::UnrollLoop(config->auto_unroll_max_step, config->auto_unroll_max_depth,
+ config->auto_unroll_max_extent, config->unroll_explicit));
// Phase 2
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RemoveNoOp());
return mod;
}
-
-std::pair<IRModule, IRModule>
-split_dev_host_funcs(IRModule mod_mixed,
- const Target& target,
- const Target& target_host,
- const BuildConfig& config) {
- Array<tvm::transform::Pass> mixed_pass_list = {
- BindTarget(target),
- tir::transform::VerifyMemory()
- };
+std::pair<IRModule, IRModule> split_dev_host_funcs(IRModule mod_mixed, const Target& target,
+ const Target& target_host,
+ const BuildConfig& config) {
+ Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
+ tir::transform::VerifyMemory()};
if (config->detect_global_barrier) {
mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
}
mod_mixed = opt_mixed(std::move(mod_mixed));
auto host_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(
- tvm::attr::kCallingConv,
- Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(target_host),
- tir::transform::LowerTVMBuiltin(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- tir::transform::CombineContextCall(),
+ Filter([](const tir::PrimFunc& f) {
+ return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) !=
+ CallingConv::kDeviceKernelLaunch;
+ }),
+ BindTarget(target_host),
+ tir::transform::LowerTVMBuiltin(),
+ tir::transform::LowerIntrin(),
+ tir::transform::LowerDeviceStorageAccessInfo(),
+ tir::transform::CombineContextCall(),
};
auto opt_host = transform::Sequential(host_pass_list);
auto mhost = opt_host(mod_mixed);
// device pipeline
auto device_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(
- tvm::attr::kCallingConv,
- Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(target),
- tir::transform::LowerWarpMemory(),
- tir::transform::Simplify(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
+ Filter([](const tir::PrimFunc& f) {
+ return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
+ CallingConv::kDeviceKernelLaunch;
+ }),
+ BindTarget(target),
+ tir::transform::LowerWarpMemory(),
+ tir::transform::Simplify(),
+ tir::transform::LowerIntrin(),
+ tir::transform::LowerDeviceStorageAccessInfo(),
};
auto opt_device = transform::Sequential(device_pass_list);
auto mdevice = opt_device(mod_mixed);
auto keys = target->keys();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && mdevice->functions.size() == 0) {
- LOG(WARNING) << "Specified target "
- << target->str()
+ LOG(WARNING) << "Specified target " << target->str()
<< " but cannot find device code. Did you forget to bind?";
}
- if (target->device_type == target::llvm()->device_type &&
- target_host == target) {
- CHECK(mdevice->functions.empty())
- << "No device code should be generated when target "
- << "and host_target are both llvm target."
- << "\n";
+ if (target->device_type == target::llvm()->device_type && target_host == target) {
+ CHECK(mdevice->functions.empty()) << "No device code should be generated when target "
+ << "and host_target are both llvm target."
+ << "\n";
}
return {mhost, mdevice};
}
-
// Build for heterogeneous execution.
-runtime::Module build(const Map<Target, IRModule>& inputs,
- const Target& target_host,
+runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_host,
const BuildConfig& config) {
std::vector<runtime::Module> device_modules;
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>());
for (const auto& it : inputs) {
- auto pair =
- split_dev_host_funcs(it.second, it.first, target_host_val, config);
+ auto pair = split_dev_host_funcs(it.second, it.first, target_host_val, config);
auto& mhost = pair.first;
auto& mdevice = pair.second;
}
// Build for heterogeneous execution when target is a string.
-runtime::Module build(const Map<std::string, IRModule>& inputs,
- const Target& target_host,
+runtime::Module build(const Map<std::string, IRModule>& inputs, const Target& target_host,
const BuildConfig& config) {
Map<Target, IRModule> updated_input;
for (const auto& it : inputs) {
}
// Build for homogeneous execution.
-runtime::Module build(const IRModule& funcs,
- const Target& target,
- const Target& target_host,
+runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host,
const BuildConfig& config) {
Map<Target, IRModule> inputs = {{target, funcs}};
return build(inputs, target_host, config);
* \file src/ir/adt.cc
* \brief ADT type definitions.
*/
-#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
+#include <tvm/relay/type.h>
namespace tvm {
-Constructor::Constructor(std::string name_hint,
- tvm::Array<Type> inputs,
- GlobalTypeVar belong_to) {
+Constructor::Constructor(std::string name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("ir.Constructor")
-.set_body_typed([](std::string name_hint,
- tvm::Array<Type> inputs,
- GlobalTypeVar belong_to) {
- return Constructor(name_hint, inputs, belong_to);
-});
+ .set_body_typed([](std::string name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
+ return Constructor(name_hint, inputs, belong_to);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const ConstructorNode*>(ref.get());
- p->stream << "ConstructorNode(" << node->name_hint << ", "
- << node->inputs << ", " << node->belong_to << ")";
-});
+ .set_dispatch<ConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const ConstructorNode*>(ref.get());
+ p->stream << "ConstructorNode(" << node->name_hint << ", " << node->inputs << ", "
+ << node->belong_to << ")";
+ });
-TypeData::TypeData(GlobalTypeVar header,
- tvm::Array<TypeVar> type_vars,
+TypeData::TypeData(GlobalTypeVar header, tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
ObjectPtr<TypeDataNode> n = make_object<TypeDataNode>();
n->header = std::move(header);
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_GLOBAL("ir.TypeData")
-.set_body_typed([](GlobalTypeVar header,
- tvm::Array<TypeVar> type_vars,
- tvm::Array<Constructor> constructors) {
- return TypeData(header, type_vars, constructors);
-});
+ .set_body_typed([](GlobalTypeVar header, tvm::Array<TypeVar> type_vars,
+ tvm::Array<Constructor> constructors) {
+ return TypeData(header, type_vars, constructors);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TypeDataNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TypeDataNode*>(ref.get());
- p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
- << node->constructors << ")";
-});
+ .set_dispatch<TypeDataNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const TypeDataNode*>(ref.get());
+ p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
+ << node->constructors << ")";
+ });
} // namespace tvm
#include <tvm/node/functor.h>
#include <tvm/tir/expr.h>
+
#include <utility>
namespace tvm {
template <typename FType>
class AttrFunctor;
-#define ATTR_FUNCTOR_DEFAULT \
+#define ATTR_FUNCTOR_DEFAULT \
{ return VisitAttrDefault_(op, std::forward<Args>(args)...); }
-
-#define ATTR_FUNCTOR_DISPATCH(OP) \
- vtable.template set_dispatch<OP>( \
- [](const ObjectRef& n, TSelf* self, Args... args) { \
- return self->VisitAttr_(static_cast<const OP*>(n.get()), \
- std::forward<Args>(args)...); \
- }); \
+#define ATTR_FUNCTOR_DISPATCH(OP) \
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
+ return self->VisitAttr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+ });
// A functor for common attribute information.
template <typename R, typename... Args>
*/
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
+
#include "attr_functor.h"
namespace tvm {
-void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
- v->Visit("__dict__", &dict);
-}
+void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); }
-void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
- v->Visit("__dict__", &dict);
-}
+void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); }
-void DictAttrsNode::InitByPackedArgs(
- const runtime::TVMArgs& args, bool allow_unknown) {
+void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) {
for (int i = 0; i < args.size(); i += 2) {
std::string key = args[i];
runtime::TVMArgValue val = args[i + 1];
}
}
-Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
- return {};
-}
+Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { return {}; }
DictAttrs::DictAttrs(Map<std::string, ObjectRef> dict) {
ObjectPtr<DictAttrsNode> n = make_object<DictAttrsNode>();
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const DictAttrsNode*>(node.get());
- p->stream << op->dict;
-});
+ .set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const DictAttrsNode*>(node.get());
+ p->stream << op->dict;
+ });
TVM_REGISTER_NODE_TYPE(DictAttrsNode);
TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);
-TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict")
-.set_body_typed([](DictAttrs attrs) {
+TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs) {
return attrs->dict;
});
-TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
-.set_body_typed([](Attrs attrs) {
+TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo").set_body_typed([](Attrs attrs) {
return attrs->ListFieldInfo();
});
namespace tvm {
-
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const EnvFuncNode*>(node.get());
- p->stream << "EnvFunc(" << op->name << ")";
-});
+ .set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const EnvFuncNode*>(node.get());
+ p->stream << "EnvFunc(" << op->name << ")";
+ });
ObjectPtr<Object> CreateEnvNode(const std::string& name) {
auto* f = runtime::Registry::Get(name);
return n;
}
-EnvFunc EnvFunc::Get(const std::string& name) {
- return EnvFunc(CreateEnvNode(name));
-}
+EnvFunc EnvFunc::Get(const std::string& name) { return EnvFunc(CreateEnvNode(name)); }
-TVM_REGISTER_GLOBAL("ir.EnvFuncGet")
-.set_body_typed(EnvFunc::Get);
+TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get);
-TVM_REGISTER_GLOBAL("ir.EnvFuncCall")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- EnvFunc env = args[0];
- CHECK_GE(args.size(), 1);
- env->func.CallPacked(TVMArgs(args.values + 1,
- args.type_codes + 1,
- args.size() - 1), rv);
- });
+TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body([](TVMArgs args, TVMRetValue* rv) {
+ EnvFunc env = args[0];
+ CHECK_GE(args.size(), 1);
+ env->func.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), rv);
+});
-TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc")
-.set_body_typed([](const EnvFunc&n) {
- return n->func;
- });
+TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc").set_body_typed([](const EnvFunc& n) {
+ return n->func;
+});
TVM_REGISTER_NODE_TYPE(EnvFuncNode)
-.set_creator(CreateEnvNode)
-.set_repr_bytes([](const Object* n) -> std::string {
- return static_cast<const EnvFuncNode*>(n)->name;
- });
+ .set_creator(CreateEnvNode)
+ .set_repr_bytes([](const Object* n) -> std::string {
+ return static_cast<const EnvFuncNode*>(n)->name;
+ });
} // namespace tvm
* \brief Utilities for error tracking and reporting.
*/
-#include <tvm/ir/module.h>
#include <tvm/ir/error.h>
+#include <tvm/ir/module.h>
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
// Rationale: use relay's printer for astext.
#include <tvm/relay/expr.h>
+// clang-fomat off
#include <string>
#include <vector>
#include <rang.hpp>
+// clang-format on
namespace tvm {
-template<typename T, typename U>
+template <typename T, typename U>
using NodeMap = std::unordered_map<T, U, ObjectHash, ObjectEqual>;
void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
// Setup error map.
auto it = error_maps.find(global);
if (it != error_maps.end()) {
- it->second.insert({ node, err_msg.str() });
+ it->second.insert({node, err_msg.str()});
} else {
- error_maps.insert({ global, { { node, err_msg.str() }}});
+ error_maps.insert({global, {{node, err_msg.str()}}});
}
}
std::stringstream annotated_prog;
// First we output a header for the errors.
- annotated_prog <<
- rang::style::bold << std::endl <<
- "Error(s) have occurred. The program has been annotated with them:"
- << std::endl << std::endl << rang::style::reset;
+ annotated_prog << rang::style::bold << std::endl
+ << "Error(s) have occurred. The program has been annotated with them:" << std::endl
+ << std::endl
+ << rang::style::reset;
// For each global function which contains errors, we will
// construct an annotated function.
// We output the name of the function before displaying
// the annotated program.
- annotated_prog <<
- rang::style::bold <<
- "In `" << global->name_hint << "`: " <<
- std::endl <<
- rang::style::reset;
+ annotated_prog << rang::style::bold << "In `" << global->name_hint << "`: " << std::endl
+ << rang::style::reset;
// We then call into the Relay printer to generate the program.
//
if (it != this->node_to_error_.end()) {
it->second.push_back(index_to_insert);
} else {
- this->node_to_error_.insert({ node, { index_to_insert }});
+ this->node_to_error_.insert({node, {index_to_insert}});
}
- this->node_to_gv_.insert({ node, global });
+ this->node_to_gv_.insert({node, global});
}
} // namespace tvm
* \file src/ir/expr.cc
* \brief The expression AST nodes for the common IR infra.
*/
-#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
+#include <tvm/runtime/registry.h>
// NOTE: reverse dependency on top/tir.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
namespace tvm {
-PrimExpr::PrimExpr(int32_t value)
- : PrimExpr(IntImm(DataType::Int(32), value)) {}
+PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {}
-PrimExpr::PrimExpr(float value)
- : PrimExpr(FloatImm(DataType::Float(32), value)) {}
+PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {}
PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker;
return tir::StringImmNode::make(GetRef<runtime::String>(ptr));
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ref.get()))
- << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
- << " but get " << ref->GetTypeKey();
+ << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName() << " but get "
+ << ref->GetTypeKey();
return Downcast<PrimExpr>(ref);
}
-
IntImm::IntImm(DataType dtype, int64_t value) {
- CHECK(dtype.is_scalar())
- << "ValueError: IntImm can only take scalar.";
- CHECK(dtype.is_int() || dtype.is_uint())
- << "ValueError: IntImm can only take scalar.";
+ CHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar.";
+ CHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm can only take scalar.";
if (dtype.is_uint()) {
CHECK_GE(value, 0U);
}
data_ = std::move(node);
}
-TVM_REGISTER_GLOBAL("ir.IntImm")
-.set_body_typed([](DataType dtype, int64_t value) {
+TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value) {
return IntImm(dtype, value);
});
TVM_REGISTER_NODE_TYPE(IntImmNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IntImmNode*>(node.get());
- if (op->dtype == DataType::Int(32)) {
- p->stream << op->value;
- } else {
- p->stream << "(" << op->dtype << ")" << op->value;
- }
- });
+ .set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const IntImmNode*>(node.get());
+ if (op->dtype == DataType::Int(32)) {
+ p->stream << op->value;
+ } else {
+ p->stream << "(" << op->dtype << ")" << op->value;
+ }
+ });
FloatImm::FloatImm(DataType dtype, double value) {
- CHECK_EQ(dtype.lanes(), 1)
- << "ValueError: FloatImm can only take scalar.";
+ CHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
node->dtype = dtype;
node->value = value;
data_ = std::move(node);
}
-TVM_REGISTER_GLOBAL("ir.FloatImm")
-.set_body_typed([](DataType dtype, double value) {
+TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value) {
return FloatImm(dtype, value);
});
TVM_REGISTER_NODE_TYPE(FloatImmNode);
-
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const FloatImmNode*>(node.get());
- auto& stream = p->stream;
- switch (op->dtype.bits()) {
- case 64:
- stream << op->value;
- break;
- case 32:
- stream << op->value << 'f';
- break;
- case 16:
- stream << op->value << 'h';
- break;
- default:
- LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
- }
- });
-
+ .set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const FloatImmNode*>(node.get());
+ auto& stream = p->stream;
+ switch (op->dtype.bits()) {
+ case 64:
+ stream << op->value;
+ break;
+ case 32:
+ stream << op->value << 'f';
+ break;
+ case 16:
+ stream << op->value << 'h';
+ break;
+ default:
+ LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
+ }
+ });
Range::Range(PrimExpr begin, PrimExpr end)
- : Range(make_object<RangeNode>(
- begin,
- tir::is_zero(begin) ? end : (end - begin))) {
-}
+ : Range(make_object<RangeNode>(begin, tir::is_zero(begin) ? end : (end - begin))) {}
Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}
-TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
-.set_body_typed(Range::make_by_min_extent);
+TVM_REGISTER_GLOBAL("ir.range_by_min_extent").set_body_typed(Range::make_by_min_extent);
-TVM_REGISTER_GLOBAL("ir.Range")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Range(args[0], args[1]);
- });
+});
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const RangeNode*>(node.get());
- p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
- });
-
+ .set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const RangeNode*>(node.get());
+ p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
+ });
GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
-TVM_REGISTER_GLOBAL("ir.GlobalVar")
-.set_body_typed([](std::string name){
+TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](std::string name) {
return GlobalVar(name);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const GlobalVarNode*>(ref.get());
- p->stream << "GlobalVar(" << node->name_hint << ")";
- });
+ .set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const GlobalVarNode*>(ref.get());
+ p->stream << "GlobalVar(" << node->name_hint << ")";
+ });
// Container printer
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ArrayNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ArrayNode*>(node.get());
- p->stream << '[';
- for (size_t i = 0 ; i < op->data.size(); ++i) {
- if (i != 0) {
- p->stream << ", ";
+ .set_dispatch<ArrayNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ArrayNode*>(node.get());
+ p->stream << '[';
+ for (size_t i = 0; i < op->data.size(); ++i) {
+ if (i != 0) {
+ p->stream << ", ";
+ }
+ p->Print(op->data[i]);
}
- p->Print(op->data[i]);
- }
- p->stream << ']';
-});
+ p->stream << ']';
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<MapNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const MapNode*>(node.get());
- p->stream << '{';
- for (auto it = op->data.begin(); it != op->data.end(); ++it) {
- if (it != op->data.begin()) {
- p->stream << ", ";
+ .set_dispatch<MapNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const MapNode*>(node.get());
+ p->stream << '{';
+ for (auto it = op->data.begin(); it != op->data.end(); ++it) {
+ if (it != op->data.begin()) {
+ p->stream << ", ";
+ }
+ p->Print(it->first);
+ p->stream << ": ";
+ p->Print(it->second);
}
- p->Print(it->first);
- p->stream << ": ";
- p->Print(it->second);
- }
- p->stream << '}';
- });
+ p->stream << '}';
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<StrMapNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const StrMapNode*>(node.get());
- p->stream << '{';
- for (auto it = op->data.begin(); it != op->data.end(); ++it) {
- if (it != op->data.begin()) {
- p->stream << ", ";
+ .set_dispatch<StrMapNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const StrMapNode*>(node.get());
+ p->stream << '{';
+ for (auto it = op->data.begin(); it != op->data.end(); ++it) {
+ if (it != op->data.begin()) {
+ p->stream << ", ";
+ }
+ p->stream << '\"' << it->first << "\": ";
+ p->Print(it->second);
}
- p->stream << '\"' << it->first << "\": ";
- p->Print(it->second);
- }
- p->stream << '}';
- });
+ p->stream << '}';
+ });
} // namespace tvm
* \file src/ir/function.cc
* \brief The function data structure.
*/
-#include <tvm/runtime/registry.h>
#include <tvm/ir/function.h>
+#include <tvm/runtime/registry.h>
// NOTE: reverse dependency on relay, tir/
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into the type specific WithAttr function
-#include <tvm/tir/function.h>
#include <tvm/relay/function.h>
-
+#include <tvm/tir/function.h>
namespace tvm {
-TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
-.set_body_typed([](BaseFunc func) {
- return func->attrs;
-});
+TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { return func->attrs; });
-TVM_REGISTER_GLOBAL("ir.BaseFuncCopy")
-.set_body_typed([](BaseFunc func) {
- return func;
-});
+TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; });
TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
-.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc {
- if (func->IsInstance<tir::PrimFuncNode>()) {
- return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
- } else if (func->IsInstance<relay::FunctionNode>()) {
- return WithAttr(Downcast<relay::Function>(std::move(func)), key, value);
- } else {
- LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
- return func;
- }
-});
-
+ .set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc {
+ if (func->IsInstance<tir::PrimFuncNode>()) {
+ return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
+ } else if (func->IsInstance<relay::FunctionNode>()) {
+ return WithAttr(Downcast<relay::Function>(std::move(func)), key, value);
+ } else {
+ LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
+ return func;
+ }
+ });
} // namespace tvm
* \file module.cc
* \brief The global module in Relay.
*/
-#include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
+#include <tvm/runtime/registry.h>
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
-#include <sstream>
#include <fstream>
+#include <sstream>
#include <unordered_set>
namespace tvm {
for (const auto& kv : n->functions) {
// set global var map
CHECK(n->global_var_map_.count(kv.first->name_hint) == 0)
- << "Duplicate global function name " << kv.first->name_hint;
+ << "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}
for (const auto& kv : n->type_definitions) {
// set global typevar map
CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
- << "Duplicate global type definition name " << kv.first->name_hint;
+ << "Duplicate global type definition name " << kv.first->name_hint;
n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second);
}
auto reduce_temp = [&]() {
// sort by the hash key of the keys.
- std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) {
- return lhs.first < rhs.first;
- });
+ std::sort(temp.begin(), temp.end(),
+ [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
hash_reduce(static_cast<uint64_t>(temp.size()));
// hash the content
CHECK(global_type_var_map_.defined());
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
- << "Cannot find global type var " << name << " in the Module";
+ << "Cannot find global type var " << name << " in the Module";
return (*it).second;
}
return tvm::Array<GlobalTypeVar>(global_type_vars);
}
-template<typename T>
+template <typename T>
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
tvm::Array<T> ret(l);
for (const T& t : r) {
}
// helper function to run type check
-relay::Function RunTypeCheck(const IRModule& mod,
- const GlobalVar& var,
- relay::Function f) {
+relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::Function f) {
auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
// Type check the item before we add it to the module.
auto fv = relay::FreeVars(func);
auto ftv = relay::FreeTypeVars(func, mod);
if (fv.size() != 0) {
- LOG(WARNING)
- << "There are free variables: "
- << fv
- << " in function: "
- << AsText(func, false)
- << std::endl;
+ LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false)
+ << std::endl;
}
if (ftv.size() != 0) {
- LOG(WARNING)
- << "There are free type variables: "
- << ftv
- << " in function: "
- << AsText(func, false)
- << std::endl;
+ LOG(WARNING) << "There are free type variables: " << ftv
+ << " in function: " << AsText(func, false) << std::endl;
}
- func = relay::Function(concat(func->params, fv),
- func->body,
- func->ret_type,
- concat(func->type_params, ftv),
- func->attrs);
+ func = relay::Function(concat(func->params, fv), func->body, func->ret_type,
+ concat(func->type_params, ftv), func->attrs);
// Type check the item before we add it to the module.
relay::Function checked_func = InferType(func, mod, var);
return checked_func;
}
-void IRModuleNode::Add(const GlobalVar& var,
- const BaseFunc& f,
- bool update) {
+void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
BaseFunc checked_func = f;
if (auto* ptr = f.as<relay::FunctionNode>()) {
- checked_func = RunTypeCheck(GetRef<IRModule>(this),
- var,
- GetRef<relay::Function>(ptr));
+ checked_func = RunTypeCheck(GetRef<IRModule>(this), var, GetRef<relay::Function>(ptr));
}
Type type = checked_func->checked_type();
CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) {
- CHECK(update)
- << "Already have definition for " << var->name_hint;
+ CHECK(update) << "Already have definition for " << var->name_hint;
auto old_type = functions[var]->checked_type();
CHECK(tvm::StructuralEqual()(type, old_type))
<< "Module#update changes type, not possible in this mode.";
AddUnchecked(var, checked_func);
}
-void IRModuleNode::AddUnchecked(const GlobalVar& var,
- const BaseFunc& func) {
+void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
this->functions.Set(var, func);
auto it = global_var_map_.find(var->name_hint);
}
}
-void IRModuleNode::AddTypeDef(const GlobalTypeVar& var,
- const TypeData& type,
- bool update) {
+void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
AddTypeDefUnchecked(var, type, update);
// need to kind check at the end because the check can look up
// a definition potentially
CHECK(relay::KindCheck(type, GetRef<IRModule>(this)) == TypeKind::kTypeData)
- << "Invalid or malformed typedata given to module: " << type;
+ << "Invalid or malformed typedata given to module: " << type;
}
-void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var,
- const TypeData& type,
+void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
bool update) {
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
CHECK(global_type_var_map_.count(var->name_hint) == 0)
- << "Duplicate global type definition name " << var->name_hint;
+ << "Duplicate global type definition name " << var->name_hint;
}
global_type_var_map_.Set(var->name_hint, var);
RegisterConstructors(var, type);
}
-void IRModuleNode::Update(const GlobalVar& var,
- const BaseFunc& func) {
+void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) {
this->Add(var, func, true);
}
-void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var,
- const TypeData& type) {
+void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
this->AddTypeDef(var, type, true);
}
BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
- CHECK(it != functions.end())
- << "There is no definition of " << var->name_hint;
+ CHECK(it != functions.end()) << "There is no definition of " << var->name_hint;
return (*it).second;
}
TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
- CHECK(it != type_definitions.end())
- << "There is no definition of " << var->name_hint;
+ CHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint;
return (*it).second;
}
Constructor IRModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
- CHECK(it != constructor_tag_map_.end())
- << "There is no constructor with the tag " << tag;
+ CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag;
return (*it).second;
}
}
}
-IRModule IRModule::FromExpr(
- const RelayExpr& expr,
- const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
- const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
+IRModule IRModule::FromExpr(const RelayExpr& expr,
+ const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
+ const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func;
std::string gv_name = "main";
}
} else {
- func = relay::Function(relay::FreeVars(expr), expr, Type(),
- relay::FreeTypeVars(expr, mod), {});
+ func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar(gv_name);
mod->Add(main_gv, func);
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
std::fstream src_file(path, std::fstream::in);
- std::string file_contents {
- std::istreambuf_iterator<char>(src_file),
- std::istreambuf_iterator<char>() };
+ std::string file_contents{std::istreambuf_iterator<char>(src_file),
+ std::istreambuf_iterator<char>()};
auto mod_to_import = IRModule::FromText(file_contents, path);
Update(mod_to_import);
}
this->Import(std_path + "/" + path.operator std::string());
}
-std::unordered_set<String> IRModuleNode::Imports() const {
- return this->import_set_;
-}
+std::unordered_set<String> IRModuleNode::Imports() const { return this->import_set_; }
IRModule IRModule::FromText(const String& text, const String& source_path) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext");
TVM_REGISTER_NODE_TYPE(IRModuleNode);
TVM_REGISTER_GLOBAL("ir.IRModule")
-.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
- tvm::Map<GlobalTypeVar, TypeData> types) {
- return IRModule(funcs, types, {});
-});
+ .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
+ tvm::Map<GlobalTypeVar, TypeData> types) {
+ return IRModule(funcs, types, {});
+ });
-TVM_REGISTER_GLOBAL("ir.Module_Add")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) {
IRModule mod = args[0];
GlobalVar var = args[1];
ObjectRef val = args[2];
*ret = mod;
});
-TVM_REGISTER_GLOBAL("ir.Module_AddDef")
-.set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
+TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
-.set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
+ .set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars")
-.set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
+ .set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
-.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
+ .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
-.set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
+ .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
-.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
+ .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
-TVM_REGISTER_GLOBAL("ir.Module_Lookup")
-.set_body_typed([](IRModule mod, GlobalVar var) {
+TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) {
return mod->Lookup(var);
});
-TVM_REGISTER_GLOBAL("ir.Module_Lookup_str")
-.set_body_typed([](IRModule mod, String var) {
+TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) {
return mod->Lookup(var);
});
-TVM_REGISTER_GLOBAL("ir.Module_LookupDef")
-.set_body_typed([](IRModule mod, GlobalTypeVar var) {
+TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) {
return mod->LookupTypeDef(var);
});
-TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str")
-.set_body_typed([](IRModule mod, String var) {
+TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) {
return mod->LookupTypeDef(var);
});
-TVM_REGISTER_GLOBAL("ir.Module_LookupTag")
-.set_body_typed([](IRModule mod, int32_t tag) {
- return mod->LookupTag(tag);
- });
+TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) {
+ return mod->LookupTag(tag);
+});
TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
-.set_body_typed([](RelayExpr e,
- tvm::Map<GlobalVar, BaseFunc> funcs,
- tvm::Map<GlobalTypeVar, TypeData> type_defs) {
- return IRModule::FromExpr(e, funcs, type_defs);
-});
+ .set_body_typed([](RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs,
+ tvm::Map<GlobalTypeVar, TypeData> type_defs) {
+ return IRModule::FromExpr(e, funcs, type_defs);
+ });
-TVM_REGISTER_GLOBAL("ir.Module_Update")
-.set_body_typed([](IRModule mod, IRModule from) {
+TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) {
mod->Update(from);
});
-TVM_REGISTER_GLOBAL("ir.Module_Import")
-.set_body_typed([](IRModule mod, String path) {
+TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) {
mod->Import(path);
});
-TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd")
-.set_body_typed([](IRModule mod, String path) {
+TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) {
mod->ImportFromStd(path);
-});;
+});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const IRModuleNode*>(ref.get());
- p->stream << "IRModuleNode( " << node->functions << ")";
-});
+ .set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const IRModuleNode*>(ref.get());
+ p->stream << "IRModuleNode( " << node->functions << ")";
+ });
} // namespace tvm
*/
#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
-#include <tvm/runtime/module.h>
#include <tvm/runtime/container.h>
+#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
namespace tvm {
-using runtime::TVMRetValue;
-using runtime::TVMArgs;
using runtime::PackedFunc;
+using runtime::TVMArgs;
+using runtime::TVMRetValue;
-::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
- return ::dmlc::Registry<OpRegistry>::Get();
-}
+::dmlc::Registry<OpRegistry>* OpRegistry::Registry() { return ::dmlc::Registry<OpRegistry>::Get(); }
// single manager of operator information.
struct OpManager {
}
}
-void OpRegistry::UpdateAttr(const std::string& key,
- TVMRetValue value,
- int plevel) {
+void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int plevel) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
}
std::pair<TVMRetValue, int>& p = op_map->data_[index];
- CHECK(p.second != plevel)
- << "Attribute " << key << " of operator " << this->name
- << " is already registered with same plevel=" << plevel;
+ CHECK(p.second != plevel) << "Attribute " << key << " of operator " << this->name
+ << " is already registered with same plevel=" << plevel;
CHECK(value.type_code() != kTVMNullptr)
- << "Registered packed_func is Null for " << key
- << " of operator " << this->name;
+ << "Registered packed_func is Null for " << key << " of operator " << this->name;
if (p.second < plevel && value.type_code() != kTVMNullptr) {
op_map->data_[index] = std::make_pair(value, plevel);
}
}
// Frontend APIs
-TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
-.set_body_typed([]() {
- Array<runtime::String> ret;
- for (const std::string& name : dmlc::Registry<OpRegistry>::ListAllNames()) {
- ret.push_back(name);
- }
- return ret;
- });
+TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() {
+ Array<runtime::String> ret;
+ for (const std::string& name : dmlc::Registry<OpRegistry>::ListAllNames()) {
+ ret.push_back(name);
+ }
+ return ret;
+});
-TVM_REGISTER_GLOBAL("relay.op._GetOp")
-.set_body_typed([](std::string name) -> Op {
+TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](std::string name) -> Op {
return Op::Get(name);
});
-TVM_REGISTER_GLOBAL("relay.op._OpGetAttr")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- Op op = args[0];
- std::string attr_name = args[1];
- auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
- if (op_map.count(op)) {
- *rv = op_map[op];
- }
- });
-
-TVM_REGISTER_GLOBAL("relay.op._OpSetAttr")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- Op op = args[0];
- std::string attr_name = args[1];
- runtime::TVMArgValue value = args[2];
- int plevel = args[3];
- auto& reg =
- OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
- reg.set_attr(attr_name, value, plevel);
- });
-
-TVM_REGISTER_GLOBAL("relay.op._OpResetAttr")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- Op op = args[0];
- std::string attr_name = args[1];
- auto& reg =
- OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
- reg.reset_attr(attr_name);
- });
-
-TVM_REGISTER_GLOBAL("relay.op._Register")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- std::string op_name = args[0];
- std::string attr_key = args[1];
- runtime::TVMArgValue value = args[2];
- int plevel = args[3];
- auto& reg =
- OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
- // enable resgiteration and override of certain properties
- if (attr_key == "num_inputs" && plevel > 128) {
- reg.set_num_inputs(value);
- } else if (attr_key == "attrs_type_key" && plevel > 128) {
- LOG(FATAL) << "attrs type key no longer supported";
+TVM_REGISTER_GLOBAL("relay.op._OpGetAttr").set_body([](TVMArgs args, TVMRetValue* rv) {
+ Op op = args[0];
+ std::string attr_name = args[1];
+ auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+ if (op_map.count(op)) {
+ *rv = op_map[op];
+ }
+});
+
+TVM_REGISTER_GLOBAL("relay.op._OpSetAttr").set_body([](TVMArgs args, TVMRetValue* rv) {
+ Op op = args[0];
+ std::string attr_name = args[1];
+ runtime::TVMArgValue value = args[2];
+ int plevel = args[3];
+ auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
+ reg.set_attr(attr_name, value, plevel);
+});
+
+TVM_REGISTER_GLOBAL("relay.op._OpResetAttr").set_body([](TVMArgs args, TVMRetValue* rv) {
+ Op op = args[0];
+ std::string attr_name = args[1];
+ auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
+ reg.reset_attr(attr_name);
+});
+
+TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue* rv) {
+ std::string op_name = args[0];
+ std::string attr_key = args[1];
+ runtime::TVMArgValue value = args[2];
+ int plevel = args[3];
+ auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
+ // enable resgiteration and override of certain properties
+ if (attr_key == "num_inputs" && plevel > 128) {
+ reg.set_num_inputs(value);
+ } else if (attr_key == "attrs_type_key" && plevel > 128) {
+ LOG(FATAL) << "attrs type key no longer supported";
+ } else {
+ // normal attr table override.
+ if (args[2].type_code() == kTVMPackedFuncHandle) {
+ // do an eager copy of the PackedFunc
+ PackedFunc f = args[2];
+ // If we get a function from frontend, avoid deleting it.
+ OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
+ reg.set_attr(attr_key, f, plevel);
} else {
- // normal attr table override.
- if (args[2].type_code() == kTVMPackedFuncHandle) {
- // do an eager copy of the PackedFunc
- PackedFunc f = args[2];
- // If we get a function from frontend, avoid deleting it.
- OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
- reg.set_attr(attr_key, f, plevel);
- } else {
- reg.set_attr(attr_key, args[2], plevel);
- }
+ reg.set_attr(attr_key, args[2], plevel);
}
- });
+ }
+});
// helper to get internal dev function in objectref.
struct Op2ObjectPtr : public ObjectRef {
- static ObjectPtr<Object> Get(const Op& op) {
- return GetDataPtr<Object>(op);
- }
+ static ObjectPtr<Object> Get(const Op& op) { return GetDataPtr<Object>(op); }
};
ObjectPtr<Object> CreateOp(const std::string& name) {
return Op2ObjectPtr::Get(op);
}
-TVM_REGISTER_NODE_TYPE(OpNode)
-.set_creator(CreateOp)
-.set_repr_bytes([](const Object* n) {
- return static_cast<const OpNode*>(n)->name;
- });
+TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes([](const Object* n) {
+ return static_cast<const OpNode*>(n)->name;
+});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const OpNode*>(ref.get());
- p->stream << "Op(" << node->name << ")";
- });
+ .set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const OpNode*>(ref.get());
+ p->stream << "Op(" << node->name << ")";
+ });
} // namespace tvm
return GetSourceNameNode(name);
}
-SourceName SourceName::Get(const String& name) {
- return SourceName(GetSourceNameNode(name));
-}
+SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); }
-TVM_REGISTER_GLOBAL("ir.SourceName")
-.set_body_typed(SourceName::Get);
+TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const SourceNameNode*>(ref.get());
- p->stream << "SourceName(" << node->name << ", " << node << ")";
- });
+ .set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const SourceNameNode*>(ref.get());
+ p->stream << "SourceName(" << node->name << ", " << node << ")";
+ });
TVM_REGISTER_NODE_TYPE(SourceNameNode)
-.set_creator(GetSourceNameNodeByStr)
-.set_repr_bytes([](const Object* n) -> std::string {
- return static_cast<const SourceNameNode*>(n)->name;
-});
+ .set_creator(GetSourceNameNodeByStr)
+ .set_repr_bytes([](const Object* n) -> std::string {
+ return static_cast<const SourceNameNode*>(n)->name;
+ });
Span SpanNode::make(SourceName source, int lineno, int col_offset) {
auto n = make_object<SpanNode>();
TVM_REGISTER_NODE_TYPE(SpanNode);
-TVM_REGISTER_GLOBAL("ir.Span")
-.set_body_typed(SpanNode::make);
+TVM_REGISTER_GLOBAL("ir.Span").set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const SpanNode*>(ref.get());
- p->stream << "Span(" << node->source << ", " << node->lineno << ", "
- << node->col_offset << ")";
- });
+ .set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const SpanNode*>(ref.get());
+ p->stream << "Span(" << node->source << ", " << node->lineno << ", " << node->col_offset
+ << ")";
+ });
} // namespace tvm
* \file src/ir/tensor_type.cc
* \brief The type system AST nodes of Relay.
*/
-#include <tvm/runtime/registry.h>
#include <tvm/ir/tensor_type.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
namespace tvm {
data_ = std::move(n);
}
-TensorType TensorType::Scalar(DataType dtype) {
- return TensorType({}, dtype);
-}
+TensorType TensorType::Scalar(DataType dtype) { return TensorType({}, dtype); }
PrimExpr TensorTypeNode::Size() const {
if (shape.size() == 0) {
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
-TVM_REGISTER_GLOBAL("ir.TensorType")
-.set_body_typed([](Array<PrimExpr> shape, DataType dtype) {
+TVM_REGISTER_GLOBAL("ir.TensorType").set_body_typed([](Array<PrimExpr> shape, DataType dtype) {
return TensorType(shape, dtype);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TensorTypeNode*>(ref.get());
- p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
-});
+ .set_dispatch<TensorTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const TensorTypeNode*>(ref.get());
+ p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
+ });
} // namespace tvm
* \brief Infrastructure for transformation passes.
*/
#include <dmlc/thread_local.h>
-#include <tvm/runtime/registry.h>
+#include <tvm/ir/transform.h>
+#include <tvm/node/repr_printer.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/device_api.h>
-#include <tvm/node/repr_printer.h>
-#include <tvm/ir/transform.h>
+#include <tvm/runtime/registry.h>
// TODO(tqchen): Update to use String container after it is merged.
#include <tvm/tir/expr.h>
namespace tvm {
namespace transform {
+using tvm::ReprPrinter;
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
-using tvm::ReprPrinter;
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
- PassContextThreadLocalEntry() {
- default_context = PassContext(make_object<PassContextNode>());
- }
+ PassContextThreadLocalEntry() { default_context = PassContext(make_object<PassContextNode>()); }
};
/*! \brief Thread local store to hold the pass context. */
-typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
- RelayPassContextThreadLocalStore;
+typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry> RelayPassContextThreadLocalStore;
void PassContext::EnterWithScope() {
- PassContextThreadLocalEntry* entry =
- RelayPassContextThreadLocalStore::Get();
+ PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void PassContext::ExitWithScope() {
- PassContextThreadLocalEntry* entry =
- RelayPassContextThreadLocalStore::Get();
+ PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
PassContext PassContext::Current() {
- PassContextThreadLocalEntry* entry =
- RelayPassContextThreadLocalStore::Get();
+ PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
if (!entry->context_stack.empty()) {
return entry->context_stack.top();
} else {
}
}
-PassContext PassContext::Create() {
- return PassContext(make_object<PassContextNode>());
-}
+PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); }
void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
- auto pass_ctx_node = this->operator->();
- if (pass_ctx_node->trace_func != nullptr) {
- pass_ctx_node->trace_func(module, info, is_before);
- }
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->trace_func != nullptr) {
+ pass_ctx_node->trace_func(module, info, is_before);
+ }
}
class ModulePass;
ModulePassNode() = default;
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("pass_info", &pass_info);
- }
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
/*!
* \brief Run a module pass on given pass context.
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};
-PassInfo::PassInfo(int opt_level,
- std::string name,
- tvm::Array<runtime::String> required) {
+PassInfo::PassInfo(int opt_level, std::string name, tvm::Array<runtime::String> required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
data_ = std::move(pass_info);
}
-ModulePass::ModulePass(
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
- PassInfo pass_info) {
+ModulePass::ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
+ PassInfo pass_info) {
auto n = make_object<ModulePassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
}
// Module -> Module optimizations.
-IRModule ModulePassNode::operator()(IRModule mod,
- const PassContext& pass_ctx) const {
+IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
- DLOG(INFO) << "Executing module pass : "
- << pass_info->name
- << " with opt level: "
- << pass_info->opt_level;
+ DLOG(INFO) << "Executing module pass : " << pass_info->name
+ << " with opt level: " << pass_info->opt_level;
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
// pass
} else if ((f = Registry::Get("relay._transform." + pass_name))) {
}
- CHECK(f != nullptr) << "Cannot use " << pass_name
- << "to create the pass";
+ CHECK(f != nullptr) << "Cannot use " << pass_name << "to create the pass";
return (*f)();
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
-IRModule SequentialNode::operator()(IRModule mod,
- const PassContext& pass_ctx) const {
+IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
- if (!PassEnabled(pass_info)) continue;
+ if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
mod = GetPass(it)(std::move(mod), pass_ctx);
return mod;
}
-Pass CreateModulePass(
- const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
- int opt_level,
- const std::string& name,
- const tvm::Array<runtime::String>& required) {
+Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
+ int opt_level, const std::string& name,
+ const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("transform.PassInfo")
-.set_body_typed([](int opt_level, std::string name, tvm::Array<runtime::String> required) {
- return PassInfo(opt_level, name, required);
-});
+ .set_body_typed([](int opt_level, std::string name, tvm::Array<runtime::String> required) {
+ return PassInfo(opt_level, name, required);
+ });
-TVM_REGISTER_GLOBAL("transform.Info")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
*ret = pass->Info();
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) {
- auto* node = static_cast<const PassInfoNode*>(ref.get());
- p->stream << "The meta data of the pass: ";
- p->stream << "pass name: " << node->name;
- p->stream << "opt_level: " << node->opt_level;
- p->stream << "required passes: [" << "\n";
- for (const auto& it : node->required) {
- p->stream << it << ", ";
- }
- p->stream << "]\n";
-});
+ .set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) {
+ auto* node = static_cast<const PassInfoNode*>(ref.get());
+ p->stream << "The meta data of the pass: ";
+ p->stream << "pass name: " << node->name;
+ p->stream << "opt_level: " << node->opt_level;
+ p->stream << "required passes: ["
+ << "\n";
+ for (const auto& it : node->required) {
+ p->stream << it << ", ";
+ }
+ p->stream << "]\n";
+ });
TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_GLOBAL("transform.MakeModulePass")
-.set_body_typed(
- [](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
- PassInfo pass_info) {
- return ModulePass(pass_func, pass_info);
-});
+ .set_body_typed([](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
+ PassInfo pass_info) { return ModulePass(pass_func, pass_info); });
-TVM_REGISTER_GLOBAL("transform.RunPass")
-.set_body_typed([](Pass pass, IRModule mod) {
+TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) {
return pass(std::move(mod));
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const ModulePassNode*>(ref.get());
- const PassInfo info = node->Info();
- p->stream << "Run Module pass: " << info->name
- << " at the optimization level " << info->opt_level;
-});
+ .set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const ModulePassNode*>(ref.get());
+ const PassInfo info = node->Info();
+ p->stream << "Run Module pass: " << info->name << " at the optimization level "
+ << info->opt_level;
+ });
TVM_REGISTER_NODE_TYPE(SequentialNode);
-TVM_REGISTER_GLOBAL("transform.Sequential")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValue* ret) {
tvm::Array<Pass> passes = args[0];
int opt_level = args[1];
std::string name = args[2];
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const SequentialNode*>(ref.get());
- const PassInfo info = node->Info();
- p->stream << "Run Sequential pass: " << info->name
- << " at the optimization level " << info->opt_level << ". ";
- p->stream << "The passes will be executed are: [";
- for (const auto& it : node->passes) {
- const PassInfo pass_info = it->Info();
- p->stream << pass_info->name << " ";
- }
- p->stream << "]";
-});
+ .set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const SequentialNode*>(ref.get());
+ const PassInfo info = node->Info();
+ p->stream << "Run Sequential pass: " << info->name << " at the optimization level "
+ << info->opt_level << ". ";
+ p->stream << "The passes will be executed are: [";
+ for (const auto& it : node->passes) {
+ const PassInfo pass_info = it->Info();
+ p->stream << pass_info->name << " ";
+ }
+ p->stream << "]";
+ });
TVM_REGISTER_NODE_TYPE(PassContextNode);
-TVM_REGISTER_GLOBAL("transform.PassContext")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("transform.PassContext").set_body([](TVMArgs args, TVMRetValue* ret) {
auto pctx = PassContext::Create();
int opt_level = args[0];
int fallback_device = args[1];
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const PassContextNode*>(ref.get());
- p->stream << "Pass context information: " << "\n";
- p->stream << "\topt_level: " << node->opt_level << "\n";
- p->stream << "\tfallback device: "
- << runtime::DeviceName(node->fallback_device)
- << "\n";
-
- p->stream << "\trequired passes: [" << node->opt_level;
- for (const auto& it : node->required_pass) {
- p->stream << it << " ";
- }
- p->stream << "]\n";
-
- p->stream << "\tdisabled passes: [" << node->opt_level;
- for (const auto& it : node->disabled_pass) {
- p->stream << it << " ";
- }
- p->stream << "]";
-});
+ .set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const PassContextNode*>(ref.get());
+ p->stream << "Pass context information: "
+ << "\n";
+ p->stream << "\topt_level: " << node->opt_level << "\n";
+ p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n";
+
+ p->stream << "\trequired passes: [" << node->opt_level;
+ for (const auto& it : node->required_pass) {
+ p->stream << it << " ";
+ }
+ p->stream << "]\n";
+
+ p->stream << "\tdisabled passes: [" << node->opt_level;
+ for (const auto& it : node->disabled_pass) {
+ p->stream << it << " ";
+ }
+ p->stream << "]";
+ });
class PassContext::Internal {
public:
- static void EnterScope(PassContext pass_ctx) {
- pass_ctx.EnterWithScope();
- }
+ static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); }
- static void ExitScope(PassContext pass_ctx) {
- pass_ctx.ExitWithScope();
- }
+ static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); }
};
-TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext")
-.set_body_typed(PassContext::Current);
-
-TVM_REGISTER_GLOBAL("transform.EnterPassContext")
-.set_body_typed(PassContext::Internal::EnterScope);
+TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current);
-TVM_REGISTER_GLOBAL("transform.ExitPassContext")
-.set_body_typed(PassContext::Internal::ExitScope);
+TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope);
+TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope);
Pass PrintIR(std::string header, bool show_meta_data) {
- auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) {
- LOG(INFO) << "PrintIR(" << header << "):\n"
- << AsText(mod, show_meta_data);
+ auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) {
+ LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data);
return mod;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
}
-TVM_REGISTER_GLOBAL("transform.PrintIR")
-.set_body_typed(PrintIR);
+TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR);
} // namespace transform
} // namespace tvm
TVM_REGISTER_NODE_TYPE(PrimTypeNode);
-TVM_REGISTER_GLOBAL("ir.PrimType")
-.set_body_typed([](runtime::DataType dtype) {
+TVM_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) {
return PrimType(dtype);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PrimTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const PrimTypeNode*>(ref.get());
- p->stream << node->dtype;
-});
-
+ .set_dispatch<PrimTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const PrimTypeNode*>(ref.get());
+ p->stream << node->dtype;
+ });
PointerType::PointerType(Type element_type) {
ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
TVM_REGISTER_NODE_TYPE(PointerTypeNode);
-TVM_REGISTER_GLOBAL("ir.PointerType")
-.set_body_typed([](Type element_type) {
+TVM_REGISTER_GLOBAL("ir.PointerType").set_body_typed([](Type element_type) {
return PointerType(element_type);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const PointerTypeNode*>(ref.get());
- p->Print(node->element_type);
- p->stream << '*';
-});
-
+ .set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const PointerTypeNode*>(ref.get());
+ p->Print(node->element_type);
+ p->stream << '*';
+ });
TypeVar::TypeVar(String name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
TVM_REGISTER_NODE_TYPE(TypeVarNode);
-TVM_REGISTER_GLOBAL("ir.TypeVar")
-.set_body_typed([](String name, int kind) {
+TVM_REGISTER_GLOBAL("ir.TypeVar").set_body_typed([](String name, int kind) {
return TypeVar(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TypeVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TypeVarNode*>(ref.get());
- p->stream << "TypeVar(" << node->name_hint << ", "
- << node->kind << ")";
-});
-
+ .set_dispatch<TypeVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const TypeVarNode*>(ref.get());
+ p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")";
+ });
GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
-TVM_REGISTER_GLOBAL("ir.GlobalTypeVar")
-.set_body_typed([](std::string name, int kind) {
+TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](std::string name, int kind) {
return GlobalTypeVar(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
- p->stream << "GlobalTypeVar(" << node->name_hint << ", "
- << node->kind << ")";
-});
+ .set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
+ p->stream << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")";
+ });
-FuncType::FuncType(tvm::Array<Type> arg_types,
- Type ret_type,
- tvm::Array<TypeVar> type_params,
+FuncType::FuncType(tvm::Array<Type> arg_types, Type ret_type, tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) {
ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>();
n->arg_types = std::move(arg_types);
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_GLOBAL("ir.FuncType")
-.set_body_typed([](tvm::Array<Type> arg_types,
- Type ret_type,
- tvm::Array<TypeVar> type_params,
- tvm::Array<TypeConstraint> type_constraints) {
- return FuncType(arg_types, ret_type, type_params, type_constraints);
-});
+ .set_body_typed([](tvm::Array<Type> arg_types, Type ret_type, tvm::Array<TypeVar> type_params,
+ tvm::Array<TypeConstraint> type_constraints) {
+ return FuncType(arg_types, ret_type, type_params, type_constraints);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const FuncTypeNode*>(ref.get());
- p->stream << "FuncType(" << node->type_params << ", "
- << node->arg_types << ", " << node->ret_type << ", "
- << node->type_constraints << ")";
-});
-
+ .set_dispatch<FuncTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const FuncTypeNode*>(ref.get());
+ p->stream << "FuncType(" << node->type_params << ", " << node->arg_types << ", "
+ << node->ret_type << ", " << node->type_constraints << ")";
+ });
TupleType::TupleType(Array<Type> fields) {
ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
data_ = std::move(n);
}
-TupleType TupleType::Empty() {
- return TupleType(Array<Type>());
-}
+TupleType TupleType::Empty() { return TupleType(Array<Type>()); }
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
-TVM_REGISTER_GLOBAL("ir.TupleType")
-.set_body_typed([](Array<Type> fields) {
+TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array<Type> fields) {
return TupleType(fields);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TupleTypeNode*>(ref.get());
- p->stream << "TupleTypeNode(" << node->fields << ")";
-});
-
+ .set_dispatch<TupleTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const TupleTypeNode*>(ref.get());
+ p->stream << "TupleTypeNode(" << node->fields << ")";
+ });
IncompleteType::IncompleteType(TypeKind kind) {
auto n = make_object<IncompleteTypeNode>();
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
-TVM_REGISTER_GLOBAL("ir.IncompleteType")
-.set_body_typed([](int kind) {
- return IncompleteType(static_cast<TypeKind>(kind));
- });
+TVM_REGISTER_GLOBAL("ir.IncompleteType").set_body_typed([](int kind) {
+ return IncompleteType(static_cast<TypeKind>(kind));
+});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
- p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
- });
-
+ .set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
+ p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
+ });
RelayRefType::RelayRefType(Type value) {
ObjectPtr<RelayRefTypeNode> n = make_object<RelayRefTypeNode>();
data_ = std::move(n);
}
-TVM_REGISTER_GLOBAL("ir.RelayRefType")
-.set_body_typed([](Type value) {
+TVM_REGISTER_GLOBAL("ir.RelayRefType").set_body_typed([](Type value) {
return RelayRefType(value);
});
TVM_REGISTER_NODE_TYPE(RelayRefTypeNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const RelayRefTypeNode*>(ref.get());
- p->stream << "RelayRefTypeNode(" << node->value << ")";
-});
+ .set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const RelayRefTypeNode*>(ref.get());
+ p->stream << "RelayRefTypeNode(" << node->value << ")";
+ });
} // namespace tvm
* \brief Implementations of type functors.
*/
#include <tvm/ir/type_functor.h>
+
#include <utility>
namespace tvm {
-void TypeVisitor::VisitType_(const TypeVarNode* op) {
-}
+void TypeVisitor::VisitType_(const TypeVarNode* op) {}
-void TypeVisitor::VisitType_(const TensorTypeNode* op) {
-}
+void TypeVisitor::VisitType_(const TensorTypeNode* op) {}
-void TypeVisitor::VisitType_(const IncompleteTypeNode* op) {
-}
+void TypeVisitor::VisitType_(const IncompleteTypeNode* op) {}
void TypeVisitor::VisitType_(const FuncTypeNode* op) {
for (auto type_param : op->type_params) {
}
}
-void TypeVisitor::VisitType_(const RelayRefTypeNode* op) {
- this->VisitType(op->value);
-}
+void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { this->VisitType(op->value); }
void TypeVisitor::VisitType_(const TypeRelationNode* op) {
for (const Type& t : op->args) {
}
}
-void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) {
-}
+void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) {}
void TypeVisitor::VisitType_(const TypeCallNode* op) {
this->VisitType(op->func);
}
}
-void TypeVisitor::VisitType_(const PrimTypeNode* op) {
-}
+void TypeVisitor::VisitType_(const PrimTypeNode* op) {}
-void TypeVisitor::VisitType_(const PointerTypeNode* op) {
- this->VisitType(op->element_type);
-}
+void TypeVisitor::VisitType_(const PointerTypeNode* op) { this->VisitType(op->element_type); }
Type TypeMutator::VisitType(const Type& t) {
return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
return arr;
}
-Type TypeMutator::VisitType_(const TypeVarNode* op) {
- return GetRef<TypeVar>(op);
-}
+Type TypeMutator::VisitType_(const TypeVarNode* op) { return GetRef<TypeVar>(op); }
Type TypeMutator::VisitType_(const TensorTypeNode* op) {
// TODO(tvm-team) recursively visit to replace Var
return GetRef<Type>(op);
}
-Type TypeMutator::VisitType_(const IncompleteTypeNode* op) {
- return GetRef<Type>(op);
-}
+Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { return GetRef<Type>(op); }
Type TypeMutator::VisitType_(const FuncTypeNode* op) {
bool changed = false;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
changed = changed || !new_type_cs.same_as(type_cs);
- if (const TypeConstraintNode* tin =
- new_type_cs.as<TypeConstraintNode>()) {
+ if (const TypeConstraintNode* tin = new_type_cs.as<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
LOG(FATAL) << new_type_cs;
changed = changed || !new_ret_type.same_as(op->ret_type);
if (!changed) return GetRef<Type>(op);
- return FuncType(new_args,
- new_ret_type,
- type_params,
- type_constraints);
+ return FuncType(new_args, new_ret_type, type_params, type_constraints);
}
Type TypeMutator::VisitType_(const TupleTypeNode* op) {
if (new_args.same_as(type_rel->args)) {
return GetRef<Type>(type_rel);
} else {
- return TypeRelation(type_rel->func,
- new_args,
- type_rel->num_inputs,
- type_rel->attrs);
+ return TypeRelation(type_rel->func, new_args, type_rel->num_inputs, type_rel->attrs);
}
}
-Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) {
- return GetRef<Type>(op);
-}
+Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { return GetRef<Type>(op); }
Type TypeMutator::VisitType_(const TypeCallNode* op) {
Type new_func = VisitType(op->func);
}
}
-Type TypeMutator::VisitType_(const TypeDataNode* op) {
- return GetRef<Type>(op);
-}
+Type TypeMutator::VisitType_(const TypeDataNode* op) { return GetRef<Type>(op); }
-Type TypeMutator::VisitType_(const PrimTypeNode* op) {
- return GetRef<Type>(op);
-}
+Type TypeMutator::VisitType_(const PrimTypeNode* op) { return GetRef<Type>(op); }
Type TypeMutator::VisitType_(const PointerTypeNode* op) {
Type element_type = VisitType(op->element_type);
// Implements bind.
class TypeBinder : public TypeMutator {
public:
- explicit TypeBinder(const tvm::Map<TypeVar, Type>& args_map)
- : args_map_(args_map) {}
+ explicit TypeBinder(const tvm::Map<TypeVar, Type>& args_map) : args_map_(args_map) {}
Type VisitType_(const TypeVarNode* op) override {
auto id = GetRef<TypeVar>(op);
TVM_REGISTER_NODE_TYPE(TypeCallNode);
-TVM_REGISTER_GLOBAL("ir.TypeCall")
-.set_body_typed([](Type func, Array<Type> type) {
+TVM_REGISTER_GLOBAL("ir.TypeCall").set_body_typed([](Type func, Array<Type> type) {
return TypeCall(func, type);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TypeCallNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TypeCallNode*>(ref.get());
- p->stream << "TypeCallNode(" << node->func << ", "
- << node->args << ")";
-});
+ .set_dispatch<TypeCallNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const TypeCallNode*>(ref.get());
+ p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")";
+ });
-TypeRelation::TypeRelation(TypeRelationFn func,
- Array<Type> args,
- int num_inputs,
- Attrs attrs) {
+TypeRelation::TypeRelation(TypeRelationFn func, Array<Type> args, int num_inputs, Attrs attrs) {
ObjectPtr<TypeRelationNode> n = make_object<TypeRelationNode>();
n->func = std::move(func);
n->args = std::move(args);
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_GLOBAL("ir.TypeRelation")
-.set_body_typed([](TypeRelationFn func,
- Array<Type> args,
- int num_inputs,
- Attrs attrs) {
- return TypeRelation(func, args, num_inputs, attrs);
-});
+ .set_body_typed([](TypeRelationFn func, Array<Type> args, int num_inputs, Attrs attrs) {
+ return TypeRelation(func, args, num_inputs, attrs);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TypeRelationNode*>(ref.get());
- p->stream << "TypeRelationNode("
- << node->func->name
- << ", " << node->args << ")";
-});
+ .set_dispatch<TypeRelationNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const TypeRelationNode*>(ref.get());
+ p->stream << "TypeRelationNode(" << node->func->name << ", " << node->args << ")";
+ });
} // namespace tvm
* Expose container API to frontend.
* \file src/node/container.cc
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
+
#include "../support/str_escape.h"
namespace tvm {
struct StringObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
- static void SHashReduce(const runtime::StringObj* key,
- SHashReducer hash_reduce) {
- hash_reduce->SHashReduceHashedValue(
- runtime::String::HashBytes(key->data, key->size));
+ static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) {
+ hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size));
}
- static bool SEqualReduce(const runtime::StringObj* lhs,
- const runtime::StringObj* rhs,
+ static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->size != rhs->size) return false;
};
struct RefToObjectPtr : public ObjectRef {
- static ObjectPtr<Object> Get(const ObjectRef& ref) {
- return GetDataPtr<Object>(ref);
- }
+ static ObjectPtr<Object> Get(const ObjectRef& ref) { return GetDataPtr<Object>(ref); }
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
-.set_creator([](const std::string& bytes) {
- return RefToObjectPtr::Get(runtime::String(bytes));
-})
-.set_repr_bytes([](const Object* n) -> std::string {
- return GetRef<runtime::String>(
- static_cast<const runtime::StringObj*>(n)).operator std::string();
-});
+ .set_creator([](const std::string& bytes) {
+ return RefToObjectPtr::Get(runtime::String(bytes));
+ })
+ .set_repr_bytes([](const Object* n) -> std::string {
+ return GetRef<runtime::String>(static_cast<const runtime::StringObj*>(n))
+ .
+ operator std::string();
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const runtime::StringObj*>(node.get());
- p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
-});
-
+ .set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const runtime::StringObj*>(node.get());
+ p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
+ });
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
- static void SHashReduce(const runtime::ADTObj* key,
- SHashReducer hash_reduce) {
+ static void SHashReduce(const runtime::ADTObj* key, SHashReducer hash_reduce) {
hash_reduce(key->tag);
hash_reduce(static_cast<uint64_t>(key->size));
for (uint32_t i = 0; i < key->size; ++i) {
}
}
- static bool SEqualReduce(const runtime::ADTObj* lhs,
- const runtime::ADTObj* rhs,
+ static bool SEqualReduce(const runtime::ADTObj* lhs, const runtime::ADTObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->tag != rhs->tag) return false;
TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
-
struct NDArrayContainerTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
- static void SHashReduce(const runtime::NDArray::Container* key,
- SHashReducer hash_reduce) {
+ static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) {
CHECK_EQ(key->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
- CHECK(runtime::IsContiguous(key->dl_tensor))
- << "Can only hash contiguous tensor";
+ CHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor";
hash_reduce(runtime::DataType(key->dl_tensor.dtype));
hash_reduce(key->dl_tensor.ndim);
for (int i = 0; i < key->dl_tensor.ndim; ++i) {
hash_reduce(key->dl_tensor.shape[i]);
}
- hash_reduce->SHashReduceHashedValue(
- runtime::String::HashBytes(
- static_cast<const char*>(key->dl_tensor.data),
- runtime::GetDataSize(key->dl_tensor)));
+ hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(
+ static_cast<const char*>(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor)));
}
static bool SEqualReduce(const runtime::NDArray::Container* lhs,
- const runtime::NDArray::Container* rhs,
- SEqualReducer equal) {
+ const runtime::NDArray::Container* rhs, SEqualReducer equal) {
if (lhs == rhs) return true;
auto ldt = lhs->dl_tensor.dtype;
auto rdt = rhs->dl_tensor.dtype;
CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
- CHECK(runtime::IsContiguous(lhs->dl_tensor))
- << "Can only compare contiguous tensor";
- CHECK(runtime::IsContiguous(rhs->dl_tensor))
- << "Can only compare contiguous tensor";
+ CHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor";
+ CHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor";
if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false;
for (int i = 0; i < lhs->dl_tensor.ndim; ++i) {
TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait);
-
struct ArrayNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
- static void SHashReduce(const ArrayNode* key,
- SHashReducer hash_reduce) {
+ static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) {
hash_reduce(static_cast<uint64_t>(key->data.size()));
for (size_t i = 0; i < key->data.size(); ++i) {
hash_reduce(key->data[i]);
}
}
- static bool SEqualReduce(const ArrayNode* lhs,
- const ArrayNode* rhs,
- SEqualReducer equal) {
+ static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) {
if (lhs->data.size() != rhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!equal(lhs->data[i], rhs->data[i])) return false;
TVM_REGISTER_OBJECT_TYPE(ArrayNode);
TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
-.set_creator([](const std::string&) -> ObjectPtr<Object> {
- return ::tvm::runtime::make_object<ArrayNode>();
- });
-
-
-TVM_REGISTER_GLOBAL("node.Array")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- std::vector<ObjectRef> data;
- for (int i = 0; i < args.size(); ++i) {
- if (args[i].type_code() != kTVMNullptr) {
- data.push_back(args[i].operator ObjectRef());
- } else {
- data.push_back(ObjectRef(nullptr));
- }
+ .set_creator([](const std::string&) -> ObjectPtr<Object> {
+ return ::tvm::runtime::make_object<ArrayNode>();
+ });
+
+TVM_REGISTER_GLOBAL("node.Array").set_body([](TVMArgs args, TVMRetValue* ret) {
+ std::vector<ObjectRef> data;
+ for (int i = 0; i < args.size(); ++i) {
+ if (args[i].type_code() != kTVMNullptr) {
+ data.push_back(args[i].operator ObjectRef());
+ } else {
+ data.push_back(ObjectRef(nullptr));
}
- auto node = make_object<ArrayNode>();
- node->data = std::move(data);
- *ret = Array<ObjectRef>(node);
- });
+ }
+ auto node = make_object<ArrayNode>();
+ node->data = std::move(data);
+ *ret = Array<ObjectRef>(node);
+});
-TVM_REGISTER_GLOBAL("node.ArrayGetItem")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- int64_t i = args[1];
- CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
- Object* ptr = static_cast<Object*>(args[0].value().v_handle);
- CHECK(ptr->IsInstance<ArrayNode>());
- auto* n = static_cast<const ArrayNode*>(ptr);
- CHECK_LT(static_cast<size_t>(i), n->data.size())
- << "out of bound of array";
- *ret = n->data[static_cast<size_t>(i)];
- });
-
-TVM_REGISTER_GLOBAL("node.ArraySize")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
- Object* ptr = static_cast<Object*>(args[0].value().v_handle);
- CHECK(ptr->IsInstance<ArrayNode>());
- *ret = static_cast<int64_t>(
- static_cast<const ArrayNode*>(ptr)->data.size());
- });
+TVM_REGISTER_GLOBAL("node.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
+ int64_t i = args[1];
+ CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
+ Object* ptr = static_cast<Object*>(args[0].value().v_handle);
+ CHECK(ptr->IsInstance<ArrayNode>());
+ auto* n = static_cast<const ArrayNode*>(ptr);
+ CHECK_LT(static_cast<size_t>(i), n->data.size()) << "out of bound of array";
+ *ret = n->data[static_cast<size_t>(i)];
+});
+TVM_REGISTER_GLOBAL("node.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) {
+ CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
+ Object* ptr = static_cast<Object*>(args[0].value().v_handle);
+ CHECK(ptr->IsInstance<ArrayNode>());
+ *ret = static_cast<int64_t>(static_cast<const ArrayNode*>(ptr)->data.size());
+});
struct MapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
- static void SHashReduce(const MapNode* key,
- SHashReducer hash_reduce) {
+ static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) {
// SHash's var handling depends on the determinism of traversal.
// NOTE: only book-keep the mapped hash keys.
// This resolves common use cases where we want to store
}
}
// sort by the hash key of the keys.
- std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) {
- return lhs.first < rhs.first;
- });
+ std::sort(temp.begin(), temp.end(),
+ [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
// add size to the hash
hash_reduce(static_cast<uint64_t>(key->data.size()));
// hash the content
for (size_t i = 0; i < temp.size();) {
size_t k = i + 1;
- for (; k < temp.size() && temp[k].first == temp[i].first; ++k) {}
+ for (; k < temp.size() && temp[k].first == temp[i].first; ++k) {
+ }
// ties are rare, but we need to skip them to make the hash determinsitic
if (k == i + 1) {
hash_reduce->SHashReduceHashedValue(temp[i].first);
}
}
- static bool SEqualReduce(const MapNode* lhs,
- const MapNode* rhs,
- SEqualReducer equal) {
+ static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
// Only allow equal checking if the keys are already mapped
TVM_REGISTER_OBJECT_TYPE(MapNode);
TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
-.set_creator([](const std::string&) -> ObjectPtr<Object> {
- return ::tvm::runtime::make_object<MapNode>();
- });
-
+ .set_creator([](const std::string&) -> ObjectPtr<Object> {
+ return ::tvm::runtime::make_object<MapNode>();
+ });
struct StrMapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
- static void SHashReduce(const StrMapNode* key,
- SHashReducer hash_reduce) {
+ static void SHashReduce(const StrMapNode* key, SHashReducer hash_reduce) {
// NOTE: only book-keep the mapped hash keys.
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
using KV = std::pair<std::string, ObjectRef>;
std::vector<KV> temp(key->data.begin(), key->data.end());
// sort by the hash key of the keys.
- std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) {
- return lhs.first < rhs.first;
- });
+ std::sort(temp.begin(), temp.end(),
+ [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
// NOTE: we won't have ties
// add size to the hash after sorting.
hash_reduce(static_cast<uint64_t>(key->data.size()));
}
}
- static bool SEqualReduce(const StrMapNode* lhs,
- const StrMapNode* rhs,
- SEqualReducer equal) {
+ static bool SEqualReduce(const StrMapNode* lhs, const StrMapNode* rhs, SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
TVM_REGISTER_OBJECT_TYPE(StrMapNode);
TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait)
-.set_creator([](const std::string&) -> ObjectPtr<Object> {
- return ::tvm::runtime::make_object<StrMapNode>();
- });
-
-
-TVM_REGISTER_GLOBAL("node.Map")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- CHECK_EQ(args.size() % 2, 0);
- if (args.size() != 0 && args[0].type_code() == kTVMStr) {
- // StrMap
- StrMapNode::ContainerType data;
- for (int i = 0; i < args.num_args; i += 2) {
- CHECK(args[i].type_code() == kTVMStr)
- << "key of str map need to be str";
- CHECK(args[i + 1].IsObjectRef<ObjectRef>())
- << "value of the map to be NodeRef";
- data.emplace(std::make_pair(args[i].operator std::string(),
- args[i + 1].operator ObjectRef()));
- }
- auto node = make_object<StrMapNode>();
- node->data = std::move(data);
- *ret = Map<ObjectRef, ObjectRef>(node);
- } else {
- // Container node.
- MapNode::ContainerType data;
- for (int i = 0; i < args.num_args; i += 2) {
- CHECK(args[i].IsObjectRef<ObjectRef>())
- << "key of str map need to be object";
- CHECK(args[i + 1].IsObjectRef<ObjectRef>())
- << "value of map to be NodeRef";
- data.emplace(std::make_pair(args[i].operator ObjectRef(),
- args[i + 1].operator ObjectRef()));
- }
- auto node = make_object<MapNode>();
- node->data = std::move(data);
- *ret = Map<ObjectRef, ObjectRef>(node);
+ .set_creator([](const std::string&) -> ObjectPtr<Object> {
+ return ::tvm::runtime::make_object<StrMapNode>();
+ });
+
+TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
+ CHECK_EQ(args.size() % 2, 0);
+ if (args.size() != 0 && args[0].type_code() == kTVMStr) {
+ // StrMap
+ StrMapNode::ContainerType data;
+ for (int i = 0; i < args.num_args; i += 2) {
+ CHECK(args[i].type_code() == kTVMStr) << "key of str map need to be str";
+ CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of the map to be NodeRef";
+ data.emplace(
+ std::make_pair(args[i].operator std::string(), args[i + 1].operator ObjectRef()));
+ }
+ auto node = make_object<StrMapNode>();
+ node->data = std::move(data);
+ *ret = Map<ObjectRef, ObjectRef>(node);
+ } else {
+ // Container node.
+ MapNode::ContainerType data;
+ for (int i = 0; i < args.num_args; i += 2) {
+ CHECK(args[i].IsObjectRef<ObjectRef>()) << "key of str map need to be object";
+ CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of map to be NodeRef";
+ data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef()));
}
- });
+ auto node = make_object<MapNode>();
+ node->data = std::move(data);
+ *ret = Map<ObjectRef, ObjectRef>(node);
+ }
+});
+TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) {
+ CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
+ Object* ptr = static_cast<Object*>(args[0].value().v_handle);
+ if (ptr->IsInstance<MapNode>()) {
+ auto* n = static_cast<const MapNode*>(ptr);
+ *ret = static_cast<int64_t>(n->data.size());
+ } else {
+ CHECK(ptr->IsInstance<StrMapNode>());
+ auto* n = static_cast<const StrMapNode*>(ptr);
+ *ret = static_cast<int64_t>(n->data.size());
+ }
+});
-TVM_REGISTER_GLOBAL("node.MapSize")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
- Object* ptr = static_cast<Object*>(args[0].value().v_handle);
- if (ptr->IsInstance<MapNode>()) {
- auto* n = static_cast<const MapNode*>(ptr);
- *ret = static_cast<int64_t>(n->data.size());
- } else {
- CHECK(ptr->IsInstance<StrMapNode>());
- auto* n = static_cast<const StrMapNode*>(ptr);
- *ret = static_cast<int64_t>(n->data.size());
- }
- });
+TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
+ CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
+ Object* ptr = static_cast<Object*>(args[0].value().v_handle);
+
+ if (ptr->IsInstance<MapNode>()) {
+ auto* n = static_cast<const MapNode*>(ptr);
+ auto it = n->data.find(args[1].operator ObjectRef());
+ CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
+ *ret = (*it).second;
+ } else {
+ CHECK(ptr->IsInstance<StrMapNode>());
+ auto* n = static_cast<const StrMapNode*>(ptr);
+ auto it = n->data.find(args[1].operator std::string());
+ CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
+ *ret = (*it).second;
+ }
+});
-TVM_REGISTER_GLOBAL("node.MapGetItem")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
- Object* ptr = static_cast<Object*>(args[0].value().v_handle);
-
- if (ptr->IsInstance<MapNode>()) {
- auto* n = static_cast<const MapNode*>(ptr);
- auto it = n->data.find(args[1].operator ObjectRef());
- CHECK(it != n->data.end())
- << "cannot find the corresponding key in the Map";
- *ret = (*it).second;
- } else {
- CHECK(ptr->IsInstance<StrMapNode>());
- auto* n = static_cast<const StrMapNode*>(ptr);
- auto it = n->data.find(args[1].operator std::string());
- CHECK(it != n->data.end())
- << "cannot find the corresponding key in the Map";
- *ret = (*it).second;
- }
- });
+TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) {
+ CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
+ Object* ptr = static_cast<Object*>(args[0].value().v_handle);
-TVM_REGISTER_GLOBAL("node.MapCount")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+ if (ptr->IsInstance<MapNode>()) {
+ auto* n = static_cast<const MapNode*>(ptr);
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
- Object* ptr = static_cast<Object*>(args[0].value().v_handle);
+ *ret = static_cast<int64_t>(n->data.count(args[1].operator ObjectRef()));
+ } else {
+ CHECK(ptr->IsInstance<StrMapNode>());
+ auto* n = static_cast<const StrMapNode*>(ptr);
+ *ret = static_cast<int64_t>(n->data.count(args[1].operator std::string()));
+ }
+});
- if (ptr->IsInstance<MapNode>()) {
- auto* n = static_cast<const MapNode*>(ptr);
- CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
- *ret = static_cast<int64_t>(
- n->data.count(args[1].operator ObjectRef()));
- } else {
- CHECK(ptr->IsInstance<StrMapNode>());
- auto* n = static_cast<const StrMapNode*>(ptr);
- *ret = static_cast<int64_t>(
- n->data.count(args[1].operator std::string()));
- }
- });
+TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) {
+ CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
+ Object* ptr = static_cast<Object*>(args[0].value().v_handle);
-TVM_REGISTER_GLOBAL("node.MapItems")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
- Object* ptr = static_cast<Object*>(args[0].value().v_handle);
-
- if (ptr->IsInstance<MapNode>()) {
- auto* n = static_cast<const MapNode*>(ptr);
- auto rkvs = make_object<ArrayNode>();
- for (const auto& kv : n->data) {
- rkvs->data.push_back(kv.first);
- rkvs->data.push_back(kv.second);
- }
- *ret = Array<ObjectRef>(rkvs);
- } else {
- auto* n = static_cast<const StrMapNode*>(ptr);
- auto rkvs = make_object<ArrayNode>();
- for (const auto& kv : n->data) {
- rkvs->data.push_back(tir::StringImmNode::make(kv.first));
- rkvs->data.push_back(kv.second);
- }
- *ret = Array<ObjectRef>(rkvs);
+ if (ptr->IsInstance<MapNode>()) {
+ auto* n = static_cast<const MapNode*>(ptr);
+ auto rkvs = make_object<ArrayNode>();
+ for (const auto& kv : n->data) {
+ rkvs->data.push_back(kv.first);
+ rkvs->data.push_back(kv.second);
+ }
+ *ret = Array<ObjectRef>(rkvs);
+ } else {
+ auto* n = static_cast<const StrMapNode*>(ptr);
+ auto rkvs = make_object<ArrayNode>();
+ for (const auto& kv : n->data) {
+ rkvs->data.push_back(tir::StringImmNode::make(kv.first));
+ rkvs->data.push_back(kv.second);
}
- });
+ *ret = Array<ObjectRef>(rkvs);
+ }
+});
} // namespace tvm
* Reflection utilities.
* \file node/reflection.cc
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/node/node.h>
+#include <tvm/ir/attrs.h>
#include <tvm/node/container.h>
+#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
-#include <tvm/ir/attrs.h>
+#include <tvm/runtime/registry.h>
namespace tvm {
-using runtime::TVMRetValue;
-using runtime::TVMArgs;
using runtime::PackedFunc;
+using runtime::TVMArgs;
+using runtime::TVMRetValue;
// Attr getter.
class AttrGetter : public AttrVisitor {
const std::string& skey;
TVMRetValue* ret;
- AttrGetter(const std::string &skey,
- TVMRetValue* ret)
- : skey(skey), ret(ret) {}
+ AttrGetter(const std::string& skey, TVMRetValue* ret) : skey(skey), ret(ret) {}
bool found_ref_object{false};
}
};
-runtime::TVMRetValue ReflectionVTable::GetAttr(
- Object* self, const std::string& field_name) const {
+runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const std::string& field_name) const {
runtime::TVMRetValue ret;
AttrGetter getter(field_name, &ret);
}
}
if (!success) {
- LOG(FATAL) << "AttributeError: " << self->GetTypeKey()
- << " object has no attributed " << getter.skey;
+ LOG(FATAL) << "AttributeError: " << self->GetTypeKey() << " object has no attributed "
+ << getter.skey;
}
return ret;
}
public:
std::vector<std::string>* names;
- void Visit(const char* key, double* value) final {
- names->push_back(key);
- }
- void Visit(const char* key, int64_t* value) final {
- names->push_back(key);
- }
- void Visit(const char* key, uint64_t* value) final {
- names->push_back(key);
- }
- void Visit(const char* key, bool* value) final {
- names->push_back(key);
- }
- void Visit(const char* key, int* value) final {
- names->push_back(key);
- }
- void Visit(const char* key, void** value) final {
- names->push_back(key);
- }
- void Visit(const char* key, DataType* value) final {
- names->push_back(key);
- }
- void Visit(const char* key, std::string* value) final {
- names->push_back(key);
- }
- void Visit(const char* key, runtime::NDArray* value) final {
- names->push_back(key);
- }
- void Visit(const char* key, runtime::ObjectRef* value) final {
- names->push_back(key);
- }
+ void Visit(const char* key, double* value) final { names->push_back(key); }
+ void Visit(const char* key, int64_t* value) final { names->push_back(key); }
+ void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
+ void Visit(const char* key, bool* value) final { names->push_back(key); }
+ void Visit(const char* key, int* value) final { names->push_back(key); }
+ void Visit(const char* key, void** value) final { names->push_back(key); }
+ void Visit(const char* key, DataType* value) final { names->push_back(key); }
+ void Visit(const char* key, std::string* value) final { names->push_back(key); }
+ void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); }
+ void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); }
};
-std::vector<std::string>
-ReflectionVTable::ListAttrNames(Object* self) const {
+std::vector<std::string> ReflectionVTable::ListAttrNames(Object* self) const {
std::vector<std::string> names;
AttrDir dir;
dir.names = &names;
return &inst;
}
-ObjectPtr<Object>
-ReflectionVTable::CreateInitObject(const std::string& type_key,
- const std::string& repr_bytes) const {
+ObjectPtr<Object> ReflectionVTable::CreateInitObject(const std::string& type_key,
+ const std::string& repr_bytes) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
- LOG(FATAL) << "TypeError: " << type_key
- << " is not registered via TVM_REGISTER_NODE_TYPE";
+ LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fcreate_[tindex](repr_bytes);
}
std::string type_key;
std::unordered_map<std::string, runtime::TVMArgValue> attrs;
- void Visit(const char* key, double* value) final {
- *value = GetAttr(key).operator double();
- }
- void Visit(const char* key, int64_t* value) final {
- *value = GetAttr(key).operator int64_t();
- }
- void Visit(const char* key, uint64_t* value) final {
- *value = GetAttr(key).operator uint64_t();
- }
- void Visit(const char* key, int* value) final {
- *value = GetAttr(key).operator int();
- }
- void Visit(const char* key, bool* value) final {
- *value = GetAttr(key).operator bool();
- }
+ void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); }
+ void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); }
+ void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); }
+ void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); }
+ void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); }
void Visit(const char* key, std::string* value) final {
*value = GetAttr(key).operator std::string();
}
- void Visit(const char* key, void** value) final {
- *value = GetAttr(key).operator void*();
- }
- void Visit(const char* key, DataType* value) final {
- *value = GetAttr(key).operator DataType();
- }
+ void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); }
+ void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); }
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
setter.type_key = n->GetTypeKey();
CHECK_EQ(args.size() % 2, 0);
for (int i = 0; i < args.size(); i += 2) {
- setter.attrs.emplace(args[i].operator std::string(),
- args[i + 1]);
+ setter.attrs.emplace(args[i].operator std::string(), args[i + 1]);
}
auto* reflection = ReflectionVTable::Global();
reflection->VisitAttrs(n, &setter);
if (setter.attrs.size() != 0) {
std::ostringstream os;
os << setter.type_key << " does not contain field ";
- for (const auto &kv : setter.attrs) {
+ for (const auto& kv : setter.attrs) {
os << " " << kv.first;
}
LOG(FATAL) << os.str();
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* self = static_cast<Object*>(args[0].value().v_handle);
- auto names = std::make_shared<std::vector<std::string> >(
- ReflectionVTable::Global()->ListAttrNames(self));
+ auto names =
+ std::make_shared<std::vector<std::string> >(ReflectionVTable::Global()->ListAttrNames(self));
- *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) {
- int64_t i = args[0];
- if (i == -1) {
- *rv = static_cast<int64_t>(names->size());
- } else {
- *rv = (*names)[i];
- }
- });
+ *ret = PackedFunc([names](TVMArgs args, TVMRetValue* rv) {
+ int64_t i = args[0];
+ if (i == -1) {
+ *rv = static_cast<int64_t>(names->size());
+ } else {
+ *rv = (*names)[i];
+ }
+ });
}
// API function to make node.
*rv = ObjectRef(n);
}
+TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);
-TVM_REGISTER_GLOBAL("node.NodeGetAttr")
-.set_body(NodeGetAttr);
-
-TVM_REGISTER_GLOBAL("node.NodeListAttrNames")
-.set_body(NodeListAttrNames);
+TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames);
-TVM_REGISTER_GLOBAL("node.MakeNode")
-.set_body(MakeNode);
+TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode);
} // namespace tvm
* Printer utilities
* \file node/repr_printer.cc
*/
-#include <tvm/runtime/registry.h>
#include <tvm/node/repr_printer.h>
+#include <tvm/runtime/registry.h>
namespace tvm {
return inst;
}
-void Dump(const runtime::ObjectRef& n) {
- std::cerr << n << "\n";
-}
+void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; }
-void Dump(const runtime::Object* n) {
- Dump(runtime::GetRef<runtime::ObjectRef>(n));
-}
+void Dump(const runtime::Object* n) { Dump(runtime::GetRef<runtime::ObjectRef>(n)); }
-TVM_REGISTER_GLOBAL("node.AsRepr")
-.set_body_typed([](runtime::ObjectRef obj) {
+TVM_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](runtime::ObjectRef obj) {
std::ostringstream os;
os << obj;
return os.str();
*/
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/runtime/packed_func.h>
+#include <tvm/ir/attrs.h>
#include <tvm/node/container.h>
#include <tvm/node/reflection.h>
#include <tvm/node/serialization.h>
-#include <tvm/ir/attrs.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
-#include <string>
#include <cctype>
#include <map>
+#include <string>
#include "../support/base64.h"
namespace tvm {
-inline std::string Type2String(const DataType& t) {
- return runtime::DLDataType2String(t);
-}
+inline std::string Type2String(const DataType& t) { return runtime::DLDataType2String(t); }
-inline DataType String2Type(std::string s) {
- return DataType(runtime::String2DLDataType(s));
-}
+inline DataType String2Type(std::string s) { return DataType(runtime::String2DLDataType(s)); }
inline std::string Base64Decode(std::string s) {
dmlc::MemoryStringStream mstrm(&s);
/*! \brief values of a map or array. */
std::vector<size_t> data;
- void Save(dmlc::JSONWriter *writer) const {
+ void Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("type_key", type_key);
if (repr_bytes.size() != 0) {
writer->EndObject();
}
- void Load(dmlc::JSONReader *reader) {
+ void Load(dmlc::JSONReader* reader) {
attrs.clear();
data.clear();
repr_bytes.clear();
s << (*value);
node_->attrs[key] = s.str();
}
- void Visit(const char* key, int64_t* value) final {
- node_->attrs[key] = std::to_string(*value);
- }
- void Visit(const char* key, uint64_t* value) final {
- node_->attrs[key] = std::to_string(*value);
- }
- void Visit(const char* key, int* value) final {
- node_->attrs[key] = std::to_string(*value);
- }
- void Visit(const char* key, bool* value) final {
- node_->attrs[key] = std::to_string(*value);
- }
- void Visit(const char* key, std::string* value) final {
- node_->attrs[key] = *value;
- }
+ void Visit(const char* key, int64_t* value) final { node_->attrs[key] = std::to_string(*value); }
+ void Visit(const char* key, uint64_t* value) final { node_->attrs[key] = std::to_string(*value); }
+ void Visit(const char* key, int* value) final { node_->attrs[key] = std::to_string(*value); }
+ void Visit(const char* key, bool* value) final { node_->attrs[key] = std::to_string(*value); }
+ void Visit(const char* key, std::string* value) final { node_->attrs[key] = *value; }
void Visit(const char* key, void** value) final {
LOG(FATAL) << "not allowed to serialize a pointer";
}
- void Visit(const char* key, DataType* value) final {
- node_->attrs[key] = Type2String(*value);
- }
+ void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); }
void Visit(const char* key, runtime::NDArray* value) final {
- node_->attrs[key] = std::to_string(
- tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
+ node_->attrs[key] =
+ std::to_string(tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
}
void Visit(const char* key, ObjectRef* value) final {
- node_->attrs[key] = std::to_string(
- node_index_->at(const_cast<Object*>(value->get())));
+ node_->attrs[key] = std::to_string(node_index_->at(const_cast<Object*>(value->get())));
}
// Get the node
if (node->IsInstance<ArrayNode>()) {
ArrayNode* n = static_cast<ArrayNode*>(node);
for (size_t i = 0; i < n->data.size(); ++i) {
- node_->data.push_back(
- node_index_->at(const_cast<Object*>(n->data[i].get())));
+ node_->data.push_back(node_index_->at(const_cast<Object*>(n->data[i].get())));
}
} else if (node->IsInstance<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
for (const auto& kv : n->data) {
- node_->data.push_back(
- node_index_->at(const_cast<Object*>(kv.first.get())));
- node_->data.push_back(
- node_index_->at(const_cast<Object*>(kv.second.get())));
+ node_->data.push_back(node_index_->at(const_cast<Object*>(kv.first.get())));
+ node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
} else if (node->IsInstance<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
for (const auto& kv : n->data) {
node_->keys.push_back(kv.first);
- node_->data.push_back(
- node_index_->at(const_cast<Object*>(kv.second.get())));
+ node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
} else {
// recursively index normal object.
}
return it->second;
}
- template<typename T>
+ template <typename T>
void ParseValue(const char* key, T* value) const {
std::istringstream is(GetValue(key));
is >> *value;
LOG(FATAL) << "Wrong value format for field " << key;
}
}
- void Visit(const char* key, double* value) final {
- ParseValue(key, value);
- }
- void Visit(const char* key, int64_t* value) final {
- ParseValue(key, value);
- }
- void Visit(const char* key, uint64_t* value) final {
- ParseValue(key, value);
- }
- void Visit(const char* key, int* value) final {
- ParseValue(key, value);
- }
- void Visit(const char* key, bool* value) final {
- ParseValue(key, value);
- }
- void Visit(const char* key, std::string* value) final {
- *value = GetValue(key);
- }
+ void Visit(const char* key, double* value) final { ParseValue(key, value); }
+ void Visit(const char* key, int64_t* value) final { ParseValue(key, value); }
+ void Visit(const char* key, uint64_t* value) final { ParseValue(key, value); }
+ void Visit(const char* key, int* value) final { ParseValue(key, value); }
+ void Visit(const char* key, bool* value) final { ParseValue(key, value); }
+ void Visit(const char* key, std::string* value) final { *value = GetValue(key); }
void Visit(const char* key, void** value) final {
LOG(FATAL) << "not allowed to deserialize a pointer";
}
MapNode* n = static_cast<MapNode*>(node);
CHECK_EQ(node_->data.size() % 2, 0U);
for (size_t i = 0; i < node_->data.size(); i += 2) {
- n->data[ObjectRef(node_list_->at(node_->data[i]))]
- = ObjectRef(node_list_->at(node_->data[i + 1]));
+ n->data[ObjectRef(node_list_->at(node_->data[i]))] =
+ ObjectRef(node_list_->at(node_->data[i + 1]));
}
} else if (node->IsInstance<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
CHECK_EQ(node_->data.size(), node_->keys.size());
for (size_t i = 0; i < node_->data.size(); ++i) {
- n->data[node_->keys[i]]
- = ObjectRef(node_list_->at(node_->data[i]));
+ n->data[node_->keys[i]] = ObjectRef(node_list_->at(node_->data[i]));
}
} else {
reflection_->VisitAttrs(node, this);
// global attributes
AttrMap attrs;
- void Save(dmlc::JSONWriter *writer) const {
+ void Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("root", root);
writer->WriteObjectKeyValue("nodes", nodes);
writer->EndObject();
}
- void Load(dmlc::JSONReader *reader) {
+ void Load(dmlc::JSONReader* reader) {
attrs.clear();
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("root", &root);
for (const JSONNode& jnode : jgraph.nodes) {
if (jnode.type_key.length() != 0) {
- ObjectPtr<Object> node =
- reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes);
+ ObjectPtr<Object> node = reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes);
nodes.emplace_back(node);
} else {
nodes.emplace_back(ObjectPtr<Object>());
// Skip the nodes that has an repr bytes representation.
// NOTE: the second condition is used to guard the case
// where the repr bytes itself is an empty string "".
- if (setter.node_->repr_bytes.length() == 0 &&
- nodes[i] != nullptr &&
+ if (setter.node_->repr_bytes.length() == 0 && nodes[i] != nullptr &&
!reflection->GetReprBytes(nodes[i].get(), nullptr)) {
setter.Set(nodes[i].get());
}
return ObjectRef(nodes.at(jgraph.root));
}
-TVM_REGISTER_GLOBAL("node.SaveJSON")
-.set_body_typed(SaveJSON);
+TVM_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON);
-TVM_REGISTER_GLOBAL("node.LoadJSON")
-.set_body_typed(LoadJSON);
+TVM_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON);
} // namespace tvm
/*!
* \file src/node/structural_equal.cc
*/
-#include <tvm/node/structural_equal.h>
-#include <tvm/node/reflection.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
+#include <tvm/node/reflection.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/runtime/registry.h>
#include <unordered_map>
namespace tvm {
// Define the dispatch functio here since primary user is in this file.
-bool ReflectionVTable::
-SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const {
+bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other,
+ SEqualReducer equal) const {
uint32_t tindex = self->type_index();
if (tindex >= fsequal_reduce_.size() || fsequal_reduce_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey()
- << " is not registered via TVM_REGISTER_NODE_TYPE."
- << " Did you forget to set _type_has_method_sequal_reduce=true?";
+ << " is not registered via TVM_REGISTER_NODE_TYPE."
+ << " Did you forget to set _type_has_method_sequal_reduce=true?";
}
return fsequal_reduce_[tindex](self, other, equal);
}
* The order of SEqual being called is the same as the order as if we
* eagerly do recursive calls in SEqualReduce.
*/
-class RemapVarSEqualHandler :
- public SEqualReducer::Handler {
+class RemapVarSEqualHandler : public SEqualReducer::Handler {
public:
- explicit RemapVarSEqualHandler(bool assert_mode)
- : assert_mode_(assert_mode) {}
+ explicit RemapVarSEqualHandler(bool assert_mode) : assert_mode_(assert_mode) {}
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
// We cannot use check lhs.same_as(rhs) to check equality.
// Check the result.
bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
if (assert_mode_ && !result) {
- LOG(FATAL)
- << "ValueError: StructuralEqual check failed, caused by\n"
- << "lhs = " << lhs << "\nrhs = " << rhs;
+ LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n"
+ << "lhs = " << lhs << "\nrhs = " << rhs;
}
return result;
}
// The default equal as registered in the structural equal vtable.
bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
auto compute = [=]() {
- CHECK(lhs.defined() &&
- rhs.defined() &&
- lhs->type_index() == rhs->type_index());
+ CHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index());
// skip entries that already have equality maps.
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) {
};
TVM_REGISTER_GLOBAL("node.StructuralEqual")
-.set_body_typed([](const ObjectRef& lhs,
- const ObjectRef& rhs,
- bool assert_mode,
- bool map_free_vars) {
- return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars);
-});
-
-bool StructuralEqual::operator()(const ObjectRef& lhs,
- const ObjectRef& rhs) const {
+ .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode,
+ bool map_free_vars) {
+ return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars);
+ });
+
+bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
return RemapVarSEqualHandler(false).Equal(lhs, rhs, false);
}
/*!
* \file src/node/structural_hash.cc
*/
-#include <tvm/node/structural_hash.h>
-#include <tvm/node/reflection.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
+#include <tvm/node/reflection.h>
+#include <tvm/node/structural_hash.h>
#include <tvm/runtime/registry.h>
-#include <unordered_map>
#include <algorithm>
-
+#include <unordered_map>
namespace tvm {
// Define the dispatch functio here since primary user is in this file.
-void ReflectionVTable::
-SHashReduce(const Object* self, SHashReducer reducer) const {
+void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) const {
uint32_t tindex = self->type_index();
if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: SHashReduce of " << self->GetTypeKey()
- << " is not registered via TVM_REGISTER_NODE_TYPE";
+ << " is not registered via TVM_REGISTER_NODE_TYPE";
}
fshash_reduce_[tindex](self, reducer);
}
// In particular, when we traverse unordered_map, we should first sort
// the entries by keys(or hash of keys) before traversing.
-class VarCountingSHashHandler :
- public SHashReducer::Handler {
+class VarCountingSHashHandler : public SHashReducer::Handler {
public:
/*! \brief Pending reduce tasks. */
struct Task {
: object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {}
};
-
VarCountingSHashHandler() {}
void MarkGraphNode() final {
}
void SHashReduceHashedValue(size_t hashed_value) final {
- pending_tasks_.emplace_back(
- Task(ObjectRef(nullptr), hashed_value, false));
+ pending_tasks_.emplace_back(Task(ObjectRef(nullptr), hashed_value, false));
}
void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) final {
if (map_free_vars) {
// use counter value.
size_t value = std::hash<size_t>()(free_var_counter_++);
- pending_tasks_.emplace_back(
- Task(ObjectRef(nullptr), value, false));
+ pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false));
} else {
// use pointer hash
size_t value = std::hash<const runtime::Object*>()(var);
- pending_tasks_.emplace_back(
- Task(ObjectRef(nullptr), value, false));
+ pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false));
}
}
}
auto it = hash_memo_.find(object);
if (it != hash_memo_.end()) {
- pending_tasks_.emplace_back(
- Task(ObjectRef(nullptr), it->second, false));
+ pending_tasks_.emplace_back(Task(ObjectRef(nullptr), it->second, false));
} else {
// Push a pending task with initial value.
- pending_tasks_.emplace_back(
- Task(object, object->GetTypeKeyHash(), map_free_vars));
+ pending_tasks_.emplace_back(Task(object, object->GetTypeKeyHash(), map_free_vars));
}
}
// Append the graph node counter to the hash
// so that we can distinguish DAG from trees.
if (entry.graph_node_hash) {
- entry.reduced_hash = HashCombine(
- entry.reduced_hash,
- std::hash<size_t>()(graph_node_counter_++));
+ entry.reduced_hash =
+ HashCombine(entry.reduced_hash, std::hash<size_t>()(graph_node_counter_++));
}
hash_memo_[entry.object] = entry.reduced_hash;
}
std::unordered_map<ObjectRef, size_t, ObjectHash, ObjectEqual> hash_memo_;
};
-
TVM_REGISTER_GLOBAL("node.StructuralHash")
-.set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t {
- size_t hashed_value =
- VarCountingSHashHandler().Hash(object, map_free_vars);
- return static_cast<int64_t>(hashed_value);
-});
+ .set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t {
+ size_t hashed_value = VarCountingSHashHandler().Hash(object, map_free_vars);
+ return static_cast<int64_t>(hashed_value);
+ });
size_t StructuralHash::operator()(const ObjectRef& object) const {
return VarCountingSHashHandler().Hash(object, false);
*
* Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98
*/
+#include "doc.h"
+
#include <tvm/runtime/packed_func.h>
-#include <vector>
+
#include <sstream>
-#include "doc.h"
+#include <vector>
namespace tvm {
/*! \brief The str content in the text. */
std::string str;
- explicit DocTextNode(std::string str_val)
- : str(str_val) {
- }
+ explicit DocTextNode(std::string str_val) : str(str_val) {}
static constexpr const char* _type_key = "printer.DocText";
TVM_DECLARE_FINAL_OBJECT_INFO(DocTextNode, DocAtomNode);
/*! \brief The amount of indent in newline. */
int indent;
- explicit DocLineNode(int indent)
- : indent(indent) {}
+ explicit DocLineNode(int indent) : indent(indent) {}
static constexpr const char* _type_key = "printer.DocLine";
TVM_DECLARE_FINAL_OBJECT_INFO(DocLineNode, DocAtomNode);
class DocLine : public DocAtom {
public:
- explicit DocLine(int indent) {
- data_ = runtime::make_object<DocLineNode>(indent);
- }
+ explicit DocLine(int indent) { data_ = runtime::make_object<DocLineNode>(indent); }
TVM_DEFINE_OBJECT_REF_METHODS(DocLine, DocAtom, DocLineNode);
};
// DSL function implementations
Doc& Doc::operator<<(const Doc& right) {
CHECK(this != &right);
- this->stream_.insert(
- this->stream_.end(), right.stream_.begin(), right.stream_.end());
+ this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end());
return *this;
}
-Doc& Doc::operator<<(std::string right) {
- return *this << DocText(right);
-}
+Doc& Doc::operator<<(std::string right) { return *this << DocText(right); }
Doc& Doc::operator<<(const DocAtom& right) {
this->stream_.push_back(right);
return os.str();
}
-Doc Doc::NewLine(int indent) {
- return Doc() << DocLine(indent);
-}
+Doc Doc::NewLine(int indent) { return Doc() << DocLine(indent); }
-Doc Doc::Text(std::string text) {
- return Doc() << DocText(text);
-}
+Doc Doc::Text(std::string text) { return Doc() << DocText(text); }
Doc Doc::RawText(std::string text) {
return Doc() << DocAtom(runtime::make_object<DocTextNode>(text));
}
}
-Doc Doc::Brace(std::string open,
- const Doc& body,
- std::string close,
- int indent) {
+Doc Doc::Brace(std::string open, const Doc& body, std::string close, int indent) {
Doc doc;
doc << open;
doc << Indent(indent, NewLine() << body) << NewLine();
#ifndef TVM_PRINTER_DOC_H_
#define TVM_PRINTER_DOC_H_
+#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
-#include <tvm/node/node.h>
+
#include <string>
-#include <vector>
#include <type_traits>
+#include <vector>
namespace tvm {
/*!
* \brief Managed reference to DocAtomNode.
* \sa DocAtomNode.
-*/
+ */
class DocAtom : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(DocAtom, ObjectRef, DocAtomNode);
* \tparam T the type of the value.
* \return reference to self.
*/
- template<typename T,
- typename = typename std::enable_if<!std::is_class<T>::value>::type>
+ template <typename T, typename = typename std::enable_if<!std::is_class<T>::value>::type>
Doc& operator<<(const T& value) {
std::ostringstream os;
os << value;
* \param indent amount of indentation.
* \return The created doc.
*/
- static Doc Brace(std::string open,
- const Doc& body,
- std::string close,
- int indent = 2);
+ static Doc Brace(std::string open, const Doc& body, std::string close, int indent = 2);
/*!
* \brief Create a doc by concatenating together with separator.
* \param vec The docs to be concatenated.
#ifndef TVM_PRINTER_META_DATA_H_
#define TVM_PRINTER_META_DATA_H_
-#include <tvm/node/serialization.h>
#include <tvm/node/container.h>
+#include <tvm/node/serialization.h>
+
#include <string>
#include <unordered_map>
+
#include "doc.h"
namespace tvm {
}
std::string type_key = node->GetTypeKey();
CHECK(!type_key.empty());
- Array<ObjectRef>& mvector =
- meta_data_[type_key];
+ Array<ObjectRef>& mvector = meta_data_[type_key];
int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node);
Doc doc;
* \param node The query node
* \return whether the node has been put in meta
*/
- bool InMeta(const ObjectRef& node) {
- return meta_repr_.find(node) != meta_repr_.end();
- }
+ bool InMeta(const ObjectRef& node) { return meta_repr_.find(node) != meta_repr_.end(); }
/*!
* \brief Print a key value pair
}
/*! \return whether the meta data context is empty. */
- bool empty() const {
- return meta_data_.empty();
- }
+ bool empty() const { return meta_data_.empty(); }
private:
/*! \brief additional metadata stored in TVM json format */
* - Var
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
-#include <tvm/ir/type_functor.h>
#include <tvm/ir/module.h>
-#include <tvm/tir/function.h>
+#include <tvm/ir/type_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
+#include <tvm/tir/function.h>
+
+#include "../ir/attr_functor.h"
+#include "../relay/analysis/dependency_graph.h"
#include "doc.h"
#include "meta_data.h"
-#include "../relay/analysis/dependency_graph.h"
-#include "../ir/attr_functor.h"
#include "text_printer.h"
namespace tvm {
namespace relay {
/*!
- * \brief Print additional info about expr in comment.
- * \param expr The expression.
- */
+ * \brief Print additional info about expr in comment.
+ * \param expr The expression.
+ */
Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
Doc doc;
// default annotations
}
Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) {
- if (node->IsInstance<BaseFuncNode>() &&
- !node->IsInstance<relay::FunctionNode>()) {
+ if (node->IsInstance<BaseFuncNode>() && !node->IsInstance<relay::FunctionNode>()) {
// Temporarily skip non-relay functions.
// TODO(tvm-team) enhance the code to work for all functions
} else if (node.as<ExprNode>()) {
Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
bool is_non_relay_func =
- node->IsInstance<BaseFuncNode>() &&
- !node->IsInstance<relay::FunctionNode>();
+ node->IsInstance<BaseFuncNode>() && !node->IsInstance<relay::FunctionNode>();
if (node.as<ExprNode>() && !is_non_relay_func) {
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as<TypeNode>()) {
return doc << "%" << n;
}
-Doc RelayTextPrinter::AllocTemp() {
- return TempVar(temp_var_counter_++);
-}
+Doc RelayTextPrinter::AllocTemp() { return TempVar(temp_var_counter_++); }
/*!
- * \brief get a unique name with the corresponding prefix
- * \param prefix The prefix of the name
- * \return The returned name.
- */
+ * \brief get a unique name with the corresponding prefix
+ * \param prefix The prefix of the name
+ * \return The returned name.
+ */
Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) {
std::string unique_prefix = prefix;
auto it = name_alloc_map_.find(prefix);
Doc RelayTextPrinter::Print(Kind k) {
switch (k) {
- case kType:
- return Doc::Text("Type");
- case kShapeVar:
- return Doc::Text("Shape");
- case kBaseType:
- return Doc::Text("BaseType");
- case kConstraint:
- return Doc::Text("Constraint");
- case kAdtHandle:
- return Doc::Text("AdtHandle");
- case kTypeData:
- return Doc::Text("TypeData");
- default:
- LOG(ERROR) << "Unknown Kind";
- throw;
+ case kType:
+ return Doc::Text("Type");
+ case kShapeVar:
+ return Doc::Text("Shape");
+ case kBaseType:
+ return Doc::Text("BaseType");
+ case kConstraint:
+ return Doc::Text("Constraint");
+ case kAdtHandle:
+ return Doc::Text("AdtHandle");
+ case kTypeData:
+ return Doc::Text("TypeData");
+ default:
+ LOG(ERROR) << "Unknown Kind";
+ throw;
}
}
/*!
// Should only be triggered when op is a free variable being visited for the
// first time.
-Doc RelayTextPrinter::VisitExpr_(const VarNode* op) {
- return AllocVar(GetRef<Var>(op));
-}
+Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef<Var>(op)); }
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param value The value to be printed.
*/
-template<typename T>
+template <typename T>
Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) {
std::ostringstream os;
if (dtype == DataType::Int(32)) {
Doc RelayTextPrinter::VisitExpr_(const LetNode* op) {
Doc doc;
- doc
- << "let "
- << AllocVar(op->var)
- << " = "
- << Print(op->value, false, true)
- << ";"
- << Doc::NewLine();
+ doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << ";"
+ << Doc::NewLine();
// we use a scope here so GNF hoisting doesn't escape too far
// and nested, unique lets are not hoisted
doc << PrintScope(op->body);
} else {
// def @xyz = meta['ExternalFunc'][id]
Doc doc;
- doc << prefix << " = " << meta_->GetMetaNode(base_func);
+ doc << prefix << " = " << meta_->GetMetaNode(base_func);
return doc;
}
}
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}
-Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
- return Doc::Text('@' + op->name_hint);
-}
+Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text('@' + op->name_hint); }
-Doc RelayTextPrinter::VisitExpr_(const OpNode* op) {
- return Doc::Text(op->name);
-}
+Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }
Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
Doc doc;
return doc;
}
-Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) {
- return Doc::Text("_");
-}
+Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) { return Doc::Text("_"); }
-Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) {
- return AllocVar(pv->var);
-}
+Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) { return AllocVar(pv->var); }
Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) {
Doc doc;
return Print(GetRef<ObjectRef>(node), true);
}
-Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) {
- return Doc::Text(node->name_hint);
-}
+Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) { return Doc::Text(node->name_hint); }
Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) {
return Doc::Text(node->name_hint);
return Doc::StrLiteral(op->value);
}
-
/*!
* \brief Attribute printer which prints the attributes in the call.
*/
-class RelayTextPrinter::AttrPrinter :
- public AttrVisitor {
+class RelayTextPrinter::AttrPrinter : public AttrVisitor {
public:
- AttrPrinter(std::vector<Doc>* doc, RelayTextPrinter* parent)
- : docs(doc), parent_(parent) {}
+ AttrPrinter(std::vector<Doc>* doc, RelayTextPrinter* parent) : docs(doc), parent_(parent) {}
- template<typename T>
+ template <typename T>
void PrintKV(const char* key, const T& value) {
Doc doc;
doc << key << "=" << value;
doc << key << "=" << *value << "f";
docs->push_back(doc);
}
- void Visit(const char* key, int64_t* value) final {
- PrintKV(key, *value);
- }
- void Visit(const char* key, uint64_t* value) final {
- PrintKV(key, *value);
- }
- void Visit(const char* key, int* value) final {
- PrintKV(key, *value);
- }
- void Visit(const char* key, bool* value) final {
- PrintKV(key, Doc::PyBoolLiteral(*value));
- }
- void Visit(const char* key, std::string* value) final {
- PrintKV(key, Doc::StrLiteral(*value));
- }
- void Visit(const char* key, void** value) final {
- LOG(FATAL) << "do not allow void as argument";
- }
+ void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); }
+ void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); }
+ void Visit(const char* key, int* value) final { PrintKV(key, *value); }
+ void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); }
+ void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); }
+ void Visit(const char* key, void** value) final { LOG(FATAL) << "do not allow void as argument"; }
void Visit(const char* key, DataType* value) final {
PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value)));
}
RelayTextPrinter* parent_;
};
-std::vector<Doc> RelayTextPrinter::PrintCallAttrs(
- const Attrs& attrs, const Expr& op) {
+std::vector<Doc> RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
const auto* op_node = op.as<OpNode>();
* that can be parsed by a parser.
*/
+#include "text_printer.h"
+
#include <tvm/tir/function.h>
+
#include <string>
-#include "text_printer.h"
namespace tvm {
return doc.str();
}
-String AsText(const ObjectRef& node,
- bool show_meta_data,
+String AsText(const ObjectRef& node, bool show_meta_data,
runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
Doc doc;
doc << kSemVer << Doc::NewLine();
runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
if (annotate != nullptr) {
ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
- [&annotate](const ObjectRef& expr) -> std::string {
- return annotate(expr);
- });
+ [&annotate](const ObjectRef& expr) -> std::string { return annotate(expr); });
}
doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node);
return doc.str();
}
-TVM_REGISTER_GLOBAL("ir.PrettyPrint")
-.set_body_typed(PrettyPrint);
+TVM_REGISTER_GLOBAL("ir.PrettyPrint").set_body_typed(PrettyPrint);
-TVM_REGISTER_GLOBAL("ir.AsText")
-.set_body_typed(AsText);
+TVM_REGISTER_GLOBAL("ir.AsText").set_body_typed(AsText);
} // namespace tvm
#ifndef TVM_PRINTER_TEXT_PRINTER_H_
#define TVM_PRINTER_TEXT_PRINTER_H_
+#include <tvm/ir/module.h>
+#include <tvm/ir/type_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/tir/expr_functor.h>
-#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/op.h>
-#include <tvm/ir/type_functor.h>
-#include <tvm/ir/module.h>
#include <tvm/tir/function.h>
-#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/pattern_functor.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <string>
#include <unordered_map>
#include <vector>
-#include <string>
-#include "../relay/analysis/dependency_graph.h"
-#include "../ir/attr_functor.h"
+#include "../ir/attr_functor.h"
+#include "../relay/analysis/dependency_graph.h"
#include "doc.h"
#include "meta_data.h"
#include "text_printer.h"
namespace tvm {
namespace relay {
-class RelayTextPrinter :
- public ExprFunctor<Doc(const Expr&)>,
- public PatternFunctor<Doc(const Pattern&)>,
- public TypeFunctor<Doc(const Type&)>,
- public AttrFunctor<Doc(const ObjectRef&)> {
+class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
+ public PatternFunctor<Doc(const Pattern&)>,
+ public TypeFunctor<Doc(const Type&)>,
+ public AttrFunctor<Doc(const ObjectRef&)> {
public:
- explicit RelayTextPrinter(bool show_meta_data,
- TextMetaDataContext* meta,
+ explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
: show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
/*!
- * \brief Print additional info about expr in comment.
- * \param expr The expression.
- */
+ * \brief Print additional info about expr in comment.
+ * \param expr The expression.
+ */
Doc PrintOptionalInfo(const Expr& expr);
// indent a new body
Doc PrintBody(const ObjectRef& node, int indent = 2);
Doc TempVar(int n);
Doc AllocTemp();
/*!
- * \brief get a unique name with the corresponding prefix
- * \param prefix The prefix of the name
- * \return The returned name.
- */
+ * \brief get a unique name with the corresponding prefix
+ * \param prefix The prefix of the name
+ * \return The returned name.
+ */
Doc GetUniqueName(const std::string& prefix);
Doc Print(Kind k);
/*!
void Collect(const ObjectRef& n) {
// these nodes can be print directly(StringLiteral or use identifier to identify)
- if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
- || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+ if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>() ||
+ n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
return;
}
if (n->IsInstance<StmtNode>()) {
public TypeFunctor<Doc(const Type&)> {
public:
explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
- : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {}
+ : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {}
/*! \brief Print the node */
Doc Print(const ObjectRef& node);
Doc PrintIterVar(const IterVarNode* op);
Doc PrintRange(const RangeNode* op);
Doc PrintBuffer(const BufferNode* op);
- Doc PrintString(const StringObj* op) {
- return Doc::StrLiteral(op->data);
- }
+ Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
/*!
* \brief special method to print out data type
public:
explicit TextPrinter(bool show_meta_data,
const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate)
- : show_meta_data_(show_meta_data), annotate_(annotate),
+ : show_meta_data_(show_meta_data),
+ annotate_(annotate),
relay_text_printer_(show_meta_data, &meta_, annotate),
tir_text_printer_(show_meta_data, &meta_) {}
Doc doc;
if (node->IsInstance<IRModuleNode>()) {
doc << PrintMod(Downcast<IRModule>(node));
- } else if (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>()
- || node->IsInstance<tir::StmtNode>()) {
+ } else if (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() ||
+ node->IsInstance<tir::StmtNode>()) {
doc << tir_text_printer_.Print(node);
} else {
doc << relay_text_printer_.PrintFinal(node);
}
// print PrimFunc
Doc doc;
- doc << "primfn" << "(";
+ doc << "primfn"
+ << "(";
// print params and its type annotation
std::vector<Doc> params;
for (const auto& param : op->params) {
std::vector<Doc> buffer_docs;
for (const auto& it : memo_buf_) {
const auto& buf = it.first;
- buffer_docs.push_back(Print(buf)
- << Doc::Text(": Buffer(") << Print(buf->data) << ", "
- << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", "
- << Print(buf->strides));
+ buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") << Print(buf->data) << ", "
+ << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", "
+ << Print(buf->strides));
if (!is_zero(buf->elem_offset)) {
buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
}
for (const auto& it : op->buffer_map) {
buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
}
- doc << Doc::Indent(2, Doc::NewLine()
- << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
+ doc << Doc::Indent(
+ 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
doc << PrintBody(op->body);
return doc;
}
return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<Var>(op));
}
-#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \
- Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \
- Doc doc; \
- doc << "(" << Print(op->a) << OpString; \
- doc << Print(op->b) << ")"; \
- return doc; \
+#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \
+ Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \
+ Doc doc; \
+ doc << "(" << Print(op->a) << OpString; \
+ doc << Print(op->b) << ")"; \
+ return doc; \
}
TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
Doc doc;
- doc << "(" << PrintDType(op->dtype) << "*)"
- << Print(op->buffer_var) << "[" << Print(op->index) << "])";
+ doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index)
+ << "])";
if (!is_one(op->predicate)) {
doc << " if " << Print(op->predicate);
}
inline const char* CallType2String(CallNode::CallType t) {
switch (t) {
- case CallNode::Extern:return "extern";
- case CallNode::ExternCPlusPlus:return "extern_cpp";
- case CallNode::PureExtern:return "pure_extern";
- case CallNode::Halide:return "halide";
- case CallNode::Intrinsic:return "intrin";
- case CallNode::PureIntrinsic:return "pure_intrin";
+ case CallNode::Extern:
+ return "extern";
+ case CallNode::ExternCPlusPlus:
+ return "extern_cpp";
+ case CallNode::PureExtern:
+ return "pure_extern";
+ case CallNode::Halide:
+ return "halide";
+ case CallNode::Intrinsic:
+ return "intrin";
+ case CallNode::PureIntrinsic:
+ return "pure_intrin";
}
LOG(FATAL) << "Unknown CallType";
return "Unknown";
for (const auto& arg : op->args) {
args.push_back(Print(arg));
}
- doc << PrintSep(args, Doc::Text(", "))
- << ", dtype=" << PrintDType(op->dtype)
+ doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype)
<< ", type=" << Doc::StrLiteral(CallType2String(op->call_type))
<< ", index=" << op->value_index << ")";
return doc;
inline const char* ForType2String(ForType t) {
switch (t) {
- case ForType::Serial:return "serial";
- case ForType::Parallel:return "parallel";
- case ForType::Vectorized:return "vectorized";
- case ForType::Unrolled:return "unroll";
+ case ForType::Serial:
+ return "serial";
+ case ForType::Parallel:
+ return "parallel";
+ case ForType::Vectorized:
+ return "vectorized";
+ case ForType::Unrolled:
+ return "unroll";
}
LOG(FATAL) << "Unknown ForType";
return "Unknown";
}
doc << Doc::Text(os.str());
switch (dtype.code()) {
- case kDLInt: doc << "i"; break;
- case kDLUInt: doc << "u"; break;
- case kDLFloat: doc << "f"; break;
+ case kDLInt:
+ doc << "i";
+ break;
+ case kDLUInt:
+ doc << "u";
+ break;
+ case kDLFloat:
+ doc << "f";
+ break;
}
doc << Doc::Text(std::to_string(dtype.bits()));
if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes()));
std::string unique_prefix = prefix;
auto it = name_alloc_map_.find(prefix);
if (it != name_alloc_map_.end()) {
- while (name_alloc_map_.count(
- unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {}
+ while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {
+ }
}
name_alloc_map_[unique_prefix] = 0;
return Doc::Text(unique_prefix);
#include "annotated_region_set.h"
-#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
+#include <tvm/relay/expr.h>
#include <tvm/runtime/container.h>
#include <unordered_map>
#include <vector>
-
namespace tvm {
namespace relay {
return AnnotatedRegion(nullptr);
}
-void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
- AnnotatedRegion dest) {
+void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, AnnotatedRegion dest) {
if (dest == src) {
return;
}
for (auto arg : args) {
const CallNode* end = arg.as<CallNode>();
if (end && end->op == end_op_) { // Ignore closed regions.
- continue;
+ continue;
}
region = region_set_->GetRegion(arg);
if (region.defined()) {
- break;
+ break;
}
}
for (auto arg : args) {
const CallNode* end = arg.as<CallNode>();
if (end && end->op == end_op_) { // Ignore closed regions.
- continue;
+ continue;
}
auto arg_region = region_set_->GetRegion(arg);
}
}
- void VisitExpr_(const TupleNode* op) {
- AddToArgRegion(GetRef<Tuple>(op), op->fields);
- }
+ void VisitExpr_(const TupleNode* op) { AddToArgRegion(GetRef<Tuple>(op), op->fields); }
void VisitExpr_(const TupleGetItemNode* g) {
Array<Expr> args = {g->tuple};
TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode);
TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet")
-.set_body_typed([](Expr expr, Op begin, Op end) {
- return AnnotatedRegionSet::Create(expr, begin, end);
-});
+ .set_body_typed([](Expr expr, Op begin, Op end) {
+ return AnnotatedRegionSet::Create(expr, begin, end);
+ });
TVM_REGISTER_GLOBAL("relay.analysis.GetRegion")
-.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) {
- return region_set->GetRegion(expr);
-});
-
+ .set_body_typed([](AnnotatedRegionSet region_set, Expr expr) {
+ return region_set->GetRegion(expr);
+ });
} // namespace relay
} // namespace tvm
#ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
#define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
+#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
-#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/runtime/container.h>
#include <tvm/relay/transform.h>
+#include <tvm/runtime/container.h>
+#include <list>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
-#include <list>
namespace tvm {
namespace relay {
}
/*! \brief Get the region ID. */
- int GetID() const {
- return id_;
- }
+ int GetID() const { return id_; }
/*! \brief Get the region target. */
- std::string GetTarget() const {
- return target_;
- }
+ std::string GetTarget() const { return target_; }
/*! \brief Get the region's inputs. */
- std::list<Expr> GetInputs() const {
- return ins_;
- }
+ std::list<Expr> GetInputs() const { return ins_; }
/*! \brief Get the region's outputs. */
- std::list<Expr> GetOutputs() const {
- return outs_;
- }
+ std::list<Expr> GetOutputs() const { return outs_; }
/*! \brief Get the region's nodes. */
- std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
- return nodes_;
- }
+ std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const { return nodes_; }
static constexpr const char* _type_key = "relay.AnnotatedRegion";
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object);
/*!
* \brief An object to hold the properties of a region as used by the
* AnnotatedRegionSet class. This should be considered read-only.
-*/
+ */
class AnnotatedRegion : public ObjectRef {
public:
AnnotatedRegion() {
}
/*!
- * \brief Construct from an object pointer.
- * \param n The object pointer.
- */
+ * \brief Construct from an object pointer.
+ * \param n The object pointer.
+ */
explicit AnnotatedRegion(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return Mutable pointers to the node. */
};
class AnnotatedRegionSetNode : public Object {
- using UnorderedRegionSet =
- std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
+ using UnorderedRegionSet = std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
// Create iterator alias for a RegionSet object.
using iterator = UnorderedRegionSet::iterator;
using const_iterator = UnorderedRegionSet::const_iterator;
AnnotatedRegionSetNode() = default;
/*! \return The begin iterator */
- iterator begin() {
- return regions_.begin();
- }
+ iterator begin() { return regions_.begin(); }
/*! \return The end iterator */
- iterator end() {
- return regions_.end();
- }
+ iterator end() { return regions_.end(); }
/*! \return The const begin iterator */
- const_iterator begin() const {
- return regions_.begin();
- }
+ const_iterator begin() const { return regions_.begin(); }
/*! \return The const end iterator */
- const_iterator end() const {
- return regions_.end();
- }
+ const_iterator end() const { return regions_.end(); }
/*!
* \brief Get the region that an expression belongs to.
AnnotatedRegion GetRegion(const Expr& expr) const;
/*!
- * \brief Merge src region into dest region.
- *
- * \param src The region to merge - will be erased.
- * \param dest The region into which src will be merged.
- */
+ * \brief Merge src region into dest region.
+ *
+ * \param src The region to merge - will be erased.
+ * \param dest The region into which src will be merged.
+ */
void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest);
void VisitAttrs(AttrVisitor* v) {
* to update and query regions.
*/
class AnnotatedRegionSet : public ObjectRef {
- using UnorderedRegionSet =
- std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
+ using UnorderedRegionSet = std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
// Create iterator alias for a RegionSet object.
using iterator = UnorderedRegionSet::iterator;
using const_iterator = UnorderedRegionSet::const_iterator;
}
/*!
- * \brief Construct from an object pointer.
- *
- * \param n The object pointer.
- */
+ * \brief Construct from an object pointer.
+ *
+ * \param n The object pointer.
+ */
explicit AnnotatedRegionSet(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The begin iterator. */
}
/*! \return The end iterator. */
const_iterator end() const {
- const auto *n = operator->();
+ const auto* n = operator->();
CHECK(n);
return n->end();
}
/*! \return The region an expression belongs to. */
AnnotatedRegion operator[](const Expr& expr) {
- const auto *n = operator->();
+ const auto* n = operator->();
CHECK(n);
return n->GetRegion(expr);
}
*
* \return The created RegionSet for the expression.
*/
- static AnnotatedRegionSet Create(const Expr& expr,
- const Op& begin,
- const Op& end);
+ static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end);
private:
/*! \brief Helper class to construct a RegionSet from an expr.*/
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/object.h>
+
#include <algorithm>
#include <memory>
#include <sstream>
const CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) const {
const_iterator cit = call_graph_.find(gv);
- CHECK(cit != call_graph_.end())
- << "GlobalVar " << gv->name_hint << " not found in the call graph!";
+ CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint
+ << " not found in the call graph!";
return cit->second.get();
}
CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) {
const_iterator cit = call_graph_.find(gv);
- CHECK(cit != call_graph_.end())
- << "GlobalVar " << gv->name_hint << " not found in the call graph!";
+ CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint
+ << " not found in the call graph!";
return cit->second.get();
}
BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const {
CHECK(module->ContainGlobalVar(var->name_hint))
- << "GlobalVar " << var->name_hint
- << " not found in the current ir module";
+ << "GlobalVar " << var->name_hint << " not found in the current ir module";
return module->Lookup(var);
}
bool update_call_graph) {
CHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1))
<< "Cannot remove global var " << cg_node->GetNameHint()
- << " from call graph, because it still calls "
- << cg_node->size() << " other global functions";
+ << " from call graph, because it still calls " << cg_node->size()
+ << " other global functions";
if (update_call_graph) {
// Update the call graph by removing all edges that point to the node
<< " with # refs = " << (*this)[it.first]->GetRefCount();
}
}
- LOG(FATAL) << "Expected " << module->functions.size()
- << " globals, but received "
+ LOG(FATAL) << "Expected " << module->functions.size() << " globals, but received "
<< ret.size();
}
// that are visited by previous CallGraphEntry entries can be memoized. This
// helps us to make sure no entry will be visited multiple times when collecting
// the nodes for an entire call graph.
-std::vector<CallGraphEntry*> CallGraphEntry::TopologicalOrder(
- CallGraphEntrySet* visited) const {
+std::vector<CallGraphEntry*> CallGraphEntry::TopologicalOrder(CallGraphEntrySet* visited) const {
std::vector<CallGraphEntry*> ret;
std::vector<CallGraphEntry*> current_nodes;
if (visited->find(this) == visited->end()) {
// Remove an edge from the current global function to the callee.
void CallGraphEntry::RemoveCallTo(const GlobalVar& callee) {
for (auto it = begin();; ++it) {
- CHECK(it != end()) << "Cannot find global function "
- << callee->name_hint << " to remove!";
+ CHECK(it != end()) << "Cannot find global function " << callee->name_hint << " to remove!";
if (it->second->GetGlobalVar() == callee) {
// Only remove one occurrence of the call site.
it->second->DecRef();
}
// Make sure all references to the callee are removed.
CHECK_EQ(callee->GetRefCount(), 0U)
- << "All references to " << callee->GetNameHint()
- << " should have been removed";
+ << "All references to " << callee->GetNameHint() << " should have been removed";
}
void CallGraphEntry::Print(std::ostream& os) const {
TVM_REGISTER_NODE_TYPE(CallGraphNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<CallGraphNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const CallGraphNode*>(ref.get());
- CHECK(node);
- p->stream << "CallGraph: \n" << GetRef<CallGraph>(node);
-});
+ .set_dispatch<CallGraphNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const CallGraphNode*>(ref.get());
+ CHECK(node);
+ p->stream << "CallGraph: \n" << GetRef<CallGraph>(node);
+ });
-TVM_REGISTER_GLOBAL("relay.analysis.CallGraph")
-.set_body_typed([](IRModule module) {
+TVM_REGISTER_GLOBAL("relay.analysis.CallGraph").set_body_typed([](IRModule module) {
return CallGraph(module);
});
-TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph")
-.set_body_typed([](CallGraph call_graph) {
+TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph").set_body_typed([](CallGraph call_graph) {
std::stringstream ss;
ss << call_graph;
return ss.str();
});
-TVM_REGISTER_GLOBAL("relay.analysis.GetModule")
-.set_body_typed([](CallGraph call_graph) {
+TVM_REGISTER_GLOBAL("relay.analysis.GetModule").set_body_typed([](CallGraph call_graph) {
return call_graph->module;
});
TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar")
-.set_body_typed([](CallGraph call_graph, GlobalVar var) {
- const auto* entry_node = call_graph[var];
- std::stringstream ss;
- ss << *entry_node;
- return ss.str();
-});
+ .set_body_typed([](CallGraph call_graph, GlobalVar var) {
+ const auto* entry_node = call_graph[var];
+ std::stringstream ss;
+ ss << *entry_node;
+ return ss.str();
+ });
TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar")
-.set_body_typed([](CallGraph call_graph, GlobalVar var) {
- const auto* entry_node = call_graph[var];
- return static_cast<int>(entry_node->GetRefCount());
-});
+ .set_body_typed([](CallGraph call_graph, GlobalVar var) {
+ const auto* entry_node = call_graph[var];
+ return static_cast<int>(entry_node->GetRefCount());
+ });
TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount")
-.set_body_typed([](CallGraph call_graph, GlobalVar var) {
- const auto* entry_node = call_graph[var];
- return static_cast<int>(entry_node->size());
-});
+ .set_body_typed([](CallGraph call_graph, GlobalVar var) {
+ const auto* entry_node = call_graph[var];
+ return static_cast<int>(entry_node->size());
+ });
TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive")
-.set_body_typed([](CallGraph call_graph, GlobalVar var) {
- const auto* entry_node = call_graph[var];
- return entry_node->IsRecursive();
-});
+ .set_body_typed([](CallGraph call_graph, GlobalVar var) {
+ const auto* entry_node = call_graph[var];
+ return entry_node->IsRecursive();
+ });
} // namespace relay
} // namespace tvm
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/runtime/object.h>
+
#include <memory>
#include <string>
#include <unordered_map>
class CallGraphNode : public Object {
using CallGraphMap =
- std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectHash,
- ObjectEqual>;
+ std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectHash, ObjectEqual>;
// Create iterator alias for a CallGraphNode object.
using iterator = CallGraphMap::iterator;
using const_iterator = CallGraphMap::const_iterator;
/*! \brief Default constructor. */
CallGraphNode() {}
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("module", &module);
- }
+ void VisitAttrs(AttrVisitor* v) { v->Visit("module", &module); }
/*!
* \brief Print the call graph.
void Print(std::ostream& os) const;
/*! \return The begin iterator. */
- iterator begin() {
- return call_graph_.begin();
- }
+ iterator begin() { return call_graph_.begin(); }
/*! \return The end iterator. */
- iterator end() {
- return call_graph_.end();
- }
+ iterator end() { return call_graph_.end(); }
/*! \return The begin iterator. */
- const_iterator begin() const {
- return call_graph_.begin();
- }
+ const_iterator begin() const { return call_graph_.begin(); }
/*! \return The end iterator. */
- const_iterator end() const {
- return call_graph_.end();
- }
+ const_iterator end() const { return call_graph_.end(); }
/*!
* \brief Get an element from the CallGraphNode using a GlobalVar.
*
* \return The GlobalVar removed from the current module.
*/
- GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node,
- bool update_call_graph = false);
+ GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph = false);
/*!
* \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for
*/
class CallGraph : public ObjectRef {
using CallGraphMap =
- std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectHash,
- ObjectEqual>;
+ std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectHash, ObjectEqual>;
// Create iterator alias for a CallGraph object.
using iterator = CallGraphMap::iterator;
using const_iterator = CallGraphMap::const_iterator;
CallGraphEntry& operator=(const CallGraphEntry&) = delete;
/*! \return The begin iterator */
- iterator begin() {
- return called_globals_.begin();
- }
+ iterator begin() { return called_globals_.begin(); }
/*! \return The end iterator */
- iterator end() {
- return called_globals_.end();
- }
+ iterator end() { return called_globals_.end(); }
/*! \return The const begin iterator */
- const_iterator begin() const {
- return called_globals_.begin();
- }
+ const_iterator begin() const { return called_globals_.begin(); }
/*! \return The const end iterator */
- const_iterator end() const {
- return called_globals_.end();
- }
+ const_iterator end() const { return called_globals_.end(); }
/*!
* \brief Return if the list of called nodes is empty.
*
* \return true if the list is empty. Otherwise, false.
*/
- bool empty() const {
- return called_globals_.empty();
- }
+ bool empty() const { return called_globals_.empty(); }
/*!
* \brief Return the size of the list that represents the nodes are called by
*
* \return The number of called nodes.
*/
- uint32_t size() const {
- return static_cast<uint32_t>(called_globals_.size());
- }
+ uint32_t size() const { return static_cast<uint32_t>(called_globals_.size()); }
/*!
* \brief Fetch the i-th CallGraphEntry from the list of nodes that are called
*
* \return The count.
*/
- uint32_t GetRefCount() const {
- return ref_cnt_;
- }
+ uint32_t GetRefCount() const { return ref_cnt_; }
/*!
* \brief Return the GlobalVar stored in the current CallGraphEntry.
*
* \return The GlobalVar.
*/
- GlobalVar GetGlobalVar() const {
- return global_;
- }
+ GlobalVar GetGlobalVar() const { return global_; }
/*!
* \brief Return the name hint of the GlobalVar stored in the CallGraphEntry.
*
* \return The name hint of the global function.
*/
- std::string GetNameHint() const {
- return global_->name_hint;
- }
+ std::string GetNameHint() const { return global_->name_hint; }
/*!
* \brief Return if the global function corresponding to the current
*
* \return true if it is recursive. Otherwise, false.
*/
- bool IsRecursive() const {
- return is_recursive_;
- }
+ bool IsRecursive() const { return is_recursive_; }
/*!
* \brief Return if the global function corresponding to the current
*
* \return true if it is both a recursive function and an entry. Otherwise, false.
*/
- bool IsRecursiveEntry() const {
- return GetRefCount() == 1 && IsRecursive();
- }
+ bool IsRecursiveEntry() const { return GetRefCount() == 1 && IsRecursive(); }
/*!
* \brief Return the topological order of the CallGraphEntry.
* \brief Implementation of dependency graph APIs.
*/
#include "dependency_graph.h"
+
#include <tvm/relay/expr_functor.h>
+
#include <unordered_set>
#include <utility>
// Creator of DependencyGraph
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
public:
- explicit Creator(support::Arena* arena)
- : arena_(arena) {}
+ explicit Creator(support::Arena* arena) : arena_(arena) {}
DependencyGraph Create(const Expr& body) {
this->VisitExpr(body);
}
}
- void VisitExpr_(const VarNode* v) final { }
+ void VisitExpr_(const VarNode* v) final {}
- void VisitExpr_(const GlobalVarNode* v) final { }
+ void VisitExpr_(const GlobalVarNode* v) final {}
- void VisitExpr_(const ConstantNode* c) final { }
+ void VisitExpr_(const ConstantNode* c) final {}
- void VisitExpr_(const OpNode* o) final { }
+ void VisitExpr_(const OpNode* o) final {}
- void VisitExpr_(const ConstructorNode* c) final { }
+ void VisitExpr_(const ConstructorNode* c) final {}
};
DependencyGraph DependencyGraph::Create(support::Arena* arena, const Expr& body) {
#define TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_
#include <tvm/relay/expr.h>
+
#include <unordered_map>
#include <vector>
-#include "../transforms/let_list.h"
+
#include "../../support/arena.h"
+#include "../transforms/let_list.h"
namespace tvm {
namespace relay {
-using support::LinkNode;
using support::LinkedList;
+using support::LinkNode;
/* DependencyGraph track input and output of an Expr.
* Additionally, dummy scope is created to model scope.
* \file feature.cc
* \brief Detect features used in Expr/Module
*/
-#include <tvm/relay/feature.h>
+#include <tvm/ir/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/ir/module.h>
+#include <tvm/relay/feature.h>
+
#include "../transforms/pass_util.h"
namespace tvm {
}
}
}
-#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \
- void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
- STMT \
- fs += f##CONSTRUCT_NAME; \
- }
-#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \
- ExprVisitor::VisitExpr_(op); \
- })
+#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \
+ void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { STMT fs += f##CONSTRUCT_NAME; }
+#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) \
+ DETECT_CONSTRUCT(CONSTRUCT_NAME, { ExprVisitor::VisitExpr_(op); })
DETECT_DEFAULT_CONSTRUCT(Var)
DETECT_DEFAULT_CONSTRUCT(GlobalVar)
DETECT_DEFAULT_CONSTRUCT(Constant)
DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
DETECT_CONSTRUCT(Function, {
- if (!op->HasNonzeroAttr(attr::kPrimitive)) {
- ExprVisitor::VisitExpr_(op);
- }
- })
+ if (!op->HasNonzeroAttr(attr::kPrimitive)) {
+ ExprVisitor::VisitExpr_(op);
+ }
+ })
DETECT_DEFAULT_CONSTRUCT(Op)
DETECT_DEFAULT_CONSTRUCT(Call)
DETECT_CONSTRUCT(Let, {
- for (const Var& v : FreeVars(op->value)) {
- if (op->var == v) {
- fs += fLetRec;
- }
+ for (const Var& v : FreeVars(op->value)) {
+ if (op->var == v) {
+ fs += fLetRec;
}
- ExprVisitor::VisitExpr_(op);
- })
+ }
+ ExprVisitor::VisitExpr_(op);
+ })
DETECT_DEFAULT_CONSTRUCT(If)
DETECT_DEFAULT_CONSTRUCT(RefCreate)
DETECT_DEFAULT_CONSTRUCT(RefRead)
return static_cast<Array<Integer>>(fs);
}
-TVM_REGISTER_GLOBAL("relay.analysis.detect_feature")
-.set_body_typed(PyDetectFeature);
+TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature);
} // namespace relay
} // namespace tvm
* We check this by ensuring the `dtype` field of a Tensor always
* contains a data type such as `int`, `float`, `uint`.
*/
+#include <tvm/ir/error.h>
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
-#include <tvm/ir/error.h>
namespace tvm {
namespace relay {
this->err_reporter.RenderErrors(mod);
}
- void CheckKindMatches(const Type& t, const Type& outer,
- Kind expected, const std::string& description) {
+ void CheckKindMatches(const Type& t, const Type& outer, Kind expected,
+ const std::string& description) {
Kind k = this->VisitType(t);
if (k != expected) {
ReportFatalError(ErrorBuilder()
- << "Incorrect kind for a " << description
- << ". Type " << t << " inside " << outer
- << " is of kind " << k
- << " but was expected to be "
- << expected);
+ << "Incorrect kind for a " << description << ". Type " << t << " inside "
+ << outer << " is of kind " << k << " but was expected to be " << expected);
}
}
- Kind VisitType_(const IncompleteTypeNode* op) override {
- return op->kind;
- }
+ Kind VisitType_(const IncompleteTypeNode* op) override { return op->kind; }
- Kind VisitType_(const TypeVarNode* op) override {
- return op->kind;
- }
+ Kind VisitType_(const TypeVarNode* op) override { return op->kind; }
- Kind VisitType_(const GlobalTypeVarNode* op) override {
- return op->kind;
- }
+ Kind VisitType_(const GlobalTypeVarNode* op) override { return op->kind; }
- Kind VisitType_(const TensorTypeNode* op) override {
- return Kind::kType;
- }
+ Kind VisitType_(const TensorTypeNode* op) override { return Kind::kType; }
Kind VisitType_(const TupleTypeNode* op) override {
// tuples should only contain normal types
for (const Type& t : op->fields) {
- CheckKindMatches(t, GetRef<TupleType>(op), Kind::kType,
- "tuple member");
+ CheckKindMatches(t, GetRef<TupleType>(op), Kind::kType, "tuple member");
}
return Kind::kType;
}
Kind VisitType_(const TypeRelationNode* op) override {
// arguments to type relation should be normal types
for (const Type& t : op->args) {
- CheckKindMatches(t, GetRef<TypeRelation>(op), Kind::kType,
- "argument to type relation");
+ CheckKindMatches(t, GetRef<TypeRelation>(op), Kind::kType, "argument to type relation");
}
return Kind::kConstraint;
}
TypeCall tc = GetRef<TypeCall>(op);
const auto* gtv = op->func.as<GlobalTypeVarNode>();
if (gtv == nullptr) {
- ReportFatalError(
- ErrorBuilder() <<"The callee in " << tc
- << " is not a global type var, but is " << op->func);
+ ReportFatalError(ErrorBuilder() << "The callee in " << tc
+ << " is not a global type var, but is " << op->func);
}
CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function");
auto var = GetRef<GlobalTypeVar>(gtv);
auto data = mod->LookupTypeDef(var);
if (data->type_vars.size() != op->args.size()) {
- ReportFatalError(ErrorBuilder()
- << "Expected " << data->type_vars.size() << "arguments for " << tc
- << "; got " << op->args.size());
+ ReportFatalError(ErrorBuilder() << "Expected " << data->type_vars.size() << "arguments for "
+ << tc << "; got " << op->args.size());
}
return Kind::kType;
}
for (const auto& con : op->constructors) {
if (!con->belong_to.same_as(op->header)) {
- ReportFatalError(ErrorBuilder()
- <<con << " has header " << con->belong_to
- << " but " << op << " has header " << op->header);
+ ReportFatalError(ErrorBuilder() << con << " has header " << con->belong_to << " but " << op
+ << " has header " << op->header);
}
for (const Type& t : con->inputs) {
return Kind::kTypeData;
}
- Kind Check(const Type& t) {
- return this->VisitType(t);
- }
+ Kind Check(const Type& t) { return this->VisitType(t); }
};
Kind KindCheck(const Type& t, const IRModule& mod) {
return kc.Check(t);
}
-TVM_REGISTER_GLOBAL("relay.analysis.check_kind")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- if (args.size() == 1) {
- *ret = KindCheck(args[0], IRModule({}, {}));
- } else {
- *ret = KindCheck(args[0], args[1]);
- }
- });
+TVM_REGISTER_GLOBAL("relay.analysis.check_kind").set_body([](TVMArgs args, TVMRetValue* ret) {
+ if (args.size() == 1) {
+ *ret = KindCheck(args[0], IRModule({}, {}));
+ } else {
+ *ret = KindCheck(args[0], args[1]);
+ }
+});
} // namespace relay
} // namespace tvm
* otherwise the count is 0.
*/
-#include <tvm/relay/op.h>
+#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/analysis.h>
+#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
+
#include "../transforms/pattern_util.h"
namespace tvm {
* \param call_node The call node.
* \return The number of MACs.
*/
-using FMacCount = runtime::TypedPackedFunc<
- int64_t(const Call& call_node)>;
+using FMacCount = runtime::TypedPackedFunc<int64_t(const Call& call_node)>;
//----------------------------------------------
// Per operator defs for MAC count
return 0;
}
Array<Expr> args = call_node->args;
- CHECK_EQ(args.size(), 2)
- << "The number of input arguments of a CONV 2D node should be 2.";
+ CHECK_EQ(args.size(), 2) << "The number of input arguments of a CONV 2D node should be 2.";
const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape;
std::string data_layout = conv_2d_attr->data_layout;
int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
- CHECK_NE(C_ind, -1)
- << "There is no input channel dimension.";
+ CHECK_NE(C_ind, -1) << "There is no input channel dimension.";
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value);
- if (c_ind != -1)
- input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
+ if (c_ind != -1) input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
- CHECK_EQ(kernel_size.size(), 2)
- << "The dimension of the kernel in Conv 2D should be 2.";
+ CHECK_EQ(kernel_size.size(), 2) << "The dimension of the kernel in Conv 2D should be 2.";
const auto* expr = call_node->checked_type().as<TensorTypeNode>();
Array<IndexExpr> output_tensor = expr->shape;
CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
- << "The dimension of the output tensor in Conv 2D should be 4 or 5.";
+ << "The dimension of the output tensor in Conv 2D should be 4 or 5.";
int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
CHECK_EQ(input_channel % conv_2d_attr->groups, 0)
- << "The number of input channels is not divisble by groups.";
- count *= input_channel/conv_2d_attr->groups;
+ << "The number of input channels is not divisble by groups.";
+ count *= input_channel / conv_2d_attr->groups;
return count;
}
}
Array<Expr> args = call_node->args;
CHECK_EQ(args.size(), 2)
- << "The number of input arguments of a CONV 2D Transpose node should be 2.";
+ << "The number of input arguments of a CONV 2D Transpose node should be 2.";
const auto* conv_2d_transpose_attr = call_node->attrs.as<Conv2DTransposeAttrs>();
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape;
std::string data_layout = conv_2d_transpose_attr->data_layout;
int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
- CHECK_NE(C_ind, -1)
- << "There is no input channel dimension.";
+ CHECK_NE(C_ind, -1) << "There is no input channel dimension.";
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value);
- if (c_ind != -1)
- input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
+ if (c_ind != -1) input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size;
CHECK_EQ(kernel_size.size(), 2)
- << "The dimension of the kernel in Conv 2D Transpose should be 2.";
+ << "The dimension of the kernel in Conv 2D Transpose should be 2.";
const auto* expr = call_node->checked_type().as<TensorTypeNode>();
Array<IndexExpr> output_tensor = expr->shape;
CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
- << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5.";
+ << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5.";
int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
CHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0)
- << "The number of input channels is not divisble by groups.";
- count *= input_channel/conv_2d_transpose_attr->groups;
+ << "The number of input channels is not divisble by groups.";
+ count *= input_channel / conv_2d_transpose_attr->groups;
return count;
}
return 0;
}
Array<Expr> args = call_node->args;
- CHECK_EQ(args.size(), 2)
- << "The number of input arguments of a Dense node should be 2.";
+ CHECK_EQ(args.size(), 2) << "The number of input arguments of a Dense node should be 2.";
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape;
Array<IndexExpr> weight_shape = weight_type->shape;
CHECK(data_shape.size() == 2 && weight_shape.size() == 2)
- << "The dimension of an input tensor to Dense node should be 2.";
+ << "The dimension of an input tensor to Dense node should be 2.";
int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImmNode>()->value);
int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImmNode>()->value);
int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImmNode>()->value);
int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImmNode>()->value);
- CHECK_EQ(d2, d4)
- << "The dimensions of input arguments do not match.";
+ CHECK_EQ(d2, d4) << "The dimensions of input arguments do not match.";
int64_t count = d1 * d2 * d3;
return count;
}
return batch * m * k * n;
}
-RELAY_REGISTER_OP("nn.conv2d")
-.set_attr<FMacCount>("FMacCount", ConvMacCount);
+RELAY_REGISTER_OP("nn.conv2d").set_attr<FMacCount>("FMacCount", ConvMacCount);
-RELAY_REGISTER_OP("nn.conv2d_transpose")
-.set_attr<FMacCount>("FMacCount", Conv2dTransposeMacCount);
+RELAY_REGISTER_OP("nn.conv2d_transpose").set_attr<FMacCount>("FMacCount", Conv2dTransposeMacCount);
-RELAY_REGISTER_OP("nn.dense")
-.set_attr<FMacCount>("FMacCount", DenseMacCount);
+RELAY_REGISTER_OP("nn.dense").set_attr<FMacCount>("FMacCount", DenseMacCount);
-RELAY_REGISTER_OP("nn.batch_matmul")
-.set_attr<FMacCount>("FMacCount", BatchMatmulMacCount);
+RELAY_REGISTER_OP("nn.batch_matmul").set_attr<FMacCount>("FMacCount", BatchMatmulMacCount);
class MacCounter : private ExprVisitor {
public:
- MacCounter() {
- count_ = 0;
- }
+ MacCounter() { count_ = 0; }
static int64_t GetTotalMacNumber(const Expr& expr) {
LOG(INFO) << "This pass only counts MACs in direct conv2d, "
<< "conv2d_transpose, dense, and batch_matmul ops";
private:
void VisitExpr_(const CallNode* call_node) final {
- static const auto& fprep =
- Op::GetAttr<FMacCount>("FMacCount");
+ static const auto& fprep = Op::GetAttr<FMacCount>("FMacCount");
auto f = fprep.get(call_node->op, nullptr);
if (f != nullptr) count_ += f(GetRef<Call>(call_node));
ExprVisitor::VisitExpr_(call_node);
int64_t count_;
};
-int64_t GetTotalMacNumber(const Expr& expr) {
- return MacCounter::GetTotalMacNumber(expr);
-}
+int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); }
-TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber")
-.set_body_typed(GetTotalMacNumber);
+TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber").set_body_typed(GetTotalMacNumber);
} // namespace mac_count
} // namespace relay
* code correctness, since hitting an unmatched case results in a
* dynamic error unless exhaustiveness is checked in advance.
*/
-#include <tvm/relay/adt.h>
#include <tvm/ir/error.h>
+#include <tvm/relay/adt.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
+
#include <stack>
namespace tvm {
}
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
- const Pattern& cand,
- const IRModule& mod);
+ const Pattern& cand, const IRModule& mod);
-Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
- const Pattern& cand,
+Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand,
const IRModule& mod);
// Expands all wildcards in the candidate pattern once
// Returns a list of all possible expansions.
-Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
- const Pattern& cand,
+Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
const IRModule& mod) {
if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
// Use the pattern to decide which constructors to insert.
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
- const Pattern& cand,
- const IRModule& mod) {
+ const Pattern& cand, const IRModule& mod) {
auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
// for a wildcard node, create constructor nodes with wildcards for all args.
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
- values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
- ctor_cand->patterns[i],
- mod));
+ values_by_field.push_back(
+ ExpandWildcards(clause_ctor->patterns[i], ctor_cand->patterns[i], mod));
}
// generate new candidates using a cartesian product.
// Expands all wildcards in the candidate pattern once.
// Returns a list of all possible expansions.
-Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
- const Pattern& cand,
+Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand,
const IRModule& mod) {
// for a wildcard node, create constructor nodes with wildcards for all args.
if (cand.as<PatternWildcardNode>()) {
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
- values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
- tuple_cand->patterns[i],
- mod));
+ values_by_field.push_back(
+ ExpandWildcards(clause_tuple->patterns[i], tuple_cand->patterns[i], mod));
}
// generate new candidates using a cartesian product
// expose for testing only
TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases")
-.set_body_typed(
- [](const Match& match, const IRModule& mod_ref) {
- IRModule call_mod = mod_ref;
- if (!call_mod.defined()) {
- call_mod = IRModule({}, {});
- }
- return UnmatchedCases(match, call_mod);
- });
+ .set_body_typed([](const Match& match, const IRModule& mod_ref) {
+ IRModule call_mod = mod_ref;
+ if (!call_mod.defined()) {
+ call_mod = IRModule({}, {});
+ }
+ return UnmatchedCases(match, call_mod);
+ });
} // namespace relay
} // namespace tvm
* \file type_solver.cc
* \brief Type solver implementations.
*/
-#include <tvm/node/structural_equal.h>
+#include "type_solver.h"
+
#include <tvm/ir/type_functor.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/tir/op.h>
-#include <string>
+
#include <memory>
+#include <string>
#include <tuple>
#include <utility>
-#include "type_solver.h"
namespace tvm {
namespace relay {
class TypeSolver::Reporter : public TypeReporterNode {
public:
- explicit Reporter(TypeSolver* solver)
- : solver_(solver) {}
+ explicit Reporter(TypeSolver* solver) : solver_(solver) {}
- void Assign(const Type& dst, const Type& src) final {
- solver_->Unify(dst, src, location);
- }
+ void Assign(const Type& dst, const Type& src) final { solver_->Unify(dst, src, location); }
bool Assert(const IndexExpr& cond) final {
if (const int64_t* pdiff = tir::as_const_int(cond)) {
return true;
}
- TVM_DLL void SetLocation(const ObjectRef& ref) final {
- location = ref;
- }
+ TVM_DLL void SetLocation(const ObjectRef& ref) final { location = ref; }
- TVM_DLL IRModule GetModule() final {
- return this->solver_->module_;
- }
+ TVM_DLL IRModule GetModule() final { return this->solver_->module_; }
private:
/*! \brief The location to report unification errors at. */
class TypeSolver::OccursChecker : public TypeVisitor {
public:
explicit OccursChecker(TypeSolver* solver, TypeNode* var)
- : solver_(solver), var_(var), found_(false) {}
+ : solver_(solver), var_(var), found_(false) {}
bool Check(const Type& t) {
VisitType(t);
if (lhs->resolved_type.as<IncompleteTypeNode>()) {
CHECK(!OccursCheck(lhs, rhs->resolved_type))
- << "Incomplete type " << lhs->resolved_type << " occurs in "
- << rhs->resolved_type << ", cannot unify";
+ << "Incomplete type " << lhs->resolved_type << " occurs in " << rhs->resolved_type
+ << ", cannot unify";
solver_->MergeFromTo(lhs, rhs);
return rhs->resolved_type;
} else if (rhs->resolved_type.as<IncompleteTypeNode>()) {
CHECK(!OccursCheck(rhs, lhs->resolved_type))
- << "Incomplete type " << rhs->resolved_type << " occurs in "
- << lhs->resolved_type << ", cannot unify";
+ << "Incomplete type " << rhs->resolved_type << " occurs in " << lhs->resolved_type
+ << ", cannot unify";
solver_->MergeFromTo(rhs, lhs);
return lhs->resolved_type;
} else {
Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
if (!resolved.defined()) {
- solver_->ReportError(
- ErrorBuilder() << "unable to unify: "
- << "`" << PrettyPrint(lhs->resolved_type) << "` and `"
- << PrettyPrint(rhs->resolved_type) << "`",
- this->loc);
+ solver_->ReportError(ErrorBuilder() << "unable to unify: "
+ << "`" << PrettyPrint(lhs->resolved_type) << "` and `"
+ << PrettyPrint(rhs->resolved_type) << "`",
+ this->loc);
return lhs->resolved_type;
} else {
TypeNode* top = solver_->GetTypeNode(resolved);
tvm::Array<IndexExpr> shape;
if (tt1->shape.size() != tt2->shape.size()) {
- this->solver_->ReportError(
- ErrorBuilder() <<
- "tensor type `" << PrettyPrint(tt1) <<
- "` has " << tt1->shape.size() <<
- " dimensions, while `" <<
- PrettyPrint(tt2) <<
- "` has " << tt2->shape.size() <<
- " dimensions", this->loc);
+ this->solver_->ReportError(ErrorBuilder() << "tensor type `" << PrettyPrint(tt1) << "` has "
+ << tt1->shape.size() << " dimensions, while `"
+ << PrettyPrint(tt2) << "` has " << tt2->shape.size()
+ << " dimensions",
+ this->loc);
return Type(nullptr);
}
ErrorBuilder err;
err << "in particular ";
for (auto mismatch : mismatches) {
- err << "dimension "
- << std::get<0>(mismatch)
- << " conflicts "
- << std::get<1>(mismatch)
- << " does not match "
- << std::get<2>(mismatch);
+ err << "dimension " << std::get<0>(mismatch) << " conflicts " << std::get<1>(mismatch)
+ << " does not match " << std::get<2>(mismatch);
}
Error error(err);
this->solver_->ReportError(error, this->loc);
Type VisitType_(const FuncTypeNode* op, const Type& tn) final {
const auto* ftn = tn.as<FuncTypeNode>();
- if (!ftn
- || op->arg_types.size() != ftn->arg_types.size()
- || op->type_constraints.size() != ftn->type_constraints.size()) {
+ if (!ftn || op->arg_types.size() != ftn->arg_types.size() ||
+ op->type_constraints.size() != ftn->type_constraints.size()) {
return Type(nullptr);
}
subst_map.Set(op->type_params[i], IncompleteType(kType));
}
- FuncType ft = FuncType(op->arg_types,
- op->ret_type,
- ft_type_params,
- op->type_constraints);
+ FuncType ft = FuncType(op->arg_types, op->ret_type, ft_type_params, op->type_constraints);
auto ft1 = Downcast<FuncType>(Bind(ft, subst_map));
auto ft2 = GetRef<FuncType>(ftn);
std::vector<TypeConstraint> type_constraints;
for (size_t i = 0; i < ft1->type_constraints.size(); ++i) {
- Type unified_constraint = Unify(ft1->type_constraints[i],
- ft2->type_constraints[i]);
+ Type unified_constraint = Unify(ft1->type_constraints[i], ft2->type_constraints[i]);
const auto* tcn = unified_constraint.as<TypeConstraintNode>();
CHECK(tcn) << "Two type constraints unified into a non-constraint?"
<< ft1->type_constraints[i] << " and " << ft2->type_constraints[i];
class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> {
public:
explicit Propagator(TypeSolver* solver, const std::unordered_set<RelationNode*>* rels)
- : solver_(solver), rels_(rels) {}
+ : solver_(solver), rels_(rels) {}
// adds the relation node to t and all child types of t
- void Propagate(const Type& t) {
- VisitType(t);
- }
+ void Propagate(const Type& t) { VisitType(t); }
void UpdateRelSet(const Type& t) {
TypeNode* tnode = solver_->GetTypeNode(t);
};
// constructor
-TypeSolver::TypeSolver(
- const GlobalVar& current_func,
- const IRModule& module,
- ErrorReporter* err_reporter)
+TypeSolver::TypeSolver(const GlobalVar& current_func, const IRModule& module,
+ ErrorReporter* err_reporter)
: reporter_(make_object<Reporter>(this)),
current_func(current_func),
err_reporter_(err_reporter),
return unifier.Unify(dst, src);
}
-void TypeSolver::ReportError(const Error& err, const ObjectRef& location) {
+void TypeSolver::ReportError(const Error& err, const ObjectRef& location) {
CHECK(location.defined());
CHECK(current_func.defined());
err_reporter_->ReportAt(current_func, location, err);
// populate the type information.
for (size_t i = 0; i < op->args.size(); ++i) {
// insert link to the type list
- LinkNode<TypeNode*>* tlink = arena_.make<LinkNode<TypeNode*> >();
+ LinkNode<TypeNode*>* tlink = arena_.make<LinkNode<TypeNode*>>();
TypeNode* tnode = GetTypeNode(op->args[i]);
tlink->value = tnode;
rnode->type_list.Push(tlink);
// insert type->relation node
- std::unordered_set<RelationNode*> singleton { rnode };
+ std::unordered_set<RelationNode*> singleton{rnode};
Propagator prop(this, &singleton);
prop.Propagate(tnode->resolved_type);
}
// add the relation to the working queue.
this->AddToQueue(rnode);
} else {
- LOG(FATAL) << "Do not know how to handle constraint type"
- << constraint->GetTypeKey();
+ LOG(FATAL) << "Do not know how to handle constraint type" << constraint->GetTypeKey();
}
}
rnode->resolved = false;
} catch (const dmlc::Error& err) {
rnode->resolved = false;
- this->ReportError(
- ErrorBuilder() << "an internal invariant was violated while "
- << "typechecking your program "
- << err.what(),
- rnode->location);
+ this->ReportError(ErrorBuilder() << "an internal invariant was violated while "
+ << "typechecking your program " << err.what(),
+ rnode->location);
}
// Mark inqueue as false after the function call
// Expose type solver only for debugging purposes.
TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver")
-.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
- using runtime::PackedFunc;
- using runtime::TypedPackedFunc;
- ErrorReporter *err_reporter = new ErrorReporter();
- auto module = IRModule({}, {});
- auto dummy_fn_name = GlobalVar("test");
- module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array<relay::Expr>({})), Type(), {}, {}));
- auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
-
- auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc {
- if (name == "Solve") {
- return TypedPackedFunc<bool()>([solver]() {
- return solver->Solve();
- });
- } else if (name == "Unify") {
- return TypedPackedFunc<Type(Type, Type)>(
- [module, solver, err_reporter](Type lhs, Type rhs) {
- auto res = solver->Unify(lhs, rhs, lhs);
- if (err_reporter->AnyErrors()) {
- err_reporter->RenderErrors(module, true);
- }
- return res;
- });
- } else if (name == "Resolve") {
- return TypedPackedFunc<Type(Type)>([solver](Type t) {
- return solver->Resolve(t);
- });
- } else if (name == "AddConstraint") {
- return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
- Expr e = Var("dummy_var",
- IncompleteType(Kind::kType));
+ .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
+ using runtime::PackedFunc;
+ using runtime::TypedPackedFunc;
+ ErrorReporter* err_reporter = new ErrorReporter();
+ auto module = IRModule({}, {});
+ auto dummy_fn_name = GlobalVar("test");
+ module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array<relay::Expr>({})), Type(), {}, {}));
+ auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
+
+ auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc {
+ if (name == "Solve") {
+ return TypedPackedFunc<bool()>([solver]() { return solver->Solve(); });
+ } else if (name == "Unify") {
+ return TypedPackedFunc<Type(Type, Type)>(
+ [module, solver, err_reporter](Type lhs, Type rhs) {
+ auto res = solver->Unify(lhs, rhs, lhs);
+ if (err_reporter->AnyErrors()) {
+ err_reporter->RenderErrors(module, true);
+ }
+ return res;
+ });
+ } else if (name == "Resolve") {
+ return TypedPackedFunc<Type(Type)>([solver](Type t) { return solver->Resolve(t); });
+ } else if (name == "AddConstraint") {
+ return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
+ Expr e = Var("dummy_var", IncompleteType(Kind::kType));
return solver->AddConstraint(c, e);
});
- } else {
- return PackedFunc();
- }
- };
- *ret = runtime::TypedPackedFunc<runtime::PackedFunc(std::string)>(mod);
- });
+ } else {
+ return PackedFunc();
+ }
+ };
+ *ret = runtime::TypedPackedFunc<runtime::PackedFunc(std::string)>(mod);
+ });
} // namespace relay
} // namespace tvm
#ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
#define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
+#include <tvm/ir/error.h>
+#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/ir/error.h>
-#include <vector>
+
#include <queue>
#include <unordered_map>
#include <unordered_set>
+#include <vector>
+
#include "../../support/arena.h"
namespace tvm {
namespace relay {
-using support::LinkNode;
using support::LinkedList;
+using support::LinkNode;
/*!
* \brief Interface of type solver used in type inference.
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h>
+
#include "../transforms/pass_util.h"
namespace tvm {
namespace relay {
-template<typename T>
+template <typename T>
struct InsertionSet {
std::unordered_set<T, ObjectHash, ObjectEqual> set;
std::vector<T> data;
class TypeVarTVisitor : public TypeVisitor {
public:
- TypeVarTVisitor(
- InsertionSet<TypeVar>* type_vars,
- InsertionSet<TypeVar>* bound_type_vars)
- : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { }
+ TypeVarTVisitor(InsertionSet<TypeVar>* type_vars, InsertionSet<TypeVar>* bound_type_vars)
+ : type_vars_(type_vars), bound_type_vars_(bound_type_vars) {}
void VisitType_(const TypeVarNode* tp) final {
TypeVar var = GetRef<TypeVar>(tp);
}
void VisitType(const Type& t) final {
- TypeVarTVisitor(&type_vars_, &bound_type_vars_)
- .VisitType(t);
+ TypeVarTVisitor(&type_vars_, &bound_type_vars_).VisitType(t);
}
private:
vars_.Insert(v);
}
- void VisitExpr_(const VarNode* var) final {
- vars_.Insert(GetRef<Var>(var));
- }
+ void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
void VisitExpr_(const FunctionNode* op) final {
for (const auto& param : op->params) {
VisitExpr(op->body);
}
- void VisitPattern(const Pattern& p) final {
- PatternVisitor::VisitPattern(p);
- }
+ void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
- void VisitPattern_(const PatternVarNode* op) final {
- MarkBounded(op->var);
- }
+ void VisitPattern_(const PatternVarNode* op) final { MarkBounded(op->var); }
private:
InsertionSet<Var> vars_;
return TypeVarEVisitor(mod).All(type);
}
-tvm::Array<Var> FreeVars(const Expr& expr) {
- return VarVisitor().Free(expr);
-}
+tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
-tvm::Array<Var> BoundVars(const Expr& expr) {
- return VarVisitor().Bound(expr);
-}
+tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); }
-tvm::Array<Var> BoundVars(const Pattern& pat) {
- return VarVisitor().Bound(pat);
-}
+tvm::Array<Var> BoundVars(const Pattern& pat) { return VarVisitor().Bound(pat); }
-tvm::Array<Var> AllVars(const Expr& expr) {
- return VarVisitor().All(expr);
-}
+tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }
-TVM_REGISTER_GLOBAL("relay.analysis.free_vars")
-.set_body_typed(FreeVars);
+TVM_REGISTER_GLOBAL("relay.analysis.free_vars").set_body_typed(FreeVars);
-TVM_REGISTER_GLOBAL("relay.analysis.bound_vars")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- ObjectRef x = args[0];
- if (x.as<ExprNode>()) {
- *ret = BoundVars(Downcast<Expr>(x));
- } else {
- *ret = BoundVars(Downcast<Pattern>(x));
- }
- });
+TVM_REGISTER_GLOBAL("relay.analysis.bound_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
+ ObjectRef x = args[0];
+ if (x.as<ExprNode>()) {
+ *ret = BoundVars(Downcast<Expr>(x));
+ } else {
+ *ret = BoundVars(Downcast<Pattern>(x));
+ }
+});
-TVM_REGISTER_GLOBAL("relay.analysis.all_vars")
-.set_body_typed(AllVars);
+TVM_REGISTER_GLOBAL("relay.analysis.all_vars").set_body_typed(AllVars);
-TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- ObjectRef x = args[0];
- IRModule mod = args[1];
- if (x.as<TypeNode>()) {
- *ret = FreeTypeVars(Downcast<Type>(x), mod);
- } else {
- *ret = FreeTypeVars(Downcast<Expr>(x), mod);
- }
- });
-
-TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- ObjectRef x = args[0];
- IRModule mod = args[1];
- if (x.as<TypeNode>()) {
- *ret = BoundTypeVars(Downcast<Type>(x), mod);
- } else {
- *ret = BoundTypeVars(Downcast<Expr>(x), mod);
- }
- });
-
-TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- ObjectRef x = args[0];
- IRModule mod = args[1];
- if (x.as<TypeNode>()) {
- *ret = AllTypeVars(Downcast<Type>(x), mod);
- } else {
- *ret = AllTypeVars(Downcast<Expr>(x), mod);
- }
- });
+TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
+ ObjectRef x = args[0];
+ IRModule mod = args[1];
+ if (x.as<TypeNode>()) {
+ *ret = FreeTypeVars(Downcast<Type>(x), mod);
+ } else {
+ *ret = FreeTypeVars(Downcast<Expr>(x), mod);
+ }
+});
+
+TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
+ ObjectRef x = args[0];
+ IRModule mod = args[1];
+ if (x.as<TypeNode>()) {
+ *ret = BoundTypeVars(Downcast<Type>(x), mod);
+ } else {
+ *ret = BoundTypeVars(Downcast<Expr>(x), mod);
+ }
+});
+
+TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
+ ObjectRef x = args[0];
+ IRModule mod = args[1];
+ if (x.as<TypeNode>()) {
+ *ret = AllTypeVars(Downcast<Type>(x), mod);
+ } else {
+ *ret = AllTypeVars(Downcast<Expr>(x), mod);
+ }
+});
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
-std::unordered_map<const Object*, size_t>
-GetExprRefCount(const Expr& body) {
+std::unordered_map<const Object*, size_t> GetExprRefCount(const Expr& body) {
class ExprRefCounter : private MixedModeVisitor {
public:
- std::unordered_map<const Object*, size_t>
- Get(const Expr& body) {
+ std::unordered_map<const Object*, size_t> Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
}
} else if (const auto* op = expr.as<CallNode>()) {
// tail recursion.
- if (op->op == expand_dims_op ||
- op->op == reshape_op ||
- op->op == transpose_op ||
+ if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op ||
op->op == squeeze_op) {
return IsAllPositiveConstant(op->args[0]);
} else {
Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
class TypeSubstMutator : public ExprMutator, public PatternMutator {
public:
- explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) { }
- Type VisitType(const Type& t) final {
- return TypeSubst(t, subst_map_);
- }
- Var VisitVar(const Var& v) final {
- return Downcast<Var>(VisitExpr(v));
- }
+ explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) {}
+ Type VisitType(const Type& t) final { return TypeSubst(t, subst_map_); }
+ Var VisitVar(const Var& v) final { return Downcast<Var>(VisitExpr(v)); }
- Pattern VisitPattern(const Pattern& p) final {
- return PatternMutator::VisitPattern(p);
- }
+ Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
+
#include <unordered_set>
namespace tvm {
namespace relay {
-
//! brief make sure each Var is bound at most once in a scope.
class WellFormedChecker : private ExprVisitor, PatternVisitor {
bool well_formed = true;
struct Scope {
WellFormedChecker* wfc;
- explicit Scope(WellFormedChecker* wfc) : wfc(wfc) {
- wfc->scope.push_back({{}});
- }
+ explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { wfc->scope.push_back({{}}); }
~Scope() {
CHECK_GE(wfc->scope.size(), 0);
for (const Var& v : wfc->scope.back()) {
VisitExpr(c->rhs);
}
- void VisitPattern(const Pattern& p) final {
- PatternVisitor::VisitPattern(p);
- }
+ void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
- void VisitVar(const Var& v) final {
- Bound(v);
- }
+ void VisitVar(const Var& v) final { Bound(v); }
void VisitExpr(const Expr& e) final {
if (auto v = e.as<VarNode>()) {
}
};
-bool WellFormed(const Expr& e) {
- return WellFormedChecker().CheckWellFormed(e);
-}
+bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); }
-TVM_REGISTER_GLOBAL("relay.analysis.well_formed")
-.set_body_typed(WellFormed);
+TVM_REGISTER_GLOBAL("relay.analysis.well_formed").set_body_typed(WellFormed);
} // namespace relay
} // namespace tvm
* \file relay/backend/build_module.cc
* \brief Code generation for TVM's graph runtime.
*/
-#include <tvm/relay/analysis.h>
#include <tvm/driver/driver_api.h>
-#include <tvm/runtime/device_api.h>
-#include <tvm/runtime/vm.h>
+#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
-#include <tvm/relay/transform.h>
#include <tvm/relay/qnn/transform.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/vm.h>
+
#include <memory>
#include "../../target/source/codegen_source_base.h"
namespace relay {
namespace backend {
-
using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform;
}
~GraphCodegen() {}
- void Init(runtime::Module* m, TargetsMap targets) {
- CallFunc("init", m, targets);
- }
+ void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); }
- void Codegen(const Function& func) {
- CallFunc("codegen", func);
- }
+ void Codegen(const Function& func) { CallFunc("codegen", func); }
- std::string GetJSON() {
- return CallFunc<std::string>("get_graph_json", nullptr);
- }
+ std::string GetJSON() { return CallFunc<std::string>("get_graph_json", nullptr); }
Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
protected:
tvm::runtime::Module mod;
- template<typename R, typename ...Args>
- R CallFunc(const std::string &name, Args... args) {
+ template <typename R, typename... Args>
+ R CallFunc(const std::string& name, Args... args) {
auto pf = mod.GetFunction(name, false);
return pf(std::forward<Args>(args)...);
}
- template<typename ...Args>
- void CallFunc(const std::string &name, Args... args) {
+ template <typename... Args>
+ void CallFunc(const std::string& name, Args... args) {
auto pf = mod.GetFunction(name, false);
pf(std::forward<Args>(args)...);
return;
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "get_graph_json") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetGraphJSON();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); });
} else if (name == "get_module") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetModule();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); });
} else if (name == "build") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
this->Build(args[0], args[1], args[2]);
});
} else if (name == "list_params") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->ListParamNames();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->ListParamNames(); });
} else if (name == "get_params") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetParams();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); });
} else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[0];
});
} else if (name == "get_irmodule") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->graph_codegen_->GetIRModule();
+ *rv = this->graph_codegen_->GetIRModule();
});
} else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->graph_codegen_->GetExternalModules();
+ *rv = this->graph_codegen_->GetExternalModules();
});
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*
* \return const std::string graph_json
*/
- const std::string& GetGraphJSON() {
- return ret_.graph_json;
- }
+ const std::string& GetGraphJSON() { return ret_.graph_json; }
/*!
* \brief Get the Module object
*
* \return runtime::Module
*/
- runtime::Module GetModule() {
- return ret_.mod;
- }
+ runtime::Module GetModule() { return ret_.mod; }
/*!
* \brief List all paramter names
* \param name name of parameter
* \param data_in input DLTensor
*/
- void SetParam(const std::string& name, runtime::NDArray data_in) {
- params_[name] = data_in;
- }
+ void SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; }
/*!
* \brief type key
*
* \return const char*
*/
- const char* type_key() const final {
- return "RelayBuildModule";
- }
+ const char* type_key() const final { return "RelayBuildModule"; }
/*!
* \brief Build relay IRModule for graph runtime
* \param target Target device
* \param target_host Host target device
*/
- void Build(IRModule mod,
- const TargetsMap& targets,
- const tvm::Target& target_host) {
+ void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
targets_ = targets;
target_host_ = target_host;
BuildRelay(mod, params_);
*
* \return relay::IRModule The updated Relay IR module after optimization.
*/
- IRModule Optimize(
- IRModule relay_module,
- const TargetsMap& targets,
- const std::unordered_map<std::string, runtime::NDArray>& params) {
+ IRModule Optimize(IRModule relay_module, const TargetsMap& targets,
+ const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) {
- CHECK(relay_module->ContainGlobalVar("main"))
- << "Missing the main entry function";
+ CHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function";
GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
auto new_main = BindParamsByName(main_func, params);
// Handle heterogeneous compilation.
transform::PassContext pass_ctx = PassContext::Current();
if (targets_.size() > 1) {
- relay_module =
- RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device);
+ relay_module = RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device);
}
// Fuse the operations if it is needed.
*
* \return updated_module The updated module after device annotation.
*/
- IRModule RunDeviceAnnotationPass(const IRModule& relay_module,
- int fallback_device) {
+ IRModule RunDeviceAnnotationPass(const IRModule& relay_module, int fallback_device) {
UpdateHeterogeneousInputs(fallback_device);
auto rewrite = transform::RewriteAnnotatedOps(fallback_device);
auto updated_module = rewrite(relay_module);
break;
}
for (auto kv : annotation_map) {
- CHECK_EQ(kv.second->value, dev_type)
- << "Expressions in the function are "
- << "annotated with various device types,"
- << "but not device copy operators "
- << "found. Please check the "
- << "RewriteAnnotation pass.";
+ CHECK_EQ(kv.second->value, dev_type) << "Expressions in the function are "
+ << "annotated with various device types,"
+ << "but not device copy operators "
+ << "found. Please check the "
+ << "RewriteAnnotation pass.";
}
targets_.Set(0, CreateDefaultTarget(dev_type));
}
* \param relay_module The Relay IR module.
* \param params The parameters.
*/
- void BuildRelay(
- IRModule relay_module,
- const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
+ void BuildRelay(IRModule relay_module,
+ const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
// Relay IRModule -> IRModule optimizations.
relay_module = Optimize(relay_module, targets_, params);
// Get the updated function.
ret_.mod = tvm::codegen::CSourceModuleCreate(";", "");
}
} else {
- ret_.mod = tvm::build(
- lowered_funcs,
- target_host_,
- BuildConfig::Current());
+ ret_.mod = tvm::build(lowered_funcs, target_host_, BuildConfig::Current());
}
Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
// Import all external runtime modules.
- for (const auto& it : ext_mods)
- ret_.mod.Import(it);
+ for (const auto& it : ext_mods) ret_.mod.Import(it);
}
private:
Target GetTargetHost() {
Target target_host = target_host_;
if (!target_host_.defined()) {
- for (const auto &it : targets_) {
+ for (const auto& it : targets_) {
if (it.second->device_type == kDLCPU) {
target_host = it.second;
break;
return runtime::Module(exec);
}
-TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RelayBuildCreate();
});
TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- Map<std::string, Constant> params = args[1];
- std::unordered_map<std::string, runtime::NDArray> params_;
- for (const auto& kv : params) {
- params_[kv.first] = kv.second->data;
- }
- *rv = relay::backend::BindParamsByName(args[0], params_);
-});
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ Map<std::string, Constant> params = args[1];
+ std::unordered_map<std::string, runtime::NDArray> params_;
+ for (const auto& kv : params) {
+ params_[kv.first] = kv.second->data;
+ }
+ *rv = relay::backend::BindParamsByName(args[0], params_);
+ });
} // namespace backend
} // namespace relay
}
// TODO(@jroesch): MOVE ME
-TVM_REGISTER_GLOBAL("relay.ir.IsDynamic")
-.set_body_typed(IsDynamic);
+TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible
for (Var param : prim_func->params) {
Array<tvm::te::Tensor> inputs;
if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
- tvm::te::Tensor tensor = tvm::te::placeholder(
- GetShape(ttype->shape), ttype->dtype);
+ tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
} else {
const auto* ttype = field.as<TensorTypeNode>();
// TODO(@icemelon): Allow recursive tuple
CHECK(ttype != nullptr);
- tvm::te::Tensor tensor = tvm::te::placeholder(
- GetShape(ttype->shape), ttype->dtype);
+ tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
}
constexpr static size_t kMaxFuncNameLength = 80;
if (candidate_name.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
- truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}
CHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
- auto value = te::compute({}, [&](const Array<tvm::tir::Var>&) {
- if (dtype == DataType::Int(32)) {
- return make_const(dtype, static_cast<const int32_t*>(data)[0]);
- } else if (dtype == DataType::Int(64)) {
- return make_const(dtype, static_cast<const int64_t*>(data)[0]);
- } else if (dtype == DataType::Float(32)) {
- return make_const(dtype, static_cast<const float*>(data)[0]);
- } else if (dtype == DataType::Float(64)) {
- return make_const(dtype, static_cast<const double*>(data)[0]);
- } else if (dtype == DataType::Bool()) {
- return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
- } else {
- LOG(FATAL) << "not handled";
- return tvm::PrimExpr();
- }
- }, "compile_engine_const", topi::kBroadcast);
+ auto value = te::compute(
+ {},
+ [&](const Array<tvm::tir::Var>&) {
+ if (dtype == DataType::Int(32)) {
+ return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+ } else if (dtype == DataType::Int(64)) {
+ return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+ } else if (dtype == DataType::Float(32)) {
+ return make_const(dtype, static_cast<const float*>(data)[0]);
+ } else if (dtype == DataType::Float(64)) {
+ return make_const(dtype, static_cast<const double*>(data)[0]);
+ } else if (dtype == DataType::Bool()) {
+ return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+ } else {
+ LOG(FATAL) << "not handled";
+ return tvm::PrimExpr();
+ }
+ },
+ "compile_engine_const", topi::kBroadcast);
scalars_.push_back(value->op);
return {value};
}
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
- static auto fpattern =
- Op::GetAttr<TOpPattern>("TOpPattern");
+ static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
CHECK(flower_call) << "relay.backend.lower_call is not registered.";
}
}
if (count_tuple) {
- CHECK_EQ(call_node->args.size(), 1U)
- << "Only allow function with a single tuple input";
+ CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
}
- CHECK(call_node->op.as<OpNode>())
- << "Primitive function only allows call into primitive ops";
+ CHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
Array<te::Tensor> outputs;
// Skip fcompute for device copy operators as it is not registered.
if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->();
- outputs.push_back(te::TensorNode::make(copy_input->shape, copy_input->dtype,
- te::Operation(), 0));
+ outputs.push_back(
+ te::TensorNode::make(copy_input->shape, copy_input->dtype, te::Operation(), 0));
} else {
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
outputs = lowered_out->outputs;
int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) {
CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce)
- << "Two complicated op in a primitive function "
- << " master=" << master_op_ << " current=" << op;
+ << "Two complicated op in a primitive function "
+ << " master=" << master_op_ << " current=" << op;
}
if (op_pattern >= master_op_pattern_) {
master_op_ = op;
master_implementation_ = impl;
}
if (outputs.size() != 1) {
- const auto* tuple_type =
- call_node->checked_type().as<TupleTypeNode>();
+ const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
CHECK(tuple_type) << "Expect output to be a tuple type";
CHECK_EQ(tuple_type->fields.size(), outputs.size());
}
Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
Array<te::Tensor> fields;
for (Expr field : op->fields) {
- CHECK(field->checked_type().as<TensorTypeNode>())
- << "Only allow Tuple of Tensor";
+ CHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";
Array<te::Tensor> res = VisitExpr(field);
CHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
shape_inputs.push_back(shape_tensor);
};
- if (const auto *ttype = param->checked_type().as<TensorTypeNode>()) {
+ if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
add_placeholder(ttype);
} else {
// flatten tuple of tensor type.
- const auto *tuple_type = param->type_as<TupleTypeNode>();
+ const auto* tuple_type = param->type_as<TupleTypeNode>();
// TODO(@icemelon): Support recursive tuple
CHECK(tuple_type);
for (Type field : tuple_type->fields) {
- const auto *ttype = field.as<TensorTypeNode>();
+ const auto* ttype = field.as<TensorTypeNode>();
CHECK(ttype);
add_placeholder(ttype);
}
constexpr static size_t kMaxFuncNameLength = 80;
if (candidate_name.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
- truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}
if (data_dependant) {
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
- auto value = tvm::te::compute({}, [&](const Array<tvm::tir::Var>&) {
- if (dtype == DataType::Int(32)) {
- return make_const(dtype, static_cast<const int32_t*>(data)[0]);
- } else if (dtype == DataType::Int(64)) {
- return make_const(dtype, static_cast<const int64_t*>(data)[0]);
- } else if (dtype == DataType::Float(32)) {
- return make_const(dtype, static_cast<const float*>(data)[0]);
- } else if (dtype == DataType::Float(64)) {
- return make_const(dtype, static_cast<const double*>(data)[0]);
- } else if (dtype == DataType::Bool()) {
- return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
- } else {
- LOG(FATAL) << "not handled";
- return tvm::PrimExpr();
- }
- }, "data_const", topi::kBroadcast);
+ auto value = tvm::te::compute(
+ {},
+ [&](const Array<tvm::tir::Var>&) {
+ if (dtype == DataType::Int(32)) {
+ return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+ } else if (dtype == DataType::Int(64)) {
+ return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+ } else if (dtype == DataType::Float(32)) {
+ return make_const(dtype, static_cast<const float*>(data)[0]);
+ } else if (dtype == DataType::Float(64)) {
+ return make_const(dtype, static_cast<const double*>(data)[0]);
+ } else if (dtype == DataType::Bool()) {
+ return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+ } else {
+ LOG(FATAL) << "not handled";
+ return tvm::PrimExpr();
+ }
+ },
+ "data_const", topi::kBroadcast);
scalars_.push_back(value);
return {value};
} else {
- auto value = tvm::te::compute({}, [&](const Array<tvm::tir::Var>&) {
- return tir::make_const(DataType::Int(64), 0);
- }, "shape_const", topi::kBroadcast);
+ auto value = tvm::te::compute(
+ {}, [&](const Array<tvm::tir::Var>&) { return tir::make_const(DataType::Int(64), 0); },
+ "shape_const", topi::kBroadcast);
scalars_.push_back(value);
return {value};
}
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fshape_func = Op::GetAttr<FShapeFunc>("FShapeFunc");
- static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>(
- "TShapeDataDependant");
- CHECK(call_node->op.as<OpNode>())
- << "Primitive function only allows call into primitive ops";
+ static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>("TShapeDataDependant");
+ CHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
CHECK(data_dependants_.empty() || !data_dependants_.back())
- << "Error in op fusion: output of the shape func is fed to a "
- << "data-dependant shape func";
- CHECK_GT(fshape_func.count(op), 0)
- << "Internal error, cannot find ShapeFunc for " << op->name;
+ << "Error in op fusion: output of the shape func is fed to a "
+ << "data-dependant shape func";
+ CHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name;
CHECK_GT(tshape_data_dependant.count(op), 0)
- << "Internal error, cannot find TShapeDataDependant for " << op->name;
+ << "Internal error, cannot find TShapeDataDependant for " << op->name;
data_dependants_.push_back(tshape_data_dependant[op]);
// Visit all inputs
}
}
if (count_tuple) {
- CHECK_EQ(call_node->args.size(), 1U)
- << "Only allow function with a single tuple input";
+ CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
}
// Get output ndims
auto ret_type = call_node->checked_type();
Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
Array<te::Tensor> fields;
for (Expr field : op->fields) {
- CHECK(field->checked_type().as<TensorTypeNode>())
- << "Only allow Tuple of Tensor";
+ CHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";
Array<te::Tensor> res = VisitExpr(field);
CHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
class CompileEngineImpl : public CompileEngineNode {
public:
// Lower the function.
- CachedFunc Lower(const CCacheKey& key) {
- return LowerInternal(key)->cached_func;
- }
+ CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
// For now, build one module per function.
PackedFunc JIT(const CCacheKey& key) final {
return ret;
}
- void Clear() final {
- cache_.clear();
- }
+ void Clear() final { cache_.clear(); }
// List all items in the cache.
Array<ObjectRef> ListItems() {
std::lock_guard<std::mutex> lock(mutex_);
private:
// implement lowered func
- CCacheValue LowerInternal(const CCacheKey& key) {
+ CCacheValue LowerInternal(const CCacheKey& key) {
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = cache_.find(key);
// codegen tool once and lower all functions together.
if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
- const auto name_node =
- key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- CHECK(name_node.defined())
- << "External function has not been attached a name yet.";
+ const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ CHECK(name_node.defined()) << "External function has not been attached a name yet.";
cache_node->func_name = std::string(name_node.value());
cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node);
CHECK(!value->cached_func.defined());
auto cfunc = CreateSchedule(key->source_func, key->target);
- auto cache_node = make_object<CachedFuncNode>(
- *(cfunc.operator->()));
+ auto cache_node = make_object<CachedFuncNode>(*(cfunc.operator->()));
// Skip lowering for device copy node.
const Expr body = (key->source_func)->body;
}
// lower the function
if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
- cache_node->funcs = (*f)(
- cfunc->schedule, all_args, cache_node->func_name, key->source_func);
+ cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func);
} else {
tvm::BuildConfig bcfg = BuildConfig::Create();
std::unordered_map<te::Tensor, tir::Buffer> binds;
- cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name,
- binds, bcfg);
+ cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds, bcfg);
}
value->cached_func = CachedFunc(cache_node);
return value;
CHECK(!value->cached_func.defined());
auto spair = MakeShapeFunc().Create(key->source_func);
- auto cache_node = make_object<CachedFuncNode>(
- *(spair.second.operator->()));
+ auto cache_node = make_object<CachedFuncNode>(*(spair.second.operator->()));
cache_node->func_name = GetUniqueName(cache_node->func_name);
cache_node->target = key->target;
const CompileEngine& CompileEngine::Global() {
// intentionally allocate raw pointer to avoid
// free during destructuion.
- static CompileEngine* inst = new CompileEngine(
- make_object<CompileEngineImpl>());
+ static CompileEngine* inst = new CompileEngine(make_object<CompileEngineImpl>());
return *inst;
}
TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
-.set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
- return LoweredOutput(outputs, impl);
-});
+ .set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+ return LoweredOutput(outputs, impl);
+ });
TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
-.set_body_typed([](Function source_func, Target target) {
- return CCacheKey(source_func, target);
-});
+ .set_body_typed([](Function source_func, Target target) {
+ return CCacheKey(source_func, target);
+ });
-TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal")
-.set_body_typed([]() {
+TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() {
return CompileEngine::Global();
});
-TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear")
-.set_body_typed([](CompileEngine self) {
+TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) {
self->Clear();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
-.set_body_typed(
- [](CompileEngine self, CCacheKey key) {
- return self->Lower(key);
-});
+ .set_body_typed([](CompileEngine self, CCacheKey key) { return self->Lower(key); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
-.set_body_typed(
- [](CompileEngine self, CCacheKey key) {
- return self->LowerShapeFunc(key);
-});
+ .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); });
TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions")
-.set_body_typed([](CompileEngine self) {
- return self->LowerExternalFunctions();
-});
+ .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
-.set_body_typed(
- [](CompileEngine self, CCacheKey key) {
- return self->JIT(key);
-});
+ .set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); });
-TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems")
-.set_body_typed(
- [](CompileEngine self){
+TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) {
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
} // namespace relay
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
-#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
-#include <tvm/relay/transform.h>
#include <tvm/relay/op_strategy.h>
-#include <string>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/module.h>
+
#include <functional>
+#include <string>
namespace tvm {
namespace relay {
*/
TVM_DLL CCacheKey(Function source_func, Target target);
- const CCacheKeyNode* operator->() const {
- return static_cast<const CCacheKeyNode*>(get());
- }
+ const CCacheKeyNode* operator->() const { return static_cast<const CCacheKeyNode*>(get()); }
// comparator
inline bool operator==(const CCacheKey& other) const {
CHECK(defined() && other.defined());
public:
CCacheValue() {}
explicit CCacheValue(ObjectPtr<Object> n) : ObjectRef(n) {}
- CCacheValueNode* operator->() {
- return static_cast<CCacheValueNode*>(get_mutable());
- }
- const CCacheValueNode* operator->() const {
- return static_cast<const CCacheValueNode*>(get());
- }
+ CCacheValueNode* operator->() { return static_cast<CCacheValueNode*>(get_mutable()); }
+ const CCacheValueNode* operator->() const { return static_cast<const CCacheValueNode*>(get()); }
using ContainerType = CCacheValueNode;
};
public:
CompileEngine() {}
explicit CompileEngine(ObjectPtr<Object> n) : ObjectRef(n) {}
- CompileEngineNode* operator->() {
- return static_cast<CompileEngineNode*>(get_mutable());
- }
+ CompileEngineNode* operator->() { return static_cast<CompileEngineNode*>(get_mutable()); }
using ContainerType = CompileEngineNode;
/*! \brief The global compile engine. */
TVM_DLL static const CompileEngine& Global();
if (hash_ != 0) return hash_;
// do structral hash, avoid 0.
hash_ = tvm::StructuralHash()(this->source_func);
- hash_ = dmlc::HashCombine(
- hash_, std::hash<std::string>()(target->str()));
+ hash_ = dmlc::HashCombine(hash_, std::hash<std::string>()(target->str()));
if (hash_ == 0) hash_ = 1;
return hash_;
}
-inline bool CCacheKeyNode::Equal(
- const CCacheKeyNode* other) const {
+inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const {
if (Hash() != other->Hash()) return false;
return this->target->str() == other->target->str() &&
- tvm::StructuralEqual()(this->source_func, other->source_func);
+ tvm::StructuralEqual()(this->source_func, other->source_func);
}
} // namespace relay
namespace std {
// overload hash
-template<>
+template <>
struct hash<::tvm::relay::CCacheKey> {
size_t operator()(const ::tvm::relay::CCacheKey& key) const {
CHECK(key.defined());
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
- buf_stream << dtype << "* " << out <<
- " = (" << dtype << "*)std::malloc(4 * " << out_size << ");";
+ buf_stream << dtype << "* " << out << " = (" << dtype << "*)std::malloc(4 * " << out_size
+ << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out << ");";
#define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_
#include <tvm/relay/expr.h>
-#include <tvm/relay/op.h>
#include <tvm/relay/function.h>
+#include <tvm/relay/op.h>
#include <tvm/runtime/container.h>
+
#include <sstream>
#include <string>
#include <utility>
* \return An external symbol.
*/
std::string GetExtSymbol(const Function& func) const {
- const auto name_node =
- func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node.value());
}
*
* \endcode
*/
- void GenerateBackendCFunc(const std::string& func_name,
- const Array<Var>& args,
+ void GenerateBackendCFunc(const std::string& func_name, const Array<Var>& args,
const Output& out) {
// Print signature
code_stream_ << "\n";
code_stream_ << "}\n\n";
// Generate the macro
- code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", "
- << func_name << "_wrapper_);\n\n";
+ code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " << func_name
+ << "_wrapper_);\n\n";
}
/*!
*/
std::string JitImpl(const std::string& ext_func_id, const Array<Var>& args,
const std::vector<std::string>& buf_decl,
- const std::vector<std::string>& body,
- const std::vector<Output>& out) {
+ const std::vector<std::string>& body, const std::vector<Output>& out) {
// Create the signature. For example, it could be:
// extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {}
code_stream_ << "extern \"C\" void " << ext_func_id << "_(";
// Allocate large arrays on the static section to avoid stakc overflow.
// Note that this would probably increase compilation time as the source
// file could be really large.
- buf_stream << "static float " << output.name << "[" << num_elems <<"] = {";
+ buf_stream << "static float " << output.name << "[" << num_elems << "] = {";
for (int64_t i = 0; i < num_elems - 1; i++) {
buf_stream << ptr[i] << ",";
}
* \brief Memory index assignment pass for executing
* the program in the graph runtime.
*/
-#include <tvm/tir/op.h>
+#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/analysis.h>
+#include <tvm/tir/op.h>
+
#include "../../support/arena.h"
namespace tvm {
}
}
- void VisitExpr_(const ConstantNode* op) final {
- this->CreateToken(op, false);
- }
+ void VisitExpr_(const ConstantNode* op) final { this->CreateToken(op, false); }
void VisitExpr_(const VarNode* op) final {
// Do nothing.
token_map_[op] = {tok[op->index]};
}
- void VisitExpr_(const IfNode* op) final {
- LOG(FATAL) << "if is not supported.";
- }
+ void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; }
void VisitExpr_(const LetNode* op) final {
auto token = GetToken(op->value);
class StorageAllocaInit : protected StorageAllocaBaseVisitor {
public:
- explicit StorageAllocaInit(support::Arena* arena)
- : arena_(arena) {}
+ explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {}
/*! \return The internal token map */
- std::unordered_map<const ExprNode*, std::vector<StorageToken*> >
- GetInitTokenMap(const Function& func) {
+ std::unordered_map<const ExprNode*, std::vector<StorageToken*> > GetInitTokenMap(
+ const Function& func) {
node_device_map_ = CollectDeviceInfo(func);
this->Run(func);
return std::move(token_map_);
protected:
using StorageAllocaBaseVisitor::VisitExpr_;
- void CreateToken(const ExprNode* op, bool can_realloc) final {
+ void CreateToken(const ExprNode* op, bool can_realloc) final {
CHECK(!token_map_.count(op));
std::vector<StorageToken*> tokens;
- int device_type = node_device_map_.count(GetRef<Expr>(op))
- ? node_device_map_[GetRef<Expr>(op)]->value
- : 0;
+ int device_type =
+ node_device_map_.count(GetRef<Expr>(op)) ? node_device_map_[GetRef<Expr>(op)]->value : 0;
if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) {
const auto* ttype = t.as<TensorTypeNode>();
}
// Either all or none of the nodes should be annotated.
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
- LOG(FATAL)
- << num_annotated_nodes << " out of " << num_nodes
- << "expressions are assigned with virtual device types. Either all "
- "or none of the expressions are expected to be annotated.";
+ LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes
+ << "expressions are assigned with virtual device types. Either all "
+ "or none of the expressions are expected to be annotated.";
}
return smap;
}
size_t size = 1;
for (IndexExpr dim : ttype->shape) {
const int64_t* pval = tir::as_const_int(dim);
- CHECK(pval != nullptr)
- << "Cannot allocate memory symbolic tensor shape "
- << ttype->shape;
- CHECK_GE(*pval, 0)
- << "Cannot allocate memory for tensor with negative shape"
- << *pval;
+ CHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape;
+ CHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval;
size *= static_cast<size_t>(pval[0]);
}
size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8);
auto end = free_.upper_bound(size * match_range_);
// search for memory blocks larger than requested
for (auto it = mid; it != end; ++it) {
- StorageToken *tok = it->second;
+ StorageToken* tok = it->second;
if (tok->device_type != prototype->device_type) continue;
CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy
// then search for memory blocks smaller than requested space
for (auto it = mid; it != begin;) {
--it;
- StorageToken *tok = it->second;
+ StorageToken* tok = it->second;
if (tok->device_type != prototype->device_type) continue;
CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy
return StorageAllocator().Plan(func);
}
-TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory")
-.set_body_typed(GraphPlanMemory);
+TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory);
} // namespace relay
} // namespace tvm
class GraphOpNode;
using IntegerArray = Array<Integer>;
-using ShapeVector = std::vector<std::vector<int64_t> >;
+using ShapeVector = std::vector<std::vector<int64_t>>;
using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
using GraphObjectPtr = std::shared_ptr<GraphNode>;
using GraphInputObjectPtr = std::shared_ptr<GraphInputNode>;
public:
GraphNodeRef() {}
GraphNodeRef(int ident, int index, int version = 0)
- : ident_(ident), index_(index), version_(version) {}
-
+ : ident_(ident), index_(index), version_(version) {}
inline void Save(dmlc::JSONWriter* writer) const {
writer->BeginArray();
writer->EndArray();
}
- inline void Load(dmlc::JSONReader* reader) {
- LOG(FATAL) << "Not implemented.";
- }
+ inline void Load(dmlc::JSONReader* reader) { LOG(FATAL) << "Not implemented."; }
protected:
int ident_;
class GraphOpNode : public GraphNode {
public:
GraphOpNode() {}
- GraphOpNode(const std::string& name,
- const GraphAttrs& nd_attrs,
- const std::string& op_name,
- const std::vector<GraphNodeRef>& inputs,
- const GraphAttrs& attrs,
+ GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name,
+ const std::vector<GraphNodeRef>& inputs, const GraphAttrs& attrs,
size_t num_outputs = 1) {
name_ = name;
attrs_ = nd_attrs;
const GraphAttrs& nd_attrs,
const std::string& op_name,
const std::vector<GraphNodeRef>& inputs,
- const GraphAttrs& attrs,
- size_t num_outputs = 1) {
+ const GraphAttrs& attrs, size_t num_outputs = 1) {
auto ptr = std::make_shared<GraphOpNode>(name, nd_attrs, op_name, inputs, attrs, num_outputs);
return std::dynamic_pointer_cast<GraphNode>(ptr);
}
return fields;
}
- std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op,
- const std::string& op_name,
+ std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& op_name,
const std::string& func_name) {
std::vector<GraphNodeRef> inputs;
for (auto arg : op->args) {
inputs.push_back(nr);
}
}
- auto node = GraphOpNode::make_node_ptr(op_name,
- GraphAttrs(),
- func_name,
- inputs,
- GraphAttrs());
+ auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, GraphAttrs());
return AddNode(node, GetRef<Expr>(op));
}
}
CHECK_GE(storage_device_map_.count(expr), 0);
- auto &device_type = storage_device_map_[expr][1];
+ auto& device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value;
// Normal Relay Function
if (targets_.size() == 1) {
- // homogeneous execution.
+ // homogeneous execution.
const auto& it = targets_.begin();
target = (*it).second;
} else {
call_dev_name = runtime::DeviceName(call_dev_type);
}
if (targets_.count(call_dev_type) == 0) {
- LOG(FATAL) << "No target is provided for device "
- << call_dev_name;
+ LOG(FATAL) << "No target is provided for device " << call_dev_name;
}
target = targets_[call_dev_type];
}
lowered_funcs_[target->str()] = IRModule::Empty();
}
lowered_funcs_[target->str()]->Update(lowered_func->funcs);
- return GraphAddCallNode(op,
- _GetUniqueName(lowered_func->func_name),
- lowered_func->func_name);
+ return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name);
}
std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
class GraphRuntimeCodegenModule : public runtime::ModuleNode {
public:
GraphRuntimeCodegenModule() {}
- virtual PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
- if (name == "init") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- CHECK_EQ(args.num_args, 2)
- << "The expected of arguments are: "
- << "runtime::Module mod and Map<int, Target> targets";
- void* mod = args[0];
- Map<Integer, tvm::Target> tmp = args[1];
- TargetsMap targets;
- for (const auto& it : tmp) {
- auto dev_type = it.first.as<tir::IntImmNode>();
- CHECK(dev_type);
- targets[dev_type->value] = it.second;
- }
- codegen_ = std::make_shared<GraphRuntimeCodegen>(
- reinterpret_cast<runtime::Module*>(mod), targets);
- });
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
+ if (name == "init") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ CHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
+ << "runtime::Module mod and Map<int, Target> targets";
+ void* mod = args[0];
+ Map<Integer, tvm::Target> tmp = args[1];
+ TargetsMap targets;
+ for (const auto& it : tmp) {
+ auto dev_type = it.first.as<tir::IntImmNode>();
+ CHECK(dev_type);
+ targets[dev_type->value] = it.second;
+ }
+ codegen_ =
+ std::make_shared<GraphRuntimeCodegen>(reinterpret_cast<runtime::Module*>(mod), targets);
+ });
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Function func = args[0];
this->output_ = this->codegen_->Codegen(func);
});
} else if (name == "get_graph_json") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->output_.graph_json;
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; });
} else if (name == "list_params_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Array<runtime::String> ret;
- for (const auto &kv : this->output_.params) {
+ for (const auto& kv : this->output_.params) {
ret.push_back(kv.first);
}
*rv = ret;
}
}
- const char* type_key() const final {
- return "RelayGraphRuntimeCodegenModule";
- }
+ const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; }
private:
std::shared_ptr<GraphRuntimeCodegen> codegen_;
}
TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = CreateGraphCodegenMod();
-});
+ .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); });
} // namespace backend
} // namespace relay
* \file src/relay/interpreter.cc
* \brief An interpreter for the Relay IR.
*/
-#include <tvm/runtime/device_api.h>
-#include <tvm/runtime/object.h>
-#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/pattern_functor.h>
-#include <tvm/relay/interpreter.h>
-#include <tvm/relay/transform.h>
+#include <tvm/driver/driver_api.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
-#include <tvm/driver/driver_api.h>
+#include <tvm/relay/interpreter.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/object.h>
#include "compile_engine.h"
using namespace runtime;
-InterpreterClosure::InterpreterClosure(tvm::Map<Var, ObjectRef> env,
- Function func) {
+InterpreterClosure::InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func) {
ObjectPtr<InterpreterClosureObj> n = make_object<InterpreterClosureObj>();
n->env = std::move(env);
n->func = std::move(func);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<InterpreterClosureObj >([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const InterpreterClosureObj*>(ref.get());
- p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")";
-});
+ .set_dispatch<InterpreterClosureObj>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const InterpreterClosureObj*>(ref.get());
+ p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")";
+ });
inline const PackedFunc& GetPackedFunc(const std::string& name) {
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<RecClosureObj>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const RecClosureObj*>(ref.get());
- p->stream << "RecClosureObj(" << node->clos << ")";
- });
+ .set_dispatch<RecClosureObj>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const RecClosureObj*>(ref.get());
+ p->stream << "RecClosureObj(" << node->clos << ")";
+ });
RefValue::RefValue(ObjectRef value) {
ObjectPtr<RefValueObj> n = make_object<RefValueObj>();
data_ = std::move(n);
}
-TVM_REGISTER_GLOBAL("relay._make.RefValue")
-.set_body_typed([](ObjectRef value){
+TVM_REGISTER_GLOBAL("relay._make.RefValue").set_body_typed([](ObjectRef value) {
return RefValue(value);
});
TVM_REGISTER_NODE_TYPE(RefValueObj);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<RefValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const RefValueObj*>(ref.get());
- p->stream << "RefValueObj(" << node->value << ")";
- });
+ .set_dispatch<RefValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const RefValueObj*>(ref.get());
+ p->stream << "RefValueObj(" << node->value << ")";
+ });
-ConstructorValue::ConstructorValue(int32_t tag,
- tvm::Array<ObjectRef> fields,
+ConstructorValue::ConstructorValue(int32_t tag, tvm::Array<ObjectRef> fields,
Constructor constructor) {
ObjectPtr<ConstructorValueObj> n = make_object<ConstructorValueObj>();
n->tag = tag;
}
TVM_REGISTER_GLOBAL("relay._make.ConstructorValue")
-.set_body_typed([](int32_t tag, tvm::Array<ObjectRef> fields,
- Constructor constructor) {
- return ConstructorValue(tag, fields, constructor);
-});
+ .set_body_typed([](int32_t tag, tvm::Array<ObjectRef> fields, Constructor constructor) {
+ return ConstructorValue(tag, fields, constructor);
+ });
TVM_REGISTER_NODE_TYPE(ConstructorValueObj);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const ConstructorValueObj*>(ref.get());
- p->stream << "ConstructorValueObj(" << node->tag << ","
- << node->fields << ")";
-});
+ .set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const ConstructorValueObj*>(ref.get());
+ p->stream << "ConstructorValueObj(" << node->tag << "," << node->fields << ")";
+ });
/*!
* \brief A stack frame in the Relay interpreter.
*/
struct LocalFrame {
Stack& st;
- explicit LocalFrame(Stack& st, const Frame& fr) : st(st) {
- st.frames.push_back(fr);
- }
+ explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { st.frames.push_back(fr); }
~LocalFrame() { st.frames.pop_back(); }
};
};
// contains DAG in dataflow-form.
//
// Conversion to ANF is recommended before running the interpretation.
-class Interpreter :
- public ExprFunctor<ObjectRef(const Expr& n)>,
- PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
+class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
+ PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
Interpreter(IRModule mod, DLContext context, Target target)
: mod_(mod),
return f();
}
- void extend(const Var& id, ObjectRef v) {
- stack_.current_frame().locals.Set(id, v);
- }
+ void extend(const Var& id, ObjectRef v) { stack_.current_frame().locals.Set(id, v); }
- ObjectRef Lookup(const Var& local) {
- return stack_.Lookup(local);
- }
+ ObjectRef Lookup(const Var& local) { return stack_.Lookup(local); }
- ObjectRef Eval(const Expr& expr) {
- return VisitExpr(expr);
- }
+ ObjectRef Eval(const Expr& expr) { return VisitExpr(expr); }
- ObjectRef VisitExpr_(const VarNode* var_node) final {
- return Lookup(GetRef<Var>(var_node));
- }
+ ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef<Var>(var_node)); }
ObjectRef VisitExpr_(const GlobalVarNode* op) final {
return Eval(mod_->Lookup(GetRef<GlobalVar>(op)));
return ObjectRef();
}
- ObjectRef VisitExpr_(const ConstantNode* op) final {
- return op->data.CopyTo(context_);
- }
+ ObjectRef VisitExpr_(const ConstantNode* op) final { return op->data.CopyTo(context_); }
ObjectRef VisitExpr_(const TupleNode* op) final {
std::vector<ObjectRef> values;
return MakeClosure(func);
}
- Array<Shape> ComputeDynamicShape(const Function& func,
- const Array<ObjectRef>& args) {
+ Array<Shape> ComputeDynamicShape(const Function& func, const Array<ObjectRef>& args) {
CCacheKey key(func, Target::Create("llvm"));
auto cfunc = engine_->LowerShapeFunc(key);
size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
cpu_ctx.device_id = 0;
auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) {
- auto nd_array = Downcast<NDArray>(val);
- if (need_shape) {
- int64_t ndim = nd_array.Shape().size();
- NDArray shape_arr;
- if (ndim == 0) {
- shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx);
- } else {
- shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx);
- int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
- for (auto j = 0; j < ndim; ++j) {
- data[j] = nd_array.Shape()[j];
- }
- }
- inputs[i] = shape_arr;
- setter(i, shape_arr);
+ auto nd_array = Downcast<NDArray>(val);
+ if (need_shape) {
+ int64_t ndim = nd_array.Shape().size();
+ NDArray shape_arr;
+ if (ndim == 0) {
+ shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx);
} else {
- auto arr = nd_array.CopyTo(cpu_ctx);
- inputs[i] = arr;
- setter(i, arr);
+ shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx);
+ int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
+ for (auto j = 0; j < ndim; ++j) {
+ data[j] = nd_array.Shape()[j];
+ }
}
+ inputs[i] = shape_arr;
+ setter(i, shape_arr);
+ } else {
+ auto arr = nd_array.CopyTo(cpu_ctx);
+ inputs[i] = arr;
+ setter(i, arr);
+ }
};
size_t arg_counter = 0;
}
}
}
- CHECK_EQ(arg_counter, cfunc->inputs.size())
- << "Shape function input sizes mismatch";
+ CHECK_EQ(arg_counter, cfunc->inputs.size()) << "Shape function input sizes mismatch";
auto fset_shape_output = [&](size_t i, Type val_type) {
- // TODO(@icemelon): allow recursive tuple
- const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
- CHECK(rtype != nullptr);
- int64_t ndim = rtype->shape.size();
- auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx);
- outputs[i] = arr;
- setter(arg_counter + i, arr);
+ // TODO(@icemelon): allow recursive tuple
+ const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
+ CHECK(rtype != nullptr);
+ int64_t ndim = rtype->shape.size();
+ auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx);
+ outputs[i] = arr;
+ setter(arg_counter + i, arr);
};
auto ret_type = func->body->checked_type();
auto tt = Downcast<TensorType>(ret_type);
fset_shape_output(0, tt);
}
- CHECK_EQ(cfunc->outputs.size(), out_cnt)
- << "Shape function output sizes mismatch";
+ CHECK_EQ(cfunc->outputs.size(), out_cnt) << "Shape function output sizes mismatch";
PackedFunc shape_func;
Module m;
return out_shapes;
}
- ObjectRef InvokePrimitiveOp(const Function& func,
- const Array<ObjectRef>& args) {
+ ObjectRef InvokePrimitiveOp(const Function& func, const Array<ObjectRef>& args) {
const auto* call_node = func->body.as<CallNode>();
if (call_node && call_node->op == debug_op_) {
if (const auto* tuple_type = func->body->checked_type().as<TupleTypeNode>()) {
arg_len += tuple_type->fields.size();
} else {
- CHECK(func->body->checked_type().as<TensorTypeNode>())
- << func->body->checked_type();
+ CHECK(func->body->checked_type().as<TensorTypeNode>()) << func->body->checked_type();
arg_len += 1;
}
std::vector<TVMValue> values(arg_len);
const auto nd_array = Downcast<NDArray>(val);
setter(i, nd_array);
DLContext arg_ctx = nd_array->ctx;
- CHECK(arg_ctx.device_type == context_.device_type &&
- arg_ctx.device_id == context_.device_id)
- << "Interpreter expect context to be "
- << context_ << ", but get " << arg_ctx;
+ CHECK(arg_ctx.device_type == context_.device_type && arg_ctx.device_id == context_.device_id)
+ << "Interpreter expect context to be " << context_ << ", but get " << arg_ctx;
};
int arg_counter = 0;
for (ObjectRef arg : args) {
if (arg->IsInstance<NDArray::ContainerType>()) {
- fset_input(arg_counter++, arg);
+ fset_input(arg_counter++, arg);
} else {
auto adt = Downcast<ADT>(arg);
for (size_t i = 0; i < adt.size(); ++i) {
}
// Invoke the closure
- ObjectRef Invoke(const InterpreterClosure& closure,
- const tvm::Array<ObjectRef>& args,
+ ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array<ObjectRef>& args,
const Var& bind = Var()) {
// Get a reference to the function inside the closure.
if (closure->func->HasNonzeroAttr(attr::kPrimitive)) {
ObjectRef VisitExpr_(const TupleGetItemNode* op) final {
ObjectRef val = Eval(op->tuple);
const auto* adt_obj = val.as<ADTObj>();
- CHECK(adt_obj)
- << "interal error: when evaluating TupleGetItem expected an ADT value";
+ CHECK(adt_obj) << "interal error: when evaluating TupleGetItem expected an ADT value";
auto adt = GetRef<ADT>(adt_obj);
- CHECK_LT(static_cast<size_t>(op->index), adt.size())
- << "internal error: index out of bounds";
+ CHECK_LT(static_cast<size_t>(op->index), adt.size()) << "internal error: index out of bounds";
return adt[op->index];
}
}
}
- ObjectRef VisitExpr_(const RefCreateNode* op) final {
- return RefValue(Eval(op->value));
- }
+ ObjectRef VisitExpr_(const RefCreateNode* op) final { return RefValue(Eval(op->value)); }
ObjectRef VisitExpr_(const RefReadNode* op) final {
ObjectRef r = Eval(op->ref);
return true;
}
- bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final {
- return true;
- }
+ bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { return true; }
bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final {
extend(op->var, v);
const Op& shape_of_op_;
};
-
-TypedPackedFunc<ObjectRef(Expr)>
-CreateInterpreter(
- IRModule mod,
- DLContext context,
- Target target) {
+TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, DLContext context, Target target) {
if (mod.defined()) {
// eta expand to support constructors in argument position
- transform::Sequential seq({
- transform::EtaExpand(
- /* expand_constructor */ true, /* expand_global_var */ false)});
+ transform::Sequential seq({transform::EtaExpand(
+ /* expand_constructor */ true, /* expand_global_var */ false)});
transform::PassContext pass_ctx = transform::PassContext::Current();
tvm::With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);
return TypedPackedFunc<ObjectRef(Expr)>(packed);
}
-TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter")
-.set_body_typed(CreateInterpreter);
+TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter").set_body_typed(CreateInterpreter);
} // namespace relay
} // namespace tvm
* \brief Implementation and registration of parameter dictionary
* serializing/deserializing functions.
*/
-#include <tvm/runtime/registry.h>
+#include "param_dict.h"
+
#include <dmlc/memory_io.h>
+#include <tvm/runtime/registry.h>
#include <string>
-#include <vector>
#include <utility>
-
-#include "param_dict.h"
-
-
+#include <vector>
namespace tvm {
namespace relay {
using namespace runtime;
-TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- CHECK_EQ(args.size() % 2, 0u);
- // `args` is in the form "key, value, key, value, ..."
- size_t num_params = args.size() / 2;
- std::vector<std::string> names;
- names.reserve(num_params);
- std::vector<DLTensor*> arrays;
- arrays.reserve(num_params);
- for (size_t i = 0; i < num_params * 2; i += 2) {
- names.emplace_back(args[i].operator std::string());
- arrays.emplace_back(args[i + 1].operator DLTensor*());
- }
- std::string bytes;
- dmlc::MemoryStringStream strm(&bytes);
- dmlc::Stream* fo = &strm;
- uint64_t header = kTVMNDArrayListMagic, reserved = 0;
- fo->Write(header);
- fo->Write(reserved);
- fo->Write(names);
- {
- uint64_t sz = static_cast<uint64_t>(arrays.size());
- fo->Write(sz);
- for (size_t i = 0; i < sz; ++i) {
- tvm::runtime::SaveDLTensor(fo, arrays[i]);
- }
+TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) {
+ CHECK_EQ(args.size() % 2, 0u);
+ // `args` is in the form "key, value, key, value, ..."
+ size_t num_params = args.size() / 2;
+ std::vector<std::string> names;
+ names.reserve(num_params);
+ std::vector<DLTensor*> arrays;
+ arrays.reserve(num_params);
+ for (size_t i = 0; i < num_params * 2; i += 2) {
+ names.emplace_back(args[i].operator std::string());
+ arrays.emplace_back(args[i + 1].operator DLTensor*());
+ }
+ std::string bytes;
+ dmlc::MemoryStringStream strm(&bytes);
+ dmlc::Stream* fo = &strm;
+ uint64_t header = kTVMNDArrayListMagic, reserved = 0;
+ fo->Write(header);
+ fo->Write(reserved);
+ fo->Write(names);
+ {
+ uint64_t sz = static_cast<uint64_t>(arrays.size());
+ fo->Write(sz);
+ for (size_t i = 0; i < sz; ++i) {
+ tvm::runtime::SaveDLTensor(fo, arrays[i]);
}
- TVMByteArray arr;
- arr.data = bytes.c_str();
- arr.size = bytes.length();
- *rv = arr;
- });
+ }
+ TVMByteArray arr;
+ arr.data = bytes.c_str();
+ arr.size = bytes.length();
+ *rv = arr;
+});
-TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- std::string bytes = args[0];
- std::vector<std::string> names;
- dmlc::MemoryStringStream memstrm(&bytes);
- dmlc::Stream* strm = &memstrm;
- uint64_t header, reserved;
- CHECK(strm->Read(&header))
- << "Invalid parameters file format";
- CHECK(header == kTVMNDArrayListMagic)
- << "Invalid parameters file format";
- CHECK(strm->Read(&reserved))
- << "Invalid parameters file format";
- CHECK(strm->Read(&names))
- << "Invalid parameters file format";
- uint64_t sz;
- strm->Read(&sz, sizeof(sz));
- size_t size = static_cast<size_t>(sz);
- CHECK(size == names.size())
- << "Invalid parameters file format";
- tvm::Array<NamedNDArray> ret;
- for (size_t i = 0; i < size; ++i) {
- tvm::runtime::NDArray temp;
- temp.Load(strm);
- auto n = tvm::make_object<NamedNDArrayNode>();
- n->name = std::move(names[i]);
- n->array = temp;
- ret.push_back(NamedNDArray(n));
- }
- *rv = ret;
- });
+TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) {
+ std::string bytes = args[0];
+ std::vector<std::string> names;
+ dmlc::MemoryStringStream memstrm(&bytes);
+ dmlc::Stream* strm = &memstrm;
+ uint64_t header, reserved;
+ CHECK(strm->Read(&header)) << "Invalid parameters file format";
+ CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format";
+ CHECK(strm->Read(&reserved)) << "Invalid parameters file format";
+ CHECK(strm->Read(&names)) << "Invalid parameters file format";
+ uint64_t sz;
+ strm->Read(&sz, sizeof(sz));
+ size_t size = static_cast<size_t>(sz);
+ CHECK(size == names.size()) << "Invalid parameters file format";
+ tvm::Array<NamedNDArray> ret;
+ for (size_t i = 0; i < size; ++i) {
+ tvm::runtime::NDArray temp;
+ temp.Load(strm);
+ auto n = tvm::make_object<NamedNDArrayNode>();
+ n->name = std::move(names[i]);
+ n->array = temp;
+ ret.push_back(NamedNDArray(n));
+ }
+ *rv = ret;
+});
TVM_REGISTER_NODE_TYPE(NamedNDArrayNode);
#define TVM_RELAY_BACKEND_PARAM_DICT_H_
#include <tvm/node/node.h>
-#include <tvm/tir/expr.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
+#include <tvm/tir/expr.h>
#include <string>
} // namespace relay
} // namespace tvm
-
#endif // TVM_RELAY_BACKEND_UTILS_H_
* \brief A compiler from relay::Module to the VM byte code.
*/
-#include <tvm/te/operation.h>
+#include "compiler.h"
+
+#include <tvm/driver/driver_api.h>
#include <tvm/ir/error.h>
+#include <tvm/relay/attrs/memory.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/qnn/transform.h>
-#include <tvm/support/logging.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
-#include <tvm/relay/attrs/memory.h>
-#include <tvm/driver/driver_api.h>
+#include <tvm/support/logging.h>
+#include <tvm/te/operation.h>
#include <iostream>
#include <memory>
#include <string>
#include <tuple>
#include <vector>
-#include "../utils.h"
+
#include "../../backend/compile_engine.h"
-#include "../../transforms/pass_util.h"
#include "../../op/op_common.h"
-#include "compiler.h"
+#include "../../transforms/pass_util.h"
+#include "../utils.h"
namespace tvm {
namespace relay {
// Runtime register num after compiling the access field path
RegName reg{-1};
- AccessField(MatchValuePtr parent, size_t index)
- : parent(parent), index(index) {}
+ AccessField(MatchValuePtr parent, size_t index) : parent(parent), index(index) {}
~AccessField() {}
};
Var var;
MatchValuePtr val;
- VarBinding(Var var, MatchValuePtr val)
- : var(var), val(val) {}
+ VarBinding(Var var, MatchValuePtr val) : var(var), val(val) {}
~VarBinding() {}
};
/*! \brief The expected tag */
int target_tag;
- TagCompare(MatchValuePtr obj, size_t target)
- : obj(obj), target_tag(target) {
- }
+ TagCompare(MatchValuePtr obj, size_t target) : obj(obj), target_tag(target) {}
~TagCompare() {}
};
using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionObjectPtr>;
using TreeBranchNode = relay::TreeBranchNode<ConditionObjectPtr>;
-TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
- Pattern pattern,
- TreeObjectPtr then_branch,
- TreeObjectPtr else_branch) {
+TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, Pattern pattern,
+ TreeObjectPtr then_branch, TreeObjectPtr else_branch) {
if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars
return then_branch;
}
}
-TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data,
- Clause clause,
- TreeObjectPtr else_branch) {
- return BuildDecisionTreeFromPattern(data, clause->lhs,
- TreeLeafNode::Make(clause->rhs), else_branch);
+TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, Clause clause,
+ TreeObjectPtr else_branch) {
+ return BuildDecisionTreeFromPattern(data, clause->lhs, TreeLeafNode::Make(clause->rhs),
+ else_branch);
}
TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
std::vector<int64_t> ToAllocTensorShape(NDArray shape) {
std::vector<int64_t> raw_shape;
CHECK_EQ(shape->ndim, 1u);
- CHECK_EQ(shape->dtype.code, 0U)
- << "The dtype of constant shape must be int32 or int64, but got "
- << DLDataType2String(shape->dtype);
+ CHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got "
+ << DLDataType2String(shape->dtype);
CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32)
- << "The dtype of constant shape must be int32 or int64, but got"
- << DLDataType2String(shape->dtype);
+ << "The dtype of constant shape must be int32 or int64, but got"
+ << DLDataType2String(shape->dtype);
if (shape->dtype.bits == 64) {
int64_t* int_ptr = reinterpret_cast<int64_t*>(shape->data);
return raw_shape;
}
-
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
}
// TODO(@jroesch): use correct tag
- Emit(Instruction::AllocADT(
- 0,
- tuple->fields.size(),
- fields_registers,
- NewRegister()));
+ Emit(Instruction::AllocADT(0, tuple->fields.size(), fields_registers, NewRegister()));
}
void VisitExpr_(const MatchNode* match_node) {
for (auto input : inputs) {
auto reg = var_register_map_.find(Downcast<Var>(input));
CHECK(reg != var_register_map_.end())
- << "internal error: all variables should be in the register mapping";
+ << "internal error: all variables should be in the register mapping";
argument_registers.push_back(reg->second);
}
for (auto output : outputs) {
auto reg = var_register_map_.find(Downcast<Var>(output));
CHECK(reg != var_register_map_.end())
- << "internal error: all variables should be in the register mapping";
+ << "internal error: all variables should be in the register mapping";
argument_registers.push_back(reg->second);
}
- Emit(Instruction::InvokePacked(op_index,
- argument_registers.size(),
- outputs.size(),
- argument_registers));
+ Emit(Instruction::InvokePacked(op_index, argument_registers.size(), outputs.size(),
+ argument_registers));
}
- void EmitInvokeTVMOp(const Function& func,
- const Expr& inputs,
- const Expr& outputs) {
+ void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) {
std::vector<Index> argument_registers;
CHECK(func->GetAttr<Integer>(attr::kPrimitive, 0) != 0)
- << "internal error: invoke_tvm_op requires the first argument to be a relay::Function";
+ << "internal error: invoke_tvm_op requires the first argument to be a relay::Function";
auto input_tuple = inputs.as<TupleNode>();
- CHECK(input_tuple)
- << "internal error: invoke_tvm_op inputs must be a tuple,"
- << "please file a bug in the memory manifestation pass";
+ CHECK(input_tuple) << "internal error: invoke_tvm_op inputs must be a tuple,"
+ << "please file a bug in the memory manifestation pass";
auto output_tuple = outputs.as<TupleNode>();
- CHECK(output_tuple)
- << "internal error: invoke_tvm_op outputs must be a tuple,"
- << "please file a bug in the memory manifestation pass";
+ CHECK(output_tuple) << "internal error: invoke_tvm_op outputs must be a tuple,"
+ << "please file a bug in the memory manifestation pass";
for (auto input : input_tuple->fields) {
auto reg = var_register_map_.find(Downcast<Var>(input));
CHECK(reg != var_register_map_.end())
- << "internal error: all variables should be in the register mapping";
+ << "internal error: all variables should be in the register mapping";
argument_registers.push_back(reg->second);
}
for (auto output : output_tuple->fields) {
auto reg = var_register_map_.find(Downcast<Var>(output));
CHECK(reg != var_register_map_.end())
- << "internal error: all variables should be in the register mapping";
+ << "internal error: all variables should be in the register mapping";
argument_registers.push_back(reg->second);
}
}
}
- Emit(Instruction::InvokePacked(op_index,
- argument_registers.size(),
- output_tuple->fields.size(),
- argument_registers));
+ Emit(Instruction::InvokePacked(op_index, argument_registers.size(), output_tuple->fields.size(),
+ argument_registers));
}
void VisitExpr_(const CallNode* call_node) {
// allocation operations.
if (op.as<OpNode>()) {
OpMatch<void> matcher;
- matcher.Match("memory.invoke_tvm_op",
- [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
- CHECK_EQ(args.size(), 3);
- EmitInvokeTVMOp(Downcast<Function>(args[0]), args[1], args[2]);
- }).Match("memory.alloc_tensor",
- [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
- CHECK_EQ(args.size(), 2);
-
- // Get the attributes.
- auto alloc_attrs = attrs.as<AllocTensorAttrs>();
- CHECK(alloc_attrs != nullptr)
- << "must be the alloc tensor attrs";
- auto dtype = alloc_attrs->dtype;
-
- // The storage will be passed dynamically.
- this->VisitExpr(args[0]);
- auto storage_register = last_register_;
-
- // If the shape is constant then we will emit a static tensor allocation instruction.
- auto const_shape = args[1].as<ConstantNode>();
-
- if (const_shape) {
- NDArray shape = const_shape->data;
- // TODO(@jroesch): we need to get an RFC done to standarize shape dtype
- std::vector<int64_t> raw_shape = ToAllocTensorShape(shape);
- // Add context field.
- Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister()));
- } else {
- this->VisitExpr(args[1]);
- auto shape_register = last_register_;
- Emit(Instruction::AllocTensorReg(
- storage_register,
- shape_register,
- dtype,
- NewRegister()));
- }
- }).Match("memory.alloc_storage",
- [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
- CHECK_EQ(args.size(), 2);
- // Compute the size of the allocation.
- this->VisitExpr(args[0]);
- auto size_register = last_register_;
-
- this->VisitExpr(args[1]);
- auto alignment_register = last_register_;
-
- // Get the dtype hint from the attributes.
- auto alloc_attrs = attrs.as<AllocStorageAttrs>();
- CHECK(alloc_attrs != nullptr)
- << "must be the alloc tensor attrs";
- auto dtype = alloc_attrs->dtype;
-
- Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, NewRegister()));
- }).Match("memory.shape_func",
- [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
- CHECK_EQ(args.size(), 3);
- auto shape_func = Downcast<Function>(args[0]);
- auto inputs = Downcast<Tuple>(args[1]);
- auto outputs = Downcast<Tuple>(args[2]);
- EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
- }).Match("memory.kill",
- [](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
- LOG(FATAL) << "memory.kill is not yet supported";
- });
+ matcher
+ .Match("memory.invoke_tvm_op",
+ [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+ CHECK_EQ(args.size(), 3);
+ EmitInvokeTVMOp(Downcast<Function>(args[0]), args[1], args[2]);
+ })
+ .Match(
+ "memory.alloc_tensor",
+ [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+ CHECK_EQ(args.size(), 2);
+
+ // Get the attributes.
+ auto alloc_attrs = attrs.as<AllocTensorAttrs>();
+ CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs";
+ auto dtype = alloc_attrs->dtype;
+
+ // The storage will be passed dynamically.
+ this->VisitExpr(args[0]);
+ auto storage_register = last_register_;
+
+ // If the shape is constant then we will emit a static tensor allocation
+ // instruction.
+ auto const_shape = args[1].as<ConstantNode>();
+
+ if (const_shape) {
+ NDArray shape = const_shape->data;
+ // TODO(@jroesch): we need to get an RFC done to standarize shape dtype
+ std::vector<int64_t> raw_shape = ToAllocTensorShape(shape);
+ // Add context field.
+ Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister()));
+ } else {
+ this->VisitExpr(args[1]);
+ auto shape_register = last_register_;
+ Emit(Instruction::AllocTensorReg(storage_register, shape_register, dtype,
+ NewRegister()));
+ }
+ })
+ .Match("memory.alloc_storage",
+ [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+ CHECK_EQ(args.size(), 2);
+ // Compute the size of the allocation.
+ this->VisitExpr(args[0]);
+ auto size_register = last_register_;
+
+ this->VisitExpr(args[1]);
+ auto alignment_register = last_register_;
+
+ // Get the dtype hint from the attributes.
+ auto alloc_attrs = attrs.as<AllocStorageAttrs>();
+ CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs";
+ auto dtype = alloc_attrs->dtype;
+
+ Emit(Instruction::AllocStorage(size_register, alignment_register, dtype,
+ NewRegister()));
+ })
+ .Match("memory.shape_func",
+ [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+ CHECK_EQ(args.size(), 3);
+ auto shape_func = Downcast<Function>(args[0]);
+ auto inputs = Downcast<Tuple>(args[1]);
+ auto outputs = Downcast<Tuple>(args[2]);
+ EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
+ })
+ .Match("memory.kill",
+ [](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+ LOG(FATAL) << "memory.kill is not yet supported";
+ });
matcher(GetRef<Call>(call_node));
return;
}
auto it = context_->global_map.find(global);
CHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
- << " with func_index=" << it->second;
+ << " with func_index=" << it->second;
// TODO(tvm-team):
// Think about mixed call into global that is not a relay::Function
// perhaps establish as an invariance(all functions in mod must be relay::Function)
auto func = Downcast<Function>(context_->module->Lookup(global));
-
if (IsClosure(func)) {
auto arity = func->params.size();
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
Target target_host_;
};
-
-PackedFunc VMCompiler::GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "lower") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
this->Codegen();
});
} else if (name == "get_executable") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = runtime::Module(exec_);
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = runtime::Module(exec_); });
} else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[0];
params_[name] = data_in;
}
-void VMCompiler::Lower(IRModule mod,
- const TargetsMap& targets,
- const tvm::Target& target_host) {
- CHECK_EQ(targets.size(), 1)
- << "Currently VM compiler doesn't support heterogeneous compilation";
+void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
+ CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
BaseFunc base_func = mod->Lookup("main");
CHECK(base_func->IsInstance<FunctionNode>())
// eta expand to support constructors in argument position
pass_seqs.push_back(transform::EtaExpand(
- /* expand_constructor */ true, /* expand_global_var */ false));
+ /* expand_constructor */ true, /* expand_global_var */ false));
pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
return;
}
- auto const &cached_funcs = context_.cached_funcs;
+ auto const& cached_funcs = context_.cached_funcs;
if (cached_funcs.size() == 0) {
return;
}
return runtime::Module(exec);
}
-TVM_REGISTER_GLOBAL("relay._vm._VMCompiler")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("relay._vm._VMCompiler").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateVMCompiler();
});
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
-#include <tvm/support/logging.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
+#include <tvm/support/logging.h>
+#include <tvm/tir/function.h>
+
#include <iostream>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
-#include "../../../runtime/vm/profiler/vm.h"
+
#include "../../../runtime/vm/naive_allocator.h"
+#include "../../../runtime/vm/profiler/vm.h"
#include "../../backend/compile_engine.h"
#include "../../transforms/pass_util.h"
std::unordered_map<tir::PrimFunc, size_t, ObjectHash, ObjectEqual> seen_funcs;
};
-
class VMCompiler : public runtime::ModuleNode {
public:
virtual ~VMCompiler() {}
- virtual PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self);
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
- const char* type_key() const {
- return "VMCompiler";
- }
+ const char* type_key() const { return "VMCompiler"; }
/*!
* \brief Set the parameters
to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device.
*/
- void Lower(IRModule mod,
- const TargetsMap& targets,
- const tvm::Target& target_host);
+ void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host);
/*! \brief Generate the machine code for lowered functions. */
void Codegen();
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/support/logging.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
+#include <tvm/support/logging.h>
+
#include <iostream>
#include <vector>
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
- DLOG(INFO) << "Before inlining primitives: " << global
- << std::endl << AsText(func, false);
+ DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false);
- func = Function(func->params,
- VisitExpr(func->body),
- func->ret_type,
- func->type_params,
+ func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Add(global, func, true);
- DLOG(INFO) << "After inlining primitives: " << global
- << std::endl << AsText(func, false);
+ DLOG(INFO) << "After inlining primitives: " << global << std::endl << AsText(func, false);
}
}
return module_;
Pass InlinePrimitives() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) {
- return relay::vm::PrimitiveInliner(m).Inline();
- };
+ [=](IRModule m, PassContext pc) { return relay::vm::PrimitiveInliner(m).Inline(); };
auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {});
// Eliminate dead code for each function after inlining.
return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives");
}
-TVM_REGISTER_GLOBAL("relay._transform.InlinePrimitives")
-.set_body_typed(InlinePrimitives);
+TVM_REGISTER_GLOBAL("relay._transform.InlinePrimitives").set_body_typed(InlinePrimitives);
} // namespace transform
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/support/logging.h>
-#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
+#include <tvm/support/logging.h>
+
#include <iostream>
#include <vector>
return std::string("lifted_name") + std::to_string(hash);
}
-bool IsClosure(const Function& func) {
- return func->GetAttr<Integer>(attr::kClosure, 0) != 0;
-}
+bool IsClosure(const Function& func) { return func->GetAttr<Integer>(attr::kClosure, 0) != 0; }
Function MarkClosure(Function func) {
return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1));
if (!letrec_.empty() && var == letrec_.back()) {
auto it = lambda_map_.find(var);
CHECK(it != lambda_map_.end());
- return Call(it->second, call->args, call_node->attrs,
- call_node->type_args);
+ return Call(it->second, call->args, call_node->attrs, call_node->type_args);
}
}
return std::move(call);
if (captured_vars.size() == 0 && free_type_vars.size() == 0) {
lifted_func = Function(body->params, body->body, body->ret_type, body->type_params);
} else {
- lifted_func =
- Function(captured_vars, body, func->func_type_annotation(), free_type_vars);
+ lifted_func = Function(captured_vars, body, func->func_type_annotation(), free_type_vars);
lifted_func = MarkClosure(lifted_func);
}
CHECK(lifted_func.defined());
-
if (module_->ContainGlobalVar(name)) {
const auto existing_func = module_->Lookup(name);
- CHECK(tvm::StructuralEqual()(lifted_func, existing_func))
- << "lifted function hash collision";
+ CHECK(tvm::StructuralEqual()(lifted_func, existing_func)) << "lifted function hash collision";
// If an identical function already exists, use its global var.
global = module_->GetGlobalVar(name);
} else {
if (auto* n = pair.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
- func = Function(func->params,
- VisitExpr(func->body),
- func->ret_type,
- func->type_params,
+ func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
}
Pass LambdaLift() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) {
- return relay::vm::LambdaLifter(m).Lift();
- };
+ [=](IRModule m, PassContext pc) { return relay::vm::LambdaLifter(m).Lift(); };
return CreateModulePass(pass_func, 1, "LambdaLift", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.LambdaLift")
-.set_body_typed(LambdaLift);
+TVM_REGISTER_GLOBAL("relay._transform.LambdaLift").set_body_typed(LambdaLift);
} // namespace transform
* \brief Remove unused global relay functions in a relay module.
*/
+#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/support/logging.h>
-#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
+#include <tvm/support/logging.h>
+
#include <iostream>
#include <unordered_set>
#include <vector>
// Record the expressions that are being visited
std::unordered_set<Expr, ObjectHash, ObjectEqual> visiting_;
- explicit CallTracer(const IRModule& module)
- : module_{module},
- called_funcs_{},
- visiting_{} {}
+ explicit CallTracer(const IRModule& module) : module_{module}, called_funcs_{}, visiting_{} {}
void VisitExpr_(const GlobalVarNode* op) final {
called_funcs_.insert(op->name_hint);
*
* \return The module with dead functions removed.
*/
-IRModule RemoveUnusedFunctions(const IRModule& module,
- Array<runtime::String> entry_funcs) {
+IRModule RemoveUnusedFunctions(const IRModule& module, Array<runtime::String> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
auto funcs = CallTracer(module).Trace(entry);
namespace transform {
Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m,
+ PassContext pc) {
return relay::vm::RemoveUnusedFunctions(m, entry_functions);
};
return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions")
-.set_body_typed(RemoveUnusedFunctions);
+TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions);
} // namespace transform
* \file src/ir/adt.cc
* \brief AST nodes for Relay algebraic data types (ADTs).
*/
-#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
+#include <tvm/relay/type.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
-TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard")
-.set_body_typed([]() {
- return PatternWildcard();
-});
+TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard").set_body_typed([]() { return PatternWildcard(); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, ReprPrinter* p) {
- p->stream << "PatternWildcardNode()";
-});
+ .set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ p->stream << "PatternWildcardNode()";
+ });
PatternVar::PatternVar(tvm::relay::Var var) {
ObjectPtr<PatternVarNode> n = make_object<PatternVarNode>();
TVM_REGISTER_NODE_TYPE(PatternVarNode);
-TVM_REGISTER_GLOBAL("relay.ir.PatternVar")
-.set_body_typed([](tvm::relay::Var var) {
+TVM_REGISTER_GLOBAL("relay.ir.PatternVar").set_body_typed([](tvm::relay::Var var) {
return PatternVar(var);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PatternVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const PatternVarNode*>(ref.get());
- p->stream << "PatternVarNode(" << node->var << ")";
-});
+ .set_dispatch<PatternVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const PatternVarNode*>(ref.get());
+ p->stream << "PatternVarNode(" << node->var << ")";
+ });
-PatternConstructor::PatternConstructor(Constructor constructor,
- tvm::Array<Pattern> patterns) {
+PatternConstructor::PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns) {
ObjectPtr<PatternConstructorNode> n = make_object<PatternConstructorNode>();
n->constructor = std::move(constructor);
n->patterns = std::move(patterns);
TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor")
-.set_body_typed([](Constructor constructor, tvm::Array<Pattern> patterns) {
- return PatternConstructor(constructor, patterns);
-});
+ .set_body_typed([](Constructor constructor, tvm::Array<Pattern> patterns) {
+ return PatternConstructor(constructor, patterns);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const PatternConstructorNode*>(ref.get());
- p->stream << "PatternConstructorNode(" << node->constructor
- << ", " << node->patterns << ")";
-});
+ .set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const PatternConstructorNode*>(ref.get());
+ p->stream << "PatternConstructorNode(" << node->constructor << ", " << node->patterns << ")";
+ });
PatternTuple::PatternTuple(tvm::Array<Pattern> patterns) {
ObjectPtr<PatternTupleNode> n = make_object<PatternTupleNode>();
TVM_REGISTER_NODE_TYPE(PatternTupleNode);
-TVM_REGISTER_GLOBAL("relay.ir.PatternTuple")
-.set_body_typed([](tvm::Array<Pattern> patterns) {
+TVM_REGISTER_GLOBAL("relay.ir.PatternTuple").set_body_typed([](tvm::Array<Pattern> patterns) {
return PatternTuple(patterns);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const PatternTupleNode*>(ref.get());
- p->stream << "PatternTupleNode(" << node->patterns << ")";
-});
+ .set_dispatch<PatternTupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const PatternTupleNode*>(ref.get());
+ p->stream << "PatternTupleNode(" << node->patterns << ")";
+ });
Clause::Clause(Pattern lhs, Expr rhs) {
ObjectPtr<ClauseNode> n = make_object<ClauseNode>();
TVM_REGISTER_NODE_TYPE(ClauseNode);
-TVM_REGISTER_GLOBAL("relay.ir.Clause")
-.set_body_typed([](Pattern lhs, Expr rhs) {
+TVM_REGISTER_GLOBAL("relay.ir.Clause").set_body_typed([](Pattern lhs, Expr rhs) {
return Clause(lhs, rhs);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ClauseNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const ClauseNode*>(ref.get());
- p->stream << "ClauseNode(" << node->lhs << ", "
- << node->rhs << ")";
- });
+ .set_dispatch<ClauseNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const ClauseNode*>(ref.get());
+ p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")";
+ });
Match::Match(Expr data, tvm::Array<Clause> clauses, bool complete) {
ObjectPtr<MatchNode> n = make_object<MatchNode>();
TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_GLOBAL("relay.ir.Match")
-.set_body_typed([](Expr data, tvm::Array<Clause> clauses, bool complete) {
- return Match(data, clauses, complete);
-});
+ .set_body_typed([](Expr data, tvm::Array<Clause> clauses, bool complete) {
+ return Match(data, clauses, complete);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<MatchNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const MatchNode*>(ref.get());
- p->stream << "MatchNode(" << node->data << ", "
- << node->clauses << ", " << node->complete << ")";
-});
+ .set_dispatch<MatchNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const MatchNode*>(ref.get());
+ p->stream << "MatchNode(" << node->data << ", " << node->clauses << ", " << node->complete
+ << ")";
+ });
} // namespace relay
} // namespace tvm
*/
#include <tvm/ir/type.h>
-#include <tvm/runtime/registry.h>
#include <tvm/relay/base.h>
+#include <tvm/runtime/registry.h>
namespace tvm {
namespace relay {
data_ = std::move(n);
}
-TVM_REGISTER_GLOBAL("ir.NodeSetSpan")
-.set_body_typed([](ObjectRef node_ref, Span sp) {
+TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span sp) {
if (auto* rn = node_ref.as<RelayNode>()) {
rn->span = sp;
} else if (auto* rn = node_ref.as<RelayExprNode>()) {
TVM_REGISTER_NODE_TYPE(ConstantNode);
-TVM_REGISTER_GLOBAL("relay.ir.Constant")
-.set_body_typed([](runtime::NDArray data) {
+TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data) {
return Constant(data);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const ConstantNode*>(ref.get());
- const PackedFunc* fprint = Registry::Get("relay._constant_repr");
- CHECK(fprint) << "unable to find printing function for constants";
- std::string data = (*fprint)(GetRef<Constant>(node));
- p->stream << "Constant(" << data << ")";
- });
+ .set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const ConstantNode*>(ref.get());
+ const PackedFunc* fprint = Registry::Get("relay._constant_repr");
+ CHECK(fprint) << "unable to find printing function for constants";
+ std::string data = (*fprint)(GetRef<Constant>(node));
+ p->stream << "Constant(" << data << ")";
+ });
TensorType ConstantNode::tensor_type() const {
auto dtype = DataType(data->dtype);
for (int i = 0; i < data->ndim; i++) {
CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
- shape.push_back(
- tvm::IntImm(DataType::Int(32), data->shape[i]));
+ shape.push_back(tvm::IntImm(DataType::Int(32), data->shape[i]));
}
return TensorType(shape, dtype);
TVM_REGISTER_NODE_TYPE(TupleNode);
-TVM_REGISTER_GLOBAL("relay.ir.Tuple")
-.set_body_typed([](tvm::Array<relay::Expr> fields) {
+TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> fields) {
return Tuple(fields);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TupleNode*>(ref.get());
- p->stream << "Tuple(" << node->fields << ")";
- });
-
+ .set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const TupleNode*>(ref.get());
+ p->stream << "Tuple(" << node->fields << ")";
+ });
Var::Var(Id vid, Type type_annotation) {
ObjectPtr<VarNode> n = make_object<VarNode>();
TVM_REGISTER_NODE_TYPE(VarNode);
-TVM_REGISTER_GLOBAL("relay.ir.Var")
-.set_body_typed([](std::string str, Type type_annotation) {
+TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](std::string str, Type type_annotation) {
return Var(str, type_annotation);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const VarNode*>(ref.get());
- p->stream << "Var(" << node->name_hint();
- if (node->type_annotation.defined()) {
- p->stream << ", ty=";
- p->Print(node->type_annotation);
- }
- p->stream << ")";
- });
+ .set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const VarNode*>(ref.get());
+ p->stream << "Var(" << node->name_hint();
+ if (node->type_annotation.defined()) {
+ p->stream << ", ty=";
+ p->Print(node->type_annotation);
+ }
+ p->stream << ")";
+ });
Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args) {
ObjectPtr<CallNode> n = make_object<CallNode>();
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_GLOBAL("relay.ir.Call")
-.set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args) {
- return Call(op, args, attrs, type_args);
-});
+ .set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args) {
+ return Call(op, args, attrs, type_args);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const CallNode*>(ref.get());
- p->stream << "CallNode(" << node->op << ", " << node->args << ", "
- << node->attrs << ", " << node->type_args << ")";
- });
+ .set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const CallNode*>(ref.get());
+ p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", "
+ << node->type_args << ")";
+ });
Let::Let(Var var, Expr value, Expr body) {
ObjectPtr<LetNode> n = make_object<LetNode>();
TVM_REGISTER_NODE_TYPE(LetNode);
-TVM_REGISTER_GLOBAL("relay.ir.Let")
-.set_body_typed([](Var var, Expr value, Expr body) {
+TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body) {
return Let(var, value, body);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const LetNode*>(ref.get());
- p->stream << "LetNode(" << node->var << ", " << node->value
- << ", " << node->body << ")";
-});
+ .set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const LetNode*>(ref.get());
+ p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")";
+ });
If::If(Expr cond, Expr true_branch, Expr false_branch) {
ObjectPtr<IfNode> n = make_object<IfNode>();
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_GLOBAL("relay.ir.If")
-.set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) {
- return If(cond, true_branch, false_branch);
-});
+ .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) {
+ return If(cond, true_branch, false_branch);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const IfNode*>(ref.get());
- p->stream << "IfNode(" << node->cond << ", " << node->true_branch
- << ", " << node->false_branch << ")";
-});
+ .set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const IfNode*>(ref.get());
+ p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", "
+ << node->false_branch << ")";
+ });
TupleGetItem::TupleGetItem(Expr tuple, int index) {
ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
-TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem")
-.set_body_typed([](Expr tuple, int index) {
+TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) {
return TupleGetItem(tuple, index);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TupleGetItemNode*>(ref.get());
- p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
-});
+ .set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const TupleGetItemNode*>(ref.get());
+ p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
+ });
RefCreate::RefCreate(Expr value) {
ObjectPtr<RefCreateNode> n = make_object<RefCreateNode>();
TVM_REGISTER_NODE_TYPE(RefCreateNode);
-TVM_REGISTER_GLOBAL("relay.ir.RefCreate")
-.set_body_typed([](Expr value) {
+TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) {
return RefCreate(value);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const RefCreateNode*>(ref.get());
- p->stream << "RefCreateNode(" << node->value << ")";
-});
+ .set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const RefCreateNode*>(ref.get());
+ p->stream << "RefCreateNode(" << node->value << ")";
+ });
RefRead::RefRead(Expr ref) {
ObjectPtr<RefReadNode> n = make_object<RefReadNode>();
TVM_REGISTER_NODE_TYPE(RefReadNode);
-TVM_REGISTER_GLOBAL("relay.ir.RefRead")
-.set_body_typed([](Expr ref) {
- return RefRead(ref);
-});
+TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return RefRead(ref); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const RefReadNode*>(ref.get());
- p->stream << "RefReadNode(" << node->ref << ")";
-});
+ .set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const RefReadNode*>(ref.get());
+ p->stream << "RefReadNode(" << node->ref << ")";
+ });
RefWrite::RefWrite(Expr ref, Expr value) {
ObjectPtr<RefWriteNode> n = make_object<RefWriteNode>();
TVM_REGISTER_NODE_TYPE(RefWriteNode);
-TVM_REGISTER_GLOBAL("relay.ir.RefWrite")
-.set_body_typed([](Expr ref, Expr value) {
+TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value) {
return RefWrite(ref, value);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const RefWriteNode*>(ref.get());
- p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
-});
+ .set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const RefWriteNode*>(ref.get());
+ p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
+ });
-TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize")
-.set_body_typed([](TempExpr temp) {
+TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize").set_body_typed([](TempExpr temp) {
return temp->Realize();
});
-TVM_REGISTER_GLOBAL("relay.ir.Any")
-.set_body_typed([]() { return Any::make(); });
+TVM_REGISTER_GLOBAL("relay.ir.Any").set_body_typed([]() { return Any::make(); });
} // namespace relay
} // namespace tvm
}
}
-Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) {
- return ExprMutator::VisitExpr(expr);
-}
+Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) { return ExprMutator::VisitExpr(expr); }
Expr MixedModeMutator::VisitExpr(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto post = ExprFunctor::VisitExpr(expr);
return rewriter_->Rewrite(expr, post);
}
+
protected:
ExprRewriter* rewriter_;
};
return GetRef<Expr>(op);
}
-Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
- return GetRef<Expr>(op);
-}
+Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef<Expr>(op); }
-Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) {
- return GetRef<Expr>(op);
-}
+Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef<Expr>(op); }
-Expr ExprMutator::VisitExpr_(const OpNode* op) {
- return GetRef<Expr>(op);
-}
+Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); }
Expr ExprMutator::VisitExpr_(const TupleNode* op) {
tvm::Array<Expr> fields;
auto ret_type = this->VisitType(op->ret_type);
auto body = this->Mutate(op->body);
- if (all_ty_params_unchanged &&
- all_params_unchanged &&
- ret_type.same_as(op->ret_type) &&
+ if (all_ty_params_unchanged && all_params_unchanged && ret_type.same_as(op->ret_type) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
auto value = this->Mutate(op->value);
auto body = this->Mutate(op->body);
- if (var.same_as(op->var) &&
- value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Let(var, value, body);
auto guard = this->Mutate(op->cond);
auto true_b = this->Mutate(op->true_branch);
auto false_b = this->Mutate(op->false_branch);
- if (op->cond.same_as(guard) &&
- op->true_branch.same_as(true_b) &&
+ if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) &&
op->false_branch.same_as(false_b)) {
- return GetRef<Expr>(op);;
+ return GetRef<Expr>(op);
} else {
return If(guard, true_b, false_b);
}
}
}
-Expr ExprMutator::VisitExpr_(const ConstructorNode* c) {
- return GetRef<Expr>(c);
-}
+Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef<Expr>(c); }
Expr ExprMutator::VisitExpr_(const MatchNode* m) {
std::vector<Clause> clauses;
}
}
-void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
-}
+void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {}
-void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {
-}
+void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {}
void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
for (auto field : op->fields) {
void ExprVisitor::VisitExpr_(const OpNode* op) { return; }
-void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
- this->VisitExpr(op->tuple);
-}
+void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); }
-void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
- this->VisitExpr(op->value);
-}
+void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitExpr(op->value); }
-void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) {
- this->VisitExpr(op->ref);
-}
+void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { this->VisitExpr(op->ref); }
void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
this->VisitExpr(op->ref);
ExprApplyVisit(fvisit).VisitExpr(e);
}
-TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit")
-.set_body_typed([](Expr expr, PackedFunc f) {
- PostOrderVisit(expr, [f](const Expr& n) {
- f(n);
- });
- });
+TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) {
+ PostOrderVisit(expr, [f](const Expr& n) { f(n); });
+});
// Implement bind.
class ExprBinder : public ExprMutator, PatternMutator {
public:
- explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
- : args_map_(args_map) {
- }
+ explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) : args_map_(args_map) {}
Expr VisitExpr_(const LetNode* op) final {
- CHECK(!args_map_.count(op->var))
- << "Cannot bind an internel variable in let";
+ CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let";
return ExprMutator::VisitExpr_(op);
}
Expr VisitExpr_(const FunctionNode* op) final {
for (Var param : op->params) {
- CHECK(!args_map_.count(param))
- << "Cannnot bind an internal function parameter";
+ CHECK(!args_map_.count(param)) << "Cannnot bind an internal function parameter";
}
return ExprMutator::VisitExpr_(op);
}
}
}
- Pattern VisitPattern(const Pattern& p) final {
- return PatternMutator::VisitPattern(p);
- }
+ Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
}
Var VisitVar(const Var& v) final {
- CHECK(!args_map_.count(v))
- << "Cannnot bind an internal pattern variable";
+ CHECK(!args_map_.count(v)) << "Cannnot bind an internal pattern variable";
return v;
}
new_params.push_back(param);
}
}
- if (new_body.same_as(func->body) &&
- new_params.size() == func->params.size()) {
+ if (new_body.same_as(func->body) && new_params.size() == func->params.size()) {
return expr;
}
- auto ret = Function(new_params,
- new_body,
- func->ret_type,
- func->type_params,
- func->attrs);
+ auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs);
std::unordered_set<Var, ObjectHash, ObjectEqual> set;
for (const auto& v : FreeVars(expr)) {
set.insert(v);
new_params.push_back(v);
}
}
- ret = Function(new_params,
- new_body,
- func->ret_type,
- func->type_params,
- func->attrs);
+ ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs);
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
return std::move(ret);
} else {
}
}
-TVM_REGISTER_GLOBAL("relay.ir.Bind")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- ObjectRef input = args[0];
- if (input->IsInstance<ExprNode>()) {
- *ret = Bind(Downcast<Expr>(input), args[1]);
- } else {
- CHECK(input->IsInstance<TypeNode>());
- *ret = Bind(Downcast<Type>(input), args[1]);
- }
- });
+TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret) {
+ ObjectRef input = args[0];
+ if (input->IsInstance<ExprNode>()) {
+ *ret = Bind(Downcast<Expr>(input), args[1]);
+ } else {
+ CHECK(input->IsInstance<TypeNode>());
+ *ret = Bind(Downcast<Type>(input), args[1]);
+ }
+});
} // namespace relay
} // namespace tvm
namespace tvm {
namespace relay {
-Function::Function(tvm::Array<Var> params,
- Expr body,
- Type ret_type,
- tvm::Array<TypeVar> type_params,
- DictAttrs attrs) {
+Function::Function(tvm::Array<Var> params, Expr body, Type ret_type,
+ tvm::Array<TypeVar> type_params, DictAttrs attrs) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
CHECK(params.defined());
CHECK(type_params.defined());
FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types;
for (auto param : this->params) {
- Type param_type = (param->type_annotation.defined()) ? param->type_annotation
- : IncompleteType(Kind::kType);
+ Type param_type =
+ (param->type_annotation.defined()) ? param->type_annotation : IncompleteType(Kind::kType);
param_types.push_back(param_type);
}
- Type ret_type = (this->ret_type.defined()) ? this->ret_type
- : IncompleteType(Kind::kType);
+ Type ret_type = (this->ret_type.defined()) ? this->ret_type : IncompleteType(Kind::kType);
return FuncType(param_types, ret_type, this->type_params, {});
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay.ir.Function")
-.set_body_typed([](tvm::Array<Var> params,
- Expr body,
- Type ret_type,
- tvm::Array<TypeVar> ty_params,
- tvm::DictAttrs attrs) {
- return Function(params, body, ret_type, ty_params, attrs);
-});
+ .set_body_typed([](tvm::Array<Var> params, Expr body, Type ret_type,
+ tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
+ return Function(params, body, ret_type, ty_params, attrs);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const FunctionNode*>(ref.get());
- p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
- << ", " << node->body << ", " << node->type_params << ", "
- << node->attrs << ")";
-});
+ .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const FunctionNode*>(ref.get());
+ p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body
+ << ", " << node->type_params << ", " << node->attrs << ")";
+ });
} // namespace relay
} // namespace tvm
TVM_REGISTER_NODE_TYPE(OpSpecializationNode);
TVM_REGISTER_NODE_TYPE(OpStrategyNode);
-Array<te::Tensor> OpImplementation::Compute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> OpImplementation::Compute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return (*this)->fcompute(attrs, inputs, out_type);
}
-te::Schedule OpImplementation::Schedule(const Attrs& attrs,
- const Array<te::Tensor> &outs,
+te::Schedule OpImplementation::Schedule(const Attrs& attrs, const Array<te::Tensor>& outs,
const Target& target) {
return (*this)->fschedule(attrs, outs, target);
}
void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute,
- tvm::relay::FTVMSchedule fschedule,
- std::string name,
+ tvm::relay::FTVMSchedule fschedule, std::string name,
int plevel) {
auto n = make_object<OpImplementationNode>();
n->fcompute = fcompute;
(*this)->implementations.push_back(OpImplementation(n));
}
-void OpStrategy::AddImplementation(FTVMCompute fcompute,
- FTVMSchedule fschedule,
- std::string name,
+void OpStrategy::AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name,
int plevel) {
auto curr_cond = te::SpecializedCondition::Current();
auto self = this->operator->();
}
TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- OpImplementation imp = args[0];
- Attrs attrs = args[1];
- Array<te::Tensor> inputs = args[2];
- Type out_type = args[3];
- *rv = imp.Compute(attrs, inputs, out_type);
-});
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ OpImplementation imp = args[0];
+ Attrs attrs = args[1];
+ Array<te::Tensor> inputs = args[2];
+ Type out_type = args[3];
+ *rv = imp.Compute(attrs, inputs, out_type);
+ });
TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- OpImplementation imp = args[0];
- Attrs attrs = args[1];
- Array<te::Tensor> outs = args[2];
- Target target = args[3];
- *rv = imp.Schedule(attrs, outs, target);
-});
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ OpImplementation imp = args[0];
+ Attrs attrs = args[1];
+ Array<te::Tensor> outs = args[2];
+ Target target = args[3];
+ *rv = imp.Schedule(attrs, outs, target);
+ });
-TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- ObjectPtr<OpStrategyNode> n = make_object<OpStrategyNode>();
- *rv = OpStrategy(n);
+TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy").set_body([](TVMArgs args, TVMRetValue* rv) {
+ ObjectPtr<OpStrategyNode> n = make_object<OpStrategyNode>();
+ *rv = OpStrategy(n);
});
TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- OpStrategy strategy = args[0];
- FTVMCompute compute = args[1];
- FTVMSchedule schedule = args[2];
- std::string name = args[3];
- int plevel = args[4];
- strategy.AddImplementation(compute, schedule, name, plevel);
-});
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ OpStrategy strategy = args[0];
+ FTVMCompute compute = args[1];
+ FTVMSchedule schedule = args[2];
+ std::string name = args[3];
+ int plevel = args[4];
+ strategy.AddImplementation(compute, schedule, name, plevel);
+ });
} // namespace relay
} // namespace tvm
namespace tvm {
namespace relay {
-Pattern PatternMutator::Mutate(const Pattern& pat) {
- return (*this)(pat);
-}
+Pattern PatternMutator::Mutate(const Pattern& pat) { return (*this)(pat); }
-Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) {
- return GetRef<Pattern>(op);
-}
+Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { return GetRef<Pattern>(op); }
Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) {
return PatternVar(VisitVar(op->var));
return PatternTuple(pat);
}
-Type PatternMutator::VisitType(const Type& t) {
- return t;
-}
+Type PatternMutator::VisitType(const Type& t) { return t; }
Var PatternMutator::VisitVar(const Var& v) {
if (var_map_.count(v) == 0) {
- var_map_.insert(std::pair<Var, Var>(v,
- Var(v->name_hint(),
- VisitType(v->type_annotation))));
+ var_map_.insert(std::pair<Var, Var>(v, Var(v->name_hint(), VisitType(v->type_annotation))));
}
return var_map_.at(v);
}
-Constructor PatternMutator::VisitConstructor(const Constructor& v) {
- return v;
-}
+Constructor PatternMutator::VisitConstructor(const Constructor& v) { return v; }
-void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { }
+void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) {}
-void PatternVisitor::VisitPattern_(const PatternVarNode* op) {
- VisitVar(op->var);
-}
+void PatternVisitor::VisitPattern_(const PatternVarNode* op) { VisitVar(op->var); }
void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) {
VisitConstructor(op->constructor);
}
}
-void PatternVisitor::VisitType(const Type& t) { }
+void PatternVisitor::VisitType(const Type& t) {}
-void PatternVisitor::VisitVar(const Var& v) {
- VisitType(v->type_annotation);
-}
+void PatternVisitor::VisitVar(const Var& v) { VisitType(v->type_annotation); }
void PatternVisitor::VisitConstructor(const Constructor& c) {
for (const auto& inp : c->inputs) {
* \brief Relay specific transformation passes.
*/
#include <dmlc/thread_local.h>
-#include <tvm/runtime/registry.h>
#include <tvm/node/repr_printer.h>
#include <tvm/relay/transform.h>
-
+#include <tvm/runtime/registry.h>
namespace tvm {
namespace relay {
FunctionPassNode() = default;
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("pass_info", &pass_info);
- }
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
/*!
* \brief Run a function pass on given pass context.
}
// Perform Module -> Module optimizations at the Function level.
-IRModule FunctionPassNode::operator()(IRModule mod,
- const PassContext& pass_ctx) const {
+IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
- DLOG(INFO) << "Executing function pass : "
- << pass_info->name
- << " with opt level: "
- << pass_info->opt_level;
+ DLOG(INFO) << "Executing function pass : " << pass_info->name
+ << " with opt level: " << pass_info->opt_level;
pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module.
// only picks up relay::Function
if (auto* n = it.second.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
- auto updated_func = SkipFunction(func)
- ? func
- : pass_func(func, updated_mod, pass_ctx);
+ auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx);
updates.push_back({it.first, updated_func});
}
}
bool FunctionPassNode::SkipFunction(const Function& func) const {
return (func->GetAttr<String>(attr::kCompiler).defined()) ||
- func->GetAttr<Integer>(attr::kSkipOptimization, 0) != 0;
+ func->GetAttr<Integer>(attr::kSkipOptimization, 0) != 0;
}
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
- int opt_level,
- const std::string& name,
- const tvm::Array<runtime::String>& required) {
+ int opt_level, const std::string& name, const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return FunctionPass(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass")
-.set_body_typed([](runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
- PassInfo pass_info) {
- return FunctionPass(pass_func, pass_info);
-});
+ .set_body_typed(
+ [](runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
+ PassInfo pass_info) { return FunctionPass(pass_func, pass_info); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const FunctionPassNode*>(ref.get());
- const PassInfo info = node->Info();
- p->stream << "Run Function pass: " << info->name
- << " at the optimization level " << info->opt_level;
-});
+ .set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const FunctionPassNode*>(ref.get());
+ const PassInfo info = node->Info();
+ p->stream << "Run Function pass: " << info->name << " at the optimization level "
+ << info->opt_level;
+ });
} // namespace transform
} // namespace relay
* \file argsort.cc
* \brief Argsort operators
*/
-#include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h>
+#include <tvm/relay/op.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(ArgsortAttrs);
-bool ArgsortRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool ArgsortRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const ArgsortAttrs* param = attrs.as<ArgsortAttrs>();
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "Argsort: expect input type to be TensorType but get "
- << types[0];
+ << "Argsort: expect input type to be TensorType but get " << types[0];
return false;
}
reporter->Assign(types[1], TensorType(data->shape, param->dtype));
return true;
}
-Expr MakeArgsort(Expr data,
- int axis,
- bool is_ascend,
- DataType dtype) {
+Expr MakeArgsort(Expr data, int axis, bool is_ascend, DataType dtype) {
auto attrs = make_object<ArgsortAttrs>();
attrs->axis = axis;
attrs->is_ascend = is_ascend;
return Call(op, {data}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op._make.argsort")
-.set_body_typed(MakeArgsort);
+TVM_REGISTER_GLOBAL("relay.op._make.argsort").set_body_typed(MakeArgsort);
RELAY_REGISTER_OP("argsort")
-.describe(R"doc(Returns the indices that would sort an
+ .describe(R"doc(Returns the indices that would sort an
input array along the given axis.
)doc" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<ArgsortAttrs>()
-.add_argument("data", "Tensor", "Input data.")
-.set_support_level(6)
-.add_type_rel("Argsort", ArgsortRel);
+ .set_num_inputs(1)
+ .set_attrs_type<ArgsortAttrs>()
+ .add_argument("data", "Tensor", "Input data.")
+ .set_support_level(6)
+ .add_type_rel("Argsort", ArgsortRel);
} // namespace relay
} // namespace tvm
* \file topk.cc
* \brief TopK operators
*/
-#include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h>
+#include <tvm/relay/op.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(TopKAttrs);
-bool TopKRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
return true;
}
-Expr MakeTopK(Expr data,
- int k,
- int axis,
- std::string ret_type,
- bool is_ascend,
- DataType dtype) {
+Expr MakeTopK(Expr data, int k, int axis, std::string ret_type, bool is_ascend, DataType dtype) {
auto attrs = make_object<TopKAttrs>();
attrs->k = k;
attrs->axis = axis;
return Call(op, {data}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op._make.topk")
-.set_body_typed(MakeTopK);
+TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
RELAY_REGISTER_OP("topk")
-.describe(R"doc(Get the top k elements in an input tensor along the given axis.
+ .describe(R"doc(Get the top k elements in an input tensor along the given axis.
)doc" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<TopKAttrs>()
-.add_argument("data", "Tensor", "Input data.")
-.set_support_level(6)
-.add_type_rel("TopK", TopKRel);
+ .set_num_inputs(1)
+ .set_attrs_type<TopKAttrs>()
+ .add_argument("data", "Tensor", "Input data.")
+ .set_support_level(6)
+ .add_type_rel("TopK", TopKRel);
} // namespace relay
} // namespace tvm
-
* \brief Registration of annotation operators.
*/
-#include <tvm/tir/expr.h>
+#include <topi/elemwise.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
-#include <topi/elemwise.h>
+#include <tvm/tir/expr.h>
#include "../../transforms/infer_layout_util.h"
#include "../type_relations.h"
TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);
TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device")
-.set_body_typed([](Expr data, int device_type) {
- auto attrs = make_object<OnDeviceAttrs>();
- attrs->device_type = device_type;
- static const Op& op = Op::Get("on_device");
- return Call(op, {data}, Attrs(attrs), {});
-});
+ .set_body_typed([](Expr data, int device_type) {
+ auto attrs = make_object<OnDeviceAttrs>();
+ attrs->device_type = device_type;
+ static const Op& op = Op::Get("on_device");
+ return Call(op, {data}, Attrs(attrs), {});
+ });
RELAY_REGISTER_OP("on_device")
-.describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_support_level(10)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ElemwiseArbitraryLayout);
+ .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE)
+ .set_num_inputs(1)
+ .set_support_level(10)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
Expr StopFusion(Expr data) {
static const Op& op = Op::Get("annotation.stop_fusion");
return Call(op, {data}, Attrs{}, {});
}
-TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion")
-.set_body_typed([](Expr data) {
- return StopFusion(data);
+TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion").set_body_typed([](Expr data) {
+ return StopFusion(data);
});
RELAY_REGISTER_OP("annotation.stop_fusion")
-.describe(R"code(Annotate an expression to prevent it being fused with previous expressions.)code"
-TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input data.")
-.add_type_rel("Identity", IdentityRel)
-.set_support_level(10)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
-.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<te::Tensor>& inputs,
- const Type& out_dtype) -> Array<te::Tensor> {
- return {topi::identity(inputs[0])};
- });
+ .describe(
+ R"code(Annotate an expression to prevent it being fused with previous expressions.)code" TVM_ADD_FILELINE)
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input data.")
+ .add_type_rel("Identity", IdentityRel)
+ .set_support_level(10)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute",
+ [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_dtype) -> Array<te::Tensor> {
+ return {topi::identity(inputs[0])};
+ });
// relay.annotation.cast_hint
TVM_REGISTER_NODE_TYPE(CastHintAttrs);
}
RELAY_REGISTER_OP("annotation.cast_hint")
-.describe(R"code(Annotate an expression to be cast into specific data type.)code"
-TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input data.")
-.add_type_rel("Identity", IdentityRel)
-.set_support_level(10)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
-.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<te::Tensor>& inputs,
- const Type& out_dtype) -> Array<te::Tensor> {
- return {topi::identity(inputs[0])};
- });
-
+ .describe(
+ R"code(Annotate an expression to be cast into specific data type.)code" TVM_ADD_FILELINE)
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input data.")
+ .add_type_rel("Identity", IdentityRel)
+ .set_support_level(10)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute",
+ [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_dtype) -> Array<te::Tensor> {
+ return {topi::identity(inputs[0])};
+ });
RELAY_REGISTER_OP("annotation.bitpack_start")
-.describe(R"code(
+ .describe(R"code(
Mark the start of bitpacking.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_support_level(10)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ElemwiseArbitraryLayout)
-.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<te::Tensor>& inputs,
- const Type& out_dtype) -> Array<te::Tensor> {
- return {topi::identity(inputs[0])};
- });
+ .set_num_inputs(1)
+ .set_support_level(10)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute",
+ [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_dtype) -> Array<te::Tensor> {
+ return {topi::identity(inputs[0])};
+ });
RELAY_REGISTER_OP("annotation.bitpack_end")
-.describe(R"code(
+ .describe(R"code(
Mark the end of bitpacking.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_support_level(10)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ElemwiseArbitraryLayout)
-.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<te::Tensor>& inputs,
- const Type& out_dtype) -> Array<te::Tensor> {
- return {topi::identity(inputs[0])};
- });
-
-TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint")
-.set_body_typed([](Expr data) {
+ .set_num_inputs(1)
+ .set_support_level(10)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute",
+ [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_dtype) -> Array<te::Tensor> {
+ return {topi::identity(inputs[0])};
+ });
+
+TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint").set_body_typed([](Expr data) {
static const Op& op = Op::Get("annotation.checkpoint");
return Call(op, {data}, Attrs{}, {});
});
RELAY_REGISTER_OP("annotation.checkpoint")
-.describe(R"code(
+ .describe(R"code(
Mark a checkpoint for checkpointing memory optimization.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_support_level(10)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ElemwiseArbitraryLayout)
-.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<te::Tensor>& inputs,
- const Type& out_dtype) -> Array<te::Tensor> {
- Array<te::Tensor> outputs;
- for (size_t i = 0; i < inputs.size(); ++i) {
- outputs.push_back(topi::identity(inputs[i]));
- }
- return outputs;
- });
+ .set_num_inputs(1)
+ .set_support_level(10)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute",
+ [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_dtype) -> Array<te::Tensor> {
+ Array<te::Tensor> outputs;
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ outputs.push_back(topi::identity(inputs[i]));
+ }
+ return outputs;
+ });
TVM_REGISTER_NODE_TYPE(CompilerAttrs);
RELAY_REGISTER_OP("annotation.compiler_begin")
-.describe(R"code(
+ .describe(R"code(
Beginning of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_support_level(10)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ElemwiseArbitraryLayout)
-.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<te::Tensor>& inputs,
- const Type& out_dtype) -> Array<te::Tensor> {
- return {topi::identity(inputs[0])};
- });
+ .set_num_inputs(1)
+ .set_support_level(10)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute",
+ [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_dtype) -> Array<te::Tensor> {
+ return {topi::identity(inputs[0])};
+ });
TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
-.set_body_typed([](Expr expr, std::string compiler) {
- auto attrs = make_object<CompilerAttrs>();
- attrs->compiler = compiler;
- static const Op& op = Op::Get("annotation.compiler_begin");
- return Call(op, {expr}, Attrs(attrs), {});
-});
+ .set_body_typed([](Expr expr, std::string compiler) {
+ auto attrs = make_object<CompilerAttrs>();
+ attrs->compiler = compiler;
+ static const Op& op = Op::Get("annotation.compiler_begin");
+ return Call(op, {expr}, Attrs(attrs), {});
+ });
RELAY_REGISTER_OP("annotation.compiler_end")
-.describe(R"code(
+ .describe(R"code(
End of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_support_level(10)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ElemwiseArbitraryLayout)
-.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<te::Tensor>& inputs,
- const Type& out_dtype) -> Array<te::Tensor> {
- return {topi::identity(inputs[0])};
- });
+ .set_num_inputs(1)
+ .set_support_level(10)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute",
+ [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_dtype) -> Array<te::Tensor> {
+ return {topi::identity(inputs[0])};
+ });
TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
-.set_body_typed([](Expr expr, std::string compiler) {
- auto attrs = make_object<CompilerAttrs>();
- attrs->compiler = compiler;
- static const Op& op = Op::Get("annotation.compiler_end");
- return Call(op, {expr}, Attrs(attrs), {});
-});
+ .set_body_typed([](Expr expr, std::string compiler) {
+ auto attrs = make_object<CompilerAttrs>();
+ attrs->compiler = compiler;
+ static const Op& op = Op::Get("annotation.compiler_end");
+ return Call(op, {expr}, Attrs(attrs), {});
+ });
} // namespace relay
} // namespace tvm
* \brief Property def of nn operators.
*/
-#include <tvm/tir/data_layout.h>
-#include <tvm/relay/op.h>
-#include <tvm/relay/attrs/debug.h>
#include <topi/elemwise.h>
+#include <tvm/relay/attrs/debug.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+
#include <vector>
-#include "./type_relations.h"
+
#include "./op_common.h"
+#include "./type_relations.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(DebugAttrs);
-Array<te::Tensor> DebugCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> DebugCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- return Array<te::Tensor>{ topi::identity(inputs[0]) };
+ return Array<te::Tensor>{topi::identity(inputs[0])};
}
RELAY_REGISTER_OP("debug")
-.describe(R"code(Enter the interpreter's debugger.
+ .describe(R"code(Enter the interpreter's debugger.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("program", "Tuple", "The program to execute before debugging.")
-.set_support_level(1)
-.set_attrs_type<DebugAttrs>()
-.add_type_rel("Debug", IdentityRel)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<FTVMCompute>("FTVMCompute", DebugCompute);
+ .set_num_inputs(1)
+ .add_argument("program", "Tuple", "The program to execute before debugging.")
+ .set_support_level(1)
+ .set_attrs_type<DebugAttrs>()
+ .add_type_rel("Debug", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<FTVMCompute>("FTVMCompute", DebugCompute);
Expr MakeDebug(Expr expr, std::string name) {
auto dattrs = make_object<DebugAttrs>();
return Call(op, {expr}, Attrs(dattrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.debug")
-.set_body_typed(MakeDebug);
+TVM_REGISTER_GLOBAL("relay.op._make.debug").set_body_typed(MakeDebug);
} // namespace relay
} // namespace tvm
-
* used as "barrier" to avoid fusing operators belonging to differen devices.
*/
-#include <tvm/tir/expr.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
+#include <tvm/tir/expr.h>
-#include "type_relations.h"
#include "../transforms/infer_layout_util.h"
+#include "type_relations.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);
TVM_REGISTER_GLOBAL("relay.op._make.device_copy")
-.set_body_typed([](Expr data, int src_dev_type,
- int dst_dev_type) {
- auto attrs = make_object<DeviceCopyAttrs>();
- attrs->src_dev_type = src_dev_type;
- attrs->dst_dev_type = dst_dev_type;
- static const Op& op = Op::Get("device_copy");
- return Call(op, {data}, Attrs(attrs), {});
-});
+ .set_body_typed([](Expr data, int src_dev_type, int dst_dev_type) {
+ auto attrs = make_object<DeviceCopyAttrs>();
+ attrs->src_dev_type = src_dev_type;
+ attrs->dst_dev_type = dst_dev_type;
+ static const Op& op = Op::Get("device_copy");
+ return Call(op, {data}, Attrs(attrs), {});
+ });
RELAY_REGISTER_OP("device_copy")
-.describe(R"code(
+ .describe(R"code(
Copy data from one tensor to another. The source and destination might be
on different devices.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_support_level(10)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ElemwiseArbitraryLayout);
+ .set_num_inputs(1)
+ .set_support_level(10)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
} // namespace relay
} // namespace tvm
* \file dilation2d.cc
* \brief Morphological dilation operator
*/
-#include <tvm/tir/data_layout.h>
-#include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+
#include "../op_common.h"
namespace tvm {
// relay.image.dilation2d
TVM_REGISTER_NODE_TYPE(Dilation2DAttrs);
-template<typename T>
-Array<Array<Layout> > Dilation2DInferCorrectLayout(
- const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types) {
+template <typename T>
+Array<Array<Layout> > Dilation2DInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
const T* params = attrs.as<T>();
- return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
- {params->data_layout}};
+ return Array<Array<Layout> >{{params->data_layout, params->kernel_layout}, {params->data_layout}};
}
// Positional relay function to create dilation2d operator
// used by frontend FFI.
-Expr MakeDilation2D(Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilations,
- std::string data_layout,
- std::string kernel_layout,
+Expr MakeDilation2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilations, std::string data_layout, std::string kernel_layout,
DataType out_dtype) {
auto attrs = make_object<Dilation2DAttrs>();
attrs->strides = std::move(strides);
template <typename AttrType>
bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
- const TypeReporter& reporter) {
+ const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<tir::AnyNode>()) {
- oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
- param->strides[0]) + 1);
+ oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}
if (!dshape_nchw[3].as<tir::AnyNode>()) {
- oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
- param->strides[1]) + 1);
+ oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
}
return true;
}
-TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d")
-.set_body_typed(MakeDilation2D);
-
+TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d").set_body_typed(MakeDilation2D);
RELAY_REGISTER_OP("image.dilation2d")
-.describe(R"code(Computes grayscale dilation of 4D input and 3D filter.
+ .describe(R"code(Computes grayscale dilation of 4D input and 3D filter.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (in_channels, height, width)
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Dilation2DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(2)
-.add_type_rel("Dilation2D", Dilation2DRel<Dilation2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- Dilation2DInferCorrectLayout<Dilation2DAttrs>);
+ .set_attrs_type<Dilation2DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(2)
+ .add_type_rel("Dilation2D", Dilation2DRel<Dilation2DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ Dilation2DInferCorrectLayout<Dilation2DAttrs>);
} // namespace relay
} // namespace tvm
* \file resize.cc
* \brief Image resize operators
*/
-#include <tvm/tir/data_layout.h>
-#include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+
#include "../op_common.h"
namespace tvm {
TVM_REGISTER_NODE_TYPE(ResizeAttrs);
-bool ResizeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
const Layout in_layout(param->layout);
auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(layout_converter.defined())
- << "Resize only support input layouts that are convertible from NCHW."
- << " But got " << in_layout;
+ << "Resize only support input layouts that are convertible from NCHW."
+ << " But got " << in_layout;
auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(2, param->size[0]);
}
// assign output type
- reporter->Assign(types[1],
- TensorType(layout_converter.BackwardShape(oshape),
- out_dtype));
+ reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), out_dtype));
return true;
}
// Positional relay function to create image operator
// used by frontend FFI.
-Expr MakeResize(Expr data,
- Array<IndexExpr> size,
- std::string layout,
- std::string method,
- std::string coordinate_transformation_mode,
- DataType out_dtype) {
+Expr MakeResize(Expr data, Array<IndexExpr> size, std::string layout, std::string method,
+ std::string coordinate_transformation_mode, DataType out_dtype) {
auto attrs = make_object<ResizeAttrs>();
attrs->size = std::move(size);
attrs->layout = std::move(layout);
return Call(op, {data}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.image._make.resize")
-.set_body_typed(MakeResize);
-
+TVM_REGISTER_GLOBAL("relay.op.image._make.resize").set_body_typed(MakeResize);
RELAY_REGISTER_OP("image.resize")
-.describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation.
+ .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
for layout NHWC
(batch_size, size[0], size[1], channels)
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ResizeAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(5)
-.add_type_rel("Resize", ResizeRel)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
-
+ .set_attrs_type<ResizeAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(5)
+ .add_type_rel("Resize", ResizeRel)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs);
-bool CropAndResizeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool CropAndResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
const auto* boxes = types[1].as<TensorTypeNode>();
const auto* box_indices = types[2].as<TensorTypeNode>();
- if (data == nullptr || boxes == nullptr ||
- box_indices == nullptr) return false;
+ if (data == nullptr || boxes == nullptr || box_indices == nullptr) return false;
const CropAndResizeAttrs* param = attrs.as<CropAndResizeAttrs>();
CHECK(param != nullptr);
oshape.Set(3, crop_size[1]);
auto bshape = layout_converter.BackwardShape(oshape);
// assign output type
- reporter->Assign(types[3],
- TensorType(layout_converter.BackwardShape(oshape),
- out_dtype));
+ reporter->Assign(types[3], TensorType(layout_converter.BackwardShape(oshape), out_dtype));
return true;
}
-Expr MakeCropAndResize(Expr data,
- Expr boxes,
- Expr box_indices,
- Array<IndexExpr> crop_size,
- std::string layout,
- std::string method,
- double extrapolation_value,
+Expr MakeCropAndResize(Expr data, Expr boxes, Expr box_indices, Array<IndexExpr> crop_size,
+ std::string layout, std::string method, double extrapolation_value,
DataType out_dtype) {
auto attrs = make_object<CropAndResizeAttrs>();
attrs->crop_size = std::move(crop_size);
return Call(op, {data, boxes, box_indices}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize")
-.set_body_typed(MakeCropAndResize);
-
+TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize").set_body_typed(MakeCropAndResize);
RELAY_REGISTER_OP("image.crop_and_resize")
- .describe(R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation.
+ .describe(
+ R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
for layout NHWC
(batch_size, crop_size[0], crop_size[1], channels)
)code" TVM_ADD_FILELINE)
-.set_num_inputs(3)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("boxes", "Tensor", "The boxes tensor.")
-.add_argument("box_indices", "Tensor", "The box indices tensor.")
-.set_attrs_type<CropAndResizeAttrs>()
-.set_support_level(5)
-.add_type_rel("CropAndResize", CropAndResizeRel)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("boxes", "Tensor", "The boxes tensor.")
+ .add_argument("box_indices", "Tensor", "The box indices tensor.")
+ .set_attrs_type<CropAndResizeAttrs>()
+ .set_support_level(5)
+ .add_type_rel("CropAndResize", CropAndResizeRel)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
*/
#include <topi/elemwise.h>
-#include <tvm/runtime/data_type.h>
#include <tvm/relay/attrs/memory.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/data_type.h>
#include "../../transforms/infer_layout_util.h"
#include "../op_common.h"
runtime::NDArray shape = konst->data;
std::vector<int64_t> raw_shape;
CHECK_EQ(shape->ndim, 1u);
- CHECK_EQ(shape->dtype.code, 0U)
- << "The dtype of constant shape must be int32 or int64, but got "
- << runtime::DLDataType2String(shape->dtype);
+ CHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got "
+ << runtime::DLDataType2String(shape->dtype);
CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32)
- << "The dtype of constant shape must be int32 or int64, but got"
- << runtime::DLDataType2String(shape->dtype);
+ << "The dtype of constant shape must be int32 or int64, but got"
+ << runtime::DLDataType2String(shape->dtype);
if (shape->dtype.bits == 32) {
const int32_t* int_ptr = reinterpret_cast<int32_t*>(shape->data);
}
}
-TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType")
-.set_body_typed([](Type type) {
+TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType").set_body_typed([](Type type) {
auto types = FlattenTupleType(type);
return Array<Type>(types.begin(), types.end());
});
-TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType")
-.set_body_typed([](Type type, Expr expr) {
+TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType").set_body_typed([](Type type, Expr expr) {
auto exprs = FromTupleType(type, expr);
return Array<Expr>(exprs.begin(), exprs.end());
});
* \brief Property def of bitserial operators.
*/
-#include <tvm/tir/data_layout.h>
#include <tvm/relay/attrs/bitserial.h>
#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
-#include "../op_common.h"
#include "../../transforms/infer_layout_util.h"
+#include "../op_common.h"
namespace tvm {
namespace relay {
packed must be divisible by number of bits.
- **out**: Packed tensor with shape appropriately compressed.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<BitPackAttrs>()
-.add_argument("data", "Tensor", "Input data.")
-.set_support_level(2)
-.add_type_rel("BitPack", BitPackRel);
+ .set_num_inputs(1)
+ .set_attrs_type<BitPackAttrs>()
+ .add_argument("data", "Tensor", "Input data.")
+ .set_support_level(2)
+ .add_type_rel("BitPack", BitPackRel);
// relay.nn.bitserial_conv2d
TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs);
Array<IndexExpr> oshape({dshape_nchw[0], param->channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
- oshape.Set(
- 2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1);
- oshape.Set(
- 3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1);
+ oshape.Set(2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1);
+ oshape.Set(3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1);
DataType out_dtype = param->out_dtype;
oshape = trans_in_layout.BackwardShape(oshape);
// assign output type
- **out**: Output with same layout as input.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<BinaryConv2DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(2)
-.add_type_rel("BinaryConv2D", BinaryConv2DRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- BinaryConv2DInferCorrectLayout<BinaryConv2DAttrs>);
+ .set_attrs_type<BinaryConv2DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(2)
+ .add_type_rel("BinaryConv2D", BinaryConv2DRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ BinaryConv2DInferCorrectLayout<BinaryConv2DAttrs>);
// relay.nn.bitserial_dense
TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs);
- **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<BinaryDenseAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "2D Tensor", "Input data.")
-.add_argument("weight", "2D Tensor", "Weight matrix.")
-.set_support_level(1)
-.add_type_rel("BinaryDense", BinaryDenseRel);
+ .set_attrs_type<BinaryDenseAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "2D Tensor", "Input data.")
+ .add_argument("weight", "2D Tensor", "Weight matrix.")
+ .set_support_level(1)
+ .add_type_rel("BinaryDense", BinaryDenseRel);
} // namespace relay
} // namespace tvm
* \file convolution.cc
* \brief Convolution operators
*/
-#include <tvm/tir/data_layout.h>
-#include <tvm/relay/op.h>
+#include "convolution.h"
+
#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+
#include <vector>
#include "../../transforms/infer_layout_util.h"
#include "../op_common.h"
-#include "convolution.h"
namespace tvm {
namespace relay {
template <typename T>
-Expr MakeConv(Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype,
- std::string op_name) {
+Expr MakeConv(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout, std::string kernel_layout,
+ std::string out_layout, DataType out_dtype, std::string op_name) {
auto attrs = make_object<T>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
}
template <typename T>
-Expr MakeConvWinograd(Expr data,
- Expr weight,
- int tile_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype,
+Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
+ IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout, DataType out_dtype,
std::string op_name) {
auto attrs = make_object<T>();
attrs->tile_size = tile_size;
return Call(op, {data, weight}, Attrs(attrs), {});
}
-Expr MakeConvWinogradWeightTransform(Expr weight,
- int tile_size,
- std::string op_name) {
+Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) {
auto attrs = make_object<ConvWinogradWeightTransformAttrs>();
attrs->tile_size = tile_size;
const Op& op = Op::Get(op_name);
}
template <typename T>
-Expr MakeConvTranspose(Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- Array<IndexExpr> output_padding,
- DataType out_dtype,
- std::string op_name) {
+Expr MakeConvTranspose(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout,
+ Array<IndexExpr> output_padding, DataType out_dtype, std::string op_name) {
auto attrs = make_object<T>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
}
template <typename T>
-Expr MakeDeformableConv(Expr data,
- Expr offset,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int deformable_groups,
- int groups,
- int channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype,
- std::string op_name) {
+Expr MakeDeformableConv(Expr data, Expr offset, Expr weight, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, Array<IndexExpr> dilation, int deformable_groups,
+ int groups, int channels, Array<IndexExpr> kernel_size,
+ std::string data_layout, std::string kernel_layout, std::string out_layout,
+ DataType out_dtype, std::string op_name) {
auto attrs = make_object<T>();
attrs->strides = strides;
attrs->padding = padding;
return Call(op, {data, offset, weight}, Attrs{attrs}, {});
}
-
// relay.nn.conv1d
TVM_REGISTER_NODE_TYPE(Conv1DAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d")
-.set_body_typed([](Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype) {
- return MakeConv<Conv1DAttrs>(
- data, weight, strides, padding, dilation,
- groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, out_dtype, "nn.conv1d");
-});
-
+ .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+ return MakeConv<Conv1DAttrs>(data, weight, strides, padding, dilation, groups, channels,
+ kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
+ "nn.conv1d");
+ });
RELAY_REGISTER_OP("nn.conv1d")
-.describe(R"code(1D convolution layer (e.g. spatial convolution over sequences).
+ .describe(R"code(1D convolution layer (e.g. spatial convolution over sequences).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.
(batch_size, channels, out_width) if `layout` is `NCW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv1DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(2)
-.add_type_rel("Conv1D", Conv1DRel<Conv1DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv1DAttrs>);
-
+ .set_attrs_type<Conv1DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(2)
+ .add_type_rel("Conv1D", Conv1DRel<Conv1DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv1DAttrs>);
// relay.nn.conv2d
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d")
-.set_body_typed([](Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype) {
- return MakeConv<Conv2DAttrs>(
- data, weight, strides, padding, dilation,
- groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, out_dtype, "nn.conv2d");
-});
-
+ .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+ return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,
+ kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
+ "nn.conv2d");
+ });
RELAY_REGISTER_OP("nn.conv2d")
-.describe(R"code(2D convolution layer (e.g. spatial convolution over images).
+ .describe(R"code(2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(2)
-.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
-
+ .set_attrs_type<Conv2DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(2)
+ .add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
// relay.nn.conv3d
TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d")
-.set_body_typed([](Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype) {
- return MakeConv<Conv3DAttrs>(
- data, weight, strides, padding, dilation,
- groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, out_dtype, "nn.conv3d");
-});
-
+ .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+ return MakeConv<Conv3DAttrs>(data, weight, strides, padding, dilation, groups, channels,
+ kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
+ "nn.conv3d");
+ });
RELAY_REGISTER_OP("nn.conv3d")
-.describe(R"code(3D convolution layer (e.g. convolution over 3D image data,
+ .describe(R"code(3D convolution layer (e.g. convolution over 3D image data,
like Magnetic Resonance Imaging (MRI) data in medicine).
This layer creates a convolution kernel that is convolved
(batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv3DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(2)
-.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);
-
+ .set_attrs_type<Conv3DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(2)
+ .add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);
// relay.nn.conv2d_transpose
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose")
-.set_body_typed([](Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- Array<IndexExpr> output_padding,
- DataType out_dtype) {
- return MakeConvTranspose<Conv2DTransposeAttrs>(
- data, weight, strides, padding, dilation,
- groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose");
-});
+ .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout,
+ Array<IndexExpr> output_padding, DataType out_dtype) {
+ return MakeConvTranspose<Conv2DTransposeAttrs>(
+ data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
+ kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose");
+ });
RELAY_REGISTER_OP("nn.conv2d_transpose")
-.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
+ .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DTransposeAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(2)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ConvInferCorrectLayout<Conv2DTransposeAttrs>)
-.add_type_rel("Conv2DTranspose", Conv2DTransposeRel<Conv2DTransposeAttrs>);
+ .set_attrs_type<Conv2DTransposeAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(2)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ ConvInferCorrectLayout<Conv2DTransposeAttrs>)
+ .add_type_rel("Conv2DTranspose", Conv2DTransposeRel<Conv2DTransposeAttrs>);
// relay.nn.conv1d_transpose
TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose")
-.set_body_typed([](Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- Array<IndexExpr> output_padding,
- DataType out_dtype) {
- return MakeConvTranspose<Conv1DTransposeAttrs>(
- data, weight, strides, padding, dilation,
- groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose");
-});
+ .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout,
+ Array<IndexExpr> output_padding, DataType out_dtype) {
+ return MakeConvTranspose<Conv1DTransposeAttrs>(
+ data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
+ kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose");
+ });
RELAY_REGISTER_OP("nn.conv1d_transpose")
-.describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution).
+ .describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution).
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv1DTransposeAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(2)
-.add_type_rel("Conv1DTranspose", Conv1DTransposeRel<Conv1DTransposeAttrs>);
+ .set_attrs_type<Conv1DTransposeAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(2)
+ .add_type_rel("Conv1DTranspose", Conv1DTransposeRel<Conv1DTransposeAttrs>);
// relay.nn.contrib_conv2d_winograd_without_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
-.set_body_typed([](Expr data,
- Expr weight,
- int tile_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype) {
- return MakeConvWinograd<Conv2DWinogradAttrs>(
- data, weight, tile_size, strides, padding, dilation,
- groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_winograd_without_weight_transform");
-});
-
+ .set_body_typed([](Expr data, Expr weight, int tile_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
+ IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+ return MakeConvWinograd<Conv2DWinogradAttrs>(
+ data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size,
+ data_layout, kernel_layout, out_layout, out_dtype,
+ "nn.contrib_conv2d_winograd_without_weight_transform");
+ });
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
-.describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout.
+ .describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv2d_winograd_weight_transform.
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DWinogradAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DWinograd", Conv2DWinogradRel<Conv2DWinogradAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ConvInferCorrectLayout<Conv2DWinogradAttrs>);
+ .set_attrs_type<Conv2DWinogradAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(10)
+ .add_type_rel("Conv2DWinograd", Conv2DWinogradRel<Conv2DWinogradAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ ConvInferCorrectLayout<Conv2DWinogradAttrs>);
// relay.nn.contrib_conv2d_winograd_weight_transform
TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body_typed([](Expr weight,
- int tile_size) {
- return MakeConvWinogradWeightTransform(
- weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform");
-});
+ .set_body_typed([](Expr weight, int tile_size) {
+ return MakeConvWinogradWeightTransform(weight, tile_size,
+ "nn.contrib_conv2d_winograd_weight_transform");
+ });
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
-.describe(R"code(Weight transformation of winograd fast convolution algorithm.
+ .describe(R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another operator in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ConvWinogradWeightTransformAttrs>()
-.set_num_inputs(1)
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
+ .set_attrs_type<ConvWinogradWeightTransformAttrs>()
+ .set_num_inputs(1)
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(10)
+ .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
// relay.nn.contrib_conv3d_winograd_without_weight_transform
TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform")
-.set_body_typed([](Expr data,
- Expr weight,
- int tile_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype) {
- return MakeConvWinograd<Conv3DWinogradAttrs>(
- data, weight, tile_size, strides, padding, dilation,
- groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, out_dtype, "nn.contrib_conv3d_winograd_without_weight_transform");
-});
+ .set_body_typed([](Expr data, Expr weight, int tile_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
+ IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+ return MakeConvWinograd<Conv3DWinogradAttrs>(
+ data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size,
+ data_layout, kernel_layout, out_layout, out_dtype,
+ "nn.contrib_conv3d_winograd_without_weight_transform");
+ });
RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform")
-.describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout.
+ .describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv3d_winograd_weight_transform.
- **out**: Output is 5D array of shape (batch_size, channels, depth, out_height, out_width)
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv3DWinogradAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv3DWinograd", Conv3DWinogradRel<Conv3DWinogradAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ConvInferCorrectLayout<Conv3DWinogradAttrs>);
+ .set_attrs_type<Conv3DWinogradAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(10)
+ .add_type_rel("Conv3DWinograd", Conv3DWinogradRel<Conv3DWinogradAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ ConvInferCorrectLayout<Conv3DWinogradAttrs>);
// relay.nn.contrib_conv3d_winograd_weight_transform
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform")
-.set_body_typed([](Expr weight,
- int tile_size) {
- return MakeConvWinogradWeightTransform(
- weight, tile_size, "nn.contrib_conv3d_winograd_weight_transform");
-});
+ .set_body_typed([](Expr weight, int tile_size) {
+ return MakeConvWinogradWeightTransform(weight, tile_size,
+ "nn.contrib_conv3d_winograd_weight_transform");
+ });
RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform")
.describe(R"code(Weight transformation of winograd fast 3d convolution algorithm.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2])
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ConvWinogradWeightTransformAttrs>()
-.set_num_inputs(1)
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel);
-
+ .set_attrs_type<ConvWinogradWeightTransformAttrs>()
+ .set_num_inputs(1)
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(10)
+ .add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel);
// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
-Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
- int convolution_algorithm,
+Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm,
DataType out_dtype) {
auto attrs = make_object<Conv2DWinogradNNPACKWeightTransformAttrs>();
attrs->convolution_algorithm = convolution_algorithm;
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
-.set_body_typed(MakeConv2DWinogradNNPACKWeightTransform);
+ .set_body_typed(MakeConv2DWinogradNNPACKWeightTransform);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
-.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
+ .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
Separate this into another symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DWinogradNNPACKWeightTransformAttrs>()
-.set_num_inputs(1)
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);
-
+ .set_attrs_type<Conv2DWinogradNNPACKWeightTransformAttrs>()
+ .set_num_inputs(1)
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(10)
+ .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
-.set_body_typed([](Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype) {
- return MakeConv<Conv2DAttrs>(
- data, weight, strides, padding, dilation,
- groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_NCHWc");
-});
+ .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+ return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,
+ kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
+ "nn.contrib_conv2d_NCHWc");
+ });
RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
-.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
+ .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
- **data**: Input is 5D packed tensor.
- **weight**: 6D packed tensor.
- **out**: Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ConvInferCorrectLayout<Conv2DAttrs>);
-
+ .set_attrs_type<Conv2DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(10)
+ .add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
// Positional relay function to create depthwise conv2d NCHWc operator
// used by frontend FFI.
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
-.set_body_typed([](Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype) {
- return MakeConv<Conv2DAttrs>(
- data, weight, strides, padding, dilation,
- groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, out_dtype, "nn.contrib_depthwise_conv2d_NCHWc");
-});
-
+ .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+ return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,
+ kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
+ "nn.contrib_depthwise_conv2d_NCHWc");
+ });
RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
-.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
+ .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
- **data**: Input is 5D packed tensor.
- **weight**: 6D packed tensor.
- **out**: Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ConvInferCorrectLayout<Conv2DAttrs>);
-
+ .set_attrs_type<Conv2DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(10)
+ .add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs);
the convolution on the *i*-th part of the data with the *i*-th weight part. The output is obtained
by concating all the *g* results.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<DeformableConv2DAttrs>()
-.set_num_inputs(3)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("offset", "Tensor", "The offset tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(5)
-.add_type_rel("DeformableConv2D", DeformableConv2DRel<DeformableConv2DAttrs>);
+ .set_attrs_type<DeformableConv2DAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("offset", "Tensor", "The offset tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(5)
+ .add_type_rel("DeformableConv2D", DeformableConv2DRel<DeformableConv2DAttrs>);
// Positional relay function to create deformable_conv2d operator
// used by frontend FFI.
TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d")
-.set_body_typed([](Expr data,
- Expr offset,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int deformable_groups,
- int groups,
- int channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- DataType out_dtype) {
- return MakeDeformableConv<DeformableConv2DAttrs>(
- data, offset, weight, strides, padding, dilation,
- deformable_groups, groups, channels, kernel_size, data_layout,
- kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d");
-});
+ .set_body_typed([](Expr data, Expr offset, Expr weight, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, Array<IndexExpr> dilation, int deformable_groups,
+ int groups, int channels, Array<IndexExpr> kernel_size,
+ std::string data_layout, std::string kernel_layout, std::string out_layout,
+ DataType out_dtype) {
+ return MakeDeformableConv<DeformableConv2DAttrs>(
+ data, offset, weight, strides, padding, dilation, deformable_groups, groups, channels,
+ kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d");
+ });
} // namespace relay
} // namespace tvm
namespace tvm {
namespace relay {
-
// Standard convolution operator shape relations
template <typename AttrType>
bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
// check the size
- CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) )
+ CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
<< "Conv1D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " wshape=" << wshape;
}
if (!dshape_ncw[2].as<tir::AnyNode>()) {
oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize,
- param->strides[0]) + 1);
+ param->strides[0]) +
+ 1);
} else {
oshape.Set(2, dshape_ncw[2]);
}
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
bool is_depthwise = false;
if (param->groups > 1) {
- CHECK(weight && weight->shape.defined()) <<
- "Weight shape must be specified when groups is greater than 1.";
+ CHECK(weight && weight->shape.defined())
+ << "Weight shape must be specified when groups is greater than 1.";
Array<IndexExpr> wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape);
if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) &&
tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) {
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<tir::AnyNode>()) {
- oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
- param->strides[0]) + 1);
+ oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}
if (!dshape_nchw[3].as<tir::AnyNode>()) {
- oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
- param->strides[1]) + 1);
+ oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
}
IndexExpr pad_d, pad_h, pad_w;
GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
- oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z,
- param->strides[0]) + 1);
+ oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_ncdhw[2]);
}
if (!dshape_ncdhw[3].as<tir::AnyNode>()) {
- oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y,
- param->strides[1]) + 1);
+ oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_ncdhw[3]);
}
if (!dshape_ncdhw[4].as<tir::AnyNode>()) {
- oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x,
- param->strides[2]) + 1);
+ oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1);
} else {
oshape.Set(4, dshape_ncdhw[4]);
}
return true;
}
-
// Winograd convolution shape relations
inline bool Conv2DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs, const TypeReporter& reporter) {
CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
- std::vector<IndexExpr> oshape {
+ std::vector<IndexExpr> oshape{
param->tile_size + data->shape[2] - 1,
param->tile_size + data->shape[3] - 1,
data->shape[0],
data->shape[1],
};
- reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
- data->dtype));
+ reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), data->dtype));
return true;
}
// Shape of packed weights depends on whether depth is being transformed or not.
Array<IndexExpr> oshape({0, 0, 0, data->shape[0], data->shape[1]});
auto* depth_imm = data->shape[2].as<IntImmNode>();
- bool transform_depth = (depth_imm->value > 2)&&(depth_imm->value < 8);
+ bool transform_depth = (depth_imm->value > 2) && (depth_imm->value < 8);
if (transform_depth) {
oshape.Set(0, param->tile_size + data->shape[2] - 1);
oshape.Set(1, param->tile_size + data->shape[3] - 1);
return true;
}
-template<typename AttrType>
-bool Conv2DWinogradRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+template <typename AttrType>
+bool Conv2DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
- << "Conv only support input layouts that are convertible from NCHW."
- << " But got " << in_layout;
+ << "Conv only support input layouts that are convertible from NCHW."
+ << " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
- << "Conv only support kernel layouts that are convertible from OIHW."
- << " But got "<< kernel_layout;
+ << "Conv only support kernel layouts that are convertible from OIHW."
+ << " But got " << kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<tir::AnyNode>()) {
- oshape.Set(2, (dshape_nchw[2] + pad_h
- - dilated_ksize_y) / param->strides[0] + 1);
+ oshape.Set(2, (dshape_nchw[2] + pad_h - dilated_ksize_y) / param->strides[0] + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}
if (!dshape_nchw[3].as<tir::AnyNode>()) {
- oshape.Set(3, (dshape_nchw[3] + pad_w
- - dilated_ksize_x) / param->strides[1] + 1);
+ oshape.Set(3, (dshape_nchw[3] + pad_w - dilated_ksize_x) / param->strides[1] + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
}
return true;
}
-
-template<typename AttrType>
-bool Conv3DWinogradRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+template <typename AttrType>
+bool Conv3DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW);
CHECK(trans_in_layout.defined())
- << "Conv only support input layouts that are convertible from NCDHW."
- << " But got " << in_layout;
+ << "Conv only support input layouts that are convertible from NCDHW."
+ << " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW);
CHECK(trans_kernel_layout.defined())
- << "Conv only support kernel layouts that are convertible from OIDHW."
- << " But got "<< kernel_layout;
+ << "Conv only support kernel layouts that are convertible from OIDHW."
+ << " But got " << kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW);
IndexExpr pad_d, pad_h, pad_w;
GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
- oshape.Set(2, (dshape_ncdhw[2] + pad_d
- - dilated_ksize_d) / param->strides[0] + 1);
+ oshape.Set(2, (dshape_ncdhw[2] + pad_d - dilated_ksize_d) / param->strides[0] + 1);
} else {
oshape.Set(2, dshape_ncdhw[2]);
}
if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
- oshape.Set(3, (dshape_ncdhw[3] + pad_h
- - dilated_ksize_y) / param->strides[1] + 1);
+ oshape.Set(3, (dshape_ncdhw[3] + pad_h - dilated_ksize_y) / param->strides[1] + 1);
} else {
oshape.Set(3, dshape_ncdhw[3]);
}
if (!dshape_ncdhw[4].as<tir::AnyNode>()) {
- oshape.Set(4, (dshape_ncdhw[4] + pad_w
- - dilated_ksize_x) / param->strides[2] + 1);
+ oshape.Set(4, (dshape_ncdhw[4] + pad_w - dilated_ksize_x) / param->strides[2] + 1);
} else {
oshape.Set(4, dshape_ncdhw[4]);
}
return true;
}
-
// Transposed convolution shape relations
template <typename AttrType>
-bool Conv1DTransposeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW);
CHECK(trans_in_layout.defined())
- << "Conv only support input layouts that are convertible from NCW."
- << " But got " << in_layout;
+ << "Conv only support input layouts that are convertible from NCW."
+ << " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
CHECK(trans_kernel_layout.defined())
- << "Conv only support kernel layouts that are convertible from OIW."
- << " But got "<< kernel_layout;
+ << "Conv only support kernel layouts that are convertible from OIW."
+ << " But got " << kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW);
CHECK(trans_out_layout.defined())
- << "Conv only support output layouts that are convertible from NCW."
- << " But got " << out_layout;
+ << "Conv only support output layouts that are convertible from NCW."
+ << " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
CHECK_EQ(param->kernel_size.size(), 1);
CHECK_EQ(param->dilation.size(), 1);
- Array<IndexExpr> wshape({dshape_ncw[1],
- indexdiv(param->channels, param->groups),
- param->kernel_size[0]});
+ Array<IndexExpr> wshape(
+ {dshape_ncw[1], indexdiv(param->channels, param->groups), param->kernel_size[0]});
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
<< "Conv1D: shape of weight is inconsistent with kernel_size, "
- << " kernel_size=" << param->kernel_size
- << " wshape=" << Array<IndexExpr>(wshape);
+ << " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv1D: shape of weight is inconsistent with channels, "
- << " channels=" << param->channels
- << " wshape=" << Array<IndexExpr>(wshape);
+ << " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
channels = wshape[1];
IndexExpr pad_w;
GetPaddingWidth(param->padding, &pad_w);
Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
- oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x -
- pad_w + param->output_padding[0]));
+ oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w +
+ param->output_padding[0]));
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
return true;
}
-
template <typename AttrType>
-bool Conv2DTransposeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
- << "Conv only support input layouts that are convertible from NCHW."
- << " But got " << in_layout;
+ << "Conv only support input layouts that are convertible from NCHW."
+ << " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
- << "Conv only support kernel layouts that are convertible from OIHW."
- << " But got "<< kernel_layout;
+ << "Conv only support kernel layouts that are convertible from OIHW."
+ << " But got " << kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
- << "Conv only support output layouts that are convertible from NCHW."
- << " But got " << out_layout;
+ << "Conv only support output layouts that are convertible from NCHW."
+ << " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
- Array<IndexExpr> wshape({dshape_nchw[1],
- indexdiv(param->channels, param->groups),
- param->kernel_size[0],
- param->kernel_size[1]});
+ Array<IndexExpr> wshape({dshape_nchw[1], indexdiv(param->channels, param->groups),
+ param->kernel_size[0], param->kernel_size[1]});
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
- << " kernel_size=" << param->kernel_size
- << " wshape=" << Array<IndexExpr>(wshape);
+ << " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv2D: shape of weight is inconsistent with channels, "
- << " channels=" << param->channels
- << " wshape=" << Array<IndexExpr>(wshape);
+ << " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
channels = wshape[1];
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
- oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
- pad_h + param->output_padding[0]));
- oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
- pad_w + param->output_padding[1]));
+ oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h +
+ param->output_padding[0]));
+ oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w +
+ param->output_padding[1]));
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
return true;
}
-
// Deformable Convolution shape relations.
template <typename AttrType>
bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
- Array<IndexExpr> wshape(
- {param->channels,
- indexdiv(data->shape[1], param->groups),
- param->kernel_size[0],
- param->kernel_size[1]});
+ Array<IndexExpr> wshape({param->channels, indexdiv(data->shape[1], param->groups),
+ param->kernel_size[0], param->kernel_size[1]});
channels = param->channels;
ksize_y = param->kernel_size[0];
ksize_x = param->kernel_size[1];
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "DeformableConv2D: shape of weight is inconsistent with kernel_size, "
- << " kernel_size=" << param->kernel_size
- << " wshape=" << wshape;
+ << " kernel_size=" << param->kernel_size << " wshape=" << wshape;
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "DeformableConv2D: shape of weight is inconsistent with channels, "
- << " channels=" << param->channels
- << " wshape=" << wshape;
+ << " channels=" << param->channels << " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
channels = wshape[0];
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
- oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y,
- param->strides[0]) + 1);
- oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x,
- param->strides[1]) + 1);
+ oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
+ oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
DataType out_dtype = param->out_dtype;
// infer offset shape
- Array<IndexExpr> offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups,
- oshape[2], oshape[3]});
+ Array<IndexExpr> offset_shape(
+ {data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]});
reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
return true;
}
-
-template<typename T>
-Array<Array<Layout> > ConvInferCorrectLayout(
- const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types) {
+template <typename T>
+Array<Array<Layout> > ConvInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
const T* params = attrs.as<T>();
// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
- return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
- {params->out_layout == "" ?
- params->data_layout : params->out_layout}};
+ return Array<Array<Layout> >{
+ {params->data_layout, params->kernel_layout},
+ {params->out_layout == "" ? params->data_layout : params->out_layout}};
}
-
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_
* \brief Property def of nn operators.
*/
-#include <tvm/tir/data_layout.h>
-#include <tvm/relay/op.h>
-#include <tvm/relay/attrs/nn.h>
-#include <tvm/relay/attrs/image.h>
+#include "nn.h"
+
#include <topi/nn.h>
#include <topi/nn/bias_add.h>
-#include <topi/nn/softmax.h>
#include <topi/nn/flatten.h>
-#include <vector>
+#include <topi/nn/softmax.h>
+#include <tvm/relay/attrs/image.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+
#include <string>
-#include "../type_relations.h"
+#include <vector>
+
#include "../../transforms/infer_layout_util.h"
#include "../op_common.h"
-#include "nn.h"
+#include "../type_relations.h"
namespace tvm {
namespace relay {
// relay.nn.bias_add
TVM_REGISTER_NODE_TYPE(BiasAddAttrs);
-bool BiasAddRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BiasAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
<< "axis " << param->axis << " is out of range";
// assign output type
- reporter->Assign(types[1], TensorType(
- {data->shape[axis]}, data->dtype));
+ reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], types[0]);
return true;
}
-
// Positional relay function to create dense operator used by frontend FFI.
-Expr MakeBiasAdd(Expr data,
- Expr bias,
- int axis) {
+Expr MakeBiasAdd(Expr data, Expr bias, int axis) {
auto attrs = make_object<BiasAddAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.bias_add");
return Call(op, {data, bias}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add")
-.set_body_typed(MakeBiasAdd);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add").set_body_typed(MakeBiasAdd);
RELAY_REGISTER_OP("nn.bias_add")
-.describe(R"code(Add bias to an axis of the input.
+ .describe(R"code(Add bias to an axis of the input.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<BiasAddAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "nD Tensor", "Input data.")
-.add_argument("bias", "1D Tensor", "Bias.")
-.set_support_level(1)
-.add_type_rel("BiasAdd", BiasAddRel)
-.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type) {
- const auto* param = attrs.as<BiasAddAttrs>();
- return tvm::Array<tvm::te::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
-});
-
+ .set_attrs_type<BiasAddAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "nD Tensor", "Input data.")
+ .add_argument("bias", "1D Tensor", "Bias.")
+ .set_support_level(1)
+ .add_type_rel("BiasAdd", BiasAddRel)
+ .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ const auto* param = attrs.as<BiasAddAttrs>();
+ return tvm::Array<tvm::te::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
+ });
// relay.nn.fifo_buffer
TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs);
return Call(op, {input, buffer}, Attrs(attrs), {});
}
-bool FIFOBufferRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool FIFOBufferRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* input = types[0].as<TensorTypeNode>();
CHECK(param != nullptr);
CHECK_EQ(input->shape.size(), buffer->shape.size());
- const size_t buffer_axis
- = static_cast<size_t>(param->axis < 0 ? static_cast<int>(buffer->shape.size()) + param->axis
- : param->axis);
+ const size_t buffer_axis = static_cast<size_t>(
+ param->axis < 0 ? static_cast<int>(buffer->shape.size()) + param->axis : param->axis);
reporter->Assert(buffer_axis < buffer->shape.size());
for (size_t i = 0; i < buffer->shape.size(); ++i) {
return true;
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer")
-.set_body_typed(MakeFIFOBuffer);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer").set_body_typed(MakeFIFOBuffer);
RELAY_REGISTER_OP("nn.fifo_buffer")
-.describe(R"code(FIFO buffer
+ .describe(R"code(FIFO buffer
Compute equivalent of
```
* Encoding explicit re-use of computation in convolution ops operated on a sliding window input
* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<FIFOBufferAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "Latest input")
-.add_argument("buffer", "Tensor",
- "Buffer storing latest [length_buffer] inputs")
-.set_support_level(3)
-.add_type_rel("FIFOBuffer", FIFOBufferRel);
-
+ .set_attrs_type<FIFOBufferAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "Latest input")
+ .add_argument("buffer", "Tensor", "Buffer storing latest [length_buffer] inputs")
+ .set_support_level(3)
+ .add_type_rel("FIFOBuffer", FIFOBufferRel);
// relay.nn.dense
TVM_REGISTER_NODE_TYPE(DenseAttrs);
// Positional relay function to create dense operator used by frontend FFI.
-Expr MakeDense(Expr data,
- Expr weight,
- IndexExpr units,
- DataType out_dtype) {
+Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
auto attrs = make_object<DenseAttrs>();
attrs->units = units;
attrs->out_dtype = out_dtype;
return Call(op, {data, weight}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.dense")
-.set_body_typed(MakeDense);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense);
RELAY_REGISTER_OP("nn.dense")
-.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
+ .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
- **data**: `(x1, x2, ..., xn, input_dim)`
- **weight**: `(units, input_dim)`
- **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<DenseAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "nD Tensor", "Input data.")
-.add_argument("weight", "2D Tensor", "Weight matrix.")
-.set_support_level(1)
-.add_type_rel("Dense", DenseRel<DenseAttrs>);
+ .set_attrs_type<DenseAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "nD Tensor", "Input data.")
+ .add_argument("weight", "2D Tensor", "Weight matrix.")
+ .set_support_level(1)
+ .add_type_rel("Dense", DenseRel<DenseAttrs>);
// relay.leaky_relu
TVM_REGISTER_NODE_TYPE(LeakyReluAttrs);
// Positional relay function to create leaky relu operator used by frontend FFI.
-Expr MakeLeakyRelu(Expr data,
- double alpha) {
+Expr MakeLeakyRelu(Expr data, double alpha) {
auto attrs = make_object<LeakyReluAttrs>();
attrs->alpha = alpha;
static const Op& op = Op::Get("nn.leaky_relu");
return Call(op, {data}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu")
-.set_body_typed(MakeLeakyRelu);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu").set_body_typed(MakeLeakyRelu);
RELAY_REGISTER_OP("nn.leaky_relu")
-.describe(R"code(Leaky version of a Rectified Linear Unit.
+ .describe(R"code(Leaky version of a Rectified Linear Unit.
`y = x > 0 ? x : alpha * x`
)code" TVM_ADD_FILELINE)
-.set_attrs_type<LeakyReluAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "Input data.")
-.set_support_level(3)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
-.set_attr<FTVMCompute>(
- "FTVMCompute", [](const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type) {
- const auto* param = attrs.as<LeakyReluAttrs>();
- return Array<te::Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
-});
-
+ .set_attrs_type<LeakyReluAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "Input data.")
+ .set_support_level(3)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ const auto* param = attrs.as<LeakyReluAttrs>();
+ return Array<te::Tensor>{topi::leaky_relu(inputs[0], param->alpha)};
+ });
// relay.prelu
TVM_REGISTER_NODE_TYPE(PReluAttrs);
-bool PReluRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool PReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(param != nullptr);
CHECK(param->axis < static_cast<int>(data->shape.size()))
- << "Wrong axis (" << param->axis << ")value.";
+ << "Wrong axis (" << param->axis << ")value.";
// assign alpha type
Array<IndexExpr> alpha_shape({data->shape[param->axis]});
return true;
}
-template<typename T>
-Array<Array<Layout> > PReluInferCorrectLayout(
- const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types) {
-
+template <typename T>
+Array<Array<Layout>> PReluInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
CHECK_EQ(old_in_layouts.size(), 2U);
CHECK_EQ(old_in_types.size(), 2U);
Layout data_layout = old_in_layouts[0];
if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 2U);
}
- return Array<Array<Layout> >{{data_layout, Layout("C")},
- {data_layout}};
+ return Array<Array<Layout>>{{data_layout, Layout("C")}, {data_layout}};
}
// Positional relay function to create prelu operator used by frontend FFI.
-Expr MakePRelu(Expr data,
- Expr alpha,
- int axis) {
+Expr MakePRelu(Expr data, Expr alpha, int axis) {
auto attrs = make_object<PReluAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.prelu");
return Call(op, {data, alpha}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu")
-.set_body_typed(MakePRelu);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu").set_body_typed(MakePRelu);
RELAY_REGISTER_OP("nn.prelu")
-.describe(R"code(Parametric version of a Rectified Linear Unit.
+ .describe(R"code(Parametric version of a Rectified Linear Unit.
It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`,
where :math:`*` is an channelwise multiplication for each sample in the batch.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<PReluAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "Input data.")
-.add_argument("alpha", "Tensor", "Input channelwise alpha.")
-.set_support_level(3)
-.add_type_rel("PRelu", PReluRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
-.set_attr<FTVMCompute>(
- "FTVMCompute", [](const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type) {
- const auto* param = attrs.as<PReluAttrs>();
- return Array<te::Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
-});
-
+ .set_attrs_type<PReluAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "Input data.")
+ .add_argument("alpha", "Tensor", "Input channelwise alpha.")
+ .set_support_level(3)
+ .add_type_rel("PRelu", PReluRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ const auto* param = attrs.as<PReluAttrs>();
+ return Array<te::Tensor>{topi::prelu(inputs[0], inputs[1], param->axis)};
+ });
// relay.softmax
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
-TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax")
-.set_body_typed([](Expr data, int axis) {
+TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax").set_body_typed([](Expr data, int axis) {
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.softmax");
return Call(op, {data}, Attrs(attrs), {});
});
-
RELAY_REGISTER_OP("nn.softmax")
.describe(R"code(Softmax layer.
- **data**: The input data
)code" TVM_ADD_FILELINE)
-.set_attrs_type<SoftmaxAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(1)
-.add_type_rel("Identity", IdentityRel);
-
+ .set_attrs_type<SoftmaxAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(1)
+ .add_type_rel("Identity", IdentityRel);
// relay.nn.log_softmax
-TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax")
-.set_body_typed([](Expr data, int axis) {
+TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr data, int axis) {
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.log_softmax");
- **data**: The input data
)code" TVM_ADD_FILELINE)
-.set_attrs_type<SoftmaxAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(1)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type) {
- const auto* param = attrs.as<SoftmaxAttrs>();
- CHECK(param != nullptr);
- CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
- << "log_softmax currently only works on last dimension";
- return Array<te::Tensor>{ topi::nn::log_softmax(inputs[0]) };
-});
-
+ .set_attrs_type<SoftmaxAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(1)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ const auto* param = attrs.as<SoftmaxAttrs>();
+ CHECK(param != nullptr);
+ CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
+ << "log_softmax currently only works on last dimension";
+ return Array<te::Tensor>{topi::nn::log_softmax(inputs[0])};
+ });
// relay.nn.batch_flatten
-bool BatchFlattenRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BatchFlattenRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
return Call(op, {data}, Attrs(), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten")
-.set_body_typed(MakeBatchFlatten);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten").set_body_typed(MakeBatchFlatten);
RELAY_REGISTER_OP("nn.batch_flatten")
-.describe(R"code(Flattens the input into a 2-D array.
+ .describe(R"code(Flattens the input into a 2-D array.
For an input array with shape ``(d1, d2, ..., dk)``, `batch_flatten` operation reshapes
the input array into an output array of shape ``(d1, d2*...*dk)``.
[ 1., 2., 3., 4., 5., 6., 7., 8., 9.]]
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("BatchFlatten", BatchFlattenRel)
-.set_attr<FTVMCompute>(
- "FTVMCompute", [](const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type) {
- return Array<te::Tensor>{ topi::nn::flatten(inputs[0]) };
-});
-
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("BatchFlatten", BatchFlattenRel)
+ .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ return Array<te::Tensor>{topi::nn::flatten(inputs[0])};
+ });
// relu
-TVM_REGISTER_GLOBAL("relay.op.nn._make.relu")
-.set_body_typed([](Expr data) {
- static const Op& op = Op::Get("nn.relu");
- return Call(op, {data}, Attrs(), {});
- });
+TVM_REGISTER_GLOBAL("relay.op.nn._make.relu").set_body_typed([](Expr data) {
+ static const Op& op = Op::Get("nn.relu");
+ return Call(op, {data}, Attrs(), {});
+});
RELAY_REGISTER_OP("nn.relu")
-.describe(R"code(Returns the relu input array, computed element-wise.
+ .describe(R"code(Returns the relu input array, computed element-wise.
.. math::
max(x, 0)
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(1)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
-.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type) {
- return Array<te::Tensor>{ topi::relu(inputs[0], 0.0f) };
-});
-
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(1)
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ return Array<te::Tensor>{topi::relu(inputs[0], 0.0f)};
+ });
// Positional relay function to create LRN operator used by frontend FFI.
TVM_REGISTER_NODE_TYPE(LRNAttrs);
-Expr MakeLRN(Expr data,
- int size,
- int axis,
- double alpha,
- double beta,
- double bias) {
+Expr MakeLRN(Expr data, int size, int axis, double alpha, double beta, double bias) {
auto attrs = make_object<LRNAttrs>();
attrs->size = size;
attrs->axis = axis;
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn")
-.set_body_typed(MakeLRN);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn").set_body_typed(MakeLRN);
RELAY_REGISTER_OP("nn.lrn")
-.describe(R"code(LRN layer.
+ .describe(R"code(LRN layer.
Normalize the input in a local region across or within feature maps.
Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta,
- **data**: The input tensor.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<LRNAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("Identity", IdentityRel);
-
+ .set_attrs_type<LRNAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("Identity", IdentityRel);
// Positional relay function to create L2Normalize operator used by frontend FFI.
TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs);
-Expr MakeL2Normalize(Expr data,
- double eps,
- Array<Integer> axis) {
+Expr MakeL2Normalize(Expr data, double eps, Array<Integer> axis) {
auto attrs = make_object<L2NormalizeAttrs>();
attrs->eps = eps;
attrs->axis = std::move(axis);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize")
-.set_body_typed(MakeL2Normalize);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize").set_body_typed(MakeL2Normalize);
RELAY_REGISTER_OP("nn.l2_normalize")
-.describe(R"code(L2 Normalization layer.
+ .describe(R"code(L2 Normalization layer.
Normalizes along dimension axis using an L2 norm
- **data**: The input tensor.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<L2NormalizeAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
-.add_type_rel("Identity", IdentityRel);
+ .set_attrs_type<L2NormalizeAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .add_type_rel("Identity", IdentityRel);
// Dropout
TVM_REGISTER_NODE_TYPE(DropoutAttrs);
-bool DropoutRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool DropoutRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout")
-.set_body_typed(MakeDropout);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout").set_body_typed(MakeDropout);
RELAY_REGISTER_OP("nn.dropout")
-.describe(R"code(Applies the dropout operation to the input array.
+ .describe(R"code(Applies the dropout operation to the input array.
During training, each element of the input is set to zero with probability ``p``.
The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<DropoutAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "Input to which dropout will be applied.")
-.set_support_level(1)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
-.add_type_rel("Dropout", DropoutRel);
+ .set_attrs_type<DropoutAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "Input to which dropout will be applied.")
+ .set_support_level(1)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .add_type_rel("Dropout", DropoutRel);
// batch_norm
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
{ret, c_layout, c_layout}};
}
-bool BatchNormRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BatchNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 6);
const auto* data = types[0].as<TensorTypeNode>();
// output is a tuple of the normed data (same shape as input), new running mean,
// and new running average (the latter two are both vectors of length dim)
std::vector<Type> fields;
- auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}),
- data->dtype);
+ auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}), data->dtype);
fields.push_back(TensorType(data->shape, data->dtype));
fields.push_back(vec_ty);
fields.push_back(vec_ty);
return true;
}
-Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var,
- int axis, double epsilon, bool center, bool scale) {
+Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, int axis,
+ double epsilon, bool center, bool scale) {
auto attrs = make_object<BatchNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
return Call(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm")
-.set_body_typed(MakeBatchNorm);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm").set_body_typed(MakeBatchNorm);
RELAY_REGISTER_OP("nn.batch_norm")
-.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
+ .describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
Normalizes the input at each batch, i.e. applies a transformation
that maintains the mean activation close to 0 and the activation
standard deviation close to 1.
.. note::
This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<BatchNormAttrs>()
-.set_num_inputs(5)
-.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
-.add_argument("gamma", "Tensor", "The gamma scale factor.")
-.add_argument("beta", "Tensor", "The beta offset factor.")
-.add_argument("moving_mean", "Tensor", "Running mean of input.")
-.add_argument("moving_var", "Tensor", "Running variance of input.")
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", BatchNormInferCorrectLayout)
-.set_support_level(1)
-.add_type_rel("BatchNorm", BatchNormRel);
-
+ .set_attrs_type<BatchNormAttrs>()
+ .set_num_inputs(5)
+ .add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
+ .add_argument("gamma", "Tensor", "The gamma scale factor.")
+ .add_argument("beta", "Tensor", "The beta offset factor.")
+ .add_argument("moving_mean", "Tensor", "Running mean of input.")
+ .add_argument("moving_var", "Tensor", "Running variance of input.")
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", BatchNormInferCorrectLayout)
+ .set_support_level(1)
+ .add_type_rel("BatchNorm", BatchNormRel);
// instance_norm
TVM_REGISTER_NODE_TYPE(InstanceNormAttrs);
-bool InstanceNormRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool InstanceNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
return true;
}
-Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
- bool center, bool scale) {
+Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center,
+ bool scale) {
auto attrs = make_object<InstanceNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv);
- });
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv);
+ });
RELAY_REGISTER_OP("nn.instance_norm")
-.describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
+ .describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
Applies instance normalization to the n-dimensional input array.
.. math::
This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<InstanceNormAttrs>()
-.set_num_inputs(3)
-.add_argument("data", "Tensor", "Input to which instance_norm will be applied.")
-.add_argument("gamma", "Tensor", "The gamma scale factor.")
-.add_argument("beta", "Tensor", "The beta offset factor.")
-.set_support_level(1)
-.add_type_rel("InstanceNorm", InstanceNormRel);
-
+ .set_attrs_type<InstanceNormAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "Input to which instance_norm will be applied.")
+ .add_argument("gamma", "Tensor", "The gamma scale factor.")
+ .add_argument("beta", "Tensor", "The beta offset factor.")
+ .set_support_level(1)
+ .add_type_rel("InstanceNorm", InstanceNormRel);
// layer_norm
TVM_REGISTER_NODE_TYPE(LayerNormAttrs);
-bool LayerNormRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool LayerNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
return true;
}
-Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
- bool center, bool scale) {
+Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center,
+ bool scale) {
auto attrs = make_object<LayerNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 7>(MakeLayerNorm, args, rv);
- });
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 7>(MakeLayerNorm, args, rv);
+ });
RELAY_REGISTER_OP("nn.layer_norm")
-.describe(R"code(
+ .describe(R"code(
)code" TVM_ADD_FILELINE)
-.set_attrs_type<LayerNormAttrs>()
-.set_num_inputs(3)
-.add_argument("data", "Tensor", "Input to which layer_norm will be applied.")
-.add_argument("gamma", "Tensor", "The gamma scale factor.")
-.add_argument("beta", "Tensor", "The beta offset factor.")
-.set_support_level(1)
-.add_type_rel("LayerNorm", LayerNormRel);
+ .set_attrs_type<LayerNormAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "Input to which layer_norm will be applied.")
+ .add_argument("gamma", "Tensor", "The gamma scale factor.")
+ .add_argument("beta", "Tensor", "The beta offset factor.")
+ .set_support_level(1)
+ .add_type_rel("LayerNorm", LayerNormRel);
// group_norm
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
-bool GroupNormRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool GroupNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
return true;
}
-Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups,
- int axis, double epsilon, bool center, bool scale) {
+Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, int axis, double epsilon,
+ bool center, bool scale) {
auto attrs = make_object<GroupNormAttrs>();
- attrs->num_groups = num_groups;
+ attrs->num_groups = num_groups;
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 8>(MakeGroupNorm, args, rv);
- });
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 8>(MakeGroupNorm, args, rv);
+ });
RELAY_REGISTER_OP("nn.group_norm")
-.describe(R"code(
+ .describe(R"code(
Group normalization normalizes over group of channels for each training examples.
We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
all the channels into a single group, group normalization becomes Layer normalization.
This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<GroupNormAttrs>()
-.set_num_inputs(3)
-.add_argument("data", "Tensor", "Input to which group_norm will be applied.")
-.add_argument("gamma", "Tensor", "The gamma scale factor.")
-.add_argument("beta", "Tensor", "The beta offset factor.")
-.set_support_level(1)
-.add_type_rel("GroupNorm", GroupNormRel);
-
+ .set_attrs_type<GroupNormAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "Input to which group_norm will be applied.")
+ .add_argument("gamma", "Tensor", "The gamma scale factor.")
+ .add_argument("beta", "Tensor", "The beta offset factor.")
+ .set_support_level(1)
+ .add_type_rel("GroupNorm", GroupNormRel);
// relay.nn.batch_matmul
-bool BatchMatmulRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
CHECK(x->shape.size() == 3 && y->shape.size() == 3);
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "BatchDot: batch dimension doesn't match, "
- << " x shape=" << x->shape
- << ", y shape=" << y->shape;
+ << " x shape=" << x->shape << ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
- << " x shape=" << x->shape
- << ", y shape=" << y->shape;
+ << " x shape=" << x->shape << ", y shape=" << y->shape;
Array<tvm::PrimExpr> oshape = x->shape;
oshape.Set(2, y->shape[1]);
return true;
}
-
// Positional relay function to create batch_matmul operator used by frontend FFI.
-Expr MakeBatchMatmul(Expr x,
- Expr y) {
+Expr MakeBatchMatmul(Expr x, Expr y) {
static const Op& op = Op::Get("nn.batch_matmul");
return Call(op, {x, y}, Attrs(), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul")
-.set_body_typed(MakeBatchMatmul);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul);
RELAY_REGISTER_OP("nn.batch_matmul")
-.describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y`
+ .describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y`
are data in batch.
.. math::
- **out**: `(b, m, n)`.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("x", "3D Tensor", "First input.")
-.add_argument("y", "3D Tensor", "Second input.")
-.set_support_level(10)
-.add_type_rel("BatchMatmul", BatchMatmulRel);
-
+ .set_num_inputs(2)
+ .add_argument("x", "3D Tensor", "First input.")
+ .add_argument("y", "3D Tensor", "Second input.")
+ .set_support_level(10)
+ .add_type_rel("BatchMatmul", BatchMatmulRel);
// relay.nn.cross_entropy
-bool CrossEntropyRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
- const TypeReporter& reporter) {
+bool CrossEntropyRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;
CHECK(x->shape.size() == 2 && y->shape.size() == 2)
- << "CrossEntropy: shapes of x and y is inconsistent, "
- << "x shape = " << x->shape << ", "
- << "y shape = " << y->shape;
+ << "CrossEntropy: shapes of x and y is inconsistent, "
+ << "x shape = " << x->shape << ", "
+ << "y shape = " << y->shape;
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
- << "CrossEntropy: shapes of x and y is inconsistent, "
- << "x shape = " << x->shape << ", "
- << "y shape = " << y->shape;
+ << "CrossEntropy: shapes of x and y is inconsistent, "
+ << "x shape = " << x->shape << ", "
+ << "y shape = " << y->shape;
CHECK(reporter->AssertEQ(x->shape[1], y->shape[1]))
- << "CrossEntropy: shapes of x and y is inconsistent, "
- << "x shape = " << x->shape << ", "
- << "y shape = " << y->shape;
+ << "CrossEntropy: shapes of x and y is inconsistent, "
+ << "x shape = " << x->shape << ", "
+ << "y shape = " << y->shape;
// assign output type
reporter->Assign(types[2], TensorType({}, x->dtype));
return true;
return Call(op, {predictions, targets}, Attrs(), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy")
-.set_body_typed(MakeCrossEntropy);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy").set_body_typed(MakeCrossEntropy);
RELAY_REGISTER_OP("nn.cross_entropy")
-.describe(R"code(
+ .describe(R"code(
Computes cross entropy given predictions and targets.
Do log on the data - do not accept logits.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("x", "1D Tensor", "Predictions.")
-.add_argument("y", "1D Tensor", "Targets.")
-.set_support_level(10)
-.add_type_rel("CrossEntropy", CrossEntropyRel);
-
+ .set_num_inputs(2)
+ .add_argument("x", "1D Tensor", "Predictions.")
+ .add_argument("y", "1D Tensor", "Targets.")
+ .set_support_level(10)
+ .add_type_rel("CrossEntropy", CrossEntropyRel);
// relay.nn.dilate
TVM_REGISTER_NODE_TYPE(DilateAttrs);
-bool DilateRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool DilateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* x = types[0].as<TensorTypeNode>();
return Call(op, {data}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate")
-.set_body_typed(MakeDilate);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate").set_body_typed(MakeDilate);
RELAY_REGISTER_OP("nn.dilate")
-.describe(R"code(
+ .describe(R"code(
Dilate data with zeros.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("x", "1D Tensor", "Data to dilate.")
-.set_support_level(10)
-.add_type_rel("Dilate", DilateRel);
+ .set_num_inputs(1)
+ .add_argument("x", "1D Tensor", "Data to dilate.")
+ .set_support_level(10)
+ .add_type_rel("Dilate", DilateRel);
// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
return Call(op, {predictions, targets}, Attrs(), {});
}
-
TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy_with_logits")
-.set_body_typed(MakeCrossEntropyWithLogits);
-
+ .set_body_typed(MakeCrossEntropyWithLogits);
RELAY_REGISTER_OP("nn.cross_entropy_with_logits")
-.describe(R"code(
+ .describe(R"code(
Computes cross entropy given predictions and targets.
Accept logits.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("x", "1D Tensor", "Predictions.")
-.add_argument("y", "1D Tensor", "Targets.")
-.set_support_level(10)
-.add_type_rel("CrossEntropy", CrossEntropyRel);
+ .set_num_inputs(2)
+ .add_argument("x", "1D Tensor", "Predictions.")
+ .add_argument("y", "1D Tensor", "Targets.")
+ .set_support_level(10)
+ .add_type_rel("CrossEntropy", CrossEntropyRel);
// Depth to space and space to depth
TVM_REGISTER_NODE_TYPE(SubPixelAttrs);
oshape.Set(3, oshape[3] * block_size);
// Assign output type
- reporter->Assign(types[1],
- TensorType(layout_converter.BackwardShape(oshape), data->dtype));
+ reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
return true;
}
oshape.Set(3, indexdiv(oshape[3], block_size));
// Assign output type
- reporter->Assign(types[1],
- TensorType(layout_converter.BackwardShape(oshape), data->dtype));
+ reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
return true;
}
#ifndef TVM_RELAY_OP_NN_NN_H_
#define TVM_RELAY_OP_NN_NN_H_
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/expr.h>
+#include <tvm/node/container.h>
+#include <tvm/relay/type.h>
+
#include <utility>
namespace tvm {
if (weight == nullptr) return false;
Array<tvm::PrimExpr> wshape = weight->shape;
CHECK(static_cast<int>(weight->shape.size()) == 2);
- CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1],
- weight->shape[1]))
+ CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
<< "DenseRel: input dimension doesn't match,"
<< " data shape=" << data->shape << ", weight shape=" << weight->shape;
oshape.Set((oshape.size() - 1), wshape[0]);
* \file pad.cc
* \brief Implementation of operator pad
*/
+#include <topi/nn.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/op.h>
-#include <tvm/relay/op.h>
-#include <tvm/relay/attrs/nn.h>
-#include <topi/nn.h>
+
#include <vector>
+
#include "../op_common.h"
namespace tvm {
// relay.nn.pad
TVM_REGISTER_NODE_TYPE(PadAttrs);
-Array<Array<Layout> > PadInferCorrectLayout(
- const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types) {
+Array<Array<Layout>> PadInferCorrectLayout(const Attrs& attrs, const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
- PadAttrs *params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());
+ PadAttrs* params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());
Layout ret;
// If new_in_layouts are defined, this code tries to modify the layout.
}
}
- return Array<Array<Layout> >{{ret}, {ret}};
+ return Array<Array<Layout>>{{ret}, {ret}};
}
-bool PadRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
// check that pad widths match lengths
CHECK(data->shape.size() == param->pad_width.size())
- << "There should be as many pad width pairs as shape dimensions "
- << "but the shape has " << data->shape.size() << " dimensions "
- << "and there are " << param->pad_width.size() << " pad width pairs.";
+ << "There should be as many pad width pairs as shape dimensions "
+ << "but the shape has " << data->shape.size() << " dimensions "
+ << "and there are " << param->pad_width.size() << " pad width pairs.";
// each pad width element should be a pair of positive integers
std::vector<IndexExpr> oshape;
for (size_t i = 0; i < param->pad_width.size(); i++) {
CHECK(param->pad_width[i].size() == 2)
- << "Each pad width element should be a pair but at index " << i
- << " there are " << param->pad_width[i].size() << " elements.";
+ << "Each pad width element should be a pair but at index " << i << " there are "
+ << param->pad_width[i].size() << " elements.";
auto width1 = tir::as_const_int(param->pad_width[i][0]);
auto width2 = tir::as_const_int(param->pad_width[i][1]);
CHECK(width1 != nullptr);
CHECK(width2 != nullptr);
- CHECK(*width1 >= 0)
- << "Param width elements should be positive but first pad width at "
- << "index " << i << " is " << *width1 << ".";
- CHECK(*width2 >= 0)
- << "Param width elements should be positive but first pad width at "
- << "index " << i << " is " << *width2 << ".";
+ CHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at "
+ << "index " << i << " is " << *width1 << ".";
+ CHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at "
+ << "index " << i << " is " << *width2 << ".";
if (!data->shape[i].as<tir::AnyNode>()) {
auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2);
}
}
- reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
- data->dtype));
+ reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), data->dtype));
return true;
}
-Array<te::Tensor> PadCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<PadAttrs>();
CHECK(param != nullptr);
auto pad_width = param->pad_width;
- CHECK(pad_width.size() == inputs[0].ndim() &&
- pad_width[0].size() == 2)
- << "Illegal pad_width";
+ CHECK(pad_width.size() == inputs[0].ndim() && pad_width[0].size() == 2) << "Illegal pad_width";
Array<IndexExpr> pad_before;
for (size_t i = 0; i < pad_width.size(); ++i) {
pad_before.push_back(pad_width[i][0]);
pad_after.push_back(pad_width[i][1]);
}
const auto* out_ttype = out_type.as<TensorTypeNode>();
- return Array<te::Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
- tvm::tir::make_const(out_ttype->dtype, param->pad_value),
- "T_pad",
- topi::kElementWise,
- param->pad_mode) };
+ return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after,
+ tvm::tir::make_const(out_ttype->dtype, param->pad_value),
+ "T_pad", topi::kElementWise, param->pad_mode)};
}
// Handler to create a call to the padding op used by front-end FFI
-Expr MakePad(Expr data,
- Array<Array<IndexExpr> > pad_width,
- double pad_value,
- std::string pad_mode) {
+Expr MakePad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value, std::string pad_mode) {
auto attrs = make_object<PadAttrs>();
attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.pad")
-.set_body_typed(MakePad);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.pad").set_body_typed(MakePad);
RELAY_REGISTER_OP("nn.pad")
-.describe(R"code(Pad for n-D tensor.
+ .describe(R"code(Pad for n-D tensor.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<PadAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("Pad", PadRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
-.set_attr<TOpPattern>("TOpPattern", kInjective)
-.set_attr<FTVMCompute>("FTVMCompute", PadCompute);
-
+ .set_attrs_type<PadAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("Pad", PadRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
+ .set_attr<TOpPattern>("TOpPattern", kInjective)
+ .set_attr<FTVMCompute>("FTVMCompute", PadCompute);
// relay.nn.mirror_pad
TVM_REGISTER_NODE_TYPE(MirrorPadAttrs);
-bool MirrorPadRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool MirrorPadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
// check that pad widths match lengths
CHECK(data->shape.size() == param->pad_width.size())
- << "There should be as many pad width pairs as shape dimensions "
- << "but the shape has " << data->shape.size() << " dimensions "
- << "and there are " << param->pad_width.size() << " pad width pairs.";
+ << "There should be as many pad width pairs as shape dimensions "
+ << "but the shape has " << data->shape.size() << " dimensions "
+ << "and there are " << param->pad_width.size() << " pad width pairs.";
// each pad width element should be a pair of positive integers
std::vector<IndexExpr> oshape;
for (size_t i = 0; i < param->pad_width.size(); i++) {
CHECK(param->pad_width[i].size() == 2)
- << "Each pad width element should be a pair but at index " << i
- << " there are " << param->pad_width[i].size() << " elements.";
+ << "Each pad width element should be a pair but at index " << i << " there are "
+ << param->pad_width[i].size() << " elements.";
auto width1 = tir::as_const_int(param->pad_width[i][0]);
auto width2 = tir::as_const_int(param->pad_width[i][1]);
CHECK(width1 != nullptr);
CHECK(width2 != nullptr);
- CHECK(*width1 >= 0)
- << "Param width elements should be positive but first pad width at "
- << "index " << i << " is " << *width1 << ".";
- CHECK(*width2 >= 0)
- << "Param width elements should be positive but first pad width at "
- << "index " << i << " is " << *width2 << ".";
+ CHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at "
+ << "index " << i << " is " << *width1 << ".";
+ CHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at "
+ << "index " << i << " is " << *width2 << ".";
auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2);
oshape.push_back(data->shape[i] + padding);
}
- reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
- data->dtype));
+ reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), data->dtype));
return true;
}
// Handler to create a call to the padding op used by front-end FFI
-Expr MakeMirrorPad(Expr data, Array<Array<IndexExpr> > pad_width, std::string mode) {
+Expr MakeMirrorPad(Expr data, Array<Array<IndexExpr>> pad_width, std::string mode) {
auto attrs = make_object<MirrorPadAttrs>();
attrs->mode = mode;
attrs->pad_width = std::move(pad_width);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad")
-.set_body_typed(MakeMirrorPad);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad").set_body_typed(MakeMirrorPad);
RELAY_REGISTER_OP("nn.mirror_pad")
-.describe(R"code(MirrorPad for n-D tensor.
+ .describe(R"code(MirrorPad for n-D tensor.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<MirrorPadAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("MirrorPad", MirrorPadRel)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_attrs_type<MirrorPadAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("MirrorPad", MirrorPadRel)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
* \file pooling.cc
* \brief Pooling operators
*/
-#include <tvm/tir/data_layout.h>
+#include <topi/nn/pooling.h>
+#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
-#include <tvm/relay/attrs/nn.h>
-#include <topi/nn/pooling.h>
+#include <tvm/tir/data_layout.h>
+
#include <vector>
+
#include "../../transforms/infer_layout_util.h"
namespace tvm {
TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs);
template <typename T>
-Array<Array<Layout> > PoolInferCorrectLayout(
- const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types) {
+Array<Array<Layout> > PoolInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
- T *params = const_cast<T*>(attrs.as<T>());
+ T* params = const_cast<T*>(attrs.as<T>());
if (new_in_layouts.defined()) {
// Set the pool with the new layout.
}
template <typename T>
-Expr MakeMaxPool(Expr data,
- Array<IndexExpr> pool_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- std::string layout,
- bool ceil_mode,
+Expr MakeMaxPool(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, std::string layout, bool ceil_mode,
std::string op_name) {
auto attrs = make_object<T>();
attrs->pool_size = std::move(pool_size);
}
template <typename T>
-Expr MakeAvgPool(Expr data,
- Array<IndexExpr> pool_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- std::string layout,
- bool ceil_mode,
- bool count_include_pad,
- std::string op_name) {
+Expr MakeAvgPool(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, std::string layout, bool ceil_mode,
+ bool count_include_pad, std::string op_name) {
auto attrs = make_object<T>();
attrs->pool_size = std::move(pool_size);
attrs->strides = std::move(strides);
}
template <typename AttrType>
-bool Pool2DRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool Pool2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
Layout layout(param->layout);
CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
- << "Invalid layout " << layout
- << ". Pool2D layout must have H and W, which cannot be split";
+ << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
oshape[hidx] = dshape[hidx];
} else {
if (param->ceil_mode) {
- oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
- param->strides[0] - 1) / param->strides[0]) + 1;
+ oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + param->strides[0] - 1) /
+ param->strides[0]) +
+ 1;
} else {
oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
}
oshape[widx] = dshape[widx];
} else {
if (param->ceil_mode) {
- oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] +
- param->strides[1] - 1) / param->strides[1]) + 1;
+ oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + param->strides[1] - 1) /
+ param->strides[1]) +
+ 1;
} else {
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1;
}
return true;
}
-template<typename AttrType, topi::nn::PoolType mode>
-Array<te::Tensor> Pool2DCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+template <typename AttrType, topi::nn::PoolType mode>
+Array<te::Tensor> Pool2DCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>();
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool2d does not support input split on width";
- CHECK(inputs[0].ndim() == 4U ||
- inputs[0].ndim() == 5U ||
- inputs[0].ndim() == 6U)
+ CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U || inputs[0].ndim() == 6U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (e.g. NCHWc on for vector instructions)"
<< " or 6-D input (e.g. NCHWnc for tensor accelerators)";
}
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool2DAttrs*>(param)->count_include_pad;
- return Array<te::Tensor>{
- topi::nn::pool(inputs[0], pool_size, strides, padding,
- mode, ceil_mode, layout.name(), count_include_pad)};
+ return Array<te::Tensor>{topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode,
+ layout.name(), count_include_pad)};
} else {
return Array<te::Tensor>{
- topi::nn::pool(inputs[0], pool_size, strides, padding,
- mode, ceil_mode, layout.name())};
+ topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())};
}
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d")
-.set_body_typed([](Expr data,
- Array<IndexExpr> pool_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- std::string layout,
- bool ceil_mode) {
- return MakeMaxPool<MaxPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
- "nn.max_pool2d");
-});
-
+ .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, std::string layout, bool ceil_mode) {
+ return MakeMaxPool<MaxPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
+ "nn.max_pool2d");
+ });
RELAY_REGISTER_OP("nn.max_pool2d")
-.describe(R"code(Max pooling operation for two dimensional data.
+ .describe(R"code(Max pooling operation for two dimensional data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
equation.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<MaxPool2DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("MaxPool2D", Pool2DRel<MaxPool2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<MaxPool2DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<MaxPool2DAttrs, topi::nn::kMaxPool>);
-
+ .set_attrs_type<MaxPool2DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("MaxPool2D", Pool2DRel<MaxPool2DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<MaxPool2DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<MaxPool2DAttrs, topi::nn::kMaxPool>);
// AvgPool2D
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d")
-.set_body_typed([](Expr data,
- Array<IndexExpr> pool_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- std::string layout,
- bool ceil_mode,
- bool count_include_pad) {
- return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
- count_include_pad, "nn.avg_pool2d");
-});
+ .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, std::string layout, bool ceil_mode,
+ bool count_include_pad) {
+ return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
+ count_include_pad, "nn.avg_pool2d");
+ });
RELAY_REGISTER_OP("nn.avg_pool2d")
-.describe(R"code(
+ .describe(R"code(
Average pooling operation for one dimensional data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
equation.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<AvgPool2DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("AvgPool2D", Pool2DRel<AvgPool2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<AvgPool2DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
+ .set_attrs_type<AvgPool2DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("AvgPool2D", Pool2DRel<AvgPool2DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<AvgPool2DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
// relay.nn.global_pool_2d & relay.nn.max_pool_2d
TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs);
-bool GlobalPool2DRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool GlobalPool2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
- if (data == nullptr) { return false; }
+ if (data == nullptr) {
+ return false;
+ }
const auto dshape = data->shape;
CHECK_GE(dshape.size(), 2U)
<< "Pool2D only support input >= 2-D: input must have height and width";
Layout layout(param->layout);
CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
- << "Invalid layout " << layout
- << ". Pool2D layout must have H and W, which cannot be split";
+ << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
return true;
}
-
-template<topi::nn::PoolType mode>
-Array<te::Tensor> GlobalPool2DCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+template <topi::nn::PoolType mode>
+Array<te::Tensor> GlobalPool2DCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(tir::BijectiveLayout(layout, kNCHW).defined())
- << "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
+ << "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
- << "global_avg_pool2d does not support input split on height";
+ << "global_avg_pool2d does not support input split on height";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
- << "global_avg_pool2d does not support input split on width";
+ << "global_avg_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
- << "Pool2D only support 4-D input (e.g., NCHW)"
- << " or 5-D input (last dimension is a split of channel)";
- return Array<te::Tensor>{
- topi::nn::global_pool(inputs[0], mode, layout.name()) };
+ << "Pool2D only support 4-D input (e.g., NCHW)"
+ << " or 5-D input (last dimension is a split of channel)";
+ return Array<te::Tensor>{topi::nn::global_pool(inputs[0], mode, layout.name())};
}
-Expr MakeGlobalAvgPool2D(Expr data,
- std::string layout) {
+Expr MakeGlobalAvgPool2D(Expr data, std::string layout) {
auto attrs = make_object<GlobalPool2DAttrs>();
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.global_avg_pool2d");
return Call(op, {data}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.global_avg_pool2d")
-.set_body_typed(MakeGlobalAvgPool2D);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.global_avg_pool2d").set_body_typed(MakeGlobalAvgPool2D);
// GlobalAvgPool
RELAY_REGISTER_OP("nn.global_avg_pool2d")
-.describe(R"code(Global average pooling operation for 2D data.
+ .describe(R"code(Global average pooling operation for 2D data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<GlobalPool2DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- PoolInferCorrectLayout<GlobalPool2DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kAvgPool>);
+ .set_attrs_type<GlobalPool2DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("GlobalAvgPool2D", GlobalPool2DRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<GlobalPool2DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kAvgPool>);
// GlobalMaxPool
-Expr MakeGlobalMaxPool2D(Expr data,
- std::string layout) {
+Expr MakeGlobalMaxPool2D(Expr data, std::string layout) {
auto attrs = make_object<GlobalPool2DAttrs>();
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.global_max_pool2d");
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d")
-.set_body_typed(MakeGlobalMaxPool2D);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d").set_body_typed(MakeGlobalMaxPool2D);
RELAY_REGISTER_OP("nn.global_max_pool2d")
-.describe(R"code(Global max pooling operation for 2D data.
+ .describe(R"code(Global max pooling operation for 2D data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<GlobalPool2DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- PoolInferCorrectLayout<GlobalPool2DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kMaxPool>);
-
+ .set_attrs_type<GlobalPool2DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("GlobalMaxPool2D", GlobalPool2DRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<GlobalPool2DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kMaxPool>);
// relay.nn.adaptive_pool_2d
TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
-bool AdaptivePool2DRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool AdaptivePool2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
- if (data == nullptr) { return false; }
+ if (data == nullptr) {
+ return false;
+ }
const auto dshape = data->shape;
CHECK_GE(dshape.size(), 2U)
- << "Pool2D only support input >= 2-D: input must have height and width";
+ << "Pool2D only support input >= 2-D: input must have height and width";
const auto* param = attrs.as<AdaptivePool2DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
- << "Invalid layout " << layout
- << ". Pool2D layout must have H and W, which cannot be split";
+ << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
Array<IndexExpr> oshape(dshape);
auto output_size = param->output_size;
- CHECK_LE(output_size.size(), 2U)
- << "output_size can have up to 2 elements.";
+ CHECK_LE(output_size.size(), 2U) << "output_size can have up to 2 elements.";
IndexExpr output_height, output_width;
if (output_size.empty()) {
output_height = dshape[hidx];
return true;
}
-template<topi::nn::PoolType mode>
-Array<te::Tensor> AdaptivePool2DCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+template <topi::nn::PoolType mode>
+Array<te::Tensor> AdaptivePool2DCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AdaptivePool2DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(tir::BijectiveLayout(layout, kNCHW).defined())
- << "Adaptive pool2d currently only supports layouts that are convertible from NCHW";
+ << "Adaptive pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
- << "Adaptive pool2d does not support input split on height";
+ << "Adaptive pool2d does not support input split on height";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
- << "Adaptive pool2d does not support input split on width";
+ << "Adaptive pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
- << "Pool2D only support 4-D input (e.g., NCHW)"
- << " or 5-D input (last dimension is a split of channel)";
+ << "Pool2D only support 4-D input (e.g., NCHW)"
+ << " or 5-D input (last dimension is a split of channel)";
auto output_size = param->output_size;
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
output_height = output_size[0];
output_width = output_size[1];
}
- return Array<te::Tensor>{
- topi::nn::adaptive_pool(inputs[0], Array<IndexExpr>{ output_height, output_width },
- mode, layout.name()) };
+ return Array<te::Tensor>{topi::nn::adaptive_pool(
+ inputs[0], Array<IndexExpr>{output_height, output_width}, mode, layout.name())};
}
// relay.nn.adaptive_avg_pool2d
-Expr MakeAdaptiveAvgPool2D(Expr data,
- Array<IndexExpr> output_size,
- std::string layout) {
+Expr MakeAdaptiveAvgPool2D(Expr data, Array<IndexExpr> output_size, std::string layout) {
auto attrs = make_object<AdaptivePool2DAttrs>();
attrs->output_size = std::move(output_size);
attrs->layout = std::move(layout);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d")
-.set_body_typed(MakeAdaptiveAvgPool2D);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d").set_body_typed(MakeAdaptiveAvgPool2D);
RELAY_REGISTER_OP("nn.adaptive_avg_pool2d")
- .describe(R"code(Adaptive average pooling operation for 2D data.
+ .describe(R"code(Adaptive average pooling operation for 2D data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
(batch_size, channels, output_height, output_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<AdaptivePool2DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(10)
-.add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- PoolInferCorrectLayout<AdaptivePool2DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", AdaptivePool2DCompute<topi::nn::kAvgPool>);
+ .set_attrs_type<AdaptivePool2DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(10)
+ .add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ PoolInferCorrectLayout<AdaptivePool2DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", AdaptivePool2DCompute<topi::nn::kAvgPool>);
// relay.nn.adaptive_max_pool2d
-Expr MakeAdaptiveMaxPool2D(Expr data,
- Array<IndexExpr> output_size,
- std::string layout) {
+Expr MakeAdaptiveMaxPool2D(Expr data, Array<IndexExpr> output_size, std::string layout) {
auto attrs = make_object<AdaptivePool2DAttrs>();
attrs->output_size = std::move(output_size);
attrs->layout = std::move(layout);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d")
-.set_body_typed(MakeAdaptiveMaxPool2D);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d").set_body_typed(MakeAdaptiveMaxPool2D);
RELAY_REGISTER_OP("nn.adaptive_max_pool2d")
- .describe(R"code(Adaptive max pooling operation for 2D data.
+ .describe(R"code(Adaptive max pooling operation for 2D data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
(batch_size, channels, output_height, output_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<AdaptivePool2DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(10)
-.add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- PoolInferCorrectLayout<AdaptivePool2DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", AdaptivePool2DCompute<topi::nn::kMaxPool>);
-
+ .set_attrs_type<AdaptivePool2DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(10)
+ .add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ PoolInferCorrectLayout<AdaptivePool2DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", AdaptivePool2DCompute<topi::nn::kMaxPool>);
TVM_REGISTER_NODE_TYPE(AdaptivePool3DAttrs);
-bool AdaptivePool3DRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool AdaptivePool3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
- if (data == nullptr) { return false; }
+ if (data == nullptr) {
+ return false;
+ }
const auto dshape = data->shape;
CHECK_GE(dshape.size(), 3U)
- << "Pool3D only support input >= 3-D: input must have depth, height and width";
+ << "Pool3D only support input >= 3-D: input must have depth, height and width";
const auto* param = attrs.as<AdaptivePool3DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) &&
layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) &&
- !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
- << "Invalid layout " << layout
- << ". Pool3D layout must have D, H and W, which cannot be split";
+ !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
+ << "Invalid layout " << layout
+ << ". Pool3D layout must have D, H and W, which cannot be split";
const auto didx = layout.IndexOf(LayoutAxis::Get('D'));
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
Array<IndexExpr> oshape(dshape);
auto output_size = param->output_size;
- CHECK_LE(output_size.size(), 3U)
- << "output_size can have up to 3 elements.";
+ CHECK_LE(output_size.size(), 3U) << "output_size can have up to 3 elements.";
IndexExpr output_depth, output_height, output_width;
if (output_size.empty()) {
output_depth = dshape[didx];
return true;
}
-template<topi::nn::PoolType mode>
-Array<te::Tensor> AdaptivePool3DCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+template <topi::nn::PoolType mode>
+Array<te::Tensor> AdaptivePool3DCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCDHW("NCDHW");
const auto* param = attrs.as<AdaptivePool3DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(tir::BijectiveLayout(layout, kNCDHW).defined())
- << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW";
+ << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1)
- << "Adaptive pool3d does not support input split on depth";
+ << "Adaptive pool3d does not support input split on depth";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
- << "Adaptive pool3d does not support input split on height";
+ << "Adaptive pool3d does not support input split on height";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
- << "Adaptive pool3d does not support input split on width";
+ << "Adaptive pool3d does not support input split on width";
CHECK(inputs[0].ndim() == 5U || inputs[0].ndim() == 6U)
- << "Pool3D only support 5-D input (e.g., NCDHW)"
- << " or 6-D input (last dimension is a split of channel)";
+ << "Pool3D only support 5-D input (e.g., NCDHW)"
+ << " or 6-D input (last dimension is a split of channel)";
auto output_size = param->output_size;
const auto didx = layout.IndexOf(LayoutAxis::Get('D'));
output_width = output_size[2];
}
- auto osize = Array<IndexExpr>{ output_depth, output_height, output_width };
- return Array<te::Tensor> {
- topi::nn::adaptive_pool3d(inputs[0], osize, mode, layout.name())
- };
+ auto osize = Array<IndexExpr>{output_depth, output_height, output_width};
+ return Array<te::Tensor>{topi::nn::adaptive_pool3d(inputs[0], osize, mode, layout.name())};
}
// relay.nn.adaptive_max_pool3d
-Expr MakeAdaptiveMaxPool3D(Expr data,
- Array<IndexExpr> output_size,
- std::string layout) {
+Expr MakeAdaptiveMaxPool3D(Expr data, Array<IndexExpr> output_size, std::string layout) {
auto attrs = make_object<AdaptivePool3DAttrs>();
attrs->output_size = std::move(output_size);
attrs->layout = std::move(layout);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d")
-.set_body_typed(MakeAdaptiveMaxPool3D);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d").set_body_typed(MakeAdaptiveMaxPool3D);
RELAY_REGISTER_OP("nn.adaptive_max_pool3d")
- .describe(R"code(Adaptive max pooling operation for 3D data.
+ .describe(R"code(Adaptive max pooling operation for 3D data.
- **data**: This depends on the `layout` parameter. Input is 5D array of shape
(batch_size, channels, depth, height, width) if `layout` is `NCDHW`.
(batch_size, channels, output_depth, output_height, output_width) if `layout` is `NCDHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<AdaptivePool3DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(10)
-.add_type_rel("AdaptiveMaxPool3D", AdaptivePool3DRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- PoolInferCorrectLayout<AdaptivePool3DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", AdaptivePool3DCompute<topi::nn::kMaxPool>);
+ .set_attrs_type<AdaptivePool3DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(10)
+ .add_type_rel("AdaptiveMaxPool3D", AdaptivePool3DRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ PoolInferCorrectLayout<AdaptivePool3DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", AdaptivePool3DCompute<topi::nn::kMaxPool>);
// relay.nn.adaptive_max_pool3d
-Expr MakeAdaptiveAvgPool3D(Expr data,
- Array<IndexExpr> output_size,
- std::string layout) {
+Expr MakeAdaptiveAvgPool3D(Expr data, Array<IndexExpr> output_size, std::string layout) {
auto attrs = make_object<AdaptivePool3DAttrs>();
attrs->output_size = std::move(output_size);
attrs->layout = std::move(layout);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d")
-.set_body_typed(MakeAdaptiveAvgPool3D);
+TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d").set_body_typed(MakeAdaptiveAvgPool3D);
RELAY_REGISTER_OP("nn.adaptive_avg_pool3d")
- .describe(R"code(Adaptive avg pooling operation for 3D data.
+ .describe(R"code(Adaptive avg pooling operation for 3D data.
- **data**: This depends on the `layout` parameter. Input is 5D array of shape
(batch_size, channels, depth, height, width) if `layout` is `NCDHW`.
- **output_size**: If this argument is not provided, input depth, height and width will be used
- **out**: This depends on the `layout` parameter. Output is 5D array of shape
(batch_size, channels, output_depth, output_height, output_width) if `layout` is `NCDHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<AdaptivePool3DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(10)
-.add_type_rel("AdaptiveAvgPool3D", AdaptivePool3DRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- PoolInferCorrectLayout<AdaptivePool3DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", AdaptivePool3DCompute<topi::nn::kAvgPool>);
-
+ .set_attrs_type<AdaptivePool3DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(10)
+ .add_type_rel("AdaptiveAvgPool3D", AdaptivePool3DRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ PoolInferCorrectLayout<AdaptivePool3DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", AdaptivePool3DCompute<topi::nn::kAvgPool>);
bool Pool2DGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
}
template <typename AttrType, topi::nn::PoolType mode>
-Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>();
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool2DAttrs*>(param)->count_include_pad;
return Array<te::Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
- mode, ceil_mode, layout.name(), count_include_pad)};
+ mode, ceil_mode, layout.name(),
+ count_include_pad)};
} else {
return Array<te::Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
- mode, ceil_mode, layout.name())};
+ mode, ceil_mode, layout.name())};
}
}
-
// MaxPool2DGrad
Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
- Array<IndexExpr> strides, Array<IndexExpr> padding, std::string layout, bool ceil_mode) {
+ Array<IndexExpr> strides, Array<IndexExpr> padding, std::string layout,
+ bool ceil_mode) {
auto attrs = make_object<MaxPool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = std::move(strides);
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad);
-
RELAY_REGISTER_OP("nn.max_pool2d_grad")
.describe(R"code(Gradient of max pooling operation for two dimensional data.
(batch_size, channels, height, width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<MaxPool2DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("MaxPool2DGrad", Pool2DGradRel)
-.set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<MaxPool2DAttrs, topi::nn::kMaxPool>);
-
+ .set_attrs_type<MaxPool2DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("MaxPool2DGrad", Pool2DGradRel)
+ .set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<MaxPool2DAttrs, topi::nn::kMaxPool>);
// AvgPool2DGrad
Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
- Array<IndexExpr> strides, Array<IndexExpr> padding, std::string layout, bool ceil_mode,
- bool count_include_pad) {
+ Array<IndexExpr> strides, Array<IndexExpr> padding, std::string layout,
+ bool ceil_mode, bool count_include_pad) {
auto attrs = make_object<AvgPool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = std::move(strides);
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad);
-
RELAY_REGISTER_OP("nn.avg_pool2d_grad")
.describe(R"code(Gradient of average pooling operation for two dimensional data.
(batch_size, channels, height, width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<MaxPool2DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("MaxPool2DGrad", Pool2DGradRel)
-.set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
-
+ .set_attrs_type<MaxPool2DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("MaxPool2DGrad", Pool2DGradRel)
+ .set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
// relay.nn.max_pool1d & relay.nn.avg_pool1d
TVM_REGISTER_NODE_TYPE(MaxPool1DAttrs);
TVM_REGISTER_NODE_TYPE(AvgPool1DAttrs);
template <typename AttrType>
-bool Pool1DRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool Pool1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const auto dshape = data->shape;
- CHECK_GE(dshape.size(), 1U)
- << "Pool1D only support input >= 1-D: input must have width";
+ CHECK_GE(dshape.size(), 1U) << "Pool1D only support input >= 1-D: input must have width";
const auto param = attrs.as<AttrType>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('w')))
- << "Invalid layout " << layout
- << ". Pool1D layout must have W, which cannot be split";
+ << "Invalid layout " << layout << ". Pool1D layout must have W, which cannot be split";
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
oshape[widx] = dshape[widx];
} else {
if (param->ceil_mode) {
- oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] +
- param->strides[0] - 1) / param->strides[0]) + 1;
+ oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] + param->strides[0] - 1) /
+ param->strides[0]) +
+ 1;
} else {
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0]) / param->strides[0]) + 1;
}
return true;
}
-
-template<typename AttrType, topi::nn::PoolType mode>
-Array<te::Tensor> Pool1DCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+template <typename AttrType, topi::nn::PoolType mode>
+Array<te::Tensor> Pool1DCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCW("NCW");
const auto* param = attrs.as<AttrType>();
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool1d does not support input split on width";
- CHECK(inputs[0].ndim() == 3U ||
- inputs[0].ndim() == 4U ||
- inputs[0].ndim() == 5U)
+ CHECK(inputs[0].ndim() == 3U || inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool1D only support 3-D input (e.g., NCW)"
<< " or 4-D input (e.g. NCWc on for vector instructions)"
<< " or 5-D input (e.g. NCWnc for tensor accelerators)";
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool1DAttrs*>(param)->count_include_pad;
- return Array<te::Tensor>{
- topi::nn::pool1d(inputs[0], pool_size, strides, padding,
- mode, ceil_mode, layout.name(), count_include_pad)};
+ return Array<te::Tensor>{topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode,
+ ceil_mode, layout.name(), count_include_pad)};
} else {
return Array<te::Tensor>{
- topi::nn::pool1d(inputs[0], pool_size, strides, padding,
- mode, ceil_mode, layout.name())};
+ topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())};
}
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool1d")
-.set_body_typed([](Expr data,
- Array<IndexExpr> pool_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- std::string layout,
- bool ceil_mode) {
- return MakeMaxPool<MaxPool1DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
- "nn.max_pool1d");
-});
+ .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, std::string layout, bool ceil_mode) {
+ return MakeMaxPool<MaxPool1DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
+ "nn.max_pool1d");
+ });
RELAY_REGISTER_OP("nn.max_pool1d")
-.describe(R"code(Max pooling operation for one dimensional data.
+ .describe(R"code(Max pooling operation for one dimensional data.
- **data**: This depends on the `layout` parameter. Input is 3D array of shape
(batch_size, channels, width) if `layout` is `NCW`.
equation.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<MaxPool1DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("MaxPool1D", Pool1DRel<MaxPool1DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<MaxPool1DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", Pool1DCompute<MaxPool1DAttrs, topi::nn::kMaxPool>);
-
+ .set_attrs_type<MaxPool1DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("MaxPool1D", Pool1DRel<MaxPool1DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<MaxPool1DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", Pool1DCompute<MaxPool1DAttrs, topi::nn::kMaxPool>);
// AvgPool1D
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool1d")
-.set_body_typed([](Expr data,
- Array<IndexExpr> pool_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- std::string layout,
- bool ceil_mode,
- bool count_include_pad) {
- return MakeAvgPool<AvgPool1DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
- count_include_pad, "nn.avg_pool1d");
-});
+ .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, std::string layout, bool ceil_mode,
+ bool count_include_pad) {
+ return MakeAvgPool<AvgPool1DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
+ count_include_pad, "nn.avg_pool1d");
+ });
RELAY_REGISTER_OP("nn.avg_pool1d")
-.describe(R"code(
+ .describe(R"code(
Average pooling operation for one dimensional data.
- **data**: This depends on the `layout` parameter. Input is 3D array of shape
equation.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<AvgPool1DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("AvgPool1D", Pool1DRel<AvgPool1DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<AvgPool1DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", Pool1DCompute<AvgPool1DAttrs, topi::nn::kAvgPool>);
-
+ .set_attrs_type<AvgPool1DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("AvgPool1D", Pool1DRel<AvgPool1DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<AvgPool1DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", Pool1DCompute<AvgPool1DAttrs, topi::nn::kAvgPool>);
// relay.nn.max_pool3d & relay.nn.avg_pool3d
TVM_REGISTER_NODE_TYPE(MaxPool3DAttrs);
TVM_REGISTER_NODE_TYPE(AvgPool3DAttrs);
template <typename AttrType>
-bool Pool3DRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool Pool3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) &&
layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) &&
!layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
- << "Invalid layout " << layout
- << ". Pool3D layout must have D, H and W, which cannot be split";
+ << "Invalid layout " << layout
+ << ". Pool3D layout must have D, H and W, which cannot be split";
const auto didx = layout.IndexOf(LayoutAxis::Get('D'));
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
oshape[ii] = dshape[ii];
} else {
if (param->ceil_mode) {
- oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i] +
- param->strides[i] - 1) / param->strides[i]) + 1;
+ oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i] + param->strides[i] - 1) /
+ param->strides[i]) +
+ 1;
} else {
oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i]) / param->strides[i]) + 1;
}
return true;
}
-
-template<typename AttrType, topi::nn::PoolType mode>
-Array<te::Tensor> Pool3DCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+template <typename AttrType, topi::nn::PoolType mode>
+Array<te::Tensor> Pool3DCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCDHW("NCDHW");
const auto* param = attrs.as<AttrType>();
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool3d does not support input split on width";
- CHECK(inputs[0].ndim() == 4U ||
- inputs[0].ndim() == 5U ||
- inputs[0].ndim() == 6U)
+ CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U || inputs[0].ndim() == 6U)
<< "Pool3D only support 5-D input (e.g., NCDHW)"
<< " or 6-D input (e.g. NCDHWc on for vector instructions)"
<< " or 7-D input (e.g. NCDHWnc for tensor accelerators)";
}
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool3DAttrs*>(param)->count_include_pad;
- return Array<te::Tensor>{
- topi::nn::pool3d(inputs[0], pool_size, strides, padding,
- mode, ceil_mode, layout.name(), count_include_pad)};
+ return Array<te::Tensor>{topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode,
+ ceil_mode, layout.name(), count_include_pad)};
} else {
return Array<te::Tensor>{
- topi::nn::pool3d(inputs[0], pool_size, strides, padding,
- mode, ceil_mode, layout.name())};
+ topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())};
}
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d")
-.set_body_typed([](Expr data,
- Array<IndexExpr> pool_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- std::string layout,
- bool ceil_mode) {
- return MakeMaxPool<MaxPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
- "nn.max_pool3d");
-});
+ .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, std::string layout, bool ceil_mode) {
+ return MakeMaxPool<MaxPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
+ "nn.max_pool3d");
+ });
RELAY_REGISTER_OP("nn.max_pool3d")
-.describe(R"code(Max pooling operation for three dimensional data.
+ .describe(R"code(Max pooling operation for three dimensional data.
- **data**: This depends on the `layout` parameter. Input is 5D array of shape
(batch_size, channels, depth, height, width) if `layout` is `NCDHW`.
equation.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<MaxPool3DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("MaxPool3D", Pool3DRel<MaxPool3DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<MaxPool3DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", Pool3DCompute<MaxPool3DAttrs, topi::nn::kMaxPool>);
-
+ .set_attrs_type<MaxPool3DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("MaxPool3D", Pool3DRel<MaxPool3DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<MaxPool3DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", Pool3DCompute<MaxPool3DAttrs, topi::nn::kMaxPool>);
// AvgPool3D
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d")
-.set_body_typed([](Expr data,
- Array<IndexExpr> pool_size,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- std::string layout,
- bool ceil_mode,
- bool count_include_pad) {
- return MakeAvgPool<AvgPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
- count_include_pad, "nn.avg_pool3d");
-});
+ .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
+ Array<IndexExpr> padding, std::string layout, bool ceil_mode,
+ bool count_include_pad) {
+ return MakeAvgPool<AvgPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
+ count_include_pad, "nn.avg_pool3d");
+ });
RELAY_REGISTER_OP("nn.avg_pool3d")
-.describe(R"code(
+ .describe(R"code(
Average pooling operation for three dimensional data.
- **data**: This depends on the `layout` parameter. Input is 5D array of shape
equation.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<AvgPool3DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("AvgPool3D", Pool3DRel<AvgPool3DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<AvgPool3DAttrs>)
-.set_attr<FTVMCompute>("FTVMCompute", Pool3DCompute<AvgPool3DAttrs, topi::nn::kAvgPool>);
+ .set_attrs_type<AvgPool3DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("AvgPool3D", Pool3DRel<AvgPool3DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PoolInferCorrectLayout<AvgPool3DAttrs>)
+ .set_attr<FTVMCompute>("FTVMCompute", Pool3DCompute<AvgPool3DAttrs, topi::nn::kAvgPool>);
} // namespace relay
} // namespace tvm
* \brief Property def of nn.sparse_dense operator.
*/
-#include <tvm/tir/data_layout.h>
-#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+
#include <vector>
#include "../../transforms/infer_layout_util.h"
if (weight_data->shape.size() == 3) {
// BSR case.
- Array<IndexExpr> oshape({
- data->shape[0],
- (weight_indptr->shape[0] - 1) * weight_data->shape[1]});
+ Array<IndexExpr> oshape(
+ {data->shape[0], (weight_indptr->shape[0] - 1) * weight_data->shape[1]});
reporter->Assign(types[4], TensorType(oshape, data->dtype));
return true;
}
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
-});
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
+ });
RELAY_REGISTER_OP("nn.sparse_dense")
-.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
+ .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
- **data**: `(x1, x2, ..., xn, input_dim)`
- **weight**: `(units, input_dim)`
- **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<SparseDenseAttrs>()
-.set_num_inputs(4)
-.add_argument("data", "nD Tensor", "Input data.")
-.add_argument("weight_data", "1D Tensor", "Weight data matrix.")
-.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
-.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
-.set_support_level(1)
-.add_type_rel("SparseDense", SparseDenseRel);
+ .set_attrs_type<SparseDenseAttrs>()
+ .set_num_inputs(4)
+ .add_argument("data", "nD Tensor", "Input data.")
+ .add_argument("weight_data", "1D Tensor", "Weight data matrix.")
+ .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
+ .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
+ .set_support_level(1)
+ .add_type_rel("SparseDense", SparseDenseRel);
// relay.nn.sparse_transpose
TVM_REGISTER_NODE_TYPE(SparseTransposeAttrs);
bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
- const TypeReporter& reporter) {
+ const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* sparse_data = types[0].as<TensorTypeNode>();
CHECK_EQ(sparse_data->shape.size(), 1);
return Call(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose")
-.set_body_typed(MakeSparseTranspose);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose").set_body_typed(MakeSparseTranspose);
RELAY_REGISTER_OP("nn.sparse_transpose")
-.describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix
+ .describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix
- **input**: `(N, N)`
- **out**: `(N, N)`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<SparseTransposeAttrs>()
-.set_num_inputs(3)
-.add_argument("sparse_data", "1D Tensor", "Sparse data matrix.")
-.add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.")
-.add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.")
-.set_support_level(1)
-.add_type_rel("SparseTranspose", SparseTransposeRel);
+ .set_attrs_type<SparseTransposeAttrs>()
+ .set_num_inputs(3)
+ .add_argument("sparse_data", "1D Tensor", "Sparse data matrix.")
+ .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.")
+ .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.")
+ .set_support_level(1)
+ .add_type_rel("SparseTranspose", SparseTransposeRel);
} // namespace relay
} // namespace tvm
* \file upsampling.cc
* \brief upsampling operator
*/
-#include <tvm/tir/data_layout.h>
-#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
+#include <tvm/tir/data_layout.h>
+
#include <vector>
+
#include "../op_common.h"
namespace tvm {
TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs);
template <typename T>
-Array<Array<Layout> > UpsamplingInferCorrectLayout(
- const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types) {
+Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
- T *params = const_cast<T*>(attrs.as<T>());
+ T* params = const_cast<T*>(attrs.as<T>());
if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);
Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
- input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
- !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))&&
+ input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
+ !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) &&
(input.IndexOf(LayoutAxis::Get('D')) == -1 ||
- (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
- !input.Contains(LayoutAxis::Get('d'))))) {
- params->layout = input.name(); // modify self to follow the input layout
+ (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
+ !input.Contains(LayoutAxis::Get('d'))))) {
+ params->layout = input.name(); // modify self to follow the input layout
}
}
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}
-bool UpSamplingRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(layout_converter.defined())
- << "UpSampling only support input layouts that are convertible from NCHW."
- << " But got " << in_layout;
+ << "UpSampling only support input layouts that are convertible from NCHW."
+ << " But got " << in_layout;
auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h)));
oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w)));
// assign output type
- reporter->Assign(types[1],
- TensorType(layout_converter.BackwardShape(oshape),
- data->dtype));
+ reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
return true;
}
-
// Positional relay function to create upsampling operator
// used by frontend FFI.
-Expr MakeUpSampling(Expr data,
- double scale_h,
- double scale_w,
- std::string layout,
- std::string method,
- bool align_corners) {
+Expr MakeUpSampling(Expr data, double scale_h, double scale_w, std::string layout,
+ std::string method, bool align_corners) {
auto attrs = make_object<UpSamplingAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling")
-.set_body_typed(MakeUpSampling);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling").set_body_typed(MakeUpSampling);
RELAY_REGISTER_OP("nn.upsampling")
-.describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.
+ .describe(
+ R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height*scale, in_width*scale, channels)
)code" TVM_ADD_FILELINE)
-.set_attrs_type<UpSamplingAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("UpSampling", UpSamplingRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- UpsamplingInferCorrectLayout<UpSamplingAttrs>)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
-
+ .set_attrs_type<UpSamplingAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("UpSampling", UpSamplingRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ UpsamplingInferCorrectLayout<UpSamplingAttrs>)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// UpSampling3D
-bool UpSampling3DRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool UpSampling3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW);
CHECK(layout_converter.defined())
- << "UpSampling3D only support input layouts that are convertible from NCDHW."
- << " But got " << in_layout;
+ << "UpSampling3D only support input layouts that are convertible from NCDHW."
+ << " But got " << in_layout;
auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d)));
oshape.Set(4, tir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w)));
// assign output type
- reporter->Assign(types[1],
- TensorType(layout_converter.BackwardShape(oshape),
- data->dtype));
+ reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
return true;
}
// Positional relay function to create upsampling3d operator
// used by frontend FFI.
-Expr MakeUpSampling3D(Expr data,
- double scale_d,
- double scale_h,
- double scale_w,
- std::string layout,
- std::string method,
- std::string coordinate_transformation_mode) {
+Expr MakeUpSampling3D(Expr data, double scale_d, double scale_h, double scale_w, std::string layout,
+ std::string method, std::string coordinate_transformation_mode) {
auto attrs = make_object<UpSampling3DAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d")
-.set_body_typed(MakeUpSampling3D);
-
+TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d").set_body_typed(MakeUpSampling3D);
RELAY_REGISTER_OP("nn.upsampling3d")
-.describe(R"code(Perform upsampling on input array with nearest neighbour or
+ .describe(R"code(Perform upsampling on input array with nearest neighbour or
bilinear interpolation.
- **data**: data is 5D array of shape
(batch_size, in_depth*scale, in_height*scale, in_width*scale, channels)
)code" TVM_ADD_FILELINE)
-.set_attrs_type<UpSampling3DAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(2)
-.add_type_rel("UpSampling3D", UpSampling3DRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- UpsamplingInferCorrectLayout<UpSampling3DAttrs>)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_attrs_type<UpSampling3DAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(2)
+ .add_type_rel("UpSampling3D", UpSampling3DRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ UpsamplingInferCorrectLayout<UpSampling3DAttrs>)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
-#include <vector>
+
#include <string>
#include <unordered_map>
-#include "type_relations.h"
+#include <vector>
+
#include "../transforms/infer_layout_util.h"
+#include "type_relations.h"
namespace tvm {
namespace relay {
* \param OpName the name of registry.
*/
-#define RELAY_REGISTER_UNARY_OP(OpName) \
- TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
- .set_body_typed([](Expr data) { \
- static const Op& op = Op::Get(OpName); \
- return Call(op, {data}, Attrs(), {}); \
- }); \
- RELAY_REGISTER_OP(OpName) \
- .set_num_inputs(1) \
- .add_argument("data", "Tensor", "The input tensor.") \
- .add_type_rel("Identity", IdentityRel) \
- .set_attr<TOpPattern>("TOpPattern", kElemWise) \
- .set_attr<TOpIsStateful>("TOpIsStateful", false) \
- .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
- ElemwiseArbitraryLayout) \
-
+#define RELAY_REGISTER_UNARY_OP(OpName) \
+ TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr data) { \
+ static const Op& op = Op::Get(OpName); \
+ return Call(op, {data}, Attrs(), {}); \
+ }); \
+ RELAY_REGISTER_OP(OpName) \
+ .set_num_inputs(1) \
+ .add_argument("data", "Tensor", "The input tensor.") \
+ .add_type_rel("Identity", IdentityRel) \
+ .set_attr<TOpPattern>("TOpPattern", kElemWise) \
+ .set_attr<TOpIsStateful>("TOpIsStateful", false) \
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
/*! Quick helper macro
* - Expose a positional make function to construct the node.
*
* \param OpName the name of registry.
*/
-#define RELAY_REGISTER_BINARY_OP(OpName) \
- TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
- .set_body_typed([](Expr lhs, Expr rhs) { \
- static const Op& op = Op::Get(OpName); \
- return Call(op, {lhs, rhs}, Attrs(), {}); \
- }); \
- RELAY_REGISTER_OP(OpName) \
- .set_num_inputs(2) \
- .add_argument("lhs", "Tensor", "The left hand side tensor.") \
- .add_argument("rhs", "Tensor", "The right hand side tensor.") \
- .add_type_rel("Broadcast", BroadcastRel) \
- .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
- .set_attr<TOpIsStateful>("TOpIsStateful", false) \
- .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
- BinaryBroadcastLayout)
+#define RELAY_REGISTER_BINARY_OP(OpName) \
+ TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \
+ static const Op& op = Op::Get(OpName); \
+ return Call(op, {lhs, rhs}, Attrs(), {}); \
+ }); \
+ RELAY_REGISTER_OP(OpName) \
+ .set_num_inputs(2) \
+ .add_argument("lhs", "Tensor", "The left hand side tensor.") \
+ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
+ .add_type_rel("Broadcast", BroadcastRel) \
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
+ .set_attr<TOpIsStateful>("TOpIsStateful", false) \
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", BinaryBroadcastLayout)
// Comparisons
-#define RELAY_REGISTER_CMP_OP(OpName) \
- TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
- .set_body_typed([](Expr lhs, Expr rhs) { \
- static const Op& op = Op::Get(OpName); \
- return Call(op, {lhs, rhs}, Attrs(), {}); \
- }); \
- RELAY_REGISTER_OP(OpName) \
- .set_num_inputs(2) \
- .add_argument("lhs", "Tensor", "The left hand side tensor.") \
- .add_argument("rhs", "Tensor", "The right hand side tensor.") \
- .add_type_rel("BroadcastComp", BroadcastCompRel) \
- .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
- .set_attr<TOpIsStateful>("TOpIsStateful", false) \
- .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
- BinaryBroadcastLayout)
-
+#define RELAY_REGISTER_CMP_OP(OpName) \
+ TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \
+ static const Op& op = Op::Get(OpName); \
+ return Call(op, {lhs, rhs}, Attrs(), {}); \
+ }); \
+ RELAY_REGISTER_OP(OpName) \
+ .set_num_inputs(2) \
+ .add_argument("lhs", "Tensor", "The left hand side tensor.") \
+ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
+ .add_type_rel("BroadcastComp", BroadcastCompRel) \
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
+ .set_attr<TOpIsStateful>("TOpIsStateful", false) \
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", BinaryBroadcastLayout)
/*! \brief A helper class for matching and rewriting operators. */
-template<typename R>
+template <typename R>
class OpMatch {
public:
using MatchFunc =
} else if (padding.size() == 2) {
*pad_w = padding[0] + padding[1];
} else {
- CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found "
- << padding.size();
+ CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " << padding.size();
}
}
*pad_h = padding[0] + padding[2];
*pad_w = padding[1] + padding[3];
} else {
- CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got "
- << padding.size();
+ CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " << padding.size();
}
}
*pad_h = padding[1] + padding[4];
*pad_w = padding[2] + padding[5];
} else {
- CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got "
- << padding.size();
+ CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " << padding.size();
}
}
* \file binary.cc
* \brief binary broadcast operators.
*/
+#include <topi/broadcast.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
-#include <topi/broadcast.h>
-#include "../type_relations.h"
+
#include "../op_common.h"
+#include "../type_relations.h"
namespace tvm {
namespace relay {
-#define RELAY_BINARY_COMPUTE(FTOPI) \
- [] (const Attrs& attrs, \
- const Array<te::Tensor>& inputs, \
- const Type& out_type) -> Array<te::Tensor> { \
- CHECK_EQ(inputs.size(), 2U); \
- return {FTOPI(inputs[0], inputs[1])}; \
- } \
+#define RELAY_BINARY_COMPUTE(FTOPI) \
+ [](const Attrs& attrs, const Array<te::Tensor>& inputs, \
+ const Type& out_type) -> Array<te::Tensor> { \
+ CHECK_EQ(inputs.size(), 2U); \
+ return {FTOPI(inputs[0], inputs[1])}; \
+ }
// Addition
RELAY_REGISTER_BINARY_OP("add")
-.describe("Elementwise add with with broadcasting")
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add));
+ .describe("Elementwise add with with broadcasting")
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add));
// Subtraction
RELAY_REGISTER_BINARY_OP("subtract")
-.describe("Elementwise substract with broadcasting")
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract));
+ .describe("Elementwise substract with broadcasting")
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract));
// Right shift
RELAY_REGISTER_BINARY_OP("right_shift")
-.describe("Elementwise right shift with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
-
+ .describe("Elementwise right shift with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
RELAY_REGISTER_BINARY_OP("left_shift")
-.describe("Elementwise left shift with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift));
-
+ .describe("Elementwise left shift with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift));
RELAY_REGISTER_BINARY_OP("maximum")
-.describe("Elementwise maximum of two tensors with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum));
-
+ .describe("Elementwise maximum of two tensors with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum));
RELAY_REGISTER_BINARY_OP("minimum")
-.describe("Elementwise minimum of two tensors with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum));
-
+ .describe("Elementwise minimum of two tensors with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum));
RELAY_REGISTER_BINARY_OP("divide")
-.describe("Elementwise divide with broadcasting")
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide));
-
+ .describe("Elementwise divide with broadcasting")
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide));
RELAY_REGISTER_BINARY_OP("floor_divide")
-.describe("Elementwise floor divide with broadcasting")
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_divide));
-
+ .describe("Elementwise floor divide with broadcasting")
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_divide));
RELAY_REGISTER_BINARY_OP("multiply")
-.describe("Elementwise multiply with broadcasting")
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply));
-
+ .describe("Elementwise multiply with broadcasting")
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply));
RELAY_REGISTER_BINARY_OP("power")
-.describe("Elementwise power with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power));
-
+ .describe("Elementwise power with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power));
RELAY_REGISTER_BINARY_OP("mod")
-.describe("Elementwise mod with broadcasting")
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
-
+ .describe("Elementwise mod with broadcasting")
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
RELAY_REGISTER_BINARY_OP("floor_mod")
- .describe("Elementwise floor mod with broadcasting")
- .set_support_level(1)
- .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_mod));
-
+ .describe("Elementwise floor mod with broadcasting")
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_mod));
RELAY_REGISTER_BINARY_OP("logical_and")
-.describe("Elementwise logical AND with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and));
-
+ .describe("Elementwise logical AND with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and));
RELAY_REGISTER_BINARY_OP("logical_or")
-.describe("Elementwise logical OR with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
-
+ .describe("Elementwise logical OR with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
RELAY_REGISTER_BINARY_OP("logical_xor")
-.describe("Elementwise logical XOR with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor));
-
+ .describe("Elementwise logical XOR with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor));
RELAY_REGISTER_BINARY_OP("bitwise_and")
-.describe("Elementwise bitwise AND with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and));
-
+ .describe("Elementwise bitwise AND with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and));
RELAY_REGISTER_BINARY_OP("bitwise_or")
-.describe("Elementwise bitwise OR with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or));
-
+ .describe("Elementwise bitwise OR with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or));
RELAY_REGISTER_BINARY_OP("bitwise_xor")
-.describe("Elementwise bitwise XOR with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor));
-
+ .describe("Elementwise bitwise XOR with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor));
RELAY_REGISTER_CMP_OP("equal")
-.describe("Elementwise equal compare with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal));
-
+ .describe("Elementwise equal compare with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal));
RELAY_REGISTER_CMP_OP("not_equal")
-.describe("Elementwise not equal with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal));
-
+ .describe("Elementwise not equal with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal));
RELAY_REGISTER_CMP_OP("less")
-.describe("Elementwise less than with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less));
-
+ .describe("Elementwise less than with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less));
RELAY_REGISTER_CMP_OP("less_equal")
-.describe("Elementwise less than or equal compare with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal));
-
+ .describe("Elementwise less than or equal compare with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal));
RELAY_REGISTER_CMP_OP("greater")
-.describe("Elementwise greater than compare with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater));
-
+ .describe("Elementwise greater than compare with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater));
RELAY_REGISTER_CMP_OP("greater_equal")
-.describe("Elementwise greater than or equal compare with broadcasting")
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal));
+ .describe("Elementwise greater than or equal compare with broadcasting")
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal));
} // namespace relay
} // namespace tvm
* \file reduce.cc
* \brief Reduction operators.
*/
-#include <tvm/relay/expr.h>
-#include <tvm/relay/op.h>
-#include <tvm/relay/attrs/reduce.h>
#include <topi/elemwise.h>
#include <topi/reduction.h>
-#include <numeric>
+#include <tvm/relay/attrs/reduce.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+
#include <limits>
+#include <numeric>
+
#include "../op_common.h"
#include "../type_relations.h"
TVM_REGISTER_NODE_TYPE(ReduceAttrs);
/*!
-* \brief GetReduceAxes, get the new axis from indim and other arguments
-* \param indim Number of dimensions of input data.
-* \param axis The input axis vector.
-* \param exclude Whether 'axis' input given is the excluded axis.
-* \return r_axes The new reduced axes of the output.
-*/
-inline std::vector<int64_t> GetReduceAxes(const uint32_t indim,
- const Array<Integer>& inaxis,
+ * \brief GetReduceAxes, get the new axis from indim and other arguments
+ * \param indim Number of dimensions of input data.
+ * \param axis The input axis vector.
+ * \param exclude Whether 'axis' input given is the excluded axis.
+ * \return r_axes The new reduced axes of the output.
+ */
+inline std::vector<int64_t> GetReduceAxes(const uint32_t indim, const Array<Integer>& inaxis,
bool exclude) {
if (!inaxis.defined()) {
std::vector<int64_t> r_axes(indim);
}
// Check out of bounds error
- CHECK(axis >= 0)
- << "Axis out of bounds in reduce operator.";
- CHECK(axis < indim)
- << "Axis out of bounds in reduce operator.";
+ CHECK(axis >= 0) << "Axis out of bounds in reduce operator.";
+ CHECK(axis < indim) << "Axis out of bounds in reduce operator.";
in_axes.push_back(axis);
}
CHECK(in_axes[in_axes.size() - 1] < indim)
- << "Reduction axis " << in_axes[in_axes.size() - 1]
- << " exceeds input dimensions " << indim;
+ << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input dimensions " << indim;
std::sort(in_axes.begin(), in_axes.end());
std::vector<int64_t> r_axes(r_size);
for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) {
if (j < in_axes.size() && in_axes[j] == i) {
- ++j;
- continue;
+ ++j;
+ continue;
}
r_axes[k++] = i;
}
return r_axes;
}
-
// Get axis under exclude condition.
-Array<Integer> GetExcludeAxes(size_t indim,
- const Array<Integer>& inaxis) {
+Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
CHECK(inaxis.defined()) << "Cannot set exclude when axis=None";
std::vector<bool> axis_flag(indim, true);
for (auto i : inaxis) {
axis = axis + static_cast<int64_t>(indim);
}
// Check out of bounds error
- CHECK_GE(axis, 0)
- << "Axis out of bounds in reduce operator.";
- CHECK_LT(axis, static_cast<int64_t>(indim))
- << "Axis out of bounds in reduce operator.";
+ CHECK_GE(axis, 0) << "Axis out of bounds in reduce operator.";
+ CHECK_LT(axis, static_cast<int64_t>(indim)) << "Axis out of bounds in reduce operator.";
axis_flag[axis] = false;
}
return Array<Array<Layout>>{{ret}, {ret}};
}
-template<typename F>
-Array<te::Tensor> ReduceCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type,
- F f) {
+template <typename F>
+Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type, F f) {
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
if (inputs[0]->shape.size() == 0) {
- return { topi::identity(inputs[0]) };
+ return {topi::identity(inputs[0])};
}
auto axes = param->axis;
if (param->exclude) {
axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
if (axes.size() == 0) {
- return { topi::identity(inputs[0]) };
+ return {topi::identity(inputs[0])};
}
}
- return { f(inputs[0], axes, param->keepdims, false) };
+ return {f(inputs[0], axes, param->keepdims, false)};
}
/*!
-* \brief ReduceShapeImpl get the outshape for the reduction operator
-* \param in_shape Shape of input data.
-* \param param ReduceAttrs details.
-* \param reporter The reporter to report solution to.
-* \return oshape Output shape inferred.
-*/
-inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr> &in_shape,
+ * \brief ReduceShapeImpl get the outshape for the reduction operator
+ * \param in_shape Shape of input data.
+ * \param param ReduceAttrs details.
+ * \param reporter The reporter to report solution to.
+ * \return oshape Output shape inferred.
+ */
+inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& in_shape,
const ReduceAttrs* param,
const TypeReporter& reporter) {
uint32_t indim = in_shape.size();
}
if (is_dynamic_input) {
- CHECK(reporter->Assert(max_shape < tir::make_const(
- DataType::Int(64), std::numeric_limits<int32_t>::max())))
- << "The maximum possible index of reduced shape cannot be more than int32 max.";
+ CHECK(reporter->Assert(max_shape <
+ tir::make_const(DataType::Int(64), std::numeric_limits<int32_t>::max())))
+ << "The maximum possible index of reduced shape cannot be more than int32 max.";
}
if (param->keepdims) {
}
/*!
-* \brief ArgReduceRel Output type and shape relation evaluation function.
-* \param num_inputs Number of input types in the args.
-* \param attrs The additional attributes of the operator.
-* \param reporter The reporter to report solution to.
-* \return false if This relation cannot be resolved. true if this relation has been resolved.
-*/
-bool ArgReduceRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
- const TypeReporter& reporter) {
+ * \brief ArgReduceRel Output type and shape relation evaluation function.
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return false if This relation cannot be resolved. true if this relation has been resolved.
+ */
+bool ArgReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
}
/*!
-* \brief ReduceRel Output type and shape relation evaluation function.
-* \param num_inputs Number of input types in the args.
-* \param attrs The additional attributes of the operator.
-* \param reporter The reporter to report solution to.
-* \return false if This relation cannot be resolved. true if this relation has been resolved.
-*/
-bool ReduceRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+ * \brief ReduceRel Output type and shape relation evaluation function.
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return false if This relation cannot be resolved. true if this relation has been resolved.
+ */
+bool ReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
return true;
}
-#define RELAY_REGISTER_REDUCE_OP(OpName) \
- TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
- .set_body_typed([]( \
- Expr data, \
- Array<Integer> axis, \
- bool keepdims, \
- bool exclude) { \
- auto attrs = make_object<ReduceAttrs>(); \
- attrs->axis = std::move(axis); \
- attrs->keepdims = keepdims; \
- attrs->exclude = exclude; \
- static const Op& op = Op::Get(OpName); \
- return Call(op, {data}, Attrs(attrs), {}); \
- }); \
- RELAY_REGISTER_OP(OpName) \
- .set_num_inputs(1) \
- .add_argument("data", "Tensor", "The input tensor.")
-
-
-Array<te::Tensor> ArgMaxCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+#define RELAY_REGISTER_REDUCE_OP(OpName) \
+ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
+ .set_body_typed([](Expr data, Array<Integer> axis, bool keepdims, bool exclude) { \
+ auto attrs = make_object<ReduceAttrs>(); \
+ attrs->axis = std::move(axis); \
+ attrs->keepdims = keepdims; \
+ attrs->exclude = exclude; \
+ static const Op& op = Op::Get(OpName); \
+ return Call(op, {data}, Attrs(attrs), {}); \
+ }); \
+ RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.")
+
+Array<te::Tensor> ArgMaxCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::argmax);
}
-
RELAY_REGISTER_REDUCE_OP("argmax")
-.describe(R"code(Creates an operation that finds the indices of the maximum
+ .describe(R"code(Creates an operation that finds the indices of the maximum
values over a given axis.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.add_type_rel("ArgReduce", ArgReduceRel)
-.set_attr<FTVMCompute>("FTVMCompute", ArgMaxCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .add_type_rel("ArgReduce", ArgReduceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ArgMaxCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<te::Tensor> ArgMinCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> ArgMinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::argmin);
}
RELAY_REGISTER_REDUCE_OP("argmin")
-.describe(R"code(Creates an operation that finds the indices of the minimum
+ .describe(R"code(Creates an operation that finds the indices of the minimum
values over a given axis.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.add_type_rel("ArgReduce", ArgReduceRel)
-.set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-
-Array<te::Tensor> SumCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .add_type_rel("ArgReduce", ArgReduceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
+
+Array<te::Tensor> SumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::sum);
}
-
RELAY_REGISTER_REDUCE_OP("sum")
-.describe(R"code(Computes the sum of array elements over given axes.
+ .describe(R"code(Computes the sum of array elements over given axes.
Example::
[ 12. 19. 27.]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.add_type_rel("Reduce", ReduceRel)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
-.set_attr<FTVMCompute>("FTVMCompute", SumCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-
-
-Array<te::Tensor> AllCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .add_type_rel("Reduce", ReduceRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
+ .set_attr<FTVMCompute>("FTVMCompute", SumCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
+
+Array<te::Tensor> AllCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::all);
}
-
RELAY_REGISTER_REDUCE_OP("all")
-.describe(R"code(Computes the logical AND of boolean array elements over given axes.
+ .describe(R"code(Computes the logical AND of boolean array elements over given axes.
Example::
[False, True, False]]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.add_type_rel("Reduce", ReduceRel)
-.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .add_type_rel("Reduce", ReduceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", AllCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<te::Tensor> AnyCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> AnyCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::any);
}
-
RELAY_REGISTER_REDUCE_OP("any")
-.describe(R"code(Computes the logical OR of boolean array elements over given axes.
+ .describe(R"code(Computes the logical OR of boolean array elements over given axes.
Example::
[False, True, True]]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.add_type_rel("Reduce", ReduceRel)
-.set_attr<FTVMCompute>("FTVMCompute", AnyCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .add_type_rel("Reduce", ReduceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", AnyCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<te::Tensor> MaxCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> MaxCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::max);
}
RELAY_REGISTER_REDUCE_OP("max")
-.describe(R"code(Computes the max of array elements over given axes.
+ .describe(R"code(Computes the max of array elements over given axes.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.add_type_rel("Reduce", ReduceRel)
-.set_attr<FTVMCompute>("FTVMCompute", MaxCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .add_type_rel("Reduce", ReduceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", MaxCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<te::Tensor> MinCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> MinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::min);
}
-
RELAY_REGISTER_REDUCE_OP("min")
-.describe(R"code(Computes the min of array elements over given axes.
+ .describe(R"code(Computes the min of array elements over given axes.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.add_type_rel("Reduce", ReduceRel)
-.set_attr<FTVMCompute>("FTVMCompute", MinCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .add_type_rel("Reduce", ReduceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", MinCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<te::Tensor> ProdCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> ProdCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::prod);
}
RELAY_REGISTER_REDUCE_OP("prod")
-.describe(R"code(Computes the products of array elements over given axes.
+ .describe(R"code(Computes the products of array elements over given axes.
Example::
[ 36 480 2058]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.add_type_rel("Reduce", ReduceRel)
-.set_attr<FTVMCompute>("FTVMCompute", ProdCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .add_type_rel("Reduce", ReduceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ProdCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
-
-Array<te::Tensor> MeanCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type) {
+Array<te::Tensor> MeanCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
auto axes = param->axis;
- for (int64_t i : GetReduceAxes(inputs[0]->shape.size(),
- param->axis,
- param->exclude)) {
+ for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) {
count *= inputs[0]->shape[i];
}
auto res = ReduceCompute(attrs, inputs, out_type, topi::sum);
return {topi::divide(res[0], count)};
}
-
RELAY_REGISTER_REDUCE_OP("mean")
-.describe(R"code(Computes the mean of array elements over given axes.
+ .describe(R"code(Computes the mean of array elements over given axes.
Example::
[ 2. 3.16666667 4.5]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.add_type_rel("Reduce", ReduceRel)
-.set_attr<FTVMCompute>("FTVMCompute", MeanCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .add_type_rel("Reduce", ReduceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", MeanCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
-bool VarianceRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
return true;
}
-Array<te::Tensor> VarianceCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> VarianceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
auto axes = param->axis;
auto data = inputs[0];
auto mean = inputs[1];
- for (int64_t i : GetReduceAxes(data->shape.size(),
- param->axis,
- param->exclude)) {
+ for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, param->exclude)) {
count *= data->shape[i];
}
std::vector<Integer> expand_shape;
return {var};
}
-Expr MakeVariance(Expr data,
- Expr mean,
- Array<Integer> axis,
- bool keepdims,
- bool exclude) {
+Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude) {
auto attrs = make_object<ReduceAttrs>();
attrs->axis = std::move(axis);
attrs->keepdims = keepdims;
return Call(op, {data, mean}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make._variance")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeVariance, args, rv);
});
RELAY_REGISTER_OP("variance")
-.describe(R"code(Computes the variance of array elements over given axes.
+ .describe(R"code(Computes the variance of array elements over given axes.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ReduceAttrs>()
-.set_support_level(4)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("mean", "Tensor", "The mean tensor.")
-.add_type_rel("Variance", VarianceRel)
-.set_attr<FTVMCompute>("FTVMCompute", VarianceCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
+ .set_attrs_type<ReduceAttrs>()
+ .set_support_level(4)
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("mean", "Tensor", "The mean tensor.")
+ .add_type_rel("Variance", VarianceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", VarianceCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
} // namespace relay
} // namespace tvm
* \file transform.cc
* \brief Transform operators.
*/
-#include <tvm/relay/op.h>
+#include "transform.h"
+
+#include <topi/broadcast.h>
+#include <topi/elemwise.h>
+#include <topi/nn.h>
+#include <topi/reduction.h>
+#include <topi/transform.h>
#include <tvm/ir/error.h>
#include <tvm/relay/attrs/transform.h>
-#include <tvm/tir/op.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/data_layout.h>
+#include <tvm/relay/op.h>
#include <tvm/runtime/packed_func.h>
-#include <topi/transform.h>
-#include <topi/elemwise.h>
-#include <topi/broadcast.h>
-#include <topi/reduction.h>
-#include <topi/nn.h>
+#include <tvm/tir/data_layout.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+
#include <vector>
-#include "../op_common.h"
+
#include "../../../arith/compute_expr.h"
#include "../../transforms/infer_layout_util.h"
#include "../../transforms/pattern_util.h"
-#include "transform.h"
+#include "../op_common.h"
namespace tvm {
namespace relay {
// relay.cast
TVM_REGISTER_NODE_TYPE(CastAttrs);
-bool CastRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool CastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "cast: expect input type to be TensorType but get "
- << types[0];
+ << "cast: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as<CastAttrs>();
- reporter->Assign(types[1], TensorType(
- data->shape, param->dtype));
+ reporter->Assign(types[1], TensorType(data->shape, param->dtype));
return true;
}
-Array<te::Tensor> CastCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> CastCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- const CastAttrs *param = attrs.as<CastAttrs>();
+ const CastAttrs* param = attrs.as<CastAttrs>();
CHECK(param != nullptr);
DataType dtype = param->dtype;
- return { topi::cast(inputs[0], dtype) };
+ return {topi::cast(inputs[0], dtype)};
}
-Expr MakeCast(Expr data,
- DataType dtype) {
+Expr MakeCast(Expr data, DataType dtype) {
auto attrs = make_object<CastAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("cast");
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.ir.cast")
-.set_body_typed(MakeCast);
+TVM_REGISTER_GLOBAL("relay.ir.cast").set_body_typed(MakeCast);
RELAY_REGISTER_OP("cast")
-.describe(R"code(Cast the data into a new data type.
+ .describe(R"code(Cast the data into a new data type.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<CastAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("Cast", CastRel)
-.set_attr<FTVMCompute>("FTVMCompute", CastCompute)
-.set_attr<TOpPattern>("TOpPattern", kElemWise)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
-
+ .set_num_inputs(1)
+ .set_attrs_type<CastAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Cast", CastRel)
+ .set_attr<FTVMCompute>("FTVMCompute", CastCompute)
+ .set_attr<TOpPattern>("TOpPattern", kElemWise)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
// relay.cast_like
-bool CastLikeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool CastLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "cast: expect input type to be TensorType but get "
- << types[0];
+ << "cast: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* dtype_like = types[1].as<TensorTypeNode>();
if (dtype_like == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
- << "cast: expect input type to be TensorType but get "
- << types[1];
+ << "cast: expect input type to be TensorType but get " << types[1];
return false;
}
reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype));
return true;
}
-
-Array<te::Tensor> CastLikeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> CastLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- return { topi::cast(inputs[0], inputs[1]->dtype) };
+ return {topi::cast(inputs[0], inputs[1]->dtype)};
}
-
-Expr MakeCastLike(Expr data,
- Expr dtype_like) {
+Expr MakeCastLike(Expr data, Expr dtype_like) {
static const Op& op = Op::Get("cast_like");
return Call(op, {data, dtype_like}, Attrs(), {});
}
-
-TVM_REGISTER_GLOBAL("relay.ir.cast_like")
-.set_body_typed(MakeCastLike);
+TVM_REGISTER_GLOBAL("relay.ir.cast_like").set_body_typed(MakeCastLike);
RELAY_REGISTER_OP("cast_like")
-.describe(R"code(Cast the data into the type of another tensor.
+ .describe(R"code(Cast the data into the type of another tensor.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("dtype_like", "Tensor", "The tensor to cast to.")
-.set_support_level(3)
-.add_type_rel("CastLike", CastLikeRel)
-.set_attr<FTVMCompute>("FTVMCompute", CastLikeCompute)
-.set_attr<TOpPattern>("TOpPattern", kElemWise)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
-
-
-Array<te::Tensor> ReinterpretCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("dtype_like", "Tensor", "The tensor to cast to.")
+ .set_support_level(3)
+ .add_type_rel("CastLike", CastLikeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", CastLikeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kElemWise)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
+
+Array<te::Tensor> ReinterpretCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const CastAttrs* param = attrs.as<CastAttrs>();
CHECK(param != nullptr);
});
RELAY_REGISTER_OP("reinterpret")
-.describe(R"code(Reinterpret the data into a new data type.
+ .describe(R"code(Reinterpret the data into a new data type.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<CastAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("Reinterpret", CastRel)
-.set_attr<FTVMCompute>("FTVMCompute", ReinterpretCompute)
-.set_attr<TOpPattern>("TOpPattern", kElemWise)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
+ .set_num_inputs(1)
+ .set_attrs_type<CastAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Reinterpret", CastRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ReinterpretCompute)
+ .set_attr<TOpPattern>("TOpPattern", kElemWise)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
// relay.expand_dims
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
-bool ExpandDimsRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool ExpandDimsRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "expand_dims: expect input type to be TensorType but get "
- << types[0];
+ << "expand_dims: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as<ExpandDimsAttrs>();
const int ndim = static_cast<int>(data->shape.size());
const int axis = param->axis;
const int num_newaxis = param->num_newaxis;
- CHECK(num_newaxis >= 0)
- << "expand_dims only accepts `num_newaxis >= 0`"
- << ", but got num_newaxis = " << num_newaxis;
+ CHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
+ << ", but got num_newaxis = " << num_newaxis;
CHECK(-ndim - 1 <= axis && axis <= ndim)
- << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
- << ", but got axis = " << axis
- << ", and data.ndim = " << ndim;
+ << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
+ << ", but got axis = " << axis << ", and data.ndim = " << ndim;
const int pivot = axis < 0 ? ndim + axis + 1 : axis;
std::vector<IndexExpr> oshape;
oshape.reserve(ndim + num_newaxis);
return true;
}
-Array<te::Tensor> ExpandDimsCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> ExpandDimsCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- const ExpandDimsAttrs *param = attrs.as<ExpandDimsAttrs>();
+ const ExpandDimsAttrs* param = attrs.as<ExpandDimsAttrs>();
CHECK(param != nullptr);
- return { topi::expand_dims(inputs[0], param->axis, param->num_newaxis) };
+ return {topi::expand_dims(inputs[0], param->axis, param->num_newaxis)};
}
-Expr MakeExpandDims(Expr data,
- int axis,
- int num_newaxis) {
+Expr MakeExpandDims(Expr data, int axis, int num_newaxis) {
auto attrs = make_object<ExpandDimsAttrs>();
attrs->axis = axis;
attrs->num_newaxis = num_newaxis;
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.expand_dims")
-.set_body_typed(MakeExpandDims);
+TVM_REGISTER_GLOBAL("relay.op._make.expand_dims").set_body_typed(MakeExpandDims);
RELAY_REGISTER_OP("expand_dims")
-.describe(R"code(Insert `num_newaxis` axises at the position given by `axis`
+ .describe(R"code(Insert `num_newaxis` axises at the position given by `axis`
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<ExpandDimsAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(1)
-.add_type_rel("ExpandDims", ExpandDimsRel)
-.set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
-.set_attr<TOpPattern>("TOpPattern", kBroadcast);
+ .set_num_inputs(1)
+ .set_attrs_type<ExpandDimsAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(1)
+ .add_type_rel("ExpandDims", ExpandDimsRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
// relay.concatenate
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
-Array<te::Tensor> ConcatenateCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> ConcatenateCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- const ConcatenateAttrs *param = attrs.as<ConcatenateAttrs>();
+ const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>();
CHECK(param != nullptr);
- return { topi::concatenate(inputs, param->axis) };
+ return {topi::concatenate(inputs, param->axis)};
}
-Expr MakeConcatenate(Expr data,
- int axis) {
+Expr MakeConcatenate(Expr data, int axis) {
auto attrs = make_object<ConcatenateAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("concatenate");
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.concatenate")
-.set_body_typed(MakeConcatenate);
+TVM_REGISTER_GLOBAL("relay.op._make.concatenate").set_body_typed(MakeConcatenate);
RELAY_REGISTER_OP("concatenate")
-.describe(R"code(Concatenate the input tensors along the given axis.
+ .describe(R"code(Concatenate the input tensors along the given axis.
- **data** : A list of tensors.
- **axis** : The axis along which the tensors are concatenated.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ConcatenateAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input list of tensors.")
-.set_support_level(1)
-.add_type_rel("Concatenate", ConcatenateRel<ConcatenateAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout)
-.set_attr<FTVMCompute>("FTVMCompute", ConcatenateCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_attrs_type<ConcatenateAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input list of tensors.")
+ .set_support_level(1)
+ .add_type_rel("Concatenate", ConcatenateRel<ConcatenateAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout)
+ .set_attr<FTVMCompute>("FTVMCompute", ConcatenateCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
TVM_REGISTER_NODE_TYPE(StackAttrs);
-bool StackRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool StackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "cast: expect input type to be TupleType but get "
- << types[0];
+ << "cast: expect input type to be TupleType but get " << types[0];
return false;
}
const auto* param = attrs.as<StackAttrs>();
// Sanity check: axis
int axis = param->axis;
- CHECK(-ndim <= axis && axis < ndim)
- << "stack only accepts `axis` in [-ndim, ndim)"
- << ", but got axis = " << axis
- << ", and ndim = " << ndim;
- axis = axis < 0 ? ndim + axis + 1: axis;
+ CHECK(-ndim <= axis && axis < ndim) << "stack only accepts `axis` in [-ndim, ndim)"
+ << ", but got axis = " << axis << ", and ndim = " << ndim;
+ axis = axis < 0 ? ndim + axis + 1 : axis;
// Sanity check: ndim and dtype.
const DataType dtype = first->dtype;
for (size_t j = 0; j < first->shape.size(); ++j) {
if (j == static_cast<size_t>(axis)) continue;
if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
- throw Error("relay.stack requires all tensors have the same shape "
- "on non-stacking axes");
+ throw Error(
+ "relay.stack requires all tensors have the same shape "
+ "on non-stacking axes");
}
}
return true;
}
-Array<te::Tensor> StackCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> StackCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- const StackAttrs *param = attrs.as<StackAttrs>();
+ const StackAttrs* param = attrs.as<StackAttrs>();
CHECK(param != nullptr);
- return { topi::stack(inputs, param->axis) };
+ return {topi::stack(inputs, param->axis)};
}
-Expr MakeStack(Expr data,
- int axis) {
+Expr MakeStack(Expr data, int axis) {
auto attrs = make_object<StackAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("stack");
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.stack")
-.set_body_typed(MakeStack);
+TVM_REGISTER_GLOBAL("relay.op._make.stack").set_body_typed(MakeStack);
RELAY_REGISTER_OP("stack")
-.describe(R"code(Stack the input tensors along the given axis.
+ .describe(R"code(Stack the input tensors along the given axis.
- **data** : A list of tensors.
- **axis** : The axis along which the tensors are stacked.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<StackAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input list of tensors.")
-.set_support_level(3)
-.add_type_rel("Stack", StackRel)
-.set_attr<FTVMCompute>("FTVMCompute", StackCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_attrs_type<StackAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input list of tensors.")
+ .set_support_level(3)
+ .add_type_rel("Stack", StackRel)
+ .set_attr<FTVMCompute>("FTVMCompute", StackCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
/* relay.transpose */
TVM_REGISTER_NODE_TYPE(TransposeAttrs);
-bool TransposeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool TransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "transpose: expect input type to be TensorType but get "
- << types[0];
+ << "transpose: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as<TransposeAttrs>();
const Array<Integer>& axes = param->axes;
// check dimension match
CHECK(!axes.defined() || static_cast<int>(axes.size()) == ndim)
- << "Dimension mismatch: axes has " << axes.size() << " elements"
- << ", but data.ndim = " << ndim;
+ << "Dimension mismatch: axes has " << axes.size() << " elements"
+ << ", but data.ndim = " << ndim;
// construct int_axes
std::vector<int> int_axes;
int_axes.reserve(ndim);
int64_t axis = e;
// sanity check for axis and ndim
CHECK(-ndim <= axis && axis < ndim)
- << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)"
- << ", but got axis = " << axis
- << ", and data.ndim = " << ndim;
+ << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)"
+ << ", but got axis = " << axis << ", and data.ndim = " << ndim;
axis = axis < 0 ? axis + ndim : axis;
// sanity check for duplication
CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis;
return true;
}
-Array<te::Tensor> TransposeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> TransposeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<TransposeAttrs>();
CHECK(param != nullptr);
- return Array<te::Tensor>{ topi::transpose(inputs[0], param->axes) };
+ return Array<te::Tensor>{topi::transpose(inputs[0], param->axes)};
}
-Expr MakeTranspose(Expr data,
- Array<Integer> axes) {
+Expr MakeTranspose(Expr data, Array<Integer> axes) {
auto attrs = make_object<TransposeAttrs>();
attrs->axes = std::move(axes);
static const Op& op = Op::Get("transpose");
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.transpose")
-.set_body_typed(MakeTranspose);
+TVM_REGISTER_GLOBAL("relay.op._make.transpose").set_body_typed(MakeTranspose);
RELAY_REGISTER_OP("transpose")
-.describe(R"code(Permutes the dimensions of an array.
+ .describe(R"code(Permutes the dimensions of an array.
- **data**: The input data to the operator.
- **axes**: The target axes order, reverse order if not specified.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<TransposeAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("Transpose", TransposeRel)
-.set_attr<FTVMCompute>("FTVMCompute", TransposeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_num_inputs(1)
+ .set_attrs_type<TransposeAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Transpose", TransposeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", TransposeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
/* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
-bool ReshapeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "reshape: expect input type to be TensorType but get "
- << types[0];
+ << "reshape: expect input type to be TensorType but get " << types[0];
return false;
}
oshape.push_back(data_shape[src_idx++]);
} else if (svalue == -1) {
// inference based on rest
- CHECK_LT(infer_idx, 0)
- << "One and only one dim can be inferred";
+ CHECK_LT(infer_idx, 0) << "One and only one dim can be inferred";
infer_idx = i;
oshape.push_back(1);
++src_idx;
Integer d1 = newshape[++i];
Integer d2 = newshape[++i];
if (d1->value == -1) {
- CHECK(d2->value != -1)
- << "Split dims cannot both be -1.";
+ CHECK(d2->value != -1) << "Split dims cannot both be -1.";
used_output_dims.insert(oshape.size());
if (d0.as<Any>()) {
oshape.push_back(Any::make());
}
if (param->reverse) {
- reporter->Assign(types[1], TensorType(
- Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
+ reporter->Assign(types[1],
+ TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
} else {
reporter->Assign(types[1], TensorType(oshape, data->dtype));
}
return true;
}
-Array<te::Tensor> ReshapeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
newshape.push_back(val);
}
}
- return { topi::reshape(inputs[0], newshape) };
+ return {topi::reshape(inputs[0], newshape)};
}
-Expr MakeReshape(Expr data,
- Array<Integer> newshape) {
+Expr MakeReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = false;
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.reshape")
-.set_body_typed(MakeReshape);
+TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape);
RELAY_REGISTER_OP("reshape")
-.describe(R"code(Reshapes the input array.
+ .describe(R"code(Reshapes the input array.
Example::
- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<ReshapeAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("Reshape", ReshapeRel)
-.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
-
+ .set_num_inputs(1)
+ .set_attrs_type<ReshapeAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Reshape", ReshapeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
/*!
-* \brief ReshapeLikeRel User defined type constraint function.
-* \param num_inputs Number of input types in the args.
-* \param attrs The additional attributes of the operator.
-* \param reporter The reporter to report solution to.
-* \return False if the relation has not been resolved, it might be resolved later.
-* True if this relation has been resolved.
-*/
-bool ReshapeLikeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+ * \brief ReshapeLikeRel User defined type constraint function.
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return False if the relation has not been resolved, it might be resolved later.
+ * True if this relation has been resolved.
+ */
+bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
}
if (is_static_shape) {
CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
- << "Reshape inputs size should be compatible.";
+ << "Reshape inputs size should be compatible.";
}
reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype));
return true;
}
-
-Expr MakeReshapeLike(Expr data,
- Expr shape_like) {
+Expr MakeReshapeLike(Expr data, Expr shape_like) {
static const Op& op = Op::Get("reshape_like");
return Call(op, {data, shape_like}, Attrs(), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op._make.reshape_like")
-.set_body_typed(MakeReshapeLike);
-
+TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike);
RELAY_REGISTER_OP("reshape_like")
-.describe(R"code(Reshapes the input array by the size of another array.
+ .describe(R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("shape_like", "Tensor", "Shape tensor.")
-.set_support_level(3)
-.add_type_rel("ReshapeLike", ReshapeLikeRel)
-.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("shape_like", "Tensor", "Shape tensor.")
+ .set_support_level(3)
+ .add_type_rel("ReshapeLike", ReshapeLikeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// ArgWhere
-bool ArgWhereRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool ArgWhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
return true;
}
-TVM_REGISTER_GLOBAL("relay.op._make.argwhere")
-.set_body_typed([](Expr data) {
+TVM_REGISTER_GLOBAL("relay.op._make.argwhere").set_body_typed([](Expr data) {
static const Op& op = Op::Get("argwhere");
return Call(op, {data}, Attrs(), {});
});
RELAY_REGISTER_OP("argwhere")
-.describe(R"doc(Find the indices of elements of a tensor that are
+ .describe(R"doc(Find the indices of elements of a tensor that are
non-zero)doc" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("condition", "Tensor", "The input condition tensor.")
-.add_type_rel("ArgWhere", ArgWhereRel)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_support_level(10);
+ .set_num_inputs(1)
+ .add_argument("condition", "Tensor", "The input condition tensor.")
+ .add_type_rel("ArgWhere", ArgWhereRel)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_support_level(10);
// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
-bool TakeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool TakeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto ndim_indices = static_cast<int>(indices->shape.size());
int axis = static_cast<int>(param->axis->value);
if (axis < 0) axis += ndim_data;
- CHECK_LE(axis, ndim_data)
- << "axis should be with in data shape"
- << ", but got = " << axis;
+ CHECK_LE(axis, ndim_data) << "axis should be with in data shape"
+ << ", but got = " << axis;
oshape.reserve(ndim_data - 1 + ndim_indices);
for (int i = 0; i < axis; ++i) {
for (int i = 0; i < ndim_indices; ++i) {
oshape.emplace_back(indices->shape[i]);
}
- for (int i = axis+1; i < ndim_data; ++i) {
+ for (int i = axis + 1; i < ndim_data; ++i) {
oshape.emplace_back(data->shape[i]);
}
return true;
}
-Array<te::Tensor> TakeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> TakeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
if (!param->axis.defined()) {
- return Array<te::Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
+ return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->mode)};
} else {
- return Array<te::Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
+ return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->axis, param->mode)};
}
}
-Expr MakeTake(Expr data,
- Expr indices,
- Integer axis,
- std::string mode) {
+Expr MakeTake(Expr data, Expr indices, Integer axis, std::string mode) {
auto attrs = make_object<TakeAttrs>();
attrs->axis = std::move(axis);
attrs->mode = std::move(mode);
return Call(op, {data, indices}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.take")
-.set_body_typed(MakeTake);
+TVM_REGISTER_GLOBAL("relay.op._make.take").set_body_typed(MakeTake);
RELAY_REGISTER_OP("take")
-.describe(R"code(Take elements from an array along an axis.
+ .describe(R"code(Take elements from an array along an axis.
When axis is not None, this function does the same thing as 'fancy' indexing
(indexing arrays using arrays); however, it can be easier to use if you need
[ 4., 3.]]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<TakeAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("indices", "Tensor", "The indices tensor.")
-.set_support_level(3)
-.add_type_rel("Take", TakeRel)
-.set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
-
+ .set_attrs_type<TakeAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("indices", "Tensor", "The indices tensor.")
+ .set_support_level(3)
+ .add_type_rel("Take", TakeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// Init ops
TVM_REGISTER_NODE_TYPE(InitOpAttrs);
-bool FullRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
}
CHECK_EQ(fill_value->shape.size(), 0)
- << "Fill value should be a scalar but has dimension "
- << fill_value->shape.size() << ".";
+ << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << ".";
reporter->Assign(types[1], TensorType(param->shape, out_dtype));
return true;
}
-Array<te::Tensor> FullCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
- return { topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]()) };
+ return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
}
-Expr MakeFull(Expr fill_value,
- Array<IndexExpr> shape,
- DataType dtype) {
+Expr MakeFull(Expr fill_value, Array<IndexExpr> shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
return Call(op, {fill_value}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.full")
-.set_body_typed(MakeFull);
+TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull);
RELAY_REGISTER_OP("full")
-.describe(R"code(Fill array with scalar value.
+ .describe(R"code(Fill array with scalar value.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<InitOpAttrs>()
-.set_num_inputs(1)
-.add_argument("fill_value", "double", "The value to fill.")
-.set_support_level(3)
-.add_type_rel("Full", FullRel)
-.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
-.set_attr<TOpPattern>("TOpPattern", kElemWise);
-
-bool InitOpRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+ .set_attrs_type<InitOpAttrs>()
+ .set_num_inputs(1)
+ .add_argument("fill_value", "double", "The value to fill.")
+ .set_support_level(3)
+ .add_type_rel("Full", FullRel)
+ .set_attr<FTVMCompute>("FTVMCompute", FullCompute)
+ .set_attr<TOpPattern>("TOpPattern", kElemWise);
+
+bool InitOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 1);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
return true;
}
-Expr MakeZeros(Array<IndexExpr> shape,
- DataType dtype) {
+Expr MakeZeros(Array<IndexExpr> shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
return Call(op, {}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.zeros")
-.set_body_typed(MakeZeros);
+TVM_REGISTER_GLOBAL("relay.op._make.zeros").set_body_typed(MakeZeros);
RELAY_REGISTER_OP("zeros")
-.describe(R"code(Fill array with zeros.
+ .describe(R"code(Fill array with zeros.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<InitOpAttrs>()
-.set_num_inputs(0)
-.set_support_level(3)
-.add_type_rel("InitOp", InitOpRel);
+ .set_attrs_type<InitOpAttrs>()
+ .set_num_inputs(0)
+ .set_support_level(3)
+ .add_type_rel("InitOp", InitOpRel);
-Expr MakeOnes(Array<IndexExpr> shape,
- DataType dtype) {
+Expr MakeOnes(Array<IndexExpr> shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
return Call(op, {}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.ones")
-.set_body_typed(MakeOnes);
+TVM_REGISTER_GLOBAL("relay.op._make.ones").set_body_typed(MakeOnes);
RELAY_REGISTER_OP("ones")
-.describe(R"code(Fill array with ones.
+ .describe(R"code(Fill array with ones.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<InitOpAttrs>()
-.set_num_inputs(0)
-.set_support_level(3)
-.add_type_rel("InitOp", InitOpRel);
-
-bool FullLikeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+ .set_attrs_type<InitOpAttrs>()
+ .set_num_inputs(0)
+ .set_support_level(3)
+ .add_type_rel("InitOp", InitOpRel);
+
+bool FullLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
}
CHECK_EQ(fill_value->shape.size(), 0)
- << "The fill value should be a scalar but here it has dimension "
- << fill_value->shape.size() << ".";
+ << "The fill value should be a scalar but here it has dimension " << fill_value->shape.size()
+ << ".";
reporter->Assign(types[2], TensorType(data->shape, data->dtype));
return true;
}
-Array<te::Tensor> FullLikeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> FullLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- return { topi::full_like(inputs[0], inputs[1]()) };
+ return {topi::full_like(inputs[0], inputs[1]())};
}
-Expr MakeFullLike(Expr data,
- Expr fill_value) {
+Expr MakeFullLike(Expr data, Expr fill_value) {
static const Op& op = Op::Get("full_like");
return Call(op, {data, fill_value}, Attrs(), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.full_like")
-.set_body_typed(MakeFullLike);
+TVM_REGISTER_GLOBAL("relay.op._make.full_like").set_body_typed(MakeFullLike);
RELAY_REGISTER_OP("full_like")
-.describe(R"code(Return an scalar value array with the same shape
+ .describe(R"code(Return an scalar value array with the same shape
and type as the input array.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("fill_value", "double", "Scalar value to fill.")
-.set_support_level(3)
-.add_type_rel("FullLike", FullLikeRel)
-.set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute)
-.set_attr<TOpPattern>("TOpPattern", kElemWise);
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("fill_value", "double", "Scalar value to fill.")
+ .set_support_level(3)
+ .add_type_rel("FullLike", FullLikeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kElemWise);
// arange operator
TVM_REGISTER_NODE_TYPE(ArangeAttrs);
return -std::numeric_limits<double>::infinity();
}
-bool ArangeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& raw_attrs,
+bool ArangeRel(const Array<Type>& types, int num_inputs, const Attrs& raw_attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const ArangeAttrs* attrs = raw_attrs.as<ArangeAttrs>();
reporter->Assign(types[1], types[2]);
reporter->Assign(types[2], TensorType({}, attrs->dtype));
- if ((cstart = attrs->start.as<ConstantNode>()) &&
- (cstop = attrs->stop.as<ConstantNode>()) &&
+ if ((cstart = attrs->start.as<ConstantNode>()) && (cstop = attrs->stop.as<ConstantNode>()) &&
(cstep = attrs->step.as<ConstantNode>())) {
double start = ToScalar(cstart->data);
double stop = ToScalar(cstop->data);
double step = ToScalar(cstep->data);
int32_t num_elem = static_cast<int32_t>(std::ceil((stop - start) / step));
- CHECK_GT(num_elem, 0)
- << "Invalid arange attributes (start, stop, step): " << attrs->start
- << ", " << attrs->stop << ", " << attrs->step;
+ CHECK_GT(num_elem, 0) << "Invalid arange attributes (start, stop, step): " << attrs->start
+ << ", " << attrs->stop << ", " << attrs->step;
reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype));
return true;
} else {
}
}
-inline te::Tensor DynamicArange(const te::Tensor& start,
- const te::Tensor& stop,
- const te::Tensor& step,
- tvm::DataType dtype,
- std::string name = "tensor",
- std::string tag = topi::kInjective) {
+inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop,
+ const te::Tensor& step, tvm::DataType dtype,
+ std::string name = "tensor", std::string tag = topi::kInjective) {
tvm::PrimExpr num_elem = tvm::tir::Var("num_elem");
- return te::compute({num_elem}, [&](const Array<tvm::tir::Var>& indices) {
- return tvm::cast(dtype, start[0] + step[0] * indices[0]);
- }, name, tag);
+ return te::compute(
+ {num_elem},
+ [&](const Array<tvm::tir::Var>& indices) {
+ return tvm::cast(dtype, start[0] + step[0] * indices[0]);
+ },
+ name, tag);
}
-Array<te::Tensor> ArangeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> ArangeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
te::Tensor start = inputs[0];
- te::Tensor stop = inputs[1];
+ te::Tensor stop = inputs[1];
te::Tensor step = inputs[2];
- return { DynamicArange(start, stop, step, param->dtype) };
+ return {DynamicArange(start, stop, step, param->dtype)};
}
-Expr MakeArange(Expr start,
- Expr stop,
- Expr step,
- DataType dtype) {
+Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype) {
auto attrs = make_object<ArangeAttrs>();
attrs->start = start;
attrs->stop = stop;
return Call(op, {start, stop, step}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.arange")
-.set_body_typed(MakeArange);
+TVM_REGISTER_GLOBAL("relay.op._make.arange").set_body_typed(MakeArange);
// An issue with the existing design is that we require dependency
// to type the operator precisely.
// In general I think we should avoid this pattern, and introduce
// a secondary shape analysis to recover more precise information.
RELAY_REGISTER_OP("arange")
-.describe(R"code(Returns evenly spaced values within a given interval.
+ .describe(R"code(Returns evenly spaced values within a given interval.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ArangeAttrs>()
-.set_num_inputs(3)
-.set_support_level(3)
-.add_type_rel("Arange", ArangeRel)
-.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
-// TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
+ .set_attrs_type<ArangeAttrs>()
+ .set_num_inputs(3)
+ .set_support_level(3)
+ .add_type_rel("Arange", ArangeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
+ // TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
// repeat operator
TVM_REGISTER_NODE_TYPE(RepeatAttrs);
-bool RepeatRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool RepeatRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "repeat: expect input type to be TensorType but get "
- << types[0];
+ << "repeat: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as<RepeatAttrs>();
const int ndim = static_cast<int>(data->shape.size());
const int repeats = param->repeats;
const int axis = param->axis;
- CHECK(repeats >= 1)
- << "repeat only accepts `repeats >= 1`"
- << ", but got repeats = " << repeats;
+ CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
+ << ", but got repeats = " << repeats;
CHECK(-ndim - 1 <= axis && axis <= ndim)
- << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
- << ", but got axis = " << axis
- << ", and data.ndim = " << ndim;
+ << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
+ << ", but got axis = " << axis << ", and data.ndim = " << ndim;
const int pivot = axis < 0 ? ndim + axis : axis;
std::vector<IndexExpr> oshape;
oshape.reserve(ndim + repeats);
return true;
}
-Array<te::Tensor> RepeatCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> RepeatCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- const RepeatAttrs *param = attrs.as<RepeatAttrs>();
+ const RepeatAttrs* param = attrs.as<RepeatAttrs>();
CHECK(param != nullptr);
- return { topi::repeat(inputs[0], param->repeats, param->axis) };
+ return {topi::repeat(inputs[0], param->repeats, param->axis)};
}
-Expr MakeRepeat(Expr data,
- int repeats,
- int axis) {
+Expr MakeRepeat(Expr data, int repeats, int axis) {
auto attrs = make_object<RepeatAttrs>();
attrs->repeats = repeats;
attrs->axis = axis;
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.repeat")
-.set_body_typed(MakeRepeat);
+TVM_REGISTER_GLOBAL("relay.op._make.repeat").set_body_typed(MakeRepeat);
RELAY_REGISTER_OP("repeat")
-.describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
+ .describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<RepeatAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("Repeat", RepeatRel)
-.set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
-.set_attr<TOpPattern>("TOpPattern", kBroadcast);
+ .set_num_inputs(1)
+ .set_attrs_type<RepeatAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Repeat", RepeatRel)
+ .set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
// tile operator
TVM_REGISTER_NODE_TYPE(TileAttrs);
-bool TileRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool TileRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "tile: expect input type to be TensorType but get "
- << types[0];
+ << "tile: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as<TileAttrs>();
const size_t ndim = data->shape.size();
const Array<Integer>& reps = param->reps;
// check dimension match
- CHECK(reps.defined())
- << "repetition array is not defined. data.ndim = " << ndim;
+ CHECK(reps.defined()) << "repetition array is not defined. data.ndim = " << ndim;
const size_t rndim = reps.size();
for (size_t i = 0; i < rndim; ++i) {
if (const tvm::tir::IntImmNode* val = reps[i].as<tvm::tir::IntImmNode>()) {
- CHECK_GT(val->value, 0)
- << "Tile reps value should always be larger than 0, but get: " << val->value;
+ CHECK_GT(val->value, 0) << "Tile reps value should always be larger than 0, but get: "
+ << val->value;
}
}
size_t tndim = (ndim > rndim) ? ndim : rndim;
return true;
}
-Array<te::Tensor> TileCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> TileCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- const TileAttrs *param = attrs.as<TileAttrs>();
+ const TileAttrs* param = attrs.as<TileAttrs>();
CHECK(param != nullptr);
- return { topi::tile(inputs[0], param->reps) };
+ return {topi::tile(inputs[0], param->reps)};
}
-Expr MakeTile(Expr data,
- Array<Integer> reps) {
+Expr MakeTile(Expr data, Array<Integer> reps) {
auto attrs = make_object<TileAttrs>();
attrs->reps = reps;
static const Op& op = Op::Get("tile");
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.tile")
-.set_body_typed(MakeTile);
+TVM_REGISTER_GLOBAL("relay.op._make.tile").set_body_typed(MakeTile);
RELAY_REGISTER_OP("tile")
-.describe(R"code(Repeat the whole array multiple times.
+ .describe(R"code(Repeat the whole array multiple times.
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<TileAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("Tile", TileRel)
-.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
-.set_attr<TOpPattern>("TOpPattern", kBroadcast);
+ .set_num_inputs(1)
+ .set_attrs_type<TileAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Tile", TileRel)
+ .set_attr<FTVMCompute>("FTVMCompute", TileCompute)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
// reverse operator
TVM_REGISTER_NODE_TYPE(ReverseAttrs);
-bool ReverseRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
- const TypeReporter& reporter) {
+bool ReverseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "reverse: expect input type to be TensorType but get "
- << types[0];
+ << "reverse: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as<ReverseAttrs>();
const int ndim = static_cast<int>(data->shape.size());
const int axis = param->axis;
CHECK(-ndim <= axis && axis < ndim)
- << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]"
- << ", but got axis = " << axis
- << ", and data.ndim = " << ndim;
+ << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]"
+ << ", but got axis = " << axis << ", and data.ndim = " << ndim;
reporter->Assign(types[1], types[0]);
return true;
}
-Array<te::Tensor> ReverseCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> ReverseCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- const ReverseAttrs *param = attrs.as<ReverseAttrs>();
+ const ReverseAttrs* param = attrs.as<ReverseAttrs>();
CHECK(param != nullptr);
- return { topi::flip(inputs[0], param->axis) };
+ return {topi::flip(inputs[0], param->axis)};
}
-Expr MakeReverse(Expr data,
- int axis) {
+Expr MakeReverse(Expr data, int axis) {
auto attrs = make_object<ReverseAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("reverse");
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.reverse")
-.set_body_typed(MakeReverse);
+TVM_REGISTER_GLOBAL("relay.op._make.reverse").set_body_typed(MakeReverse);
RELAY_REGISTER_OP("reverse")
-.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
+ .describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<ReverseAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("Reverse", ReverseRel)
-.set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_num_inputs(1)
+ .set_attrs_type<ReverseAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Reverse", ReverseRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// where operator
-bool WhereRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4U);
const auto* condition = types[0].as<TensorTypeNode>();
CHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size";
if (cond_shape.size() != x_shape.size()) {
- CHECK_EQ(cond_shape.size(), 1)
- << "Shape of condition " << condition->shape
- << " must be either equal to x or has dimension of 1.";
+ CHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape
+ << " must be either equal to x or has dimension of 1.";
}
for (size_t i = 0; i < x_shape.size(); i++) {
CHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
<< "x and y must have the same shape: " << x_shape << " vs " << y_shape;
if (i < cond_shape.size()) {
- CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
- << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
+ CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
+ << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
}
}
reporter->Assign(types[3], TensorType(x_shape, x->dtype));
return Call(op, {condition, x, y});
}
-Array<te::Tensor> WhereCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> WhereCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- return { topi::where(inputs[0], inputs[1], inputs[2]) };
+ return {topi::where(inputs[0], inputs[1], inputs[2])};
}
-TVM_REGISTER_GLOBAL("relay.op._make.where")
-.set_body_typed(MakeWhere);
+TVM_REGISTER_GLOBAL("relay.op._make.where").set_body_typed(MakeWhere);
RELAY_REGISTER_OP("where")
-.describe(R"code(
+ .describe(R"code(
Return the elements, either from x or y, depending on the condition.
Given three ndarrays, condition, x, and y, return an ndarray with the elements
where(cond, x, y) = [[1, 2], [7, 8]]
)code" TVM_ADD_FILELINE)
-.add_argument("condition", "Tensor", "Condition array")
-.add_argument("x", "Tensor", "First array to be selected")
-.add_argument("y", "Tensor", "Second array to be selected")
-.set_num_inputs(3)
-.set_support_level(4)
-.add_type_rel("Where", WhereRel)
-.set_attr<FTVMCompute>("FTVMCompute", WhereCompute)
-.set_attr<TOpPattern>("TOpPattern", kBroadcast);
-
+ .add_argument("condition", "Tensor", "Condition array")
+ .add_argument("x", "Tensor", "First array to be selected")
+ .add_argument("y", "Tensor", "Second array to be selected")
+ .set_num_inputs(3)
+ .set_support_level(4)
+ .add_type_rel("Where", WhereRel)
+ .set_attr<FTVMCompute>("FTVMCompute", WhereCompute)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
// Squeeze
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
-Expr MakeSqueeze(Expr data,
- Array<Integer> axis) {
+Expr MakeSqueeze(Expr data, Array<Integer> axis) {
auto attrs = make_object<SqueezeAttrs>();
attrs->axis = std::move(axis);
static const Op& op = Op::Get("squeeze");
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.squeeze")
-.set_body_typed(MakeSqueeze);
+TVM_REGISTER_GLOBAL("relay.op._make.squeeze").set_body_typed(MakeSqueeze);
-
-bool SqueezeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool SqueezeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
}
} else {
// pair up original shape with a boolean which control whether it will be in the final shape.
- std::vector<std::pair<IndexExpr, bool> > original_shape;
+ std::vector<std::pair<IndexExpr, bool>> original_shape;
for (const auto& e : data->shape) {
original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
}
return true;
}
-Array<te::Tensor> SqueezeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> SqueezeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- const SqueezeAttrs *param = attrs.as<SqueezeAttrs>();
+ const SqueezeAttrs* param = attrs.as<SqueezeAttrs>();
CHECK(param != nullptr);
- return { topi::squeeze(inputs[0], param->axis) };
+ return {topi::squeeze(inputs[0], param->axis)};
}
-
RELAY_REGISTER_OP("squeeze")
-.describe(R"code(Squeeze the input tensor at the dimensions given by axes
+ .describe(R"code(Squeeze the input tensor at the dimensions given by axes
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<SqueezeAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("Squeeze", SqueezeRel)
-.set_attr<FTVMCompute>("FTVMCompute", SqueezeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
-
+ .set_num_inputs(1)
+ .set_attrs_type<SqueezeAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Squeeze", SqueezeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", SqueezeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
-bool CollapseSumLikeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool CollapseSumLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
return BroadcastRel({types[0], types[1], types[0]}, 2, Attrs(), reporter);
}
-Expr MakeCollapseSumLike(Expr data,
- Expr collapse_type) {
+Expr MakeCollapseSumLike(Expr data, Expr collapse_type) {
static const Op& op = Op::Get("collapse_sum_like");
return Call(op, {data, collapse_type}, Attrs(), {});
}
-Array<te::Tensor> CollapseSumLikeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> CollapseSumLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
- return { topi::collapse_sum(inputs[0], out_ttype->shape) };
+ return {topi::collapse_sum(inputs[0], out_ttype->shape)};
}
-TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like")
-.set_body_typed(MakeCollapseSumLike);
+TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like").set_body_typed(MakeCollapseSumLike);
RELAY_REGISTER_OP("collapse_sum_like")
-.describe(R"code(Collapse the first input to match the shape of the second input.
+ .describe(R"code(Collapse the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.")
-.set_support_level(10)
-.add_type_rel("CollapseSumLike", CollapseSumLikeRel)
-.set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
-.set_attr<TOpPattern>("TOpPattern", kCommReduce);
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("collapse_type", "Tensor", "Provide the type to collapse to.")
+ .set_support_level(10)
+ .add_type_rel("CollapseSumLike", CollapseSumLikeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kCommReduce);
// BroadCastTo: <A, B> -> B where BroadCast(A, B) = B
-bool BroadCastToRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
auto ioattrs = attrs.as<InitOpAttrs>();
CHECK(ioattrs);
auto intt = types[0].as<TensorTypeNode>();
- if (intt == nullptr) { return false; }
+ if (intt == nullptr) {
+ return false;
+ }
auto type = TensorType(ioattrs->shape, intt->dtype);
reporter->Assign(types[1], type);
return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
return Call(op, {data}, Attrs(attrs), {});
}
-Array<te::Tensor> BroadCastToCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> BroadCastToCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
auto ioattrs = attrs.as<InitOpAttrs>();
CHECK(ioattrs != nullptr);
- return { topi::broadcast_to(inputs[0], ioattrs->shape) };
+ return {topi::broadcast_to(inputs[0], ioattrs->shape)};
}
-TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to")
-.set_body_typed(MakeBroadCastTo);
+TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastTo);
RELAY_REGISTER_OP("broadcast_to")
-.describe(R"code(Broadcast the first input to match the shape argument.
+ .describe(R"code(Broadcast the first input to match the shape argument.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(4)
-.add_type_rel("BroadCastTo", BroadCastToRel)
-.set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
-.set_attr<TOpPattern>("TOpPattern", kBroadcast);
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(4)
+ .add_type_rel("BroadCastTo", BroadCastToRel)
+ .set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
-bool BroadCastToLikeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BroadCastToLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
}
-Expr MakeBroadCastToLike(Expr data,
- Expr broadcast_type) {
+Expr MakeBroadCastToLike(Expr data, Expr broadcast_type) {
static const Op& op = Op::Get("broadcast_to_like");
return Call(op, {data, broadcast_type}, Attrs(), {});
}
-Array<te::Tensor> BroadCastToLikeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> BroadCastToLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
- return { topi::broadcast_to(inputs[0], out_ttype->shape) };
+ return {topi::broadcast_to(inputs[0], out_ttype->shape)};
}
-TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like")
-.set_body_typed(MakeBroadCastToLike);
+TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like").set_body_typed(MakeBroadCastToLike);
RELAY_REGISTER_OP("broadcast_to_like")
-.describe(R"code(Broadcast the first input to match the shape of the second input.
+ .describe(R"code(Broadcast the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.")
-.set_support_level(10)
-.add_type_rel("BroadCastToLike", BroadCastToLikeRel)
-.set_attr<FTVMCompute>("FTVMCompute", BroadCastToLikeCompute)
-.set_attr<TOpPattern>("TOpPattern", kBroadcast);
-
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.")
+ .set_support_level(10)
+ .add_type_rel("BroadCastToLike", BroadCastToLikeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", BroadCastToLikeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
- CHECK(!arr[i].defined() || arr[i].as<IntImmNode>())
- << "Expect an int array";
+ CHECK(!arr[i].defined() || arr[i].as<IntImmNode>()) << "Expect an int array";
}
- return Downcast<Array<Integer> >(arr);
+ return Downcast<Array<Integer>>(arr);
}
-
// strided_slice
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
-bool StridedSliceRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
- const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
+ const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
CHECK(param != nullptr);
auto dshape = data->shape;
int64_t begin_v = begin_vec[i];
int64_t end_v = end_vec[i];
- if ((stride_v == 1 &&
- begin_v == 0 &&
- end_v == max_range) ||
- (stride_v == -1 &&
- begin_v == max_range &&
- end_v == 0)) {
+ if ((stride_v == 1 && begin_v == 0 && end_v == max_range) ||
+ (stride_v == -1 && begin_v == max_range && end_v == 0)) {
// Quick path, do not slice this dimension.
oshape[i] = dshape[i];
continue;
// Require concrete integer as symbolic inference of min/max
// can get complicated and not very helpful.
const int64_t* p_dim_size = tir::as_const_int(dshape[i]);
- CHECK(p_dim_size)
- << "strided_slice requires sliced dimension to be concrete int";
+ CHECK(p_dim_size) << "strided_slice requires sliced dimension to be concrete int";
int64_t dim_size = p_dim_size[0];
begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
end_v = (end_v < 0) ? dim_size + end_v : end_v;
int64_t slice_range, step;
if (stride_v < 0) {
if (end_v < -1) end_v = -1;
- CHECK_LT(end_v, begin_v)
- << "strided_slice get empty slice at axis " << i;
+ CHECK_LT(end_v, begin_v) << "strided_slice get empty slice at axis " << i;
begin_v = std::min(dim_size - 1, begin_v);
slice_range = begin_v - end_v;
step = -stride_v;
} else {
if (begin_v < 0) begin_v = 0;
CHECK_GE(stride_v, 0);
- CHECK_LT(begin_v, end_v)
- << "strided_slice get empty slice at axis " << i;
+ CHECK_LT(begin_v, end_v) << "strided_slice get empty slice at axis " << i;
end_v = std::min(dim_size, end_v);
slice_range = end_v - begin_v;
step = stride_v;
return true;
}
-
-Array<Array<Layout> > StridedSliceInferCorrectLayout(
- const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type>& old_in_types) {
-
+Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
auto shape = old_in_shapes[0];
// NOTE: Discard "const" qualifier here.
- auto *params = const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
+ auto* params = const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
Array<Integer> new_begin, new_end;
}
}
int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0;
- int64_t end = params->end[i].defined() ? params->end[i]->value :
- shape[i].as<IntImmNode>()->value;
+ int64_t end =
+ params->end[i].defined() ? params->end[i]->value : shape[i].as<IntImmNode>()->value;
if (begin % factor || end % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
return {{layout}, {layout}};
}
-
// Positional relay function to create StridedSlice operator used by frontend FFI.
-Expr MakeStridedSlice(Expr data,
- Array<Integer> begin,
- Array<Integer> end,
- Array<Integer> strides) {
+Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides) {
auto attrs = make_object<StridedSliceAttrs>();
attrs->begin = std::move(begin);
attrs->end = std::move(end);
return Call(op, {data}, Attrs(attrs), {});
}
-Array<te::Tensor> StridedSliceCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
+ const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
CHECK(param != nullptr);
return Array<te::Tensor>{
- topi::strided_slice(inputs[0], param->begin, param->end, param->strides)
- };
+ topi::strided_slice(inputs[0], param->begin, param->end, param->strides)};
}
-
-TVM_REGISTER_GLOBAL("relay.op._make.strided_slice")
-.set_body_typed(MakeStridedSlice);
-
+TVM_REGISTER_GLOBAL("relay.op._make.strided_slice").set_body_typed(MakeStridedSlice);
RELAY_REGISTER_OP("strided_slice")
.describe(R"code(Strided slice of an array.
[[ 5., 6.],
[ 7., 8.]]]
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(4)
-.set_attrs_type<StridedSliceAttrs>()
-.add_type_rel("StridedSlice", StridedSliceRel)
-.set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(4)
+ .set_attrs_type<StridedSliceAttrs>()
+ .add_type_rel("StridedSlice", StridedSliceRel)
+ .set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);
// strided_set
-bool StridedSetRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool StridedSetRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 6);
reporter->Assign(types[5], types[0]);
return true;
}
-Expr MakeStridedSet(Expr data,
- Expr v,
- Expr begin,
- Expr end,
- Expr strides) {
+Expr MakeStridedSet(Expr data, Expr v, Expr begin, Expr end, Expr strides) {
static const Op& op = Op::Get("strided_set");
return Call(op, {data, v, begin, end, strides}, {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.strided_set")
-.set_body_typed(MakeStridedSet);
-
+TVM_REGISTER_GLOBAL("relay.op._make.strided_set").set_body_typed(MakeStridedSet);
RELAY_REGISTER_OP("strided_set")
- .describe(R"code(Strided set of an array.
+ .describe(R"code(Strided set of an array.
Example::
x = [[ 1., 4., 7., 10.],
[ 2., 44., 55., 66.],
[ 3., 6., 9., 12.]]
)code" TVM_ADD_FILELINE)
-.set_num_inputs(5)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("v", "Tensor", "The data to set.")
-.add_argument("begin", "Tensor", "Indices for the start of the slice.")
-.add_argument("end", "Tensor", "Indices indicating the end of the slice.")
-.add_argument("strides", "Tensor", "The strides values.")
-.set_support_level(4)
-.set_attr<TOpPattern>("TOpPattern", kInjective)
-.add_type_rel("StridedSet", StridedSetRel);
+ .set_num_inputs(5)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("v", "Tensor", "The data to set.")
+ .add_argument("begin", "Tensor", "Indices for the start of the slice.")
+ .add_argument("end", "Tensor", "Indices indicating the end of the slice.")
+ .add_argument("strides", "Tensor", "The strides values.")
+ .set_support_level(4)
+ .set_attr<TOpPattern>("TOpPattern", kInjective)
+ .add_type_rel("StridedSet", StridedSetRel);
// relay.split
TVM_REGISTER_NODE_TYPE(SplitAttrs);
-bool SplitRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool SplitRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
if (axis < 0) {
axis += data->shape.size();
}
- CHECK_LT(axis, data->shape.size())
- << "axis should be within the input dimension range.";
- CHECK_GE(axis, 0)
- << "axis should be within the input dimension range.";
+ CHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range.";
+ CHECK_GE(axis, 0) << "axis should be within the input dimension range.";
if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
- CHECK(reporter->Assert(indexmod(data->shape[axis],
- sections->value) == tir::make_zero(DataType::Int(64))))
+ CHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) ==
+ tir::make_zero(DataType::Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]";
std::vector<Type> fields;
for (int i = 0; i < sections->value; ++i) {
- std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
- oshape[axis] = indexdiv(oshape[axis], sections->value);
- auto vec_type = TensorType(oshape, data->dtype);
- fields.push_back(vec_type);
+ std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
+ oshape[axis] = indexdiv(oshape[axis], sections->value);
+ auto vec_type = TensorType(oshape, data->dtype);
+ fields.push_back(vec_type);
}
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
} else {
return true;
}
-Array<te::Tensor> SplitCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> SplitCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto param = attrs.as<SplitAttrs>();
CHECK(param != nullptr);
if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
int64_t num_sections = sections->value;
- return Array<te::Tensor>{
- topi::split_sections(inputs[0], num_sections, param->axis) };
+ return Array<te::Tensor>{topi::split_sections(inputs[0], num_sections, param->axis)};
} else {
- auto indices = Downcast<Array<Integer> >(param->indices_or_sections);
- return Array<te::Tensor>{ topi::split(inputs[0], indices, param->axis) };
+ auto indices = Downcast<Array<Integer>>(param->indices_or_sections);
+ return Array<te::Tensor>{topi::split(inputs[0], indices, param->axis)};
}
}
-Expr MakeSplit(Expr data,
- ObjectRef indices_or_sections,
- int axis) {
+Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) {
auto attrs = make_object<SplitAttrs>();
attrs->axis = axis;
attrs->indices_or_sections = std::move(indices_or_sections);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.split")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- if (args.type_codes[1] == kDLInt) {
- // Note: we change it from Int(64) to Int(32) for now as
- // combine_parallel_dense will transform the graph with Int(32).
- // More invetigation is needs to check which one we should use.
- *rv = MakeSplit(args[0],
- tir::make_const(DataType::Int(32), static_cast<int>(args[1])),
- args[2]);
- } else {
- *rv = MakeSplit(args[0], args[1], args[2]);
- }
+TVM_REGISTER_GLOBAL("relay.op._make.split").set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ if (args.type_codes[1] == kDLInt) {
+ // Note: we change it from Int(64) to Int(32) for now as
+ // combine_parallel_dense will transform the graph with Int(32).
+ // More invetigation is needs to check which one we should use.
+ *rv =
+ MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast<int>(args[1])), args[2]);
+ } else {
+ *rv = MakeSplit(args[0], args[1], args[2]);
+ }
});
RELAY_REGISTER_OP("split")
-.describe(R"code(Splits an array along a particular axis into multiple sub-arrays.
+ .describe(R"code(Splits an array along a particular axis into multiple sub-arrays.
Indices or sections to split into. Accepts an int or a tuple
If indices_or_sections is an integer, the input will be divided equally
the entries indicate where along axis the array is split.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<SplitAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("Split", SplitRel)
-.set_attr<FTVMCompute>("FTVMCompute", SplitCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
-
+ .set_attrs_type<SplitAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Split", SplitRel)
+ .set_attr<FTVMCompute>("FTVMCompute", SplitCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.slice_like
TVM_REGISTER_NODE_TYPE(SliceLikeAttrs);
/*!
-* \brief SliceLikeRel User defined type constraint function.
-* \param num_inputs Number of input types in the args.
-* \param attrs The additional attributes of the operator.
-* \param reporter The reporter to report solution to.
-* \return False if the relation has not been resolved, it might be resolved later.
-* True if this relation has been resolved.
-*/
-bool SliceLikeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+ * \brief SliceLikeRel User defined type constraint function.
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return False if the relation has not been resolved, it might be resolved later.
+ * True if this relation has been resolved.
+ */
+bool SliceLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (i < target_shape.size()) {
oshape[i] = target_shape[i];
CHECK(reporter->Assert(oshape[i] <= dshape[i]))
- << "End index of axis " << i << " exceeds input shape: "
- << oshape[i] << " vs " << dshape[i];
+ << "End index of axis " << i << " exceeds input shape: " << oshape[i] << " vs "
+ << dshape[i];
}
}
} else {
axis += dshape.size();
}
CHECK(axis < static_cast<int>(target_shape.size()))
- << "Axis " << axis << " exceeds dimension "
- << target_shape.size() << " of target_shape.";
+ << "Axis " << axis << " exceeds dimension " << target_shape.size() << " of target_shape.";
oshape[axis] = target_shape[axis];
CHECK(reporter->Assert(oshape[axis] <= dshape[axis]))
- << "End index of axis " << axis << " exceeds input shape: "
- << oshape[axis] << " vs " << dshape[axis];
+ << "End index of axis " << axis << " exceeds input shape: " << oshape[axis] << " vs "
+ << dshape[axis];
}
}
return true;
}
-
-Expr MakeSliceLike(Expr data,
- Expr shape_like,
- Array<Integer> axes) {
+Expr MakeSliceLike(Expr data, Expr shape_like, Array<Integer> axes) {
auto attrs = make_object<SliceLikeAttrs>();
attrs->axes = std::move(axes);
static const Op& op = Op::Get("slice_like");
return Call(op, {data, shape_like}, Attrs(attrs), {});
}
-Array<te::Tensor> SliceLikeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<SliceLikeAttrs>();
CHECK(param != nullptr);
for (size_t i = 0; i < src_shape.size(); ++i) {
if (i < target_shape.size()) {
end_idx.Set(i, target_shape[i]);
- CHECK_LE(topi::GetConstInt(end_idx[i]),
- topi::GetConstInt(src_shape[i]))
- << "End index of axis " << i << " exceeds input shape: "
- << topi::GetConstInt(end_idx[i]) << " vs "
- << topi::GetConstInt(src_shape[i]);
+ CHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i]))
+ << "End index of axis " << i
+ << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs "
+ << topi::GetConstInt(src_shape[i]);
}
}
} else {
axis = static_cast<int>(src_shape.size()) + axis;
}
end_idx.Set(axis, target_shape[axis]);
- CHECK_LE(topi::GetConstInt(end_idx[axis]),
- topi::GetConstInt(src_shape[axis]))
- << "End index of axis " << axis << " exceeds input shape: "
- << topi::GetConstInt(end_idx[axis]) << " vs "
- << topi::GetConstInt(src_shape[axis]);
+ CHECK_LE(topi::GetConstInt(end_idx[axis]), topi::GetConstInt(src_shape[axis]))
+ << "End index of axis " << axis
+ << " exceeds input shape: " << topi::GetConstInt(end_idx[axis]) << " vs "
+ << topi::GetConstInt(src_shape[axis]);
}
}
- return Array<te::Tensor>{
- topi::strided_slice(inputs[0],
- GetIntArray(begin_idx),
- GetIntArray(end_idx),
- GetIntArray(strides))
- };
+ return Array<te::Tensor>{topi::strided_slice(inputs[0], GetIntArray(begin_idx),
+ GetIntArray(end_idx), GetIntArray(strides))};
}
-
-TVM_REGISTER_GLOBAL("relay.op._make.slice_like")
-.set_body_typed(MakeSliceLike);
-
+TVM_REGISTER_GLOBAL("relay.op._make.slice_like").set_body_typed(MakeSliceLike);
RELAY_REGISTER_OP("slice_like")
-.describe(R"code(Slice the first input respect to the second input.
+ .describe(R"code(Slice the first input respect to the second input.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<SliceLikeAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("shape_like", "Tensor", "Shape tensor.")
-.set_support_level(10)
-.add_type_rel("SliceLike", SliceLikeRel)
-.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_attrs_type<SliceLikeAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("shape_like", "Tensor", "Shape tensor.")
+ .set_support_level(10)
+ .add_type_rel("SliceLike", SliceLikeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.layout_transform
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
-Array<te::Tensor> LayoutTransformCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> LayoutTransformCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<LayoutTransformAttrs>();
CHECK(param != nullptr);
- return Array<te::Tensor>{
- topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)
- };
+ return Array<te::Tensor>{topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)};
}
-bool LayoutTransformRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool LayoutTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "LayoutTransform: expect input data type to be TensorType but get "
- << types[0];
+ << "LayoutTransform: expect input data type to be TensorType but get " << types[0];
return false;
}
const LayoutTransformAttrs* params = attrs.as<LayoutTransformAttrs>();
Layout src_layout(params->src_layout);
Layout dst_layout(params->dst_layout);
- CHECK(src_layout.defined() && dst_layout.defined())
- << "cannot convert from/to undefined layout";
+ CHECK(src_layout.defined() && dst_layout.defined()) << "cannot convert from/to undefined layout";
auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout);
CHECK(layout_converter.defined())
- << "cannot convert from " << params->src_layout << " to " << params->dst_layout;
+ << "cannot convert from " << params->src_layout << " to " << params->dst_layout;
const auto& out_shape = layout_converter.ForwardShape(data->shape);
reporter->Assign(types[1], TensorType(out_shape, data->dtype));
return true;
}
-Expr MakeLayoutTransform(Expr data,
- std::string src_layout,
- std::string dst_layout) {
+Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout) {
auto attrs = make_object<LayoutTransformAttrs>();
attrs->src_layout = std::move(src_layout);
attrs->dst_layout = std::move(dst_layout);
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.layout_transform")
-.set_body_typed(MakeLayoutTransform);
+TVM_REGISTER_GLOBAL("relay.op._make.layout_transform").set_body_typed(MakeLayoutTransform);
RELAY_REGISTER_OP("layout_transform")
-.describe(R"code(Transform the input data layout.
+ .describe(R"code(Transform the input data layout.
For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes
the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<LayoutTransformAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_type_rel("layout_transform", LayoutTransformRel)
-.set_support_level(5)
-.set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);
-
+ .set_attrs_type<LayoutTransformAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_type_rel("layout_transform", LayoutTransformRel)
+ .set_support_level(5)
+ .set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);
/* relay._contrib_reverse_reshape */
-Expr MakeReverseReshape(Expr data,
- Array<Integer> newshape) {
+Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = true;
return Call(op, {data}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape")
-.set_body_typed(MakeReverseReshape);
+TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape").set_body_typed(MakeReverseReshape);
RELAY_REGISTER_OP("_contrib_reverse_reshape")
-.describe(R"code(Reshapes the input array where the special values are inferred from
+ .describe(R"code(Reshapes the input array where the special values are inferred from
right to left.
Example::
- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5)
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<ReshapeAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(10)
-.add_type_rel("Reshape", ReshapeRel)
-.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_num_inputs(1)
+ .set_attrs_type<ReshapeAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(10)
+ .add_type_rel("Reshape", ReshapeRel)
+ .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// gather_nd operator
-bool GatherNDRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* indices = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "GatherND: expect input data type to be TensorType but get "
- << types[0];
+ << "GatherND: expect input data type to be TensorType but get " << types[0];
return false;
}
if (indices == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
- << "GatherND: expect indices type to be TensorType but get "
- << types[1];
+ << "GatherND: expect indices type to be TensorType but get " << types[1];
return false;
}
const size_t ndim = data->shape.size();
const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
const size_t kdim = indices->shape.size() - 1;
- CHECK(size_t(mdim->value) <= ndim)
- << "GatherND: indices shape does satisfy.";
+ CHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy.";
Array<IndexExpr> oshape;
- for (size_t i = 1; i < kdim + 1; ++i)
- oshape.push_back(indices->shape[i]);
- for (size_t i = mdim->value; i < ndim; ++i)
- oshape.push_back(data->shape[i]);
+ for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
+ for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]);
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
-Array<te::Tensor> GatherNDCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
- return { topi::gather_nd(inputs[0], inputs[1]) };
+ return {topi::gather_nd(inputs[0], inputs[1])};
}
-Expr MakeGatherND(Expr data,
- Expr indices) {
+Expr MakeGatherND(Expr data, Expr indices) {
static const Op& op = Op::Get("gather_nd");
return Call(op, {data, indices}, {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.gather_nd")
-.set_body_typed(MakeGatherND);
+TVM_REGISTER_GLOBAL("relay.op._make.gather_nd").set_body_typed(MakeGatherND);
RELAY_REGISTER_OP("gather_nd")
-.describe(R"code(Gather elements or slices from data and store to
+ .describe(R"code(Gather elements or slices from data and store to
a tensor whose shape is defined by indices.
Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with
(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N,
output shape will simply be (Y_0, ..., Y_{K-1}).
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(3)
-.add_type_rel("GatherND", GatherNDRel)
-.set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("GatherND", GatherNDRel)
+ .set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.sequence_mask
TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs);
-bool SequenceMaskRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool SequenceMaskRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, valid_length, result]
CHECK_EQ(types.size(), 3);
return true;
}
-Array<te::Tensor> SequenceMaskCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> SequenceMaskCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<SequenceMaskAttrs>();
CHECK(param != nullptr);
return Array<te::Tensor>{
- topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) };
+ topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis)};
}
-Expr MakeSequenceMask(Expr data,
- Expr valid_length,
- double mask_value,
- int axis) {
+Expr MakeSequenceMask(Expr data, Expr valid_length, double mask_value, int axis) {
auto attrs = make_object<SequenceMaskAttrs>();
attrs->mask_value = std::move(mask_value);
attrs->axis = std::move(axis);
return Call(op, {data, valid_length}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask")
-.set_body_typed(MakeSequenceMask);
+TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask").set_body_typed(MakeSequenceMask);
RELAY_REGISTER_OP("sequence_mask")
-.describe(R"code(Sets all elements outside the expected length of the sequence to a constant value.
+ .describe(
+ R"code(Sets all elements outside the expected length of the sequence to a constant value.
This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or
[batch_size, MAX_LENGTH, ...] and returns an array of the same shape.
[[ 0.1, 0.1, 0.1],
[ 16., 17., 18.]]]
)code" TVM_ADD_FILELINE)
-.set_attrs_type<SequenceMaskAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.")
-.set_support_level(10)
-.add_type_rel("SequenceMask", SequenceMaskRel)
-.set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_attrs_type<SequenceMaskAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.")
+ .set_support_level(10)
+ .add_type_rel("SequenceMask", SequenceMaskRel)
+ .set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.one_hot
TVM_REGISTER_NODE_TYPE(OneHotAttrs);
-bool OneHotRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool OneHotRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [indices, on_value, off_value, result]
CHECK_EQ(types.size(), 4);
return true;
}
-Array<te::Tensor> OneHotCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> OneHotCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<OneHotAttrs>();
CHECK(param != nullptr);
- return Array<te::Tensor> {
- topi::one_hot(inputs[0],
- inputs[1](),
- inputs[2](),
- param->depth,
- param->axis,
- param->dtype)
- };
-}
-
-Expr MakeOneHot(Expr indices,
- Expr on_value,
- Expr off_value,
- int depth,
- int axis,
- DataType dtype) {
+ return Array<te::Tensor>{
+ topi::one_hot(inputs[0], inputs[1](), inputs[2](), param->depth, param->axis, param->dtype)};
+}
+
+Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype) {
auto attrs = make_object<OneHotAttrs>();
attrs->depth = std::move(depth);
attrs->axis = axis;
return Call(op, {indices, on_value, off_value}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.one_hot")
-.set_body_typed(MakeOneHot);
+TVM_REGISTER_GLOBAL("relay.op._make.one_hot").set_body_typed(MakeOneHot);
RELAY_REGISTER_OP("one_hot")
-.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
+ .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
other locations take value 0. Final dimension is <indices dimensions> x depth.
**indices** Locations to set to 1.
**axis** Axis to fill.
**dtype**)code" TVM_ADD_FILELINE)
-.set_attrs_type<OneHotAttrs>()
-.set_num_inputs(3)
-.add_argument("indices", "Tensor", "Locations to set to on_value.")
-.add_argument("on_value", "Expr", "Value to fill at indices.")
-.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
-.set_support_level(10)
-.add_type_rel("OneHot", OneHotRel)
-.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
-.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
+ .set_attrs_type<OneHotAttrs>()
+ .set_num_inputs(3)
+ .add_argument("indices", "Tensor", "Locations to set to on_value.")
+ .add_argument("on_value", "Expr", "Value to fill at indices.")
+ .add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
+ .set_support_level(10)
+ .add_type_rel("OneHot", OneHotRel)
+ .set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
+ .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
/* relay.unravel_index */
-bool UnRavelIndexRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool UnRavelIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* indices = types[0].as<TensorTypeNode>();
if (indices == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
- << "unravel_index: expect input type to be TensorType but get "
- << types[0];
+ << "unravel_index: expect input type to be TensorType but get " << types[0];
return false;
}
- CHECK(indices->dtype.is_int())
- << "indices of unravel_index must be tensor of integer";
+ CHECK(indices->dtype.is_int()) << "indices of unravel_index must be tensor of integer";
const auto* shape = types[1].as<TensorTypeNode>();
if (shape == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
- << "unravel_index: expect input type to be TensorType but get "
- << types[1];
+ << "unravel_index: expect input type to be TensorType but get " << types[1];
return false;
}
- CHECK(indices->dtype.is_int())
- << "shape of unravel_index must be tensor of integer";
+ CHECK(indices->dtype.is_int()) << "shape of unravel_index must be tensor of integer";
Array<IndexExpr> indices_shape;
Array<IndexExpr> shape_shape;
return true;
}
-Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])};
}
-Expr MakeUnRavelIndex(Expr data,
- Expr shape) {
+Expr MakeUnRavelIndex(Expr data, Expr shape) {
static const Op& op = Op::Get("unravel_index");
return Call(op, {data, shape}, Attrs(), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.unravel_index")
-.set_body_typed(MakeUnRavelIndex);
+TVM_REGISTER_GLOBAL("relay.op._make.unravel_index").set_body_typed(MakeUnRavelIndex);
RELAY_REGISTER_OP("unravel_index")
-.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays.
+ .describe(
+ R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Example::
- unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]]
)code" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.set_support_level(3)
-.add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
-.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+ .set_num_inputs(2)
+ .set_support_level(3)
+ .add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
+ .set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
#include <tvm/ir/error.h>
#include <tvm/relay/attrs/transform.h>
-#include <vector>
+#include <tvm/relay/op_attr_types.h>
+
#include <algorithm>
#include <limits>
#include <string>
#include <unordered_set>
#include <utility>
+#include <vector>
namespace tvm {
namespace relay {
template <typename AttrType>
-bool ConcatenateRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
/* If we receive a tuple we can continue, if we receive
* anything but an incomplete type we should signal an
* error.
- */
+ */
const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) {
throw Error(
- ErrorBuilder()
- << "concatenate requires a tuple of tensors as the first argument, found "
- << PrettyPrint(types[0]));
+ ErrorBuilder() << "concatenate requires a tuple of tensors as the first argument, found "
+ << PrettyPrint(types[0]));
} else if (types[0].as<IncompleteTypeNode>() != nullptr) {
return false;
}
// Sanity check: axis
int axis = param->axis;
if (!(-ndim <= axis && axis < ndim)) {
- throw Error(ErrorBuilder() <<
- "concatenate only accepts `axis` in [-ndim, ndim)" <<
- ", but got axis = " << axis <<
- ", and ndim = " << ndim);
+ throw Error(ErrorBuilder() << "concatenate only accepts `axis` in [-ndim, ndim)"
+ << ", but got axis = " << axis << ", and ndim = " << ndim);
}
axis = axis < 0 ? ndim + axis : axis;
for (size_t j = 0; j < first->shape.size(); ++j) {
if (j == static_cast<size_t>(axis)) continue;
if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
- throw Error("relay.concatenate requires all tensors have the same shape "
- "on non-concatenating axes");
+ throw Error(
+ "relay.concatenate requires all tensors have the same shape "
+ "on non-concatenating axes");
}
}
// Calculate shape
std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
- IndexExpr &concat_dim = oshape[axis];
+ IndexExpr& concat_dim = oshape[axis];
bool has_any = false;
if (concat_dim.as<Any>()) {
has_any = true;
return true;
}
-static inline Array<Array<Layout>> ConcatenateLayout(
- const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types) {
+static inline Array<Array<Layout>> ConcatenateLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());
Array<Array<IndexExpr>> old_in_shapes;
}
}
- size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
- static_cast<size_t>(param->axis);
+ size_t axis =
+ param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
Layout ret;
bool is_new_layout_selected = false;
}
if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
- return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
+ return Array<Array<Layout>>{{Layout::Undef()}, {Layout::Undef()}};
}
}
- return Array<Array<Layout> > {Array<Layout>(old_in_layouts.size(), ret), {ret}};
+ return Array<Array<Layout>>{Array<Layout>(old_in_layouts.size(), ret), {ret}};
}
} // namespace relay
* \file unary.cc
* \brief Unary operators.
*/
-#include <tvm/relay/expr.h>
-#include <tvm/relay/op.h>
-#include <tvm/relay/attrs/transform.h>
#include <topi/elemwise.h>
#include <topi/transform.h>
-#include "../type_relations.h"
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+
#include "../op_common.h"
+#include "../type_relations.h"
namespace tvm {
namespace relay {
-#define RELAY_UNARY_COMPUTE(FTOPI) \
- [] (const Attrs& attrs, \
- const Array<te::Tensor>& inputs, \
- const Type& out_type) -> Array<te::Tensor> { \
- return {FTOPI(inputs[0])}; \
- } \
-
+#define RELAY_UNARY_COMPUTE(FTOPI) \
+ [](const Attrs& attrs, const Array<te::Tensor>& inputs, \
+ const Type& out_type) -> Array<te::Tensor> { return {FTOPI(inputs[0])}; }
RELAY_REGISTER_UNARY_OP("log")
-.describe(R"code(Returns the log input array, computed element-wise.
+ .describe(R"code(Returns the log input array, computed element-wise.
.. math::
log(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
RELAY_REGISTER_UNARY_OP("log2")
-.describe(R"code(Returns the log to base 2 of input array, computed element-wise.
+ .describe(R"code(Returns the log to base 2 of input array, computed element-wise.
.. math::
log2(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2));
RELAY_REGISTER_UNARY_OP("log10")
-.describe(R"code(Returns the log to base 10 of input array, computed element-wise.
+ .describe(R"code(Returns the log to base 10 of input array, computed element-wise.
.. math::
log10(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10));
RELAY_REGISTER_UNARY_OP("tan")
-.describe(R"code(Returns the tan of input array, computed element-wise.
+ .describe(R"code(Returns the tan of input array, computed element-wise.
.. math::
Y = tan(X)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan));
RELAY_REGISTER_UNARY_OP("cos")
-.describe(R"code(Returns the cos of input array, computed element-wise.
+ .describe(R"code(Returns the cos of input array, computed element-wise.
.. math::
Y = cos(X)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos));
RELAY_REGISTER_UNARY_OP("cosh")
-.describe(R"code(Returns the cosh of input array, computed element-wise.
+ .describe(R"code(Returns the cosh of input array, computed element-wise.
.. math::
Y = cosh(X)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh));
RELAY_REGISTER_UNARY_OP("sin")
-.describe(R"code(Returns the sin of input array, computed element-wise.
+ .describe(R"code(Returns the sin of input array, computed element-wise.
.. math::
Y = sin(X)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin));
RELAY_REGISTER_UNARY_OP("sinh")
-.describe(R"code(Returns the sinh of input array, computed element-wise.
+ .describe(R"code(Returns the sinh of input array, computed element-wise.
.. math::
Y = sinh(X)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh));
RELAY_REGISTER_UNARY_OP("acos")
.describe(R"code(Returns the acos of input array, computed element-wise.
RELAY_REGISTER_UNARY_OP("atan")
-.describe(R"code(Returns the atan of input array, computed element-wise.
+ .describe(R"code(Returns the atan of input array, computed element-wise.
.. math::
Y = atan(X)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atan));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atan));
RELAY_REGISTER_UNARY_OP("atanh")
.describe(R"code(Returns the atanh of input array, computed element-wise.
RELAY_REGISTER_UNARY_OP("exp")
-.describe(R"code(Returns the exp input array, computed element-wise.
+ .describe(R"code(Returns the exp input array, computed element-wise.
.. math::
\exp(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("fast_exp")
-.describe(R"code(Returns the fast_exp input array, computed element-wise.
+ .describe(R"code(Returns the fast_exp input array, computed element-wise.
.. math::
\fast_exp(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp));
RELAY_REGISTER_UNARY_OP("erf")
-.describe(R"code(Returns the error function value for input array, computed element-wise.
+ .describe(R"code(Returns the error function value for input array, computed element-wise.
.. math::
\erf(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf));
RELAY_REGISTER_UNARY_OP("fast_erf")
-.describe(R"code(Returns the error function value for input array, computed element-wise.
+ .describe(R"code(Returns the error function value for input array, computed element-wise.
.. math::
\fast_erf(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf));
RELAY_REGISTER_UNARY_OP("sqrt")
-.describe(R"code(Returns the sqrt input array, computed element-wise.
+ .describe(R"code(Returns the sqrt input array, computed element-wise.
.. math::
sqrt(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));
RELAY_REGISTER_UNARY_OP("rsqrt")
-.describe(R"code(Returns the rsqrt input array, computed element-wise.
+ .describe(R"code(Returns the rsqrt input array, computed element-wise.
.. math::
1/sqrt(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt));
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt));
RELAY_REGISTER_UNARY_OP("zeros_like")
-.describe(R"code(Returns an array of zeros, with same type and shape as the input.
+ .describe(R"code(Returns an array of zeros, with same type and shape as the input.
)code" TVM_ADD_FILELINE)
-.set_support_level(4);
+ .set_support_level(4);
RELAY_REGISTER_UNARY_OP("ones_like")
-.describe(R"code(Returns an array of ones, with same type and shape as the input.
+ .describe(R"code(Returns an array of ones, with same type and shape as the input.
)code" TVM_ADD_FILELINE)
-.set_support_level(4);
+ .set_support_level(4);
RELAY_REGISTER_UNARY_OP("sigmoid")
-.describe(R"code(Returns the sigmoid input array, computed element-wise.
+ .describe(R"code(Returns the sigmoid input array, computed element-wise.
.. math::
sigmoid(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid));
RELAY_REGISTER_UNARY_OP("copy")
-.describe(R"code(Copy a tensor.
+ .describe(R"code(Copy a tensor.
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
+ .set_support_level(3)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
// relay.clip
TVM_REGISTER_NODE_TYPE(ClipAttrs);
-TVM_REGISTER_GLOBAL("relay.op._make.clip")
-.set_body_typed([](Expr a, double a_min, double a_max) {
- auto attrs = make_object<ClipAttrs>();
- attrs->a_min = a_min;
- attrs->a_max = a_max;
- static const Op& op = Op::Get("clip");
+TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed([](Expr a, double a_min, double a_max) {
+ auto attrs = make_object<ClipAttrs>();
+ attrs->a_min = a_min;
+ attrs->a_max = a_max;
+ static const Op& op = Op::Get("clip");
return Call(op, {a}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("clip")
-.describe(R"code(Clip tensor values.
+ .describe(R"code(Clip tensor values.
This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_type_rel("Identity", IdentityRel)
-.set_attr<TOpPattern>("TOpPattern", kElemWise)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
-.set_attrs_type<ClipAttrs>()
-.set_support_level(3);
-
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kElemWise)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attrs_type<ClipAttrs>()
+ .set_support_level(3);
RELAY_REGISTER_UNARY_OP("floor")
-.describe(R"code(Returns the floor of input array, computed element-wise.
+ .describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor));
-
+ .set_support_level(3)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor));
RELAY_REGISTER_UNARY_OP("ceil")
-.describe(R"code(Returns the ceil of input array, computed element-wise.
+ .describe(R"code(Returns the ceil of input array, computed element-wise.
.. math::
ceil(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil));
-
+ .set_support_level(3)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil));
RELAY_REGISTER_UNARY_OP("trunc")
-.describe(R"code(Returns the trunc of input array, computed element-wise.
+ .describe(R"code(Returns the trunc of input array, computed element-wise.
.. math::
trunc(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc));
+ .set_support_level(3)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc));
RELAY_REGISTER_UNARY_OP("round")
-.describe(R"code(Returns the round of input array, computed element-wise.
+ .describe(R"code(Returns the round of input array, computed element-wise.
.. math::
round(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));
+ .set_support_level(3)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));
RELAY_REGISTER_UNARY_OP("sign")
-.describe(R"code(Returns the sign of input array, computed element-wise.
+ .describe(R"code(Returns the sign of input array, computed element-wise.
.. numpy::
sign(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign));
-
+ .set_support_level(3)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign));
RELAY_REGISTER_UNARY_OP("abs")
-.describe(R"code(Returns the abs of input array, computed element-wise.
+ .describe(R"code(Returns the abs of input array, computed element-wise.
.. math::
abs(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs));
-
+ .set_support_level(3)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs));
RELAY_REGISTER_UNARY_OP("tanh")
-.describe(R"code(Returns the tanh of input array, computed element-wise.
+ .describe(R"code(Returns the tanh of input array, computed element-wise.
.. math::
Y = sinh(X) / cosh(X)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
RELAY_REGISTER_UNARY_OP("fast_tanh")
-.describe(R"code(Returns the fast_tanh of input array, computed element-wise.
+ .describe(R"code(Returns the fast_tanh of input array, computed element-wise.
.. math::
Y = sinh(X) / cosh(X)
)code" TVM_ADD_FILELINE)
-.set_support_level(1)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh));
-
+ .set_support_level(1)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh));
RELAY_REGISTER_UNARY_OP("negative")
-.describe(R"code(Returns the numeric negative of input array, computed element-wise.
+ .describe(R"code(Returns the numeric negative of input array, computed element-wise.
.. math::
-(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
-
+ .set_support_level(3)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
RELAY_REGISTER_UNARY_OP("logical_not")
-.describe(R"code(Returns the logical inverse of input array, computed element-wise.
+ .describe(R"code(Returns the logical inverse of input array, computed element-wise.
.. math::
!(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));
-
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));
RELAY_REGISTER_UNARY_OP("bitwise_not")
-.describe(R"code(Returns the bitwise inverse of input array, computed element-wise.
+ .describe(R"code(Returns the bitwise inverse of input array, computed element-wise.
.. math::
~(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(4)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not));
-
+ .set_support_level(4)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not));
// shape_of
TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);
-bool ShapeOfRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
return true;
}
-Array<te::Tensor> ShapeOfCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<ShapeOfAttrs>();
return {topi::shape(inputs[0], param->dtype)};
}
-TVM_REGISTER_GLOBAL("relay.op._make.shape_of")
-.set_body_typed([](Expr data, DataType dtype) {
+TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("shape_of");
});
RELAY_REGISTER_OP("shape_of")
-.describe(R"code(Returns a tensor representing the shape of a tensor.
+ .describe(R"code(Returns a tensor representing the shape of a tensor.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<ShapeOfAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.add_type_rel("ShapeOf", ShapeOfRel)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-// Use kOpaque for shape_of op for now since it won't be performance critic,
-// and it makes things easier for dynamic shape func
-.set_attr<TOpPattern>("TOpPattern", kOpaque)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
- ElemwiseArbitraryLayout)
-.set_support_level(10)
-.set_attr<FTVMCompute>("FTVMCompute", ShapeOfCompute);
-
+ .set_num_inputs(1)
+ .set_attrs_type<ShapeOfAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_type_rel("ShapeOf", ShapeOfRel)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ // Use kOpaque for shape_of op for now since it won't be performance critic,
+ // and it makes things easier for dynamic shape func
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_support_level(10)
+ .set_attr<FTVMCompute>("FTVMCompute", ShapeOfCompute);
TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs);
-bool NdarraySizeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
- const TypeReporter& reporter) {
+bool NdarraySizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
CHECK(tt != nullptr);
return true;
}
-Array<te::Tensor> NdarraySizeCompute(const Attrs& attrs,
- const Array<te::Tensor>& inputs,
+Array<te::Tensor> NdarraySizeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<NdarraySizeAttrs>();
return Array<te::Tensor>{topi::ndarray_size(inputs[0], param->dtype)};
}
-TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size")
-.set_body_typed([](Expr data, DataType dtype) {
+TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size").set_body_typed([](Expr data, DataType dtype) {
auto attrs = make_object<NdarraySizeAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("ndarray_size");
});
RELAY_REGISTER_OP("ndarray_size")
-.describe(R"code(Returns a tensor representing the number of elements of input tensor.
+ .describe(R"code(Returns a tensor representing the number of elements of input tensor.
)code" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_attrs_type<NdarraySizeAttrs>()
-.add_argument("data", "Tensor", "The input tensor.")
-.add_type_rel("NdarraySize", NdarraySizeRel)
-.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<TOpPattern>("TOpPattern", kInjective)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
-ElemwiseArbitraryLayout)
-.set_support_level(10)
-.set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);
+ .set_num_inputs(1)
+ .set_attrs_type<NdarraySizeAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_type_rel("NdarraySize", NdarraySizeRel)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<TOpPattern>("TOpPattern", kInjective)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_support_level(10)
+ .set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);
RELAY_REGISTER_UNARY_OP("isnan")
-.describe(R"code(Returns whether the input contains any NaN, computed element-wise.
+ .describe(R"code(Returns whether the input contains any NaN, computed element-wise.
.. math::
isnan(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.add_type_rel("IdentityCompRel", IdentityCompRel)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan));
+ .set_support_level(3)
+ .add_type_rel("IdentityCompRel", IdentityCompRel)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan));
RELAY_REGISTER_UNARY_OP("isfinite")
-.describe(R"code(Returns the finiteness of input, computed element-wise.
+ .describe(R"code(Returns the finiteness of input, computed element-wise.
.. math::
isfinite(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.add_type_rel("IdentityCompRel", IdentityCompRel)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite));
+ .set_support_level(3)
+ .add_type_rel("IdentityCompRel", IdentityCompRel)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite));
RELAY_REGISTER_UNARY_OP("isinf")
-.describe(R"code(Returns the infiniteness of input, computed element-wise.
+ .describe(R"code(Returns the infiniteness of input, computed element-wise.
.. math::
isinf(x)
)code" TVM_ADD_FILELINE)
-.set_support_level(3)
-.add_type_rel("IdentityCompRel", IdentityCompRel)
-.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf));
+ .set_support_level(3)
+ .add_type_rel("IdentityCompRel", IdentityCompRel)
+ .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf));
} // namespace relay
} // namespace tvm
* \brief A set of utilities and common functionality
* for type relations.
*/
+#include "./type_relations.h"
+
#include <tvm/arith/analyzer.h>
-#include <tvm/tir/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
+#include <tvm/tir/op.h>
+
#include <numeric>
-#include "./type_relations.h"
namespace tvm {
namespace relay {
-bool IdentityRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
for (size_t i = 1; i < types.size(); ++i) {
reporter->Assign(types[i], types[0]);
return true;
}
-bool EqualCheck(const IndexExpr& lhs,
- const IndexExpr& rhs) {
+bool EqualCheck(const IndexExpr& lhs, const IndexExpr& rhs) {
IndexExpr diff = lhs - rhs;
if (const int64_t* pdiff = tir::as_const_int(diff)) {
return pdiff[0] == 0;
return false;
}
-Type ConcreteBroadcast(const TensorType& t1,
- const TensorType& t2,
- DataType output_dtype) {
+Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
std::vector<IndexExpr> oshape;
size_t ndim1 = t1->shape.size();
size_t ndim2 = t2->shape.size();
} else if (EqualCheck(s1, s2)) {
oshape.push_back(s1);
} else {
- throw Error(ErrorBuilder()
- << "Incompatible broadcast type "
- << t1 << " and " << t2);
+ throw Error(ErrorBuilder() << "Incompatible broadcast type " << t1 << " and " << t2);
}
}
for (; i <= max_ndim; ++i) {
oshape.push_back(rshape[max_ndim - i]);
}
- return TensorType(Array<IndexExpr>(
- oshape.rbegin(), oshape.rend()), output_dtype);
+ return TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), output_dtype);
}
-bool BroadcastRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto* t1 = types[1].as<TensorTypeNode>()) {
CHECK_EQ(t0->dtype, t1->dtype);
- reporter->Assign(types[2],
- ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype));
+ reporter->Assign(
+ types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype));
return true;
}
}
return false;
}
-bool BroadcastCompRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BroadcastCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto* t1 = types[1].as<TensorTypeNode>()) {
CHECK_EQ(t0->dtype, t1->dtype);
- reporter->Assign(types[2],
- ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), DataType::Bool()));
+ reporter->Assign(types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1),
+ DataType::Bool()));
return true;
}
}
return false;
}
-bool IdentityCompRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool IdentityCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
if (auto* t0 = types[0].as<TensorTypeNode>()) {
Type out_type = TensorType(GetRef<TensorType>(t0)->shape, DataType::Bool());
if (shape.size() == 0) {
return {};
} else {
- return { tvm::Integer(shape.size()) };
+ return {tvm::Integer(shape.size())};
}
}
#include <tvm/ir/error.h>
#include <tvm/relay/type.h>
+
#include <string>
namespace tvm {
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/
-bool IdentityRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);
/*!
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/
-bool BroadcastRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);
/*!
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/
-bool BroadcastCompRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool BroadcastCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);
-bool IdentityCompRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
- const TypeReporter& reporter);
+bool IdentityCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter);
Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);
* \file multibox_op.cc
* \brief Multibox related operators
*/
-#include <tvm/tir/op.h>
-#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/op.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(MultiBoxPriorAttrs);
-bool MultiboxPriorRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool MultiboxPriorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
const MultiBoxPriorAttrs* param = attrs.as<MultiBoxPriorAttrs>();
const auto& dshape = data->shape;
CHECK_EQ(dshape.size(), 4) << "Input data should be 4D: "
- "[batch, channel, height, width]";
+ "[batch, channel, height, width]";
IndexExpr in_height = dshape[2];
IndexExpr in_width = dshape[3];
int num_sizes = static_cast<int>(param->sizes.size());
int num_ratios = static_cast<int>(param->ratios.size());
// since input sizes are same in each batch, we could share MultiBoxPrior
- std::vector<IndexExpr> oshape(
- {1, in_height * in_width * (num_sizes + num_ratios - 1), 4});
+ std::vector<IndexExpr> oshape({1, in_height * in_width * (num_sizes + num_ratios - 1), 4});
// assign output type
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
-
-Expr MakeMultiBoxPrior(Expr data,
- Array<IndexExpr> sizes,
- Array<IndexExpr> ratios,
- Array<IndexExpr> steps,
- Array<IndexExpr> offsets,
- bool clip) {
+Expr MakeMultiBoxPrior(Expr data, Array<IndexExpr> sizes, Array<IndexExpr> ratios,
+ Array<IndexExpr> steps, Array<IndexExpr> offsets, bool clip) {
auto attrs = make_object<MultiBoxPriorAttrs>();
attrs->sizes = std::move(sizes);
attrs->ratios = std::move(ratios);
return Call(op, {data}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior")
-.set_body_typed(MakeMultiBoxPrior);
-
+TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior").set_body_typed(MakeMultiBoxPrior);
RELAY_REGISTER_OP("vision.multibox_prior")
-.describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
+ .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
)doc" TVM_ADD_FILELINE)
-.set_attrs_type<MultiBoxPriorAttrs>()
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_support_level(5)
-.add_type_rel("MultiBoxPrior", MultiboxPriorRel);
+ .set_attrs_type<MultiBoxPriorAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(5)
+ .add_type_rel("MultiBoxPrior", MultiboxPriorRel);
TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs);
-bool MultiBoxTransformLocRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool MultiBoxTransformLocRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto& loc_shape = loc_pred->shape;
const auto& anchor_shape = anchor->shape;
- CHECK_EQ(cls_shape.size(), 3U)
- << "The dimension of class probability should be 3, but received "
- << cls_shape.size();
+ CHECK_EQ(cls_shape.size(), 3U) << "The dimension of class probability should be 3, but received "
+ << cls_shape.size();
CHECK_EQ(loc_shape.size(), 2U)
- << "The dimension of location prediction should be 2, but received "
- << loc_shape.size();
+ << "The dimension of location prediction should be 2, but received " << loc_shape.size();
CHECK_EQ(anchor_shape.size(), 3U)
- << "The dimension of anchor should be 3, but received "
- << anchor_shape.size();
+ << "The dimension of anchor should be 3, but received " << anchor_shape.size();
- CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1]))
- << "Number of anchors mismatch found";
- CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1]))
- << "# anchors mismatch with # loc.";
+ CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) << "Number of anchors mismatch found";
+ CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) << "# anchors mismatch with # loc.";
CHECK(reporter->Assert(anchor_shape[1] > 0)) << "Number of anchors must > 0.";
CHECK(reporter->AssertEQ(anchor_shape[2], 4));
return true;
}
-Expr MakeMultiBoxTransformLoc(Expr cls_prob,
- Expr loc_pred,
- Expr anchor,
- bool clip,
- double threshold,
- Array<IndexExpr> variances) {
+Expr MakeMultiBoxTransformLoc(Expr cls_prob, Expr loc_pred, Expr anchor, bool clip,
+ double threshold, Array<IndexExpr> variances) {
auto attrs = make_object<MultiBoxTransformLocAttrs>();
attrs->clip = std::move(clip);
attrs->threshold = std::move(threshold);
}
TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_transform_loc")
-.set_body_typed(MakeMultiBoxTransformLoc);
+ .set_body_typed(MakeMultiBoxTransformLoc);
RELAY_REGISTER_OP("vision.multibox_transform_loc")
-.describe(R"doc("Location transformation for multibox detection."
+ .describe(R"doc("Location transformation for multibox detection."
)doc" TVM_ADD_FILELINE)
-.set_attrs_type<MultiBoxTransformLocAttrs>()
-.set_num_inputs(3)
-.add_argument("cls_prob", "Tensor", "Class probabilities.")
-.add_argument("loc_pred", "Tensor", "Location regression predictions.")
-.add_argument("anchor", "Tensor", "Multibox prior anchor boxes")
-.add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel)
-.set_support_level(5);
+ .set_attrs_type<MultiBoxTransformLocAttrs>()
+ .set_num_inputs(3)
+ .add_argument("cls_prob", "Tensor", "Class probabilities.")
+ .add_argument("loc_pred", "Tensor", "Location regression predictions.")
+ .add_argument("anchor", "Tensor", "Multibox prior anchor boxes")
+ .add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel)
+ .set_support_level(5);
} // namespace relay
} // namespace tvm
* \file nms.cc
* \brief Non-maximum suppression operators
*/
-#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>
+#include <tvm/relay/op.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs);
-bool GetValidCountRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool GetValidCountRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
return true;
}
-Expr MakeGetValidCounts(Expr data,
- double score_threshold,
- int id_index,
- int score_index) {
+Expr MakeGetValidCounts(Expr data, double score_threshold, int id_index, int score_index) {
auto attrs = make_object<GetValidCountsAttrs>();
attrs->score_threshold = score_threshold;
attrs->id_index = id_index;
return Call(op, {data}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts")
-.set_body_typed(MakeGetValidCounts);
-
+TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts").set_body_typed(MakeGetValidCounts);
RELAY_REGISTER_OP("vision.get_valid_counts")
-.describe(R"doc(Get valid count of bounding boxes given
+ .describe(R"doc(Get valid count of bounding boxes given
a score threshold. Also moves valid boxes to the top of
input data.
)doc" TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("data", "Tensor", "Input data.")
-.set_support_level(5)
-.add_type_rel("GetValidCount", GetValidCountRel);
-
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "Input data.")
+ .set_support_level(5)
+ .add_type_rel("GetValidCount", GetValidCountRel);
TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs);
-bool NMSRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool NMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* valid_count = types[1].as<TensorTypeNode>();
- const NonMaximumSuppressionAttrs* param =
- attrs.as<NonMaximumSuppressionAttrs>();
+ const NonMaximumSuppressionAttrs* param = attrs.as<NonMaximumSuppressionAttrs>();
const auto& dshape = data->shape;
const auto& vshape = valid_count->shape;
CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
return true;
}
-
-Expr MakeNMS(Expr data,
- Expr valid_count,
- int max_output_size,
- double iou_threshold,
- bool force_suppress,
- int top_k,
- int coord_start,
- int score_index,
- int id_index,
- bool return_indices,
- bool invalid_to_bottom) {
+Expr MakeNMS(Expr data, Expr valid_count, int max_output_size, double iou_threshold,
+ bool force_suppress, int top_k, int coord_start, int score_index, int id_index,
+ bool return_indices, bool invalid_to_bottom) {
auto attrs = make_object<NonMaximumSuppressionAttrs>();
attrs->max_output_size = max_output_size;
attrs->iou_threshold = iou_threshold;
return Call(op, {data, valid_count}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression")
-.set_body_typed(MakeNMS);
-
+TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS);
RELAY_REGISTER_OP("vision.non_max_suppression")
-.describe(R"doc(Non-maximum suppression. The input boxes should
+ .describe(R"doc(Non-maximum suppression. The input boxes should
be in the format of [class_id, score, left, top, right, bottom].
Set id_index to be -1 to ignore class_id axis.
)doc" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "Input data.")
-.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
-.set_support_level(5)
-.add_type_rel("NMS", NMSRel);
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "Input data.")
+ .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
+ .set_support_level(5)
+ .add_type_rel("NMS", NMSRel);
} // namespace relay
} // namespace tvm
* \file rcnn_op.cc
* \brief Faster RCNN and Mask RCNN operators
*/
+#include <tvm/relay/attrs/vision.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
-#include <tvm/relay/attrs/vision.h>
namespace tvm {
namespace relay {
return Call(op, {data, rois}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align")
-.set_body_typed(MakeROIAlign);
+TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align").set_body_typed(MakeROIAlign);
RELAY_REGISTER_OP("vision.roi_align")
.describe(R"doc(ROI Align operator.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`.
)doc" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("rois", "Tensor", "The input rois")
-.set_support_level(5)
-.add_type_rel("ROIAlign", ROIAlignRel);
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("rois", "Tensor", "The input rois")
+ .set_support_level(5)
+ .add_type_rel("ROIAlign", ROIAlignRel);
TVM_REGISTER_NODE_TYPE(ROIPoolAttrs);
bool ROIPoolRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
- const TypeReporter& reporter) {
+ const TypeReporter& reporter) {
auto roi_pool_attrs = attrs.as<ROIPoolAttrs>();
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
return Call(op, {data, rois}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool")
-.set_body_typed(MakeROIPool);
+TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool").set_body_typed(MakeROIPool);
RELAY_REGISTER_OP("vision.roi_pool")
.describe(R"doc(ROI Pool operator.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`.
)doc" TVM_ADD_FILELINE)
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("rois", "Tensor", "The input rois")
-.set_support_level(5)
-.add_type_rel("ROIPool", ROIPoolRel);
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("rois", "Tensor", "The input rois")
+ .set_support_level(5)
+ .add_type_rel("ROIPool", ROIPoolRel);
TVM_REGISTER_NODE_TYPE(ProposalAttrs);
auto batch = cls_prob->shape[0];
- std::vector<IndexExpr> oshape(
- {batch * proposal_attrs->rpn_post_nms_top_n, 5});
+ std::vector<IndexExpr> oshape({batch * proposal_attrs->rpn_post_nms_top_n, 5});
reporter->Assign(types[3], TensorType(oshape, cls_prob->dtype));
return true;
}
Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array<IndexExpr> scales,
Array<IndexExpr> ratios, int feature_stride, double threshold,
- int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size,
- bool iou_loss) {
+ int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, bool iou_loss) {
auto attrs = make_object<ProposalAttrs>();
attrs->scales = scales;
attrs->ratios = ratios;
return Call(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal")
-.set_body_typed(MakeProposal);
+TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal").set_body_typed(MakeProposal);
RELAY_REGISTER_OP("vision.proposal")
.describe(R"code(Generate region proposals via RPN.
- **im_info**: 2-D with shape [batch, 3].
- **out**: 2-D with shape [batch * rpn_post_nms_top_n, 5].
)code" TVM_ADD_FILELINE)
-.set_num_inputs(3)
-.add_argument("cls_prob", "Tensor", "Score of how likely proposal is object")
-.add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals")
-.add_argument("im_info", "Tensor", "Image size and scale")
-.set_support_level(5)
-.add_type_rel("Proposal", ProposalRel);
+ .set_num_inputs(3)
+ .add_argument("cls_prob", "Tensor", "Score of how likely proposal is object")
+ .add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals")
+ .add_argument("im_info", "Tensor", "Image size and scale")
+ .set_support_level(5)
+ .add_type_rel("Proposal", ProposalRel);
} // namespace relay
} // namespace tvm
* \file yolo.cc
* \brief Yolo related operators
*/
-#include <tvm/relay/op.h>
-#include <tvm/relay/attrs/vision.h>
#include <topi/vision/reorg.h>
+#include <tvm/relay/attrs/vision.h>
+#include <tvm/relay/op.h>
+
#include <vector>
+
#include "../op_common.h"
#include "../type_relations.h"
TVM_REGISTER_NODE_TYPE(YoloReorgAttrs);
/*!
-* \brief YoloReorgRel Output type and shape relation evaluation function.
-* \param num_inputs Number of input types in the args.
-* \param attrs The additional attributes of the operator.
-* \param reporter The reporter to report solution to.
-* \return false if This relation cannot be resolved. true if this relation has been resolved.
-*/
-bool YoloReorgRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+ * \brief YoloReorgRel Output type and shape relation evaluation function.
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return false if This relation cannot be resolved. true if this relation has been resolved.
+ */
+bool YoloReorgRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
return true;
}
-Expr MakeYoloReorg(Expr data,
- Integer stride) {
+Expr MakeYoloReorg(Expr data, Integer stride) {
auto attrs = make_object<YoloReorgAttrs>();
attrs->stride = stride;
static const Op& op = Op::Get("vision.yolo_reorg");
return Call(op, {data}, Attrs(attrs), {});
}
-
-TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg")
-.set_body_typed(MakeYoloReorg);
-
+TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg").set_body_typed(MakeYoloReorg);
RELAY_REGISTER_OP("vision.yolo_reorg")
-.describe(R"doc("Yolo reorg operation. This layer reorganize the output.
+ .describe(R"doc("Yolo reorg operation. This layer reorganize the output.
Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
-.add_argument("data", "Tensor", "The input tensor.")
-.set_num_inputs(1)
-.set_support_level(5)
-.set_attrs_type<YoloReorgAttrs>()
-.add_type_rel("YoloReorg", YoloReorgRel)
-.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type) {
- const auto* params = attrs.as<YoloReorgAttrs>();
- CHECK(params != nullptr);
- return Array<te::Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
-});
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_num_inputs(1)
+ .set_support_level(5)
+ .set_attrs_type<YoloReorgAttrs>()
+ .add_type_rel("YoloReorg", YoloReorgRel)
+ .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ const auto* params = attrs.as<YoloReorgAttrs>();
+ CHECK(params != nullptr);
+ return Array<te::Tensor>{topi::vision::reorg(inputs[0], params->stride)};
+ });
} // namespace relay
} // namespace tvm
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
+
#include "op_common.h"
namespace tvm {
// Get the input dtype and shape.
QnnBinaryOpTensorType input_type(arg_types, 0);
-
// FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in
// the start, we can insert requantize at the end if both input tensors have same qnn params. In
// that case, we can first add the tensors, subtract the zero point, and requantize at the end.
// Q_c = Q_a' + Q_b' - zp_c
// The add op is done in int32 precision.
-
-
// Requantize LHS if necessary. Computes Q_a'
- auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale,
- args.lhs_zero_point,
- args.output_scale, args.output_zero_point,
- input_type.shape);
+ auto requantized_lhs =
+ RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale,
+ args.output_zero_point, input_type.shape);
// Requantize RHS if necessary. Computes Q_b'
- auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale,
- args.rhs_zero_point,
- args.output_scale, args.output_zero_point,
- input_type.shape);
+ auto requantized_rhs =
+ RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale,
+ args.output_zero_point, input_type.shape);
// Computes Q_a' + Q_b'
auto output = Add(requantized_lhs, requantized_rhs);
// QNN Addition operator.
QNN_REGISTER_BINARY_OP("add")
-.describe("Elementwise add with with broadcasting for quantized tensors.")
-.set_support_level(11)
-.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize);
+ .describe("Elementwise add with with broadcasting for quantized tensors.")
+ .set_support_level(11)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize);
} // namespace qnn
} // namespace relay
* \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis.
*/
-#include <tvm/tir/expr.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
+#include <tvm/tir/expr.h>
+
#include "../../op/tensor/transform.h"
-#include "../../transforms/pattern_util.h"
#include "../../transforms/infer_layout_util.h"
+#include "../../transforms/pattern_util.h"
#include "../util.h"
namespace tvm {
// Check the scale and zero point types
const auto* input_scales_tuple = types[1].as<TupleTypeNode>();
if (input_scales_tuple == nullptr) {
- throw Error(
- ErrorBuilder()
- << "qnn concatenate requires a tuple of scales as the second argument, found "
- << PrettyPrint(types[1]));
+ throw Error(ErrorBuilder()
+ << "qnn concatenate requires a tuple of scales as the second argument, found "
+ << PrettyPrint(types[1]));
}
for (const auto& input_scale : input_scales_tuple->fields) {
CHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx]
const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>();
if (input_zero_points_tuple == nullptr) {
- throw Error(
- ErrorBuilder()
- << "qnn concatenate requires a tuple of zero_points as the third argument, found "
- << PrettyPrint(types[2]));
+ throw Error(ErrorBuilder()
+ << "qnn concatenate requires a tuple of zero_points as the third argument, found "
+ << PrettyPrint(types[2]));
}
for (const auto& input_zero_point : input_zero_points_tuple->fields) {
CHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx]
auto attrs = make_object<ConcatenateAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("qnn.concatenate");
- return Call(op,
- {data, input_scales, input_zero_points, output_scale, output_zero_point},
- Attrs(attrs), {});
+ return Call(op, {data, input_scales, input_zero_points, output_scale, output_zero_point},
+ Attrs(attrs), {});
}
/*
}
RELAY_REGISTER_OP("qnn.concatenate")
-.describe(R"code(Concatenate the quantized input tensors along the given axis.
+ .describe(R"code(Concatenate the quantized input tensors along the given axis.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<ConcatenateAttrs>()
-.set_num_inputs(5)
-.add_argument("data", "Tensor", "The tensor to concatenate.")
-.add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.")
-.add_argument("input_zero_points", "Tensor", "The quantization zero_points of the input tensors.")
-.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
-.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
-.set_support_level(11)
-.add_type_rel("QnnConcatenate", QnnConcatenateRel)
-.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConcatenateLayout);
-
-TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate")
-.set_body_typed(MakeQnnConcatenate);
+ .set_attrs_type<ConcatenateAttrs>()
+ .set_num_inputs(5)
+ .add_argument("data", "Tensor", "The tensor to concatenate.")
+ .add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.")
+ .add_argument("input_zero_points", "Tensor",
+ "The quantization zero_points of the input tensors.")
+ .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
+ .add_argument("output_zero_point", "Tensor",
+ "The quantization zero_point of the output tensor.")
+ .set_support_level(11)
+ .add_type_rel("QnnConcatenate", QnnConcatenateRel)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConcatenateLayout);
+
+TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate").set_body_typed(MakeQnnConcatenate);
} // namespace qnn
} // namespace relay
* \file src/relay/qnn/op/convolution.cc
* \brief Property def of qnn convolution operator.
*/
-#include <tvm/tir/data_layout.h>
+#include "../../op/nn/convolution.h"
+
#include <tvm/relay/analysis.h>
#include <tvm/relay/base.h>
#include <tvm/relay/op.h>
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/transform.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/data_layout.h>
-#include "../../op/nn/convolution.h"
#include "../../transforms/pattern_util.h"
#include "../util.h"
}
bool is_depthwise(const Conv2DAttrs* param) {
- return param->channels.defined() &&
- tvm::tir::ExprDeepEqual()(param->channels, param->groups) &&
- param->groups != 1;
+ return param->channels.defined() && tvm::tir::ExprDeepEqual()(param->channels, param->groups) &&
+ param->groups != 1;
}
// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier
auto pad_left_value = get_const_int(param->padding[1]);
auto pad_bottom_value = get_const_int(param->padding[2]);
auto pad_right_value = get_const_int(param->padding[3]);
- bool do_pad = pad_top_value != 0 || pad_left_value != 0 ||
- pad_bottom_value != 0 || pad_right_value != 0;
+ bool do_pad =
+ pad_top_value != 0 || pad_left_value != 0 || pad_bottom_value != 0 || pad_right_value != 0;
if (do_pad) {
Array<IndexExpr> pad_n({0, 0});
Array<IndexExpr> pad_c({0, 0});
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("qnn.conv2d");
- return Call(
- op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
- Attrs(attrs), {});
+ return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
+ Attrs(attrs), {});
}
RELAY_REGISTER_OP("qnn.conv2d")
-.describe(R"code(2D quantized convolution layer.
+ .describe(R"code(2D quantized convolution layer.
This operator convolves quantized weight with quantized data. The scale of the
output quantized tensor is the product of the weight_scale and input_scale of
the input quantized tensors. The zero point of the output quantized tensor is
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DAttrs>()
-.set_num_inputs(6)
-.add_argument("data", "Tensor", "The quantized input data tensor.")
-.add_argument("weight", "Tensor", "The quantized weight tensor.")
-.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
-.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
-.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.")
-.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.")
-.set_support_level(11)
-.add_type_rel("QnnConv2D", QnnConv2DRel)
-.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConvInferCorrectLayout);
+ .set_attrs_type<Conv2DAttrs>()
+ .set_num_inputs(6)
+ .add_argument("data", "Tensor", "The quantized input data tensor.")
+ .add_argument("weight", "Tensor", "The quantized weight tensor.")
+ .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
+ .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
+ .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.")
+ .add_argument("weight_zero_point", "Tensor",
+ "The quantization zero_point of the weight tensor.")
+ .set_support_level(11)
+ .add_type_rel("QnnConv2D", QnnConv2DRel)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConvInferCorrectLayout);
TVM_REGISTER_GLOBAL("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D);
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
+
#include "../../op/nn/nn.h"
#include "../../transforms/pattern_util.h"
#include "../util.h"
attrs->units = std::move(units);
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("qnn.dense");
- return Call(
- op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
- Attrs(attrs), {});
+ return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
+ Attrs(attrs), {});
}
Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel,
}
RELAY_REGISTER_OP("qnn.dense")
-.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
+ .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
- **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)`
- **weight**: quantized(int8, unit8) `(units, input_dim)`
- **out**: quantized(int32) `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<DenseAttrs>()
-.set_num_inputs(6)
-.add_argument("data", "quantized nD Tensor", "Input data.")
-.add_argument("weight", "quantized 2D Tensor", "Weight matrix.")
-.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
-.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
-.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.")
-.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.")
-.set_support_level(11)
-.add_type_rel("QDense", QnnDenseRel)
-.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize);
-
-TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense")
-.set_body_typed(MakeQuantizedDense);
+ .set_attrs_type<DenseAttrs>()
+ .set_num_inputs(6)
+ .add_argument("data", "quantized nD Tensor", "Input data.")
+ .add_argument("weight", "quantized 2D Tensor", "Weight matrix.")
+ .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
+ .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
+ .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.")
+ .add_argument("weight_zero_point", "Tensor",
+ "The quantization zero_point of the weight tensor.")
+ .set_support_level(11)
+ .add_type_rel("QDense", QnnDenseRel)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize);
+
+TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense);
} // namespace qnn
} // namespace relay
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
+
#include "../../transforms/pattern_util.h"
#include "../util.h"
namespace relay {
namespace qnn {
-bool DequantizeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
const auto input_dtype = data->dtype;
- CHECK(input_dtype == DataType::Int(8) ||
- input_dtype == DataType::UInt(8) ||
+ CHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) ||
input_dtype == DataType::Int(32))
- << "Input type should be one of the quantized types [unit8, int8, int32] but was "
- << input_dtype;
+ << "Input type should be one of the quantized types [unit8, int8, int32] but was "
+ << input_dtype;
// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale
}
RELAY_REGISTER_OP("qnn.dequantize")
-.describe(R"code(Dequantizes the input and produces float32 output.
+ .describe(R"code(Dequantizes the input and produces float32 output.
The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point.
- **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point
)code" TVM_ADD_FILELINE)
-.set_num_inputs(3)
-.add_argument("data", "Tensor", "The tensor to dequantize.")
-.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
-.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
-.set_support_level(11)
-.add_type_rel("Dequantize", DequantizeRel)
-.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "The tensor to dequantize.")
+ .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
+ .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
+ .set_support_level(11)
+ .add_type_rel("Dequantize", DequantizeRel)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);
-TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize")
-.set_body_typed(MakeDequantize);
+TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize").set_body_typed(MakeDequantize);
} // namespace qnn
} // namespace relay
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
+
#include "../../transforms/pattern_util.h"
#include "../util.h"
#include "op_common.h"
auto new_input_zero_point = zero_scalar;
// Requantize to get Q_c
- output = Requantize(output, input_type.shape,
- new_input_scale,
- new_input_zero_point,
- args.output_scale,
- args.output_zero_point,
- input_type.dtype);
+ output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
+ args.output_scale, args.output_zero_point, input_type.dtype);
return output;
}
// QNN Multiplication operator.
QNN_REGISTER_BINARY_OP("mul")
-.describe("Elementwise mul with with broadcasting for quantized tensors.")
-.set_support_level(11)
-.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);
+ .describe("Elementwise mul with with broadcasting for quantized tensors.")
+ .set_support_level(11)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);
} // namespace qnn
} // namespace relay
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
+
#include <vector>
+
#include "../../op/type_relations.h"
#include "../../transforms/infer_layout_util.h"
#include "../util.h"
*/
struct QnnBinaryOpTensorType {
DataType dtype;
- Array <PrimExpr> shape;
+ Array<PrimExpr> shape;
- explicit QnnBinaryOpTensorType(const Array<tvm::relay::Type>& arg_types,
- const int32_t arg_idx) {
+ explicit QnnBinaryOpTensorType(const Array<tvm::relay::Type>& arg_types, const int32_t arg_idx) {
CHECK_EQ(arg_types.size(), kNumQnnBinaryOpArgTypes);
auto tensor_type = arg_types[arg_idx].as<TensorTypeNode>();
CHECK(tensor_type != nullptr);
* \return New expression with target dtype and possibly lower
* precision.
*/
-inline Expr ConvertDtype(const Expr& expr,
- const DataType& target_dtype) {
+inline Expr ConvertDtype(const Expr& expr, const DataType& target_dtype) {
auto q_min = GetQmin(target_dtype);
auto q_max = GetQmax(target_dtype);
auto output = Clip(expr, q_min, q_max);
* it simply casts the given expression to Int32 as no requantization is
* needed in this case.
*/
-inline Expr RequantizeOrUpcast(const Expr& expr,
- const Expr& expr_scale,
- const Expr& expr_zero_point,
- const Expr& target_scale,
- const Expr& target_zero_point,
- const Array<PrimExpr>& expr_shape,
+inline Expr RequantizeOrUpcast(const Expr& expr, const Expr& expr_scale,
+ const Expr& expr_zero_point, const Expr& target_scale,
+ const Expr& target_zero_point, const Array<PrimExpr>& expr_shape,
const DataType& target_dtype = DataType::Int(32)) {
auto result = expr;
if (!IsEqualScalar(expr_scale, target_scale) ||
!IsEqualScalar(expr_zero_point, target_zero_point)) {
- result = Requantize(expr, expr_shape, expr_scale, expr_zero_point,
- target_scale, target_zero_point, target_dtype);
+ result = Requantize(expr, expr_shape, expr_scale, expr_zero_point, target_scale,
+ target_zero_point, target_dtype);
} else {
result = Cast(result, target_dtype);
}
}
/*! \brief Infer layout for QNN binary broadcast operators */
-inline Array<Array<Layout> > QnnBinaryBroadcastLayout(
- const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type>& old_in_types) {
+inline Array<Array<Layout> > QnnBinaryBroadcastLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
// Use Relay Binary Broadcast Infer correct layout.
auto layouts = BinaryBroadcastLayout(attrs, new_in_layouts, old_in_layouts, old_in_types);
// Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
// tensors can be treated as C.
Layout channel_layout = Layout("C");
- Array<Layout> input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout,
+ Array<Layout> input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout,
channel_layout, channel_layout, channel_layout, channel_layout};
Array<Layout> output_layouts = layouts[1];
return {input_layouts, output_layouts};
}
-
-static inline bool QnnBroadcastRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes);
*
* \param OpName the name of registry.
*/
-#define QNN_REGISTER_BINARY_OP(OpName) \
- TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \
- .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \
- Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \
- static const Op& op = Op::Get("qnn." OpName); \
- return Call(op, {lhs, rhs, \
- lhs_scale, lhs_zero_point, \
- rhs_scale, rhs_zero_point, \
- output_scale, output_zero_point}, Attrs(), {}); \
- }); \
- RELAY_REGISTER_OP("qnn." OpName) \
- .set_num_inputs(kNumQnnBinaryOpInputs) \
- .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \
- .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \
- .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \
- .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \
- .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \
- .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \
- .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \
- .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \
- .add_type_rel("QnnBroadcast", QnnBroadcastRel) \
- .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)
+#define QNN_REGISTER_BINARY_OP(OpName) \
+ TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \
+ .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \
+ Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \
+ static const Op& op = Op::Get("qnn." OpName); \
+ return Call(op, \
+ {lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, \
+ output_zero_point}, \
+ Attrs(), {}); \
+ }); \
+ RELAY_REGISTER_OP("qnn." OpName) \
+ .set_num_inputs(kNumQnnBinaryOpInputs) \
+ .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \
+ .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \
+ .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \
+ .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \
+ .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \
+ .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \
+ .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \
+ .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \
+ .add_type_rel("QnnBroadcast", QnnBroadcastRel) \
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)
} // namespace qnn
} // namespace relay
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
+
#include "../../transforms/pattern_util.h"
#include "../util.h"
TVM_REGISTER_NODE_TYPE(QuantizeAttrs);
-bool QuantizeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
const auto input_dtype = data->dtype;
CHECK(input_dtype == DataType::Float(32))
- << "Input type should be one of float32 but was " << input_dtype;
+ << "Input type should be one of float32 but was " << input_dtype;
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
int axis = quantize_attrs->axis;
- axis = (axis == -1) ? data->shape.size() - 1: axis;
+ axis = (axis == -1) ? data->shape.size() - 1 : axis;
CHECK_LT(axis, static_cast<int>(data->shape.size()))
<< "axis " << quantize_attrs->axis << " is out of range";
- CHECK_GE(axis, 0)
- << "axis " << quantize_attrs->axis << " is out of range";
+ CHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range";
// Check and assign types for scale and zero points.
AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale
}
RELAY_REGISTER_OP("qnn.quantize")
-.describe(R"code(Quantizes the input and produces quantized output.
+ .describe(R"code(Quantizes the input and produces quantized output.
The input can be either float or quantized(int8, unit8). If the input is float,
this op takes scale and zero point and quantize the float value to
quantized output, in int8 or uint8 format. If the input is quantized value,
- **data**: Tensor of any shape to quantize. The input data can be of floating point
or quantized.
)code" TVM_ADD_FILELINE)
-.set_attrs_type<QuantizeAttrs>()
-.set_num_inputs(3)
-.add_argument("data", "Tensor", "The tensor to quantize.")
-.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
-.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
-.set_support_level(11)
-.add_type_rel("Quantize", QuantizeRel)
-.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize);
-
-TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize")
-.set_body_typed(MakeQuantize);
+ .set_attrs_type<QuantizeAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "The tensor to quantize.")
+ .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
+ .add_argument("output_zero_point", "Tensor",
+ "The quantization zero_point of the output tensor.")
+ .set_support_level(11)
+ .add_type_rel("Quantize", QuantizeRel)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize);
+
+TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize").set_body_typed(MakeQuantize);
} // namespace qnn
} // namespace relay
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
-#include "../../transforms/pattern_util.h"
+
#include "../../transforms/infer_layout_util.h"
+#include "../../transforms/pattern_util.h"
#include "../util.h"
namespace tvm {
for (auto iter_var : new_in_layouts[0]->axes) {
const auto& layout_axis = LayoutAxis::Get(iter_var);
const std::string& layout_dim = layout_axis.name();
- if (old_dim == layout_dim) {
+ if (old_dim == layout_dim) {
new_axis = tvm::Integer(axis_index);
}
// Collect only the primal axis.
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
const auto in_dtype = data->dtype;
- CHECK(in_dtype == DataType::Int(8) ||
- in_dtype == DataType::UInt(8) ||
+ CHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
in_dtype == DataType::Int(32))
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
int axis = requantize_attrs->axis;
- axis = (axis == -1) ? data->shape.size() - 1: axis;
+ axis = (axis == -1) ? data->shape.size() - 1 : axis;
CHECK_LT(axis, static_cast<int>(data->shape.size()))
<< "axis " << requantize_attrs->axis << " is out of range";
- CHECK_GE(axis, 0)
- << "axis " << requantize_attrs->axis << " is out of range";
+ CHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range";
// Check and assign types for scale and zero points.
AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale
const Array<tvm::PrimExpr> oshape = data->shape;
// assign output type
auto out_dtype = requantize_attrs->out_dtype;
- CHECK(out_dtype == DataType::Int(8) ||
- out_dtype == DataType::UInt(8) ||
+ CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
reporter->Assign(types[5], TensorType(oshape, out_dtype));
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("qnn.requantize");
return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point},
- Attrs(attrs), {});
+ Attrs(attrs), {});
}
RELAY_REGISTER_OP("qnn.requantize")
-.describe(R"code(Requantize operator.
+ .describe(R"code(Requantize operator.
The requantize operator converts one quantized tensor to another quantized
tensor. For the output tensor, we are provided with output scale and zero
point. The computation looks like this
Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
)code" TVM_ADD_FILELINE)
-.set_attrs_type<RequantizeAttrs>()
-.set_num_inputs(5)
-.add_argument("data", "Tensor", "The quantized input tensor.")
-.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
-.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
-.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
-.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
-.set_support_level(11)
-.add_type_rel("Requantize", RequantizeRel)
-.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", RequantizeInferCorrectLayout);
-
-TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize")
-.set_body_typed(MakeRequantize);
+ .set_attrs_type<RequantizeAttrs>()
+ .set_num_inputs(5)
+ .add_argument("data", "Tensor", "The quantized input tensor.")
+ .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
+ .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
+ .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
+ .add_argument("output_zero_point", "Tensor",
+ "The quantization zero_point of the output tensor.")
+ .set_support_level(11)
+ .add_type_rel("Requantize", RequantizeRel)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", RequantizeInferCorrectLayout);
+
+TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize").set_body_typed(MakeRequantize);
} // namespace qnn
} // namespace relay
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
+
#include "op_common.h"
namespace tvm {
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for add op.
*/
-Expr QnnSubtractCanonicalize(const Attrs& attrs,
- const Array<Expr>& new_args,
+Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// Get the args.
QnnBinaryOpArguments args(new_args);
// The subtract op is done in int32 precision.
// Requantize LHS if necessary. Computes Q_a'
- auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale,
- args.lhs_zero_point,
- args.output_scale,
- args.output_zero_point,
- input_type.shape);
+ auto requantized_lhs =
+ RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale,
+ args.output_zero_point, input_type.shape);
// Requantize RHS if necessary. Computes Q_b'
- auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale,
- args.rhs_zero_point,
- args.output_scale,
- args.output_zero_point,
- input_type.shape);
+ auto requantized_rhs =
+ RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale,
+ args.output_zero_point, input_type.shape);
// Computes Q_a' - Q_b'
auto output = Subtract(requantized_lhs, requantized_rhs);
// QNN Subtraction operator.
QNN_REGISTER_BINARY_OP("subtract")
-.describe("Elementwise subtract with with broadcasting for quantized tensors.")
-.set_support_level(11)
-.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize);
-
+ .describe("Elementwise subtract with with broadcasting for quantized tensors.")
+ .set_support_level(11)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize);
} // namespace qnn
} // namespace relay
*/
#include "util.h"
+
#include "../transforms/pattern_util.h"
namespace tvm {
*
* Credit to TFLite reference implementation.
*/
-std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
- double double_multiplier) {
+std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
int32_t significand, exponent;
if (double_multiplier == 0.) {
significand = 0;
// 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier, shift;
- std::tie(fixed_point_multiplier, shift) =
- GetFixedPointMultiplierShift(multiplier);
+ std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(multiplier);
int left_shift = shift > 0 ? shift : 0;
int right_shift = shift > 0 ? 0 : -shift;
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
auto zero_t = Zeros(input_shape, hp_dtype);
- round_scalar =
- Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
+ round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
}
tensor = Add(tensor, round_scalar);
// 5) Simply right shift the result to get the final output.
- tensor =
- RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
+ tensor = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
return Cast(tensor, DataType::Int(32));
#ifndef TVM_RELAY_QNN_UTIL_H_
#define TVM_RELAY_QNN_UTIL_H_
-#include <tvm/tir/expr.h>
-#include <tvm/tir/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/qnn/attrs.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+
#include <limits>
#include <string>
-#include <vector>
#include <utility>
+#include <vector>
namespace tvm {
namespace relay {
}
static inline int32_t GetQmin(const DataType& dtype) {
- CHECK_LE(dtype.bits(), 32)
- << "QNN ops support int32 or lower precision";
+ CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision";
if (dtype.is_int() || dtype.is_uint()) {
auto* min_value = tir::as_const_int(tvm::min_value(dtype));
CHECK(min_value != nullptr);
}
static inline int32_t GetQmax(const DataType& dtype) {
- CHECK_LE(dtype.bits(), 32)
- << "QNN ops support int32 or lower precision";
+ CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision";
if (dtype.is_int() || dtype.is_uint()) {
auto* max_value = tir::as_const_int(tvm::max_value(dtype));
CHECK(max_value != nullptr);
const TypeReporter& reporter) {
// Scale/Zero_points can be either const scalar or a vector with C axis num elems.
const auto* tensor_type = expr_type.as<TensorTypeNode>();
- CHECK(tensor_type) << "Can assign type to Tensor type only. But got "
- << AsText(expr_type, false);
+ CHECK(tensor_type) << "Can assign type to Tensor type only. But got " << AsText(expr_type, false);
const auto tensor_dtype = tensor_type->dtype;
CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
if (tensor_type->shape.size() != 0) {
* \brief Annotating the graph with simulated quantize operators.
*/
-#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/transform.h>
+
#include "./quantize.h"
namespace tvm {
TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode);
};
-
-Expr QAnnotateExprNode::Realize() const {
- return expr;
-}
+Expr QAnnotateExprNode::Realize() const { return expr; }
QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) {
auto rnode = make_object<QAnnotateExprNode>();
data_ = std::move(rnode);
}
-TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr")
-.set_body_typed([](Expr expr, int kind) {
+TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr").set_body_typed([](Expr expr, int kind) {
return QAnnotateExpr(expr, static_cast<QAnnotateKind>(kind));
});
-
Pass QuantizeAnnotate() {
// TODO(tvm-teams): since partition has added cast_hint in different
// branches, try to remove this in the future.
if (e->IsInstance<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
- const PackedFunc* f =
- runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
+ const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExpr(ret, kQInput));
}
};
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
- auto new_params = func->params;
- for (const auto& x : FreeVars(func)) {
- new_params.push_back(x);
- }
- return Function(new_params,
- func->body,
- func->ret_type,
- func->type_params,
- func->attrs);
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
+ auto new_params = func->params;
+ for (const auto& x : FreeVars(func)) {
+ new_params.push_back(x);
+ }
+ return Function(new_params, func->body, func->ret_type, func->type_params, func->attrs);
+ };
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
-TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate")
-.set_body_typed(QuantizeAnnotate);
+TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate").set_body_typed(QuantizeAnnotate);
TVM_REGISTER_NODE_TYPE(QAnnotateExprNode);
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
+
#include <numeric>
+
#include "./quantize.h"
namespace tvm {
}
static float ComputeEntropy(float* p, float* q, size_t size) {
- float p_sum = std::accumulate(p, p+size, 0.f);
- float q_sum = std::accumulate(q, q+size, 0.f);
+ float p_sum = std::accumulate(p, p + size, 0.f);
+ float q_sum = std::accumulate(q, q + size, 0.f);
float ret = 0;
for (size_t i = 0; i < size; i++) {
CHECK(p[i] > 0 && q[i] > 0);
return ret;
}
-float MinimizeKL(const std::vector<int>& hist,
- const std::vector<float>& hist_edges,
- int num_bins, int num_quantized_bins) {
+float MinimizeKL(const std::vector<int>& hist, const std::vector<float>& hist_edges, int num_bins,
+ int num_quantized_bins) {
const int zero_bin_idx = num_bins / 2;
const int num_half_quantized_bins = num_quantized_bins / 2;
std::vector<float> thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f);
divergence[i - num_half_quantized_bins] = ComputeEntropy(p.data(), q.data(), p.size());
}
}
- auto min_divergence_idx = std::distance(divergence.begin(),
- std::min_element(divergence.begin(), divergence.end()));
- return thresholds[min_divergence_idx];;
+ auto min_divergence_idx =
+ std::distance(divergence.begin(), std::min_element(divergence.begin(), divergence.end()));
+ return thresholds[min_divergence_idx];
}
class StatsCollector : private ExprMutator {
CHECK(func) << "Input shoule be Function";
Expr new_body = Tuple(std::move(profile_data_));
return Function(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
- func->attrs);
+ func->attrs);
}
private:
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
// rewrite the annotation
auto new_attrs = make_object<SimulatedQuantizeAttrs>();
- const Expr& quantize_input = new_call->args[0]; // expression being quantized
+ const Expr& quantize_input = new_call->args[0]; // expression being quantized
auto placeholder = MakeConstantScalar(DataType::Float(32), 0.); // unused argument
Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder};
new_attrs->kind = QAnnotateKind::kQIdentity;
* \param expr The simulation graph after annotation.
* \return The profile graph.
*/
-Expr CreateStatsCollector(const Expr& expr) {
- return StatsCollector().Collect(expr);
-}
-
-TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector")
-.set_body_typed(CreateStatsCollector);
+Expr CreateStatsCollector(const Expr& expr) { return StatsCollector().Collect(expr); }
+TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector").set_body_typed(CreateStatsCollector);
TVM_REGISTER_GLOBAL("relay._quantize.FindScaleByKLMinimization")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- int* hist_ptr = static_cast<int*>(static_cast<void*>(args[0]));
- float* hist_edges_ptr = static_cast<float*>(static_cast<void*>(args[1]));
- int num_bins = args[2];
- int num_quantized_bins = args[3];
- std::vector<int> hist(hist_ptr, hist_ptr + num_bins);
- std::vector<float> hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1);
- ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins);
-});
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ int* hist_ptr = static_cast<int*>(static_cast<void*>(args[0]));
+ float* hist_edges_ptr = static_cast<float*>(static_cast<void*>(args[1]));
+ int num_bins = args[2];
+ int num_quantized_bins = args[3];
+ std::vector<int> hist(hist_ptr, hist_ptr + num_bins);
+ std::vector<float> hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1);
+ ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins);
+ });
} // namespace quantize
} // namespace relay
*/
#include <tvm/relay/transform.h>
+
#include "../transforms/pattern_util.h"
#include "./quantize.h"
using namespace relay::transform;
-
class QPartitionExpr;
class QPartitionExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("expr", &expr);
- }
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); }
Expr Realize() const final;
TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode);
};
-
Expr QPartitionExprNode::Realize() const {
// insert cast hint and stop fusion
const QConfig& cfg = QConfig::Current();
data_ = std::move(rnode);
}
-TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr")
-.set_body_typed([](Expr expr) {
+TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr").set_body_typed([](Expr expr) {
return QPartitionExpr(expr);
});
Pass QuantizePartition() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- auto ret = Downcast<Function>(
- ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
- return ret;
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ auto ret = Downcast<Function>(ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
+ return ret;
+ };
return CreateFunctionPass(pass_func, 1, "QuantizePartition", {});
}
-TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition")
-.set_body_typed(QuantizePartition);
+TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition").set_body_typed(QuantizePartition);
TVM_REGISTER_NODE_TYPE(QPartitionExprNode);
* \brief transform a graph to a low-bit graph
* for compression and acceleration.
*/
+#include "./quantize.h"
+
#include <dmlc/thread_local.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
-#include <stack>
-#include "./quantize.h"
+#include <stack>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
-bool SimulatedQuantizeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool SimulatedQuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 5);
const auto param = attrs.as<SimulatedQuantizeAttrs>();
CHECK(data != nullptr);
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
- reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale
- reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min
- reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max
- reporter->Assign(types[4], types[0]); // output
+ reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale
+ reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min
+ reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max
+ reporter->Assign(types[4], types[0]); // output
return true;
}
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
-.describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE)
-.set_num_inputs(4)
-.add_argument("data", "Tensor", "The input data.")
-.add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar")
-.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar")
-.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar")
-.set_attrs_type<SimulatedQuantizeAttrs>()
-.set_support_level(11)
-.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
+ .describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE)
+ .set_num_inputs(4)
+ .add_argument("data", "Tensor", "The input data.")
+ .add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar")
+ .add_argument("clip_min", "Tensor", "lower bound. It should be a scalar")
+ .add_argument("clip_max", "Tensor", "upper bound. It should be a scalar")
+ .set_attrs_type<SimulatedQuantizeAttrs>()
+ .set_support_level(11)
+ .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize")
-.set_body_typed(
- [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max,
- int kind, bool sign, std::string rounding) {
- auto attrs = make_object<SimulatedQuantizeAttrs>();
- attrs->kind = kind;
- attrs->sign = sign;
- attrs->rounding = rounding;
- static const Op& op = Op::Get("relay.op.annotation.simulated_quantize");
- return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
- });
-
+ .set_body_typed([](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign,
+ std::string rounding) {
+ auto attrs = make_object<SimulatedQuantizeAttrs>();
+ attrs->kind = kind;
+ attrs->sign = sign;
+ attrs->rounding = rounding;
+ static const Op& op = Op::Get("relay.op.annotation.simulated_quantize");
+ return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
+ });
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMQConfigThreadLocalEntry {
/*! \brief The current build config context */
std::stack<QConfig> context_stack;
- TVMQConfigThreadLocalEntry() :
- default_config(make_object<QConfigNode>()) {
- }
+ TVMQConfigThreadLocalEntry() : default_config(make_object<QConfigNode>()) {}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMQConfigThreadLocalEntry> TVMQConfigThreadLocalStore;
void QConfig::EnterQConfigScope(const QConfig& build_config) {
- TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get();
+ TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get();
entry->context_stack.push(build_config);
}
void QConfig::ExitQConfigScope() {
- TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get();
+ TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get();
entry->context_stack.pop();
}
QConfig& QConfig::Current() {
- TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get();
+ TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
TVM_REGISTER_NODE_TYPE(QConfigNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<QConfigNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* op = static_cast<const QConfigNode*>(ref.get());
- p->stream << "qconfig(";
- p->stream << "nbit_input=" << op->nbit_input << ", ";
- p->stream << "nbit_weight=" << op->nbit_weight << ", ";
- p->stream << "nbit_activation=" << op->nbit_activation << ", ";
- p->stream << "calibrate_mode=" << op->calibrate_mode << ", ";
- p->stream << "global_scale=" << op->global_scale << ", ";
- p->stream << "weight_scale=" << op->weight_scale << ", ";
- p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
- p->stream << "do_simulation==" << op->do_simulation << ", ";
- p->stream << "round_for_shift==" << op->round_for_shift << ", ";
- p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", ";
- p->stream << "rounding==" << op->rounding;
- p->stream << ")";
-});
-
-TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig")
-.set_body_typed([]() -> QConfig {
+ .set_dispatch<QConfigNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* op = static_cast<const QConfigNode*>(ref.get());
+ p->stream << "qconfig(";
+ p->stream << "nbit_input=" << op->nbit_input << ", ";
+ p->stream << "nbit_weight=" << op->nbit_weight << ", ";
+ p->stream << "nbit_activation=" << op->nbit_activation << ", ";
+ p->stream << "calibrate_mode=" << op->calibrate_mode << ", ";
+ p->stream << "global_scale=" << op->global_scale << ", ";
+ p->stream << "weight_scale=" << op->weight_scale << ", ";
+ p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
+ p->stream << "do_simulation==" << op->do_simulation << ", ";
+ p->stream << "round_for_shift==" << op->round_for_shift << ", ";
+ p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
+ p->stream << "rounding==" << op->rounding;
+ p->stream << ")";
+ });
+
+TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig").set_body_typed([]() -> QConfig {
return QConfig::Current();
});
TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope")
-.set_body_typed(QConfig::EnterQConfigScope);
+ .set_body_typed(QConfig::EnterQConfigScope);
-TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope")
-.set_body_typed(QConfig::ExitQConfigScope);
+TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope").set_body_typed(QConfig::ExitQConfigScope);
} // namespace quantize
} // namespace relay
#ifndef TVM_RELAY_QUANTIZE_QUANTIZE_H_
#define TVM_RELAY_QUANTIZE_QUANTIZE_H_
-#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+
#include <string>
+
#include "../transforms/pattern_util.h"
namespace tvm {
namespace quantize {
/*! \brief Kind of annotate field */
-enum QAnnotateKind : int {
- kQIdentity = 0,
- kQInput = 1,
- kQWeight = 2,
- kQActivation = 3
-};
+enum QAnnotateKind : int { kQIdentity = 0, kQInput = 1, kQWeight = 2, kQActivation = 3 };
/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
std::string rounding;
TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
- TVM_ATTR_FIELD(kind)
- .describe("kind of field, hint for nbit/dtype configuration.");
- TVM_ATTR_FIELD(sign).set_default(true)
- .describe("whether to use signed data type.");
- TVM_ATTR_FIELD(rounding).set_default("round")
- .describe("rounding mode. Can be 'floor', 'ceil', 'round'");
+ TVM_ATTR_FIELD(kind).describe("kind of field, hint for nbit/dtype configuration.");
+ TVM_ATTR_FIELD(sign).set_default(true).describe("whether to use signed data type.");
+ TVM_ATTR_FIELD(rounding).set_default("round").describe(
+ "rounding mode. Can be 'floor', 'ceil', 'round'");
}
};
-
class QConfig;
/*!
-* \brief Container for build configuration options
-*/
+ * \brief Container for build configuration options
+ */
class QConfigNode : public Object {
public:
int nbit_input = 8;
};
/*!
-* \brief Container for build configuration options
-*/
+ * \brief Container for build configuration options
+ */
class QConfig : public ObjectRef {
public:
QConfig() {}
explicit QConfig(ObjectPtr<Object> n) : ObjectRef(n) {}
- const QConfigNode* operator->() const {
- return static_cast<const QConfigNode*>(get());
- }
+ const QConfigNode* operator->() const { return static_cast<const QConfigNode*>(get()); }
- QConfigNode* operator->() {
- return static_cast<QConfigNode*>(get_mutable());
- }
+ QConfigNode* operator->() { return static_cast<QConfigNode*>(get_mutable()); }
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
*/
- explicit QConfigContext(const QConfig& qconfig) {
- QConfig::EnterQConfigScope(qconfig);
- }
+ explicit QConfigContext(const QConfig& qconfig) { QConfig::EnterQConfigScope(qconfig); }
/*! \brief Destructor. Pops the context off the thread local stack. */
- ~QConfigContext() {
- QConfig::ExitQConfigScope();
- }
+ ~QConfigContext() { QConfig::ExitQConfigScope(); }
};
} // namespace quantize
* graph.
*/
-#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
-#include "./quantize.h"
-#include "../transforms/pattern_util.h"
+#include <tvm/relay/transform.h>
+
#include "../qnn/util.h"
+#include "../transforms/pattern_util.h"
+#include "./quantize.h"
namespace tvm {
namespace relay {
TVM_DEFINE_OBJECT_REF_METHODS(QRealizeExpr, TempExpr, QRealizeExprNode);
};
-
class QRealizeIntExprNode : public QRealizeExprNode {
public:
Expr dom_scale;
Expr Realize() const final;
- static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
+ static constexpr const char* _type_key = "relay.quantize.QRealizeIntExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(QRealizeIntExprNode, QRealizeExprNode);
};
TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode);
};
-
Expr QRealizeIntExprNode::Realize() const {
Expr data = this->data;
// dequantize
data_ = std::move(n);
}
-
inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
return Call(ref_call->op, args, ref_call->attrs, ref_call->type_args);
}
-
/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
- const Array<IndexExpr> &data_shape) {
+ const Array<IndexExpr>& data_shape) {
const QConfig& cfg = QConfig::Current();
// here we assume the dtype of data is dtype activation
if (s1 == s2) return data;
float shift_factor = std::log2(factor);
CHECK_GT(shift_factor, 0);
if (static_cast<int>(shift_factor) == shift_factor) {
- return LeftShift(data, MakeConstantScalar(dtype,
- static_cast<int>(shift_factor)));
+ return LeftShift(data, MakeConstantScalar(dtype, static_cast<int>(shift_factor)));
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
}
}
-Expr QuantizeRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
// do not handle data type cast
const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
// use right shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
- data = Add(data, MakeConstantScalar(cfg->dtype_activation,
- static_cast<int>(round_bias)));
+ data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
}
- data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
- static_cast<int>(shift_nbit)));
+ data = RightShift(data,
+ MakeConstantScalar(cfg->dtype_activation, static_cast<int>(shift_nbit)));
} else {
- data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
- static_cast<int>(shift_nbit)));
+ data = LeftShift(data,
+ MakeConstantScalar(cfg->dtype_activation, static_cast<int>(shift_nbit)));
}
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExpr(data, dom_scale, n->dtype);
} else {
data = Cast(data, DataType::Int(64));
data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
- ref_call->type_as<TensorTypeNode>()->shape,
- cfg->rounding);
+ ref_call->type_as<TensorTypeNode>()->shape, cfg->rounding);
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExpr(data, dom_scale, n->dtype);
}
}
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
-
+ .set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
-Expr Conv2dRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) {
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
- Expr ret = Call(ref_call->op,
- {ldata, rdata}, Attrs(attrs), ref_call->type_args);
+ Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExpr(ret, dom_scale, out_dtype);
}
-RELAY_REGISTER_OP("nn.conv2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
-
+RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
-Expr DenseRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr DenseRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()) {
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
- Expr ret = Call(ref_call->op,
- {ldata, rdata}, Attrs(attrs), ref_call->type_args);
+ Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExpr(ret, dom_scale, out_dtype);
}
-RELAY_REGISTER_OP("nn.dense")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
+RELAY_REGISTER_OP("nn.dense").set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
-
-Expr MulRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr MulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
return Expr(nullptr);
}
-RELAY_REGISTER_OP("multiply")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
-
+RELAY_REGISTER_OP("multiply").set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
if (nptrs.size() == 2) {
}
}
-
/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
DataType* dtype_ptr, Expr* scale_ptr) {
return ret;
}
-Expr AddRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
return Expr(nullptr);
}
-RELAY_REGISTER_OP("add")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
+RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
-Expr ClipRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr ClipRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
attrs->a_min = ref_attrs->a_min / dom_scale;
attrs->a_max = ref_attrs->a_max / dom_scale;
- Expr ret = Call(ref_call->op,
- {n->data}, Attrs(attrs), ref_call->type_args);
+ Expr ret = Call(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args);
return QRealizeIntExpr(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
-RELAY_REGISTER_OP("clip")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
-
+RELAY_REGISTER_OP("clip").set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
-Expr ConcatenateRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr ConcatenateRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
CHECK_EQ(new_args.size(), 1);
CHECK_EQ(ref_call->args.size(), 1);
}
}
-RELAY_REGISTER_OP("concatenate")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
-
+RELAY_REGISTER_OP("concatenate").set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
/* \brief forward the original operator */
-Expr IdentityRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr IdentityRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
return Expr(nullptr);
}
-RELAY_REGISTER_OP("nn.relu")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
-RELAY_REGISTER_OP("strided_slice")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("annotation.stop_fusion")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+ .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
/* \brief for unary operators which requantize its input to dtype_nbit */
-Expr CastDtypeInputRealize(const Call& ref_call,
- const Array<Expr>& new_args,
+Expr CastDtypeInputRealize(const Call& ref_call, const Array<Expr>& new_args,
const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
}
RELAY_REGISTER_OP("nn.max_pool2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
-
+ .set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
-Expr AvgPoolRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
return Expr(nullptr);
}
-RELAY_REGISTER_OP("nn.avg_pool2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
+RELAY_REGISTER_OP("nn.avg_pool2d").set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
RELAY_REGISTER_OP("nn.global_avg_pool2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
+ .set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
-Expr CastHintRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const ObjectRef& ctx) {
+Expr CastHintRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const auto param = ref_call->attrs.as<CastHintAttrs>();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
}
RELAY_REGISTER_OP("annotation.cast_hint")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);
+ .set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);
Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(
- ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
+ };
return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}
-TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize")
-.set_body_typed(QuantizeRealizePass);
+TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize").set_body_typed(QuantizeRealizePass);
} // namespace quantize
} // namespace relay
custom layouts or other general weight pre-transformation.
*/
#include <tvm/relay/analysis.h>
-#include <tvm/relay/transform.h>
-#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>
-#include <tuple>
-#include <vector>
+
#include <functional>
#include <string>
-#include <utility>
+#include <tuple>
#include <unordered_map>
+#include <utility>
+#include <vector>
-#include "transform_layout.h"
#include "pattern_util.h"
+#include "transform_layout.h"
namespace tvm {
namespace relay {
}
// TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes.
// Probably we need to disable the AlterOpLayout when compiling dynamic models.
- Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos,
- ref_call->checked_type());
+ Expr altered_value =
+ falter_layout[op](ref_call->attrs, new_args, tinfos, ref_call->checked_type());
if (altered_value.defined()) {
new_e = altered_value;
modified = true;
Pass AlterOpLayout() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
+ };
return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
-.set_body_typed(AlterOpLayout);
+TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout").set_body_typed(AlterOpLayout);
} // namespace transform
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, targets));
};
- auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
- {"InferType"});
+ auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
}
* \brief Canonicalize cast expressions to make operator fusion more efficient.
*/
#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
+
#include "pass_util.h"
#include "pattern_util.h"
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
CHECK(new_call->op == cast_op_);
- return Call(new_call->op, new_call->args, new_call->attrs,
- new_call->type_args);
+ return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args);
}
}
}
}
};
-Expr CanonicalizeCast(const Expr& e) {
- return CastCanonicalizer().Mutate(e);
-}
+Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); }
namespace transform {
Pass CanonicalizeCast() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(CanonicalizeCast(f));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(CanonicalizeCast(f));
+ };
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast")
-.set_body_typed(CanonicalizeCast);
+TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast").set_body_typed(CanonicalizeCast);
} // namespace transform
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
*/
#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
-#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
+
#include "pattern_util.h"
namespace tvm {
Pass CanonicalizeOps() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(CanonicalizeOps(f));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(CanonicalizeOps(f));
+ };
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps")
-.set_body_typed(CanonicalizeOps);
+TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps").set_body_typed(CanonicalizeOps);
} // namespace transform
*/
#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+
#include <unordered_map>
#include <unordered_set>
-#include "./expr_subst.h"
+
#include "./combine_parallel_op.h"
+#include "./expr_subst.h"
#include "pattern_util.h"
namespace tvm {
class ParallelConv2DCombiner : public ParallelOpCombiner {
public:
explicit ParallelConv2DCombiner(uint64_t min_num_branches)
- : ParallelOpCombiner("nn.conv2d", min_num_branches) {
- }
+ : ParallelOpCombiner("nn.conv2d", min_num_branches) {}
protected:
- bool IsSupportedOp(const CallNode* n) {
- return n->attrs.as<Conv2DAttrs>()->groups == 1;
- }
+ bool IsSupportedOp(const CallNode* n) { return n->attrs.as<Conv2DAttrs>()->groups == 1; }
bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
StructuralEqual eq;
CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
- const auto shape_a = tir::BijectiveLayout(
- Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
- const auto shape_b = tir::BijectiveLayout(
- Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
+ const auto shape_a =
+ tir::BijectiveLayout(Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
+ const auto shape_b =
+ tir::BijectiveLayout(Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
auto toutput_a = a->type_as<TensorTypeNode>();
auto toutput_b = b->type_as<TensorTypeNode>();
- if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size())
- return false;
+ if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false;
// Position of the 'C' dimension in the argument
size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size();
for (size_t i = 0; i < ta->shape.size(); i++) {
if (i == arg_channel_pos) continue;
- if (!eq(ta->shape[i], tb->shape[i]))
- return false;
+ if (!eq(ta->shape[i], tb->shape[i])) return false;
}
return true;
}
- Call MakeCombinedCallFromFollowingOps(const Expr& data,
- const Group& branches,
- size_t depth,
+ Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
return Call(call->op, new_args, call->attrs, {});
}
- void UpdateGroupOutput(const Expr& data,
- const Group& branches,
- size_t depth,
+ void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) {
int64_t index = 0;
for (const auto& branch : branches) {
Pass CombineParallelConv2D(uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
+ };
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D")
-.set_body_typed(CombineParallelConv2D);
+TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D").set_body_typed(CombineParallelConv2D);
} // namespace transform
*/
#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+
#include <unordered_map>
#include <unordered_set>
+
+#include "./combine_parallel_op_batch.h"
#include "./expr_subst.h"
#include "pattern_util.h"
-#include "./combine_parallel_op_batch.h"
namespace tvm {
namespace relay {
class ParallelDenseCombiner : public ParallelOpBatchCombiner {
public:
explicit ParallelDenseCombiner(uint64_t min_num_branches)
- : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {
- }
+ : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {}
protected:
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
return eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
- eq(weight_a->shape[0], weight_b->shape[0]) &&
- eq(weight_a->shape[1], weight_b->shape[1]);
+ eq(weight_a->shape[0], weight_b->shape[0]) && eq(weight_a->shape[1], weight_b->shape[1]);
}
};
Pass CombineParallelDense(uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(CombineParallelDense(f, min_num_branches));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(CombineParallelDense(f, min_num_branches));
+ };
return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
-.set_body_typed(CombineParallelDense);
+TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense").set_body_typed(CombineParallelDense);
} // namespace transform
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \brief Abstract class to combine parallel ops and their successive element-wise ops.
*/
+#include "combine_parallel_op.h"
+
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+
#include <algorithm>
-#include <utility>
#include <unordered_map>
#include <unordered_set>
+#include <utility>
+
#include "expr_subst.h"
#include "pattern_util.h"
-#include "combine_parallel_op.h"
-
namespace tvm {
namespace relay {
-BranchGroupFinder::BranchGroupFinder(const Op& op,
- FIsSupportedOp fis_supported_op,
+BranchGroupFinder::BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops)
- : cached_op_(op),
- fis_supported_op_(fis_supported_op),
- fare_compatible_ops_(fare_compatible_ops) {
-}
+ : cached_op_(op),
+ fis_supported_op_(fis_supported_op),
+ fare_compatible_ops_(fare_compatible_ops) {}
std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
this->VisitExpr(expr);
}
ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches)
- : cached_op_(Op::Get(op_name)),
- min_num_branches_(min_num_branches) {
-}
+ : cached_op_(Op::Get(op_name)), min_num_branches_(min_num_branches) {}
Expr ParallelOpCombiner::Combine(const Expr& expr) {
- auto groups = BranchGroupFinder(cached_op_,
- [&](const CallNode* n) {
- return IsSupportedOp(n);
- },
- [&](const CallNode* a, const CallNode* b) {
- return CanOpsBeCombined(a, b);
- }).Find(expr);
+ auto groups = BranchGroupFinder(
+ cached_op_, [&](const CallNode* n) { return IsSupportedOp(n); },
+ [&](const CallNode* a, const CallNode* b) { return CanOpsBeCombined(a, b); })
+ .Find(expr);
for (const Group& group : groups) {
if (group.size() < min_num_branches_) {
continue;
void ParallelOpCombiner::CombineBranches(const Group& branches) {
Call combined = MakeCombinedOp(branches);
auto it = std::min_element(branches.begin(), branches.end(),
- [](const Branch& branch_a,
- const Branch& branch_b) {
- return branch_a.size() < branch_b.size();
- });
+ [](const Branch& branch_a, const Branch& branch_b) {
+ return branch_a.size() < branch_b.size();
+ });
size_t depth = it->size();
size_t i;
// starting from 1 to skip the op
}
bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) {
- const CallNode* call = branches[0][depth];
- tvm::StructuralEqual attrs_equal;
- // check if all branches in current depth can be combined
- for (auto it = branches.begin() + 1; it != branches.end(); it++) {
- const Branch& branch = *it;
- if (!branch[depth]->op.same_as(call->op) ||
- !attrs_equal(branch[depth]->attrs, call->attrs) ||
- branch[depth]->args.size() != call->args.size()) {
- return false;
- }
+ const CallNode* call = branches[0][depth];
+ tvm::StructuralEqual attrs_equal;
+ // check if all branches in current depth can be combined
+ for (auto it = branches.begin() + 1; it != branches.end(); it++) {
+ const Branch& branch = *it;
+ if (!branch[depth]->op.same_as(call->op) || !attrs_equal(branch[depth]->attrs, call->attrs) ||
+ branch[depth]->args.size() != call->args.size()) {
+ return false;
+ }
- if (branch[depth]->args[parent_index].get() != branch[depth - 1])
- return false;
+ if (branch[depth]->args[parent_index].get() != branch[depth - 1]) return false;
- // Check args
- for (size_t i = 0; i < call->args.size(); i++) {
- if (i == parent_index) continue;
+ // Check args
+ for (size_t i = 0; i < call->args.size(); i++) {
+ if (i == parent_index) continue;
- if (!IsArgCompatible(call, branch[depth], i) ||
- !attrs_equal(call->attrs, branch[depth]->attrs)) {
- return false;
- }
+ if (!IsArgCompatible(call, branch[depth], i) ||
+ !attrs_equal(call->attrs, branch[depth]->attrs)) {
+ return false;
}
}
- return true;
}
+ return true;
+}
} // namespace relay
} // namespace tvm
#define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_
#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+
+#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
-#include <string>
+
#include "./expr_subst.h"
#include "pattern_util.h"
-
namespace tvm {
namespace relay {
using Branch = std::vector<const CallNode*>;
using Group = std::vector<Branch>;
-using FIsSupportedOp = std::function<bool (const CallNode* n)>;
-using FAreCompatibleOps = std::function<bool (const CallNode* a, const CallNode* b)>;
+using FIsSupportedOp = std::function<bool(const CallNode* n)>;
+using FAreCompatibleOps = std::function<bool(const CallNode* a, const CallNode* b)>;
using ExprSubstMap = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>;
/*
* \param fare_compatible_ops function that returns true if
* two ops are compatible for combining
*/
- BranchGroupFinder(const Op& op,
- FIsSupportedOp fis_supported_op,
+ BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops);
/*
* all combined ops
* \return new combined call
*/
- virtual Call MakeCombinedCallFromFollowingOps(const Expr& data,
- const Group& branches,
- size_t depth,
- size_t parent_index) = 0;
+ virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches,
+ size_t depth, size_t parent_index) = 0;
/*
* \brief Updates map of expr to substitute with combined expr. This usually involves
* \param depth depth at which to substitute
* \param subst_map map of Expr to replace with Expr to replace it with
*/
- virtual void UpdateGroupOutput(const Expr& data,
- const Group& branches,
- size_t depth,
+ virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) = 0;
private:
*
*/
+#include "./combine_parallel_op_batch.h"
+
#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+
#include <unordered_map>
#include <unordered_set>
-#include "./expr_subst.h"
+
#include "./combine_parallel_op.h"
-#include "./combine_parallel_op_batch.h"
+#include "./expr_subst.h"
#include "pattern_util.h"
namespace tvm {
ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches)
- : ParallelOpCombiner(op_name, min_num_branches),
- batch_op_name_(batch_op_name) {
-}
+ : ParallelOpCombiner(op_name, min_num_branches), batch_op_name_(batch_op_name) {}
-bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) {
- return true;
-}
+bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { return true; }
bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) {
if (a->args.size() != b->args.size()) {
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
- if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size())
- return false;
+ if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false;
for (size_t i = 0; i < ta->shape.size(); i++) {
- if (!eq(ta->shape[i], tb->shape[i]))
- return false;
+ if (!eq(ta->shape[i], tb->shape[i])) return false;
}
return true;
}
Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data,
- const Group& branches,
- size_t depth,
+ const Group& branches, size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
return Call(call->op, new_args, call->attrs, {});
}
-void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data,
- const Group& branches,
- size_t depth,
- ExprSubstMap* subst_map) {
+void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches,
+ size_t depth, ExprSubstMap* subst_map) {
int index = 0;
auto split = MakeSplit(data, Integer(branches.size()), 0);
for (const auto& branch : branches) {
}
/*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */
-Expr CombineParallelOpBatch(const Expr& expr,
- const std::string& op_name,
- const std::string& batch_op_name,
- uint64_t min_num_branches) {
+Expr CombineParallelOpBatch(const Expr& expr, const std::string& op_name,
+ const std::string& batch_op_name, uint64_t min_num_branches) {
return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr);
}
namespace transform {
-Pass CombineParallelOpBatch(const std::string& op_name,
- const std::string& batch_op_name,
+Pass CombineParallelOpBatch(const std::string& op_name, const std::string& batch_op_name,
uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(CombineParallelOpBatch(f,
- op_name,
- batch_op_name,
- min_num_branches));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(
+ CombineParallelOpBatch(f, op_name, batch_op_name, min_num_branches));
+ };
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
-.set_body_typed(CombineParallelOpBatch);
+ .set_body_typed(CombineParallelOpBatch);
} // namespace transform
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_
#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+
+#include <string>
#include <unordered_map>
#include <unordered_set>
-#include <string>
-#include "./expr_subst.h"
+
#include "./combine_parallel_op.h"
+#include "./expr_subst.h"
#include "pattern_util.h"
namespace tvm {
* \param min_num_branches min number of parallel branches beginning with op
* to start combining
*/
- ParallelOpBatchCombiner(const std::string& op_name,
- const std::string& batch_op_name,
+ ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name,
uint64_t min_num_branches);
protected:
* all combined ops
* \return new combined call as batch op by stacking args
*/
- Call MakeCombinedCallFromFollowingOps(const Expr& data,
- const Group& branches,
- size_t depth,
+ Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
size_t parent_index) final;
/*
* \param depth depth at which to substitute
* \param subst_map map of Expr to replace with Expr to replace it with
*/
- void UpdateGroupOutput(const Expr& data,
- const Group& branches,
- size_t depth,
+ void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) final;
private:
/* \brief name of op to replace combined ops with. for example,
* for combining parallel dense, this will will be set to
- * nn.batch_matmul
+ * nn.batch_matmul
*/
std::string batch_op_name_;
};
custom layouts or other general weight pre-transformation.
*/
#include <tvm/relay/analysis.h>
-#include <tvm/relay/transform.h>
-#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>
-#include <tuple>
-#include <vector>
+
#include <functional>
#include <string>
-#include <utility>
+#include <tuple>
#include <unordered_map>
+#include <utility>
+#include <vector>
-#include "transform_layout.h"
#include "pattern_util.h"
+#include "transform_layout.h"
namespace tvm {
namespace relay {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
};
- return CreateFunctionPass(
- pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
+ return CreateFunctionPass(pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
}
TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
// Remove FreeVar warnings
auto f0 = Downcast<Function>(DenseToSparse(f, weight_name, weight_shape));
Array<Var> sparse_params = FreeVars(f0);
- auto f1 = Function(sparse_params,
- f0->body,
- f0->ret_type,
- f0->type_params,
- f0->attrs);
+ auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs);
Array<Var> params = FreeVars(f1);
for (const auto& var : sparse_params) {
params.push_back(var);
}
- return Function(params,
- f1->body,
- f1->ret_type,
- f1->type_params,
- f1->attrs);
+ return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs);
};
return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"});
}
* \brief Use a fresh Id for every Var to make the result well-formed.
*/
#include <tvm/ir/type_functor.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
namespace tvm {
namespace relay {
Expr DeDup(const Expr& e) {
- class DeDupMutator : public TypeMutator,
- public ExprMutator,
- public PatternMutator {
+ class DeDupMutator : public TypeMutator, public ExprMutator, public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVar(tv->name_hint, tv->kind);
return Let(v, VisitExpr(op->value), VisitExpr(op->body));
}
- Type VisitType(const Type& t) final {
- return t.defined() ? TypeMutator::VisitType(t) : t;
- }
+ Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; }
Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<TypeVar> type_params;
for (const Var& param : op->params) {
params.push_back(Fresh(param));
}
- return Function(params,
- VisitExpr(op->body),
- VisitType(op->ret_type),
- type_params,
- op->attrs);
+ return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, op->attrs);
}
- Pattern VisitPattern(const Pattern& p) final {
- return PatternFunctor::VisitPattern(p);
- }
+ Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); }
- Pattern VisitPattern_(const PatternVarNode* op) final {
- return PatternVar(Fresh(op->var));
- }
+ Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(Fresh(op->var)); }
Type VisitType_(const TypeVarNode* op) final {
TypeVar v = GetRef<TypeVar>(op);
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
}
- Var VisitVar(const Var& v) final {
- return Fresh(v);
- }
+ Var VisitVar(const Var& v) final { return Fresh(v); }
private:
std::unordered_map<Var, Var, ObjectHash, ObjectEqual> rename_;
return ret;
}
-TVM_REGISTER_GLOBAL("relay._transform.dedup")
-.set_body_typed(DeDup);
+TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup);
} // namespace relay
} // namespace tvm
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
+
#include "let_list.h"
namespace tvm {
namespace relay {
-template<typename X>
+template <typename X>
using VarMap = std::unordered_map<Var, X, ObjectHash, ObjectEqual>;
using VarSet = std::unordered_set<Var, ObjectHash, ObjectEqual>;
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
bool inline_once_;
- explicit Eliminator(const VarMap<Expr>& expr_map,
- const VarMap<size_t>& use_map,
- bool inline_once) :
- expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { }
+ explicit Eliminator(const VarMap<Expr>& expr_map, const VarMap<size_t>& use_map, bool inline_once)
+ : expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) {}
friend CalcDep;
bool HasLet(const Var& v) {
switch (use_map_[v]) {
- case 0:
- return false;
- case 1:
- return !inline_once_;
- default:
- return true;
+ case 0:
+ return false;
+ case 1:
+ return !inline_once_;
+ default:
+ return true;
}
}
}
private:
- explicit CalcDep(const VarMap<Expr>& expr_map)
- : MixedModeVisitor(2), expr_map_(expr_map) {}
+ explicit CalcDep(const VarMap<Expr>& expr_map) : MixedModeVisitor(2), expr_map_(expr_map) {}
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
}
}
- void VisitExpr_(const LetNode* l) final {
- VisitExpr(l->body);
- }
+ void VisitExpr_(const LetNode* l) final { VisitExpr(l->body); }
void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(DeadCodeElimination(f, inline_once));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(DeadCodeElimination(f, inline_once));
+ };
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination")
-.set_body_typed(DeadCodeElimination);
+TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination").set_body_typed(DeadCodeElimination);
} // namespace transform
* 3. Collect the device allocation of each expression.
*/
-#include <tvm/tir/expr.h>
-#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
+#include <tvm/tir/expr.h>
#include <memory>
#include <unordered_map>
* \return The device type.
*/
int GetDeviceId(const CallNode* call_node) {
- CHECK(IsOnDeviceNode(call_node))
- << "The input call node must be on_device node.";
+ CHECK(IsOnDeviceNode(call_node)) << "The input call node must be on_device node.";
const OnDeviceAttrs* on_device_attr = call_node->attrs.as<OnDeviceAttrs>();
return on_device_attr->device_type;
}
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr tuple = op->tuple;
if (NeedDeviceCopy(tuple.operator->(), op)) {
- Expr new_expr =
- TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index);
+ Expr new_expr = TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index);
UpdateAnnotationMap(op, new_expr.operator->());
return this->VisitExpr(new_expr);
} else {
}
if (annotated) {
- Call new_call = Call(call_node->op, new_args, call_node->attrs,
- call_node->type_args);
+ Call new_call = Call(call_node->op, new_args, call_node->attrs, call_node->type_args);
UpdateAnnotationMap(call_node, new_call.operator->());
return this->VisitExpr(new_call);
return CreateDeviceCopy(src, fallback_device_, dit->second);
} else {
const auto dit = annotation_map_.find(dst);
- int dst_dev_type =
- dit == annotation_map_.end() ? fallback_device_ : dit->second;
+ int dst_dev_type = dit == annotation_map_.end() ? fallback_device_ : dit->second;
return CreateDeviceCopy(src, sit->second, dst_dev_type);
}
}
visitor(expr);
return visitor.annotations_;
}
+
private:
void VisitExpr_(const CallNode* call_node) {
if (IsOnDeviceNode(call_node)) {
// TODO(zhiics) Skip annotation of tuple node for now.
}
- void VisitExpr_(const TupleGetItemNode* op) final {
- ExprVisitor::VisitExpr_(op);
- }
+ void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); }
void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
post_dfs_order_.push_back(std::make_pair(in, has_copy_));
}
-
int num_device_copy_ops_{0};
bool has_copy_ = false;
std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
- if (it->second) device_map_.Set(GetRef<Expr>(it->first),
- attrs->dst_dev_type);
+ if (it->second) device_map_.Set(GetRef<Expr>(it->first), attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it->first);
CHECK_EQ(device_map_.count(expr), 0U);
if (it->second) device_map_.Set(expr, cur_dev_type);
}
}
- return out_dev_type;
+ return out_dev_type;
}
void FillPropagation(int out_dev_type) {
for (const auto& it : post_visitor_.post_dfs_order_) {
- Expr expr = GetRef<Expr>(it.first);
- if (!it.second) device_map_.Set(expr, out_dev_type);
+ Expr expr = GetRef<Expr>(it.first);
+ if (!it.second) device_map_.Set(expr, out_dev_type);
}
}
-
PostDfsOrderVisitor post_visitor_;
Map<Expr, Integer> device_map_;
};
}
CHECK_GT(new_body.size(), 0U);
if (new_body.size() == 1) {
- return Function(params, new_body[0], Type(nullptr),
- fn->type_params, fn->attrs);
+ return Function(params, new_body[0], Type(nullptr), fn->type_params, fn->attrs);
} else if (tuple->fields.size() == new_body.size()) {
- return new_expr;
+ return new_expr;
} else {
Tuple tuple_body = Tuple(new_body);
- return Function(params, tuple_body, Type(nullptr),
- fn->type_params, fn->attrs);
+ return Function(params, tuple_body, Type(nullptr), fn->type_params, fn->attrs);
}
} else {
return new_expr;
if (tuple->fields.size() == new_fields.size()) {
return new_fields.size() == 1 ? new_fields[0] : new_expr;
} else {
- return new_fields.size() == 1 ? new_fields[0]
- : Tuple(new_fields);
+ return new_fields.size() == 1 ? new_fields[0] : Tuple(new_fields);
}
} else {
return new_expr;
}
}
-Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
- return DeviceInfo::GetDeviceMap(expr);
-}
+Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) { return DeviceInfo::GetDeviceMap(expr); }
Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
return AnnotatationVisitor::GetAnnotations(expr);
}
-TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo")
-.set_body_typed(CollectDeviceInfo);
+TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo").set_body_typed(CollectDeviceInfo);
TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps")
-.set_body_typed(CollectDeviceAnnotationOps);
+ .set_body_typed(CollectDeviceAnnotationOps);
namespace transform {
Pass RewriteAnnotatedOps(int fallback_device) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
+ };
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
-.set_body_typed(RewriteAnnotatedOps);
+TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation").set_body_typed(RewriteAnnotatedOps);
} // namespace transform
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
+
#include <unordered_map>
+
#include "pattern_util.h"
namespace tvm {
class CommonSubexprEliminator : public ExprMutator {
public:
- explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip): fskip_(fskip) {}
+ explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip) : fskip_(fskip) {}
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
Pass EliminateCommonSubexpr(PackedFunc fskip) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
+ };
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
-.set_body_typed(EliminateCommonSubexpr);
+ .set_body_typed(EliminateCommonSubexpr);
} // namespace transform
*
*/
#include <tvm/ir/type_functor.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
-#include <tvm/relay/expr_functor.h>
namespace tvm {
namespace relay {
type_var_replacer_(TypeVarReplacer()),
expand_constructor_(expand_constructor),
expand_global_var_(expand_global_var) {
- CHECK(expand_constructor || expand_global_var)
- << "must expand at least one language feature";
+ CHECK(expand_constructor || expand_global_var) << "must expand at least one language feature";
}
IRModule Expand() {
for (GlobalVar global_var : mod_->GetGlobalVars()) {
const BaseFunc base_func = mod_->Lookup(global_var);
if (auto* n = base_func.as<FunctionNode>()) {
- const Function new_func = Downcast<Function>(
- VisitExpr(GetRef<Function>(n)));
+ const Function new_func = Downcast<Function>(VisitExpr(GetRef<Function>(n)));
mod_->Update(global_var, new_func);
}
}
Expr body = Call(cons, params, Attrs());
Type ret_type = TypeCall(cons->belong_to, type_params);
- return Function(
- Downcast<tvm::Array<Var>>(params),
- body,
- ret_type,
- Downcast<tvm::Array<TypeVar>>(type_params));
+ return Function(Downcast<tvm::Array<Var>>(params), body, ret_type,
+ Downcast<tvm::Array<TypeVar>>(type_params));
}
Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
return std::move(gvar);
}
const auto base_func = mod_->Lookup(gvar);
- if (auto *ptr = base_func.as<FunctionNode>()) {
+ if (auto* ptr = base_func.as<FunctionNode>()) {
// handle relay function, skip external functions.
auto func = GetRef<Function>(ptr);
tvm::Array<Expr> params;
args.push_back(var);
}
- return Function(
- args,
- Call(gvar, params),
- func->ret_type,
- func->type_params);
+ return Function(args, Call(gvar, params), func->ret_type, func->type_params);
} else {
return std::move(gvar);
}
namespace transform {
Pass EtaExpand(bool expand_constructor, bool expand_global_var) {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule mod, PassContext pc) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
+ PassContext pc) {
return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand();
};
return CreateModulePass(pass_func, 1, "EtaExpand", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.EtaExpand")
-.set_body_typed(EtaExpand);
+TVM_REGISTER_GLOBAL("relay._transform.EtaExpand").set_body_typed(EtaExpand);
} // namespace transform
* \brief Utility functions for substituting expressions.
*/
-#include <tvm/relay/expr_functor.h>
#include "./expr_subst.h"
+#include <tvm/relay/expr_functor.h>
+
namespace tvm {
namespace relay {
#ifndef TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_
#define TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_
#include <tvm/relay/expr.h>
+
#include <unordered_map>
namespace tvm {
namespace relay {
-Expr ExprSubst(const Expr& expr,
- std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> subst_map);
+Expr ExprSubst(const Expr& expr, std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> subst_map);
} // namespace relay
} // namespace tvm
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*/
#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
-#include <tvm/relay/transform.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
+#include <tvm/relay/transform.h>
+
#include "pattern_util.h"
namespace tvm {
class FastMathMutator : public ExprRewriter {
public:
- FastMathMutator()
- : exp_op_(Op::Get("exp")),
- erf_op_(Op::Get("erf")),
- tanh_op_(Op::Get("tanh")) {}
+ FastMathMutator() : exp_op_(Op::Get("exp")), erf_op_(Op::Get("erf")), tanh_op_(Op::Get("tanh")) {}
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (pre->op == exp_op_) {
Pass FastMath() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(FastMath(f));
- };
+ [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(FastMath(f)); };
return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.FastMath")
-.set_body_typed(FastMath);
+TVM_REGISTER_GLOBAL("relay._transform.FastMath").set_body_typed(FastMath);
} // namespace transform
* \file constant_folding.cc
*/
#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/interpreter.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
-#include <tvm/relay/interpreter.h>
-#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
-#include <tvm/runtime/object.h>
-#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/container.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/object.h>
+
#include "pattern_util.h"
namespace tvm {
return true;
}
const auto it = memo_.find(expr);
- if (it != memo_.end())
- return it->second;
+ if (it != memo_.end()) return it->second;
VisitExpr(expr);
return memo_[expr]; // return memoized result or the default value false
}
}
};
-bool ConstantCheck(const Expr& e) {
- return ConstantChecker().Check(e);
-}
+bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); }
-TVM_REGISTER_GLOBAL("relay.analysis.check_constant")
-.set_body_typed(ConstantCheck);
+TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantCheck);
// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
- if (var.same_as(op->var) &&
- value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Let(var, value, body);
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
if (skip_list.count(op->name)) {
- return res;
+ return res;
}
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
}
// We should think about potentially constant evaluation over these ops too.
- if (call->op == invoke_tvm_op_ ||
- call->op == shape_func_op_ ||
- call->op == alloc_tensor_op_ ||
+ if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ ||
call->op == alloc_storage_op_) {
return GetRef<Call>(call);
}
if (value->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(value);
for (auto dim : nd_array.Shape()) {
- CHECK_GT(dim, 0)
- << "invalid dimension after constant eval";
+ CHECK_GT(dim, 0) << "invalid dimension after constant eval";
}
return Constant(nd_array);
} else if (const auto* val = value.as<runtime::ADTObj>()) {
}
// Constant evaluate a expression.
Expr ConstEvaluate(Expr expr) {
- std::vector<transform::Pass> passes = {transform::FuseOps(0),
- transform::ToANormalForm(),
+ std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::ToANormalForm(),
transform::InferType()};
Function func;
if (expr.as<FunctionNode>()) {
// TODO(@jroesch): fix this
func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
}
- auto mod = IRModule(
- {},
- module_->type_definitions,
- module_->Imports());
+ auto mod = IRModule({}, module_->type_definitions, module_->Imports());
auto global = GlobalVar("main");
mod->Add(global, func);
auto seq = transform::Sequential(passes);
value = runtime::NDArray::Empty({}, cdtype, ctx);
} else {
CHECK_NE(ishape.size(), 0);
- std::vector<int64_t> cshape = { static_cast<int64_t>(ishape.size()) };
+ std::vector<int64_t> cshape = {static_cast<int64_t>(ishape.size())};
value = runtime::NDArray::Empty(cshape, cdtype, ctx);
int32_t* dims = static_cast<int32_t*>(value->data);
using ::tvm::tir::IntImmNode;
// Cast the constant into correct dtype
auto cast_attrs = make_object<CastAttrs>();
cast_attrs->dtype = param->dtype;
- Expr ret = Call(cast_op_, { shape }, Attrs(cast_attrs), {});
+ Expr ret = Call(cast_op_, {shape}, Attrs(cast_attrs), {});
return ConstEvaluate(ret);
}
};
-
Expr FoldConstant(const Expr& expr, const IRModule& mod) {
DLContext ctx;
ctx.device_type = kDLCPU;
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(FoldConstant(f, m));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(FoldConstant(f, m));
+ };
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
-.set_body_typed(FoldConstant);
+TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant);
} // namespace transform
* \brief Fold axis scaling into weights of
* conv/dense operators.
*/
-#include <tvm/tir/data_layout.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
-#include "pattern_util.h"
-#include "pass_util.h"
+#include <tvm/tir/data_layout.h>
+#include "pass_util.h"
+#include "pattern_util.h"
namespace tvm {
namespace relay {
using runtime::TypedPackedFunc;
-
// FoldScaleAxis algorithm:
//
// The general idea is to transform Expr to tuple of
TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode);
};
-Message::Message(const AxesSet& axes, bool require_positive) {
+Message::Message(const AxesSet& axes, bool require_positive) {
auto n = make_object<MessageNode>();
n->axes = axes;
n->require_positive = require_positive;
++j;
} else {
ret.push_back(lhs[i]);
- ++i; ++j;
+ ++i;
+ ++j;
}
}
return ret;
* positive scale is required.
* \return The message containing the result scaling on axes of the input.
*/
-using FForwardPrep = runtime::TypedPackedFunc<
- Array<Message> (const Call& call, const Message& out_message)>;
+using FForwardPrep =
+ runtime::TypedPackedFunc<Array<Message>(const Call& call, const Message& out_message)>;
/*! \brief Axis scale tuple. */
class ScaledExprNode : public TempExprNode {
Expr scale = NullValue<Expr>();
Expr Realize() const final {
- CHECK(!axes.defined())
- << "outstanding scale";
+ CHECK(!axes.defined()) << "outstanding scale";
return value;
}
TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode);
};
-using FForwardRewrite = TypedPackedFunc<
- Expr(const Call& ref_call,
- const Array<Expr>& new_args,
- const Message& message)>;
+using FForwardRewrite = TypedPackedFunc<Expr(const Call& ref_call, const Array<Expr>& new_args,
+ const Message& message)>;
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
class ForwardPrep : private ExprVisitor {
public:
- std::unordered_map<const Object*, Message>
- Prepare(const Expr& body) {
+ std::unordered_map<const Object*, Message> Prepare(const Expr& body) {
this->Update(body, NullValue<Message>());
this->VisitExpr(body);
// flist is added in the Post-DFS order
private:
// The invoke list
- std::vector<std::function<void()> > flist_;
+ std::vector<std::function<void()>> flist_;
// The message on each node.
std::unordered_map<const Object*, Message> message_;
// Update the message stored at node.
}
}
// Visitor pattern override.
- void VisitExpr_(const LetNode* call) {
- LOG(FATAL) << "FoldScaleAxis only accept dataflow-form";
- }
+ void VisitExpr_(const LetNode* call) { LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; }
void VisitExpr_(const FunctionNode* op) {
ExprVisitor::VisitExpr_(op);
- auto flazy = [this, op] {
- this->Update(op->body, NullValue<Message>());
- };
+ auto flazy = [this, op] { this->Update(op->body, NullValue<Message>()); };
flist_.push_back(flazy);
}
ExprVisitor::VisitExpr_(call);
// function to be lazily invoked
auto flazy = [this, call]() {
- static const auto& fprep =
- Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
+ static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
// find the message send to this node.
auto it = message_.find(call);
Message out_message;
return {out_message};
}
-Expr ReluForwardRewrite(const Call& ref_call,
- const Array<Expr>& new_args,
- const Message& message) {
+Expr ReluForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, const Message& message) {
const auto* input = new_args[0].as<ScaledExprNode>();
if (input == nullptr) return Expr(nullptr);
// return transformed conv2d
auto rnode = make_object<ScaledExprNode>();
- rnode->value = Call(
- ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args);
+ rnode->value = Call(ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args);
rnode->scale = input->scale;
rnode->axes = input->axes;
return Expr(rnode);
}
-RELAY_REGISTER_OP("nn.relu")
-.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
+RELAY_REGISTER_OP("nn.relu").set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
-RELAY_REGISTER_OP("nn.relu")
-.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
+RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FScaleAxisForwardRewrite",
+ ReluForwardRewrite);
-RELAY_REGISTER_OP("nn.leaky_relu")
-.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
+RELAY_REGISTER_OP("nn.leaky_relu").set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
RELAY_REGISTER_OP("nn.leaky_relu")
-.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
+ .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
// AddSub
Array<Message> AddSubForwardPrep(const Call& call, const Message& out_message) {
return {none, none};
}
-Expr AddSubForwardRewrite(const Call& ref_call,
- const Array<Expr>& new_args,
+Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
const Message& message) {
const auto* slhs = new_args[0].as<ScaledExprNode>();
const auto* srhs = new_args[1].as<ScaledExprNode>();
if (slhs != nullptr) {
CHECK(srhs == nullptr);
CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
- Expr scale = ExpandBiasToMatchAxis(
- slhs->scale, tlhs->shape.size(), slhs->axes);
+ Expr scale = ExpandBiasToMatchAxis(slhs->scale, tlhs->shape.size(), slhs->axes);
Expr rhs = Divide(new_args[1], scale);
- rnode->value = Call(ref_call->op, {slhs->value, rhs},
- ref_call->attrs, ref_call->type_args);
+ rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args);
rnode->scale = slhs->scale;
rnode->axes = slhs->axes;
} else {
CHECK(srhs != nullptr);
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
- Expr scale = ExpandBiasToMatchAxis(
- srhs->scale, trhs->shape.size(), srhs->axes);
+ Expr scale = ExpandBiasToMatchAxis(srhs->scale, trhs->shape.size(), srhs->axes);
Expr lhs = Divide(new_args[0], scale);
- rnode->value = Call(ref_call->op, {lhs, srhs->value},
- ref_call->attrs, ref_call->type_args);
+ rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args);
rnode->scale = srhs->scale;
rnode->axes = srhs->axes;
}
return Expr(rnode);
}
-RELAY_REGISTER_OP("add")
-.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
+RELAY_REGISTER_OP("add").set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
-RELAY_REGISTER_OP("add")
-.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
+RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FScaleAxisForwardRewrite",
+ AddSubForwardRewrite);
-RELAY_REGISTER_OP("subtract")
-.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
+RELAY_REGISTER_OP("subtract").set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
RELAY_REGISTER_OP("subtract")
-.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
+ .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
// Producer operators
// Multiply produces the scale-axis pair.
-Expr MultiplyForwardRewrite(const Call& ref_call,
- const Array<Expr>& new_args,
+Expr MultiplyForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
const Message& message) {
if (!message.defined()) return Expr();
const auto& expected_out_axes = message->axes;
}
RELAY_REGISTER_OP("multiply")
-.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
+ .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
// Consumer operators
// Conv2D send out requirement of axis folding.
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
- if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
- c_small_axis < 0 &&
+ if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis};
}
}
// Conv2D consumes the scale axis during transformation.
-Expr Conv2DForwardRewrite(const Call& ref_call,
- const Array<Expr>& new_args,
+Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
const Message& message) {
// if data do not have scale, normal transform path.
const auto* sdata = new_args[0].as<ScaledExprNode>();
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
- CHECK(sdata->axes.size() == 1 &&
- c_big_axis == sdata->axes[0]->value);
+ CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value);
int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
// match the ic_axis
if (is_depthwise_conv2d) {
- Expr scale = ExpandBiasToMatchAxis(
- sdata->scale, kernel_layout.ndim(), {big_oc_axis});
+ Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, scale);
} else {
- Expr scale = ExpandBiasToMatchAxis(
- sdata->scale, kernel_layout.ndim(), {big_ic_axis});
+ Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ic_axis});
weight = Multiply(weight, scale);
}
// return transformed conv2d
- return Call(
- ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
+ return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}
-RELAY_REGISTER_OP("nn.conv2d")
-.set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
+RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
RELAY_REGISTER_OP("nn.conv2d")
-.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
-
+ .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
Expr ForwardFoldScaleAxis(const Expr& data) {
auto message = ForwardPrep().Prepare(data);
- auto fcontext = [&](const Call& call) -> ObjectRef{
+ auto fcontext = [&](const Call& call) -> ObjectRef {
auto it = message.find(call.get());
if (it != message.end()) {
return it->second;
return ObjectRef(nullptr);
}
};
- return ForwardRewrite(
- data, "FScaleAxisForwardRewrite", fcontext);
+ return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext);
}
//----------------------------------------
* positive scale is required.
* \return Message containing the result scaling on axes of the input.
*/
-using FBackwardPrep = TypedPackedFunc<
- Message(const Call& call, const Array<Message>& in_messages)>;
+using FBackwardPrep = TypedPackedFunc<Message(const Call& call, const Array<Message>& in_messages)>;
-using FBackwardTransform = TypedPackedFunc<
- Expr(const Call& call,
- const Message& message,
- const Expr& scale,
- const BackwardTransformer& transformer)>;
+using FBackwardTransform =
+ TypedPackedFunc<Expr(const Call& call, const Message& message, const Expr& scale,
+ const BackwardTransformer& transformer)>;
//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
class BackwardPrep : private ExprVisitor {
public:
// The message on each node.
- std::unordered_map<const Object*, Message>
- Prepare(const Expr& body) {
+ std::unordered_map<const Object*, Message> Prepare(const Expr& body) {
ref_counter_ = GetExprRefCount(body);
this->VisitExpr(body);
return std::move(message_);
// Visit the expression.
void VisitExpr_(const CallNode* call) {
ExprVisitor::VisitExpr_(call);
- static const auto& fprep =
- Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
+ static const auto& fprep = Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
auto f = fprep.get(call->op, nullptr);
if (f == nullptr) return;
auto rit = ref_counter_.find(call);
}
};
-class BackwardTransformerNode :
- public Object,
- private ExprMutator {
+class BackwardTransformerNode : public Object, private ExprMutator {
public:
// Run forward transform.
Expr Fold(Expr expr) {
class BackwardTransformer : public ObjectRef {
public:
BackwardTransformer() {}
- explicit BackwardTransformer(
- ::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {
- }
+ explicit BackwardTransformer(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {}
BackwardTransformerNode* operator->() const {
return static_cast<BackwardTransformerNode*>(get_mutable());
}
using ContainerType = BackwardTransformerNode;
};
-Expr BackwardTransformerNode::Transform(
- const CallNode* call_node, Message message, Expr scale) {
- static const auto& ftransform =
- Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
+Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) {
+ static const auto& ftransform = Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr);
if (f != nullptr) {
const Call call = GetRef<Call>(call_node);
if (it != memo_.end()) {
return it->second;
}
- Expr new_expr = f(GetRef<Call>(call_node),
- message,
- scale,
- GetRef<BackwardTransformer>(this));
+ Expr new_expr = f(GetRef<Call>(call_node), message, scale, GetRef<BackwardTransformer>(this));
memo_[call] = new_expr;
return new_expr;
} else {
}
}
-
//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------
return in_messages[0];
}
-Expr ReluBackwardTransform(const Call& call,
- const Message& message,
- const Expr& scale,
+Expr ReluBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
- Expr input = transformer->Transform(
- call->args[0], message, scale);
+ Expr input = transformer->Transform(call->args[0], message, scale);
return Call(call->op, {input}, call->attrs, call->type_args);
}
-RELAY_REGISTER_OP("nn.relu")
-.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
+RELAY_REGISTER_OP("nn.relu").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
-RELAY_REGISTER_OP("nn.relu")
-.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
+RELAY_REGISTER_OP("nn.relu").set_attr<FBackwardTransform>("FScaleAxisBackwardTransform",
+ ReluBackwardTransform);
RELAY_REGISTER_OP("nn.leaky_relu")
-.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
+ .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
RELAY_REGISTER_OP("nn.leaky_relu")
-.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
+ .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
// AddSub
Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
StructuralEqual equal;
- if (in_messages[0].defined() &&
- MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
+ if (in_messages[0].defined() && MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
return in_messages[0];
} else if (in_messages[1].defined() &&
MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) {
return in_messages[1];
- } else if (in_messages[0].defined() &&
- in_messages[1].defined() &&
- equal(in_messages[0]->axes, in_messages[1]->axes) &&
- equal(tlhs->shape, trhs->shape)) {
+ } else if (in_messages[0].defined() && in_messages[1].defined() &&
+ equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) {
// add of two elements.
return in_messages[0];
} else {
}
}
-Expr AddSubBackwardTransform(const Call& call,
- const Message& message,
- const Expr& scale,
+Expr AddSubBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
} else if (lhs_message.defined()) {
CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], message, scale);
- Expr rhs = transformer->Transform(
- call->args[1], NullValue<Message>(), NullValue<Expr>());
- Expr rhs_scale = ExpandBiasToMatchAxis(
- scale, tlhs->shape.size(), message->axes);
+ Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
+ Expr rhs_scale = ExpandBiasToMatchAxis(scale, tlhs->shape.size(), message->axes);
rhs = Multiply(rhs, rhs_scale);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (rhs_message.defined()) {
CHECK(equal(message->axes, rhs_message->axes));
- Expr lhs = transformer->Transform(
- call->args[0], NullValue<Message>(), NullValue<Expr>());
+ Expr lhs = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr rhs = transformer->Transform(call->args[1], message, scale);
- Expr lhs_scale = ExpandBiasToMatchAxis(
- scale, trhs->shape.size(), message->axes);
+ Expr lhs_scale = ExpandBiasToMatchAxis(scale, trhs->shape.size(), message->axes);
lhs = Multiply(lhs, lhs_scale);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else {
}
}
-RELAY_REGISTER_OP("add")
-.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
+RELAY_REGISTER_OP("add").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
-RELAY_REGISTER_OP("add")
-.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
+RELAY_REGISTER_OP("add").set_attr<FBackwardTransform>("FScaleAxisBackwardTransform",
+ AddSubBackwardTransform);
-RELAY_REGISTER_OP("subtract")
-.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
+RELAY_REGISTER_OP("subtract").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
RELAY_REGISTER_OP("subtract")
-.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
+ .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
// Producer operators
// Multiply produces the scale-axis pair.
-Expr MultiplyBackwardTransform(const Call& call,
- const Message& message,
- const Expr& scale,
+Expr MultiplyBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
CHECK(!message.defined()) << "outstanding scale";
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
}
RELAY_REGISTER_OP("multiply")
-.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);
+ .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);
// Consumer operators
// Conv2D send out requirement of axis folding.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
- kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
- c_small_axis < 0 &&
+ kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
return Message({c_big_axis}, false);
} else {
}
// Conv2D consumes the scale axis during transformation.
-Expr Conv2DBackwardTransform(const Call& call,
- const Message& message,
- const Expr& scale,
+Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
- CHECK(message->axes.size() == 1 &&
- c_big_axis == message->axes[0]->value);
+ CHECK(message->axes.size() == 1 && c_big_axis == message->axes[0]->value);
int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
- Expr data = transformer->Transform(
- call->args[0], NullValue<Message>(), NullValue<Expr>());
- Expr weight = transformer->Transform(
- call->args[1], NullValue<Message>(), NullValue<Expr>());
+ Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
+ Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
// scale on input for deptwise.
- Expr wscale = ExpandBiasToMatchAxis(
- scale, kernel_layout.ndim(), {big_oc_axis});
+ Expr wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, wscale);
- return Call(
- call->op, {data, weight}, call->attrs, call->type_args);
+ return Call(call->op, {data, weight}, call->attrs, call->type_args);
}
RELAY_REGISTER_OP("nn.conv2d")
-.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);
+ .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);
RELAY_REGISTER_OP("nn.conv2d")
-.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
+ .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
Expr BackwardFoldScaleAxis(const Expr& data) {
return make_object<BackwardTransformerNode>()->Fold(data);
Pass ForwardFoldScaleAxis() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(
- relay::fold_scale_axis::ForwardFoldScaleAxis(f));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(relay::fold_scale_axis::ForwardFoldScaleAxis(f));
+ };
return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
-.set_body_typed(ForwardFoldScaleAxis);
+TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis").set_body_typed(ForwardFoldScaleAxis);
Pass BackwardFoldScaleAxis() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(
- relay::fold_scale_axis::BackwardFoldScaleAxis(f));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(relay::fold_scale_axis::BackwardFoldScaleAxis(f));
+ };
return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
-.set_body_typed(BackwardFoldScaleAxis);
+TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis").set_body_typed(BackwardFoldScaleAxis);
Pass FoldScaleAxis() {
// FoldScaleAxis pass contains the following three passes. Therefore, we can
// register it as a sequential pass.
- Pass pass = Sequential(
- {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()},
- "FoldScaleAxis");
+ Pass pass = Sequential({BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()},
+ "FoldScaleAxis");
return pass;
}
-TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis")
-.set_body_typed(FoldScaleAxis);
+TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis").set_body_typed(FoldScaleAxis);
} // namespace transform
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+
#include "pass_util.h"
namespace tvm {
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private MixedModeMutator {
public:
- Expr Realize(Expr expr) {
- return Mutate(expr);
- }
+ Expr Realize(Expr expr) { return Mutate(expr); }
private:
Expr DispatchVisitExpr(const Expr& expr) final {
ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger)
- : rewrite_map_(rewrite_map),
- fcontext_(fcontext),
- fmulti_ref_trigger_(fmulti_ref_trigger) {}
+ : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {}
ForwardRewriter(const FForwardRewrite* rewrite_func,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger)
- : rewrite_func_(rewrite_func),
- fcontext_(fcontext),
- fmulti_ref_trigger_(fmulti_ref_trigger) {}
-
+ : rewrite_func_(rewrite_func), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {}
// Transform expression.
Expr Rewrite(const Expr& expr) {
TempRealizer realizer_;
// Visit and allow non-realized version.
- Expr GetTempExpr(const Expr& expr, const Expr& post) {
+ Expr GetTempExpr(const Expr& expr, const Expr& post) {
if (fmulti_ref_trigger_ != nullptr) {
Expr ret = post;
auto it = ref_counter_.find(expr.get());
}
// try to rewrite.
if (frewrite != nullptr) {
- Expr res = frewrite(
- ref_call, call_args,
- fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr));
+ Expr res = frewrite(ref_call, call_args,
+ fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr));
if (res.defined()) return res;
// abort, use old rule
for (size_t i = 0; i < call_args.size(); ++i) {
}
}
if (unchanged) return ref_call;
- return Call(
- new_op, call_args, call_node->attrs, call_node->type_args);
+ return Call(new_op, call_args, call_node->attrs, call_node->type_args);
}
};
-Expr ForwardRewrite(const Expr& expr,
- const std::string& rewrite_map_name,
+Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_name,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
}
-Expr ForwardRewrite(const Expr& expr,
- const FForwardRewrite& rewrite_func,
+Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
* \brief This is a backend-aware optimization pass.
* Fuse necessary ops into a single one.
*/
-#include <tvm/tir/op.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
-#include "pattern_util.h"
+#include <tvm/tir/op.h>
+
#include "../../support/arena.h"
+#include "pattern_util.h"
namespace tvm {
namespace relay {
However, at the point of conv2d we do not necessarily know that all the future paths
will merge at the elemwise add. The fusion algorithm applies post-dominator analysis.
- The immediate post-dominator of a node defined by the closest node where all the future path goes into.
- In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm is as follows:
+ The immediate post-dominator of a node defined by the closest node where all the future path goes
+ into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm
+ is as follows:
- Construct a DAG of dataflow graph for dominator analysis
- Construct a post-dominator tree which gives immediate post dominator of each node.
- CommitFuse: mark all the nodes between source and post-dominator as the same group.
- We use an Union-Find data structure to manage the groups.
*/
-using support::LinkNode;
using support::LinkedList;
+using support::LinkNode;
constexpr uint32_t kMaxFusedOps = 256;
std::ostringstream os;
for (size_t i = 0; i < post_dfs_order.size(); ++i) {
Node* node = post_dfs_order[i];
- os << "node[" << i << "], "
- << GetRef<ObjectRef>(node->ref)
- << " outputs=[";
+ os << "node[" << i << "], " << GetRef<ObjectRef>(node->ref) << " outputs=[";
for (auto* link = node->outputs.head; link != nullptr; link = link->next) {
os << link->value.node->index << ", ";
}
// Creator of post dominator tree of the dataflow
class IndexedForwardGraph::Creator : private ExprVisitor {
public:
- explicit Creator(support::Arena* arena)
- : arena_(arena) {}
+ explicit Creator(support::Arena* arena) : arena_(arena) {}
IndexedForwardGraph Prepare(const Expr& body) {
this->Update(body, nullptr, kOpaque);
// attribute equal comparator
StructuralEqual attr_equal_;
// Update the message stored at the node.
- void Update(const Expr& node,
- IndexedForwardGraph::Node* parent,
- OpPatternKind pattern) {
+ void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) {
const tvm::Object* key = node.get();
IndexedForwardGraph::Node* current;
auto it = graph_.node_map.find(key);
void AddNode(const tvm::Object* key) {
auto it = graph_.node_map.find(key);
- CHECK(it != graph_.node_map.end())
- << "Cannot find node " << GetRef<ObjectRef>(key);
+ CHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<ObjectRef>(key);
IndexedForwardGraph::Node* node = it->second;
CHECK(node->ref == nullptr);
node->ref = key;
Node* node = graph_.node_map.at(op);
DataType dtype = DataType(op->data->dtype);
// This rule must be consistent with code generator.
- bool is_simple_const = (
- dtype == DataType::Int(32) ||
- dtype == DataType::Int(64) ||
- dtype == DataType::Float(32) ||
- dtype == DataType::Float(64) ||
- dtype == DataType::Bool());
+ bool is_simple_const =
+ (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) ||
+ dtype == DataType::Float(64) || dtype == DataType::Bool());
if (op->is_scalar() && is_simple_const) {
node->pattern = kElemWise;
} else {
void VisitExpr_(const CallNode* call) final {
CHECK(graph_.node_map.count(call));
Node* node = graph_.node_map.at(call);
- static auto fpattern =
- Op::GetAttr<TOpPattern>("TOpPattern");
+ static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
// Now we set the pattern of this call.
//
// If we see a call mentioning an operator we should mark it with its
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the analysis back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
- const auto* arg_type =
- call->args[i]->checked_type().as<TensorTypeNode>();
+ const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>();
// specifically check if result type is the same as arguments type
OpPatternKind edge_pattern = op_pattern;
- if (edge_pattern == kBroadcast &&
- arg_type != nullptr &&
- rtype != nullptr &&
+ if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr &&
attr_equal_(rtype->shape, arg_type->shape)) {
edge_pattern = kElemWise;
}
this->AddNode(op);
}
- void VisitExpr_(const VarNode* op) final {
- this->AddNode(op);
- }
+ void VisitExpr_(const VarNode* op) final { this->AddNode(op); }
void VisitExpr_(const LetNode* op) final {
// do not fuse through let.
}
};
-IndexedForwardGraph IndexedForwardGraph::Create(
- support::Arena* arena, const Expr& body) {
+IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) {
return Creator(arena).Prepare(body);
}
* \note This algorithm makes use of the fact that graph is DAG,
* and runs a single pass algorithm via LCA (Least Common Ancestor)
*/
- static DominatorTree PostDom(support::Arena* arena,
- const IndexedForwardGraph& graph);
+ static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph);
private:
// Combine pattern together.
- static OpPatternKind CombinePattern(
- OpPatternKind lhs, OpPatternKind rhs) {
+ static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
if (lhs > rhs) return lhs;
return rhs;
}
* The combined edge pattern across all the parents.
* \return The least common ancestor of the two.
*/
- static Node* LeastCommonAncestor(
- Node* lhs,
- Node* rhs,
- OpPatternKind* edge_pattern) {
+ static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) {
while (lhs != rhs) {
if (lhs == nullptr) return nullptr;
if (rhs == nullptr) return nullptr;
if (lhs->depth < rhs->depth) {
- edge_pattern[0] = CombinePattern(
- edge_pattern[0], rhs->pattern);
+ edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
rhs = rhs->parent;
} else if (rhs->depth < lhs->depth) {
- edge_pattern[0] = CombinePattern(
- edge_pattern[0], lhs->pattern);
+ edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
lhs = lhs->parent;
} else {
- edge_pattern[0] = CombinePattern(
- edge_pattern[0], lhs->pattern);
- edge_pattern[0] = CombinePattern(
- edge_pattern[0], rhs->pattern);
+ edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
+ edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
lhs = lhs->parent;
rhs = rhs->parent;
}
}
};
-
-DominatorTree DominatorTree::PostDom(support::Arena* arena,
- const IndexedForwardGraph& graph) {
+DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
DominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order
/*! \brief internal field used for deduplication */
std::unordered_set<IndexedForwardGraph::Node*> visited_;
// Internal implelementation of CheckPath
- template<typename F>
- bool CheckPath_(IndexedForwardGraph::Node* src,
- IndexedForwardGraph::Node* sink,
- F fcond) {
+ template <typename F>
+ bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) {
if (visited_.count(src)) return true;
visited_.insert(src);
- Group* gnode = groups_[src->index];
+ Group* gnode = groups_[src->index];
CHECK(gnode != nullptr);
gnode = gnode->FindRoot();
if (!fcond(gnode->pattern, src == sink)) return false;
* \tparam F the condition function, with signature
* \note sink must be a post-dominator of src.
*/
- template<typename F>
- bool CheckPath(IndexedForwardGraph::Node* src,
- IndexedForwardGraph::Node* sink,
- F fcond) {
+ template <typename F>
+ bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) {
CHECK(!src->extern_ref);
visited_.clear();
CHECK(src != sink);
return true;
}
// Combine two patterns together.
- static OpPatternKind CombinePattern(
- OpPatternKind lhs, OpPatternKind rhs) {
+ static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
if (lhs > kBroadcast && rhs > kBroadcast) {
LOG(FATAL) << "Cannot merge two complex group together";
}
if (child->master_ref != nullptr) {
CHECK(parent->master_ref == nullptr);
parent->master_ref = child->master_ref;
- parent->pattern = CombinePattern(
- child->pattern, parent->pattern);
+ parent->pattern = CombinePattern(child->pattern, parent->pattern);
}
}
// Internal implelementation of CommitFuse
- void CommitFuse_(IndexedForwardGraph::Node* src,
- IndexedForwardGraph::Node* sink,
- Group* target) {
+ void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) {
if (src == sink) return;
if (visited_.count(src)) return;
visited_.insert(src);
* \param sink The termination node.
* \note sink must be a post-dominator of src.
*/
- void CommitFuse(IndexedForwardGraph::Node* src,
- IndexedForwardGraph::Node* sink) {
+ void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
Group* target = groups_[sink->index];
visited_.clear();
CHECK(src != sink);
}
// execute the fusion algorithm.
- void RunFuse(const IndexedForwardGraph& graph,
- const DominatorTree& post_dom_tree,
- int phase) {
+ void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) {
for (size_t nid = 0; nid < groups_.size(); ++nid) {
// the group of current node has been specified already.
auto* graph_node = graph.post_dfs_order[nid];
size_t dom_parent_gindex = dom_node->parent->gnode->index;
// refuse the fusion if too many ops are going to be fused together
- if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps)
- continue;
+ if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) continue;
if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (dom_root_group->pattern == kTuple) continue;
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
// Now we know the tuple has been fused into subsequent injective ops
- auto fcond = [](OpPatternKind kind, bool is_sink) {
- return kind <= kInjective;
- };
+ auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
// dom_root_group can also be tuple, as in inception layers
// CheckPath is needed to avoid fusing two intermediate tuples
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
CHECK(dom_node->parent->gnode != nullptr);
// The fuse can be executed if all the intermediate ops are still broadcast.
- auto fcond = [](OpPatternKind kind, bool is_sink) {
- return kind <= kBroadcast;
- };
+ auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
} else if (group_node->pattern <= kBroadcast) {
// Pre-condition: can only be fused to parent which is injective or reduction.
if (dom_node->parent != nullptr &&
- (dom_node->pattern <= kInjective ||
- dom_node->pattern == kCommReduce)) {
+ (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
// Check if all the intermediate ops are still broadcast.
// The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) {
// are allowed be fused to the elemwise/broadcast master.
return kind <= kInjective;
} else {
- return (kind <= kBroadcast ||
- kind == kCommReduce ||
- kind == kInjective ||
+ return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
kind == kOutEWiseFusable);
}
};
// so conv2d always finishes fusing.
if (phase != 1) continue;
// Check if all path are injective.
- auto fcond = [](OpPatternKind kind, bool is_sink) {
- return kind <= kInjective;
- };
+ auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
};
-std::vector<GraphPartitioner::Group*>
-GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
+std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
+ const IndexedForwardGraph& graph) {
this->InitGroups(graph);
if (opt_level_ == 0) return std::move(groups_);
// get post dominator tree
Expr Transform(const Expr& body, int fuse_opt_level) {
// setup the group map.
auto graph = IndexedForwardGraph::Create(&arena_, body);
- auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(
- graph);
+ auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph);
for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
CHECK(graph.post_dfs_order[nid]->ref != nullptr);
gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
return this->Mutate(body);
}
-
private:
/*! \brief Temporary information from each group. */
struct GroupInfo {
// Transform calls.
Expr VisitExpr_(const CallNode* call) {
if (call->op.as<OpNode>()) {
- static auto fnoncomputational =
- Op::GetAttr<TNonComputational>("TNonComputational");
+ static auto fnoncomputational = Op::GetAttr<TNonComputational>("TNonComputational");
if (fnoncomputational.get(Downcast<Op>(call->op), false)) {
return ExprMutator::VisitExpr_(call);
auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> new_args = GetNewArguments(call->args, ret_group);
- auto new_call = Call(
- call->op, new_args, call->attrs, call->type_args);
+ auto new_call = Call(call->op, new_args, call->attrs, call->type_args);
if (ret_group->root_ref == call) {
// This is the root of the group
// If the function has no call, it is not a primitive function.
struct HasCallVisitor : ExprVisitor {
bool has_call = false;
- void VisitExpr_(const CallNode* op) final {
- has_call = true;
- }
+ void VisitExpr_(const CallNode* op) final { has_call = true; }
} visitor;
visitor(body);
const GroupInfo& ginfo = ginfo_[group];
// Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) {
std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string {
- auto it = gmap_.find(expr.get());
- if (it == gmap_.end()) return "";
- std::ostringstream os;
- auto *group = it->second->FindRoot();
- os << " /* group=" << group << " */";
- return os.str();
- });
+ auto it = gmap_.find(expr.get());
+ if (it == gmap_.end()) return "";
+ std::ostringstream os;
+ auto* group = it->second->FindRoot();
+ os << " /* group=" << group << " */";
+ return os.str();
+ });
LOG(INFO) << "Dump of group info:\n" << text;
}
};
Pass FuseOps(int fuse_opt_level) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
- return Downcast<Function>(FuseOps(f, opt_level, m));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
+ return Downcast<Function>(FuseOps(f, opt_level, m));
+ };
return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.FuseOps")
-.set_body_typed(FuseOps);
+TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps);
} // namespace transform
* \brief API for Automatic Differentiation for the Relay IR.
*/
#include <tvm/ir/type_functor.h>
-#include <tvm/te/operation.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
-#include "pattern_util.h"
-#include "pass_util.h"
+#include <tvm/te/operation.h>
+
#include "let_list.h"
+#include "pass_util.h"
+#include "pattern_util.h"
namespace tvm {
namespace relay {
* Formally speaking, such requirement mean that the input function is a closed expression -
* that is, it only refer to local variable that is it's parameter, or defined inside it.
* Every top level definition satisfy this criteria.
- * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> (Float[] -> Float[]).
- * In relay we currently only support compile-time AD, but it should be enough for a lot of use case.
+ * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) ->
+ * (Float[] -> Float[]). In relay we currently only support compile-time AD, but it should be enough
+ * for a lot of use case.
*
- * In deep learning, the most common way to train a deep neural network is by gradient descent or some of it's variant.
- * Such optimization method require us to input the gradient of neural network, which can be obtained easily using AD.
- * In fact, back propagation is essentially reverse-mode automatic differentiation, a kind of AD!
+ * In deep learning, the most common way to train a deep neural network is by gradient descent or
+ * some of it's variant. Such optimization method require us to input the gradient of neural
+ * network, which can be obtained easily using AD. In fact, back propagation is essentially
+ * reverse-mode automatic differentiation, a kind of AD!
*/
/*! In relay, automatic differentiation(AD) is a macro,
* (x0, x1, x2, ...) -> Float[] to
* (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)),
* When x0, x1, x2... are Float of different shape.
- * the return value is a pair, with left hand side as the original value, and right hand side as gradient of the input.
- * WithGradientType will take the type of input, and produce the type of output.
- * There are multiple implementation of AD in relay, with different characteristic.
- * However, they all transform the input expr according to WithGradientType.
+ * the return value is a pair, with left hand side as the original value, and right hand side as
+ * gradient of the input. WithGradientType will take the type of input, and produce the type of
+ * output. There are multiple implementation of AD in relay, with different characteristic. However,
+ * they all transform the input expr according to WithGradientType.
*/
Type WithGradientType(const Type&);
// TODO(M.K.): stricter checking
auto ty = t.as<FuncTypeNode>();
CHECK(ty) << "input should be a function";
- return FuncType(ty->arg_types,
- TupleType({
- ty->ret_type,
- TupleType(ty->arg_types)}), {}, {});
+ return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {});
}
//! \brief if the expression is a GlobalVar, transform to it's expression.
* pass.
*/
struct ADValueNode {
- virtual ~ADValueNode() { }
+ virtual ~ADValueNode() {}
template <typename T>
T& get() {
auto ret = dynamic_cast<T*>(this);
struct ADTensor : ADValueNode {
Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication
- ADTensor(LetList* ll, const Expr& forward) :
- forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) {
+ ADTensor(LetList* ll, const Expr& forward)
+ : forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) {
this->forward->checked_type_ = forward->checked_type();
}
};
* can compute away this function to obtain a reverse mode program.
*/
struct ADFunction : ADValueNode {
- std::function<ADValue(const Type&,
- const std::vector<ADValue>&,
- const Attrs&,
- const tvm::Array<Type>&)> func;
- explicit ADFunction(const std::function<ADValue(const Type&,
- const std::vector<ADValue>&,
- const Attrs&,
- const tvm::Array<Type>&)>& func) :
- func(func) { }
+ std::function<ADValue(const Type&, const std::vector<ADValue>&, const Attrs&,
+ const tvm::Array<Type>&)>
+ func;
+ explicit ADFunction(const std::function<ADValue(const Type&, const std::vector<ADValue>&,
+ const Attrs&, const tvm::Array<Type>&)>& func)
+ : func(func) {}
};
-struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
+struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Var, ADValue, ObjectHash, ObjectEqual> env;
LetList* ll;
- FirstOrderReverseAD(LetList* ll) : ll(ll) { }
+ FirstOrderReverseAD(LetList* ll) : ll(ll) {}
ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
- CHECK(rev_map.count(op_ref))
- << op->name << " does not have reverse mode defined";
- return std::make_shared<ADFunction>([this, op_ref](const Type& orig_type,
- const std::vector<ADValue>& args,
- const Attrs& attrs,
- const tvm::Array<Type>& type_args) {
- std::vector<Expr> call_args;
- for (const ADValue& adval : args) {
- call_args.push_back(adval->get<ADTensor>().forward);
- }
- auto orig = Call(op_ref, call_args, attrs, type_args);
- orig->checked_type_ = orig_type;
- auto ret = std::make_shared<ADTensor>(ll, orig);
- backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
- tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
- CHECK(args.size() == rev.size());
- for (size_t i = 0; i < args.size(); ++i) {
- args[i]->get<ADTensor>().reverse =
- ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
- }
- });
- return ret;
- });
+ CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined";
+ return std::make_shared<ADFunction>(
+ [this, op_ref](const Type& orig_type, const std::vector<ADValue>& args, const Attrs& attrs,
+ const tvm::Array<Type>& type_args) {
+ std::vector<Expr> call_args;
+ for (const ADValue& adval : args) {
+ call_args.push_back(adval->get<ADTensor>().forward);
+ }
+ auto orig = Call(op_ref, call_args, attrs, type_args);
+ orig->checked_type_ = orig_type;
+ auto ret = std::make_shared<ADTensor>(ll, orig);
+ backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
+ tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
+ CHECK(args.size() == rev.size());
+ for (size_t i = 0; i < args.size(); ++i) {
+ args[i]->get<ADTensor>().reverse =
+ ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
+ }
+ });
+ return ret;
+ });
}
ADValue VisitExpr_(const ConstantNode* op) final {
ADValue VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
// todo: assert no closure
- return std::make_shared<ADFunction>([this, f](const Type& orig_type,
- const std::vector<ADValue>& args,
- const Attrs& attrs,
- const tvm::Array<Type>& type_args) {
- CHECK_EQ(f->params.size(), args.size());
- for (size_t i = 0; i < f->params.size(); ++i) {
- env[f->params[i]] = args[i];
- }
- return VisitExpr(f->body);
- });
+ return std::make_shared<ADFunction>(
+ [this, f](const Type& orig_type, const std::vector<ADValue>& args, const Attrs& attrs,
+ const tvm::Array<Type>& type_args) {
+ CHECK_EQ(f->params.size(), args.size());
+ for (size_t i = 0; i < f->params.size(); ++i) {
+ env[f->params[i]] = args[i];
+ }
+ return VisitExpr(f->body);
+ });
}
ADValue VisitExpr_(const VarNode* op) final {
const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) {
res.reverse = OnesLike(res.forward);
- for (auto it = reverse_ad.backprop_actions.rbegin();
- it != reverse_ad.backprop_actions.rend();
+ for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
}
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
-TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient")
-.set_body_typed(FirstOrderGradient);
+TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient);
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
}
};
-Type ReverseType(const Type& t) {
- return ReverseADType()(t);
-}
+Type ReverseType(const Type& t) { return ReverseADType()(t); }
/*! \brief Lift a function that transform Tensor to a function that also transform more type
* by doing a structure preserving map.
*/
Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
- const std::function<Type(const Type&)>& tf,
- const Type& forward_type,
- const Expr& e,
+ const std::function<Type(const Type&)>& tf, const Type& forward_type, const Expr& e,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
if (forward_type.as<TensorTypeNode>()) {
tvm::Array<Expr> fields;
tvm::Array<Type> types;
for (size_t i = 0; i < tt->fields.size(); ++i) {
- auto field = LiftTensor(f,
- tf,
- tt->fields[i],
- ll->Push(GetField(e, i)),
- ll);
+ auto field = LiftTensor(f, tf, tt->fields[i], ll->Push(GetField(e, i)), ll);
fields.push_back(field);
types.push_back(field->checked_type_);
}
/*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
* by stitching the references in the AD values.
*/
-void TransferGrads(const Type& forward_type,
- const Expr& from,
- const Expr& to,
- LetList* ll) {
+void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, LetList* ll) {
CHECK(IsAtomic(from)) << from;
CHECK(IsAtomic(to)) << to;
if (forward_type.as<TensorTypeNode>()) {
ll->Push(RefWrite(to_ref, RefRead(from_ref)));
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
- TransferGrads(tt->fields[i],
- ll->Push(TupleGetItem(from, i)),
- ll->Push(TupleGetItem(to, i)),
+ TransferGrads(tt->fields[i], ll->Push(TupleGetItem(from, i)), ll->Push(TupleGetItem(to, i)),
ll);
}
} else {
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
- auto rev = [&](const Expr& e) {
- return Pair(e, ll->Push(RefCreate(ZerosLike(e))));
- };
- auto rev_type = [&](const Type& forward_type) {
- return ReverseType(forward_type);
- };
+ auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); };
+ auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); };
return LiftTensor(rev, rev_type, forward_type, e, ll);
}
/*! \brief ReverseType(t) -> t. Get the original value. */
Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
- auto val = [&](const Expr& e) {
- return GetField(e, 0);
- };
- auto val_type = [&](const Type& forward_type) {
- return forward_type;
- };
+ auto val = [&](const Expr& e) { return GetField(e, 0); };
+ auto val_type = [&](const Type& forward_type) { return forward_type; };
return LiftTensor(val, val_type, forward_type, e, ll);
}
/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
- auto grad = [&](const Expr& e) {
- return ll->Push(RefRead(GetField(e, 1)));
- };
- auto grad_type = [&](const Type& forward_type) {
- return forward_type;
- };
+ auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); };
+ auto grad_type = [&](const Type& forward_type) { return forward_type; };
return LiftTensor(grad, grad_type, forward_type, e, ll);
}
void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
if (t.as<TensorTypeNode>()) {
- ll->Push(RefWrite(GetField(arg, 1),
- Add(ll->Push(RefRead(GetField(arg, 1))),
- grad)));
+ ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad)));
} else if (auto* tt = t.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
- UpdateGrad(tt->fields[i],
- ll->Push(GetField(arg, i)),
- ll->Push(GetField(grad, i)),
- ll);
+ UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll);
}
} else {
LOG(FATAL) << "unsupported arg type of operator: " << t;
std::shared_ptr<ADVarMap> ad_vars;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
- explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars)
- : bp(bp), ad_vars(ad_vars) { }
+ explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars) : bp(bp), ad_vars(ad_vars) {}
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
throw;
}
- Expr VisitCheckpoint(const CallNode *call) {
+ Expr VisitCheckpoint(const CallNode* call) {
const OpNode* op_node = call->op.as<OpNode>();
CHECK(op_node) << "expected op in call";
Op op_ref = GetRef<Op>(op_node);
auto x_var = ll->Push(x);
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefRead(bp));
- Expr nbp = Function(
- {},
- LetList::With([&](LetList* ll) {
- // we need a new ReverseAD visitor to avoid clobbering the bp local var
- auto dup_bp = ll->Push(BPEmpty());
- ReverseAD dup_diff(dup_bp, ad_vars);
- auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
-
- TransferGrads(call->checked_type(), ret, dup_ad, ll);
- ll->Push(Call(RefRead(dup_bp), {}));
- return Call(bpv, {});
- }),
- TupleType::Empty(),
- {});
+ Expr nbp = Function({}, LetList::With([&](LetList* ll) {
+ // we need a new ReverseAD visitor to avoid clobbering the bp local var
+ auto dup_bp = ll->Push(BPEmpty());
+ ReverseAD dup_diff(dup_bp, ad_vars);
+ auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
+
+ TransferGrads(call->checked_type(), ret, dup_ad, ll);
+ ll->Push(Call(RefRead(dup_bp), {}));
+ return Call(bpv, {});
+ }),
+ TupleType::Empty(), {});
ll->Push(RefWrite(bp, nbp));
return ret;
});
return VisitCheckpoint(call);
}
- CHECK(rev_map.count(op_ref))
- << op_node->name << " does not have reverse mode defined";
+ CHECK(rev_map.count(op_ref)) << op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
for (const auto& arg : call->args) {
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefRead(bp));
- Expr nbp = Function(
- {},
- LetList::With([&](LetList* ll) {
- tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
- CHECK(args.size() == rev.size());
- for (size_t i = 0; i < args.size(); ++i) {
- UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
- }
- return Call(bpv, {});
- }),
- TupleType::Empty(),
- {});
+ Expr nbp = Function({}, LetList::With([&](LetList* ll) {
+ tvm::Array<Expr> rev =
+ rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
+ CHECK(args.size() == rev.size());
+ for (size_t i = 0; i < args.size(); ++i) {
+ UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
+ }
+ return Call(bpv, {});
+ }),
+ TupleType::Empty(), {});
ll->Push(RefWrite(bp, nbp));
return ret;
});
}
Expr VisitExpr_(const IfNode* op) final {
- return If(TupleGetItem(VisitExpr(op->cond), 0),
- VisitExpr(op->true_branch),
- VisitExpr(op->false_branch));
+ return If(TupleGetItem(VisitExpr(op->cond), 0), VisitExpr(op->true_branch),
+ VisitExpr(op->false_branch));
}
Expr VisitExpr_(const VarNode* var) final {
return ad_vars->at(var_ref);
}
- Type VisitType(const Type& t) final {
- return t.defined() ? ReverseType(t) : t;
- }
+ Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; }
};
bool MissingGrad(const Expr& e) {
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
-TVM_REGISTER_GLOBAL("relay._transform.gradient")
-.set_body_typed(Gradient);
+TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient);
} // namespace relay
} // namespace tvm
#ifndef TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_
#define TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_
-#include <tvm/tir/data_layout.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
+#include <tvm/tir/data_layout.h>
+
#include <string>
#include <tuple>
+
#include "pattern_util.h"
namespace tvm {
* \return infered_layout An array of two elements that are inferred input layouts and
* inferred output layouts.
*/
-using FInferCorrectLayout = runtime::TypedPackedFunc<
- Array<Array<Layout>>(const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types)>;
+using FInferCorrectLayout = runtime::TypedPackedFunc<Array<Array<Layout>>(
+ const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types)>;
/*! \brief take arbitrary input layout and copy to output */
-inline Array<Array<Layout> > ElemwiseArbitraryLayout(const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types) {
+inline Array<Array<Layout>> ElemwiseArbitraryLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
Layout ret;
if (new_in_layouts.defined()) {
}
}
- return Array<Array<Layout> >{Array<Layout>(old_in_layouts.size(), ret), {ret}};
+ return Array<Array<Layout>>{Array<Layout>(old_in_layouts.size(), ret), {ret}};
}
/*! \brief Infer layout for binary broadcast operators */
-inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type> &old_in_types) {
+inline Array<Array<Layout>> BinaryBroadcastLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
Array<Layout> layouts;
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
if (!layouts[0].defined() && !layouts[1].defined()) {
// both undefined, infer fails
- return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
+ return Array<Array<Layout>>{{Layout::Undef()}, {Layout::Undef()}};
} else if (!layouts[0].defined() || !layouts[1].defined()) {
// only one is defined, use shape information to help infer
int defined_idx = layouts[0].defined() ? 0 : 1;
int undef_idx = 1 - defined_idx;
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
- layouts.Set(undef_idx,
- layouts[defined_idx].SubLayout(
- old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
- old_in_shapes[undef_idx].size()));
- return Array<Array<Layout> >{layouts, {layouts[defined_idx]}};
+ layouts.Set(undef_idx, layouts[defined_idx].SubLayout(old_in_shapes[defined_idx].size() -
+ old_in_shapes[undef_idx].size(),
+ old_in_shapes[undef_idx].size()));
+ return Array<Array<Layout>>{layouts, {layouts[defined_idx]}};
} else {
// only know the tensor with smaller dimensions,
// so we cannot infer the final broadcasted output.
// fails in this case.
- return Array<Array<Layout> >{{Layout::Undef()}, {Layout::Undef()}};
+ return Array<Array<Layout>>{{Layout::Undef()}, {Layout::Undef()}};
}
} else if (layouts[0].defined() && layouts[1].defined() &&
- (layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) {
+ (layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) {
int scalar = layouts[0].ndim() == 0 ? 0 : 1;
- return Array<Array<Layout> >{layouts, {layouts[1-scalar]}};
+ return Array<Array<Layout>>{layouts, {layouts[1 - scalar]}};
} else {
// Set the layout of the larger dimension. If one dimension size is lower, we call expand dims
// while transforming layout.
Op op = Downcast<Op>(call->op);
if (finfer_layout.count(op)) {
Array<Array<Layout>> inferred_layouts;
- inferred_layouts =
- finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types);
+ inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types);
CHECK_EQ(inferred_layouts.size(), 2)
<< "FInferCorrectLayout should return an array with size of 2";
for (auto x : inferred_layouts) {
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/support/logging.h>
#include <tvm/relay/transform.h>
+#include <tvm/support/logging.h>
+
#include <string>
#include <unordered_set>
}
Function Inline(const Function& func) {
- return Function(func->params,
- VisitExpr(func->body),
- func->ret_type,
- func->type_params,
- func->attrs);
+ return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
+ func->attrs);
}
private:
}
// Make a new Relay expression to replace the callee.
- Expr MakeNewExpr(const GlobalVar& global,
- const Array<Expr>& args,
- const Expr& callee) {
- CHECK(callee->IsInstance<CallNode>() ||
- callee->IsInstance<GlobalVarNode>());
+ Expr MakeNewExpr(const GlobalVar& global, const Array<Expr>& args, const Expr& callee) {
+ CHECK(callee->IsInstance<CallNode>() || callee->IsInstance<GlobalVarNode>());
auto base_func = call_graph_->GetGlobalFunction(global);
const auto* fn = base_func.as<FunctionNode>();
CHECK(fn) << "Expected to work on a Relay function.";
- auto func = Function(fn->params,
- fn->body,
- fn->ret_type,
- fn->type_params,
- fn->attrs);
+ auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
// Cannot replace TensorType/TensorTupleType with FuncType. Therefore,
// we simply inline the function as a closure instead of directly using
// its body when the global var returns FuncType.
- return ret_type->IsInstance<FuncTypeNode>() ? std::move(func)
- : func->body;
+ return ret_type->IsInstance<FuncTypeNode>() ? std::move(func) : func->body;
} else {
CHECK(callee->IsInstance<CallNode>());
return Bind(func->body, bind_map);
}
} else if (const auto* call_node = callee.as<CallNode>()) {
- return Call(func, args, call_node->attrs, call_node->type_args);
+ return Call(func, args, call_node->attrs, call_node->type_args);
} else {
return std::move(func);
}
Pass Inline() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) {
- return relay::Inline(m);
- };
+ [=](IRModule m, PassContext pc) { return relay::Inline(m); };
return CreateModulePass(pass_func, 1, "InlineGlobals", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.Inline")
-.set_body_typed(Inline);
+TVM_REGISTER_GLOBAL("relay._transform.Inline").set_body_typed(Inline);
} // namespace transform
* \brief Lazily instantiate 0-filled or 1-filled tensors.
* This pass should be used after reverse-mode ad so that gradient tensors
* are not instantiated until after the forward pass.
- *
- * This pass delays or removes memory allocation by converting tensors into
+ *
+ * This pass delays or removes memory allocation by converting tensors into
* GradCell, an algebraic data type defined in gradient.rly.
- *
+ *
* This will delay or decrease memory usage. All calls to
* ones, ones_like, zeros, zeros_like will call the One or Zero constructor
* of GradCell, which will not instantiate in memory until needed. All other cases result
* in using the Raw constructor which means the tensor is instantiated in memory.
- *
+ *
* It also overloads + and * operation which can increase performance when doing
* operations involving tensors with values of only 0 or 1.
- *
+ *
* Note: this pass can only be used with functions where the input/output types are
* a combination of TupleTypes and TensorTypes
- *
+ *
* This pass optimizes 6 ops:
* - add
* - multiply
* - ones_like
* - zeros
* - zeros_like
- *
+ *
* This pass makes use of three visitor. The most important one visits the entire function,
* one is used for wrap inputs and one to unwrap outputs.
- *
+ *
* For example:
* fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32]
- *
+ *
* After this pass
* fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]]
- *
+ *
* Thus, it is necessary to wrap this outer function so that the input/output types remain the same
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/ir/type_functor.h>
#include <tvm/relay/transform.h>
+
#include "let_list.h"
namespace tvm {
namespace relay {
/*!
-* \brief Visitor appropriately wraps tensors with Raw constructor
-*
-* Recursively looks at the type of the expression (TensorType or TupleType are only supported for now)
-* and either call the GradCell constructor if TensorType
-* or unfold and recursively visit if TupleType
-*/
-class InputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {
+ * \brief Visitor appropriately wraps tensors with Raw constructor
+ *
+ * Recursively looks at the type of the expression (TensorType or TupleType are only supported for
+ * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if
+ * TupleType
+ */
+class InputVisitor : public ExprFunctor<Expr(const Expr&, const Type&)> {
public:
- explicit InputVisitor(IRModule module): module_(module) {}
+ explicit InputVisitor(IRModule module) : module_(module) {}
Expr VisitExpr_(const VarNode* op, const Type& t) final {
std::cout << op->type_annotation << std::endl;
Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
return WrapExpr(GetRef<TupleGetItem>(op), t);
}
+
private:
IRModule module_;
Expr WrapExpr(const Expr expr, const Type& type) {
if (type.as<TensorTypeNode>()) {
- return Call(module_->GetConstructor("GradCell", "Raw"),
- {expr}, Attrs(), {type});
+ return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type});
} else if (auto* type_anno = type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < type_anno->fields.size(); i++) {
};
/*!
-* \brief Visitor appropriately unwraps expressions with GradCell type into Tensors
-*
-* Recursively looks at the type of the expression
-* and either use the FromGradCell function if TypeCall to GradCell
-* or unfold and recursively visit if TupleType
-*/
-class OutputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {
+ * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors
+ *
+ * Recursively looks at the type of the expression
+ * and either use the FromGradCell function if TypeCall to GradCell
+ * or unfold and recursively visit if TupleType
+ */
+class OutputVisitor : public ExprFunctor<Expr(const Expr&, const Type&)> {
public:
- explicit OutputVisitor(IRModule module): module_(module) {}
+ explicit OutputVisitor(IRModule module) : module_(module) {}
Expr VisitExpr_(const CallNode* op, const Type& t) final {
return UnwrapExpr(GetRef<Call>(op), t);
Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
return UnwrapExpr(GetRef<TupleGetItem>(op), t);
}
+
private:
IRModule module_;
}
};
-class LazyGradientInitializer: public ExprMutator, public TypeMutator {
+class LazyGradientInitializer : public ExprMutator, public TypeMutator {
public:
- explicit LazyGradientInitializer(IRModule module):
- module_(module) {
- module_->ImportFromStd("gradient.rly");
- }
+ explicit LazyGradientInitializer(IRModule module) : module_(module) {
+ module_->ImportFromStd("gradient.rly");
+ }
/*!
- * \brief apply LazyGradientInit transformation and wrap function
- * so that function type stays the same
- *
- * input/output types should only be a combination of TupleTypes and TensorTypes
- */
+ * \brief apply LazyGradientInit transformation and wrap function
+ * so that function type stays the same
+ *
+ * input/output types should only be a combination of TupleTypes and TensorTypes
+ */
Expr Transform(const Expr& e) {
auto* f = (e).as<FunctionNode>();
auto* transformed = this->Mutate(e).as<FunctionNode>();
}
Expr VisitExpr_(const ConstantNode* op) final {
- return Call(module_->GetConstructor("GradCell", "Raw"),
- {GetRef<Constant>(op)}, Attrs(), {op->checked_type()});
+ return Call(module_->GetConstructor("GradCell", "Raw"), {GetRef<Constant>(op)}, Attrs(),
+ {op->checked_type()});
}
Expr VisitExpr_(const CallNode* call_node) final {
if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) {
// fn() -> T, function returns result of the operation
- Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)},
- {call_node->checked_type()}, {});
+ Expr func =
+ Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {});
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero";
- return Call(module_->GetConstructor("GradCell", constructor_name),
- {func}, Attrs(), {call_node->checked_type()});
+ return Call(module_->GetConstructor("GradCell", constructor_name), {func}, Attrs(),
+ {call_node->checked_type()});
}
if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) {
Expr func = Function({}, result, {call_node->checked_type()}, Array<TypeVar>());
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero";
- return Call(module_->GetConstructor("GradCell", "One"),
- {func}, Attrs(), {call_node->checked_type()});
+ return Call(module_->GetConstructor("GradCell", "One"), {func}, Attrs(),
+ {call_node->checked_type()});
}
// handle all other ops
Expr result = CallPrimitiveOp(call_node);
// wrap result with Raw constructor
- return Call(module_->GetConstructor("GradCell", "Raw"), {result},
- Attrs(), {call_node->checked_type()});
+ return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(),
+ {call_node->checked_type()});
}
// not an op
return ExprMutator::VisitExpr_(call_node);
}
- Type VisitType(const Type& t) final {
- return TypeMutator::VisitType(t);
- }
+ Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); }
Type VisitType_(const TensorTypeNode* op) {
GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell");
IRModule module_;
/*!
- * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type
- */
+ * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type
+ */
Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) {
// can only use overloaded functions if 2 arguments of same type
if (call_node->args.size() != 2 ||
!tvm::StructuralEqual()(call_node->args[0]->checked_type(),
call_node->args[1]->checked_type())) {
Expr result = CallPrimitiveOp(call_node);
- return Call(module_->GetConstructor("GradCell", "Raw"), {result},
- Attrs(), {call_node->checked_type()});
+ return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(),
+ {call_node->checked_type()});
}
tvm::Array<Expr> args;
// create "fallback" function for overloaded function
Type paramType = call_node->args[0]->checked_type();
- tvm::Array<Var> params = {Var("lhs", paramType),
- Var("rhs", paramType)};
+ tvm::Array<Var> params = {Var("lhs", paramType), Var("rhs", paramType)};
// use primitive op in this case
Expr callOp = Call(call_node->op, {params[0], params[1]});
Expr func = Function(params, callOp, paramType, Array<TypeVar>());
}
/*!
- * \brief Convert calls to other ops by converting args into TensorType
- * \return call expr returning result of op
- */
+ * \brief Convert calls to other ops by converting args into TensorType
+ * \return call expr returning result of op
+ */
Expr CallPrimitiveOp(const CallNode* call_node) {
const auto fromFunc = module_->GetGlobalVar("FromGradCell");
tvm::Array<Expr> args;
// use FromGradCell to convert args to Tensor
for (Expr expr : call_node->args) {
- args.push_back(Call(fromFunc,
- {VisitExpr(expr)}, Attrs(), {expr->checked_type()}));
+ args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()}));
}
// result of operation
return Call(call_node->op, args);
namespace transform {
Pass LazyGradientInit() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(LazyGradientInit(f, m));
- };
- return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {});
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(LazyGradientInit(f, m));
+ };
+ return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit")
-.set_body_typed(LazyGradientInit);
+TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit").set_body_typed(LazyGradientInit);
} // namespace transform
* shape, dtype or layout to another op or a sequence of ops.
*/
-#include <tvm/te/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
namespace tvm {
namespace relay {
#ifndef TVM_RELAY_TRANSFORMS_LET_LIST_H_
#define TVM_RELAY_TRANSFORMS_LET_LIST_H_
-#include <tvm/relay/expr.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+
+#include <string>
+#include <tuple>
#include <utility>
#include <vector>
-#include <tuple>
-#include <string>
+
#include "tvm/relay/type.h"
namespace tvm {
*
* \return a Var that hold the inserted expr.
*/
- Var Push(Expr expr, Type ty) {
- return Push(Var("x", ty), expr);
- }
+ Var Push(Expr expr, Type ty) { return Push(Var("x", ty), expr); }
/*!
* \brief insert a binding.
*
* \return a Var that hold the inserted expr.
*/
- Var Push(Expr expr) {
- return Push(expr, Type());
- }
+ Var Push(Expr expr) { return Push(expr, Type()); }
/*!
* \brief wrap an expr around the LetList.
*
* \return the wrapped Expr.
*/
- template<typename F>
+ template <typename F>
static Expr With(F&& f) {
LetList ll;
return ll.Get(f(&ll));
}
static Expr LetBind(const Expr& e, const std::function<Expr(const Var&)>& f) {
- return With([&](LetList* ll) {
- return f(ll->Push(e));
- });
+ return With([&](LetList* ll) { return f(ll->Push(e)); });
}
private:
return func_pass;
}
-TVM_REGISTER_GLOBAL("relay._transform.MergeComposite")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
tvm::Array<runtime::String> pattern_names = args[0];
tvm::Array<Expr> patterns = args[1];
std::vector<PackedFunc> checks;
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
-#include <tvm/relay/transform.h>
#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
-#include "pass_util.h"
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+
#include "let_list.h"
+#include "pass_util.h"
namespace tvm {
namespace relay {
* Use VarHash to hash Var by id.
*/
struct VarHash {
- size_t operator()(const Var& v) const {
- return ObjectHash()(v->vid);
- }
+ size_t operator()(const Var& v) const { return ObjectHash()(v->vid); }
};
/*! \brief Compare Var by it's id.
* Use VarEqual to compare Var by id.
*/
struct VarEqual {
- bool operator()(const Var& l, const Var& r) const {
- return l->vid.get() == r->vid.get();
- }
+ bool operator()(const Var& l, const Var& r) const { return l->vid.get() == r->vid.get(); }
};
Expr PostProcess(const Expr&);
public:
Static() {}
explicit Static(ObjectPtr<Object> n) : ObjectRef(n) {}
- const StaticNode* operator->() const {
- return static_cast<const StaticNode*>(get());
- }
+ const StaticNode* operator->() const { return static_cast<const StaticNode*>(get()); }
using ContainerType = StaticNode;
};
Static pstatic; // may be null
Expr dynamic;
Time created_time;
- PStaticNode(const Static& pstatic, const Expr& dynamic) :
- pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
- explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
+ PStaticNode(const Static& pstatic, const Expr& dynamic)
+ : pstatic(pstatic), dynamic(dynamic), created_time(time()) {}
+ explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) {}
static constexpr const char* _type_key = "relay.PStatic";
TVM_DECLARE_FINAL_OBJECT_INFO(PStaticNode, Object);
};
struct STupleNode : StaticNode {
std::vector<PStatic> fields;
- explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) { }
+ explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) {}
static constexpr const char* _type_key = "relay.STuple";
TVM_DECLARE_FINAL_OBJECT_INFO(STupleNode, StaticNode);
};
struct STensorNode : StaticNode {
runtime::NDArray data;
- explicit STensorNode(const NDArray& data) : data(data) { }
+ explicit STensorNode(const NDArray& data) : data(data) {}
static constexpr const char* _type_key = "relay.STensor";
TVM_DECLARE_FINAL_OBJECT_INFO(STensorNode, StaticNode);
};
TVM_DEFINE_OBJECT_REF_METHODS(STensor, Static, STensorNode);
};
-Static MkSTensor(const NDArray& data) {
- return Static(make_object<STensorNode>(data));
-}
+Static MkSTensor(const NDArray& data) { return Static(make_object<STensorNode>(data)); }
struct SConstructorNode : StaticNode {
Constructor constructor;
std::vector<PStatic> fields;
- SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields) :
- constructor(constructor), fields(fields) { }
+ SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields)
+ : constructor(constructor), fields(fields) {}
static constexpr const char* _type_key = "relay.SConstructor";
TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode);
};
TVM_DEFINE_OBJECT_REF_METHODS(SRef, Static, SRefNode);
};
-Static MkSRef() {
- return Static(make_object<SRefNode>());
-}
+Static MkSRef() { return Static(make_object<SRefNode>()); }
-using Func = std::function<PStatic(const PStatic&,
- const std::vector<PStatic>&,
- const Attrs&,
- const Array<Type>&,
- LetList*)>;
+using Func = std::function<PStatic(const PStatic&, const std::vector<PStatic>&, const Attrs&,
+ const Array<Type>&, LetList*)>;
struct SFuncNode : StaticNode {
Func func;
- explicit SFuncNode(const Func& func) : func(func) { }
+ explicit SFuncNode(const Func& func) : func(func) {}
static constexpr const char* _type_key = "relay.SFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(SFuncNode, StaticNode);
};
TVM_DEFINE_OBJECT_REF_METHODS(SFunc, Static, SFuncNode);
};
-Static MkSFunc(const Func& func) {
- return Static(make_object<SFuncNode>(func));
-}
-
+Static MkSFunc(const Func& func) { return Static(make_object<SFuncNode>(func)); }
class FuelNode;
/*! \brief A meet-semilattice with finite descending chain.
* It means that we can meet two element to get an element,
- * and for every element, there is only a finite amount of meet before getting back the same element.
+ * and for every element, there is only a finite amount of meet before getting back the same
+ * element.
*
* Every time we recurse, we do a meet and require that progress must be made.
* This ensures we do not recurse infinitely in the Partial Evaluator.
TVM_DECLARE_BASE_OBJECT_INFO(FuelNode, RelayNode);
};
-const FuelNode* Fuel::operator->() const {
- return static_cast<const FuelNode*>(get());
-}
+const FuelNode* Fuel::operator->() const { return static_cast<const FuelNode*>(get()); }
Fuel MkFSeq(const std::vector<Fuel>& fuels);
struct FSeqNode : FuelNode {
}
return MkFSeq(new_fuels);
}
- explicit FSeqNode(const std::vector<Fuel>& fuels) : fuels(fuels) { }
+ explicit FSeqNode(const std::vector<Fuel>& fuels) : fuels(fuels) {}
static constexpr const char* _type_key = "relay.FSeq";
TVM_DECLARE_FINAL_OBJECT_INFO(FSeqNode, FuelNode);
};
TVM_DEFINE_OBJECT_REF_METHODS(FSeq, Fuel, FSeqNode);
};
-Fuel MkFSeq(const std::vector<Fuel>& fuels) {
- return Fuel(make_object<FSeqNode>(fuels));
-}
+Fuel MkFSeq(const std::vector<Fuel>& fuels) { return Fuel(make_object<FSeqNode>(fuels)); }
Fuel MkFTime(Time time);
struct FTimeNode : FuelNode {
Time new_time = std::min(time, x->time);
return std::make_tuple(MkFTime(new_time), new_time < time);
}
- explicit FTimeNode(Time time) : time(time) { }
+ explicit FTimeNode(Time time) : time(time) {}
static constexpr const char* _type_key = "relay.FTime";
TVM_DECLARE_FINAL_OBJECT_INFO(FTimeNode, FuelNode);
};
TVM_DEFINE_OBJECT_REF_METHODS(FTime, Fuel, FTimeNode);
};
-Fuel MkFTime(Time time) {
- return Fuel(make_object<FTimeNode>(time));
-}
+Fuel MkFTime(Time time) { return Fuel(make_object<FTimeNode>(time)); }
Fuel MkFTValue(size_t tvalue);
/*! \brief If the pstatic is hold a positive integer scalar, that number, else 0. */
size_t new_tvalue = std::min(tvalue, x->tvalue);
return std::make_tuple(MkFTValue(new_tvalue), new_tvalue < tvalue);
}
- explicit FTValueNode(size_t tvalue) : tvalue(tvalue) { }
+ explicit FTValueNode(size_t tvalue) : tvalue(tvalue) {}
static constexpr const char* _type_key = "relay.FTValue";
TVM_DECLARE_FINAL_OBJECT_INFO(FTValueNode, FuelNode);
};
TVM_DEFINE_OBJECT_REF_METHODS(FTValue, Fuel, FTValueNode);
};
-Fuel MkFTValue(size_t tvalue) {
- return Fuel(make_object<FTValueNode>(tvalue));
-}
+Fuel MkFTValue(size_t tvalue) { return Fuel(make_object<FTValueNode>(tvalue)); }
/*! \brief Initially every element has Fuel of FTop. It is the largest element.
*
TVM_DEFINE_OBJECT_REF_METHODS(FTop, Fuel, FTopNode);
};
-Fuel MkFTop() {
- return Fuel(make_object<FTopNode>());
-}
+Fuel MkFTop() { return Fuel(make_object<FTopNode>()); }
/*!
* \brief A stack frame in the Relay interpreter.
class Environment {
public:
- Environment() : env_({Frame()}) { }
+ Environment() : env_({Frame()}) {}
Environment(const Environment&) = delete;
- template<typename T>
+ template <typename T>
T Extend(const std::function<T()>& body) {
FrameContext fc(this);
return body();
struct FrameContext {
Environment* env_;
- explicit FrameContext(Environment* env) : env_(env) {
- env_->env_.push_back(Frame());
- }
- ~FrameContext() {
- env_->env_.pop_back();
- }
+ explicit FrameContext(Environment* env) : env_(env) { env_->env_.push_back(Frame()); }
+ ~FrameContext() { env_->env_.pop_back(); }
};
};
* It only outdate the frame above it, but not the current frame.
*/
bool history_valid = true;
- explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) { }
+ explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) {}
StoreFrame() = default;
};
class Store {
public:
- Store() : store_({StoreFrame()}) { }
+ Store() : store_({StoreFrame()}) {}
Store(const Store&) = delete;
- template<typename T>
+ template <typename T>
T Extend(const std::function<T()>& body) {
StoreFrameContext sfc(this);
return body();
return PStatic(make_object<PStaticNode>(stat, dynamic));
}
-PStatic NoStatic(const Expr& dynamic) {
- return PStatic(make_object<PStaticNode>(dynamic));
-}
+PStatic NoStatic(const Expr& dynamic) { return PStatic(make_object<PStaticNode>(dynamic)); }
-enum struct MatchStatus {
- Match, NoMatch, Unknown
-};
+enum struct MatchStatus { Match, NoMatch, Unknown };
bool StatefulOp(const Expr& e) {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
FuncId fid;
TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") {
- TVM_ATTR_FIELD(fid)
- .describe("The FuncId that an function is annotated with.")
- .set_default(-1);
+ TVM_ATTR_FIELD(fid).describe("The FuncId that an function is annotated with.").set_default(-1);
}
};
TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
-
RELAY_REGISTER_OP("annotation.with_funcid")
-.describe(R"code(Annotate a function with a funcid.)code"
-TVM_ADD_FILELINE)
-.set_num_inputs(1)
-.add_argument("func", "Function", "The input data.");
+ .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE)
+ .set_num_inputs(1)
+ .add_argument("func", "Function", "The input data.");
// Cache with_funcid op to reduce lookup overhead during traversal.
static const Op& with_funcid_op = Op::Get("annotation.with_funcid");
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public:
- PartialEvaluator(const IRModule& mod) : mod_(mod) { }
+ PartialEvaluator(const IRModule& mod) : mod_(mod) {}
PStatic VisitExpr(const Expr& e, LetList* ll) final {
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
return VisitExpr(c->args[0], ll, name);
}
}
- PStatic ret = e.as<FunctionNode>() ?
- VisitFunc(Downcast<Function>(e), ll, name) :
- VisitExpr(e, ll);
+ PStatic ret =
+ e.as<FunctionNode>() ? VisitFunc(Downcast<Function>(e), ll, name) : VisitExpr(e, ll);
CHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
return ret;
}
}
}
- PStatic VisitExpr_(const VarNode* op, LetList* ll) final {
- return env_.Lookup(GetRef<Var>(op));
- }
+ PStatic VisitExpr_(const VarNode* op, LetList* ll) final { return env_.Lookup(GetRef<Var>(op)); }
PStatic VisitGlobalVar(const GlobalVar& gv) {
CHECK(mod_.defined());
}
} else {
Expr t = store_.Extend<Expr>([&]() {
- return LetList::With([&](LetList* ll) {
- return VisitExpr(op->true_branch, ll)->dynamic;
- });
- });
+ return LetList::With([&](LetList* ll) { return VisitExpr(op->true_branch, ll)->dynamic; });
+ });
Expr f = store_.Extend<Expr>([&]() {
- return LetList::With([&](LetList* ll) {
- return VisitExpr(op->false_branch, ll)->dynamic;
- });
- });
+ return LetList::With([&](LetList* ll) { return VisitExpr(op->false_branch, ll)->dynamic; });
+ });
store_.Invalidate();
return NoStatic(ll->Push(If(c->dynamic, t, f)));
}
PartialEvaluator* pe_;
FuncId fid_;
Fuel old_fuel;
- FuelFrame(PartialEvaluator* pe,
- FuncId fid,
- const Fuel& new_fuel) : pe_(pe), fid_(fid) {
+ FuelFrame(PartialEvaluator* pe, FuncId fid, const Fuel& new_fuel) : pe_(pe), fid_(fid) {
CHECK_GT(pe_->fuel_map_.count(fid_), 0);
old_fuel = pe_->fuel_map_[fid_];
pe_->fuel_map_[fid_] = new_fuel;
}
- ~FuelFrame() {
- pe_->fuel_map_[fid_] = old_fuel;
- }
+ ~FuelFrame() { pe_->fuel_map_[fid_] = old_fuel; }
};
size_t GetFTValue(const PStatic& ps) {
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
}
}
- return [=](const PStatic& self,
- const std::vector<PStatic>& pv,
- const Attrs& attrs,
- const tvm::Array<Type>& type_args,
- LetList* ll) {
+ return [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs,
+ const tvm::Array<Type>& type_args, LetList* ll) {
return env_.Extend<PStatic>([&]() {
- CHECK_EQ(pv.size(), func->params.size());
- CHECK_GT(func_map_.count(func), 0);
- FuncId fid = func_map_.at(func);
- if (fuel_map_.count(fid) == 0) {
- fuel_map_.insert({fid, MkFTop()});
+ CHECK_EQ(pv.size(), func->params.size());
+ CHECK_GT(func_map_.count(func), 0);
+ FuncId fid = func_map_.at(func);
+ if (fuel_map_.count(fid) == 0) {
+ fuel_map_.insert({fid, MkFTop()});
+ }
+ std::vector<Fuel> args_fuel;
+ for (const auto& v : pv) {
+ args_fuel.push_back(GetFuel(v));
+ }
+ auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel));
+ if (std::get<1>(meet_res)) {
+ FuelFrame tf(this, fid, std::get<0>(meet_res));
+ Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func)));
+ Function func = AsFunc(dedup_func);
+ if (var.as<VarNode>()) {
+ env_.Insert(Downcast<Var>(var), self);
}
- std::vector<Fuel> args_fuel;
- for (const auto& v : pv) {
- args_fuel.push_back(GetFuel(v));
+ for (size_t i = 0; i < pv.size(); ++i) {
+ env_.Insert(func->params[i], pv[i]);
+ }
+ for (const auto& p : free_vars) {
+ env_.Insert(p.first, p.second);
+ }
+ tvm::Map<TypeVar, Type> subst;
+ for (size_t i = 0; i < type_args.size(); ++i) {
+ subst.Set(func->type_params[i], type_args[i]);
}
- auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel));
- if (std::get<1>(meet_res)) {
- FuelFrame tf(this, fid, std::get<0>(meet_res));
- Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func)));
- Function func = AsFunc(dedup_func);
- if (var.as<VarNode>()) {
- env_.Insert(Downcast<Var>(var), self);
- }
- for (size_t i = 0; i < pv.size(); ++i) {
- env_.Insert(func->params[i], pv[i]);
- }
- for (const auto& p : free_vars) {
- env_.Insert(p.first, p.second);
- }
- tvm::Map<TypeVar, Type> subst;
- for (size_t i = 0; i < type_args.size(); ++i) {
- subst.Set(func->type_params[i], type_args[i]);
- }
- for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
- subst.Set(func->type_params[i], IncompleteType(kType));
- }
- return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
- } else {
- std::vector<Expr> dyn;
- for (const auto& v : pv) {
- dyn.push_back(v->dynamic);
- }
- return NoStatic(ll->Push(Call(var, dyn, attrs, type_args)));
+ for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
+ subst.Set(func->type_params[i], IncompleteType(kType));
}
- });
+ return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
+ } else {
+ std::vector<Expr> dyn;
+ for (const auto& v : pv) {
+ dyn.push_back(v->dynamic);
+ }
+ return NoStatic(ll->Push(Call(var, dyn, attrs, type_args)));
+ }
+ });
};
}
Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) {
return store_.Extend<Expr>([&]() {
store_.Invalidate();
- return Function(func->params,
- LetList::With([&](LetList* ll) {
- std::vector<PStatic> pv;
- for (const auto& v : func->params) {
- pv.push_back(NoStatic(v));
- }
- tvm::Array<Type> type_args;
- for (const auto& tp : func->type_params) {
- type_args.push_back(tp);
- }
- return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic;
- }), func->ret_type, func->type_params, func->attrs);
+ return Function(func->params, LetList::With([&](LetList* ll) {
+ std::vector<PStatic> pv;
+ for (const auto& v : func->params) {
+ pv.push_back(NoStatic(v));
+ }
+ tvm::Array<Type> type_args;
+ for (const auto& tp : func->type_params) {
+ type_args.push_back(tp);
+ }
+ return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic;
+ }),
+ func->ret_type, func->type_params, func->attrs);
});
}
- PStatic VisitFunc(const Function& func,
- LetList* ll,
- const Var& name = Var("x", Type())) {
+ PStatic VisitFunc(const Function& func, LetList* ll, const Var& name = Var("x", Type())) {
Func f = VisitFuncStatic(func, name);
Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
// TODO(@M.K.): we seems to reduce landin knot into letrec.
// restore letrec support across whole relay.
- return HasStatic(MkSFunc(f),
- ll->Push(name, VisitFuncDynamic(u_func, f, name)));
+ return HasStatic(MkSFunc(f), ll->Push(name, VisitFuncDynamic(u_func, f, name)));
}
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
}
struct ReflectError : dmlc::Error {
- ReflectError() : dmlc::Error("static value not found") { }
+ ReflectError() : dmlc::Error("static value not found") {}
};
Expr Reflect(const PStatic& st) {
// Constant evaluate a expression.
PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
- std::vector<transform::Pass> passes = {transform::FuseOps(0),
- transform::InferType()};
+ std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::InferType()};
auto mod = IRModule::FromExpr(expr);
auto seq = transform::Sequential(passes);
mod = seq(mod);
auto entry_func = Downcast<Function>(mod->Lookup("main"));
- auto fused_infered =
- expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
+ auto fused_infered = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return Reify(executor_(fused_infered), ll);
}
Func ConstEvaluateFunc(const Expr& expr) {
CHECK_EQ(FreeVars(expr).size(), 0);
- return [=](const PStatic& self,
- const std::vector<PStatic>& pv,
- const Attrs& attrs,
- const tvm::Array<Type>& type_args,
- LetList* ll) {
+ return [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs,
+ const tvm::Array<Type>& type_args, LetList* ll) {
tvm::Array<Expr> ns_args;
for (const PStatic& ps : pv) {
ns_args.push_back(ps->dynamic);
}
- auto ns = [&]() {
- return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args)));
- };
+ auto ns = [&]() { return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); };
if (StatefulOp(expr)) {
return ns();
}
args.push_back(Reflect(ps));
}
return ConstEvaluate(Call(expr, args, attrs, type_args), ll);
- }
- catch (const ReflectError&) {
+ } catch (const ReflectError&) {
return ns();
}
};
PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
Constructor c = GetRef<Constructor>(op);
- Func f = [=](const PStatic& self,
- const std::vector<PStatic>& pv,
- const Attrs& attrs,
- const tvm::Array<Type>& type_args,
- LetList* ll) {
+ Func f = [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs,
+ const tvm::Array<Type>& type_args, LetList* ll) {
tvm::Array<Expr> dyn;
for (const PStatic& ps : pv) {
dyn.push_back(ps->dynamic);
return env_.Extend<PStatic>([&]() {
for (const Clause& c : op->clauses) {
switch (VisitPattern(c->lhs, ps)) {
- case MatchStatus::Match:
- return VisitExpr(c->rhs, ll);
- case MatchStatus::NoMatch:
- continue;
- case MatchStatus::Unknown:
- return [&]() {
- tvm::Array<Clause> clauses;
- for (const Clause& c : op->clauses) {
- Expr expr = store_.Extend<Expr>([&]() {
- return LetList::With([&](LetList* ll) {
- for (const Var& v : BoundVars(c->lhs)) {
- env_.Insert(v, NoStatic(v));
- }
- return VisitExpr(c->rhs, ll)->dynamic;
+ case MatchStatus::Match:
+ return VisitExpr(c->rhs, ll);
+ case MatchStatus::NoMatch:
+ continue;
+ case MatchStatus::Unknown:
+ return [&]() {
+ tvm::Array<Clause> clauses;
+ for (const Clause& c : op->clauses) {
+ Expr expr = store_.Extend<Expr>([&]() {
+ return LetList::With([&](LetList* ll) {
+ for (const Var& v : BoundVars(c->lhs)) {
+ env_.Insert(v, NoStatic(v));
+ }
+ return VisitExpr(c->rhs, ll)->dynamic;
+ });
});
- });
- clauses.push_back(Clause(c->lhs, expr));
- }
- store_.Invalidate();
- return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete)));
- }();
- default:
- LOG(FATAL) << "Unknown MatchStatus";
- throw;
+ clauses.push_back(Clause(c->lhs, expr));
+ }
+ store_.Invalidate();
+ return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete)));
+ }();
+ default:
+ LOG(FATAL) << "Unknown MatchStatus";
+ throw;
}
}
LOG(FATAL) << "No case Match";
for (size_t i = 0; i < op->patterns.size(); ++i) {
MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]);
switch (ms) {
- case MatchStatus::Match:
- continue;
- case MatchStatus::NoMatch:
- return MatchStatus::NoMatch;
- case MatchStatus::Unknown:
- current_match_status = MatchStatus::Unknown;
+ case MatchStatus::Match:
+ continue;
+ case MatchStatus::NoMatch:
+ return MatchStatus::NoMatch;
+ case MatchStatus::Unknown:
+ current_match_status = MatchStatus::Unknown;
}
}
return current_match_status;
for (size_t i = 0; i < op->patterns.size(); ++i) {
MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]);
switch (ms) {
- case MatchStatus::Match:
- continue;
- case MatchStatus::NoMatch:
- return MatchStatus::NoMatch;
- case MatchStatus::Unknown:
- current_match_status = MatchStatus::Unknown;
+ case MatchStatus::Match:
+ continue;
+ case MatchStatus::NoMatch:
+ return MatchStatus::NoMatch;
+ case MatchStatus::Unknown:
+ current_match_status = MatchStatus::Unknown;
}
}
return current_match_status;
void InitializeFuncId(const Expr& e) {
struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
PartialEvaluator* pe;
- explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
+ explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {}
void VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
VisitExpr(f->body);
}
- void VisitPattern(const Pattern& p) final {
- PatternVisitor::VisitPattern(p);
- }
+ void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
};
InitializeFuncIdVisitor(this).VisitExpr(e);
}
Expr RegisterFuncId(const Expr& e) {
struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor {
PartialEvaluator* pe;
- explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
+ explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {}
void VisitExpr_(const CallNode* op) final {
if (op->op == with_funcid_op) {
ExprVisitor::VisitExpr_(op);
}
- void VisitPattern(const Pattern& p) final {
- PatternVisitor::VisitPattern(p);
- }
+ void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
};
RegisterFuncIdVisitor(this).VisitExpr(e);
return e;
Expr AnnotateFuncId(const Expr& e) {
struct AnnotateFuncIdMutator : ExprMutator, PatternMutator {
PartialEvaluator* pe;
- explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { }
+ explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) {}
Expr VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f));
}
- Pattern VisitPattern(const Pattern& p) final {
- return PatternMutator::VisitPattern(p);
- }
+ Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
- Var VisitVar(const Var& v) final {
- return v;
- }
+ Var VisitVar(const Var& v) final { return v; }
};
return AnnotateFuncIdMutator(this).VisitExpr(e);
}
* If no progress is made, we do not inline.
* In both case, we remap the mapping to the new Fuel
* when we PE inside the Function body.
- * Termination is guaranteed because Fuel is finitely descending - there can only be so many meet.
+ * Termination is guaranteed because Fuel is finitely descending - there can only be so many
+ * meet.
*/
std::unordered_map<Function, FuncId, ObjectHash, ObjectEqual> func_map_;
std::unordered_map<FuncId, Fuel> fuel_map_;
return remap_.at(v);
}
- Var VisitVar(const Var& v) final {
- return Downcast<Var>(VisitExpr(v));
- }
+ Var VisitVar(const Var& v) final { return Downcast<Var>(VisitExpr(v)); }
private:
std::unordered_map<Var, Var, VarHash, VarEqual> remap_;
}
}
- Pattern VisitPattern(const Pattern& p) final {
- return PatternMutator::VisitPattern(p);
- }
+ Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
- Var VisitVar(const Var& v) final {
- return v;
- }
+ Var VisitVar(const Var& v) final { return v; }
};
return StripWithFuncIdMutator().VisitExpr(e);
}
-Expr PostProcess(const Expr& e) {
- return StripWithFuncId(DeDup(Remap(e)));
-}
+Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); }
} // namespace partial_eval
Pass PartialEval() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) {
- return relay::PartialEval(m);
- };
+ [=](IRModule m, PassContext pc) { return relay::PartialEval(m); };
return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate")
-.set_body_typed(PartialEval);
+TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval);
} // namespace transform
#ifndef TVM_RELAY_TRANSFORMS_PASS_UTIL_H_
#define TVM_RELAY_TRANSFORMS_PASS_UTIL_H_
-#include <tvm/relay/op.h>
-#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+
#include <memory>
#include <unordered_map>
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}
-template<typename ConditionObjectPtr>
+template <typename ConditionObjectPtr>
struct TreeNode {
typedef std::shared_ptr<TreeNode<ConditionObjectPtr>> pointer;
virtual ~TreeNode() {}
};
-template<typename ConditionObjectPtr>
+template <typename ConditionObjectPtr>
struct TreeLeafNode : TreeNode<ConditionObjectPtr> {
using TreeObjectPtr = typename TreeNode<ConditionObjectPtr>::pointer;
Expr body;
- explicit TreeLeafNode(Expr body): body(body) {}
+ explicit TreeLeafNode(Expr body) : body(body) {}
- static TreeObjectPtr Make(Expr body) {
- return std::make_shared<TreeLeafNode>(body);
- }
+ static TreeObjectPtr Make(Expr body) { return std::make_shared<TreeLeafNode>(body); }
~TreeLeafNode() {}
};
-template<typename ConditionObjectPtr>
+template <typename ConditionObjectPtr>
struct TreeLeafFatalNode : TreeNode<ConditionObjectPtr> {
using TreeObjectPtr = typename TreeNode<ConditionObjectPtr>::pointer;
TreeLeafFatalNode() = default;
- static TreeObjectPtr Make() {
- return std::make_shared<TreeLeafFatalNode>();
- }
+ static TreeObjectPtr Make() { return std::make_shared<TreeLeafFatalNode>(); }
~TreeLeafFatalNode() {}
};
-template<typename ConditionObjectPtr>
+template <typename ConditionObjectPtr>
struct TreeBranchNode : TreeNode<ConditionObjectPtr> {
using TreeObjectPtr = typename TreeNode<ConditionObjectPtr>::pointer;
TreeObjectPtr then_branch;
TreeObjectPtr else_branch;
- TreeBranchNode(ConditionObjectPtr cond,
- TreeObjectPtr then_branch,
- TreeObjectPtr else_branch)
- : cond(cond), then_branch(then_branch), else_branch(else_branch) {}
-
+ TreeBranchNode(ConditionObjectPtr cond, TreeObjectPtr then_branch, TreeObjectPtr else_branch)
+ : cond(cond), then_branch(then_branch), else_branch(else_branch) {}
- static TreeObjectPtr Make(ConditionObjectPtr cond,
- TreeObjectPtr then_branch,
- TreeObjectPtr else_branch) {
+ static TreeObjectPtr Make(ConditionObjectPtr cond, TreeObjectPtr then_branch,
+ TreeObjectPtr else_branch) {
return std::make_shared<TreeBranchNode>(cond, then_branch, else_branch);
}
#include <builtin_fp16.h>
#include <tvm/node/structural_equal.h>
-#include <tvm/tir/data_layout.h>
-#include <tvm/relay/op.h>
-#include <tvm/relay/expr.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
-#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/reduce.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
+#include <tvm/tir/data_layout.h>
#include <string>
-#include <vector>
#include <utility>
-
+#include <vector>
namespace tvm {
namespace relay {
* \brief Dispatch DataType to the C++ data type
* during runtime.
*/
-#define TVM_DTYPE_DISPATCH(type, DType, ...) \
- if (type == DataType::Float(64)) { \
- typedef double DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Float(32)) { \
- typedef float DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Float(16)) { \
- typedef uint16_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Int(64)) { \
- typedef int64_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Int(32)) { \
- typedef int32_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Int(16)) { \
- typedef int16_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Int(8)) { \
- typedef int8_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::UInt(64)) { \
- typedef uint64_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::UInt(32)) { \
- typedef uint32_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::UInt(16)) { \
- typedef uint16_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::UInt(8)) { \
- typedef uint8_t DType; \
- {__VA_ARGS__} \
- } else { \
- LOG(FATAL) << "unknown data type " << type; \
+#define TVM_DTYPE_DISPATCH(type, DType, ...) \
+ if (type == DataType::Float(64)) { \
+ typedef double DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Float(32)) { \
+ typedef float DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Float(16)) { \
+ typedef uint16_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Int(64)) { \
+ typedef int64_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Int(32)) { \
+ typedef int32_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Int(16)) { \
+ typedef int16_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Int(8)) { \
+ typedef int8_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::UInt(64)) { \
+ typedef uint64_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::UInt(32)) { \
+ typedef uint32_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::UInt(16)) { \
+ typedef uint16_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::UInt(8)) { \
+ typedef uint8_t DType; \
+ { __VA_ARGS__ } \
+ } else { \
+ LOG(FATAL) << "unknown data type " << type; \
}
/*!
* \param rhs_value A squeezed version of rhs which only contains matched dimension.
* \return Whether match is successful.
*/
-inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
- const TensorTypeNode* trhs,
- const Array<Integer>& lhs_axes,
- Expr* rhs_value = nullptr) {
+inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const TensorTypeNode* trhs,
+ const Array<Integer>& lhs_axes, Expr* rhs_value = nullptr) {
if (tlhs->shape.size() < trhs->shape.size()) return false;
StructuralEqual equal;
size_t base = tlhs->shape.size() - trhs->shape.size();
* \param target_ndim Target dimension.
* \param axes The axis on the output we want to match on.
*/
-inline Expr ExpandBiasToMatchAxis(Expr bias,
- int target_ndim,
- const Array<Integer>& axes) {
+inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array<Integer>& axes) {
static const Op& expand_dims = Op::Get("expand_dims");
for (size_t i = axes.size(); i != 0; --i) {
if (i == axes.size()) {
* \param param The conv2d attributes.
* \return Whether it is depthwise_conv2d.
*/
-inline bool IsDepthwiseConv2D(const Call& call,
- const Conv2DAttrs* param,
+inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param,
const Layout& kernel_layout) {
static const Layout kOIHW("OIHW");
const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW);
auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
- return tir::is_const_int(wshape[0], param->groups) &&
- tir::is_const_int(wshape[1], 1);
+ return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1);
}
/*!
* \return Super-dimension size of output channels of conv2d.
*/
inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
- auto param = call->attrs.as<Conv2DAttrs>();
- auto tweight = call->args[1]->type_as<TensorTypeNode>();
- auto index = param->kernel_layout.find('O');
- CHECK_NE(index, std::string::npos);
- auto channels = tir::as_const_int(tweight->shape[index]);
- return *channels;
+ auto param = call->attrs.as<Conv2DAttrs>();
+ auto tweight = call->args[1]->type_as<TensorTypeNode>();
+ auto index = param->kernel_layout.find('O');
+ CHECK_NE(index, std::string::npos);
+ auto channels = tir::as_const_int(tweight->shape[index]);
+ return *channels;
}
/*!
return tvm::StructuralEqual()(a, b);
}
-inline Expr GetField(Expr t, size_t i) {
- return TupleGetItem(t, i);
-}
+inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); }
-inline Expr Pair(Expr l, Expr r) {
- return Tuple({l, r});
-}
+inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); }
inline Expr Exp(Expr e) {
static const Op& op = Op::Get("exp");
return Call(op, {x}, Attrs(), {});
}
-
inline Expr Sqrt(Expr x) {
static const Op& op = Op::Get("sqrt");
return Call(op, {x}, Attrs(), {});
}
-
inline Expr Relu(Expr x) {
static const Op& op = Op::Get("nn.relu");
return Call(op, {x}, Attrs(), {});
}
-
inline Expr Round(Expr x) {
static const Op& op = Op::Get("round");
return Call(op, {x}, Attrs(), {});
}
-
inline Expr Clip(Expr x, double a_min, double a_max) {
static const Op& op = Op::Get("clip");
auto attrs = make_object<ClipAttrs>();
return Call(op, {x}, Attrs(attrs), {});
}
-
inline Expr Add(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return Call(op, {lhs, rhs}, Attrs(), {});
}
-
inline Expr Subtract(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("subtract");
return Call(op, {lhs, rhs}, Attrs(), {});
}
-
inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply");
return Call(op, {lhs, rhs}, Attrs(), {});
}
-
inline Expr Divide(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("divide");
return Call(op, {lhs, rhs}, Attrs(), {});
return Call(op, {lhs, rhs}, Attrs(), {});
}
-
inline Expr RightShift(Expr x, Expr nbit) {
static const Op& op = Op::Get("right_shift");
return Call(op, {x, nbit}, Attrs(), {});
}
-
inline Expr LeftShift(Expr x, Expr nbit) {
static const Op& op = Op::Get("left_shift");
return Call(op, {x, nbit}, Attrs(), {});
}
-
inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return Call(op, {lhs, rhs}, Attrs(), {});
}
-
inline Expr Copy(Expr data) {
static const Op& op = Op::Get("copy");
return Call(op, {data}, Attrs(), {});
}
-
inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
auto attrs = make_object<ReduceAttrs>();
attrs->axis = std::move(axis);
return Call(op, {data, mean}, Attrs(attrs), {});
}
-
static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
static const Op& op = Op::Get("where");
return Call(op, {condition, x, y});
return Call(op, {lhs, rhs}, Attrs(), {});
}
-static inline Expr Full(Expr fill_value,
- Array<IndexExpr> shape,
- DataType dtype) {
+static inline Expr Full(Expr fill_value, Array<IndexExpr> shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
return Call(op, {data, weight}, Attrs(attrs), {});
}
-static inline Expr Dense(Expr data,
- Expr weight,
- IndexExpr units,
- DataType out_dtype) {
+static inline Expr Dense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
auto attrs = make_object<DenseAttrs>();
attrs->units = units;
attrs->out_dtype = out_dtype;
// Remove FreeVar warning
auto f0 = Downcast<Function>(SimplifyFCTranspose(f, target_weights));
Array<Var> wt_params = FreeVars(f0);
- auto f1 = Function(wt_params,
- f0->body,
- f0->ret_type,
- f0->type_params,
- f0->attrs);
+ auto f1 = Function(wt_params, f0->body, f0->ret_type, f0->type_params, f0->attrs);
Array<Var> params = FreeVars(f1);
for (const auto& var : wt_params) {
params.push_back(var);
}
- return Function(params,
- f1->body,
- f1->ret_type,
- f1->type_params,
- f1->attrs);
+ return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs);
};
return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"});
}
* \file simplify_inference.cc
*/
#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
-#include <tvm/relay/transform.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
+#include <tvm/relay/transform.h>
+
#include "pattern_util.h"
namespace tvm {
namespace relay {
-Expr BatchNormToInferUnpack(const Attrs attrs,
- Expr data,
- Expr gamma,
- Expr beta,
- Expr moving_mean,
- Expr moving_var,
- Type tdata) {
+Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean,
+ Expr moving_var, Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<BatchNormAttrs>();
return out;
}
-
-Expr GroupNormToInferUnpack(const Attrs attrs,
- Expr data,
- Expr gamma,
- Expr beta,
- Type tdata) {
+Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<GroupNormAttrs>();
// new shape = N, num_groups, C/num_groups, H, W
// reduce_axes = axis of (C/num_groups, H, W)
for (int i = 0; i < ndim; ++i) {
- auto val = ttype->shape[i].as<IntImmNode>()->value;
-
- // Save the old shape to reshape later
- old_shape.push_back(val);
- if (i == axis) {
- new_shape.push_back(num_groups);
- new_shape.push_back(channel / num_groups);
- reduced_axes.push_back(i + 1);
- continue;
- }
- if (i >= axis) {
- reduced_axes.push_back(i + 1);
- }
- new_shape.push_back(val);
+ auto val = ttype->shape[i].as<IntImmNode>()->value;
+
+ // Save the old shape to reshape later
+ old_shape.push_back(val);
+ if (i == axis) {
+ new_shape.push_back(num_groups);
+ new_shape.push_back(channel / num_groups);
+ reduced_axes.push_back(i + 1);
+ continue;
+ }
+ if (i >= axis) {
+ reduced_axes.push_back(i + 1);
+ }
+ new_shape.push_back(val);
}
data = Reshape(data, new_shape);
return out;
}
-Expr LayerNormToInferUnpack(const Attrs attrs,
- Expr data,
- Expr gamma,
- Expr beta,
- Type tdata) {
+Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<LayerNormAttrs>();
return out;
}
-Expr InstanceNormToInferUnpack(const Attrs attrs,
- Expr data,
- Expr gamma,
- Expr beta,
- Type tdata) {
+Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<InstanceNormAttrs>();
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
Array<Integer> reduced_axes;
for (int i = 1; i < ndim; ++i) {
- if (i != axis)
- reduced_axes.push_back(i);
+ if (i != axis) reduced_axes.push_back(i);
}
Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast<float>(param->epsilon));
std::unordered_map<Expr, Type, ObjectHash, ObjectEqual> ty_map_;
};
-Expr SimplifyInference(const Expr& e) {
- return InferenceSimplifier().Mutate(e);
-}
+Expr SimplifyInference(const Expr& e) { return InferenceSimplifier().Mutate(e); }
namespace transform {
Pass SimplifyInference() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(SimplifyInference(f));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(SimplifyInference(f));
+ };
return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"});
}
-TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference")
-.set_body_typed(SimplifyInference);
+TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference").set_body_typed(SimplifyInference);
} // namespace transform
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h>
-#include "let_list.h"
-#include "pass_util.h"
+
#include "../../support/arena.h"
#include "../analysis/dependency_graph.h"
+#include "let_list.h"
+#include "pass_util.h"
namespace tvm {
namespace relay {
size_t level;
Scope parent;
std::shared_ptr<LetList> ll = std::make_shared<LetList>();
- explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { }
- ScopeNode() : level(0) { }
+ explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {}
+ ScopeNode() : level(0) {}
};
-Scope ChildScope(const Scope& s) {
- return std::make_shared<ScopeNode>(s);
-}
+Scope ChildScope(const Scope& s) { return std::make_shared<ScopeNode>(s); }
Scope LCA(Scope lhs, Scope rhs) {
while (lhs != rhs) {
*/
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public:
- static Expr ToANormalForm(const Expr& e,
- const DependencyGraph& dg,
+ static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) {
Fill fi(dg, node_scope);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo;
- Fill(const DependencyGraph& dg,
- std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) :
- dg_(dg),
- node_scope_(node_scope) { }
+ Fill(const DependencyGraph& dg, std::unordered_map<DependencyGraph::Node*, Scope>* node_scope)
+ : dg_(dg), node_scope_(node_scope) {}
- Scope GetScope(const Expr& e) {
- return node_scope_->at(dg_.expr_node.at(e));
- }
+ Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); }
Scope GetSubScope(const Expr& e, size_t i) {
DependencyGraph::Node* n = dg_.expr_node.at(e);
return ret;
}
- Expr VisitExpr(const Expr& e) {
- return this->VisitExpr(e, Var());
- }
+ Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); }
- Expr Atomic(const Expr& e, const Var& v) {
- return v.defined() ? GetScope(e)->ll->Push(v, e) : e;
- }
+ Expr Atomic(const Expr& e, const Var& v) { return v.defined() ? GetScope(e)->ll->Push(v, e) : e; }
Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
- Var var = v.defined() ?
- v :
- Var(std::string("x"), Type());
+ Var var = v.defined() ? v : Var(std::string("x"), Type());
return GetScope(orig)->ll->Push(var, now);
}
Expr VisitExpr_(const IfNode* i, const Var& v) final {
Expr e = GetRef<Expr>(i);
- Expr ret = If(VisitExpr(i->cond),
- GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)),
- GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch)));
+ Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)),
+ GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch)));
return Compound(e, ret, v);
}
if (f->HasNonzeroAttr(attr::kPrimitive)) {
ret = e;
} else {
- ret = Function(f->params,
- GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)),
- f->ret_type,
- f->type_params,
- f->attrs);
+ ret = Function(f->params, GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), f->ret_type,
+ f->type_params, f->attrs);
}
return Compound(e, ret, v);
}
Expr data = VisitExpr(m->data);
std::vector<Clause> clauses;
for (const Clause& c : m->clauses) {
- clauses.push_back(Clause(
- c->lhs,
- GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
+ clauses.push_back(
+ Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
}
return Compound(e, Match(data, clauses, m->complete), v);
}
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
}
- Expr ret =
- TransformF([&](const Expr& e) {
- return ToANormalFormAux(e);
- }, it.second);
+ Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second);
CHECK_EQ(FreeVars(ret).size(), 0)
- << AsText(ret)
- << "should not has free vars: "
- << FreeVars(ret);
+ << AsText(ret) << "should not has free vars: " << FreeVars(ret);
updates.Set(it.first, Downcast<Function>(ret));
}
Pass ToANormalForm() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) {
- return relay::ToANormalForm(m);
- };
+ [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); };
return CreateModulePass(pass_func, 1, "ToANormalForm", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm")
-.set_body_typed(ToANormalForm);
+TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed(ToANormalForm);
} // namespace transform
* wheter directly invoking it, or indirectly by recursion.
*/
#include <tvm/ir/type_functor.h>
-#include <tvm/relay/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+
#include "let_list.h"
#include "pass_util.h"
// we assume the data type has no closure - no idea how to look into datatype right now.
-Type Arrow(const Type& l, const Type& r) {
- return FuncType({l}, r, {}, {});
-}
+Type Arrow(const Type& l, const Type& r) { return FuncType({l}, r, {}, {}); }
Type CPSType(const Type& t, const TypeVar& answer);
Type CPSType(const Type& t, const TypeVar& answer) {
struct CPSTypeMutator : TypeMutator {
- explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) { }
+ explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) {}
TypeVar answer;
Type VisitType_(const FuncTypeNode* t) final {
return CPSFuncType(GetRef<FuncType>(t), answer);
Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm);
-Function ToCPS(const Function& f,
- const IRModule& m,
- CPSMap* cm,
- VarMap* vm,
+Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm,
const TypeVar& answer) {
- std::function<Var(Var)> remap = [&](const Var& v) {
- return vm->count(v) == 0 ? v : vm->at(v);
- };
+ std::function<Var(Var)> remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); };
auto function_type = Downcast<FuncType>(f->checked_type());
// Each MCont can be used at most once.
struct CPSFunctor : ExprFunctor<Expr(const Expr&, const MCont&)>, PatternMutator {
- CPSFunctor(const std::function<Var(Var)>& remap,
- const TypeVar& answer,
- const IRModule& m,
- VarMap* vm,
- CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { }
+ CPSFunctor(const std::function<Var(Var)>& remap, const TypeVar& answer, const IRModule& m,
+ VarMap* vm, CPSMap* cm)
+ : remap(remap), answer(answer), m(m), vm(vm), cm(cm) {}
const std::function<Var(Var)>& remap;
TypeVar answer;
IRModule m;
CPSMap* cm;
Expr VisitExpr_(const LetNode* op, const MCont& k) final {
- return VisitExpr(op->value, [&](const Expr& v) {
- return Let(remap(op->var), v, VisitExpr(op->body, k));
- });
+ return VisitExpr(
+ op->value, [&](const Expr& v) { return Let(remap(op->var), v, VisitExpr(op->body, k)); });
}
Expr VisitExpr_(const FunctionNode* op, const MCont& k) final {
return k(GetRef<Constant>(op));
}
- Expr VisitExpr_(const VarNode* op, const MCont& k) final {
- return k(remap(GetRef<Var>(op)));
- }
+ Expr VisitExpr_(const VarNode* op, const MCont& k) final { return k(remap(GetRef<Var>(op))); }
- Pattern VisitPattern_(const PatternVarNode* op) final {
- return PatternVar(remap(op->var));
- }
+ Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(remap(op->var)); }
Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final {
auto gv = GetRef<GlobalVar>(op);
}
Expr reify(const MCont& k, const std::function<Expr(MCont)>& cont) {
- return LetList::LetBind(reify(k),
- [&](const Var& f) {
+ return LetList::LetBind(reify(k), [&](const Var& f) {
return cont([&](const Expr& e) { return Call(f, {e}); });
});
}
Expr VisitExpr_(const IfNode* op, const MCont& k) final {
return reify(k, [&](const MCont& kf) {
- return VisitExpr(op->cond,
- [&](const Expr& v) {
+ return VisitExpr(op->cond, [&](const Expr& v) {
return If(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf));
});
});
}
Expr VisitExpr_(const RefReadNode* op, const MCont& k) final {
- return VisitExpr(op->ref,
- [&](const Expr& r) {
- return LetList::LetBind(RefRead(r), k);
- });
+ return VisitExpr(op->ref, [&](const Expr& r) { return LetList::LetBind(RefRead(r), k); });
}
Expr VisitExpr_(const RefWriteNode* op, const MCont& k) final {
- return VisitExpr(op->ref,
- [&](const Expr& r) {
+ return VisitExpr(op->ref, [&](const Expr& r) {
return VisitExpr(op->value,
- [&](const Expr& v) {
- return LetList::LetBind(RefWrite(r, v), k);
- });
+ [&](const Expr& v) { return LetList::LetBind(RefWrite(r, v), k); });
});
}
tvm::Array<Expr> fields;
std::function<Expr()> next;
next = [&]() {
- return (fields.size() == op->fields.size()) ?
- k(Tuple(fields)) :
- VisitExpr(op->fields[fields.size()], [&](const Expr& v) {
- fields.push_back(v);
- return next();
- });
+ return (fields.size() == op->fields.size())
+ ? k(Tuple(fields))
+ : VisitExpr(op->fields[fields.size()], [&](const Expr& v) {
+ fields.push_back(v);
+ return next();
+ });
};
return next();
}
Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final {
- return VisitExpr(op->tuple, [&](const Expr& v) {
- return k(TupleGetItem(v, op->index));
- });
+ return VisitExpr(op->tuple, [&](const Expr& v) { return k(TupleGetItem(v, op->index)); });
}
Expr VisitExpr_(const CallNode* op, const MCont& k) final {
return LetList::LetBind(Call(op->op, args, op->attrs, op->type_args), k);
} else {
return VisitExpr(op->args[args.size()], [&](const Expr& v) {
- args.push_back(v);
- return next();
- });
+ args.push_back(v);
+ return next();
+ });
}
};
return next();
return next();
});
}
- };
+ };
return VisitExpr(op->op, [&](const Expr& v) {
f = v;
return next();
new_params.push_back(remap(v));
}
new_params.push_back(k);
- return Function(new_params,
- mut.VisitExpr(f->body,
- [&](const Expr& e) { return Call(k, {e}); }),
- answer,
- f->type_params,
- f->attrs);
+ return Function(new_params, mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }),
+ answer, f->type_params, f->attrs);
}
Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
TypeVar answer = TypeVar("answer", kType);
VarMap var;
struct Remapper : ExprVisitor, PatternVisitor {
- Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { }
+ Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) {}
TypeVar answer;
VarMap* vm;
void VisitExpr_(const VarNode* vn) final {
}
}
- void VisitPattern(const Pattern& p) final {
- PatternVisitor::VisitPattern(p);
- }
+ void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
- void VisitPattern_(const PatternVarNode* op) final {
- VisitExpr(op->var);
- }
+ void VisitPattern_(const PatternVarNode* op) final { VisitExpr(op->var); }
} remap(answer, &var);
remap.VisitExpr(f);
Function ret = ToCPS(f, m, cm, &var, answer);
type_args.push_back(tp);
}
type_args.push_back(new_ret_type);
- return Function(new_params,
- Call(f, args, {}, type_args),
- new_ret_type,
- new_type_params,
- f->attrs);
+ return Function(new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params,
+ f->attrs);
}
TVM_REGISTER_GLOBAL("relay._transform.to_cps")
-.set_body_typed(static_cast<Function (*)(const Function&, const IRModule&)>(ToCPS));
+ .set_body_typed(static_cast<Function (*)(const Function&, const IRModule&)>(ToCPS));
-TVM_REGISTER_GLOBAL("relay._transform.un_cps")
-.set_body_typed(UnCPS);
+TVM_REGISTER_GLOBAL("relay._transform.un_cps").set_body_typed(UnCPS);
namespace transform {
Pass ToCPS() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Function(ToCPS(f, m));
- };
+ [=](Function f, IRModule m, PassContext pc) { return Function(ToCPS(f, m)); };
return CreateFunctionPass(pass_func, 1, "ToCPS", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.ToCPS")
-.set_body_typed(ToCPS);
-
+TVM_REGISTER_GLOBAL("relay._transform.ToCPS").set_body_typed(ToCPS);
Pass UnCPS() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Function(UnCPS(f));
- };
+ [=](Function f, IRModule m, PassContext pc) { return Function(UnCPS(f)); };
return CreateFunctionPass(pass_func, 1, "UnCPS", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.UnCPS")
-.set_body_typed(UnCPS);
+TVM_REGISTER_GLOBAL("relay._transform.UnCPS").set_body_typed(UnCPS);
} // namespace transform
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
+
#include "let_list.h"
namespace tvm {
class UseVarVisitor : public ExprVisitor {
public:
- explicit UseVarVisitor(const Var& v) : v(v) { }
+ explicit UseVarVisitor(const Var& v) : v(v) {}
static bool UseVar(const Var& v, const Expr& e) {
UseVarVisitor uv(v);
bool use_var = false;
Var v;
- void VisitExpr_(const VarNode* vn) override {
- use_var = use_var || (v == GetRef<Var>(vn));
- }
+ void VisitExpr_(const VarNode* vn) override { use_var = use_var || (v == GetRef<Var>(vn)); }
};
class GNF : public ExprMutator {
return var_map_.count(v) == 0 ? v : var_map_.at(v);
}
- static bool UseVar(const Var& v, const Expr& e) {
- return UseVarVisitor::UseVar(v, e);
- }
+ static bool UseVar(const Var& v, const Expr& e) { return UseVarVisitor::UseVar(v, e); }
static Expr WrapRec(const Var& var, const Expr& val) {
return UseVar(var, val) ? Let(var, val, var) : val;
}
};
-Expr ToGraphNormalForm(const Expr& e) {
- return GNF()(e);
-}
+Expr ToGraphNormalForm(const Expr& e) { return GNF()(e); }
namespace transform {
Pass ToGraphNormalForm() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(ToGraphNormalForm(f));
- };
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(ToGraphNormalForm(f));
+ };
return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.ToGraphNormalForm")
-.set_body_typed(ToGraphNormalForm);
+TVM_REGISTER_GLOBAL("relay._transform.ToGraphNormalForm").set_body_typed(ToGraphNormalForm);
} // namespace transform
#ifndef TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_
#define TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_
-#include <tvm/tir/data_layout.h>
#include <tvm/relay/expr.h>
+#include <tvm/tir/data_layout.h>
+
#include <string>
-#include <unordered_map>
#include <tuple>
+#include <unordered_map>
#include <vector>
-#include "pattern_util.h"
+
#include "infer_layout_util.h"
+#include "pattern_util.h"
namespace tvm {
namespace relay {
struct key_hash : public std::function<std::size_t(TransformKey)> {
std::size_t operator()(const TransformKey& k) const {
return dmlc::HashCombine<std::string>(
- dmlc::HashCombine<std::string>(
- std::hash<const Object*>()(std::get<0>(k)), std::get<1>(k)),
+ dmlc::HashCombine<std::string>(std::hash<const Object*>()(std::get<0>(k)),
+ std::get<1>(k)),
(std::get<2>(k)));
}
};
// new_in2, new_out = op.infer(new_in)
if (new_call->op->IsInstance<OpNode>()) {
success = false;
- std::tie(new_in2, new_out, success) =
- InferCorrectLayouts(new_call, new_in, old_in, types);
+ std::tie(new_in2, new_out, success) = InferCorrectLayouts(new_call, new_in, old_in, types);
if (!success) {
return Expr(nullptr);
}
* If we can not infer a type or there are conflicting typing
* constraints we will trigger an error.
*/
-#include <tvm/ir/type_functor.h>
#include <tvm/ir/error.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
-#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
-#include "pass_util.h"
+
#include "../analysis/type_solver.h"
+#include "pass_util.h"
namespace tvm {
namespace relay {
struct TupleGetItemAttrs : public tvm::AttrsNode<TupleGetItemAttrs> {
int index;
- TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") {
- TVM_ATTR_FIELD(index);
- }
+ TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") { TVM_ATTR_FIELD(index); }
};
-bool TupleGetItemRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
+bool TupleGetItemRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
if (types[0].as<IncompleteTypeNode>()) return false;
const auto* data = types[0].as<TupleTypeNode>();
- CHECK(data != nullptr)
- << "TupleGetItem expect input type to be TupleType "
- << " get " << types[0] << " instead";
+ CHECK(data != nullptr) << "TupleGetItem expect input type to be TupleType "
+ << " get " << types[0] << " instead";
const auto* param = attrs.as<TupleGetItemAttrs>();
CHECK(param != nullptr);
CHECK_GE(param->index, 0);
}
TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
-TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem")
-.set_body_typed(
- TupleGetItemRel);
+TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem").set_body_typed(TupleGetItemRel);
struct ResolvedTypeInfo {
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
// constructors
explicit TypeInferencer(IRModule mod, GlobalVar current_func)
- : mod_(mod), current_func_(current_func),
- err_reporter(), solver_(current_func, mod, &this->err_reporter) {
+ : mod_(mod),
+ current_func_(current_func),
+ err_reporter(),
+ solver_(current_func, mod, &this->err_reporter) {
CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer";
}
Type Unify(const Type& t1, const Type& t2, const ObjectRef& expr) {
try {
return solver_.Unify(t1, t2, expr);
- } catch (const dmlc::Error &e) {
+ } catch (const dmlc::Error& e) {
this->ReportFatalError(
- expr,
- ErrorBuilder()
- << "Error unifying `"
- << t1
- << "` and `"
- << t2
- << "`: " << e.what());
+ expr, ErrorBuilder() << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what());
return Type();
}
}
// Lazily get type for expr
// expression, we will populate it now, and return the result.
- Type GetType(const Expr &expr) {
+ Type GetType(const Expr& expr) {
auto it = type_map_.find(expr);
if (it != type_map_.end() && it->second.checked_type.defined()) {
return it->second.checked_type;
Type VisitExpr_(const GlobalVarNode* op) final {
GlobalVar var = GetRef<GlobalVar>(op);
if (!mod_.defined()) {
- this->ReportFatalError(
- GetRef<GlobalVar>(op),
- ErrorBuilder() <<
- "Cannot do type inference on global variables " \
- "without a module");
+ this->ReportFatalError(GetRef<GlobalVar>(op),
+ ErrorBuilder() << "Cannot do type inference on global variables "
+ "without a module");
}
Expr e = mod_->Lookup(var);
return e->checked_type();
}
- Type VisitExpr_(const ConstantNode* op) final {
- return op->tensor_type();
- }
+ Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); }
Type VisitExpr_(const TupleNode* op) final {
Array<Type> types;
}
Type VisitExpr_(const TupleGetItemNode* op) final {
- if (!tuple_getitem_rel_.defined()) {
- tuple_getitem_rel_ = Downcast<TypeRelationFn>(
- EnvFunc::Get("tvm.relay.type_relation.TupleGetItem"));
+ if (!tuple_getitem_rel_.defined()) {
+ tuple_getitem_rel_ =
+ Downcast<TypeRelationFn>(EnvFunc::Get("tvm.relay.type_relation.TupleGetItem"));
}
Type tuple_type = GetType(op->tuple);
Type rtype = IncompleteType(Kind::kType);
auto attrs = make_object<TupleGetItemAttrs>();
attrs->index = op->index;
- solver_.AddConstraint(TypeRelation(
- tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef<TupleGetItem>(op));
+ solver_.AddConstraint(TypeRelation(tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)),
+ GetRef<TupleGetItem>(op));
return rtype;
}
void VisitPattern_(const PatternConstructorNode* con, const Type& t) {
- CHECK(mod_.defined())
- << "Cannot do type inference without a environment:"
- << con->constructor->name_hint;
+ CHECK(mod_.defined()) << "Cannot do type inference without a environment:"
+ << con->constructor->name_hint;
TypeData td = mod_->type_definitions.at(con->constructor->belong_to);
auto pc = GetRef<PatternConstructor>(con);
this->ReportFatalError(pc, ErrorBuilder() << "Expected a type call, got " << unified);
}
if (td->header != tc->func) {
- this->ReportFatalError(pc,
- ErrorBuilder() << "ADT headers must match, but we have "
- << td->header << " and " << tc->func);
+ this->ReportFatalError(pc, ErrorBuilder() << "ADT headers must match, but we have "
+ << td->header << " and " << tc->func);
}
if (td->type_vars.size() != tc->args.size()) {
- this->ReportFatalError(pc,
- ErrorBuilder() << "The number of type args must match"
- << "the number of type vars in the type data: "
- << td->type_vars.size() << " != " << tc->args.size());
+ this->ReportFatalError(
+ pc, ErrorBuilder() << "The number of type args must match"
+ << "the number of type vars in the type data: " << td->type_vars.size()
+ << " != " << tc->args.size());
}
std::unordered_map<TypeVar, Type, ObjectHash, ObjectEqual> type_var_map_;
for (size_t i = 0; i < td->type_vars.size(); ++i) {
}
CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern";
if (con->constructor->inputs.size() != con->patterns.size()) {
- this->ReportFatalError(pc,
- ErrorBuilder() << "Not enough inputs for the constructor; "
- << "expected " << con->constructor->inputs.size()
- << ", got " << con->patterns.size());
+ this->ReportFatalError(pc, ErrorBuilder() << "Not enough inputs for the constructor; "
+ << "expected " << con->constructor->inputs.size()
+ << ", got " << con->patterns.size());
}
for (size_t i = 0; i < con->constructor->inputs.size(); ++i) {
VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_));
Unify(vt, t, pv->span);
}
- void VisitPattern_(const PatternWildcardNode* wc, const Type& t) { }
+ void VisitPattern_(const PatternWildcardNode* wc, const Type& t) {}
Type VisitExpr_(const MatchNode* op) final {
Type dtype = GetType(op->data);
}
Type rtype = IncompleteType(Kind::kType);
for (const auto& c : op->clauses) {
- rtype = this->Unify(rtype,
- GetType(c->rhs),
- op->span);
+ rtype = this->Unify(rtype, GetType(c->rhs), op->span);
}
if (op->complete) {
for (auto cs : unmatched_cases) {
ss << "case " << i++ << ": \n" << PrettyPrint(cs);
}
- this->ReportFatalError(
- match,
- ss);
+ this->ReportFatalError(match, ss);
}
}
return rtype;
}
- Type VisitExpr_(const OpNode* op) final {
- return op->op_type;
- }
+ Type VisitExpr_(const OpNode* op) final { return op->op_type; }
Type VisitExpr_(const LetNode* let) final {
// if the definition is a function literal, permit recursion
type_map_[let->var].checked_type = let_type;
}
-
if (let->var->type_annotation.defined()) {
let_type = Unify(let_type, let->var->type_annotation, GetRef<Let>(let));
}
// Ensure the type of the guard is of Tensor[Bool, ()],
// that is a rank-0 boolean tensor.
Type cond_type = this->GetType(ite->cond);
- this->Unify(cond_type,
- TensorType::Scalar(tvm::DataType::Bool()),
- ite->cond);
+ this->Unify(cond_type, TensorType::Scalar(tvm::DataType::Bool()), ite->cond);
Type checked_true = this->GetType(ite->true_branch);
Type checked_false = this->GetType(ite->false_branch);
return this->Unify(checked_true, checked_false, GetRef<If>(ite));
// which are registered in the style defined in src/relay/op/*.
//
// The result will be the return type of the operator.
- Type PrimitiveCall(const FuncTypeNode* op,
- Array<Type> arg_types,
- const Attrs& attrs,
+ Type PrimitiveCall(const FuncTypeNode* op, Array<Type> arg_types, const Attrs& attrs,
const ObjectRef& loc) {
if (op->type_params.size() != arg_types.size() + 1) return Type();
if (op->type_constraints.size() != 1) return Type();
Type rtype = IncompleteType(Kind::kType);
arg_types.push_back(rtype);
// we can do simple replacement here
- solver_.AddConstraint(TypeRelation(
- rel->func, arg_types, arg_types.size() - 1, attrs), loc);
+ solver_.AddConstraint(TypeRelation(rel->func, arg_types, arg_types.size() - 1, attrs), loc);
return rtype;
}
ret_type = IncompleteType(Kind::kType);
}
- Type inst_ty = FuncType(fn_ty->arg_types,
- ret_type, {},
- fn_ty->type_constraints);
+ Type inst_ty = FuncType(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints);
inst_ty = Bind(inst_ty, subst_map);
return Downcast<FuncType>(inst_ty);
}
return InstantiateFuncType(fn_ty, type_args);
}
-
void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
auto type_info = type_map_.find(expr);
if (type_info == type_map_.end()) {
if (fn_ty_node == nullptr && inc_ty_node == nullptr) {
this->ReportFatalError(
- GetRef<Call>(call),
- ErrorBuilder()
- << "only expressions with function types can be called, found "
- << ftype);
+ GetRef<Call>(call),
+ ErrorBuilder() << "only expressions with function types can be called, found " << ftype);
}
// incomplete type => it must be a function taking the arg types
Array<Type> type_args = call->type_args;
if (type_args.size() > fn_ty_node->type_params.size()) {
this->ReportFatalError(GetRef<Call>(call),
- ErrorBuilder()
- << "Incorrect number of type args in "
- << call->span << ": "
- << "Expected "
- << fn_ty_node->type_params.size()
- << "but got " << type_args.size());
+ ErrorBuilder()
+ << "Incorrect number of type args in " << call->span << ": "
+ << "Expected " << fn_ty_node->type_params.size() << "but got "
+ << type_args.size());
}
FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);
if (type_arity != number_of_args) {
if (type_arity < number_of_args) {
- this->ReportFatalError(
- GetRef<Call>(call),
- ErrorBuilder()
- << "the function is provided too many arguments "
- << "expected " << type_arity << ", found " << number_of_args);
+ this->ReportFatalError(GetRef<Call>(call),
+ ErrorBuilder()
+ << "the function is provided too many arguments "
+ << "expected " << type_arity << ", found " << number_of_args);
} else {
- this->ReportFatalError(
- GetRef<Call>(call),
- ErrorBuilder()
- << "the function is provided too few arguments "
- << "expected " << type_arity << ", found " << number_of_args);
+ this->ReportFatalError(GetRef<Call>(call),
+ ErrorBuilder()
+ << "the function is provided too few arguments "
+ << "expected " << type_arity << ", found " << number_of_args);
}
}
for (auto cs : fn_ty->type_constraints) {
if (const auto* tr = cs.as<TypeRelationNode>()) {
- solver_.AddConstraint(
- TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs),
- GetRef<Call>(call));
+ solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs),
+ GetRef<Call>(call));
} else {
solver_.AddConstraint(cs, GetRef<Call>(call));
}
}
if (const OpNode* opnode = call->op.as<OpNode>()) {
- Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(),
- arg_types,
- call->attrs,
+ Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(), arg_types, call->attrs,
GetRef<Call>(call));
if (rtype.defined()) {
AddTypeArgs(GetRef<Call>(call), arg_types);
return solver_.Resolve(ret);
}
- Type VisitExpr_(const RefCreateNode* op) final {
- return RelayRefType(GetType(op->value));
- }
+ Type VisitExpr_(const RefCreateNode* op) final { return RelayRefType(GetType(op->value)); }
Type VisitExpr_(const RefReadNode* op) final {
Type it = IncompleteType(Kind::kType);
}
Type VisitExpr_(const ConstructorNode* c) final {
- CHECK(mod_.defined())
- << "Cannot do type inference without a environment:"
- << c->name_hint;
+ CHECK(mod_.defined()) << "Cannot do type inference without a environment:" << c->name_hint;
TypeData td = mod_->LookupTypeDef(c->belong_to);
std::vector<Type> types;
- for (const auto & t : td->type_vars) {
+ for (const auto& t : td->type_vars) {
types.push_back(t);
}
- return FuncType(c->inputs, TypeCall(c->belong_to, types),
- td->type_vars, {});
+ return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {});
}
void Solve() {
public:
Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, ObjectHash, ObjectEqual>& tmap,
TypeSolver* solver)
- : tmap_(tmap), solver_(solver) {
- }
+ : tmap_(tmap), solver_(solver) {}
- Expr VisitExpr_(const VarNode* op) final {
- return VisitVar(GetRef<Var>(op));
- }
+ Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef<Var>(op)); }
- Expr VisitExpr_(const ConstantNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const ConstantNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const GlobalVarNode* op) final {
- return GetRef<GlobalVar>(op);
- }
+ Expr VisitExpr_(const GlobalVarNode* op) final { return GetRef<GlobalVar>(op); }
- Expr VisitExpr_(const OpNode* op) final {
- return ExprMutator::VisitExpr_(op);
- }
+ Expr VisitExpr_(const OpNode* op) final { return ExprMutator::VisitExpr_(op); }
- Expr VisitExpr_(const TupleNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const TupleNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const TupleGetItemNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const TupleGetItemNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const FunctionNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const CallNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const CallNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const LetNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const IfNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const RefCreateNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const RefCreateNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const RefReadNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const RefReadNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const RefWriteNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const RefWriteNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const ConstructorNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const ConstructorNode* op) final { return AttachCheckedType(op); }
- Expr VisitExpr_(const MatchNode* op) final {
- return AttachCheckedType(op);
- }
+ Expr VisitExpr_(const MatchNode* op) final { return AttachCheckedType(op); }
- Pattern VisitPattern(const Pattern& p) final {
- return PatternMutator::VisitPattern(p);
- }
+ Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
Var VisitVar(const Var& v) final {
if (vmap_.count(v) == 0) {
}
// attach checked type to the mutated node.
- template<typename T>
+ template <typename T>
Expr AttachCheckedType(const T* op) {
auto it = tmap_.find(GetRef<Expr>(op));
CHECK(it != tmap_.end());
// TODO(@jroesch): it would be nice if we would report resolution
// errors directly on the program.
CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
- << "Cannot resolve type of " << GetRef<Expr>(op)
- << " at " << op->span;
+ << "Cannot resolve type of " << GetRef<Expr>(op) << " at " << op->span;
Expr new_e = ExprMutator::VisitExpr_(op);
// new_call and new_var's code is only going to be valid for VarNode/CallNode.
// Compiler optimization will likely fold these away for other nodes.
- CallNode* new_call =(
- std::is_base_of<CallNode, T>::value ?
- const_cast<CallNode*>(static_cast<const CallNode*>(new_e.get())) : nullptr);
- VarNode* new_var =(
- std::is_base_of<VarNode, T>::value ?
- const_cast<VarNode*>(static_cast<const VarNode*>(new_e.get())) : nullptr);
- FunctionNode* new_fn =(
- std::is_base_of<FunctionNode, T>::value ?
- const_cast<FunctionNode*>(static_cast<const FunctionNode*>(new_e.get())) : nullptr);
+ CallNode* new_call = (std::is_base_of<CallNode, T>::value
+ ? const_cast<CallNode*>(static_cast<const CallNode*>(new_e.get()))
+ : nullptr);
+ VarNode* new_var = (std::is_base_of<VarNode, T>::value
+ ? const_cast<VarNode*>(static_cast<const VarNode*>(new_e.get()))
+ : nullptr);
+ FunctionNode* new_fn =
+ (std::is_base_of<FunctionNode, T>::value
+ ? const_cast<FunctionNode*>(static_cast<const FunctionNode*>(new_e.get()))
+ : nullptr);
// check if we need update the new_e
bool need_update_type = !checked_type.same_as(new_e->checked_type_);
- bool need_update_call = (
- std::is_base_of<CallNode, T>::value &&
- it->second.type_args.defined() &&
- !it->second.type_args.same_as(new_call->type_args));
- bool need_update_var = (
- std::is_base_of<VarNode, T>::value &&
- update_missing_type_annotation_ &&
- !new_var->type_annotation.defined());
-
- bool need_update_fn =(
- std::is_base_of<FunctionNode, T>::value &&
- update_missing_type_annotation_ &&
- !new_fn->ret_type.defined());
-
- if (!need_update_type &&
- !need_update_var &&
- !need_update_call &&
- !need_update_fn) {
+ bool need_update_call =
+ (std::is_base_of<CallNode, T>::value && it->second.type_args.defined() &&
+ !it->second.type_args.same_as(new_call->type_args));
+ bool need_update_var = (std::is_base_of<VarNode, T>::value && update_missing_type_annotation_ &&
+ !new_var->type_annotation.defined());
+
+ bool need_update_fn = (std::is_base_of<FunctionNode, T>::value &&
+ update_missing_type_annotation_ && !new_fn->ret_type.defined());
+
+ if (!need_update_type && !need_update_var && !need_update_call && !need_update_fn) {
return new_e;
}
// we make a copy mutating an existing reference.
ObjectPtr<ExprNode> ptr = make_object<T>(*new_e.as<T>());
new_e = Expr(ptr);
- new_call = (
- std::is_base_of<CallNode, T>::value ?
- static_cast<CallNode*>(ptr.get()) : nullptr);
- new_var = (
- std::is_base_of<VarNode, T>::value ?
- static_cast<VarNode*>(ptr.get()) : nullptr);
- new_fn = (
- std::is_base_of<FunctionNode, T>::value ?
- static_cast<FunctionNode*>(ptr.get()) : nullptr);
+ new_call =
+ (std::is_base_of<CallNode, T>::value ? static_cast<CallNode*>(ptr.get()) : nullptr);
+ new_var = (std::is_base_of<VarNode, T>::value ? static_cast<VarNode*>(ptr.get()) : nullptr);
+ new_fn = (std::is_base_of<FunctionNode, T>::value ? static_cast<FunctionNode*>(ptr.get())
+ : nullptr);
}
// attach the information.
return new_e;
}
- Type VisitType(const Type &t) final {
- return solver_->Resolve(t);
- }
+ Type VisitType(const Type& t) final { return solver_->Resolve(t); }
private:
std::unordered_map<Var, Var, ObjectHash, ObjectEqual> vmap_;
struct AllCheckTypePopulated : ExprVisitor {
void VisitExpr(const Expr& e) {
- if (e.as<OpNode>()) { return; }
- if (e.as<GlobalVarNode>()) { return; }
- if (e.as<ConstructorNode>()) { return; }
+ if (e.as<OpNode>()) {
+ return;
+ }
+ if (e.as<GlobalVarNode>()) {
+ return;
+ }
+ if (e.as<ConstructorNode>()) {
+ return;
+ }
CHECK(e->checked_type_.defined()) << "Expression: " << e;
return ExprVisitor::VisitExpr(e);
}
};
-void EnsureCheckedType(const Expr& e) {
- AllCheckTypePopulated().VisitExpr(e);
-}
+void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); }
Expr InferType(const Expr& expr, const IRModule& mod) {
auto main = mod->GetGlobalVar("main");
auto e = inferencer.Infer(expr);
CHECK(WellFormed(e));
auto free_tvars = FreeTypeVars(e, mod);
- CHECK(free_tvars.size() == 0)
- << "Found unbound type variables in " << e << ": " << free_tvars;
+ CHECK(free_tvars.size() == 0) << "Found unbound type variables in " << e << ": " << free_tvars;
EnsureCheckedType(e);
return e;
}
-Function InferType(const Function& func,
- const IRModule& mod,
- const GlobalVar& var) {
+Function InferType(const Function& func, const IRModule& mod, const GlobalVar& var) {
CHECK(mod.defined()) << "internal error: module must be set for type inference";
Function func_copy = Function(make_object<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation();
mod->Remove(var);
CHECK(WellFormed(func_ret));
auto free_tvars = FreeTypeVars(func_ret, mod);
- CHECK(free_tvars.size() == 0)
- << "Found unbound type variables in: "
- << std::endl
- << AsText(func, true)
- << std::endl << free_tvars;
+ CHECK(free_tvars.size() == 0) << "Found unbound type variables in: " << std::endl
+ << AsText(func, true) << std::endl
+ << free_tvars;
return Downcast<Function>(func_ret);
}
Pass InferType() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(InferType(f, m));
- };
+ [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(InferType(f, m)); };
return CreateFunctionPass(pass_func, 0, "InferType", {});
}
-TVM_REGISTER_GLOBAL("relay._transform.InferType")
-.set_body_typed([]() {
- return InferType();
-});
+TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); });
} // namespace transform
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
/*!
* \file builtin_fp16.cc
* \brief Functions for conversion between fp32 and fp16
-*/
+ */
#include <builtin_fp16.h>
#include <tvm/runtime/c_runtime_api.h>
* \brief Device specific implementations
*/
#include <dmlc/thread_local.h>
-#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
-#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include <tvm/runtime/device_api.h>
-#include <sstream>
-#include <array>
+
#include <algorithm>
-#include <string>
-#include <cstdlib>
+#include <array>
#include <cctype>
-#include "runtime_base.h"
+#include <cstdlib>
+#include <sstream>
+#include <string>
+
#include "object_internal.h"
+#include "runtime_base.h"
namespace tvm {
namespace runtime {
public:
static const int kMaxDeviceAPI = 32;
// Get API
- static DeviceAPI* Get(const TVMContext& ctx) {
- return Get(ctx.device_type);
- }
+ static DeviceAPI* Get(const TVMContext& ctx) { return Get(ctx.device_type); }
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing);
}
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_;
// constructor
- DeviceAPIManager() {
- std::fill(api_.begin(), api_.end(), nullptr);
- }
+ DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
std::string factory = "device_api." + name;
auto* f = Registry::Get(factory);
if (f == nullptr) {
- CHECK(allow_missing)
- << "Device API " << name << " is not enabled.";
+ CHECK(allow_missing) << "Device API " << name << " is not enabled.";
return nullptr;
}
void* ptr = (*f)();
};
DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
- return DeviceAPIManager::Get(
- static_cast<int>(ctx.device_type), allow_missing);
+ return DeviceAPIManager::Get(static_cast<int>(ctx.device_type), allow_missing);
}
-void* DeviceAPI::AllocWorkspace(TVMContext ctx,
- size_t size,
- DLDataType type_hint) {
+void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}
-void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
- FreeDataSpace(ctx, ptr);
-}
+void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { FreeDataSpace(ctx, ptr); }
TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) {
LOG(FATAL) << "Device does not support stream api.";
LOG(FATAL) << "Device does not support stream api.";
}
-void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
- TVMStreamHandle event_src,
+void DeviceAPI::SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src,
TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
// Parse error type.
{
size_t start_pos = 0, end_pos;
- for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {}
+ for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {
+ }
for (end_pos = start_pos; end_pos < line.length(); ++end_pos) {
char ch = line[end_pos];
if (ch == ':') {
}
if (error_type.length() != 0) {
// if we successfully detected error_type: trim the following space.
- for (start_pos = end_pos + 1;
- start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {}
+ for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' ';
+ ++start_pos) {
+ }
line = line.substr(start_pos);
} else {
// did not detect error_type, use default value.
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
-const char *TVMGetLastError() {
- return TVMAPIRuntimeStore::Get()->last_error.c_str();
-}
+const char* TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); }
-int TVMAPIHandleException(const std::runtime_error &e) {
+int TVMAPIHandleException(const std::runtime_error& e) {
TVMAPISetLastError(NormalizeError(e.what()).c_str());
return -1;
}
-void TVMAPISetLastError(const char* msg) {
- TVMAPIRuntimeStore::Get()->last_error = msg;
-}
+void TVMAPISetLastError(const char* msg) { TVMAPIRuntimeStore::Get()->last_error = msg; }
-int TVMModLoadFromFile(const char* file_name,
- const char* format,
- TVMModuleHandle* out) {
+int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) {
API_BEGIN();
TVMRetValue ret;
ret = Module::LoadFromFile(file_name, format);
API_END();
}
-int TVMModImport(TVMModuleHandle mod,
- TVMModuleHandle dep) {
+int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) {
API_BEGIN();
- ObjectInternal::GetModuleNode(mod)->Import(
- GetRef<Module>(ObjectInternal::GetModuleNode(dep)));
+ ObjectInternal::GetModuleNode(mod)->Import(GetRef<Module>(ObjectInternal::GetModuleNode(dep)));
API_END();
}
-int TVMModGetFunction(TVMModuleHandle mod,
- const char* func_name,
- int query_imports,
- TVMFunctionHandle *func) {
+int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports,
+ TVMFunctionHandle* func) {
API_BEGIN();
- PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(
- func_name, query_imports != 0);
+ PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0);
if (pf != nullptr) {
*func = new PackedFunc(pf);
} else {
API_END();
}
-int TVMModFree(TVMModuleHandle mod) {
- return TVMObjectFree(mod);
-}
+int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); }
-int TVMBackendGetFuncFromEnv(void* mod_node,
- const char* func_name,
- TVMFunctionHandle *func) {
+int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) {
API_BEGIN();
- *func = (TVMFunctionHandle)(
- static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
+ *func = (TVMFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
API_END();
}
-void* TVMBackendAllocWorkspace(int device_type,
- int device_id,
- uint64_t size,
- int dtype_code_hint,
+void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint,
int dtype_bits_hint) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
- return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
- static_cast<size_t>(size),
- type_hint);
+ return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast<size_t>(size), type_hint);
}
-int TVMBackendFreeWorkspace(int device_type,
- int device_id,
- void* ptr) {
+int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
return 0;
}
-int TVMBackendRunOnce(void** handle,
- int (*f)(void*),
- void* cdata,
- int nbytes) {
+int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1);
return (*f)(cdata);
API_END();
}
-int TVMFuncCall(TVMFunctionHandle func,
- TVMValue* args,
- int* arg_type_codes,
- int num_args,
- TVMValue* ret_val,
- int* ret_type_code) {
+int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,
+ TVMValue* ret_val, int* ret_type_code) {
API_BEGIN();
TVMRetValue rv;
- (*static_cast<const PackedFunc*>(func)).CallPacked(
- TVMArgs(args, arg_type_codes, num_args), &rv);
+ (*static_cast<const PackedFunc*>(func)).CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
- if (rv.type_code() == kTVMStr ||
- rv.type_code() == kTVMDataType ||
- rv.type_code() == kTVMBytes) {
+ if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
if (rv.type_code() != kTVMDataType) {
e->ret_str = *rv.ptr<std::string>();
API_END();
}
-int TVMCFuncSetReturn(TVMRetValueHandle ret,
- TVMValue* value,
- int* type_code,
- int num_ret) {
+int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) {
API_BEGIN();
CHECK_EQ(num_ret, 1);
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
API_END();
}
-int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
- void* resource_handle,
- TVMPackedCFuncFinalizer fin,
- TVMFunctionHandle *out) {
+int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin,
+ TVMFunctionHandle* out) {
API_BEGIN();
if (fin == nullptr) {
- *out = new PackedFunc(
- [func, resource_handle](TVMArgs args, TVMRetValue* rv) {
- int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
- args.num_args, rv, resource_handle);
- if (ret != 0) {
- throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());
- }
- });
+ *out = new PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) {
+ int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
+ args.num_args, rv, resource_handle);
+ if (ret != 0) {
+ throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());
+ }
+ });
} else {
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
- *out = new PackedFunc(
- [func, rpack](TVMArgs args, TVMRetValue* rv) {
- int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
- args.num_args, rv, rpack.get());
- if (ret != 0) {
- throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());
- }
- });
+ *out = new PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) {
+ int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
+ args.num_args, rv, rpack.get());
+ if (ret != 0) {
+ throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());
+ }
+ });
}
API_END();
}
API_END();
}
-int TVMStreamStreamSynchronize(int device_type,
- int device_id,
- TVMStreamHandle src,
+int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src,
TVMStreamHandle dst) {
API_BEGIN();
TVMContext ctx;
API_END();
}
-
-int TVMDeviceAllocDataSpace(DLContext ctx,
- size_t nbytes,
- size_t alignment,
- DLDataType type_hint,
+int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint,
void** out_data) {
API_BEGIN();
- out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace(
- ctx, nbytes, alignment, type_hint);
+ out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace(ctx, nbytes, alignment, type_hint);
API_END();
}
API_END();
}
-int TVMDeviceCopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t num_bytes,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
- TVMStreamHandle stream) {
+int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
+ size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to,
+ DLDataType type_hint, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx = ctx_from.device_type != kDLCPU ? ctx_from : ctx_to;
- DeviceAPIManager::Get(ctx)->CopyDataFromTo(
- from, from_offset,
- to, to_offset,
- num_bytes, ctx_from, ctx_to, type_hint, stream);
+ DeviceAPIManager::Get(ctx)->CopyDataFromTo(from, from_offset, to, to_offset, num_bytes, ctx_from,
+ ctx_to, type_hint, stream);
API_END();
}
// set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- TVMContext ctx;
- ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
- ctx.device_id = args[1];
- DeviceAPIManager::Get(ctx)->SetDevice(ctx);
- });
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ TVMContext ctx;
+ ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
+ ctx.device_id = args[1];
+ DeviceAPIManager::Get(ctx)->SetDevice(ctx);
+ });
// set device api
-TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- TVMContext ctx;
- ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
- ctx.device_id = args[1];
-
- DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
- if (kind == kExist) {
- DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
- if (api != nullptr) {
- api->GetAttr(ctx, kind, ret);
- } else {
- *ret = 0;
- }
+TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr").set_body([](TVMArgs args, TVMRetValue* ret) {
+ TVMContext ctx;
+ ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
+ ctx.device_id = args[1];
+
+ DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
+ if (kind == kExist) {
+ DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
+ if (api != nullptr) {
+ api->GetAttr(ctx, kind, ret);
} else {
- DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
+ *ret = 0;
}
- });
-
+ } else {
+ DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
+ }
+});
-TVM_REGISTER_GLOBAL("runtime.TVMSetStream")
-.set_body_typed(TVMSetStream);
+TVM_REGISTER_GLOBAL("runtime.TVMSetStream").set_body_typed(TVMSetStream);
#include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
-#include <tvm/runtime/vm.h>
#include <tvm/runtime/registry.h>
+#include <tvm/runtime/vm.h>
namespace tvm {
namespace runtime {
using namespace vm;
-TVM_REGISTER_GLOBAL("runtime.GetADTTag")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime.GetADTTag").set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag());
});
-TVM_REGISTER_GLOBAL("runtime.GetADTSize")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime.GetADTSize").set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.size());
});
-
-TVM_REGISTER_GLOBAL("runtime.GetADTFields")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime.GetADTFields").set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
const auto& adt = Downcast<ADT>(obj);
*rv = adt[idx];
});
-TVM_REGISTER_GLOBAL("runtime.Tuple")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime.Tuple").set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]);
*rv = ADT::Tuple(fields);
});
-TVM_REGISTER_GLOBAL("runtime.ADT")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<ObjectRef> fields;
*rv = ADT(tag, fields);
});
-TVM_REGISTER_GLOBAL("runtime.String")
-.set_body_typed([](std::string str) {
+TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) {
return String(std::move(str));
});
-TVM_REGISTER_GLOBAL("runtime.GetFFIString")
-.set_body_typed([](String str) {
+TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) {
return std::string(str);
});
* \file Use external cblas library call.
*/
#include <dmlc/logging.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
+
#include "gemm_common.h"
extern "C" {
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
int ldb, float beta, float* C, int ldc) {
#if USE_DNNL == 1
- dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B,
- ldb, A, lda, beta, C, ldc);
+ dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B, ldb, A,
+ lda, beta, C, ldc);
#else
cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
};
// matrix multiplication for row major
-TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
CallGemm(args, ret, CblasDgemmOp());
});
-TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32)) {
});
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- DLTensor* A = args[0];
- CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
- if (TypeMatch(A->dtype, kDLFloat, 32)) {
- CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
- } else {
- CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp());
- }
-});
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
+ if (TypeMatch(A->dtype, kDLFloat, 32)) {
+ CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
+ } else {
+ CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp());
+ }
+ });
} // namespace contrib
} // namespace tvm
*/
#pragma once
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
+
#include <algorithm>
namespace tvm {
namespace contrib {
using namespace runtime;
-inline int ColumnStride(DLTensor *tensor) {
+inline int ColumnStride(DLTensor* tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
}
}
-inline int ElementStride(DLTensor *tensor) {
+inline int ElementStride(DLTensor* tensor) {
if (tensor->strides) {
return std::min(tensor->strides[0], tensor->strides[1]);
} else {
}
// Reversed strides indicates an in-place transpose operation.
-inline bool IsInPlaceTransposed(DLTensor *tensor) {
+inline bool IsInPlaceTransposed(DLTensor* tensor) {
return tensor->strides && (tensor->strides[1] > tensor->strides[0]);
}
-inline int RowCount(DLTensor *tensor, bool trans) {
- return tensor->shape[trans ? 1 : 0];
-}
+inline int RowCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; }
-inline int ColumnCount(DLTensor *tensor, bool trans) {
- return tensor->shape[trans ? 0 : 1];
-}
+inline int ColumnCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; }
// Call a column major blas. Note that data is stored in tvm as row
// major, so this we switch the arguments.
template <typename TGemmOp>
-inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
- DLTensor *A = args[0];
- DLTensor *B = args[1];
- DLTensor *C = args[2];
+inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) {
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8;
CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
- op(transb, transa, ColumnCount(B, transb), RowCount(A, transa),
- ColumnCount(A, transa), static_cast<typename TGemmOp::TDatatype>(alpha),
- reinterpret_cast<typename TGemmOp::TDatatype *>(
- static_cast<char *>(B->data) + B->byte_offset),
+ op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa),
+ static_cast<typename TGemmOp::TDatatype>(alpha),
+ reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(B->data) + B->byte_offset),
ColumnStride(B),
- reinterpret_cast<typename TGemmOp::TDatatype *>(
- static_cast<char *>(A->data) + A->byte_offset),
+ reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(A->data) + A->byte_offset),
ColumnStride(A), static_cast<typename TGemmOp::TDatatype>(beta),
- reinterpret_cast<typename TGemmOp::TDatatype *>(
- static_cast<char *>(C->data) + C->byte_offset),
+ reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(C->data) + C->byte_offset),
ColumnStride(C));
}
-inline int ColumnStride3D(DLTensor *tensor) {
+inline int ColumnStride3D(DLTensor* tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
return tensor->shape[2];
}
}
-inline int ElementStride3D(DLTensor *tensor) {
+inline int ElementStride3D(DLTensor* tensor) {
if (tensor->strides) {
return std::min(tensor->strides[1], tensor->strides[2]);
} else {
}
}
// Reversed strides indicates an in-place transpose operation.
-inline bool IsInPlaceTransposed3D(DLTensor *tensor) {
+inline bool IsInPlaceTransposed3D(DLTensor* tensor) {
return tensor->strides && (tensor->strides[2] > tensor->strides[1]);
}
-inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; }
-inline int RowCount3D(DLTensor *tensor, bool trans) {
- return tensor->shape[trans ? 2 : 1];
-}
-inline int ColumnCount3D(DLTensor *tensor, bool trans) {
- return tensor->shape[trans ? 1 : 2];
-}
+inline int BatchCount3D(DLTensor* tensor) { return tensor->shape[0]; }
+inline int RowCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 2 : 1]; }
+inline int ColumnCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 2]; }
template <typename TBatchGemmOp>
-inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) {
+inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) {
using DType = typename TBatchGemmOp::TDatatype;
- DLTensor *A = args[0];
- DLTensor *B = args[1];
- DLTensor *C = args[2];
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
int bit_depth = sizeof(DType) * 8;
const int A_size = A->shape[1] * A->shape[2];
const int B_size = B->shape[1] * B->shape[2];
const int C_size = C->shape[1] * C->shape[2];
- DType *A_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
- static_cast<char *>(A->data) + A->byte_offset);
- DType *B_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
- static_cast<char *>(B->data) + B->byte_offset);
- DType *C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
- static_cast<char *>(C->data) + C->byte_offset);
- op(batch_size, transb, transa, ColumnCount3D(B, transb),
- RowCount3D(A, transa), ColumnCount3D(A, transa),
- static_cast<typename TBatchGemmOp::TDatatype>(alpha),
- B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
+ DType* A_data = reinterpret_cast<typename TBatchGemmOp::TDatatype*>(static_cast<char*>(A->data) +
+ A->byte_offset);
+ DType* B_data = reinterpret_cast<typename TBatchGemmOp::TDatatype*>(static_cast<char*>(B->data) +
+ B->byte_offset);
+ DType* C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype*>(static_cast<char*>(C->data) +
+ C->byte_offset);
+ op(batch_size, transb, transa, ColumnCount3D(B, transb), RowCount3D(A, transa),
+ ColumnCount3D(A, transa), static_cast<typename TBatchGemmOp::TDatatype>(alpha), B_data, B_size,
+ ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
static_cast<typename TBatchGemmOp::TDatatype>(beta), C_data, C_size, ColumnStride3D(C));
}
#ifndef TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_
#define TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_
-#import <Foundation/Foundation.h>
#import <CoreML/CoreML.h>
+#import <Foundation/Foundation.h>
#include <dlpack/dlpack.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
-#include <vector>
-#include <string>
#include <memory>
+#include <string>
+#include <vector>
namespace tvm {
namespace runtime {
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
- virtual PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self);
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
/*!
* \return The type key of the executor.
*/
- const char* type_key() const {
- return "CoreMLRuntime";
- }
+ const char* type_key() const { return "CoreMLRuntime"; }
/*!
* \brief Invoke the coreml prediction.
* \param ctx The context where the coreml model will be executed on.
* \param output_names The output names of the model.
*/
- void Init(const std::string& model_path,
- TVMContext ctx,
- const std::vector<NSString *>& output_names);
+ void Init(const std::string& model_path, TVMContext ctx,
+ const std::vector<NSString*>& output_names);
/*!
* \brief set input to the model.
int GetNumOutputs() const;
// CoreML model
- MLModel *model_;
+ MLModel* model_;
// CoreML model input dictionary
- NSMutableDictionary<NSString *, id> *input_dict_;
+ NSMutableDictionary<NSString*, id>* input_dict_;
// CoreML model output
id<MLFeatureProvider> output_;
// List of output names
- std::vector<NSString *> output_names_;
+ std::vector<NSString*> output_names_;
// TVM context
TVMContext ctx_;
};
namespace tvm {
namespace runtime {
-MLModel *load_coreml_model(const std::string& model_path) {
+MLModel* load_coreml_model(const std::string& model_path) {
NSBundle* bundle = [NSBundle mainBundle];
NSString* base = [bundle privateFrameworksPath];
NSString* fname = [NSString stringWithUTF8String:("tvm/" + model_path).c_str()];
- NSString* assetPath = [base stringByAppendingPathComponent: fname];
+ NSString* assetPath = [base stringByAppendingPathComponent:fname];
if (![[NSFileManager defaultManager] fileExistsAtPath:assetPath]) {
- assetPath = [NSString stringWithCString: model_path.c_str() encoding:NSUTF8StringEncoding];
+ assetPath = [NSString stringWithCString:model_path.c_str() encoding:NSUTF8StringEncoding];
}
- NSURL *url = [NSURL fileURLWithPath:assetPath];
+ NSURL* url = [NSURL fileURLWithPath:assetPath];
- MLModel *model = [MLModel modelWithContentsOfURL:url error:nil];
+ MLModel* model = [MLModel modelWithContentsOfURL:url error:nil];
if (model == nil) {
NSLog(@"modelc %@ not found", url);
}
return model;
}
-void CoreMLRuntime::Init(const std::string& model_path,
- TVMContext ctx,
- const std::vector<NSString *>& output_names) {
+void CoreMLRuntime::Init(const std::string& model_path, TVMContext ctx,
+ const std::vector<NSString*>& output_names) {
model_ = load_coreml_model(model_path);
ctx_ = ctx;
input_dict_ = [NSMutableDictionary dictionary];
}
void CoreMLRuntime::Invoke() {
- id<MLFeatureProvider> input = [[MLDictionaryFeatureProvider alloc] initWithDictionary:input_dict_ error:nil];
+ id<MLFeatureProvider> input = [[MLDictionaryFeatureProvider alloc] initWithDictionary:input_dict_
+ error:nil];
output_ = [model_ predictionFromFeatures:input error:nil];
}
void CoreMLRuntime::SetInput(const std::string& key, DLTensor* data_in) {
int64_t size = 1;
- NSMutableArray *shape = [[NSMutableArray alloc] init];
+ NSMutableArray* shape = [[NSMutableArray alloc] init];
for (int64_t i = 0; i < data_in->ndim; ++i) {
size *= data_in->shape[i];
[shape addObject:[NSNumber numberWithInteger:data_in->shape[i]]];
return;
}
- MLMultiArray *dest = [[MLMultiArray alloc] initWithShape:shape
- dataType:dataType error:nil];
+ MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:nil];
CHECK(data_in->strides == NULL);
memcpy(dest.dataPointer, data_in->data, size);
- NSString *nsKey = [NSString stringWithUTF8String:key.c_str()];
+ NSString* nsKey = [NSString stringWithUTF8String:key.c_str()];
[input_dict_ setObject:dest forKey:nsKey];
}
NDArray CoreMLRuntime::GetOutput(int index) const {
- NSString *name = output_names_[index];
- MLModelDescription *model_desc = model_.modelDescription;
- MLFeatureDescription *output_desc = model_desc.outputDescriptionsByName[name];
- MLMultiArrayConstraint *data_desc = output_desc.multiArrayConstraint;
+ NSString* name = output_names_[index];
+ MLModelDescription* model_desc = model_.modelDescription;
+ MLFeatureDescription* output_desc = model_desc.outputDescriptionsByName[name];
+ MLMultiArrayConstraint* data_desc = output_desc.multiArrayConstraint;
std::vector<int64_t> shape;
int64_t size = 1;
for (int64_t i = 0; i < data_desc.shape.count; ++i) {
} else {
LOG(FATAL) << "unexpected data type " << data_desc.dataType;
}
- MLMultiArray *src = [output_ featureValueForName:name].multiArrayValue;
+ MLMultiArray* src = [output_ featureValueForName:name].multiArrayValue;
NDArray ret = NDArray::Empty(shape, dtype, ctx_);
ret.CopyFromBytes(src.dataPointer, size);
return ret;
}
-int CoreMLRuntime::GetNumOutputs() const {
- return output_names_.size();
-}
+int CoreMLRuntime::GetNumOutputs() const { return output_names_.size(); }
-PackedFunc CoreMLRuntime::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc CoreMLRuntime::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query.
if (name == "invoke") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- this->Invoke();
- });
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); });
} else if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- const auto& input_name = args[0].operator std::string();
- this->SetInput(input_name, args[1]);
- });
+ const auto& input_name = args[0].operator std::string();
+ this->SetInput(input_name, args[1]);
+ });
} else if (name == "get_output") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetOutput(args[0]);
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); });
} else if (name == "get_num_outputs") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetNumOutputs();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetNumOutputs(); });
} else {
return PackedFunc();
}
}
-Module CoreMLRuntimeCreate(const std::string& model_path,
- TVMContext ctx,
- const std::vector<NSString *>& output_names) {
+Module CoreMLRuntimeCreate(const std::string& model_path, TVMContext ctx,
+ const std::vector<NSString*>& output_names) {
auto exec = make_object<CoreMLRuntime>();
exec->Init(model_path, ctx, output_names);
return Module(exec);
}
-TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- std::vector<NSString *> output_names;
- for (size_t i = 2; i < args.size(); i++) {
- const std::string& name = args[i];
- output_names.push_back([NSString stringWithUTF8String:name.c_str()]);
- }
- *rv = CoreMLRuntimeCreate(args[0], args[1], output_names);
- });
+TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) {
+ std::vector<NSString*> output_names;
+ for (size_t i = 2; i < args.size(); i++) {
+ const std::string& name = args[i];
+ output_names.push_back([NSString stringWithUTF8String:name.c_str()]);
+ }
+ *rv = CoreMLRuntimeCreate(args[0], args[1], output_names);
+});
} // namespace runtime
} // namespace tvm
/*!
* \file Use external cblas library call.
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/data_type.h>
#include <dmlc/logging.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
+
#include "../cblas/gemm_common.h"
#include "cublas_utils.h"
-
namespace tvm {
namespace contrib {
using namespace runtime;
-inline cublasOperation_t BooleanToTranspose(bool item) {
- return item ? CUBLAS_OP_T : CUBLAS_OP_N;
-}
+inline cublasOperation_t BooleanToTranspose(bool item) { return item ? CUBLAS_OP_T : CUBLAS_OP_N; }
inline void TryEnableTensorCore(cublasHandle_t hdl) {
// TensorCores are only supported in cublas 9.0 or higher
int version;
CHECK_CUBLAS_ERROR(cublasGetVersion(hdl, &version));
- if (version >= 9000)
- CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH));
+ if (version >= 9000) CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH));
}
struct CublasHgemmOp {
typedef half TDatatype;
cublasHandle_t handle;
- explicit CublasHgemmOp(cublasHandle_t hdl)
- : handle(hdl) {}
-
- void operator()(bool ta, bool tb,
- int M, int N, int K,
- half alpha, half* A, int lda,
- half* B, int ldb,
- half beta, half* C, int ldc) {
- CHECK_CUBLAS_ERROR(cublasHgemm(handle,
- BooleanToTranspose(ta),
- BooleanToTranspose(tb),
- M, N, K,
- &alpha, A, lda,
- B, ldb,
- &beta, C, ldc));
+ explicit CublasHgemmOp(cublasHandle_t hdl) : handle(hdl) {}
+
+ void operator()(bool ta, bool tb, int M, int N, int K, half alpha, half* A, int lda, half* B,
+ int ldb, half beta, half* C, int ldc) {
+ CHECK_CUBLAS_ERROR(cublasHgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K,
+ &alpha, A, lda, B, ldb, &beta, C, ldc));
}
};
struct CublasSgemmOp {
typedef float TDatatype;
cublasHandle_t handle;
- explicit CublasSgemmOp(cublasHandle_t hdl)
- : handle(hdl) {}
-
- void operator()(bool ta, bool tb,
- int M, int N, int K,
- float alpha, float* A, int lda,
- float* B, int ldb,
- float beta, float* C, int ldc) {
- CHECK_CUBLAS_ERROR(cublasSgemm(handle,
- BooleanToTranspose(ta),
- BooleanToTranspose(tb),
- M, N, K,
- &alpha, A, lda,
- B, ldb,
- &beta, C, ldc));
+ explicit CublasSgemmOp(cublasHandle_t hdl) : handle(hdl) {}
+
+ void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
+ int ldb, float beta, float* C, int ldc) {
+ CHECK_CUBLAS_ERROR(cublasSgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K,
+ &alpha, A, lda, B, ldb, &beta, C, ldc));
}
};
struct CublasDgemmOp {
typedef double TDatatype;
cublasHandle_t handle;
- explicit CublasDgemmOp(cublasHandle_t hdl)
- : handle(hdl) {}
- void operator()(bool ta, bool tb,
- int M, int N, int K,
- double alpha, double* A, int lda,
- double* B, int ldb,
- double beta, double* C, int ldc) {
- CHECK_CUBLAS_ERROR(cublasDgemm(handle,
- BooleanToTranspose(ta),
- BooleanToTranspose(tb),
- M, N, K,
- &alpha, A, lda,
- B, ldb,
- &beta, C, ldc));
+ explicit CublasDgemmOp(cublasHandle_t hdl) : handle(hdl) {}
+ void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda,
+ double* B, int ldb, double beta, double* C, int ldc) {
+ CHECK_CUBLAS_ERROR(cublasDgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K,
+ &alpha, A, lda, B, ldb, &beta, C, ldc));
}
};
struct CublasHgemmBatchOp {
typedef half TDatatype;
cublasHandle_t handle;
- explicit CublasHgemmBatchOp(cublasHandle_t hdl)
- : handle(hdl) {}
+ explicit CublasHgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, half alpha, half* A,
int a_stride, int lda, half* B, int b_stride, int ldb, half beta, half* C,
int c_stride, int ldc) {
- CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched(handle,
- BooleanToTranspose(ta),
- BooleanToTranspose(tb),
- M, N, K,
- &alpha,
- A, lda, a_stride,
- B, ldb, b_stride,
- &beta,
- C, ldc, c_stride,
- batch_size));
+ CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched(
+ handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride,
+ B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size));
}
};
struct CublasSgemmBatchOp {
typedef float TDatatype;
cublasHandle_t handle;
- explicit CublasSgemmBatchOp(cublasHandle_t hdl)
- : handle(hdl) {}
+ explicit CublasSgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) {
- CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle,
- BooleanToTranspose(ta),
- BooleanToTranspose(tb),
- M, N, K,
- &alpha,
- A, lda, a_stride,
- B, ldb, b_stride,
- &beta,
- C, ldc, c_stride,
- batch_size));
+ CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(
+ handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride,
+ B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size));
}
};
struct CublasDgemmBatchOp {
typedef double TDatatype;
cublasHandle_t handle;
- explicit CublasDgemmBatchOp(cublasHandle_t hdl)
- : handle(hdl) {}
+ explicit CublasDgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
- CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle,
- BooleanToTranspose(ta),
- BooleanToTranspose(tb),
- M, N, K,
- &alpha,
- A, lda, a_stride,
- B, ldb, b_stride,
- &beta,
- C, ldc, c_stride,
- batch_size));
+ CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(
+ handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride,
+ B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size));
}
};
if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
return TypeMatch(in_dtype, kDLInt, 8);
} else if (TypeMatch(out_dtype, kDLFloat, 32)) {
- return TypeMatch(in_dtype, kDLInt, 8) ||
- TypeMatch(in_dtype, kDLFloat, 16);
+ return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16);
} else {
return false;
}
}
-int roundoff(int v, int d) {
- return (v + d - 1) / d * d;
-}
+int roundoff(int v, int d) { return (v + d - 1) / d * d; }
#if CUDART_VERSION >= 10010
-inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) {
- DLTensor *A = args[0];
- DLTensor *B = args[1];
- DLTensor *C = args[2];
+inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) {
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
// Reversed strides indicates an in-place transpose operation.
cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
cublasLtMatmulDesc_t operationDesc = nullptr;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I));
- CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
- operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose)));
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
+ &opTranspose, sizeof(opTranspose)));
cublasOperation_t opTransA = BooleanToTranspose(transa);
cublasOperation_t opTransB = BooleanToTranspose(transb);
- CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
- operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA)));
- CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
- operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB)));
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
+ &opTransA, sizeof(opTransA)));
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
+ &opTransB, sizeof(opTransB)));
// Create descriptors for the original matrices
- CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(
- &Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k ,
- opTransA == CUBLAS_OP_N ? k : m, lda));
- CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(
- &Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n ,
- opTransB == CUBLAS_OP_N ? n : k, ldb));
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k,
+ opTransA == CUBLAS_OP_N ? k : m, lda));
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n,
+ opTransB == CUBLAS_OP_N ? n : k, ldb));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &order_COL32, sizeof(order_COL32)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
- Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)));
- CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
- Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C)));
- CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
- Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)));
-
- CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl,
- operationDesc,
- &alpha,
- B_data,
- Adesc,
- A_data,
- Bdesc,
- &beta,
- C_data,
- Cdesc,
- C_data,
- Cdesc,
- NULL,
- NULL,
- 0,
- 0));
+ Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C)));
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &order_COL32, sizeof(order_COL32)));
+
+ CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, &alpha, B_data, Adesc, A_data, Bdesc, &beta,
+ C_data, Cdesc, C_data, Cdesc, NULL, NULL, 0, 0));
}
#endif
-inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
- DLTensor *A = args[0];
- DLTensor *B = args[1];
- DLTensor *C = args[2];
+inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) {
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
CHECK_EQ(A->ndim, 2);
transb = IsInPlaceTransposed(B) ? !transb : transb;
CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type";
- CHECK(!TypeMatch(A->dtype, kDLInt, 8) ||
- ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
- CHECK(!TypeMatch(B->dtype, kDLInt, 8) ||
- ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
+ CHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0)
+ << "leading dimension must divide 4 for int8 gemm";
+ CHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0)
+ << "leading dimension must divide 4 for int8 gemm";
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
beta_ptr = &beta_float;
}
- auto A_data = reinterpret_cast<void *>(static_cast<char *>(A->data) + A->byte_offset);
- auto B_data = reinterpret_cast<void *>(static_cast<char *>(B->data) + B->byte_offset);
- auto C_data = reinterpret_cast<void *>(static_cast<char *>(C->data) + C->byte_offset);
-
- CHECK_CUBLAS_ERROR(cublasGemmEx(hdl,
- BooleanToTranspose(transb),
- BooleanToTranspose(transa),
- ColumnCount(B, transb),
- RowCount(A, transa),
- ColumnCount(A, transa),
- alpha_ptr,
- B_data, cuda_in_type, ColumnStride(B),
- A_data, cuda_in_type, ColumnStride(A),
- beta_ptr,
- C_data, cuda_out_type, ColumnStride(C),
- cuda_out_type, algo));
+ auto A_data = reinterpret_cast<void*>(static_cast<char*>(A->data) + A->byte_offset);
+ auto B_data = reinterpret_cast<void*>(static_cast<char*>(B->data) + B->byte_offset);
+ auto C_data = reinterpret_cast<void*>(static_cast<char*>(C->data) + C->byte_offset);
+
+ CHECK_CUBLAS_ERROR(cublasGemmEx(hdl, BooleanToTranspose(transb), BooleanToTranspose(transa),
+ ColumnCount(B, transb), RowCount(A, transa),
+ ColumnCount(A, transa), alpha_ptr, B_data, cuda_in_type,
+ ColumnStride(B), A_data, cuda_in_type, ColumnStride(A), beta_ptr,
+ C_data, cuda_out_type, ColumnStride(C), cuda_out_type, algo));
}
-inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
- DLTensor *A = args[0];
- DLTensor *B = args[1];
- DLTensor *C = args[2];
+inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) {
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
CHECK_EQ(A->ndim, 3);
transb = IsInPlaceTransposed(B) ? !transb : transb;
CHECK(CheckMixPrecisionType(A->dtype, C->dtype, false)) << "Unsupported data type";
- CHECK(!TypeMatch(A->dtype, kDLInt, 8) ||
- ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
- CHECK(!TypeMatch(B->dtype, kDLInt, 8) ||
- ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
+ CHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0)
+ << "leading dimension must divide 4 for int8 gemm";
+ CHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0)
+ << "leading dimension must divide 4 for int8 gemm";
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
beta_ptr = &beta_float;
}
- auto A_data = reinterpret_cast<void *>(static_cast<char *>(A->data) + A->byte_offset);
- auto B_data = reinterpret_cast<void *>(static_cast<char *>(B->data) + B->byte_offset);
- auto C_data = reinterpret_cast<void *>(static_cast<char *>(C->data) + C->byte_offset);
- CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx(hdl,
- BooleanToTranspose(transb),
- BooleanToTranspose(transa),
- ColumnCount3D(B, transb),
- RowCount3D(A, transa),
- ColumnCount3D(A, transa),
- alpha_ptr,
- B_data, cuda_in_type, ColumnStride3D(B), B_size,
- A_data, cuda_in_type, ColumnStride3D(A), A_size,
- beta_ptr,
- C_data, cuda_out_type, ColumnStride3D(C), C_size,
- batch_size, cuda_out_type, algo));
+ auto A_data = reinterpret_cast<void*>(static_cast<char*>(A->data) + A->byte_offset);
+ auto B_data = reinterpret_cast<void*>(static_cast<char*>(B->data) + B->byte_offset);
+ auto C_data = reinterpret_cast<void*>(static_cast<char*>(C->data) + C->byte_offset);
+ CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx(
+ hdl, BooleanToTranspose(transb), BooleanToTranspose(transa), ColumnCount3D(B, transb),
+ RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data, cuda_in_type,
+ ColumnStride3D(B), B_size, A_data, cuda_in_type, ColumnStride3D(A), A_size, beta_ptr, C_data,
+ cuda_out_type, ColumnStride3D(C), C_size, batch_size, cuda_out_type, algo));
}
// matrix multiplication for row major
-TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLTensor* A = args[0];
- DLTensor* C = args[2];
+TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ DLTensor* C = args[2];
- CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
+ CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
- TryEnableTensorCore(entry_ptr->handle);
+ TryEnableTensorCore(entry_ptr->handle);
- if (TypeEqual(A->dtype, C->dtype)) {
- CHECK(TypeMatch(A->dtype, kDLFloat, 16) ||
- TypeMatch(A->dtype, kDLFloat, 32) ||
+ if (TypeEqual(A->dtype, C->dtype)) {
+ CHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));
- if (TypeMatch(A->dtype, kDLFloat, 16))
- CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle));
- else if (TypeMatch(A->dtype, kDLFloat, 32))
- CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle));
- else
- CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
- } else {
- CallGemmEx(args, ret, entry_ptr->handle);
- }
+ if (TypeMatch(A->dtype, kDLFloat, 16))
+ CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle));
+ else if (TypeMatch(A->dtype, kDLFloat, 32))
+ CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle));
+ else
+ CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
+ } else {
+ CallGemmEx(args, ret, entry_ptr->handle);
+ }
});
#if CUDART_VERSION >= 10010
-TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- DLTensor* A = args[0];
+TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
- CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
+ CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
- TryEnableTensorCore(entry_ptr->handle);
+ TryEnableTensorCore(entry_ptr->handle);
- CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n";
- cublasLtHandle_t ltHandle;
- CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle));
- CallLtIgemm(args, ret, ltHandle);
- CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle));
+ CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n";
+ cublasLtHandle_t ltHandle;
+ CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle));
+ CallLtIgemm(args, ret, ltHandle);
+ CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle));
});
#endif // CUDART_VERSION >= 10010
-TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- DLTensor* A = args[0];
- DLTensor* C = args[2];
+TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ DLTensor* C = args[2];
- CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
+ CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
- TryEnableTensorCore(entry_ptr->handle);
- if (TypeEqual(A->dtype, C->dtype)) {
- CHECK(TypeMatch(A->dtype, kDLFloat, 16) ||
- TypeMatch(A->dtype, kDLFloat, 32) ||
+ TryEnableTensorCore(entry_ptr->handle);
+ if (TypeEqual(A->dtype, C->dtype)) {
+ CHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));
- if (TypeMatch(A->dtype, kDLFloat, 16))
- CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle));
- else if (TypeMatch(A->dtype, kDLFloat, 32))
- CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle));
- else
- CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle));
- } else {
- CallBatchGemmEx(args, ret, entry_ptr->handle);
- }
+ if (TypeMatch(A->dtype, kDLFloat, 16))
+ CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle));
+ else if (TypeMatch(A->dtype, kDLFloat, 32))
+ CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle));
+ else
+ CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle));
+ } else {
+ CallBatchGemmEx(args, ret, entry_ptr->handle);
+ }
});
} // namespace contrib
* \file Use external cudnn utils function
*/
#include "cublas_utils.h"
+
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
+
#include "../../cuda/cuda_common.h"
namespace tvm {
namespace contrib {
-
-CuBlasThreadEntry::CuBlasThreadEntry() {
- CHECK_CUBLAS_ERROR(cublasCreate(&handle));
-}
-
+CuBlasThreadEntry::CuBlasThreadEntry() { CHECK_CUBLAS_ERROR(cublasCreate(&handle)); }
CuBlasThreadEntry::~CuBlasThreadEntry() {
if (handle) {
}
}
-
typedef dmlc::ThreadLocalStore<CuBlasThreadEntry> CuBlasThreadStore;
-
CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() {
auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
CuBlasThreadEntry* retval = CuBlasThreadStore::Get();
return retval;
}
-
} // namespace contrib
} // namespace tvm
#ifndef TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_
#define TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_
-#include <dmlc/logging.h>
-#include <dlpack/dlpack.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
+#include <dlpack/dlpack.h>
+#include <dmlc/logging.h>
+
#include <cstdint>
#if CUDART_VERSION >= 10010
#include <cublasLt.h>
inline const char* GetCublasErrorString(int error) {
switch (error) {
- case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
- case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
- case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
- case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
- case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
- case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
- case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
- case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
- case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
+ case CUBLAS_STATUS_NOT_INITIALIZED:
+ return "CUBLAS_STATUS_NOT_INITIALIZED";
+ case CUBLAS_STATUS_ALLOC_FAILED:
+ return "CUBLAS_STATUS_ALLOC_FAILED";
+ case CUBLAS_STATUS_INVALID_VALUE:
+ return "CUBLAS_STATUS_INVALID_VALUE";
+ case CUBLAS_STATUS_ARCH_MISMATCH:
+ return "CUBLAS_STATUS_ARCH_MISMATCH";
+ case CUBLAS_STATUS_MAPPING_ERROR:
+ return "CUBLAS_STATUS_MAPPING_ERROR";
+ case CUBLAS_STATUS_EXECUTION_FAILED:
+ return "CUBLAS_STATUS_EXECUTION_FAILED";
+ case CUBLAS_STATUS_INTERNAL_ERROR:
+ return "CUBLAS_STATUS_INTERNAL_ERROR";
+ case CUBLAS_STATUS_NOT_SUPPORTED:
+ return "CUBLAS_STATUS_NOT_SUPPORTED";
+ case CUBLAS_STATUS_LICENSE_ERROR:
+ return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "Unrecognized error";
}
#ifndef CHECK_CUBLAS_ERROR
-#define CHECK_CUBLAS_ERROR(fn) \
- do { \
- int error = static_cast<int>(fn); \
+#define CHECK_CUBLAS_ERROR(fn) \
+ do { \
+ int error = static_cast<int>(fn); \
CHECK_EQ(error, CUBLAS_STATUS_SUCCESS) << "CUBLAS: " << GetCublasErrorString(error); \
} while (0) // ; intentionally left off.
-#endif // CHECK_CUBLAS_ERROR
-
+#endif // CHECK_CUBLAS_ERROR
struct CuBlasThreadEntry {
CuBlasThreadEntry();
inline cudaDataType_t GetCudaDataType(DLDataType type) {
if (type.code == kDLInt) {
switch (type.bits) {
- case 8: return CUDA_R_8I;
- case 32: return CUDA_R_32I;
+ case 8:
+ return CUDA_R_8I;
+ case 32:
+ return CUDA_R_32I;
}
} else if (type.code == kDLUInt) {
switch (type.bits) {
- case 8: return CUDA_R_8U;
- case 32: return CUDA_R_32U;
+ case 8:
+ return CUDA_R_8U;
+ case 32:
+ return CUDA_R_32U;
}
} else if (type.code == kDLFloat) {
switch (type.bits) {
- case 16: return CUDA_R_16F;
- case 32: return CUDA_R_32F;
- case 64: return CUDA_R_64F;
+ case 16:
+ return CUDA_R_16F;
+ case 32:
+ return CUDA_R_32F;
+ case 64:
+ return CUDA_R_64F;
}
}
LOG(FATAL) << "Unsupported cuda type";
/*!
* \file Use external cudnn utils function
*/
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include "cudnn_utils.h"
namespace tvm {
using namespace runtime;
-void ConvolutionForward(
- int mode,
- int format,
- int algo,
- int dims,
- int groups,
- const int pad[],
- const int stride[],
- const int dilation[],
- DLTensor* x,
- DLTensor* w,
- DLTensor* y,
- const std::string& conv_dtype) {
+void ConvolutionForward(int mode, int format, int algo, int dims, int groups, const int pad[],
+ const int stride[], const int dilation[], DLTensor* x, DLTensor* w,
+ DLTensor* y, const std::string& conv_dtype) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
if (dims == 2) {
// Set Desc
- CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
- pad[0],
- pad[1],
- stride[0],
- stride[1],
- dilation[0],
- dilation[1],
- entry_ptr->conv_entry.mode,
- entry_ptr->conv_entry.data_type));
+ CUDNN_CALL(cudnnSetConvolution2dDescriptor(
+ entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0],
+ dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type));
int ni, ci, hi, wi;
- if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
+ if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
ni = 0;
ci = 3;
hi = 1;
}
// Set Filter
- CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
- data_type,
- entry_ptr->conv_entry.tensor_format,
- static_cast<int>(w->shape[ni]),
- static_cast<int>(w->shape[ci]),
- static_cast<int>(w->shape[hi]),
- static_cast<int>(w->shape[wi])));
+ CUDNN_CALL(cudnnSetFilter4dDescriptor(
+ entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format,
+ static_cast<int>(w->shape[ni]), static_cast<int>(w->shape[ci]),
+ static_cast<int>(w->shape[hi]), static_cast<int>(w->shape[wi])));
// Set Input
- CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.tensor_format,
- data_type,
- static_cast<int>(x->shape[ni]),
- static_cast<int>(x->shape[ci]),
- static_cast<int>(x->shape[hi]),
- static_cast<int>(x->shape[wi])));
+ CUDNN_CALL(cudnnSetTensor4dDescriptor(
+ entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type,
+ static_cast<int>(x->shape[ni]), static_cast<int>(x->shape[ci]),
+ static_cast<int>(x->shape[hi]), static_cast<int>(x->shape[wi])));
// Set Output
- CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
- entry_ptr->conv_entry.tensor_format,
- data_type,
- static_cast<int>(y->shape[ni]),
- static_cast<int>(y->shape[ci]),
- static_cast<int>(y->shape[hi]),
- static_cast<int>(y->shape[wi])));
+ CUDNN_CALL(cudnnSetTensor4dDescriptor(
+ entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type,
+ static_cast<int>(y->shape[ni]), static_cast<int>(y->shape[ci]),
+ static_cast<int>(y->shape[hi]), static_cast<int>(y->shape[wi])));
} else {
- CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
- dims,
- pad,
- stride,
- dilation,
- entry_ptr->conv_entry.mode,
+ CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
+ dilation, entry_ptr->conv_entry.mode,
entry_ptr->conv_entry.data_type));
// Set Filter
for (int i = 0; i < full_dims; i++) {
dim[i] = static_cast<int>(w->shape[i]);
}
- CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc,
- data_type,
- entry_ptr->conv_entry.tensor_format,
- full_dims,
+ CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
+ entry_ptr->conv_entry.tensor_format, full_dims,
dim.data()));
// Set Input
for (int i = 0; i < full_dims; i++) {
dim[i] = static_cast<int>(x->shape[i]);
}
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
- CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
- data_type,
- full_dims,
- dim.data(),
- tensor_stride.data()));
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
+ dim.data(), tensor_stride.data()));
// Set Output
for (int i = 0; i < full_dims; i++) {
dim[i] = static_cast<int>(y->shape[i]);
}
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
- CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc,
- data_type,
- full_dims,
- dim.data(),
- tensor_stride.data()));
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
+ dim.data(), tensor_stride.data()));
}
if (cudnnGetVersion() > 7000) {
// Set workspace
size_t workspace_size = 0;
- CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle,
- entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.filter_desc,
- entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.output_desc,
- entry_ptr->conv_entry.fwd_algo,
- &workspace_size));
+ CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
+ entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
+ entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
+ entry_ptr->conv_entry.fwd_algo, &workspace_size));
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
- CUDNN_CALL(cudnnConvolutionForward(entry_ptr->handle,
- CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
- entry_ptr->conv_entry.input_desc,
- x->data,
- entry_ptr->conv_entry.filter_desc,
- w->data,
- entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.fwd_algo,
- entry_ptr->conv_entry.workspace,
- workspace_size,
- CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
- entry_ptr->conv_entry.output_desc,
- y->data));
+ CUDNN_CALL(cudnnConvolutionForward(
+ entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
+ entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data,
+ entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
+ entry_ptr->conv_entry.workspace, workspace_size,
+ CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
+ entry_ptr->conv_entry.output_desc, y->data));
}
-
-void OutputShape(
- int format,
- int dims,
- int groups,
- const int pad[],
- const int stride[],
- const int dilation[],
- const int x_dim[],
- const int w_dim[],
- void *out_shape,
- const std::string& data_dtype,
- const std::string& conv_dtype) {
+void OutputShape(int format, int dims, int groups, const int pad[], const int stride[],
+ const int dilation[], const int x_dim[], const int w_dim[], void* out_shape,
+ const std::string& data_dtype, const std::string& conv_dtype) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Data Type
// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
- CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
- dims,
- pad,
- stride,
- dilation,
- CUDNN_CROSS_CORRELATION,
+ CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
+ dilation, CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));
- if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
+ if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.tensor_format,
- data_type,
- x_dim[0],
- x_dim[3],
- x_dim[1],
- x_dim[2]));
+ entry_ptr->conv_entry.tensor_format, data_type, x_dim[0],
+ x_dim[3], x_dim[1], x_dim[2]));
// filter desc
- CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
- data_type,
- entry_ptr->conv_entry.tensor_format,
- w_dim[0],
- w_dim[3],
- w_dim[1],
- w_dim[2]));
-
- CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.filter_desc,
- static_cast<int*>(out_shape),
- static_cast<int*>(out_shape) + 3,
- static_cast<int*>(out_shape) + 1,
- static_cast<int*>(out_shape) + 2));
+ CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
+ entry_ptr->conv_entry.tensor_format, w_dim[0], w_dim[3],
+ w_dim[1], w_dim[2]));
+
+ CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(
+ entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
+ entry_ptr->conv_entry.filter_desc, static_cast<int*>(out_shape),
+ static_cast<int*>(out_shape) + 3, static_cast<int*>(out_shape) + 1,
+ static_cast<int*>(out_shape) + 2));
} else {
// Set Input
std::vector<int> tensor_stride(full_dims);
GetCudnnStride(full_dims, x_dim, tensor_stride.data());
- CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
- data_type,
- full_dims,
- x_dim,
- tensor_stride.data()));
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
+ x_dim, tensor_stride.data()));
// filter desc
- CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc,
- data_type,
- entry_ptr->conv_entry.tensor_format,
- full_dims,
- w_dim));
-
- CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.filter_desc,
- full_dims,
- static_cast<int*>(out_shape)));
+ CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
+ entry_ptr->conv_entry.tensor_format, full_dims, w_dim));
+
+ CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(
+ entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
+ entry_ptr->conv_entry.filter_desc, full_dims, static_cast<int*>(out_shape)));
}
}
-
-void FindAlgo(
- int format,
- int dims,
- int groups,
- const int pad[],
- const int stride[],
- const int dilation[],
- const int x_dim[],
- const int w_dim[],
- const int y_dim[],
- const std::string& data_dtype,
- const std::string& conv_dtype,
- TVMRetValue *ret) {
+void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
+ const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
+ const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Data Type
// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
- CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
- dims,
- pad,
- stride,
- dilation,
- CUDNN_CROSS_CORRELATION,
+ CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
+ dilation, CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));
std::vector<int> tensor_stride(full_dims);
// input desc
GetCudnnStride(full_dims, x_dim, tensor_stride.data());
- CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
- data_type,
- full_dims,
- x_dim,
- tensor_stride.data()));
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
+ x_dim, tensor_stride.data()));
// filter desc
- CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc,
- data_type,
- entry_ptr->conv_entry.tensor_format,
- full_dims,
- w_dim));
+ CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
+ entry_ptr->conv_entry.tensor_format, full_dims, w_dim));
// output desc
GetCudnnStride(full_dims, y_dim, tensor_stride.data());
- CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc,
- data_type,
- full_dims,
- y_dim,
- tensor_stride.data()));
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
+ y_dim, tensor_stride.data()));
if (cudnnGetVersion() > 7000) {
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
}
int returned_algo_count = 0;
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
- CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle,
- entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.filter_desc,
- entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.output_desc,
- CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
- &returned_algo_count,
- perf_results));
-
- const std::vector<std::string> fwd_algo_names{
- "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
- "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
- "CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
- "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
- "CUDNN_CONVOLUTION_FWD_ALGO_FFT",
- "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
- "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
- "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"
- };
+ CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(
+ entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
+ entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
+ CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results));
+
+ const std::vector<std::string> fwd_algo_names{"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
+ "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
+ "CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
+ "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
+ "CUDNN_CONVOLUTION_FWD_ALGO_FFT",
+ "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
+ "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
+ "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"};
auto best_algo = perf_results[0].algo;
- LOG(INFO) << "\tCUDNN Found " << returned_algo_count
- << " fwd algorithms, choosing " << fwd_algo_names[best_algo];
+ LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing "
+ << fwd_algo_names[best_algo];
for (int i = 0; i < returned_algo_count; ++i) {
LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
<< " - time: " << perf_results[i].time << " ms"
ret[0] = best_algo;
}
-
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- int mode = args[0];
- int format = args[1];
- int algo = args[2];
- int pad_v[2], stride_v[2], dilation_v[2];
- for (int i = 0; i < 2; i++) {
- pad_v[i] = args[3 + i];
- stride_v[i] = args[5 + i];
- dilation_v[i] = args[7 + i];
- }
- DLTensor* x = args[9];
- DLTensor* w = args[10];
- DLTensor* y = args[11];
- std::string conv_dtype = args[12];
- int groups = args[13];
-
- ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v,
- dilation_v, x, w, y, conv_dtype);
-});
-
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ int mode = args[0];
+ int format = args[1];
+ int algo = args[2];
+ int pad_v[2], stride_v[2], dilation_v[2];
+ for (int i = 0; i < 2; i++) {
+ pad_v[i] = args[3 + i];
+ stride_v[i] = args[5 + i];
+ dilation_v[i] = args[7 + i];
+ }
+ DLTensor* x = args[9];
+ DLTensor* w = args[10];
+ DLTensor* y = args[11];
+ std::string conv_dtype = args[12];
+ int groups = args[13];
+
+ ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y,
+ conv_dtype);
+ });
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- int mode = args[0];
- int format = args[1];
- int algo = args[2];
- int pad_v[3], stride_v[3], dilation_v[3];
- for (int i = 0; i < 3; i++) {
- pad_v[i] = args[3 + i];
- stride_v[i] = args[6 + i];
- dilation_v[i] = args[9 + i];
- }
- DLTensor *x = args[12];
- DLTensor *w = args[13];
- DLTensor *y = args[14];
- std::string conv_dtype = args[15];
- int groups = args[16];
-
- ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v,
- dilation_v, x, w, y, conv_dtype);
-});
-
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ int mode = args[0];
+ int format = args[1];
+ int algo = args[2];
+ int pad_v[3], stride_v[3], dilation_v[3];
+ for (int i = 0; i < 3; i++) {
+ pad_v[i] = args[3 + i];
+ stride_v[i] = args[6 + i];
+ dilation_v[i] = args[9 + i];
+ }
+ DLTensor* x = args[12];
+ DLTensor* w = args[13];
+ DLTensor* y = args[14];
+ std::string conv_dtype = args[15];
+ int groups = args[16];
+
+ ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, dilation_v, x, w, y,
+ conv_dtype);
+ });
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- int format = args[0];
- int dims = args[1];
- int* pad = static_cast<int*>(static_cast<void*>(args[2]));
- int* stride = static_cast<int*>(static_cast<void*>(args[3]));
- int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
- int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
- int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
- void* out_shape = args[7];
- std::string data_dtype = args[8];
- std::string conv_dtype = args[9];
- int groups = args[10];
-
- OutputShape(format, dims, groups, pad, stride, dilation, x_dim,
- w_dim, out_shape, data_dtype, conv_dtype);
-});
-
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ int format = args[0];
+ int dims = args[1];
+ int* pad = static_cast<int*>(static_cast<void*>(args[2]));
+ int* stride = static_cast<int*>(static_cast<void*>(args[3]));
+ int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
+ int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
+ int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
+ void* out_shape = args[7];
+ std::string data_dtype = args[8];
+ std::string conv_dtype = args[9];
+ int groups = args[10];
+
+ OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype,
+ conv_dtype);
+ });
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- int format = args[0];
- int dims = args[1];
- int* pad = static_cast<int*>(static_cast<void*>(args[2]));
- int* stride = static_cast<int*>(static_cast<void*>(args[3]));
- int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
- int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
- int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
- int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
- std::string data_dtype = args[8];
- std::string conv_dtype = args[9];
- int groups = args[10];
-
- FindAlgo(format, dims, groups, pad, stride, dilation, x_dim,
- w_dim, y_dim, data_dtype, conv_dtype, ret);
-});
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ int format = args[0];
+ int dims = args[1];
+ int* pad = static_cast<int*>(static_cast<void*>(args[2]));
+ int* stride = static_cast<int*>(static_cast<void*>(args[3]));
+ int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
+ int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
+ int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
+ int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
+ std::string data_dtype = args[8];
+ std::string conv_dtype = args[9];
+ int groups = args[10];
+
+ FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, data_dtype,
+ conv_dtype, ret);
+ });
} // namespace contrib
} // namespace tvm
* \file Use external cudnn utils function
*/
#include "cudnn_utils.h"
+
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
-
namespace tvm {
namespace contrib {
// CuDNN Data Type
-cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType &dtype) {
+cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType& dtype) {
switch (dtype.code) {
- case kDLInt:
- if (dtype.bits == 8 && dtype.lanes == 1) return CUDNN_DATA_INT8;
- else if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_INT32;
- else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4;
- else
- LOG(FATAL) << "Unsupported type";
- break;
- case kDLUInt:
+ case kDLInt:
+ if (dtype.bits == 8 && dtype.lanes == 1)
+ return CUDNN_DATA_INT8;
+ else if (dtype.bits == 32 && dtype.lanes == 1)
+ return CUDNN_DATA_INT32;
+ else if (dtype.bits == 8 && dtype.lanes == 4)
+ return CUDNN_DATA_INT8x4;
+ else
LOG(FATAL) << "Unsupported type";
- break;
- case kDLFloat:
- if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_FLOAT;
- else if (dtype.bits == 64 && dtype.lanes == 1) return CUDNN_DATA_DOUBLE;
- else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF;
- else
- LOG(FATAL) << "Unsupported type";
- break;
- }
- return CUDNN_DATA_FLOAT;
+ break;
+ case kDLUInt:
+ LOG(FATAL) << "Unsupported type";
+ break;
+ case kDLFloat:
+ if (dtype.bits == 32 && dtype.lanes == 1)
+ return CUDNN_DATA_FLOAT;
+ else if (dtype.bits == 64 && dtype.lanes == 1)
+ return CUDNN_DATA_DOUBLE;
+ else if (dtype.bits == 16 && dtype.lanes == 1)
+ return CUDNN_DATA_HALF;
+ else
+ LOG(FATAL) << "Unsupported type";
+ break;
+ }
+ return CUDNN_DATA_FLOAT;
}
-template<>
+template <>
const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) {
static const int int_v = 0;
static const float float_v = 0;
return nullptr;
}
-template<>
+template <>
const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) {
static const int int_v = 1;
static const float float_v = 1.f;
CuDNNThreadEntry::CuDNNThreadEntry() {
auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
auto func = runtime::Registry::Get("device_api.gpu");
- void *ret = (*func)();
+ void* ret = (*func)();
cuda_api = static_cast<runtime::DeviceAPI*>(ret);
CUDNN_CALL(cudnnCreate(&handle));
CUDNN_CALL(cudnnSetStream(handle, stream));
conv_entry.cuda_api = cuda_api;
}
-CuDNNThreadEntry::~CuDNNThreadEntry() {
- CUDNN_CALL(cudnnDestroy(handle));
-}
+CuDNNThreadEntry::~CuDNNThreadEntry() { CUDNN_CALL(cudnnDestroy(handle)); }
typedef dmlc::ThreadLocalStore<CuDNNThreadEntry> CuDNNThreadStore;
-CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() {
- return CuDNNThreadStore::Get();
-}
+CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() { return CuDNNThreadStore::Get(); }
// ConvEntry
// SoftmaxEntry
-SoftmaxEntry::SoftmaxEntry() {
- CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc));
-}
+SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); }
-SoftmaxEntry::~SoftmaxEntry() {
- CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc));
-}
+SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); }
} // namespace contrib
} // namespace tvm
#ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_
#define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_
-#include <dmlc/logging.h>
#include <cudnn.h>
+#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
-#include "../../cuda/cuda_common.h"
+#include "../../cuda/cuda_common.h"
namespace tvm {
namespace contrib {
/*! breif Convert DLTensor type to CuDNN type */
struct CuDNNDataType {
- static cudnnDataType_t DLTypeToCuDNNType(const DLDataType &dtype);
- template<int v>
+ static cudnnDataType_t DLTypeToCuDNNType(const DLDataType& dtype);
+ template <int v>
static const void* GetConst(cudnnDataType_t type);
}; // struct CuDNNDataType
-inline void GetStride(int nbdim, const int *dims, int *strides) {
+inline void GetStride(int nbdim, const int* dims, int* strides) {
int mul = 1;
- for (int i = nbdim - 1; i >=0; --i) {
+ for (int i = nbdim - 1; i >= 0; --i) {
mul *= dims[i];
strides[i] = mul;
}
}
-inline void GetCudnnStride(int nbdim,
- const int* dims,
- int* strides) {
+inline void GetCudnnStride(int nbdim, const int* dims, int* strides) {
int mul = 1;
- for (int i = nbdim - 1; i >=0; --i) {
+ for (int i = nbdim - 1; i >= 0; --i) {
strides[i] = mul;
mul *= dims[i];
}
cudnnConvolutionFwdAlgo_t fwd_algo;
// cudnnMathType_t math_type;
TVMContext ctx;
- runtime::DeviceAPI *cuda_api;
- void *workspace{nullptr};
+ runtime::DeviceAPI* cuda_api;
+ void* workspace{nullptr};
size_t workspace_size{0};
ConvEntry();
~ConvEntry();
cudnnHandle_t handle{nullptr};
ConvEntry conv_entry;
SoftmaxEntry softmax_entry;
- runtime::DeviceAPI *cuda_api{nullptr};
+ runtime::DeviceAPI* cuda_api{nullptr};
static CuDNNThreadEntry* ThreadLocal();
}; // CuDNNThreadEntry
* \file src/runtime/contrib/cudnn/softmax.cc
* \brief Use external cudnn softmax function
*/
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include "cudnn_utils.h"
namespace tvm {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLTensor* x = args[0];
- DLTensor* y = args[1];
- int axis = args[2];
- int ndim = x->ndim;
- int64_t* shape = x->shape;
- if (axis < 0) axis += ndim;
- CHECK(axis >= 0 && axis < ndim);
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* x = args[0];
+ DLTensor* y = args[1];
+ int axis = args[2];
+ int ndim = x->ndim;
+ int64_t* shape = x->shape;
+ if (axis < 0) axis += ndim;
+ CHECK(axis >= 0 && axis < ndim);
- CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
- entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
+ CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+ entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
- // Set mode and shape descriptor
- if (axis == ndim - 1) {
- int64_t N = 1;
- for (int i = 0; i < ndim - 1; ++i) {
- N *= shape[i];
- }
- entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE;
- CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc,
- CUDNN_TENSOR_NCHW,
- entry_ptr->softmax_entry.data_type,
- static_cast<int>(N),
- static_cast<int>(shape[ndim - 1]),
- 1,
- 1));
- } else {
- int64_t pre_axis_dim = 1;
- int64_t post_axis_dim = 1;
- for (int i = 0; i < ndim; ++i) {
- if (i < axis) {
- pre_axis_dim *= shape[i];
- } else if (i > axis) {
- post_axis_dim *= shape[i];
+ // Set mode and shape descriptor
+ if (axis == ndim - 1) {
+ int64_t N = 1;
+ for (int i = 0; i < ndim - 1; ++i) {
+ N *= shape[i];
+ }
+ entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE;
+ CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc,
+ CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type,
+ static_cast<int>(N),
+ static_cast<int>(shape[ndim - 1]), 1, 1));
+ } else {
+ int64_t pre_axis_dim = 1;
+ int64_t post_axis_dim = 1;
+ for (int i = 0; i < ndim; ++i) {
+ if (i < axis) {
+ pre_axis_dim *= shape[i];
+ } else if (i > axis) {
+ post_axis_dim *= shape[i];
+ }
+ }
+ entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL;
+ CUDNN_CALL(cudnnSetTensor4dDescriptor(
+ entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW,
+ entry_ptr->softmax_entry.data_type, static_cast<int>(pre_axis_dim),
+ static_cast<int>(shape[axis]), static_cast<int>(post_axis_dim), 1));
}
- }
- entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL;
- CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc,
- CUDNN_TENSOR_NCHW,
- entry_ptr->softmax_entry.data_type,
- static_cast<int>(pre_axis_dim),
- static_cast<int>(shape[axis]),
- static_cast<int>(post_axis_dim),
- 1));
- }
- auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type);
- auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type);
- CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle,
- CUDNN_SOFTMAX_ACCURATE,
- entry_ptr->softmax_entry.mode,
- alpha,
- entry_ptr->softmax_entry.shape_desc,
- x->data,
- beta,
- entry_ptr->softmax_entry.shape_desc,
- y->data));
-});
+ auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type);
+ auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type);
+ CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, CUDNN_SOFTMAX_ACCURATE,
+ entry_ptr->softmax_entry.mode, alpha,
+ entry_ptr->softmax_entry.shape_desc, x->data, beta,
+ entry_ptr->softmax_entry.shape_desc, y->data));
+ });
} // namespace contrib
} // namespace tvm
* \brief TVM compatible wrappers for dnnl kernels.
*/
-#include "dnnl_kernel.h"
-
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <vector>
+#include "dnnl_kernel.h"
+
namespace tvm {
namespace runtime {
namespace contrib {
p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op());
}
-extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_,
- int p_I_, int p_O_) {
+extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) {
using tag = memory::format_tag;
using dt = memory::data_type;
auto bias_memory = memory(bias_md, eng, bias.data());
auto dst_memory = memory(dst_md, eng);
- auto dense_desc = inner_product_forward::desc(
- prop_kind::forward_inference, data_md, weight_md, bias_md, dst_md);
+ auto dense_desc = inner_product_forward::desc(prop_kind::forward_inference, data_md, weight_md,
+ bias_md, dst_md);
auto dense_prim_desc = inner_product_forward::primitive_desc(dense_desc, eng);
assert(dst_md == dense_prim_desc.dst_desc());
read_from_dnnl_memory(out, dst_memory);
}
-extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
- int p_W_) {
+extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_) {
using tag = memory::format_tag;
using dt = memory::data_type;
auto data_memory = memory(data_md, eng, data);
auto dst_memory = memory(data_md, eng);
- auto relu_desc = eltwise_forward::desc(prop_kind::forward_inference,
- algorithm::eltwise_relu, data_md, 0);
+ auto relu_desc =
+ eltwise_forward::desc(prop_kind::forward_inference, algorithm::eltwise_relu, data_md, 0);
auto relu_prim_desc = eltwise_forward::primitive_desc(relu_desc, eng);
assert(data_md == relu_prim_desc.dst_desc());
auto bn_desc = batch_normalization_forward::desc(
prop_kind::forward_inference, data_md, p_E_,
- normalization_flags::use_global_stats |
- normalization_flags::use_scale_shift);
+ normalization_flags::use_global_stats | normalization_flags::use_scale_shift);
auto bn_prim_desc = batch_normalization_forward::primitive_desc(bn_desc, eng);
assert(data_md == bn_prim_desc.dst_desc());
free(weight);
}
-extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_,
- int p_C_, int p_H_, int p_W_) {
+extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, int p_C_, int p_H_,
+ int p_W_) {
using tag = memory::format_tag;
using dt = memory::data_type;
auto weight_memory = memory(weight_md, eng, weight);
auto dst_memory = memory(dst_md, eng);
- auto add_desc =
- binary::desc(algorithm::binary_add, data_md, weight_md, dst_md);
+ auto add_desc = binary::desc(algorithm::binary_add, data_md, weight_md, dst_md);
auto add_prim_desc = binary::primitive_desc(add_desc, eng);
assert(dst_md == add_prim_desc.dst_desc());
auto add = binary(add_prim_desc);
- add.execute(s, {{DNNL_ARG_SRC_0, data_memory},
- {DNNL_ARG_SRC_1, weight_memory},
- {DNNL_ARG_DST, dst_memory}});
+ add.execute(
+ s,
+ {{DNNL_ARG_SRC_0, data_memory}, {DNNL_ARG_SRC_1, weight_memory}, {DNNL_ARG_DST, dst_memory}});
s.wait();
read_from_dnnl_memory(out, dst_memory);
}
#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_
#include <tvm/runtime/c_runtime_api.h>
+
#include "dnnl.hpp"
namespace tvm {
/*!
* \file edgetpu_runtime.cc
*/
-#include <tvm/runtime/registry.h>
+#include "edgetpu_runtime.h"
+
+#include <edgetpu.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>
-#include <edgetpu.h>
-
-
-#include "edgetpu_runtime.h"
+#include <tvm/runtime/registry.h>
namespace tvm {
namespace runtime {
-void EdgeTPURuntime::Init(const std::string& tflite_model_bytes,
- TVMContext ctx) {
+void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
// Load compiled model as a FlatBufferModel
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
+ tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
// Build resolver
tflite::ops::builtin::BuiltinOpResolver resolver;
// Init EdgeTPUContext object
ctx_ = ctx;
}
-Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes,
- TVMContext ctx) {
+Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx) {
auto exec = make_object<EdgeTPURuntime>();
exec->Init(tflite_model_bytes, ctx);
return Module(exec);
}
-TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = EdgeTPURuntimeCreate(args[0], args[1]);
- });
+TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = EdgeTPURuntimeCreate(args[0], args[1]);
+});
} // namespace runtime
} // namespace tvm
#ifndef TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
#define TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
-#include <string>
#include <memory>
+#include <string>
#include "../tflite/tflite_runtime.h"
/*!
* \return The type key of the executor.
*/
- const char* type_key() const final {
- return "EdgeTPURuntime";
- }
+ const char* type_key() const final { return "EdgeTPURuntime"; }
/*!
* \brief Initialize the edge TPU tflite runtime with tflite model and context.
* \param tflite_model_bytes The tflite model.
* \param ctx The context where the tflite model will be executed on.
*/
- void Init(const std::string& tflite_model_bytes,
- TVMContext ctx);
+ void Init(const std::string& tflite_model_bytes, TVMContext ctx);
private:
std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_;
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include <fstream>
#include <cmath>
+#include <fstream>
#include <map>
#include <sstream>
#include <string>
DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
DLTensor* out = static_cast<DLTensor*>(value[2].v_handle);
- Add_(static_cast<float*>(arg0->data), arg0->shape[0],
- static_cast<float*>(arg1->data), arg1->shape[0],
- static_cast<float*>(out->data));
+ Add_(static_cast<float*>(arg0->data), arg0->shape[0], static_cast<float*>(arg1->data),
+ arg1->shape[0], static_cast<float*>(out->data));
return 0;
}
DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
DLTensor* out = static_cast<DLTensor*>(value[2].v_handle);
- Sub_(static_cast<float*>(arg0->data), arg0->shape[0],
- static_cast<float*>(arg1->data), arg1->shape[0],
- static_cast<float*>(out->data));
+ Sub_(static_cast<float*>(arg0->data), arg0->shape[0], static_cast<float*>(arg1->data),
+ arg1->shape[0], static_cast<float*>(out->data));
return 0;
}
DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
DLTensor* out = static_cast<DLTensor*>(value[2].v_handle);
- Mul_(static_cast<float*>(arg0->data), arg0->shape[0],
- static_cast<float*>(arg1->data), arg1->shape[0],
- static_cast<float*>(out->data));
+ Mul_(static_cast<float*>(arg0->data), arg0->shape[0], static_cast<float*>(arg1->data),
+ arg1->shape[0], static_cast<float*>(out->data));
return 0;
}
*
* \return The function pointer when it is found, otherwise, PackedFunc(nullptr).
*/
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (this->graph_.find(name) != this->graph_.end()) {
this->curr_subgraph_ = name;
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*
* \param stream. The stream to save the binary.
*/
- void SaveToBinary(dmlc::Stream* stream) final {
- stream->Write(this->graph_json_);
- }
+ void SaveToBinary(dmlc::Stream* stream) final { stream->Write(this->graph_json_); }
/*!
* \brief Parse the example json string.
};
TVM_REGISTER_GLOBAL("runtime.module.loadfile_examplejson")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = ExampleJsonModule::Create(args[0]);
-});
+ .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = ExampleJsonModule::Create(args[0]); });
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_examplejson")
-.set_body_typed(ExampleJsonModule::LoadFromBinary);
+ .set_body_typed(ExampleJsonModule::LoadFromBinary);
} // namespace runtime
} // namespace tvm
/*!
* \file Use external miopen utils function
*/
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include "miopen_utils.h"
namespace tvm {
using namespace runtime;
-TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
+TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup").set_body([](TVMArgs args, TVMRetValue* ret) {
const int mode = args[0];
const int dtype = args[1];
const int pad_h = args[2];
const int w_dim2 = args[14];
const int w_dim3 = args[15];
const int n_group = args[16];
- void *out_shape = args[17];
+ void* out_shape = args[17];
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
assert(n_group > 0 && "Group Size > 0 is expected");
- if (n_group > 1)
- assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1");
+ if (n_group > 1) assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1");
// Set Mode
entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
// Set Ctx
entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
// Set Data Type
- entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
- dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at
- // this moment.
+ entry_ptr->conv_entry.data_type =
+ static_cast<miopenDataType_t>(dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf),
+ // int32, int8 at this moment.
// Set Desc
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.mode,
- pad_h,
- pad_w,
- stride_h,
- stride_w,
- dilation_h,
- dilation_w));
+ entry_ptr->conv_entry.mode, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w));
if (n_group > 1)
MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group));
// Set Filter
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
- entry_ptr->conv_entry.data_type,
- w_dim0,
- w_dim1/n_group,
- w_dim2,
- w_dim3));
+ entry_ptr->conv_entry.data_type, w_dim0, w_dim1 / n_group,
+ w_dim2, w_dim3));
// Set Input
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.data_type,
- x_dim0,
- x_dim1,
- x_dim2,
+ entry_ptr->conv_entry.data_type, x_dim0, x_dim1, x_dim2,
x_dim3));
// Set Output shape
- MIOPEN_CALL(miopenGetConvolutionForwardOutputDim(entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.filter_desc,
- static_cast<int*>(out_shape),
- static_cast<int*>(out_shape) + 1,
- static_cast<int*>(out_shape) + 2,
- static_cast<int*>(out_shape) + 3));
-
- const int *oshape = static_cast<int*>(out_shape);
+ MIOPEN_CALL(miopenGetConvolutionForwardOutputDim(
+ entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
+ entry_ptr->conv_entry.filter_desc, static_cast<int*>(out_shape),
+ static_cast<int*>(out_shape) + 1, static_cast<int*>(out_shape) + 2,
+ static_cast<int*>(out_shape) + 3));
+
+ const int* oshape = static_cast<int*>(out_shape);
// Set Output
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc,
- entry_ptr->conv_entry.data_type,
- oshape[0],
- oshape[1],
- oshape[2],
- oshape[3]));
+ entry_ptr->conv_entry.data_type, oshape[0], oshape[1],
+ oshape[2], oshape[3]));
// Set workspace
size_t workspace_size = 0;
- MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize(entry_ptr->handle,
- entry_ptr->conv_entry.filter_desc,
- entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.output_desc,
- &workspace_size));
+ MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize(
+ entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.input_desc,
+ entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, &workspace_size));
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
const size_t input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3;
const size_t output_size = oshape[0] * oshape[1] * oshape[2] * oshape[3];
runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry.rocm_api;
- float* input_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
- input_size * sizeof(float)));
- float* filter_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
- filter_size * sizeof(float)));
- float* output_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
- output_size * sizeof(float)));
+ float* input_buf = static_cast<float*>(
+ rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, input_size * sizeof(float)));
+ float* filter_buf = static_cast<float*>(
+ rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, filter_size * sizeof(float)));
+ float* output_buf = static_cast<float*>(
+ rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, output_size * sizeof(float)));
const int request_algo_count = 4;
const bool exhaustive_search = false;
int returned_algo_count = 0;
miopenConvAlgoPerf_t perfs[4];
- MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm(entry_ptr->handle,
- entry_ptr->conv_entry.input_desc,
- input_buf,
- entry_ptr->conv_entry.filter_desc,
- filter_buf,
- entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.output_desc,
- output_buf,
- request_algo_count,
- &returned_algo_count,
- perfs,
- workspace,
- workspace_size,
- exhaustive_search));
+ MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm(
+ entry_ptr->handle, entry_ptr->conv_entry.input_desc, input_buf,
+ entry_ptr->conv_entry.filter_desc, filter_buf, entry_ptr->conv_entry.conv_desc,
+ entry_ptr->conv_entry.output_desc, output_buf, request_algo_count, &returned_algo_count,
+ perfs, workspace, workspace_size, exhaustive_search));
rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, input_buf);
rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, filter_buf);
"miopenConvolutionFwdAlgoWinograd",
};
const auto best_algo = perfs[0].fwd_algo;
- LOG(INFO) << "\tMIOpen Found " << returned_algo_count
- << " fwd algorithms, choosing " << fwd_algo_names[best_algo];
+ LOG(INFO) << "\tMIOpen Found " << returned_algo_count << " fwd algorithms, choosing "
+ << fwd_algo_names[best_algo];
for (int i = 0; i < returned_algo_count; ++i) {
LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perfs[i].fwd_algo]
<< " - time: " << perfs[i].time << " ms"
ret[0] = static_cast<int>(best_algo);
});
-
TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- const int mode = args[0];
- const int dtype = args[1];
- const int pad_h = args[2];
- const int pad_w = args[3];
- const int stride_h = args[4];
- const int stride_w = args[5];
- const int dilation_h = args[6];
- const int dilation_w = args[7];
- const int algo = args[8];
- const DLTensor *x = args[9];
- const DLTensor *w = args[10];
- const DLTensor *y = args[11];
-
- MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
- entry_ptr->conv_entry.fwd_algo = static_cast<miopenConvFwdAlgorithm_t>(algo);
- // Set Mode
- entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
- // Set Ctx
- entry_ptr->conv_entry.ctx = x->ctx;
- // Set Data Type
- entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
- dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
- // this moment.
- // Set Desc
- MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.mode,
- pad_h,
- pad_w,
- stride_h,
- stride_w,
- dilation_h,
- dilation_w));
- // Set Filter
- MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
- entry_ptr->conv_entry.data_type,
- w->shape[0],
- w->shape[1],
- w->shape[2],
- w->shape[3]));
- // Set Input
- MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc,
- entry_ptr->conv_entry.data_type,
- x->shape[0],
- x->shape[1],
- x->shape[2],
- x->shape[3]));
- // Set Output
- MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc,
- entry_ptr->conv_entry.data_type,
- y->shape[0],
- y->shape[1],
- y->shape[2],
- y->shape[3]));
-
- const float alpha = 1.f;
- const float beta = 0.f;
- MIOPEN_CALL(miopenConvolutionForward(entry_ptr->handle,
- &alpha,
- entry_ptr->conv_entry.input_desc,
- x->data,
- entry_ptr->conv_entry.filter_desc,
- w->data,
- entry_ptr->conv_entry.conv_desc,
- entry_ptr->conv_entry.fwd_algo,
- &beta,
- entry_ptr->conv_entry.output_desc,
- y->data,
- entry_ptr->conv_entry.workspace,
- entry_ptr->conv_entry.workspace_size));
-});
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ const int mode = args[0];
+ const int dtype = args[1];
+ const int pad_h = args[2];
+ const int pad_w = args[3];
+ const int stride_h = args[4];
+ const int stride_w = args[5];
+ const int dilation_h = args[6];
+ const int dilation_w = args[7];
+ const int algo = args[8];
+ const DLTensor* x = args[9];
+ const DLTensor* w = args[10];
+ const DLTensor* y = args[11];
+
+ MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
+ entry_ptr->conv_entry.fwd_algo = static_cast<miopenConvFwdAlgorithm_t>(algo);
+ // Set Mode
+ entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
+ // Set Ctx
+ entry_ptr->conv_entry.ctx = x->ctx;
+ // Set Data Type
+ entry_ptr->conv_entry.data_type =
+ static_cast<miopenDataType_t>(dtype); // MIOpen supports fp32(miopenFloat),
+ // fp16(miopenHalf) at this moment.
+ // Set Desc
+ MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
+ entry_ptr->conv_entry.mode, pad_h, pad_w,
+ stride_h, stride_w, dilation_h, dilation_w));
+ // Set Filter
+ MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
+ entry_ptr->conv_entry.data_type, w->shape[0],
+ w->shape[1], w->shape[2], w->shape[3]));
+ // Set Input
+ MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc,
+ entry_ptr->conv_entry.data_type, x->shape[0],
+ x->shape[1], x->shape[2], x->shape[3]));
+ // Set Output
+ MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc,
+ entry_ptr->conv_entry.data_type, y->shape[0],
+ y->shape[1], y->shape[2], y->shape[3]));
+
+ const float alpha = 1.f;
+ const float beta = 0.f;
+ MIOPEN_CALL(miopenConvolutionForward(
+ entry_ptr->handle, &alpha, entry_ptr->conv_entry.input_desc, x->data,
+ entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc,
+ entry_ptr->conv_entry.fwd_algo, &beta, entry_ptr->conv_entry.output_desc, y->data,
+ entry_ptr->conv_entry.workspace, entry_ptr->conv_entry.workspace_size));
+ });
} // namespace miopen
} // namespace contrib
* \file Use external miopen utils function
*/
#include "miopen_utils.h"
+
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
-#include <vector>
+
#include <string>
+#include <vector>
namespace tvm {
namespace contrib {
namespace miopen {
std::string miopenGetErrorString(int error_code) {
- const std::vector<std::string> mio_err{
- "StatusSuccess ", "StatusNotInitialized ", "StatusInvalidValue ",
- "StatusBadParm ", "StatusAllocFailed ", "StatusInternalError ",
- "StatusNotImplemented ", "StatusUnknownError "};
+ const std::vector<std::string> mio_err{"StatusSuccess ", "StatusNotInitialized ",
+ "StatusInvalidValue ", "StatusBadParm ",
+ "StatusAllocFailed ", "StatusInternalError ",
+ "StatusNotImplemented ", "StatusUnknownError "};
return mio_err[error_code];
}
MIOpenThreadEntry::MIOpenThreadEntry() {
auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream;
auto func = runtime::Registry::Get("device_api.rocm");
- void *ret = (*func)();
+ void* ret = (*func)();
rocm_api = static_cast<runtime::DeviceAPI*>(ret);
MIOPEN_CALL(miopenCreate(&handle));
MIOPEN_CALL(miopenSetStream(handle, stream));
conv_entry.rocm_api = rocm_api;
}
-MIOpenThreadEntry::~MIOpenThreadEntry() {
- MIOPEN_CALL(miopenDestroy(handle));
-}
+MIOpenThreadEntry::~MIOpenThreadEntry() { MIOPEN_CALL(miopenDestroy(handle)); }
typedef dmlc::ThreadLocalStore<MIOpenThreadEntry> MIOpenThreadStore;
-MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() {
- return MIOpenThreadStore::Get();
-}
+MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { return MIOpenThreadStore::Get(); }
// ConvEntry
#include <dmlc/logging.h>
#include <miopen/miopen.h>
#include <tvm/runtime/device_api.h>
+
#include <string>
+
#include "../../rocm/rocm_common.h"
namespace tvm {
std::string miopenGetErrorString(int error_code);
-#define MIOPEN_CALL(func) \
- { \
- miopenStatus_t e = (func); \
- CHECK_EQ(e, miopenStatusSuccess) \
- << "miopen error: " << miopenGetErrorString(e); \
+#define MIOPEN_CALL(func) \
+ { \
+ miopenStatus_t e = (func); \
+ CHECK_EQ(e, miopenStatusSuccess) << "miopen error: " << miopenGetErrorString(e); \
}
struct ConvEntry {
miopenTensorDescriptor_t output_desc;
miopenConvFwdAlgorithm_t fwd_algo;
TVMContext ctx;
- runtime::DeviceAPI *rocm_api;
- void *workspace{nullptr};
+ runtime::DeviceAPI* rocm_api;
+ void* workspace{nullptr};
size_t workspace_size{0};
ConvEntry();
~ConvEntry();
~MIOpenThreadEntry();
miopenHandle_t handle{nullptr};
ConvEntry conv_entry;
- runtime::DeviceAPI *rocm_api{nullptr};
- static MIOpenThreadEntry *ThreadLocal();
+ runtime::DeviceAPI* rocm_api{nullptr};
+ static MIOpenThreadEntry* ThreadLocal();
}; // MIOpenThreadEntry
} // namespace miopen
using namespace runtime;
-TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLTensor *buf = args[0];
- DLTensor *img = args[1];
+TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* buf = args[0];
+ DLTensor* img = args[1];
// copy to temp
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
- MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
- runtime::metal::MetalThreadEntry *rt =
- runtime::metal::MetalThreadEntry::ThreadLocal();
+ MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
+ runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
- entry_ptr->metal_api->CopyDataFromTo(
- (__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length],
- buf->ctx, buf->ctx, nullptr
- );
+ entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge void*)temp, 0,
+ [mtlbuf length], buf -> ctx, buf -> ctx, nullptr);
- MPSImageDescriptor *desc = [MPSImageDescriptor
- imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
- width:buf->shape[2]
- height:buf->shape[1]
- featureChannels:buf->shape[3]];
+ MPSImageDescriptor* desc =
+ [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
+ width:buf->shape[2]
+ height:buf->shape[1]
+ featureChannels:buf->shape[3]];
- MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc);
+ MPSImage* mpsimg = entry_ptr->AllocMPSImage(dev, desc);
[mpsimg writeBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
- img->data = (__bridge void *)mpsimg;
+ img->data = (__bridge void*)mpsimg;
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
+});
- });
-
-TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLTensor *img = args[0];
- DLTensor *buf = args[1];
+TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* img = args[0];
+ DLTensor* buf = args[1];
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
- MPSImage *mpsimg = (__bridge MPSImage *)(img->data);
- MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
- runtime::metal::MetalThreadEntry *rt =
- runtime::metal::MetalThreadEntry::ThreadLocal();
+ MPSImage* mpsimg = (__bridge MPSImage*)(img->data);
+ MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
+ runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
- entry_ptr->metal_api->CopyDataFromTo(
- (__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length],
- buf->ctx, buf->ctx, nullptr);
-
- });
+ entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge void*)mtlbuf, 0,
+ [mtlbuf length], buf -> ctx, buf -> ctx, nullptr);
+});
-TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
+TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) {
// MPS-NHWC
- DLTensor *data = args[0];
- DLTensor *weight = args[1];
- DLTensor *output = args[2];
+ DLTensor* data = args[0];
+ DLTensor* weight = args[1];
+ DLTensor* output = args[2];
int pad = args[3];
int stride = args[4];
auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img");
auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer");
// Get Metal device API
- MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
- runtime::metal::MetalThreadEntry *rt =
- runtime::metal::MetalThreadEntry::ThreadLocal();
+ MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
+ runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(data->ctx);
- id<MTLCommandQueue> queue =
- entry_ptr->metal_api->GetCommandQueue(data->ctx);
+ id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(data->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
// data to MPSImage
DLTensor tmp_in;
(*f_buf2img)(data, &tmp_in);
- MPSImage *tempA = (__bridge MPSImage *)tmp_in.data;
+ MPSImage* tempA = (__bridge MPSImage*)tmp_in.data;
// weight to temp memory
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
- entry_ptr->metal_api->CopyDataFromTo(
- (__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length],
- weight->ctx, weight->ctx, nullptr);
- float *ptr_w = (float *)[tempB contents];
+ entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0,
+ [bufB length], weight -> ctx, weight -> ctx, nullptr);
+ float* ptr_w = (float*)[tempB contents];
// output to MPSImage
DLTensor tmp_out;
(*f_buf2img)(output, &tmp_out);
- MPSImage *tempC = (__bridge MPSImage *)tmp_out.data;
+ MPSImage* tempC = (__bridge MPSImage*)tmp_out.data;
// conv desc
- MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor
- cnnConvolutionDescriptorWithKernelWidth:kW
- kernelHeight:kH
- inputFeatureChannels:iCh
- outputFeatureChannels:oCh];
+ MPSCNNConvolutionDescriptor* conv_desc =
+ [MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW
+ kernelHeight:kH
+ inputFeatureChannels:iCh
+ outputFeatureChannels:oCh];
[conv_desc setStrideInPixelsX:stride];
[conv_desc setStrideInPixelsY:stride];
- MPSCNNConvolution *conv =
- [[MPSCNNConvolution alloc] initWithDevice:dev
- convolutionDescriptor:conv_desc
- kernelWeights:ptr_w
- biasTerms:nil
- flags:MPSCNNConvolutionFlagsNone];
+ MPSCNNConvolution* conv = [[MPSCNNConvolution alloc] initWithDevice:dev
+ convolutionDescriptor:conv_desc
+ kernelWeights:ptr_w
+ biasTerms:nil
+ flags:MPSCNNConvolutionFlagsNone];
if (pad == 0) {
- conv.padding = [MPSNNDefaultPadding
- paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
- MPSNNPaddingMethodAlignCentered |
- MPSNNPaddingMethodSizeSame];
+ conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
+ MPSNNPaddingMethodAlignCentered |
+ MPSNNPaddingMethodSizeSame];
} else if (pad == 1) {
- conv.padding = [MPSNNDefaultPadding
- paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
- MPSNNPaddingMethodAlignCentered |
- MPSNNPaddingMethodSizeValidOnly];
+ conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
+ MPSNNPaddingMethodAlignCentered |
+ MPSNNPaddingMethodSizeValidOnly];
}
[conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC];
[cb waitUntilCompleted];
(*f_img2buf)(&tmp_out, output);
+});
- });
-
-} // namespace contrib
-} // namespace tvm
+} // namespace contrib
+} // namespace tvm
using namespace runtime;
-TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLTensor *A = args[0];
- DLTensor *B = args[1];
- DLTensor *C = args[2];
+TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
// Get Metal device API
- MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
+ MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
// CHECK_EQ(A->ctx, B->ctx);
// CHECK_EQ(A->ctx, C->ctx);
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K);
// mps a
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
- MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
- matrixDescriptorWithDimensions:M
- columns:K
- rowBytes:K * sizeof(MPSDataTypeFloat32)
- dataType:MPSDataTypeFloat32];
+ MPSMatrixDescriptor* descA =
+ [MPSMatrixDescriptor matrixDescriptorWithDimensions:M
+ columns:K
+ rowBytes:K * sizeof(MPSDataTypeFloat32)
+ dataType:MPSDataTypeFloat32];
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
- MPSMatrix *matrixA =
- [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
+ MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
// mps b
- MPSMatrixDescriptor *descB =
- [MPSMatrixDescriptor matrixDescriptorWithDimensions:K
- columns:N
- rowBytes:N * sizeof(dtype)
- dataType:dtype];
+ MPSMatrixDescriptor* descB = [MPSMatrixDescriptor matrixDescriptorWithDimensions:K
+ columns:N
+ rowBytes:N * sizeof(dtype)
+ dataType:dtype];
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
- MPSMatrix *matrixB =
- [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
+ MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
// mps c
- MPSMatrixDescriptor *descC =
- [MPSMatrixDescriptor matrixDescriptorWithDimensions:M
- columns:N
- rowBytes:N * sizeof(dtype)
- dataType:dtype];
+ MPSMatrixDescriptor* descC = [MPSMatrixDescriptor matrixDescriptorWithDimensions:M
+ columns:N
+ rowBytes:N * sizeof(dtype)
+ dataType:dtype];
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
- MPSMatrix *matrixC =
- [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
+ MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
// kernel
- MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init];
- MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev
+ MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init];
+ MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev
transposeLeft:transa
transposeRight:transb
resultRows:M
alpha:1.0f
beta:0.0f];
CHECK(sgemm != nil);
- [sgemm encodeToCommandBuffer:cb
- leftMatrix:matrixA
- rightMatrix:matrixB
- resultMatrix:matrixC];
+ [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC];
[cb commit];
+});
- });
-
-} // namespace contrib
-} // namespace tvm
+} // namespace contrib
+} // namespace tvm
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
+#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
-#include <tvm/runtime/data_type.h>
+
#include <vector>
+
#include "../../metal/metal_common.h"
namespace tvm {
/*! breif Convert DLTensor type to MPS type */
struct MPSType {
- static MPSDataType DLTypeToMPSType(const DLDataType &dtype);
+ static MPSDataType DLTypeToMPSType(const DLDataType& dtype);
}; // struct MPSType
struct MetalThreadEntry {
MetalThreadEntry();
~MetalThreadEntry();
- MPSImage *AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor *desc);
- MPSTemporaryImage *AllocTempImage(id<MTLCommandBuffer> cb,
- MPSImageDescriptor *desc);
- runtime::metal::MetalWorkspace *metal_api{nullptr};
- static MetalThreadEntry *ThreadLocal();
- std::vector<MPSImage *> img_table;
+ MPSImage* AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor* desc);
+ MPSTemporaryImage* AllocTempImage(id<MTLCommandBuffer> cb, MPSImageDescriptor* desc);
+ runtime::metal::MetalWorkspace* metal_api{nullptr};
+ static MetalThreadEntry* ThreadLocal();
+ std::vector<MPSImage*> img_table;
}; // MetalThreadEntry
} // namespace contrib
namespace contrib {
// MPS Data Type
-MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) {
+MPSDataType MPSType::DLTypeToMPSType(const DLDataType& dtype) {
switch (dtype.code) {
- case kDLInt:
- if (dtype.bits == 8 && dtype.lanes == 1)
- return MPSDataTypeInt8;
- else if (dtype.bits == 16 && dtype.lanes == 1)
- return MPSDataTypeInt16;
- else
+ case kDLInt:
+ if (dtype.bits == 8 && dtype.lanes == 1)
+ return MPSDataTypeInt8;
+ else if (dtype.bits == 16 && dtype.lanes == 1)
+ return MPSDataTypeInt16;
+ else
+ LOG(FATAL) << "Unsupported type";
+ break;
+ case kDLUInt:
+ if (dtype.bits == 8 && dtype.lanes == 1)
+ return MPSDataTypeUInt8;
+ else if (dtype.bits == 16 && dtype.lanes == 1)
+ return MPSDataTypeUInt16;
+ else if (dtype.bits == 32 && dtype.lanes == 1)
+ return MPSDataTypeUInt32;
LOG(FATAL) << "Unsupported type";
- break;
- case kDLUInt:
- if (dtype.bits == 8 && dtype.lanes == 1)
- return MPSDataTypeUInt8;
- else if (dtype.bits == 16 && dtype.lanes == 1)
- return MPSDataTypeUInt16;
- else if (dtype.bits == 32 && dtype.lanes == 1)
- return MPSDataTypeUInt32;
- LOG(FATAL) << "Unsupported type";
- break;
- case kDLFloat:
- if (dtype.bits == 16 && dtype.lanes == 1)
- return MPSDataTypeFloat16;
- else if (dtype.bits == 32 && dtype.lanes == 1)
- return MPSDataTypeFloat32;
- else
+ break;
+ case kDLFloat:
+ if (dtype.bits == 16 && dtype.lanes == 1)
+ return MPSDataTypeFloat16;
+ else if (dtype.bits == 32 && dtype.lanes == 1)
+ return MPSDataTypeFloat32;
+ else
+ LOG(FATAL) << "Unsupported type";
+ break;
+ default:
LOG(FATAL) << "Unsupported type";
- break;
- default:
- LOG(FATAL) << "Unsupported type";
}
return MPSDataTypeFloat32;
}
// MetalThreadEntry
-MPSImage *MetalThreadEntry::AllocMPSImage(id<MTLDevice> dev,
- MPSImageDescriptor *desc) {
- MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc];
+MPSImage* MetalThreadEntry::AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor* desc) {
+ MPSImage* mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc];
img_table.push_back(mpsimg);
return mpsimg;
}
-MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id<MTLCommandBuffer> cb,
- MPSImageDescriptor *desc) {
- MPSTemporaryImage *mpsimg =
- [MPSTemporaryImage temporaryImageWithCommandBuffer:cb
- imageDescriptor:desc];
+MPSTemporaryImage* MetalThreadEntry::AllocTempImage(id<MTLCommandBuffer> cb,
+ MPSImageDescriptor* desc) {
+ MPSTemporaryImage* mpsimg = [MPSTemporaryImage temporaryImageWithCommandBuffer:cb
+ imageDescriptor:desc];
return mpsimg;
}
MetalThreadEntry::MetalThreadEntry() {
auto func = runtime::Registry::Get("device_api.metal");
- void *ret = (*func)();
- metal_api = static_cast<runtime::metal::MetalWorkspace *>(ret);
+ void* ret = (*func)();
+ metal_api = static_cast<runtime::metal::MetalWorkspace*>(ret);
}
MetalThreadEntry::~MetalThreadEntry() {
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
-MetalThreadEntry *MetalThreadEntry::ThreadLocal() {
- return MetalThreadStore::Get();
-}
+MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); }
-} // namespace contrib
-} // namespace tvm
+} // namespace contrib
+} // namespace tvm
/*!
* \file Use external nnpack library call.
*/
-#include <tvm/runtime/device_api.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/data_type.h>
#include <dmlc/logging.h>
#include <nnpack.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include "nnpack_utils.h"
namespace tvm {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
- .set_body([](TVMArgs args, TVMRetValue *ret) {
- NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal();
static std::once_flag flag;
- std::call_once(flag,
- []() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
- DLTensor *input = args[0];
- DLTensor *kernel = args[1];
- DLTensor *bias = nullptr;
+ std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
+ DLTensor* input = args[0];
+ DLTensor* kernel = args[1];
+ DLTensor* bias = nullptr;
if (args[2].type_code() == kTVMDLTensorHandle) {
bias = args[2];
}
- DLTensor *output = args[3];
- uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6],
- pad_left = args[7];
+ DLTensor* output = args[3];
+ uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
uint64_t stride_width = args[8], stride_height = args[9];
nnp_size stride_size{stride_width, stride_height};
NNPackConfig(args[10]);
uint64_t algo_ = args[11];
- nnp_convolution_algorithm algo =
- static_cast<nnp_convolution_algorithm>(algo_);
+ nnp_convolution_algorithm algo = static_cast<nnp_convolution_algorithm>(algo_);
CHECK_EQ(input->ndim, 4);
CHECK_EQ(kernel->ndim, 4);
if (bias) {
size_t workspace_size = 0;
nnp_status status = nnp_convolution_inference(
- algo, nnp_convolution_transform_strategy_compute, input_channels,
- output_channels, input_size, input_padding, kernel_size, stride_size,
- nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
- nnp_activation_identity, nullptr, entry->threadpool, nullptr);
+ algo, nnp_convolution_transform_strategy_compute, input_channels, output_channels,
+ input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr,
+ nullptr, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
// Division with rounding up, in case size is not multiple of sizeof(float)
DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer =
- cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
+ cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
CHECK(workspace_buffer != nullptr);
for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference(
- algo, nnp_convolution_transform_strategy_compute, input_channels,
- output_channels, input_size, input_padding, kernel_size,
- stride_size,
- static_cast<float *>(input->data) + n * input->shape[1] *
- input->shape[2] *
- input->shape[3],
- static_cast<float *>(kernel->data),
- bias ? static_cast<float *>(bias->data) : zero_bias->data(),
- static_cast<float *>(output->data) + n * output->shape[1] *
- output->shape[2] *
- output->shape[3],
- workspace_buffer, &workspace_size,
- nnp_activation_identity, nullptr, entry->threadpool, nullptr);
+ algo, nnp_convolution_transform_strategy_compute, input_channels, output_channels,
+ input_size, input_padding, kernel_size, stride_size,
+ static_cast<float*>(input->data) +
+ n * input->shape[1] * input->shape[2] * input->shape[3],
+ static_cast<float*>(kernel->data),
+ bias ? static_cast<float*>(bias->data) : zero_bias->data(),
+ static_cast<float*>(output->data) +
+ n * output->shape[1] * output->shape[2] * output->shape[3],
+ workspace_buffer, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool,
+ nullptr);
CHECK_EQ(status, nnp_status_success);
}
});
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform")
- .set_body([](TVMArgs args, TVMRetValue *ret) {
- NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal();
static std::once_flag flag;
- std::call_once(flag,
- []() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
- DLTensor *input = args[0];
- DLTensor *transformed_kernel = args[1];
- DLTensor *bias = nullptr;
+ std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
+ DLTensor* input = args[0];
+ DLTensor* transformed_kernel = args[1];
+ DLTensor* bias = nullptr;
if (args[2].type_code() == kTVMDLTensorHandle) {
bias = args[2];
}
- DLTensor *output = args[3];
- uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6],
- pad_left = args[7];
+ DLTensor* output = args[3];
+ uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
uint64_t stride_width = args[8], stride_height = args[9];
nnp_size stride_size{stride_width, stride_height};
NNPackConfig(args[10]);
uint64_t algo_ = args[11];
- nnp_convolution_algorithm algo =
- static_cast<nnp_convolution_algorithm>(algo_);
+ nnp_convolution_algorithm algo = static_cast<nnp_convolution_algorithm>(algo_);
CHECK_EQ(input->ndim, 4);
if (bias) {
CHECK_EQ(bias->ndim, 1);
size_t workspace_size = 0;
nnp_status status = nnp_convolution_inference(
- algo, nnp_convolution_transform_strategy_reuse, input_channels,
- output_channels, input_size, input_padding, kernel_size, stride_size,
- nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
- nnp_activation_identity, nullptr, entry->threadpool, nullptr);
+ algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels,
+ input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr,
+ nullptr, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
// Division with rounding up, in case size is not multiple of sizeof(float)
DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer =
- cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
+ cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
CHECK(workspace_buffer != nullptr);
for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels,
input_size, input_padding, kernel_size, stride_size,
- static_cast<float *>(input->data) + n * input->shape[1] *
- input->shape[2] *
- input->shape[3],
- static_cast<float *>(transformed_kernel->data),
- bias ? static_cast<float *>(bias->data) : zero_bias->data(),
- static_cast<float *>(output->data) + n * output->shape[1] *
- output->shape[2] *
- output->shape[3],
- workspace_buffer, &workspace_size,
- nnp_activation_identity, nullptr, entry->threadpool, nullptr);
+ static_cast<float*>(input->data) +
+ n * input->shape[1] * input->shape[2] * input->shape[3],
+ static_cast<float*>(transformed_kernel->data),
+ bias ? static_cast<float*>(bias->data) : zero_bias->data(),
+ static_cast<float*>(output->data) +
+ n * output->shape[1] * output->shape[2] * output->shape[3],
+ workspace_buffer, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool,
+ nullptr);
CHECK_EQ(status, nnp_status_success);
}
cpu_api->FreeWorkspace(ctx, workspace_buffer);
});
-TVM_REGISTER_GLOBAL(
- "tvm.contrib.nnpack.convolution_inference_weight_transform")
- .set_body([](TVMArgs args, TVMRetValue *ret) {
- NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
+TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_weight_transform")
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal();
static std::once_flag flag;
- std::call_once(flag,
- []() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
- DLTensor *kernel = args[0];
- DLTensor *transformed_kernel = args[1];
+ std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
+ DLTensor* kernel = args[0];
+ DLTensor* transformed_kernel = args[1];
// Dummy sizes
nnp_padding input_padding{1, 1, 1, 1};
nnp_size stride_size{1, 1};
NNPackConfig(args[2]);
uint64_t algo_ = args[3];
- nnp_convolution_algorithm algo =
- static_cast<nnp_convolution_algorithm>(algo_);
+ nnp_convolution_algorithm algo = static_cast<nnp_convolution_algorithm>(algo_);
CHECK_EQ(kernel->ndim, 4);
size_t input_channels = kernel->shape[1];
size_t output_channels = kernel->shape[0];
size_t transformed_kernel_size = 0;
nnp_status status;
status = nnp_convolution_inference(
- algo, nnp_convolution_transform_strategy_precompute, input_channels,
- output_channels, input_size, input_padding, kernel_size, stride_size,
- nullptr, nullptr, nullptr, nullptr, nullptr, &transformed_kernel_size,
- nnp_activation_identity, nullptr, entry->threadpool, nullptr);
+ algo, nnp_convolution_transform_strategy_precompute, input_channels, output_channels,
+ input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr,
+ nullptr, &transformed_kernel_size, nnp_activation_identity, nullptr, entry->threadpool,
+ nullptr);
CHECK_EQ(status, nnp_status_success);
CHECK_LE(transformed_kernel_size, GetDataSize(*transformed_kernel));
status = nnp_convolution_inference(
- algo, nnp_convolution_transform_strategy_precompute, input_channels,
- output_channels, input_size, input_padding, kernel_size, stride_size,
- nullptr, static_cast<float *>(kernel->data), nullptr, nullptr,
- static_cast<float *>(transformed_kernel->data),
- &transformed_kernel_size, nnp_activation_identity, nullptr,
- entry->threadpool, nullptr);
+ algo, nnp_convolution_transform_strategy_precompute, input_channels, output_channels,
+ input_size, input_padding, kernel_size, stride_size, nullptr,
+ static_cast<float*>(kernel->data), nullptr, nullptr,
+ static_cast<float*>(transformed_kernel->data), &transformed_kernel_size,
+ nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
});
} // namespace contrib
/*!
* \file Use external nnpack library call.
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/data_type.h>
#include <dmlc/logging.h>
#include <nnpack.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
+
#include "nnpack_utils.h"
namespace tvm {
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
- nnp_initialize();
- DLTensor* A = args[0];
- DLTensor* B = args[1];
- DLTensor* C = args[2];
- NNPackConfig(args[3]);
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal();
+ nnp_initialize();
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
+ NNPackConfig(args[3]);
- CHECK_EQ(A->ndim, 1);
- CHECK_EQ(B->ndim, 2);
- CHECK_EQ(C->ndim, 1);
- CHECK_EQ(B->shape[0], C->shape[0]);
- CHECK_EQ(B->shape[1], A->shape[0]);
- CHECK(C->strides == nullptr);
- CHECK(B->strides == nullptr);
- CHECK(A->strides == nullptr);
- CHECK(TypeMatch(A->dtype, kDLFloat, 32));
- CHECK(TypeMatch(B->dtype, kDLFloat, 32));
- CHECK(TypeMatch(C->dtype, kDLFloat, 32));
+ CHECK_EQ(A->ndim, 1);
+ CHECK_EQ(B->ndim, 2);
+ CHECK_EQ(C->ndim, 1);
+ CHECK_EQ(B->shape[0], C->shape[0]);
+ CHECK_EQ(B->shape[1], A->shape[0]);
+ CHECK(C->strides == nullptr);
+ CHECK(B->strides == nullptr);
+ CHECK(A->strides == nullptr);
+ CHECK(TypeMatch(A->dtype, kDLFloat, 32));
+ CHECK(TypeMatch(B->dtype, kDLFloat, 32));
+ CHECK(TypeMatch(C->dtype, kDLFloat, 32));
- nnp_fully_connected_inference(B->shape[1],
- B->shape[0],
- static_cast<float*>(A->data),
- static_cast<float*>(B->data),
- static_cast<float*>(C->data),
- entry->threadpool);
- });
+ nnp_fully_connected_inference(B->shape[1], B->shape[0], static_cast<float*>(A->data),
+ static_cast<float*>(B->data), static_cast<float*>(C->data),
+ entry->threadpool);
+ });
} // namespace contrib
} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
typedef dmlc::ThreadLocalStore<NNPackThreadLocalEntry> NNPackThreadLocalStore;
-
NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() {
return NNPackThreadLocalStore::Get();
}
bool NNPackConfig(uint64_t nthreads) {
- NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
+ NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal();
if (entry->threadpool && pthreadpool_get_threads_count(entry->threadpool) == nthreads) {
CHECK_NE(nthreads, 1);
return true;
return true;
}
-
-TVM_REGISTER_GLOBAL("contrib.nnpack._initialize")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = nnp_initialize();
- });
+TVM_REGISTER_GLOBAL("contrib.nnpack._initialize").set_body([](TVMArgs args, TVMRetValue* ret) {
+ *ret = nnp_initialize();
+});
} // namespace contrib
} // namespace tvm
*/
#ifndef TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_
#define TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/data_type.h>
-#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
+#include <dmlc/thread_local.h>
#include <nnpack.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
namespace tvm {
namespace contrib {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \brief mt19937 random engine
*/
#include <dmlc/logging.h>
+
#include <algorithm>
#include <ctime>
#include <random>
*/
class RandomEngine {
public:
- /*!
- * \brief Creates a RandomEngine using a default seed.
- */
- RandomEngine() {
- this->Seed(time(0));
- }
-
- /*!
- * \brief Creates a RandomEngine, suggesting the use of a provided seed.
- */
- explicit RandomEngine(unsigned seed) {
- this->Seed(seed);
- }
-
- /*!
- * \brief Seeds the underlying RNG, if possible.
- */
+ /*!
+ * \brief Creates a RandomEngine using a default seed.
+ */
+ RandomEngine() { this->Seed(time(0)); }
+
+ /*!
+ * \brief Creates a RandomEngine, suggesting the use of a provided seed.
+ */
+ explicit RandomEngine(unsigned seed) { this->Seed(seed); }
+
+ /*!
+ * \brief Seeds the underlying RNG, if possible.
+ */
inline void Seed(unsigned seed) {
rnd_engine_.seed(seed);
this->rseed_ = static_cast<unsigned>(seed);
}
- /*!
- * \return the seed associated with the underlying RNG.
- */
- inline unsigned GetSeed() const {
- return rseed_;
- }
+ /*!
+ * \return the seed associated with the underlying RNG.
+ */
+ inline unsigned GetSeed() const { return rseed_; }
- /*!
- * \return a random integer sampled from the RNG.
- */
- inline unsigned GetRandInt() {
- return rnd_engine_();
- }
+ /*!
+ * \return a random integer sampled from the RNG.
+ */
+ inline unsigned GetRandInt() { return rnd_engine_(); }
- /*!
- * \brief Fills a tensor with values drawn from Unif(low, high)
- */
+ /*!
+ * \brief Fills a tensor with values drawn from Unif(low, high)
+ */
void SampleUniform(DLTensor* data, float low, float high) {
CHECK_GT(high, low) << "high must be bigger than low";
CHECK(data->strides == nullptr);
if (data->ctx.device_type == kDLCPU) {
std::uniform_real_distribution<float> uniform_dist(low, high);
- std::generate_n(static_cast<float*>(data->data), size, [&] () {
- return uniform_dist(rnd_engine_);
- });
+ std::generate_n(static_cast<float*>(data->data), size,
+ [&]() { return uniform_dist(rnd_engine_); });
} else {
LOG(FATAL) << "Do not support random.uniform on this device yet";
}
}
- /*!
- * \brief Fills a tensor with values drawn from Normal(loc, scale**2)
- */
+ /*!
+ * \brief Fills a tensor with values drawn from Normal(loc, scale**2)
+ */
void SampleNormal(DLTensor* data, float loc, float scale) {
CHECK_GT(scale, 0) << "standard deviation must be positive";
CHECK(data->strides == nullptr);
if (data->ctx.device_type == kDLCPU) {
std::normal_distribution<float> normal_dist(loc, scale);
- std::generate_n(static_cast<float*>(data->data), size, [&] () {
- return normal_dist(rnd_engine_);
- });
+ std::generate_n(static_cast<float*>(data->data), size,
+ [&]() { return normal_dist(rnd_engine_); });
} else {
LOG(FATAL) << "Do not support random.normal on this device yet";
}
/*!
* \file External random functions for tensor.
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/data_type.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
+
#include <algorithm>
+
#include "mt_random_engine.cc"
#define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \
if (type.code == kDLInt && type.bits == 32) { \
typedef int32_t DType; \
- {__VA_ARGS__} \
+ { __VA_ARGS__ } \
} else if (type.code == kDLInt && type.bits == 16) { \
typedef int16_t DType; \
- {__VA_ARGS__} \
+ { __VA_ARGS__ } \
} else if (type.code == kDLInt && type.bits == 8) { \
typedef int8_t DType; \
- {__VA_ARGS__} \
+ { __VA_ARGS__ } \
} else if (type.code == kDLUInt && type.bits == 32) { \
typedef uint32_t DType; \
- {__VA_ARGS__} \
+ { __VA_ARGS__ } \
} else if (type.code == kDLUInt && type.bits == 16) { \
typedef uint16_t DType; \
- {__VA_ARGS__} \
+ { __VA_ARGS__ } \
} else if (type.code == kDLUInt && type.bits == 8) { \
typedef uint8_t DType; \
- {__VA_ARGS__} \
+ { __VA_ARGS__ } \
} else { \
LOG(FATAL) << "unknown data type"; \
}
return RandomThreadLocalStore::Get();
}
+TVM_REGISTER_GLOBAL("tvm.contrib.random.randint").set_body([](TVMArgs args, TVMRetValue* ret) {
+ RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
+ int64_t low = args[0];
+ int64_t high = args[1];
+ DLTensor* out = args[2];
+ CHECK_GT(high, low) << "high must be bigger than low";
+ CHECK(out->strides == nullptr);
+
+ DLDataType dtype = out->dtype;
+ int64_t size = 1;
+ for (int i = 0; i < out->ndim; ++i) {
+ size *= out->shape[i];
+ }
-TVM_REGISTER_GLOBAL("tvm.contrib.random.randint")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal();
- int64_t low = args[0];
- int64_t high = args[1];
- DLTensor* out = args[2];
- CHECK_GT(high, low) << "high must be bigger than low";
- CHECK(out->strides == nullptr);
-
- DLDataType dtype = out->dtype;
- int64_t size = 1;
- for (int i = 0; i < out->ndim; ++i) {
- size *= out->shape[i];
+ DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, {
+ int64_t numeric_low = std::numeric_limits<DType>::min();
+ int64_t numeric_high = std::numeric_limits<DType>::max();
+ numeric_high += 1; // exclusive upper bound
+ low = std::max(low, numeric_low);
+ high = std::min(high, numeric_high);
+
+ if (out->ctx.device_type == kDLCPU) {
+ // file the data with random byte
+ std::generate_n(static_cast<DType*>(out->data), size, [&]() {
+ unsigned rint = entry->random_engine.GetRandInt();
+ return low + rint % (high - low);
+ });
+ } else {
+ LOG(FATAL) << "Do not support random.randint on this device yet";
}
-
- DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, {
- int64_t numeric_low = std::numeric_limits<DType>::min();
- int64_t numeric_high = std::numeric_limits<DType>::max();
- numeric_high += 1; // exclusive upper bound
- low = std::max(low, numeric_low);
- high = std::min(high, numeric_high);
-
- if (out->ctx.device_type == kDLCPU) {
- // file the data with random byte
- std::generate_n(static_cast<DType*>(out->data), size, [&] () {
- unsigned rint = entry->random_engine.GetRandInt();
- return low + rint % (high - low);
- });
- } else {
- LOG(FATAL) << "Do not support random.randint on this device yet";
- }
- })
- });
-
-
-TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal();
- double low = args[0];
- double high = args[1];
- DLTensor* out = args[2];
- entry->random_engine.SampleUniform(out, low, high);
- });
-
-
-TVM_REGISTER_GLOBAL("tvm.contrib.random.normal")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal();
- double loc = args[0];
- double scale = args[1];
- DLTensor* out = args[2];
- entry->random_engine.SampleNormal(out, loc, scale);
- });
-
+ })
+});
+
+TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform").set_body([](TVMArgs args, TVMRetValue* ret) {
+ RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
+ double low = args[0];
+ double high = args[1];
+ DLTensor* out = args[2];
+ entry->random_engine.SampleUniform(out, low, high);
+});
+
+TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRetValue* ret) {
+ RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
+ double loc = args[0];
+ double scale = args[1];
+ DLTensor* out = args[2];
+ entry->random_engine.SampleNormal(out, loc, scale);
+});
} // namespace contrib
} // namespace tvm
/*!
* \file Use external rocblas library call.
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/data_type.h>
-#include <dmlc/logging.h>
#include "rocblas.h"
+#include <dmlc/logging.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
+
namespace tvm {
namespace contrib {
using namespace runtime;
#ifndef CHECK_ROCBLAS_ERROR
-#define CHECK_ROCBLAS_ERROR(error) \
-if (error != rocblas_status_success) { \
- fprintf(stderr, "rocBLAS error: "); \
- if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \
- if (error == rocblas_status_not_implemented) fprintf(stderr, " rocblas_status_not_implemented"); \
- if (error == rocblas_status_invalid_pointer) fprintf(stderr, "rocblas_status_invalid_pointer"); \
- if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \
- if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \
- if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \
- fprintf(stderr, "\n"); \
- exit(EXIT_FAILURE); \
-}
+#define CHECK_ROCBLAS_ERROR(error) \
+ if (error != rocblas_status_success) { \
+ fprintf(stderr, "rocBLAS error: "); \
+ if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \
+ if (error == rocblas_status_not_implemented) \
+ fprintf(stderr, " rocblas_status_not_implemented"); \
+ if (error == rocblas_status_invalid_pointer) \
+ fprintf(stderr, "rocblas_status_invalid_pointer"); \
+ if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \
+ if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \
+ if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \
+ fprintf(stderr, "\n"); \
+ exit(EXIT_FAILURE); \
+ }
#endif
-
// matrix multiplication for row major
-TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLTensor* A = args[0];
- DLTensor* B = args[1];
- DLTensor* C = args[2];
- bool transa = args[3];
- bool transb = args[4];
- // call gemm for simple compact code.
- CHECK_EQ(A->ndim, 2);
- CHECK_EQ(B->ndim, 2);
- CHECK_EQ(C->ndim, 2);
- CHECK(C->strides == nullptr);
- CHECK(B->strides == nullptr);
- CHECK(A->strides == nullptr);
- CHECK(TypeMatch(A->dtype, kDLFloat, 32));
- CHECK(TypeMatch(B->dtype, kDLFloat, 32));
- CHECK(TypeMatch(C->dtype, kDLFloat, 32));
+TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
+ bool transa = args[3];
+ bool transb = args[4];
+ // call gemm for simple compact code.
+ CHECK_EQ(A->ndim, 2);
+ CHECK_EQ(B->ndim, 2);
+ CHECK_EQ(C->ndim, 2);
+ CHECK(C->strides == nullptr);
+ CHECK(B->strides == nullptr);
+ CHECK(A->strides == nullptr);
+ CHECK(TypeMatch(A->dtype, kDLFloat, 32));
+ CHECK(TypeMatch(B->dtype, kDLFloat, 32));
+ CHECK(TypeMatch(C->dtype, kDLFloat, 32));
- rocblas_handle handle;
- CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
- float alpha = 1.0;
- float beta = 0.0;
- float *A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
- float *B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
- float *C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);
+ rocblas_handle handle;
+ CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
+ float alpha = 1.0;
+ float beta = 0.0;
+ float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
+ float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
+ float* C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);
- CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle,
- transb ? rocblas_operation_transpose : rocblas_operation_none,
- transa ? rocblas_operation_transpose : rocblas_operation_none,
- transb ? B->shape[0] : B->shape[1],
- transa ? A->shape[1] : A->shape[0],
- transb ? B->shape[1] : B->shape[0],
- &alpha,
- A_ptr,
- B->shape[1],
- B_ptr,
- A->shape[1],
- &beta,
- C_ptr,
- C->shape[1]));
+ CHECK_ROCBLAS_ERROR(
+ rocblas_sgemm(handle, transb ? rocblas_operation_transpose : rocblas_operation_none,
+ transa ? rocblas_operation_transpose : rocblas_operation_none,
+ transb ? B->shape[0] : B->shape[1], transa ? A->shape[1] : A->shape[0],
+ transb ? B->shape[1] : B->shape[0], &alpha, A_ptr, B->shape[1], B_ptr,
+ A->shape[1], &beta, C_ptr, C->shape[1]));
- CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
+ CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
});
} // namespace contrib
} // namespace tvm
* \file Use standard C library call.
*/
-#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
+#include <tvm/runtime/registry.h>
+
#include <algorithm>
#include <vector>
using namespace runtime;
-template<typename DType>
-bool CompareAscend(const std::pair<int64_t, DType>& lhs,
- const std::pair<int64_t, DType>& rhs) {
+template <typename DType>
+bool CompareAscend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) {
return lhs.second < rhs.second;
}
-template<typename DType>
-bool CompareDescend(const std::pair<int64_t, DType>& lhs,
- const std::pair<int64_t, DType>& rhs) {
+template <typename DType>
+bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) {
return lhs.second > rhs.second;
}
-
// Argsort implemented C library sort for nms.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
-TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLTensor *input = args[0];
- DLTensor *sort_num = args[1];
- DLTensor *output = args[2];
+TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* input = args[0];
+ DLTensor* sort_num = args[1];
+ DLTensor* output = args[2];
int32_t axis = args[3];
bool is_ascend = args[4];
auto dtype = input->dtype;
- auto data_ptr = static_cast<float *>(input->data);
- auto sort_num_ptr = static_cast<int32_t *>(sort_num->data);
+ auto data_ptr = static_cast<float*>(input->data);
+ auto sort_num_ptr = static_cast<int32_t*>(sort_num->data);
std::vector<std::pair<int32_t, float>> sorter;
int64_t axis_mul_before = 1;
int64_t axis_mul_after = 1;
// Currently only supports input dtype to be float32.
CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
- "to be float.";
+ "to be float.";
#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1)
CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
- "to be float32.";
+ "to be float32.";
#endif
CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
- "input ndim " << input->ndim;
+ "input ndim "
+ << input->ndim;
for (int i = 0; i < input->ndim; ++i) {
if (i < axis) {
}
}
- for (int64_t i = 0 ; i < axis_mul_before; ++i) {
- for (int64_t j = 0 ; j < axis_mul_after; ++j) {
+ for (int64_t i = 0; i < axis_mul_before; ++i) {
+ for (int64_t j = 0; j < axis_mul_after; ++j) {
sorter.clear();
int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j);
int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>);
} else {
#endif
- std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
+ std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1)
}
#endif
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>);
} else {
#endif
- std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
+ std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1)
}
#endif
}
for (int32_t k = 0; k < input->shape[axis]; ++k) {
- *(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after)
- = k < static_cast<int32_t>(sorter.size()) ? sorter[k].first : k;
+ *(static_cast<int32_t*>(output->data) + base_idx + k * axis_mul_after) =
+ k < static_cast<int32_t>(sorter.size()) ? sorter[k].first : k;
}
}
}
});
-template<typename DataType, typename OutType>
+template <typename DataType, typename OutType>
void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
- auto data_ptr = static_cast<DataType *>(input->data);
- auto out_ptr = static_cast<OutType *>(output->data);
- std::vector<std::pair<int64_t, DataType> > sorter;
+ auto data_ptr = static_cast<DataType*>(input->data);
+ auto out_ptr = static_cast<OutType*>(output->data);
+ std::vector<std::pair<int64_t, DataType>> sorter;
int axis_mul_before = 1;
int axis_mul_after = 1;
}
}
- for (int i = 0 ; i < axis_mul_before; ++i) {
- for (int j = 0 ; j < axis_mul_after; ++j) {
+ for (int i = 0; i < axis_mul_before; ++i) {
+ for (int j = 0; j < axis_mul_after; ++j) {
sorter.clear();
int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
for (int64_t k = 0; k < input->shape[axis]; ++k) {
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
-TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLTensor *input = args[0];
- DLTensor *output = args[1];
+TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* input = args[0];
+ DLTensor* output = args[1];
int32_t axis = args[2];
bool is_ascend = args[3];
if (axis < 0) {
axis = input->ndim + axis;
}
CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
- "input ndim " << input->ndim;
+ "input ndim "
+ << input->ndim;
auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(output->dtype);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
- } else if (data_dtype == "int64") {
+ } else if (data_dtype == "int64") {
if (out_dtype == "int32") {
argsort<int64_t, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
}
});
-template<typename DataType, typename IndicesType>
-void topk(DLTensor* input,
- DLTensor* out_values,
- DLTensor* out_indices,
- int k,
- int axis,
+template <typename DataType, typename IndicesType>
+void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis,
bool is_ascend) {
- DataType* data_ptr = static_cast<DataType *>(input->data);
- DataType* values_ptr = (out_values == nullptr) ? nullptr :
- static_cast<DataType *>(out_values->data);
- IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr :
- static_cast<IndicesType *>(out_indices->data);
- std::vector<std::pair<int64_t, DataType> > sorter;
+ DataType* data_ptr = static_cast<DataType*>(input->data);
+ DataType* values_ptr =
+ (out_values == nullptr) ? nullptr : static_cast<DataType*>(out_values->data);
+ IndicesType* indices_ptr =
+ (out_indices == nullptr) ? nullptr : static_cast<IndicesType*>(out_indices->data);
+ std::vector<std::pair<int64_t, DataType>> sorter;
int axis_mul_before = 1;
int axis_mul_after = 1;
k = input->shape[axis];
}
- for (int i = 0 ; i < axis_mul_before; ++i) {
- for (int j = 0 ; j < axis_mul_after; ++j) {
+ for (int i = 0; i < axis_mul_before; ++i) {
+ for (int j = 0; j < axis_mul_after; ++j) {
sorter.clear();
int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j;
int64_t dst_base_idx = i * k * axis_mul_after + j;
for (int64_t kk = 0; kk < cnt; ++kk) {
if (indices_ptr != nullptr) {
indices_ptr[dst_base_idx + kk * axis_mul_after] =
- static_cast<IndicesType>(sorter[kk].first);
+ static_cast<IndicesType>(sorter[kk].first);
}
if (values_ptr != nullptr) {
- values_ptr[dst_base_idx + kk * axis_mul_after] =
- static_cast<DataType>(sorter[kk].second);
+ values_ptr[dst_base_idx + kk * axis_mul_after] = static_cast<DataType>(sorter[kk].second);
}
}
}
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
-TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* input = args[0];
DLTensor* values_out = nullptr;
DLTensor* indices_out = nullptr;
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
- } else if (data_dtype == "int64") {
+ } else if (data_dtype == "int64") {
if (out_dtype == "int32") {
topk<int64_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
/*!
* \file tflite_runtime.cc
*/
-#include <tvm/runtime/registry.h>
+#include "tflite_runtime.h"
+
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>
-
-
-#include "tflite_runtime.h"
+#include <tvm/runtime/registry.h>
namespace tvm {
namespace runtime {
-#define TVM_DTYPE_DISPATCH(type, DType, ...) \
- if (type == DataType::Float(64)) { \
- typedef double DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Float(32)) { \
- typedef float DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Float(16)) { \
- typedef uint16_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Int(64)) { \
- typedef int64_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Int(32)) { \
- typedef int32_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Int(16)) { \
- typedef int16_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::Int(8)) { \
- typedef int8_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::UInt(64)) { \
- typedef uint64_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::UInt(32)) { \
- typedef uint32_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::UInt(16)) { \
- typedef uint16_t DType; \
- {__VA_ARGS__} \
- } else if (type == DataType::UInt(8)) { \
- typedef uint8_t DType; \
- {__VA_ARGS__} \
- } else { \
- LOG(FATAL) << "unknown data type " << type; \
+#define TVM_DTYPE_DISPATCH(type, DType, ...) \
+ if (type == DataType::Float(64)) { \
+ typedef double DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Float(32)) { \
+ typedef float DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Float(16)) { \
+ typedef uint16_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Int(64)) { \
+ typedef int64_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Int(32)) { \
+ typedef int32_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Int(16)) { \
+ typedef int16_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::Int(8)) { \
+ typedef int8_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::UInt(64)) { \
+ typedef uint64_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::UInt(32)) { \
+ typedef uint32_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::UInt(16)) { \
+ typedef uint16_t DType; \
+ { __VA_ARGS__ } \
+ } else if (type == DataType::UInt(8)) { \
+ typedef uint8_t DType; \
+ { __VA_ARGS__ } \
+ } else { \
+ LOG(FATAL) << "unknown data type " << type; \
}
DataType TfLiteDType2TVMDType(TfLiteType dtype) {
}
}
-void TFLiteRuntime::Init(const std::string& tflite_model_bytes,
- TVMContext ctx) {
+void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
+ tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
tflite::ops::builtin::BuiltinOpResolver resolver;
// Build interpreter
TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
ctx_ = ctx;
}
-void TFLiteRuntime::Invoke() {
- interpreter_->Invoke();
-}
+void TFLiteRuntime::Invoke() { interpreter_->Invoke(); }
void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
DataType dtype(data_in->dtype);
TVM_DTYPE_DISPATCH(dtype, DType, {
- DType* dest = interpreter_->typed_input_tensor<DType>(index);
- DType* src = static_cast<DType*>(data_in->data);
- CHECK(data_in->strides == NULL);
- int64_t size = 1;
- for (int64_t i = 0; i < data_in->ndim; ++i) {
- size *= data_in->shape[i];
- }
- for (int64_t i = 0; i < size; ++i) {
- dest[i] = src[i];
- }
- });
+ DType* dest = interpreter_->typed_input_tensor<DType>(index);
+ DType* src = static_cast<DType*>(data_in->data);
+ CHECK(data_in->strides == NULL);
+ int64_t size = 1;
+ for (int64_t i = 0; i < data_in->ndim; ++i) {
+ size *= data_in->shape[i];
+ }
+ for (int64_t i = 0; i < size; ++i) {
+ dest[i] = src[i];
+ }
+ });
}
NDArray TFLiteRuntime::GetOutput(int index) const {
}
NDArray ret = NDArray::Empty(shape, dtype, ctx_);
TVM_DTYPE_DISPATCH(dtype, DType, {
- DType* dest = static_cast<DType*>(ret->data);
- DType* src = interpreter_->typed_output_tensor<DType>(index);
- for (int64_t i = 0; i < size; ++i) {
- dest[i] = src[i];
- }
- });
+ DType* dest = static_cast<DType*>(ret->data);
+ DType* src = interpreter_->typed_output_tensor<DType>(index);
+ for (int64_t i = 0; i < size; ++i) {
+ dest[i] = src[i];
+ }
+ });
return ret;
}
-PackedFunc TFLiteRuntime::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc TFLiteRuntime::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- int in_idx = args[0];
- CHECK_GE(in_idx, 0);
- this->SetInput(in_idx, args[1]);
- });
+ int in_idx = args[0];
+ CHECK_GE(in_idx, 0);
+ this->SetInput(in_idx, args[1]);
+ });
} else if (name == "get_output") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetOutput(args[0]);
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); });
} else if (name == "invoke") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- this->Invoke();
- });
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); });
} else {
return PackedFunc();
}
}
-Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes,
- TVMContext ctx) {
+Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx) {
auto exec = make_object<TFLiteRuntime>();
exec->Init(tflite_model_bytes, ctx);
return Module(exec);
}
-TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = TFLiteRuntimeCreate(args[0], args[1]);
- });
+TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = TFLiteRuntimeCreate(args[0], args[1]);
+});
} // namespace runtime
} // namespace tvm
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
-#include <vector>
-#include <string>
#include <memory>
+#include <string>
+#include <vector>
namespace tvm {
namespace runtime {
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
- virtual PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self);
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
/*!
* \return The type key of the executor.
*/
- const char* type_key() const {
- return "TFLiteRuntime";
- }
+ const char* type_key() const { return "TFLiteRuntime"; }
/*!
- * \brief Invoke the internal tflite interpreter and run the whole model in
+ * \brief Invoke the internal tflite interpreter and run the whole model in
* dependency order.
*/
void Invoke();
* \param tflite_model_bytes The tflite model.
* \param ctx The context where the tflite model will be executed on.
*/
- void Init(const std::string& tflite_model_bytes,
- TVMContext ctx);
+ void Init(const std::string& tflite_model_bytes, TVMContext ctx);
/*!
* \brief set index-th input to the model.
*/
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include <cstdlib>
#include <cstring>
+
#include "workspace_pool.h"
#ifdef __ANDROID__
*rv = 1;
}
}
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) final {
void* ptr;
#if _MSC_VER
#endif
}
- void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final {
- memcpy(static_cast<char*>(to) + to_offset,
- static_cast<const char*>(from) + from_offset,
- size);
+ memcpy(static_cast<char*>(to) + to_offset, static_cast<const char*>(from) + from_offset, size);
}
- void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
- }
+ void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {}
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() {
- static std::shared_ptr<CPUDeviceAPI> inst =
- std::make_shared<CPUDeviceAPI>();
+ static std::shared_ptr<CPUDeviceAPI> inst = std::make_shared<CPUDeviceAPI>();
return inst;
}
};
struct CPUWorkspacePool : public WorkspacePool {
- CPUWorkspacePool() :
- WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
+ CPUWorkspacePool() : WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
};
-void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx,
- size_t size,
- DLDataType type_hint) {
- return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
- ->AllocWorkspace(ctx, size);
+void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) {
+ return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(ctx, size);
}
void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}
-TVM_REGISTER_GLOBAL("device_api.cpu")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = CPUDeviceAPI::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.cpu").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = CPUDeviceAPI::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
} // namespace runtime
} // namespace tvm
#include <dlpack/dlpack.h>
#include "load_json.h"
+#include "module.h"
#include "ndarray.h"
#include "packed_func.h"
-#include "module.h"
/*! \brief operator attributes about tvm op */
typedef struct TVMOpParam {
uint32_t index;
uint32_t version;
// JSON Loader
- void (*Load)(JSONReader *reader);
+ void (*Load)(JSONReader* reader);
} TVMGraphRuntimeNodeEntry;
// Node
// parameters
TVMOpParam param;
// inputs
- TVMGraphRuntimeNodeEntry * inputs;
+ TVMGraphRuntimeNodeEntry* inputs;
// number of inputs
size_t inputs_count;
// control deps
uint32_t control_deps[20];
// JSON Loader
- void (*LoadAttrs)(struct TVMGraphRuntimeNode * node, JSONReader *reader, TVMOpParam* param);
+ void (*LoadAttrs)(struct TVMGraphRuntimeNode* node, JSONReader* reader, TVMOpParam* param);
// JSON Loader
- int (*Load)(struct TVMGraphRuntimeNode * node, JSONReader *reader);
+ int (*Load)(struct TVMGraphRuntimeNode* node, JSONReader* reader);
} TVMGraphRuntimeNode;
// Graph attribute
typedef struct TVMGraphRuntimeGraphAttr {
uint32_t storage_num_not_alloctaed;
- uint32_t * storage_id;
- uint32_t * device_index;
- char * dltype; // "int8", "int16", "float32"
+ uint32_t* storage_id;
+ uint32_t* device_index;
+ char* dltype; // "int8", "int16", "float32"
uint32_t dltype_count;
- int64_t * shape;
- uint32_t * ndim;
+ int64_t* shape;
+ uint32_t* ndim;
uint32_t shape_count;
} TVMGraphRuntimeGraphAttr;
*/
/* class GraphRuntime : public ModuleNode { */
typedef struct TVMGraphRuntime {
- void (*Run)(struct TVMGraphRuntime * runtime);
+ void (*Run)(struct TVMGraphRuntime* runtime);
/*!
* \brief Initialize the graph executor with graph and context.
* \param ctxs The context of the host and devices where graph nodes will be
* executed on.
*/
- void (*Init)(struct TVMGraphRuntime * runtime,
- const char * graph_json,
- const TVMModule * module,
- const TVMContext * ctxs);
+ void (*Init)(struct TVMGraphRuntime* runtime, const char* graph_json, const TVMModule* module,
+ const TVMContext* ctxs);
/*!
* \brief Get the input index given the name of input.
* \param name The name of the input.
* \return The index of input.
*/
- int (*GetInputIndex)(struct TVMGraphRuntime * runtime, const char * name);
+ int (*GetInputIndex)(struct TVMGraphRuntime* runtime, const char* name);
/*!
* \brief set input to the graph based on name.
* \param name The name of the input.
* \param data_in The input data.
*/
- void (*SetInput)(struct TVMGraphRuntime * runtime, const char * name, DLTensor* data_in);
+ void (*SetInput)(struct TVMGraphRuntime* runtime, const char* name, DLTensor* data_in);
/*!
* \brief Return NDArray for given output index.
* \param out The DLTensor corresponding to given output node index.
* \return The result of this function execution.
*/
- int (*GetOutput)(struct TVMGraphRuntime * runtime, const int32_t index, DLTensor * out);
+ int (*GetOutput)(struct TVMGraphRuntime* runtime, const int32_t index, DLTensor* out);
/*!
* \brief Load parameters from parameter blob.
* \param runtime The graph runtime.
* \param param_size The parameter size.
* \return The result of this function execution.
*/
- int (*LoadParams)(struct TVMGraphRuntime * runtime, const char * param_blob,
+ int (*LoadParams)(struct TVMGraphRuntime* runtime, const char* param_blob,
const uint32_t param_size);
// The graph attribute fields.
- int (*Load)(struct TVMGraphRuntime * runtime, JSONReader *reader);
+ int (*Load)(struct TVMGraphRuntime* runtime, JSONReader* reader);
/*! \brief Setup the temporal storage */
- void (*SetupStorage)(struct TVMGraphRuntime * runtime);
+ void (*SetupStorage)(struct TVMGraphRuntime* runtime);
/*! \brief Setup the executors. */
- int (*SetupOpExecs)(struct TVMGraphRuntime * runtime);
+ int (*SetupOpExecs)(struct TVMGraphRuntime* runtime);
/*!
* \brief Create an execution function given input.
* \param pf The created executor.
* \return The result of this function execution.
*/
- int32_t (*CreateTVMOp)(struct TVMGraphRuntime * runtime, const TVMOpParam * attrs,
- DLTensorPtr * args, const uint32_t args_count,
- uint32_t num_inputs, TVMPackedFunc * pf);
+ int32_t (*CreateTVMOp)(struct TVMGraphRuntime* runtime, const TVMOpParam* attrs,
+ DLTensorPtr* args, const uint32_t args_count, uint32_t num_inputs,
+ TVMPackedFunc* pf);
// Get node entry index.
- uint32_t (*GetEntryId)(struct TVMGraphRuntime * runtime, uint32_t nid, uint32_t index);
+ uint32_t (*GetEntryId)(struct TVMGraphRuntime* runtime, uint32_t nid, uint32_t index);
/*! \brief The graph nodes. */
- TVMGraphRuntimeNode * nodes;
+ TVMGraphRuntimeNode* nodes;
/*! \brief The graph nodes counter. */
uint32_t nodes_count;
/*! \brief The argument nodes. */
- uint32_t * input_nodes;
+ uint32_t* input_nodes;
uint32_t input_nodes_count;
/*! \brief Used for quick entry indexing. */
- uint32_t * node_row_ptr;
+ uint32_t* node_row_ptr;
uint32_t node_row_ptr_count;
/*! \brief Output entries. */
- TVMGraphRuntimeNodeEntry * outputs;
+ TVMGraphRuntimeNodeEntry* outputs;
/*! \brief Output entries counter. */
uint32_t outputs_count;
/*! \brief Additional graph attributes. */
TVMModule module;
/*! \brief Execution context of all devices including the host. */
TVMContext ctxs[1];
- uint32_t ctxs_count;
+ uint32_t ctxs_count;
/*! \brief Common storage pool for all devices. */
- TVMNDArray * storage_pool;
+ TVMNDArray* storage_pool;
uint32_t storage_pool_count;
/*! \brief Data entry of each node. */
- TVMNDArray * data_entry;
+ TVMNDArray* data_entry;
uint32_t data_entry_count;
/*! \brief Operator on each node. */
- TVMPackedFunc * op_execs;
+ TVMPackedFunc* op_execs;
uint32_t op_execs_count;
} TVMGraphRuntime;
// public functions
-TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, const TVMModule * m,
- const TVMContext * ctxs);
-void TVMGraphRuntimeRelease(TVMGraphRuntime ** runtime);
+TVMGraphRuntime* TVMGraphRuntimeCreate(const char* sym_json, const TVMModule* m,
+ const TVMContext* ctxs);
+void TVMGraphRuntimeRelease(TVMGraphRuntime** runtime);
// private functions
-void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in);
-int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob,
+void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in);
+int TVMGraphRuntime_LoadParams(TVMGraphRuntime* runtime, const char* param_blob,
const uint32_t param_size);
-void TVMGraphRuntime_Run(TVMGraphRuntime * runtime);
-int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTensor * out);
+void TVMGraphRuntime_Run(TVMGraphRuntime* runtime);
+int TVMGraphRuntime_GetOutput(TVMGraphRuntime* runtime, const int32_t idx, DLTensor* out);
#endif // TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_
#ifndef TVM_RUNTIME_CRT_LOAD_JSON_H_
#define TVM_RUNTIME_CRT_LOAD_JSON_H_
-#include <stdio.h>
#include <ctype.h>
+#include <stdio.h>
enum {
JSON_READ_TYPE_U8 = 1,
};
typedef struct Seq {
- uint32_t * data;
+ uint32_t* data;
uint64_t allocated;
uint32_t size;
- void (*push_back)(struct Seq * seq, uint32_t src);
- uint32_t * (*back)(struct Seq * seq);
- void (*pop_back)(struct Seq * seq);
+ void (*push_back)(struct Seq* seq, uint32_t src);
+ uint32_t* (*back)(struct Seq* seq);
+ void (*pop_back)(struct Seq* seq);
} Seq;
/*!
*/
typedef struct JSONReader {
/*! \brief internal reader string */
- char * is_;
- char * isptr;
+ char* is_;
+ char* isptr;
/*! \brief "\\r" counter */
size_t line_count_r_;
/*! \brief "\\n" counter */
* \brief record how many element processed in
* current array/object scope.
*/
- Seq * scope_counter_;
+ Seq* scope_counter_;
- char (*NextChar)(struct JSONReader * reader);
- char (*NextNonSpace)(struct JSONReader * reader);
- char (*PeekNextChar)(struct JSONReader * reader);
- char (*PeekNextNonSpace)(struct JSONReader * reader);
- int (*ReadUnsignedInteger)(struct JSONReader * reader, unsigned int * out_value);
- int (*ReadInteger)(struct JSONReader * reader, int64_t * out_value);
- int (*ReadString)(struct JSONReader * reader, char * out_value);
- void (*BeginArray)(struct JSONReader * reader);
- void (*BeginObject)(struct JSONReader * reader);
- uint8_t (*NextObjectItem)(struct JSONReader * reader, char * out_key);
- uint8_t (*NextArrayItem)(struct JSONReader * reader);
+ char (*NextChar)(struct JSONReader* reader);
+ char (*NextNonSpace)(struct JSONReader* reader);
+ char (*PeekNextChar)(struct JSONReader* reader);
+ char (*PeekNextNonSpace)(struct JSONReader* reader);
+ int (*ReadUnsignedInteger)(struct JSONReader* reader, unsigned int* out_value);
+ int (*ReadInteger)(struct JSONReader* reader, int64_t* out_value);
+ int (*ReadString)(struct JSONReader* reader, char* out_value);
+ void (*BeginArray)(struct JSONReader* reader);
+ void (*BeginObject)(struct JSONReader* reader);
+ uint8_t (*NextObjectItem)(struct JSONReader* reader, char* out_key);
+ uint8_t (*NextArrayItem)(struct JSONReader* reader);
} JSONReader;
/*!
* \brief Constructor of JSONReader class
* \param is the input source.
*/
-JSONReader JSONReader_Create(const char * is);
+JSONReader JSONReader_Create(const char* is);
-void JSONReader_Release(JSONReader * reader);
+void JSONReader_Release(JSONReader* reader);
#endif // TVM_RUNTIME_CRT_LOAD_JSON_H_
#define TVM_RUNTIME_CRT_LOGGING_H_
#ifndef CHECK
-#define CHECK(x) \
- do { \
- if (!(x)) { \
- fprintf(stderr, "Check failed: %s\n", #x); \
- exit(-1); \
- } \
- }while(0)
+#define CHECK(x) \
+ do { \
+ if (!(x)) { \
+ fprintf(stderr, "Check failed: %s\n", #x); \
+ exit(-1); \
+ } \
+ } while (0)
#endif
#ifndef CHECK_BINARY_OP
-#define CHECK_BINARY_OP(op, x, y, fmt, ...) \
- do { \
- if (!(x op y)) { \
+#define CHECK_BINARY_OP(op, x, y, fmt, ...) \
+ do { \
+ if (!(x op y)) { \
fprintf(stderr, "Check failed: %s %s %s: " fmt "\n", #x, #op, #y, ##__VA_ARGS__); \
- exit(-1); \
- } \
- }while(0)
+ exit(-1); \
+ } \
+ } while (0)
#endif
#ifndef CHECK_LT
-#define CHECK_LT(x, y, fmt, ...) CHECK_BINARY_OP(<, x, y, fmt, ##__VA_ARGS__)
+#define CHECK_LT(x, y, fmt, ...) CHECK_BINARY_OP(<, x, y, fmt, ##__VA_ARGS__)
#endif
#ifndef CHECK_GT
-#define CHECK_GT(x, y, fmt, ...) CHECK_BINARY_OP(>, x, y, fmt, ##__VA_ARGS__)
+#define CHECK_GT(x, y, fmt, ...) CHECK_BINARY_OP(>, x, y, fmt, ##__VA_ARGS__)
#endif
#ifndef CHECK_LE
#ifndef TVM_RUNTIME_CRT_MODULE_H_
#define TVM_RUNTIME_CRT_MODULE_H_
-#include <tvm/runtime/c_runtime_api.h>
#include <string.h>
+#include <tvm/runtime/c_runtime_api.h>
struct TVMPackedFunc;
*
* This function will return PackedFunc(nullptr) if function do not exist.
*/
- void (*GetFunction)(struct TVMModule * mod, const char * name, struct TVMPackedFunc * pf);
+ void (*GetFunction)(struct TVMModule* mod, const char* name, struct TVMPackedFunc* pf);
} TVMModule;
#endif // TVM_RUNTIME_CRT_MODULE_H_
#ifndef TVM_RUNTIME_CRT_NDARRAY_H_
#define TVM_RUNTIME_CRT_NDARRAY_H_
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/c_backend_api.h>
#include <dlpack/dlpack.h>
-
-#include <string.h>
#include <stdio.h>
#include <stdlib.h>
+#include <string.h>
+#include <tvm/runtime/c_backend_api.h>
+#include <tvm/runtime/c_runtime_api.h>
/*! \brief Magic number for NDArray file */
static const uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
DLTensor dl_tensor;
} TVMNDArray;
-TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t * shape,
- DLDataType dtype, DLContext ctx);
+TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype,
+ DLContext ctx);
-TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape,
- DLDataType dtype, DLContext ctx);
+TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype,
+ DLContext ctx);
-int TVMNDArray_Load(TVMNDArray * ret, const char ** strm);
+int TVMNDArray_Load(TVMNDArray* ret, const char** strm);
-TVMNDArray TVMNDArray_CreateView(TVMNDArray * arr, const tvm_index_t * shape,
- uint32_t ndim, DLDataType dtype);
+TVMNDArray TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, uint32_t ndim,
+ DLDataType dtype);
-int TVMNDArray_Release(TVMNDArray * arr);
+int TVMNDArray_Release(TVMNDArray* arr);
#endif // TVM_RUNTIME_CRT_NDARRAY_H_
#ifndef TVM_RUNTIME_CRT_PACKED_FUNC_H_
#define TVM_RUNTIME_CRT_PACKED_FUNC_H_
-#include <tvm/runtime/c_runtime_api.h>
-
+#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
-#include <assert.h>
+#include <tvm/runtime/c_runtime_api.h>
#include "module.h"
-static inline DLDataType String2DLDataType(const char * s) {
+static inline DLDataType String2DLDataType(const char* s) {
DLDataType t;
// handle None type
if (strlen(s) == 0) {
- t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
+ t.bits = 0;
+ t.lanes = 0;
+ t.code = kTVMOpaqueHandle;
return t;
}
- t.bits = 32; t.lanes = 1;
+ t.bits = 32;
+ t.lanes = 1;
const char* scan;
if (!strncmp(s, "int", 3)) {
- t.code = kDLInt; scan = s + 3;
+ t.code = kDLInt;
+ scan = s + 3;
} else if (!strncmp(s, "uint", 4)) {
- t.code = kDLUInt; scan = s + 4;
+ t.code = kDLUInt;
+ scan = s + 4;
} else if (!strncmp(s, "float", 5)) {
- t.code = kDLFloat; scan = s + 5;
+ t.code = kDLFloat;
+ scan = s + 5;
} else if (!strncmp(s, "handle", 6)) {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
typedef struct TVMArgs {
TVMValue values[TVM_CRT_MAX_ARGS];
- int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */
+ int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */
uint32_t values_count;
} TVMArgs;
-static inline TVMArgs TVMArgs_Create(TVMValue * values, uint32_t * tcodes, uint32_t values_count) {
+static inline TVMArgs TVMArgs_Create(TVMValue* values, uint32_t* tcodes, uint32_t values_count) {
uint32_t idx;
TVMArgs args;
memset(&args, 0, sizeof(args));
return args;
}
-static inline int TVMNoOperation(TVMValue * args, int * type_codes, int num_args,
- TVMRetValueHandle ret, void * res) {
+static inline int TVMNoOperation(TVMValue* args, int* type_codes, int num_args,
+ TVMRetValueHandle ret, void* res) {
return 0;
}
char name[200];
TVMPackedCFunc fexec;
TVMArgs args;
- void (*Call)(struct TVMPackedFunc * pf);
- void (*SetArgs)(struct TVMPackedFunc * pf, const struct TVMArgs * args);
+ void (*Call)(struct TVMPackedFunc* pf);
+ void (*SetArgs)(struct TVMPackedFunc* pf, const struct TVMArgs* args);
} TVMPackedFunc;
-static inline void TVMPackedFunc_Call(TVMPackedFunc * pf) {
+static inline void TVMPackedFunc_Call(TVMPackedFunc* pf) {
pf->fexec(pf->args.values, pf->args.tcodes, pf->args.values_count, 0, 0);
}
-static inline void TVMPackedFunc_SetArgs(TVMPackedFunc * pf, const TVMArgs * args) {
+static inline void TVMPackedFunc_SetArgs(TVMPackedFunc* pf, const TVMArgs* args) {
memcpy(&(pf->args), args, sizeof(TVMArgs));
}
-TVMPackedFunc * g_fexecs = 0;
+TVMPackedFunc* g_fexecs = 0;
uint32_t g_fexecs_count = 0;
// Implement TVMModule::GetFunction
// Put implementation in this file so we have seen the TVMPackedFunc
-static inline void TVMModule_GetFunction(TVMModule * mod, const char * name, TVMPackedFunc * pf) {
+static inline void TVMModule_GetFunction(TVMModule* mod, const char* name, TVMPackedFunc* pf) {
int idx;
memset(pf, 0, sizeof(TVMPackedFunc));
assert(strlen(name) <= sizeof(pf->name));
#include <cuda_runtime.h>
#include <tvm/runtime/packed_func.h>
+
#include <string>
+
#include "../workspace_pool.h"
namespace tvm {
{ \
CUresult result = x; \
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
- const char *msg; \
+ const char* msg; \
cuGetErrorName(result, &msg); \
- LOG(FATAL) \
- << "CUDAError: " #x " failed with error: " << msg; \
+ LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \
} \
}
-#define CUDA_CALL(func) \
- { \
- cudaError_t e = (func); \
- CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
- << "CUDA: " << cudaGetErrorString(e); \
+#define CUDA_CALL(func) \
+ { \
+ cudaError_t e = (func); \
+ CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) << "CUDA: " << cudaGetErrorString(e); \
}
/*! \brief Thread local workspace */
* \file cuda_device_api.cc
* \brief GPU specific API
*/
-#include <tvm/runtime/device_api.h>
-
-#include <dmlc/thread_local.h>
-#include <tvm/runtime/registry.h>
#include <cuda.h>
#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include <cstring>
+
#include "cuda_common.h"
namespace tvm {
class CUDADeviceAPI final : public DeviceAPI {
public:
- void SetDevice(TVMContext ctx) final {
- CUDA_CALL(cudaSetDevice(ctx.device_id));
- }
+ void SetDevice(TVMContext ctx) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); }
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
int value = 0;
switch (kind) {
case kExist:
- value = (
- cudaDeviceGetAttribute(
- &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
- == cudaSuccess);
+ value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) ==
+ cudaSuccess);
break;
case kMaxThreadsPerBlock: {
- CUDA_CALL(cudaDeviceGetAttribute(
- &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
+ CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
break;
}
case kWarpSize: {
- CUDA_CALL(cudaDeviceGetAttribute(
- &value, cudaDevAttrWarpSize, ctx.device_id));
+ CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, ctx.device_id));
break;
}
case kMaxSharedMemoryPerBlock: {
- CUDA_CALL(cudaDeviceGetAttribute(
- &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id));
+ CUDA_CALL(
+ cudaDeviceGetAttribute(&value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id));
break;
}
case kComputeVersion: {
std::ostringstream os;
- CUDA_CALL(cudaDeviceGetAttribute(
- &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id));
+ CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMajor, ctx.device_id));
os << value << ".";
- CUDA_CALL(cudaDeviceGetAttribute(
- &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id));
+ CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMinor, ctx.device_id));
os << value;
*rv = os.str();
return;
return;
}
case kMaxClockRate: {
- CUDA_CALL(cudaDeviceGetAttribute(
- &value, cudaDevAttrClockRate, ctx.device_id));
+ CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrClockRate, ctx.device_id));
break;
}
case kMultiProcessorCount: {
- CUDA_CALL(cudaDeviceGetAttribute(
- &value, cudaDevAttrMultiProcessorCount, ctx.device_id));
+ CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMultiProcessorCount, ctx.device_id));
break;
}
case kMaxThreadDimensions: {
int dims[3];
- CUDA_CALL(cudaDeviceGetAttribute(
- &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id));
- CUDA_CALL(cudaDeviceGetAttribute(
- &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id));
- CUDA_CALL(cudaDeviceGetAttribute(
- &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));
+ CUDA_CALL(cudaDeviceGetAttribute(&dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id));
+ CUDA_CALL(cudaDeviceGetAttribute(&dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id));
+ CUDA_CALL(cudaDeviceGetAttribute(&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));
std::stringstream ss; // use json string to return multiple int values;
- ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
+ ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
return;
}
- case kGcnArch: return;
+ case kGcnArch:
+ return;
}
*rv = value;
}
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) final {
- CHECK_EQ(256 % alignment, 0U)
- << "CUDA space is aligned at 256 bytes";
- void *ret;
+ CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes";
+ void* ret;
if (ctx.device_type == kDLCPUPinned) {
CUDA_CALL(cudaMallocHost(&ret, nbytes));
} else {
}
}
- void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
// In case there is a copy from host mem to host mem */
if (ctx_to.device_type == kDLCPU && ctx_from.device_type == kDLCPU) {
- memcpy(to, from, size);
- return;
+ memcpy(to, from, size);
+ return;
}
if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) {
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else {
- cudaMemcpyPeerAsync(to, ctx_to.device_id,
- from, ctx_from.device_id,
- size, cu_stream);
+ cudaMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream);
}
} else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
}
void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
- CUDAThreadEntry::ThreadLocal()
- ->stream = static_cast<cudaStream_t>(stream);
+ CUDAThreadEntry::ThreadLocal()->stream = static_cast<cudaStream_t>(stream);
}
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
}
static const std::shared_ptr<CUDADeviceAPI>& Global() {
- static std::shared_ptr<CUDADeviceAPI> inst =
- std::make_shared<CUDADeviceAPI>();
+ static std::shared_ptr<CUDADeviceAPI> inst = std::make_shared<CUDADeviceAPI>();
return inst;
}
private:
- static void GPUCopy(const void* from,
- void* to,
- size_t size,
- cudaMemcpyKind kind,
+ static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind kind,
cudaStream_t stream) {
if (stream != 0) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
-CUDAThreadEntry::CUDAThreadEntry()
- : pool(kDLGPU, CUDADeviceAPI::Global()) {
-}
+CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {}
-CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
- return CUDAThreadStore::Get();
-}
+CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); }
-TVM_REGISTER_GLOBAL("device_api.gpu")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = CUDADeviceAPI::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = CUDADeviceAPI::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
-TVM_REGISTER_GLOBAL("device_api.cpu_pinned")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = CUDADeviceAPI::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = CUDADeviceAPI::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
} // namespace runtime
} // namespace tvm
*/
#include "cuda_module.h"
-#include <tvm/runtime/registry.h>
#include <cuda.h>
#include <cuda_runtime.h>
-#include <vector>
+#include <tvm/runtime/registry.h>
+
#include <array>
-#include <string>
#include <mutex>
+#include <string>
#include <unordered_map>
-#include "cuda_common.h"
+#include <vector>
+
+#include "../file_util.h"
+#include "../meta_data.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
-#include "../meta_data.h"
-#include "../file_util.h"
+#include "cuda_common.h"
namespace tvm {
namespace runtime {
// The modules will be lazily loaded
class CUDAModuleNode : public runtime::ModuleNode {
public:
- explicit CUDAModuleNode(std::string data,
- std::string fmt,
+ explicit CUDAModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string cuda_source)
: data_(data), fmt_(fmt), fmap_(fmap), cuda_source_(cuda_source) {
}
}
- const char* type_key() const final {
- return "cuda";
- }
+ const char* type_key() const final { return "cuda"; }
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final;
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
- void SaveToFile(const std::string& file_name,
- const std::string& format) final {
+ void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
if (fmt == "cu") {
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, cuda_source_);
} else {
- CHECK_EQ(fmt, fmt_)
- << "Can only save to format=" << fmt_;
+ CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}
CUfunction func;
CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str());
if (result != CUDA_SUCCESS) {
- const char *msg;
+ const char* msg;
cuGetErrorName(result, &msg);
- LOG(FATAL)
- << "CUDAError: cuModuleGetFunction " << func_name
- << " failed with error: " << msg;
+ LOG(FATAL) << "CUDAError: cuModuleGetFunction " << func_name << " failed with error: " << msg;
}
return func;
}
// get a global var from primary context in device_id
- CUdeviceptr GetGlobal(int device_id,
- const std::string& global_name,
- size_t expect_nbytes) {
+ CUdeviceptr GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
CUdeviceptr global;
size_t nbytes;
- CUresult result = cuModuleGetGlobal(&global, &nbytes,
- module_[device_id], global_name.c_str());
+ CUresult result = cuModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str());
CHECK_EQ(nbytes, expect_nbytes);
if (result != CUDA_SUCCESS) {
- const char *msg;
+ const char* msg;
cuGetErrorName(result, &msg);
- LOG(FATAL)
- << "CUDAError: cuModuleGetGlobal " << global_name
- << " failed with error: " << msg;
+ LOG(FATAL) << "CUDAError: cuModuleGetGlobal " << global_name << " failed with error: " << msg;
}
return global;
}
class CUDAWrappedFunc {
public:
// initialize the CUDA function.
- void Init(CUDAModuleNode* m,
- ObjectPtr<Object> sptr,
- const std::string& func_name,
- size_t num_void_args,
- const std::vector<std::string>& thread_axis_tags) {
+ void Init(CUDAModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
+ size_t num_void_args, const std::vector<std::string>& thread_axis_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
}
// invoke the function with void arguments
- void operator()(TVMArgs args,
- TVMRetValue* rv,
- void** void_args) const {
+ void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
if (fcache_[device_id] == nullptr) {
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
- CUresult result = cuLaunchKernel(
- fcache_[device_id],
- wl.grid_dim(0),
- wl.grid_dim(1),
- wl.grid_dim(2),
- wl.block_dim(0),
- wl.block_dim(1),
- wl.block_dim(2),
- 0, strm, void_args, 0);
+ CUresult result =
+ cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2),
+ wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), 0, strm, void_args, 0);
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
- const char *msg;
+ const char* msg;
cuGetErrorName(result, &msg);
std::ostringstream os;
os << "CUDALaunch Error: " << msg << "\n"
- << " grid=(" << wl.grid_dim(0) << ","
- << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), "
- << " block=(" << wl.block_dim(0) << ","
- << wl.block_dim(1) << "," << wl.block_dim(2) << ")\n";
+ << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), "
+ << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2)
+ << ")\n";
std::string cuda = m_->GetSource("");
if (cuda.length() != 0) {
os << "// func_name=" << func_name_ << "\n"
class CUDAPrepGlobalBarrier {
public:
- CUDAPrepGlobalBarrier(CUDAModuleNode* m,
- ObjectPtr<Object> sptr)
- : m_(m), sptr_(sptr) {
+ CUDAPrepGlobalBarrier(CUDAModuleNode* m, ObjectPtr<Object> sptr) : m_(m), sptr_(sptr) {
std::fill(pcache_.begin(), pcache_.end(), 0);
}
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
if (pcache_[device_id] == 0) {
- pcache_[device_id] = m_->GetGlobal(
- device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned));
+ pcache_[device_id] =
+ m_->GetGlobal(device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned));
}
CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1));
}
mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
};
-PackedFunc CUDAModuleNode::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc CUDAModuleNode::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
- CHECK_NE(name, symbol::tvm_module_main)
- << "Device function do not have main";
+ CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
if (name == symbol::tvm_prepare_global_barrier) {
return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self));
}
return PackFuncVoidAddr(f, info.arg_types);
}
-Module CUDAModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string cuda_source) {
+Module CUDAModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap,
+ std::string cuda_source) {
auto n = make_object<CUDAModuleNode>(data, fmt, fmap, cuda_source);
return Module(n);
}
// Load module from module.
-Module CUDAModuleLoadFile(const std::string& file_name,
- const std::string& format) {
+Module CUDAModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
return CUDAModuleCreate(data, fmt, fmap, std::string());
}
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin")
-.set_body_typed(CUDAModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile);
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx")
-.set_body_typed(CUDAModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile);
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda")
-.set_body_typed(CUDAModuleLoadBinary);
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary);
} // namespace runtime
} // namespace tvm
#define TVM_RUNTIME_CUDA_CUDA_MODULE_H_
#include <tvm/runtime/module.h>
+
#include <memory>
-#include <vector>
#include <string>
#include <unordered_map>
+#include <vector>
+
#include "../meta_data.h"
namespace tvm {
* \param fmap The map function information map of each function.
* \param cuda_source Optional, cuda source file
*/
-Module CUDAModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string cuda_source);
+Module CUDAModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap,
+ std::string cuda_source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_
* \file dso_libary.cc
* \brief Create library module to load from dynamic shared library.
*/
-#include <tvm/runtime/module.h>
#include <tvm/runtime/memory.h>
-#include <tvm/runtime/registry.h>
+#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
#include "library_module.h"
#if defined(_WIN32)
~DSOLibrary() {
if (lib_handle_) Unload();
}
- void Init(const std::string& name) {
- Load(name);
- }
+ void Init(const std::string& name) { Load(name); }
- void* GetSymbol(const char* name) final {
- return GetSymbol_(name);
- }
+ void* GetSymbol(const char* name) final { return GetSymbol_(name); }
private:
// Platform dependent handling.
HMODULE lib_handle_{nullptr};
void* GetSymbol_(const char* name) {
- return reinterpret_cast<void*>(
- GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
+ return reinterpret_cast<void*>(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
// Load the library
// use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
- CHECK(lib_handle_ != nullptr)
- << "Failed to load dynamic shared library " << name;
+ CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name;
}
void Unload() {
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
- CHECK(lib_handle_ != nullptr)
- << "Failed to load dynamic shared library " << name
- << " " << dlerror();
+ CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " "
+ << dlerror();
}
- void* GetSymbol_(const char* name) {
- return dlsym(lib_handle_, name);
- }
+ void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); }
void Unload() {
dlclose(lib_handle_);
#endif
};
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_so")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- auto n = make_object<DSOLibrary>();
- n->Init(args[0]);
- *rv = CreateModuleFromLibrary(n);
- });
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) {
+ auto n = make_object<DSOLibrary>();
+ n->Init(args[0]);
+ *rv = CreateModuleFromLibrary(n);
+});
} // namespace runtime
} // namespace tvm
/*!
* \file file_util.cc
*/
+#include "file_util.h"
+
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <tvm/runtime/serializer.h>
+
#include <fstream>
-#include <vector>
#include <unordered_map>
-#include "file_util.h"
+#include <vector>
namespace tvm {
namespace runtime {
return true;
}
-std::string GetFileFormat(const std::string& file_name,
- const std::string& format) {
+std::string GetFileFormat(const std::string& file_name, const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
size_t pos = file_name.find_last_of(".");
}
std::string GetMetaFilePath(const std::string& file_name) {
- size_t pos = file_name.find_last_of(".");
+ size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".tvm_meta.json";
} else {
}
}
-void LoadBinaryFromFile(const std::string& file_name,
- std::string* data) {
+void LoadBinaryFromFile(const std::string& file_name, std::string* data) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
// get its size:
fs.read(&(*data)[0], size);
}
-void SaveBinaryToFile(
- const std::string& file_name,
- const std::string& data) {
+void SaveBinaryToFile(const std::string& file_name, const std::string& data) {
std::ofstream fs(file_name, std::ios::out | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
fs.write(&data[0], data.length());
}
-void SaveMetaDataToFile(
- const std::string& file_name,
- const std::unordered_map<std::string, FunctionInfo>& fmap) {
+void SaveMetaDataToFile(const std::string& file_name,
+ const std::unordered_map<std::string, FunctionInfo>& fmap) {
std::string version = "0.1.0";
std::ofstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
fs.close();
}
-void LoadMetaDataFromFile(
- const std::string& file_name,
- std::unordered_map<std::string, FunctionInfo>* fmap) {
+void LoadMetaDataFromFile(const std::string& file_name,
+ std::unordered_map<std::string, FunctionInfo>* fmap) {
std::ifstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
std::string version;
fs.close();
}
-void RemoveFile(const std::string& file_name) {
- std::remove(file_name.c_str());
-}
+void RemoveFile(const std::string& file_name) { std::remove(file_name.c_str()); }
} // namespace runtime
} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <string>
#include <unordered_map>
+
#include "meta_data.h"
namespace tvm {
* \param file_name The name of the file.
* \param format The format of the file.
*/
-std::string GetFileFormat(const std::string& file_name,
- const std::string& format);
+std::string GetFileFormat(const std::string& file_name, const std::string& format);
/*!
* \return the directory in which TVM stores cached files.
* \param file_name The name of the file.
* \param data The data to be loaded.
*/
-void LoadBinaryFromFile(const std::string& file_name,
- std::string* data);
+void LoadBinaryFromFile(const std::string& file_name, std::string* data);
/*!
* \brief Load binary file into a in-memory buffer.
* \param file_name The name of the file.
* \param data The binary data to be saved.
*/
-void SaveBinaryToFile(const std::string& file_name,
- const std::string& data);
+void SaveBinaryToFile(const std::string& file_name, const std::string& data);
/*!
* \brief Save meta data to file.
* \param file_name The name of the file.
* \param fmap The function info map.
*/
-void SaveMetaDataToFile(
- const std::string& file_name,
- const std::unordered_map<std::string, FunctionInfo>& fmap);
+void SaveMetaDataToFile(const std::string& file_name,
+ const std::unordered_map<std::string, FunctionInfo>& fmap);
/*!
* \brief Load meta data to file.
* \param file_name The name of the file.
* \param fmap The function info map.
*/
-void LoadMetaDataFromFile(
- const std::string& file_name,
- std::unordered_map<std::string, FunctionInfo>* fmap);
+void LoadMetaDataFromFile(const std::string& file_name,
+ std::unordered_map<std::string, FunctionInfo>* fmap);
/*!
* \brief Remove (unlink) a file.
/*!
* \file graph_runtime_debug.cc
*/
+#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include <tvm/runtime/ndarray.h>
#include <chrono>
#include <sstream>
+
#include "../graph_runtime.h"
namespace tvm {
std::ostringstream os;
std::vector<double> time_per_op(op_execs_.size(), 0);
for (int i = 0; i < repeat; ++i) {
- std::chrono::time_point<
- std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
+ std::chrono::time_point<std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin,
+ tend;
double duration_ms = 0.0;
do {
std::fill(time_per_op.begin(), time_per_op.end(), 0);
if (duration_ms > 0.0) {
- number = static_cast<int>(
- std::max((min_repeat_ms / (duration_ms / number) + 1),
- number * 1.618)); // 1.618 is chosen by random
+ number = static_cast<int>(std::max((min_repeat_ms / (duration_ms / number) + 1),
+ number * 1.618)); // 1.618 is chosen by random
}
tbegin = std::chrono::high_resolution_clock::now();
for (int k = 0; k < number; k++) {
op_execs_[index]();
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_tend = std::chrono::high_resolution_clock::now();
- double op_duration = std::chrono::duration_cast<
- std::chrono::duration<double> >(op_tend - op_tbegin).count();
+ double op_duration =
+ std::chrono::duration_cast<std::chrono::duration<double> >(op_tend - op_tbegin)
+ .count();
time_per_op[index] += op_duration * 1e6; // us
}
}
}
tend = std::chrono::high_resolution_clock::now();
- duration_ms = std::chrono::duration_cast<std::chrono::duration<double> >
- (tend - tbegin).count() * 1000;
+ duration_ms =
+ std::chrono::duration_cast<std::chrono::duration<double> >(tend - tbegin).count() *
+ 1000;
} while (duration_ms < min_repeat_ms);
LOG(INFO) << "Iteration: " << i;
for (size_t index = 0; index < time_per_op.size(); index++) {
if (op_execs_[index]) {
time_per_op[index] /= number;
- LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": "
- << time_per_op[index] << " us/iter";
+ LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " << time_per_op[index]
+ << " us/iter";
}
}
}
* \param index The index of op which needs to be returned.
* \param eid The Entry id of the op.
*/
- NDArray GetOutputByLayer(int index, int eid) {
- return data_entry_[entry_id(index, eid)];
- }
+ NDArray GetOutputByLayer(int index, int eid) { return data_entry_[entry_id(index, eid)]; }
/*!
* \brief GetFunction Get the function based on input.
* \param name The function which needs to be invoked.
* \param sptr_to_self Packed function pointer.
*/
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self);
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
/*!
* \brief Get the node index given the name of node.
}
LOG(FATAL) << "cannot find " << name << " among nodex";
return -1;
-}
+ }
-/*!
- * \brief Copy index-th node to data_out.
- *
- * This method will do a partial run of the the graph
- * from begining upto the index-th node and return output of index-th node.
- * This is costly operation and suggest to use only for debug porpose.
- *
- * \param index: The index of the node.
- * \param data_out the node data.
- */
-void DebugGetNodeOutput(int index, DLTensor* data_out) {
- CHECK_LT(static_cast<size_t>(index), op_execs_.size());
- uint32_t eid = index;
+ /*!
+ * \brief Copy index-th node to data_out.
+ *
+ * This method will do a partial run of the the graph
+ * from begining upto the index-th node and return output of index-th node.
+ * This is costly operation and suggest to use only for debug porpose.
+ *
+ * \param index: The index of the node.
+ * \param data_out the node data.
+ */
+ void DebugGetNodeOutput(int index, DLTensor* data_out) {
+ CHECK_LT(static_cast<size_t>(index), op_execs_.size());
+ uint32_t eid = index;
- for (size_t i = 0; i < op_execs_.size(); ++i) {
- if (op_execs_[i]) op_execs_[i]();
- if (static_cast<int>(i) == index) break;
- }
+ for (size_t i = 0; i < op_execs_.size(); ++i) {
+ if (op_execs_[i]) op_execs_[i]();
+ if (static_cast<int>(i) == index) break;
+ }
- data_entry_[eid].CopyTo(data_out);
-}
+ data_entry_[eid].CopyTo(data_out);
+ }
};
-
/*!
* \brief GetFunction Get the function based on input.
* \param name The function which needs to be invoked.
* \param sptr_to_self Packed function pointer.
*/
-PackedFunc GraphRuntimeDebug::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc GraphRuntimeDebug::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
// return member functions during query.
if (name == "get_output_by_layer") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetOutputByLayer(args[0], args[1]);
- });
+ *rv = this->GetOutputByLayer(args[0], args[1]);
+ });
} else if (name == "debug_get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- if (args[0].type_code() == kTVMStr) {
- this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
- } else {
- this->DebugGetNodeOutput(args[0], args[1]);
- }
- });
+ if (args[0].type_code() == kTVMStr) {
+ this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
+ } else {
+ this->DebugGetNodeOutput(args[0], args[1]);
+ }
+ });
} else if (name == "run_individual") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int number = args[0];
* \param m Compiled module which will be loaded.
* \param ctxs All devices contexts.
*/
-Module GraphRuntimeDebugCreate(const std::string& sym_json,
- const tvm::runtime::Module& m,
+Module GraphRuntimeDebugCreate(const std::string& sym_json, const tvm::runtime::Module& m,
const std::vector<TVMContext>& ctxs) {
auto exec = make_object<GraphRuntimeDebug>();
exec->Init(sym_json, m, ctxs);
return Module(exec);
}
-TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- CHECK_GE(args.num_args, 4)
- << "The expected number of arguments for graph_runtime.create is "
- "at least 4, but it has "
- << args.num_args;
- *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args));
- });
+TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create").set_body([](TVMArgs args, TVMRetValue* rv) {
+ CHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is "
+ "at least 4, but it has "
+ << args.num_args;
+ *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args));
+});
} // namespace runtime
} // namespace tvm
/*!
* \file graph_runtime.cc
*/
+#include "graph_runtime.h"
+
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <utility>
#include <vector>
-#include "graph_runtime.h"
-
namespace tvm {
namespace runtime {
namespace details {
* \param ctxs The context of the host and devices where graph nodes will be
* executed on.
*/
-void GraphRuntime::Init(const std::string& graph_json,
- tvm::runtime::Module module,
+void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module,
const std::vector<TVMContext>& ctxs) {
std::istringstream is(graph_json);
dmlc::JSONReader reader(&is);
*
* \return The number of outputs from graph.
*/
-int GraphRuntime::NumOutputs() const {
- return outputs_.size();
-}
+int GraphRuntime::NumOutputs() const { return outputs_.size(); }
/*!
* \brief Return NDArray for given input index.
* \param index The input index.
void GraphRuntime::LoadParams(dmlc::Stream* strm) {
uint64_t header, reserved;
- CHECK(strm->Read(&header))
- << "Invalid parameters file format";
- CHECK(header == kTVMNDArrayListMagic)
- << "Invalid parameters file format";
- CHECK(strm->Read(&reserved))
- << "Invalid parameters file format";
+ CHECK(strm->Read(&header)) << "Invalid parameters file format";
+ CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format";
+ CHECK(strm->Read(&reserved)) << "Invalid parameters file format";
std::vector<std::string> names;
- CHECK(strm->Read(&names))
- << "Invalid parameters file format";
+ CHECK(strm->Read(&names)) << "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
- CHECK(size == names.size())
- << "Invalid parameters file format";
+ CHECK(size == names.size()) << "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
int in_idx = GetInputIndex(names[i]);
CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i];
}
void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
- uint64_t header, reserved;
- CHECK(strm->Read(&header))
- << "Invalid parameters file format";
- CHECK(header == kTVMNDArrayListMagic)
- << "Invalid parameters file format";
- CHECK(strm->Read(&reserved))
- << "Invalid parameters file format";
+ uint64_t header, reserved;
+ CHECK(strm->Read(&header)) << "Invalid parameters file format";
+ CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format";
+ CHECK(strm->Read(&reserved)) << "Invalid parameters file format";
std::vector<std::string> names;
CHECK(strm->Read(&names)) << "Invalid parameters file format";
uint64_t sz;
CHECK_GE(storage_id, 0) << "Do not support runtime shape op";
DLDataType t = vtype[i];
size_t bits = t.bits * t.lanes;
- CHECK(bits % 8U == 0U || bits ==1U);
+ CHECK(bits % 8U == 0U || bits == 1U);
size_t bytes = ((bits + 7U) / 8U) * size;
uint32_t sid = static_cast<uint32_t>(storage_id);
if (sid >= pool_entry.size()) {
pool_entry.resize(sid + 1, {0, -1});
} else {
- CHECK(pool_entry[sid].device_type == -1 ||
- pool_entry[sid].device_type == device_type)
+ CHECK(pool_entry[sid].device_type == -1 || pool_entry[sid].device_type == device_type)
<< "The same pool entry cannot be assigned to multiple devices";
}
pool_entry[sid].size = std::max(pool_entry[sid].size, bytes);
std::vector<int64_t> shape;
// This for loop is very fast since there are usually only a couple of
// devices available on the same hardware.
- const auto& cit =
- std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) {
- return pit.device_type == static_cast<int>(c.device_type);
- });
+ const auto& cit = std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) {
+ return pit.device_type == static_cast<int>(c.device_type);
+ });
TVMContext ctx = cit == ctxs_.end() ? ctxs_[0] : *cit;
shape.push_back(static_cast<int64_t>(pit.size + 3) / 4);
- storage_pool_.push_back(
- NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx));
+ storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx));
}
// Assign the pooled entries. A unified memory pool is used to simplifiy
for (size_t i = 0; i < data_entry_.size(); ++i) {
int storage_id = attrs_.storage_id[i];
CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size());
- data_entry_[i] =
- storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);
+ data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);
const DLTensor* tmp = data_entry_[i].operator->();
data_alignment_[i] = details::GetDataAlignment(*tmp);
}
CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op";
std::shared_ptr<OpArgs> op_args = nullptr;
- std::tie(op_execs_[nid], op_args) =
- CreateTVMOp(inode.param, args, inode.inputs.size());
+ std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args, inode.inputs.size());
for (size_t i = 0; i < inode.inputs.size(); i++) {
uint32_t eid = this->entry_id(inode.inputs[i]);
// check if op input is model input
if (input_node_eids.count(eid) > 0) {
- input_dltensors_[eid].push_back(
- static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
+ input_dltensors_[eid].push_back(static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
}
}
}
}
std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRuntime::CreateTVMOp(
- const TVMOpParam& param,
- const std::vector<DLTensor>& args,
- size_t num_inputs) {
+ const TVMOpParam& param, const std::vector<DLTensor>& args, size_t num_inputs) {
std::shared_ptr<GraphRuntime::OpArgs> arg_ptr = std::make_shared<GraphRuntime::OpArgs>();
// setup address.
arg_ptr->args = args;
arg_ptr->arg_values.push_back(v);
arg_ptr->arg_tcodes.push_back(kTVMDLTensorHandle);
if (param.flatten_data) {
- arg_ptr->shape_data[i] = std::accumulate(
- t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>());
+ arg_ptr->shape_data[i] =
+ std::accumulate(t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>());
t->ndim = 1;
t->shape = &(arg_ptr->shape_data[i]);
}
}
if (param.func_name == "__nop") {
- return {[](){}, arg_ptr};
+ return {[]() {}, arg_ptr};
} else if (param.func_name == "__copy") {
// Perform cross device data copy.
// Directly copy data from the input to the output.
auto fexec = [arg_ptr, pf]() {
TVMRetValue rv;
- TVMArgs targs(arg_ptr->arg_values.data(),
- arg_ptr->arg_tcodes.data(),
+ TVMArgs targs(arg_ptr->arg_values.data(), arg_ptr->arg_tcodes.data(),
static_cast<int>(arg_ptr->arg_values.size()));
pf.CallPacked(targs, &rv);
};
return {fexec, arg_ptr};
}
-PackedFunc GraphRuntime::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc GraphRuntime::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- if (args[0].type_code() == kTVMStr) {
- int in_idx = this->GetInputIndex(args[0]);
- if (in_idx >= 0) this->SetInput(in_idx, args[1]);
- } else {
- this->SetInput(args[0], args[1]);
- }
- });
+ if (args[0].type_code() == kTVMStr) {
+ int in_idx = this->GetInputIndex(args[0]);
+ if (in_idx >= 0) this->SetInput(in_idx, args[1]);
+ } else {
+ this->SetInput(args[0], args[1]);
+ }
+ });
} else if (name == "set_input_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kTVMStr) {
});
} else if (name == "get_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- int in_idx = 0;
- if (args[0].type_code() == kTVMStr) {
- in_idx = this->GetInputIndex(args[0]);
- } else {
- in_idx = args[0];
- }
- CHECK_GE(in_idx, 0);
- *rv = this->GetInput(in_idx);
- });
+ int in_idx = 0;
+ if (args[0].type_code() == kTVMStr) {
+ in_idx = this->GetInputIndex(args[0]);
+ } else {
+ in_idx = args[0];
+ }
+ CHECK_GE(in_idx, 0);
+ *rv = this->GetInput(in_idx);
+ });
} else if (name == "get_num_outputs") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->NumOutputs();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
} else if (name == "run") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- this->Run();
- });
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); });
} else if (name == "load_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- this->LoadParams(args[0].operator std::string());
- });
+ this->LoadParams(args[0].operator std::string());
+ });
} else if (name == "share_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- const auto& module = args[0].operator Module();
- CHECK_EQ(module.operator->()->type_key(), "GraphRuntime");
- const auto& param_blob = args[1].operator std::string();
- dmlc::MemoryStringStream strm(const_cast<std::string*>(¶m_blob));
- this->ShareParams(dynamic_cast<const GraphRuntime&>(*module.operator->()), &strm);
- });
+ const auto& module = args[0].operator Module();
+ CHECK_EQ(module.operator->()->type_key(), "GraphRuntime");
+ const auto& param_blob = args[1].operator std::string();
+ dmlc::MemoryStringStream strm(const_cast<std::string*>(¶m_blob));
+ this->ShareParams(dynamic_cast<const GraphRuntime&>(*module.operator->()), &strm);
+ });
} else {
return PackedFunc();
}
}
-Module GraphRuntimeCreate(const std::string& sym_json,
- const tvm::runtime::Module& m,
+Module GraphRuntimeCreate(const std::string& sym_json, const tvm::runtime::Module& m,
const std::vector<TVMContext>& ctxs) {
auto exec = make_object<GraphRuntime>();
exec->Init(sym_json, m, ctxs);
// execution support yet. For heterogenenous execution, at least 5 arguments will
// be passed in. The third one is the number of devices.
// Eventually, we will only probably pass TVMContext for all the languages.
-TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- CHECK_GE(args.num_args, 4)
- << "The expected number of arguments for graph_runtime.create is "
- "at least 4, but it has "
- << args.num_args;
- const auto& contexts = GetAllContext(args);
- *rv = GraphRuntimeCreate(args[0], args[1], contexts);
- });
+TVM_REGISTER_GLOBAL("tvm.graph_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) {
+ CHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is "
+ "at least 4, but it has "
+ << args.num_args;
+ const auto& contexts = GetAllContext(args);
+ *rv = GraphRuntimeCreate(args[0], args[1], contexts);
+});
} // namespace runtime
} // namespace tvm
#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
#include <dlpack/dlpack.h>
-#include <dmlc/memory_io.h>
#include <dmlc/json.h>
+#include <dmlc/memory_io.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
+#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
-#include <string>
namespace tvm {
namespace runtime {
/*! \brief macro to do C API call */
-#define TVM_CCALL(func) \
- { \
- int ret = (func); \
- CHECK_EQ(ret, 0) \
- << TVMGetLastError(); \
+#define TVM_CCALL(func) \
+ { \
+ int ret = (func); \
+ CHECK_EQ(ret, 0) << TVMGetLastError(); \
}
/*! \brief Magic number for NDArray list file */
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
- virtual PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self);
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
/*!
* \return The type key of the executor.
*/
- const char* type_key() const final {
- return "GraphRuntime";
- }
+ const char* type_key() const final { return "GraphRuntime"; }
void Run();
/*!
* executed on.
*/
- void Init(const std::string& graph_json,
- tvm::runtime::Module module,
+ void Init(const std::string& graph_json, tvm::runtime::Module module,
const std::vector<TVMContext>& ctxs);
/*!
* \brief Get total number of nodes.
* \return Total number of nodes.
*/
- uint32_t GetNumOfNodes() const {
- return static_cast<uint32_t>(nodes_.size());
- }
-
- std::string GetNodeName(uint32_t nid) const {
- return nodes_[nid].name;
- }
+ uint32_t GetNumOfNodes() const { return static_cast<uint32_t>(nodes_.size()); }
+ std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; }
protected:
// Memory pool entry.
uint32_t index;
uint32_t version;
// JSON Loader
- void Load(dmlc::JSONReader *reader) {
+ void Load(dmlc::JSONReader* reader) {
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&node_id);
// control deps
std::vector<uint32_t> control_deps;
// JSON Loader
- void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) {
+ void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) {
int bitmask = 0;
std::string key, value;
reader->BeginObject();
bitmask |= 8;
}
}
- CHECK_EQ(bitmask, 1|2|4|8) << "invalid format";
+ CHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "invalid format";
}
// JSON Loader
- void Load(dmlc::JSONReader *reader) {
+ void Load(dmlc::JSONReader* reader) {
reader->BeginObject();
int bitmask = 0;
std::string key;
LOG(FATAL) << "do not support key " << key;
}
}
- CHECK_EQ(bitmask, 1|2|4) << "invalid format";
+ CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format";
}
};
struct GraphAttr {
std::vector<int> storage_id;
std::vector<int> device_index;
std::vector<std::string> dltype;
- std::vector<std::vector<int64_t> > shape;
+ std::vector<std::vector<int64_t>> shape;
// The graph attribute fields.
- void Load(dmlc::JSONReader *reader) {
+ void Load(dmlc::JSONReader* reader) {
reader->BeginObject();
int bitmask = 0;
std::string key, type;
CHECK(!reader->NextArrayItem());
}
}
- CHECK_EQ(bitmask, 1|2|4) << "invalid format";
+ CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format";
}
};
// The graph attribute fields.
- void Load(dmlc::JSONReader *reader) {
- reader->BeginObject();
- int bitmask = 0;
- std::string key;
- while (reader->NextObjectItem(&key)) {
- if (key == "nodes") {
- reader->Read(&nodes_);
- bitmask |= 1;
- } else if (key == "arg_nodes") {
- reader->Read(&input_nodes_);
- bitmask |= 2;
- } else if (key == "node_row_ptr") {
- reader->Read(&node_row_ptr_);
- bitmask |= 4;
- } else if (key == "heads") {
- reader->Read(&outputs_);
- bitmask |= 8;
- } else if (key == "attrs") {
- reader->Read(&attrs_);
- bitmask |= 16;
- } else if (key == "metadata") {
- break;
- } else {
- LOG(FATAL) << "key " << key << " is not supported";
- }
+ void Load(dmlc::JSONReader* reader) {
+ reader->BeginObject();
+ int bitmask = 0;
+ std::string key;
+ while (reader->NextObjectItem(&key)) {
+ if (key == "nodes") {
+ reader->Read(&nodes_);
+ bitmask |= 1;
+ } else if (key == "arg_nodes") {
+ reader->Read(&input_nodes_);
+ bitmask |= 2;
+ } else if (key == "node_row_ptr") {
+ reader->Read(&node_row_ptr_);
+ bitmask |= 4;
+ } else if (key == "heads") {
+ reader->Read(&outputs_);
+ bitmask |= 8;
+ } else if (key == "attrs") {
+ reader->Read(&attrs_);
+ bitmask |= 16;
+ } else if (key == "metadata") {
+ break;
+ } else {
+ LOG(FATAL) << "key " << key << " is not supported";
}
- CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format";
+ }
+ CHECK_EQ(bitmask, 1 | 2 | 4 | 8 | 16) << "invalid format";
}
/*! \brief Setup the temporal storage */
void SetupStorage();
* \param num_inputs Number of inputs.
* \return The created executor.
*/
- std::pair<std::function<void()>, std::shared_ptr<OpArgs> > CreateTVMOp(
- const TVMOpParam& attrs, const std::vector<DLTensor>& args,
- size_t num_inputs);
+ std::pair<std::function<void()>, std::shared_ptr<OpArgs>> CreateTVMOp(
+ const TVMOpParam& attrs, const std::vector<DLTensor>& args, size_t num_inputs);
// Get node entry index.
- uint32_t entry_id(uint32_t nid, uint32_t index) const {
- return node_row_ptr_[nid] + index;
- }
+ uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; }
// Get node entry index.
- uint32_t entry_id(const NodeEntry& e) const {
- return entry_id(e.node_id, e.index);
- }
+ uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); }
// Number of node entries.
- uint32_t num_node_entries() const {
- return node_row_ptr_.back();
- }
+ uint32_t num_node_entries() const { return node_row_ptr_.back(); }
/*! \brief The graph nodes. */
std::vector<Node> nodes_;
/*! \brief The argument nodes. */
/*! \brief Data alignment of each node. */
std::vector<size_t> data_alignment_;
/*! \brief Operator on each node. */
- std::vector<std::function<void()> > op_execs_;
+ std::vector<std::function<void()>> op_execs_;
};
std::vector<TVMContext> GetAllContext(const TVMArgs& args);
public:
void SetDevice(TVMContext ctx) final;
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
- void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
- DLDataType type_hint) final;
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
- void CopyDataFromTo(const void* from, size_t from_offset, void* to,
- size_t to_offset, size_t num_bytes, TVMContext ctx_from,
- TVMContext ctx_to, DLDataType type_hint,
- TVMStreamHandle stream) final;
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
+ size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to,
+ DLDataType type_hint, TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
- void* AllocWorkspace(TVMContext ctx, size_t nbytes,
- DLDataType type_hint = {}) final;
+ void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}) final;
void FreeWorkspace(TVMContext ctx, void* ptr) final;
static const std::shared_ptr<HexagonDeviceAPI>& Global() {
- static std::shared_ptr<HexagonDeviceAPI> inst =
- std::make_shared<HexagonDeviceAPI>();
+ static std::shared_ptr<HexagonDeviceAPI> inst = std::make_shared<HexagonDeviceAPI>();
return inst;
}
};
inline void HexagonDeviceAPI::SetDevice(TVMContext ctx) {}
-inline void HexagonDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind,
- TVMRetValue* rv) {
+inline void HexagonDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
if (kind == kExist) *rv = 1;
}
-inline void* HexagonDeviceAPI::AllocDataSpace(TVMContext ctx, size_t nbytes,
- size_t alignment,
+inline void* HexagonDeviceAPI::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) {
CHECK(hexagon::Device::ValidateDeviceId(ctx.device_id));
return hexagon::Device::Global()->Alloc(nbytes, alignment);
hexagon::Device::Global()->Free(ptr);
}
-inline void HexagonDeviceAPI::CopyDataFromTo(
- const void* from, size_t from_offset, void* to, size_t to_offset,
- size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to,
- DLDataType type_hint, TVMStreamHandle stream) {
+inline void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to,
+ size_t to_offset, size_t num_bytes,
+ TVMContext ctx_from, TVMContext ctx_to,
+ DLDataType type_hint, TVMStreamHandle stream) {
const char* src = static_cast<const char*>(from) + from_offset;
char* dst = static_cast<char*>(to) + to_offset;
}
}
-inline void HexagonDeviceAPI::StreamSync(TVMContext ctx,
- TVMStreamHandle stream) {}
+inline void HexagonDeviceAPI::StreamSync(TVMContext ctx, TVMStreamHandle stream) {}
-inline void* HexagonDeviceAPI::AllocWorkspace(TVMContext ctx, size_t nbytes,
- DLDataType type_hint) {
+inline void* HexagonDeviceAPI::AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint) {
CHECK(hexagon::Device::ValidateDeviceId(ctx.device_id));
if (type_hint.code == 100) {
size_t align = std::min(nbytes, 2048lu);
DeviceAPI::FreeWorkspace(ctx, ptr);
}
-TVM_REGISTER_GLOBAL("device_api.hexagon")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = HexagonDeviceAPI::Global().get();
- *rv = ptr;
- });
+TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = HexagonDeviceAPI::Global().get();
+ *rv = ptr;
+});
} // namespace runtime
} // namespace tvm
if (!InReg) {
// Allocate on stack.
- CHECK_EQ((t_align & (t_align - 1)), 0)
- << "Alignment should be a power of 2";
+ CHECK_EQ((t_align & (t_align - 1)), 0) << "Alignment should be a power of 2";
CHECK_GE(t_align, 4) << "Alignment should be at least 4";
// Round t_size up to a multiple of 4.
unsigned s_size = Stack.size();
class HexagonModuleNode final : public runtime::ModuleNode {
public:
HexagonModuleNode(std::string data, std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string asm_str, std::string obj_str,
- std::string ir_str, std::string bc_str,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string asm_str,
+ std::string obj_str, std::string ir_str, std::string bc_str,
const std::set<std::string>& packed_c_abi)
: hexagon_device_(hexagon::Device::Global()),
data_(data),
}
}
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final;
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
const char* type_key() const final { return "hexagon"; }
- void SaveToFile(const std::string& file_name,
- const std::string& format) final {
+ void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = runtime::GetFileFormat(file_name, format);
if (fmt == "so" || fmt == "dll" || fmt == "hexagon") {
std::string meta_file = GetMetaFilePath(file_name);
CHECK(!bc_.empty()) << "LLVM IR bitcode not available";
SaveBinaryToFile(file_name, bc_);
} else {
- LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt
- << "'";
+ LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt << "'";
}
}
void SaveToBinary(dmlc::Stream* stream) final {
}
private:
- void CallRemotePackedCABI(void* func_ptr, const TVMArgs& args,
- TVMRetValue* rv) const;
- void CallRemoteDirect(void* func_ptr, const TVMArgs& args,
- TVMRetValue* rv) const;
+ void CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const;
+ void CallRemoteDirect(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const;
void RemapArgs(const TVMArgs& args,
std::vector<TVMValue>& values, // NOLINT(*)
std::vector<int>& type_codes, // NOLINT(*)
std::set<std::string> packed_c_abi_funcs_;
};
-void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr,
- const TVMArgs& args,
+void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, const TVMArgs& args,
TVMRetValue* rv) const {
// Remap all arguments, creating remote DLTensors.
std::vector<TVMValue> values;
int num_args = args.size();
int values_size = num_args * sizeof(TVMValue);
int codes_size = num_args * sizeof(int);
- void* remote = hexagon_device_->Alloc(
- values_size + sizeof(TVMValue) + codes_size + sizeof(int), 8);
+ void* remote =
+ hexagon_device_->Alloc(values_size + sizeof(TVMValue) + codes_size + sizeof(int), 8);
// Copy all argument TVMValues to the remote space.
void* remote_values = remote;
temp_values[2].v_int64 = num_args;
temp_values[3].v_handle = remote_ret_value;
temp_values[4].v_handle = remote_ret_code;
- int temp_codes[5] = {kTVMOpaqueHandle, kTVMOpaqueHandle, kDLInt,
- kTVMOpaqueHandle, kTVMOpaqueHandle};
+ int temp_codes[5] = {kTVMOpaqueHandle, kTVMOpaqueHandle, kDLInt, kTVMOpaqueHandle,
+ kTVMOpaqueHandle};
TVMArgs temp_args(temp_values, temp_codes, 5);
hexagon::ArgLayout as = BuildArgLayout(temp_args);
- hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(),
- as.Stack.data(), as.Stack.size());
+ hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(),
+ as.Stack.size());
// TODO(kparzysz-quic): copy return value back
std::for_each(remote_tensors.begin(), remote_tensors.end(),
void HexagonModuleNode::CallRemoteDirect(void* func_ptr, const TVMArgs& args,
TVMRetValue* rv) const {
hexagon::ArgLayout as = BuildArgLayout(args);
- hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(),
- as.Stack.data(), as.Stack.size());
+ hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(),
+ as.Stack.size());
}
-PackedFunc HexagonModuleNode::GetFunction(
- const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc HexagonModuleNode::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
auto f = fmap_.find(name);
if (f == fmap_.end()) return PackedFunc(nullptr);
}
}
-void HexagonModuleNode::RemapArgs(const TVMArgs& args,
- std::vector<TVMValue>& values,
+void HexagonModuleNode::RemapArgs(const TVMArgs& args, std::vector<TVMValue>& values,
std::vector<int>& type_codes,
std::vector<void*>& remote_tensors) const {
for (unsigned i = 0, e = args.size(); i != e; ++i) {
uint32_t remote_as_int = reinterpret_cast<uintptr_t>(remote);
void* remote_ss = reinterpret_cast<void*>(remote_as_int + size_ht);
- HexagonDLTensor local = {
- .data = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(t->data)),
- .ctx_device_type = uint8_t(t->ctx.device_type),
- .pad0 = {0, 0, 0},
- .ctx_device_id = t->ctx.device_id,
- .ndim = t->ndim,
- .dtype_code = t->dtype.code,
- .dtype_bits = t->dtype.bits,
- .dtype_lanes = t->dtype.lanes,
- .shape = remote_as_int + size_ht,
- .strides = t->strides ? remote_as_int + size_ht + size_s : 0u,
- .byte_offset = t->byte_offset};
+ HexagonDLTensor local = {.data = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(t->data)),
+ .ctx_device_type = uint8_t(t->ctx.device_type),
+ .pad0 = {0, 0, 0},
+ .ctx_device_id = t->ctx.device_id,
+ .ndim = t->ndim,
+ .dtype_code = t->dtype.code,
+ .dtype_bits = t->dtype.bits,
+ .dtype_lanes = t->dtype.lanes,
+ .shape = remote_as_int + size_ht,
+ .strides = t->strides ? remote_as_int + size_ht + size_s : 0u,
+ .byte_offset = t->byte_offset};
std::vector<uint64_t> local_ss(size_ss / 8);
for (int i = 0; i != ndim; ++i) local_ss[i] = t->shape[i];
}
Module HexagonModuleCreate(std::string data, std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string asm_str, std::string obj_str,
- std::string ir_str, std::string bc_str,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string asm_str,
+ std::string obj_str, std::string ir_str, std::string bc_str,
const std::set<std::string>& packed_c_abi) {
- auto n = make_object<HexagonModuleNode>(data, fmt, fmap, asm_str, obj_str,
- ir_str, bc_str, packed_c_abi);
+ auto n = make_object<HexagonModuleNode>(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str,
+ packed_c_abi);
return Module(n);
}
// Load module from file.
-Module HexagonModuleLoadFile(const std::string& file_name,
- const std::string& format) {
+Module HexagonModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data = file_name;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
} // namespace hexagon
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = HexagonModuleLoadFile(args[0], args[1]);
- });
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = HexagonModuleLoadFile(args[0], args[1]);
+});
} // namespace runtime
} // namespace tvm
* convention.
*/
Module HexagonModuleCreate(std::string data, std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string asm_str, std::string obj_str,
- std::string ir_str, std::string bc_str,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string asm_str,
+ std::string obj_str, std::string ir_str, std::string bc_str,
const std::set<std::string>& packed_c_abi);
namespace hexagon {
* \param src Pointer (local to device) of the source buffer.
* \param len Number of bytes to copy.
*/
- virtual void CopyDeviceToDevice(void* dst, const void* src,
- unsigned len) = 0;
+ virtual void CopyDeviceToDevice(void* dst, const void* src, unsigned len) = 0;
/*!
* \brief Copy a block of data from device to host.
* \param host_dst Pointer (local to host) to the destination buffer.
* \param src Pointer (local to device) to the source buffer.
* \param len Number of bytes to copy.
*/
- virtual void CopyDeviceToHost(void* host_dst, const void* src,
- unsigned len) = 0;
+ virtual void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) = 0;
/*!
* \brief Copy a block of data from host to device.
* \param dst Pointer (local to device) to the destination buffer.
* \param host_src Pointer (local to host) to the source buffer.
* \param len Number of bytes to copy.
*/
- virtual void CopyHostToDevice(void* dst, const void* host_src,
- unsigned len) = 0;
+ virtual void CopyHostToDevice(void* dst, const void* host_src, unsigned len) = 0;
/*!
* \brief Load a module (typically a shared library) into device.
* \param data Name of the shared library.
* for padding.
* \param st_num Number of values in the "stack" array.
*/
- virtual void Call(void* func, uint32_t* scalar, unsigned sc_num,
- uint32_t* stack, unsigned st_num) = 0;
+ virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack,
+ unsigned st_num) = 0;
virtual ~Device() = 0;
#include <stdlib.h>
extern "C" {
-int posix_memalign(void** memptr, size_t alignment, size_t size)
- __attribute__((nothrow));
+int posix_memalign(void** memptr, size_t alignment, size_t size) __attribute__((nothrow));
}
-__attribute__((nothrow)) int posix_memalign(void** memptr, size_t alignment,
- size_t size) {
+__attribute__((nothrow)) int posix_memalign(void** memptr, size_t alignment, size_t size) {
if (void* p = memalign(alignment, size)) {
*memptr = p;
return 0;
namespace runtime {
namespace hexagon {
-static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t),
- "Hexagon VA must be uint32");
+static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t), "Hexagon VA must be uint32");
template <typename T>
struct unalign {
// user from memory reallocation and copying.
struct non_const_str {
non_const_str() {}
- explicit non_const_str(const std::string& str)
- : non_const_str(std::vector<std::string>{str}) {}
+ explicit non_const_str(const std::string& str) : non_const_str(std::vector<std::string>{str}) {}
explicit non_const_str(const std::vector<std::string>& vec) {
for (const std::string& s : vec) {
auto c = detail::make_unique<char[]>(s.size() + 1);
void* Load(const std::string& data, const std::string& fmt) final;
void Unload(void* mod) final;
void* Resolve(const std::string& sym) final;
- void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack,
- unsigned st_num) final;
+ void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, unsigned st_num) final;
static std::string to_string(HEXAPI_Status status);
bool should_parse_next(const string_list& rest);
llvm::Optional<HEXAPI_Interval> to_interval(const detail::MaybeString& str);
- llvm::Optional<HEXAPI_TimingMode> to_timingmode(
- const detail::MaybeString& str);
- llvm::Optional<HEXAPI_VerboseMode> to_verbosemode(
- const detail::MaybeString& str);
+ llvm::Optional<HEXAPI_TimingMode> to_timingmode(const detail::MaybeString& str);
+ llvm::Optional<HEXAPI_VerboseMode> to_verbosemode(const detail::MaybeString& str);
llvm::Optional<HEXAPI_Nullptr> to_nullptr(const detail::MaybeString& str);
MaybeUIntRange ahb_, axi2_;
{"--verbose", &HexagonSimulator::HandleVerbose},
};
-#define CHECKED_CALL(func, ...) \
- do { \
- HEXAPI_Status s = sim_->func(__VA_ARGS__); \
- CHECK_EQ(s, HEX_STAT_SUCCESS) \
- << "HexagonSimulator: " #func " failed with code " \
- << HexagonSimulator::to_string(s); \
+#define CHECKED_CALL(func, ...) \
+ do { \
+ HEXAPI_Status s = sim_->func(__VA_ARGS__); \
+ CHECK_EQ(s, HEX_STAT_SUCCESS) << "HexagonSimulator: " #func " failed with code " \
+ << HexagonSimulator::to_string(s); \
} while (false)
inline HEX_VA_t HexagonSimulator::p2va(const void* p) {
pd->value = v;
}
-void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src,
- unsigned len) {
+void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src, unsigned len) {
const uint8_t* src = static_cast<const uint8_t*>(host_src);
while (len >= 8) {
using iterator = std::istream_iterator<std::string>;
auto sim_args = string_list(iterator(sim_args_iss), iterator());
- std::string target_str =
- !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66");
+ std::string target_str = !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66");
arch_ = target_str;
- sim_ =
- detail::make_unique<HexagonWrapper>(detail::non_const_str(target_str));
+ sim_ = detail::make_unique<HexagonWrapper>(detail::non_const_str(target_str));
LOG(INFO) << "HexagonSimulator: Core version: " << arch_;
// Locate the sim_dev binary in PATH, or in the current working directory.
llvm::StringRef sim_dev = "sim_dev";
- detail::MaybeString path_sim_dev =
- llvm::sys::Process::FindInEnvPath("PATH", sim_dev);
+ detail::MaybeString path_sim_dev = llvm::sys::Process::FindInEnvPath("PATH", sim_dev);
if (!path_sim_dev) {
if (!llvm::sys::fs::exists(sim_dev)) {
LOG(FATAL) << "Cannot find sim_dev in PATH.";
}
void* HexagonSimulator::Alloc(unsigned size, unsigned align) {
- LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align
- << ')';
+ LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align << ')';
Message m = {kAlloc, sizeof(MsgAlloc), 0u};
MsgAlloc ma = {size, align};
SendMsg(m, &ma, true);
}
void HexagonSimulator::Free(void* ptr) {
- LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec
- << ')';
+ LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec << ')';
if (task_queuing_) {
Message mf = {kFlush, 0, 0};
SendMsg(mf, 0, true);
}
void* HexagonSimulator::AllocVtcm(unsigned size, unsigned align) {
- LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size
- << ", align=" << align << ')';
+ LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size << ", align=" << align << ')';
Message m = {kAllocVtcm, sizeof(MsgAlloc), 0u};
MsgAlloc ma = {size, align};
SendMsg(m, &ma, true);
MsgPointer mp;
CopyFromV(&mp, m.va, m.len);
- LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va
- << std::dec;
+ LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va << std::dec;
CHECK_NE(mp.va, 0);
return va2p(mp.va);
}
void HexagonSimulator::FreeVtcm(void* ptr) {}
-void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src,
- unsigned len) {
- LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst
- << ", src=" << src << ", len=" << std::dec << len << ')';
+void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src, unsigned len) {
+ LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst << ", src=" << src
+ << ", len=" << std::dec << len << ')';
CHECK(dst != nullptr && src != nullptr);
Message m = {kCopy, sizeof(MsgCopy), 0u};
MsgCopy mc = {p2va(dst), p2va(src), len};
SendMsg(m, &mc, true);
}
-void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src,
- unsigned len) {
- LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst
- << ", src=" << src << ", len=" << len << ')';
+void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) {
+ LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst << ", src=" << src
+ << ", len=" << len << ')';
if (task_queuing_) {
Message mf = {kFlush, 0, 0};
SendMsg(mf, 0, true);
CopyFromV(host_dst, p2va(src), len);
}
-void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src,
- unsigned len) {
- LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst
- << ", host_src=" << host_src << ", len=" << len << ')';
+void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src, unsigned len) {
+ LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst << ", host_src=" << host_src
+ << ", len=" << len << ')';
CopyToV(p2va(dst), host_src, len);
}
MsgPointer mp;
CopyFromV(&mp, m.va, sizeof(mp));
- LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va
- << std::dec;
+ LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va << std::dec;
return va2p(mp.va);
}
-void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num,
- uint32_t* stack, unsigned st_num) {
- LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func
- << ", scalar=" << scalar << ", sc_num=" << std::dec
+void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack,
+ unsigned st_num) {
+ LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func << ", scalar=" << scalar
+ << ", sc_num=" << std::dec
<< sc_num
// NOLINTNEXTLINE(build/include_what_you_use)
- << ", stack=" << std::hex << stack << ", st_num=" << std::dec
- << st_num;
+ << ", stack=" << std::hex << stack << ", st_num=" << std::dec << st_num;
std::vector<uint32_t> data;
log_data << std::dec << " }" << std::flush;
LOG(INFO) << log_data.str();
- Message m = {kCall, static_cast<uint32_t>(data.size() * sizeof(uint32_t)),
- 0u};
+ Message m = {kCall, static_cast<uint32_t>(data.size() * sizeof(uint32_t)), 0u};
SendMsg(m, data.data(), true);
if (!task_queuing_) {
std::ostringstream log_rv;
log_rv << "HexagonSimulator::Call -> {" << std::hex;
for (unsigned i = 0, e = std::min<unsigned>(rv.size(), 4u); i != e; ++i) {
- log_rv << ' ' << std::setw(2) << std::setfill('0')
- << static_cast<uint32_t>(rv[i]);
+ log_rv << ' ' << std::setw(2) << std::setfill('0') << static_cast<uint32_t>(rv[i]);
}
if (rv.size() > 4) log_rv << "...";
log_rv << std::dec << " }";
}
bool HexagonSimulator::HandlePCFilter(string_list& rest) {
- auto range =
- detail::to_range<uint64_t, detail::to_uint>(detail::pop_front(rest));
+ auto range = detail::to_range<uint64_t, detail::to_uint>(detail::pop_front(rest));
if (range) {
CHECKED_CALL(ConfigurePCRangeFilter, range->first, range->second);
}
}
bool HexagonSimulator::HandleTimeFilterNS(string_list& rest) {
- auto range =
- detail::to_range<uint64_t, detail::to_uint>(detail::pop_front(rest));
+ auto range = detail::to_range<uint64_t, detail::to_uint>(detail::pop_front(rest));
if (range) {
- CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC,
- range->second, HEX_NANOSEC);
+ CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC, range->second, HEX_NANOSEC);
}
return static_cast<bool>(range);
}
return false;
}
-llvm::Optional<HEXAPI_Interval> HexagonSimulator::to_interval(
- const detail::MaybeString& str) {
+llvm::Optional<HEXAPI_Interval> HexagonSimulator::to_interval(const detail::MaybeString& str) {
auto none = llvm::Optional<HEXAPI_Interval>();
if (!str) return none;
.Default(none);
}
-llvm::Optional<HEXAPI_TimingMode> HexagonSimulator::to_timingmode(
- const detail::MaybeString& str) {
+llvm::Optional<HEXAPI_TimingMode> HexagonSimulator::to_timingmode(const detail::MaybeString& str) {
auto none = llvm::Optional<HEXAPI_TimingMode>();
if (!str) return none;
.Default(none);
}
-llvm::Optional<HEXAPI_Nullptr> HexagonSimulator::to_nullptr(
- const detail::MaybeString& str) {
+llvm::Optional<HEXAPI_Nullptr> HexagonSimulator::to_nullptr(const detail::MaybeString& str) {
auto none = llvm::Optional<HEXAPI_Nullptr>();
if (!str) return none;
// Stub functions for targets that don't support VTCM.
static void* HAP_request_VTCM(int a, int b) { return 0; }
static int HAP_release_VTCM(void* a) { return 0; }
-static int HAP_query_avail_VTCM(unsigned* avail_block_size,
- unsigned* max_page_size, unsigned* num_pages) {
+static int HAP_query_avail_VTCM(unsigned* avail_block_size, unsigned* max_page_size,
+ unsigned* num_pages) {
FARF(ALWAYS, "%s: running on architecture V62 or less", __func__);
return AEE_ENOMEMORY;
}
return rc;
}
- *handle_ptr =
- static_cast<remote_handle64>(reinterpret_cast<uintptr_t>(malloc(1)));
+ *handle_ptr = static_cast<remote_handle64>(reinterpret_cast<uintptr_t>(malloc(1)));
if (!*handle_ptr) {
FARF(ERROR, "%s: cannot allocate memory", __func__);
return AEE_ENOMEMORY;
* This function is present as a workaround. See comment at the call site
* in hexagon_device_target.cc.
*/
-int tvm_remote_call_mmap64(remote_handle64 handle) {
- return AEE_SUCCESS;
-}
+int tvm_remote_call_mmap64(remote_handle64 handle) { return AEE_SUCCESS; }
/*!
* \brief Load a shared library.
*
* \return 0 on success, negative value on error.
*/
-int tvm_remote_load_library(remote_handle64 handle, const char* soname,
- int soname_len, tvm_remote_handle_t* lib_ptr) {
+int tvm_remote_load_library(remote_handle64 handle, const char* soname, int soname_len,
+ tvm_remote_handle_t* lib_ptr) {
return tvm_remote_nd_load_library(soname, soname_len, lib_ptr);
}
*
* \return 0 on success, negative value on error.
*/
-int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib,
- const char* name, int name_len,
- tvm_remote_handle_t* sym_ptr) {
+int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib, const char* name,
+ int name_len, tvm_remote_handle_t* sym_ptr) {
return tvm_remote_nd_get_symbol(lib, name, name_len, sym_ptr);
}
* The 8 "octet" arguments in this function are used for cache operations
* only. They are not used for procesing.
*/
-int tvm_remote_kernel(
- remote_handle64 handle, tvm_remote_handle_t lib,
- tvm_remote_handle_t symbol, const int* scalar, int scalar_len,
- const int* stack, int stack_len, const tvm_remote_buffer* scalar_in_octet,
- int scalar_in_octet_len, tvm_remote_buffer* scalar_out_octet,
- int scalar_out_octet_len, const tvm_remote_buffer* stack_in_octet,
- int stack_in_octet_len, tvm_remote_buffer* stack_out_octet,
- int stack_out_octet_len, uint64* pcycles, uint64* time_usec) {
+int tvm_remote_kernel(remote_handle64 handle, tvm_remote_handle_t lib, tvm_remote_handle_t symbol,
+ const int* scalar, int scalar_len, const int* stack, int stack_len,
+ const tvm_remote_buffer* scalar_in_octet, int scalar_in_octet_len,
+ tvm_remote_buffer* scalar_out_octet, int scalar_out_octet_len,
+ const tvm_remote_buffer* stack_in_octet, int stack_in_octet_len,
+ tvm_remote_buffer* stack_out_octet, int stack_out_octet_len, uint64* pcycles,
+ uint64* time_usec) {
return tvm_remote_nd_kernel(
lib, symbol, scalar, scalar_len, stack, stack_len,
- reinterpret_cast<const tvm_remote_nd_buffer*>(scalar_in_octet),
- scalar_in_octet_len,
- reinterpret_cast<tvm_remote_nd_buffer*>(scalar_out_octet),
- scalar_out_octet_len,
- reinterpret_cast<const tvm_remote_nd_buffer*>(stack_in_octet),
- stack_in_octet_len,
- reinterpret_cast<tvm_remote_nd_buffer*>(stack_out_octet),
- stack_out_octet_len, pcycles, time_usec);
+ reinterpret_cast<const tvm_remote_nd_buffer*>(scalar_in_octet), scalar_in_octet_len,
+ reinterpret_cast<tvm_remote_nd_buffer*>(scalar_out_octet), scalar_out_octet_len,
+ reinterpret_cast<const tvm_remote_nd_buffer*>(stack_in_octet), stack_in_octet_len,
+ reinterpret_cast<tvm_remote_nd_buffer*>(stack_out_octet), stack_out_octet_len, pcycles,
+ time_usec);
}
/*!
*
* \return 0 on success, negative value on error.
*/
-int tvm_remote_release_library(remote_handle64 handle,
- tvm_remote_handle_t lib) {
+int tvm_remote_release_library(remote_handle64 handle, tvm_remote_handle_t lib) {
// FARF(ALWAYS, "tvm_remote_release_library begin ");
return tvm_remote_nd_release_library(lib);
}
*
* \return 0 on success, negative value on error.
*/
-int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size,
- unsigned align, unsigned* dsp_va) {
+int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size, unsigned align, unsigned* dsp_va) {
FARF(ALWAYS, "%s: size=%u, align=%u", __func__, size, align);
unsigned avail_block_size, max_page_size, num_pages;
int rc = HAP_query_avail_VTCM(&avail_block_size, &max_page_size, &num_pages);
FARF(ERROR, "%s: HAP_query_avail_VTCM failed, rc=%08x", __func__, rc);
return rc;
}
- FARF(ALWAYS, "%s: avail_block_size=%u, max_page_size=%u, num_pages=%u",
- __func__, avail_block_size, max_page_size, num_pages);
+ FARF(ALWAYS, "%s: avail_block_size=%u, max_page_size=%u, num_pages=%u", __func__,
+ avail_block_size, max_page_size, num_pages);
if (max_page_size < MIN_VTCM_SZ) {
- FARF(ERROR, "%s: available VTCM size less than %d KB, aborting", __func__,
- MIN_VTCM_SZ / 1024);
+ FARF(ERROR, "%s: available VTCM size less than %d KB, aborting", __func__, MIN_VTCM_SZ / 1024);
return AEE_ENOMEMORY;
}
uint32_t data[];
} __attribute__((packed));
-__attribute__((naked)) uint32_t launcher(volatile msg_call* mc,
- uint64_t* pcc) {
+__attribute__((naked)) uint32_t launcher(volatile msg_call* mc, uint64_t* pcc) {
__asm__(
"// This function is intentionally written to be readable, \n"
"// rather than fast. \n"
extern "C" {
#pragma weak __wrap_pthread_create
-int __wrap_pthread_create(pthread_t* restrict thread,
- const pthread_attr_t* restrict attr,
+int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr,
void* (*start)(void*), void* restrict arg) {
FARF(ERROR, "Wrong %s called", __func__);
abort();
int tvm_remote_nd_open() {
lib_thread = dlopen("libtvm_wrap_pthread.so", RTLD_NOW | RTLD_GLOBAL);
if (lib_thread == nullptr) {
- FARF(ERROR, "%s: dlopen failed for libtvm_wrap_pthread.so: %s", __func__,
- dlerror());
+ FARF(ERROR, "%s: dlopen failed for libtvm_wrap_pthread.so: %s", __func__, dlerror());
return AEE_EUNABLETOLOAD;
}
lib_rt = dlopen("libtvm_runtime.so", RTLD_NOW | RTLD_GLOBAL);
if (lib_rt == nullptr) {
- FARF(ERROR, "%s: dlopen failed for libtvm_runtime.so: %s", __func__,
- dlerror());
+ FARF(ERROR, "%s: dlopen failed for libtvm_runtime.so: %s", __func__, dlerror());
return AEE_EUNABLETOLOAD;
}
return AEE_SUCCESS;
* This function is present as a workaround. See comment at the call site
* in hexagon_device_target.cc.
*/
-int tvm_remote_nd_call_mmap64() {
- return AEE_SUCCESS;
-}
+int tvm_remote_nd_call_mmap64() { return AEE_SUCCESS; }
/*!
* \brief Load a shared library.
*
* \return 0 on success, negative value on error.
*/
-int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name,
- int name_len, tvm_remote_nd_handle_t* sym_ptr) {
+int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name, int name_len,
+ tvm_remote_nd_handle_t* sym_ptr) {
FARF(ALWAYS, "%s: name=%s", __func__, name);
if (void* p = dlsym(reinterpret_cast<void*>(lib), name)) {
*sym_ptr = reinterpret_cast<tvm_remote_nd_handle_t>(p);
}
static void print_msg_call(const msg_call& mc) {
- FARF(ALWAYS, "device: launching %x scalar_num:%d stack_num:%d", mc.func_va,
- mc.scalar_num, mc.stack_num);
+ FARF(ALWAYS, "device: launching %x scalar_num:%d stack_num:%d", mc.func_va, mc.scalar_num,
+ mc.stack_num);
for (unsigned i = 0; i != mc.scalar_num; ++i) {
FARF(ALWAYS, "scalar_data[%d] %x", i, mc.data[i]);
}
* The 8 "octet" arguments in this function are used for cache operations
* only. They are not used for procesing.
*/
-int tvm_remote_nd_kernel(
- tvm_remote_nd_handle_t lib, tvm_remote_nd_handle_t symbol,
- const int* scalar, int scalar_len, const int* stack, int stack_len,
- const tvm_remote_nd_buffer* scalar_in_octet, int scalar_in_octet_len,
- tvm_remote_nd_buffer* scalar_out_octet, int scalar_out_octet_len,
- const tvm_remote_nd_buffer* stack_in_octet, int stack_in_octet_len,
- tvm_remote_nd_buffer* stack_out_octet, int stack_out_octet_len,
- uint64* pcycles, uint64* time_usec) {
+int tvm_remote_nd_kernel(tvm_remote_nd_handle_t lib, tvm_remote_nd_handle_t symbol,
+ const int* scalar, int scalar_len, const int* stack, int stack_len,
+ const tvm_remote_nd_buffer* scalar_in_octet, int scalar_in_octet_len,
+ tvm_remote_nd_buffer* scalar_out_octet, int scalar_out_octet_len,
+ const tvm_remote_nd_buffer* stack_in_octet, int stack_in_octet_len,
+ tvm_remote_nd_buffer* stack_out_octet, int stack_out_octet_len,
+ uint64* pcycles, uint64* time_usec) {
hvx::config_t hvx_info = {0};
hvx::prepare_mt_job(&hvx_info);
if (hvx_info.num_reserved > 0) {
lock_result = hvx::lock(hvx::MODE_128B);
if (lock_result < 0) {
- FARF(ERROR, "%s: HVX locking failed lock_result=%d num_reserved=%d",
- __func__, lock_result, hvx_info.num_reserved);
+ FARF(ERROR, "%s: HVX locking failed lock_result=%d num_reserved=%d", __func__, lock_result,
+ hvx_info.num_reserved);
} else {
- FARF(ALWAYS, "%s: HVX lock successful lock_result=%d", __func__,
- lock_result);
+ FARF(ALWAYS, "%s: HVX lock successful lock_result=%d", __func__, lock_result);
}
} else {
FARF(ERROR, "%s: there are no HVX units available", __func__);
}
- struct msg_call* mc = (struct msg_call*)malloc(sizeof(uint32_t) *
- (3 + scalar_len + stack_len));
+ struct msg_call* mc = (struct msg_call*)malloc(sizeof(uint32_t) * (3 + scalar_len + stack_len));
if (mc == nullptr) {
FARF(ERROR, "%s: failed to allocate memory for mc", __func__);
return AEE_ENOMEMORY;
uint64_t start_time = HAP_perf_get_time_us();
int result = launcher(mc, pcycles);
*time_usec = HAP_perf_get_time_us() - start_time;
- FARF(ALWAYS, "kernel execution: %llu pcycles %llu usec", *pcycles,
- *time_usec);
+ FARF(ALWAYS, "kernel execution: %llu pcycles %llu usec", *pcycles, *time_usec);
if (lock_result > 0) hvx::unlock();
hvx::cleanup_mt_job(&hvx_info);
if (mc) free(mc);
// Make sure the function has C linkage.
extern "C" {
-int __wrap_pthread_create(pthread_t* restrict thread,
- const pthread_attr_t* restrict attr,
+int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr,
void* (*start)(void*), void* restrict arg);
}
-int __wrap_pthread_create(pthread_t* restrict thread,
- const pthread_attr_t* restrict attr,
+int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr,
void* (*start)(void*), void* restrict arg) {
pthread_attr_t def_attr;
if (attr == nullptr) {
FARF(ALWAYS, "launching thread with stack_size=%zu", stack_size);
int t = pthread_create(thread, attr, start, arg);
if (int rc = pthread_attr_destroy(&def_attr)) {
- FARF(ERROR, "pthread_attr_destroy failed (after pthread_create): rc=%08x",
- rc);
+ FARF(ERROR, "pthread_attr_destroy failed (after pthread_create): rc=%08x", rc);
}
return t;
}
// The downside is that the format string must be given as a string literal,
// but it seems to be a minor issue.
#define VA_EXPANDER(...) , ##__VA_ARGS__
-#define TVM_LOGD_HT(fmt, ...) \
- TVM_LOGD("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__))
-#define TVM_LOGE_HT(fmt, ...) \
- TVM_LOGE("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__))
+#define TVM_LOGD_HT(fmt, ...) TVM_LOGD("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__))
+#define TVM_LOGE_HT(fmt, ...) TVM_LOGE("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__))
namespace tvm {
namespace runtime {
unsigned stack_num) final;
private:
- std::pair<void*, size_t> AddAddrMapping(const void* dsp_addr,
- void* apps_addr, size_t size);
+ std::pair<void*, size_t> AddAddrMapping(const void* dsp_addr, void* apps_addr, size_t size);
std::pair<void*, size_t> GetAppsAddr(const void* dsp_addr, bool exact) const;
void RemoveAddrMapping(const void* dsp_addr);
int OpenDomainChannel(bool set_unsigned_pd);
void* const HexagonTarget::vtcm_mark_ = reinterpret_cast<void*>(~0);
-std::shared_ptr<Device> CreateHexagonTarget() {
- return std::make_shared<HexagonTarget>();
-}
+std::shared_ptr<Device> CreateHexagonTarget() { return std::make_shared<HexagonTarget>(); }
-std::pair<void*, size_t> HexagonTarget::AddAddrMapping(const void* dsp_addr,
- void* apps_addr,
+std::pair<void*, size_t> HexagonTarget::AddAddrMapping(const void* dsp_addr, void* apps_addr,
size_t size) {
crit_section_.lock();
auto p = dsp_to_apps_.insert({dsp_addr, {apps_addr, size}});
crit_section_.unlock();
if (!p.second) {
- TVM_LOGE_HT(
- "failed to insert address mapping: dsp:%p -> apps:%p, size:%zu",
- dsp_addr, apps_addr, size);
+ TVM_LOGE_HT("failed to insert address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr,
+ apps_addr, size);
return std::make_pair(nullptr, 0);
}
- TVM_LOGD_HT("added address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr,
- apps_addr, size);
+ TVM_LOGD_HT("added address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, apps_addr, size);
return p.first->second;
}
crit_section_.unlock();
}
-std::pair<void*, size_t> HexagonTarget::GetAppsAddr(const void* dsp_addr,
- bool exact) const {
+std::pair<void*, size_t> HexagonTarget::GetAppsAddr(const void* dsp_addr, bool exact) const {
struct AutoUnlock {
explicit AutoUnlock(std::mutex& m) : m(m) {}
~AutoUnlock() { m.unlock(); }
data.domain = CDSP_DOMAIN_ID;
int rc = rsc_ptr(DSPRPC_CONTROL_UNSIGNED_MODULE, &data, sizeof(data));
if (rc != AEE_SUCCESS) {
- TVM_LOGE_HT("remote_session_control failed rc=%08x for unsigned PD",
- rc);
+ TVM_LOGE_HT("remote_session_control failed rc=%08x for unsigned PD", rc);
}
}
} else {
TVM_LOGD_HT("remote_session_control not available");
}
- int rc = stub_api->tvm_remote_open(tvm_remote_URI "&_dom=cdsp",
- &domain_channel_handle_);
+ int rc = stub_api->tvm_remote_open(tvm_remote_URI "&_dom=cdsp", &domain_channel_handle_);
if (rc != AEE_SUCCESS) {
TVM_LOGE_HT("failed to open channel rc=0x%x", rc);
} else {
crit_section_.lock();
if (module_pointer_ != AEE_EUNKNOWN) {
const StubAPI* stub_api = StubAPI::Global();
- int rc = stub_api->tvm_remote_release_library(domain_channel_handle_,
- module_pointer_);
+ int rc = stub_api->tvm_remote_release_library(domain_channel_handle_, module_pointer_);
if (rc != AEE_SUCCESS) {
TVM_LOGE_HT("failed to unload device library rc=0x%x", rc);
} else {
// thread then remote_mmap64 fails. FastRPC expects one call to be made to
// DSP before calling remote_map64. Hence this call is needed for now untill
// FastRPC comes up with a fix.
- int rc_call_mmap_64 =
- stub_api->tvm_remote_call_mmap64(domain_channel_handle_);
+ int rc_call_mmap_64 = stub_api->tvm_remote_call_mmap64(domain_channel_handle_);
if (rc_call_mmap_64 != AEE_SUCCESS) {
- TVM_LOGE_HT("mmap64 failed for domain channel %lu",
- domain_channel_handle_);
+ TVM_LOGE_HT("mmap64 failed for domain channel %lu", domain_channel_handle_);
return nullptr;
}
- void* mem =
- stub_api->rpcmem_alloc_ptr()(RPCMEM_HEAP, RPCMEM_DEFAULT_FLAGS, size);
+ void* mem = stub_api->rpcmem_alloc_ptr()(RPCMEM_HEAP, RPCMEM_DEFAULT_FLAGS, size);
if (mem == nullptr) {
TVM_LOGE_HT("mem alloc failed for size=0x%x alignment=0x%x", size, align);
return nullptr;
}
int mem_fd = stub_api->rpcmem_to_fd_ptr()(mem);
uintptr_t dsp_va = 0;
- int rc = dsp_api->remote_mmap64_ptr()(
- mem_fd, 0, reinterpret_cast<uintptr_t>(mem), size, &dsp_va);
+ int rc = dsp_api->remote_mmap64_ptr()(mem_fd, 0, reinterpret_cast<uintptr_t>(mem), size, &dsp_va);
if (rc != AEE_SUCCESS) {
TVM_LOGE_HT(
"buffer mapping failed for remote_map64 fd=0x%x rc=0x%x "
auto aa = GetAppsAddr(ptr, true);
if (aa.first == nullptr) return;
- int rc = dsp_api->remote_munmap64_ptr()(reinterpret_cast<uintptr_t>(ptr),
- aa.second);
+ int rc = dsp_api->remote_munmap64_ptr()(reinterpret_cast<uintptr_t>(ptr), aa.second);
if (rc != AEE_SUCCESS) {
TVM_LOGE_HT("buffer unmapping failed rc=0x%x", rc);
}
const StubAPI* stub_api = StubAPI::Global();
unsigned int dsp_va = 0;
- int rc = stub_api->tvm_remote_alloc_vtcm(domain_channel_handle_, size, align,
- &dsp_va);
+ int rc = stub_api->tvm_remote_alloc_vtcm(domain_channel_handle_, size, align, &dsp_va);
if (rc != AEE_SUCCESS) {
TVM_LOGE_HT("VTCM allocation failed size=%u, align=%u", size, align);
return nullptr;
TVM_LOGD_HT("Done VTCM free from HexagonTarget::FreeVtcm");
}
-void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src,
- unsigned len) {
+void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, unsigned len) {
auto aa_src = GetAppsAddr(src, false);
auto aa_dst = GetAppsAddr(dst, false);
if (aa_src.first == vtcm_mark_ || aa_dst.first == vtcm_mark_) {
len, aa_dst.second);
}
len = std::min({size_t(len), aa_src.second, aa_dst.second});
- TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> dsp:%p(apps:%p), len:%u", src,
- aa_src.first, dst, aa_dst.first, len);
+ TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> dsp:%p(apps:%p), len:%u", src, aa_src.first, dst,
+ aa_dst.first, len);
std::memcpy(aa_dst.first, aa_src.first, len);
}
-void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src,
- unsigned len) {
+void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) {
auto aa = GetAppsAddr(src, false);
if (aa.first == vtcm_mark_) {
TVM_LOGE_HT("VTCM address. Copy operation not supported");
return;
}
if (aa.second < len) {
- TVM_LOGD_HT(
- "specified length:%u larger than buffer size:%zu, copy truncated", len,
- aa.second);
+ TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second);
len = aa.second;
}
- TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> apps:%p, len:%u", src, aa.first,
- host_dst, len);
+ TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> apps:%p, len:%u", src, aa.first, host_dst, len);
std::memcpy(host_dst, aa.first, len);
}
-void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src,
- unsigned len) {
+void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, unsigned len) {
auto aa = GetAppsAddr(dst, false);
if (aa.first == vtcm_mark_) {
TVM_LOGE_HT("VTCM address. Copy operation not supported");
return;
}
if (aa.second < len) {
- TVM_LOGD_HT(
- "specified length:%u larger than buffer size:%zu, copy truncated", len,
- aa.second);
+ TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second);
len = aa.second;
}
- TVM_LOGD_HT("copy, dsp:%p(apps:%p) <- apps:%p, len:%u", dst, aa.first,
- host_src, len);
+ TVM_LOGD_HT("copy, dsp:%p(apps:%p) <- apps:%p, len:%u", dst, aa.first, host_src, len);
std::memcpy(aa.first, host_src, len);
}
int rc_oc = OpenDomainChannel(/*use_unsigned_pd*/ unsigned_pd);
crit_section_.unlock();
if (rc_oc != AEE_SUCCESS) {
- TVM_LOGE_HT("loading of %s failed: unable to open domain channel",
- data.c_str());
+ TVM_LOGE_HT("loading of %s failed: unable to open domain channel", data.c_str());
return nullptr;
}
crit_section_.lock();
TVM_LOGD_HT("loading library %s ", data.c_str());
const StubAPI* stub_api = StubAPI::Global();
- int rc = stub_api->tvm_remote_load_library(
- domain_channel_handle_, data.c_str(), data.size() + 1, &module_pointer_);
+ int rc = stub_api->tvm_remote_load_library(domain_channel_handle_, data.c_str(), data.size() + 1,
+ &module_pointer_);
if (rc != AEE_SUCCESS) {
TVM_LOGE_HT("failed to load device library rc=0x%x", rc);
}
tvm_remote_handle_t pf;
TVM_LOGD_HT("resolving symbol %s", sym.c_str());
- int rc =
- stub_api->tvm_remote_get_symbol(domain_channel_handle_, module_pointer_,
- sym.c_str(), sym.size() + 1, &pf);
+ int rc = stub_api->tvm_remote_get_symbol(domain_channel_handle_, module_pointer_, sym.c_str(),
+ sym.size() + 1, &pf);
if (rc != AEE_SUCCESS) {
TVM_LOGE_HT("failed to get symbol from CDSP rc=0x%x", rc);
return nullptr;
return addr;
}
-void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num,
- uint32_t* stack, unsigned stack_num) {
+void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, uint32_t* stack,
+ unsigned stack_num) {
uint64 pcycles = 0, execution_time_usec = 0;
- auto scalar_octet =
- std::unique_ptr<tvm_remote_buffer[]>(new tvm_remote_buffer[scalar_num]);
- auto stack_octet =
- std::unique_ptr<tvm_remote_buffer[]>(new tvm_remote_buffer[stack_num]);
+ auto scalar_octet = std::unique_ptr<tvm_remote_buffer[]>(new tvm_remote_buffer[scalar_num]);
+ auto stack_octet = std::unique_ptr<tvm_remote_buffer[]>(new tvm_remote_buffer[stack_num]);
TVM_LOGD_HT("scalars=%p, stack=%p", scalar, stack);
if (scalar_octet == nullptr || stack_octet == nullptr) {
std::memset(scalar_octet.get(), 0, scalar_num * sizeof(tvm_remote_buffer));
std::memset(stack_octet.get(), 0, stack_num * sizeof(tvm_remote_buffer));
- auto ProcessInputs = [this](uint32_t* inputs, tvm_remote_buffer* buffers,
- unsigned num) {
+ auto ProcessInputs = [this](uint32_t* inputs, tvm_remote_buffer* buffers, unsigned num) {
for (unsigned i = 0; i != num; ++i) {
void* ptr = reinterpret_cast<void*>(static_cast<uintptr_t>(inputs[i]));
auto aa = GetAppsAddr(ptr, false);
int rc = stub_api->tvm_remote_kernel(
domain_channel_handle_, module_pointer_,
static_cast<tvm_remote_handle_t>(reinterpret_cast<uintptr_t>(func)),
- reinterpret_cast<int*>(scalar), scalar_num,
- reinterpret_cast<int*>(stack), stack_num, scalar_octet.get(), scalar_num,
- scalar_octet.get(), scalar_num, stack_octet.get(), stack_num,
+ reinterpret_cast<int*>(scalar), scalar_num, reinterpret_cast<int*>(stack), stack_num,
+ scalar_octet.get(), scalar_num, scalar_octet.get(), scalar_num, stack_octet.get(), stack_num,
stack_octet.get(), stack_num, &pcycles, &execution_time_usec);
if (rc != AEE_SUCCESS) {
TVM_LOGE_HT("failed to run kernel on CDSP rc=0x%x", rc);
} else {
- TVM_LOGD_HT("kernel execution: %llu pcycles, %llu usec, scalar_num=%d",
- pcycles, execution_time_usec, scalar_num);
+ TVM_LOGD_HT("kernel execution: %llu pcycles, %llu usec, scalar_num=%d", pcycles,
+ execution_time_usec, scalar_num);
}
}
constexpr auto domain_lib_name = "libtvm_remote_stub.so";
constexpr auto nondomain_lib_name = "libtvm_remote_nd_stub.so";
- const char* lib_name =
- enable_domains_ ? domain_lib_name : nondomain_lib_name;
+ const char* lib_name = enable_domains_ ? domain_lib_name : nondomain_lib_name;
CHECK(lib_handle_ = dlopen(lib_name, RTLD_LAZY | RTLD_LOCAL));
#define RESOLVE(fn) p##fn##_ = GetSymbol<fn##_t*>(#fn)
// two types identical in the function types created below.
// For example, int foo(tvm_remote_buffer*) and
// int bar(tvm_remote_nd_buffer*) should both have the same type.
-#define MAPTYPE(fn, ty) \
- using fn##_t = typename map_func_type<ty, void, decltype(::fn)>::type;
+#define MAPTYPE(fn, ty) using fn##_t = typename map_func_type<ty, void, decltype(::fn)>::type;
MAPTYPE(tvm_remote_load_library, tvm_remote_buffer)
MAPTYPE(tvm_remote_release_library, tvm_remote_buffer)
MAPTYPE(tvm_remote_get_symbol, tvm_remote_buffer)
public:
template <typename Fd, typename Fnd, typename... Ts>
- int invoke(Fd func_d, Fnd func_nd, remote_handle64 handle,
- Ts... args) const {
+ int invoke(Fd func_d, Fnd func_nd, remote_handle64 handle, Ts... args) const {
if (enable_domains_) {
return func_d(handle, args...);
}
#define FUNC_ND(name) CONCAT_STR(tvm_remote_nd_, name)
#define PTRNAME(fn) CONCAT_STR(p, CONCAT_STR(fn, _))
-#define DECLFUNC(name) \
- template <typename... Ts> \
- int FUNC(name)(remote_handle64 handle, Ts... args) const { \
- return invoke(PTRNAME(FUNC_D(name)), PTRNAME(FUNC_ND(name)), handle, \
- args...); \
+#define DECLFUNC(name) \
+ template <typename... Ts> \
+ int FUNC(name)(remote_handle64 handle, Ts... args) const { \
+ return invoke(PTRNAME(FUNC_D(name)), PTRNAME(FUNC_ND(name)), handle, args...); \
}
#define DECLFUNC_D(name) \
#include <android/log.h>
-#define TVM_LOGV(...) \
- __android_log_print(ANDROID_LOG_VERBOSE, "TVM", ##__VA_ARGS__)
-#define TVM_LOGD(...) \
- __android_log_print(ANDROID_LOG_DEBUG, "TVM", ##__VA_ARGS__)
-#define TVM_LOGI(...) \
- __android_log_print(ANDROID_LOG_INFO, "TVM", ##__VA_ARGS__)
-#define TVM_LOGW(...) \
- __android_log_print(ANDROID_LOG_WARN, "TVM", ##__VA_ARGS__)
-#define TVM_LOGE(...) \
- __android_log_print(ANDROID_LOG_ERROR, "TVM", ##__VA_ARGS__)
-#define TVM_LOGF(...) \
- __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__)
+#define TVM_LOGV(...) __android_log_print(ANDROID_LOG_VERBOSE, "TVM", ##__VA_ARGS__)
+#define TVM_LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, "TVM", ##__VA_ARGS__)
+#define TVM_LOGI(...) __android_log_print(ANDROID_LOG_INFO, "TVM", ##__VA_ARGS__)
+#define TVM_LOGW(...) __android_log_print(ANDROID_LOG_WARN, "TVM", ##__VA_ARGS__)
+#define TVM_LOGE(...) __android_log_print(ANDROID_LOG_ERROR, "TVM", ##__VA_ARGS__)
+#define TVM_LOGF(...) __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__)
#endif // __ANDROID__
#endif // TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_TARGET_LOG_H_
* \file module_util.cc
* \brief Utilities for module.
*/
+#include "library_module.h"
+
#include <dmlc/memory_io.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
+
#include <string>
-#include <vector>
#include <utility>
-#include "library_module.h"
+#include <vector>
namespace tvm {
namespace runtime {
// Library module that exposes symbols from a library.
class LibraryModuleNode final : public ModuleNode {
public:
- explicit LibraryModuleNode(ObjectPtr<Library> lib)
- : lib_(lib) {
- }
+ explicit LibraryModuleNode(ObjectPtr<Library> lib) : lib_(lib) {}
- const char* type_key() const final {
- return "library";
- }
+ const char* type_key() const final { return "library"; }
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
TVMBackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
- const char* entry_name = reinterpret_cast<const char*>(
- lib_->GetSymbol(runtime::symbol::tvm_module_main));
- CHECK(entry_name!= nullptr)
+ const char* entry_name =
+ reinterpret_cast<const char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main));
+ CHECK(entry_name != nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<TVMBackendPackedCFunc>(lib_->GetSymbol(entry_name));
} else {
class ModuleInternal {
public:
// Get mutable reference of imports.
- static std::vector<Module>* GetImportsAddr(ModuleNode* node) {
- return &(node->imports_);
- }
+ static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
};
-PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
- TVMValue ret_value;
- int ret_type_code = kTVMNullptr;
- int ret = (*faddr)(
- const_cast<TVMValue*>(args.values),
- const_cast<int*>(args.type_codes),
- args.num_args,
- &ret_value,
- &ret_type_code);
- CHECK_EQ(ret, 0) << TVMGetLastError();
- if (ret_type_code != kTVMNullptr) {
- *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code);
- }
- });
+ TVMValue ret_value;
+ int ret_type_code = kTVMNullptr;
+ int ret = (*faddr)(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes),
+ args.num_args, &ret_value, &ret_type_code);
+ CHECK_EQ(ret, 0) << TVMGetLastError();
+ if (ret_type_code != kTVMNullptr) {
+ *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code);
+ }
+ });
}
void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
- #define TVM_INIT_CONTEXT_FUNC(FuncName) \
- if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \
- (fgetsymbol("__" #FuncName))) { \
- *fp = FuncName; \
- }
+#define TVM_INIT_CONTEXT_FUNC(FuncName) \
+ if (auto* fp = reinterpret_cast<decltype(&FuncName)*>(fgetsymbol("__" #FuncName))) { \
+ *fp = FuncName; \
+ }
// Initialize the functions
TVM_INIT_CONTEXT_FUNC(TVMFuncCall);
TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier);
- #undef TVM_INIT_CONTEXT_FUNC
+#undef TVM_INIT_CONTEXT_FUNC
}
/*!
uint64_t nbytes = 0;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
uint64_t c = mblob[i];
- nbytes |= (c & 0xffUL) << (i * 8);
+ nbytes |= (c & 0xffUL) << (i * 8);
}
- dmlc::MemoryFixedSizeStream fs(
- const_cast<char*>(mblob + sizeof(nbytes)), static_cast<size_t>(nbytes));
+ dmlc::MemoryFixedSizeStream fs(const_cast<char*>(mblob + sizeof(nbytes)),
+ static_cast<size_t>(nbytes));
dmlc::Stream* stream = &fs;
uint64_t size;
CHECK(stream->Read(&size));
} else {
std::string fkey = "runtime.module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
- CHECK(f != nullptr)
- << "Loader of " << tkey << "("
- << fkey << ") is not presented.";
+ CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream));
modules.emplace_back(m);
}
}
Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
- InitContextFunctions([lib](const char* fname) {
- return lib->GetSymbol(fname);
- });
+ InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); });
auto n = make_object<LibraryModuleNode>(lib);
// Load the imported modules
const char* dev_mblob =
- reinterpret_cast<const char*>(
- lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
+ reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
Module root_mod;
if (dev_mblob != nullptr) {
root_mod = ProcessModuleBlob(dev_mblob, lib);
}
// allow lookup of symbol from root (so all symbols are visible).
- if (auto *ctx_addr =
- reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
+ if (auto* ctx_addr = reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = root_mod.operator->();
}
#ifndef TVM_RUNTIME_LIBRARY_MODULE_H_
#define TVM_RUNTIME_LIBRARY_MODULE_H_
-#include <tvm/runtime/module.h>
-#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/module.h>
+
#include <functional>
namespace tvm {
* \param name The name of the symbol.
* \return The symbol.
*/
- virtual void *GetSymbol(const char* name) = 0;
+ virtual void* GetSymbol(const char* name) = 0;
// NOTE: we do not explicitly create an type index and type_key here for libary.
// This is because we do not need dynamic type downcasting.
};
Module CreateModuleFromLibrary(ObjectPtr<Library> lib);
} // namespace runtime
} // namespace tvm
-#endif // TVM_RUNTIME_LIBRARY_MODULE_H_
+#endif // TVM_RUNTIME_LIBRARY_MODULE_H_
#ifndef TVM_RUNTIME_META_DATA_H_
#define TVM_RUNTIME_META_DATA_H_
-#include <dmlc/json.h>
#include <dmlc/io.h>
+#include <dmlc/json.h>
#include <tvm/runtime/packed_func.h>
+
#include <string>
#include <vector>
+
#include "runtime_base.h"
namespace tvm {
std::vector<DLDataType> arg_types;
std::vector<std::string> thread_axis_tags;
- void Save(dmlc::JSONWriter *writer) const;
- void Load(dmlc::JSONReader *reader);
- void Save(dmlc::Stream *writer) const;
- bool Load(dmlc::Stream *reader);
+ void Save(dmlc::JSONWriter* writer) const;
+ void Load(dmlc::JSONReader* reader);
+ void Save(dmlc::Stream* writer) const;
+ bool Load(dmlc::Stream* reader);
};
} // namespace runtime
} // namespace tvm
#ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_
#define TVM_RUNTIME_METAL_METAL_COMMON_H_
+#import <Metal/MTLBlitCommandEncoder.h>
#import <Metal/MTLBuffer.h>
-#import <Metal/MTLCommandQueue.h>
#import <Metal/MTLCommandBuffer.h>
-#import <Metal/MTLBlitCommandEncoder.h>
+#import <Metal/MTLCommandQueue.h>
#import <Metal/MTLDevice.h>
#import <Metal/MTLLibrary.h>
-
+#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
-#include <dmlc/logging.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <memory>
#include <mutex>
#include <string>
#include <vector>
-#include <memory>
+
#include "../workspace_pool.h"
namespace tvm {
// Get command queue for given context.
id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kDLMetal);
- CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
+ CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid Metal device_id=" << ctx.device_id;
return queues[ctx.device_id];
}
// Get device for given context
id<MTLDevice> GetDevice(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kDLMetal);
- CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < devices.size())
+ CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < devices.size())
<< "Invalid Metal device_id=" << ctx.device_id;
return devices[ctx.device_id];
}
// override device API
void SetDevice(TVMContext ctx) final;
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
- DLDataType type_hint) final;
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
- void CopyDataFromTo(const void* from,
- size_t from_size,
- void* to,
- size_t to_size,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
/*! \brief workspace pool */
WorkspacePool pool;
// constructor
- MetalThreadEntry()
- : pool(static_cast<DLDeviceType>(kDLMetal), MetalWorkspace::Global()) {
+ MetalThreadEntry() : pool(static_cast<DLDeviceType>(kDLMetal), MetalWorkspace::Global()) {
context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kDLMetal);
}
/*!
* \file metal_device_api.mm
*/
-#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
+#include <tvm/runtime/registry.h>
#include "metal_common.h"
namespace tvm {
namespace metal {
const std::shared_ptr<MetalWorkspace>& MetalWorkspace::Global() {
- static std::shared_ptr<MetalWorkspace> inst =
- std::make_shared<MetalWorkspace>();
+ static std::shared_ptr<MetalWorkspace> inst = std::make_shared<MetalWorkspace>();
return inst;
}
-void MetalWorkspace::GetAttr(
- TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
+void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init();
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
- *rv = int(index< devices.size());
+ *rv = int(index < devices.size());
return;
}
- CHECK_LT(index, devices.size())
- << "Invalid device id " << index;
+ CHECK_LT(index, devices.size()) << "Invalid device id " << index;
switch (kind) {
case kMaxThreadsPerBlock: {
- *rv = static_cast<int>(
- [devices[ctx.device_id] maxThreadsPerThreadgroup].width);
+ *rv = static_cast<int>([devices[ctx.device_id] maxThreadsPerThreadgroup].width);
break;
}
case kWarpSize: {
*rv = 1;
break;
}
- case kMaxSharedMemoryPerBlock: return;
- case kComputeVersion: return;
- case kDeviceName: return;
- case kMaxClockRate: return;
- case kMultiProcessorCount: return;
- case kMaxThreadDimensions: return;
- case kExist: break;
- case kGcnArch: return;
+ case kMaxSharedMemoryPerBlock:
+ return;
+ case kComputeVersion:
+ return;
+ case kDeviceName:
+ return;
+ case kMaxClockRate:
+ return;
+ case kMultiProcessorCount:
+ return;
+ case kMaxThreadDimensions:
+ return;
+ case kExist:
+ break;
+ case kGcnArch:
+ return;
}
}
// But we keep this code.
int GetWarpSize(id<MTLDevice> dev) {
NSError* error_msg = nil;
- id<MTLLibrary> lib =
- [dev
- newLibraryWithSource:
- [NSString stringWithUTF8String:kDummyKernel]
- options:nil
- error:&error_msg];
+ id<MTLLibrary> lib = [dev newLibraryWithSource:[NSString stringWithUTF8String:kDummyKernel]
+ options:nil
+ error:&error_msg];
CHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
- id<MTLFunction> f =
- [lib
- newFunctionWithName:
- [NSString stringWithUTF8String:"CopyKernel"]];
- CHECK(f!= nil);
- id<MTLComputePipelineState> state =
- [dev
- newComputePipelineStateWithFunction:f
- error:&error_msg];
+ id<MTLFunction> f = [lib newFunctionWithName:[NSString stringWithUTF8String:"CopyKernel"]];
+ CHECK(f != nil);
+ id<MTLComputePipelineState> state = [dev newComputePipelineStateWithFunction:f error:&error_msg];
CHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
return static_cast<int>(state.threadExecutionWidth);
}
initialized_ = true;
if (devices.size() != 0) return;
#if TARGET_OS_IPHONE
- // on iPhone
- id<MTLDevice> d = MTLCreateSystemDefaultDevice();
+ // on iPhone
+ id<MTLDevice> d = MTLCreateSystemDefaultDevice();
+ devices.push_back([d retain]);
+ queues.push_back([[d newCommandQueue] retain]);
+#else
+ NSArray<id<MTLDevice> >* devs = MTLCopyAllDevices();
+ for (size_t i = 0; i < devs.count; ++i) {
+ id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back([d retain]);
queues.push_back([[d newCommandQueue] retain]);
-#else
- NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices();
- for (size_t i = 0; i < devs.count; ++i) {
- id<MTLDevice> d = [devs objectAtIndex:i];
- devices.push_back([d retain]);
- queues.push_back([[d newCommandQueue] retain]);
- LOG(INFO) << "Intializing Metal device " << i
- << ", name=" << [d.name UTF8String];
- warp_size.push_back(GetWarpSize(d));
- }
+ LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String];
+ warp_size.push_back(GetWarpSize(d));
+ }
#endif
}
MetalThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
}
-void* MetalWorkspace::AllocDataSpace(
- TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) {
+void* MetalWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
+ DLDataType type_hint) {
this->Init();
id<MTLDevice> dev = GetDevice(ctx);
// GPU memory only
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
- id<MTLBuffer> buf = [
- dev newBufferWithLength:nbytes
- options:storage_mode];
+ id<MTLBuffer> buf = [dev newBufferWithLength:nbytes options:storage_mode];
CHECK(buf != nil);
return (__bridge void*)([buf retain]);
}
CFRelease(ptr);
}
-void MetalWorkspace::CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
+ size_t to_offset, size_t size, TVMContext ctx_from,
+ TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) {
this->Init();
CHECK(stream == nullptr);
int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
- CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
- << "Metal disallow cross device copy.";
+ CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
- sourceOffset:from_offset
- toBuffer:(__bridge id<MTLBuffer>)(to)
- destinationOffset:to_offset
- size:size];
+ sourceOffset:from_offset
+ toBuffer:(__bridge id<MTLBuffer>)(to)destinationOffset:to_offset
+ size:size];
[encoder endEncoding];
[cb commit];
} else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
// copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
- id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
- ->GetTempBuffer(ctx_from, size);
+ id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:from_buf
- sourceOffset:from_offset
- toBuffer:temp
- destinationOffset:0
- size:size];
+ sourceOffset:from_offset
+ toBuffer:temp
+ destinationOffset:0
+ size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
- memcpy(static_cast<char*>(to) + to_offset,
- static_cast<char*>([temp contents]),
- size);
+ memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>([temp contents]), size);
} else {
memcpy(static_cast<char*>(to) + to_offset,
- static_cast<char*>([from_buf contents]) + from_offset,
- size);
+ static_cast<char*>([from_buf contents]) + from_offset, size);
}
} else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) {
- id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
- ->GetTempBuffer(ctx_to, size);
- memcpy([temp contents],
- static_cast<const char*>(from) + from_offset,
- size);
+ id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size);
+ memcpy([temp contents], static_cast<const char*>(from) + from_offset, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:temp
- sourceOffset:0
- toBuffer:to_buf
- destinationOffset:to_offset
- size:size];
+ sourceOffset:0
+ toBuffer:to_buf
+ destinationOffset:to_offset
+ size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
} else {
memcpy(static_cast<char*>([to_buf contents]) + to_offset,
- static_cast<const char*>(from) + from_offset,
- size);
+ static_cast<const char*>(from) + from_offset, size);
}
} else {
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
- << ", from=" << from_dev_type
- << ", to=" << to_dev_type;
+ << ", from=" << from_dev_type << ", to=" << to_dev_type;
}
}
[cb waitUntilCompleted];
}
-void* MetalWorkspace::AllocWorkspace(TVMContext ctx,
- size_t size,
- DLDataType type_hint) {
+void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
if (temp_buffer_.size() <= static_cast<size_t>(ctx.device_id)) {
temp_buffer_.resize(ctx.device_id + 1, nil);
}
- if (temp_buffer_[ctx.device_id] == nil ||
- temp_buffer_[ctx.device_id].length < size) {
+ if (temp_buffer_[ctx.device_id] == nil || temp_buffer_[ctx.device_id].length < size) {
id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
if (temp_buffer_[ctx.device_id] != nil) {
[temp_buffer_[ctx.device_id] release];
}
- temp_buffer_[ctx.device_id] = [
- [dev newBufferWithLength:size
- options:MTLStorageModeShared] retain];
+ temp_buffer_[ctx.device_id] = [[dev newBufferWithLength:size
+ options:MTLStorageModeShared] retain];
}
return temp_buffer_[ctx.device_id];
}
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
-MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
- return MetalThreadStore::Get();
-}
+MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); }
-TVM_REGISTER_GLOBAL("device_api.metal")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = MetalWorkspace::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = MetalWorkspace::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
} // namespace metal
} // namespace runtime
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_RUNTIME_METAL_METAL_MODULE_H_
#include <tvm/runtime/packed_func.h>
+
#include <memory>
-#include <vector>
#include <string>
#include <unordered_map>
+#include <vector>
+
#include "../meta_data.h"
namespace tvm {
* \param fmap The map function information map of each function.
* \param source Optional, source file
*/
-Module MetalModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source);
+Module MetalModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_
/*!
* \file metal_module.cc
*/
+#include "metal_module.h"
#include <dmlc/memory_io.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
#include <array>
-#include <string>
#include <mutex>
-#include "metal_module.h"
-#include "metal_common.h"
+#include <string>
+#include "../file_util.h"
+#include "../meta_data.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
-#include "../meta_data.h"
-#include "../file_util.h"
+#include "metal_common.h"
namespace tvm {
namespace runtime {
// Module to support thread-safe multi-GPU execution.
// The runtime will contain a per-device module table
// The modules will be lazily loaded
-class MetalModuleNode final :public runtime::ModuleNode {
+class MetalModuleNode final : public runtime::ModuleNode {
public:
- explicit MetalModuleNode(std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source)
- : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {
- }
- const char* type_key() const final {
- return "metal";
- }
+ explicit MetalModuleNode(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
+ : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {}
+ const char* type_key() const final { return "metal"; }
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final;
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
- void SaveToFile(const std::string& file_name,
- const std::string& format) final {
+ void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
- CHECK_EQ(fmt, fmt_)
- << "Can only save to format=" << fmt_;
+ CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}
}
// get a from primary context in device_id
- id<MTLComputePipelineState> GetPipelineState(
- size_t device_id, const std::string& func_name) {
+ id<MTLComputePipelineState> GetPipelineState(size_t device_id, const std::string& func_name) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get();
CHECK_LT(device_id, w->devices.size());
// start lock scope.
NSError* err_msg = nil;
if (e.lib == nil) {
if (fmt_ == "metal") {
- MTLCompileOptions *opts = [MTLCompileOptions alloc];
+ MTLCompileOptions* opts = [MTLCompileOptions alloc];
// Use the Metal 1.2 for now.
opts.languageVersion = MTLLanguageVersion1_2;
opts.fastMathEnabled = YES;
// opts = nil;
- e.lib = [
- w->devices[device_id]
- newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()]
- options:opts
- error:&err_msg];
+ e.lib = [w->devices[device_id]
+ newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()]
+ options:opts
+ error:&err_msg];
[opts dealloc];
if (e.lib == nil) {
- LOG(FATAL) << "Fail to compile metal lib:"
- << [[err_msg localizedDescription] UTF8String];
+ LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
}
if (err_msg != nil) {
- LOG(INFO) << "Warning: "
- << [[err_msg localizedDescription] UTF8String];
+ LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String];
}
} else {
// Build from library.
auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
- auto data = dispatch_data_create(
- data_.c_str(), data_.length(), q, ^{});
- e.lib = [
- w->devices[device_id]
- newLibraryWithData:data
- error:&err_msg];
+ auto data = dispatch_data_create(data_.c_str(), data_.length(), q,
+ ^{
+ });
+ e.lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg];
if (err_msg != nil || e.lib == nil) {
- LOG(FATAL) << "Fail to compile metal lib:"
- << [[err_msg localizedDescription] UTF8String];
+ LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
}
}
[e.lib retain];
}
- id<MTLFunction> f = [
- e.lib
- newFunctionWithName:
- [NSString stringWithUTF8String:func_name.c_str()]];
+ id<MTLFunction> f =
+ [e.lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
CHECK(f != nil) << "cannot find function " << func_name;
id<MTLComputePipelineState> state =
- [w->devices[device_id]
- newComputePipelineStateWithFunction:f
- error:&err_msg];
- CHECK(state != nil)
- << "cannot get state:" << " for function " << func_name
- << [[err_msg localizedDescription] UTF8String];
+ [w->devices[device_id] newComputePipelineStateWithFunction:f error:&err_msg];
+ CHECK(state != nil) << "cannot get state:"
+ << " for function " << func_name
+ << [[err_msg localizedDescription] UTF8String];
// The state.threadExecutionWidth can change dynamically according
// to the resource constraint in kernel, so it is not strictly hold
// Turn of warp aware optimziation for now.
~DeviceEntry() {
if (lib != nil) [lib release];
- for (auto &&kv : smap) {
+ for (auto&& kv : smap) {
[kv.second release];
}
}
class MetalWrappedFunc {
public:
// initialize the METAL function.
- void Init(MetalModuleNode* m,
- ObjectPtr<Object> sptr,
- const std::string& func_name,
- size_t num_buffer_args,
- size_t num_pack_args,
+ void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
+ size_t num_buffer_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
w_ = metal::MetalWorkspace::Global().get();
m_ = m;
scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
}
// invoke the function with void arguments
- void operator()(TVMArgs args,
- TVMRetValue* rv,
- const ArgUnion* pack_args) const {
+ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->context.device_id;
if (scache_[device_id] == nil) {
}
if (num_pack_args_ != 0) {
[encoder setBytes:pack_args
- length:num_pack_args_ * sizeof(ArgUnion)
- atIndex:num_buffer_args_];
+ length:num_pack_args_ * sizeof(ArgUnion)
+ atIndex:num_buffer_args_];
}
// launch
- MTLSize dimGrid = MTLSizeMake(
- wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
- MTLSize dimBlock = MTLSizeMake(
- wl.block_dim(0), wl.block_dim(1), wl.block_dim(2));
- [encoder dispatchThreadgroups: dimGrid
- threadsPerThreadgroup: dimBlock];
+ MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
+ MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), wl.block_dim(2));
+ [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
[encoder endEncoding];
[cb commit];
}
ThreadAxisConfig thread_axis_cfg_;
};
-PackedFunc MetalModuleNode::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc MetalModuleNode::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
- CHECK_NE(name, symbol::tvm_module_main)
- << "Device function do not have main";
+ CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
MetalWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
- f.Init(this, sptr_to_self, name,
- num_buffer_args, info.arg_types.size() - num_buffer_args,
+ f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
info.thread_axis_tags);
return PackFuncNonBufferArg(f, info.arg_types);
}
-Module MetalModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source) {
+Module MetalModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
metal::MetalWorkspace::Global()->Init();
auto n = make_object<MetalModuleNode>(data, fmt, fmap, source);
return Module(n);
}
// Load module from module.
-Module MetalModuleLoadFile(const std::string& file_name,
- const std::string& format) {
+Module MetalModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
return MetalModuleCreate(data, fmt, fmap, "");
}
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal")
-.set_body_typed(MetalModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal").set_body_typed(MetalModuleLoadFile);
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal")
-.set_body_typed(MetalModuleLoadBinary);
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary);
} // namespace runtime
} // namespace tvm
#endif
#include <stdint.h>
-#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
+#include <tvm/runtime/c_runtime_api.h>
/*!
* \brief TODO
// GCC -O3 begins to inject memset and memmove calls, so we provide impls in
// the runtime for this case and for general usage.
-void *memset(void *s, int c, size_t n);
+void* memset(void* s, int c, size_t n);
-void *memmove(void *to, const void *from, size_t n);
+void* memmove(void* to, const void* from, size_t n);
#ifdef __cplusplus
} // TVM_EXTERN_C
*/
#include <sys/mman.h>
+
#include <cstring>
#include <memory>
-#include "micro_common.h"
+
#include "low_level_device.h"
+#include "micro_common.h"
namespace tvm {
namespace runtime {
int mmap_prot = PROT_READ | PROT_WRITE | PROT_EXEC;
int mmap_flags = MAP_ANONYMOUS | MAP_PRIVATE;
base_addr_ = mmap(nullptr, size_in_pages * kPageSize, mmap_prot, mmap_flags, -1, 0);
- *base_addr = TargetPtr(TargetWordSize(sizeof(size_t) * 8),
- reinterpret_cast<uint64_t>(base_addr_));
+ *base_addr =
+ TargetPtr(TargetWordSize(sizeof(size_t) * 8), reinterpret_cast<uint64_t>(base_addr_));
}
/*!
* \brief destructor to deallocate on-host device region
*/
- virtual ~HostLowLevelDevice() {
- munmap(base_addr_, size_);
- }
+ virtual ~HostLowLevelDevice() { munmap(base_addr_, size_); }
void Read(TargetPtr addr, void* buf, size_t num_bytes) {
std::memcpy(buf, addr.cast_to<void*>(), num_bytes);
reinterpret_cast<void (*)(void)>(func_addr.value().uint64())();
}
- const char* device_type() const final {
- return "host";
- }
+ const char* device_type() const final { return "host"; }
private:
/*! \brief base address of the micro device memory region */
* \param buffer on-host buffer to be read into
* \param num_bytes number of bytes to read
*/
- virtual void Read(TargetPtr addr,
- void* buffer,
- size_t num_bytes) = 0;
+ virtual void Read(TargetPtr addr, void* buffer, size_t num_bytes) = 0;
/*!
* \brief writes num_bytes from buffer to device memory at addr
* \param buffer host buffer to write from
* \param num_bytes number of bytes to write
*/
- virtual void Write(TargetPtr addr,
- const void* buffer,
- size_t num_bytes) = 0;
+ virtual void Write(TargetPtr addr, const void* buffer, size_t num_bytes) = 0;
/*!
* \brief starts execution of device at func_addr
* \brief common utilties for uTVM
*/
+#include "micro_common.h"
+
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
+
+#include <cstdint>
#include <cstdio>
-#include <string>
#include <sstream>
-#include <cstdint>
-#include "micro_session.h"
-#include "micro_common.h"
+#include <string>
+
#include "low_level_device.h"
+#include "micro_session.h"
namespace tvm {
namespace runtime {
const char* SectionToString(SectionKind section) {
switch (section) {
- case SectionKind::kText: return "text";
- case SectionKind::kRodata: return "rodata";
- case SectionKind::kData: return "data";
- case SectionKind::kBss: return "bss";
- case SectionKind::kArgs: return "args";
- case SectionKind::kHeap: return "heap";
- case SectionKind::kWorkspace: return "workspace";
- case SectionKind::kStack: return "stack";
- default: return "";
+ case SectionKind::kText:
+ return "text";
+ case SectionKind::kRodata:
+ return "rodata";
+ case SectionKind::kData:
+ return "data";
+ case SectionKind::kBss:
+ return "bss";
+ case SectionKind::kArgs:
+ return "args";
+ case SectionKind::kHeap:
+ return "heap";
+ case SectionKind::kWorkspace:
+ return "workspace";
+ case SectionKind::kStack:
+ return "stack";
+ default:
+ return "";
}
}
-std::string RelocateBinarySections(
- const std::string& binary_path,
- TargetWordSize word_size,
- TargetPtr text_start,
- TargetPtr rodata_start,
- TargetPtr data_start,
- TargetPtr bss_start,
- TargetPtr stack_end,
- const std::string& toolchain_prefix) {
+std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size,
+ TargetPtr text_start, TargetPtr rodata_start,
+ TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end,
+ const std::string& toolchain_prefix) {
const auto* f = Registry::Get("tvm_callback_relocate_binary");
- CHECK(f != nullptr)
- << "Require tvm_callback_relocate_binary to exist in registry";
- std::string relocated_bin = (*f)(binary_path,
- word_size.bytes(),
- text_start.cast_to<uint64_t>(),
- rodata_start.cast_to<uint64_t>(),
- data_start.cast_to<uint64_t>(),
- bss_start.cast_to<uint64_t>(),
- stack_end.cast_to<uint64_t>(),
- toolchain_prefix);
+ CHECK(f != nullptr) << "Require tvm_callback_relocate_binary to exist in registry";
+ std::string relocated_bin =
+ (*f)(binary_path, word_size.bytes(), text_start.cast_to<uint64_t>(),
+ rodata_start.cast_to<uint64_t>(), data_start.cast_to<uint64_t>(),
+ bss_start.cast_to<uint64_t>(), stack_end.cast_to<uint64_t>(), toolchain_prefix);
return relocated_bin;
}
-std::string ReadSection(const std::string& binary,
- SectionKind section,
+std::string ReadSection(const std::string& binary, SectionKind section,
const std::string& toolchain_prefix) {
CHECK(section == SectionKind::kText || section == SectionKind::kRodata ||
section == SectionKind::kData || section == SectionKind::kBss)
<< "ReadSection requires section to be one of text, rodata, data, or bss.";
const auto* f = Registry::Get("tvm_callback_read_binary_section");
- CHECK(f != nullptr)
- << "Require tvm_callback_read_binary_section to exist in registry";
+ CHECK(f != nullptr) << "Require tvm_callback_read_binary_section to exist in registry";
TVMByteArray arr;
arr.data = &binary[0];
arr.size = binary.length();
return section_contents;
}
-size_t GetSectionSize(const std::string& binary_path,
- SectionKind section,
- const std::string& toolchain_prefix,
- TargetWordSize word_size) {
+size_t GetSectionSize(const std::string& binary_path, SectionKind section,
+ const std::string& toolchain_prefix, TargetWordSize word_size) {
CHECK(section == SectionKind::kText || section == SectionKind::kRodata ||
section == SectionKind::kData || section == SectionKind::kBss)
<< "GetSectionSize requires section to be one of text, rodata, data, or bss.";
const auto* f = Registry::Get("tvm_callback_get_section_size");
- CHECK(f != nullptr)
- << "Require tvm_callback_get_section_size to exist in registry";
+ CHECK(f != nullptr) << "Require tvm_callback_get_section_size to exist in registry";
int size = (*f)(binary_path, SectionToString(section), toolchain_prefix);
return UpperAlignValue(size, word_size.bytes());
}
#define TVM_RUNTIME_MICRO_MICRO_COMMON_H_
#include <stdio.h>
-
#include <tvm/runtime/registry.h>
#include <sstream>
public:
explicit TargetWordSize(size_t word_size_bits) : word_size_bits_{word_size_bits} {
CHECK(word_size_bits == 32 || word_size_bits == 64)
- << "only 32-bit and 64-bit are supported now";
+ << "only 32-bit and 64-bit are supported now";
}
- size_t bytes() const {
- return word_size_bits_ / 8;
- }
+ size_t bytes() const { return word_size_bits_ / 8; }
- size_t bits() const {
- return word_size_bits_;
- }
+ size_t bits() const { return word_size_bits_; }
private:
size_t word_size_bits_;
};
-
/*! \brief class for storing values on varying target word sizes */
class TargetVal {
private:
public:
/*! \brief construct a TargetVal matching the size of the given integral argument */
- template<typename T, typename U = typename std::enable_if<std::is_integral<T>::value, T>::type>
+ template <typename T, typename U = typename std::enable_if<std::is_integral<T>::value, T>::type>
explicit constexpr TargetVal(T value) : TargetVal(sizeof(T) * 8, value) {}
/*! \brief construct an uninitialized value */
/*! \brief construct a TargetVal with explicit size and value */
TargetVal(size_t width_bits, uint64_t value) : width_bits_{width_bits} {
- CHECK(width_bits >= 8 &&
- width_bits <= 64 &&
- (width_bits & (width_bits - 1)) == 0)
- << "width_bits must be a power of 2 in [8, 64], got " << width_bits;
+ CHECK(width_bits >= 8 && width_bits <= 64 && (width_bits & (width_bits - 1)) == 0)
+ << "width_bits must be a power of 2 in [8, 64], got " << width_bits;
value_ = value & Bitmask();
}
}
CHECK(width_bits_ >= other.width_bits_)
- << "Cannot assign TargetVal with width " << other.width_bits_
- << "bits to TargetVal with width " << width_bits_ << "bits";
+ << "Cannot assign TargetVal with width " << other.width_bits_
+ << "bits to TargetVal with width " << width_bits_ << "bits";
value_ = other.value_ & Bitmask();
return *this;
class TargetPtr {
public:
/*! \brief construct a device address with variable-length value `value` */
- TargetPtr(TargetWordSize word_size, std::uint64_t value) :
- value_(TargetVal(word_size.bits(), value)) {}
+ TargetPtr(TargetWordSize word_size, std::uint64_t value)
+ : value_(TargetVal(word_size.bits(), value)) {}
/*! \brief construct a null address */
- TargetPtr(TargetWordSize word_size, std::nullptr_t value) :
- value_{TargetVal(word_size.bits(), 0)} {}
+ TargetPtr(TargetWordSize word_size, std::nullptr_t value)
+ : value_{TargetVal(word_size.bits(), 0)} {}
/*! \brief construct an uninitialized pointer whose word_size can be changed once */
TargetPtr() = default;
* \return casted result
*/
template <typename T>
- T cast_to() const { return reinterpret_cast<T>(value_.uint64()); }
+ T cast_to() const {
+ return reinterpret_cast<T>(value_.uint64());
+ }
/*! \brief check if location is null */
bool operator==(std::nullptr_t) const { return value_.uint64() == 0; }
* \param binary contents of binary object file
* \param toolchain_prefix prefix of compiler toolchain to use
*/
- SymbolMap(const std::string& binary,
- const std::string& toolchain_prefix,
+ SymbolMap(const std::string& binary, const std::string& toolchain_prefix,
TargetWordSize word_size) {
const auto* f = Registry::Get("tvm_callback_get_symbol_map");
CHECK(f != nullptr) << "require tvm_callback_get_symbol_map to exist in registry";
return result->second;
}
- bool HasSymbol(const std::string& name) const {
- return map_.find(name) != map_.end();
- }
+ bool HasSymbol(const std::string& name) const { return map_.find(name) != map_.end(); }
void Dump(std::ostream& stream) const {
for (auto e : map_) {
* \param toolchain_prefix prefix of compiler toolchain to use
* \return relocated binary file contents
*/
-std::string RelocateBinarySections(
- const std::string& binary_path,
- TargetWordSize word_size,
- TargetPtr text_start,
- TargetPtr rodata_start,
- TargetPtr data_start,
- TargetPtr bss_start,
- TargetPtr stack_end,
- const std::string& toolchain_prefix);
+std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size,
+ TargetPtr text_start, TargetPtr rodata_start,
+ TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end,
+ const std::string& toolchain_prefix);
/*!
* \brief reads section from binary
* \param toolchain_prefix prefix of compiler toolchain to use
* \return contents of the section
*/
-std::string ReadSection(const std::string& binary,
- SectionKind section,
+std::string ReadSection(const std::string& binary, SectionKind section,
const std::string& toolchain_prefix);
/*!
* \param word_size word size of the target, for alignment
* \return size of the section if it exists, 0 otherwise
*/
-size_t GetSectionSize(const std::string& binary_name,
- SectionKind section,
- const std::string& toolchain_prefix,
- TargetWordSize word_size);
+size_t GetSectionSize(const std::string& binary_name, SectionKind section,
+ const std::string& toolchain_prefix, TargetWordSize word_size);
} // namespace runtime
} // namespace tvm
* \file micro_device_api.cc
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include "../workspace_pool.h"
#include "micro_session.h"
class MicroDeviceAPI final : public DeviceAPI {
public:
/*! \brief constructor */
- MicroDeviceAPI() { }
+ MicroDeviceAPI() {}
void SetDevice(TVMContext ctx) final {}
}
}
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) final {
ObjectPtr<MicroSession>& session = MicroSession::Current();
TargetPtr data = session->AllocateInSection(SectionKind::kHeap, nbytes);
delete dev_space;
}
- void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final {
std::tuple<int, int> type_from_to(ctx_from.device_type, ctx_to.device_type);
if (type_from_to == std::make_tuple(kDLMicroDev, kDLMicroDev)) {
MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
CHECK(from_space->session == to_space->session)
- << "attempt to copy data between different micro sessions ("
- << from_space->session.get()
+ << "attempt to copy data between different micro sessions (" << from_space->session.get()
<< " != " << to_space->session.get() << ")";
CHECK(ctx_from.device_id == ctx_to.device_id)
- << "can only copy between the same micro device";
+ << "can only copy between the same micro device";
ObjectPtr<MicroSession>& session = from_space->session;
// flush all pending tasks to ensure data is consistent
session->FlushTaskQueue();
TargetPtr data = session->AllocateInSection(SectionKind::kWorkspace, size);
CHECK(data.value().uint64() != 0)
- << "unable to allocate " << size << " bytes on device workspace";
+ << "unable to allocate " << size << " bytes on device workspace";
return static_cast<void*>(new MicroDevSpace{data, session});
}
}
private:
- TargetPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) {
- return dev_space->data + offset;
- }
+ TargetPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) { return dev_space->data + offset; }
void* GetHostLoc(const void* ptr, size_t offset) {
return reinterpret_cast<void*>(reinterpret_cast<std::uintptr_t>(ptr) + offset);
};
// register device that can be obtained from Python frontend
-TVM_REGISTER_GLOBAL("device_api.micro_dev")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = MicroDeviceAPI::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = MicroDeviceAPI::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
} // namespace runtime
} // namespace tvm
* \file micro_module.cc
*/
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/module.h>
-#include <unordered_map>
+#include <tvm/runtime/registry.h>
+
#include <string>
-#include "micro_session.h"
+#include <unordered_map>
+
+#include "../pack_args.h"
#include "low_level_device.h"
#include "micro_common.h"
-#include "../pack_args.h"
+#include "micro_session.h"
namespace tvm {
namespace runtime {
~MicroModuleNode() {}
- const char* type_key() const final {
- return "micro";
- }
+ const char* type_key() const final { return "micro"; }
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final;
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
/*!
* \brief initializes module by establishing device connection and loads binary
class MicroWrappedFunc {
public:
- MicroWrappedFunc(ObjectPtr<MicroSession> session,
- TargetPtr func_ptr) {
+ MicroWrappedFunc(ObjectPtr<MicroSession> session, TargetPtr func_ptr) {
session_ = session;
func_ptr_ = func_ptr;
}
TargetPtr func_ptr_;
};
-PackedFunc MicroModuleNode::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc MicroModuleNode::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
TargetPtr func_ptr;
if (name == tvm::runtime::symbol::tvm_module_main) {
if (symbol_map_.HasSymbol(tvm::runtime::symbol::tvm_module_main)) {
// register loadfile function to load module from Python frontend
TVM_REGISTER_GLOBAL("runtime.module.loadfile_micro_dev")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- auto n = make_object<MicroModuleNode>();
- n->InitMicroModule(args[0]);
- *rv = runtime::Module(n);
- });
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ auto n = make_object<MicroModuleNode>();
+ n->InitMicroModule(args[0]);
+ *rv = runtime::Module(n);
+ });
} // namespace runtime
} // namespace tvm
#include <string>
#include <unordered_map>
+
#include "micro_common.h"
namespace tvm {
* \brief constructor that specifies section boundaries
* \param region location and size of the section on the device
*/
- explicit MicroSectionAllocator(std::string section_name,
- DevMemRegion region,
+ explicit MicroSectionAllocator(std::string section_name, DevMemRegion region,
TargetWordSize word_size)
- : section_name_(section_name),
- start_addr_(region.start),
- size_(0),
- capacity_(region.size),
- word_size_(word_size) {
- CHECK_EQ(start_addr_.value().uint64() % word_size.bytes(), 0)
+ : section_name_(section_name),
+ start_addr_(region.start),
+ size_(0),
+ capacity_(region.size),
+ word_size_(word_size) {
+ CHECK_EQ(start_addr_.value().uint64() % word_size.bytes(), 0)
<< "micro section start not aligned to " << word_size.bytes() << " bytes";
- CHECK_EQ(capacity_ % word_size.bytes(), 0)
+ CHECK_EQ(capacity_ % word_size.bytes(), 0)
<< "micro section end not aligned to " << word_size.bytes() << " bytes";
- }
+ }
/*!
* \brief destructor
TargetPtr Allocate(size_t size) {
size_ = UpperAlignValue(size_, word_size_.bytes());
CHECK(size_ + size < capacity_)
- << "cannot alloc " << size << " bytes in section \""
- << section_name_ << "\" (start_addr=" << start_addr_.cast_to<void*>()
- << ", used=" << size_ << ", capacity=" << capacity_ << ")";
+ << "cannot alloc " << size << " bytes in section \"" << section_name_
+ << "\" (start_addr=" << start_addr_.cast_to<void*>() << ", used=" << size_
+ << ", capacity=" << capacity_ << ")";
TargetPtr alloc_addr = start_addr_ + size_;
size_ += size;
alloc_map_[alloc_addr.value().uint64()] = size;
*/
void Free(TargetPtr addr) {
CHECK(alloc_map_.find(addr.value().uint64()) != alloc_map_.end())
- << "freed pointer was never allocated";
+ << "freed pointer was never allocated";
alloc_map_.erase(addr.value().uint64());
if (alloc_map_.empty()) {
size_ = 0;
* \file micro_session.cc
*/
+#include "micro_session.h"
+
#include <dmlc/thread_local.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include <chrono>
-#include <memory>
#include <locale>
+#include <memory>
#include <stack>
#include <tuple>
#include <vector>
-#include "micro_session.h"
+
#include "low_level_device.h"
#include "target_data_layout_encoder.h"
typedef dmlc::ThreadLocalStore<TVMMicroSessionThreadLocalEntry> TVMMicroSessionThreadLocalStore;
ObjectPtr<MicroSession>& MicroSession::Current() {
- TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get();
+ TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get();
CHECK_GT(entry->session_stack.size(), 0) << "No current session";
return entry->session_stack.top();
}
void MicroSession::EnterWithScope(ObjectPtr<MicroSession> session) {
- TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get();
+ TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get();
entry->session_stack.push(session);
}
void MicroSession::ExitWithScope() {
- TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get();
+ TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get();
CHECK(!entry->session_stack.empty());
entry->session_stack.pop();
}
-MicroSession::MicroSession(
- const std::string& comms_method,
- const std::string& binary_path,
- const std::string& toolchain_prefix,
- uint64_t text_start,
- size_t text_size,
- uint64_t rodata_start,
- size_t rodata_size,
- uint64_t data_start,
- size_t data_size,
- uint64_t bss_start,
- size_t bss_size,
- uint64_t args_start,
- size_t args_size,
- uint64_t heap_start,
- size_t heap_size,
- uint64_t workspace_start,
- size_t workspace_size,
- uint64_t stack_start,
- size_t stack_size,
- TargetWordSize word_size,
- bool thumb_mode,
- bool use_device_timer,
- const std::string& server_addr,
- int port)
+MicroSession::MicroSession(const std::string& comms_method, const std::string& binary_path,
+ const std::string& toolchain_prefix, uint64_t text_start,
+ size_t text_size, uint64_t rodata_start, size_t rodata_size,
+ uint64_t data_start, size_t data_size, uint64_t bss_start,
+ size_t bss_size, uint64_t args_start, size_t args_size,
+ uint64_t heap_start, size_t heap_size, uint64_t workspace_start,
+ size_t workspace_size, uint64_t stack_start, size_t stack_size,
+ TargetWordSize word_size, bool thumb_mode, bool use_device_timer,
+ const std::string& server_addr, int port)
: toolchain_prefix_(toolchain_prefix),
word_size_(word_size),
thumb_mode_(thumb_mode),
batch_args_encoder_(args_size, word_size) {
if (comms_method == "host") {
// TODO(weberlo): move checks to python
- CHECK(
- text_start == 0 &&
- rodata_start == 0 &&
- data_start == 0 &&
- bss_start == 0 &&
- args_start == 0 &&
- heap_start == 0 &&
- workspace_start == 0 &&
- stack_start == 0) << "unable to specify section addresses for host device";
- size_t memory_size =
- text_size + rodata_size + data_size + bss_size +
- args_size + heap_size + workspace_size + stack_size;
+ CHECK(text_start == 0 && rodata_start == 0 && data_start == 0 && bss_start == 0 &&
+ args_start == 0 && heap_start == 0 && workspace_start == 0 && stack_start == 0)
+ << "unable to specify section addresses for host device";
+ size_t memory_size = text_size + rodata_size + data_size + bss_size + args_size + heap_size +
+ workspace_size + stack_size;
TargetPtr base_addr;
low_level_device_ = HostLowLevelDeviceCreate(memory_size, &base_addr);
CHECK_EQ(base_addr.value().uint64() % word_size.bytes(), 0)
- << "base address not aligned to " << word_size.bytes() << " bytes";
+ << "base address not aligned to " << word_size.bytes() << " bytes";
TargetPtr curr_addr = base_addr;
- section_allocators_[0] = std::make_shared<MicroSectionAllocator>(
- "text",
- DevMemRegion {
- .start = curr_addr,
- .size = text_size,
- }, word_size_);
+ section_allocators_[0] = std::make_shared<MicroSectionAllocator>("text",
+ DevMemRegion{
+ .start = curr_addr,
+ .size = text_size,
+ },
+ word_size_);
curr_addr += text_size;
- section_allocators_[1] = std::make_shared<MicroSectionAllocator>(
- "rodata",
- DevMemRegion {
- .start = curr_addr,
- .size = rodata_size,
- }, word_size_);
+ section_allocators_[1] = std::make_shared<MicroSectionAllocator>("rodata",
+ DevMemRegion{
+ .start = curr_addr,
+ .size = rodata_size,
+ },
+ word_size_);
curr_addr += rodata_size;
- section_allocators_[2] = std::make_shared<MicroSectionAllocator>(
- "data",
- DevMemRegion {
- .start = curr_addr,
- .size = data_size,
- }, word_size_);
+ section_allocators_[2] = std::make_shared<MicroSectionAllocator>("data",
+ DevMemRegion{
+ .start = curr_addr,
+ .size = data_size,
+ },
+ word_size_);
curr_addr += data_size;
- section_allocators_[3] = std::make_shared<MicroSectionAllocator>(
- "bss",
- DevMemRegion {
- .start = curr_addr,
- .size = bss_size,
- }, word_size_);
+ section_allocators_[3] = std::make_shared<MicroSectionAllocator>("bss",
+ DevMemRegion{
+ .start = curr_addr,
+ .size = bss_size,
+ },
+ word_size_);
curr_addr += bss_size;
- section_allocators_[4] = std::make_shared<MicroSectionAllocator>(
- "args",
- DevMemRegion {
- .start = curr_addr,
- .size = args_size,
- }, word_size_);
+ section_allocators_[4] = std::make_shared<MicroSectionAllocator>("args",
+ DevMemRegion{
+ .start = curr_addr,
+ .size = args_size,
+ },
+ word_size_);
curr_addr += args_size;
- section_allocators_[5] = std::make_shared<MicroSectionAllocator>(
- "heap",
- DevMemRegion {
- .start = curr_addr,
- .size = heap_size,
- }, word_size_);
+ section_allocators_[5] = std::make_shared<MicroSectionAllocator>("heap",
+ DevMemRegion{
+ .start = curr_addr,
+ .size = heap_size,
+ },
+ word_size_);
curr_addr += heap_size;
- section_allocators_[6] = std::make_shared<MicroSectionAllocator>(
- "workspace",
- DevMemRegion {
- .start = curr_addr,
- .size = workspace_size,
- }, word_size_);
+ section_allocators_[6] = std::make_shared<MicroSectionAllocator>("workspace",
+ DevMemRegion{
+ .start = curr_addr,
+ .size = workspace_size,
+ },
+ word_size_);
curr_addr += workspace_size;
- section_allocators_[7] = std::make_shared<MicroSectionAllocator>(
- "stack",
- DevMemRegion {
- .start = curr_addr,
- .size = stack_size,
- }, word_size_);
+ section_allocators_[7] = std::make_shared<MicroSectionAllocator>("stack",
+ DevMemRegion{
+ .start = curr_addr,
+ .size = stack_size,
+ },
+ word_size_);
curr_addr += stack_size;
} else if (comms_method == "openocd") {
low_level_device_ = OpenOCDLowLevelDeviceCreate(server_addr, port);
- section_allocators_[0] = std::make_shared<MicroSectionAllocator>(
- "text",
- DevMemRegion {
- .start = TargetPtr(word_size_, text_start),
- .size = text_size,
- }, word_size_);
- section_allocators_[1] = std::make_shared<MicroSectionAllocator>(
- "rodata",
- DevMemRegion {
- .start = TargetPtr(word_size_, rodata_start),
- .size = rodata_size,
- }, word_size_);
- section_allocators_[2] = std::make_shared<MicroSectionAllocator>(
- "data",
- DevMemRegion {
- .start = TargetPtr(word_size_, data_start),
- .size = data_size,
- }, word_size_);
- section_allocators_[3] = std::make_shared<MicroSectionAllocator>(
- "bss",
- DevMemRegion {
- .start = TargetPtr(word_size_, bss_start),
- .size = bss_size,
- }, word_size_);
- section_allocators_[4] = std::make_shared<MicroSectionAllocator>(
- "args",
- DevMemRegion {
- .start = TargetPtr(word_size_, args_start),
- .size = args_size,
- }, word_size_);
- section_allocators_[5] = std::make_shared<MicroSectionAllocator>(
- "heap",
- DevMemRegion {
- .start = TargetPtr(word_size_, heap_start),
- .size = heap_size,
- }, word_size_);
- section_allocators_[6] = std::make_shared<MicroSectionAllocator>(
- "workspace",
- DevMemRegion {
- .start = TargetPtr(word_size_, workspace_start),
- .size = workspace_size,
- }, word_size_);
- section_allocators_[7] = std::make_shared<MicroSectionAllocator>(
- "stack",
- DevMemRegion {
- .start = TargetPtr(word_size_, stack_start),
- .size = stack_size,
- }, word_size_);
+ section_allocators_[0] =
+ std::make_shared<MicroSectionAllocator>("text",
+ DevMemRegion{
+ .start = TargetPtr(word_size_, text_start),
+ .size = text_size,
+ },
+ word_size_);
+ section_allocators_[1] =
+ std::make_shared<MicroSectionAllocator>("rodata",
+ DevMemRegion{
+ .start = TargetPtr(word_size_, rodata_start),
+ .size = rodata_size,
+ },
+ word_size_);
+ section_allocators_[2] =
+ std::make_shared<MicroSectionAllocator>("data",
+ DevMemRegion{
+ .start = TargetPtr(word_size_, data_start),
+ .size = data_size,
+ },
+ word_size_);
+ section_allocators_[3] =
+ std::make_shared<MicroSectionAllocator>("bss",
+ DevMemRegion{
+ .start = TargetPtr(word_size_, bss_start),
+ .size = bss_size,
+ },
+ word_size_);
+ section_allocators_[4] =
+ std::make_shared<MicroSectionAllocator>("args",
+ DevMemRegion{
+ .start = TargetPtr(word_size_, args_start),
+ .size = args_size,
+ },
+ word_size_);
+ section_allocators_[5] =
+ std::make_shared<MicroSectionAllocator>("heap",
+ DevMemRegion{
+ .start = TargetPtr(word_size_, heap_start),
+ .size = heap_size,
+ },
+ word_size_);
+ section_allocators_[6] =
+ std::make_shared<MicroSectionAllocator>("workspace",
+ DevMemRegion{
+ .start = TargetPtr(word_size_, workspace_start),
+ .size = workspace_size,
+ },
+ word_size_);
+ section_allocators_[7] =
+ std::make_shared<MicroSectionAllocator>("stack",
+ DevMemRegion{
+ .start = TargetPtr(word_size_, stack_start),
+ .size = stack_size,
+ },
+ word_size_);
} else {
LOG(FATAL) << "unsupported micro low-level device";
}
TargetVal arg_values_dev_addr{std::get<0>(arg_field_addrs).value()};
TargetVal arg_type_codes_dev_addr{std::get<1>(arg_field_addrs).value()};
- task_queue_.push_back(
- DevTask {
- .func = func_dev_addr,
- .arg_values = arg_values_dev_addr,
- .arg_type_codes = arg_type_codes_dev_addr,
- .num_args = args.num_args
- });
+ task_queue_.push_back(DevTask{.func = func_dev_addr,
+ .arg_values = arg_values_dev_addr,
+ .arg_type_codes = arg_type_codes_dev_addr,
+ .num_args = args.num_args});
if (task_queue_.size() == MicroSession::kTaskQueueCapacity) {
FlushTaskQueue();
}
// Flush `args` to device memory.
- low_level_device()->Write(
- batch_args_encoder_.start_addr(),
- reinterpret_cast<void*>(batch_args_encoder_.data()),
- batch_args_encoder_.buf_size());
+ low_level_device()->Write(batch_args_encoder_.start_addr(),
+ reinterpret_cast<void*>(batch_args_encoder_.data()),
+ batch_args_encoder_.buf_size());
// Flush `tasks` to device memory.
TargetPtr dev_tasks_addr = runtime_symbol_map_["utvm_tasks"];
- low_level_device()->Write(
- dev_tasks_addr,
- reinterpret_cast<void*>(prepped_tasks.data()),
- prepped_tasks.size() * sizeof(T));
+ low_level_device()->Write(dev_tasks_addr, reinterpret_cast<void*>(prepped_tasks.data()),
+ prepped_tasks.size() * sizeof(T));
DevSymbolWrite<uint32_t>(runtime_symbol_map_, "utvm_num_tasks", prepped_tasks.size());
TargetPtr utvm_init_addr = runtime_symbol_map_["UTVMInit"];
utvm_init_addr += 1;
}
- std::chrono::time_point<
- std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
+ std::chrono::time_point<std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin,
+ tend;
tbegin = std::chrono::high_resolution_clock::now();
// std::string tmp;
// while (tmp[0] != 'd' && tmp[0] != 'e') {
uint64_t sum = 0;
std::vector<uint32_t> times;
times.resize(task_queue_.size());
- low_level_device()->Read(runtime_symbol_map_["utvm_task_times"],
- times.data(),
+ low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(),
task_queue_.size() * sizeof(uint32_t));
int i = 0;
for (uint32_t time : times) {
}
last_batch_time_ += static_cast<double>(sum) / 1e3;
} else {
- last_batch_time_ += std::chrono::duration_cast<std::chrono::duration<double> >
- (tend - tbegin).count() * 1000;
+ last_batch_time_ +=
+ std::chrono::duration_cast<std::chrono::duration<double>>(tend - tbegin).count() * 1000;
// TODO(weberlo): Reading internal data structure is hacky.
uint64_t sum = 0;
std::vector<uint32_t> times;
times.resize(task_queue_.size());
- low_level_device()->Read(runtime_symbol_map_["utvm_task_times"],
- times.data(),
+ low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(),
task_queue_.size() * sizeof(uint32_t));
for (uint32_t time : times) {
sum += time;
DevMemRegion data_section;
DevMemRegion bss_section;
- text_section.size = GetSectionSize(
- binary_path, SectionKind::kText, toolchain_prefix_, word_size_);
- rodata_section.size = GetSectionSize(
- binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_);
- data_section.size = GetSectionSize(
- binary_path, SectionKind::kData, toolchain_prefix_, word_size_);
- bss_section.size = GetSectionSize(
- binary_path, SectionKind::kBss, toolchain_prefix_, word_size_);
+ text_section.size =
+ GetSectionSize(binary_path, SectionKind::kText, toolchain_prefix_, word_size_);
+ rodata_section.size =
+ GetSectionSize(binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_);
+ data_section.size =
+ GetSectionSize(binary_path, SectionKind::kData, toolchain_prefix_, word_size_);
+ bss_section.size = GetSectionSize(binary_path, SectionKind::kBss, toolchain_prefix_, word_size_);
text_section.start = AllocateInSection(SectionKind::kText, text_section.size);
rodata_section.start = AllocateInSection(SectionKind::kRodata, rodata_section.size);
bss_section.start = AllocateInSection(SectionKind::kBss, bss_section.size);
std::string relocated_bin = RelocateBinarySections(
- binary_path,
- word_size_,
- text_section.start,
- rodata_section.start,
- data_section.start,
- bss_section.start,
- GetAllocator(SectionKind::kStack)->max_addr(),
- toolchain_prefix_);
+ binary_path, word_size_, text_section.start, rodata_section.start, data_section.start,
+ bss_section.start, GetAllocator(SectionKind::kStack)->max_addr(), toolchain_prefix_);
std::string text_contents = ReadSection(relocated_bin, SectionKind::kText, toolchain_prefix_);
std::string rodata_contents = ReadSection(relocated_bin, SectionKind::kRodata, toolchain_prefix_);
std::string data_contents = ReadSection(relocated_bin, SectionKind::kData, toolchain_prefix_);
low_level_device_->Write(rodata_section.start, &rodata_contents[0], rodata_section.size);
low_level_device_->Write(data_section.start, &data_contents[0], data_section.size);
low_level_device_->Write(bss_section.start, &bss_contents[0], bss_section.size);
- SymbolMap symbol_map {relocated_bin, toolchain_prefix_, word_size_};
+ SymbolMap symbol_map{relocated_bin, toolchain_prefix_, word_size_};
if (patch_dylib_pointers) {
// Patch device lib pointers.
PatchImplHole(symbol_map, "TVMAPISetLastError");
}
- return BinaryInfo {
+ return BinaryInfo{
.text_section = text_section,
.rodata_section = rodata_section,
.data_section = data_section,
};
}
-std::tuple<TargetPtr, TargetPtr> MicroSession::EncoderAppend(
- TargetDataLayoutEncoder* encoder, const TVMArgs& args) {
+std::tuple<TargetPtr, TargetPtr> MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder,
+ const TVMArgs& args) {
const int* type_codes = args.type_codes;
int num_args = args.num_args;
strides_dev_addr = stride_slot.start_addr();
}
- T dev_arr(
- TargetVal { word_size_.bits(), reinterpret_cast<uint64_t>(arr.data) },
- arr.ctx,
- arr.ndim,
- arr.dtype,
- shape_dev_addr.value(),
- strides_dev_addr.value(),
- TargetVal { word_size_.bits(), arr.byte_offset });
+ T dev_arr(TargetVal{word_size_.bits(), reinterpret_cast<uint64_t>(arr.data)}, arr.ctx, arr.ndim,
+ arr.dtype, shape_dev_addr.value(), strides_dev_addr.value(),
+ TargetVal{word_size_.bits(), arr.byte_offset});
CHECK(dev_arr.ctx.device_type == static_cast<DLDeviceType>(kDLMicroDev))
- << "attempt to write DLTensor with non-micro device type";
+ << "attempt to write DLTensor with non-micro device type";
// Update the device type to CPU, because from the microcontroller's
// perspective, it is.
dev_arr.ctx.device_type = DLDeviceType::kDLCPU;
if (last_error) {
if (!use_device_timer_ &&
- (last_error == UTVM_ERR_TIMER_OVERFLOW ||
- last_error == UTVM_ERR_TIMER_NOT_IMPLEMENTED)) {
+ (last_error == UTVM_ERR_TIMER_OVERFLOW || last_error == UTVM_ERR_TIMER_NOT_IMPLEMENTED)) {
// these errors don't matter if we're not using the on-device timer
return;
}
return result;
}
-void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map,
- const std::string& symbol,
+void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol,
const TargetPtr& ptr) {
if (word_size_.bytes() == 4) {
DevSymbolWrite(symbol_map, symbol, ptr.value().uint32());
}
template <typename T>
-void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map,
- const std::string& symbol,
+void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol,
const T& value) {
TargetPtr sym_addr = symbol_map[symbol];
low_level_device()->Write(sym_addr, &value, sizeof(T));
}
-PackedFunc MicroSession::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc MicroSession::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
if (name == "enter") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
MicroSession::EnterWithScope(GetObjectPtr<MicroSession>(this));
});
} else if (name == "exit") {
- return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) {
- MicroSession::ExitWithScope();
- });
+ return PackedFunc(
+ [sptr_to_self](TVMArgs args, TVMRetValue* rv) { MicroSession::ExitWithScope(); });
// TODO(weberlo): add a `clear_batch_timer` func
} else if (name == "get_last_batch_time") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetLastBatchTime();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchTime(); });
// TODO(weberlo): remove this func
} else if (name == "get_last_batch_cycles") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetLastBatchCycles();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchCycles(); });
} else {
return PackedFunc();
}
}
-TVM_REGISTER_GLOBAL("micro._GetMicroTimeEvaluator")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("micro._GetMicroTimeEvaluator").set_body([](TVMArgs args, TVMRetValue* rv) {
PackedFunc pf = args[0];
TVMContext ctx = args[1];
uint64_t number = args[2];
uint64_t repeat = args[3];
- auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue *rv) mutable {
+ auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue* rv) mutable {
TVMRetValue temp;
std::ostringstream os;
for (unsigned int i = 0; i < repeat; ++i) {
// start timing
CHECK(number < MicroSession::kTaskQueueCapacity)
- << "`number` must be less than uTVM task queue capacity";
+ << "`number` must be less than uTVM task queue capacity";
for (unsigned int j = 0; j < number; ++j) {
pf.CallPacked(args, &temp);
}
*rv = PackedFunc(ftimer);
});
-
// create micro session and low-level device from Python frontend
-TVM_REGISTER_GLOBAL("micro._CreateSession")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- const std::string& comms_method = args[0];
- const std::string& binary_path = args[1];
- const std::string& toolchain_prefix = args[2];
- uint64_t text_start = args[3];
- size_t text_size = uint64_t(args[4]);
- uint64_t rodata_start = args[5];
- size_t rodata_size = uint64_t(args[6]);
- uint64_t data_start = args[7];
- size_t data_size = uint64_t(args[8]);
- uint64_t bss_start = args[9];
- size_t bss_size = uint64_t(args[10]);
- uint64_t args_start = args[11];
- size_t args_size = uint64_t(args[12]);
- uint64_t heap_start = args[13];
- size_t heap_size = uint64_t(args[14]);
- uint64_t workspace_start = args[15];
- size_t workspace_size = uint64_t(args[16]);
- uint64_t stack_start = args[17];
- size_t stack_size = uint64_t(args[18]);
- TargetWordSize word_size{uint64_t(args[19])};
- bool thumb_mode = args[20];
- bool use_device_timer = args[21];
- const std::string& server_addr = args[22];
- int port = args[23];
- ObjectPtr<MicroSession> session = make_object<MicroSession>(
- comms_method,
- binary_path,
- toolchain_prefix,
- text_start,
- text_size,
- rodata_start,
- rodata_size,
- data_start,
- data_size,
- bss_start,
- bss_size,
- args_start,
- args_size,
- heap_start,
- heap_size,
- workspace_start,
- workspace_size,
- stack_start,
- stack_size,
- word_size,
- thumb_mode,
- use_device_timer,
- server_addr,
- port);
- *rv = Module(session);
- });
+TVM_REGISTER_GLOBAL("micro._CreateSession").set_body([](TVMArgs args, TVMRetValue* rv) {
+ const std::string& comms_method = args[0];
+ const std::string& binary_path = args[1];
+ const std::string& toolchain_prefix = args[2];
+ uint64_t text_start = args[3];
+ size_t text_size = uint64_t(args[4]);
+ uint64_t rodata_start = args[5];
+ size_t rodata_size = uint64_t(args[6]);
+ uint64_t data_start = args[7];
+ size_t data_size = uint64_t(args[8]);
+ uint64_t bss_start = args[9];
+ size_t bss_size = uint64_t(args[10]);
+ uint64_t args_start = args[11];
+ size_t args_size = uint64_t(args[12]);
+ uint64_t heap_start = args[13];
+ size_t heap_size = uint64_t(args[14]);
+ uint64_t workspace_start = args[15];
+ size_t workspace_size = uint64_t(args[16]);
+ uint64_t stack_start = args[17];
+ size_t stack_size = uint64_t(args[18]);
+ TargetWordSize word_size{uint64_t(args[19])};
+ bool thumb_mode = args[20];
+ bool use_device_timer = args[21];
+ const std::string& server_addr = args[22];
+ int port = args[23];
+ ObjectPtr<MicroSession> session = make_object<MicroSession>(
+ comms_method, binary_path, toolchain_prefix, text_start, text_size, rodata_start, rodata_size,
+ data_start, data_size, bss_start, bss_size, args_start, args_size, heap_start, heap_size,
+ workspace_start, workspace_size, stack_start, stack_size, word_size, thumb_mode,
+ use_device_timer, server_addr, port);
+ *rv = Module(session);
+});
} // namespace runtime
} // namespace tvm
#ifndef TVM_RUNTIME_MICRO_MICRO_SESSION_H_
#define TVM_RUNTIME_MICRO_MICRO_SESSION_H_
-#include "micro_common.h"
-#include "micro_section_allocator.h"
-
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/registry.h>
#include <memory>
#include <string>
+#include <tuple>
#include <unordered_map>
#include <vector>
-#include <tuple>
#include "low_level_device.h"
+#include "micro_common.h"
+#include "micro_section_allocator.h"
#include "target_data_layout_encoder.h"
namespace tvm {
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
- virtual PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self);
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
// todo having this decoupled from the value in utvm_runtime.c gives me stress dreams
static const size_t kTaskQueueCapacity = 20;
/*!
* \return The type key of the executor.
*/
- const char* type_key() const final {
- return "MicroSession";
- }
+ const char* type_key() const final { return "MicroSession"; }
/*!
* \brief creates session by setting up a low-level device and initting allocators for it
* \param server_addr address of the OpenOCD server to connect to (if `comms_method == "openocd"`)
* \param port port of the OpenOCD server to connect to (if `comms_method == "openocd"`)
*/
- MicroSession(
- const std::string& comms_method,
- const std::string& binary_path,
- const std::string& toolchain_prefix,
- uint64_t text_start,
- size_t text_size,
- uint64_t rodata_start,
- size_t rodata_size,
- uint64_t data_start,
- size_t data_size,
- uint64_t bss_start,
- size_t bss_size,
- uint64_t args_start,
- size_t args_size,
- uint64_t heap_start,
- size_t heap_size,
- uint64_t workspace_start,
- size_t workspace_size,
- uint64_t stack_start,
- size_t stack_size,
- TargetWordSize word_size,
- bool thumb_mode,
- bool use_device_timer,
- const std::string& server_addr,
- int port);
+ MicroSession(const std::string& comms_method, const std::string& binary_path,
+ const std::string& toolchain_prefix, uint64_t text_start, size_t text_size,
+ uint64_t rodata_start, size_t rodata_size, uint64_t data_start, size_t data_size,
+ uint64_t bss_start, size_t bss_size, uint64_t args_start, size_t args_size,
+ uint64_t heap_start, size_t heap_size, uint64_t workspace_start,
+ size_t workspace_size, uint64_t stack_start, size_t stack_size,
+ TargetWordSize word_size, bool thumb_mode, bool use_device_timer,
+ const std::string& server_addr, int port);
/*!
* \brief destructor
std::string ReadString(TargetPtr str_addr);
/*!
- * \brief read value of symbol from device memory
- * \param symbol_map symbol map to read location of symbol from
- * \param symbol name of symbol being read from
- * \return value at symbol in memory
- */
+ * \brief read value of symbol from device memory
+ * \param symbol_map symbol map to read location of symbol from
+ * \param symbol name of symbol being read from
+ * \return value at symbol in memory
+ */
template <typename T>
T DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol);
/*!
* \brief write pointer value into device memory corresponding to symbol
- * \param symbol_map symbol map to read location of symbol from
- * \param symbol name of symbol being written to
- * \param ptr pointer value to write into symbol
+ * \param symbol_map symbol map to read location of symbol from
+ * \param symbol name of symbol being written to
+ * \param ptr pointer value to write into symbol
*/
- void DevSymbolWrite(const SymbolMap& symbol_map,
- const std::string& symbol,
- const TargetPtr& ptr);
+ void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const TargetPtr& ptr);
/*!
- * \brief write value into device memory corresponding to symbol
- * \param symbol_map symbol map to read location of symbol from
- * \param symbol name of symbol being written to
- * \param value value being written into symbol
+ * \brief write value into device memory corresponding to symbol
+ * \param symbol_map symbol map to read location of symbol from
+ * \param symbol name of symbol being written to
+ * \param value value being written into symbol
*/
template <typename T>
void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const T& value);
}
/*!
- * \brief Push a new session context onto the thread-local stack.
- * The session on top of the stack is used as the current global session.
- */
+ * \brief Push a new session context onto the thread-local stack.
+ * The session on top of the stack is used as the current global session.
+ */
static void EnterWithScope(ObjectPtr<MicroSession> session);
/*!
- * \brief Pop a session off the thread-local context stack,
- * restoring the previous session as the current context.
- */
+ * \brief Pop a session off the thread-local context stack,
+ * restoring the previous session as the current context.
+ */
static void ExitWithScope();
};
/*! \brief TVM array for serialization to 32-bit devices */
struct TVMArray32 {
- TVMArray32(
- TargetVal data,
- DLContext ctx,
- int32_t ndim,
- DLDataType dtype,
- TargetVal shape,
- TargetVal strides,
- TargetVal byte_offset)
- : data(data.uint32()),
- ctx(ctx),
- ndim(ndim),
- pad0(0),
- dtype(dtype),
- shape(shape.uint32()),
- strides(strides.uint32()),
- pad1(0),
- byte_offset(byte_offset.uint32()),
- pad2(0) { }
+ TVMArray32(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape,
+ TargetVal strides, TargetVal byte_offset)
+ : data(data.uint32()),
+ ctx(ctx),
+ ndim(ndim),
+ pad0(0),
+ dtype(dtype),
+ shape(shape.uint32()),
+ strides(strides.uint32()),
+ pad1(0),
+ byte_offset(byte_offset.uint32()),
+ pad2(0) {}
/*!
* \brief The opaque data pointer points to the allocated data.
/*! \brief TVM array for serialization to 64-bit devices */
struct TVMArray64 {
- TVMArray64(
- TargetVal data,
- DLContext ctx,
- int32_t ndim,
- DLDataType dtype,
- TargetVal shape,
- TargetVal strides,
- TargetVal byte_offset)
- : data(data.uint64()),
- ctx(ctx),
- ndim(ndim),
- pad0(0),
- dtype(dtype),
- shape(shape.uint64()),
- strides(strides.uint64()),
- byte_offset(byte_offset.uint64()) { }
+ TVMArray64(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape,
+ TargetVal strides, TargetVal byte_offset)
+ : data(data.uint64()),
+ ctx(ctx),
+ ndim(ndim),
+ pad0(0),
+ dtype(dtype),
+ shape(shape.uint64()),
+ strides(strides.uint64()),
+ byte_offset(byte_offset.uint64()) {}
/*!
* \brief The opaque data pointer points to the allocated data.
* This will be CUDA device pointer or cl_mem handle in OpenCL.
/*! \brief MicroTVM task for serialization to 32-bit devices */
typedef struct StructUTVMTask32 {
StructUTVMTask32(DevTask task)
- : func(task.func.uint32()),
- arg_values(task.arg_values.uint32()),
- arg_type_codes(task.arg_type_codes.uint32()),
- num_args(task.num_args) { }
+ : func(task.func.uint32()),
+ arg_values(task.arg_values.uint32()),
+ arg_type_codes(task.arg_type_codes.uint32()),
+ num_args(task.num_args) {}
/*! \brief Pointer to function to call for this task */
uint32_t func;
/*! \brief MicroTVM task for serialization to 64-bit devices */
typedef struct StructUTVMTask64 {
StructUTVMTask64(DevTask task)
- : func(task.func.uint64()),
- arg_values(task.arg_values.uint64()),
- arg_type_codes(task.arg_type_codes.uint64()),
- num_args(task.num_args) { }
+ : func(task.func.uint64()),
+ arg_values(task.arg_values.uint64()),
+ arg_type_codes(task.arg_type_codes.uint64()),
+ num_args(task.num_args) {}
/*! \brief Pointer to function to call for this task */
uint64_t func;
#include <iomanip>
#include <sstream>
-#include "micro_common.h"
#include "low_level_device.h"
+#include "micro_common.h"
#include "tcl_socket.h"
namespace tvm {
* \param server_addr address of the OpenOCD server to connect to
* \param port port of the OpenOCD server to connect to
*/
- explicit OpenOCDLowLevelDevice(const std::string& server_addr,
- int port) : socket_() {
+ explicit OpenOCDLowLevelDevice(const std::string& server_addr, int port) : socket_() {
server_addr_ = server_addr;
port_ = port;
socket_.cmd_builder() << "array unset output";
socket_.SendCommand();
- socket_.cmd_builder()
- << "mem2array output"
- << " " << std::dec << kWordSize
- << " " << addr.cast_to<void*>()
- // Round up any request sizes under a byte, since OpenOCD doesn't support
- // sub-byte-sized transfers.
- << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes);
+ socket_.cmd_builder() << "mem2array output"
+ << " " << std::dec << kWordSize << " "
+ << addr.cast_to<void*>()
+ // Round up any request sizes under a byte, since OpenOCD doesn't
+ // support sub-byte-sized transfers.
+ << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes);
socket_.SendCommand();
}
// The response from this command pairs indices with the contents of the
// memory at that index.
values >> index;
- CHECK(index < num_bytes)
- << "index " << index <<
- " out of bounds (length " << num_bytes << ")";
+ CHECK(index < num_bytes) << "index " << index << " out of bounds (length " << num_bytes
+ << ")";
// Read the value into `curr_val`, instead of reading directly into
// `buf_iter`, because otherwise it's interpreted as the ASCII value and
// not the integral value.
socket_.SendCommand();
}
{
- socket_.cmd_builder()
- << "array2mem input"
- << " " << std::dec << kWordSize
- << " " << addr.cast_to<void*>()
- << " " << std::dec << num_bytes;
+ socket_.cmd_builder() << "array2mem input"
+ << " " << std::dec << kWordSize << " " << addr.cast_to<void*>() << " "
+ << std::dec << num_bytes;
socket_.SendCommand();
}
}
socket_.SendCommand();
}
- const char* device_type() const final {
- return "openocd";
- }
+ const char* device_type() const final { return "openocd"; }
private:
/*! \brief socket used to communicate with the device through Tcl */
const std::shared_ptr<LowLevelDevice> OpenOCDLowLevelDeviceCreate(const std::string& server_addr,
int port) {
- std::shared_ptr<LowLevelDevice> lld =
- std::make_shared<OpenOCDLowLevelDevice>(server_addr, port);
+ std::shared_ptr<LowLevelDevice> lld = std::make_shared<OpenOCDLowLevelDevice>(server_addr, port);
return lld;
}
namespace tvm {
namespace micro {
-
// A minimal wrapper, derived from https://github.com/Robbepop/dynarray/, that
// supports a minimal subset of the std::vector API with a minimized code size.
template <typename T>
#include "utvm_graph_runtime.h"
#include <dlfcn.h>
+
#include <cassert>
#include <string>
+
#include "picojson.h"
namespace tvm {
* specific language governing permissions and limitations
* under the License.
*/
+#include "tvm/runtime/micro/standalone/utvm_runtime.h"
+
#include <cassert>
-#include "tvm/runtime/micro/standalone/utvm_runtime.h"
#include "utvm_graph_runtime.h"
void* UTVMRuntimeCreate(const char* json, size_t json_len, void* module) {
- return new tvm::micro::MicroGraphRuntime(
- std::string(json, json + json_len),
- reinterpret_cast<tvm::micro::DSOModule*>(module));
+ return new tvm::micro::MicroGraphRuntime(std::string(json, json + json_len),
+ reinterpret_cast<tvm::micro::DSOModule*>(module));
}
void UTVMRuntimeDestroy(void* handle) {
#include "utvm_runtime_api.h"
#include <stdlib.h>
+
#include <cassert>
#include <string>
#include <stdint.h>
#include <stdlib.h>
+
#include <cassert>
// The subset of the TVM runtime API that is implemented by the minimal runtime API.
#define TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_
#include <vector>
+
#include "host_driven/utvm_runtime.h"
namespace tvm {
* \param start_addr start address of the encoder in device memory
*/
explicit TargetDataLayoutEncoder(size_t capacity, TargetWordSize word_size)
- : buf_(std::vector<uint8_t>()), curr_offset_(0),
+ : buf_(std::vector<uint8_t>()),
+ curr_offset_(0),
start_addr_(word_size, nullptr),
- capacity_(capacity), word_size_(word_size) {
- }
+ capacity_(capacity),
+ word_size_(word_size) {}
/*!
* \brief allocates a slot for `sizeof(T) * num_elems` bytes of data
* \brief returns the array backing the encoder's buffer
* \return array backing the encoder's buffer
*/
- uint8_t* data() {
- return buf_.data();
- }
+ uint8_t* data() { return buf_.data(); }
/*!
* \brief returns current size of the encoder's buffer
* \return buffer size
*/
- size_t buf_size() const {
- return buf_.size();
- }
+ size_t buf_size() const { return buf_.size(); }
TargetPtr start_addr() const {
CHECK_NE(start_addr_.value().uint64(), 0) << "start addr uninitialized";
void set_start_addr(TargetPtr start_addr) {
CHECK_EQ(buf_.size(), 0) << "cannot change encoder start addr unless empty";
- start_addr_ = TargetPtr(word_size_,
- UpperAlignValue(start_addr.value().uint64(), word_size_.bytes()));
+ start_addr_ =
+ TargetPtr(word_size_, UpperAlignValue(start_addr.value().uint64(), word_size_.bytes()));
}
private:
};
template <typename T>
-TargetDataLayoutEncoder::Slot<T>::Slot(TargetDataLayoutEncoder* parent,
- size_t start_offset,
- size_t size,
- TargetPtr start_addr)
+TargetDataLayoutEncoder::Slot<T>::Slot(TargetDataLayoutEncoder* parent, size_t start_offset,
+ size_t size, TargetPtr start_addr)
: parent_(parent),
start_offset_(start_offset),
curr_offset_(0),
TargetDataLayoutEncoder::Slot<T>::~Slot() {
// TODO(weberlo, areusch): this can mask the exception thrown by slot allocation... even though
// that doesn't make sense.
- CHECK(curr_offset_ == size_) << "unwritten space in slot; curr_offset="
- << curr_offset_ << ", size=" << size_;
+ CHECK(curr_offset_ == size_) << "unwritten space in slot; curr_offset=" << curr_offset_
+ << ", size=" << size_;
}
template <typename T>
/*!
* \file tcl_socket.cc
*/
-#include <string>
-
#include "tcl_socket.h"
+#include <string>
+
namespace tvm {
namespace runtime {
reply_buf_.reserve(kReplyBufSize);
}
-TclSocket::~TclSocket() {
- tcp_socket_.Close();
-}
+TclSocket::~TclSocket() { tcp_socket_.Close(); }
void TclSocket::Connect(tvm::support::SockAddr addr) {
CHECK(tcp_socket_.Connect(addr)) << "failed to connect";
cmd_builder_ << terminate_token;
std::string full_cmd = cmd_builder_.str();
- CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1)
- << "failed to send command";
+ CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1) << "failed to send command";
cmd_builder_.str(std::string());
reply_builder_.str(std::string());
CHECK(bytes_read != -1) << "failed to read command reply";
} while (last_read != terminate_token);
last_reply_ = reply_builder_.str();
- CHECK_EQ(last_reply_[last_reply_.length()-1], terminate_token)
- << "missing command terminator";
+ CHECK_EQ(last_reply_[last_reply_.length() - 1], terminate_token) << "missing command terminator";
}
} // namespace runtime
/*
* \return string stream for current command being built
- */
+ */
std::ostringstream& cmd_builder() { return cmd_builder_; }
/*
* \return reply from most recently sent command
- */
+ */
const std::string& last_reply() { return last_reply_; }
private:
* \brief TVM module system
*/
#include <tvm/runtime/module.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
-#include <unordered_set>
+#include <tvm/runtime/registry.h>
+
#include <cstring>
+#include <unordered_set>
+
#include "file_util.h"
namespace tvm {
stack.push_back(next);
}
}
- CHECK(!visited.count(this))
- << "Cyclic dependency detected during import";
+ CHECK(!visited.count(this)) << "Cyclic dependency detected during import";
this->imports_.emplace_back(std::move(other));
}
return pf;
}
-Module Module::LoadFromFile(const std::string& file_name,
- const std::string& format) {
+Module Module::LoadFromFile(const std::string& file_name, const std::string& format) {
std::string fmt = GetFileFormat(file_name, format);
- CHECK(fmt.length() != 0)
- << "Cannot deduce format of file " << file_name;
+ CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name;
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so";
}
std::string load_f_name = "runtime.module.loadfile_" + fmt;
const PackedFunc* f = Registry::Get(load_f_name);
- CHECK(f != nullptr)
- << "Loader of " << format << "("
- << load_f_name << ") is not presented.";
+ CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name << ") is not presented.";
Module m = (*f)(file_name, format);
return m;
}
-void ModuleNode::SaveToFile(const std::string& file_name,
- const std::string& format) {
+void ModuleNode::SaveToFile(const std::string& file_name, const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
}
}
if (pf == nullptr) {
const PackedFunc* f = Registry::Get(name);
- CHECK(f != nullptr)
- << "Cannot find function " << name
- << " in the imported modules or global registry";
+ CHECK(f != nullptr) << "Cannot find function " << name
+ << " in the imported modules or global registry";
return f;
} else {
import_cache_.insert(std::make_pair(name, std::make_shared<PackedFunc>(pf)));
return runtime::Registry::Get(f_name) != nullptr;
}
-TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled")
-.set_body_typed(RuntimeEnabled);
+TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled);
-TVM_REGISTER_GLOBAL("runtime.ModuleGetSource")
-.set_body_typed([](Module mod, std::string fmt) {
+TVM_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) {
return mod->GetSource(fmt);
});
-TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize")
-.set_body_typed([](Module mod) {
+TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) {
return static_cast<int64_t>(mod->imports().size());
});
-TVM_REGISTER_GLOBAL("runtime.ModuleGetImport")
-.set_body_typed([](Module mod, int index) {
+TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) {
return mod->imports().at(index);
});
-TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey")
-.set_body_typed([](Module mod) {
+TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
return std::string(mod->type_key());
});
-TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile")
-.set_body_typed(Module::LoadFromFile);
+TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);
TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
-.set_body_typed([](Module mod, std::string name, std::string fmt) {
- mod->SaveToFile(name, fmt);
-});
+ .set_body_typed([](Module mod, std::string name, std::string fmt) {
+ mod->SaveToFile(name, fmt);
+ });
TVM_REGISTER_OBJECT_TYPE(ModuleNode);
} // namespace runtime
* \brief NDArray container infratructure.
*/
#include <dmlc/logging.h>
-#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/ndarray.h>
+
#include "runtime_base.h"
extern "C" {
// allow uint1 as a special flag for bool.
if (dtype.bits == 1 && dtype.code == kDLUInt) return;
// allow int1/uint4/int4
- else if (dtype.bits == 1 && dtype.code == kDLInt) return;
- else if (dtype.bits == 4 && dtype.code == kDLUInt) return;
- else if (dtype.bits == 4 && dtype.code == kDLInt) return;
+ else if (dtype.bits == 1 && dtype.code == kDLInt)
+ return;
+ else if (dtype.bits == 4 && dtype.code == kDLUInt)
+ return;
+ else if (dtype.bits == 4 && dtype.code == kDLInt)
+ return;
else
CHECK_EQ(dtype.bits % 8, 0);
}
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
- CHECK_EQ(arr_size, nbytes)
- << "ArrayCopyFromBytes: size mismatch";
- DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
- data, 0,
- handle->data, static_cast<size_t>(handle->byte_offset),
- nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr);
+ CHECK_EQ(arr_size, nbytes) << "ArrayCopyFromBytes: size mismatch";
+ DeviceAPI::Get(handle->ctx)
+ ->CopyDataFromTo(data, 0, handle->data, static_cast<size_t>(handle->byte_offset), nbytes,
+ cpu_ctx, handle->ctx, handle->dtype, nullptr);
}
void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) {
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
- CHECK_EQ(arr_size, nbytes)
- << "ArrayCopyToBytes: size mismatch";
- DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
- handle->data, static_cast<size_t>(handle->byte_offset),
- data, 0,
- nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr);
+ CHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch";
+ DeviceAPI::Get(handle->ctx)
+ ->CopyDataFromTo(handle->data, static_cast<size_t>(handle->byte_offset), data, 0, nbytes,
+ handle->ctx, cpu_ctx, handle->dtype, nullptr);
}
struct NDArray::Internal {
if (ptr->manager_ctx != nullptr) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
} else if (ptr->dl_tensor.data != nullptr) {
- tvm::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace(
- ptr->dl_tensor.ctx, ptr->dl_tensor.data);
+ tvm::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)
+ ->FreeDataSpace(ptr->dl_tensor.ctx, ptr->dl_tensor.data);
}
delete ptr;
}
}
// Local create function which allocates tensor metadata
// but does not allocate space for the data.
- static NDArray Create(std::vector<int64_t> shape,
- DLDataType dtype,
- DLContext ctx) {
+ static NDArray Create(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
VerifyDataType(dtype);
// critical zone: construct header
ObjectRef::FFIClearAfterMove(&arr);
return handle;
}
- static void FFIDecRef(TVMArrayHandle tensor) {
- NDArray::FFIDecRef(tensor);
- }
+ static void FFIDecRef(TVMArrayHandle tensor) { NDArray::FFIDecRef(tensor); }
// Container to DLManagedTensor
static DLManagedTensor* ToDLPack(TVMArrayHandle handle) {
- auto* from = static_cast<NDArray::Container*>(
- reinterpret_cast<NDArray::ContainerBase*>(handle));
+ auto* from =
+ static_cast<NDArray::Container*>(reinterpret_cast<NDArray::ContainerBase*>(handle));
return ToDLPack(from);
}
NDArray NDArray::CreateView(std::vector<int64_t> shape, DLDataType dtype) {
CHECK(data_ != nullptr);
- CHECK(get_mutable()->dl_tensor.strides == nullptr)
- << "Can only create view for compact tensor";
+ CHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx);
- ret.get_mutable()->dl_tensor.byte_offset =
- this->get_mutable()->dl_tensor.byte_offset;
+ ret.get_mutable()->dl_tensor.byte_offset = this->get_mutable()->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor);
size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor);
CHECK_LE(view_size, curr_size)
return ret;
}
-DLManagedTensor* NDArray::ToDLPack() const {
- return Internal::ToDLPack(get_mutable());
-}
+DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); }
-NDArray NDArray::Empty(std::vector<int64_t> shape,
- DLDataType dtype,
- DLContext ctx) {
+NDArray NDArray::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content
size_t size = GetDataSize(ret.get_mutable()->dl_tensor);
size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor);
ret.get_mutable()->dl_tensor.data =
- DeviceAPI::Get(ret->ctx)->AllocDataSpace(
- ret->ctx, size, alignment, ret->dtype);
+ DeviceAPI::Get(ret->ctx)->AllocDataSpace(ret->ctx, size, alignment, ret->dtype);
return ret;
}
ArrayCopyFromBytes(&get_mutable()->dl_tensor, data, nbytes);
}
-void NDArray::CopyFromTo(const DLTensor* from,
- DLTensor* to,
- TVMStreamHandle stream) {
+void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
size_t from_size = GetDataSize(*from);
size_t to_size = GetDataSize(*to);
- CHECK_EQ(from_size, to_size)
- << "TVMArrayCopyFromTo: The size must exactly match";
+ CHECK_EQ(from_size, to_size) << "TVMArrayCopyFromTo: The size must exactly match";
- CHECK(from->ctx.device_type == to->ctx.device_type
- || from->ctx.device_type == kDLCPU
- || to->ctx.device_type == kDLCPU
- || from->ctx.device_type == kDLCPUPinned
- || to->ctx.device_type == kDLCPUPinned)
- << "Can not copy across different ctx types directly";
+ CHECK(from->ctx.device_type == to->ctx.device_type || from->ctx.device_type == kDLCPU ||
+ to->ctx.device_type == kDLCPU || from->ctx.device_type == kDLCPUPinned ||
+ to->ctx.device_type == kDLCPUPinned)
+ << "Can not copy across different ctx types directly";
// Use the context that is *not* a cpu context to get the correct device
// api manager.
TVMContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx;
- DeviceAPI::Get(ctx)->CopyDataFromTo(
- from->data, static_cast<size_t>(from->byte_offset),
- to->data, static_cast<size_t>(to->byte_offset),
- from_size, from->ctx, to->ctx, from->dtype, stream);
+ DeviceAPI::Get(ctx)->CopyDataFromTo(from->data, static_cast<size_t>(from->byte_offset), to->data,
+ static_cast<size_t>(to->byte_offset), from_size, from->ctx,
+ to->ctx, from->dtype, stream);
}
-std::vector<int64_t> NDArray::Shape() const {
- return get_mutable()->shape_;
-}
+std::vector<int64_t> NDArray::Shape() const { return get_mutable()->shape_; }
TVM_REGISTER_OBJECT_TYPE(NDArray::Container);
API_END();
}
-int TVMArrayAlloc(const tvm_index_t* shape,
- int ndim,
- int dtype_code,
- int dtype_bits,
- int dtype_lanes,
- int device_type,
- int device_id,
- TVMArrayHandle* out) {
+int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits,
+ int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) {
API_BEGIN();
DLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code);
API_END();
}
-int TVMArrayCopyFromTo(TVMArrayHandle from,
- TVMArrayHandle to,
- TVMStreamHandle stream) {
+int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream) {
API_BEGIN();
NDArray::CopyFromTo(from, to, stream);
API_END();
}
-int TVMArrayFromDLPack(DLManagedTensor* from,
- TVMArrayHandle* out) {
+int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from));
API_END();
}
-int TVMArrayToDLPack(TVMArrayHandle from,
- DLManagedTensor** out) {
+int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out) {
API_BEGIN();
*out = NDArray::Internal::ToDLPack(from);
API_END();
}
-void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) {
- (*(dltensor->deleter))(dltensor);
-}
+void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) { (*(dltensor->deleter))(dltensor); }
-int TVMArrayCopyFromBytes(TVMArrayHandle handle,
- void* data,
- size_t nbytes) {
+int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) {
API_BEGIN();
ArrayCopyFromBytes(handle, data, nbytes);
API_END();
}
-int TVMArrayCopyToBytes(TVMArrayHandle handle,
- void* data,
- size_t nbytes) {
+int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes) {
API_BEGIN();
ArrayCopyToBytes(handle, data, nbytes);
API_END();
* \brief Object type management system.
*/
#include <dmlc/logging.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/object.h>
+#include <tvm/runtime/registry.h>
+
#include <mutex>
#include <string>
-#include <vector>
-#include <utility>
#include <unordered_map>
+#include <utility>
+#include <vector>
+
#include "object_internal.h"
#include "runtime_base.h"
return child_tindex == parent_tindex;
}
- uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey,
- uint32_t static_tindex,
- uint32_t parent_tindex,
- uint32_t num_child_slots,
+ uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex,
+ uint32_t parent_tindex, uint32_t num_child_slots,
bool child_slots_can_overflow) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = type_key2index_.find(skey);
allocated_tindex = static_tindex;
CHECK_LT(static_tindex, type_table_.size());
CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
- << "Conflicting static index " << static_tindex
- << " between " << type_table_[allocated_tindex].name
- << " and "
- << skey;
+ << "Conflicting static index " << static_tindex << " between "
+ << type_table_[allocated_tindex].name << " and " << skey;
} else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
// allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots;
type_table_[allocated_tindex].parent_index = parent_tindex;
type_table_[allocated_tindex].num_slots = num_slots;
type_table_[allocated_tindex].allocated_slots = 1;
- type_table_[allocated_tindex].child_slots_can_overflow =
- child_slots_can_overflow;
+ type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow;
type_table_[allocated_tindex].name = skey;
type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey);
// update the key2index mapping.
std::string TypeIndex2Key(uint32_t tindex) {
std::lock_guard<std::mutex> lock(mutex_);
- CHECK(tindex < type_table_.size() &&
- type_table_[tindex].allocated_slots != 0)
+ CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0)
<< "Unknown type index " << tindex;
return type_table_[tindex].name;
}
size_t TypeIndex2KeyHash(uint32_t tindex) {
std::lock_guard<std::mutex> lock(mutex_);
- CHECK(tindex < type_table_.size() &&
- type_table_[tindex].allocated_slots != 0)
+ CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0)
<< "Unknown type index " << tindex;
return type_table_[tindex].name_hash;
}
for (const auto& info : type_table_) {
if (info.index != 0 && num_children[info.index] >= min_children_count) {
- std::cerr <<'[' << info.index << "] "<< info.name
+ std::cerr << '[' << info.index << "] " << info.name
<< "\tparent=" << type_table_[info.parent_index].name
<< "\tnum_child_slots=" << info.num_slots - 1
<< "\tnum_children=" << num_children[info.index] << std::endl;
std::unordered_map<std::string, uint32_t> type_key2index_;
};
-uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key,
- uint32_t static_tindex,
- uint32_t parent_tindex,
- uint32_t num_child_slots,
+uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex,
+ uint32_t parent_tindex, uint32_t num_child_slots,
bool child_slots_can_overflow) {
return TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow);
}
bool Object::DerivedFrom(uint32_t parent_tindex) const {
- return TypeContext::Global()->DerivedFrom(
- this->type_index_, parent_tindex);
+ return TypeContext::Global()->DerivedFrom(this->type_index_, parent_tindex);
}
std::string Object::TypeIndex2Key(uint32_t tindex) {
return TypeContext::Global()->TypeKey2Index(key);
}
-
-TVM_REGISTER_GLOBAL("runtime.ObjectHash")
-.set_body_typed([](ObjectRef obj) {
+TVM_REGISTER_GLOBAL("runtime.ObjectHash").set_body_typed([](ObjectRef obj) {
return static_cast<int64_t>(ObjectHash()(obj));
});
-TVM_REGISTER_GLOBAL("runtime.DumpTypeTable")
-.set_body_typed([](int min_child_count) {
+TVM_REGISTER_GLOBAL("runtime.DumpTypeTable").set_body_typed([](int min_child_count) {
TypeContext::Global()->Dump(min_child_count);
});
} // namespace runtime
int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
API_BEGIN();
- out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(
- type_key);
+ out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key);
API_END();
}
#ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_
#define TVM_RUNTIME_OBJECT_INTERNAL_H_
-#include <tvm/runtime/object.h>
#include <tvm/runtime/module.h>
+#include <tvm/runtime/object.h>
+
#include <string>
namespace tvm {
} // namespace runtime
} // namespace tvm
-#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_
+#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_
#define TVM_RUNTIME_OPENCL_AOCL_AOCL_COMMON_H_
#include <memory>
+
#include "../opencl_common.h"
namespace tvm {
static const std::shared_ptr<OpenCLWorkspace>& Global();
};
-
/*! \brief Thread local workspace for AOCL */
class AOCLThreadEntry : public OpenCLThreadEntry {
public:
/*!
* \file aocl_device_api.cc
*/
-#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
+#include <tvm/runtime/registry.h>
+
#include "aocl_common.h"
namespace tvm {
namespace runtime {
namespace cl {
-OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() {
- return AOCLThreadEntry::ThreadLocal();
-}
+OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { return AOCLThreadEntry::ThreadLocal(); }
const std::shared_ptr<OpenCLWorkspace>& AOCLWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<AOCLWorkspace>();
typedef dmlc::ThreadLocalStore<AOCLThreadEntry> AOCLThreadStore;
-AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() {
- return AOCLThreadStore::Get();
-}
+AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { return AOCLThreadStore::Get(); }
-TVM_REGISTER_GLOBAL("device_api.aocl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = AOCLWorkspace::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.aocl").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = AOCLWorkspace::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
} // namespace cl
} // namespace runtime
/*!
* \file aocl_module.cc
*/
+#include "aocl_module.h"
+
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
-#include <vector>
+
#include <string>
#include <unordered_map>
+#include <vector>
+
#include "aocl_common.h"
-#include "aocl_module.h"
namespace tvm {
namespace runtime {
class AOCLModuleNode : public OpenCLModuleNode {
public:
- explicit AOCLModuleNode(std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source)
+ explicit AOCLModuleNode(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: OpenCLModuleNode(data, fmt, fmap, source) {}
const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
};
return cl::AOCLWorkspace::Global();
}
-Module AOCLModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source) {
+Module AOCLModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
auto n = make_object<AOCLModuleNode>(data, fmt, fmap, source);
n->Init();
return Module(n);
}
-Module AOCLModuleLoadFile(const std::string& file_name,
- const std::string& format) {
+Module AOCLModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
return AOCLModuleCreate(data, fmt, fmap, std::string());
}
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_aocx")
-.set_body_typed(AOCLModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_aocx").set_body_typed(AOCLModuleLoadFile);
} // namespace runtime
} // namespace tvm
#define TVM_RUNTIME_OPENCL_AOCL_AOCL_MODULE_H_
#include <tvm/runtime/packed_func.h>
+
#include <memory>
-#include <vector>
#include <string>
#include <unordered_map>
+#include <vector>
+
#include "../../meta_data.h"
namespace tvm {
* \param fmt The format of the data, can be "aocx"
* \param fmap The map function information map of each function.
*/
-Module AOCLModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source);
+Module AOCLModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OPENCL_AOCL_AOCL_MODULE_H_
#ifndef TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_
#define TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_
+#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
-#include <dmlc/logging.h>
+#include <tvm/runtime/packed_func.h>
/* There are many OpenCL platforms that do not yet support OpenCL 2.0,
* hence we use 1.2 APIs, some of which are now deprecated. In order
#include <CL/opencl.h>
#endif
+#include <memory>
#include <mutex>
#include <string>
-#include <vector>
-#include <memory>
#include <unordered_map>
-#include "../workspace_pool.h"
+#include <vector>
+
+#include "../file_util.h"
+#include "../meta_data.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
-#include "../meta_data.h"
-#include "../file_util.h"
+#include "../workspace_pool.h"
namespace tvm {
namespace runtime {
namespace cl {
-static_assert(sizeof(cl_mem) ==sizeof(void*),
- "Required to store cl_mem inside void*");
+static_assert(sizeof(cl_mem) == sizeof(void*), "Required to store cl_mem inside void*");
inline const char* CLGetErrorString(cl_int error) {
switch (error) {
- case CL_SUCCESS: return "CL_SUCCESS";
- case CL_DEVICE_NOT_FOUND: return "CL_DEVICE_NOT_FOUND";
- case CL_DEVICE_NOT_AVAILABLE: return "CL_DEVICE_NOT_AVAILABLE";
- case CL_COMPILER_NOT_AVAILABLE: return "CL_COMPILER_NOT_AVAILABLE";
- case CL_MEM_OBJECT_ALLOCATION_FAILURE: return "CL_MEM_OBJECT_ALLOCATION_FAILURE";
- case CL_OUT_OF_RESOURCES: return "CL_OUT_OF_RESOURCES";
- case CL_OUT_OF_HOST_MEMORY: return "CL_OUT_OF_HOST_MEMORY";
- case CL_PROFILING_INFO_NOT_AVAILABLE: return "CL_PROFILING_INFO_NOT_AVAILABLE";
- case CL_MEM_COPY_OVERLAP: return "CL_MEM_COPY_OVERLAP";
- case CL_IMAGE_FORMAT_MISMATCH: return "CL_IMAGE_FORMAT_MISMATCH";
- case CL_IMAGE_FORMAT_NOT_SUPPORTED: return "CL_IMAGE_FORMAT_NOT_SUPPORTED";
- case CL_BUILD_PROGRAM_FAILURE: return "CL_BUILD_PROGRAM_FAILURE";
- case CL_MAP_FAILURE: return "CL_MAP_FAILURE";
- case CL_INVALID_VALUE: return "CL_INVALID_VALUE";
- case CL_INVALID_DEVICE_TYPE: return "CL_INVALID_DEVICE_TYPE";
- case CL_INVALID_PLATFORM: return "CL_INVALID_PLATFORM";
- case CL_INVALID_DEVICE: return "CL_INVALID_DEVICE";
- case CL_INVALID_CONTEXT: return "CL_INVALID_CONTEXT";
- case CL_INVALID_QUEUE_PROPERTIES: return "CL_INVALID_QUEUE_PROPERTIES";
- case CL_INVALID_COMMAND_QUEUE: return "CL_INVALID_COMMAND_QUEUE";
- case CL_INVALID_HOST_PTR: return "CL_INVALID_HOST_PTR";
- case CL_INVALID_MEM_OBJECT: return "CL_INVALID_MEM_OBJECT";
- case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR";
- case CL_INVALID_IMAGE_SIZE: return "CL_INVALID_IMAGE_SIZE";
- case CL_INVALID_SAMPLER: return "CL_INVALID_SAMPLER";
- case CL_INVALID_BINARY: return "CL_INVALID_BINARY";
- case CL_INVALID_BUILD_OPTIONS: return "CL_INVALID_BUILD_OPTIONS";
- case CL_INVALID_PROGRAM: return "CL_INVALID_PROGRAM";
- case CL_INVALID_PROGRAM_EXECUTABLE: return "CL_INVALID_PROGRAM_EXECUTABLE";
- case CL_INVALID_KERNEL_NAME: return "CL_INVALID_KERNEL_NAME";
- case CL_INVALID_KERNEL_DEFINITION: return "CL_INVALID_KERNEL_DEFINITION";
- case CL_INVALID_KERNEL: return "CL_INVALID_KERNEL";
- case CL_INVALID_ARG_INDEX: return "CL_INVALID_ARG_INDEX";
- case CL_INVALID_ARG_VALUE: return "CL_INVALID_ARG_VALUE";
- case CL_INVALID_ARG_SIZE: return "CL_INVALID_ARG_SIZE";
- case CL_INVALID_KERNEL_ARGS: return "CL_INVALID_KERNEL_ARGS";
- case CL_INVALID_WORK_DIMENSION: return "CL_INVALID_WORK_DIMENSION";
- case CL_INVALID_WORK_GROUP_SIZE: return "CL_INVALID_WORK_GROUP_SIZE";
- case CL_INVALID_WORK_ITEM_SIZE: return "CL_INVALID_WORK_ITEM_SIZE";
- case CL_INVALID_GLOBAL_OFFSET: return "CL_INVALID_GLOBAL_OFFSET";
- case CL_INVALID_EVENT_WAIT_LIST: return "CL_INVALID_EVENT_WAIT_LIST";
- case CL_INVALID_EVENT: return "CL_INVALID_EVENT";
- case CL_INVALID_OPERATION: return "CL_INVALID_OPERATION";
- case CL_INVALID_GL_OBJECT: return "CL_INVALID_GL_OBJECT";
- case CL_INVALID_BUFFER_SIZE: return "CL_INVALID_BUFFER_SIZE";
- case CL_INVALID_MIP_LEVEL: return "CL_INVALID_MIP_LEVEL";
- default: return "Unknown OpenCL error code";
+ case CL_SUCCESS:
+ return "CL_SUCCESS";
+ case CL_DEVICE_NOT_FOUND:
+ return "CL_DEVICE_NOT_FOUND";
+ case CL_DEVICE_NOT_AVAILABLE:
+ return "CL_DEVICE_NOT_AVAILABLE";
+ case CL_COMPILER_NOT_AVAILABLE:
+ return "CL_COMPILER_NOT_AVAILABLE";
+ case CL_MEM_OBJECT_ALLOCATION_FAILURE:
+ return "CL_MEM_OBJECT_ALLOCATION_FAILURE";
+ case CL_OUT_OF_RESOURCES:
+ return "CL_OUT_OF_RESOURCES";
+ case CL_OUT_OF_HOST_MEMORY:
+ return "CL_OUT_OF_HOST_MEMORY";
+ case CL_PROFILING_INFO_NOT_AVAILABLE:
+ return "CL_PROFILING_INFO_NOT_AVAILABLE";
+ case CL_MEM_COPY_OVERLAP:
+ return "CL_MEM_COPY_OVERLAP";
+ case CL_IMAGE_FORMAT_MISMATCH:
+ return "CL_IMAGE_FORMAT_MISMATCH";
+ case CL_IMAGE_FORMAT_NOT_SUPPORTED:
+ return "CL_IMAGE_FORMAT_NOT_SUPPORTED";
+ case CL_BUILD_PROGRAM_FAILURE:
+ return "CL_BUILD_PROGRAM_FAILURE";
+ case CL_MAP_FAILURE:
+ return "CL_MAP_FAILURE";
+ case CL_INVALID_VALUE:
+ return "CL_INVALID_VALUE";
+ case CL_INVALID_DEVICE_TYPE:
+ return "CL_INVALID_DEVICE_TYPE";
+ case CL_INVALID_PLATFORM:
+ return "CL_INVALID_PLATFORM";
+ case CL_INVALID_DEVICE:
+ return "CL_INVALID_DEVICE";
+ case CL_INVALID_CONTEXT:
+ return "CL_INVALID_CONTEXT";
+ case CL_INVALID_QUEUE_PROPERTIES:
+ return "CL_INVALID_QUEUE_PROPERTIES";
+ case CL_INVALID_COMMAND_QUEUE:
+ return "CL_INVALID_COMMAND_QUEUE";
+ case CL_INVALID_HOST_PTR:
+ return "CL_INVALID_HOST_PTR";
+ case CL_INVALID_MEM_OBJECT:
+ return "CL_INVALID_MEM_OBJECT";
+ case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR:
+ return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR";
+ case CL_INVALID_IMAGE_SIZE:
+ return "CL_INVALID_IMAGE_SIZE";
+ case CL_INVALID_SAMPLER:
+ return "CL_INVALID_SAMPLER";
+ case CL_INVALID_BINARY:
+ return "CL_INVALID_BINARY";
+ case CL_INVALID_BUILD_OPTIONS:
+ return "CL_INVALID_BUILD_OPTIONS";
+ case CL_INVALID_PROGRAM:
+ return "CL_INVALID_PROGRAM";
+ case CL_INVALID_PROGRAM_EXECUTABLE:
+ return "CL_INVALID_PROGRAM_EXECUTABLE";
+ case CL_INVALID_KERNEL_NAME:
+ return "CL_INVALID_KERNEL_NAME";
+ case CL_INVALID_KERNEL_DEFINITION:
+ return "CL_INVALID_KERNEL_DEFINITION";
+ case CL_INVALID_KERNEL:
+ return "CL_INVALID_KERNEL";
+ case CL_INVALID_ARG_INDEX:
+ return "CL_INVALID_ARG_INDEX";
+ case CL_INVALID_ARG_VALUE:
+ return "CL_INVALID_ARG_VALUE";
+ case CL_INVALID_ARG_SIZE:
+ return "CL_INVALID_ARG_SIZE";
+ case CL_INVALID_KERNEL_ARGS:
+ return "CL_INVALID_KERNEL_ARGS";
+ case CL_INVALID_WORK_DIMENSION:
+ return "CL_INVALID_WORK_DIMENSION";
+ case CL_INVALID_WORK_GROUP_SIZE:
+ return "CL_INVALID_WORK_GROUP_SIZE";
+ case CL_INVALID_WORK_ITEM_SIZE:
+ return "CL_INVALID_WORK_ITEM_SIZE";
+ case CL_INVALID_GLOBAL_OFFSET:
+ return "CL_INVALID_GLOBAL_OFFSET";
+ case CL_INVALID_EVENT_WAIT_LIST:
+ return "CL_INVALID_EVENT_WAIT_LIST";
+ case CL_INVALID_EVENT:
+ return "CL_INVALID_EVENT";
+ case CL_INVALID_OPERATION:
+ return "CL_INVALID_OPERATION";
+ case CL_INVALID_GL_OBJECT:
+ return "CL_INVALID_GL_OBJECT";
+ case CL_INVALID_BUFFER_SIZE:
+ return "CL_INVALID_BUFFER_SIZE";
+ case CL_INVALID_MIP_LEVEL:
+ return "CL_INVALID_MIP_LEVEL";
+ default:
+ return "Unknown OpenCL error code";
}
}
* \brief Protected OpenCL call
* \param func Expression to call.
*/
-#define OPENCL_CHECK_ERROR(e) \
- { \
- CHECK(e == CL_SUCCESS) \
- << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); \
- }
+#define OPENCL_CHECK_ERROR(e) \
+ { CHECK(e == CL_SUCCESS) << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); }
-#define OPENCL_CALL(func) \
- { \
- cl_int e = (func); \
- OPENCL_CHECK_ERROR(e); \
+#define OPENCL_CALL(func) \
+ { \
+ cl_int e = (func); \
+ OPENCL_CHECK_ERROR(e); \
}
class OpenCLThreadEntry;
// Initialzie the device.
void Init(const std::string& type_key, const std::string& device_type,
const std::string& platform_name = "");
- virtual void Init() {
- Init("opencl", "gpu");
- }
+ virtual void Init() { Init("opencl", "gpu"); }
// Check whether the context is OpenCL or not.
- virtual bool IsOpenCLDevice(TVMContext ctx) {
- return ctx.device_type == kDLOpenCL;
- }
+ virtual bool IsOpenCLDevice(TVMContext ctx) { return ctx.device_type == kDLOpenCL; }
// get the queue of the context
cl_command_queue GetQueue(TVMContext ctx) {
CHECK(IsOpenCLDevice(ctx));
this->Init();
- CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
+ CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid OpenCL device_id=" << ctx.device_id;
return queues[ctx.device_id];
}
// override device API
void SetDevice(TVMContext ctx) final;
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
- void* AllocDataSpace(TVMContext ctx,
- size_t size,
- size_t alignment,
- DLDataType type_hint) final;
+ void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
- void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
static const std::shared_ptr<OpenCLWorkspace>& Global();
};
-
/*! \brief Thread local workspace */
class OpenCLThreadEntry {
public:
context.device_id = 0;
context.device_type = device_type;
}
- OpenCLThreadEntry()
- : OpenCLThreadEntry(kDLOpenCL, OpenCLWorkspace::Global()) {}
+ OpenCLThreadEntry() : OpenCLThreadEntry(kDLOpenCL, OpenCLWorkspace::Global()) {}
// get the global workspace
static OpenCLThreadEntry* ThreadLocal();
size_t kernel_id;
size_t version;
};
- explicit OpenCLModuleNode(std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source)
+ explicit OpenCLModuleNode(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {}
// destructor
~OpenCLModuleNode();
const char* type_key() const final { return workspace_->type_key.c_str(); }
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final;
- void SaveToFile(const std::string& file_name,
- const std::string& format) final;
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
+ void SaveToFile(const std::string& file_name, const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
std::string GetSource(const std::string& format) final;
// Initialize the programs
void Init();
// install a new kernel to thread local entry
- cl_kernel InstallKernel(cl::OpenCLWorkspace* w,
- cl::OpenCLThreadEntry* t,
- const std::string& func_name,
- const KTRefEntry& e);
+ cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
+ const std::string& func_name, const KTRefEntry& e);
private:
// The workspace, need to keep reference to use it in destructor.
/*!
* \file opencl_device_api.cc
*/
-#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
+#include <tvm/runtime/registry.h>
+
#include "opencl_common.h"
namespace tvm {
namespace runtime {
namespace cl {
-OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() {
- return OpenCLThreadEntry::ThreadLocal();
-}
+OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); }
const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
GetThreadEntry()->context.device_id = ctx.device_id;
}
-void OpenCLWorkspace::GetAttr(
- TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
+void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init();
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
- *rv = static_cast<int>(index< devices.size());
+ *rv = static_cast<int>(index < devices.size());
return;
}
- CHECK_LT(index, devices.size())
- << "Invalid device id " << index;
+ CHECK_LT(index, devices.size()) << "Invalid device id " << index;
switch (kind) {
- case kExist: break;
+ case kExist:
+ break;
case kMaxThreadsPerBlock: {
size_t value;
- OPENCL_CALL(clGetDeviceInfo(
- devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE,
- sizeof(size_t), &value, nullptr));
+ OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t),
+ &value, nullptr));
*rv = static_cast<int64_t>(value);
break;
}
}
case kMaxSharedMemoryPerBlock: {
cl_ulong value;
- OPENCL_CALL(clGetDeviceInfo(
- devices[index], CL_DEVICE_LOCAL_MEM_SIZE,
- sizeof(cl_ulong), &value, nullptr));
+ OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_LOCAL_MEM_SIZE, sizeof(cl_ulong),
+ &value, nullptr));
*rv = static_cast<int64_t>(value);
break;
}
- case kComputeVersion: return;
+ case kComputeVersion:
+ return;
case kDeviceName: {
char value[128] = {0};
- OPENCL_CALL(clGetDeviceInfo(
- devices[index], CL_DEVICE_NAME,
- sizeof(value) - 1, value, nullptr));
+ OPENCL_CALL(
+ clGetDeviceInfo(devices[index], CL_DEVICE_NAME, sizeof(value) - 1, value, nullptr));
*rv = std::string(value);
break;
}
case kMaxClockRate: {
cl_uint value;
- OPENCL_CALL(clGetDeviceInfo(
- devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY,
- sizeof(cl_uint), &value, nullptr));
+ OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY, sizeof(cl_uint),
+ &value, nullptr));
*rv = static_cast<int32_t>(value);
break;
}
case kMultiProcessorCount: {
cl_uint value;
- OPENCL_CALL(clGetDeviceInfo(
- devices[index], CL_DEVICE_MAX_COMPUTE_UNITS,
- sizeof(cl_uint), &value, nullptr));
+ OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(cl_uint),
+ &value, nullptr));
*rv = static_cast<int32_t>(value);
break;
}
case kMaxThreadDimensions: {
size_t dims[3];
- OPENCL_CALL(clGetDeviceInfo(
- devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr));
+ OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims,
+ nullptr));
std::stringstream ss; // use json string to return multiple int values;
- ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
+ ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
break;
}
- case kGcnArch: return;
+ case kGcnArch:
+ return;
}
}
-void* OpenCLWorkspace::AllocDataSpace(
- TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) {
+void* OpenCLWorkspace::AllocDataSpace(TVMContext ctx, size_t size, size_t alignment,
+ DLDataType type_hint) {
this->Init();
CHECK(context != nullptr) << "No OpenCL device";
cl_int err_code;
- cl_mem mptr = clCreateBuffer(
- this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
+ cl_mem mptr = clCreateBuffer(this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
return mptr;
}
OPENCL_CALL(clReleaseMemObject(mptr));
}
-void OpenCLWorkspace::CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+void OpenCLWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
+ size_t to_offset, size_t size, TVMContext ctx_from,
+ TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) {
this->Init();
CHECK(stream == nullptr);
if (IsOpenCLDevice(ctx_from) && IsOpenCLDevice(ctx_to)) {
- OPENCL_CALL(clEnqueueCopyBuffer(
- this->GetQueue(ctx_to),
- static_cast<cl_mem>((void*)from), // NOLINT(*)
- static_cast<cl_mem>(to),
- from_offset, to_offset, size, 0, nullptr, nullptr));
+ OPENCL_CALL(clEnqueueCopyBuffer(this->GetQueue(ctx_to),
+ static_cast<cl_mem>((void*)from), // NOLINT(*)
+ static_cast<cl_mem>(to), from_offset, to_offset, size, 0,
+ nullptr, nullptr));
} else if (IsOpenCLDevice(ctx_from) && ctx_to.device_type == kDLCPU) {
- OPENCL_CALL(clEnqueueReadBuffer(
- this->GetQueue(ctx_from),
- static_cast<cl_mem>((void*)from), // NOLINT(*)
- CL_FALSE, from_offset, size,
- static_cast<char*>(to) + to_offset,
- 0, nullptr, nullptr));
+ OPENCL_CALL(clEnqueueReadBuffer(this->GetQueue(ctx_from),
+ static_cast<cl_mem>((void*)from), // NOLINT(*)
+ CL_FALSE, from_offset, size, static_cast<char*>(to) + to_offset,
+ 0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
} else if (ctx_from.device_type == kDLCPU && IsOpenCLDevice(ctx_to)) {
- OPENCL_CALL(clEnqueueWriteBuffer(
- this->GetQueue(ctx_to),
- static_cast<cl_mem>(to),
- CL_FALSE, to_offset, size,
- static_cast<const char*>(from) + from_offset,
- 0, nullptr, nullptr));
+ OPENCL_CALL(clEnqueueWriteBuffer(this->GetQueue(ctx_to), static_cast<cl_mem>(to), CL_FALSE,
+ to_offset, size, static_cast<const char*>(from) + from_offset,
+ 0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_to)));
} else {
LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL";
OPENCL_CALL(clFinish(this->GetQueue(ctx)));
}
-void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx,
- size_t size,
- DLDataType type_hint) {
+void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) {
return GetThreadEntry()->pool.AllocWorkspace(ctx, size);
}
typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore;
-OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() {
- return OpenCLThreadStore::Get();
-}
+OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { return OpenCLThreadStore::Get(); }
-std::string GetPlatformInfo(
- cl_platform_id pid, cl_platform_info param_name) {
+std::string GetPlatformInfo(cl_platform_id pid, cl_platform_info param_name) {
size_t ret_size;
OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size));
std::string ret;
return ret;
}
-std::string GetDeviceInfo(
- cl_device_id pid, cl_device_info param_name) {
+std::string GetDeviceInfo(cl_device_id pid, cl_device_info param_name) {
size_t ret_size;
OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size));
std::string ret;
return ret;
}
-std::vector<cl_device_id> GetDeviceIDs(
- cl_platform_id pid, std::string device_type) {
+std::vector<cl_device_id> GetDeviceIDs(cl_platform_id pid, std::string device_type) {
cl_device_type dtype = CL_DEVICE_TYPE_ALL;
if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU;
if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU;
return ret;
}
-bool MatchPlatformInfo(
- cl_platform_id pid,
- cl_platform_info param_name,
- std::string value) {
+bool MatchPlatformInfo(cl_platform_id pid, cl_platform_info param_name, std::string value) {
if (value.length() == 0) return true;
std::string param_value = GetPlatformInfo(pid, param_name);
return param_value.find(value) != std::string::npos;
return;
}
cl_int err_code;
- this->context = clCreateContext(
- nullptr, this->devices.size(), &(this->devices[0]),
- nullptr, nullptr, &err_code);
+ this->context = clCreateContext(nullptr, this->devices.size(), &(this->devices[0]), nullptr,
+ nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
CHECK_EQ(this->queues.size(), 0U);
for (size_t i = 0; i < this->devices.size(); ++i) {
cl_device_id did = this->devices[i];
- this->queues.push_back(
- clCreateCommandQueue(this->context, did, 0, &err_code));
+ this->queues.push_back(clCreateCommandQueue(this->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code);
}
initialized_ = true;
}
-TVM_REGISTER_GLOBAL("device_api.opencl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = OpenCLWorkspace::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = OpenCLWorkspace::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
} // namespace cl
} // namespace runtime
/*!
* \file opencl_module.cc
*/
+#include "opencl_module.h"
+
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
-#include <vector>
+
#include <string>
#include <unordered_map>
+#include <vector>
+
#include "opencl_common.h"
-#include "opencl_module.h"
namespace tvm {
namespace runtime {
class OpenCLWrappedFunc {
public:
// initialize the OpenCL function.
- void Init(OpenCLModuleNode* m,
- ObjectPtr<Object> sptr,
- OpenCLModuleNode::KTRefEntry entry,
- std::string func_name,
- std::vector<size_t> arg_size,
- const std::vector<std::string>& thread_axis_tags) {
+ void Init(OpenCLModuleNode* m, ObjectPtr<Object> sptr, OpenCLModuleNode::KTRefEntry entry,
+ std::string func_name, std::vector<size_t> arg_size,
+ const std::vector<std::string>& thread_axis_tags) {
w_ = m->GetGlobalWorkspace().get();
m_ = m;
sptr_ = sptr;
thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags);
}
// invoke the function with void arguments
- void operator()(TVMArgs args,
- TVMRetValue* rv,
- void** void_args) const {
+ void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
CHECK(w_->context != nullptr) << "No OpenCL device";
cl::OpenCLThreadEntry* t = w_->GetThreadEntry();
// get the kernel from thread local kernel table.
wl.work_size[i] *= wl.work_size[i + 3];
}
// launch kernel
- OPENCL_CALL(clEnqueueNDRangeKernel(
- queue, kernel, work_dim, nullptr,
- wl.work_size,
- wl.work_size + 3,
- 0, nullptr, nullptr));
+ OPENCL_CALL(clEnqueueNDRangeKernel(queue, kernel, work_dim, nullptr, wl.work_size,
+ wl.work_size + 3, 0, nullptr, nullptr));
}
private:
return cl::OpenCLWorkspace::Global();
}
-PackedFunc OpenCLModuleNode::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc OpenCLModuleNode::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
- CHECK_NE(name, symbol::tvm_module_main)
- << "Device function do not have main";
+ CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
}
}
// initialize the wrapped func.
- f.Init(this, sptr_to_self, kid_map_.at(name),
- name, arg_size, info.thread_axis_tags);
+ f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags);
return PackFuncVoidAddr(f, info.arg_types);
}
-void OpenCLModuleNode::SaveToFile(const std::string& file_name,
- const std::string& format) {
+void OpenCLModuleNode::SaveToFile(const std::string& file_name, const std::string& format) {
std::string fmt = GetFileFormat(file_name, format);
- CHECK_EQ(fmt, fmt_)
- << "Can only save to format=" << fmt_;
+ CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}
}
-cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w,
- cl::OpenCLThreadEntry* t,
- const std::string& func_name,
- const KTRefEntry& e) {
+cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
+ const std::string& func_name, const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_);
int device_id = t->context.device_id;
if (!device_built_flag_[device_id]) {
OPENCL_CHECK_ERROR(err);
}
} else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") {
- const unsigned char* s = (const unsigned char *)data_.c_str();
+ const unsigned char* s = (const unsigned char*)data_.c_str();
size_t len = data_.length();
cl_int err;
cl_device_id dev = w->devices[device_id];
if (err != CL_SUCCESS) {
size_t len;
std::string log;
- clGetProgramBuildInfo(
- program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
+ clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
log.resize(len);
- clGetProgramBuildInfo(
- program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
+ clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << log;
}
device_built_flag_[device_id] = true;
return kernel;
}
-Module OpenCLModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source) {
+Module OpenCLModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
n->Init();
return Module(n);
}
// Load module from module.
-Module OpenCLModuleLoadFile(const std::string& file_name,
- const std::string& format) {
+Module OpenCLModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
return OpenCLModuleCreate(data, fmt, fmap, std::string());
}
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl")
-.set_body_typed(OpenCLModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl").set_body_typed(OpenCLModuleLoadFile);
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin")
-.set_body_typed(OpenCLModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin").set_body_typed(OpenCLModuleLoadFile);
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl")
-.set_body_typed(OpenCLModuleLoadBinary);
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl").set_body_typed(OpenCLModuleLoadBinary);
} // namespace runtime
} // namespace tvm
#define TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_
#include <tvm/runtime/packed_func.h>
+
#include <memory>
-#include <vector>
#include <string>
#include <unordered_map>
+#include <vector>
+
#include "../meta_data.h"
namespace tvm {
* \param fmt The format of the data, can be "clbin", "cl"
* \param fmap The map function information map of each function.
*/
-Module OpenCLModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source);
+Module OpenCLModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_COMMON_H_
#include <memory>
+
#include "../opencl_common.h"
namespace tvm {
static const std::shared_ptr<OpenCLWorkspace>& Global();
};
-
/*! \brief Thread local workspace for SDAccel*/
class SDAccelThreadEntry : public OpenCLThreadEntry {
public:
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
/*!
* \file sdaccel_device_api.cc
*/
-#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
+#include <tvm/runtime/registry.h>
+
#include "sdaccel_common.h"
namespace tvm {
namespace runtime {
namespace cl {
-OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() {
- return SDAccelThreadEntry::ThreadLocal();
-}
+OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { return SDAccelThreadEntry::ThreadLocal(); }
const std::shared_ptr<OpenCLWorkspace>& SDAccelWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<SDAccelWorkspace>();
return inst;
}
-void SDAccelWorkspace::Init() {
- OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx");
-}
+void SDAccelWorkspace::Init() { OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx"); }
bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) {
return ctx.device_type == static_cast<DLDeviceType>(kDLSDAccel);
typedef dmlc::ThreadLocalStore<SDAccelThreadEntry> SDAccelThreadStore;
-SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() {
- return SDAccelThreadStore::Get();
-}
+SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { return SDAccelThreadStore::Get(); }
-TVM_REGISTER_GLOBAL("device_api.sdaccel")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = SDAccelWorkspace::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.sdaccel").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = SDAccelWorkspace::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
} // namespace cl
} // namespace runtime
/*!
* \file sdaccel_module.cc
*/
+#include "sdaccel_module.h"
+
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
-#include <vector>
+
#include <string>
#include <unordered_map>
+#include <vector>
+
#include "sdaccel_common.h"
-#include "sdaccel_module.h"
namespace tvm {
namespace runtime {
class SDAccelModuleNode : public OpenCLModuleNode {
public:
- explicit SDAccelModuleNode(std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source)
+ explicit SDAccelModuleNode(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: OpenCLModuleNode(data, fmt, fmap, source) {}
const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
};
return cl::SDAccelWorkspace::Global();
}
-Module SDAccelModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source) {
+Module SDAccelModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
auto n = make_object<SDAccelModuleNode>(data, fmt, fmap, source);
n->Init();
return Module(n);
}
-Module SDAccelModuleLoadFile(const std::string& file_name,
- const std::string& format) {
+Module SDAccelModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
return SDAccelModuleCreate(data, fmt, fmap, std::string());
}
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_xclbin")
-.set_body_typed(SDAccelModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_xclbin").set_body_typed(SDAccelModuleLoadFile);
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_awsxclbin")
-.set_body_typed(SDAccelModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_awsxclbin").set_body_typed(SDAccelModuleLoadFile);
} // namespace runtime
} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_MODULE_H_
#include <tvm/runtime/packed_func.h>
+
#include <memory>
-#include <vector>
#include <string>
#include <unordered_map>
+#include <vector>
+
#include "../../meta_data.h"
namespace tvm {
* \param fmt The format of the data, can be "xclbin", "awsxclbin"
* \param fmap The map function information map of each function.
*/
-Module SDAccelModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source);
+Module SDAccelModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_MODULE_H_
#ifndef TVM_RUNTIME_OPENGL_OPENGL_COMMON_H_
#define TVM_RUNTIME_OPENGL_OPENGL_COMMON_H_
+#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
-#include <dmlc/logging.h>
+#include <tvm/runtime/packed_func.h>
#if defined(__APPLE__)
#define GLFW_INCLUDE_GLCOREARB
#endif
#include <GLFW/glfw3.h>
+
+#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
-#include <memory>
namespace tvm {
namespace runtime {
return proc;
}
-#define SetGLFunctionPointer(NAME) \
- NAME(decltype(NAME)(GetProcAddress("gl" #NAME)))
+#define SetGLFunctionPointer(NAME) NAME(decltype(NAME)(GetProcAddress("gl" #NAME)))
/*!
* \brief The function pointers of all OpenGL APIs that are used.
void (*BindFramebuffer)(GLenum target, GLuint framebuffer);
void (*BindTexture)(GLenum target, GLuint texture);
void (*BindVertexArray)(GLuint array);
- void (*BufferData)(GLenum target, GLsizeiptr size, const GLvoid* data,
- GLenum usage);
+ void (*BufferData)(GLenum target, GLsizeiptr size, const GLvoid* data, GLenum usage);
GLenum (*CheckFramebufferStatus)(GLenum target);
void (*Clear)(GLbitfield mask);
void (*CompileShader)(GLuint shader);
void (*DrawBuffers)(GLsizei n, const GLenum* bufs);
void (*EnableVertexAttribArray)(GLuint index);
void (*Finish)();
- void (*FramebufferTexture2D)(GLenum target, GLenum attachment,
- GLenum textarget, GLuint texture, GLint level);
+ void (*FramebufferTexture2D)(GLenum target, GLenum attachment, GLenum textarget, GLuint texture,
+ GLint level);
void (*GenBuffers)(GLsizei n, GLuint* buffers);
void (*GenFramebuffers)(GLsizei n, GLuint* ids);
void (*GenTextures)(GLsizei n, GLuint* textures);
GLint (*GetAttribLocation)(GLuint program, const GLchar* name);
GLenum (*GetError)();
void (*GetIntegerv)(GLenum pname, GLint* data);
- void (*GetProgramInfoLog)(GLuint program, GLsizei maxLength, GLsizei* length,
- GLchar* info_log);
+ void (*GetProgramInfoLog)(GLuint program, GLsizei maxLength, GLsizei* length, GLchar* info_log);
void (*GetProgramiv)(GLuint program, GLenum pname, GLint* params);
- void (*GetShaderInfoLog)(GLuint shader, GLsizei max_length, GLsizei* length,
- GLchar* info_log);
+ void (*GetShaderInfoLog)(GLuint shader, GLsizei max_length, GLsizei* length, GLchar* info_log);
void (*GetShaderiv)(GLuint shader, GLenum pname, GLint* params);
- const GLubyte *(*GetString)(GLenum name);
+ const GLubyte* (*GetString)(GLenum name);
GLint (*GetUniformLocation)(GLuint program, const GLchar* name);
void (*LinkProgram)(GLuint program);
- void (*ReadPixels)(GLint x, GLint y, GLsizei width, GLsizei height,
- GLenum format, GLenum type, GLvoid* data);
- void (*ShaderSource)(GLuint shader, GLsizei count, const GLchar** string,
- const GLint* length);
- void (*TexImage2D)(GLenum target, GLint level, GLint internal_format,
- GLsizei width, GLsizei height, GLint border, GLenum format,
- GLenum type, const GLvoid* data);
+ void (*ReadPixels)(GLint x, GLint y, GLsizei width, GLsizei height, GLenum format, GLenum type,
+ GLvoid* data);
+ void (*ShaderSource)(GLuint shader, GLsizei count, const GLchar** string, const GLint* length);
+ void (*TexImage2D)(GLenum target, GLint level, GLint internal_format, GLsizei width,
+ GLsizei height, GLint border, GLenum format, GLenum type, const GLvoid* data);
void (*TexParameteri)(GLenum target, GLenum pname, GLint param);
- void (*TexSubImage2D)(GLenum target, GLint level, GLint xoffset,
- GLint yoffset, GLsizei width, GLsizei height,
- GLenum format, GLenum type, const GLvoid* data);
+ void (*TexSubImage2D)(GLenum target, GLint level, GLint xoffset, GLint yoffset, GLsizei width,
+ GLsizei height, GLenum format, GLenum type, const GLvoid* data);
void (*Uniform1f)(GLint location, GLfloat v0);
void (*Uniform1i)(GLint location, GLint v0);
void (*UseProgram)(GLuint program);
- void (*VertexAttribPointer)(GLuint index, GLint size, GLenum type,
- GLboolean normalized, GLsizei stride,
- const GLvoid* pointer);
+ void (*VertexAttribPointer)(GLuint index, GLint size, GLenum type, GLboolean normalized,
+ GLsizei stride, const GLvoid* pointer);
void (*Viewport)(GLint x, GLint y, GLsizei width, GLsizei height);
};
// override device API
void SetDevice(TVMContext ctx) final;
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
- DLDataType type_hint) final;
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
- void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
* \param nelems The number of elements to be written to.
* \param data The user data.
*/
- void PutTextureData(Texture* texture,
- GLint begin,
- GLsizei nelems,
- const GLvoid* data);
+ void PutTextureData(Texture* texture, GLint begin, GLsizei nelems, const GLvoid* data);
/*!
* \brief Download a sub-region of an OpenGL texture.
* \param texture The texture to download from.
* \param nelems The number of elements to download from.
* \param data The user buffer.
*/
- void GetTextureData(const Texture* texture,
- GLint begin,
- GLsizei nelems,
- GLvoid* data);
+ void GetTextureData(const Texture* texture, GLint begin, GLsizei nelems, GLvoid* data);
/*!
* \brief Set currently used OpenGL program.
* \param type The type of the uniform.
* \param value The value to pass in.
*/
- void SetUniform(const Program& program,
- const std::string& name,
- DLDataType type,
- void* value);
+ void SetUniform(const Program& program, const std::string& name, DLDataType type, void* value);
/*!
* \brief Set input texture for an OpenGL program.
* different unit.
* \param texture The OpenGL texture to pass in.
*/
- void SetInputTexture(const Program& program,
- const std::string& name,
- GLuint unit,
+ void SetInputTexture(const Program& program, const std::string& name, GLuint unit,
Texture* texture);
/*!
class Program {
public:
// Move constructor.
- Program(Program&& other) noexcept
- : workspace_(other.workspace_), program_(other.program_) {
+ Program(Program&& other) noexcept : workspace_(other.workspace_), program_(other.program_) {
other.program_ = kInvalidProgram;
}
GLsizei elemsz() const {
switch (type) {
- case GL_BYTE: case GL_UNSIGNED_BYTE:
+ case GL_BYTE:
+ case GL_UNSIGNED_BYTE:
return 1;
- case GL_SHORT: case GL_UNSIGNED_SHORT:
+ case GL_SHORT:
+ case GL_UNSIGNED_SHORT:
return 2;
- case GL_INT: case GL_UNSIGNED_INT:
+ case GL_INT:
+ case GL_UNSIGNED_INT:
return 4;
case GL_FLOAT:
return 4;
bool operator==(const TextureFormat& other) const {
return std::make_tuple(internal_format, format, type) ==
- std::make_tuple(other.internal_format, other.format, other.type);
+ std::make_tuple(other.internal_format, other.format, other.type);
}
GLint internal_format; // OpenGL says this is GLint, not GLenum.
public:
// Move constructor.
Texture(Texture&& other) noexcept
- : workspace_(other.workspace_), texture_(other.texture_),
- format_(other.format_), width_(other.width_), height_(other.height_) {
+ : workspace_(other.workspace_),
+ texture_(other.texture_),
+ format_(other.format_),
+ width_(other.width_),
+ height_(other.height_) {
other.texture_ = kInvalidTexture;
}
// We enforce this to make sure OpenGL is initialized.
// Always only use the first dimension of a 2D texture.
// The reason is that texelFetch only supports 2D textures.
- explicit Texture(OpenGLWorkspace* workspace, GLuint texture,
- TextureFormat format,
- GLsizei width, GLsizei height)
- : workspace_(workspace), texture_(texture), format_(format),
- width_(width), height_(height) {}
+ explicit Texture(OpenGLWorkspace* workspace, GLuint texture, TextureFormat format, GLsizei width,
+ GLsizei height)
+ : workspace_(workspace), texture_(texture), format_(format), width_(width), height_(height) {}
// The internal texture ID.
GLuint texture() const { return texture_; }
* \file opengl_device_api.cc
*/
#include <tvm/runtime/registry.h>
+
#include <cstring>
+
#include "opengl_common.h"
#include "opengl_module.h"
*/
void OpenGLWorkspace::CheckOpenGLError() {
GLenum err = gl->GetError();
- CHECK_EQ(err, GL_NO_ERROR) << "OpenGL error, code=" << err << ": "
- << gl::GLGetErrorString(err);
+ CHECK_EQ(err, GL_NO_ERROR) << "OpenGL error, code=" << err << ": " << gl::GLGetErrorString(err);
}
/*!
* \brief Protected OpenGL call.
* \param func Expression to call.
*/
-#define OPENGL_CALL(func) \
- { \
- (func); \
- CheckOpenGLError(); \
+#define OPENGL_CALL(func) \
+ { \
+ (func); \
+ CheckOpenGLError(); \
}
/*!
* \brief The error handling callback passed to GLFW.
*/
-void GlfwErrorCallback(int err, const char* str) {
- LOG(FATAL) << "Error: [" << err << "] " << str;
-}
+void GlfwErrorCallback(int err, const char* str) { LOG(FATAL) << "Error: [" << err << "] " << str; }
const std::shared_ptr<OpenGLWorkspace>& OpenGLWorkspace::Global() {
static std::shared_ptr<OpenGLWorkspace> inst(new OpenGLWorkspace);
}
void OpenGLWorkspace::SetDevice(TVMContext ctx) {
- CHECK_EQ(ctx.device_type, static_cast<int>(kOpenGL))
- << "Device type must be OpenGL.";
+ CHECK_EQ(ctx.device_type, static_cast<int>(kOpenGL)) << "Device type must be OpenGL.";
CHECK_EQ(ctx.device_id, 0) << "Only support 1 OpenGL \"device\".";
}
-void OpenGLWorkspace::GetAttr(
- TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
+void OpenGLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
switch (kind) {
case kExist: {
*rv = static_cast<int>(ctx.device_id == 0);
*rv = 1;
break;
}
- case kMaxSharedMemoryPerBlock: return;
+ case kMaxSharedMemoryPerBlock:
+ return;
case kComputeVersion: {
break;
}
- case kDeviceName: return;
- case kMaxClockRate: return;
- case kMultiProcessorCount: return;
- case kMaxThreadDimensions: return;
- case kGcnArch: return;
+ case kDeviceName:
+ return;
+ case kMaxClockRate:
+ return;
+ case kMultiProcessorCount:
+ return;
+ case kMaxThreadDimensions:
+ return;
+ case kGcnArch:
+ return;
}
}
-void* OpenGLWorkspace::AllocDataSpace(
- TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) {
+void* OpenGLWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
+ DLDataType type_hint) {
return reinterpret_cast<void*>(new Texture(CreateTexture(type_hint, nbytes)));
}
delete reinterpret_cast<Texture*>(ptr);
}
-void OpenGLWorkspace::CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+void OpenGLWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
+ size_t to_offset, size_t size, TVMContext ctx_from,
+ TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) {
CHECK(stream == nullptr);
} else if (type_from_to == std::make_tuple(gl_devtype, kDLCPU)) {
auto texture = static_cast<const Texture*>(from);
- void *data = static_cast<char *>(to) + to_offset;
+ void* data = static_cast<char*>(to) + to_offset;
auto elemsz = texture->elemsz();
auto begin = static_cast<GLint>(from_offset / elemsz);
auto nelems = static_cast<GLsizei>(size / elemsz);
GLuint vertex_buffer;
OPENGL_CALL(gl->GenBuffers(1, &vertex_buffer));
OPENGL_CALL(gl->BindBuffer(GL_ARRAY_BUFFER, vertex_buffer));
- OPENGL_CALL(gl->BufferData(GL_ARRAY_BUFFER, sizeof(vertices), vertices,
- GL_STATIC_DRAW));
+ OPENGL_CALL(gl->BufferData(GL_ARRAY_BUFFER, sizeof(vertices), vertices, GL_STATIC_DRAW));
GLuint vertex_array;
OPENGL_CALL(gl->GenVertexArrays(1, &vertex_array));
OPENGL_CALL(gl->DeleteTextures(1, &texture));
}
-void OpenGLWorkspace::OnDeleteProgram(GLuint program) {
- OPENGL_CALL(gl->DeleteProgram(program));
-}
+void OpenGLWorkspace::OnDeleteProgram(GLuint program) { OPENGL_CALL(gl->DeleteProgram(program)); }
GLuint OpenGLWorkspace::NumTextureUnits() {
GLint num_units;
}
const OpenGLWorkspace::Vertex OpenGLWorkspace::vertices[OpenGLWorkspace::kNumVertices] = {
- {-1.f, -1.f},
- {1.0f, -1.f},
- {1.0f, 1.0f},
- {-1.f, -1.f},
- {-1.f, 1.0f},
- {1.0f, 1.0f},
+ {-1.f, -1.f}, {1.0f, -1.f}, {1.0f, 1.0f}, {-1.f, -1.f}, {-1.f, 1.0f}, {1.0f, 1.0f},
};
// Don't need to change this.
// The vertex shader only needs to take in the triangle points.
// No need for point transformations.
-const char* OpenGLWorkspace::vertex_shader_text_ = "#version 300 es\n"
+const char* OpenGLWorkspace::vertex_shader_text_ =
+ "#version 300 es\n"
"in vec2 point; // input to vertex shader\n"
"void main() {\n"
" gl_Position = vec4(point, 0.0, 1.0);\n"
"}\n";
-Program OpenGLWorkspace::CreateProgram(
- const char* fragment_shader_src) {
+Program OpenGLWorkspace::CreateProgram(const char* fragment_shader_src) {
// Create and compile the shaders.
- GLuint fragment_shader = CreateShader(GL_FRAGMENT_SHADER,
- fragment_shader_src);
+ GLuint fragment_shader = CreateShader(GL_FRAGMENT_SHADER, fragment_shader_src);
// Link the shaders and create the program.
Program program = CreateProgram(fragment_shader);
return program;
}
-GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind,
- const char* shader_src) {
+GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind, const char* shader_src) {
// Create the shader.
GLuint shader = gl->CreateShader(shader_kind);
gl->ShaderSource(shader, 1, &shader_src, nullptr);
auto nelems = static_cast<GLsizei>(nbytes / (type.bits / 8));
auto height = (nelems + kTextureRowSize - 1) / kTextureRowSize;
auto width = (height == 1) ? nelems : kTextureRowSize;
- OPENGL_CALL(gl->TexImage2D(GL_TEXTURE_2D, /*level=*/0,
- texture_format.internal_format,
- width, height, /*border=*/0,
- texture_format.format, texture_format.type,
+ OPENGL_CALL(gl->TexImage2D(GL_TEXTURE_2D, /*level=*/0, texture_format.internal_format, width,
+ height, /*border=*/0, texture_format.format, texture_format.type,
/*data=*/nullptr));
- OPENGL_CALL(
- gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE));
- OPENGL_CALL(
- gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE));
- OPENGL_CALL(
- gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST));
- OPENGL_CALL(
- gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST));
+ OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE));
+ OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE));
+ OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST));
+ OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST));
return Texture(this, texture, texture_format, width, height);
}
auto point_attrib = GLuint(gl->GetAttribLocation(program, "point"));
OPENGL_CALL(gl->EnableVertexAttribArray(point_attrib));
- OPENGL_CALL(gl->VertexAttribPointer(point_attrib, 2, GL_FLOAT, GL_FALSE,
- sizeof(Vertex), nullptr));
+ OPENGL_CALL(
+ gl->VertexAttribPointer(point_attrib, 2, GL_FLOAT, GL_FALSE, sizeof(Vertex), nullptr));
return Program(this, program);
}
on_2d_block(0, ylast, xlast + 1, 1);
}
-void OpenGLWorkspace::PutTextureData(Texture *texture,
- GLint begin,
- GLsizei nelems,
+void OpenGLWorkspace::PutTextureData(Texture* texture, GLint begin, GLsizei nelems,
const GLvoid* data) {
// Bind to temporary unit.
BindTextureUnit(NumTextureUnits() - 1, texture->texture());
- Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg,
- GLsizei width, GLsizei height) {
+ Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, GLsizei width, GLsizei height) {
auto offset = (ybeg * kTextureRowSize + xbeg - begin) * texture->elemsz();
const GLvoid* ptr = static_cast<const char*>(data) + offset;
// Similar to cudaMemcpy.
- OPENGL_CALL(gl->TexSubImage2D(GL_TEXTURE_2D, /*level=*/0,
- xbeg, ybeg, width, height,
- texture->format_.format,
- texture->format_.type, ptr));
+ OPENGL_CALL(gl->TexSubImage2D(GL_TEXTURE_2D, /*level=*/0, xbeg, ybeg, width, height,
+ texture->format_.format, texture->format_.type, ptr));
});
}
-void OpenGLWorkspace::GetTextureData(const Texture *texture,
- GLint begin,
- GLsizei nelems,
+void OpenGLWorkspace::GetTextureData(const Texture* texture, GLint begin, GLsizei nelems,
GLvoid* data) {
BindTextureUnit(NumTextureUnits() - 1, texture->texture());
OPENGL_CALL(gl->BindFramebuffer(GL_FRAMEBUFFER, frame_buffer));
// Bind texture to framebuffer's attachment 0.
- OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
- GL_TEXTURE_2D, texture->texture(), 0));
+ OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
+ texture->texture(), 0));
// Always check that our framebuffer is okay.
if (gl->CheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE) {
auto nchannels = 4;
auto padded_data_size = nchannels * nelems * elemsz;
auto padded_data = std::unique_ptr<char[]>(new char[padded_data_size]);
- Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg,
- GLsizei width, GLsizei height) {
+ Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, GLsizei width, GLsizei height) {
auto data_offset = (ybeg * kTextureRowSize + xbeg - begin) * elemsz;
auto padded_data_offset = data_offset * nchannels;
- OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height,
- GL_RGBA, GL_FLOAT,
+ OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, GL_RGBA, GL_FLOAT,
padded_data.get() + padded_data_offset));
});
for (GLsizei i = 0; i != nelems; ++i) {
- auto dst = reinterpret_cast<char *>(data) + i * elemsz;
+ auto dst = reinterpret_cast<char*>(data) + i * elemsz;
auto src = padded_data.get() + nchannels * i * elemsz;
std::memcpy(dst, src, elemsz);
}
#else
- Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg,
- GLsizei width, GLsizei height) {
+ Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, GLsizei width, GLsizei height) {
auto offset = (ybeg * kTextureRowSize + xbeg - begin) * texture->elemsz();
GLvoid* ptr = static_cast<char*>(data) + offset;
- OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height,
- texture->format_.format, texture->format_.type,
- ptr));
+ OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, texture->format_.format,
+ texture->format_.type, ptr));
});
#endif
OPENGL_CALL(gl->UseProgram(program.program()));
}
-void OpenGLWorkspace::SetUniform(const Program& program,
- const std::string& name,
- DLDataType type,
+void OpenGLWorkspace::SetUniform(const Program& program, const std::string& name, DLDataType type,
void* value) {
GLint location = gl->GetUniformLocation(program.program(), name.c_str());
switch (type.code) {
}
}
-void OpenGLWorkspace::SetInputTexture(const Program& program,
- const std::string& name,
- GLuint unit,
+void OpenGLWorkspace::SetInputTexture(const Program& program, const std::string& name, GLuint unit,
Texture* texture) {
// We always use the last texture unit as temporary.
// Therefore, we can have "NumTextureUnits() - 1" input textures.
OPENGL_CALL(gl->BindFramebuffer(GL_FRAMEBUFFER, frame_buffer));
// Set "renderedTexture" as our colour attachement 0.
- OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
- GL_TEXTURE_2D, output->texture(), 0));
+ OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
+ output->texture(), 0));
// Specify that we will render to color attachment 0.
GLenum DrawBuffers[1] = {GL_COLOR_ATTACHMENT0};
OPENGL_CALL(gl->DeleteFramebuffers(1, &frame_buffer));
}
-TVM_REGISTER_GLOBAL("device_api.opengl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("device_api.opengl").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenGLWorkspace::Global().get();
*rv = static_cast<void*>(ptr);
});
/*!
* \file opengl_module.cc
*/
+#include "opengl_module.h"
+
#include <tvm/runtime/registry.h>
-#include <utility>
+
#include <unordered_map>
-#include "opengl_common.h"
-#include "opengl_module.h"
+#include <utility>
+
+#include "../file_util.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
-#include "../file_util.h"
+#include "opengl_common.h"
namespace tvm {
namespace runtime {
class OpenGLModuleNode final : public ModuleNode {
public:
- OpenGLModuleNode(std::unordered_map<std::string, OpenGLShader> shaders,
- std::string fmt,
+ OpenGLModuleNode(std::unordered_map<std::string, OpenGLShader> shaders, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap);
~OpenGLModuleNode() override = default;
const char* type_key() const final { return "opengl"; }
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final;
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
std::string GetSource(const std::string& format) final;
- void SaveToFile(const std::string& file_name,
- const std::string& format) final;
+ void SaveToFile(const std::string& file_name, const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
class OpenGLWrappedFunc {
public:
- OpenGLWrappedFunc(OpenGLModuleNode* m,
- ObjectPtr<Object> sptr,
- std::string func_name,
- std::vector<size_t> arg_size,
- const std::vector<std::string>& thread_axis_tags);
+ OpenGLWrappedFunc(OpenGLModuleNode* m, ObjectPtr<Object> sptr, std::string func_name,
+ std::vector<size_t> arg_size, const std::vector<std::string>& thread_axis_tags);
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const;
ThreadAxisConfig thread_axis_cfg_;
};
-OpenGLModuleNode::OpenGLModuleNode(
- std::unordered_map<std::string, OpenGLShader> shaders,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap)
- : workspace_(gl::OpenGLWorkspace::Global()), shaders_(std::move(shaders)),
- fmt_(std::move(fmt)), fmap_(std::move(fmap)), programs_() {
+OpenGLModuleNode::OpenGLModuleNode(std::unordered_map<std::string, OpenGLShader> shaders,
+ std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap)
+ : workspace_(gl::OpenGLWorkspace::Global()),
+ shaders_(std::move(shaders)),
+ fmt_(std::move(fmt)),
+ fmap_(std::move(fmap)),
+ programs_() {
CHECK_EQ(fmt_, "gl") << "Unknown OpenGL format " << fmt_;
- for (auto &pair : shaders_) {
- auto &func_name = pair.first;
- auto &shader = pair.second;
- programs_.emplace(func_name,
- workspace_->CreateProgram(shader.source.c_str()));
+ for (auto& pair : shaders_) {
+ auto& func_name = pair.first;
+ auto& shader = pair.second;
+ programs_.emplace(func_name, workspace_->CreateProgram(shader.source.c_str()));
}
}
-PackedFunc OpenGLModuleNode::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc OpenGLModuleNode::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto func_info_it = fmap_.find(name);
- if (func_info_it == fmap_.end()) { return PackedFunc(); }
- auto &func_info = func_info_it->second;
+ if (func_info_it == fmap_.end()) {
+ return PackedFunc();
+ }
+ auto& func_info = func_info_it->second;
std::vector<size_t> arg_size(func_info.arg_types.size());
for (size_t i = 0; i < func_info.arg_types.size(); ++i) {
}
// Initialize the wrapped func.
- OpenGLWrappedFunc f(this, sptr_to_self, name, arg_size,
- func_info.thread_axis_tags);
+ OpenGLWrappedFunc f(this, sptr_to_self, name, arg_size, func_info.thread_axis_tags);
return PackFuncVoidAddr(f, func_info.arg_types);
}
std::string OpenGLModuleNode::GetSource(const std::string& format) {
- if (format != fmt_ && fmt_ != "gl") { return ""; }
+ if (format != fmt_ && fmt_ != "gl") {
+ return "";
+ }
std::ostringstream os;
- for (auto &pair : shaders_) {
- auto &name = pair.first;
- auto &shader = pair.second;
- os << "[" << name << "]" << "\n";
- os << shader.source <<"\n";
+ for (auto& pair : shaders_) {
+ auto& name = pair.first;
+ auto& shader = pair.second;
+ os << "[" << name << "]"
+ << "\n";
+ os << shader.source << "\n";
}
return os.str();
}
-void OpenGLModuleNode::SaveToFile(const std::string& file_name,
- const std::string& format) {
+void OpenGLModuleNode::SaveToFile(const std::string& file_name, const std::string& format) {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
stream->Write(ToJSON(shaders_));
}
-const gl::Program& OpenGLModuleNode::GetProgram(
- const std::string& func_name) const {
+const gl::Program& OpenGLModuleNode::GetProgram(const std::string& func_name) const {
auto it = programs_.find(func_name);
if (it == programs_.end()) {
LOG(FATAL) << "Cannot find program";
return it->second;
}
-const OpenGLShader& OpenGLModuleNode::GetShader(
- const std::string& func_name) const {
+const OpenGLShader& OpenGLModuleNode::GetShader(const std::string& func_name) const {
auto it = shaders_.find(func_name);
if (it == shaders_.end()) {
LOG(FATAL) << "Cannot find shader";
return it->second;
}
-const FunctionInfo& OpenGLModuleNode::GetFunctionInfo(
- const std::string& func_name) const {
+const FunctionInfo& OpenGLModuleNode::GetFunctionInfo(const std::string& func_name) const {
auto it = fmap_.find(func_name);
if (it == fmap_.end()) {
LOG(FATAL) << "Cannot find shader";
return it->second;
}
-OpenGLWrappedFunc::OpenGLWrappedFunc(
- OpenGLModuleNode* m,
- ObjectPtr<Object> sptr,
- std::string func_name,
- std::vector<size_t> arg_size,
- const std::vector<std::string>& thread_axis_tags)
- : m_(m), sptr_(std::move(sptr)), func_name_(std::move(func_name)),
+OpenGLWrappedFunc::OpenGLWrappedFunc(OpenGLModuleNode* m, ObjectPtr<Object> sptr,
+ std::string func_name, std::vector<size_t> arg_size,
+ const std::vector<std::string>& thread_axis_tags)
+ : m_(m),
+ sptr_(std::move(sptr)),
+ func_name_(std::move(func_name)),
arg_size_(std::move(arg_size)) {
thread_axis_cfg_.Init(arg_size_.size(), thread_axis_tags);
}
-void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
- void** void_args) const {
- auto &shader = m_->GetShader(func_name_);
- auto &program = m_->GetProgram(func_name_);
- auto &func_info = m_->GetFunctionInfo(func_name_);
+void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
+ auto& shader = m_->GetShader(func_name_);
+ auto& program = m_->GetProgram(func_name_);
+ auto& func_info = m_->GetFunctionInfo(func_name_);
size_t nargs = shader.arg_kinds.size();
// Must call this function before setting uniforms & input textures.
GLuint texture_unit = 0;
gl::Texture* output = nullptr;
for (size_t i = 0; i != nargs; ++i) {
- auto &name = shader.arg_names.at(i);
+ auto& name = shader.arg_names.at(i);
auto kind = shader.arg_kinds.at(i);
auto type = func_info.arg_types.at(i);
switch (kind) {
// Set "thread_extent" uniform.
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
std::unique_ptr<GLint> thread_extent(new GLint(wl.block_dim(0)));
- m_->workspace().SetUniform(program, shader.thread_extent_var,
- DLDataType{kDLInt, 32, 1},
+ m_->workspace().SetUniform(program, shader.thread_extent_var, DLDataType{kDLInt, 32, 1},
static_cast<void*>(thread_extent.get()));
m_->workspace().Render(output);
}
-Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders,
- std::string fmt,
+Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap) {
- auto n = make_object<OpenGLModuleNode>(std::move(shaders),
- std::move(fmt),
- std::move(fmap));
+ auto n = make_object<OpenGLModuleNode>(std::move(shaders), std::move(fmt), std::move(fmap));
return Module(n);
}
-Module OpenGLModuleLoadFile(const std::string& file_name,
- const std::string& format) {
+Module OpenGLModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
return OpenGLModuleCreate(FromJSON(data), fmt, fmap);
}
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_gl")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = OpenGLModuleLoadFile(args[0], args[1]);
- });
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_gl").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = OpenGLModuleLoadFile(args[0], args[1]);
+});
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_glbin")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = OpenGLModuleLoadFile(args[0], args[1]);
- });
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_glbin").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = OpenGLModuleLoadFile(args[0], args[1]);
+});
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opengl")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = OpenGLModuleLoadBinary(args[0]);
- });
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opengl").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = OpenGLModuleLoadBinary(args[0]);
+});
} // namespace runtime
} // namespace tvm
#define TVM_RUNTIME_OPENGL_OPENGL_MODULE_H_
#include <tvm/runtime/packed_func.h>
+
#include <algorithm>
#include <memory>
#include <string>
-#include <vector>
-#include <utility>
#include <unordered_map>
+#include <utility>
+#include <vector>
+
#include "../meta_data.h"
namespace tvm {
*/
struct OpenGLShader {
OpenGLShader() = default;
- OpenGLShader(std::string source,
- std::vector<std::string> arg_names,
- std::vector<OpenGLArgKind> arg_kinds,
- std::string thread_extent_var)
- : source(std::move(source)), arg_names(std::move(arg_names)),
+ OpenGLShader(std::string source, std::vector<std::string> arg_names,
+ std::vector<OpenGLArgKind> arg_kinds, std::string thread_extent_var)
+ : source(std::move(source)),
+ arg_names(std::move(arg_names)),
arg_kinds(std::move(arg_kinds)),
thread_extent_var(std::move(thread_extent_var)) {
CHECK_EQ(this->arg_names.size(), this->arg_kinds.size()) << "Invalid input";
* \param fmt The format of the data,
* \param fmap The map function information map of each function.
*/
-Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders,
- std::string fmt,
+Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap);
inline std::string OpenGLArgKind2String(OpenGLArgKind kind) {
}
}
-inline std::string ToJSON(
- const std::unordered_map<std::string, OpenGLShader>& shaders) {
+inline std::string ToJSON(const std::unordered_map<std::string, OpenGLShader>& shaders) {
std::ostringstream os;
dmlc::JSONWriter writer(&os);
writer.BeginObject();
return os.str();
}
-inline std::unordered_map<std::string, OpenGLShader> FromJSON(
- const std::string& str) {
+inline std::unordered_map<std::string, OpenGLShader> FromJSON(const std::string& str) {
std::unordered_map<std::string, OpenGLShader> shaders;
std::istringstream is(str);
dmlc::JSONReader reader(&is);
#define TVM_RUNTIME_PACK_ARGS_H_
#include <tvm/runtime/c_runtime_api.h>
-#include <vector>
+
#include <cstring>
+#include <vector>
namespace tvm {
namespace runtime {
*
* \return The wrapped packed function.
*/
-template<typename F>
+template <typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types);
/*!
* \brief Create a packed function that from function only packs buffer arguments.
*
* \return The wrapped packed function.
*/
-template<typename F>
+template <typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types);
/*!
* \brief Create a packed function that from function that takes a packed arguments.
*
* \return The wrapped packed function.
*/
-template<typename F>
+template <typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types);
/*!
* \brief Extract number of buffer argument from the argument types.
// implementations details
namespace detail {
-template<typename T, int kSize>
+template <typename T, int kSize>
class TempArray {
public:
explicit TempArray(int size) {}
- T* data() {
- return data_;
- }
+ T* data() { return data_; }
+
private:
T data_[kSize];
};
-template<typename T>
+template <typename T>
class TempArray<T, 0> {
public:
explicit TempArray(int size) : data_(size) {}
- T* data() {
- return data_.data();
- }
+ T* data() { return data_.data(); }
+
private:
std::vector<T> data_;
};
};
inline ArgConvertCode GetArgConvertCode(DLDataType t) {
- CHECK_EQ(t.lanes, 1U)
- << "Cannot pass vector type argument to devic function for now";
+ CHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now";
if (t.code == kDLInt) {
if (t.bits == 64U) return INT64_TO_INT64;
if (t.bits == 32U) return INT64_TO_INT32;
return HANDLE_TO_HANDLE;
}
-template<int N, typename F>
+template <int N, typename F>
inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
addr[i] = &(holder[i]);
break;
}
- case INT64_TO_UINT32 : {
+ case INT64_TO_UINT32: {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
addr[i] = &(holder[i]);
break;
return PackedFunc(ret);
}
-template<int N, typename F>
-inline PackedFunc PackFuncNonBufferArg_(
- F f, int base, const std::vector<ArgConvertCode>& codes) {
+template <int N, typename F>
+inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<ArgUnion, N> holder_(num_args);
switch (codes[i]) {
case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64: {
- LOG(FATAL) << "Do not support 64bit argument to device function"; break;
+ LOG(FATAL) << "Do not support 64bit argument to device function";
+ break;
}
case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
break;
}
- case INT64_TO_UINT32 : {
+ case INT64_TO_UINT32: {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
break;
}
break;
}
case HANDLE_TO_HANDLE: {
- LOG(FATAL) << "not reached"; break;
+ LOG(FATAL) << "not reached";
+ break;
}
}
}
return PackedFunc(ret);
}
-template<int N, typename F>
-inline PackedFunc PackFuncPackedArg_(
- F f, const std::vector<ArgConvertCode>& codes) {
+template <int N, typename F>
+inline PackedFunc PackFuncPackedArg_(F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<uint64_t, N> pack_(num_args);
++ptr;
break;
}
- case INT64_TO_UINT32 : {
- *reinterpret_cast<uint32_t*>(ptr) =
- static_cast<uint32_t>(args.values[i].v_int64);
+ case INT64_TO_UINT32: {
+ *reinterpret_cast<uint32_t*>(ptr) = static_cast<uint32_t>(args.values[i].v_int64);
++ptr;
break;
}
case FLOAT64_TO_FLOAT32: {
- *reinterpret_cast<float*>(ptr) =
- static_cast<float>(args.values[i].v_float64);
+ *reinterpret_cast<float*>(ptr) = static_cast<float>(args.values[i].v_float64);
++ptr;
break;
}
default: {
- LOG(FATAL) << "not reached"; break;
+ LOG(FATAL) << "not reached";
+ break;
}
}
}
}
} // namespace detail
-template<typename F>
+template <typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
size_t base = arg_types.size();
for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i].code != kTVMOpaqueHandle) {
- base = i; break;
+ base = i;
+ break;
}
}
for (size_t i = base; i < arg_types.size(); ++i) {
- CHECK(arg_types[i].code != kTVMOpaqueHandle)
- << "Device function need to be organized";
+ CHECK(arg_types[i].code != kTVMOpaqueHandle) << "Device function need to be organized";
}
return base;
}
-template<typename F>
+template <typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types) {
size_t num_buffer = NumBufferArgs(arg_types);
std::vector<detail::ArgConvertCode> codes;
}
}
-template<typename F>
+template <typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes;
for (size_t i = 0; i < arg_types.size(); ++i) {
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
-#include <unordered_map>
-#include <mutex>
-#include <memory>
+
#include <array>
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+
#include "runtime_base.h"
namespace tvm {
// mutex
std::mutex mutex;
- Manager() {
- }
+ Manager() {}
static Manager* Global() {
// We deliberately leak the Manager instance, to avoid leak sanitizers
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
if (m->fmap.count(name)) {
- CHECK(can_override)
- << "Global PackedFunc " << name << " is already registered";
+ CHECK(can_override) << "Global PackedFunc " << name << " is already registered";
}
Registry* r = new Registry();
std::lock_guard<std::mutex> lock(m->mutex);
std::vector<std::string> keys;
keys.reserve(m->fmap.size());
- for (const auto &kv : m->fmap) {
+ for (const auto& kv : m->fmap) {
keys.push_back(kv.first);
}
return keys;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
- std::vector<const char *> ret_vec_charp;
+ std::vector<const char*> ret_vec_charp;
};
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
-int TVMFuncRegisterGlobal(
- const char* name, TVMFunctionHandle f, int override) {
+int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {
API_BEGIN();
tvm::runtime::Registry::Register(name, override != 0)
.set_body(*static_cast<tvm::runtime::PackedFunc*>(f));
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
API_BEGIN();
- const tvm::runtime::PackedFunc* fp =
- tvm::runtime::Registry::Get(name);
+ const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name);
if (fp != nullptr) {
*out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*)
} else {
API_END();
}
-int TVMFuncListGlobalNames(int *out_size,
- const char*** out_array) {
+int TVMFuncListGlobalNames(int* out_size, const char*** out_array) {
API_BEGIN();
- TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get();
+ TVMFuncThreadLocalEntry* ret = TVMFuncThreadLocalStore::Get();
ret->ret_vec_str = tvm::runtime::Registry::ListNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#ifndef TVM_RUNTIME_ROCM_ROCM_COMMON_H_
#define TVM_RUNTIME_ROCM_ROCM_COMMON_H_
-#include <tvm/runtime/packed_func.h>
#include <hip/hip_runtime_api.h>
+#include <tvm/runtime/packed_func.h>
+
#include <string>
+
#include "../workspace_pool.h"
namespace tvm {
namespace runtime {
-#define ROCM_DRIVER_CALL(x) \
- { \
- hipError_t result = x; \
- if (result != hipSuccess && result != hipErrorDeinitialized) { \
- LOG(FATAL) \
- << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \
- } \
+#define ROCM_DRIVER_CALL(x) \
+ { \
+ hipError_t result = x; \
+ if (result != hipSuccess && result != hipErrorDeinitialized) { \
+ LOG(FATAL) << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \
+ } \
}
-#define ROCM_CALL(func) \
- { \
- hipError_t e = (func); \
- CHECK(e == hipSuccess) \
- << "ROCM HIP: " << hipGetErrorString(e); \
+#define ROCM_CALL(func) \
+ { \
+ hipError_t e = (func); \
+ CHECK(e == hipSuccess) << "ROCM HIP: " << hipGetErrorString(e); \
}
/*! \brief Thread local workspace */
class ROCMDeviceAPI final : public DeviceAPI {
public:
- void SetDevice(TVMContext ctx) final {
- ROCM_CALL(hipSetDevice(ctx.device_id));
- }
+ void SetDevice(TVMContext ctx) final { ROCM_CALL(hipSetDevice(ctx.device_id)); }
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
int value = 0;
switch (kind) {
break;
}
case kMaxThreadsPerBlock: {
- ROCM_CALL(hipDeviceGetAttribute(
- &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id));
+ ROCM_CALL(
+ hipDeviceGetAttribute(&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id));
break;
}
case kWarpSize: {
- ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize,
- ctx.device_id));
+ ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, ctx.device_id));
break;
}
case kMaxSharedMemoryPerBlock: {
- ROCM_CALL(hipDeviceGetAttribute(
- &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id));
+ ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeMaxSharedMemoryPerBlock,
+ ctx.device_id));
break;
}
case kComputeVersion: {
std::ostringstream os;
- ROCM_CALL(hipDeviceGetAttribute(
- &value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id));
+ ROCM_CALL(
+ hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id));
os << value << ".";
- ROCM_CALL(hipDeviceGetAttribute(
- &value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id));
+ ROCM_CALL(
+ hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id));
os << value;
*rv = os.str();
return;
return;
}
case kMaxClockRate: {
- ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate,
- ctx.device_id));
+ ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, ctx.device_id));
break;
}
case kMultiProcessorCount: {
- ROCM_CALL(hipDeviceGetAttribute(
- &value, hipDeviceAttributeMultiprocessorCount, ctx.device_id));
+ ROCM_CALL(
+ hipDeviceGetAttribute(&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id));
break;
}
case kMaxThreadDimensions: {
int dims[3];
- ROCM_CALL(hipDeviceGetAttribute(
- &dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id));
- ROCM_CALL(hipDeviceGetAttribute(
- &dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id));
- ROCM_CALL(hipDeviceGetAttribute(
- &dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id));
+ ROCM_CALL(hipDeviceGetAttribute(&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id));
+ ROCM_CALL(hipDeviceGetAttribute(&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id));
+ ROCM_CALL(hipDeviceGetAttribute(&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id));
std::stringstream ss;
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
ROCM_CALL(hipFree(ptr));
}
- void CopyDataFromTo(const void* from, size_t from_offset, void* to,
- size_t to_offset, size_t size, TVMContext ctx_from,
- TVMContext ctx_to, DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final {
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
} else {
- hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size,
- hip_stream);
+ hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, hip_stream);
}
- } else if (ctx_from.device_type == kDLROCM &&
- ctx_to.device_type == kDLCPU) {
+ } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
ROCM_CALL(hipSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
- } else if (ctx_from.device_type == kDLCPU &&
- ctx_to.device_type == kDLROCM) {
+ } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
ROCM_CALL(hipSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
} else {
}
static const std::shared_ptr<ROCMDeviceAPI>& Global() {
- static std::shared_ptr<ROCMDeviceAPI> inst =
- std::make_shared<ROCMDeviceAPI>();
+ static std::shared_ptr<ROCMDeviceAPI> inst = std::make_shared<ROCMDeviceAPI>();
return inst;
}
private:
- static void GPUCopy(const void* from, void* to, size_t size,
- hipMemcpyKind kind, hipStream_t stream) {
+ static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind,
+ hipStream_t stream) {
if (stream != 0) {
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
} else {
ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}
-ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
- return ROCMThreadStore::Get();
-}
+ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); }
-TVM_REGISTER_GLOBAL("device_api.rocm")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
} // namespace runtime
} // namespace tvm
/*!
* \file rocm_module.cc
*/
-#include <tvm/runtime/registry.h>
+#include "rocm_module.h"
+
#include <hip/hip_runtime_api.h>
-#include <vector>
+#include <tvm/runtime/registry.h>
+
#include <array>
-#include <string>
#include <mutex>
+#include <string>
#include <unordered_map>
-#include "rocm_module.h"
-#include "rocm_common.h"
+#include <vector>
+
+#include "../file_util.h"
+#include "../meta_data.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
-#include "../meta_data.h"
-#include "../file_util.h"
+#include "rocm_common.h"
namespace tvm {
namespace runtime {
// The modules will be lazily loaded
class ROCMModuleNode : public runtime::ModuleNode {
public:
- explicit ROCMModuleNode(std::string data,
- std::string fmt,
+ explicit ROCMModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
- std::string hip_source,
- std::string assembly)
- : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) {
+ std::string hip_source, std::string assembly)
+ : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) {
std::fill(module_.begin(), module_.end(), nullptr);
}
// destructor
}
}
- const char* type_key() const final {
- return "hip";
- }
-
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final;
+ const char* type_key() const final { return "hip"; }
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
- void SaveToFile(const std::string& file_name,
- const std::string& format) final {
+ void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
// note: llvm and asm formats are not laodable, so we don't save them
}
std::string GetSource(const std::string& format) final {
- if (format == fmt_) { return data_; }
- if (format == "llvm" || format == "") { return hip_source_; }
- if (format == "asm") { return assembly_; }
+ if (format == fmt_) {
+ return data_;
+ }
+ if (format == "llvm" || format == "") {
+ return hip_source_;
+ }
+ if (format == "asm") {
+ return assembly_;
+ }
return "";
}
hipFunction_t func;
hipError_t result = hipModuleGetFunction(&func, module_[device_id], func_name.c_str());
if (result != hipSuccess) {
- LOG(FATAL)
- << "ROCMError: hipModuleGetFunction " << func_name
- << " failed with error: " << hipGetErrorString(result);
+ LOG(FATAL) << "ROCMError: hipModuleGetFunction " << func_name
+ << " failed with error: " << hipGetErrorString(result);
}
return func;
}
// get a global var from primary context in device_id
- hipDeviceptr_t GetGlobal(int device_id,
- const std::string& global_name,
- size_t expect_nbytes) {
+ hipDeviceptr_t GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
hipDeviceptr_t global = nullptr;
size_t nbytes = 0;
- ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes,
- module_[device_id], global_name.c_str()));
+ ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str()));
CHECK_EQ(nbytes, expect_nbytes);
return global;
}
class ROCMWrappedFunc {
public:
// initialize the ROCM function.
- void Init(ROCMModuleNode* m,
- ObjectPtr<Object> sptr,
- const std::string& func_name,
- size_t num_void_args,
- const std::vector<std::string>& thread_axis_tags) {
+ void Init(ROCMModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
+ size_t num_void_args, const std::vector<std::string>& thread_axis_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
}
// invoke the function with void arguments
- void operator()(TVMArgs args,
- TVMRetValue* rv,
- void* packed_args,
- size_t packed_nbytes) const {
+ void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const {
int device_id;
ROCM_CALL(hipGetDevice(&device_id));
if (fcache_[device_id] == nullptr) {
hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
- void* config[] = {
- HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args,
- HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes,
- HIP_LAUNCH_PARAM_END
- };
+ void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE,
+ &packed_nbytes, HIP_LAUNCH_PARAM_END};
// HIP supports only extra_args.
ROCM_DRIVER_CALL(hipModuleLaunchKernel(
- fcache_[device_id],
- wl.grid_dim(0),
- wl.grid_dim(1),
- wl.grid_dim(2),
- wl.block_dim(0),
- wl.block_dim(1),
- wl.block_dim(2),
- 0, strm, nullptr,
- reinterpret_cast<void**>(&config)));
+ fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0),
+ wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast<void**>(&config)));
}
private:
ThreadAxisConfig thread_axis_cfg_;
};
-
-PackedFunc ROCMModuleNode::GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc ROCMModuleNode::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
- CHECK_NE(name, symbol::tvm_module_main)
- << "Device function do not have main";
+ CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
return PackFuncPackedArg(f, info.arg_types);
}
-Module ROCMModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string hip_source,
- std::string assembly) {
+Module ROCMModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string hip_source,
+ std::string assembly) {
auto n = make_object<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
return Module(n);
}
-Module ROCMModuleLoadFile(const std::string& file_name,
- const std::string& format) {
+Module ROCMModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
}
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary);
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco")
-.set_body_typed(ROCMModuleLoadBinary);
-
-
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip")
-.set_body_typed(ROCMModuleLoadBinary);
-
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary);
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco")
-.set_body_typed(ROCMModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile);
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip")
-.set_body_typed(ROCMModuleLoadFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile);
} // namespace runtime
} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_RUNTIME_ROCM_ROCM_MODULE_H_
#include <tvm/runtime/module.h>
+
#include <memory>
-#include <vector>
#include <string>
#include <unordered_map>
+#include <vector>
+
#include "../meta_data.h"
namespace tvm {
* \param fmap The map function information map of each function.
* \param rocm_source Optional, rocm source file
*/
-Module ROCMModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string rocm_source,
- std::string assembly);
+Module ROCMModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string rocm_source,
+ std::string assembly);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_ROCM_ROCM_MODULE_H_
#include <dmlc/endian.h>
#include <tvm/runtime/c_runtime_api.h>
-#include "../rpc_protocol.h"
+
#include "../../../support/arena.h"
+#include "../rpc_protocol.h"
/*! \brief Whether or not to enable glog style DLOG */
#ifndef TVM_MINRPC_ENABLE_LOGGING
#endif
#ifndef MINRPC_CHECK
-#define MINRPC_CHECK(cond) \
+#define MINRPC_CHECK(cond) \
if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError);
#endif
#include <dmlc/logging.h>
#endif
-
namespace tvm {
namespace runtime {
* - PosixWrite, PosixRead, Close: posix style, read, write, close API.
* - Exit: exit with status code.
*/
-template<typename TIOHandler>
+template <typename TIOHandler>
class MinRPCServer {
public:
/*!
* \brief Constructor.
* \param io The IO handler.
*/
- explicit MinRPCServer(TIOHandler io)
- : io_(io), arena_(PageAllocator(io)) {}
+ explicit MinRPCServer(TIOHandler io) : io_(io), arena_(PageAllocator(io)) {}
/*! \brief Run the server loop until shutdown signal is received. */
void ServerLoop() {
this->Read(&call_handle);
RecvPackedSeq(&values, &tcodes, &num_args);
- int call_ecode = TVMFuncCall(
- reinterpret_cast<void*>(call_handle),
- values, tcodes, num_args,
- &(ret_value[1]), &(ret_tcode[1]));
+ int call_ecode = TVMFuncCall(reinterpret_cast<void*>(call_handle), values, tcodes, num_args,
+ &(ret_value[1]), &(ret_tcode[1]));
if (call_ecode == 0) {
// Return value encoding as in LocalSession
ret_value[2].v_handle = ret_value[1].v_handle;
ret_tcode[2] = kTVMOpaqueHandle;
this->ReturnPackedSeq(ret_value, ret_tcode, 3);
- } else if (rv_tcode == kTVMPackedFuncHandle ||
- rv_tcode == kTVMModuleHandle) {
+ } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
ret_tcode[1] = kTVMOpaqueHandle;
this->ReturnPackedSeq(ret_value, ret_tcode, 2);
} else {
data_ptr = reinterpret_cast<uint8_t*>(handle) + offset;
} else {
data_ptr = this->ArenaAlloc<uint8_t>(num_bytes);
- call_ecode = TVMDeviceCopyDataFromTo(
- reinterpret_cast<void*>(handle), offset,
- data_ptr, 0, num_bytes,
- ctx, DLContext{kDLCPU, 0},
- type_hint, nullptr);
+ call_ecode =
+ TVMDeviceCopyDataFromTo(reinterpret_cast<void*>(handle), offset, data_ptr, 0, num_bytes,
+ ctx, DLContext{kDLCPU, 0}, type_hint, nullptr);
// need sync to make sure that the copy is completed.
if (call_ecode == 0) {
- call_ecode = TVMSynchronize(
- ctx.device_type, ctx.device_id, nullptr);
+ call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
}
}
uint8_t* temp_data = this->ArenaAlloc<uint8_t>(num_bytes);
this->ReadArray(temp_data, num_bytes);
- call_ecode = TVMDeviceCopyDataFromTo(
- temp_data, 0,
- reinterpret_cast<void*>(handle), offset,
- num_bytes,
- DLContext{kDLCPU, 0}, ctx,
- type_hint, nullptr);
+ call_ecode =
+ TVMDeviceCopyDataFromTo(temp_data, 0, reinterpret_cast<void*>(handle), offset, num_bytes,
+ DLContext{kDLCPU, 0}, ctx, type_hint, nullptr);
// need sync to make sure that the copy is completed.
if (call_ecode == 0) {
- call_ecode = TVMSynchronize(
- ctx.device_type, ctx.device_id, nullptr);
+ call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
}
}
DLDataType type_hint = values[7].v_type;
TVMStreamHandle stream = values[8].v_handle;
- int call_ecode = TVMDeviceCopyDataFromTo(
- from, from_offset,
- to, to_offset, size,
- ctx_from, ctx_to, type_hint, stream);
+ int call_ecode = TVMDeviceCopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from,
+ ctx_to, type_hint, stream);
if (call_ecode == 0) {
this->ReturnVoid();
DLDataType type_hint = values[3].v_type;
void* handle;
- int call_ecode = TVMDeviceAllocDataSpace(
- ctx, nbytes, alignment, type_hint, &handle);
+ int call_ecode = TVMDeviceAllocDataSpace(ctx, nbytes, alignment, type_hint, &handle);
if (call_ecode == 0) {
this->ReturnHandle(handle);
io_.Exit(static_cast<int>(code));
}
- template<typename T>
+ template <typename T>
T* ArenaAlloc(int count) {
static_assert(std::is_pod<T>::value, "need to be trival");
return arena_.template allocate_<T>(count);
}
- template<typename T>
+ template <typename T>
void Read(T* data) {
static_assert(std::is_pod<T>::value, "need to be trival");
this->ReadRawBytes(data, sizeof(T));
}
- template<typename T>
+ template <typename T>
void ReadArray(T* data, size_t count) {
static_assert(std::is_pod<T>::value, "need to be trival");
return this->ReadRawBytes(data, sizeof(T) * count);
}
- template<typename T>
+ template <typename T>
void Write(const T& data) {
static_assert(std::is_pod<T>::value, "need to be trival");
return this->WriteRawBytes(&data, sizeof(T));
}
- template<typename T>
+ template <typename T>
void WriteArray(T* data, size_t count) {
static_assert(std::is_pod<T>::value, "need to be trival");
return this->WriteRawBytes(data, sizeof(T) * count);
public:
using ArenaPageHeader = tvm::support::ArenaPageHeader;
- explicit PageAllocator(TIOHandler io)
- : io_(io) {}
+ explicit PageAllocator(TIOHandler io) : io_(io) {}
ArenaPageHeader* allocate(size_t min_size) {
size_t npages = ((min_size + kPageSize - 1) / kPageSize);
void* data;
- if (TVMDeviceAllocDataSpace(
- DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign,
- DLDataType{kDLInt, 1, 1}, &data) != 0) {
+ if (TVMDeviceAllocDataSpace(DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign,
+ DLDataType{kDLInt, 1, 1}, &data) != 0) {
io_.Exit(static_cast<int>(RPCServerStatus::kAllocError));
}
TIOHandler io_;
};
- void RecvPackedSeq(TVMValue** out_values,
- int** out_tcodes,
- int* out_num_args) {
- RPCReference::RecvPackedSeq(
- out_values, out_tcodes, out_num_args, this);
+ void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) {
+ RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this);
}
void ReturnVoid() {
int32_t tcode = kTVMNullptr;
RPCCode code = RPCCode::kReturn;
- uint64_t packet_nbytes =
- sizeof(code) + sizeof(num_args) + sizeof(tcode);
+ uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
this->Write(packet_nbytes);
this->Write(code);
uint64_t encode_handle = reinterpret_cast<uint64_t>(handle);
uint64_t packet_nbytes =
- sizeof(code) + sizeof(num_args) +
- sizeof(tcode) + sizeof(encode_handle);
+ sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle);
this->Write(packet_nbytes);
this->Write(code);
this->Write(encode_handle);
}
- void ReturnException(const char* msg) {
- RPCReference::ReturnException(msg, this);
- }
+ void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); }
- void ReturnPackedSeq(const TVMValue* arg_values,
- const int* type_codes,
- int num_args) {
+ void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) {
RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this);
}
- void ReturnLastTVMError() {
- this->ReturnException(TVMGetLastError());
- }
+ void ReturnLastTVMError() { this->ReturnException(TVMGetLastError()); }
void ReadRawBytes(void* data, size_t size) {
uint8_t* buf = reinterpret_cast<uint8_t*>(data);
size_t ndone = 0;
- while (ndone < size) {
+ while (ndone < size) {
ssize_t ret = io_.PosixRead(buf, size - ndone);
if (ret == 0) {
if (allow_clean_shutdown_) {
}
void WriteRawBytes(const void* data, size_t size) {
- const uint8_t *buf = reinterpret_cast<const uint8_t*>(data);
+ const uint8_t* buf = reinterpret_cast<const uint8_t*>(data);
size_t ndone = 0;
- while (ndone < size) {
+ while (ndone < size) {
ssize_t ret = io_.PosixWrite(buf, size - ndone);
if (ret == 0 || ret == -1) {
this->ThrowError(RPCServerStatus::kWriteError);
#define TVM_ARENA_HAS_DESTRUCTOR 0
#include <unistd.h>
+
#include <cstdlib>
+
#include "minrpc_server.h"
namespace tvm {
class PosixIOHandler {
public:
explicit PosixIOHandler(int read_fd = 0, int write_fd = 1)
- : read_fd_(read_fd), write_fd_(write_fd) {
- }
+ : read_fd_(read_fd), write_fd_(write_fd) {}
- ssize_t PosixRead(void* data, size_t size) {
- return read(read_fd_, data, size);
- }
+ ssize_t PosixRead(void* data, size_t size) { return read(read_fd_, data, size); }
- ssize_t PosixWrite(const void* data, size_t size) {
- return write(write_fd_, data, size);
- }
+ ssize_t PosixWrite(const void* data, size_t size) { return write(write_fd_, data, size); }
- void Exit(int code) {
- exit(code);
- }
+ void Exit(int code) { exit(code); }
void Close() {
if (read_fd_ != 0) close(read_fd_);
/*!
* \file rpc_channel.cc
*/
-#include <string>
#include "rpc_channel.h"
+#include <string>
+
namespace tvm {
namespace runtime {
#define TVM_RUNTIME_RPC_RPC_CHANNEL_H_
#include <tvm/runtime/packed_func.h>
+
#include <utility>
namespace tvm {
* \file rpc_device_api.cc
*/
#include <dmlc/logging.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include <utility>
+
#include "rpc_session.h"
namespace tvm {
GetSess(ctx)->GetDeviceAPI(remote_ctx)->GetAttr(remote_ctx, kind, rv);
}
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) final {
auto sess = GetSess(ctx);
auto remote_ctx = RemoveSessMask(ctx);
- void *data = sess->GetDeviceAPI(remote_ctx)->AllocDataSpace(
- remote_ctx, nbytes, alignment, type_hint);
+ void* data =
+ sess->GetDeviceAPI(remote_ctx)->AllocDataSpace(remote_ctx, nbytes, alignment, type_hint);
RemoteSpace* space = new RemoteSpace();
space->data = data;
RemoteSpace* space = static_cast<RemoteSpace*>(ptr);
auto remote_ctx = RemoveSessMask(ctx);
try {
- GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace(
- remote_ctx, space->data);
+ GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace(remote_ctx, space->data);
} catch (const dmlc::Error& e) {
// fault tolerance to remote close.
}
delete space;
}
- void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final {
int from_dev_type = ctx_from.device_type;
int to_dev_type = ctx_to.device_type;
- if (from_dev_type > kRPCSessMask &&
- to_dev_type > kRPCSessMask) {
+ if (from_dev_type > kRPCSessMask && to_dev_type > kRPCSessMask) {
CHECK(ctx_from.device_type == ctx_to.device_type)
<< "Cannot copy across two different remote session";
auto remote_ctx_from = RemoveSessMask(ctx_from);
auto remote_ctx_to = RemoveSessMask(ctx_to);
auto remote_ctx = remote_ctx_from;
if (remote_ctx.device_type == kDLCPU) remote_ctx = remote_ctx_to;
- GetSess(ctx_from)->GetDeviceAPI(remote_ctx)
+ GetSess(ctx_from)
+ ->GetDeviceAPI(remote_ctx)
->CopyDataFromTo(static_cast<const RemoteSpace*>(from)->data, from_offset,
- static_cast<const RemoteSpace*>(to)->data, to_offset,
- size, remote_ctx_from, remote_ctx_to, type_hint, stream);
- } else if (from_dev_type > kRPCSessMask &&
- to_dev_type == kDLCPU) {
+ static_cast<const RemoteSpace*>(to)->data, to_offset, size,
+ remote_ctx_from, remote_ctx_to, type_hint, stream);
+ } else if (from_dev_type > kRPCSessMask && to_dev_type == kDLCPU) {
auto remote_ctx_from = RemoveSessMask(ctx_from);
- GetSess(ctx_from)->CopyFromRemote(
- static_cast<const RemoteSpace*>(from)->data, from_offset,
- to, to_offset, size, remote_ctx_from, type_hint);
- } else if (from_dev_type == kDLCPU &&
- to_dev_type > kRPCSessMask) {
+ GetSess(ctx_from)->CopyFromRemote(static_cast<const RemoteSpace*>(from)->data, from_offset,
+ to, to_offset, size, remote_ctx_from, type_hint);
+ } else if (from_dev_type == kDLCPU && to_dev_type > kRPCSessMask) {
auto remote_ctx_to = RemoveSessMask(ctx_to);
- GetSess(ctx_to)->CopyToRemote(
- const_cast<void*>(from), from_offset,
- static_cast<const RemoteSpace*>(to)->data, to_offset,
- size, remote_ctx_to, type_hint);
+ GetSess(ctx_to)->CopyToRemote(const_cast<void*>(from), from_offset,
+ static_cast<const RemoteSpace*>(to)->data, to_offset, size,
+ remote_ctx_to, type_hint);
} else {
LOG(FATAL) << "expect copy from/to remote or between remote";
}
std::shared_ptr<RPCSession> GetSess(TVMContext ctx) {
int dev_type = ctx.device_type;
CHECK_GE(dev_type, kRPCSessMask);
- int tbl_index = dev_type / kRPCSessMask - 1;
+ int tbl_index = dev_type / kRPCSessMask - 1;
return RPCSession::Get(tbl_index);
}
}
};
-TVM_REGISTER_GLOBAL("device_api.rpc")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- static RPCDeviceAPI inst;
- DeviceAPI* ptr = &inst;
- *rv = static_cast<void*>(ptr);
- });
+TVM_REGISTER_GLOBAL("device_api.rpc").set_body([](TVMArgs args, TVMRetValue* rv) {
+ static RPCDeviceAPI inst;
+ DeviceAPI* ptr = &inst;
+ *rv = static_cast<void*>(ptr);
+});
} // namespace runtime
} // namespace tvm
* \file rpc_session.cc
* \brief RPC session for remote function call.
*/
+#include "rpc_endpoint.h"
+
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
-#include <memory>
+
+#include <algorithm>
#include <array>
-#include <string>
#include <chrono>
-#include <vector>
-#include <utility>
#include <cmath>
-#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
-#include "rpc_endpoint.h"
-#include "rpc_local_session.h"
-#include "../object_internal.h"
-#include "../../support/ring_buffer.h"
#include "../../support/arena.h"
+#include "../../support/ring_buffer.h"
+#include "../object_internal.h"
+#include "rpc_local_session.h"
namespace tvm {
namespace runtime {
*/
class RPCEndpoint::EventHandler : public dmlc::Stream {
public:
- EventHandler(support::RingBuffer* reader,
- support::RingBuffer* writer,
- std::string name,
- std::string* remote_key,
- std::function<void()> flush_writer)
+ EventHandler(support::RingBuffer* reader, support::RingBuffer* writer, std::string name,
+ std::string* remote_key, std::function<void()> flush_writer)
: reader_(reader),
writer_(writer),
name_(name),
}
/*! \return Whether we are ready to handle next request. */
- bool Ready() const {
- return reader_->bytes_available() >= pending_request_bytes_;
- }
+ bool Ready() const { return reader_->bytes_available() >= pending_request_bytes_; }
/*! \return Whether we can perform a clean shutdown */
- bool CanCleanShutdown() const {
- return state_ == kRecvPacketNumBytes;
- }
+ bool CanCleanShutdown() const { return state_ == kRecvPacketNumBytes; }
/*! \brief Finish the copy ack stage. */
- void FinishCopyAck() {
- this->SwitchToState(kRecvPacketNumBytes);
- }
+ void FinishCopyAck() { this->SwitchToState(kRecvPacketNumBytes); }
/*!
* \brief Enter the io loop until the next event.
* \param setreturn The function to set the return value encoding.
* \return The function to set return values when there is a return event.
*/
- RPCCode HandleNextEvent(bool client_mode,
- bool async_server_mode,
+ RPCCode HandleNextEvent(bool client_mode, bool async_server_mode,
RPCSession::FEncodeReturn setreturn) {
std::swap(client_mode_, client_mode);
std::swap(async_server_mode_, async_server_mode);
RPCCode status = RPCCode::kNone;
- while (status == RPCCode::kNone &&
- state_ != kWaitForAsyncCallback &&
- this->Ready()) {
+ while (status == RPCCode::kNone && state_ != kWaitForAsyncCallback && this->Ready()) {
switch (state_) {
- case kInitHeader: HandleInitHeader(); break;
+ case kInitHeader:
+ HandleInitHeader();
+ break;
case kRecvPacketNumBytes: {
uint64_t packet_nbytes;
CHECK(this->Read(&packet_nbytes));
* \param arg_values The argument values.
* \param type_codes The type codes.
*/
- void ValidateArguments(const TVMValue* arg_values,
- const int* type_codes,
- int num_args) {
+ void ValidateArguments(const TVMValue* arg_values, const int* type_codes, int num_args) {
TVMArgs args(arg_values, type_codes, num_args);
for (int i = 0; i < num_args; ++i) {
int tcode = type_codes[i];
if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) {
- LOG(FATAL) << "ValueError: Cannot pass argument " << i
- << ", type " << args[i].AsObjectRef<ObjectRef>()->GetTypeKey()
- << " is not supported by RPC";
+ LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type "
+ << args[i].AsObjectRef<ObjectRef>()->GetTypeKey() << " is not supported by RPC";
} else if (tcode == kTVMContext) {
DLContext ctx = args[i];
CHECK_LT(static_cast<int>(ctx.device_type), kRPCSessMask)
LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code);
}
- uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values,
- const int* type_codes,
- int num_args,
+ uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, int num_args,
bool client_mode) {
- return RPCReference::PackedSeqGetNumBytes(
- arg_values, type_codes, num_args, client_mode, this);
+ return RPCReference::PackedSeqGetNumBytes(arg_values, type_codes, num_args, client_mode, this);
}
- void SendPackedSeq(const TVMValue* arg_values,
- const int* type_codes,
- int num_args,
+ void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args,
bool client_mode) {
- RPCReference::SendPackedSeq(
- arg_values, type_codes, num_args, client_mode, this);
+ RPCReference::SendPackedSeq(arg_values, type_codes, num_args, client_mode, this);
}
// Endian aware IO handling
using Stream::Read;
- using Stream::Write;
using Stream::ReadArray;
+ using Stream::Write;
using Stream::WriteArray;
bool Read(RPCCode* code) {
this->Write(cdata);
}
- template<typename T>
+ template <typename T>
T* ArenaAlloc(int count) {
static_assert(std::is_pod<T>::value, "need to be trival");
return arena_.template allocate_<T>(count);
void SwitchToState(State state) {
// invariant
if (state != kCopyAckReceived) {
- CHECK_EQ(pending_request_bytes_, 0U)
- << "state=" << state;
+ CHECK_EQ(pending_request_bytes_, 0U) << "state=" << state;
}
// need to actively flush the writer
// so the data get pushed out.
flush_writer_();
}
state_ = state;
- CHECK(state != kInitHeader)
- << "cannot switch to init header";
+ CHECK(state != kInitHeader) << "cannot switch to init header";
if (state == kRecvPacketNumBytes) {
this->RequestBytes(sizeof(uint64_t));
// recycle arena for the next session.
if (code >= RPCCode::kSyscallCodeStart) {
this->HandleSyscall(code);
} else {
- switch (code) {
- case RPCCode::kInitServer: {
- this->HandleInitServer();
- break;
- }
- case RPCCode::kCallFunc: {
- this->HandleNormalCallFunc();
- break;
- }
- case RPCCode::kCopyFromRemote: {
- this->HandleCopyFromRemote();
- break;
- }
- case RPCCode::kCopyToRemote: {
- this->HandleCopyToRemote();
- break;
- }
- case RPCCode::kException:
- case RPCCode::kReturn: {
- this->HandleReturn(code, setreturn);
- break;
- }
- case RPCCode::kCopyAck: {
- this->SwitchToState(kCopyAckReceived);
- break;
- }
- case RPCCode::kShutdown: {
- this->SwitchToState(kShutdownReceived);
- break;
- }
- default: LOG(FATAL) << "Unknown event " << static_cast<int>(code);
+ switch (code) {
+ case RPCCode::kInitServer: {
+ this->HandleInitServer();
+ break;
}
+ case RPCCode::kCallFunc: {
+ this->HandleNormalCallFunc();
+ break;
+ }
+ case RPCCode::kCopyFromRemote: {
+ this->HandleCopyFromRemote();
+ break;
+ }
+ case RPCCode::kCopyToRemote: {
+ this->HandleCopyToRemote();
+ break;
+ }
+ case RPCCode::kException:
+ case RPCCode::kReturn: {
+ this->HandleReturn(code, setreturn);
+ break;
+ }
+ case RPCCode::kCopyAck: {
+ this->SwitchToState(kCopyAckReceived);
+ break;
+ }
+ case RPCCode::kShutdown: {
+ this->SwitchToState(kShutdownReceived);
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown event " << static_cast<int>(code);
+ }
}
}
* \brief Return exception to the remote.
* \param err_msg The error message.
*/
- void ReturnException(const char* err_msg) {
- RPCReference::ReturnException(err_msg, this);
- }
+ void ReturnException(const char* err_msg) { RPCReference::ReturnException(err_msg, this); }
/*!
* \brief Return nullptr to the remote.
* \param err_msg The error message.
*/
- void ReturnVoid() {
- RPCReference::ReturnVoid(this);
- }
+ void ReturnVoid() { RPCReference::ReturnVoid(this); }
/*!
* \brief Return a packed sequence to the remote.
// switch to the state before sending exception.
this->SwitchToState(kRecvPacketNumBytes);
std::string msg = args[0];
- LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg;
+ LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg;
}
CHECK(setreturn != nullptr) << "fsetreturn not available";
// When session is local, we can directly treat handle
// as the cpu pointer without allocating a temp space.
- if (ctx.device_type == kDLCPU &&
- sess->IsLocalSession() &&
- DMLC_IO_NO_ENDIAN_SWAP) {
+ if (ctx.device_type == kDLCPU && sess->IsLocalSession() && DMLC_IO_NO_ENDIAN_SWAP) {
char* data_ptr = reinterpret_cast<char*>(handle) + offset;
fcopyack(data_ptr, num_bytes);
} else {
char* data_ptr = this->ArenaAlloc<char>(num_bytes);
- auto on_copy_complete = [this, elem_bytes, num_bytes, data_ptr, fcopyack](
- RPCCode status, TVMArgs args) {
+ auto on_copy_complete = [this, elem_bytes, num_bytes, data_ptr, fcopyack](RPCCode status,
+ TVMArgs args) {
if (status == RPCCode::kException) {
this->ReturnException(args.values[0].v_str);
this->SwitchToState(kRecvPacketNumBytes);
};
this->SwitchToState(kWaitForAsyncCallback);
- sess->AsyncCopyFromRemote(
- reinterpret_cast<void*>(handle), offset,
- data_ptr, 0,
- num_bytes, ctx, type_hint,
- on_copy_complete);
+ sess->AsyncCopyFromRemote(reinterpret_cast<void*>(handle), offset, data_ptr, 0, num_bytes,
+ ctx, type_hint, on_copy_complete);
}
}
// When session is local, we can directly treat handle
// as the cpu pointer without allocating a temp space.
if (ctx.device_type == kDLCPU && sess->IsLocalSession()) {
- char* dptr = reinterpret_cast<char*>(handle) + offset;
- this->ReadArray(dptr, num_bytes);
-
- if (!DMLC_IO_NO_ENDIAN_SWAP) {
- dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes);
- }
- this->ReturnVoid();
- this->SwitchToState(kRecvPacketNumBytes);
+ char* dptr = reinterpret_cast<char*>(handle) + offset;
+ this->ReadArray(dptr, num_bytes);
+
+ if (!DMLC_IO_NO_ENDIAN_SWAP) {
+ dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes);
+ }
+ this->ReturnVoid();
+ this->SwitchToState(kRecvPacketNumBytes);
} else {
char* temp_data = this->ArenaAlloc<char>(num_bytes);
this->ReadArray(temp_data, num_bytes);
};
this->SwitchToState(kWaitForAsyncCallback);
- sess->AsyncCopyToRemote(
- temp_data, 0,
- reinterpret_cast<void*>(handle), offset,
- num_bytes, ctx, type_hint,
- on_copy_complete);
+ sess->AsyncCopyToRemote(temp_data, 0, reinterpret_cast<void*>(handle), offset, num_bytes, ctx,
+ type_hint, on_copy_complete);
}
}
TVMArgs args = RecvPackedSeq();
this->SwitchToState(kWaitForAsyncCallback);
- GetServingSession()->AsyncCallFunc(
- reinterpret_cast<void*>(call_handle),
- args.values, args.type_codes, args.size(),
- [this](RPCCode status, TVMArgs args) {
- if (status == RPCCode::kException) {
- this->ReturnException(args.values[0].v_str);
- } else {
- this->ReturnPackedSeq(args);
- }
- this->SwitchToState(kRecvPacketNumBytes);
- });
+ GetServingSession()->AsyncCallFunc(reinterpret_cast<void*>(call_handle), args.values,
+ args.type_codes, args.size(),
+ [this](RPCCode status, TVMArgs args) {
+ if (status == RPCCode::kException) {
+ this->ReturnException(args.values[0].v_str);
+ } else {
+ this->ReturnPackedSeq(args);
+ }
+ this->SwitchToState(kRecvPacketNumBytes);
+ });
}
void HandleInitServer() {
TVMArgs args = RecvPackedSeq();
try {
- CHECK(serving_session_ == nullptr)
- << "Server has already been initialized";
+ CHECK(serving_session_ == nullptr) << "Server has already been initialized";
std::string server_protocol_ver = kRPCProtocolVer;
CHECK_EQ(client_protocol_ver, server_protocol_ver)
}
auto* fconstructor = Registry::Get(constructor_name);
- CHECK(fconstructor != nullptr)
- << " Cannot find session constructor " << constructor_name;
+ CHECK(fconstructor != nullptr) << " Cannot find session constructor " << constructor_name;
TVMRetValue con_ret;
try {
fconstructor->CallPacked(constructor_args, &con_ret);
} catch (const dmlc::Error& e) {
LOG(FATAL) << "Server[" << name_ << "]:"
- << " Error caught from session constructor " << constructor_name
- << ":\n" << e.what();
+ << " Error caught from session constructor " << constructor_name << ":\n"
+ << e.what();
}
CHECK_EQ(con_ret.type_code(), kTVMModuleHandle)
<< "Server[" << name_ << "]:"
- << " Constructor " << constructor_name
- << " need to return an RPCModule";
+ << " Constructor " << constructor_name << " need to return an RPCModule";
Module mod = con_ret;
std::string tkey = mod->type_key();
- CHECK_EQ(tkey, "rpc")
- << "Constructor " << constructor_name << " to return an RPCModule";
+ CHECK_EQ(tkey, "rpc") << "Constructor " << constructor_name << " to return an RPCModule";
serving_session_ = RPCModuleGetSession(mod);
this->ReturnVoid();
- } catch (const std::runtime_error &e) {
+ } catch (const std::runtime_error& e) {
this->ReturnException(e.what());
}
TVMStreamHandle handle = args[1];
this->SwitchToState(kWaitForAsyncCallback);
- GetServingSession()->AsyncStreamWait(
- ctx, handle, [this](RPCCode status, TVMArgs args) {
- if (status == RPCCode::kException) {
- this->ReturnException(args.values[0].v_str);
- } else {
- this->ReturnVoid();
- }
- this->SwitchToState(kRecvPacketNumBytes);
- });
+ GetServingSession()->AsyncStreamWait(ctx, handle, [this](RPCCode status, TVMArgs args) {
+ if (status == RPCCode::kException) {
+ this->ReturnException(args.values[0].v_str);
+ } else {
+ this->ReturnVoid();
+ }
+ this->SwitchToState(kRecvPacketNumBytes);
+ });
} catch (const std::runtime_error& e) {
this->ReturnException(e.what());
this->SwitchToState(kRecvPacketNumBytes);
}
// Handler for special syscalls that have a specific RPCCode.
- template<typename F>
+ template <typename F>
void SysCallHandler(F f) {
TVMArgs args = RecvPackedSeq();
try {
return size;
}
// wriite the data to the channel.
- void Write(const void* data, size_t size) final {
- writer_->Write(data, size);
- }
+ void Write(const void* data, size_t size) final { writer_->Write(data, size); }
// Number of pending bytes requests
size_t pending_request_bytes_{0};
// The ring buffer to read data from.
std::function<void()> flush_writer_;
};
-RPCCode RPCEndpoint::HandleUntilReturnEvent(
- bool client_mode,
- RPCSession::FEncodeReturn setreturn) {
+RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) {
RPCCode code = RPCCode::kCallFunc;
- while (code != RPCCode::kReturn &&
- code != RPCCode::kShutdown &&
- code != RPCCode::kCopyAck) {
+ while (code != RPCCode::kReturn && code != RPCCode::kShutdown && code != RPCCode::kCopyAck) {
while (writer_.bytes_available() != 0) {
- writer_.ReadWithCallback([this](const void *data, size_t size) {
- return channel_->Send(data, size);
- }, writer_.bytes_available());
+ writer_.ReadWithCallback(
+ [this](const void* data, size_t size) { return channel_->Send(data, size); },
+ writer_.bytes_available());
}
size_t bytes_needed = handler_->BytesNeeded();
if (bytes_needed != 0) {
- size_t n = reader_.WriteWithCallback([this](void* data, size_t size) {
- return channel_->Recv(data, size);
- }, bytes_needed);
+ size_t n = reader_.WriteWithCallback(
+ [this](void* data, size_t size) { return channel_->Recv(data, size); }, bytes_needed);
if (n == 0) {
if (handler_->CanCleanShutdown()) {
return RPCCode::kShutdown;
// callback to flush the writer.
auto flush_writer = [this]() {
while (writer_.bytes_available() != 0) {
- size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) {
- return channel_->Send(data, size);
- }, writer_.bytes_available());
+ size_t n = writer_.ReadWithCallback(
+ [this](const void* data, size_t size) { return channel_->Send(data, size); },
+ writer_.bytes_available());
if (n == 0) break;
}
};
// Event handler
- handler_ = std::make_shared<EventHandler>(
- &reader_, &writer_, name_, &remote_key_, flush_writer);
+ handler_ = std::make_shared<EventHandler>(&reader_, &writer_, name_, &remote_key_, flush_writer);
// Quick function to for syscall remote.
syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) {
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = static_cast<RPCCode>(all_args[0].operator int());
- TVMArgs args(all_args.values + 1, all_args.type_codes +1, all_args.num_args -1);
+ TVMArgs args(all_args.values + 1, all_args.type_codes + 1, all_args.num_args - 1);
- uint64_t packet_nbytes =
- sizeof(code) +
- handler_->PackedSeqGetNumBytes(
- args.values, args.type_codes, args.num_args, true);
+ uint64_t packet_nbytes = sizeof(code) + handler_->PackedSeqGetNumBytes(
+ args.values, args.type_codes, args.num_args, true);
// All packet begins with packet nbytes
handler_->Write(packet_nbytes);
});
}
-std::shared_ptr<RPCEndpoint> RPCEndpoint::Create(
- std::unique_ptr<RPCChannel> channel,
- std::string name,
- std::string remote_key) {
+std::shared_ptr<RPCEndpoint> RPCEndpoint::Create(std::unique_ptr<RPCChannel> channel,
+ std::string name, std::string remote_key) {
std::shared_ptr<RPCEndpoint> endpt = std::make_shared<RPCEndpoint>();
endpt->channel_ = std::move(channel);
endpt->name_ = std::move(name);
return endpt;
}
-RPCEndpoint::~RPCEndpoint() {
- this->Shutdown();
-}
+RPCEndpoint::~RPCEndpoint() { this->Shutdown(); }
void RPCEndpoint::Shutdown() {
if (channel_ != nullptr) {
// flush all writing buffer to output channel.
try {
while (writer_.bytes_available() != 0) {
- size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) {
- return channel_->Send(data, size);
- }, writer_.bytes_available());
+ size_t n = writer_.ReadWithCallback(
+ [this](const void* data, size_t size) { return channel_->Send(data, size); },
+ writer_.bytes_available());
if (n == 0) break;
}
} catch (const dmlc::Error& e) {
code = handler_->HandleNextEvent(false, true, [](TVMArgs) {});
}
if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
- writer_.ReadWithCallback([this](const void *data, size_t size) {
- return channel_->Send(data, size);
- }, writer_.bytes_available());
+ writer_.ReadWithCallback(
+ [this](const void* data, size_t size) { return channel_->Send(data, size); },
+ writer_.bytes_available());
}
CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
if (code == RPCCode::kShutdown) return 0;
uint64_t length = protocol_ver.length();
uint64_t packet_nbytes =
- sizeof(code) +
- sizeof(length) +
- length +
- handler_->PackedSeqGetNumBytes(
- args.values, args.type_codes, args.num_args, true);
+ sizeof(code) + sizeof(length) + length +
+ handler_->PackedSeqGetNumBytes(args.values, args.type_codes, args.num_args, true);
// All packet begins with packet nbytes
handler_->Write(packet_nbytes);
}
// Get remote function with name
-void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h,
- const TVMValue* arg_values,
- const int* arg_type_codes,
- int num_args,
+void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, const TVMValue* arg_values,
+ const int* arg_type_codes, int num_args,
RPCSession::FEncodeReturn encode_return) {
std::lock_guard<std::mutex> lock(mutex_);
uint64_t handle = reinterpret_cast<uint64_t>(h);
uint64_t packet_nbytes =
- sizeof(code) +
- sizeof(handle) +
- handler_->PackedSeqGetNumBytes(
- arg_values, arg_type_codes, num_args, true);
+ sizeof(code) + sizeof(handle) +
+ handler_->PackedSeqGetNumBytes(arg_values, arg_type_codes, num_args, true);
handler_->Write(packet_nbytes);
handler_->Write(code);
handler_->Write(handle);
- handler_->SendPackedSeq(
- arg_values, arg_type_codes, num_args, true);
+ handler_->SendPackedSeq(arg_values, arg_type_codes, num_args, true);
code = HandleUntilReturnEvent(true, encode_return);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
}
-void RPCEndpoint::CopyToRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t data_size,
- TVMContext ctx_to,
- DLDataType type_hint) {
+void RPCEndpoint::CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset,
+ size_t data_size, TVMContext ctx_to, DLDataType type_hint) {
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = RPCCode::kCopyToRemote;
uint64_t handle = reinterpret_cast<uint64_t>(to);
uint64_t offset = static_cast<uint64_t>(to_offset);
uint64_t size = static_cast<uint64_t>(data_size);
- uint64_t packet_nbytes =
- sizeof(code) +
- sizeof(handle) +
- sizeof(offset) +
- sizeof(size) +
- sizeof(ctx_to) +
- sizeof(type_hint) +
- data_size;
+ uint64_t packet_nbytes = sizeof(code) + sizeof(handle) + sizeof(offset) + sizeof(size) +
+ sizeof(ctx_to) + sizeof(type_hint) + data_size;
handler_->Write(packet_nbytes);
handler_->Write(code);
handler_->Write(type_hint);
handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size);
- CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kReturn);
+ CHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kReturn);
}
-void RPCEndpoint::CopyFromRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t data_size,
- TVMContext ctx_from,
- DLDataType type_hint) {
+void RPCEndpoint::CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset,
+ size_t data_size, TVMContext ctx_from, DLDataType type_hint) {
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = RPCCode::kCopyFromRemote;
uint64_t handle = reinterpret_cast<uint64_t>(from);
uint64_t offset = static_cast<uint64_t>(from_offset);
uint64_t size = static_cast<uint64_t>(data_size);
- uint64_t packet_nbytes =
- sizeof(code) +
- sizeof(handle) +
- sizeof(offset) +
- sizeof(size) +
- sizeof(ctx_from) +
- sizeof(type_hint);
+ uint64_t packet_nbytes = sizeof(code) + sizeof(handle) + sizeof(offset) + sizeof(size) +
+ sizeof(ctx_from) + sizeof(type_hint);
handler_->Write(packet_nbytes);
handler_->Write(code);
handler_->Write(type_hint);
TVMRetValue rv;
- CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kCopyAck);
+ CHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kCopyAck);
handler_->ReadArray(reinterpret_cast<char*>(to) + to_offset, data_size);
handler_->FinishCopyAck();
}
*rv = handler->GetFunction(name);
}
-void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
void* handle = args[0];
int type_code = args[1];
handler->FreeHandle(handle, type_code);
}
-void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
TVMContext ctx = args[0];
handler->GetDeviceAPI(ctx)->SetDevice(ctx);
}
-void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
TVMContext ctx = args[0];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[1].operator int());
if (kind == kExist) {
*rv = 0;
}
} else {
- handler->GetDeviceAPI(ctx)->GetAttr(
- ctx, static_cast<DeviceAttrKind>(kind), rv);
+ handler->GetDeviceAPI(ctx)->GetAttr(ctx, static_cast<DeviceAttrKind>(kind), rv);
}
}
-void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
TVMContext ctx = args[0];
uint64_t nbytes = args[1];
uint64_t alignment = args[2];
DLDataType type_hint = args[3];
- void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace(
- ctx, nbytes, alignment, type_hint);
+ void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace(ctx, nbytes, alignment, type_hint);
*rv = data;
}
-void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
TVMContext ctx = args[0];
void* ptr = args[1];
handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr);
}
-void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
void* from = args[0];
uint64_t from_offset = args[1];
void* to = args[2];
if (ctx.device_type == kDLCPU) {
ctx = ctx_to;
} else {
- CHECK(ctx_to.device_type == kDLCPU ||
- ctx_to.device_type == ctx_from.device_type)
+ CHECK(ctx_to.device_type == kDLCPU || ctx_to.device_type == ctx_from.device_type)
<< "Can not copy across different ctx types directly";
}
- handler->GetDeviceAPI(ctx)->CopyDataFromTo(
- from, from_offset,
- to, to_offset,
- size, ctx_from, ctx_to, type_hint, stream);
+ handler->GetDeviceAPI(ctx)->CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from,
+ ctx_to, type_hint, stream);
}
void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) {
// Event handler sit at clean state at this point.
switch (code) {
// system functions
- case RPCCode::kFreeHandle: SysCallHandler(RPCFreeHandle); break;
- case RPCCode::kGetGlobalFunc: SysCallHandler(RPCGetGlobalFunc); break;
- case RPCCode::kDevSetDevice: SysCallHandler(RPCDevSetDevice); break;
- case RPCCode::kDevGetAttr: SysCallHandler(RPCDevGetAttr); break;
- case RPCCode::kDevAllocData: SysCallHandler(RPCDevAllocData); break;
- case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break;
- case RPCCode::kDevStreamSync: this->HandleSyscallStreamSync(); break;
- case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break;
- default: LOG(FATAL) << "Unknown event " << static_cast<int>(code);
+ case RPCCode::kFreeHandle:
+ SysCallHandler(RPCFreeHandle);
+ break;
+ case RPCCode::kGetGlobalFunc:
+ SysCallHandler(RPCGetGlobalFunc);
+ break;
+ case RPCCode::kDevSetDevice:
+ SysCallHandler(RPCDevSetDevice);
+ break;
+ case RPCCode::kDevGetAttr:
+ SysCallHandler(RPCDevGetAttr);
+ break;
+ case RPCCode::kDevAllocData:
+ SysCallHandler(RPCDevAllocData);
+ break;
+ case RPCCode::kDevFreeData:
+ SysCallHandler(RPCDevFreeData);
+ break;
+ case RPCCode::kDevStreamSync:
+ this->HandleSyscallStreamSync();
+ break;
+ case RPCCode::kCopyAmongRemote:
+ SysCallHandler(RPCCopyAmongRemote);
+ break;
+ default:
+ LOG(FATAL) << "Unknown event " << static_cast<int>(code);
}
if (state_ != kWaitForAsyncCallback) {
/*!
* \brief RPC client session that proxies all calls to an endpoint.
*/
-class RPCClientSession : public RPCSession,
- public DeviceAPI {
+class RPCClientSession : public RPCSession, public DeviceAPI {
public:
/*!
* \brief param endpoint The client endpoint of the session.
*/
- explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint)
- : endpoint_(endpoint) {}
+ explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint) : endpoint_(endpoint) {}
// function overrides
PackedFuncHandle GetFunction(const std::string& name) final {
return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name);
}
- void CallFunc(PackedFuncHandle func,
- const TVMValue* arg_values,
- const int* arg_type_codes,
- int num_args,
- const FEncodeReturn& fencode_return) final {
- endpoint_->CallFunc(
- func, arg_values, arg_type_codes, num_args, fencode_return);
+ void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes,
+ int num_args, const FEncodeReturn& fencode_return) final {
+ endpoint_->CallFunc(func, arg_values, arg_type_codes, num_args, fencode_return);
}
- void CopyToRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t nbytes,
- TVMContext ctx_to,
- DLDataType type_hint) final {
- endpoint_->CopyToRemote(
- from, from_offset, to, to_offset, nbytes, ctx_to, type_hint);
+ void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes,
+ TVMContext ctx_to, DLDataType type_hint) final {
+ endpoint_->CopyToRemote(from, from_offset, to, to_offset, nbytes, ctx_to, type_hint);
}
- void CopyFromRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t nbytes,
- TVMContext ctx_from,
- DLDataType type_hint) final {
- endpoint_->CopyFromRemote(
- from, from_offset, to, to_offset, nbytes, ctx_from, type_hint);
+ void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes,
+ TVMContext ctx_from, DLDataType type_hint) final {
+ endpoint_->CopyFromRemote(from, from_offset, to, to_offset, nbytes, ctx_from, type_hint);
}
void FreeHandle(void* handle, int type_code) final {
endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code);
}
-
- void SetDevice(TVMContext ctx) final {
- endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx);
- }
+ void SetDevice(TVMContext ctx) final { endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx); }
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
if (ctx.device_type == kDLCPU && kind == kExist) {
}
}
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) final {
- return endpoint_->SysCallRemote(
- RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint);
+ return endpoint_->SysCallRemote(RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint);
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
endpoint_->SysCallRemote(RPCCode::kDevFreeData, ctx, ptr);
}
- void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final {
- endpoint_->SysCallRemote(
- RPCCode::kCopyAmongRemote,
- const_cast<void*>(from), from_offset,
- to, to_offset,
- size,
- ctx_from, ctx_to,
- type_hint, stream);
+ endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, const_cast<void*>(from), from_offset, to,
+ to_offset, size, ctx_from, ctx_to, type_hint, stream);
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
endpoint_->SysCallRemote(RPCCode::kDevStreamSync, ctx, stream);
}
- DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final {
- return this;
- }
+ DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final { return this; }
- bool IsLocalSession() const final {
- return false;
- }
+ bool IsLocalSession() const final { return false; }
private:
std::shared_ptr<RPCEndpoint> endpoint_;
};
-std::shared_ptr<RPCSession>
-CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
+std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
return std::make_shared<RPCClientSession>(endpoint);
}
#define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_
#include <tvm/runtime/packed_func.h>
+
+#include <memory>
#include <mutex>
#include <string>
-#include <memory>
#include <utility>
-#include "rpc_session.h"
+
+#include "../../support/ring_buffer.h"
#include "rpc_channel.h"
#include "rpc_protocol.h"
-#include "../../support/ring_buffer.h"
+#include "rpc_session.h"
namespace tvm {
namespace runtime {
kGetPendingMatchKeys = 7
};
-
/*!
* \brief Communication endpoints to connect local and remote RPC sessions.
* An endpoint can either be a client or a server.
* \param num_args Number of arguments.
* \param fencode_return The function to receive return value encodings.
*/
- void CallFunc(RPCSession::PackedFuncHandle handle,
- const TVMValue* arg_values,
- const int* arg_type_codes,
- int num_args,
- RPCSession::FEncodeReturn encode_return);
+ void CallFunc(RPCSession::PackedFuncHandle handle, const TVMValue* arg_values,
+ const int* arg_type_codes, int num_args, RPCSession::FEncodeReturn encode_return);
/*!
* \brief Copy bytes into remote array content.
* \param from The source host data.
* \param ctx_to The target context.
* \param type_hint Hint of content data type.
*/
- void CopyToRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t nbytes,
- TVMContext ctx_to,
- DLDataType type_hint);
+ void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes,
+ TVMContext ctx_to, DLDataType type_hint);
/*!
* \brief Copy bytes from remote array content.
* \param from The source host data.
* \param ctx_from The source context.
* \param type_hint Hint of content data type.
*/
- void CopyFromRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t nbytes,
- TVMContext ctx_from,
- DLDataType type_hint);
+ void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes,
+ TVMContext ctx_from, DLDataType type_hint);
/*!
* \brief Call a remote defined system function with arguments.
* \param args The arguments
* \return The returned remote value.
*/
- template<typename... Args>
- inline TVMRetValue SysCallRemote(RPCCode fcode, Args&& ...args);
+ template <typename... Args>
+ inline TVMRetValue SysCallRemote(RPCCode fcode, Args&&... args);
/*!
* \brief Create a RPC session with given channel.
* \param channel The communication channel.
* if remote_key equals "%toinit", we need to re-intialize
* it by event handler.
*/
- static std::shared_ptr<RPCEndpoint> Create(
- std::unique_ptr<RPCChannel> channel,
- std::string name,
- std::string remote_key);
+ static std::shared_ptr<RPCEndpoint> Create(std::unique_ptr<RPCChannel> channel, std::string name,
+ std::string remote_key);
private:
class EventHandler;
* \param endpoint The endpoint.
* \return The created session.
*/
-std::shared_ptr<RPCSession>
-CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint);
+std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint);
// implementation of inline functions
-template<typename... Args>
-inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&& ...args) {
+template <typename... Args>
+inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) {
return syscall_remote_(static_cast<int>(code), std::forward<Args>(args)...);
}
} // namespace runtime
* \brief Event driven RPC server implementation.
*/
#include <tvm/runtime/registry.h>
+
#include <memory>
+
#include "rpc_endpoint.h"
#include "rpc_local_session.h"
namespace tvm {
namespace runtime {
-PackedFunc CreateEventDrivenServer(PackedFunc fsend,
- std::string name,
- std::string remote_key) {
+PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name, std::string remote_key) {
static PackedFunc frecv([](TVMArgs args, TVMRetValue* rv) {
LOG(FATAL) << "Do not allow explicit receive";
return 0;
});
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend, frecv));
- std::shared_ptr<RPCEndpoint> sess =
- RPCEndpoint::Create(std::move(ch), name, remote_key);
+ std::shared_ptr<RPCEndpoint> sess = RPCEndpoint::Create(std::move(ch), name, remote_key);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
- int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]);
- *rv = ret;
- });
+ int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]);
+ *rv = ret;
+ });
}
-TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer")
-.set_body_typed(CreateEventDrivenServer);
+TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer").set_body_typed(CreateEventDrivenServer);
} // namespace runtime
} // namespace tvm
* \file local_session.cc
* \brief Local session that directs requests to local API.
*/
-#include <tvm/runtime/registry.h>
+#include "rpc_local_session.h"
+
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
#include <memory>
-#include "rpc_local_session.h"
namespace tvm {
namespace runtime {
-RPCSession::PackedFuncHandle
-LocalSession::GetFunction(const std::string& name) {
+RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) {
if (auto* fp = tvm::runtime::Registry::Get(name)) {
// return raw handle because the remote need to explicitly manage it.
return new PackedFunc(*fp);
ret_value_pack[2].v_handle = ret_value_pack[1].v_handle;
ret_tcode_pack[2] = kTVMOpaqueHandle;
encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3));
- } else if (rv_tcode == kTVMPackedFuncHandle ||
- rv_tcode == kTVMModuleHandle) {
+ } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
// MoveToCHost means rv no longer manages the object.
// return handle instead.
rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]);
}
}
-void LocalSession::CallFunc(RPCSession::PackedFuncHandle func,
- const TVMValue* arg_values,
- const int* arg_type_codes,
- int num_args,
+void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* arg_values,
+ const int* arg_type_codes, int num_args,
const FEncodeReturn& encode_return) {
auto* pf = static_cast<PackedFunc*>(func);
TVMRetValue rv;
this->EncodeReturn(std::move(rv), encode_return);
}
-void LocalSession::CopyToRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t nbytes,
- TVMContext ctx_to,
- DLDataType type_hint) {
+void LocalSession::CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset,
+ size_t nbytes, TVMContext ctx_to, DLDataType type_hint) {
TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
- this->GetDeviceAPI(ctx_to)->CopyDataFromTo(
- from, from_offset,
- to, to_offset,
- nbytes, cpu_ctx, ctx_to, type_hint, nullptr);
+ this->GetDeviceAPI(ctx_to)->CopyDataFromTo(from, from_offset, to, to_offset, nbytes, cpu_ctx,
+ ctx_to, type_hint, nullptr);
// Copy can happen asynchrously
// synchronize to make sure that copy is completed
this->GetDeviceAPI(ctx_to)->StreamSync(ctx_to, nullptr);
}
-void LocalSession::CopyFromRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t nbytes,
- TVMContext ctx_from,
- DLDataType type_hint) {
+void LocalSession::CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset,
+ size_t nbytes, TVMContext ctx_from, DLDataType type_hint) {
TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
- this->GetDeviceAPI(ctx_from)->CopyDataFromTo(
- from, from_offset,
- to, to_offset,
- nbytes, ctx_from, cpu_ctx, type_hint, nullptr);
+ this->GetDeviceAPI(ctx_from)->CopyDataFromTo(from, from_offset, to, to_offset, nbytes, ctx_from,
+ cpu_ctx, type_hint, nullptr);
// Copy can happen asynchrously
// synchronize to make sure that copy is completed
this->GetDeviceAPI(ctx_from)->StreamSync(ctx_from, nullptr);
return DeviceAPI::Get(ctx, allow_missing);
}
-TVM_REGISTER_GLOBAL("rpc.LocalSession")
-.set_body_typed([]() {
+TVM_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() {
return CreateRPCSessionModule(std::make_shared<LocalSession>());
});
#ifndef TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_
#define TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/packed_func.h>
+
#include <functional>
#include <string>
#include <utility>
+
#include "rpc_session.h"
namespace tvm {
// function overrides
PackedFuncHandle GetFunction(const std::string& name) override;
- void CallFunc(PackedFuncHandle func,
- const TVMValue* arg_values,
- const int* arg_type_codes,
- int num_args,
- const FEncodeReturn& fencode_return) override;
+ void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes,
+ int num_args, const FEncodeReturn& fencode_return) override;
- void CopyToRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t nbytes,
- TVMContext ctx_to,
- DLDataType type_hint) override;
+ void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes,
+ TVMContext ctx_to, DLDataType type_hint) override;
- void CopyFromRemote(void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t nbytes,
- TVMContext ctx_from,
- DLDataType type_hint) override;
+ void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes,
+ TVMContext ctx_from, DLDataType type_hint) override;
void FreeHandle(void* handle, int type_code) override;
DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) override;
- bool IsLocalSession() const override {
- return true;
- }
+ bool IsLocalSession() const override { return true; }
protected:
/*!
* \file rpc_module.cc
* \brief RPC runtime module.
*/
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
-#include <memory>
+#include <tvm/runtime/registry.h>
+
#include <cstring>
+#include <memory>
+
#include "rpc_endpoint.h"
#include "rpc_session.h"
*/
class RPCWrappedFunc : public Object {
public:
- RPCWrappedFunc(void* handle,
- std::shared_ptr<RPCSession> sess)
- : handle_(handle), sess_(sess) {
- }
+ RPCWrappedFunc(void* handle, std::shared_ptr<RPCSession> sess) : handle_(handle), sess_(sess) {}
void operator()(TVMArgs args, TVMRetValue* rv) const {
std::vector<TVMValue> values(args.values, args.values + args.size());
// are compatible to each other, just need to change the index.
type_codes[i] = kTVMDLTensorHandle;
// translate to a remote view of DLTensor
- auto dptr = std::make_unique<DLTensor>(
- *static_cast<DLTensor*>(values[i].v_handle));
+ auto dptr = std::make_unique<DLTensor>(*static_cast<DLTensor*>(values[i].v_handle));
dptr->ctx = RemoveSessMask(dptr->ctx);
dptr->data = static_cast<RemoteSpace*>(dptr->data)->data;
values[i].v_handle = dptr.get();
}
case kTVMPackedFuncHandle:
case kTVMModuleHandle: {
- values[i].v_handle = UnwrapRemoteValueToHandle(
- TVMArgValue(values[i], tcode));
+ values[i].v_handle = UnwrapRemoteValueToHandle(TVMArgValue(values[i], tcode));
break;
}
}
}
- auto set_return = [this, rv](TVMArgs args) {
- this->WrapRemoteReturnToValue(args, rv);
- };
- sess_->CallFunc(handle_, values.data(), type_codes.data(),
- args.size(), set_return);
+ auto set_return = [this, rv](TVMArgs args) { this->WrapRemoteReturnToValue(args, rv); };
+ sess_->CallFunc(handle_, values.data(), type_codes.data(), args.size(), set_return);
}
~RPCWrappedFunc() {
data->dl_tensor.data = space;
NDArray ret(GetObjectPtr<Object>(data));
// RAII now in effect
- data->shape_ = std::vector<int64_t>(
- tensor->shape, tensor->shape + tensor->ndim);
+ data->shape_ = std::vector<int64_t>(tensor->shape, tensor->shape + tensor->ndim);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
// setup dtype
// setup ctx, encode as remote session
data->dl_tensor.ctx.device_id = tensor->ctx.device_id;
data->dl_tensor.ctx.device_type = static_cast<DLDeviceType>(
- static_cast<int>(tensor->ctx.device_type) +
- kRPCSessMask * (sess_->table_index() + 1));
+ static_cast<int>(tensor->ctx.device_type) + kRPCSessMask * (sess_->table_index() + 1));
// check strides.
CHECK(tensor->strides == nullptr);
// setup byteoffset
class RPCModuleNode final : public ModuleNode {
public:
RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess)
- : module_handle_(module_handle), sess_(sess) {
- }
+ : module_handle_(module_handle), sess_(sess) {}
~RPCModuleNode() {
if (module_handle_ != nullptr) {
}
}
- const char* type_key() const final {
- return "rpc";
- }
+ const char* type_key() const final { return "rpc"; }
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (module_handle_ == nullptr) {
return WrapRemoteFunc(sess_->GetFunction(name));
} else {
return "";
}
- PackedFunc GetTimeEvaluator(const std::string& name,
- TVMContext ctx,
- int number,
- int repeat,
+ PackedFunc GetTimeEvaluator(const std::string& name, TVMContext ctx, int number, int repeat,
int min_repeat_ms) {
InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator");
// Remove session mask because we pass ctx by parts.
ctx.device_type = static_cast<DLDeviceType>(ctx.device_type % kRPCSessMask);
if (module_handle_ != nullptr) {
- return remote_get_time_evaluator_(
- GetRef<Module>(this), name,
- static_cast<int>(ctx.device_type), ctx.device_id,
- number, repeat, min_repeat_ms);
+ return remote_get_time_evaluator_(GetRef<Module>(this), name,
+ static_cast<int>(ctx.device_type), ctx.device_id, number,
+ repeat, min_repeat_ms);
} else {
- return remote_get_time_evaluator_(
- Optional<Module>(nullptr), name,
- static_cast<int>(ctx.device_type), ctx.device_id,
- number, repeat, min_repeat_ms);
+ return remote_get_time_evaluator_(Optional<Module>(nullptr), name,
+ static_cast<int>(ctx.device_type), ctx.device_id, number,
+ repeat, min_repeat_ms);
}
}
remote_import_module_(GetRef<Module>(this), other);
}
- const std::shared_ptr<RPCSession>& sess() {
- return sess_;
- }
+ const std::shared_ptr<RPCSession>& sess() { return sess_; }
- void* module_handle() const {
- return module_handle_;
- }
+ void* module_handle() const { return module_handle_; }
private:
- template<typename FType>
+ template <typename FType>
void InitRemoteFunc(FType* func, const std::string& name) {
if (*func != nullptr) return;
RPCSession::PackedFuncHandle handle = sess_->GetFunction(name);
PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) {
if (handle == nullptr) return PackedFunc();
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
- return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
- return wf->operator()(args, rv);
- });
+ return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); });
}
// The module handle
std::shared_ptr<RPCSession> sess_;
// remote function to get time evaluator
TypedPackedFunc<PackedFunc(Optional<Module>, std::string, int, int, int, int, int)>
- remote_get_time_evaluator_;
+ remote_get_time_evaluator_;
// remote function getter for modules.
TypedPackedFunc<PackedFunc(Module, std::string, bool)> remote_mod_get_function_;
// remote function getter for load module
TypedPackedFunc<void(Module, Module)> remote_import_module_;
};
-
void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const {
if (arg.type_code() == kTVMModuleHandle) {
Module mod = arg;
std::string tkey = mod->type_key();
- CHECK_EQ(tkey, "rpc")
- << "ValueError: Cannot pass a non-RPC module to remote";
+ CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote";
auto* rmod = static_cast<RPCModuleNode*>(mod.operator->());
CHECK(rmod->sess() == sess_)
<< "ValueError: Cannot pass in module into a different remote session";
return rmod->module_handle();
} else {
- LOG(FATAL) << "ValueError: Cannot pass type "
- << runtime::TypeCode2Str(arg.type_code())
+ LOG(FATAL) << "ValueError: Cannot pass type " << runtime::TypeCode2Str(arg.type_code())
<< " as an argument to the remote";
return nullptr;
}
}
-void RPCWrappedFunc::WrapRemoteReturnToValue(
- TVMArgs args,
- TVMRetValue *rv) const {
+void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const {
int tcode = args[0];
if (tcode == kTVMNullptr) return;
CHECK_EQ(args.size(), 2);
void* handle = args[1];
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
- *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
- return wf->operator()(args, rv);
- });
+ *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); });
} else if (tcode == kTVMModuleHandle) {
CHECK_EQ(args.size(), 2);
void* handle = args[1];
std::shared_ptr<RPCSession> RPCModuleGetSession(Module mod) {
std::string tkey = mod->type_key();
- CHECK_EQ(tkey, "rpc")
- << "ValueError: Cannot pass a non-RPC module to remote";
+ CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote";
auto* rmod = static_cast<RPCModuleNode*>(mod.operator->());
return rmod->sess();
}
-PackedFunc WrapTimeEvaluator(PackedFunc pf,
- TVMContext ctx,
- int number,
- int repeat,
+PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat,
int min_repeat_ms) {
CHECK(pf != nullptr);
return (*get_micro_time_evaluator)(pf, ctx, number, repeat);
}
- auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv)
- mutable {
+ auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) mutable {
TVMRetValue temp;
std::ostringstream os;
// skip first time call, to activate lazy compilation components.
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
for (int i = 0; i < repeat; ++i) {
- std::chrono::time_point<
- std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
+ std::chrono::time_point<std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin,
+ tend;
double duration_ms = 0.0;
do {
if (duration_ms > 0.0) {
- number = static_cast<int>(
- std::max((min_repeat_ms / (duration_ms / number) + 1),
- number * 1.618)); // 1.618 is chosen by random
+ number = static_cast<int>(std::max((min_repeat_ms / (duration_ms / number) + 1),
+ number * 1.618)); // 1.618 is chosen by random
}
tbegin = std::chrono::high_resolution_clock::now();
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
tend = std::chrono::high_resolution_clock::now();
- duration_ms = std::chrono::duration_cast<std::chrono::duration<double> >
- (tend - tbegin).count() * 1000;
+ duration_ms =
+ std::chrono::duration_cast<std::chrono::duration<double>>(tend - tbegin).count() * 1000;
} while (duration_ms < min_repeat_ms);
- double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
- tend - tbegin).count() / number;
+ double speed =
+ std::chrono::duration_cast<std::chrono::duration<double>>(tend - tbegin).count() / number;
os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
}
return PackedFunc(ftimer);
}
-
TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator")
-.set_body_typed([](Optional<Module> opt_mod,
- std::string name,
- int device_type,
- int device_id,
- int number,
- int repeat,
- int min_repeat_ms) {
- TVMContext ctx;
- ctx.device_type = static_cast<DLDeviceType>(device_type);
- ctx.device_id = device_id;
- if (opt_mod.defined()) {
- Module m = opt_mod.value();
- std::string tkey = m->type_key();
- if (tkey == "rpc") {
- return static_cast<RPCModuleNode*>(m.operator->())
- ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms);
- } else {
- return WrapTimeEvaluator(
- m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms);
- }
- } else {
- auto* pf = runtime::Registry::Get(name);
- CHECK(pf != nullptr) << "Cannot find " << name << " in the global function";
- return WrapTimeEvaluator(
- *pf, ctx, number, repeat, min_repeat_ms);
- }
-});
+ .set_body_typed([](Optional<Module> opt_mod, std::string name, int device_type, int device_id,
+ int number, int repeat, int min_repeat_ms) {
+ TVMContext ctx;
+ ctx.device_type = static_cast<DLDeviceType>(device_type);
+ ctx.device_id = device_id;
+ if (opt_mod.defined()) {
+ Module m = opt_mod.value();
+ std::string tkey = m->type_key();
+ if (tkey == "rpc") {
+ return static_cast<RPCModuleNode*>(m.operator->())
+ ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms);
+ } else {
+ return WrapTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms);
+ }
+ } else {
+ auto* pf = runtime::Registry::Get(name);
+ CHECK(pf != nullptr) << "Cannot find " << name << " in the global function";
+ return WrapTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms);
+ }
+ });
// server function registration.
-TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule")
-.set_body_typed([](Module parent, Module child) {
+TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule").set_body_typed([](Module parent, Module child) {
parent->Import(child);
});
TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction")
-.set_body_typed([](Module parent, std::string name, bool query_imports) {
- return parent->GetFunction(name, query_imports);
-});
+ .set_body_typed([](Module parent, std::string name, bool query_imports) {
+ return parent->GetFunction(name, query_imports);
+ });
// functions to access an RPC module.
-TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule")
-.set_body_typed([](Module sess, std::string name) {
+TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) {
std::string tkey = sess->type_key();
CHECK_EQ(tkey, "rpc");
return static_cast<RPCModuleNode*>(sess.operator->())->LoadModule(name);
});
-TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule")
-.set_body_typed([](Module parent, Module child) {
+TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) {
std::string tkey = parent->type_key();
CHECK_EQ(tkey, "rpc");
static_cast<RPCModuleNode*>(parent.operator->())->ImportModule(child);
});
-TVM_REGISTER_GLOBAL("rpc.SessTableIndex")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
CHECK_EQ(tkey, "rpc");
// Linux only for now, as linux is the most common usecase.
#if defined(__linux__) || defined(__ANDROID__)
-#include <sys/types.h>
-#include <unistd.h>
#include <errno.h>
#include <signal.h>
-
+#include <sys/types.h>
#include <tvm/runtime/registry.h>
-#include <memory>
+#include <unistd.h>
+
#include <cstdlib>
+#include <memory>
+#include "../../support/pipe.h"
#include "rpc_endpoint.h"
#include "rpc_local_session.h"
-#include "../../support/pipe.h"
namespace tvm {
namespace runtime {
class PipeChannel final : public RPCChannel {
public:
explicit PipeChannel(int readfd, int writefd, pid_t child_pid)
- : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) {
- }
+ : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) {}
- ~PipeChannel() {
- Close();
- }
+ ~PipeChannel() { Close(); }
size_t Send(const void* data, size_t size) final {
ssize_t n = write(writefd_, data, size);
pid_t child_pid_;
};
-
Module CreatePipeClient(std::vector<std::string> cmd) {
int parent2child[2];
int child2parent[2];
close(child_write);
auto endpt = RPCEndpoint::Create(
- std::unique_ptr<PipeChannel>(
- new PipeChannel(parent_read, parent_write, pid)),
- "pipe", "pipe");
+ std::unique_ptr<PipeChannel>(new PipeChannel(parent_read, parent_write, pid)), "pipe",
+ "pipe");
endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0));
return CreateRPCSessionModule(CreateClientSession(endpt));
}
-TVM_REGISTER_GLOBAL("rpc.CreatePipeClient")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<std::string> cmd;
for (int i = 0; i < args.size(); ++i) {
cmd.push_back(args[i].operator std::string());
*rv = CreatePipeClient(cmd);
});
-
} // namespace runtime
} // namespace tvm
#endif
*/
inline const char* RPCServerStatusToString(RPCServerStatus status) {
switch (status) {
- case RPCServerStatus::kSuccess: return "kSuccess";
- case RPCServerStatus::kInvalidTypeCodeObject: return "kInvalidTypeCodeObject";
- case RPCServerStatus::kInvalidTypeCodeNDArray: return "kInvalidTypeCodeNDArray";
- case RPCServerStatus::kInvalidDLTensorFieldStride: return "kInvalidDLTensorFieldStride";
+ case RPCServerStatus::kSuccess:
+ return "kSuccess";
+ case RPCServerStatus::kInvalidTypeCodeObject:
+ return "kInvalidTypeCodeObject";
+ case RPCServerStatus::kInvalidTypeCodeNDArray:
+ return "kInvalidTypeCodeNDArray";
+ case RPCServerStatus::kInvalidDLTensorFieldStride:
+ return "kInvalidDLTensorFieldStride";
case RPCServerStatus::kInvalidDLTensorFieldByteOffset: {
return "kInvalidDLTensorFieldByteOffset";
}
- case RPCServerStatus::kUnknownTypeCode: return "kUnknownTypeCode";
- case RPCServerStatus::kUnknownRPCCode: return "kUnknownRPCCode";
- case RPCServerStatus::kRPCCodeNotSupported: return "RPCCodeNotSupported";
- case RPCServerStatus::kUnknownRPCSyscall: return "kUnknownRPCSyscall";
- case RPCServerStatus::kCheckError: return "kCheckError";
- case RPCServerStatus::kReadError: return "kReadError";
- case RPCServerStatus::kWriteError: return "kWriteError";
- case RPCServerStatus::kAllocError: return "kAllocError";
- default: return "";
+ case RPCServerStatus::kUnknownTypeCode:
+ return "kUnknownTypeCode";
+ case RPCServerStatus::kUnknownRPCCode:
+ return "kUnknownRPCCode";
+ case RPCServerStatus::kRPCCodeNotSupported:
+ return "RPCCodeNotSupported";
+ case RPCServerStatus::kUnknownRPCSyscall:
+ return "kUnknownRPCSyscall";
+ case RPCServerStatus::kCheckError:
+ return "kCheckError";
+ case RPCServerStatus::kReadError:
+ return "kReadError";
+ case RPCServerStatus::kWriteError:
+ return "kWriteError";
+ case RPCServerStatus::kAllocError:
+ return "kAllocError";
+ default:
+ return "";
}
}
* \brief Auxiliary class to get the packed sequence.
* \tparam TChannel The channel to throw errror.
*/
- template<typename TChannel>
+ template <typename TChannel>
struct PackedSeqNumBytesGetter {
public:
- explicit PackedSeqNumBytesGetter(TChannel* channel)
- : channel_(channel) {}
+ explicit PackedSeqNumBytesGetter(TChannel* channel) : channel_(channel) {}
template <typename T>
void Write(const T& value) {
num_bytes_ += sizeof(T) * num;
}
- void ThrowError(RPCServerStatus status) {
- channel_->ThrowError(status);
- }
+ void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); }
- uint64_t num_bytes() const {
- return num_bytes_;
- }
+ uint64_t num_bytes() const { return num_bytes_; }
private:
TChannel* channel_;
* \tparam TChannel The type of the communication channel.
* \return The total number of bytes.
*/
- template<typename TChannel>
- static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values,
- const int* type_codes,
- int num_args,
- bool client_mode,
- TChannel* channel) {
+ template <typename TChannel>
+ static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes,
+ int num_args, bool client_mode, TChannel* channel) {
PackedSeqNumBytesGetter<TChannel> getter(channel);
SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter);
return getter.num_bytes();
* \param channel The communication channel handler.
* \tparam TChannel The type of the communication channel.
*/
- template<typename TChannel>
- static void SendPackedSeq(const TVMValue* arg_values,
- const int* type_codes,
- int num_args,
- bool client_mode,
- TChannel* channel) {
+ template <typename TChannel>
+ static void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args,
+ bool client_mode, TChannel* channel) {
channel->Write(num_args);
channel->WriteArray(type_codes, num_args);
}
break;
}
- case kTVMNullptr: break;
+ case kTVMNullptr:
+ break;
case kTVMStr: {
const char* s = value.v_str;
uint64_t len = StrLength(s);
* \tparam TChannel The type of the communication channel.
* \note The temporary space are populated via an arena inside channel.
*/
- template<typename TChannel>
- static void RecvPackedSeq(TVMValue** out_values,
- int** out_tcodes,
- int* out_num_args,
+ template <typename TChannel>
+ static void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args,
TChannel* channel) {
// receive number of args
int num_args;
* \param channel The communication channel handler.
* \tparam TChannel The type of the communication channel.
*/
- template<typename TChannel>
+ template <typename TChannel>
static void ReturnException(const char* msg, TChannel* channel) {
RPCCode code = RPCCode::kException;
int32_t num_args = 1;
int32_t tcode = kTVMStr;
uint64_t len = StrLength(msg);
- uint64_t packet_nbytes =
- sizeof(code) +
- sizeof(num_args) +
- sizeof(tcode) +
- sizeof(len) +
- len;
+ uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(len) + len;
channel->Write(packet_nbytes);
channel->Write(code);
* \param channel The communication channel handler.
* \tparam TChannel The type of the communication channel.
*/
- template<typename TChannel>
- static void ReturnPackedSeq(const TVMValue* arg_values,
- const int* type_codes,
- int num_args,
+ template <typename TChannel>
+ static void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args,
TChannel* channel) {
RPCCode code = RPCCode::kReturn;
uint64_t packet_nbytes =
- sizeof(code) +
- PackedSeqGetNumBytes(
- arg_values, type_codes, num_args, false, channel);
+ sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args, false, channel);
channel->Write(packet_nbytes);
channel->Write(code);
- SendPackedSeq(
- arg_values, type_codes, num_args, false, channel);
+ SendPackedSeq(arg_values, type_codes, num_args, false, channel);
}
/*!
* \param channel The communication channel handler.
* \tparam TChannel The type of the communication channel.
*/
- template<typename TChannel>
+ template <typename TChannel>
static void ReturnVoid(TChannel* channel) {
int32_t num_args = 1;
int32_t tcode = kTVMNullptr;
RPCCode code = RPCCode::kReturn;
- uint64_t packet_nbytes =
- sizeof(code) +
- sizeof(num_args) +
- sizeof(tcode);
+ uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
channel->Write(packet_nbytes);
channel->Write(code);
* \brief Server environment of the RPC.
*/
#include <tvm/runtime/registry.h>
+
#include "../file_util.h"
namespace tvm {
std::string RPCGetPath(const std::string& name) {
// do live lookup everytime as workpath can change.
- const PackedFunc* f =
- runtime::Registry::Get("tvm.rpc.server.workpath");
+ const PackedFunc* f = runtime::Registry::Get("tvm.rpc.server.workpath");
CHECK(f != nullptr) << "require tvm.rpc.server.workpath";
return (*f)(name);
}
-TVM_REGISTER_GLOBAL("tvm.rpc.server.upload").
-set_body([](TVMArgs args, TVMRetValue *rv) {
- std::string file_name = RPCGetPath(args[0]);
- std::string data = args[1];
- SaveBinaryToFile(file_name, data);
- });
+TVM_REGISTER_GLOBAL("tvm.rpc.server.upload").set_body([](TVMArgs args, TVMRetValue* rv) {
+ std::string file_name = RPCGetPath(args[0]);
+ std::string data = args[1];
+ SaveBinaryToFile(file_name, data);
+});
-TVM_REGISTER_GLOBAL("tvm.rpc.server.download")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- std::string file_name = RPCGetPath(args[0]);
- std::string data;
- LoadBinaryFromFile(file_name, &data);
- TVMByteArray arr;
- arr.data = data.c_str();
- arr.size = data.length();
- LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size;
- *rv = arr;
- });
+TVM_REGISTER_GLOBAL("tvm.rpc.server.download").set_body([](TVMArgs args, TVMRetValue* rv) {
+ std::string file_name = RPCGetPath(args[0]);
+ std::string data;
+ LoadBinaryFromFile(file_name, &data);
+ TVMByteArray arr;
+ arr.data = data.c_str();
+ arr.size = data.length();
+ LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size;
+ *rv = arr;
+});
-TVM_REGISTER_GLOBAL("tvm.rpc.server.remove")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- std::string file_name = RPCGetPath(args[0]);
- RemoveFile(file_name);
- });
+TVM_REGISTER_GLOBAL("tvm.rpc.server.remove").set_body([](TVMArgs args, TVMRetValue* rv) {
+ std::string file_name = RPCGetPath(args[0]);
+ RemoveFile(file_name);
+});
} // namespace runtime
} // namespace tvm
* \file rpc_session.cc
* \brief RPC session for remote function call.
*/
-#include <tvm/runtime/packed_func.h>
+#include "rpc_session.h"
+
#include <tvm/runtime/device_api.h>
-#include <mutex>
+#include <tvm/runtime/packed_func.h>
+
#include <array>
-#include "rpc_session.h"
+#include <mutex>
namespace tvm {
namespace runtime {
-bool RPCSession::IsAsync() const {
- return false;
-}
+bool RPCSession::IsAsync() const { return false; }
void RPCSession::SendException(FAsyncCallback callback, const char* msg) {
TVMValue value;
callback(RPCCode::kException, TVMArgs(&value, &tcode, 1));
}
-void RPCSession::AsyncCallFunc(PackedFuncHandle func,
- const TVMValue* arg_values,
- const int* arg_type_codes,
- int num_args,
- FAsyncCallback callback) {
+void RPCSession::AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values,
+ const int* arg_type_codes, int num_args, FAsyncCallback callback) {
try {
this->CallFunc(func, arg_values, arg_type_codes, num_args,
- [&callback](TVMArgs args) {
- callback(RPCCode::kReturn, args);
- });
+ [&callback](TVMArgs args) { callback(RPCCode::kReturn, args); });
} catch (const std::runtime_error& e) {
this->SendException(callback, e.what());
}
}
-
-void RPCSession::AsyncCopyToRemote(void* local_from,
- size_t local_from_offset,
- void* remote_to,
- size_t remote_to_offset,
- size_t nbytes,
- TVMContext remote_ctx_to,
- DLDataType type_hint,
- RPCSession::FAsyncCallback callback) {
+void RPCSession::AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to,
+ size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to,
+ DLDataType type_hint, RPCSession::FAsyncCallback callback) {
TVMValue value;
int32_t tcode = kTVMNullptr;
value.v_handle = nullptr;
try {
- this->CopyToRemote(local_from, local_from_offset,
- remote_to, remote_to_offset,
- nbytes, remote_ctx_to, type_hint);
+ this->CopyToRemote(local_from, local_from_offset, remote_to, remote_to_offset, nbytes,
+ remote_ctx_to, type_hint);
callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
} catch (const std::runtime_error& e) {
this->SendException(callback, e.what());
}
}
-void RPCSession::AsyncCopyFromRemote(void* remote_from,
- size_t remote_from_offset,
- void* local_to,
- size_t local_to_offset,
- size_t nbytes,
- TVMContext remote_ctx_from,
- DLDataType type_hint,
+void RPCSession::AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to,
+ size_t local_to_offset, size_t nbytes,
+ TVMContext remote_ctx_from, DLDataType type_hint,
RPCSession::FAsyncCallback callback) {
TVMValue value;
int32_t tcode = kTVMNullptr;
value.v_handle = nullptr;
try {
- this->CopyFromRemote(remote_from, remote_from_offset,
- local_to, local_to_offset,
- nbytes, remote_ctx_from, type_hint);
+ this->CopyFromRemote(remote_from, remote_from_offset, local_to, local_to_offset, nbytes,
+ remote_ctx_from, type_hint);
callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
} catch (const std::runtime_error& e) {
this->SendException(callback, e.what());
}
}
-void RPCSession::AsyncStreamWait(TVMContext ctx,
- TVMStreamHandle stream,
+void RPCSession::AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream,
RPCSession::FAsyncCallback callback) {
TVMValue value;
int32_t tcode = kTVMNullptr;
}
}
-
class RPCSessTable {
public:
static constexpr int kMaxRPCSession = 32;
std::lock_guard<std::mutex> lock(mutex_);
for (int i = 0; i < kMaxRPCSession; ++i) {
if (tbl_[i].lock() == nullptr) {
- tbl_[i] = ptr; return i;
+ tbl_[i] = ptr;
+ return i;
}
}
LOG(FATAL) << "maximum number of RPC session reached";
#ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_
#define TVM_RUNTIME_RPC_RPC_SESSION_H_
-
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/packed_func.h>
+
#include <functional>
#include <memory>
#include <string>
+
#include "rpc_protocol.h"
namespace tvm {
* \param fencode_return The function to set the return value,
* if not called, return value is null.
*/
- virtual void CallFunc(PackedFuncHandle func,
- const TVMValue* arg_values,
- const int* arg_type_codes,
- int num_args,
+ virtual void CallFunc(PackedFuncHandle func, const TVMValue* arg_values,
+ const int* arg_type_codes, int num_args,
const FEncodeReturn& fencode_return) = 0;
/*!
* \param remote_ctx_to The target context.
* \param type_hint Hint of content data type.
*/
- virtual void CopyToRemote(void* local_from,
- size_t local_from_offset,
- void* remote_to,
- size_t remote_to_offset,
- size_t nbytes,
- TVMContext remote_ctx_to,
+ virtual void CopyToRemote(void* local_from, size_t local_from_offset, void* remote_to,
+ size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to,
DLDataType type_hint) = 0;
/*!
* \brief Copy bytes from remote array content.
* \param remote_ctx_from The source context in the remote.
* \param type_hint Hint of content data type.
*/
- virtual void CopyFromRemote(void* remote_from,
- size_t remote_from_offset,
- void* local_to,
- size_t local_to_offset,
- size_t nbytes,
- TVMContext remote_ctx_from,
+ virtual void CopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to,
+ size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from,
DLDataType type_hint) = 0;
/*!
*
* \param callback The callback to pass the return value or exception.
*/
- virtual void AsyncCallFunc(PackedFuncHandle func,
- const TVMValue* arg_values,
- const int* arg_type_codes,
- int num_args,
- FAsyncCallback callback);
+ virtual void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values,
+ const int* arg_type_codes, int num_args, FAsyncCallback callback);
/*!
* \brief Asynchrous version of CopyToRemote.
* \note All the allocated memory in local_from, and remote_to
* must stay alive until on_compelete is called.
*/
- virtual void AsyncCopyToRemote(void* local_from,
- size_t local_from_offset,
- void* remote_to,
- size_t remote_to_offset,
- size_t nbytes,
- TVMContext remote_ctx_to,
- DLDataType type_hint,
- FAsyncCallback on_complete);
+ virtual void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to,
+ size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to,
+ DLDataType type_hint, FAsyncCallback on_complete);
/*!
* \brief Asynchrous version of CopyFromRemote.
* \note All the allocated memory in remote_from, and local_to
* must stay alive until on_compelete is called.
*/
- virtual void AsyncCopyFromRemote(void* remote_from,
- size_t remote_from_offset,
- void* local_to,
- size_t local_to_offset,
- size_t nbytes,
- TVMContext remote_ctx_from,
- DLDataType type_hint,
+ virtual void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to,
+ size_t local_to_offset, size_t nbytes,
+ TVMContext remote_ctx_from, DLDataType type_hint,
FAsyncCallback on_complete);
/*!
* \brief Asynchrously wait for all events in ctx, stream compeletes.
* \param stream The stream to wait on.
* \param on_complete The callback to signal copy complete.
*/
- virtual void AsyncStreamWait(TVMContext ctx,
- TVMStreamHandle stream,
- FAsyncCallback on_compelte);
+ virtual void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_compelte);
/*!
* \return The session table index of the session.
*/
- int table_index() const {
- return table_index_;
- }
+ int table_index() const { return table_index_; }
/*!
* \brief Try get session from the global session table by table index.
* the `number` parameter will be automatically increased.
* \return f_timer A timer function.
*/
-PackedFunc WrapTimeEvaluator(PackedFunc f,
- TVMContext ctx,
- int number,
- int repeat,
+PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int number, int repeat,
int min_repeat_ms);
/*!
* \file rpc_socket_impl.cc
* \brief Socket based RPC implementation.
*/
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
+#include <tvm/runtime/registry.h>
+
#include <memory>
+
+#include "../../support/socket.h"
#include "rpc_endpoint.h"
-#include "rpc_session.h"
#include "rpc_local_session.h"
-#include "../../support/socket.h"
+#include "rpc_session.h"
namespace tvm {
namespace runtime {
class SockChannel final : public RPCChannel {
public:
- explicit SockChannel(support::TCPSocket sock)
- : sock_(sock) {}
+ explicit SockChannel(support::TCPSocket sock) : sock_(sock) {}
~SockChannel() {
try {
// BadSocket can throw
support::TCPSocket sock_;
};
-std::shared_ptr<RPCEndpoint>
-RPCConnect(std::string url, int port, std::string key, TVMArgs init_seq) {
+std::shared_ptr<RPCEndpoint> RPCConnect(std::string url, int port, std::string key,
+ TVMArgs init_seq) {
support::TCPSocket sock;
support::SockAddr addr(url.c_str(), port);
sock.Create(addr.ss_family());
- CHECK(sock.Connect(addr))
- << "Connect to " << addr.AsString() << " failed";
+ CHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed";
// hand shake
std::ostringstream os;
int code = kRPCMagic;
CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code == kRPCMagic + 2) {
sock.Close();
- LOG(FATAL) << "URL " << url << ":" << port
- << " cannot find server that matches key=" << key;
+ LOG(FATAL) << "URL " << url << ":" << port << " cannot find server that matches key=" << key;
} else if (code == kRPCMagic + 1) {
sock.Close();
- LOG(FATAL) << "URL " << url << ":" << port
- << " server already have key=" << key;
+ LOG(FATAL) << "URL " << url << ":" << port << " server already have key=" << key;
} else if (code != kRPCMagic) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
remote_key.resize(keylen);
CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen);
}
- auto endpt = RPCEndpoint::Create(
- std::unique_ptr<SockChannel>(new SockChannel(sock)), key, remote_key);
+ auto endpt =
+ RPCEndpoint::Create(std::unique_ptr<SockChannel>(new SockChannel(sock)), key, remote_key);
endpt->InitRemoteSession(init_seq);
return endpt;
}
-Module RPCClientConnect(std::string url,
- int port,
- std::string key,
- TVMArgs init_seq) {
+Module RPCClientConnect(std::string url, int port, std::string key, TVMArgs init_seq) {
auto endpt = RPCConnect(url, port, "client:" + key, init_seq);
return CreateRPCSessionModule(CreateClientSession(endpt));
}
// TVM_DLL needed for MSVC
TVM_DLL void RPCServerLoop(int sockfd) {
- support::TCPSocket sock(
- static_cast<support::TCPSocket::SockType>(sockfd));
- RPCEndpoint::Create(
- std::unique_ptr<SockChannel>(new SockChannel(sock)),
- "SockServerLoop", "")->ServerLoop();
+ support::TCPSocket sock(static_cast<support::TCPSocket::SockType>(sockfd));
+ RPCEndpoint::Create(std::unique_ptr<SockChannel>(new SockChannel(sock)), "SockServerLoop", "")
+ ->ServerLoop();
}
-void RPCServerLoop(PackedFunc fsend,
- PackedFunc frecv) {
- RPCEndpoint::Create(
- std::unique_ptr<CallbackChannel>(new CallbackChannel(fsend, frecv)),
- "SockServerLoop", "")->ServerLoop();
+void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) {
+ RPCEndpoint::Create(std::unique_ptr<CallbackChannel>(new CallbackChannel(fsend, frecv)),
+ "SockServerLoop", "")
+ ->ServerLoop();
}
-TVM_REGISTER_GLOBAL("rpc.Connect")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("rpc.Connect").set_body([](TVMArgs args, TVMRetValue* rv) {
std::string url = args[0];
int port = args[1];
std::string key = args[2];
- *rv = RPCClientConnect(
- url, port, key,
- TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3));
+ *rv = RPCClientConnect(url, port, key,
+ TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3));
});
-TVM_REGISTER_GLOBAL("rpc.ServerLoop")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kDLInt) {
RPCServerLoop(args[0]);
} else {
- RPCServerLoop(
- args[0].operator tvm::runtime::PackedFunc(),
- args[1].operator tvm::runtime::PackedFunc());
+ RPCServerLoop(args[0].operator tvm::runtime::PackedFunc(),
+ args[1].operator tvm::runtime::PackedFunc());
}
});
#define TVM_RUNTIME_RUNTIME_BASE_H_
#include <tvm/runtime/c_runtime_api.h>
+
#include <stdexcept>
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
-#define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
+#define API_END() \
+ } \
+ catch (std::runtime_error & _except_) { \
+ return TVMAPIHandleException(_except_); \
+ } \
+ return 0; // NOLINT(*)
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
-#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
+#define API_END_HANDLE_ERROR(Finalize) \
+ } \
+ catch (std::runtime_error & _except_) { \
+ Finalize; \
+ return TVMAPIHandleException(_except_); \
+ } \
+ return 0; // NOLINT(*)
/*!
* \brief handle exception throwed out
* \param e the exception
* \return the return value of API after exception is handled
*/
-int TVMAPIHandleException(const std::runtime_error &e);
+int TVMAPIHandleException(const std::runtime_error& e);
#endif // TVM_RUNTIME_RUNTIME_BASE_H_
* Implementation stack VM.
* \file stackvm.cc
*/
+#include "stackvm.h"
+
#include <dmlc/thread_local.h>
#include <tvm/runtime/c_backend_api.h>
+
#include <algorithm>
-#include "stackvm.h"
namespace tvm {
namespace runtime {
typedef dmlc::ThreadLocalStore<StackVM::State> StackVMStateStore;
-StackVM::State* StackVM::ThreadLocalState() {
- return StackVMStateStore::Get();
-}
+StackVM::State* StackVM::ThreadLocalState() { return StackVMStateStore::Get(); }
#define STACK_VM_BINOP(OP, FIELD) \
{ \
stack[sp - 1].FIELD = stack[sp - 1].FIELD OP stack[sp].FIELD; \
- sp -= 1; pc += 1; \
+ sp -= 1; \
+ pc += 1; \
}
#define STACK_VM_CMPOP(OP, FIELD) \
{ \
stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \
- sp -= 1; pc += 1; \
+ sp -= 1; \
+ pc += 1; \
}
-#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \
- { \
- int index = code[pc + 1].v_int; \
- stack[sp]FIELD = static_cast<DST_TYPE>( \
- static_cast<SRC_TYPE*>(stack[sp].v_handle)[index]); \
- pc += 2; \
+#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \
+ { \
+ int index = code[pc + 1].v_int; \
+ stack[sp] FIELD = static_cast<DST_TYPE>(static_cast<SRC_TYPE*>(stack[sp].v_handle)[index]); \
+ pc += 2; \
}
-#define STACK_VM_STORE(FIELD, DST_TYPE) \
- { \
- int index = code[pc + 1].v_int; \
- static_cast<DST_TYPE*>(stack[sp - 1].v_handle)[index] = \
- static_cast<DST_TYPE>(stack[sp]FIELD); \
- sp -= 2; pc += 2; \
+#define STACK_VM_STORE(FIELD, DST_TYPE) \
+ { \
+ int index = code[pc + 1].v_int; \
+ static_cast<DST_TYPE*>(stack[sp - 1].v_handle)[index] = \
+ static_cast<DST_TYPE>(stack[sp] FIELD); \
+ sp -= 2; \
+ pc += 2; \
}
-#define STACK_VM_PRINT_CODE0(CODE) \
- case CODE: { \
- os << "[" << pc << "]\t" << #CODE << std::endl; return pc + 1; \
+#define STACK_VM_PRINT_CODE0(CODE) \
+ case CODE: { \
+ os << "[" << pc << "]\t" << #CODE << std::endl; \
+ return pc + 1; \
}
-#define STACK_VM_PRINT_CODE1(CODE) \
- case CODE: { \
+#define STACK_VM_PRINT_CODE1(CODE) \
+ case CODE: { \
os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \
- << "[" << pc + 1 << "]" << std::endl; \
- return pc + 2; \
+ << "[" << pc + 1 << "]" << std::endl; \
+ return pc + 2; \
}
-#define STACK_VM_PRINT_CODE2(CODE) \
- case CODE: { \
- os << "[" << pc << "]\t" << #CODE \
- << " " << code[pc + 1].v_int \
- << " " << code[pc + 2].v_int << "\n" \
- << "[" << pc + 1 << "]" << std::endl \
- << "[" << pc + 2 << "]" << std::endl; \
- return pc + 3; \
+#define STACK_VM_PRINT_CODE2(CODE) \
+ case CODE: { \
+ os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " << code[pc + 2].v_int \
+ << "\n" \
+ << "[" << pc + 1 << "]" << std::endl \
+ << "[" << pc + 2 << "]" << std::endl; \
+ return pc + 3; \
}
-#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \
- case CODE: { \
- os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int \
- << " " << heap_id_name[code[pc + 1].v_int] << "\n" \
- << "[" << pc + 1 << "]" << std::endl; \
- return pc + 2; \
+#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \
+ case CODE: { \
+ os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " \
+ << heap_id_name[code[pc + 1].v_int] << "\n" \
+ << "[" << pc + 1 << "]" << std::endl; \
+ return pc + 2; \
}
-#define STACK_VM_PRINT_JUMP(CODE) \
- case CODE: { \
- os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int \
- << " to " << pc + code[pc + 1].v_int << '\n' \
- << "[" << pc + 1 << "]" << std::endl; \
- return pc + 2; \
+#define STACK_VM_PRINT_JUMP(CODE) \
+ case CODE: { \
+ os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int << " to " \
+ << pc + code[pc + 1].v_int << '\n' \
+ << "[" << pc + 1 << "]" << std::endl; \
+ return pc + 2; \
}
-
int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
switch (code[pc].op_code) {
// int
int begin = code[pc + 2].v_int;
int end = code[pc + 3].v_int;
os << "[" << pc << "]\tCALL_PACKED_FUNC "
- << " fid=" << call_fid
- << " begin=" << begin
- << " end=" << end;
+ << " fid=" << call_fid << " begin=" << begin << " end=" << end;
os << '\n';
for (int i = 0; i < 3; ++i) {
os << "[" << pc + 1 + i << "]" << std::endl;
std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*)
int64_t pc = 0;
const int64_t code_size = static_cast<int64_t>(vm.code.size());
- os << "Program dump: code-size=" << code_size << '\n'
- << "----------begin-----------------\n";
+ os << "Program dump: code-size=" << code_size << '\n' << "----------begin-----------------\n";
while (pc < code_size) {
pc = vm.PrintCode(os, pc);
}
return os;
}
-void StackVM::Run(const runtime::TVMArgs& args,
- runtime::ModuleNode* mod_ctx) const {
+void StackVM::Run(const runtime::TVMArgs& args, runtime::ModuleNode* mod_ctx) const {
StackVM::State* s = StackVM::ThreadLocalState();
if (s->heap.size() < heap_size) {
s->heap.resize(heap_size);
s->sp = 0;
s->pc = 0;
s->mod_ctx = mod_ctx;
- s->heap[0].v_handle = (void*)args.values; // NOLINT(*)
+ s->heap[0].v_handle = (void*)args.values; // NOLINT(*)
s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*)
s->heap[2].v_int64 = args.num_args;
this->Run(s);
void StackVM::InitCache() {
extern_func_cache_.clear();
- extern_func_cache_.resize(
- extern_func_name.size(), PackedFunc(nullptr));
+ extern_func_cache_.resize(extern_func_name.size(), PackedFunc(nullptr));
}
void StackVM::Save(dmlc::Stream* strm) const {
// to be endian invariant.
std::vector<int32_t> code_copy(code.size());
- std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) {
- return c.v_int;
- });
+ std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) { return c.v_int; });
strm->Write(code_copy);
strm->Write(str_data);
strm->Write(extern_func_name);
strm->Write(stack_size);
}
-bool StackVM::Load(dmlc::Stream* strm) {
+bool StackVM::Load(dmlc::Stream* strm) {
// to be endian invariant.
std::vector<int32_t> code_copy;
if (!strm->Read(&code_copy)) return false;
code.resize(code_copy.size());
std::transform(code_copy.begin(), code_copy.end(), code.begin(), [](int v) {
- Code code; code.v_int = v; return code;
- });
+ Code code;
+ code.v_int = v;
+ return code;
+ });
if (!strm->Read(&str_data)) return false;
if (!strm->Read(&extern_func_name)) return false;
if (!strm->Read(&heap_id_name)) return false;
const int64_t code_size = static_cast<int64_t>(code.size());
while (pc < code_size) {
switch (code[pc].op_code) {
- case ADD_I64: STACK_VM_BINOP(+, v_int64); break;
- case SUB_I64: STACK_VM_BINOP(-, v_int64); break;
- case MUL_I64: STACK_VM_BINOP(*, v_int64); break;
- case DIV_I64: STACK_VM_BINOP(/, v_int64); break;
- case MOD_I64: STACK_VM_BINOP(%, v_int64); break;
- case EQ_I64: STACK_VM_CMPOP(==, v_int64); break;
- case LT_I64: STACK_VM_CMPOP(<, v_int64); break;
- case LE_I64: STACK_VM_CMPOP(<=, v_int64); break;
- case ADD_F64: STACK_VM_BINOP(+, v_float64); break;
- case SUB_F64: STACK_VM_BINOP(-, v_float64); break;
- case MUL_F64: STACK_VM_BINOP(*, v_float64); break;
- case DIV_F64: STACK_VM_BINOP(/, v_float64); break;
- case EQ_F64: STACK_VM_CMPOP(==, v_float64); break;
- case LT_F64: STACK_VM_CMPOP(<, v_float64); break;
- case LE_F64: STACK_VM_CMPOP(<=, v_float64); break;
- case EQ_HANDLE: STACK_VM_CMPOP(==, v_handle); break;
+ case ADD_I64:
+ STACK_VM_BINOP(+, v_int64);
+ break;
+ case SUB_I64:
+ STACK_VM_BINOP(-, v_int64);
+ break;
+ case MUL_I64:
+ STACK_VM_BINOP(*, v_int64);
+ break;
+ case DIV_I64:
+ STACK_VM_BINOP(/, v_int64);
+ break;
+ case MOD_I64:
+ STACK_VM_BINOP(%, v_int64);
+ break;
+ case EQ_I64:
+ STACK_VM_CMPOP(==, v_int64);
+ break;
+ case LT_I64:
+ STACK_VM_CMPOP(<, v_int64);
+ break;
+ case LE_I64:
+ STACK_VM_CMPOP(<=, v_int64);
+ break;
+ case ADD_F64:
+ STACK_VM_BINOP(+, v_float64);
+ break;
+ case SUB_F64:
+ STACK_VM_BINOP(-, v_float64);
+ break;
+ case MUL_F64:
+ STACK_VM_BINOP(*, v_float64);
+ break;
+ case DIV_F64:
+ STACK_VM_BINOP(/, v_float64);
+ break;
+ case EQ_F64:
+ STACK_VM_CMPOP(==, v_float64);
+ break;
+ case LT_F64:
+ STACK_VM_CMPOP(<, v_float64);
+ break;
+ case LE_F64:
+ STACK_VM_CMPOP(<=, v_float64);
+ break;
+ case EQ_HANDLE:
+ STACK_VM_CMPOP(==, v_handle);
+ break;
// addressing
- case ARRAY_LOAD_UINT32: STACK_VM_LOAD(.v_int64, int64_t, uint32_t); break;
- case ARRAY_LOAD_INT32: STACK_VM_LOAD(.v_int64, int64_t, int32_t); break;
- case ARRAY_LOAD_INT64: STACK_VM_LOAD(.v_int64, int64_t, int64_t); break;
- case ARRAY_LOAD_FP64: STACK_VM_LOAD(.v_float64, double, double); break;
- case ARRAY_LOAD_HANDLE: STACK_VM_LOAD(.v_handle, void*, void*); break;
- case ARRAY_LOAD_TVMVALUE: STACK_VM_LOAD(, TVMValue, TVMValue); break;
+ case ARRAY_LOAD_UINT32:
+ STACK_VM_LOAD(.v_int64, int64_t, uint32_t);
+ break;
+ case ARRAY_LOAD_INT32:
+ STACK_VM_LOAD(.v_int64, int64_t, int32_t);
+ break;
+ case ARRAY_LOAD_INT64:
+ STACK_VM_LOAD(.v_int64, int64_t, int64_t);
+ break;
+ case ARRAY_LOAD_FP64:
+ STACK_VM_LOAD(.v_float64, double, double);
+ break;
+ case ARRAY_LOAD_HANDLE:
+ STACK_VM_LOAD(.v_handle, void*, void*);
+ break;
+ case ARRAY_LOAD_TVMVALUE:
+ STACK_VM_LOAD(, TVMValue, TVMValue);
+ break;
// store
- case ARRAY_STORE_UINT32: STACK_VM_STORE(.v_int64, uint32_t); break;
- case ARRAY_STORE_INT32: STACK_VM_STORE(.v_int64, int32_t); break;
- case ARRAY_STORE_INT64: STACK_VM_STORE(.v_int64, int64_t); break;
- case ARRAY_STORE_FP64: STACK_VM_STORE(.v_float64, double); break;
- case ARRAY_STORE_HANDLE: STACK_VM_STORE(.v_handle, void*); break;
- case ARRAY_STORE_TVMVALUE: STACK_VM_STORE(, TVMValue); break;
+ case ARRAY_STORE_UINT32:
+ STACK_VM_STORE(.v_int64, uint32_t);
+ break;
+ case ARRAY_STORE_INT32:
+ STACK_VM_STORE(.v_int64, int32_t);
+ break;
+ case ARRAY_STORE_INT64:
+ STACK_VM_STORE(.v_int64, int64_t);
+ break;
+ case ARRAY_STORE_FP64:
+ STACK_VM_STORE(.v_float64, double);
+ break;
+ case ARRAY_STORE_HANDLE:
+ STACK_VM_STORE(.v_handle, void*);
+ break;
+ case ARRAY_STORE_TVMVALUE:
+ STACK_VM_STORE(, TVMValue);
+ break;
// add
case ADDR_ADD: {
stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*)
}
case ASSERT_SP: {
int64_t expected = code[pc + 1].v_int;
- CHECK_EQ(sp, expected)
- << "sp assertion failed, expected="
- << expected << " now=" << sp << ", pc=" << pc;
+ CHECK_EQ(sp, expected) << "sp assertion failed, expected=" << expected << " now=" << sp
+ << ", pc=" << pc;
pc += 2;
break;
}
int begin = code[pc + 2].v_int;
int end = code[pc + 3].v_int;
int num_args = end - begin;
- static_assert(sizeof(Code) == sizeof(int) &&
- alignof(Code) == alignof(int), "asusmption");
+ static_assert(sizeof(Code) == sizeof(int) && alignof(Code) == alignof(int), "asusmption");
runtime::TVMRetValue rv;
- GetExtern(s, call_fid).CallPacked(
- runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv);
+ GetExtern(s, call_fid)
+ .CallPacked(runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv);
sp = sp - 1;
stack[sp] = rv.value();
pc += 4;
DLTensor* arr = static_cast<DLTensor*>(stack[sp].v_handle);
switch (kind) {
case StackVM::kArrData: {
- stack[sp].v_handle = arr[index].data; break;
+ stack[sp].v_handle = arr[index].data;
+ break;
}
case StackVM::kArrShape: {
- stack[sp].v_handle = arr[index].shape; break;
+ stack[sp].v_handle = arr[index].shape;
+ break;
}
case StackVM::kArrStrides: {
- stack[sp].v_handle = arr[index].strides; break;
+ stack[sp].v_handle = arr[index].strides;
+ break;
}
case StackVM::kArrNDim: {
- stack[sp].v_int64 = arr[index].ndim; break;
+ stack[sp].v_int64 = arr[index].ndim;
+ break;
}
case StackVM::kArrTypeCode: {
- stack[sp].v_int64 = static_cast<int64_t>(
- arr[index].dtype.code); break;
+ stack[sp].v_int64 = static_cast<int64_t>(arr[index].dtype.code);
+ break;
}
case StackVM::kArrTypeBits: {
- stack[sp].v_int64 = static_cast<int64_t>(
- arr[index].dtype.bits); break;
+ stack[sp].v_int64 = static_cast<int64_t>(arr[index].dtype.bits);
+ break;
}
case StackVM::kArrTypeLanes: {
- stack[sp].v_int64 = static_cast<int64_t>(
- arr[index].dtype.lanes); break;
+ stack[sp].v_int64 = static_cast<int64_t>(arr[index].dtype.lanes);
+ break;
}
case StackVM::kArrByteOffset: {
- stack[sp].v_int64 = static_cast<int64_t>(
- arr[index].byte_offset); break;
+ stack[sp].v_int64 = static_cast<int64_t>(arr[index].byte_offset);
+ break;
}
case StackVM::kArrDeviceId: {
- stack[sp].v_int64 = arr[index].ctx.device_id; break;
+ stack[sp].v_int64 = arr[index].ctx.device_id;
+ break;
}
case StackVM::kArrDeviceType: {
- stack[sp].v_int64 = static_cast<int64_t>(
- arr[index].ctx.device_type); break;
+ stack[sp].v_int64 = static_cast<int64_t>(arr[index].ctx.device_type);
+ break;
}
case StackVM::kArrAddr: {
- stack[sp].v_handle = arr + index; break;
+ stack[sp].v_handle = arr + index;
+ break;
}
case StackVM::kTVMValueContent: {
- stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index]; break;
+ stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index];
+ break;
}
- default: LOG(FATAL) << "unhandled get " << kind;
+ default:
+ LOG(FATAL) << "unhandled get " << kind;
}
pc = pc + 3;
break;
DLTensor* arr = static_cast<DLTensor*>(stack[sp - 1].v_handle);
switch (kind) {
case StackVM::kArrData: {
- arr[index].data = stack[sp].v_handle; break;
+ arr[index].data = stack[sp].v_handle;
+ break;
}
case StackVM::kArrShape: {
arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle);
break;
}
case StackVM::kTVMValueContent: {
- static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp]; break;
+ static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp];
+ break;
}
- default: LOG(FATAL) << "unhandled tvm_struct_set " << kind;
+ default:
+ LOG(FATAL) << "unhandled tvm_struct_set " << kind;
}
sp -= 2;
pc += 3;
size_t nbytes = static_cast<size_t>(stack[sp - 2].v_int64);
int dtype_code_hint = static_cast<int>(stack[sp - 1].v_int64);
int dtype_bits_hint = static_cast<int>(stack[sp].v_int64);
- void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes,
- dtype_code_hint, dtype_bits_hint);
+ void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint,
+ dtype_bits_hint);
stack[sp - 4].v_handle = ptr;
sp = sp - 4;
pc = pc + 1;
// allow race write in this, since write is idempotent
PackedFunc& f = extern_func_cache_[fid];
if (f == nullptr) {
- CHECK(s->mod_ctx != nullptr)
- << "No local context is set in stackvm";
+ CHECK(s->mod_ctx != nullptr) << "No local context is set in stackvm";
const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]);
CHECK(pf != nullptr);
f = *pf;
#define TVM_RUNTIME_STACKVM_STACKVM_H_
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
+
#include <string>
#include <vector>
* \param pc The pc
* \return the pc to next instruction.
*/
- int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*)
+ int64_t PrintCode(std::ostream& os, int64_t pc) const; // NOLINT(*)
/*! \brief Get thread local state of the stack VM */
static State* ThreadLocalState();
// The code below are programs
*/
static OpCode CodeI64ToF64(OpCode code) {
switch (code) {
- case ADD_I64: return ADD_F64;
- case SUB_I64: return SUB_F64;
- case MUL_I64: return MUL_F64;
- case DIV_I64: return DIV_F64;
- case EQ_I64: return EQ_F64;
- case LT_I64: return LT_F64;
- case LE_I64: return LE_F64;
- case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64;
- default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64;
+ case ADD_I64:
+ return ADD_F64;
+ case SUB_I64:
+ return SUB_F64;
+ case MUL_I64:
+ return MUL_F64;
+ case DIV_I64:
+ return DIV_F64;
+ case EQ_I64:
+ return EQ_F64;
+ case LT_I64:
+ return LT_F64;
+ case LE_I64:
+ return LE_F64;
+ case MOD_I64:
+ LOG(FATAL) << "cannot handle mod for float";
+ return ADD_F64;
+ default:
+ LOG(FATAL) << "cannot handle op " << code;
+ return ADD_F64;
}
}
/*!
if (t.code == kTVMOpaqueHandle) return ARRAY_LOAD_HANDLE;
if (t.code == kDLInt) {
switch (t.bits) {
- case 32 : return ARRAY_LOAD_INT32;
- case 64 : return ARRAY_LOAD_INT64;
+ case 32:
+ return ARRAY_LOAD_INT32;
+ case 64:
+ return ARRAY_LOAD_INT64;
}
} else if (t.code == kDLUInt) {
switch (t.bits) {
- case 32 : return ARRAY_LOAD_UINT32;
+ case 32:
+ return ARRAY_LOAD_UINT32;
}
} else if (t.code == kDLFloat) {
switch (t.bits) {
- case 64 : return ARRAY_LOAD_FP64;
+ case 64:
+ return ARRAY_LOAD_FP64;
}
}
LOG(FATAL) << "Cannot load type " << t;
if (t.code == kTVMOpaqueHandle) return ARRAY_STORE_HANDLE;
if (t.code == kDLInt) {
switch (t.bits) {
- case 32 : return ARRAY_STORE_INT32;
- case 64 : return ARRAY_STORE_INT64;
+ case 32:
+ return ARRAY_STORE_INT32;
+ case 64:
+ return ARRAY_STORE_INT64;
}
} else if (t.code == kDLUInt) {
switch (t.bits) {
- case 32 : return ARRAY_STORE_UINT32;
+ case 32:
+ return ARRAY_STORE_UINT32;
}
} else if (t.code == kDLFloat) {
switch (t.bits) {
- case 64 : return ARRAY_STORE_FP64;
+ case 64:
+ return ARRAY_STORE_FP64;
}
}
LOG(FATAL) << "Cannot store type " << t;
/*!
* \file stackvm_module.cc
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/module.h>
+#include "stackvm_module.h"
+
#include <dmlc/memory_io.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+
#include <memory>
-#include <utility>
#include <unordered_map>
-#include "stackvm_module.h"
+#include <utility>
+
#include "../file_util.h"
namespace tvm {
class StackVMModuleNode : public runtime::ModuleNode {
public:
- const char* type_key() const {
- return "stackvm";
- }
+ const char* type_key() const { return "stackvm"; }
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == runtime::symbol::tvm_module_main) {
return GetFunction(entry_func_, sptr_to_self);
}
if (it == fmap_.end()) return PackedFunc();
const StackVM& vm = it->second;
// capture sptr_to_self to keep module node alive.
- return PackedFunc([vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- vm.Run(args, this);
- });
+ return PackedFunc(
+ [vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { vm.Run(args, this); });
}
std::string GetSource(const std::string& format) final {
return os.str();
}
- void SaveToFile(const std::string& file_name,
- const std::string& format) final {
+ void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string data, mblob;
dmlc::MemoryStringStream writer(&data);
dmlc::Stream* strm = &writer;
strm->Write(num_imports);
for (runtime::Module im : imports_) {
- CHECK_EQ(im->imports().size(), 0U)
- << "Only support simply one-level hierarchy";
+ CHECK_EQ(im->imports().size(), 0U) << "Only support simply one-level hierarchy";
std::string tkey = im->type_key();
strm->Write(tkey);
LOG(INFO) << "save " << tkey;
SaveBinaryToFile(file_name, data);
}
- static Module Create(std::unordered_map<std::string, StackVM> fmap,
- std::string entry_func) {
+ static Module Create(std::unordered_map<std::string, StackVM> fmap, std::string entry_func) {
auto n = make_object<StackVMModuleNode>();
n->fmap_ = std::move(fmap);
n->entry_func_ = std::move(entry_func);
CHECK(strm->Read(&tkey));
std::string fkey = "runtime.module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
- CHECK(f != nullptr)
- << "Loader of " << tkey << "("
- << fkey << ") is not presented.";
+ CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(strm));
n->imports_.emplace_back(std::move(m));
}
return Module(n);
}
- static Module LoadFromFile(std::string file_name,
- std::string format) {
+ static Module LoadFromFile(std::string file_name, std::string format) {
std::string data;
LoadBinaryFromFile(file_name, &data);
dmlc::MemoryStringStream reader(&data);
std::string entry_func_;
};
-Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap,
- std::string entry_func) {
+Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap, std::string entry_func) {
return StackVMModuleNode::Create(fmap, entry_func);
}
TVM_REGISTER_GLOBAL("runtime.module.loadfile_stackvm")
-.set_body_typed(StackVMModuleNode::LoadFromFile);
+ .set_body_typed(StackVMModuleNode::LoadFromFile);
} // namespace runtime
} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_RUNTIME_STACKVM_STACKVM_MODULE_H_
#include <tvm/runtime/packed_func.h>
+
#include <string>
#include <unordered_map>
+
#include "stackvm.h"
namespace tvm {
* \param entry_func The entry function name.
* \return The created module
*/
-Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap,
- std::string entry_func);
+Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap, std::string entry_func);
} // namespace runtime
} // namespace tvm
* \file system_library.cc
* \brief Create library module that directly get symbol from the system lib.
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/memory.h>
#include <tvm/runtime/c_backend_api.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/registry.h>
+
#include <mutex>
+
#include "library_module.h"
namespace tvm {
std::lock_guard<std::mutex> lock(mutex_);
auto it = tbl_.find(name);
if (it != tbl_.end() && ptr != it->second) {
- LOG(WARNING)
- << "SystemLib symbol " << name
- << " get overriden to a different address "
- << ptr << "->" << it->second;
+ LOG(WARNING) << "SystemLib symbol " << name << " get overriden to a different address " << ptr
+ << "->" << it->second;
}
tbl_[name] = ptr;
}
std::unordered_map<std::string, void*> tbl_;
};
-TVM_REGISTER_GLOBAL("runtime.SystemLib")
-.set_body_typed([]() {
- static auto mod = CreateModuleFromLibrary(
- SystemLibrary::Global());
- return mod;
+TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_typed([]() {
+ static auto mod = CreateModuleFromLibrary(SystemLibrary::Global());
+ return mod;
});
} // namespace runtime
} // namespace tvm
* \file thread_pool.cc
* \brief Threadpool for multi-threading runtime.
*/
-#include <tvm/runtime/c_runtime_api.h>
+#include <dmlc/logging.h>
+#include <dmlc/thread_local.h>
#include <tvm/runtime/c_backend_api.h>
-#include <tvm/runtime/registry.h>
+#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>
-#include <dmlc/thread_local.h>
-#include <dmlc/logging.h>
#if TVM_THREADPOOL_USE_OPENMP
#include <omp.h>
#endif
-#include <thread>
-#include <condition_variable>
-#include <mutex>
-#include <atomic>
#include <algorithm>
-#include <vector>
-#include <string>
+#include <atomic>
+#include <condition_variable>
#include <cstring>
#include <memory>
+#include <mutex>
#include <sstream>
+#include <string>
+#include <thread>
+#include <vector>
const constexpr int kL1CacheBytes = 64;
class ParallelLauncher {
public:
// Reset the the task request.
- void Init(FTVMParallelLambda flambda,
- void* cdata,
- int num_task,
- bool need_sync) {
+ void Init(FTVMParallelLambda flambda, void* cdata, int num_task, bool need_sync) {
num_pending_.store(num_task);
this->cdata = cdata;
this->flambda = flambda;
}
if (need_sync) {
for (int i = 0; i < num_task; ++i) {
- sync_counter_[i * kSyncStride].store(
- 0, std::memory_order_relaxed);
+ sync_counter_[i * kSyncStride].store(0, std::memory_order_relaxed);
}
this->env.sync_handle = sync_counter_;
} else {
this->env.sync_handle = nullptr;
}
}
- ~ParallelLauncher() {
- delete[] sync_counter_;
- }
+ ~ParallelLauncher() { delete[] sync_counter_; }
// Wait n jobs to finish
int WaitForJobs() {
while (num_pending_.load() != 0) {
has_error_.store(true);
}
// Signal that one job has finished.
- void SignalJobFinish() {
- num_pending_.fetch_sub(1);
- }
+ void SignalJobFinish() { num_pending_.fetch_sub(1); }
// Get thread local version of the store.
- static ParallelLauncher* ThreadLocal() {
- return dmlc::ThreadLocalStore<ParallelLauncher>::Get();
- }
+ static ParallelLauncher* ThreadLocal() { return dmlc::ThreadLocalStore<ParallelLauncher>::Get(); }
// The parallel lambda
FTVMParallelLambda flambda;
// The closure data
int32_t task_id;
};
- SpscTaskQueue() :
- buffer_(new Task[kRingSize]),
- head_(0),
- tail_(0) {
- }
+ SpscTaskQueue() : buffer_(new Task[kRingSize]), head_(0), tail_(0) {}
- ~SpscTaskQueue() {
- delete[] buffer_;
- }
+ ~SpscTaskQueue() { delete[] buffer_; }
/*!
* \brief Push a task into the queue and notify the comsumer if it is on wait.
}
if (pending_.fetch_sub(1) == 0) {
std::unique_lock<std::mutex> lock(mutex_);
- cv_.wait(lock, [this] {
- return pending_.load() >= 0 || exit_now_.load();
- });
+ cv_.wait(lock, [this] { return pending_.load() >= 0 || exit_now_.load(); });
}
if (exit_now_.load(std::memory_order_relaxed)) {
return false;
// The thread pool
class ThreadPool {
public:
- ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) {
+ ThreadPool() : num_workers_(tvm::runtime::threading::MaxConcurrency()) {
for (int i = 0; i < num_workers_; ++i) {
// The SpscTaskQueue only hosts ONE item at a time
queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
}
threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
new tvm::runtime::threading::ThreadGroup(
- num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
- exclude_worker0_ /* include_main_thread */));
+ num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
+ exclude_worker0_ /* include_main_thread */));
num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
}
~ThreadPool() {
}
threads_.reset();
}
- int Launch(FTVMParallelLambda flambda,
- void* cdata,
- int num_task,
- int need_sync) {
+ int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, int need_sync) {
ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
CHECK(!launcher->is_worker)
<< "Cannot launch parallel job inside worker, consider fuse then parallel";
return res;
}
- static ThreadPool* ThreadLocal() {
- return dmlc::ThreadLocalStore<ThreadPool>::Get();
- }
+ static ThreadPool* ThreadLocal() { return dmlc::ThreadLocalStore<ThreadPool>::Get(); }
void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) {
// this will also reset the affinity of the ThreadGroup
// may use less than the MaxConcurrency number of workers
- num_workers_used_ = threads_->Configure(mode, nthreads,
- exclude_worker0_);
+ num_workers_used_ = threads_->Configure(mode, nthreads, exclude_worker0_);
// if MaxConcurrency restricted the number of workers (e.g., due to
// hyperthreading), respect the restriction
num_workers_used_ = std::min(num_workers_, num_workers_used_);
std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
};
-TVM_REGISTER_GLOBAL("runtime.config_threadpool")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- threading::ThreadGroup::AffinityMode mode =\
- static_cast<threading::ThreadGroup::AffinityMode>(\
- static_cast<int>(args[0]));
- int nthreads = args[1];
- ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
+TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRetValue* rv) {
+ threading::ThreadGroup::AffinityMode mode =
+ static_cast<threading::ThreadGroup::AffinityMode>(static_cast<int>(args[0]));
+ int nthreads = args[1];
+ ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
});
-
} // namespace runtime
} // namespace tvm
-
-int TVMBackendParallelLaunch(
- FTVMParallelLambda flambda,
- void* cdata,
- int num_task) {
+int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) {
#if !TVM_THREADPOOL_USE_OPENMP
- int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(
- flambda, cdata, num_task, 1);
+ int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1);
return res;
#else
int num_workers = tvm::runtime::threading::MaxConcurrency();
if (num_task == 0) num_task = num_workers;
omp_set_num_threads(num_workers);
- #pragma omp parallel num_threads(num_workers)
+#pragma omp parallel num_threads(num_workers)
{
TVMParallelGroupEnv env;
env.num_task = num_task;
int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
#if TVM_THREADPOOL_USE_OPENMP
- #pragma omp barrier
+#pragma omp barrier
#else
using tvm::runtime::kSyncStride;
int num_task = penv->num_task;
- std::atomic<int>* sync_counter =
- reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
- int old_counter = sync_counter[task_id * kSyncStride].fetch_add(
- 1, std::memory_order_release);
+ std::atomic<int>* sync_counter = reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
+ int old_counter = sync_counter[task_id * kSyncStride].fetch_add(1, std::memory_order_release);
for (int i = 0; i < num_task; ++i) {
if (i != task_id) {
- while (sync_counter[i * kSyncStride].load(
- std::memory_order_relaxed) <= old_counter) {
+ while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <= old_counter) {
tvm::runtime::threading::Yield();
}
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
#include <tvm/runtime/packed_func.h>
+
#include <string>
#include <vector>
*/
inline StorageRank DefaultStorageRank(int thread_scope_rank) {
switch (thread_scope_rank) {
- case -1: return StorageRank::kGlobal;
- case 0: return StorageRank::kShared;
- case 1: return StorageRank::kLocal;
+ case -1:
+ return StorageRank::kGlobal;
+ case 0:
+ return StorageRank::kShared;
+ case 1:
+ return StorageRank::kLocal;
default: {
LOG(FATAL) << "unknown rank";
return StorageRank::kGlobal;
inline bool operator==(const StorageScope& other) const {
return rank == other.rank && tag == other.tag;
}
- inline bool operator!=(const StorageScope& other) const {
- return !(*this == other);
- }
+ inline bool operator!=(const StorageScope& other) const { return !(*this == other); }
inline std::string to_string() const {
std::string ret;
switch (rank) {
- case StorageRank::kGlobal: return "global" + tag;
- case StorageRank::kShared: return "shared" + tag;
- case StorageRank::kWarp: return "warp" + tag;
- case StorageRank::kLocal: return "local" + tag;
- case StorageRank::kWMMAMatrixA: return "wmma.matrix_a" + tag;
- case StorageRank::kWMMAMatrixB: return "wmma.matrix_b" + tag;
- case StorageRank::kWMMAAccumulator: return "wmma.accumulator" + tag;
- default: LOG(FATAL) << "unknown storage scope"; return "";
+ case StorageRank::kGlobal:
+ return "global" + tag;
+ case StorageRank::kShared:
+ return "shared" + tag;
+ case StorageRank::kWarp:
+ return "warp" + tag;
+ case StorageRank::kLocal:
+ return "local" + tag;
+ case StorageRank::kWMMAMatrixA:
+ return "wmma.matrix_a" + tag;
+ case StorageRank::kWMMAMatrixB:
+ return "wmma.matrix_b" + tag;
+ case StorageRank::kWMMAAccumulator:
+ return "wmma.accumulator" + tag;
+ default:
+ LOG(FATAL) << "unknown storage scope";
+ return "";
}
}
/*!
*/
static StorageScope make(const std::string& s) {
StorageScope r;
- if (s.compare(0, 6, "global") == 0) {
+ if (s.compare(0, 6, "global") == 0) {
r.rank = StorageRank::kGlobal;
r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 6, "shared") == 0) {
}
};
-
/*! \brief workload specification */
struct ThreadWorkLoad {
// array, first three are thread configuration.
* \param i The block dimension.
* \return i-th block dim
*/
- inline size_t block_dim(size_t i) const {
- return work_size[i + 3];
- }
+ inline size_t block_dim(size_t i) const { return work_size[i + 3]; }
/*!
* \param i The grid dimension.
* \return i-th grid dim
*/
- inline size_t grid_dim(size_t i) const {
- return work_size[i];
- }
+ inline size_t grid_dim(size_t i) const { return work_size[i]; }
};
/*! \brief Thread axis configuration */
class ThreadAxisConfig {
public:
- void Init(size_t base,
- const std::vector<std::string>& thread_axis_tags) {
+ void Init(size_t base, const std::vector<std::string>& thread_axis_tags) {
base_ = base;
std::vector<bool> filled(6, false);
for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
ThreadWorkLoad w;
std::fill(w.work_size, w.work_size + 6, 1);
for (size_t i = 0; i < arg_index_map_.size(); ++i) {
- w.work_size[arg_index_map_[i]] =
- static_cast<size_t>(x.values[base_ + i].v_int64);
+ w.work_size[arg_index_map_[i]] = static_cast<size_t>(x.values[base_ + i].v_int64);
}
return w;
}
// return the work dim
- size_t work_dim() const {
- return work_dim_;
- }
+ size_t work_dim() const { return work_dim_; }
private:
/*! \brief base axis */
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \file threading_backend.cc
* \brief Native threading backend
*/
-#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
-#include <thread>
+#include <tvm/runtime/threading_backend.h>
+
#include <algorithm>
+#include <thread>
#if defined(__linux__) || defined(__ANDROID__)
#include <fstream>
#include <sstream>
class ThreadGroup::Impl {
public:
- Impl(int num_workers,
- std::function<void(int)> worker_callback,
- bool exclude_worker0)
+ Impl(int num_workers, std::function<void(int)> worker_callback, bool exclude_worker0)
: num_workers_(num_workers) {
- CHECK_GE(num_workers, 1)
- << "Requested a non-positive number of worker threads.";
+ CHECK_GE(num_workers, 1) << "Requested a non-positive number of worker threads.";
for (int i = exclude_worker0; i < num_workers_; ++i) {
threads_.emplace_back([worker_callback, i] { worker_callback(i); });
}
// ones.
num_workers_used = std::min(num_workers_, num_workers_used);
- const char *val = getenv("TVM_BIND_THREADS");
+ const char* val = getenv("TVM_BIND_THREADS");
if (val == nullptr || atoi(val) == 1) {
// Do not set affinity if there are more workers than found cores
if (sorted_order_.size() >= static_cast<unsigned int>(num_workers_)) {
- SetAffinity(exclude_worker0, mode == kLittle);
+ SetAffinity(exclude_worker0, mode == kLittle);
} else {
- LOG(WARNING)
- << "The thread affinity cannot be set when the number of workers"
- << "is larger than the number of available cores in the system.";
+ LOG(WARNING) << "The thread affinity cannot be set when the number of workers"
+ << "is larger than the number of available cores in the system.";
}
}
return num_workers_used;
#if defined(__ANDROID__)
#ifndef CPU_SET
#define CPU_SETSIZE 1024
-#define __NCPUBITS (8 * sizeof (uint64_t))
+#define __NCPUBITS (8 * sizeof(uint64_t))
typedef struct {
uint64_t __bits[CPU_SETSIZE / __NCPUBITS];
} cpu_set_t;
#define CPU_SET(cpu, cpusetp) \
- ((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
-#define CPU_ZERO(cpusetp) \
- memset((cpusetp), 0, sizeof(cpu_set_t))
+ ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
+#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t))
#endif
#endif
#if defined(__linux__) || defined(__ANDROID__)
#if defined(__ANDROID__)
sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);
#else
- pthread_setaffinity_np(threads_[i].native_handle(),
- sizeof(cpu_set_t), &cpuset);
+ pthread_setaffinity_np(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);
#endif
}
if (exclude_worker0) { // master thread run task
void InitSortedOrder() {
unsigned int threads = std::thread::hardware_concurrency();
- std::vector<std::pair <unsigned int, int64_t> > max_freqs;
+ std::vector<std::pair<unsigned int, int64_t> > max_freqs;
for (unsigned int i = 0; i < threads; ++i) {
int64_t cur_freq = 0;
- #if defined(__linux__) || defined(__ANDROID__)
- std::ostringstream filepath;
- filepath << "/sys/devices/system/cpu/cpu" << i << "/cpufreq/cpuinfo_max_freq";
- std::ifstream ifs(filepath.str());
- if (!ifs.fail()) {
- if (!(ifs >> cur_freq)) {
- cur_freq = -1;
- }
- ifs.close();
+#if defined(__linux__) || defined(__ANDROID__)
+ std::ostringstream filepath;
+ filepath << "/sys/devices/system/cpu/cpu" << i << "/cpufreq/cpuinfo_max_freq";
+ std::ifstream ifs(filepath.str());
+ if (!ifs.fail()) {
+ if (!(ifs >> cur_freq)) {
+ cur_freq = -1;
}
- #endif
+ ifs.close();
+ }
+#endif
max_freqs.push_back(std::make_pair(i, cur_freq));
}
- auto fcmpbyfreq = [] (const std::pair<unsigned int, int64_t> &a,
- const std::pair<unsigned int, int64_t> &b) {
- return a.second == b.second ? a.first < b.first : a.second > b.second;
+ auto fcmpbyfreq = [](const std::pair<unsigned int, int64_t>& a,
+ const std::pair<unsigned int, int64_t>& b) {
+ return a.second == b.second ? a.first < b.first : a.second > b.second;
};
std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq);
int64_t big_freq = max_freqs.begin()->second;
int little_count_ = 0;
};
-ThreadGroup::ThreadGroup(int num_workers,
- std::function<void(int)> worker_callback,
+ThreadGroup::ThreadGroup(int num_workers, std::function<void(int)> worker_callback,
bool exclude_worker0)
- : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
+ : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
ThreadGroup::~ThreadGroup() { delete impl_; }
void ThreadGroup::Join() { impl_->Join(); }
return impl_->Configure(mode, nthreads, exclude_worker0);
}
-void Yield() {
- std::this_thread::yield();
-}
+void Yield() { std::this_thread::yield(); }
int MaxConcurrency() {
int max_concurrency = 1;
- const char *val = getenv("TVM_NUM_THREADS");
+ const char* val = getenv("TVM_NUM_THREADS");
if (val == nullptr) {
val = getenv("OMP_NUM_THREADS");
}
return std::max(max_concurrency, 1);
}
-
} // namespace threading
} // namespace runtime
} // namespace tvm
#include <tvm/runtime/vm.h>
#include <algorithm>
-#include <memory>
-#include <iostream>
#include <iomanip>
+#include <iostream>
+#include <memory>
#include <sstream>
#include <utility>
#include <vector>
// Helper to deserialize a serialized vm instruction.
Instruction DeserializeInstruction(const VMInstructionSerializer& instr);
-PackedFunc Executable::GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "get_lib") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetLib();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); });
} else if (name == "get_bytecode") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->GetBytecode();
- });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetBytecode(); });
} else if (name == "get_stats") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->Stats();
- });
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); });
} else if (name == "save") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->Save();
- });
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Save(); });
} else if (name == "get_function_arity") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
// Get the number of globals and the name of each of them.
oss << " Globals (#" << global_map.size() << "): [";
for (const auto& it : global_map) {
- oss << "(\"" << it.first << "\", " << it.second << ")" << ", ";
+ oss << "(\"" << it.first << "\", " << it.second << ")"
+ << ", ";
}
if (!global_map.empty()) oss.seekp(-2, oss.cur);
oss << "]" << std::endl;
void Executable::SaveGlobalSection(dmlc::Stream* strm) {
std::vector<std::pair<std::string, Index> > globals(this->global_map.begin(),
this->global_map.end());
- auto comp = [](const std::pair<std::string, Index>& a,
- const std::pair<std::string, Index>& b) {
+ auto comp = [](const std::pair<std::string, Index>& a, const std::pair<std::string, Index>& b) {
return a.second < b.second;
};
std::sort(globals.begin(), globals.end(), comp);
fields.assign({instr.constructor_tag, instr.num_fields, instr.dst});
// Save the fields.
- fields.insert(fields.end(), instr.datatype_fields,
- instr.datatype_fields + instr.num_fields);
+ fields.insert(fields.end(), instr.datatype_fields, instr.datatype_fields + instr.num_fields);
break;
}
case Opcode::AllocClosure: {
fields.assign({instr.clo_index, instr.num_freevar, instr.dst});
// Save the free vars.
- fields.insert(fields.end(), instr.free_vars,
- instr.free_vars + instr.num_freevar);
+ fields.insert(fields.end(), instr.free_vars, instr.free_vars + instr.num_freevar);
break;
}
case Opcode::If: {
// Number of fields = 4
- fields.assign({instr.if_op.test,
- instr.if_op.target,
- instr.if_op.true_offset,
+ fields.assign({instr.if_op.test, instr.if_op.target, instr.if_op.true_offset,
instr.if_op.false_offset});
break;
}
fields.assign({instr.closure, instr.num_closure_args, instr.dst});
// Save the args.
- fields.insert(fields.end(), instr.closure_args,
- instr.closure_args + instr.num_closure_args);
+ fields.insert(fields.end(), instr.closure_args, instr.closure_args + instr.num_closure_args);
break;
}
case Opcode::LoadConst: {
strm->Write(static_cast<uint64_t>(this->functions.size()));
for (const auto& func : this->functions) {
// Save the function info.
- VMFunctionSerializer func_format(func.name,
- func.register_file_size,
- func.instructions.size(),
+ VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(),
func.params);
func_format.Save(strm);
// Extract the `cnt` number of fields started at `start` from the list
// `instr_fields`.
-inline std::vector<Index> ExtractFields(const std::vector<Index>& instr_fields,
- Index start,
+inline std::vector<Index> ExtractFields(const std::vector<Index>& instr_fields, Index start,
Index cnt) {
CHECK_LE(static_cast<size_t>(start + cnt), instr_fields.size());
std::vector<Index> ret;
RegName dst = instr.fields[5];
- return Instruction::AllocStorage(
- allocation_size,
- alignment,
- dtype,
- dst);
+ return Instruction::AllocStorage(allocation_size, alignment, dtype, dst);
}
case Opcode::If: {
// Number of fields = 4
}
// Create the VM function.
- VMFunction vm_func = VMFunction(loaded_func.name,
- loaded_func.params,
- instructions,
+ VMFunction vm_func = VMFunction(loaded_func.name, loaded_func.params, instructions,
loaded_func.register_file_size);
auto it = this->global_map.find(loaded_func.name);
CHECK(it != this->global_map.end());
}
}
-TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec);
*rv = static_cast<int>(exec->global_map.size());
});
-TVM_REGISTER_GLOBAL("runtime.GetGlobalFields")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime.GetGlobalFields").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec);
int idx = args[1];
std::vector<std::pair<std::string, Index> > globals(exec->global_map.begin(),
exec->global_map.end());
- auto comp = [](const std::pair<std::string, Index>& a,
- const std::pair<std::string, Index>& b) {
+ auto comp = [](const std::pair<std::string, Index>& a, const std::pair<std::string, Index>& b) {
return a.second < b.second;
};
std::sort(globals.begin(), globals.end(), comp);
*rv = globals[idx].first;
});
-TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec);
*rv = static_cast<int>(exec->primitive_map.size());
});
-
-TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec);
});
TVM_REGISTER_GLOBAL("runtime.Load_Executable")
-.set_body_typed([](
- std::string code,
- runtime::Module lib) {
- return Executable::Load(code, lib);
-});
+ .set_body_typed([](std::string code, runtime::Module lib) {
+ return Executable::Load(code, lib);
+ });
} // namespace vm
} // namespace runtime
* \file tvm/runtime/vm/memory_manager.cc
* \brief Allocate and manage memory for the runtime.
*/
-#include <utility>
-#include <memory>
#include "memory_manager.h"
+
+#include <memory>
+#include <utility>
+
#include "naive_allocator.h"
#include "pooled_allocator.h"
auto* ptr = static_cast<NDArray::Container*>(obj);
CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
- MemoryManager::Global()->GetAllocator(buffer->ctx)->
- Free(*(buffer));
+ MemoryManager::Global()->GetAllocator(buffer->ctx)->Free(*(buffer));
delete buffer;
delete ptr;
}
// RAII in effect, now run the check.
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
- << "size mistmatch required " << needed_size << " found " << this->buffer.size;
+ << "size mistmatch required " << needed_size << " found " << this->buffer.size;
return ret;
}
Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
std::lock_guard<std::mutex> lock(mu_);
if (allocators_.find(ctx) == allocators_.end()) {
- DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "("
- << ctx.device_id << ")";
+ DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" << ctx.device_id
+ << ")";
std::unique_ptr<Allocator> alloc(new NaiveAllocator(ctx));
allocators_.emplace(ctx, std::move(alloc));
}
container->SetDeleter(BufferDeleter);
size_t size = GetDataSize(container->dl_tensor);
size_t alignment = GetDataAlignment(container->dl_tensor);
- Buffer *buffer = new Buffer;
+ Buffer* buffer = new Buffer;
*buffer = this->Alloc(size, alignment, dtype);
container->manager_ctx = reinterpret_cast<void*>(buffer);
container->dl_tensor.data = buffer->data;
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>
+
#include <functional>
#include <memory>
#include <mutex>
* \param ctx The context where the array is allocated.
* \return The empty NDArray.
*/
- NDArray Empty(std::vector<int64_t> shape,
- DLDataType dtype,
- DLContext ctx);
+ NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx);
/*! \brief Allocate a buffer given a size, alignment and type.
* \param nbytes The size of the buffer.
* \param alignment The alignment of the buffer.
* \param type_hint A type hint to the allocator.
* \return A sized allocation in the form of a buffer.
- */
+ */
virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0;
/*! \brief Free a buffer allocated by the allocator.
* \param buffer The buffer to free.
Buffer buffer;
/*! \brief Allocate an NDArray from a given piece of storage. */
- NDArray AllocNDArray(size_t offset,
- std::vector<int64_t> shape,
- DLDataType dtype);
+ NDArray AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDataType dtype);
/*! \brief The deleter for an NDArray when allocated from underlying storage. */
static void Deleter(Object* ptr);
#define TVM_RUNTIME_VM_NAIVE_ALLOCATOR_H_
#include <tvm/runtime/device_api.h>
+
#include <atomic>
#include "memory_manager.h"
DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B";
}
- size_t UsedMemory() const override {
- return used_memory_.load(std::memory_order_relaxed);
- }
+ size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); }
private:
std::atomic<size_t> used_memory_;
#define TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_
#include <tvm/runtime/device_api.h>
+
#include <atomic>
#include <mutex>
#include <unordered_map>
* \brief The Relay debug virtual machine.
*/
+#include "vm.h"
+
#include <tvm/runtime/registry.h>
#include <tvm/runtime/vm.h>
#include <utility>
#include <vector>
-#include "vm.h"
-
namespace tvm {
namespace runtime {
namespace vm {
-PackedFunc VirtualMachineDebug::GetFunction(
- const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
+PackedFunc VirtualMachineDebug::GetFunction(const std::string& name,
+ const ObjectPtr<Object>& sptr_to_self) {
if (name == "get_stat") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1U);
std::vector<std::pair<Index, double>> op_acc_time;
for (auto kv : op_durations_) {
- auto val = std::make_pair(
- kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0));
+ auto val =
+ std::make_pair(kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0));
op_acc_time.push_back(val);
}
bool sort_by_time = args[0];
if (sort_by_time) {
- auto comp = [](const std::pair<Index, double>& lhs,
- const std::pair<Index, double>& rhs) {
+ auto comp = [](const std::pair<Index, double>& lhs, const std::pair<Index, double>& rhs) {
return lhs.second > rhs.second;
};
std::sort(op_acc_time.begin(), op_acc_time.end(), comp);
auto min_value = *std::min_element(vals.begin(), vals.end());
auto max_value = *std::max_element(vals.begin(), vals.end());
- os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t"
- << std::setw(10) << std::left << op_invokes_[kv.first] << "\t"
- << sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl;
+ os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t" << std::setw(10)
+ << std::left << op_invokes_[kv.first] << "\t" << sum << "/" << mean << "/" << min_value
+ << "/" << max_value << std::endl;
total_duration += sum;
total_packed_funcs += op_invokes_[kv.first];
}
}
-void VirtualMachineDebug::InvokePacked(Index packed_index,
- const PackedFunc& func, Index arg_count,
- Index output_size,
- const std::vector<ObjectRef>& args) {
+void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
+ Index output_size, const std::vector<ObjectRef>& args) {
CHECK(exec_);
auto ctx = this->GetParamsContext();
// warmup
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_end = std::chrono::high_resolution_clock::now();
double op_duration =
- std::chrono::duration_cast<std::chrono::duration<double> >(op_end -
- op_begin)
- .count();
+ std::chrono::duration_cast<std::chrono::duration<double>>(op_end - op_begin).count();
op_durations_[packed_index].push_back(op_duration * 1e6);
op_invokes_[packed_index] += 1;
return runtime::Module(vm);
}
-TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec) << "Virtual machine has not been defined yet."
public:
VirtualMachineDebug() : VirtualMachine() {}
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final;
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
void LoadExecutable(const Executable* exec) final;
~VirtualMachineDebug() {}
private:
- void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
- Index output_size, const std::vector<ObjectRef>& args) final;
+ void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size,
+ const std::vector<ObjectRef>& args) final;
std::unordered_map<Index, std::string> packed_index_map_;
std::unordered_map<Index, std::vector<double>> op_durations_;
VMFunctionSerializer() = default;
- VMFunctionSerializer(const std::string& name,
- Index register_file_size,
- size_t num_instructions,
+ VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions,
const std::vector<std::string>& params)
: name(name),
register_file_size(register_file_size),
}
/*!
- * \brief Save the VM function header into the serialized form.
+ * \brief Save the VM function header into the serialized form.
* \param strm The stream used to save data.
*/
void Save(dmlc::Stream* strm) const {
VMInstructionSerializer() = default;
- VMInstructionSerializer(Index opcode, const std::vector<Index>& fields) :
- opcode(opcode), fields(fields) {}
+ VMInstructionSerializer(Index opcode, const std::vector<Index>& fields)
+ : opcode(opcode), fields(fields) {}
/*!
- * \brief Compute the hash of the serialized instruction.
+ * \brief Compute the hash of the serialized instruction.
* \return The hash that combines the opcode and all fields of the VM
* instruction.
*/
}
Index hash = Hash();
- CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: "
- << opcode << "\n";
+ CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " << opcode << "\n";
return true;
}
/*!
- * \brief Save the instruction into the serialized form.
+ * \brief Save the instruction into the serialized form.
* \param strm The stream used to save data.
*/
void Save(dmlc::Stream* strm) const {
*/
#include <dmlc/memory_io.h>
-#include <tvm/support/logging.h>
#include <tvm/runtime/container.h>
-#include <tvm/runtime/vm.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
+#include <tvm/runtime/vm.h>
+#include <tvm/support/logging.h>
#include <algorithm>
#include <chrono>
// We could put cache in here, from ctx to storage allocator.
auto storage_obj = SimpleObjAllocator().make_object<StorageObj>();
auto alloc = MemoryManager::Global()->GetAllocator(ctx);
- DCHECK(alloc != nullptr)
- << "allocator must not null";
+ DCHECK(alloc != nullptr) << "allocator must not null";
storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint);
return Storage(storage_obj);
}
case Opcode::AllocTensor:
this->alloc_tensor.storage = instr.alloc_tensor.storage;
this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
- this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
- instr.alloc_tensor.ndim);
+ this->alloc_tensor.shape =
+ Duplicate<int64_t>(instr.alloc_tensor.shape, instr.alloc_tensor.ndim);
this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
return;
case Opcode::AllocTensorReg:
}
}
-template<typename T>
+template <typename T>
static inline void FreeIf(T* t) {
if (t != nullptr) {
delete t;
case Opcode::AllocTensor:
this->alloc_tensor.storage = instr.alloc_tensor.storage;
this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
- this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
- instr.alloc_tensor.ndim);
+ this->alloc_tensor.shape =
+ Duplicate<int64_t>(instr.alloc_tensor.shape, instr.alloc_tensor.ndim);
this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
return *this;
case Opcode::AllocTensorReg:
return instr;
}
-Instruction Instruction::InvokePacked(Index packed_index,
- Index arity,
- Index output_size,
+Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size,
const std::vector<RegName>& args) {
Instruction instr;
instr.op = Opcode::InvokePacked;
return instr;
}
-Instruction Instruction::AllocTensor(
- RegName storage,
- const std::vector<int64_t>& shape,
- DLDataType dtype, Index dst) {
+Instruction Instruction::AllocTensor(RegName storage, const std::vector<int64_t>& shape,
+ DLDataType dtype, Index dst) {
Instruction instr;
instr.op = Opcode::AllocTensor;
instr.dst = dst;
return instr;
}
-Instruction Instruction::AllocTensorReg(
- RegName storage,
- RegName shape_register,
- DLDataType dtype, Index dst) {
+Instruction Instruction::AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype,
+ Index dst) {
Instruction instr;
instr.op = Opcode::AllocTensorReg;
instr.dst = dst;
return instr;
}
-Instruction Instruction::AllocStorage(RegName size,
- Index alignment,
- DLDataType dtype_hint,
+Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
Index dst) {
Instruction instr;
instr.op = Opcode::AllocStorage;
}
Instruction Instruction::AllocADT(Index tag, Index num_fields,
- const std::vector<RegName>& datatype_fields, Index dst) {
+ const std::vector<RegName>& datatype_fields, Index dst) {
Instruction instr;
instr.op = Opcode::AllocADT;
instr.dst = dst;
}
}
-template<typename T>
+template <typename T>
std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") {
if (cnt == 0) {
return "";
}
case Opcode::InvokePacked: {
os << "invoke_packed PackedFunc[" << instr.packed_index << "] (in: $"
- << StrJoin<RegName>(instr.packed_args, 0,
- instr.arity - instr.output_size, ", $")
+ << StrJoin<RegName>(instr.packed_args, 0, instr.arity - instr.output_size, ", $")
<< ", out: $"
- << StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size,
- instr.output_size, ", $")
+ << StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size, instr.output_size,
+ ", $")
<< ")";
break;
}
case Opcode::AllocTensor: {
- os << "alloc_tensor $" << instr.dst << " $"
- << instr.alloc_tensor.storage << " ["
- << StrJoin<int64_t>(instr.alloc_tensor.shape, 0,
- instr.alloc_tensor.ndim)
- << "] ";
+ os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " ["
+ << StrJoin<int64_t>(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] ";
DLDatatypePrint(os, instr.alloc_tensor.dtype);
break;
}
case Opcode::AllocTensorReg: {
- os << "alloc_tensor_reg $" << instr.dst << " $"
- << instr.alloc_tensor_reg.storage << " $"
+ os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $"
<< instr.alloc_tensor_reg.shape_register << " ";
DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
break;
break;
}
case Opcode::AllocClosure: {
- os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index
- << "]($" << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$")
- << ")";
+ os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index << "]($"
+ << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$") << ")";
break;
}
case Opcode::If: {
- os << "if " << "$" << instr.if_op.test << " $" << instr.if_op.target << " "
- << instr.if_op.true_offset << " " << instr.if_op.false_offset;
+ os << "if "
+ << "$" << instr.if_op.test << " $" << instr.if_op.target << " " << instr.if_op.true_offset
+ << " " << instr.if_op.false_offset;
break;
}
case Opcode::Invoke: {
os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
- << StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$")
- << ")";
+ << StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$") << ")";
break;
}
case Opcode::InvokeClosure: {
os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
- << StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$")
- << ")";
+ << StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$") << ")";
break;
}
case Opcode::LoadConst: {
break;
}
case Opcode::GetField: {
- os << "get_field $" << instr.dst << " $" << instr.object << "["
- << instr.field_index << "]";
+ os << "get_field $" << instr.dst << " $" << instr.object << "[" << instr.field_index << "]";
break;
}
case Opcode::GetTag: {
break;
}
case Opcode::AllocStorage: {
- os << "alloc_storage $" <<
- instr.dst << " $" <<
- instr.alloc_storage.allocation_size << " $" <<
- instr.alloc_storage.alignment << " " <<
- DLDataType2String(instr.alloc_storage.dtype_hint);
+ os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " $"
+ << instr.alloc_storage.alignment << " "
+ << DLDataType2String(instr.alloc_storage.dtype_hint);
break;
}
default:
std::string func_name = args[0];
auto git = exec_->global_map.find(func_name);
CHECK(git != exec_->global_map.end())
- << "Cannot find function " << func_name << " in the executable";
+ << "Cannot find function " << func_name << " in the executable";
auto func = exec_->functions[git->second];
if (func.params.empty()) {
*rv = Invoke(func, {});
} else {
auto it = inputs_.find(func_name);
CHECK(it != inputs_.end()) << "Input has not been set for function " << func_name;
- const std::vector<ObjectRef> &func_args = it->second;
+ const std::vector<ObjectRef>& func_args = it->second;
*rv = Invoke(func, func_args);
}
});
const auto& param_names = vm_func.params;
// TODO(icemelon9): For heterogeneous execution, get input device information
TVMContext ctx = ctxs_[0];
- CHECK_EQ(args.size() - 1, param_names.size()) <<
- "The number of provided parameters doesn't match the number of arguments";
+ CHECK_EQ(args.size() - 1, param_names.size())
+ << "The number of provided parameters doesn't match the number of arguments";
std::vector<ObjectRef> func_args(param_names.size());
for (int i = 1; i < args.size(); ++i) {
ObjectRef obj = CopyTo(args[i], ctx);
ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) {
CHECK(exec_) << "The executable has not been created yet.";
auto it = exec_->global_map.find(name);
- CHECK(it != exec_->global_map.end())
- << "Cannot find function " << name << " in the executable";
+ CHECK(it != exec_->global_map.end()) << "Cannot find function " << name << " in the executable";
auto func_index_ = it->second;
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index_;
return Invoke(exec_->functions[func_index_], args);
}
-void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
- Index arg_count, Index output_size,
- const std::vector<ObjectRef>& args) {
+void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
+ Index output_size, const std::vector<ObjectRef>& args) {
size_t arity = 0;
for (Index i = 0; i < arg_count; i++) {
if (const auto* obj = args[i].as<ADTObj>()) {
}
}
-
-void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
- ctxs_ = ctxs;
-}
+void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { ctxs_ = ctxs; }
inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
frames_.back().register_file[r] = val;
goto main_loop;
}
case Opcode::InvokePacked: {
- DLOG(INFO) << "InvokedPacked " << "arity=" << instr.arity;
+ DLOG(INFO) << "InvokedPacked "
+ << "arity=" << instr.arity;
const auto& func = packed_funcs_[instr.packed_index];
const auto& arity = instr.arity;
std::vector<ObjectRef> args;
for (Index i = 0; i < arity; ++i) {
- DLOG(INFO) <<
- "arg" << i << " $" << instr.packed_args[i];
+ DLOG(INFO) << "arg" << i << " $" << instr.packed_args[i];
auto arg = ReadRegister(instr.packed_args[i]);
args.push_back(arg);
}
auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
auto alignment = LoadScalarInt(instr.alloc_storage.alignment);
- DLOG(INFO) <<
- "AllocStorage: allocation_size=" << size <<
- "alignment=" << alignment <<
- "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint);
+ DLOG(INFO) << "AllocStorage: allocation_size=" << size << "alignment=" << alignment
+ << "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint);
auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]);
WriteRegister(instr.dst, storage);
return runtime::Module(vm);
}
-TVM_REGISTER_GLOBAL("runtime._VirtualMachine")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("runtime._VirtualMachine").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec) << "The virtual machine executable has not been defined yet.";
* under the License.
*/
-#include <vulkan/vulkan.h>
#include <dmlc/memory_io.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
+#include <vulkan/vulkan.h>
#include <array>
#include <cstring>
-
#include "../file_util.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../workspace_pool.h"
-
#include "vulkan_common.h"
#include "vulkan_module.h"
#include "vulkan_shader.h"
}
void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) final {
const auto& vctx = context(ctx.device_id);
VkBufferCreateInfo info;
#ifdef USE_VULKAN_IMMEDIATE_MODE
if (has_extension("VK_KHR_push_descriptor") &&
has_extension("VK_KHR_descriptor_update_template")) {
- ctx.descriptor_template_khr_functions =
- std::unique_ptr<VulkanDescriptorTemplateKHRFunctions>(
- new VulkanDescriptorTemplateKHRFunctions());
+ ctx.descriptor_template_khr_functions = std::unique_ptr<VulkanDescriptorTemplateKHRFunctions>(
+ new VulkanDescriptorTemplateKHRFunctions());
ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR =
CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
ctx.device, "vkCreateDescriptorUpdateTemplateKHR"));
// a wrapped function class to get packed func.
class VulkanWrappedFunc {
public:
- void Init(VulkanModuleNode* m,
- ObjectPtr<Object> sptr,
- const std::string& func_name,
+ void Init(VulkanModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_buffer_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
m_ = m;
class VulkanModuleNode final : public runtime::ModuleNode {
public:
explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
- std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: smap_(smap), fmap_(fmap), source_(source) {}
const char* type_key() const final { return "vulkan"; }
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
}
}
- std::shared_ptr<VulkanPipeline> GetPipeline(
- size_t device_id,
- const std::string& func_name,
- size_t num_pack_args) {
+ std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
+ size_t num_pack_args) {
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
std::lock_guard<std::mutex> lock(mutex_);
const auto& cp = ecache_[device_id][func_name];
return streams_[device_id].get();
}
-void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
- const ArgUnion* pack_args) const {
+void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
CHECK_LT(device_id, kVulkanMaxNumDevice);
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/packed_func.h>
-
#include <vulkan/vulkan.h>
+
#include <memory>
#include <mutex>
#include <string>
bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; }
};
-
} // namespace vulkan
} // namespace runtime
} // namespace tvm
*/
#pragma once
-
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <functional>
#include <memory>
-#include <vector>
#include <unordered_map>
+#include <vector>
#include "vulkan_common.h"
-
namespace tvm {
namespace runtime {
namespace vulkan {
class VulkanStream {
public:
- explicit VulkanStream(const VulkanContext* vctx)
- : vctx_(vctx), state_(new VulkanStreamState()) {
+ explicit VulkanStream(const VulkanContext* vctx) : vctx_(vctx), state_(new VulkanStreamState()) {
// create command pool
VkCommandPoolCreateInfo cmd_pool_cinfo;
cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
* \file workspace_pool.h
* \brief Workspace pool utility.
*/
-#include <memory>
#include "workspace_pool.h"
+#include <memory>
+
namespace tvm {
namespace runtime {
if (free_list_.back().size >= nbytes) {
// find smallest fit
auto it = free_list_.end() - 2;
- for (; it->size >= nbytes; --it) {}
+ for (; it->size >= nbytes; --it) {
+ }
e = *(it + 1);
free_list_.erase(it + 1);
} else {
allocated_.pop_back();
} else {
int index = static_cast<int>(allocated_.size()) - 2;
- for (; index > 0 && allocated_[index].data != data; --index) {}
+ for (; index > 0 && allocated_[index].data != data; --index) {
+ }
CHECK_GT(index, 0) << "trying to free things that has not been allocated";
e = allocated_[index];
allocated_.erase(allocated_.begin() + index);
};
WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
- : device_type_(device_type), device_(device) {
-}
+ : device_type_(device_type), device_(device) {}
WorkspacePool::~WorkspacePool() {
for (size_t i = 0; i < array_.size(); ++i) {
}
void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) {
- CHECK(static_cast<size_t>(ctx.device_id) < array_.size() &&
- array_[ctx.device_id] != nullptr);
+ CHECK(static_cast<size_t>(ctx.device_id) < array_.size() && array_[ctx.device_id] != nullptr);
array_[ctx.device_id]->Free(ptr);
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_RUNTIME_WORKSPACE_POOL_H_
#include <tvm/runtime/device_api.h>
-#include <vector>
+
#include <memory>
+#include <vector>
namespace tvm {
namespace runtime {
#endif
#include <cstddef>
-#include <utility>
#include <type_traits>
-
+#include <utility>
namespace tvm {
namespace support {
* \brief De-allocate an allocate page.
* \param page The page to be de-allocated.
*/
- void deallocate(ArenaPageHeader* page) {
- delete [] reinterpret_cast<Page*>(page);
- }
+ void deallocate(ArenaPageHeader* page) { delete[] reinterpret_cast<Page*>(page); }
static const constexpr int kPageSize = 16 << 10;
static const constexpr int kPageAlign = 1024;
* \brief Arena allocator that allocates memory from continuous
* chunk and frees them all only during destruction.
*/
-template<typename PageAllocator>
+template <typename PageAllocator>
class GenericArena {
public:
- explicit GenericArena(PageAllocator alloc = PageAllocator())
- : alloc_(alloc) {
+ explicit GenericArena(PageAllocator alloc = PageAllocator()) : alloc_(alloc) {
// eagerly allocate the first page.
head_ = tail_ = alloc_.allocate(1);
head_->next = nullptr;
}
#if TVM_ARENA_HAS_DESTRUCTOR
- ~GenericArena() {
- this->FreeAll();
- }
+ ~GenericArena() { this->FreeAll(); }
#endif
/*! \brief Free all pages. */
* \param count Numberof elements
* \note The space of T is not initialized.
*/
- template<typename T>
+ template <typename T>
T* allocate_(int count = 1) {
- static_assert(PageAllocator::kPageAlign % alignof(T) == 0,
- "To large alignment");
+ static_assert(PageAllocator::kPageAlign % alignof(T) == 0, "To large alignment");
return static_cast<T*>(Alloc(sizeof(T) * count, alignof(T)));
}
/*!
* memory allocated from the same arena.
* Otherwise the destructor needs to be called explicitly.
*/
- template<typename T, typename... Args>
+ template <typename T, typename... Args>
T* make(Args&&... args) {
T* ptr = allocate_<T>();
new (ptr) T(std::forward<Args>(args)...);
} else {
ArenaPageHeader* new_head;
offset = UpperAlign(sizeof(ArenaPageHeader), align);
- if (free_list_ != nullptr && offset + size <= free_list_-> size) {
+ if (free_list_ != nullptr && offset + size <= free_list_->size) {
new_head = free_list_;
free_list_ = free_list_->next;
} else {
* \brief Link list node
* \tparam T the content data type
*/
-template<typename T>
+template <typename T>
struct LinkNode {
/*! \brief The content value */
T value;
* \note This is a simple data structure that can be used together with the arena.
* \sa LinkNode
*/
-template<typename T>
+template <typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
#define TVM_SUPPORT_BASE64_H_
#include <dmlc/logging.h>
-#include <dmlc/logging.h>
+
#include <cctype>
#include <cstdio>
#include <string>
namespace base64 {
// decoding table
const char DecodeTable[] = {
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 62, // '+'
- 0, 0, 0,
- 63, // '/'
- 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
- 0, 0, 0, 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
- 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
- 0, 0, 0, 0, 0, 0,
- 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
- 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 62, // '+'
+ 0, 0, 0,
+ 63, // '/'
+ 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
+ 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
+ 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
// encoding table
static const char EncodeTable[] =
*/
class StreamBufferReader {
public:
- explicit StreamBufferReader(size_t buffer_size) {
- buffer_.resize(buffer_size);
- }
+ explicit StreamBufferReader(size_t buffer_size) { buffer_.resize(buffer_size); }
/*!
* \brief set input stream
* \param stream The stream to be set
*/
- void set_stream(dmlc::Stream *stream) {
+ void set_stream(dmlc::Stream* stream) {
stream_ = stream;
read_len_ = read_ptr_ = 1;
}
}
}
/*! \return whether we are reaching the end of file */
- bool AtEnd() const {
- return read_len_ == 0;
- }
+ bool AtEnd() const { return read_len_ == 0; }
private:
/*! \brief the underlying stream */
- dmlc::Stream *stream_{nullptr};
+ dmlc::Stream* stream_{nullptr};
/*! \brief buffer to hold data */
std::string buffer_;
/*! \brief length of valid data in buffer */
/*!
* \brief Input stream from base64 encoding
*/
-class Base64InStream: public dmlc::Stream {
+class Base64InStream : public dmlc::Stream {
public:
- explicit Base64InStream(dmlc::Stream *fs) : reader_(256) {
- reader_.set_stream(fs);
- }
+ explicit Base64InStream(dmlc::Stream* fs) : reader_(256) { reader_.set_stream(fs); }
/*!
* \brief initialize the stream position to beginning of next base64 stream
* \note call this function before actually start read
} while (isspace(temp_ch_));
}
/*! \brief whether current position is end of a base64 stream */
- bool IsEOF(void) const {
- return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_));
- }
+ bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); }
// override read function.
- virtual size_t Read(void *ptr, size_t size) {
+ virtual size_t Read(void* ptr, size_t size) {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
size_t tlen = size;
- unsigned char *cptr = static_cast<unsigned char*>(ptr);
+ unsigned char* cptr = static_cast<unsigned char*>(ptr);
// if anything left, load from previous buffered result
if (num_prev_ != 0) {
if (num_prev_ == 2) {
num_prev_ = 0;
} else {
// assert tlen == 1
- *cptr++ = buf_prev[0]; --tlen;
+ *cptr++ = buf_prev[0];
+ --tlen;
buf_prev[0] = buf_prev[1];
num_prev_ = 1;
}
} else {
// assert num_prev_ == 1
- *cptr++ = buf_prev[0]; --tlen; num_prev_ = 0;
+ *cptr++ = buf_prev[0];
+ --tlen;
+ num_prev_ = 0;
}
}
if (tlen == 0) return size;
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
nvalue |= DecodeTable[temp_ch_] << 12;
- *cptr++ = (nvalue >> 16) & 0xFF; --tlen;
- }
+ *cptr++ = (nvalue >> 16) & 0xFF;
+ --tlen;
+ }
{
// third byte
temp_ch_ = reader_.GetChar();
temp_ch_ = reader_.GetChar();
CHECK(temp_ch_ == '=') << "invalid base64 format";
temp_ch_ = reader_.GetChar();
- CHECK(temp_ch_ == EOF || isspace(temp_ch_))
- << "invalid base64 format";
+ CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format";
break;
}
nvalue |= DecodeTable[temp_ch_] << 6;
if (tlen) {
- *cptr++ = (nvalue >> 8) & 0xFF; --tlen;
+ *cptr++ = (nvalue >> 8) & 0xFF;
+ --tlen;
} else {
buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF;
}
{
// fourth byte
temp_ch_ = reader_.GetChar();
- CHECK(temp_ch_ != EOF && !isspace(temp_ch_))
- << "invalid base64 format";
+ CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
if (temp_ch_ == '=') {
temp_ch_ = reader_.GetChar();
- CHECK(temp_ch_ == EOF || isspace(temp_ch_))
- << "invalid base64 format";
+ CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format";
break;
}
nvalue |= DecodeTable[temp_ch_];
if (tlen) {
- *cptr++ = nvalue & 0xFF; --tlen;
+ *cptr++ = nvalue & 0xFF;
+ --tlen;
} else {
- buf_prev[num_prev_ ++] = nvalue & 0xFF;
+ buf_prev[num_prev_++] = nvalue & 0xFF;
}
}
// get next char
}
return size - tlen;
}
- virtual void Write(const void *ptr, size_t size) {
+ virtual void Write(const void* ptr, size_t size) {
LOG(FATAL) << "Base64InStream do not support write";
}
/*!
* \brief Stream to write to base64 format.
*/
-class Base64OutStream: public dmlc::Stream {
+class Base64OutStream : public dmlc::Stream {
public:
- explicit Base64OutStream(dmlc::Stream *fp) : fp_(fp) {
- }
- virtual void Write(const void *ptr, size_t size) {
+ explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {}
+ virtual void Write(const void* ptr, size_t size) {
using base64::EncodeTable;
size_t tlen = size;
- const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
+ const unsigned char* cptr = static_cast<const unsigned char*>(ptr);
while (tlen) {
- while (buf__top_ < 3 && tlen != 0) {
- buf_[++buf__top_] = *cptr++; --tlen;
+ while (buf__top_ < 3 && tlen != 0) {
+ buf_[++buf__top_] = *cptr++;
+ --tlen;
}
if (buf__top_ == 3) {
// flush 4 bytes out
}
}
}
- virtual size_t Read(void *ptr, size_t size) {
+ virtual size_t Read(void* ptr, size_t size) {
LOG(FATAL) << "Base64OutStream do not support read";
return 0;
}
private:
static constexpr size_t kBufferSize = 256;
- dmlc::Stream *fp_{nullptr};
+ dmlc::Stream* fp_{nullptr};
int buf__top_{0};
unsigned char buf_[4];
std::string out_buf_;
-
void PutChar(char ch) {
out_buf_ += ch;
if (out_buf_.length() >= kBufferSize) Flush();
* under the License.
*/
- /*!
+/*!
* FFI registration code used for frontend testing purposes.
* \file ffi_testing.cc
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/tir/expr.h>
-#include <tvm/te/tensor.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/tensor.h>
+#include <tvm/tir/expr.h>
namespace tvm {
// Attrs used to python API
TypedEnvFunc<int(int)> func;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
- TVM_ATTR_FIELD(axis)
- .set_default(10)
- .set_lower_bound(1)
- .set_upper_bound(10)
- .describe("axis field");
- TVM_ATTR_FIELD(name)
- .describe("name");
- TVM_ATTR_FIELD(padding)
- .describe("padding of input")
- .set_default(Array<PrimExpr>({0, 0}));
+ TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe(
+ "axis field");
+ TVM_ATTR_FIELD(name).describe("name");
+ TVM_ATTR_FIELD(padding).describe("padding of input").set_default(Array<PrimExpr>({0, 0}));
TVM_ATTR_FIELD(func)
.describe("some random env function")
.set_default(TypedEnvFunc<int(int)>(nullptr));
TVM_REGISTER_NODE_TYPE(TestAttrs);
-TVM_REGISTER_GLOBAL("testing.nop")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- });
+TVM_REGISTER_GLOBAL("testing.nop").set_body([](TVMArgs args, TVMRetValue* ret) {});
-TVM_REGISTER_GLOBAL("testing.echo")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
+TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0];
- });
+});
-TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- PackedFunc pf = args[0];
- *ret = runtime::TypedPackedFunc<void()>([pf](){
- pf();
- });
- });
+TVM_REGISTER_GLOBAL("testing.test_wrap_callback").set_body([](TVMArgs args, TVMRetValue* ret) {
+ PackedFunc pf = args[0];
+ *ret = runtime::TypedPackedFunc<void()>([pf]() { pf(); });
+});
TVM_REGISTER_GLOBAL("testing.test_raise_error_callback")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- std::string msg = args[0];
- *ret = runtime::TypedPackedFunc<void()>([msg](){
- LOG(FATAL) << msg;
- });
- });
-
-TVM_REGISTER_GLOBAL("testing.test_check_eq_callback")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- std::string msg = args[0];
- *ret = runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y){
- CHECK_EQ(x, y) << msg;
- });
- });
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ std::string msg = args[0];
+ *ret = runtime::TypedPackedFunc<void()>([msg]() { LOG(FATAL) << msg; });
+ });
-TVM_REGISTER_GLOBAL("testing.context_test")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLContext ctx = args[0];
- int dtype = args[1];
- int did = args[2];
- CHECK_EQ(static_cast<int>(ctx.device_type), dtype);
- CHECK_EQ(static_cast<int>(ctx.device_id), did);
- *ret = ctx;
- });
+TVM_REGISTER_GLOBAL("testing.test_check_eq_callback").set_body([](TVMArgs args, TVMRetValue* ret) {
+ std::string msg = args[0];
+ *ret =
+ runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y) { CHECK_EQ(x, y) << msg; });
+});
+TVM_REGISTER_GLOBAL("testing.context_test").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLContext ctx = args[0];
+ int dtype = args[1];
+ int did = args[2];
+ CHECK_EQ(static_cast<int>(ctx.device_type), dtype);
+ CHECK_EQ(static_cast<int>(ctx.device_id), did);
+ *ret = ctx;
+});
// in src/api_test.cc
void ErrorTest(int x, int y) {
}
}
-TVM_REGISTER_GLOBAL("testing.ErrorTest")
-.set_body_typed(ErrorTest);
+TVM_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest);
// internal function used for debug and testing purposes
-TVM_REGISTER_GLOBAL("testing.object_use_count")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- runtime::ObjectRef obj = args[0];
- // substract the current one because we always copy
- // and get another value.
- *ret = (obj.use_count() - 1);
- });
+TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRetValue* ret) {
+ runtime::ObjectRef obj = args[0];
+ // substract the current one because we always copy
+ // and get another value.
+ *ret = (obj.use_count() - 1);
+});
} // namespace tvm
#ifndef TVM_SUPPORT_PIPE_H_
#define TVM_SUPPORT_PIPE_H_
-#include <dmlc/logging.h>
#include <dmlc/io.h>
+#include <dmlc/logging.h>
#ifdef _WIN32
#include <windows.h>
#else
-#include <unistd.h>
#include <errno.h>
-#include <cstring>
+#include <unistd.h>
+
#include <cstdlib>
+#include <cstring>
#endif
namespace tvm {
using PipeHandle = int;
#endif
/*! \brief Construct a pipe from system handle. */
- explicit Pipe(int64_t handle)
- : handle_(static_cast<PipeHandle>(handle)) {}
+ explicit Pipe(int64_t handle) : handle_(static_cast<PipeHandle>(handle)) {}
/*! \brief destructor */
- ~Pipe() {
- Flush();
- }
+ ~Pipe() { Flush(); }
using Stream::Read;
using Stream::Write;
/*!
* \param size block size
* \return the size of data read
*/
- size_t Read(void *ptr, size_t size) final {
+ size_t Read(void* ptr, size_t size) final {
if (size == 0) return 0;
#ifdef _WIN32
DWORD nread;
- CHECK(ReadFile(handle_, static_cast<TCHAR*>(ptr),
- &nread, nullptr))
+ CHECK(ReadFile(handle_, static_cast<TCHAR*>(ptr), &nread, nullptr))
<< "Read Error: " << GetLastError();
#else
ssize_t nread;
nread = read(handle_, ptr, size);
- CHECK_GE(nread, 0)
- << "Write Error: " << strerror(errno);
+ CHECK_GE(nread, 0) << "Write Error: " << strerror(errno);
#endif
return static_cast<size_t>(nread);
}
* \param size block size
* \return the size of data read
*/
- void Write(const void *ptr, size_t size) final {
+ void Write(const void* ptr, size_t size) final {
if (size == 0) return;
#ifdef _WIN32
DWORD nwrite;
- CHECK(WriteFile(handle_, static_cast<const TCHAR*>(ptr),
- &nwrite, nullptr) &&
+ CHECK(WriteFile(handle_, static_cast<const TCHAR*>(ptr), &nwrite, nullptr) &&
static_cast<size_t>(nwrite) == size)
<< "Write Error: " << GetLastError();
#else
ssize_t nwrite;
nwrite = write(handle_, ptr, size);
- CHECK_EQ(static_cast<size_t>(nwrite), size)
- << "Write Error: " << strerror(errno);
+ CHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << strerror(errno);
#endif
}
/*!
#ifndef TVM_SUPPORT_RING_BUFFER_H_
#define TVM_SUPPORT_RING_BUFFER_H_
-#include <vector>
-#include <cstring>
#include <algorithm>
+#include <cstring>
+#include <vector>
namespace tvm {
namespace support {
/*! \brief constructor */
RingBuffer() : ring_(kInitCapacity) {}
/*! \return number of bytes available in buffer. */
- size_t bytes_available() const {
- return bytes_available_;
- }
+ size_t bytes_available() const { return bytes_available_; }
/*! \return Current capacity of buffer. */
- size_t capacity() const {
- return ring_.size();
- }
+ size_t capacity() const { return ring_.size(); }
/*!
* Reserve capacity to be at least n.
* Will only increase capacity if n is bigger than current capacity.
*/
void Reserve(size_t n) {
if (ring_.size() < n) {
- size_t old_size = ring_.size();
- size_t new_size = static_cast<size_t>(n * 1.2);
- ring_.resize(new_size);
- if (head_ptr_ + bytes_available_ > old_size) {
- // copy the ring overflow part into the tail.
- size_t ncopy = head_ptr_ + bytes_available_ - old_size;
- memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
- }
- } else if (ring_.size() > n * 8 &&
- ring_.size() > kInitCapacity) {
+ size_t old_size = ring_.size();
+ size_t new_size = static_cast<size_t>(n * 1.2);
+ ring_.resize(new_size);
+ if (head_ptr_ + bytes_available_ > old_size) {
+ // copy the ring overflow part into the tail.
+ size_t ncopy = head_ptr_ + bytes_available_ - old_size;
+ memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
+ }
+ } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) {
// shrink too large temporary buffer to
// avoid out of memory on some embedded devices
if (bytes_available_ != 0) {
bytes_available_ = old_bytes;
}
// shrink the ring.
- size_t new_size = kInitCapacity;
+ size_t new_size = kInitCapacity;
new_size = std::max(new_size, n);
new_size = std::max(new_size, bytes_available_);
size_t ncopy = std::min(size, ring_.size() - head_ptr_);
memcpy(data, &ring_[0] + head_ptr_, ncopy);
if (ncopy < size) {
- memcpy(reinterpret_cast<char*>(data) + ncopy,
- &ring_[0], size - ncopy);
+ memcpy(reinterpret_cast<char*>(data) + ncopy, &ring_[0], size - ncopy);
}
head_ptr_ = (head_ptr_ + size) % ring_.size();
bytes_available_ -= size;
* \param max_nbytes Maximum number of bytes can to read.
* \tparam FSend A non-blocking function with signature size_t (const void* data, size_t size);
*/
- template<typename FSend>
+ template <typename FSend>
size_t ReadWithCallback(FSend fsend, size_t max_nbytes) {
size_t size = std::min(max_nbytes, bytes_available_);
CHECK_NE(size, 0U);
* \param max_nbytes Maximum number of bytes can write.
* \tparam FRecv A non-blocking function with signature size_t (void* data, size_t size);
*/
- template<typename FRecv>
+ template <typename FRecv>
size_t WriteWithCallback(FRecv frecv, size_t max_nbytes) {
this->Reserve(bytes_available_ + max_nbytes);
size_t nbytes = max_nbytes;
#pragma comment(lib, "Ws2_32.lib")
#endif
#else
+#include <arpa/inet.h>
+#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
-#include <errno.h>
-#include <unistd.h>
-#include <arpa/inet.h>
#include <netinet/in.h>
-#include <sys/socket.h>
-#include <sys/select.h>
#include <sys/ioctl.h>
+#include <sys/select.h>
+#include <sys/socket.h>
+#include <unistd.h>
#endif
#include <dmlc/logging.h>
-#include <string>
+
#include <cstring>
-#include <vector>
+#include <string>
#include <unordered_map>
+#include <vector>
+
#include "../support/util.h"
#if defined(_WIN32)
-static inline int poll(struct pollfd *pfd, int nfds,
- int timeout) {
+static inline int poll(struct pollfd* pfd, int nfds, int timeout) {
return WSAPoll(pfd, nfds, timeout);
}
#else
* \return The hostname.
*/
inline std::string GetHostName() {
- std::string buf; buf.resize(256);
+ std::string buf;
+ buf.resize(256);
CHECK_NE(gethostname(&buf[0], 256), -1);
return std::string(buf.c_str());
}
* \param url The url of the address
* \param port The port of the address.
*/
- SockAddr(const char *url, int port) {
- this->Set(url, port);
- }
+ SockAddr(const char* url, int port) { this->Set(url, port); }
/*!
- * \brief SockAddr Get the socket address from tracker.
- * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090)
- * \return SockAddr parsed from url.
- */
- explicit SockAddr(const std::string &url) {
+ * \brief SockAddr Get the socket address from tracker.
+ * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090)
+ * \return SockAddr parsed from url.
+ */
+ explicit SockAddr(const std::string& url) {
size_t sep = url.find(",");
std::string host = url.substr(2, sep - 3);
std::string port = url.substr(sep + 1, url.length() - 1);
* \param host the url of the address
* \param port the port of address
*/
- void Set(const char *host, int port) {
+ void Set(const char* host, int port) {
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = PF_UNSPEC;
hints.ai_flags = AI_PASSIVE;
hints.ai_socktype = SOCK_STREAM;
- addrinfo *res = NULL;
+ addrinfo* res = NULL;
int sig = getaddrinfo(host, NULL, &hints, &res);
- CHECK(sig == 0 && res != NULL)
- << "cannot obtain address of " << host;
+ CHECK(sig == 0 && res != NULL) << "cannot obtain address of " << host;
switch (res->ai_family) {
case AF_INET: {
- sockaddr_in *addr4 = reinterpret_cast<sockaddr_in *>(&addr);
- memcpy(addr4, res->ai_addr, res->ai_addrlen);
- addr4->sin_port = htons(port);
- addr4->sin_family = AF_INET;
- }
- break;
+ sockaddr_in* addr4 = reinterpret_cast<sockaddr_in*>(&addr);
+ memcpy(addr4, res->ai_addr, res->ai_addrlen);
+ addr4->sin_port = htons(port);
+ addr4->sin_family = AF_INET;
+ } break;
case AF_INET6: {
- sockaddr_in6 *addr6 = reinterpret_cast<sockaddr_in6 *>(&addr);
- memcpy(addr6, res->ai_addr, res->ai_addrlen);
- addr6->sin6_port = htons(port);
- addr6->sin6_family = AF_INET6;
- }
- break;
+ sockaddr_in6* addr6 = reinterpret_cast<sockaddr_in6*>(&addr);
+ memcpy(addr6, res->ai_addr, res->ai_addrlen);
+ addr6->sin6_port = htons(port);
+ addr6->sin6_family = AF_INET6;
+ } break;
default:
CHECK(false) << "cannot decode address";
}
}
/*! \brief return port of the address */
int port() const {
- return ntohs((addr.ss_family == AF_INET6)? \
- reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_port : \
- reinterpret_cast<const sockaddr_in *>(&addr)->sin_port);
+ return ntohs((addr.ss_family == AF_INET6)
+ ? reinterpret_cast<const sockaddr_in6*>(&addr)->sin6_port
+ : reinterpret_cast<const sockaddr_in*>(&addr)->sin_port);
}
/*! \brief return the ip address family */
- int ss_family() const {
- return addr.ss_family;
- }
+ int ss_family() const { return addr.ss_family; }
/*! \return a string representation of the address */
std::string AsString() const {
- std::string buf; buf.resize(256);
+ std::string buf;
+ buf.resize(256);
- const void *sinx_addr = nullptr;
- if (addr.ss_family == AF_INET6) {
- const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_addr;
- sinx_addr = reinterpret_cast<const void *>(&addr6);
- } else if (addr.ss_family == AF_INET) {
- const in_addr& addr4 = reinterpret_cast<const sockaddr_in *>(&addr)->sin_addr;
- sinx_addr = reinterpret_cast<const void *>(&addr4);
- } else {
- CHECK(false) << "illegal address";
- }
+ const void* sinx_addr = nullptr;
+ if (addr.ss_family == AF_INET6) {
+ const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6*>(&addr)->sin6_addr;
+ sinx_addr = reinterpret_cast<const void*>(&addr6);
+ } else if (addr.ss_family == AF_INET) {
+ const in_addr& addr4 = reinterpret_cast<const sockaddr_in*>(&addr)->sin_addr;
+ sinx_addr = reinterpret_cast<const void*>(&addr4);
+ } else {
+ CHECK(false) << "illegal address";
+ }
#ifdef _WIN32
- const char *s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*)
+ const char* s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*)
&buf[0], buf.length());
#else
- const char *s = inet_ntop(addr.ss_family, sinx_addr,
- &buf[0], static_cast<socklen_t>(buf.length()));
+ const char* s =
+ inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast<socklen_t>(buf.length()));
#endif
CHECK(s != nullptr) << "cannot decode address";
std::ostringstream os;
* \brief bind the socket to an address
* \param addr The address to be binded
*/
- void Bind(const SockAddr &addr) {
+ void Bind(const SockAddr& addr) {
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
- (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
- sizeof(sockaddr_in))) == -1) {
+ (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) ==
+ -1) {
Socket::Error("Bind");
}
}
for (int port = start_port; port < end_port; ++port) {
SockAddr addr(host.c_str(), port);
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
- (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
- sizeof(sockaddr_in))) == 0) {
+ (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) ==
+ 0) {
return port;
} else {
LOG(WARNING) << "Bind failed to " << host << ":" << port;
int GetSockError() const {
int error = 0;
socklen_t len = sizeof(error);
- if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
+ if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
Error("GetSockError");
}
return error;
return false;
}
/*! \brief check if socket is already closed */
- bool IsClosed() const {
- return sockfd == INVALID_SOCKET;
- }
+ bool IsClosed() const { return sockfd == INVALID_SOCKET; }
/*! \brief close the socket */
void Close() {
if (sockfd != INVALID_SOCKET) {
* \brief Report an socket error.
* \param msg The error message.
*/
- static void Error(const char *msg) {
+ static void Error(const char* msg) {
int errsv = GetLastError();
#ifdef _WIN32
LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
}
protected:
- explicit Socket(SockType sockfd) : sockfd(sockfd) {
- }
+ explicit Socket(SockType sockfd) : sockfd(sockfd) {}
};
/*!
*/
class TCPSocket : public Socket {
public:
- TCPSocket() : Socket(INVALID_SOCKET) {
- }
+ TCPSocket() : Socket(INVALID_SOCKET) {}
/*!
* \brief construct a TCP socket from existing descriptor
* \param sockfd The descriptor
*/
- explicit TCPSocket(SockType sockfd) : Socket(sockfd) {
- }
+ explicit TCPSocket(SockType sockfd) : Socket(sockfd) {}
/*!
* \brief enable/disable TCP keepalive
* \param keepalive whether to set the keep alive option on
*/
void SetKeepAlive(bool keepalive) {
int opt = static_cast<int>(keepalive);
- if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
- reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
+ if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char*>(&opt), sizeof(opt)) <
+ 0) {
Socket::Error("SetKeepAlive");
}
}
* \brief perform listen of the socket
* \param backlog backlog parameter
*/
- void Listen(int backlog = 16) {
- listen(sockfd, backlog);
- }
+ void Listen(int backlog = 16) { listen(sockfd, backlog); }
/*!
* \brief get a new connection
* \return The accepted socket connection.
return TCPSocket(newfd);
}
/*!
- * \brief get a new connection
- * \param addr client address from which connection accepted
- * \return The accepted socket connection.
- */
- TCPSocket Accept(SockAddr *addr) {
+ * \brief get a new connection
+ * \param addr client address from which connection accepted
+ * \return The accepted socket connection.
+ */
+ TCPSocket Accept(SockAddr* addr) {
socklen_t addrlen = sizeof(addr->addr);
- SockType newfd = accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr),
- &addrlen);
+ SockType newfd = accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr), &addrlen);
if (newfd == INVALID_SOCKET) {
Socket::Error("Accept");
}
* \param addr the address to connect to
* \return whether connect is successful
*/
- bool Connect(const SockAddr &addr) {
- return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
- (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
- sizeof(sockaddr_in))) == 0;
+ bool Connect(const SockAddr& addr) {
+ return connect(
+ sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
+ (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0;
}
/*!
* \brief send data using the socket
* \return size of data actually sent
* return -1 if error occurs
*/
- ssize_t Send(const void *buf_, size_t len, int flag = 0) {
- const char *buf = reinterpret_cast<const char*>(buf_);
+ ssize_t Send(const void* buf_, size_t len, int flag = 0) {
+ const char* buf = reinterpret_cast<const char*>(buf_);
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
}
/*!
* \return size of data actually received
* return -1 if error occurs
*/
- ssize_t Recv(void *buf_, size_t len, int flags = 0) {
- char *buf = reinterpret_cast<char*>(buf_);
+ ssize_t Recv(void* buf_, size_t len, int flags = 0) {
+ char* buf = reinterpret_cast<char*>(buf_);
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
}
/*!
* \param len the size of the buffer
* \return size of data actually sent
*/
- size_t SendAll(const void *buf_, size_t len) {
- const char *buf = reinterpret_cast<const char*>(buf_);
+ size_t SendAll(const void* buf_, size_t len) {
+ const char* buf = reinterpret_cast<const char*>(buf_);
size_t ndone = 0;
- while (ndone < len) {
+ while (ndone < len) {
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
if (ret == -1) {
if (LastErrorWouldBlock()) return ndone;
* \param len length of data to recv
* \return size of data actually sent
*/
- size_t RecvAll(void *buf_, size_t len) {
- char *buf = reinterpret_cast<char*>(buf_);
+ size_t RecvAll(void* buf_, size_t len) {
+ char* buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0;
- while (ndone < len) {
- ssize_t ret = recv(sockfd, buf,
- static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
+ while (ndone < len) {
+ ssize_t ret = recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
if (ret == -1) {
- if (LastErrorWouldBlock()) {
+ if (LastErrorWouldBlock()) {
LOG(FATAL) << "would block";
return ndone;
}
* \param timeout the timeout counter, can be negative, which means wait until the event happen
* \return 1 if success, 0 if timeout, and -1 if error occurs
*/
- inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*)
+ inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*)
pollfd pfd;
pfd.fd = fd;
pfd.events = POLLPRI;
#ifndef TVM_SUPPORT_STR_ESCAPE_H_
#define TVM_SUPPORT_STR_ESCAPE_H_
-#include <string>
#include <sstream>
+#include <string>
namespace tvm {
namespace support {
* \param size The size of the string.
* \return the Result string.
*/
-inline std::string StrEscape(const std::string& val) {
- return StrEscape(val.data(), val.length());
-}
+inline std::string StrEscape(const std::string& val) { return StrEscape(val.data(), val.length()); }
} // namespace support
} // namespace tvm
#include <stdio.h>
#ifndef _WIN32
-#include <sys/wait.h>
#include <sys/types.h>
+#include <sys/wait.h>
#endif
-#include <vector>
-#include <string>
-#include <sstream>
#include <algorithm>
#include <array>
#include <cctype>
#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
namespace tvm {
namespace support {
#endif
}
-
/*!
* \brief IsNumber check whether string is a number.
* \param str input string
* \return result of operation.
*/
inline bool IsNumber(const std::string& str) {
- return !str.empty() && std::find_if(str.begin(),
- str.end(), [](char c) { return !std::isdigit(c); }) == str.end();
+ return !str.empty() &&
+ std::find_if(str.begin(), str.end(), [](char c) { return !std::isdigit(c); }) == str.end();
}
/*!
#ifndef TVM_TARGET_BUILD_COMMON_H_
#define TVM_TARGET_BUILD_COMMON_H_
-#include <tvm/target/codegen.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/container.h>
#include <tvm/ir/module.h>
-#include <tvm/tir/function.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
-#include <unordered_map>
+
#include <string>
+#include <unordered_map>
+
#include "../runtime/meta_data.h"
namespace tvm {
namespace codegen {
-inline std::unordered_map<std::string, runtime::FunctionInfo>
-ExtractFuncInfo(const IRModule& mod) {
+inline std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<tir::PrimFuncNode>())
- << "Can only lower IR Module with PrimFuncs";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
* \file codegen.cc
* \brief Common utilities to generated C style code.
*/
+#include <dmlc/memory_io.h>
+#include <tvm/ir/module.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
#include <tvm/target/codegen.h>
#include <tvm/target/target.h>
-
-#include <tvm/ir/module.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/transform.h>
-#include <tvm/runtime/container.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/module.h>
-#include <tvm/runtime/c_runtime_api.h>
-#include <dmlc/memory_io.h>
-#include <sstream>
-#include <vector>
#include <cstdint>
-#include <unordered_set>
#include <cstring>
+#include <sstream>
+#include <unordered_set>
+#include <vector>
namespace tvm {
namespace codegen {
std::string build_f_name = "target.build." + target->target_name;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
- CHECK(bf != nullptr)
- << "target.build." << target << " is not enabled";
+ CHECK(bf != nullptr) << "target.build." << target << " is not enabled";
return (*bf)(mod, target->str());
}
/*! \brief Helper class to serialize module */
class ModuleSerializer {
public:
- explicit ModuleSerializer(runtime::Module mod) : mod_(mod) {
- Init();
- }
+ explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); }
void SerializeModule(dmlc::Stream* stream) {
// Only have one DSO module and it is in the root, then
// invariance: root module is always at location 0.
// The module order is collected via DFS
void CreateModuleIndex() {
- std::unordered_set<const runtime::ModuleNode*> visited {mod_.operator->()};
- std::vector<runtime::ModuleNode*> stack {mod_.operator->()};
+ std::unordered_set<const runtime::ModuleNode*> visited{mod_.operator->()};
+ std::vector<runtime::ModuleNode*> stack{mod_.operator->()};
uint64_t module_index = 0;
while (!stack.empty()) {
}
bool DSOExportable(const runtime::ModuleNode* mod) {
- return !std::strcmp(mod->type_key(), "llvm") ||
- !std::strcmp(mod->type_key(), "c");
+ return !std::strcmp(mod->type_key(), "llvm") || !std::strcmp(mod->type_key(), "c");
}
runtime::Module mod_;
std::unordered_map<runtime::ModuleNode*, size_t> mod2index_;
// index -> module
std::vector<runtime::ModuleNode*> mod_vec_;
- std::vector<uint64_t> import_tree_row_ptr_ {0};
+ std::vector<uint64_t> import_tree_row_ptr_{0};
std::vector<uint64_t> import_tree_child_indices_;
};
namespace {
- std::string SerializeModule(const runtime::Module& mod) {
- std::string bin;
- dmlc::MemoryStringStream ms(&bin);
- dmlc::Stream* stream = &ms;
+std::string SerializeModule(const runtime::Module& mod) {
+ std::string bin;
+ dmlc::MemoryStringStream ms(&bin);
+ dmlc::Stream* stream = &ms;
- ModuleSerializer module_serializer(mod);
- module_serializer.SerializeModule(stream);
+ ModuleSerializer module_serializer(mod);
+ module_serializer.SerializeModule(stream);
- return bin;
- }
+ return bin;
+}
} // namespace
std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
<< "#endif\n";
os << "TVM_EXPORT extern const unsigned char " << runtime::symbol::tvm_dev_mblob << "[];\n";
uint64_t nbytes = bin.length();
- os << "const unsigned char " << runtime::symbol::tvm_dev_mblob
- << "[" << bin.length() + sizeof(nbytes) << "] = {\n ";
+ os << "const unsigned char " << runtime::symbol::tvm_dev_mblob << "["
+ << bin.length() + sizeof(nbytes) << "] = {\n ";
os << std::hex;
size_t nunit = 80 / 4;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
return os.str();
}
-runtime::Module PackImportsToLLVM(const runtime::Module& mod,
- bool system_lib,
+runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib,
const std::string& target_triple) {
std::string bin = SerializeModule(mod);
std::string codegen_f_name = "codegen.codegen_blob";
// the codegen function.
const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name);
- CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented.";
+ CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented.";
return (*codegen_f)(blob_byte_array, system_lib, target_triple);
}
-TVM_REGISTER_GLOBAL("target.Build")
-.set_body_typed(Build);
+TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build);
// Export two auxiliary function to the runtime namespace.
-TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC")
-.set_body_typed(PackImportsToC);
+TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC);
-TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM")
-.set_body_typed(PackImportsToLLVM);
+TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM);
} // namespace codegen
} // namespace tvm
* specific language governing permissions and limitations
* under the License.
*/
-#include <tvm/runtime/registry.h>
#include "registry.h"
+#include <tvm/runtime/registry.h>
+
namespace tvm {
namespace datatype {
using runtime::TVMArgs;
using runtime::TVMRetValue;
-TVM_REGISTER_GLOBAL("runtime._datatype_register")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("runtime._datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) {
datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
});
-TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = datatype::Registry::Global()->GetTypeCode(args[0]);
});
-TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Registry::Global()->GetTypeName(args[0].operator int());
});
TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = Registry::Global()->GetTypeRegistered(args[0].operator int());
-});
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ *ret = Registry::Global()->GetTypeRegistered(args[0].operator int());
+ });
Registry* Registry::Global() {
static Registry inst;
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
+
#include <string>
#include <unordered_map>
* \param type_name The type name
* \return The type code
*/
- uint8_t GetTypeCode(const std::string &type_name);
+ uint8_t GetTypeCode(const std::string& type_name);
/*!
* \brief Get type name from type code
* \file src/target/generic_func.cc
*/
#include <dmlc/thread_local.h>
-
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/container.h>
#include <tvm/node/node.h>
#include <tvm/node/repr_printer.h>
-#include <tvm/target/target.h>
-#include <tvm/target/generic_func.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
+#include <tvm/target/generic_func.h>
+#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <algorithm>
// mutex
std::mutex mutex;
- Manager() {
- }
+ Manager() {}
static Manager* Global() {
static Manager inst;
m->fmap[name] = func;
}
-GenericFunc& GenericFunc::set_default(const PackedFunc value,
- bool allow_override) {
+GenericFunc& GenericFunc::set_default(const PackedFunc value, bool allow_override) {
auto node = static_cast<GenericFuncNode*>(operator->());
if (!allow_override) {
CHECK(node->generic_func_ == nullptr)
- << "Generic function already registered for " << node->name_;
+ << "Generic function already registered for " << node->name_;
}
node->generic_func_ = value;
return *this;
}
GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
- const PackedFunc value,
- bool allow_override) {
- for (auto &t : tags) {
+ const PackedFunc value, bool allow_override) {
+ for (auto& t : tags) {
if (!allow_override) {
auto iter = (*this)->dispatch_dict_.find(t);
CHECK(iter == (*this)->dispatch_dict_.end())
- << "Tag " << t << " already registered for schedule factory " << (*this)->name_;
+ << "Tag " << t << " already registered for schedule factory " << (*this)->name_;
}
(*this)->dispatch_dict_[t] = value;
}
PackedFunc func;
if (target.defined()) {
- for (auto &k : target->keys()) {
+ for (auto& k : target->keys()) {
auto iter = node->dispatch_dict_.find(k);
if (iter != node->dispatch_dict_.end()) {
func = iter->second;
func.CallPacked(args, ret);
}
-TVM_REGISTER_GLOBAL("target.GenericFuncCreate")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("target.GenericFuncCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(make_object<GenericFuncNode>());
- });
+});
-TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal").set_body([](TVMArgs args, TVMRetValue* ret) {
std::string func_name = args[0];
*ret = GenericFunc::Get(func_name);
- });
+});
-TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault").set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
bool allow_override = args[2];
- generic_func
- .set_default(*func, allow_override);
- });
+ generic_func.set_default(*func, allow_override);
+});
-TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
tags_vector.push_back(tag);
}
- generic_func
- .register_func(tags_vector, *func, allow_override);
- });
+ generic_func.register_func(tags_vector, *func, allow_override);
+});
-TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1);
- generic_func
- .CallPacked(func_args, ret);
- });
+ generic_func.CallPacked(func_args, ret);
+});
} // namespace tvm
* \file intrin_rule_default.cc
* \brief Default intrinsic rules.
*/
-#include <tvm/tir/op.h>
#include "intrin_rule.h"
+#include <tvm/tir/op.h>
+
namespace tvm {
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
-.set_body([](const TVMArgs& args, TVMRetValue* rv){
- PrimExpr e = args[0];
- const CallNode* call = e.as<CallNode>();
- CHECK(call != nullptr);
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ PrimExpr e = args[0];
+ const CallNode* call = e.as<CallNode>();
+ CHECK(call != nullptr);
- auto one = make_const(call->args[0].dtype(), 1);
- *rv = one / sqrt(call->args[0]);
- });
+ auto one = make_const(call->args[0].dtype(), 1);
+ *rv = one / sqrt(call->args[0]);
+ });
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
-.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
-.set_body([](const TVMArgs& args, TVMRetValue* rv){
- PrimExpr e = args[0];
- const CallNode* call = e.as<CallNode>();
- CHECK(call != nullptr);
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ PrimExpr e = args[0];
+ const CallNode* call = e.as<CallNode>();
+ CHECK(call != nullptr);
- auto one = make_const(call->args[0].dtype(), 1);
- *rv = one / (one + exp(-call->args[0]));
- });
+ auto one = make_const(call->args[0].dtype(), 1);
+ *rv = one / (one + exp(-call->args[0]));
+ });
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite")
-.set_body([](const TVMArgs& args, TVMRetValue* rv){
- PrimExpr e = args[0];
- const CallNode* call = e.as<CallNode>();
- CHECK(call != nullptr);
- *rv = isfinite(call->args[0]);
- });
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ PrimExpr e = args[0];
+ const CallNode* call = e.as<CallNode>();
+ CHECK(call != nullptr);
+ *rv = isfinite(call->args[0]);
+ });
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf")
-.set_body([](const TVMArgs& args, TVMRetValue* rv){
- PrimExpr e = args[0];
- const CallNode* call = e.as<CallNode>();
- CHECK(call != nullptr);
- *rv = isinf(call->args[0]);
- });
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ PrimExpr e = args[0];
+ const CallNode* call = e.as<CallNode>();
+ CHECK(call != nullptr);
+ *rv = isinf(call->args[0]);
+ });
} // namespace intrin
} // namespace codegen
#ifndef TVM_TARGET_INTRIN_RULE_H_
#define TVM_TARGET_INTRIN_RULE_H_
-#include <tvm/tir/expr.h>
-#include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+
#include <string>
namespace tvm {
// Return the intrinsic name
struct Direct {
- std::string operator()(DataType t, std::string name) const {
- return name;
- }
+ std::string operator()(DataType t, std::string name) const { return name; }
};
// Call pure extern function.
-template<typename T>
+template <typename T>
inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
std::string name = T()(call->dtype, call->name);
if (name.length() != 0) {
- *rv = CallNode::make(
- call->dtype, name, call->args, CallNode::PureExtern);
+ *rv = CallNode::make(call->dtype, name, call->args, CallNode::PureExtern);
} else {
*rv = e;
}
*/
#ifdef TVM_LLVM_VERSION
-#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
-#include "codegen_llvm.h"
-#include "../build_common.h"
+
#include "../../runtime/rocm/rocm_module.h"
+#include "../build_common.h"
+#include "codegen_llvm.h"
namespace tvm {
namespace codegen {
TVMRetValue val;
api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
if (val.operator int() == 1) {
- tvm::runtime::DeviceAPI::Get(tvm_ctx)->
- GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, &val);
+ tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock,
+ &val);
return val.operator int();
}
}
llvm::Value* buf = nullptr;
int32_t constant_size = op->constant_allocation_size();
- CHECK_GT(constant_size, 0)
- << "Can only handle constant size stack allocation in GPU";
+ CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
- return builder_->CreateAlloca(
- DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
- });
+ return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
+ });
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
alloca->setAlignment(llvm::Align(info.alignment));
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
- llvm::Type* type = llvm::ArrayType::get(
- DTypeToLLVMType(op->dtype), constant_size);
+ llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
- llvm::GlobalVariable *global = new llvm::GlobalVariable(
- *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
- nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
+ llvm::GlobalVariable* global = new llvm::GlobalVariable(
+ *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr,
+ llvm::GlobalValue::NotThreadLocal, shared_address_space);
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(info.alignment));
#else
}
buf = builder_->CreatePointerCast(
- buf, DTypeToLLVMType(op->dtype)->getPointerTo(
- buf->getType()->getPointerAddressSpace()));
+ buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x;
if (ts.rank == 1) {
switch (ts.dim_index) {
- case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break;
- case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break;
- case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break;
- default: LOG(FATAL) << "unknown workitem idx";
+ case 0:
+ intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x;
+ break;
+ case 1:
+ intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y;
+ break;
+ case 2:
+ intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z;
+ break;
+ default:
+ LOG(FATAL) << "unknown workitem idx";
}
} else {
CHECK_EQ(ts.rank, 0);
switch (ts.dim_index) {
- case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break;
- case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break;
- case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break;
- default: LOG(FATAL) << "unknown workgroup idx";
+ case 0:
+ intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x;
+ break;
+ case 1:
+ intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y;
+ break;
+ case 2:
+ intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z;
+ break;
+ default:
+ LOG(FATAL) << "unknown workgroup idx";
}
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
if (sync == "warp") {
return nullptr;
} else if (sync == "shared") {
- llvm::Function* f = llvm::Intrinsic::getDeclaration(
- module_.get(),
- ::llvm::Intrinsic::amdgcn_s_barrier);
+ llvm::Function* f =
+ llvm::Intrinsic::getDeclaration(module_.get(), ::llvm::Intrinsic::amdgcn_s_barrier);
return builder_->CreateCall(f, {});
} else {
LOG(FATAL) << "Do not support sync " << sync;
// Additional optimization hook to tweak the builder.
}
- unsigned GetGlobalAddressSpace() const final {
- return 1;
- }
+ unsigned GetGlobalAddressSpace() const final { return 1; }
protected:
void InitTarget(llvm::TargetMachine* tm) final {
// issue #4087 for a discussion
#endif
InitializeLLVM();
- CHECK(target.length() >= 4 &&
- target.substr(0, 4) == "rocm");
+ CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm");
std::ostringstream config;
- config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
- << DetectROCMComputeVersion(target)
- << " -mattr=-code-object-v3 "
- << target.substr(4, target.length() - 4);
+ config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target)
+ << " -mattr=-code-object-v3 " << target.substr(4, target.length() - 4);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false);
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "Can only lower IR Module with PrimFuncs";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
cg->AddFunction(f);
}
- const auto *find_rocm_bitcodes =
- tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
+ const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
Array<runtime::String> bitcode_files = (*find_rocm_bitcodes)();
- for (auto &bitcode_path : bitcode_files) {
+ for (auto& bitcode_path : bitcode_files) {
std::string path = bitcode_path;
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
}
mlib->setTargetTriple(tm->getTargetTriple().str());
mlib->setDataLayout(tm->createDataLayout());
- for (llvm::Function &f : mlib->functions()) {
+ for (llvm::Function& f : mlib->functions()) {
f.addFnAttr(llvm::Attribute::AlwaysInline);
}
cg->AddLinkModule(std::move(mlib));
llvm::legacy::PassManager pass;
#if TVM_LLVM_VERSION <= 60
- CHECK(tm->addPassesToEmitFile(
- pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0)
- << "Cannot emit target CGFT_ObjectFile";
+ CHECK(tm->addPassesToEmitFile(pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0)
+ << "Cannot emit target CGFT_ObjectFile";
#elif TVM_LLVM_VERSION <= 90
- CHECK(tm->addPassesToEmitFile(
- pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0)
- << "Cannot emit target CGFT_ObjectFile";
+ CHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0)
+ << "Cannot emit target CGFT_ObjectFile";
#else
- CHECK(tm->addPassesToEmitFile(
- pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0)
- << "Cannot emit target CGFT_ObjectFile";
+ CHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0)
+ << "Cannot emit target CGFT_ObjectFile";
#endif
pass.run(*mObj);
std::string obj(dataObj.begin(), dataObj.end());
llvm::legacy::PassManager passAsm;
#if TVM_LLVM_VERSION <= 60
- CHECK(tm->addPassesToEmitFile(passAsm, destAsm,
- llvm::TargetMachine::CGFT_AssemblyFile) == 0)
+ CHECK(tm->addPassesToEmitFile(passAsm, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#elif TVM_LLVM_VERSION <= 90
CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr,
llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#else
- CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr,
- llvm::CGFT_AssemblyFile) == 0)
+ CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#endif
passAsm.run(*mAsm);
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly);
}
-TVM_REGISTER_GLOBAL("target.build.rocm")
-.set_body_typed(BuildAMDGPU);
+TVM_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU);
} // namespace codegen
} // namespace tvm
llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic("llvm_intrin")) {
- llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
- Downcast<IntImm>(op->args[0])->value);
+ llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
if (id == ::llvm::Intrinsic::ctpop) {
PrimExpr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
return CodeGenCPU::CreateIntrinsic(op);
}
-PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
+PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
using namespace tir;
const PrimExpr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;
// Fallback to default llvm lowering rule if input type not a full vector or half vector length
- int total_size = call->dtype.bits() * call->dtype.lanes();
+ int total_size = call->dtype.bits() * call->dtype.lanes();
if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
- (total_size != 128 && total_size != 64)) {
+ (total_size != 128 && total_size != 64)) {
Array<PrimExpr> vcnt_args;
vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt_args.push_back(e);
- return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
+ return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
}
// Popcount lowering rule:
// to return back to original input type
// Dvisions are always divisible (number of bits = 64 or 128)
- DataType uint8_type = DataType(
- e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8);
- DataType uint16_type = DataType(
- uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
- DataType uint32_type = DataType(
- uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
+ DataType uint8_type = DataType(e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8);
+ DataType uint16_type =
+ DataType(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
+ DataType uint32_type =
+ DataType(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
// Interpret input as vector of 8bit values
PrimExpr input8 = reinterpret(uint8_type, e);
vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt8_args.push_back(input8);
- PrimExpr vcnt8 = tir::CallNode::make(
- uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
+ PrimExpr vcnt8 =
+ tir::CallNode::make(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
// Accumulation 8->16bit
Array<PrimExpr> vcnt16_args;
vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8);
- PrimExpr vcnt16 = tir::CallNode::make(
- uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
+ PrimExpr vcnt16 =
+ tir::CallNode::make(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 16) {
return vcnt16;
}
vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16);
- PrimExpr vcnt32 = tir::CallNode::make(
- uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
+ PrimExpr vcnt32 =
+ tir::CallNode::make(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 32) {
return vcnt32;
}
vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32);
- return tir::CallNode::make(
- call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
+ return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
-.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
- CodeGenLLVM* cg = new CodeGenARM();
- *rv = static_cast<void*>(cg);
- });
+ .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+ CodeGenLLVM* cg = new CodeGenARM();
+ *rv = static_cast<void*>(cg);
+ });
} // namespace codegen
} // namespace tvm
* \file codegen_blob.cc
*/
#ifdef TVM_LLVM_VERSION
+#include "codegen_blob.h"
+
#include <tvm/runtime/module.h>
+
#include <cstring>
-#include "codegen_blob.h"
namespace tvm {
namespace codegen {
-std::pair<std::unique_ptr<llvm::Module>,
- std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(const std::string& data,
- bool system_lib,
- const std::string& target_triple) {
+std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(
+ const std::string& data, bool system_lib, const std::string& target_triple) {
InitializeLLVM();
auto tm = GetLLVMTargetMachine(std::string("-target ") + target_triple);
auto triple = tm->getTargetTriple();
module->setTargetTriple(triple.str());
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
- auto* tvm_dev_mblob = new llvm::GlobalVariable(*module, blob_value->getType(), true,
- llvm::GlobalValue::ExternalLinkage, blob_value,
- runtime::symbol::tvm_dev_mblob, nullptr,
- llvm::GlobalVariable::NotThreadLocal, 0);
+ auto* tvm_dev_mblob = new llvm::GlobalVariable(
+ *module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value,
+ runtime::symbol::tvm_dev_mblob, nullptr, llvm::GlobalVariable::NotThreadLocal, 0);
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob->setAlignment(llvm::Align(1));
auto int8_ptr_ty = int8_ty->getPointerTo(0);
llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty);
- auto* tvm_dev_mblob_reg =
- new llvm::GlobalVariable(*module, int32_ty,
- false, llvm::GlobalValue::InternalLinkage,
- constant_zero,
- std::string(runtime::symbol::tvm_dev_mblob) + "_reg_");
+ auto* tvm_dev_mblob_reg = new llvm::GlobalVariable(
+ *module, int32_ty, false, llvm::GlobalValue::InternalLinkage, constant_zero,
+ std::string(runtime::symbol::tvm_dev_mblob) + "_reg_");
auto tvm_dev_mblob_reg_alignment = module->getDataLayout().getABITypeAlignment(int32_ty);
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_reg->setAlignment(llvm::Align(tvm_dev_mblob_reg_alignment));
llvm::ArrayType::get(int8_ty, std::strlen(runtime::symbol::tvm_dev_mblob) + 1);
auto* tvm_dev_mblob_string_value =
llvm::ConstantDataArray::getString(*ctx, runtime::symbol::tvm_dev_mblob, true);
- auto* tvm_dev_mblob_string =
- new llvm::GlobalVariable(*module, tvm_dev_mblob_string_ty,
- true, llvm::GlobalValue::PrivateLinkage,
- tvm_dev_mblob_string_value,
- std::string(runtime::symbol::tvm_dev_mblob) + ".str");
+ auto* tvm_dev_mblob_string = new llvm::GlobalVariable(
+ *module, tvm_dev_mblob_string_ty, true, llvm::GlobalValue::PrivateLinkage,
+ tvm_dev_mblob_string_value, std::string(runtime::symbol::tvm_dev_mblob) + ".str");
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_string->setAlignment(llvm::Align(1));
#else
#endif
// Global init function
- llvm::Function* init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false),
- llvm::GlobalValue::InternalLinkage,
- llvm::Twine("_GLOBAL__sub_I_", module_name),
- module.get());
+ llvm::Function* init_fn = llvm::Function::Create(
+ llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage,
+ llvm::Twine("_GLOBAL__sub_I_", module_name), module.get());
// Create variable initialization function.
- llvm::Function* var_init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false),
- llvm::GlobalValue::InternalLinkage,
- llvm::Twine("__cxx_global_var_init"),
- module.get());
+ llvm::Function* var_init_fn = llvm::Function::Create(
+ llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage,
+ llvm::Twine("__cxx_global_var_init"), module.get());
// Create TVMBackendRegisterSystemLibSymbol function
llvm::Function* tvm_backend_fn =
llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false),
llvm::GlobalValue::ExternalLinkage,
- llvm::Twine("TVMBackendRegisterSystemLibSymbol"),
- module.get());
+ llvm::Twine("TVMBackendRegisterSystemLibSymbol"), module.get());
// Set necessary fn sections
auto get_static_init_section_specifier = [&triple]() -> std::string {
- if (triple.isOSLinux()) {
- return ".text.startup";
- } else if (triple.isOSDarwin()) {
- return "__TEXT,__StaticInit,regular,pure_instructions";
- } else {
- return "";
- }
+ if (triple.isOSLinux()) {
+ return ".text.startup";
+ } else if (triple.isOSDarwin()) {
+ return "__TEXT,__StaticInit,regular,pure_instructions";
+ } else {
+ return "";
+ }
};
auto static_init_section_specifier = get_static_init_section_specifier();
llvm::Constant* indices[] = {constant_zero, constant_zero};
llvm::SmallVector<llvm::Value*, 2> args;
args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_dev_mblob_string_ty,
- tvm_dev_mblob_string,
- indices));
- args.push_back(llvm::ConstantExpr::getGetElementPtr(blob_value->getType(),
- tvm_dev_mblob,
- indices));
+ tvm_dev_mblob_string, indices));
+ args.push_back(
+ llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), tvm_dev_mblob, indices));
auto* tvm_backend_fn_ret_value = ir_builder.CreateCall(tvm_backend_fn, args);
ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_dev_mblob_reg);
ir_builder.CreateRetVoid();
#ifndef TVM_TARGET_LLVM_CODEGEN_BLOB_H_
#define TVM_TARGET_LLVM_CODEGEN_BLOB_H_
#ifdef TVM_LLVM_VERSION
-#include <utility>
#include <memory>
#include <string>
+#include <utility>
+
#include "llvm_common.h"
namespace tvm {
*
* \return LLVM module and LLVM context
*/
-std::pair<std::unique_ptr<llvm::Module>,
- std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(const std::string& data,
- bool system_lib,
- const std::string& target_triple);
+std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(
+ const std::string& data, bool system_lib, const std::string& target_triple);
} // namespace codegen
} // namespace tvm
*/
#ifdef TVM_LLVM_VERSION
+#include "codegen_cpu.h"
+
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/tir/analysis.h>
+
#include <memory>
#include <unordered_map>
-#include "codegen_cpu.h"
namespace tvm {
namespace codegen {
-void CodeGenCPU::Init(const std::string& module_name,
- llvm::TargetMachine* tm,
- llvm::LLVMContext* ctx,
- bool system_lib,
- bool dynamic_lookup) {
+void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm,
+ llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) {
CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup);
dbg_info_ = CreateDebugInfo(module_.get());
static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
t_tvm_func_handle_ = t_void_p_;
- t_tvm_array_ = llvm::StructType::create(
- {t_void_p_,
- t_tvm_context_,
- t_int_,
- t_tvm_type_,
- t_tvm_shape_index_->getPointerTo(),
- t_tvm_shape_index_->getPointerTo(),
- t_int64_});
+ t_tvm_array_ = llvm::StructType::create({t_void_p_, t_tvm_context_, t_int_, t_tvm_type_,
+ t_tvm_shape_index_->getPointerTo(),
+ t_tvm_shape_index_->getPointerTo(), t_int64_});
t_tvm_value_ = llvm::StructType::create({t_float64_});
- t_tvm_parallel_group_env_ = llvm::StructType::create({
- t_int32_->getPointerTo(), t_int32_});
+ t_tvm_parallel_group_env_ = llvm::StructType::create({t_int32_->getPointerTo(), t_int32_});
ftype_tvm_parallel_lambda_ = llvm::FunctionType::get(
- t_int_,
- {t_int_,
- t_tvm_parallel_group_env_->getPointerTo(),
- t_void_p_}, false);
+ t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo(), t_void_p_}, false);
md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_);
// Runtime functions.
- ftype_tvm_func_call_ = llvm::FunctionType::get(t_int_, {
- t_tvm_func_handle_,
- t_tvm_value_->getPointerTo(),
- t_int_->getPointerTo(),
+ ftype_tvm_func_call_ = llvm::FunctionType::get(
t_int_,
- t_tvm_value_->getPointerTo(),
- t_int_->getPointerTo()}, false);
- ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(t_int_, {
- t_void_p_,
- t_char_->getPointerTo(),
- t_tvm_func_handle_->getPointerTo()}, false);
- ftype_tvm_api_set_last_error_ = llvm::FunctionType::get(
- t_void_, {t_char_->getPointerTo()}, false);
- ftype_tvm_parallel_launch_ =
- llvm::FunctionType::get(t_int_, {
- ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}
- , false);
+ {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_,
+ t_tvm_value_->getPointerTo(), t_int_->getPointerTo()},
+ false);
+ ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(
+ t_int_, {t_void_p_, t_char_->getPointerTo(), t_tvm_func_handle_->getPointerTo()}, false);
+ ftype_tvm_api_set_last_error_ =
+ llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false);
+ ftype_tvm_parallel_launch_ = llvm::FunctionType::get(
+ t_int_, {ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}, false);
ftype_tvm_parallel_barrier_ =
- llvm::FunctionType::get(t_int_, {
- t_int_, t_tvm_parallel_group_env_->getPointerTo()}
- , false);
- ftype_tvm_static_init_callback_ =
- llvm::FunctionType::get(t_int_, {t_void_p_}, false);
+ llvm::FunctionType::get(t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo()}, false);
+ ftype_tvm_static_init_callback_ = llvm::FunctionType::get(t_int_, {t_void_p_}, false);
ftype_tvm_static_init_ =
- llvm::FunctionType::get(t_int_, {
- t_void_p_->getPointerTo(),
- ftype_tvm_static_init_callback_->getPointerTo(),
- t_void_p_, t_int_}
- , false);
+ llvm::FunctionType::get(t_int_,
+ {t_void_p_->getPointerTo(),
+ ftype_tvm_static_init_callback_->getPointerTo(), t_void_p_, t_int_},
+ false);
// initialize TVM runtime API
if (system_lib) {
// We will need this in environment for backward registration.
f_tvm_register_system_symbol_ = nullptr;
}
if (dynamic_lookup || system_lib) {
- f_tvm_func_call_ = llvm::Function::Create(
- ftype_tvm_func_call_,
- llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
- f_tvm_get_func_from_env_ = llvm::Function::Create(
- ftype_tvm_get_func_from_env_,
- llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
- f_tvm_api_set_last_error_ = llvm::Function::Create(
- ftype_tvm_api_set_last_error_,
- llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
- f_tvm_parallel_launch_ = llvm::Function::Create(
- ftype_tvm_parallel_launch_,
- llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get());
- f_tvm_parallel_barrier_ = llvm::Function::Create(
- ftype_tvm_parallel_barrier_,
- llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get());
+ f_tvm_func_call_ = llvm::Function::Create(ftype_tvm_func_call_, llvm::Function::ExternalLinkage,
+ "TVMFuncCall", module_.get());
+ f_tvm_get_func_from_env_ =
+ llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage,
+ "TVMBackendGetFuncFromEnv", module_.get());
+ f_tvm_api_set_last_error_ =
+ llvm::Function::Create(ftype_tvm_api_set_last_error_, llvm::Function::ExternalLinkage,
+ "TVMAPISetLastError", module_.get());
+ f_tvm_parallel_launch_ =
+ llvm::Function::Create(ftype_tvm_parallel_launch_, llvm::Function::ExternalLinkage,
+ "TVMBackendParallelLaunch", module_.get());
+ f_tvm_parallel_barrier_ =
+ llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage,
+ "TVMBackendParallelBarrier", module_.get());
}
this->InitGlobalContext(dynamic_lookup);
}
#if TVM_LLVM_VERSION >= 80
auto* DIFunction = dbg_info_->di_builder_->createFunction(
- dbg_info_->file_, function->getName(), "",
- dbg_info_->file_,
- 0 /* line number */,
- DIFunctionTy,
- false /* internal linkage */);
+ dbg_info_->file_, function->getName(), "", dbg_info_->file_, 0 /* line number */,
+ DIFunctionTy, false /* internal linkage */);
#else
auto* DIFunction = dbg_info_->di_builder_->createFunction(
- dbg_info_->file_, function->getName(), "",
- dbg_info_->file_,
- 0 /* line number */,
- DIFunctionTy,
- false, /* internal linkage */
- true,
- 0 /* line number */,
- llvm::DINode::FlagPrototyped,
- true /* isOptimized */);
+ dbg_info_->file_, function->getName(), "", dbg_info_->file_, 0 /* line number */,
+ DIFunctionTy, false, /* internal linkage */
+ true, 0 /* line number */, llvm::DINode::FlagPrototyped, true /* isOptimized */);
#endif
CHECK(DIFunction);
llvm::Function* f = module_->getFunction(entry_func_name);
CHECK(f) << "Function " << entry_func_name << "does not in module";
llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1);
- llvm::GlobalVariable *global = new llvm::GlobalVariable(
- *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0,
- runtime::symbol::tvm_module_main);
+ llvm::GlobalVariable* global = new llvm::GlobalVariable(
+ *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0, runtime::symbol::tvm_module_main);
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(1));
#else
}
return CodeGenLLVM::Finish();
}
-llvm::Value* CodeGenCPU::CreateStructRefPtr(
- DataType t, llvm::Value* buf, llvm::Value* index, int kind) {
+llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index,
+ int kind) {
if (kind < intrinsic::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)});
}
case intrinsic::kArrTypeCode: {
- return builder_->CreateInBoundsGEP(
- buf, {index, ConstInt32(3), ConstInt32(0)});
+ return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)});
}
case intrinsic::kArrTypeBits: {
- return builder_->CreateInBoundsGEP(
- buf, {index, ConstInt32(3), ConstInt32(1)});
+ return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)});
}
case intrinsic::kArrTypeLanes: {
- return builder_->CreateInBoundsGEP(
- buf, {index, ConstInt32(3), ConstInt32(2)});
+ return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)});
}
case intrinsic::kArrByteOffset: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)});
}
case intrinsic::kArrDeviceId: {
- return builder_->CreateInBoundsGEP(
- buf, {index, ConstInt32(1), ConstInt32(1)});
+ return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)});
}
case intrinsic::kArrDeviceType: {
- return builder_->CreateInBoundsGEP(
- buf, {index, ConstInt32(1), ConstInt32(0)});
+ return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)});
}
case intrinsic::kTVMValueContent: {
CHECK_EQ(t.lanes(), 1);
return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo());
}
}
- default: LOG(FATAL) << "unknown field code"; return nullptr;
+ default:
+ LOG(FATAL) << "unknown field code";
+ return nullptr;
}
}
for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType());
}
- llvm::FunctionType* ftype = llvm::FunctionType::get(
- GetLLVMType(GetRef<PrimExpr>(op)), arg_types, false);
+ llvm::FunctionType* ftype =
+ llvm::FunctionType::get(GetLLVMType(GetRef<PrimExpr>(op)), arg_types, false);
// Check if it is available in global function table as injected function.
auto it = gv_func_map_.find(op->name);
if (it != gv_func_map_.end()) {
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
- f = llvm::Function::Create(
- ftype, llvm::Function::ExternalLinkage, op->name, module_.get());
+ f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, op->name, module_.get());
}
#if TVM_LLVM_VERSION >= 90
auto ext_callee = llvm::FunctionCallee(f);
}
}
-llvm::GlobalVariable* CodeGenCPU::InitContextPtr(
- llvm::Type* p_type, std::string name) {
+llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
- *module_, p_type, false,
- llvm::GlobalValue::LinkOnceAnyLinkage, 0,
- name);
+ *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, 0, name);
#if TVM_LLVM_VERSION >= 100
gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(p_type)));
#else
#else
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
#endif
- faddr->setMetadata(
- "tbaa",
- md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
+ faddr->setMetadata("tbaa",
+ md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
return faddr;
}
std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
} else {
if (!dynamic_lookup) {
- gv_tvm_func_call_ = InitContextPtr(
- ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall");
- gv_tvm_get_func_from_env_ = InitContextPtr(
- ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv");
- gv_tvm_api_set_last_error_ = InitContextPtr(
- ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError");
- gv_tvm_parallel_launch_ = InitContextPtr(
- ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch");
- gv_tvm_parallel_barrier_ = InitContextPtr(
- ftype_tvm_parallel_barrier_->getPointerTo(), "__TVMBackendParallelBarrier");
+ gv_tvm_func_call_ = InitContextPtr(ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall");
+ gv_tvm_get_func_from_env_ = InitContextPtr(ftype_tvm_get_func_from_env_->getPointerTo(),
+ "__TVMBackendGetFuncFromEnv");
+ gv_tvm_api_set_last_error_ =
+ InitContextPtr(ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError");
+ gv_tvm_parallel_launch_ =
+ InitContextPtr(ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch");
+ gv_tvm_parallel_barrier_ = InitContextPtr(ftype_tvm_parallel_barrier_->getPointerTo(),
+ "__TVMBackendParallelBarrier");
// Mark as context functions
gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function.
using llvm::BasicBlock;
- BasicBlock* fail_block = BasicBlock::Create(
- *ctx_, "call_fail", function_);
- BasicBlock* end_block = BasicBlock::Create(
- *ctx_, "call_end", function_);
- llvm::Value* succ = builder_->CreateICmpEQ(
- retcode, llvm::ConstantInt::get(t_int_, 0));
+ BasicBlock* fail_block = BasicBlock::Create(*ctx_, "call_fail", function_);
+ BasicBlock* end_block = BasicBlock::Create(*ctx_, "call_end", function_);
+ llvm::Value* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0));
builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
builder_->SetInsertPoint(fail_block);
// return the code.
arg_values.push_back(value);
arg_types.push_back(value->getType());
}
- llvm::FunctionType* ftype =
- llvm::FunctionType::get(t_int_, arg_types, false);
- llvm::Function* fcompute =
- llvm::Function::Create(ftype,
- llvm::Function::PrivateLinkage,
- op->value.as<StringImmNode>()->value,
- module_.get());
- BasicBlock* compute_call_end = CheckCallSuccess(
- builder_->CreateCall(fcompute, arg_values));
+ llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_types, false);
+ llvm::Function* fcompute = llvm::Function::Create(
+ ftype, llvm::Function::PrivateLinkage, op->value.as<StringImmNode>()->value, module_.get());
+ BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values));
// setup compute fuinction.
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
size_t idx = 0;
- for (auto it = fcompute->arg_begin();
- it != fcompute->arg_end(); ++it, ++idx) {
+ for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
const Var& var = vargs[idx];
new_vmap[var.get()] = v;
}
std::swap(function_, fcompute);
std::swap(new_vmap, var_map_);
- BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
+ BasicBlock* compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(compute_entry);
this->VisitStmt(op->body);
builder_->CreateRet(ConstInt32(0));
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
- builder_->CreateStore(
- var_map_.at(vfields[i].get()),
- builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
+ builder_->CreateStore(var_map_.at(vfields[i].get()),
+ builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
*num_bytes = data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType());
return cdata;
}
-void CodeGenCPU::UnpackClosureData(llvm::Value* cdata,
- const Array<Var>& vfields,
+void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, const Array<Var>& vfields,
std::unordered_map<const VarNode*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) {
(*vmap)[vfields[i].get()] =
- builder_->CreateLoad(builder_->CreateInBoundsGEP(
- cdata, {ConstInt32(0), ConstInt32(i)}));
+ builder_->CreateLoad(builder_->CreateInBoundsGEP(cdata, {ConstInt32(0), ConstInt32(i)}));
}
}
void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
using llvm::BasicBlock;
// closure data
- llvm::Function* f = llvm::Function::Create(
- ftype_tvm_parallel_lambda_,
- llvm::Function::PrivateLinkage,
- "__tvm_parallel_lambda", module_.get());
+ llvm::Function* f =
+ llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage,
+ "__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure.
Array<Var> vfields = tir::UndefinedVars(body, {});
uint64_t nbytes;
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
#if TVM_LLVM_VERSION >= 90
- auto launch_callee = llvm::FunctionCallee(
- ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch());
+ auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch());
#else
auto launch_callee = RuntimeTVMParallelLaunch();
#endif
- BasicBlock* par_launch_end = CheckCallSuccess(
- builder_->CreateCall(
- launch_callee,
- {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)}));
+ BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall(
+ launch_callee, {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)}));
// Setup the closure function.
- BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
+ BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* task_id = &(*it++);
par_env.task_id = Var("task_id", DataType::Int(32));
par_env.num_task = Var("num_task", DataType::Int(32));
new_vmap[par_env.task_id.get()] = task_id;
- new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
- builder_->CreateInBoundsGEP(
- penv, {ConstInt32(0), ConstInt32(1)}));
+ new_vmap[par_env.num_task.get()] =
+ builder_->CreateLoad(builder_->CreateInBoundsGEP(penv, {ConstInt32(0), ConstInt32(1)}));
par_env.penv = penv;
std::swap(function_, f);
std::swap(parallel_env_, par_env);
std::swap(var_map_, new_vmap);
std::swap(parallel_env_, par_env);
std::swap(function_, f);
- CHECK_NE(par_env.parallel_loop_count, 0)
- << "Cannot find parallel loop within parallel launch";
+ CHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch";
builder_->SetInsertPoint(par_launch_end);
}
llvm::Value* CodeGenCPU::CreateStaticHandle() {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
- *module_, t_void_p_, false,
- llvm::GlobalValue::PrivateLinkage, 0,
- "__tvm_static_handle");
+ *module_, t_void_p_, false, llvm::GlobalValue::PrivateLinkage, 0, "__tvm_static_handle");
#if TVM_LLVM_VERSION >= 100
gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(t_void_p_)));
#else
void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) {
using llvm::BasicBlock;
// closure data
- llvm::Function* f = llvm::Function::Create(
- ftype_tvm_static_init_callback_,
- llvm::Function::PrivateLinkage,
- "__tvm_static_init_lambda", module_.get());
+ llvm::Function* f =
+ llvm::Function::Create(ftype_tvm_static_init_callback_, llvm::Function::PrivateLinkage,
+ "__tvm_static_init_lambda", module_.get());
llvm::Value* gv = CreateStaticHandle();
llvm::Function* finit = module_->getFunction(init_fname);
if (finit == nullptr) {
- finit = llvm::Function::Create(
- ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get());
+ finit = llvm::Function::Create(ftype_tvm_static_init_, llvm::Function::ExternalLinkage,
+ init_fname, module_.get());
}
// allocate and setup the closure, call the closure.
uint64_t nbytes;
Array<Var> vfields = tir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
- BasicBlock* init_end = CheckCallSuccess(
- builder_->CreateCall(
- finit,
- {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)}));
+ BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall(
+ finit, {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)}));
// Setup the closure function.
- BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
+ BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
if (it == func_handle_map_.end()) {
// create global location for the handle
// create the function handle
- hptr = new llvm::GlobalVariable(
- *module_, t_tvm_func_handle_, false,
- llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname);
+ hptr =
+ new llvm::GlobalVariable(*module_, t_tvm_func_handle_, false,
+ llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname);
#if TVM_LLVM_VERSION >= 100
hptr->setAlignment(llvm::Align(align));
#else
}
// create emit codes that checks and load the function.
BasicBlock* pre_block = builder_->GetInsertBlock();
- BasicBlock* init_block = BasicBlock::Create(
- *ctx_, "handle_init", function_);
- BasicBlock* end_block = BasicBlock::Create(
- *ctx_, "handle_init_end", function_);
+ BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_);
+ BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_);
#if TVM_LLVM_VERSION >= 110
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align));
#else
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
#endif
- llvm::Value* handle_not_null = builder_->CreateICmpNE(
- handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
- builder_->CreateCondBr(
- handle_not_null, end_block, init_block, md_very_likely_branch_);
+ llvm::Value* handle_not_null =
+ builder_->CreateICmpNE(handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
+ builder_->CreateCondBr(handle_not_null, end_block, init_block, md_very_likely_branch_);
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
- llvm::Value* out = WithFunctionEntry([&]() {
- return builder_->CreateAlloca(t_tvm_func_handle_);
- });
+ llvm::Value* out =
+ WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); });
#if TVM_LLVM_VERSION >= 110
- llvm::LoadInst* ctx = builder_->CreateAlignedLoad(
- gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment()));
+ llvm::LoadInst* ctx =
+ builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment()));
#else
- llvm::LoadInst* ctx = builder_->CreateAlignedLoad(
- gv_mod_ctx_, gv_mod_ctx_->getAlignment());
+ llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment());
#endif
- ctx->setMetadata(
- "tbaa",
- md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
+ ctx->setMetadata("tbaa",
+ md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
#if TVM_LLVM_VERSION >= 90
- auto env_callee = llvm::FunctionCallee(
- ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv());
+ auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv());
#else
auto env_callee = RuntimeTVMGetFuncFromEnv();
#endif
- llvm::Value* retcode = builder_->CreateCall(
- env_callee, {ctx, GetConstString(fname), out});
+ llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
#if TVM_LLVM_VERSION >= 110
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align));
return phi;
}
-llvm::BasicBlock *
-CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue,
- llvm::Value **ret_tcode, const DataType &r_type,
- const int64_t begin, const int64_t end) {
+llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array<PrimExpr>& args, llvm::Value** rvalue,
+ llvm::Value** ret_tcode, const DataType& r_type,
+ const int64_t begin, const int64_t end) {
using llvm::BasicBlock;
std::string func_name = args[0].as<StringImmNode>()->value;
- llvm::Value *handle = GetPackedFuncHandle(func_name);
+ llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
int64_t nargs = end - begin;
CHECK_GE(nargs, 0);
- llvm::Value *stack_value = MakeValue(args[1]);
- llvm::Value *stack_tcode = MakeValue(args[2]);
- llvm::Value *arg_value = builder_->CreateInBoundsGEP(
- builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
- ConstInt32(begin));
- llvm::Value *arg_tcode =
- CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
- llvm::Value *ret_value = builder_->CreateInBoundsGEP(
- builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
- ConstInt32(end));
+ llvm::Value* stack_value = MakeValue(args[1]);
+ llvm::Value* stack_tcode = MakeValue(args[2]);
+ llvm::Value* arg_value = builder_->CreateInBoundsGEP(
+ builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin));
+ llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
+ llvm::Value* ret_value = builder_->CreateInBoundsGEP(
+ builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end));
*ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));
#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
#else
auto call_callee = RuntimeTVMFuncCall();
#endif
- BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall(
- call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs),
- ret_value, *ret_tcode}));
+ BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall(
+ call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode}));
DataType r_api_type = tir::APIType(r_type);
- llvm::Value* load_ptr = builder_->CreatePointerCast(
- ret_value, DTypeToLLVMType(r_api_type)->getPointerTo());
+ llvm::Value* load_ptr =
+ builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo());
#if TVM_LLVM_VERSION >= 110
*rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8));
#else
return end_block;
}
-llvm::Value *CodeGenCPU::CreateCallPacked(const CallNode *op) {
+llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) {
CHECK_EQ(op->args.size(), 5U);
- llvm::Value *rvalue = nullptr;
- llvm::Value *ret_tcode = nullptr;
- MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype,
- op->args[3].as<IntImmNode>()->value,
+ llvm::Value* rvalue = nullptr;
+ llvm::Value* ret_tcode = nullptr;
+ MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
return rvalue;
}
-llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) {
+llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) {
using llvm::BasicBlock;
CHECK_EQ(op->args.size(), 6U);
- llvm::Value *rvalue = nullptr;
- llvm::Value *ret_tcode = nullptr;
- BasicBlock *end_block = MakeCallPacked(
- op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
- op->args[4].as<IntImmNode>()->value);
+ llvm::Value* rvalue = nullptr;
+ llvm::Value* ret_tcode = nullptr;
+ BasicBlock* end_block =
+ MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
+ op->args[4].as<IntImmNode>()->value);
// Get traced value.
- llvm::Value *traced_value = MakeValue(op->args[5]);
+ llvm::Value* traced_value = MakeValue(op->args[5]);
// The update_block handles case when we need to update the return value.
- BasicBlock *update_block =
- BasicBlock::Create(*ctx_, "update_block", function_);
+ BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_);
// The continue_block handles case when we need to return original
// traced value.
- BasicBlock *continue_block =
- BasicBlock::Create(*ctx_, "continue_block", function_);
+ BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_);
#if TVM_LLVM_VERSION >= 110
- llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8));
+ llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8));
#else
- llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8);
+ llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8);
#endif
// Check the ret_type_code and create cmp instruction.
- llvm::Value *cmp = builder_->CreateICmpNE(
- ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr));
+ llvm::Value* cmp =
+ builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr));
builder_->CreateCondBr(cmp, update_block, continue_block);
builder_->SetInsertPoint(update_block);
builder_->CreateBr(continue_block);
builder_->SetInsertPoint(continue_block);
// The return value depends on from what bb we come from.
- llvm::PHINode *phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2);
+ llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2);
phi_rvalue->addIncoming(rvalue, update_block);
phi_rvalue->addIncoming(traced_value, end_block);
return phi_rvalue;
void CodeGenCPU::AddStartupFunction() {
if (export_system_symbols_.size() != 0) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
- function_ = llvm::Function::Create(
- ftype,
- llvm::Function::InternalLinkage,
- "__tvm_module_startup", module_.get());
+ function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
+ "__tvm_module_startup", module_.get());
llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(startup_entry);
for (const auto& kv : export_system_symbols_) {
llvm::Value* name = GetConstString(kv.first);
- builder_->CreateCall(
- f_tvm_register_system_symbol_, {
- name, builder_->CreateBitCast(kv.second, t_void_p_)});
+ builder_->CreateCall(f_tvm_register_system_symbol_,
+ {name, builder_->CreateBitCast(kv.second, t_void_p_)});
}
llvm::appendToGlobalCtors(*module_, function_, 65535);
builder_->CreateRet(nullptr);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImmNode>()->value;
- llvm::Value* ref = this->CreateStructRefPtr(
- op->dtype, MakeValue(op->args[0]),
- MakeValue(op->args[1]), kind);
+ llvm::Value* ref =
+ this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind);
if (kind == intrinsic::kArrAddr) {
return builder_->CreatePointerCast(ref, t_void_p_);
} else {
CHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* value = MakeValue(op->args[3]);
- llvm::Value* ref = this->CreateStructRefPtr(
- op->args[3].dtype(), MakeValue(op->args[0]),
- MakeValue(op->args[1]), kind);
+ llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]),
+ MakeValue(op->args[1]), kind);
CHECK(kind != intrinsic::kArrAddr);
if (value->getType()->isPointerTy()) {
- value = builder_->CreatePointerCast(
- value, ref->getType()->getPointerElementType());
+ value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType());
}
builder_->CreateStore(value, ref);
return ConstInt32(0);
CHECK_EQ(op->args.size(), 2U);
const std::string& type = op->args[0].as<StringImmNode>()->value;
return WithFunctionEntry([&]() -> llvm::AllocaInst* {
- const int64_t* pval = as_const_int(op->args[1]);
- CHECK(pval) << "require stack alloca to contain constant value";
- llvm::Value* num = ConstInt32(pval[0]);
- if (type == "shape") {
- return builder_->CreateAlloca(t_tvm_shape_index_, num);
- } else if (type == "arg_value") {
- return builder_->CreateAlloca(t_tvm_value_, num);
- } else if (type == "arg_tcode") {
- return builder_->CreateAlloca(t_int_, num);
- } else if (type == "array") {
- return builder_->CreateAlloca(t_tvm_array_, num);
- } else {
- LOG(FATAL) << "Unknown stack alloca type " << type;
- return nullptr;
- }
- });
+ const int64_t* pval = as_const_int(op->args[1]);
+ CHECK(pval) << "require stack alloca to contain constant value";
+ llvm::Value* num = ConstInt32(pval[0]);
+ if (type == "shape") {
+ return builder_->CreateAlloca(t_tvm_shape_index_, num);
+ } else if (type == "arg_value") {
+ return builder_->CreateAlloca(t_tvm_value_, num);
+ } else if (type == "arg_tcode") {
+ return builder_->CreateAlloca(t_int_, num);
+ } else if (type == "array") {
+ return builder_->CreateAlloca(t_tvm_array_, num);
+ } else {
+ LOG(FATAL) << "Unknown stack alloca type " << type;
+ return nullptr;
+ }
+ });
} else {
return CodeGenLLVM::CreateIntrinsic(op);
}
os << ", " << op->message.as<StringImmNode>()->value;
}
llvm::Value* msg = GetConstString(os.str());
- BasicBlock* fail_block = BasicBlock::Create(
- *ctx_, "assert_fail", function_);
- BasicBlock* end_block = BasicBlock::Create(
- *ctx_, "assert_end", function_);
+ BasicBlock* fail_block = BasicBlock::Create(*ctx_, "assert_fail", function_);
+ BasicBlock* end_block = BasicBlock::Create(*ctx_, "assert_end", function_);
builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
// fail condition.
builder_->SetInsertPoint(fail_block);
#if TVM_LLVM_VERSION >= 90
- auto err_callee = llvm::FunctionCallee(
- ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError());
+ auto err_callee =
+ llvm::FunctionCallee(ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError());
#else
auto err_callee = RuntimeTVMAPISetLastError();
#endif
void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::coproc_uop_scope) {
this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body);
- } else if (op->attr_key == tir::attr::compute_scope) {
+ } else if (op->attr_key == tir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (tir::attr::IsPragmaKey(op->attr_key)) {
if (op->attr_key == "pragma_parallel_stride_pattern") {
} else if (op->attr_key == "pragma_parallel_launch_point") {
CreateParallelLaunch(op->body, 0);
} else if (op->attr_key == "pragma_parallel_barrier_when_finish") {
- CHECK(parallel_env_.penv != nullptr)
- << "Cannot run barrier without parallel environment";
+ CHECK(parallel_env_.penv != nullptr) << "Cannot run barrier without parallel environment";
CHECK(!parallel_env_.in_parallel_loop)
<< "Cannot not place within parallel loop as the workload may differ, "
<< " place it between parallel and parallel_launch_point";
this->VisitStmt(op->body);
#if TVM_LLVM_VERSION >= 90
- auto bar_callee = llvm::FunctionCallee(
- ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier());
+ auto bar_callee =
+ llvm::FunctionCallee(ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier());
#else
auto bar_callee = RuntimeTVMParallelBarrier();
#endif
- builder_->CreateCall(
- bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv});
+ builder_->CreateCall(bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else if (op->attr_key == tir::attr::pragma_import_llvm) {
const StringImmNode* value = op->value.as<StringImmNode>();
CHECK(value != nullptr);
void CodeGenCPU::VisitStmt_(const ForNode* op) {
CHECK(is_zero(op->min));
- if (op->for_type == ForType::Serial ||
- op->for_type == ForType::Unrolled) {
+ if (op->for_type == ForType::Serial || op->for_type == ForType::Unrolled) {
CodeGenLLVM::VisitStmt_(op);
} else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
- ForNode::make(
- op->loop_var, op->min, op->extent,
- op->for_type, op->device_api, op->body), 0);
+ ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body),
+ 0);
} else {
// already in parallel env.
CHECK(parallel_env_.task_id.defined());
<< "Nested parallel loop is not supported by threadpool, try fuse them instead";
parallel_env_.in_parallel_loop = true;
if (parallel_env_.stride_pattern) {
- CreateSerialFor(MakeValue(task_id),
- MakeValue(op->extent),
- MakeValue(num_task),
- op->loop_var,
- op->body);
+ CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task),
+ op->loop_var, op->body);
} else {
PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
PrimExpr begin = MinNode::make(task_id * step, op->extent);
PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
- CreateSerialFor(MakeValue(begin),
- MakeValue(end),
- llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
- op->loop_var,
- op->body);
+ CreateSerialFor(MakeValue(begin), MakeValue(end),
+ llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body);
}
parallel_env_.in_parallel_loop = false;
++parallel_env_.parallel_loop_count;
#ifndef TVM_TARGET_LLVM_CODEGEN_CPU_H_
#define TVM_TARGET_LLVM_CODEGEN_CPU_H_
-#include <utility>
-#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
+#include <utility>
+#include <vector>
+
#include "codegen_llvm.h"
namespace tvm {
// CPU host code generation
class CodeGenCPU : public CodeGenLLVM {
public:
- void Init(const std::string& module_name,
- llvm::TargetMachine* tm,
- llvm::LLVMContext* ctx,
- bool system_lib,
- bool dynamic_lookup) override;
+ void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx,
+ bool system_lib, bool dynamic_lookup) override;
void AddFunction(const PrimFunc& f) override;
void AddMainFunction(const std::string& entry_func_name) override;
std::unique_ptr<llvm::Module> Finish() override;
llvm::Value* RuntimeTVMParallelBarrier();
llvm::Value* CreateStaticHandle();
llvm::Value* GetPackedFuncHandle(const std::string& str);
- llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t *num_bytes);
+ llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t* num_bytes);
llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
- void UnpackClosureData(llvm::Value*cdata,
- const Array<Var>& fields,
+ void UnpackClosureData(llvm::Value* cdata, const Array<Var>& fields,
std::unordered_map<const VarNode*, llvm::Value*>* vmap);
// Make packed call.
- llvm::BasicBlock *MakeCallPacked(const Array<PrimExpr> &args,
- llvm::Value **rvalue,
- llvm::Value **ret_tcode, const DataType &r_type,
+ llvm::BasicBlock* MakeCallPacked(const Array<PrimExpr>& args, llvm::Value** rvalue,
+ llvm::Value** ret_tcode, const DataType& r_type,
const int64_t begin, const int64_t end);
// create call into tvm packed function.
llvm::Value* CreateCallPacked(const CallNode* op);
// Create trace call into tvm packed function.
- llvm::Value* CreateCallTracePacked(const CallNode *op);
+ llvm::Value* CreateCallTracePacked(const CallNode* op);
// Create static initialization
void CreateStaticInit(const std::string& init_fname, const Stmt& body);
// Create parallel launch
*/
#ifdef TVM_LLVM_VERSION
// Part of the code are adapted from Halide's CodeGen_LLVM
-#include <tvm/runtime/device_api.h>
+#include "codegen_llvm.h"
+
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/tir/op.h>
#include <algorithm>
-#include "codegen_llvm.h"
-#include "codegen_cpu.h"
#include "../../arith/pattern_match.h"
#include "../build_common.h"
+#include "codegen_cpu.h"
namespace tvm {
namespace codegen {
-std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
+std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine* tm) {
std::string target = tm->getTarget().getName();
std::string factory_name = "tvm.codegen.llvm.target_" + target;
const PackedFunc* f = runtime::Registry::Get(factory_name);
}
}
-void CodeGenLLVM::Init(const std::string& module_name,
- llvm::TargetMachine* tm,
- llvm::LLVMContext* ctx,
- bool system_lib,
- bool dynamic_lookup) {
+void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm,
+ llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) {
InitializeLLVM();
ctx_ = ctx;
builder_.reset(new IRBuilder(*ctx_));
t_int64_ = llvm::Type::getInt64Ty(*ctx_);
t_float64_ = llvm::Type::getDoubleTy(*ctx_);
// meta data
- md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1);
+ md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
this->InitTarget(tm);
}
}
-void CodeGenLLVM::AddFunction(const PrimFunc& f) {
- this->AddFunctionInternal(f, false);
-}
+void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }
void CodeGenLLVM::InitFuncState() {
var_map_.clear();
analyzer_.reset(new arith::Analyzer());
}
-
void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
this->InitFuncState();
// TODO(tvm-team):
// Update the function type to respect the ret_type field of f.
// Once we allow more flexibility in the PrimFunc.
- llvm::FunctionType* ftype = llvm::FunctionType::get(
- ret_void ? t_void_ : t_int_, param_types, false);
+ llvm::FunctionType* ftype =
+ llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
CHECK(module_->getFunction(static_cast<std::string>(global_symbol.value())) == nullptr)
<< "Function " << global_symbol << " already exist in module";
- function_ = llvm::Function::Create(
- ftype, llvm::Function::ExternalLinkage,
- global_symbol.value().operator std::string(), module_.get());
+ function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
+ global_symbol.value().operator std::string(), module_.get());
function_->setCallingConv(llvm::CallingConv::C);
function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
}
}
-
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
for (size_t i = 0; i < link_modules_.size(); ++i) {
return std::move(module_);
}
-
void CodeGenLLVM::HandleImport(const std::string& code) {
std::unique_ptr<llvm::Module> mlib;
llvm::SMDiagnostic err;
if (code.length() >= 3 &&
- (code.substr(code.length() - 3) == ".ll" ||
- code.substr(code.length() - 3) == ".bc")) {
+ (code.substr(code.length() - 3) == ".ll" || code.substr(code.length() - 3) == ".bc")) {
mlib = llvm::parseIRFile(code, err, *ctx_);
if (mlib.get() == nullptr) {
std::string msg = std::string(err.getMessage());
<< "line " << err.getLineNo() << ":" << msg;
}
} else {
- std::unique_ptr<llvm::MemoryBuffer> buf =
- llvm::MemoryBuffer::getMemBuffer(code);
+ std::unique_ptr<llvm::MemoryBuffer> buf = llvm::MemoryBuffer::getMemBuffer(code);
mlib = llvm::parseIR(*buf, err, *ctx_);
if (mlib.get() == nullptr) {
std::string msg = std::string(err.getMessage());
LOG(FATAL) << "Fail to load llvm ir "
- << "line " << err.getLineNo() << ":" << msg
- << "\ncontent:\n" << code;
+ << "line " << err.getLineNo() << ":" << msg << "\ncontent:\n"
+ << code;
}
}
mlib->setTargetTriple(target_machine_->getTargetTriple().str());
mlib->setDataLayout(target_machine_->createDataLayout());
// mark all the functions as force inline
- for (llvm::Function &f : mlib->functions()) {
+ for (llvm::Function& f : mlib->functions()) {
f.removeFnAttr(llvm::Attribute::NoInline);
f.addFnAttr(llvm::Attribute::AlwaysInline);
f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
class FPassManager : public llvm::legacy::FunctionPassManager {
public:
- explicit FPassManager(llvm::Module* m)
- : llvm::legacy::FunctionPassManager(m) {}
+ explicit FPassManager(llvm::Module* m) : llvm::legacy::FunctionPassManager(m) {}
// override add to allow messaging
- void add(llvm::Pass* p) final {
- llvm::legacy::FunctionPassManager::add(p);
- }
+ void add(llvm::Pass* p) final { llvm::legacy::FunctionPassManager::add(p); }
};
class MPassManager : public llvm::legacy::PassManager {
public:
// override add to allow messaging
- void add(llvm::Pass* p) final {
- llvm::legacy::PassManager::add(p);
- }
+ void add(llvm::Pass* p) final { llvm::legacy::PassManager::add(p); }
};
-void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {
-}
+void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {}
void CodeGenLLVM::Optimize() {
// pass manager
FPassManager fpass(module_.get());
MPassManager mpass;
mpass.add(llvm::createTargetTransformInfoWrapperPass(
- target_machine_ ? target_machine_->getTargetIRAnalysis() :
- llvm::TargetIRAnalysis()));
+ target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
fpass.add(llvm::createTargetTransformInfoWrapperPass(
- target_machine_ ? target_machine_->getTargetIRAnalysis() :
- llvm::TargetIRAnalysis()));
+ target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
// place optimization pass
llvm::PassManagerBuilder builder;
return native_vector_bits_;
}
-unsigned CodeGenLLVM::GetGlobalAddressSpace() const {
- return 0;
-}
+unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; }
llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
if (dtype.is_handle()) {
etype = llvm::Type::getIntNTy(*ctx_, dtype.bits());
} else if (dtype.is_float()) {
switch (dtype.bits()) {
- case 16: etype = llvm::Type::getHalfTy(*ctx_); break;
- case 32: etype = llvm::Type::getFloatTy(*ctx_); break;
- case 64: etype = llvm::Type::getDoubleTy(*ctx_); break;
- default: LOG(FATAL) << "do not support " << dtype;
+ case 16:
+ etype = llvm::Type::getHalfTy(*ctx_);
+ break;
+ case 32:
+ etype = llvm::Type::getFloatTy(*ctx_);
+ break;
+ case 64:
+ etype = llvm::Type::getDoubleTy(*ctx_);
+ break;
+ default:
+ LOG(FATAL) << "do not support " << dtype;
}
}
if (dtype.lanes() != 1) {
//
// This trick comes from Halide's CodeGen_LLVM
//
-void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
- const VarNode* buffer,
- PrimExpr index,
+void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, PrimExpr index,
DataType type) {
if (alias_var_set_.count(buffer) != 0) {
// Mark all possibly aliased pointer as same type.
llvm::MDNode* meta = md_tbaa_alias_set_;
- inst->setMetadata(
- "tbaa",
- md_builder_->createTBAAStructTagNode(meta, meta, 0));
+ inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
return;
}
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
}
}
- inst->setMetadata(
- "tbaa",
- md_builder_->createTBAAStructTagNode(meta, meta, 0));
+ inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
}
-void CodeGenLLVM::GetAlignment(DataType t,
- const VarNode* buf_var,
- const PrimExpr& index,
- int* p_alignment,
- int* p_native_bits) {
+void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index,
+ int* p_alignment, int* p_native_bits) {
int max_align_bits = t.bits();
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
int64_t coeff = me->coeff;
int align_bits = t.bits();
- while (align_bits < max_align_bits &&
- base % 2 == 0 &&
- coeff % 2 == 0) {
- base = base / 2;
- coeff = coeff / 2;
+ while (align_bits < max_align_bits && base % 2 == 0 && coeff % 2 == 0) {
+ base = base / 2;
+ coeff = coeff / 2;
align_bits *= 2;
}
if (align_bits < 8) {
*p_alignment = align_bits / 8;
}
-std::unique_ptr<CodeGenLLVM::DebugInfo>
-CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
+std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
#if TVM_LLVM_VERSION >= 100
auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>();
debug_info->di_builder_ = std::make_unique<llvm::DIBuilder>(*module);
}
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
- llvm::Constant* undef = llvm::UndefValue::get(
- llvm::VectorType::get(value->getType(), lanes));
+ llvm::Constant* undef = llvm::UndefValue::get(llvm::VectorType::get(value->getType(), lanes));
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(undef, value, zero);
#if TVM_LLVM_VERSION >= 110
}
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
- llvm::Value* mask = llvm::UndefValue::get(
- DTypeToLLVMType(DataType::Int(32, target_lanes)));
+ llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes)));
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
if (num_elems == target_lanes) return vec;
CHECK_LT(num_elems, target_lanes);
return CreateVecSlice(vecs[0], 0, total_lanes);
}
-
-void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
- llvm::Value* end,
- llvm::Value* stride,
- const Var& loop_var,
- const Stmt& body) {
+void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
+ const Var& loop_var, const Stmt& body) {
using llvm::BasicBlock;
BasicBlock* pre_block = builder_->GetInsertBlock();
- BasicBlock* for_begin = BasicBlock::Create(
- *ctx_, "for_begin", function_);
- BasicBlock* for_body = BasicBlock::Create(
- *ctx_, "for_body", function_);
- BasicBlock* for_end = BasicBlock::Create(
- *ctx_, "for_end", function_);
+ BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_);
+ BasicBlock* for_body = BasicBlock::Create(*ctx_, "for_body", function_);
+ BasicBlock* for_end = BasicBlock::Create(*ctx_, "for_end", function_);
builder_->CreateBr(for_begin);
builder_->SetInsertPoint(for_begin);
llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
loop_value->addIncoming(begin, pre_block);
CHECK(!var_map_.count(loop_var.get()));
var_map_[loop_var.get()] = loop_value;
- builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end),
- for_body, for_end, md_very_likely_branch_);
+ builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end,
+ md_very_likely_branch_);
builder_->SetInsertPoint(for_body);
this->VisitStmt(body);
var_map_.erase(loop_var.get());
// cast operatpr
llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
- llvm::Type * target = DTypeToLLVMType(to);
+ llvm::Type* target = DTypeToLLVMType(to);
if (value->getType() == target) return value;
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
auto it = str_map_.find(str);
if (it != str_map_.end()) return it->second;
llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
- llvm::GlobalVariable *global = new llvm::GlobalVariable(
- *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
+ llvm::GlobalVariable* global =
+ new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(1));
#else
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero};
- llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(
- type, global, indices);
+ llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(type, global, indices);
str_map_[str] = ptr;
return ptr;
}
-llvm::Value* CodeGenLLVM::CreateBufferPtr(
- DataType t, llvm::Value* buffer, llvm::Value* index) {
+llvm::Value* CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index) {
CHECK_EQ(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
return builder_->CreateInBoundsGEP(buffer, index);
}
-llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
- DataType t, llvm::Value* buffer, llvm::Value* index) {
+llvm::Value* CodeGenLLVM::CreateBufferVecPtr(DataType t, llvm::Value* buffer, llvm::Value* index) {
CHECK_GT(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
- llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(
- btype->getAddressSpace());
+ llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace());
if (btype != ptype) {
buffer = builder_->CreatePointerCast(buffer, ptype);
}
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
- llvm::FunctionType* ftype = llvm::FunctionType::get(
- GetLLVMType(GetRef<PrimExpr>(op)), arg_type, false);
+ llvm::FunctionType* ftype =
+ llvm::FunctionType::get(GetLLVMType(GetRef<PrimExpr>(op)), arg_type, false);
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
- f = llvm::Function::Create(
- ftype, llvm::Function::ExternalLinkage,
- op->name, module_.get());
+ f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, op->name, module_.get());
}
llvm::CallInst* call = builder_->CreateCall(f, arg_value);
return call;
}
-llvm::Function* CodeGenLLVM::GetIntrinsicDecl(
- llvm::Intrinsic::ID id, llvm::Type* ret_type,
- llvm::ArrayRef<llvm::Type*> arg_types) {
+llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
+ llvm::ArrayRef<llvm::Type*> arg_types) {
llvm::Module* module = module_.get();
if (!llvm::Intrinsic::isOverloaded(id)) {
auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) {
overload_types.clear();
llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
- auto match =
- llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
+ auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref);
if (error) {
// Failed to identify the type.
return nullptr;
-#else // TVM_LLVM_VERSION
+#else // TVM_LLVM_VERSION
llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
// matchIntrinsicType returns true on error.
if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) {
llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
- llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
- Downcast<IntImm>(op->args[0])->value);
- int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
+ llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
+ int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = 2; i < op->args.size(); ++i) {
// mismatch will have to be treated specially here.
// TODO(kparzysz-quic): fix this once TVM prefetch uses the same
// type as LLVM.
- llvm::Type *return_type = (id != llvm::Intrinsic::prefetch)
- ? GetLLVMType(GetRef<PrimExpr>(op))
- : llvm::Type::getVoidTy(*ctx_);
+ llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef<PrimExpr>(op))
+ : llvm::Type::getVoidTy(*ctx_);
llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const LoadNode *l = op->args[0].as<LoadNode>();
+ const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
- const RampNode *r = l->index.as<RampNode>();
+ const RampNode* r = l->index.as<RampNode>();
llvm::Value* ptr;
unsigned addrspace;
if (!r) {
- ptr = CreateBufferPtr(
- l->dtype, MakeValue(l->buffer_var), MakeValue(l->index));
- addrspace = llvm::dyn_cast<llvm::PointerType>(
- ptr->getType())->getAddressSpace();
+ ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index));
+ addrspace = llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getAddressSpace();
} else {
- PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes);
- ptr = CreateBufferVecPtr(
- l->dtype, MakeValue(l->buffer_var), MakeValue(index));
- addrspace = llvm::dyn_cast<llvm::PointerType>(
- ptr->getType())->getAddressSpace();
+ PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes);
+ ptr = CreateBufferVecPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index));
+ addrspace = llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getAddressSpace();
}
return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
} else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) {
uint64_t val = (high << 32U) | low;
return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
- CHECK_EQ(op->args[0].dtype().lanes(), 1)
- << "if_then_else can only take scalar condition";
+ CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition";
using llvm::BasicBlock;
- BasicBlock* then_block = BasicBlock::Create(
- *ctx_, "if_then", function_);
- BasicBlock* else_block = BasicBlock::Create(
- *ctx_, "if_else", function_);
- BasicBlock* end_block = BasicBlock::Create(
- *ctx_, "if_end", function_);
+ BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_);
+ BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_);
+ BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_);
builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
builder_->SetInsertPoint(then_block);
llvm::Value* then_value = MakeValue(op->args[1]);
value->addIncoming(else_value, else_value_block);
return value;
} else if (op->is_intrinsic(CallNode::reinterpret)) {
- llvm::Type * target = DTypeToLLVMType(op->dtype);
+ llvm::Type* target = DTypeToLLVMType(op->dtype);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->is_intrinsic(CallNode::isnan)) {
// TODO(hgt312): set fast math flag
llvm::Value* a = MakeValue(op->args[0]);
return builder_->CreateFCmpUNO(a, a);
} else if (op->is_intrinsic("vectorlow")) {
- llvm::Value *v = MakeValue(op->args[0]);
+ llvm::Value* v = MakeValue(op->args[0]);
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
- return CreateVecSlice(v, 0, l/2);
+ return CreateVecSlice(v, 0, l / 2);
} else if (op->is_intrinsic("vectorhigh")) {
- llvm::Value *v = MakeValue(op->args[0]);
+ llvm::Value* v = MakeValue(op->args[0]);
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
- return CreateVecSlice(v, l/2, l/2);
+ return CreateVecSlice(v, l / 2, l / 2);
} else if (op->is_intrinsic("vectorcombine")) {
- llvm::Value *v0 = MakeValue(op->args[0]);
- llvm::Value *v1 = MakeValue(op->args[1]);
+ llvm::Value* v0 = MakeValue(op->args[0]);
+ llvm::Value* v1 = MakeValue(op->args[1]);
int num_elems = llvm::cast<llvm::VectorType>(v0->getType())->getNumElements() * 2;
#if TVM_LLVM_VERSION >= 110
std::vector<int> indices;
}
}
-void CodeGenLLVM::Scalarize(const PrimExpr& e,
- std::function<void(int i, llvm::Value* v)> f) {
+void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f) {
if (const RampNode* ramp = e.as<RampNode>()) {
for (int i = 0; i < ramp->dtype.lanes(); ++i) {
PrimExpr offset = ramp->base + (ramp->stride * i);
}
}
-
// Visitors
-llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) {
- return GetVarValue(op);
-}
+llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); }
llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value);
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
- return GetConstString(op->value);
-}
-
-#define DEFINE_CODEGEN_BINARY_OP(Op) \
- llvm::Value* CodeGenLLVM::Create ## Op( \
- DataType t, llvm::Value* a, llvm::Value *b) { \
- if (t.is_int()) { \
- if (t.bits() >= 32) { \
- return builder_->CreateNSW ## Op (a, b); \
- } else { \
- return builder_->Create ## Op (a, b); \
- } \
- } else if (t.is_uint()) { \
- if (t.bits() >= 32) { \
- return builder_->CreateNUW ## Op (a, b); \
- } else { \
- return builder_->Create ## Op (a, b); \
- } \
- } else { \
- CHECK(t.is_float()); \
- return builder_->CreateF ## Op (a, b); \
- } \
- } \
- llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \
- return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
+llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); }
+
+#define DEFINE_CODEGEN_BINARY_OP(Op) \
+ llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
+ if (t.is_int()) { \
+ if (t.bits() >= 32) { \
+ return builder_->CreateNSW##Op(a, b); \
+ } else { \
+ return builder_->Create##Op(a, b); \
+ } \
+ } else if (t.is_uint()) { \
+ if (t.bits() >= 32) { \
+ return builder_->CreateNUW##Op(a, b); \
+ } else { \
+ return builder_->Create##Op(a, b); \
+ } \
+ } else { \
+ CHECK(t.is_float()); \
+ return builder_->CreateF##Op(a, b); \
+ } \
+ } \
+ llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
+ return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul);
-#define DEFINE_CODEGEN_CMP_OP(Op) \
- llvm::Value* CodeGenLLVM::Create ## Op( \
- DataType t, llvm::Value* a, llvm::Value* b) { \
- if (t.is_int()) { \
- return builder_->CreateICmpS ## Op (a, b); \
- } else if (t.is_uint()) { \
- return builder_->CreateICmpU ## Op (a, b); \
- } else { \
- CHECK(t.is_float()); \
- return builder_->CreateFCmpO ## Op (a, b); \
- } \
-} \
- llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \
- return Create ## Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
+#define DEFINE_CODEGEN_CMP_OP(Op) \
+ llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
+ if (t.is_int()) { \
+ return builder_->CreateICmpS##Op(a, b); \
+ } else if (t.is_uint()) { \
+ return builder_->CreateICmpU##Op(a, b); \
+ } else { \
+ CHECK(t.is_float()); \
+ return builder_->CreateFCmpO##Op(a, b); \
+ } \
+ } \
+ llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
+ return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_CMP_OP(LT);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
- return builder_->CreateSelect(
- MakeValue(op->condition),
- MakeValue(op->true_value),
- MakeValue(op->false_value));
+ return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value),
+ MakeValue(op->false_value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
#if TVM_LLVM_VERSION >= 110
- llvm::LoadInst* load =
- builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile);
+ llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile);
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
#endif
return load;
} else {
// vector load
- unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
- buffer->getType())->getAddressSpace();
+ unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(buffer->getType())->getAddressSpace();
if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
CHECK_EQ(ramp->lanes, t.lanes());
- llvm::Value* ptr = CreateBufferPtr(
- t.element_of(), buffer, MakeValue(ramp->base));
- ptr = builder_->CreatePointerCast(
- ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
+ llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base));
+ ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
#if TVM_LLVM_VERSION >= 110
- llvm::LoadInst* load = builder_->CreateAlignedLoad(
- ptr, llvm::Align(alignment), is_volatile);
+ llvm::LoadInst* load =
+ builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile);
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
#endif
auto f = [&](int i, llvm::Value* index) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
#if TVM_LLVM_VERSION >= 110
- llvm::LoadInst* load = builder_->CreateAlignedLoad(
- ptr, llvm::Align(basic_align), is_volatile);
+ llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(basic_align), is_volatile);
#else
- llvm::LoadInst* load = builder_->CreateAlignedLoad(
- ptr, basic_align, is_volatile);
+ llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile);
#endif
ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
AddAliasInfo(load, op->buffer_var.get(), PrimExpr(), t);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
- if (op->call_type == CallNode::Intrinsic ||
- op->call_type == CallNode::PureIntrinsic) {
+ if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
return CreateIntrinsic(op);
- } else if (op->call_type == CallNode::Extern ||
- op->call_type == CallNode::PureExtern) {
+ } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) {
return CreateCallExtern(op);
} else {
- LOG(FATAL) << "Unknown call type " <<
- "name= " << op->name <<
- " call_type= " << op->call_type;
+ LOG(FATAL) << "Unknown call type "
+ << "name= " << op->name << " call_type= " << op->call_type;
return nullptr;
}
}
llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
for (int i = 0; i < op->lanes; ++i) {
vec = builder_->CreateInsertElement(
- vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)),
- ConstInt32(i));
+ vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i));
}
return vec;
}
llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
- std::vector<llvm::Value *> vecs(op->vectors.size());
+ std::vector<llvm::Value*> vecs(op->vectors.size());
int total_lanes = 0;
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
vecs[i] = VisitExpr(op->vectors[i]);
llvm::Value* v0 = CreateVecConcat(vecs);
std::vector<uint32_t> idx(op->indices.size());
for (int i = 0, e = op->indices.size(); i < e; ++i) {
- const int64_t *val = as_const_int(op->indices[i]);
- CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, "
- << "but get " << op->indices[i] << "\n";
+ const int64_t* val = as_const_int(op->indices[i]);
+ CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, "
+ << "but get " << op->indices[i] << "\n";
idx[i] = *val;
}
llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
return;
} else {
// vector store
- unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
- buffer->getType())->getAddressSpace();
+ unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(buffer->getType())->getAddressSpace();
if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
CHECK_EQ(ramp->lanes, t.lanes());
- llvm::Value* ptr = CreateBufferPtr(
- t.element_of(), buffer, MakeValue(ramp->base));
+ llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
#if TVM_LLVM_VERSION >= 110
llvm::StoreInst* store =
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
#if TVM_LLVM_VERSION >= 110
llvm::StoreInst* store = builder_->CreateAlignedStore(
- builder_->CreateExtractElement(value, i),
- ptr, llvm::Align(basic_align), is_volatile);
+ builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile);
#else
- llvm::StoreInst* store = builder_->CreateAlignedStore(
- builder_->CreateExtractElement(value, i),
- ptr, basic_align, is_volatile);
+ llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i),
+ ptr, basic_align, is_volatile);
#endif
AddAliasInfo(store, op->buffer_var.get(), PrimExpr(), op->value.dtype());
};
CHECK(op->for_type == ForType::Serial);
}
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
- llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
- op->loop_var, op->body);
+ llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body);
}
-
void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
- BasicBlock* then_block = BasicBlock::Create(
- *ctx_, "if_then", function_);
- BasicBlock* end_block = BasicBlock::Create(
- *ctx_, "if_end", function_);
+ BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_);
+ BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_);
if (op->else_case.defined()) {
- BasicBlock* else_block = BasicBlock::Create(
- *ctx_, "if_else", function_);
+ BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_);
builder_->CreateCondBr(cond, then_block, else_block);
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->SetInsertPoint(end_block);
}
-
void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
- int32_t constant_size = op->constant_allocation_size();
- CHECK_GT(constant_size, 0)
- << "Can only handle constant size stack allocation";
- StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
- if (constant_size % 4 == 0 && info.alignment == 0) {
- info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
- }
- // maximum necessary alignment in the NV devices
- if (info.alignment > 16) {
- info.alignment = 16;
- }
- llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
- return builder_->CreateAlloca(
- DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
- });
- if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
+ int32_t constant_size = op->constant_allocation_size();
+ CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation";
+ StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
+ if (constant_size % 4 == 0 && info.alignment == 0) {
+ info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
+ }
+ // maximum necessary alignment in the NV devices
+ if (info.alignment > 16) {
+ info.alignment = 16;
+ }
+ llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
+ return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
+ });
+ if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
- alloca->setAlignment(llvm::Align(info.alignment));
+ alloca->setAlignment(llvm::Align(info.alignment));
#else
- alloca->setAlignment(info.alignment);
+ alloca->setAlignment(info.alignment);
#endif
- }
- info.alignment = alloca->getAlignment();
- buf = alloca;
+ }
+ info.alignment = alloca->getAlignment();
+ buf = alloca;
buf = builder_->CreatePointerCast(
- buf, DTypeToLLVMType(op->dtype)->getPointerTo(
- buf->getType()->getPointerAddressSpace()));
+ buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
} else if (op->attr_key == tir::attr::storage_alignment) {
const VarNode* v = op->node.as<VarNode>();
CHECK(v);
- alloc_storage_info_[v].alignment =
- static_cast<int>(op->value.as<IntImmNode>()->value);
+ alloc_storage_info_[v].alignment = static_cast<int>(op->value.as<IntImmNode>()->value);
} else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>();
CHECK(v);
}
}
-void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
- MakeValue(op->value);
-}
+void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); }
} // namespace codegen
} // namespace tvm
#define TVM_TARGET_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION
+#include <tvm/arith/analyzer.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/container.h>
-#include <tvm/arith/analyzer.h>
+#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt.h>
-#include <tvm/tir/op.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/target/codegen.h>
-
#include <memory>
-#include <utility>
-#include <vector>
#include <string>
#include <unordered_map>
#include <unordered_set>
-#include "llvm_common.h"
-#include "../../runtime/thread_storage_scope.h"
+#include <utility>
+#include <vector>
+
#include "../../arith/compute_expr.h"
+#include "../../runtime/thread_storage_scope.h"
#include "../../tir/transforms/ir_util.h"
+#include "llvm_common.h"
namespace tvm {
namespace codegen {
/*!
* \brief A base class to generate a LLVM.
*/
-class CodeGenLLVM :
- public ExprFunctor<llvm::Value* (const PrimExpr&)>,
- public StmtFunctor<void(const Stmt&)> {
+class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
+ public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Create new code generator based on target machine.
* \param dynamic_lookup Whether dynamically lookup runtime function
* or use the runtime function table passed by caller.
*/
- virtual void Init(const std::string& module_name,
- llvm::TargetMachine* tm,
- llvm::LLVMContext* ctx,
- bool system_lib,
- bool dynamic_lookup);
+ virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx,
+ bool system_lib, bool dynamic_lookup);
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
* \param e The expression to be created value for.
* \return created value.
*/
- llvm::Value* MakeValue(const PrimExpr& e) {
- return VisitExpr(e);
- }
+ llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); }
// Short hande code to get a constant int 32
llvm::Constant* ConstInt32(int64_t value) const {
return llvm::ConstantInt::getSigned(t_int32_, value);
* \tparam F The function to be executed.
* \return The result.
*/
- template<typename F>
+ template <typename F>
llvm::AllocaInst* WithFunctionEntry(F falloca) {
llvm::BasicBlock* current = builder_->GetInsertBlock();
llvm::BasicBlock* entry = &(function_->getEntryBlock());
virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
// Scalarize by iterating elements of e.
// f is a callback that takes index and v.
- virtual void Scalarize(const PrimExpr& e,
- std::function<void(int i, llvm::Value* v)> f);
+ virtual void Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f);
// Initialize target
virtual void InitTarget(llvm::TargetMachine* tm);
// Add module startup function if needed.
virtual unsigned GetGlobalAddressSpace() const;
void AddFunctionInternal(const PrimFunc& f, bool ret_void);
// Create extern call
- llvm::CallInst* CreateCallExtern(llvm::Type* ret,
- const std::string& name,
+ llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name,
const std::vector<llvm::Value*>& value);
/*!
* \brief Get the LLVM Type for a given runtime type.
* could not be generated (e.g. if the argument/return types do not
* match).
*/
- llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id,
- llvm::Type* ret_type,
+ llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
llvm::ArrayRef<llvm::Type*> arg_types);
// initialize the function state.
void InitFuncState();
// Get alignment given index.
- void GetAlignment(
- DataType t, const VarNode* buf_var, const PrimExpr& index,
- int* p_alignment, int* p_native_bits);
+ void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment,
+ int* p_native_bits);
// Get constant string
llvm::Value* GetConstString(const std::string& str);
// do a scalarize call with f
- llvm::Value* CreateScalarizedCall(
- const CallNode* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
+ llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f,
+ const std::vector<llvm::Value*>& args);
// handle module import
void HandleImport(const std::string& code);
// cast operatpr
llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
// Create serial for
- void CreateSerialFor(llvm::Value* begin,
- llvm::Value* end,
- llvm::Value* stride,
+ void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
const Var& loop_var, const Stmt& body);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index, DataType type);
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/device_api.h>
-#include "codegen_llvm.h"
-#include "../build_common.h"
+
#include "../../runtime/cuda/cuda_module.h"
+#include "../build_common.h"
+#include "codegen_llvm.h"
namespace tvm {
namespace codegen {
CodeGenLLVM::AddFunctionInternal(f, true);
// annotate as kernel function
module_->getOrInsertNamedMetadata("nvvm.annotations")
- ->addOperand(llvm::MDNode::get(*ctx_, {
- llvm::ValueAsMetadata::get(function_),
- llvm::MDString::get(*ctx_, "kernel"),
- llvm::ValueAsMetadata::get(ConstInt32(1)) }));
+ ->addOperand(llvm::MDNode::get(
+ *ctx_, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx_, "kernel"),
+ llvm::ValueAsMetadata::get(ConstInt32(1))}));
}
void VisitStmt_(const AllocateNode* op) final {
llvm::Value* buf = nullptr;
int32_t constant_size = op->constant_allocation_size();
- CHECK_GT(constant_size, 0)
- << "Can only handle constant size stack allocation in GPU";
+ CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
- return builder_->CreateAlloca(
- DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
- });
+ return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
+ });
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
alloca->setAlignment(llvm::Align(info.alignment));
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
- llvm::Type* type = llvm::ArrayType::get(
- DTypeToLLVMType(op->dtype), constant_size);
+ llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
- llvm::GlobalVariable *global = new llvm::GlobalVariable(
- *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
- nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
+ llvm::GlobalVariable* global = new llvm::GlobalVariable(
+ *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr,
+ llvm::GlobalValue::NotThreadLocal, shared_address_space);
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(info.alignment));
#else
}
buf = builder_->CreatePointerCast(
- buf, DTypeToLLVMType(op->dtype)->getPointerTo(
- buf->getType()->getPointerAddressSpace()));
+ buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
if (ts.rank == 1) {
switch (ts.dim_index) {
- case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; break;
- case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; break;
- case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break;
- default: LOG(FATAL) << "unknown thread idx";
+ case 0:
+ intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
+ break;
+ case 1:
+ intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y;
+ break;
+ case 2:
+ intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z;
+ break;
+ default:
+ LOG(FATAL) << "unknown thread idx";
}
} else {
CHECK_EQ(ts.rank, 0);
switch (ts.dim_index) {
- case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break;
- case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; break;
- case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break;
- default: LOG(FATAL) << "unknown thread idx";
+ case 0:
+ intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x;
+ break;
+ case 1:
+ intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y;
+ break;
+ case 2:
+ intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z;
+ break;
+ default:
+ LOG(FATAL) << "unknown thread idx";
}
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
// TODO(tqchen) warp sync in CUDA9
return nullptr;
} else if (sync == "shared") {
- llvm::Function* f = llvm::Intrinsic::getDeclaration(
- module_.get(),
- ::llvm::Intrinsic::nvvm_barrier0);
+ llvm::Function* f =
+ llvm::Intrinsic::getDeclaration(module_.get(), ::llvm::Intrinsic::nvvm_barrier0);
return builder_->CreateCall(f, {});
} else {
LOG(FATAL) << "Do not support sync " << sync;
tvm_ctx.device_type = kDLGPU;
tvm_ctx.device_id = 0;
TVMRetValue val;
- tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
- tvm_ctx, tvm::runtime::kExist, &val);
+ tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
if (val.operator int() == 1) {
- tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
- tvm_ctx, tvm::runtime::kComputeVersion, &val);
+ tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
std::string version = val;
std::istringstream is(version);
double ver;
runtime::Module BuildNVPTX(IRModule mod, std::string target) {
InitializeLLVM();
- CHECK(target.length() >= 5 &&
- target.substr(0, 5) == "nvptx");
+ CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx");
int compute_ver = DetectCUDAComputeVersion();
std::ostringstream config;
- config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_"
- << compute_ver
+ config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver
<< target.substr(5, target.length() - 5);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false);
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "Can only lower IR Module with PrimFuncs";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
cg->AddFunction(f);
}
- const auto* flibdevice_path =
- tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
+ const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
if (flibdevice_path != nullptr) {
std::string path = (*flibdevice_path)(compute_ver);
if (path.length() != 0) {
// emit ptx
llvm::legacy::PassManager pass;
#if TVM_LLVM_VERSION <= 60
- CHECK(tm->addPassesToEmitFile(
- pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
+ CHECK(tm->addPassesToEmitFile(pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
#elif TVM_LLVM_VERSION <= 90
- CHECK(tm->addPassesToEmitFile(
- pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
+ CHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) ==
+ 0)
<< "Cannot emit target CGFT_ObjectFile";
#else
- CHECK(tm->addPassesToEmitFile(
- pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0)
+ CHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
#endif
pass.run(*module);
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll);
}
-TVM_REGISTER_GLOBAL("target.build.nvptx")
-.set_body_typed(BuildNVPTX);
+TVM_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX);
} // namespace codegen
} // namespace tvm
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/registry.h>
-#include "codegen_cpu.h"
+#include "codegen_cpu.h"
#include "llvm/MC/MCSubtargetInfo.h"
namespace tvm {
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
DTypeToLLVMType(DataType::Float(32, from.lanes())),
{
- MakeValue(tir::CallNode::make(
- DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value},
- tir::CallNode::PureIntrinsic)),
- MakeValue(
- tir::BroadcastNode::make(
- FloatImm(DataType::Float(32), 0), from.lanes())),
- /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
- /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
+ MakeValue(tir::CallNode::make(DataType::Int(16, from.lanes()),
+ tir::CallNode::reinterpret, {op->value},
+ tir::CallNode::PureIntrinsic)),
+ MakeValue(tir::BroadcastNode::make(FloatImm(DataType::Float(32), 0), from.lanes())),
+ /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
+ /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
});
}
const auto has_f16c = TargetHasFeature(*target_machine_, "f16c");
if (from.lanes() >= 8 && has_f16c) {
- return CallVectorIntrin(
- ::llvm::Intrinsic::x86_vcvtph2ps_256, 8,
- DTypeToLLVMType(DataType::Float(32, from.lanes())),
- {MakeValue(tir::CallNode::make(
- DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value},
- tir::CallNode::PureIntrinsic))});
+ return CallVectorIntrin(::llvm::Intrinsic::x86_vcvtph2ps_256, 8,
+ DTypeToLLVMType(DataType::Float(32, from.lanes())),
+ {MakeValue(tir::CallNode::make(
+ DataType::Int(16, from.lanes()), tir::CallNode::reinterpret,
+ {op->value}, tir::CallNode::PureIntrinsic))});
}
#endif
}
}
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
-.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
- CodeGenLLVM* cg = new CodeGenX86_64();
- *rv = static_cast<void*>(cg);
- });
+ .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+ CodeGenLLVM* cg = new CodeGenX86_64();
+ *rv = static_cast<void*>(cg);
+ });
} // namespace codegen
} // namespace tvm
*/
#ifdef TVM_LLVM_VERSION
-#include <tvm/tir/op.h>
#include "intrin_rule_llvm.h"
+#include <tvm/tir/op.h>
+
namespace tvm {
namespace codegen {
namespace llvm {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
-.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>);
+ .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10")
-.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
- using tir::make_const;
- using tir::make_zero;
- PrimExpr e = targs[0];
- const tir::CallNode* call = e.as<tir::CallNode>();
- CHECK(call != nullptr);
- const PrimExpr& x = call->args[0];
- PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
- PrimExpr ret = tir::CallNode::make(
- x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic);
- *rv = ret;
-});
+ .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+ using tir::make_const;
+ using tir::make_zero;
+ PrimExpr e = targs[0];
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ CHECK(call != nullptr);
+ const PrimExpr& x = call->args[0];
+ PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
+ PrimExpr ret =
+ tir::CallNode::make(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic);
+ *rv = ret;
+ });
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
-.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
- using tir::make_const;
- using tir::make_zero;
- PrimExpr e = targs[0];
- const tir::CallNode* call = e.as<tir::CallNode>();
- CHECK(call != nullptr);
- const PrimExpr& x = call->args[0];
- PrimExpr one = make_const(x.dtype(), 1);
- PrimExpr two = make_const(x.dtype(), 2);
- PrimExpr neg_two = make_const(x.dtype(), -2);
-
- PrimExpr exp_neg2x = tir::CallNode::make(
- x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic);
- PrimExpr exp_pos2x = tir::CallNode::make(
- x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic);
-
- PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
- PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
- *rv = tir::SelectNode::make(
- x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
-});
+ .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+ using tir::make_const;
+ using tir::make_zero;
+ PrimExpr e = targs[0];
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ CHECK(call != nullptr);
+ const PrimExpr& x = call->args[0];
+ PrimExpr one = make_const(x.dtype(), 1);
+ PrimExpr two = make_const(x.dtype(), 2);
+ PrimExpr neg_two = make_const(x.dtype(), -2);
+
+ PrimExpr exp_neg2x =
+ tir::CallNode::make(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic);
+ PrimExpr exp_pos2x =
+ tir::CallNode::make(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic);
+
+ PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
+ PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
+ *rv = tir::SelectNode::make(x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
+ });
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan")
-.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
- PrimExpr sin_x = tir::CallNode::make(
- x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic);
- PrimExpr cos_x = tir::CallNode::make(
- x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic);
+ PrimExpr sin_x = tir::CallNode::make(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic);
+ PrimExpr cos_x = tir::CallNode::make(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic);
PrimExpr tan_x = sin_x / cos_x;
*rv = tan_x;
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh")
-.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
- using tir::make_const;
- using tir::make_zero;
- PrimExpr e = targs[0];
- const tir::CallNode* call = e.as<tir::CallNode>();
- CHECK(call != nullptr);
- const PrimExpr& x = call->args[0];
- PrimExpr two = make_const(x.dtype(), 2);
- PrimExpr neg_one = make_const(x.dtype(), -1);
- PrimExpr exp_negx = tir::CallNode::make(
- x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
- PrimExpr exp_posx = tir::CallNode::make(
- x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
- PrimExpr ret = (exp_posx + exp_negx) / two;
- *rv = ret;
-});
+ .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+ using tir::make_const;
+ using tir::make_zero;
+ PrimExpr e = targs[0];
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ CHECK(call != nullptr);
+ const PrimExpr& x = call->args[0];
+ PrimExpr two = make_const(x.dtype(), 2);
+ PrimExpr neg_one = make_const(x.dtype(), -1);
+ PrimExpr exp_negx =
+ tir::CallNode::make(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
+ PrimExpr exp_posx = tir::CallNode::make(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
+ PrimExpr ret = (exp_posx + exp_negx) / two;
+ *rv = ret;
+ });
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
+ .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh")
-.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
- using tir::make_const;
- using tir::make_zero;
- PrimExpr e = targs[0];
- const tir::CallNode* call = e.as<tir::CallNode>();
- CHECK(call != nullptr);
- const PrimExpr& x = call->args[0];
- PrimExpr two = make_const(x.dtype(), 2);
- PrimExpr neg_one = make_const(x.dtype(), -1);
- PrimExpr exp_negx = tir::CallNode::make(
- x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
- PrimExpr exp_posx = tir::CallNode::make(
- x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
- PrimExpr ret = (exp_posx - exp_negx) / two;
- *rv = ret;
-});
+ .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+ using tir::make_const;
+ using tir::make_zero;
+ PrimExpr e = targs[0];
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ CHECK(call != nullptr);
+ const PrimExpr& x = call->args[0];
+ PrimExpr two = make_const(x.dtype(), 2);
+ PrimExpr neg_one = make_const(x.dtype(), -1);
+ PrimExpr exp_negx =
+ tir::CallNode::make(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
+ PrimExpr exp_posx = tir::CallNode::make(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
+ PrimExpr ret = (exp_posx - exp_negx) / two;
+ *rv = ret;
+ });
} // namespace llvm
} // namespace codegen
#define TVM_TARGET_LLVM_INTRIN_RULE_LLVM_H_
#ifdef TVM_LLVM_VERSION
-#include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h>
-
#include <tvm/target/codegen.h>
+#include <tvm/tir/expr.h>
+
#include <string>
+
#include "llvm_common.h"
namespace tvm {
namespace codegen {
// num_signature means number of arguments used to query signature
-template<unsigned id, int num_signature>
+template <unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
- *rv = tir::CallNode::make(
- call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic);
+ *rv = tir::CallNode::make(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic);
}
-template<unsigned id, int num_signature>
+template <unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
- *rv = tir::CallNode::make(
- call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic);
+ *rv = tir::CallNode::make(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic);
}
} // namespace codegen
*/
#ifdef TVM_LLVM_VERSION
-#include <tvm/tir/expr.h>
-#include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+
#include <sstream>
namespace tvm {
std::ostringstream intrinsic_name;
intrinsic_name << "__nv_" << call->name;
if (call->dtype.bits() == 32) intrinsic_name << "f";
- *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
- CallNode::PureExtern);
+ *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern);
}
namespace llvm {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan")
-.set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchExternLibDevice);
} // namespace llvm
} // namespace codegen
*/
#ifdef TVM_LLVM_VERSION
-#include <tvm/tir/expr.h>
-#include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
#include <sstream>
CHECK(call != nullptr);
std::ostringstream intrinsic_name;
intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits();
- *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
- CallNode::PureExtern);
+ *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern);
}
namespace llvm {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan")
-.set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchExternOCML);
} // namespace llvm
} // namespace codegen
*/
#ifdef TVM_LLVM_VERSION
+#include "llvm_common.h"
+
#include <dmlc/logging.h>
+
#include <atomic>
-#include <mutex>
#include <memory>
-#include "llvm_common.h"
+#include <mutex>
namespace tvm {
namespace codegen {
}
}
-void ParseLLVMTargetOptions(const std::string& target_str,
- std::string* triple,
- std::string* mcpu,
- std::string* mattr,
- llvm::TargetOptions* options) {
+void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu,
+ std::string* mattr, llvm::TargetOptions* options) {
// setup target triple
size_t start = 0;
- if (target_str.length() >= 4 &&
- target_str.substr(0, 4) == "llvm") {
+ if (target_str.length() >= 4 && target_str.substr(0, 4) == "llvm") {
start = 4;
}
// simple parser
}
size_t pos = key.find('=');
if (pos != std::string::npos) {
- CHECK_GE(key.length(), pos + 1)
- << "invalid argument " << key;
+ CHECK_GE(key.length(), pos + 1) << "invalid argument " << key;
value = key.substr(pos + 1, key.length() - 1);
key = key.substr(0, pos);
} else {
- CHECK(is >> value)
- << "Unspecified value for option " << key;
+ CHECK(is >> value) << "Unspecified value for option " << key;
}
- if (key == "-target" ||
- key == "-mtriple") {
+ if (key == "-target" || key == "-mtriple") {
*triple = value;
} else if (key == "-mcpu") {
*mcpu = value;
}
}
- if (triple->length() == 0 ||
- *triple == "default") {
+ if (triple->length() == 0 || *triple == "default") {
*triple = llvm::sys::getDefaultTargetTriple();
}
// set target option
llvm::TargetOptions& opt = *options;
opt = llvm::TargetOptions();
- #if TVM_LLVM_VERSION < 50
+#if TVM_LLVM_VERSION < 50
opt.LessPreciseFPMADOption = true;
- #endif
+#endif
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
}
}
-
-std::unique_ptr<llvm::TargetMachine>
-GetLLVMTargetMachine(const std::string& target_str,
- bool allow_null) {
+std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const std::string& target_str,
+ bool allow_null) {
std::string target_triple, mcpu, mattr;
llvm::TargetOptions opt;
- ParseLLVMTargetOptions(target_str,
- &target_triple,
- &mcpu,
- &mattr,
- &opt);
+ ParseLLVMTargetOptions(target_str, &target_triple, &mcpu, &mattr, &opt);
- if (target_triple.length() == 0 ||
- target_triple == "default") {
+ if (target_triple.length() == 0 || target_triple == "default") {
target_triple = llvm::sys::getDefaultTargetTriple();
}
if (mcpu.length() == 0) {
}
std::string err;
- const llvm::Target* target =
- llvm::TargetRegistry::lookupTarget(target_triple, err);
+ const llvm::Target* target = llvm::TargetRegistry::lookupTarget(target_triple, err);
if (target == nullptr) {
CHECK(allow_null) << err << " target_triple=" << target_triple;
return nullptr;
}
- llvm::TargetMachine* tm = target->createTargetMachine(
- target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);
+ llvm::TargetMachine* tm =
+ target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);
return std::unique_ptr<llvm::TargetMachine>(tm);
}
#define TVM_TARGET_LLVM_LLVM_COMMON_H_
#ifdef TVM_LLVM_VERSION
-#include <llvm/ExecutionEngine/MCJIT.h>
-
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/Bitcode/BitcodeWriter.h>
-#include <llvm/Support/SourceMgr.h>
-
-#include <llvm/IR/Value.h>
+#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/IR/Intrinsics.h>
+#include <llvm/IR/Value.h>
+#include <llvm/Support/SourceMgr.h>
#if TVM_LLVM_VERSION >= 100
#include <llvm/IR/IntrinsicsAMDGPU.h>
#include <llvm/IR/IntrinsicsARM.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Constants.h>
-#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/DIBuilder.h>
+#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/LegacyPassManager.h>
+#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
-#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/Verifier.h>
-
-#include <llvm/IR/LegacyPassManager.h>
+#include <llvm/Transforms/IPO.h>
+#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/Utils/Cloning.h>
#include <llvm/Transforms/Utils/ModuleUtils.h>
-#include <llvm/Transforms/IPO/PassManagerBuilder.h>
-#include <llvm/Transforms/IPO.h>
#if TVM_LLVM_VERSION >= 100
#include <llvm/Support/Alignment.h>
#endif
+#include <llvm/CodeGen/TargetLoweringObjectFileImpl.h>
+#include <llvm/IRReader/IRReader.h>
+#include <llvm/Linker/Linker.h>
+#include <llvm/Support/Casting.h>
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/Host.h>
#include <llvm/Support/MemoryBuffer.h>
-#include <llvm/Support/raw_ostream.h>
-#include <llvm/Support/Casting.h>
#include <llvm/Support/TargetRegistry.h>
#include <llvm/Support/TargetSelect.h>
+#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h>
-#include <llvm/IRReader/IRReader.h>
-#include <llvm/CodeGen/TargetLoweringObjectFileImpl.h>
-
-#include <llvm/Linker/Linker.h>
-#include <utility>
-#include <string>
#include <memory>
+#include <string>
+#include <utility>
namespace tvm {
namespace codegen {
* \param options the options
* \param mattr The attributes
*/
-void ParseLLVMTargetOptions(const std::string& target_str,
- std::string* triple,
- std::string* mcpu,
- std::string* mattr,
- llvm::TargetOptions* options);
+void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu,
+ std::string* mattr, llvm::TargetOptions* options);
/*!
* \brief Get target machine from target_str string.
* \param allow_null Whether allow null to be returned.
* \return target machine
*/
-std::unique_ptr<llvm::TargetMachine>
-GetLLVMTargetMachine(const std::string& target_str, bool allow_null = false);
+std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const std::string& target_str,
+ bool allow_null = false);
} // namespace codegen
} // namespace tvm
*/
#ifdef TVM_LLVM_VERSION
+#include <tvm/ir/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include <tvm/ir/module.h>
#include <tvm/target/codegen.h>
+
#include <mutex>
-#include "llvm_common.h"
-#include "codegen_llvm.h"
-#include "codegen_blob.h"
+
#include "../../runtime/file_util.h"
#include "../../runtime/library_module.h"
+#include "codegen_blob.h"
+#include "codegen_llvm.h"
+#include "llvm_common.h"
namespace tvm {
namespace codegen {
+using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
-using runtime::PackedFunc;
class LLVMModuleNode final : public runtime::ModuleNode {
public:
}
}
- const char* type_key() const {
- return "llvm";
- }
+ const char* type_key() const { return "llvm"; }
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "__tvm_is_system_module") {
- bool flag =
- (mptr_->getFunction("__tvm_module_startup") != nullptr);
- return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) {
- * rv = flag;
- });
+ bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr);
+ return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; });
} else if (name == "_get_target_triple") {
std::string target_triple = tm_->getTargetTriple().str();
- return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) {
- *rv = target_triple;
- });
+ return PackedFunc([target_triple](TVMArgs args, TVMRetValue* rv) { *rv = target_triple; });
}
if (ee_ == nullptr) LazyInitJIT();
TVMBackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
- const char* entry_name = reinterpret_cast<const char*>(
- GetGlobalAddr(runtime::symbol::tvm_module_main));
+ const char* entry_name =
+ reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main));
CHECK(entry_name != nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(entry_name));
return WrapPackedFunc(faddr, sptr_to_self);
}
- void SaveToFile(const std::string& file_name,
- const std::string& format) final {
+ void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = runtime::GetFileFormat(file_name, format);
std::error_code ecode;
llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None);
- CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name
- << " " << ecode.message();
+ CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message();
if (fmt == "o" || fmt == "obj") {
#if TVM_LLVM_VERSION <= 60
std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
llvm::legacy::PassManager pass;
CHECK(tm_);
#if TVM_LLVM_VERSION <= 60
- CHECK(tm_->addPassesToEmitFile(
- pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0)
+ CHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
#elif TVM_LLVM_VERSION <= 90
- CHECK(tm_->addPassesToEmitFile(
- pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0)
+ CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) ==
+ 0)
<< "Cannot emit target CGFT_ObjectFile";
#else
- CHECK(tm_->addPassesToEmitFile(
- pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0)
+ CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
#endif
pass.run(*m);
llvm::legacy::PassManager pass;
CHECK(tm_);
#if TVM_LLVM_VERSION <= 60
- CHECK(tm_->addPassesToEmitFile(
- pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
+ CHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#elif TVM_LLVM_VERSION <= 90
- CHECK(tm_->addPassesToEmitFile(
- pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
+ CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) ==
+ 0)
<< "Cannot emit target CGFT_AssemblyFile";
#else
- CHECK(tm_->addPassesToEmitFile(
- pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0)
+ CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#endif
pass.run(*m);
llvm::WriteBitcodeToFile(*mptr_, dest);
#endif
} else {
- LOG(FATAL) << "Do not know how to save file "
- << file_name << " with format=\'"<< format << "\'";
+ LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format
+ << "\'";
}
dest.close();
}
llvm::raw_svector_ostream rso(str);
if (fmt == "s" || fmt == "asm") {
- #if TVM_LLVM_VERSION <= 60
- std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
- #else
- std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
- #endif
- llvm::legacy::PassManager pass;
- CHECK(tm_);
- #if TVM_LLVM_VERSION <= 60
- CHECK(tm_->addPassesToEmitFile(
- pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
- << "Cannot emit target CGFT_AssemblyFile";
- #elif TVM_LLVM_VERSION <= 90
- CHECK(tm_->addPassesToEmitFile(
- pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
- << "Cannot emit target CGFT_AssemblyFile";
- #else
- CHECK(tm_->addPassesToEmitFile(
- pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0)
- << "Cannot emit target CGFT_AssemblyFile";
- #endif
- pass.run(*m);
- return rso.str().str();
+#if TVM_LLVM_VERSION <= 60
+ std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
+#else
+ std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
+#endif
+ llvm::legacy::PassManager pass;
+ CHECK(tm_);
+#if TVM_LLVM_VERSION <= 60
+ CHECK(tm_->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
+ << "Cannot emit target CGFT_AssemblyFile";
+#elif TVM_LLVM_VERSION <= 90
+ CHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) ==
+ 0)
+ << "Cannot emit target CGFT_AssemblyFile";
+#else
+ CHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0)
+ << "Cannot emit target CGFT_AssemblyFile";
+#endif
+ pass.run(*m);
+ return rso.str().str();
} else if (fmt == "" || fmt == "ll") {
std::string type_str;
llvm::raw_string_ostream rso(type_str);
mptr_->print(rso, nullptr);
return rso.str();
} else {
- LOG(FATAL) << "Do not know how to get source code with format: "
- << format << "\'";
+ LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'";
}
return "";
}
std::vector<PrimFunc> funcs;
std::string entry_func;
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "Can only lower IR Module with PrimFuncs";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
mptr_ = module_.get();
}
- void Init(std::unique_ptr<llvm::Module> module,
- std::shared_ptr<llvm::LLVMContext> ctx) {
+ void Init(std::unique_ptr<llvm::Module> module, std::shared_ptr<llvm::LLVMContext> ctx) {
InitializeLLVM();
ctx_ = ctx;
llvm::SMDiagnostic err;
CHECK(layout == mptr_->getDataLayout())
<< "Data layout mismatch between module("
<< mptr_->getDataLayout().getStringRepresentation() << ")"
- << " and ExecutionEngine ("
- << layout.getStringRepresentation() << ")";
+ << " and ExecutionEngine (" << layout.getStringRepresentation() << ")";
ee_ = builder.create(tm.release());
- CHECK(ee_ != nullptr)
- << "Failed to initialize jit engine for " << mptr_->getTargetTriple();
+ CHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << mptr_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false);
- if (void** ctx_addr = reinterpret_cast<void**>(
- GetGlobalAddr(runtime::symbol::tvm_module_ctx))) {
+ if (void** ctx_addr =
+ reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
}
- runtime::InitContextFunctions([this](const char *name) {
- return reinterpret_cast<void*>(GetGlobalAddr(name));
- });
+ runtime::InitContextFunctions(
+ [this](const char* name) { return reinterpret_cast<void*>(GetGlobalAddr(name)); });
}
// Get global address from execution engine.
uint64_t GetGlobalAddr(const std::string& name) const {
// JIT lock
std::mutex mutex_;
// execution engine
- llvm::ExecutionEngine *ee_{nullptr};
+ llvm::ExecutionEngine* ee_{nullptr};
// The raw pointer to the module.
llvm::Module* mptr_{nullptr};
// The target machine
return llvm::Function::lookupIntrinsicID(name);
}
-
-TVM_REGISTER_GLOBAL("target.build.llvm")
-.set_body_typed([](IRModule mod, std::string target) {
+TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed([](IRModule mod, std::string target) {
auto n = make_object<LLVMModuleNode>();
n->Init(mod, target);
return runtime::Module(n);
});
-
-TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<LLVMModuleNode>();
auto target = args[0].operator std::string();
auto module_name = args[1].operator std::string();
*rv = runtime::Module(n);
});
-TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
- });
-
-TVM_REGISTER_GLOBAL("target.llvm_version_major")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- int major = TVM_LLVM_VERSION / 10;
- *rv = major;
- });
-
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- auto n = make_object<LLVMModuleNode>();
- n->LoadIR(args[0]);
- *rv = runtime::Module(n);
- });
-
-TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- InitializeLLVM();
- *rv = (GetLLVMTargetMachine(args[0], true) != nullptr);
- });
+TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
+});
+
+TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body([](TVMArgs args, TVMRetValue* rv) {
+ int major = TVM_LLVM_VERSION / 10;
+ *rv = major;
+});
+
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll").set_body([](TVMArgs args, TVMRetValue* rv) {
+ auto n = make_object<LLVMModuleNode>();
+ n->LoadIR(args[0]);
+ *rv = runtime::Module(n);
+});
+
+TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled").set_body([](TVMArgs args, TVMRetValue* rv) {
+ InitializeLLVM();
+ *rv = (GetLLVMTargetMachine(args[0], true) != nullptr);
+});
-TVM_REGISTER_GLOBAL("codegen.codegen_blob")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("codegen.codegen_blob").set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<LLVMModuleNode>();
- auto p = CodeGenBlob(args[0].operator std::string(),
- args[1].operator bool(),
+ auto p = CodeGenBlob(args[0].operator std::string(), args[1].operator bool(),
args[2].operator std::string());
n->Init(std::move(p.first), p.second);
*rv = runtime::Module(n);
/*!
* Optional module when build aocl is switched to off
*/
-#include "../source/codegen_source_base.h"
#include "../../runtime/opencl/opencl_module.h"
+#include "../source/codegen_source_base.h"
namespace tvm {
namespace runtime {
-Module AOCLModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source) {
+Module AOCLModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
LOG(WARNING) << "AOCL runtime not enabled, return a source module...";
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "aocl");
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
namespace tvm {
namespace runtime {
-Module CUDAModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string cuda_source) {
+Module CUDAModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap,
+ std::string cuda_source) {
LOG(FATAL) << "CUDA is not enabled";
return Module();
}
#include <sys/stat.h>
#endif
#include <cuda_runtime.h>
-
#include <nvrtc.h>
+
#include <cstdlib>
-#include "../build_common.h"
-#include "../source/codegen_cuda.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/cuda/cuda_module.h"
-
+#include "../build_common.h"
+#include "../source/codegen_cuda.h"
namespace tvm {
namespace codegen {
-#define NVRTC_CALL(x) \
- { \
- nvrtcResult result = x; \
- if (result != NVRTC_SUCCESS) { \
- LOG(FATAL) \
- << "NvrtcError: " #x " failed with error: " \
- << nvrtcGetErrorString(result); \
- } \
+#define NVRTC_CALL(x) \
+ { \
+ nvrtcResult result = x; \
+ if (result != NVRTC_SUCCESS) { \
+ LOG(FATAL) << "NvrtcError: " #x " failed with error: " << nvrtcGetErrorString(result); \
+ } \
}
-
std::string FindCUDAIncludePath() {
#if defined(_WIN32)
const std::string delimiter = "\\";
return cuda_include_path;
}
-
std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
}
for (const auto& string : compile_params) {
- param_cstrings.push_back(string.c_str());
+ param_cstrings.push_back(string.c_str());
}
- NVRTC_CALL(nvrtcCreateProgram(
- &prog, code.c_str(), nullptr, 0, nullptr, nullptr));
- nvrtcResult compile_res =
- nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
+ NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
+ nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
- std::string log; log.resize(log_size);
+ std::string log;
+ log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
CHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
CodeGenCUDA cg;
cg.Init(output_ssa);
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodeGenCUDA: Can only take PrimFunc";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
-TVM_REGISTER_GLOBAL("target.build.cuda")
-.set_body_typed(BuildCUDA);
+TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);
} // namespace codegen
} // namespace tvm
namespace runtime {
Module HexagonModuleCreate(std::string data, std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string asm_str, std::string obj_str,
- std::string ir_str, std::string bc_str,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string asm_str,
+ std::string obj_str, std::string ir_str, std::string bc_str,
const std::set<std::string>& packed_c_abi) {
LOG(WARNING) << "Hexagon runtime is not enabled, return a source module...";
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hex");
/*!
* Optional module when build metal is switched to off
*/
-#include "../source/codegen_source_base.h"
#include "../../runtime/metal/metal_module.h"
+#include "../source/codegen_source_base.h"
namespace tvm {
namespace runtime {
-Module MetalModuleCreate(std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source) {
+Module MetalModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
LOG(WARNING) << "Metal runtime not enabled, return a source module...";
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "metal");
}
/*!
* Optional module when build opencl is switched to off
*/
-#include "../source/codegen_source_base.h"
#include "../../runtime/opencl/opencl_module.h"
+#include "../source/codegen_source_base.h"
namespace tvm {
namespace runtime {
-Module OpenCLModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source) {
+Module OpenCLModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl");
}
/*!
* Optional module when build opencl is switched to off
*/
-#include "../source/codegen_source_base.h"
#include "../../runtime/opengl/opengl_module.h"
+#include "../source/codegen_source_base.h"
namespace tvm {
namespace runtime {
-Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders,
- std::string fmt,
+Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap) {
LOG(WARNING) << "OpenGL runtime not enabled, return a source module...";
auto data = ToJSON(shaders);
/*!
* Optional module when build rocm is switched to off
*/
-#include "../source/codegen_source_base.h"
#include "../../runtime/rocm/rocm_module.h"
+#include "../source/codegen_source_base.h"
namespace tvm {
namespace runtime {
-Module ROCMModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string rocm_source,
- std::string assembly) {
-
+Module ROCMModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string rocm_source,
+ std::string assembly) {
LOG(WARNING) << "ROCM runtime is not enabled, return a source module...";
auto fget_source = [rocm_source, assembly](const std::string& format) {
if (format.length() == 0) return assembly;
if (format == "asm") return assembly;
return std::string("");
};
- return codegen::DeviceSourceModuleCreate(
- data, fmt, fmap, "hsaco", fget_source);
+ return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hsaco", fget_source);
}
} // namespace runtime
/*!
* Optional module when build opencl is switched to off
*/
-#include "../source/codegen_source_base.h"
#include "../../runtime/opencl/opencl_module.h"
+#include "../source/codegen_source_base.h"
namespace tvm {
namespace runtime {
-Module SDAccelModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source) {
+Module SDAccelModuleCreate(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
LOG(WARNING) << "OpenCL runtime not enabled, return a source module...";
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "sdaccel");
}
* \file codegen_aocl.cc
*/
#include <tvm/target/target.h>
-#include <vector>
+
#include <string>
-#include "codegen_opencl.h"
-#include "../build_common.h"
-#include "../../runtime/opencl/aocl/aocl_module.h"
+#include <vector>
+
#include "../../runtime/file_util.h"
+#include "../../runtime/opencl/aocl/aocl_module.h"
+#include "../build_common.h"
+#include "codegen_opencl.h"
namespace tvm {
namespace codegen {
-runtime::Module BuildAOCL(IRModule mod,
- std::string target_str,
- bool emulation) {
+runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) {
// Get code.
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
cg.Init(output_ssa);
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodegenOpenCL: Can only take PrimFunc";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code);
}
-TVM_REGISTER_GLOBAL("target.build.aocl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildAOCL(args[0], args[1], false);
- });
+TVM_REGISTER_GLOBAL("target.build.aocl").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = BuildAOCL(args[0], args[1], false);
+});
-TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildAOCL(args[0], args[1], true);
- });
+TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = BuildAOCL(args[0], args[1], true);
+});
} // namespace codegen
} // namespace tvm
/*!
* \file codegen_c.cc
*/
-#include <iomanip>
-#include <cctype>
#include "codegen_c.h"
-#include "../../arith/pattern_match.h"
+
+#include <cctype>
+#include <iomanip>
+
#include "../../arith/compute_expr.h"
+#include "../../arith/pattern_match.h"
namespace tvm {
namespace codegen {
using namespace tir;
-void CodeGenC::Init(bool output_ssa) {
- print_ssa_form_ = output_ssa;
-}
+void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; }
void CodeGenC::InitFuncState(const PrimFunc& f) {
alloc_storage_scope_.clear();
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- CHECK(global_symbol.defined())
- << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
+ CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix();
this->stream << "}\n\n";
}
-void CodeGenC::PrintFuncPrefix() {
- stream << "void";
-}
+void CodeGenC::PrintFuncPrefix() { stream << "void"; }
-void CodeGenC::PrintFinalReturn() {
-}
+void CodeGenC::PrintFinalReturn() {}
-std::string CodeGenC::Finish() {
- return decl_stream.str() + stream.str();
-}
+std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*)
if (print_ssa_form_) {
}
}
-void CodeGenC::PrintSSAAssign(
- const std::string& target, const std::string& src, DataType t) {
+void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, DataType t) {
PrintType(t, stream);
stream << ' ' << target << " = ";
- if (src.length() > 3 &&
- src[0] == '(' && src[src.length() - 1] == ')') {
+ if (src.length() > 3 && src[0] == '(' && src[src.length() - 1] == ')') {
stream << src.substr(1, src.length() - 2);
} else {
stream << src;
}
// Print a reference expression to a buffer.
-std::string CodeGenC::GetBufferRef(
- DataType t, const VarNode* buffer, PrimExpr index) {
+std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) {
std::ostringstream os;
std::string vid = GetVarID(buffer);
std::string scope;
os << "[(";
PrintExpr(index, os);
os << ")";
- if (t.bits() == 4 ||
- (t.bits() == 1 && t.is_int())) {
+ if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
os << " / " << (32 / t.bits());
}
os << ']';
// optimize for constant access
if (auto* ptr = index.as<tir::IntImmNode>()) {
int64_t offset = ptr->value;
- CHECK_EQ(offset % t.lanes(), 0)
- << "Find unaligned vector load to a vector type";
+ CHECK_EQ(offset % t.lanes(), 0) << "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
return os.str();
}
os << vid << " + (";
PrintExpr(index, os);
os << ")";
- if (t.bits() == 4 ||
- (t.bits() == 1 && t.is_int())) {
+ if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
os << " / " << (32 / t.bits());
}
os << "))[0]";
}
// Print a reference expression to a buffer.
-std::string CodeGenC::GetStructRef(
- DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) {
+std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index,
+ int kind) {
if (kind < intrinsic::kArrKindBound_) {
std::ostringstream os;
os << "(((DLTensor*)";
os << "].";
// other case: get fields.
switch (kind) {
- case intrinsic::kArrData: os << "data"; break;
- case intrinsic::kArrShape: os << "shape"; break;
- case intrinsic::kArrStrides: os << "strides"; break;
- case intrinsic::kArrNDim: os << "ndim"; break;
- case intrinsic::kArrTypeCode: os << "dtype.code"; break;
- case intrinsic::kArrTypeBits: os << "dtype.bits"; break;
- case intrinsic::kArrByteOffset: os << "byte_offset"; break;
- case intrinsic::kArrTypeLanes: os << "dtype.lanes"; break;
- case intrinsic::kArrDeviceId: os << "ctx.device_id"; break;
- case intrinsic::kArrDeviceType: os << "ctx.device_type"; break;
- default: LOG(FATAL) << "unknown field code";
+ case intrinsic::kArrData:
+ os << "data";
+ break;
+ case intrinsic::kArrShape:
+ os << "shape";
+ break;
+ case intrinsic::kArrStrides:
+ os << "strides";
+ break;
+ case intrinsic::kArrNDim:
+ os << "ndim";
+ break;
+ case intrinsic::kArrTypeCode:
+ os << "dtype.code";
+ break;
+ case intrinsic::kArrTypeBits:
+ os << "dtype.bits";
+ break;
+ case intrinsic::kArrByteOffset:
+ os << "byte_offset";
+ break;
+ case intrinsic::kArrTypeLanes:
+ os << "dtype.lanes";
+ break;
+ case intrinsic::kArrDeviceId:
+ os << "ctx.device_id";
+ break;
+ case intrinsic::kArrDeviceType:
+ os << "ctx.device_type";
+ break;
+ default:
+ LOG(FATAL) << "unknown field code";
}
os << ')';
return os.str();
if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t;
} else {
- CHECK(it->second == t)
- << "conflicting buf var type";
+ CHECK(it->second == t) << "conflicting buf var type";
}
}
-void CodeGenC::PrintVecElemLoad(const std::string& vec,
- DataType t, int i,
+void CodeGenC::PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i << std::dec;
}
-void CodeGenC::PrintVecElemStore(const std::string& vec,
- DataType t, int i,
+void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
this->PrintIndent();
- stream << vec << ".s" << std::hex << i
- << " = " << value << ";\n" << std::dec;
+ stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec;
}
-std::string CodeGenC::GetVecLoad(
- DataType t, const VarNode* buffer, PrimExpr base) {
+std::string CodeGenC::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) {
return GetBufferRef(t, buffer, base);
}
-void CodeGenC::PrintVecStore(const VarNode* buffer,
- DataType t, PrimExpr base,
+void CodeGenC::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base,
const std::string& value) {
std::string ref = GetBufferRef(t, buffer, base);
this->PrintIndent();
return os.str();
}
-void CodeGenC::BindThreadIndex(const IterVar& iv) {
- LOG(FATAL) << "not implemented";
-}
+void CodeGenC::BindThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented"; }
-void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*)
+void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*)
}
-void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
+void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_EQ(scope, "global");
}
void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
- CHECK_EQ(t.lanes(), 1)
- << "do not yet support vector types";
+ CHECK_EQ(t.lanes(), 1) << "do not yet support vector types";
if (t.is_handle()) {
- os << "void*"; return;
+ os << "void*";
+ return;
}
if (t.is_float()) {
if (t.bits() == 32) {
- os << "float"; return;
+ os << "float";
+ return;
}
if (t.bits() == 64) {
- os << "double"; return;
+ os << "double";
+ return;
}
} else if (t.is_uint()) {
switch (t.bits()) {
- case 8: case 16: case 32: case 64: {
- os << "uint" << t.bits() << "_t"; return;
+ case 8:
+ case 16:
+ case 32:
+ case 64: {
+ os << "uint" << t.bits() << "_t";
+ return;
}
- case 1: os << "int"; return;
+ case 1:
+ os << "int";
+ return;
}
} else if (t.is_int()) {
switch (t.bits()) {
- case 8: case 16: case 32: case 64: {
- os << "int" << t.bits() << "_t"; return;
+ case 8:
+ case 16:
+ case 32:
+ case 64: {
+ os << "int" << t.bits() << "_t";
+ return;
}
}
}
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
-
-void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*)
+void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*)
if (auto* ptr = type.as<PrimTypeNode>()) {
return PrintType(ptr->dtype, os);
} else if (auto* ptr = type.as<PointerTypeNode>()) {
}
}
-
-inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->dtype == DataType::Int(32)) {
std::ostringstream temp;
temp << op->value;
}
}
-
-inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os,
+ CodeGenC* p) { // NOLINT(*)
if (dtype == DataType::UInt(32)) {
std::ostringstream temp;
temp << val << "U";
}
}
-inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
switch (op->dtype.bits()) {
- case 64: case 32: {
+ case 64:
+ case 32: {
std::ostringstream temp;
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) temp << 'f';
case 16: {
os << '(';
p->PrintType(op->dtype, os);
- os << ')' << std::scientific <<op->value << 'f';
+ os << ')' << std::scientific << op->value << 'f';
break;
}
- default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
+ default:
+ LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
PrintConst(op, os, this);
}
-void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
-void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
os << "\"" << op->value << "\"";
}
-template<typename T>
-inline void PrintBinaryExpr(const T* op,
- const char* opstr,
+template <typename T>
+inline void PrintBinaryExpr(const T* op, const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
if (op->dtype.lanes() == 1) {
}
}
-inline void PrintBinaryIntrinsic(const CallNode* op,
- const char* opstr,
- std::ostream& os, // NOLINT(*)
- CodeGenC* p) {
+inline void PrintBinaryIntrinsic(const CallNode* op, const char* opstr,
+ std::ostream& os, // NOLINT(*)
+ CodeGenC* p) {
if (op->dtype.lanes() == 1) {
CHECK_EQ(op->args.size(), 2U);
os << '(';
}
void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
- if (op->call_type == CallNode::Extern ||
- op->call_type == CallNode::PureExtern) {
+ if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) {
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
this->PrintExpr(op->args[i], os);
PrintExpr(op->args[2], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const LoadNode *l = op->args[0].as<LoadNode>();
+ const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
os << "((";
this->PrintType(l->dtype.element_of(), os);
- os << " *)" << this->GetVarID(l->buffer_var.get())
- << " + ";
+ os << " *)" << this->GetVarID(l->buffer_var.get()) << " + ";
this->PrintExpr(l->index, os);
os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
- os << GetStructRef(
- op->dtype, op->args[0], op->args[1],
- op->args[2].as<IntImmNode>()->value);
+ os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as<IntImmNode>()->value);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
os << "(";
this->PrintExpr(op->args[0], os);
os << ")";
} else {
- if (op->call_type == CallNode::Intrinsic ||
- op->call_type == CallNode::PureIntrinsic) {
- LOG(FATAL) << "Unresolved intrinsic " << op->name
- << " with return type " << op->dtype;
+ if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
+ LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype;
} else {
LOG(FATAL) << "Unresolved call type " << op->call_type;
}
}
}
-void CodeGenC::PrintVecBinaryOp(
- const std::string& op, DataType t,
- PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*)
+void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
+ std::ostream& os) { // NOLINT(*)
if (isalpha(op[0])) {
os << op << "(";
this->PrintExpr(lhs, os);
this->PrintExpr(rhs, os);
os << ")";
} else {
- os <<"(";
+ os << "(";
this->PrintExpr(lhs, os);
os << ' ' << op << ' ';
this->PrintExpr(rhs, os);
std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index);
HandleVolatileLoads(ref, op, os);
} else {
- CHECK(is_one(op->predicate))
- << "predicated load is not supported";
+ CHECK(is_one(op->predicate)) << "predicated load is not supported";
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) {
DataType t = op->value.dtype();
if (t.lanes() == 1) {
std::string value = this->PrintExpr(op->value);
- std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index);
+ std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index);
this->PrintIndent();
stream << ref << " = " << value << ";\n";
} else {
- CHECK(is_one(op->predicate))
- << "Predicated store is not supported";
+ CHECK(is_one(op->predicate)) << "Predicated store is not supported";
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
CHECK_EQ(op->base.dtype(), DataType::Int(32));
os << "((int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
- os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
- if (i != op->lanes - 1)
- os << ", ";
+ os << "(" << PrintExpr(op->base) << ")"
+ << "+(" << PrintExpr(op->stride) << "*" << i << ")";
+ if (i != op->lanes - 1) os << ", ";
}
os << "))";
}
LOG(FATAL) << "Shuffle: not supported ";
}
-void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported ";
}
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
- if (op->var.dtype() == DataType::Handle() &&
- handle_data_type_.count(op->var.get())) {
+ if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) {
PrintType(handle_data_type_.at(op->var.get()), stream);
- stream << "* "
- << AllocVarID(op->var.get())
- << " = (";
+ stream << "* " << AllocVarID(op->var.get()) << " = (";
PrintType(handle_data_type_.at(op->var.get()), stream);
- stream << "*)" << value << ";\n";
+ stream << "*)" << value << ";\n";
} else {
PrintType(op->var.dtype(), this->stream);
- this->stream << ' '
- << AllocVarID(op->var.get())
- << " = " << value << ";\n";
+ this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n";
}
}
PrintStmt(op->body);
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
- this->PrintIndent();
- int32_t constant_size = op->constant_allocation_size();
- CHECK_GT(constant_size, 0)
- << "Can only handle constant size stack allocation for now";
- const VarNode* buffer = op->buffer_var.as<VarNode>();
- std::string scope = alloc_storage_scope_.at(buffer);
- PrintStorageScope(scope, stream);
- PrintType(op->dtype, stream);
- stream << ' ' << vid << '[' << constant_size << "];\n";
+ this->PrintIndent();
+ int32_t constant_size = op->constant_allocation_size();
+ CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
+ const VarNode* buffer = op->buffer_var.as<VarNode>();
+ std::string scope = alloc_storage_scope_.at(buffer);
+ PrintStorageScope(scope, stream);
+ PrintType(op->dtype, stream);
+ stream << ' ' << vid << '[' << constant_size << "];\n";
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
CHECK(is_zero(op->min));
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
- stream << ' ' << vid << " = 0; "
- << vid << " < " << extent
- << "; ++" << vid << ") {\n";
+ stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
const CallNode* call = op->value.as<CallNode>();
if (call) {
if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
- this->PrintStorageSync(call); return;
+ this->PrintStorageSync(call);
+ return;
} else if (call->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(call->args.size(), 4);
std::string value = PrintExpr(call->args[3]);
- std::string ref = GetStructRef(
- call->args[3].dtype(),
- call->args[0],
- call->args[1],
- call->args[2].as<IntImmNode>()->value);
+ std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1],
+ call->args[2].as<IntImmNode>()->value);
this->PrintIndent();
this->stream << ref << " = " << value << ";\n";
return;
}
}
-void CodeGenC::PrintVecElemLoadExpr(
- DataType t, int i, const std::string& value, std::ostream& os) {
+void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) {
CHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (i != 0) {
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_H_
+#include <tvm/runtime/container.h>
+#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/target/codegen.h>
-#include <tvm/runtime/container.h>
+
#include <string>
-#include <vector>
#include <unordered_map>
#include <unordered_set>
+#include <vector>
+
#include "codegen_source_base.h"
namespace tvm {
* and OpenCL-C. You might find some odd variant features, e.g., type `int3` for
* a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`.
*/
-class CodeGenC :
- public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
- public StmtFunctor<void(const Stmt&)>,
- public CodeGenSourceBase {
+class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
+ public StmtFunctor<void(const Stmt&)>,
+ public CodeGenSourceBase {
public:
/*!
* \brief Initialize the code generator.
* \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed.
*/
- void PrintStmt(const Stmt& n) {
- VisitStmt(n);
- }
+ void PrintStmt(const Stmt& n) { VisitStmt(n); }
/*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
*
* Example: stream << "void";
*/
- virtual void PrintFuncPrefix(); // NOLINT(*)
+ virtual void PrintFuncPrefix(); // NOLINT(*)
/*!
* \brief Print the final return at the end the function.
*/
- virtual void PrintFinalReturn(); // NOLINT(*)
+ virtual void PrintFinalReturn(); // NOLINT(*)
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
*/
virtual void InitFuncState(const PrimFunc& f);
// expression
- void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmtNode* op) override;
* \param t The type representation.
* \param os The stream to print the ctype into
*/
- virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
+ virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
/*!
* Print Type represetnation of type type.
* \param type The type representation.
* \param os The stream to print the ctype into
*/
- virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*)
+ virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*)
/*!
* \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded;
*/
- virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
- virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
- virtual void PrintStorageSync(const CallNode* op); // NOLINT(*)
+ virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
+ virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
+ virtual void PrintStorageSync(const CallNode* op); // NOLINT(*)
// Binary vector op.
- virtual void PrintVecBinaryOp(
- const std::string&op, DataType op_type,
- PrimExpr lhs, PrimExpr rhs, std::ostream& os); // NOLINT(*)
+ virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs,
+ std::ostream& os); // NOLINT(*)
// print vector load
virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base);
// print vector store
- virtual void PrintVecStore(const VarNode* buffer,
- DataType t, PrimExpr base,
+ virtual void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base,
const std::string& value); // NOLINT(*)
// print load of single element
- virtual void PrintVecElemLoad(
- const std::string& vec, DataType t, int i, std::ostream& os); // NOLINT(*)
+ virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i,
+ std::ostream& os); // NOLINT(*)
// print store of single element.
- virtual void PrintVecElemStore(
- const std::string& vec, DataType t, int i, const std::string& value);
+ virtual void PrintVecElemStore(const std::string& vec, DataType t, int i,
+ const std::string& value);
// Get a cast type from to
virtual std::string CastFromTo(std::string value, DataType from, DataType target);
// Get load of single element with expression
protected:
// Print reference to struct location
- std::string GetStructRef(
- DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
+ std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
// Print reference to a buffer as type t in index.
- virtual std::string GetBufferRef(
- DataType t, const VarNode* buffer, PrimExpr index);
+ virtual std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index);
/*!
* \brief Handle volatile loads.
* does not implement volatile member functions. CUDA codegen will cast
* away volatile qualifier from CUDA __half types.
*/
- virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op,
- std::ostream& os) {
+ virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) {
// By default, do nothing but print the loaded value.
os << value;
}
* or "__constant__" is not part of type but a storage class (like
* C/C++ static).
*/
- virtual bool IsScopePartOfType() const {
- return true;
- }
+ virtual bool IsScopePartOfType() const { return true; }
/*!
* \brief If buffer is allocated as type t.
*/
void RegisterHandleType(const VarNode* buf_var, DataType t);
// override
- void PrintSSAAssign(
- const std::string& target, const std::string& src, DataType t) final;
+ void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final;
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();
/*! \brief Check if buf_var is volatile or not. */
- bool IsVolatile(const VarNode *buf_var) const {
- return volatile_buf_.count(buf_var) != 0;
- }
+ bool IsVolatile(const VarNode* buf_var) const { return volatile_buf_.count(buf_var) != 0; }
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*!
* \file codegen_c_host.cc
*/
+#include "codegen_c_host.h"
+
#include <tvm/target/codegen.h>
-#include <vector>
+
#include <string>
+#include <vector>
+
#include "../build_common.h"
-#include "codegen_c_host.h"
namespace tvm {
namespace codegen {
-CodeGenCHost::CodeGenCHost() {
- module_name_ = GetUniqueName("__tvm_module_ctx");
-}
+CodeGenCHost::CodeGenCHost() { module_name_ = GetUniqueName("__tvm_module_ctx"); }
void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) {
emit_asserts_ = emit_asserts;
void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
- CHECK_EQ(lanes, 1)
- << "does not support vector types";
- os << "void*"; return;
+ CHECK_EQ(lanes, 1) << "does not support vector types";
+ os << "void*";
+ return;
}
if (t == DataType::Bool()) {
- os << "bool"; return;
+ os << "bool";
+ return;
}
bool fail = false;
if (t.is_float()) {
case 16:
os << "half";
break;
- case 32: os << "float"; break;
+ case 32:
+ os << "float";
+ break;
case 64:
os << "double";
break;
- default: fail = true; break;
+ default:
+ fail = true;
+ break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
- os << lanes; return;
+ os << lanes;
+ return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << 'u';
}
switch (t.bits()) {
- case 8: os << "int8_t"; break;
- case 16: os << "int16_t"; break;
- case 32: os << "int32_t"; break;
- case 64: os << "int64_t"; break;
- case 1: os << "int32_t"; break;
- default: fail = true; break;
+ case 8:
+ os << "int8_t";
+ break;
+ case 16:
+ os << "int16_t";
+ break;
+ case 32:
+ os << "int32_t";
+ break;
+ case 64:
+ os << "int64_t";
+ break;
+ case 1:
+ os << "int32_t";
+ break;
+ default:
+ fail = true;
+ break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
- os << lanes; return;
+ os << lanes;
+ return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
-void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << "((";
PrintType(op->dtype, os);
this->stream << "if (" << packed_func_name << " == NULL) {\n";
int packed_func_if_scope = this->BeginScope();
this->PrintIndent();
- this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_
- << ", \"" << func_name << "\""
- << ", &" << packed_func_name << ") != 0) {\n";
+ this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" << func_name << "\""
+ << ", &" << packed_func_name << ") != 0) {\n";
int get_func_env_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->stream << "int " << ret_type_code << ";\n";
this->PrintIndent();
this->stream << "if (TVMFuncCall(" << packed_func_name << ", "
- << "(TVMValue*) stack_value" << ", " << "(int*) stack_tcode" << ", "
- << num_args << ", " << "&" << ret_val << ", " << "&"
- << ret_type_code << ") != 0) {\n";
+ << "(TVMValue*) stack_value"
+ << ", "
+ << "(int*) stack_tcode"
+ << ", " << num_args << ", "
+ << "&" << ret_val << ", "
+ << "&" << ret_type_code << ") != 0) {\n";
int func_call_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->stream << "}\n";
}
-void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
std::string stack_name = GetUniqueName("stack");
const std::string& type = op->args[0].as<StringImmNode>()->value;
std::string packed_func_name = func_name + "_packed";
if (declared_globals_.insert(packed_func_name).second) {
// Still reserve the name among unique names.
- CHECK(GetUniqueName(packed_func_name) == packed_func_name) <<
- "Expected name " << packed_func_name << " to not be taken";
+ CHECK(GetUniqueName(packed_func_name) == packed_func_name)
+ << "Expected name " << packed_func_name << " to not be taken";
decl_stream << "static void* " << packed_func_name << " = NULL;\n";
}
this->PrintGetFuncFromBackend(func_name, packed_func_name);
}
}
-void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*)
+void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) { // NOLINT(*)
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
this->PrintStmt(op->body);
}
-void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, "<", os);
}
-void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, ">", os);
}
template <typename T>
-inline void CodeGenCHost::PrintTernaryCondExpr(const T* op,
- const char* compare,
- std::ostream& os) { // NOLINT(*)
+inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare,
+ std::ostream& os) { // NOLINT(*)
std::ostringstream temp_a;
VisitExpr(op->a, temp_a);
std::string a_id = SSAGetID(temp_a.str(), op->a.dtype());
cg.Init(output_ssa, emit_asserts);
for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodegenCHost: Can only take PrimFunc";
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
cg.AddFunction(f);
}
return CSourceModuleCreate(code, "c");
}
-TVM_REGISTER_GLOBAL("target.build.c")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("target.build.c").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildCHost(args[0]);
});
} // namespace codegen
#include <set>
#include <string>
+
+#include "codegen_c.h"
#include "tvm/target/codegen.h"
#include "tvm/tir/expr.h"
-#include "codegen_c.h"
namespace tvm {
namespace codegen {
CodeGenCHost();
void Init(bool output_ssa, bool emit_asserts);
- void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
- void PrintFuncPrefix() final; // NOLINT(*)
- void PrintFinalReturn() final; // NOLINT(*)
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
+ void PrintFuncPrefix() final; // NOLINT(*)
+ void PrintFinalReturn() final; // NOLINT(*)
// overload visitor functions
- void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const CallNode *op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
// overload min and max to use the ternary operator, so we don't rely on the
// standard library implementations
- void VisitExpr_(const MinNode *op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const MaxNode *op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const MinNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const MaxNode* op, std::ostream& os) final; // NOLINT(*)
- void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
+ void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*)
private:
std::string module_name_;
* \param os stream reference to print into
*/
template <typename T>
- inline void PrintTernaryCondExpr(const T* op,
- const char* compare,
+ inline void PrintTernaryCondExpr(const T* op, const char* compare,
std::ostream& os); // NOLINT(*)
};
* \file codegen_cuda.cc
*/
+#include "codegen_cuda.h"
+
#include <tvm/runtime/registry.h>
#include <cmath>
+#include <string>
#include <utility>
#include <vector>
-#include <string>
+
#include "literal/cuda_half_t.h"
-#include "codegen_cuda.h"
namespace tvm {
namespace codegen {
-CodeGenCUDA::CodeGenCUDA() {
- restrict_keyword_ = "__restrict__";
-}
+CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
void CodeGenCUDA::Init(bool output_ssa) {
CodeGenC::Init(output_ssa);
CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
}
-
-void CodeGenCUDA::PrintFuncPrefix() {
- stream << "extern \"C\" __global__ void";
-}
+void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; }
std::string CodeGenCUDA::Finish() {
if (enable_fp16_) {
void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
- var_idmap_[iv->var.get()] =
- CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
+ var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
}
void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
- CHECK_EQ(lanes, 1)
- << "do not yet support vector types";
- os << "void*"; return;
+ CHECK_EQ(lanes, 1) << "do not yet support vector types";
+ os << "void*";
+ return;
}
bool fail = false;
if (t.is_float()) {
fail = true;
}
break;
- case 32: os << "float"; break;
- case 64: os << "double"; break;
- default: fail = true; break;
+ case 32:
+ os << "float";
+ break;
+ case 64:
+ os << "double";
+ break;
+ default:
+ fail = true;
+ break;
}
if (!fail && (lanes == 1 || t.bits() == 16)) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
- os << lanes; return;
+ os << lanes;
+ return;
}
} else if (t == DataType::Bool()) {
- os << "bool"; return;
+ os << "bool";
+ return;
} else if (t.is_vector_bool()) {
// CUDA does not support bool vectors.
// Use ushort vectors to represent instead.
int n = t.lanes();
if (n <= 4) {
- os << "ushort" << n; return;
+ os << "ushort" << n;
+ return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
switch (t.bits()) {
case 1: {
if (t.lanes() == 1) {
- os << "int"; return;
+ os << "int";
+ return;
} else if (t.lanes() == 8) {
- os << "int8_t"; return;
+ os << "int8_t";
+ return;
} else if (t.lanes() == 16) {
- os << "int16_t"; return;
+ os << "int16_t";
+ return;
} else if (t.lanes() == 32) {
- os << "int"; return;
+ os << "int";
+ return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 4: {
if (t.lanes() == 1) {
- os << "int"; return;
+ os << "int";
+ return;
} else if (t.lanes() == 4) {
- os << "int16_t"; return;
+ os << "int16_t";
+ return;
} else if (t.lanes() == 8) {
// directly 8 4-bit int in integer.
- os << "int"; return;
+ os << "int";
+ return;
} else if (t.lanes() == 16) {
- os << "int2"; return;
+ os << "int2";
+ return;
} else if (t.lanes() == 32) {
- os << "int4"; return;
+ os << "int4";
+ return;
} else if (t.lanes() == 64) {
- os << "int8"; return;
+ os << "int8";
+ return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
// We use int for int8x4 instead of char4 because using char4 is
// likely to produce extra instructions to pack four int8 elements
// into 32-bit data.
- os << "int"; return;
+ os << "int";
+ return;
} else if (t.lanes() == 8) {
enable_int8_ = true;
- os << "int2"; return;
+ os << "int2";
+ return;
} else if (t.lanes() == 16) {
enable_int8_ = true;
- os << "int4"; return;
+ os << "int4";
+ return;
} else if (!t.is_uint() && t.lanes() == 1) {
- os << "signed char"; break;
+ os << "signed char";
+ break;
} else {
- os << "char"; break;
+ os << "char";
+ break;
}
}
- case 16: os << "short"; break;
- case 32: os << "int"; break;
+ case 16:
+ os << "short";
+ break;
+ case 32:
+ os << "int";
+ break;
case 64: {
- if (sizeof(long) != 8) { // NOLINT(*)
+ if (sizeof(long) != 8) { // NOLINT(*)
if (t.lanes() == 1) {
- os << "long long"; break;
+ os << "long long";
+ break;
} else if (t.lanes() == 2) {
- os << "longlong"; break;
+ os << "longlong";
+ break;
} else {
// No longlong3, longlong4
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type on a L32 platform";
break;
}
} else {
- os << "long"; break;
+ os << "long";
+ break;
}
}
- default: fail = true; break;
+ default:
+ fail = true;
+ break;
}
if (!fail && lanes == 1) {
return;
}
if (!fail && (lanes >= 2 && lanes <= 4)) {
- os << lanes; return;
+ os << lanes;
+ return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}
-void CodeGenCUDA::PrintVecBinaryOp(
- const std::string& op, DataType t,
- PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*)
+void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
+ std::ostream& os) { // NOLINT(*)
// Delcare the result.
std::string sret = GetUniqueName("_");
this->PrintIndent();
os << sret;
}
-void CodeGenCUDA::PrintVecElemLoad(
- const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*)
+void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
+ std::ostream& os) { // NOLINT(*)
if (t.is_scalar()) {
os << vec;
return;
os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
}
} else if (t.is_float16()) {
- os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
- << access[i % 2];
+ os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else {
os << vec << "." << access[i];
}
}
-void CodeGenCUDA::PrintVecElemStore(
- const std::string& vec, DataType t, int i, const std::string& value) {
+void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
+ const std::string& value) {
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
stream << "(" << value << " << " << i * 8 << ");\n";
}
} else if (t.is_float16()) {
- stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
- << access[i % 2] << " = " << value << ";\n";
+ stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
+ << value << ";\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
} else if (sync == "global") {
if (!need_global_barrier_) {
need_global_barrier_ = true;
- this->decl_stream << "extern \"C\" __device__ unsigned "
- << vid_global_barrier_state_ << ";\n";
+ this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_
+ << ";\n";
}
// global synchronizer
std::string is_load = PrintExpr(op->args[1]);
this->PrintIndent();
// In theory only threadfence is needed
// but we observed problems with only threadfence
- this->stream <<"__threadfence_system();\n";
+ this->stream << "__threadfence_system();\n";
this->PrintIndent();
- this->stream <<"if (" << is_load << ") {\n";
+ this->stream << "if (" << is_load << ") {\n";
int wb = this->BeginScope();
this->PrintIndent();
this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
this->PrintIndent();
std::string ptr = GetUniqueName("pf");
- this->stream << "volatile unsigned* "
- << ptr << " = &" << vid_global_barrier_state_<< ";\n";
+ this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n";
this->PrintIndent();
this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
this->PrintIndent();
- this->stream <<"while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n";
+ this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n";
this->EndScope(wb);
this->PrintIndent();
- this->stream <<"}\n";
+ this->stream << "}\n";
this->PrintIndent();
- this->stream <<"__syncthreads();\n";
+ this->stream << "__syncthreads();\n";
}
}
-void CodeGenCUDA::PrintStorageScope(
- const std::string& scope, std::ostream& os) { // NOLINT(*)
+void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_NE(scope, "global");
if (scope == "shared") {
os << "__shared__ ";
CHECK_EQ(target_ty.lanes(), from_ty.lanes());
// Emit simple C-style type conversion.
- if (from_ty.is_scalar())
- return CodeGenC::VisitExpr_(op, os);
+ if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// This is only for backward compatibility with __shfl_{up/down}.
// A macro will be used to replace *_sync calls to legacy ones.
- if (op->is_intrinsic("__shfl_sync") ||
- op->is_intrinsic("__shfl_up_sync") ||
+ if (op->is_intrinsic("__shfl_sync") || op->is_intrinsic("__shfl_up_sync") ||
op->is_intrinsic("__shfl_down_sync")) {
enable_warp_shuffle_ = true;
}
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
- if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
+ if (const StringImmNode* str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
- os << "]" << ((i < 3) ? ", ": ")");
+ os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
need_mma_h_ = true;
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
- os << "]" << ((i < 3) ? ", ": ")");
+ os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->call_type == CallNode::PureExtern && op->dtype.is_vector()) {
//
std::ostringstream scall;
scall << op->name << "(";
for (size_t j = 0; j < op->args.size(); ++j) {
- if (j > 0)
- scall << ", ";
+ if (j > 0) scall << ", ";
PrintVecElemLoad(sargs[j], op->args[j].dtype(), i, scall);
}
scall << ")";
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
- CHECK_GT(constant_size, 0)
- << "Can only handle constant size stack allocation for now";
+ CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
- CHECK(op->dtype == DataType::Float(16) ||
- op->dtype == DataType::Int(8) ||
- op->dtype == DataType::UInt(8) ||
- op->dtype == DataType::Int(4) ||
- op->dtype == DataType::UInt(4) ||
- op->dtype == DataType::Int(1))
- << "Matrix_a and matrix_b only support half or char or unsigned char "
- << "or uint4 or int4 or int1 type for now";
+ CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
+ op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
+ op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1))
+ << "Matrix_a and matrix_b only support half or char or unsigned char "
+ << "or uint4 or int4 or int1 type for now";
} else {
- CHECK(op->dtype == DataType::Float(16) ||
- op->dtype == DataType::Float(32) ||
+ CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
op->dtype == DataType::Int(32))
- << "Accumulator only support half, float and int type for now";
+ << "Accumulator only support half, float and int type for now";
}
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
PrintWmmaScope(scope, op->dtype, buffer, stream);
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
}
- if ((op->dtype == DataType::Int(4) ||
- op->dtype == DataType::UInt(4) ||
- op->dtype == DataType::Int(1)) && scope == "shared") {
+ if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
+ op->dtype == DataType::Int(1)) &&
+ scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
}
- stream << ' '<< vid << '['
- << constant_size << "];\n";
+ stream << ' ' << vid << '[' << constant_size << "];\n";
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
-void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) {
+void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
if (is_const(op->value)) return;
const CallNode* call = op->value.as<CallNode>();
if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) {
void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
os << "((make_int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
- os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
- if (i != op->lanes - 1)
- os << ", ";
+ os << "(" << PrintExpr(op->base) << ")"
+ << "+(" << PrintExpr(op->stride) << "*" << i << ")";
+ if (i != op->lanes - 1) os << ", ";
}
os << "))";
}
-void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) {
// make_int8x4
- const int64_t *p = as_const_int(op->value);
+ const int64_t* p = as_const_int(op->value);
CHECK(p);
int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v;
os << '(';
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
- os << "__pack_half2(" << v << ", " << v << ")";
+ os << "__pack_half2(" << v << ", " << v << ")";
}
os << ')';
return;
os << ')';
}
-void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) {
+void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
std::vector<std::string> to_shuffle(op->vectors.size());
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
PrintType(op->dtype, os);
os << '(';
for (int i = 0, e = op->indices.size(); i < e; ++i) {
- const int64_t *val = as_const_int(op->indices[i]);
- CHECK(val && *val >= 0 && (int) *val < (int) to_shuffle.size());
+ const int64_t* val = as_const_int(op->indices[i]);
+ CHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size());
if (i != 0) os << ", ";
os << to_shuffle[*val];
}
os << ')';
}
-void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) {
+void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
// Non-vector cases.
if (!op->dtype.is_vector()) {
CodeGenC::VisitExpr_(op, os);
}
// Codegen vector condition case by serializing the select op.
- CHECK(op->false_value->dtype == op->dtype &&
- op->true_value->dtype == op->dtype &&
+ CHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype &&
op->dtype.lanes() == op->condition.dtype().lanes());
std::string r_var = GetUniqueName("_");
os << r_var;
}
-inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
+inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
switch (op->dtype.bits()) {
- case 64: case 32: {
+ case 64:
+ case 32: {
std::ostringstream temp;
if (std::isinf(op->value)) {
if (op->value < 0) {
os << '(' << std::scientific << op->value << 'f' << ')';
break;
}
- default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
+ default:
+ LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
-
-void CodeGenCUDA::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
-void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
- const VarNode* variable, std::ostream &os) {
+void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
+ std::ostream& os) {
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes[variable];
if (scope == "wmma.matrix_a") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
- os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
- << shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
+ os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str()
+ << ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.matrix_b") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
- os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
- << shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
+ os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str()
+ << ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.accumulator") {
need_mma_h_ = true;
- os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
- << shape_str << ", "<< type.str() << ">";
+ os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str()
+ << ">";
}
}
-int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
- const VarNode* variable, int32_t size) {
+int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable,
+ int32_t size) {
std::string shape_str = fragment_shapes[variable];
size_t m, n, k;
size_t last_pos = 0, pos = 0;
return 0;
}
-void CodeGenCUDA::HandleVolatileLoads(const std::string& value,
- const LoadNode* op, std::ostream& os) {
+void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* op,
+ std::ostream& os) {
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
}
}
-void CodeGenCUDA::PrintVecElemLoadExpr(
- DataType t, int i, const std::string& value, std::ostream& os) {
+void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
+ std::ostream& os) {
CHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (!(t.lanes() == 2 || t.lanes() == 3)) {
if (i != 0) {
os << "|";
}
- os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
+ os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
}
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
+
#include <string>
#include <unordered_map>
+
#include "codegen_c.h"
namespace tvm {
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
- void PrintVecBinaryOp(
- const std::string& op, DataType t,
- PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*)
- void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
- void PrintVecElemLoad(
- const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*)
- void PrintVecElemStore(
- const std::string& vec, DataType t, int i, const std::string& value) final;
+ void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
+ std::ostream& os) final; // NOLINT(*)
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
+ void PrintVecElemLoad(const std::string& vec, DataType t, int i,
+ std::ostream& os) final; // NOLINT(*)
+ void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
// overload visitor
- void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const FloatImmNode *op, std::ostream& os) final;
- void VisitExpr_(const CallNode *op, std::ostream& os) final;
+ void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
+ void VisitExpr_(const CallNode* op, std::ostream& os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final;
- void VisitStmt_(const EvaluateNode *op) final;
- void VisitStmt_(const AllocateNode *op) final;
- void VisitStmt_(const AttrStmtNode *op) final;
+ void VisitStmt_(const EvaluateNode* op) final;
+ void VisitStmt_(const AllocateNode* op) final;
+ void VisitStmt_(const AttrStmtNode* op) final;
private:
// Handle volatile loads
- void HandleVolatileLoads(const std::string& value, const LoadNode* op,
- std::ostream& os) final;
+ void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type.
- bool IsScopePartOfType() const final {
- return false;
- }
+ bool IsScopePartOfType() const final { return false; }
// Whether global barrier is needed.
bool need_global_barrier_{false};
std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
- void PrintWmmaScope(
- const std::string& scope, DataType t, const VarNode* variable, std::ostream& os);
- int32_t GetWmmaFragmentSize(
- const std::string &scope, const VarNode* variable, int32_t size);
+ void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
+ std::ostream& os);
+ int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size);
};
} // namespace codegen
/*!
* \file codegen_metal.cc
*/
-#include <vector>
-#include <string>
-#include <algorithm>
#include "codegen_metal.h"
-#include "../build_common.h"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
#include "../../runtime/metal/metal_module.h"
#include "../../runtime/thread_storage_scope.h"
+#include "../build_common.h"
namespace tvm {
namespace codegen {
// add to alloc buffer type.
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- CHECK(global_symbol.defined())
- << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
+ CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
// Function header.
this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
size_t num_buffer = 0;
for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
Var v = f->params[i];
- if (!v.dtype().is_handle()) break;
+ if (!v.dtype().is_handle()) break;
stream << " ";
std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get());
RegisterHandleType(v.get(), prim->dtype);
}
}
- stream << ' ' << vid
- << " [[ buffer(" << i << ") ]],\n";
+ stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
}
// Setup normal arguments.
size_t nargs = f->params.size() - num_buffer;
std::string varg = GetUniqueName("arg");
if (nargs != 0) {
- std::string arg_buf_type =
- static_cast<std::string>(global_symbol.value()) + "_args_t";
- stream << " constant " << arg_buf_type << "& " << varg
- << " [[ buffer(" << num_buffer << ") ]],\n";
+ std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t";
+ stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer
+ << ") ]],\n";
// declare the struct
decl_stream << "struct " << arg_buf_type << " {\n";
for (size_t i = num_buffer; i < f->params.size(); ++i) {
CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx");
CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx");
int work_dim = 0;
- auto thread_axis = f->GetAttr<Array<tir::IterVar>>(
- tir::attr::kDeviceThreadAxis).value();
+ auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis).value();
for (IterVar iv : thread_axis) {
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
- CHECK_EQ(lanes, 1)
- << "do not yet support vector types";
- os << "void*"; return;
+ CHECK_EQ(lanes, 1) << "do not yet support vector types";
+ os << "void*";
+ return;
}
if (t == DataType::Bool()) {
- os << "bool"; return;
+ os << "bool";
+ return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
- case 16: os << "half"; break;
- case 32: os << "float"; break;
- default: fail = true; break;
+ case 16:
+ os << "half";
+ break;
+ case 32:
+ os << "float";
+ break;
+ default:
+ fail = true;
+ break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
- os << lanes; return;
+ os << lanes;
+ return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
}
if (t.bits() == 8 && t.lanes() == 4) {
// directly 4 8 bit int in integer.
- os << "int"; return;
+ os << "int";
+ return;
}
switch (t.bits()) {
- case 8: os << "char"; break;
- case 16: os << "short"; break;
- case 32: os << "int"; break;
- case 1: os << "bool"; break;
- default: fail = true; break;
+ case 8:
+ os << "char";
+ break;
+ case 16:
+ os << "short";
+ break;
+ case 32:
+ os << "int";
+ break;
+ case 1:
+ os << "bool";
+ break;
+ default:
+ fail = true;
+ break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
- os << lanes; return;
+ os << lanes;
+ return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
}
}
-void CodeGenMetal::PrintVecElemLoad(const std::string& vec,
- DataType t, int i,
+void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << "[" << i << "]";
}
-void CodeGenMetal::PrintVecElemStore(const std::string& vec,
- DataType t, int i,
+void CodeGenMetal::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << "[" << i << "]"
<< " = " << value << ";\n";
}
-void CodeGenMetal::PrintStorageScope(
- const std::string& scope, std::ostream& os) { // NOLINT(*)
+void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") {
os << "device ";
} else if (scope == "shared") {
}
}
-void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
PrintType(op->dtype, os);
os << "(";
CodeGenMetal cg;
cg.Init(output_ssa);
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodeGenMetal: Can only take PrimFunc";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source);
}
-TVM_REGISTER_GLOBAL("target.build.metal")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildMetal(args[0]);
- });
+TVM_REGISTER_GLOBAL("target.build.metal").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = BuildMetal(args[0]);
+});
} // namespace codegen
} // namespace tvm
#define TVM_TARGET_SOURCE_CODEGEN_METAL_H_
#include <tvm/target/codegen.h>
+
#include <string>
+
#include "codegen_c.h"
namespace tvm {
CodeGenMetal();
// override print thread tag.
void PrintArgUnionDecl();
- void AddFunction(const PrimFunc& f); // NOLINT(*)
+ void AddFunction(const PrimFunc& f); // NOLINT(*)
void InitFuncState(const PrimFunc& f) final;
- void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
- void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
- void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
- void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
+ void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
+ void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
+ void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// print load of single element
- void PrintVecElemLoad(
- const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*)
+ void PrintVecElemLoad(const std::string& vec, DataType t, int i,
+ std::ostream& os) final; // NOLINT(*)
// print store of single element.
- void PrintVecElemStore(
- const std::string& vec, DataType t, int i, const std::string& value) final;
+ void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
- void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
// overload visitor
- void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
// reuse parent's function.
using CodeGenC::PrintType;
/*!
* \file codegen_opencl.cc
*/
+#include "codegen_opencl.h"
+
#include <cmath>
-#include <vector>
#include <string>
-#include "codegen_opencl.h"
-#include "../build_common.h"
-#include "../../runtime/thread_storage_scope.h"
+#include <vector>
+
#include "../../runtime/opencl/opencl_module.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "../build_common.h"
namespace tvm {
namespace codegen {
-CodeGenOpenCL::CodeGenOpenCL() {
- restrict_keyword_ = "restrict";
-}
+CodeGenOpenCL::CodeGenOpenCL() { restrict_keyword_ = "restrict"; }
void CodeGenOpenCL::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
}
}
-void CodeGenOpenCL::PrintFuncPrefix() {
- stream << "__kernel void";
-}
+void CodeGenOpenCL::PrintFuncPrefix() { stream << "__kernel void"; }
std::string CodeGenOpenCL::Finish() {
// inject extension enable pragma for fp16 and fp64
if (enable_fp16_) {
- decl_stream
- << "#ifdef cl_khr_fp16\n"
- "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
- "#elif defined(cl_amd_fp16)\n"
- "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n"
- "#else\n"
- "#error \"Half precision floating point not supported"
- "by OpenCL implementation on your device.\" \n"
- "#endif\n\n";
+ decl_stream << "#ifdef cl_khr_fp16\n"
+ "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
+ "#elif defined(cl_amd_fp16)\n"
+ "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n"
+ "#else\n"
+ "#error \"Half precision floating point not supported"
+ "by OpenCL implementation on your device.\" \n"
+ "#endif\n\n";
}
if (enable_fp64_) {
- decl_stream
- << "#ifdef cl_khr_fp64\n"
- "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
- "#elif defined(cl_amd_fp64)\n"
- "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n"
- "#else\n"
- "#error \"Double precision floating point not supported"
- "by OpenCL implementation on your device.\" \n"
- "#endif\n\n";
+ decl_stream << "#ifdef cl_khr_fp64\n"
+ "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
+ "#elif defined(cl_amd_fp64)\n"
+ "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n"
+ "#else\n"
+ "#error \"Double precision floating point not supported"
+ "by OpenCL implementation on your device.\" \n"
+ "#endif\n\n";
}
return CodeGenC::Finish();
} else {
os << "get_group_id(" << ts.dim_index << ")";
}
- var_idmap_[iv->var.get()] =
- CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype());
+ var_idmap_[iv->var.get()] = CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype());
}
void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
- CHECK_EQ(lanes, 1)
- << "do not yet support vector types";
- os << "void*"; return;
+ CHECK_EQ(lanes, 1) << "do not yet support vector types";
+ os << "void*";
+ return;
}
if (t == DataType::Bool()) {
- os << "bool"; return;
+ os << "bool";
+ return;
}
bool fail = false;
if (t.is_float()) {
os << "half";
enable_fp16_ = true;
break;
- case 32: os << "float"; break;
+ case 32:
+ os << "float";
+ break;
case 64:
os << "double";
enable_fp64_ = true;
break;
- default: fail = true; break;
+ default:
+ fail = true;
+ break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
- os << lanes; return;
+ os << lanes;
+ return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
}
if (t.bits() == 8 && t.lanes() == 4) {
// directly 4 8 bit int in integer.
- os << "int"; return;
+ os << "int";
+ return;
}
switch (t.bits()) {
- case 8: os << "char"; break;
- case 16: os << "short"; break;
- case 32: os << "int"; break;
- case 64: os << "long"; break;
- case 1: os << "int"; break;
- default: fail = true; break;
+ case 8:
+ os << "char";
+ break;
+ case 16:
+ os << "short";
+ break;
+ case 32:
+ os << "int";
+ break;
+ case 64:
+ os << "long";
+ break;
+ case 1:
+ os << "int";
+ break;
+ default:
+ fail = true;
+ break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
- os << lanes; return;
+ os << lanes;
+ return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type";
}
-void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
- PrimExpr base, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base,
+ std::ostream& os) { // NOLINT(*)
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
auto it = alloc_storage_scope_.find(buffer);
os << GetVarID(buffer) << " + ";
PrintExpr(base, os);
}
-std::string CodeGenOpenCL::GetVecLoad(
- DataType t, const VarNode* buffer, PrimExpr base) {
+std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) {
std::ostringstream os;
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);
return os.str();
}
-void CodeGenOpenCL::PrintVecStore(const VarNode* buffer,
- DataType t, PrimExpr base,
+void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base,
const std::string& value) {
this->PrintIndent();
stream << "vstore" << t.lanes() << "(" << value << ", 0, ";
}
}
-void CodeGenOpenCL::PrintStorageScope(
- const std::string& scope, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") {
os << "__global ";
} else if (scope == "shared") {
return os.str();
}
-void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << "((";
PrintType(op->dtype, os);
os << "))";
}
-void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
os << "-";
CodeGenOpenCL cg;
cg.Init(output_ssa);
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodeGenOpenCL: Can only take PrimFunc";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code);
}
-TVM_REGISTER_GLOBAL("target.build.opencl")
-.set_body_typed(BuildOpenCL);
+TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL);
} // namespace codegen
} // namespace tvm
#define TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_
#include <tvm/target/codegen.h>
+
#include <string>
+
#include "codegen_c.h"
namespace tvm {
// override print thread tag.
void InitFuncState(const PrimFunc& f) final;
- void PrintFuncPrefix() final; // NOLINT(*)
- void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
- void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
- void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
- void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
- std::string GetVecLoad(DataType t, const VarNode* buffer,
- PrimExpr base) final;
- void PrintVecStore(const VarNode* buffer,
- DataType t, PrimExpr base,
+ void PrintFuncPrefix() final; // NOLINT(*)
+ void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
+ void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
+ void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
+ std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) final;
+ void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base,
const std::string& value) final; // NOLINT(*)
// the address of load/store
- void PrintVecAddr(const VarNode* buffer, DataType t,
- PrimExpr base, std::ostream& os); // NOLINT(*)
- std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
+ void PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base,
+ std::ostream& os); // NOLINT(*)
+ std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
// overload visitor
- void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
private:
// whether enable fp16 and fp64 extension
* We are targeting OpenGL 3.3. The reason of not targeting a recent version
* of OpenGL is to have better compatibility of WebGL 2.
*/
-#include <vector>
+#include "codegen_opengl.h"
+
#include <string>
-#include <utility>
#include <unordered_map>
-#include "codegen_opengl.h"
-#include "../build_common.h"
+#include <utility>
+#include <vector>
+
#include "../../runtime/thread_storage_scope.h"
+#include "../build_common.h"
namespace tvm {
namespace codegen {
-CodeGenOpenGL::CodeGenOpenGL()
- : output_(nullptr), output_iter_var_(nullptr) {}
+CodeGenOpenGL::CodeGenOpenGL() : output_(nullptr), output_iter_var_(nullptr) {}
void CodeGenOpenGL::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
CHECK(global_symbol.defined())
<< "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute";
- shaders_[static_cast<std::string>(global_symbol.value())] = runtime::OpenGLShader(
- this->decl_stream.str() + this->stream.str(),
- std::move(arg_names), std::move(arg_kinds),
- this->thread_extent_var_);
+ shaders_[static_cast<std::string>(global_symbol.value())] =
+ runtime::OpenGLShader(this->decl_stream.str() + this->stream.str(), std::move(arg_names),
+ std::move(arg_kinds), this->thread_extent_var_);
}
-std::unordered_map<std::string, runtime::OpenGLShader> CodeGenOpenGL::Finish() {
- return shaders_;
-}
+std::unordered_map<std::string, runtime::OpenGLShader> CodeGenOpenGL::Finish() { return shaders_; }
void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) {
CHECK_EQ(iv->thread_tag, "threadIdx.x") << "Must be threadIdx.x";
- CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end())
- << "Only support one thread iter var";
+ CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end()) << "Only support one thread iter var";
CHECK(output_iter_var_ == nullptr) << "Only support one thread iter var";
var_idmap_[iv->var.get()] = iv->thread_tag;
// Print a reference expression to a buffer.
// Format: texelFetch(buffer, index, 0).r
-std::string CodeGenOpenGL::GetBufferRef(
- DataType t, const VarNode* buffer, PrimExpr index) {
+std::string CodeGenOpenGL::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) {
CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
// Doesn't support store to vector.
auto type = value.dtype();
- CHECK_EQ(type.lanes(), 1)
- << "Vectorized store not implemented, type = " << type;
+ CHECK_EQ(type.lanes(), 1) << "Vectorized store not implemented, type = " << type;
CHECK(inputs_.find(buffer) == inputs_.cend())
- << "Texture has been read from before. Must not store to it.";
+ << "Texture has been read from before. Must not store to it.";
if (output_ == nullptr) {
output_ = buffer; // Record that this texture is the output.
} else {
CodeGenOpenGL cg;
cg.Init(output_ssa);
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodeGenOpenGL: Can only take PrimFunc";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenGL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(mod));
}
-TVM_REGISTER_GLOBAL("target.build.opengl")
-.set_body_typed(BuildOpenGL);
+TVM_REGISTER_GLOBAL("target.build.opengl").set_body_typed(BuildOpenGL);
} // namespace codegen
} // namespace tvm
#define TVM_TARGET_SOURCE_CODEGEN_OPENGL_H_
#include <tvm/target/codegen.h>
+
#include <string>
-#include <unordered_set>
#include <unordered_map>
-#include "codegen_c.h"
+#include <unordered_set>
+
#include "../../runtime/opengl/opengl_module.h"
+#include "codegen_c.h"
namespace tvm {
namespace codegen {
void VisitStmt_(const StoreNode* op) final;
std::string TexelFetch(const VarNode* buffer, PrimExpr index);
std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) final;
- void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// Codegen for immediate values
- void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*)
// Match glsl_texture_store Call.
}
std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) {
- CHECK(!var_idmap_.count(v))
- << "Need input to be in SSA form dup " << v->name_hint;
+ CHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint;
std::string key = v->name_hint;
std::string vid = GetUniqueName(key);
var_idmap_[v] = vid;
std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const {
auto it = var_idmap_.find(v);
- CHECK(it != var_idmap_.end())
- << "Find undefined Variable " << v->name_hint;
+ CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint;
return it->second;
}
#ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
#define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
+#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
-#include <tvm/target/codegen.h>
-#include <string>
-#include <vector>
+
#include <functional>
+#include <string>
#include <unordered_map>
+#include <vector>
+
#include "../../runtime/meta_data.h"
namespace tvm {
* \param src The source expression.
* \param t The type of target.
*/
- virtual void PrintSSAAssign(
- const std::string& target, const std::string& src, DataType t) = 0;
+ virtual void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) = 0;
/*! \brief the declaration stream */
std::ostringstream decl_stream;
* \param fget_source a closure to replace default get source behavior.
*/
runtime::Module DeviceSourceModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, runtime::FunctionInfo> fmap,
- std::string type_key,
- std::function<std::string(const std::string&)> fget_source = nullptr);
+ std::string data, std::string fmt, std::unordered_map<std::string, runtime::FunctionInfo> fmap,
+ std::string type_key, std::function<std::string(const std::string&)> fget_source = nullptr);
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
/*!
* \file codegen_vhls.cc
*/
-#include <vector>
-#include <string>
#include "codegen_vhls.h"
-#include "../build_common.h"
+
+#include <string>
+#include <vector>
+
#include "../../runtime/opencl/sdaccel/sdaccel_module.h"
+#include "../build_common.h"
namespace tvm {
namespace codegen {
if (t.is_uint()) {
switch (t.bits()) {
case 8:
- os << "unsigned char"; break;
+ os << "unsigned char";
+ break;
case 16:
- os << "unsigned short"; break;
+ os << "unsigned short";
+ break;
case 32:
- os << "unsigned int"; break;
+ os << "unsigned int";
+ break;
case 64:
- os << "unsigned long long"; break;
+ os << "unsigned long long";
+ break;
default:
- os << "ap_uint<" << t.bits() << ">"; break;
+ os << "ap_uint<" << t.bits() << ">";
+ break;
}
} else if (t.is_int()) {
switch (t.bits()) {
case 8:
- os << "char"; break;
+ os << "char";
+ break;
case 16:
- os << "short"; break;
+ os << "short";
+ break;
case 32:
- os << "int"; break;
+ os << "int";
+ break;
case 64:
- os << "long long"; break;
+ os << "long long";
+ break;
default:
- os << "ap_int<" << t.bits() << ">"; break;
+ os << "ap_int<" << t.bits() << ">";
+ break;
}
} else {
CodeGenC::PrintType(t, os);
}
}
-void CodeGenVivadoHLS::PrintFuncPrefix() {
- stream << "extern \"C\" void";
-}
+void CodeGenVivadoHLS::PrintFuncPrefix() { stream << "extern \"C\" void"; }
void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {
this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n";
}
-template<typename T>
-inline void PrintBinaryExpr(const T* op,
- const char *opstr,
+template <typename T>
+inline void PrintBinaryExpr(const T* op, const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenVivadoHLS* p) {
os << opstr << '(';
os << ')';
}
-void CodeGenVivadoHLS::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*)
- const char *opstr = "std::min";
+void CodeGenVivadoHLS::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
+ const char* opstr = "std::min";
if (op->dtype.is_float()) {
switch (op->dtype.bits()) {
case 32:
- opstr = "fminf"; break;
+ opstr = "fminf";
+ break;
case 64:
- opstr = "fmin"; break;
+ opstr = "fmin";
+ break;
}
}
PrintBinaryExpr(op, opstr, os, this);
}
-void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*)
- const char *opstr = "std::max";
+void CodeGenVivadoHLS::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
+ const char* opstr = "std::max";
if (op->dtype.is_float()) {
switch (op->dtype.bits()) {
case 32:
- opstr = "fmaxf"; break;
+ opstr = "fmaxf";
+ break;
case 64:
- opstr = "fmax"; break;
+ opstr = "fmax";
+ break;
}
}
PrintBinaryExpr(op, opstr, os, this);
}
-
runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
using tvm::runtime::Registry;
bool output_ssa = false;
// Generate source code for get_source().
cg.Init(output_ssa);
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodeGenVHLS: Can only take PrimFunc";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
// Generate source code for compilation.
Array<Array<runtime::String> > kernel_info;
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodeGenOpenCL: Can only take PrimFunc";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
CodeGenVivadoHLS cg;
cg.Init(output_ssa);
return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(mod), whole_code);
}
-TVM_REGISTER_GLOBAL("target.build.sdaccel")
-.set_body_typed(BuildSDAccel);
+TVM_REGISTER_GLOBAL("target.build.sdaccel").set_body_typed(BuildSDAccel);
} // namespace codegen
} // namespace tvm
#include <tvm/target/codegen.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
+
#include <string>
+
#include "codegen_c.h"
namespace tvm {
void PrintFuncPrefix() final;
void PreFunctionBody(const PrimFunc& f) final;
- void VisitExpr_(const MinNode *op, std::ostream& os) final;
- void VisitExpr_(const MaxNode *op, std::ostream& os) final;
+ void VisitExpr_(const MinNode* op, std::ostream& os) final;
+ void VisitExpr_(const MaxNode* op, std::ostream& os) final;
};
} // namespace codegen
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt")
-.set_body(DispatchExtern<Direct>);
-
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow")
-.set_body(DispatchExtern<Direct>);
-
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
std::string operator()(DataType t, std::string name) const {
if (t.is_float()) {
switch (t.bits()) {
- case 64: return name;
- case 32: return name + 'f';
- case 16: return 'h' + name;
- default: return "";
+ case 64:
+ return name;
+ case 32:
+ return name + 'f';
+ case 16:
+ return 'h' + name;
+ default:
+ return "";
}
}
return "";
struct CUDAFastMathTan : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float()) {
- switch (t.bits()) {
- case 64: return name;
- // `__tanf` seems to produce some values too deviant from numpy tan version.
- // So, let's use just `tanf` instead.
- case 32: return name + 'f';
- case 16: LOG(FATAL) << "cuda tan unsupported for float16";
- default: return "";
- }
+ switch (t.bits()) {
+ case 64:
+ return name;
+ // `__tanf` seems to produce some values too deviant from numpy tan version.
+ // So, let's use just `tanf` instead.
+ case 32:
+ return name + 'f';
+ case 16:
+ LOG(FATAL) << "cuda tan unsupported for float16";
+ default:
+ return "";
+ }
}
return "";
}
std::string operator()(DataType t, std::string name) const {
if (t.is_uint()) {
switch (t.bits()) {
- case 32: return "__popc";
- case 64: return "__popcll";
- default: return "";
+ case 32:
+ return "__popc";
+ case 64:
+ return "__popcll";
+ default:
+ return "";
}
}
return "";
}
};
-
struct CUDAWarpIntrinsic {
const char* operator()(DataType t, const std::string& name) const {
if (name == intrinsic::tvm_warp_shuffle) {
*rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern);
}
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
-.set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10")
-.set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
-.set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2")
-.set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10")
-.set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan")
-.set_body(DispatchExtern<CUDAFastMathTan>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchExtern<CUDAFastMathTan>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
-.set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin")
-.set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
-.set_body(DispatchExtern<CUDAPopcount>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchExtern<CUDAPopcount>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
-.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
+ .set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_up")
-.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
+ .set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down")
-.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
+ .set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask")
-.set_body(DispatchExtern<CUDAWarpIntrinsic>);
+ .set_body(DispatchExtern<CUDAWarpIntrinsic>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
-.set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern<CUDAMath>);
} // namespace intrin
} // namespace codegen
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
* \brief OpenCL intrinsic rules.
*/
#include <tvm/arith/analyzer.h>
+
#include "../intrin_rule.h"
namespace tvm {
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchExtern<Direct>);
// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(call->args[3] == call->args[4]))
- << "Intel warp shuffle dose not support width != warp_size";
+ << "Intel warp shuffle dose not support width != warp_size";
Array<PrimExpr> opencl_args{{call->args[1], call->args[2]}};
- *rv = CallNode::make(call->dtype, "intel_sub_group_shuffle",
- opencl_args, CallNode::PureExtern);
+ *rv = CallNode::make(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern);
}
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
-.set_body(DispatchIntelShuffle);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle);
} // namespace intrin
} // namespace codegen
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sqrt")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sqrt").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh").set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh")
-.set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include "codegen_source_base.h"
+
#include "../../runtime/file_util.h"
#include "../../runtime/meta_data.h"
+#include "codegen_source_base.h"
namespace tvm {
namespace codegen {
+using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
-using runtime::PackedFunc;
+using runtime::FunctionInfo;
using runtime::GetFileFormat;
using runtime::GetMetaFilePath;
-using runtime::FunctionInfo;
using runtime::SaveBinaryToFile;
// Simulator function
class SourceModuleNode : public runtime::ModuleNode {
public:
- SourceModuleNode(std::string code,
- std::string fmt)
- : code_(code), fmt_(fmt) {}
- const char* type_key() const {
- return "source";
- }
+ SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {}
+ const char* type_key() const { return "source"; }
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
- std::string GetSource(const std::string& format) final {
- return code_;
- }
+ std::string GetSource(const std::string& format) final { return code_; }
protected:
std::string code_;
// Simulator function
class CSourceModuleNode : public runtime::ModuleNode {
public:
- CSourceModuleNode(std::string code,
- std::string fmt)
- : code_(code), fmt_(fmt) {}
- const char* type_key() const {
- return "c";
- }
+ CSourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {}
+ const char* type_key() const { return "c"; }
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "C Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
- std::string GetSource(const std::string& format) final {
- return code_;
- }
+ std::string GetSource(const std::string& format) final { return code_; }
- void SaveToFile(const std::string& file_name,
- const std::string& format) final {
+ void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
if (fmt == "cc") {
CHECK_NE(code_.length(), 0);
SaveBinaryToFile(file_name, code_);
} else {
- CHECK_EQ(fmt, fmt_)
- << "Can only save to format=" << fmt_;
+ CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
}
}
// supports limited save without cross compile
class DeviceSourceModuleNode final : public runtime::ModuleNode {
public:
- DeviceSourceModuleNode(std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string type_key,
+ DeviceSourceModuleNode(std::string data, std::string fmt,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string type_key,
std::function<std::string(const std::string&)> fget_source)
- : data_(data),
- fmt_(fmt),
- fmap_(fmap),
- type_key_(type_key),
- fget_source_(fget_source) {}
-
- PackedFunc GetFunction(
- const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {}
+
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
}
- const char* type_key() const {
- return type_key_.c_str();
- }
+ const char* type_key() const { return type_key_.c_str(); }
- void SaveToFile(const std::string& file_name,
- const std::string& format) final {
+ void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
- CHECK_EQ(fmt, fmt_)
- << "Can only save to format=" << fmt_;
+ CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
};
runtime::Module DeviceSourceModuleCreate(
- std::string data,
- std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string type_key,
- std::function<std::string(const std::string&)> fget_source) {
+ std::string data, std::string fmt, std::unordered_map<std::string, FunctionInfo> fmap,
+ std::string type_key, std::function<std::string(const std::string&)> fget_source) {
auto n = make_object<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
return runtime::Module(n);
}
-TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate")
-.set_body_typed(SourceModuleCreate);
+TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate);
-TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate")
-.set_body_typed(CSourceModuleCreate);
+TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate").set_body_typed(CSourceModuleCreate);
} // namespace codegen
} // namespace tvm
* \brief Build SPIRV block
*/
// Use libspirv for parsing and validating code.
-#include <libspirv.h>
#include <dmlc/memory_io.h>
+#include <libspirv.h>
#include <tvm/tir/transform.h>
-#include "codegen_spirv.h"
-#include "../build_common.h"
-
-#include "../../runtime/vulkan/vulkan_shader.h"
#include "../../runtime/vulkan/vulkan_module.h"
+#include "../../runtime/vulkan/vulkan_shader.h"
+#include "../build_common.h"
+#include "codegen_spirv.h"
namespace tvm {
namespace codegen {
class SPIRVTools {
public:
- SPIRVTools() {
- ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0);
- }
- ~SPIRVTools() {
- spvContextDestroy(ctx_);
- }
+ SPIRVTools() { ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); }
+ ~SPIRVTools() { spvContextDestroy(ctx_); }
std::string BinaryToText(const std::vector<uint32_t>& bin) {
spv_text text = nullptr;
spv_diagnostic diagnostic;
spv_const_binary_t spv_bin{bin.data(), bin.size()};
spv_result_t res;
- res = spvBinaryToText(
- ctx_, spv_bin.code, spv_bin.wordCount,
- SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
- SPV_BINARY_TO_TEXT_OPTION_INDENT,
- &text, &diagnostic);
+ res =
+ spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount,
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT,
+ &text, &diagnostic);
- CHECK_EQ(res, SPV_SUCCESS)
- << " line=" << diagnostic->position.line
- << " column=" << diagnostic->position.column
- << " index=" << diagnostic->position.index
- << " error:" << diagnostic->error;
+ CHECK_EQ(res, SPV_SUCCESS) << " line=" << diagnostic->position.line
+ << " column=" << diagnostic->position.column
+ << " index=" << diagnostic->position.index
+ << " error:" << diagnostic->error;
std::string ret(text->str);
spvTextDestroy(text);
CodeGenSPIRV cg;
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodeGenSPIRV: Can only take PrimFunc";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenSPIRV: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
if (webgpu_restriction) {
for (auto param : f->params) {
- CHECK(param.dtype().is_handle())
- << "WebGPU does not yet support non-buffer arguments";
+ CHECK(param.dtype().is_handle()) << "WebGPU does not yet support non-buffer arguments";
}
}
smap[f_name] = std::move(shader);
}
- return runtime::VulkanModuleCreate(
- smap, ExtractFuncInfo(mod), code_data.str());
+ return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), code_data.str());
}
-TVM_REGISTER_GLOBAL("target.build.vulkan")
-.set_body_typed([](IRModule mod, std::string target) {
+TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, std::string target) {
return BuildSPIRV(mod, target, false);
});
-TVM_REGISTER_GLOBAL("target.build.webgpu")
-.set_body_typed([](IRModule mod, std::string target) {
+TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, std::string target) {
return BuildSPIRV(mod, target, true);
});
* \file codegen_spirv.cc
* \brief Generate SPIRV block
*/
-#include <tvm/tir/expr.h>
+#include "codegen_spirv.h"
+
#include <tvm/runtime/container.h>
+#include <tvm/tir/expr.h>
+
#include <string>
-#include "codegen_spirv.h"
+
#include "../../arith/compute_expr.h"
namespace tvm {
namespace codegen {
-std::vector<uint32_t> CodeGenSPIRV::BuildFunction(
- const PrimFunc& f,
- const std::string& name) {
+std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) {
this->InitFuncState();
- CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias))
- << "SPIRV only takes restricted memory model";
+ CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model";
std::vector<Var> pod_args;
uint32_t num_buffer = 0;
auto* prim = ptr->element_type.as<PrimTypeNode>();
CHECK(prim);
DataType value_type = prim->dtype;
- spirv::Value arg_value = builder_->BufferArgument(
- builder_->GetSType(value_type), 0, num_buffer);
+ spirv::Value arg_value =
+ builder_->BufferArgument(builder_->GetSType(value_type), 0, num_buffer);
storage_info_[arg.get()].UpdateContentType(value_type);
var_map_[arg.get()] = arg_value;
} else {
}
spirv::Value ptr = builder_->DeclarePushConstant(value_types);
for (size_t i = 0; i < pod_args.size(); ++i) {
- spirv::Value value = builder_->GetPushConstant(
- ptr, value_types[i], static_cast<uint32_t>(i));
+ spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
}
}
builder_->InitHeader();
}
-spirv::Value CodeGenSPIRV::GetThreadIndex(
- const IterVar& iv, const PrimExpr& extent) {
+spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) {
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
spirv::Value v;
if (ts.rank == 1) {
v = builder_->GetLocalID(ts.dim_index);
auto* sizeptr = extent.as<tir::IntImmNode>();
- CHECK(sizeptr)
- << "SPIRV only allows constant thread group size " << " get " << extent;
+ CHECK(sizeptr) << "SPIRV only allows constant thread group size "
+ << " get " << extent;
CHECK_LT(ts.dim_index, 3);
workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
} else {
} else if (sync == "shared") {
auto type_int = builder_->GetSType(DataType::Int(32));
builder_->MakeInst(
- spv::OpControlBarrier,
- builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
- builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
- builder_->IntImm(type_int, static_cast<int64_t>(
- spv::MemorySemanticsSequentiallyConsistentMask |
- spv::MemorySemanticsWorkgroupMemoryMask)));
+ spv::OpControlBarrier,
+ builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
+ builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
+ builder_->IntImm(type_int,
+ static_cast<int64_t>(spv::MemorySemanticsSequentiallyConsistentMask |
+ spv::MemorySemanticsWorkgroupMemoryMask)));
} else {
LOG(FATAL) << "Do not support sync " << sync;
}
}
spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
- return builder_->Select(MakeValue(op->condition),
- MakeValue(op->true_value),
+ return builder_->Select(MakeValue(op->condition), MakeValue(op->true_value),
MakeValue(op->false_value));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
if (op->is_intrinsic("spirv_glsl450")) {
CHECK_GE(op->args.size(), 2U);
- uint32_t inst_id = static_cast<uint32_t>(
- op->args[0].as<IntImmNode>()->value);
+ uint32_t inst_id = static_cast<uint32_t>(op->args[0].as<IntImmNode>()->value);
std::vector<spirv::Value> values;
for (size_t i = 1; i < op->args.size(); ++i) {
values.push_back(MakeValue(op->args[i]));
}
- return builder_->CallGLSL450(
- builder_->GetSType(op->dtype), inst_id, values);
+ return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values);
} else if (op->is_intrinsic(CallNode::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Label then_label = builder_->NewLabel();
spirv::Label else_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
- builder_->MakeInst(
- spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
- builder_->MakeInst(
- spv::OpBranchConditional, cond, then_label, else_label);
+ builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
+ builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
// then block, must get label after we see the value
builder_->StartLabel(then_label);
spirv::Value then_value = MakeValue(op->args[1]);
phi.SetIncoming(1, else_value, else_value_label);
return phi;
} else if (op->is_intrinsic("popcount")) {
- return builder_->MakeValue(
- spv::OpBitCount,
- builder_->GetSType(op->dtype),
- MakeValue(op->args[0]));
+ return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype),
+ MakeValue(op->args[0]));
} else {
- if (op->call_type == CallNode::Intrinsic ||
- op->call_type == CallNode::PureIntrinsic) {
- LOG(FATAL) << "Unresolved intrinsic " << op->name
- << " with return type " << op->dtype;
- } else if (op->call_type == CallNode::Extern ||
- op->call_type == CallNode::PureExtern) {
- LOG(FATAL) << "Unresolved extern " << op->name
- << " with return type " << op->dtype;
+ if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
+ LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype;
+ } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) {
+ LOG(FATAL) << "Unresolved extern " << op->name << " with return type " << op->dtype;
} else {
LOG(FATAL) << "Unresolved call type " << op->call_type;
}
for (int i = 0; i < op->lanes; ++i) {
spirv::Value v = base;
if (i != 0) {
- spirv::Value offset = MakeValue(
- make_const(op->stride.dtype(), i) * op->stride);
+ spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride);
v = builder_->Add(v, offset);
}
values.push_back(v);
spirv::SType content_type = builder_->GetSType(info.content_type);
spirv::Value buffer = MakeValue(op->buffer_var);
- spirv::SType ptr_type = builder_->GetPointerType(
- content_type, buffer.stype.storage_class);
+ spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
uint32_t mask = spv::MemoryAccessMaskNone;
if (info.is_volatile) {
CHECK_EQ(info.content_type, op->dtype)
<< "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(op->index);
- spirv::Value ptr = builder_->StructArrayAccess(
- ptr_type, buffer, index);
+ spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
} else {
if (op->dtype.element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
std::vector<spirv::Value> values;
auto f = [&](int i, spirv::Value index) {
- spirv::Value ptr = builder_->StructArrayAccess(
- ptr_type, buffer, index);
- values.emplace_back(
- builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
+ spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
+ values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
};
this->Scalarize(op->index, f);
return builder_->Concat(values);
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, op->dtype.lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
- CHECK((me->coeff % ramp->lanes) == 0 &&
- (me->base % ramp->lanes) == 0)
+ CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
- PrimExpr vec_index = analyzer_->Simplify(
- ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
- spirv::Value ptr = builder_->StructArrayAccess(
- ptr_type, buffer, MakeValue(vec_index));
+ PrimExpr vec_index =
+ analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
+ spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
}
}
return spirv::Value();
}
-void CodeGenSPIRV::Scalarize(const PrimExpr& e,
- std::function<void(int i, spirv::Value v)> f) {
+void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f) {
if (const RampNode* ramp = e.as<RampNode>()) {
for (int i = 0; i < ramp->dtype.lanes(); ++i) {
PrimExpr offset = ramp->base + ramp->stride * i;
spirv::SType etype = builder_->GetSType(e.dtype().element_of());
spirv::Value value = MakeValue(e);
for (int i = 0; i < e.dtype().lanes(); ++i) {
- f(i, builder_->MakeValue(
- spv::OpCompositeExtract, etype, value, i));
+ f(i, builder_->MakeValue(spv::OpCompositeExtract, etype, value, i));
}
}
}
spirv::SType content_type = builder_->GetSType(info.content_type);
spirv::Value buffer = MakeValue(op->buffer_var);
spirv::Value value = MakeValue(op->value);
- spirv::SType ptr_type = builder_->GetPointerType(
- content_type, buffer.stype.storage_class);
+ spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
uint32_t mask = spv::MemoryAccessMaskNone;
if (info.is_volatile) {
CHECK_EQ(info.content_type, op->value.dtype())
<< "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(op->index);
- spirv::Value ptr = builder_->StructArrayAccess(
- ptr_type, buffer, index);
+ spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, value, mask);
} else {
if (op->value.dtype().element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
auto f = [&](int i, spirv::Value index) {
- spirv::Value elem = builder_->MakeValue(
- spv::OpCompositeExtract, content_type, value, i);
- spirv::Value ptr = builder_->StructArrayAccess(
- ptr_type, buffer, index);
+ spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i);
+ spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, elem, mask);
};
this->Scalarize(op->index, f);
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, op->value.dtype().lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
- CHECK((me->coeff % ramp->lanes) == 0 &&
- (me->base % ramp->lanes) == 0)
+ CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
- PrimExpr vec_index = analyzer_->Simplify(
- ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
- spirv::Value ptr = builder_->StructArrayAccess(
- ptr_type, buffer, MakeValue(vec_index));
+ PrimExpr vec_index =
+ analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
+ spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
builder_->MakeInst(spv::OpStore, ptr, value, mask);
return;
}
spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
loop_var.SetIncoming(0, init_value, init_label);
spirv::Value loop_cond = builder_->LT(loop_var, extent_value);
- uint32_t control = (
- op->for_type == ForType::Unrolled ?
- spv::LoopControlUnrollMask : spv::LoopControlMaskNone);
- builder_->MakeInst(
- spv::OpLoopMerge, merge_label, continue_label, control);
- builder_->MakeInst(
- spv::OpBranchConditional, loop_cond, body_label, merge_label,
- weight_likely_branch_, 1);
+ uint32_t control =
+ (op->for_type == ForType::Unrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone);
+ builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
+ builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label,
+ weight_likely_branch_, 1);
// loop body
builder_->StartLabel(body_label);
// loop continue
builder_->StartLabel(continue_label);
- spirv::Value one =
- op->loop_var.dtype().is_int() ?
- builder_->IntImm(loop_var.stype, 1) :
- builder_->UIntImm(loop_var.stype, 1);
+ spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1)
+ : builder_->UIntImm(loop_var.stype, 1);
spirv::Value next_value = builder_->Add(loop_var, one);
loop_var.SetIncoming(1, next_value, builder_->CurrentLabel());
builder_->MakeInst(spv::OpBranch, head_label);
spirv::Label merge_label = builder_->NewLabel();
if (op->else_case.defined()) {
spirv::Label else_label = builder_->NewLabel();
- builder_->MakeInst(
- spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
- builder_->MakeInst(
- spv::OpBranchConditional, cond, then_label, else_label);
+ builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
+ builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
// then block
builder_->StartLabel(then_label);
this->VisitStmt(op->then_case);
this->VisitStmt(op->else_case);
builder_->MakeInst(spv::OpBranch, merge_label);
} else {
- builder_->MakeInst(
- spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
- builder_->MakeInst(
- spv::OpBranchConditional, cond, then_label, merge_label,
- weight_likely_branch_, 1);
+ builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
+ builder_->MakeInst(spv::OpBranchConditional, cond, then_label, merge_label,
+ weight_likely_branch_, 1);
// then block
builder_->StartLabel(then_label);
this->VisitStmt(op->then_case);
CHECK(!is_zero(op->condition));
CHECK(!op->dtype.is_handle());
int32_t constant_size = op->constant_allocation_size();
- CHECK_GT(constant_size, 0)
- << "Can only handle constant size stack allocation in GPU";
+ CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";
spirv::Value buf;
StorageInfo& info = storage_info_[op->buffer_var.get()];
spirv::SType etype = builder_->GetSType(op->dtype);
if (info.scope.rank == runtime::StorageRank::kLocal) {
- buf = builder_->Allocate(
- etype, static_cast<uint32_t>(constant_size),
- spv::StorageClassFunction);
+ buf =
+ builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassFunction);
} else {
// shared memory
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory
- buf = builder_->Allocate(
- etype, static_cast<uint32_t>(constant_size),
- spv::StorageClassWorkgroup);
+ buf =
+ builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassWorkgroup);
}
CHECK(!info.content_fixed);
info.UpdateContentType(op->dtype);
} else if (op->attr_key == tir::attr::storage_scope) {
const VarNode* v = op->node.as<VarNode>();
CHECK(v);
- storage_info_[v].scope =
- runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
+ storage_info_[v].scope = runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>();
CHECK(v);
}
}
-void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) {
- MakeValue(op->value);
-}
+void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); }
} // namespace codegen
} // namespace tvm
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
-#include <vector>
#include <memory>
-#include <unordered_map>
#include <string>
+#include <unordered_map>
+#include <vector>
-#include "ir_builder.h"
#include "../../runtime/thread_storage_scope.h"
+#include "ir_builder.h"
namespace tvm {
namespace codegen {
/*!
* \brief Code generator into SPIRV
*/
-class CodeGenSPIRV:
- public ExprFunctor<spirv::Value(const PrimExpr&)>,
- public StmtFunctor<void(const Stmt&)> {
+class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
+ public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Compile and add function f to the current module.
* \param name The name of the target function.
* \return The final spirv module.
*/
- virtual std::vector<uint32_t> BuildFunction(const PrimFunc& f,
- const std::string& name);
+ virtual std::vector<uint32_t> BuildFunction(const PrimFunc& f, const std::string& name);
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
* \return created value.
*/
- spirv::Value MakeValue(const PrimExpr& e) {
- return VisitExpr(e);
- }
+ spirv::Value MakeValue(const PrimExpr& e) { return VisitExpr(e); }
// override codegen
spirv::Value VisitExpr_(const VarNode* op) override;
spirv::Value VisitExpr_(const CastNode* op) override;
// Update content type if it hasn't beenupdated.
void UpdateContentType(DataType type) {
if (content_fixed) {
- CHECK_EQ(type, content_type)
- << "Cannot use two different content type in GLSL model";
+ CHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model";
} else {
this->content_type = type;
content_fixed = true;
// Get the thread index
spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent);
spirv::Value CreateStorageSync(const CallNode* op);
- void Scalarize(const PrimExpr& e,
- std::function<void(int i, spirv::Value v)> f);
+ void Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f);
// The builder
std::unique_ptr<spirv::IRBuilder> builder_;
// Work group size of three
} // namespace codegen
} // namespace tvm
-
#endif // TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_
/*!
* \file intrin_rule_spirv.cc
*/
+#include <GLSL.std.450.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <GLSL.std.450.h>
namespace tvm {
namespace codegen {
using namespace runtime;
// num_signature means number of arguments used to query signature
-template<unsigned id>
+template <unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
- *rv = tir::CallNode::make(
- call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic);
+ *rv = tir::CallNode::make(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);
+ .set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil").set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Round>);
+ .set_body(DispatchGLSLPureIntrin<GLSLstd450Round>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);
-
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450FAbs>);
+ .set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntrin<GLSLstd450FAbs>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh").set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);
// WebGPU rules.
TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);
+ .set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.ceil")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.ceil").set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.round")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Round>);
+ .set_body(DispatchGLSLPureIntrin<GLSLstd450Round>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.trunc")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);
+ .set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.fabs")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450FAbs>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.fabs").set_body(DispatchGLSLPureIntrin<GLSLstd450FAbs>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.exp")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.exp").set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.log")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.log").set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.sqrt")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.sqrt").set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.pow")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.pow").set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.tanh")
-.set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.tanh").set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);
} // namespace spirv
} // namespace codegen
// shader
ib_.Begin(spv::OpCapability).Add(spv::CapabilityShader).Commit(&header_);
// memory model
- ib_.Begin(spv::OpMemoryModel).AddSeq(
- spv::AddressingModelLogical,
- spv::MemoryModelGLSL450).Commit(&entry_);
+ ib_.Begin(spv::OpMemoryModel)
+ .AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450)
+ .Commit(&entry_);
this->InitPreDefs();
}
t_void_.id = id_counter_++;
ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_);
t_void_func_.id = id_counter_++;
- ib_.Begin(spv::OpTypeFunction)
- .AddSeq(t_void_func_, t_void_).Commit(&global_);
+ ib_.Begin(spv::OpTypeFunction).AddSeq(t_void_func_, t_void_).Commit(&global_);
}
SType IRBuilder::GetSType(const DataType& dtype) {
return t;
}
-SType IRBuilder::GetPointerType(const SType& value_type,
- spv::StorageClass storage_class) {
+SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass storage_class) {
CHECK_NE(storage_class, spv::StorageClassMax);
auto key = std::make_pair(value_type.id, storage_class);
auto it = pointer_type_tbl_.find(key);
t.type = DataType::Handle();
t.element_type_id = value_type.id;
t.storage_class = storage_class;
- ib_.Begin(spv::OpTypePointer)
- .AddSeq(t, storage_class, value_type).Commit(&global_);
+ ib_.Begin(spv::OpTypePointer).AddSeq(t, storage_class, value_type).Commit(&global_);
pointer_type_tbl_[key] = t;
return t;
}
-SType IRBuilder::GetStructArrayType(const SType& value_type,
- uint32_t num_elems) {
+SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems) {
auto key = std::make_pair(value_type.id, num_elems);
auto it = struct_array_type_tbl_.find(key);
if (it != struct_array_type_tbl_.end()) {
if (num_elems != 0) {
Value length = UIntImm(GetSType(DataType::UInt(32)), num_elems);
- ib_.Begin(spv::OpTypeArray)
- .AddSeq(arr_type, value_type, length).Commit(&global_);
+ ib_.Begin(spv::OpTypeArray).AddSeq(arr_type, value_type, length).Commit(&global_);
} else {
- ib_.Begin(spv::OpTypeRuntimeArray)
- .AddSeq(arr_type, value_type).Commit(&global_);
+ ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_);
}
int nbits = value_type.type.bits() * value_type.type.lanes();
CHECK_EQ(nbits % 8, 0);
uint32_t nbytes = static_cast<uint32_t>(nbits) / 8;
// decorate the array type.
- this->Decorate(spv::OpDecorate,
- arr_type, spv::DecorationArrayStride, nbytes);
+ this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes);
// declare struct of array
SType struct_type;
struct_type.id = id_counter_++;
struct_type.type = DataType::Handle();
struct_type.element_type_id = value_type.id;
- ib_.Begin(spv::OpTypeStruct)
- .AddSeq(struct_type, arr_type).Commit(&global_);
+ ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_);
// decorate the array type.
ib_.Begin(spv::OpMemberDecorate)
.AddSeq(struct_type, 0, spv::DecorationOffset, 0)
.Commit(&decorate_);
-
#if SPV_VERSION < 0x10300
// NOTE: BufferBlock was deprecated in SPIRV 1.3
// use StorageClassStorageBuffer instead.
// runtime array are always decorated as BufferBlock(shader storage buffer)
if (num_elems == 0) {
- this->Decorate(spv::OpDecorate,
- struct_type, spv::DecorationBufferBlock);
+ this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
}
#else
- this->Decorate(spv::OpDecorate,
- struct_type, spv::DecorationBlock);
+ this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
#endif
struct_array_type_tbl_[key] = struct_type;
return struct_type;
}
-Value IRBuilder::StructArrayAccess(const SType& res_type,
- Value buffer,
- Value index) {
+Value IRBuilder::StructArrayAccess(const SType& res_type, Value buffer, Value index) {
CHECK(buffer.flag == kStructArrayPtr);
- return MakeValue(spv::OpInBoundsAccessChain,
- res_type, buffer,
- const_i32_zero_, index);
+ return MakeValue(spv::OpInBoundsAccessChain, res_type, buffer, const_i32_zero_, index);
}
Value IRBuilder::IntImm(const SType& dtype, int64_t value) {
return GetConst_(dtype, reinterpret_cast<uint64_t*>(&value));
}
-Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) {
- return GetConst_(dtype, &value);
-}
+Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) { return GetConst_(dtype, &value); }
Value IRBuilder::FloatImm(const SType& dtype, double value) {
if (dtype.type.bits() == 64) {
return GetConst_(dtype, &data);
} else {
CHECK_EQ(dtype.type.bits(), 16);
- return Cast(dtype,
- FloatImm(GetSType(DataType::Float(32)), value));
+ return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value));
}
}
-Value IRBuilder::BufferArgument(const SType& value_type,
- uint32_t descriptor_set,
+Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set,
uint32_t binding) {
// NOTE: BufferBlock was deprecated in SPIRV 1.3
// use StorageClassStorageBuffer instead.
SType ptr_type = GetPointerType(sarr_type, storage_class);
Value val = NewValue(ptr_type, kStructArrayPtr);
- ib_.Begin(spv::OpVariable)
- .AddSeq(ptr_type, val, storage_class).Commit(&global_);
+ ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_);
- this->Decorate(spv::OpDecorate,
- val, spv::DecorationDescriptorSet, descriptor_set);
- this->Decorate(spv::OpDecorate,
- val, spv::DecorationBinding, binding);
+ this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set);
+ this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
return val;
}
.Commit(&decorate_);
DataType t = value_types[i].type;
uint32_t nbits = t.bits() * t.lanes();
- CHECK_EQ(nbits % 8 , 0);
+ CHECK_EQ(nbits % 8, 0);
offset += nbits / 8;
}
// Decorate push constants as UBO
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
- SType ptr_type = GetPointerType(
- struct_type, spv::StorageClassPushConstant);
+ SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant);
Value val = NewValue(ptr_type, kPushConstantPtr);
- ib_.Begin(spv::OpVariable)
- .AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_);
+ ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_);
return val;
}
-Value IRBuilder::GetPushConstant(
- Value ptr_push_const, const SType& v_type, uint32_t index) {
+Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) {
SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant);
- Value ptr = this->MakeValue(
- spv::OpAccessChain, ptr_vtype, ptr_push_const,
- IntImm(t_int32_, static_cast<int64_t>(index)));
+ Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
+ IntImm(t_int32_, static_cast<int64_t>(index)));
return this->MakeValue(spv::OpLoad, v_type, ptr);
}
-Value IRBuilder::NewFunction() {
- return NewValue(t_void_func_, kFunction);
-}
+Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); }
void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) {
CHECK_EQ(func.flag, kFunction);
- ib_.Begin(spv::OpEntryPoint)
- .AddSeq(spv::ExecutionModelGLCompute, func, name);
+ ib_.Begin(spv::OpEntryPoint).AddSeq(spv::ExecutionModelGLCompute, func, name);
if (workgroup_id_.id != 0) {
ib_.Add(workgroup_id_);
}
void IRBuilder::StartFunction(const Value& func) {
CHECK_EQ(func.flag, kFunction);
// add function declaration to the header.
- ib_.Begin(spv::OpFunction).AddSeq(
- t_void_, func, 0, t_void_func_).Commit(&func_header_);
+ ib_.Begin(spv::OpFunction).AddSeq(t_void_, func, 0, t_void_func_).Commit(&func_header_);
spirv::Label start_label = this->NewLabel();
ib_.Begin(spv::OpLabel).AddSeq(start_label).Commit(&func_header_);
curr_label_ = start_label;
}
-void IRBuilder::SetLocalSize(const Value& func,
- uint32_t local_size[3]) {
+void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) {
CHECK_EQ(func.flag, kFunction);
ib_.Begin(spv::OpExecutionMode)
- .AddSeq(func, spv::ExecutionModeLocalSize,
- local_size[0], local_size[1], local_size[2])
+ .AddSeq(func, spv::ExecutionModeLocalSize, local_size[0], local_size[1], local_size[2])
.Commit(&exec_mode_);
}
-Value IRBuilder::Allocate(const SType& value_type,
- uint32_t num_elems,
+Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems,
spv::StorageClass storage_class) {
CHECK_NE(num_elems, 0U);
SType sarr_type = GetStructArrayType(value_type, num_elems);
SType ptr_type = GetPointerType(sarr_type, storage_class);
Value val = NewValue(ptr_type, kStructArrayPtr);
if (storage_class == spv::StorageClassFunction) {
- ib_.Begin(spv::OpVariable)
- .AddSeq(ptr_type, val, storage_class).Commit(&func_header_);
+ ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&func_header_);
} else {
- ib_.Begin(spv::OpVariable)
- .AddSeq(ptr_type, val, storage_class).Commit(&global_);
+ ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_);
}
return val;
}
Value IRBuilder::GetWorkgroupID(uint32_t dim_index) {
if (workgroup_id_.id == 0) {
SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3));
- SType ptr_type = this->GetPointerType(
- vec3_type, spv::StorageClassInput);
+ SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput);
workgroup_id_ = NewValue(ptr_type, kVectorPtr);
ib_.Begin(spv::OpVariable)
.AddSeq(ptr_type, workgroup_id_, spv::StorageClassInput)
.Commit(&global_);
- this->Decorate(spv::OpDecorate, workgroup_id_,
- spv::DecorationBuiltIn, spv::BuiltInWorkgroupId);
+ this->Decorate(spv::OpDecorate, workgroup_id_, spv::DecorationBuiltIn, spv::BuiltInWorkgroupId);
}
SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput);
- Value ptr = this->MakeValue(
- spv::OpAccessChain, pint_type, workgroup_id_,
- IntImm(t_int32_, static_cast<int64_t>(dim_index)));
+ Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, workgroup_id_,
+ IntImm(t_int32_, static_cast<int64_t>(dim_index)));
return this->MakeValue(spv::OpLoad, t_int32_, ptr);
}
SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3));
SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput);
local_id_ = NewValue(ptr_type, kVectorPtr);
- ib_.Begin(spv::OpVariable)
- .AddSeq(ptr_type, local_id_, spv::StorageClassInput)
- .Commit(&global_);
- this->Decorate(spv::OpDecorate, local_id_,
- spv::DecorationBuiltIn, spv::BuiltInLocalInvocationId);
+ ib_.Begin(spv::OpVariable).AddSeq(ptr_type, local_id_, spv::StorageClassInput).Commit(&global_);
+ this->Decorate(spv::OpDecorate, local_id_, spv::DecorationBuiltIn,
+ spv::BuiltInLocalInvocationId);
}
SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput);
- Value ptr = this->MakeValue(
- spv::OpAccessChain, pint_type, local_id_,
- UIntImm(t_int32_, static_cast<int64_t>(dim_index)));
+ Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, local_id_,
+ UIntImm(t_int32_, static_cast<int64_t>(dim_index)));
return this->MakeValue(spv::OpLoad, t_int32_, ptr);
}
if (dtype.type.bits() > 32) {
if (dtype.type.is_int()) {
int64_t sign_mask = 0xFFFFFFFFL;
- const int64_t* sign_ptr =
- reinterpret_cast<const int64_t*>(pvalue);
- ib_.Add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask));
+ const int64_t* sign_ptr = reinterpret_cast<const int64_t*>(pvalue);
+ ib_.Add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask));
} else {
ib_.Add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask));
}
t.id = id_counter_++;
t.type = dtype;
SType base_type = GetSType(dtype.element_of());
- ib_.Begin(spv::OpTypeVector).AddSeq(
- t, base_type, dtype.lanes()).Commit(&global_);
+ ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_);
return t;
}
}
return phi;
}
-Value IRBuilder::CallGLSL450(const SType& ret_type,
- uint32_t inst_id,
+Value IRBuilder::CallGLSL450(const SType& ret_type, uint32_t inst_id,
const std::vector<Value>& args) {
Value val = NewValue(ret_type, kNormal);
- ib_.Begin(spv::OpExtInst)
- .AddSeq(ret_type, val, ext_glsl450_, inst_id);
+ ib_.Begin(spv::OpExtInst).AddSeq(ret_type, val, ext_glsl450_, inst_id);
for (const Value& v : args) {
ib_.Add(v);
}
return MakeValue(spv::OpUConvert, dst_type, value);
} else if (from.is_uint() && to.is_int()) {
if (from.bits() != to.bits()) {
- value = MakeValue(
- spv::OpUConvert, GetSType(from.with_bits(to.bits())), value);
+ value = MakeValue(spv::OpUConvert, GetSType(from.with_bits(to.bits())), value);
}
return MakeValue(spv::OpBitcast, dst_type, value);
} else if (from.is_int() && to.is_uint()) {
if (from.bits() != to.bits()) {
- value = MakeValue(
- spv::OpSConvert, GetSType(from.with_bits(to.bits())), value);
+ value = MakeValue(spv::OpSConvert, GetSType(from.with_bits(to.bits())), value);
}
return MakeValue(spv::OpBitcast, dst_type, value);
} else if (from.is_float() && to.is_int()) {
} else if (from.is_float() && to.is_float()) {
return MakeValue(spv::OpFConvert, dst_type, value);
} else {
- LOG(FATAL) << "do not support type cast from "
- << from << " to " << to;
+ LOG(FATAL) << "do not support type cast from " << from << " to " << to;
return Value();
}
}
-#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \
- Value IRBuilder::_OpName(Value a, Value b) { \
- CHECK_EQ(a.stype.id, b.stype.id); \
- if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
- return MakeValue(spv::OpI ## _Op, a.stype, a, b); \
- } else { \
- CHECK(a.stype.type.is_float()); \
- return MakeValue(spv::OpF ## _Op, a.stype, a, b); \
- } \
+#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \
+ Value IRBuilder::_OpName(Value a, Value b) { \
+ CHECK_EQ(a.stype.id, b.stype.id); \
+ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
+ return MakeValue(spv::OpI##_Op, a.stype, a, b); \
+ } else { \
+ CHECK(a.stype.type.is_float()); \
+ return MakeValue(spv::OpF##_Op, a.stype, a, b); \
+ } \
}
#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
}
}
-#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
- Value IRBuilder::_OpName(Value a, Value b) { \
- CHECK_EQ(a.stype.id, b.stype.id); \
- CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
+#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
+ Value IRBuilder::_OpName(Value a, Value b) { \
+ CHECK_EQ(a.stype.id, b.stype.id); \
+ CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
- if (a.stype.type.is_int()) { \
- return MakeValue(spv::OpS##_Op, bool_type, a, b); \
- } else if (a.stype.type.is_uint()) { \
- return MakeValue(spv::OpU##_Op, bool_type, a, b); \
- } else { \
- CHECK(a.stype.type.is_float()); \
- return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
- } \
+ if (a.stype.type.is_int()) { \
+ return MakeValue(spv::OpS##_Op, bool_type, a, b); \
+ } else if (a.stype.type.is_uint()) { \
+ return MakeValue(spv::OpU##_Op, bool_type, a, b); \
+ } else { \
+ CHECK(a.stype.type.is_float()); \
+ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
+ } \
}
DEFINE_BUILDER_CMP_OP(LT, LessThan);
DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
-#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
- Value IRBuilder::_OpName(Value a, Value b) { \
- CHECK_EQ(a.stype.id, b.stype.id); \
- CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
+#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
+ Value IRBuilder::_OpName(Value a, Value b) { \
+ CHECK_EQ(a.stype.id, b.stype.id); \
+ CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
- if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
- return MakeValue(spv::OpI##_Op, bool_type, a, b); \
- } else { \
- CHECK(a.stype.type.is_float()); \
- return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
- } \
+ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
+ return MakeValue(spv::OpI##_Op, bool_type, a, b); \
+ } else { \
+ CHECK(a.stype.type.is_float()); \
+ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
+ } \
}
DEFINE_BUILDER_CMP_UOP(EQ, Equal);
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/expr.h>
+// clang-format off
#include <algorithm>
-#include <utility>
-#include <vector>
-#include <string>
#include <map>
+#include <string>
#include <unordered_map>
-
+#include <utility>
+#include <vector>
#include <spirv.hpp>
+// clang-format on
namespace tvm {
namespace codegen {
namespace spirv {
-
/*! \brief Represent the SPIRV Type */
struct SType {
/*! \brief The Id to represent type */
class Instr {
public:
/*! \return the word count */
- uint32_t WordCount() const {
- return word_count_;
- }
+ uint32_t WordCount() const { return word_count_; }
/*!
* \brief Access idx-th word of instruction
* \param idx The index
* \param value The value to come
* \param parent The parent label.
*/
- void SetIncoming(uint32_t index,
- const Value& value,
- const Label& parent) {
+ void SetIncoming(uint32_t index, const Value& value, const Label& parent) {
CHECK_EQ(this->stype.id, value.stype.id);
instr[3 + index * 2] = value.id;
instr[3 + index * 2 + 1] = parent.id;
*/
InstrBuilder& Add(const std::string& v) {
const uint32_t kWordSize = sizeof(uint32_t);
- uint32_t nwords =
- (static_cast<uint32_t>(v.length()) + kWordSize) / kWordSize;
+ uint32_t nwords = (static_cast<uint32_t>(v.length()) + kWordSize) / kWordSize;
size_t begin = data_.size();
data_.resize(begin + nwords, 0U);
- std::copy(v.begin(), v.end(),
- reinterpret_cast<char*>(&data_[begin]));
+ std::copy(v.begin(), v.end(), reinterpret_cast<char*>(&data_[begin]));
return *this;
}
/*!
* \return reference to self.
* \tparams Args The positional arguments
*/
- template<typename... Args>
- InstrBuilder& AddSeq(Args&& ...args) {
+ template <typename... Args>
+ InstrBuilder& AddSeq(Args&&... args) {
AddSeqHelper helper;
helper.builder = this;
runtime::detail::for_each(helper, std::forward<Args>(args)...);
// The reference to builder
InstrBuilder* builder;
// invoke function
- template<typename T>
+ template <typename T>
void operator()(size_t, const T& v) const {
builder->Add(v);
}
curr_label_ = label;
}
/*! \return The current label */
- Label CurrentLabel() const {
- return curr_label_;
- }
+ Label CurrentLabel() const { return curr_label_; }
/*!
* \brief Add code to debug segment.
* \param op The operator
* \param args The instruction sequence
* \tparams Args The positional arguments
*/
- template<typename... Args>
- void Debug(spv::Op op, Args&& ...args) {
+ template <typename... Args>
+ void Debug(spv::Op op, Args&&... args) {
ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&debug_);
}
/*!
* \param args The instruction sequence
* \tparams Args The positional arguments
*/
- template<typename... Args>
- void ExecutionMode(Value func, Args&& ...args) {
- ib_.Begin(spv::OpExecutionMode).AddSeq(
- func, std::forward<Args>(args)...).Commit(&exec_mode_);
+ template <typename... Args>
+ void ExecutionMode(Value func, Args&&... args) {
+ ib_.Begin(spv::OpExecutionMode).AddSeq(func, std::forward<Args>(args)...).Commit(&exec_mode_);
}
/*!
* \brief Add code to decorate segment.
* \param args The instruction sequence
* \tparams Args The positional arguments
*/
- template<typename... Args>
- void Decorate(spv::Op op, Args&& ...args) {
+ template <typename... Args>
+ void Decorate(spv::Op op, Args&&... args) {
ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&decorate_);
}
/*!
* \param args The instruction sequence
* \tparams Args The positional arguments
*/
- template<typename... Args>
- void DeclareGlobal(spv::Op op, Args&& ...args) {
+ template <typename... Args>
+ void DeclareGlobal(spv::Op op, Args&&... args) {
ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&decorate_);
}
/*!
* \return The result SSA value.
* \tparams Args The positional arguments
*/
- template<typename... Args>
- Instr MakeInst(spv::Op op, Args&& ...args) {
+ template <typename... Args>
+ Instr MakeInst(spv::Op op, Args&&... args) {
return ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&function_);
}
/*!
* \return The result SSA value.
* \tparams Args The positional arguments
*/
- template<typename... Args>
- Value MakeValue(spv::Op op, const SType& out_type, Args&& ...args) {
+ template <typename... Args>
+ Value MakeValue(spv::Op op, const SType& out_type, Args&&... args) {
Value val = NewValue(out_type, kNormal);
MakeInst(op, out_type, val, std::forward<Args>(args)...);
return val;
* \param args The arguments
* \return The result value.
*/
- Value CallGLSL450(const SType& ret_type,
- uint32_t inst_id,
- const std::vector<Value>& args);
+ Value CallGLSL450(const SType& ret_type, uint32_t inst_id, const std::vector<Value>& args);
/*!
* \brief Build vector by concatenating components
*
* \param storage_class The storage class
* \return The corresponding spirv type.
*/
- SType GetPointerType(const SType& value_type,
- spv::StorageClass storage_class);
+ SType GetPointerType(const SType& value_type, spv::StorageClass storage_class);
/*!
* \brief Get a struct{ value_type[num_elems] } type.
* \param value_type the content value type.
*
* \return The corresponding spirv type.
*/
- SType GetStructArrayType(const SType& value_type,
- uint32_t num_elems);
+ SType GetStructArrayType(const SType& value_type, uint32_t num_elems);
/*!
* \brief Get a struct array access with a given index.
* \param ptr_type The pointer type.
* \param buffer The buffer ptr to struct array
* \param index The array index.
*/
- Value StructArrayAccess(const SType& ptr_type,
- Value buffer,
- Value index);
+ Value StructArrayAccess(const SType& ptr_type, Value buffer, Value index);
/*!
* \brief Create a cast that cast value to dst_type
* \param dst_type The target type.
* \param binding The binding locaiton in descriptor set.
* \param The argument type.
*/
- Value BufferArgument(const SType& value_type,
- uint32_t descriptor_set,
- uint32_t binding);
+ Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding);
/*!
* \brief Declare POD arguments through push constants.
*
* \param num_elems Number of elements to allocate.
* \param storage_class The storage class we want to store to.
*/
- Value Allocate(const SType& value_type,
- uint32_t num_elems,
- spv::StorageClass storage_class);
+ Value Allocate(const SType& value_type, uint32_t num_elems, spv::StorageClass storage_class);
/*
* \brief Get the i-th workgroup id.
* \return The value representing the workgroup id.
std::vector<uint32_t> debug_;
/*! \brief Annotation segment */
std::vector<uint32_t> decorate_;
- /*! \brief Global segment: types, variables, types */
+ /*! \brief Global segment: types, variables, types */
std::vector<uint32_t> global_;
/*! \brief Function header segment */
std::vector<uint32_t> func_header_;
/*!
* \file codegen_stackvm.cc
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/container.h>
+#include "codegen_stackvm.h"
+
#include <tvm/ir/module.h>
-#include <tvm/tir/op.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+
#include <limits>
#include <utility>
-#include "codegen_stackvm.h"
+
#include "../../runtime/stackvm/stackvm_module.h"
namespace tvm {
StackVM::StructFieldKind MapFieldKind(int64_t kind) {
auto val = static_cast<intrinsic::TVMStructFieldKind>(kind);
switch (val) {
- case intrinsic::kArrData: return StackVM::kArrData;
- case intrinsic::kArrShape: return StackVM::kArrShape;
- case intrinsic::kArrAddr: return StackVM::kArrAddr;
- case intrinsic::kArrStrides: return StackVM::kArrStrides;
- case intrinsic::kArrNDim: return StackVM::kArrNDim;
- case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode;
- case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits;
- case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes;
- case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset;
- case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId;
- case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType;
- case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent;
- default: LOG(FATAL) << "Do not know how to map field " << kind;
+ case intrinsic::kArrData:
+ return StackVM::kArrData;
+ case intrinsic::kArrShape:
+ return StackVM::kArrShape;
+ case intrinsic::kArrAddr:
+ return StackVM::kArrAddr;
+ case intrinsic::kArrStrides:
+ return StackVM::kArrStrides;
+ case intrinsic::kArrNDim:
+ return StackVM::kArrNDim;
+ case intrinsic::kArrTypeCode:
+ return StackVM::kArrTypeCode;
+ case intrinsic::kArrTypeBits:
+ return StackVM::kArrTypeBits;
+ case intrinsic::kArrTypeLanes:
+ return StackVM::kArrTypeLanes;
+ case intrinsic::kArrByteOffset:
+ return StackVM::kArrByteOffset;
+ case intrinsic::kArrDeviceId:
+ return StackVM::kArrDeviceId;
+ case intrinsic::kArrDeviceType:
+ return StackVM::kArrDeviceType;
+ case intrinsic::kTVMValueContent:
+ return StackVM::kTVMValueContent;
+ default:
+ LOG(FATAL) << "Do not know how to map field " << kind;
}
return StackVM::kArrData;
}
}
void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) {
- CHECK(operand >= std::numeric_limits<int>::min() &&
- operand <= std::numeric_limits<int>::max());
+ CHECK(operand >= std::numeric_limits<int>::min() && operand <= std::numeric_limits<int>::max());
vm_.code.at(operand_index).v_int = static_cast<int>(operand);
}
int CodeGenStackVM::GetVarID(const VarNode* v) const {
auto it = var_idmap_.find(v);
- CHECK(it != var_idmap_.end())
- << "Find undefined Variable " << v->name_hint;
+ CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint;
return it->second;
}
void CodeGenStackVM::VisitExpr_(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const LoadNode *l = op->args[0].as<LoadNode>();
+ const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
this->Push(l->index);
}
}
-void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
- const PrimExpr& a,
- const PrimExpr& b) {
+void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b) {
this->Push(a);
this->Push(b);
DataType t = a.dtype();
CHECK(op->value >= std::numeric_limits<int>::min() &&
op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound";
- this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
+ this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}
void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) {
PushCast(op->dtype, op->value.dtype());
}
-void CodeGenStackVM::VisitExpr_(const AddNode* op) {
- PushBinary(StackVM::ADD_I64, op->a, op->b);
-}
+void CodeGenStackVM::VisitExpr_(const AddNode* op) { PushBinary(StackVM::ADD_I64, op->a, op->b); }
-void CodeGenStackVM::VisitExpr_(const SubNode* op) {
- PushBinary(StackVM::SUB_I64, op->a, op->b);
-}
+void CodeGenStackVM::VisitExpr_(const SubNode* op) { PushBinary(StackVM::SUB_I64, op->a, op->b); }
-void CodeGenStackVM::VisitExpr_(const MulNode* op) {
- PushBinary(StackVM::MUL_I64, op->a, op->b);
-}
+void CodeGenStackVM::VisitExpr_(const MulNode* op) { PushBinary(StackVM::MUL_I64, op->a, op->b); }
-void CodeGenStackVM::VisitExpr_(const DivNode* op) {
- PushBinary(StackVM::DIV_I64, op->a, op->b);
-}
+void CodeGenStackVM::VisitExpr_(const DivNode* op) { PushBinary(StackVM::DIV_I64, op->a, op->b); }
-void CodeGenStackVM::VisitExpr_(const ModNode* op) {
- PushBinary(StackVM::MOD_I64, op->a, op->b);
-}
+void CodeGenStackVM::VisitExpr_(const ModNode* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); }
void CodeGenStackVM::VisitExpr_(const MinNode* op) {
this->Push(op->a);
this->PushOp(StackVM::SELECT);
}
-void CodeGenStackVM::VisitExpr_(const EQNode* op) {
- PushBinary(StackVM::EQ_I64, op->a, op->b);
-}
+void CodeGenStackVM::VisitExpr_(const EQNode* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); }
-void CodeGenStackVM::VisitExpr_(const LENode* op) {
- PushBinary(StackVM::LE_I64, op->a, op->b);
-}
+void CodeGenStackVM::VisitExpr_(const LENode* op) { PushBinary(StackVM::LE_I64, op->a, op->b); }
void CodeGenStackVM::VisitExpr_(const NENode* op) {
PushBinary(StackVM::EQ_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
-void CodeGenStackVM::VisitExpr_(const LTNode* op) {
- PushBinary(StackVM::LT_I64, op->a, op->b);
-}
+void CodeGenStackVM::VisitExpr_(const LTNode* op) { PushBinary(StackVM::LT_I64, op->a, op->b); }
void CodeGenStackVM::VisitExpr_(const GENode* op) {
PushBinary(StackVM::LT_I64, op->a, op->b);
}
}
-void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) {
+void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) {
if (is_const(ev->value)) return;
const CallNode* op = ev->value.as<CallNode>();
if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) {
this->Push(op->body);
}
-void CodeGenStackVM::VisitExpr_(const RampNode* op) {
- LOG(FATAL) << "Ramp is not supported";
-}
+void CodeGenStackVM::VisitExpr_(const RampNode* op) { LOG(FATAL) << "Ramp is not supported"; }
void CodeGenStackVM::VisitExpr_(const BroadcastNode* op) {
LOG(FATAL) << "Broadcast is not supported";
this->Push(op->body);
}
-void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) {
- this->Push(op->body);
-}
+void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) { this->Push(op->body); }
void CodeGenStackVM::VisitExpr_(const LetNode* op) {
this->Push(op->value);
std::unordered_map<std::string, StackVM> fmap;
std::string entry_func;
- for (auto kv : mod->functions) {
- CHECK(kv.second->IsInstance<PrimFuncNode>())
- << "CodeGenStackVM: Can only take PrimFunc";
+ for (auto kv : mod->functions) {
+ CHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenStackVM: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol.value();
StackVM vm = codegen::CodeGenStackVM().Compile(f);
- CHECK(!fmap.count(f_name))
- << "Function name " << f_name << "already exist in list";
+ CHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list";
fmap[f_name] = std::move(vm);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
return runtime::StackVMModuleCreate(fmap, entry_func);
}
-TVM_REGISTER_GLOBAL("target.build.stackvm")
-.set_body_typed(BuildStackVM);
+TVM_REGISTER_GLOBAL("target.build.stackvm").set_body_typed(BuildStackVM);
} // namespace codegen
} // namespace tvm
#ifndef TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_
#define TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_
+#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/target/codegen.h>
+
#include <string>
-#include <vector>
#include <unordered_map>
+#include <vector>
#include "../../runtime/stackvm/stackvm.h"
* This module is used to generate host wrapper
* into device function when only device JIT is available.
*/
-class CodeGenStackVM
- : public ExprFunctor<void(const PrimExpr&)>,
- public StmtFunctor<void(const Stmt&)> {
+class CodeGenStackVM : public ExprFunctor<void(const PrimExpr&)>,
+ public StmtFunctor<void(const Stmt&)> {
public:
- /*!
+ /*!
* \brief Generate a stack VM representing
* \param f The function to be compiled
* \param device_funcs The extern device functions to be linked.
/*! \brief Push stmt to generate new code */
void Push(const Stmt& n);
/*! \brief Push expr to generate new code */
- void Push(const PrimExpr& n) {
- VisitExpr(n);
- }
+ void Push(const PrimExpr& n) { VisitExpr(n); }
/*!
* \brief Push the opcode to the code.
* \param opcode The code to be pushed.
*/
void SetOperand(int64_t operand_index, int64_t operand);
/*! \return The current program pointer */
- int64_t GetPC() const {
- return static_cast<int64_t>(vm_.code.size());
- }
+ int64_t GetPC() const { return static_cast<int64_t>(vm_.code.size()); }
/*!
* \brief Get string id in vm
* \param key The string to get id.
*/
int GetVarID(const VarNode* v) const;
// Push binary operator
- void PushBinary(StackVM::OpCode op_int64,
- const PrimExpr& a,
- const PrimExpr& b);
+ void PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b);
// push cast;
void PushCast(DataType dst, DataType src);
// overloadable functions
* \file src/target/target.cc
*/
#include <dmlc/thread_local.h>
-
-#include <tvm/runtime/registry.h>
#include <tvm/node/repr_printer.h>
+#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
-
#include <tvm/tir/expr.h>
#include <algorithm>
namespace tvm {
+using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
-using runtime::PackedFunc;
TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const TargetNode*>(node.get());
- p->stream << op->str();
- });
+ .set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const TargetNode*>(node.get());
+ p->stream << op->str();
+ });
/*!
-* \brief Construct a Target node from the given name and options.
-* \param target_name The major target name. Should be one of
-* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm",
-* "metal", "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"}
-* \param options Additional options appended to the target
-* \return The constructed Target
-*/
-Target CreateTarget(const std::string& target_name,
- const std::vector<std::string>& options) {
+ * \brief Construct a Target node from the given name and options.
+ * \param target_name The major target name. Should be one of
+ * {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm",
+ * "metal", "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"}
+ * \param options Additional options appended to the target
+ * \return The constructed Target
+ */
+Target CreateTarget(const std::string& target_name, const std::vector<std::string>& options) {
auto t = make_object<TargetNode>();
t->target_name = target_name;
if (t->device_name == "intel_graphics") {
t->thread_warp_size = 16;
}
- } else if (target_name == "metal" ||
- target_name == "vulkan" ||
- target_name == "webgpu") {
+ } else if (target_name == "metal" || target_name == "vulkan" || target_name == "webgpu") {
if (target_name == "metal") {
t->device_type = kDLMetal;
} else if (target_name == "vulkan") {
return Target(t);
}
-TVM_REGISTER_GLOBAL("target.TargetCreate")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_name = args[0];
std::vector<std::string> options;
for (int i = 1; i < args.num_args; ++i) {
}
*ret = CreateTarget(target_name, options);
- });
+});
-TVM_REGISTER_GLOBAL("target.TargetFromString")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_str = args[0];
*ret = Target::Create(target_str);
- });
+});
std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result;
if (str_repr_.length() != 0) return str_repr_;
std::ostringstream result;
result << target_name;
- for (const auto &x : options()) {
+ for (const auto& x : options()) {
result << " " << x;
}
str_repr_ = result.str();
return str_repr_;
}
-
bool StartsWith(const std::string& str, const std::string& pattern) {
return str.compare(0, pattern.length(), pattern) == 0;
}
typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore;
void Target::EnterWithScope() {
- TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
+ TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void Target::ExitWithScope() {
- TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
+ TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
tvm::Target Target::Current(bool allow_not_defined) {
- TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
+ TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
CHECK(allow_not_defined)
- << "Target context required. Please set it by constructing a TargetContext";
+ << "Target context required. Please set it by constructing a TargetContext";
return Target();
}
-TVM_REGISTER_GLOBAL("target.GetCurrentTarget")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body([](TVMArgs args, TVMRetValue* ret) {
bool allow_not_defined = args[0];
*ret = Target::Current(allow_not_defined);
- });
+});
class Target::Internal {
public:
- static void EnterScope(Target target) {
- target.EnterWithScope();
- }
- static void ExitScope(Target target) {
- target.ExitWithScope();
- }
+ static void EnterScope(Target target) { target.EnterWithScope(); }
+ static void ExitScope(Target target) { target.ExitWithScope(); }
};
-TVM_REGISTER_GLOBAL("target.EnterTargetScope")
-.set_body_typed(Target::Internal::EnterScope);
+TVM_REGISTER_GLOBAL("target.EnterTargetScope").set_body_typed(Target::Internal::EnterScope);
-TVM_REGISTER_GLOBAL("target.ExitTargetScope")
-.set_body_typed(Target::Internal::ExitScope);
+TVM_REGISTER_GLOBAL("target.ExitTargetScope").set_body_typed(Target::Internal::ExitScope);
namespace target {
std::vector<std::string> MergeOptions(std::vector<std::string> opts,
- const std::vector<std::string>& new_opts) {
+ const std::vector<std::string>& new_opts) {
opts.insert(opts.end(), new_opts.begin(), new_opts.end());
return opts;
}
-Target llvm(const std::vector<std::string>& options) {
- return CreateTarget("llvm", options);
-}
+Target llvm(const std::vector<std::string>& options) { return CreateTarget("llvm", options); }
-Target cuda(const std::vector<std::string>& options) {
- return CreateTarget("cuda", options);
-}
+Target cuda(const std::vector<std::string>& options) { return CreateTarget("cuda", options); }
-Target rocm(const std::vector<std::string>& options) {
- return CreateTarget("rocm", options);
-}
+Target rocm(const std::vector<std::string>& options) { return CreateTarget("rocm", options); }
-Target opencl(const std::vector<std::string>& options) {
- return CreateTarget("opencl", options);
-}
+Target opencl(const std::vector<std::string>& options) { return CreateTarget("opencl", options); }
-Target metal(const std::vector<std::string>& options) {
- return CreateTarget("metal", options);
-}
+Target metal(const std::vector<std::string>& options) { return CreateTarget("metal", options); }
Target mali(const std::vector<std::string>& options) {
- return CreateTarget("opencl", MergeOptions(options, {
- "-device=mali"
- }));
+ return CreateTarget("opencl", MergeOptions(options, {"-device=mali"}));
}
Target intel_graphics(const std::vector<std::string>& options) {
- return CreateTarget("opencl", MergeOptions(options, {
- "-device=intel_graphics"
- }));
+ return CreateTarget("opencl", MergeOptions(options, {"-device=intel_graphics"}));
}
-Target stackvm(const std::vector<std::string>& options) {
- return CreateTarget("stackvm", options);
-}
+Target stackvm(const std::vector<std::string>& options) { return CreateTarget("stackvm", options); }
-Target ext_dev(const std::vector<std::string>& options) {
- return CreateTarget("ext_dev", options);
-}
+Target ext_dev(const std::vector<std::string>& options) { return CreateTarget("ext_dev", options); }
-Target hexagon(const std::vector<std::string>& options) {
- return CreateTarget("hexagon", options);
-}
+Target hexagon(const std::vector<std::string>& options) { return CreateTarget("hexagon", options); }
} // namespace target
-BuildConfig BuildConfig::Create() {
- return BuildConfig(make_object<BuildConfigNode>());
-}
+BuildConfig BuildConfig::Create() { return BuildConfig(make_object<BuildConfigNode>()); }
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMBuildConfigThreadLocalEntry {
/*! \brief The current build config context */
std::stack<BuildConfig> context_stack;
- TVMBuildConfigThreadLocalEntry() :
- default_config(BuildConfig::Create()) {
- }
+ TVMBuildConfigThreadLocalEntry() : default_config(BuildConfig::Create()) {}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore;
void BuildConfig::EnterWithScope() {
- TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
+ TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void BuildConfig::ExitWithScope() {
- TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
+ TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
tvm::BuildConfig BuildConfig::Current() {
- TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
+ TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<BuildConfigNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BuildConfigNode*>(node.get());
- p->stream << "build_config(";
- p->stream << "data_alignment=" << op->data_alignment << ", ";
- p->stream << "offset_factor=" << op->offset_factor << ", ";
- p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", ";
- p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", ";
- p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", ";
- p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", ";
- p->stream << "unroll_explicit=" << op->unroll_explicit << ", ";
- p->stream << "restricted_func=" << op->restricted_func << ", ";
- p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
- p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
- p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
- p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
- p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
- p->stream << "disable_vectorize=" << op->disable_vectorize;
- p->stream << "disable_assert=" << op->disable_assert;
- p->stream << ")";
-});
-
-TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+ .set_dispatch<BuildConfigNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BuildConfigNode*>(node.get());
+ p->stream << "build_config(";
+ p->stream << "data_alignment=" << op->data_alignment << ", ";
+ p->stream << "offset_factor=" << op->offset_factor << ", ";
+ p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", ";
+ p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", ";
+ p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", ";
+ p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", ";
+ p->stream << "unroll_explicit=" << op->unroll_explicit << ", ";
+ p->stream << "restricted_func=" << op->restricted_func << ", ";
+ p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
+ p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
+ p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
+ p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
+ p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
+ p->stream << "disable_vectorize=" << op->disable_vectorize;
+ p->stream << "disable_assert=" << op->disable_assert;
+ p->stream << ")";
+ });
+
+TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BuildConfig::Current();
- });
+});
class BuildConfig::Internal {
public:
- static void EnterScope(BuildConfig target) {
- target.EnterWithScope();
- }
- static void ExitScope(BuildConfig target) {
- target.ExitWithScope();
- }
+ static void EnterScope(BuildConfig target) { target.EnterWithScope(); }
+ static void ExitScope(BuildConfig target) { target.ExitWithScope(); }
};
TVM_REGISTER_GLOBAL("target.EnterBuildConfigScope")
-.set_body_typed(BuildConfig::Internal::EnterScope);
+ .set_body_typed(BuildConfig::Internal::EnterScope);
-TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
-.set_body_typed(BuildConfig::Internal::ExitScope);
+TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope").set_body_typed(BuildConfig::Internal::ExitScope);
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- BuildConfig cfg = args[0];
- std::vector<std::pair<int, transform::Pass>> add_lower_pass;
- CHECK_EQ(args.size() % 2, 1);
- for (int i = 1; i < args.size(); i += 2) {
- add_lower_pass.push_back(std::make_pair(
- args[i].operator int(),
- args[i + 1].operator transform::Pass()));
- }
- cfg->add_lower_pass = add_lower_pass;
- });
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ BuildConfig cfg = args[0];
+ std::vector<std::pair<int, transform::Pass>> add_lower_pass;
+ CHECK_EQ(args.size() % 2, 1);
+ for (int i = 1; i < args.size(); i += 2) {
+ add_lower_pass.push_back(
+ std::make_pair(args[i].operator int(), args[i + 1].operator transform::Pass()));
+ }
+ cfg->add_lower_pass = add_lower_pass;
+ });
TVM_REGISTER_GLOBAL("target.BuildConfigGetAddLowerPassInfo")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- // Return one of the following:
- // * Size of add_lower_pass if num_args == 1
- // * Phase index of pass if args are (config, index, true)
- // * Function of pass if args are (config, index, false)
- BuildConfig cfg = args[0];
- if (args.num_args == 1) {
- *ret = static_cast<int64_t>(cfg->add_lower_pass.size());
- } else {
- int index = args[1];
- bool get_phase = args[2];
- auto item = cfg->add_lower_pass[index];
- if (get_phase) {
- *ret = item.first;
- } else {
- *ret = item.second;
- }
- }
-});
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ // Return one of the following:
+ // * Size of add_lower_pass if num_args == 1
+ // * Phase index of pass if args are (config, index, true)
+ // * Function of pass if args are (config, index, false)
+ BuildConfig cfg = args[0];
+ if (args.num_args == 1) {
+ *ret = static_cast<int64_t>(cfg->add_lower_pass.size());
+ } else {
+ int index = args[1];
+ bool get_phase = args[2];
+ auto item = cfg->add_lower_pass[index];
+ if (get_phase) {
+ *ret = item.first;
+ } else {
+ *ret = item.second;
+ }
+ }
+ });
} // namespace tvm
/*!
* \file target/target_info.cc
*/
-#include <tvm/runtime/registry.h>
#include <tvm/node/repr_printer.h>
+#include <tvm/runtime/registry.h>
#include <tvm/target/target_info.h>
namespace tvm {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<MemoryInfoNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const MemoryInfoNode*>(node.get());
- p->stream << "mem-info("
- << "unit_bits=" << op->unit_bits << ", "
- << "max_num_bits=" << op->max_num_bits << ", "
- << "max_simd_bits=" << op->max_simd_bits << ", "
- << "head_address=" << op->head_address << ")";
-});
+ .set_dispatch<MemoryInfoNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const MemoryInfoNode*>(node.get());
+ p->stream << "mem-info("
+ << "unit_bits=" << op->unit_bits << ", "
+ << "max_num_bits=" << op->max_num_bits << ", "
+ << "max_simd_bits=" << op->max_simd_bits << ", "
+ << "head_address=" << op->head_address << ")";
+ });
TVM_REGISTER_NODE_TYPE(MemoryInfoNode);
* \file ad_util.cc
* \brief Utility for tensor-level auto-differentiation.
*/
+#include "ad_util.h"
+
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
+
#include <string>
-#include "ad_util.h"
namespace tvm {
namespace te {
Map<Var, PrimExpr> vmap;
for (const IterVar& iv : vars) {
IterVar new_v =
- IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""),
- iv->iter_type, iv->thread_tag);
+ IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), iv->iter_type, iv->thread_tag);
new_vars.push_back(new_v);
vmap.Set(iv->var, new_v->var);
}
src_with_newaxis.push_back(tir::Substitute(src, vmap));
}
- return ReduceNode::make(red->combiner, src_with_newaxis,
- new_axis, tir::Substitute(red->condition, vmap), red->value_index);
+ return ReduceNode::make(red->combiner, src_with_newaxis, new_axis,
+ tir::Substitute(red->condition, vmap), red->value_index);
} else {
return expr;
}
#ifndef TVM_TE_AUTODIFF_AD_UTIL_H_
#define TVM_TE_AUTODIFF_AD_UTIL_H_
-#include <tvm/tir/expr.h>
#include <tvm/te/operation.h>
-#include <vector>
+#include <tvm/tir/expr.h>
+
#include <unordered_map>
#include <utility>
+#include <vector>
namespace tvm {
namespace te {
* (3) and sum them together to get the adjoint of the input itself.
* The three steps are computed recursively.
*/
+#include <topi/elemwise.h>
+#include <topi/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/autodiff.h>
#include <tvm/tir/stmt_functor.h>
-#include <topi/transform.h>
-#include <topi/elemwise.h>
+
#include <memory>
#include <vector>
// add extra dimension for Jacobian
shape.push_back(e);
}
- auto func =
- [&output](const Array<Var>& input_indices) {
- PrimExpr res = const_true();
- for (size_t i = 0; i < output->shape.size(); ++i) {
- res = res && (PrimExpr(input_indices[i]) ==
- PrimExpr(input_indices[output->shape.size() + i]));
- }
- return CastNode::make(output->dtype, res);
- };
+ auto func = [&output](const Array<Var>& input_indices) {
+ PrimExpr res = const_true();
+ for (size_t i = 0; i < output->shape.size(); ++i) {
+ res =
+ res && (PrimExpr(input_indices[i]) == PrimExpr(input_indices[output->shape.size() + i]));
+ }
+ return CastNode::make(output->dtype, res);
+ };
return te::compute(shape, func, "identity");
}
-Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head) {
+Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Tensor& head) {
Tensor jac = Jacobian(output, input);
Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(),
output->op->name + "." + input->op->name + ".grad");
return result;
}
-Array<Tensor> Gradient(const Tensor& output,
- const Array<Tensor>& inputs,
+Array<Tensor> Gradient(const Tensor& output, const Array<Tensor>& inputs,
const Tensor& head_or_null) {
// Diagonal identity tensor
Tensor head = head_or_null.get() ? head_or_null : Identity(output);
// This is a recursive function that does all the work. It computes the adjoint for a given
// tensor, adds it to the map, and returns it
std::function<Tensor(const Tensor&)> compute_adjoint;
- compute_adjoint =
- [&compute_adjoint, &adjoints, &reverse_dependencies, &head, &output]
- (const Tensor& tensor) {
- if (!adjoints.count(tensor)) {
- // Here the adjoint hasn't been computed yet
- Tensor res_adjoint;
- std::vector<Tensor> direct_consumers = reverse_dependencies[tensor];
- if (direct_consumers.empty()) {
- // No reverse dependencies means that the output does not depend on this tensor,
- // return a zero tensor of the appropriate shape
- // (i.e., output shape + tensor shape, aka shape of Jacobian)
- Array<PrimExpr> result_shape(head->shape.begin(),
- head->shape.end() + (-output->shape.size()));
- for (auto e : tensor->shape) {
- result_shape.push_back(e);
- }
- res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype));
- } else {
- // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied
- // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian
- // and the multiplication is done in the function VectorJacobianProduct
- for (const Tensor& direct_consumer : direct_consumers) {
- // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor)
- Tensor part = VectorJacobianProduct(
- direct_consumer, tensor, compute_adjoint(direct_consumer));
- res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part;
- }
+ compute_adjoint = [&compute_adjoint, &adjoints, &reverse_dependencies, &head,
+ &output](const Tensor& tensor) {
+ if (!adjoints.count(tensor)) {
+ // Here the adjoint hasn't been computed yet
+ Tensor res_adjoint;
+ std::vector<Tensor> direct_consumers = reverse_dependencies[tensor];
+ if (direct_consumers.empty()) {
+ // No reverse dependencies means that the output does not depend on this tensor,
+ // return a zero tensor of the appropriate shape
+ // (i.e., output shape + tensor shape, aka shape of Jacobian)
+ Array<PrimExpr> result_shape(head->shape.begin(),
+ head->shape.end() + (-output->shape.size()));
+ for (auto e : tensor->shape) {
+ result_shape.push_back(e);
}
-
- adjoints[tensor] = res_adjoint;
- return res_adjoint;
+ res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype));
} else {
- return adjoints[tensor];
+ // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied
+ // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian
+ // and the multiplication is done in the function VectorJacobianProduct
+ for (const Tensor& direct_consumer : direct_consumers) {
+ // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor)
+ Tensor part =
+ VectorJacobianProduct(direct_consumer, tensor, compute_adjoint(direct_consumer));
+ res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part;
+ }
}
- };
+
+ adjoints[tensor] = res_adjoint;
+ return res_adjoint;
+ } else {
+ return adjoints[tensor];
+ }
+ };
// Adjoints corresponding to inputs
Array<Tensor> result;
return result;
}
-TVM_REGISTER_GLOBAL("te.Gradient")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- LOG(WARNING) << "te.Gradient is an experimental feature.";
- if (args.size() == 2) {
- *ret = Gradient(args[0], args[1]);
- } else if (args.size() == 3) {
- *ret = Gradient(args[0], args[1], args[2]);
- }
- });
+TVM_REGISTER_GLOBAL("te.Gradient").set_body([](TVMArgs args, TVMRetValue* ret) {
+ LOG(WARNING) << "te.Gradient is an experimental feature.";
+ if (args.size() == 2) {
+ *ret = Gradient(args[0], args[1]);
+ } else if (args.size() == 3) {
+ *ret = Gradient(args[0], args[1], args[2]);
+ }
+});
} // namespace te
} // namespace tvm
* X must be direct input tensor of Y.
* The result Jacobian shape will be (Y.shape, X.shape)
*/
-#include <tvm/te/autodiff.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
#include <tvm/tir/stmt_functor.h>
#include <memory>
+
#include "ad_util.h"
namespace tvm {
namespace te {
-#define NOT_IMPLEMENTED \
- { LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef<PrimExpr>(op); throw; }
+#define NOT_IMPLEMENTED \
+ { \
+ LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef<PrimExpr>(op); \
+ throw; \
+ }
/*! \brief Differentiate an expression wrt a variable or a tensor element */
class JacobianMutator : public ExprMutator {
* \param indices The indices of the element with respect to which to differentiate.
*/
explicit JacobianMutator(Tensor input, Array<PrimExpr> indices)
- : input_(input), indices_(indices) {}
+ : input_(input), indices_(indices) {}
/*!
* \brief Differentiate wrt the input variable.
* \param input The input variable.
}
}
- PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED
+ PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED;
PrimExpr VisitExpr_(const CallNode* op) {
PrimExpr expr = GetRef<PrimExpr>(op);
if (op->call_type == CallNode::CallType::Halide) {
- if (input_.get() && op->func.same_as(input_->op) &&
- op->value_index == input_->value_index) {
+ if (input_.get() && op->func.same_as(input_->op) && op->value_index == input_->value_index) {
// Tensor(indices)
CHECK_EQ(indices_.size(), op->args.size());
PrimExpr condition = const_true();
return MulNode::make(Mutate(op->args[0]),
MulNode::make(expr, SubNode::make(FloatImm(expr.dtype(), 1.0), expr)));
} else if (op->name == "sqrt") {
- return DivNode::make(Mutate(op->args[0]),
- MulNode::make(expr, FloatImm(expr.dtype(), 2.0)));
+ return DivNode::make(Mutate(op->args[0]), MulNode::make(expr, FloatImm(expr.dtype(), 2.0)));
} else if (op->name == "tanh") {
return MulNode::make(Mutate(op->args[0]),
SubNode::make(FloatImm(expr.dtype(), 1.0), MulNode::make(expr, expr)));
} else if (op->name == "pow") {
auto x = op->args[0], y = op->args[1];
- return expr * (Mutate(y)*log(x) + Mutate(x)*y/x);
+ return expr * (Mutate(y) * log(x) + Mutate(x) * y / x);
} else if (op->name == "fabs") {
auto type = op->args[0].dtype();
return MulNode::make(Mutate(op->args[0]),
SelectNode::make(GENode::make(op->args[0], make_zero(type)),
FloatImm(type, 1.0), FloatImm(type, -1.0)));
} else if (op->name == intrinsic::tvm_if_then_else) {
- Array<PrimExpr> new_args = {op->args[0],
- Mutate(op->args[1]),
- Mutate(op->args[2])};
- return CallNode::make(op->dtype, op->name, new_args,
- op->call_type, op->func, op->value_index);
+ Array<PrimExpr> new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])};
+ return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func,
+ op->value_index);
} else if (piecewise_const.count(op->name)) {
return FloatImm(expr.dtype(), 0.0);
} else {
throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name);
}
}
- NOT_IMPLEMENTED
+ NOT_IMPLEMENTED;
}
- PrimExpr VisitExpr_(const AddNode* op) {
- return AddNode::make(Mutate(op->a), Mutate(op->b));
- }
+ PrimExpr VisitExpr_(const AddNode* op) { return AddNode::make(Mutate(op->a), Mutate(op->b)); }
- PrimExpr VisitExpr_(const SubNode* op) {
- return SubNode::make(Mutate(op->a), Mutate(op->b));
- }
+ PrimExpr VisitExpr_(const SubNode* op) { return SubNode::make(Mutate(op->a), Mutate(op->b)); }
PrimExpr VisitExpr_(const MulNode* op) {
- return AddNode::make(
- MulNode::make(Mutate(op->a), op->b),
- MulNode::make(op->a, Mutate(op->b)));
+ return AddNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b)));
}
PrimExpr VisitExpr_(const DivNode* op) {
return DivNode::make(
- SubNode::make(
- MulNode::make(Mutate(op->a), op->b),
- MulNode::make(op->a, Mutate(op->b))),
+ SubNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))),
MulNode::make(op->b, op->b));
}
- PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED
+ PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED;
PrimExpr VisitExpr_(const FloorDivNode* op) {
return FloorDivNode::make(
- SubNode::make(
- MulNode::make(Mutate(op->a), op->b),
- MulNode::make(op->a, Mutate(op->b))),
+ SubNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))),
MulNode::make(op->b, op->b));
}
- PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED
+ PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED;
PrimExpr VisitExpr_(const MinNode* op) {
- return SelectNode::make(LENode::make(op->a, op->b),
- Mutate(op->a), Mutate(op->b));
+ return SelectNode::make(LENode::make(op->a, op->b), Mutate(op->a), Mutate(op->b));
}
PrimExpr VisitExpr_(const MaxNode* op) {
- return SelectNode::make(GENode::make(op->a, op->b),
- Mutate(op->a), Mutate(op->b));
+ return SelectNode::make(GENode::make(op->a, op->b), Mutate(op->a), Mutate(op->b));
}
- PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED
+ PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED;
PrimExpr VisitExpr_(const ReduceNode* op) {
// This case is relatively difficult because a reduction expression
CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
// Also simplify the resulting combiner
// (mostly to get rid of unused components, e.g., the original expressions)
- return analyzer_.Simplify(
- ReduceNode::make(new_combiner, new_source, new_op->axis,
- new_op->condition, new_op->value_index));
+ return analyzer_.Simplify(ReduceNode::make(new_combiner, new_source, new_op->axis,
+ new_op->condition, new_op->value_index));
}
PrimExpr VisitExpr_(const CastNode* op) {
}
}
- PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED
+ PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED;
PrimExpr VisitExpr_(const SelectNode* op) {
- return SelectNode::make(op->condition,
- Mutate(op->true_value), Mutate(op->false_value));
+ return SelectNode::make(op->condition, Mutate(op->true_value), Mutate(op->false_value));
}
- PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED
- PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED
+ PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED;
+ PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED;
- PrimExpr VisitExpr_(const IntImmNode* op) {
- return IntImm(op->dtype, 0);
- }
+ PrimExpr VisitExpr_(const IntImmNode* op) { return IntImm(op->dtype, 0); }
- PrimExpr VisitExpr_(const FloatImmNode* op) {
- return FloatImm(op->dtype, 0);
- }
+ PrimExpr VisitExpr_(const FloatImmNode* op) { return FloatImm(op->dtype, 0); }
- PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED
+ PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED;
private:
Tensor input_;
Array<PrimExpr> input_indices;
size_t i = 0;
for (PrimExpr ext : input->shape) {
- IterVar new_v = IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i++)),
- IterVarType::kDataPar);
+ IterVar new_v =
+ IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i++)), IterVarType::kDataPar);
// Append jacobian iter to new_axis
new_axis.push_back(new_v);
// Differentiate wrt input[input_indices]
}
arith::Analyzer analzyer;
// Compute Jacobian
- PrimExpr new_body = Jacobian(
- Substitute(op->body[output->value_index], vmap), input, input_indices);
+ PrimExpr new_body =
+ Jacobian(Substitute(op->body[output->value_index], vmap), input, input_indices);
new_body = analzyer.Simplify(new_body);
int value_index = 0;
value_index = red->value_index;
for (size_t idx = 0; idx < red->source.size(); ++idx) {
new_bodies.push_back(
- ReduceNode::make(red->combiner, red->source, red->axis, red->condition, idx));
+ ReduceNode::make(red->combiner, red->source, red->axis, red->condition, idx));
}
} else {
new_bodies.push_back(new_body);
}
- auto new_op = ComputeOpNode::make(
- op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies);
+ auto new_op =
+ ComputeOpNode::make(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies);
// Jacobian shape = output.shape + input.shape
Array<PrimExpr> new_shape = output->shape;
* \brief Compute Op.
* \file compute_op.cc
*/
+#include "compute_op.h"
+
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/tir/expr.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <unordered_set>
+
#include <string>
+#include <unordered_set>
#include <utility>
-#include "compute_op.h"
-#include "op_util.h"
-#include "../schedule/message_passing.h"
+
#include "../../arith/compute_expr.h"
#include "../../arith/interval_set.h"
+#include "../schedule/message_passing.h"
+#include "op_util.h"
namespace tvm {
namespace te {
using namespace tir;
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ComputeOpNode*>(node.get());
- p->stream << "compute(" << op->name << ", " << op << ")";
-});
+ .set_dispatch<ComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ComputeOpNode*>(node.get());
+ p->stream << "compute(" << op->name << ", " << op << ")";
+ });
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
/// Verify if ComputeOp is valid with respect to Reduce operations.
-static void VerifyComputeOp(const ComputeOpNode *op);
+static void VerifyComputeOp(const ComputeOpNode* op);
inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
- return (a->combiner.same_as(b->combiner)) &&
- (a->source.same_as(b->source)) &&
- (a->axis.same_as(b->axis)) &&
- (a->condition.same_as(b->condition));
+ return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
+ (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition));
}
-int ComputeOpNode::num_outputs() const {
- return body.size();
-}
+int ComputeOpNode::num_outputs() const { return body.size(); }
Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
if (reduce_axis.size() == 0) return axis;
return shape;
}
-Tensor compute(Array<PrimExpr> shape,
- FCompute fcompute,
- std::string name,
- std::string tag,
+Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name, std::string tag,
Map<std::string, ObjectRef> attrs) {
auto op_node = make_object<ComputeOpNode>();
// compute dimension.
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
- axis.emplace_back(IterVarNode::make(
- Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
+ axis.emplace_back(
+ IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}
- return ComputeOpNode::make(
- name, tag, attrs, axis, {fcompute(args)}).output(0);
+ return ComputeOpNode::make(name, tag, attrs, axis, {fcompute(args)}).output(0);
}
-Array<Tensor> compute(Array<PrimExpr> shape,
- FBatchCompute fcompute,
- std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs) {
+Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, std::string name,
+ std::string tag, Map<std::string, ObjectRef> attrs) {
auto op_node = make_object<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
- axis.emplace_back(IterVarNode::make(
- Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
+ axis.emplace_back(
+ IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}
return outputs;
}
-Operation ComputeOpNode::make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<IterVar> axis,
- Array<PrimExpr> body) {
+Operation ComputeOpNode::make(std::string name, std::string tag, Map<std::string, ObjectRef> attrs,
+ Array<IterVar> axis, Array<PrimExpr> body) {
if (!attrs.defined()) {
attrs = Map<std::string, ObjectRef>();
}
return Operation(n);
}
-TVM_REGISTER_GLOBAL("te.ComputeOp")
-.set_body_typed(ComputeOpNode::make);
-
+TVM_REGISTER_GLOBAL("te.ComputeOp").set_body_typed(ComputeOpNode::make);
// The schedule related logics
Array<Tensor> ComputeOpNode::InputTensors() const {
std::unordered_set<Tensor> visited;
for (auto& e : body) {
tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
- const tir::CallNode *call = n.as<tir::CallNode>();
- if (call != nullptr && call->func.defined()) {
- Tensor t = Downcast<Operation>(call->func).output(call->value_index);
- if (!visited.count(t)) {
- ret.push_back(t);
- visited.insert(t);
- }
+ const tir::CallNode* call = n.as<tir::CallNode>();
+ if (call != nullptr && call->func.defined()) {
+ Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+ if (!visited.count(t)) {
+ ret.push_back(t);
+ visited.insert(t);
}
- });
+ }
+ });
}
return ret;
}
-Operation ComputeOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
+Operation ComputeOpNode::ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
VerifyComputeOp(this);
Array<PrimExpr> arr;
arr = this->body;
}
} else {
- arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) {
- return te::ReplaceTensor(e, rmap);
- });
+ arr =
+ UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); });
}
if (!arr.same_as(this->body)) {
- return ComputeOpNode::make(
- this->name, this->tag, this->attrs, this->axis, arr);
+ return ComputeOpNode::make(this->name, this->tag, this->attrs, this->axis, arr);
} else {
return self;
}
}
-void ComputeOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
- auto *call = n.as<tir::CallNode>();
+ auto* call = n.as<tir::CallNode>();
if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index);
if (t->op.defined() && out_dom_map->count(t)) {
for (auto& e : body) tir::PostOrderVisit(e, fvisit);
}
-void BaseComputeOpNode::GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
+void BaseComputeOpNode::GatherBound(const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) {
}
}
-Stmt BaseComputeOpNode::BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const {
+Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
Region bounds;
for (IterVar iv : this->axis) {
}
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
- Tensor t = stage->op.output(i-1);
- realize = tir::RealizeNode::make(t->op, t->value_index,
- t->dtype, bounds, const_true(), realize);
+ Tensor t = stage->op.output(i - 1);
+ realize =
+ tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < num_schedulable_dims(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
if (it != stage->iter_var_attrs.end()) {
IterVarAttr attr = (*it).second;
if (attr->dim_align_factor != 0) {
- Array<PrimExpr> tuple = {static_cast<int>(i),
- attr->dim_align_factor,
- attr->dim_align_offset};
- realize = tir::AttrStmtNode::make(
- t, tir::attr::buffer_dim_align,
- CallNode::make(DataType::Handle(),
- tir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic),
- realize);
+ Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor,
+ attr->dim_align_offset};
+ realize =
+ tir::AttrStmtNode::make(t, tir::attr::buffer_dim_align,
+ CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic),
+ realize);
}
}
}
return realize;
}
-size_t ComputeOpNode::num_schedulable_dims() const {
- return axis.size();
-}
+size_t ComputeOpNode::num_schedulable_dims() const { return axis.size(); }
// Build a reduction body.
-void MakeReduction(const ComputeOpNode* op,
- const Array<Tensor>& tensors,
- Stmt* init,
+void MakeReduction(const ComputeOpNode* op, const Array<Tensor>& tensors, Stmt* init,
Stmt* provide) {
- Array<PrimExpr> args;
+ Array<PrimExpr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
- inits.emplace_back(ProvideNode::make(
- t->op, t->value_index, init_value[i], args));
- provides.emplace_back(ProvideNode::make(
- t->op, t->value_index, update_value[i], args));
+ inits.emplace_back(ProvideNode::make(t->op, t->value_index, init_value[i], args));
+ provides.emplace_back(ProvideNode::make(t->op, t->value_index, update_value[i], args));
}
*init = SeqStmt::Flatten(inits);
*provide = SeqStmt::Flatten(provides);
}
// Normal computation.
-Stmt MakeProvide(const ComputeOpNode* op,
- const Tensor& t) {
+Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) {
Array<PrimExpr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args);
}
-Stmt MakeComputeStmt(const ComputeOpNode* self,
- const Stage& stage,
+Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
// grab the nest structure
init = MergeNest(n.init_nest, init);
init = Substitute(init, n.init_vmap);
// common nest
- std::vector<std::vector<Stmt> > common(
- n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
- std::vector<std::vector<Stmt> > reduce(
- n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
+ std::vector<std::vector<Stmt> > common(n.main_nest.begin(),
+ n.main_nest.begin() + n.num_common_loop + 1);
+ std::vector<std::vector<Stmt> > reduce(n.main_nest.begin() + n.num_common_loop + 1,
+ n.main_nest.end());
provide = MergeNest(reduce, provide);
if (debug_keep_trivial_loop) {
provide = MergeNest(common, provide);
}
}
-enum class ComputeType {
- kNormal,
- kCrossThreadReduction,
- kTensorize
-};
+enum class ComputeType { kNormal, kCrossThreadReduction, kTensorize };
-ComputeType DetectComputeType(const ComputeOpNode* self,
- const Stage& stage) {
+ComputeType DetectComputeType(const ComputeOpNode* self, const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0, tensorize = 0;
++normal_red;
}
} else {
- CHECK_EQ(thread_red, 0)
- << "Cross thread reduce cannot swap with normal data axis";
+ CHECK_EQ(thread_red, 0) << "Cross thread reduce cannot swap with normal data axis";
}
}
if (tensorize != 0) {
- CHECK(thread_red == 0)
- << "Cannot mix cross thread reduction with Tensorize";
+ CHECK(thread_red == 0) << "Cannot mix cross thread reduction with Tensorize";
return ComputeType::kTensorize;
}
if (thread_red != 0) {
}
// implement the provide utility.
-Stmt ComputeOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
+Stmt ComputeOpNode::BuildProvide(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
ComputeType ctype = DetectComputeType(this, stage);
if (ctype == ComputeType::kCrossThreadReduction) {
}
}
-ComputeLoopNest ComputeLoopNest::make(
- const BaseComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) {
+ComputeLoopNest ComputeLoopNest::make(const BaseComputeOpNode* self, const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) {
CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret;
// make main loop nest
- ret.main_nest = MakeLoopNest(
- stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
- debug_keep_trivial_loop);
- ret.main_predicates = MakeBoundCheck(
- stage, dom_map, ret.main_vmap, false,
- std::unordered_set<IterVar>());
+ ret.main_nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(),
+ &ret.main_vmap, debug_keep_trivial_loop);
+ ret.main_predicates =
+ MakeBoundCheck(stage, dom_map, ret.main_vmap, false, std::unordered_set<IterVar>());
for (auto& e : ret.main_predicates) {
e = likely(e);
}
auto iv = leaf_iter_vars[i];
int flag = update_state.at(iv);
if ((flag & 2) != 0) {
- begin_loop = i; break;
+ begin_loop = i;
+ break;
}
ret.init_vmap[iv] = ret.main_vmap.at(iv);
}
int flag = kv.second;
if (flag == 2) skip_iter.insert(kv.first);
}
- ret.init_nest = MakeLoopNest(
- stage, dom_map, begin_loop, true,
- skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
- ret.init_predicates = MakeBoundCheck(
- stage, dom_map, ret.init_vmap, true, skip_iter);
+ ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap),
+ debug_keep_trivial_loop);
+ ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) {
e = likely(e);
}
for (const PrimExpr e : compute_->body) {
// Check for consistency of top level reductions
const tir::ReduceNode* reduce = e.as<tir::ReduceNode>();
- CHECK((reduce && reduce_) || (!reduce && !reduce_))
- << "All ComputeOp should be consistent "
- << "with being Reduce operation or not.";
+ CHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent "
+ << "with being Reduce operation or not.";
if (reduce && reduce_) {
- CHECK(ReduceEqual(reduce, reduce_))
- << "The Reduce inputs of ComputeOp should "
- << "have the same attribute except value_index";
+ CHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should "
+ << "have the same attribute except value_index";
}
level_ = 0;
void VisitExpr_(const tir::ReduceNode* op) final {
// Check for non top level reductions
- CHECK(0 == level_)
- << "Reductions are only allowed at the top level of compute. "
- << "Please create another tensor for further composition.";
+ CHECK(0 == level_) << "Reductions are only allowed at the top level of compute. "
+ << "Please create another tensor for further composition.";
}
//@}
private:
- const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
- const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation
- int level_{0}; ///< Level of op being processed
+ const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
+ const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation
+ int level_{0}; ///< Level of op being processed
};
} // namespace
v.Run();
}
-Stmt TransformUpdate(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const ComputeLoopNest& n,
- Stmt body,
- Stmt update) {
+Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ const ComputeLoopNest& n, Stmt body, Stmt update) {
Array<PrimExpr> conds;
std::unordered_set<const VarNode*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
}
}
- auto fbanned = [&](const VarNode* node) {
- return banned.count(node);
- };
+ auto fbanned = [&](const VarNode* node) { return banned.count(node); };
for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
- LOG(FATAL) << "Tensorize update transform failed, the condition "
- << pred << " has a conflict with the reset condition";
+ LOG(FATAL) << "Tensorize update transform failed, the condition " << pred
+ << " has a conflict with the reset condition";
}
}
- return IfThenElseNode::make(arith::ComputeReduce<tir::OrNode>(conds, const_true(1)),
- update, body);
+ return IfThenElseNode::make(arith::ComputeReduce<tir::OrNode>(conds, const_true(1)), update,
+ body);
}
} // namespace te
#ifndef TVM_TE_OPERATION_COMPUTE_OP_H_
#define TVM_TE_OPERATION_COMPUTE_OP_H_
-#include <tvm/tir/expr.h>
#include <tvm/te/operation.h>
-#include <vector>
+#include <tvm/tir/expr.h>
+
#include <unordered_map>
+#include <vector>
namespace tvm {
namespace te {
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return The constructed loop nest
*/
- static ComputeLoopNest make(
- const BaseComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop);
+ static ComputeLoopNest make(const BaseComputeOpNode* self, const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop);
};
/*!
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return The created statement.
*/
-Stmt MakeCrossThreadReduction(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop);
+Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop);
/*!
* \brief Build body of compute for tensorization.
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return The created statement.
*/
-Stmt MakeTensorize(const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop);
+Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map, bool debug_keep_trivial_loop);
/*!
* \brief Transform the update part when there is no init func in tensorizing
* \param update The update func in tensorize intrin
* \return Transformed result.
*/
-Stmt TransformUpdate(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const ComputeLoopNest& n,
- Stmt body,
- Stmt update);
+Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ const ComputeLoopNest& n, Stmt body, Stmt update);
} // namespace te
} // namespace tvm
namespace te {
using namespace tir;
-Stmt MakeCrossThreadReduction(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) {
- Array<PrimExpr> args;
+Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) {
+ Array<PrimExpr> args;
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
std::unordered_map<IterVar, PrimExpr> value_map;
- auto nest = MakeLoopNest(
- stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
- auto conds = MakeBoundCheck(
- stage, dom_map, value_map, false,
- std::unordered_set<IterVar>());
+ auto nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map,
+ debug_keep_trivial_loop);
+ auto conds = MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set<IterVar>());
size_t size = self->body.size();
CHECK_GT(size, 0);
Array<PrimExpr> update_value = (*combiner)(lhs, reduces[0]->source);
for (size_t i = 0; i < size; ++i) {
DataType t = reduces[i]->dtype;
- normal_init.emplace_back(StoreNode::make(
- normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
- normal_update.emplace_back(StoreNode::make(
- normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
+ normal_init.emplace_back(
+ StoreNode::make(normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
+ normal_update.emplace_back(
+ StoreNode::make(normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
}
}
for (size_t i = 0; i < size; ++i) {
if (!normal_red.empty()) {
DataType t = reduces[i]->dtype;
- freduce_args.push_back(LoadNode::make(
- t, normal_res_handles[i], 0, const_true(t.lanes())));
+ freduce_args.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes())));
} else {
freduce_args.push_back(reduces[0]->source[i]);
}
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
- if (it != stage->iter_var_attrs.end() &&
- (*it).second->bind_thread.defined()) {
+ if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
IterVar tv = (*it).second->bind_thread;
freduce_args.push_back(tv->var);
}
}
Stmt reduce_body = EvaluateNode::make(CallNode::make(
- DataType::Handle(),
- tir::intrinsic::tvm_thread_allreduce,
- freduce_args, CallNode::Intrinsic));
- reduce_body = AttrStmtNode::make(
- reduces[0]->combiner,
- tir::attr::reduce_scope,
- make_zero(DataType::Handle()),
- reduce_body);
+ DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic));
+ reduce_body = AttrStmtNode::make(reduces[0]->combiner, tir::attr::reduce_scope,
+ make_zero(DataType::Handle()), reduce_body);
if (!normal_red.empty()) {
Stmt init_body = SeqStmt::Flatten(normal_init);
for (size_t idx = 0; idx < size; ++idx) {
DataType t = reduces[idx]->dtype;
assigns[idx] = ProvideNode::make(
- stage->op, idx,
- LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
+ stage->op, idx, LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = SeqStmt::Flatten(assigns);
assign_body = MergeNest(MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(MakeIfNest(conds), assign_body);
Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
- body = AllocateNode::make(
- res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
- body = AttrStmtNode::make(
- res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
+ body =
+ AllocateNode::make(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+ body = AttrStmtNode::make(res_handles[idx - 1], tir::attr::storage_scope,
+ StringImmNode::make("local"), body);
if (!normal_red.empty()) {
- body = AllocateNode::make(
- normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
- body = AttrStmtNode::make(
- normal_res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
+ body = AllocateNode::make(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1},
+ const_true(), body);
+ body = AttrStmtNode::make(normal_res_handles[idx - 1], tir::attr::storage_scope,
+ StringImmNode::make("local"), body);
}
}
body = Substitute(body, value_map);
* \brief External computation rule.
* \file extern_op.cc
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
-#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
+
#include <unordered_set>
+
#include "op_util.h"
namespace tvm {
using namespace tir;
// ExternOpNode
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ExternOpNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ExternOpNode*>(node.get());
- p->stream << "extern(" << op->name << ", " << op << ")";
- });
+ .set_dispatch<ExternOpNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ExternOpNode*>(node.get());
+ p->stream << "extern(" << op->name << ", " << op << ")";
+ });
TVM_REGISTER_NODE_TYPE(ExternOpNode);
-int ExternOpNode::num_outputs() const {
- return static_cast<int>(output_placeholders.size());
-}
-
-Array<IterVar> ExternOpNode::root_iter_vars() const {
- return {};
-}
+int ExternOpNode::num_outputs() const { return static_cast<int>(output_placeholders.size()); }
-DataType ExternOpNode::output_dtype(size_t i) const {
- return output_placeholders[i]->dtype;
-}
+Array<IterVar> ExternOpNode::root_iter_vars() const { return {}; }
-Array<PrimExpr> ExternOpNode::output_shape(size_t i) const {
- return output_placeholders[i]->shape;
-}
+DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders[i]->dtype; }
+Array<PrimExpr> ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; }
-Operation ExternOpNode::make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<Tensor> inputs,
- Array<Buffer> input_placeholders,
- Array<Buffer> output_placeholders,
- Stmt body) {
+Operation ExternOpNode::make(std::string name, std::string tag, Map<std::string, ObjectRef> attrs,
+ Array<Tensor> inputs, Array<Buffer> input_placeholders,
+ Array<Buffer> output_placeholders, Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, ObjectRef>();
}
CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size());
for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) {
- CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
+ CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
}
CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
}
return Operation(n);
}
-TVM_REGISTER_GLOBAL("te.ExternOp")
-.set_body_typed(ExternOpNode::make);
+TVM_REGISTER_GLOBAL("te.ExternOp").set_body_typed(ExternOpNode::make);
+Array<Tensor> ExternOpNode::InputTensors() const { return inputs; }
-Array<Tensor> ExternOpNode::InputTensors() const {
- return inputs;
-}
-
-Operation ExternOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
+Operation ExternOpNode::ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_object<ExternOpNode>(*this);
n->body = ReplaceTensor(this->body, rmap);
}
}
- if (body.same_as(n->body) &&
- inputs.same_as(n->inputs)) {
+ if (body.same_as(n->body) && inputs.same_as(n->inputs)) {
return self;
} else {
return Operation(n);
}
}
-void ExternOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+void ExternOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom& dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
- Range::make_by_min_extent(
- make_const(t->shape[i].dtype(), 0), t->shape[i])));
+ Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])));
}
}
}
-void ExternOpNode::GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
-}
+void ExternOpNode::GatherBound(const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {}
-Stmt ExternOpNode::BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const {
+Stmt ExternOpNode::BuildRealize(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
- bounds.push_back(
- Range::make_by_min_extent(
- make_const(t->shape[i].dtype(), 0), t->shape[i]));
+ bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
- realize_body = tir::RealizeNode::make(
- t->op, t->value_index, t->dtype,
- bounds, const_true(), realize_body);
+ realize_body =
+ tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize_body);
}
return realize_body;
}
-Stmt ExternOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
+Stmt ExternOpNode::BuildProvide(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt ret = AttrStmtNode::make(
- make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
+ Stmt ret =
+ AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<ObjectRef> bind_spec;
Array<PrimExpr> tuple;
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
+#include "hybrid_op.h"
+
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
-#include <unordered_set>
+#include <tvm/tir/stmt_functor.h>
+
#include <string>
+#include <unordered_set>
#include <utility>
+
#include "op_util.h"
-#include "hybrid_op.h"
namespace tvm {
namespace te {
using namespace tir;
// HybridOpNode
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<HybridOpNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const HybridOpNode*>(node.get());
- p->stream << "hybrid(" << op->name << ", " << op << ")";
- });
+ .set_dispatch<HybridOpNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const HybridOpNode*>(node.get());
+ p->stream << "hybrid(" << op->name << ", " << op << ")";
+ });
TVM_REGISTER_NODE_TYPE(HybridOpNode);
-int HybridOpNode::num_outputs() const {
- return static_cast<int>(outputs.size());
-}
-
-Array<IterVar> HybridOpNode::root_iter_vars() const {
- return this->axis;
-}
+int HybridOpNode::num_outputs() const { return static_cast<int>(outputs.size()); }
-DataType HybridOpNode::output_dtype(size_t i) const {
- return outputs[i]->dtype;
-}
+Array<IterVar> HybridOpNode::root_iter_vars() const { return this->axis; }
-Array<PrimExpr> HybridOpNode::output_shape(size_t i) const {
- return outputs[i]->shape;
-}
+DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; }
+Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; }
-Operation HybridOpNode::make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<Tensor> inputs,
- Array<Tensor> outputs,
- Stmt body) {
+Operation HybridOpNode::make(std::string name, std::string tag, Map<std::string, ObjectRef> attrs,
+ Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, ObjectRef>();
}
return res;
}
-TVM_REGISTER_GLOBAL("te.HybridOp")
-.set_body_typed(HybridOpNode::make);
-
+TVM_REGISTER_GLOBAL("te.HybridOp").set_body_typed(HybridOpNode::make);
Array<Tensor> HybridOpNode::InputTensors() const {
// Because input tensors could be potentially inlined into hybrid scripts,
std::unordered_set<Tensor> visited;
Array<Tensor> curr_inputs;
tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
- const tir::CallNode *call = n.as<tir::CallNode>();
- if (call != nullptr && call->func.defined()) {
- Tensor t = Downcast<Operation>(call->func).output(call->value_index);
- if (orig_inputs.count(t) && !visited.count(t)) {
- curr_inputs.push_back(t);
- visited.insert(t);
- }
+ const tir::CallNode* call = n.as<tir::CallNode>();
+ if (call != nullptr && call->func.defined()) {
+ Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+ if (orig_inputs.count(t) && !visited.count(t)) {
+ curr_inputs.push_back(t);
+ visited.insert(t);
}
+ }
});
return curr_inputs;
}
-Operation HybridOpNode::ReplaceInputs(
- const Operation &self,
- const std::unordered_map<Tensor, Tensor> &rmap) const {
+Operation HybridOpNode::ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_object<HybridOpNode>(*this);
n->body = te::ReplaceTensor(this->body, rmap);
}
}
- if (body.same_as(n->body) &&
- inputs.same_as(n->inputs)) {
+ if (body.same_as(n->body) && inputs.same_as(n->inputs)) {
return self;
} else {
return Operation(n);
}
}
-void HybridOpNode::PropBoundToInputs(
- const Operation &self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet> &dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+void HybridOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
auto curr_inputs = InputTensors();
for (Tensor t : curr_inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
- TensorDom &dom = it->second;
+ TensorDom& dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
- Range::make_by_min_extent(
- make_const(t->shape[i].dtype(), 0), t->shape[i])));
+ Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])));
}
}
}
-void HybridOpNode::GatherBound(
- const Operation &self,
- const std::unordered_map<Tensor, TensorDom> &tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
+void HybridOpNode::GatherBound(const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {
for (auto iter_var : axis) {
CHECK(!out_dom_map->count(iter_var));
out_dom_map->operator[](iter_var) = iter_var->dom;
}
}
-Stmt HybridOpNode::BuildRealize(
- const Stage &stage,
- const std::unordered_map<IterVar, Range> &realize_map,
- const Stmt &body) const {
+Stmt HybridOpNode::BuildRealize(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const {
// TODO(@were): Add attribute inject here and remove it from hybrid parser.
CHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
Tensor t = stage->op.output(k);
Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
- bounds.push_back(
- Range::make_by_min_extent(
- make_const(t->shape[i].dtype(), 0), t->shape[i]));
+ bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
- realize_body = tir::RealizeNode::make(
- t->op, t->value_index, t->dtype,
- bounds, const_true(), realize_body);
+ realize_body =
+ tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize_body);
}
return realize_body;
}
-Stmt HybridOpNode::BuildProvide(
- const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map,
- bool debug_keep_trivial_loop) const {
+Stmt HybridOpNode::BuildProvide(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt ret = AttrStmtNode::make(
- make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
+ Stmt ret =
+ AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
return ret;
}
-Stmt ApplyLoopShapes(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
+Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ Stmt stmt) {
class LoopSpliter : public StmtExprMutator {
PrimExpr factor;
- const VarNode *parent;
+ const VarNode* parent;
IterVar inner, outer;
public:
bool splitted;
- LoopSpliter(const SplitNode *split,
- const std::unordered_map<IterVar, Range> &dom_map) :
- factor(split->factor), splitted(false) {
+ LoopSpliter(const SplitNode* split, const std::unordered_map<IterVar, Range>& dom_map)
+ : factor(split->factor), splitted(false) {
parent = split->parent->var.get();
- auto &inner_ = split->inner;
+ auto& inner_ = split->inner;
CHECK(dom_map.count(inner_));
- auto &inner_dom = dom_map.find(inner_)->second;
+ auto& inner_dom = dom_map.find(inner_)->second;
CHECK(is_const_int(inner_dom->min, 0));
- auto &outer_ = split->outer;
+ auto& outer_ = split->outer;
CHECK(dom_map.count(outer_));
- auto &outer_dom = dom_map.find(outer_)->second;
+ auto& outer_dom = dom_map.find(outer_)->second;
CHECK(is_const_int(outer_dom->min, 0));
inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type);
outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
}
- Stmt VisitStmt_(const ForNode *op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
if (op->loop_var.get() == parent) {
- std::unordered_map<const VarNode *, PrimExpr> rmap;
+ std::unordered_map<const VarNode*, PrimExpr> rmap;
rmap[op->loop_var.get()] = inner + outer * factor;
Stmt ret = tir::Substitute(op->body, rmap);
PrimExpr cond = likely(outer * factor < (op->extent - inner));
ret = IfThenElseNode::make(cond, ret);
ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent,
- IterVarTypeToForType(inner->iter_type), op->device_api, ret);
+ IterVarTypeToForType(inner->iter_type), op->device_api, ret);
ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent,
- IterVarTypeToForType(outer->iter_type), op->device_api, ret);
+ IterVarTypeToForType(outer->iter_type), op->device_api, ret);
splitted = true;
return ret;
}
};
class LoopFuser : public StmtExprMutator {
- const IterVar &parent;
- const VarNode *inner;
- const VarNode *outer;
+ const IterVar& parent;
+ const VarNode* inner;
+ const VarNode* outer;
bool under_outer;
PrimExpr extent;
public:
bool fused;
- explicit LoopFuser(const FuseNode *fuse_)
- : parent(fuse_->fused), inner(fuse_->inner->var.get()),
- outer(fuse_->outer->var.get()), under_outer(false),
- extent(0), fused(false) {}
+ explicit LoopFuser(const FuseNode* fuse_)
+ : parent(fuse_->fused),
+ inner(fuse_->inner->var.get()),
+ outer(fuse_->outer->var.get()),
+ under_outer(false),
+ extent(0),
+ fused(false) {}
// TODO(@were): Handle imperfect loops
Stmt VisitStmt_(const ForNode* op) final {
if (op->loop_var.get() == inner) {
CHECK(under_outer);
- std::unordered_map<const VarNode *, PrimExpr> rmap;
+ std::unordered_map<const VarNode*, PrimExpr> rmap;
rmap[op->loop_var.get()] = indexmod(parent, op->extent);
extent = op->extent;
fused = true;
} else if (op->loop_var.get() == outer) {
under_outer = true;
Stmt body = this->VisitStmt(op->body);
- std::unordered_map<const VarNode *, PrimExpr> rmap;
+ std::unordered_map<const VarNode*, PrimExpr> rmap;
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = tir::Substitute(body, rmap);
under_outer = false;
- return ForNode::make(parent->var, PrimExpr(0), extent * op->extent,
- op->for_type, op->device_api, body);
+ return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, op->for_type,
+ op->device_api, body);
} else if (under_outer) {
Stmt body = this->VisitStmt(op->body);
- std::unordered_map<const VarNode *, PrimExpr> rmap;
+ std::unordered_map<const VarNode*, PrimExpr> rmap;
rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
body = tir::Substitute(body, rmap);
extent = extent * op->extent;
}
};
- for (auto &rel : stage->relations) {
- if (const SplitNode *split = rel.as<SplitNode>()) {
+ for (auto& rel : stage->relations) {
+ if (const SplitNode* split = rel.as<SplitNode>()) {
LoopSpliter Spliter(split, dom_map);
stmt = Spliter(stmt);
CHECK(Spliter.splitted);
- } else if (const FuseNode *fuse = rel.as<FuseNode>()) {
+ } else if (const FuseNode* fuse = rel.as<FuseNode>()) {
LoopFuser Fuser(fuse);
stmt = Fuser(stmt);
CHECK(Fuser.fused);
return stmt;
}
-Stmt ApplyLoopAnnotations(const Stage &stage,
- const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
+Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar, IterVar>& rebased,
+ Stmt stmt) {
class LoopAnnotator : public StmtMutator {
- const VarNode *var;
- const IterVarAttr &attr;
+ const VarNode* var;
+ const IterVarAttr& attr;
public:
- LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
+ LoopAnnotator(const VarNode* var_, const IterVarAttr& attr_) : var(var_), attr(attr_) {}
- Stmt VisitStmt_(const ForNode *op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
tir::ExprDeepEqual expr_equal;
if (op->loop_var.get() == var) {
if (attr->bind_thread.defined()) {
- const auto &iter_var = attr->bind_thread;
+ const auto& iter_var = attr->bind_thread;
if (iter_var->dom.defined()) {
CHECK(is_const_int(iter_var->dom->min, 0));
CHECK(expr_equal(iter_var->dom->extent, op->extent))
- << "Thread extent and loop extent mismatch!\n";
+ << "Thread extent and loop extent mismatch!\n";
}
- std::unordered_map<const VarNode *, PrimExpr> rmap;
+ std::unordered_map<const VarNode*, PrimExpr> rmap;
rmap[op->loop_var.get()] = iter_var;
Stmt body = tir::Substitute(op->body, rmap);
return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
} else {
return ForNode::make(op->loop_var, op->min, op->extent,
- IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
+ IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
}
}
return StmtMutator::VisitStmt_(op);
}
};
- for (auto &iter_var : stage->leaf_iter_vars) {
+ for (auto& iter_var : stage->leaf_iter_vars) {
bool need_change = false;
int found = 0;
- const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
- const VarNode *var = actual->var.get();
+ const IterVar& actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
+ const VarNode* var = actual->var.get();
ForType expected = IterVarTypeToForType(iter_var->iter_type);
IterVarAttr attr;
if (stage->iter_var_attrs.count(iter_var)) {
expected = IterVarTypeToForType(attr->iter_type);
}
- PostOrderVisit(stmt,
- [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
- if (const ForNode *op = node.as<ForNode>()) {
+ PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
+ if (const ForNode* op = node.as<ForNode>()) {
if (op->loop_var.get() == var) {
++found;
need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
return stmt;
}
-Stmt ApplyLoopOrder(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map,
- const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
+Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt) {
std::vector<const VarNode*> current_order;
PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) {
- if (const ForNode *op = node.as<ForNode>())
- current_order.push_back(op->loop_var.get());
+ if (const ForNode* op = node.as<ForNode>()) current_order.push_back(op->loop_var.get());
});
std::reverse(current_order.begin(), current_order.end());
- auto &required_ord = stage->leaf_iter_vars;
+ auto& required_ord = stage->leaf_iter_vars;
CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
- std::unordered_map<const VarNode *, IterVar> reorder;
+ std::unordered_map<const VarNode*, IterVar> reorder;
bool need_reorder = false;
for (size_t i = 0; i < current_order.size(); ++i) {
- auto ¤t = current_order[i];
- const IterVar &iter_var = required_ord[i];
- const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
+ auto& current = current_order[i];
+ const IterVar& iter_var = required_ord[i];
+ const IterVar& required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
reorder[current] = required;
if (current != required->var.get()) {
}
class LoopReorder : public StmtMutator {
- const Stage &stage;
- const std::unordered_map<IterVar, Range> &dom_map;
- const std::unordered_map<const VarNode *, IterVar> &reorder;
+ const Stage& stage;
+ const std::unordered_map<IterVar, Range>& dom_map;
+ const std::unordered_map<const VarNode*, IterVar>& reorder;
public:
- LoopReorder(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map,
- const std::unordered_map<const VarNode*, IterVar> &reorder)
- : stage(stage), dom_map(dom_map), reorder(reorder) {}
+ LoopReorder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ const std::unordered_map<const VarNode*, IterVar>& reorder)
+ : stage(stage), dom_map(dom_map), reorder(reorder) {}
Stmt VisitStmt_(const ForNode* op) final {
// Reorder from in to out
auto target = reorder.find(op->loop_var.get())->second;
if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
return GetRef<Stmt>(op);
- const Stmt &body = op->body.same_as(body_) ? op->body : body_;
+ const Stmt& body = op->body.same_as(body_) ? op->body : body_;
ForType for_type = IterVarTypeToForType(target->iter_type);
if (stage->iter_var_attrs.count(target)) {
for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
}
- const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
- return ForNode::make(target->var, range->min, range->extent,
- for_type, DeviceAPI::None, body);
+ const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
+ return ForNode::make(target->var, range->min, range->extent, for_type, DeviceAPI::None, body);
}
};
- if (need_reorder)
- return LoopReorder(stage, dom_map, reorder)(stmt);
+ if (need_reorder) return LoopReorder(stage, dom_map, reorder)(stmt);
return stmt;
}
-Stmt ApplySchedule(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
+Stmt ApplySchedule(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ Stmt stmt) {
// TODO(@were): Eliminate loop rebase in script parser and move the burden here
// Gather rebased variables
std::unordered_map<IterVar, IterVar> rebased;
// TODO(@were): Write a comprehensive pass to analyze iter var types
std::vector<IterVar> res_;
PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
- if (const ForNode *op = node.as<ForNode>()) {
+ if (const ForNode* op = node.as<ForNode>()) {
Var loop_var(op->loop_var);
Range dom = Range::make_by_min_extent(op->min, op->extent);
res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public tir::StmtMutator {
public:
- explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
- : vmap_(vmap) {}
+ explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {}
Stmt VisitStmt_(const tir::ProvideNode* op) final {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
- Stmt ret = tir::ProvideNode::make(
- it->second->op, it->second->value_index, op->value, op->args);
+ Stmt ret =
+ tir::ProvideNode::make(it->second->op, it->second->value_index, op->value, op->args);
found = true;
return this->VisitStmt(ret);
}
bool found{false};
private:
- const std::unordered_map<Tensor, Tensor> &vmap_;
+ const std::unordered_map<Tensor, Tensor>& vmap_;
};
-Stmt ReplaceProvideTensor(Stmt stmt,
- const std::unordered_map<Tensor, Tensor> &replace) {
+Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace) {
ProviderReplacer repl(replace);
Stmt ret = repl(stmt);
return repl.found ? ret : stmt;
#ifndef TVM_TE_OPERATION_HYBRID_OP_H_
#define TVM_TE_OPERATION_HYBRID_OP_H_
-#include <tvm/tir/expr.h>
#include <tvm/te/schedule.h>
+#include <tvm/tir/expr.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
-#include "../schedule/message_passing.h"
-#include "../../tir/transforms/ir_util.h"
#include "../../tir/transforms/arg_binder.h"
+#include "../../tir/transforms/ir_util.h"
+#include "../schedule/message_passing.h"
namespace tvm {
namespace te {
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
-Stmt ReplaceProvideTensor(Stmt stmt,
- const std::unordered_map<Tensor, Tensor>& replace);
+Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Apply the schedule manipulation on the function body.
* \param dom_map The extents of the iterative variables may be used.
* \param stage The schedule information to be applied.
*/
-Stmt ApplySchedule(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
+Stmt ApplySchedule(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ Stmt stmt);
/*!
* \brief Apply loop splits and fuses in the schedule on the function body.
* \param dom_map The extents of the iterative variables may be used.
* \param stmt The statement to be processed.
*/
-Stmt ApplyLoopShapes(const Stage &stage,
- const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
-
+Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ Stmt stmt);
/*!
* \brief Apply loop annotation in the schedule on the function body.
* \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
* \param stmt The statement to be processed.
*/
-Stmt ApplyLoopAnnotations(const Stage &stage,
- const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt);
+Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar, IterVar>& rebased,
+ Stmt stmt);
/*!
* \brief Apply loop order in the schedule on the function body.
* \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
* \param stmt The statement to be processed.
*/
-Stmt ApplyLoopOrder(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map,
- const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt);
+Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt);
} // namespace te
} // namespace tvm
* \brief Utility to make loop nest.
* \file op_util.cc
*/
+#include "op_util.h"
+
+#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/te/operation.h>
+
#include <string>
-#include "op_util.h"
-#include "../schedule/message_passing.h"
+
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
+#include "../schedule/message_passing.h"
namespace tvm {
namespace te {
using namespace arith;
using namespace tir;
-std::vector<std::vector<Stmt> >
-MakeLoopNest(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- size_t begin_iter_pos,
- bool new_loop_var,
- const std::unordered_set<IterVar>& skip_iter,
- std::unordered_map<IterVar, PrimExpr>* p_value_map,
- bool debug_keep_trivial_loop) {
+std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ size_t begin_iter_pos, bool new_loop_var,
+ const std::unordered_set<IterVar>& skip_iter,
+ std::unordered_map<IterVar, PrimExpr>* p_value_map,
+ bool debug_keep_trivial_loop) {
auto leaf_iter_vars = stage->leaf_iter_vars;
Stmt no_op = EvaluateNode::make(0);
// create the loop nest
}
if (it_attr.defined()) {
switch (it_attr->iter_type) {
- case kUnrolled: for_type = ForType::Unrolled; break;
- case kVectorized: for_type = ForType::Vectorized; break;
- case kParallelized: for_type = ForType::Parallel; break;
- case kDataPar: break;
- case kTensorized: break;
- default: LOG(FATAL) << "Unknown iter type"
- << it_attr->iter_type
- << " in the iter_var_attrs";
+ case kUnrolled:
+ for_type = ForType::Unrolled;
+ break;
+ case kVectorized:
+ for_type = ForType::Vectorized;
+ break;
+ case kParallelized:
+ for_type = ForType::Parallel;
+ break;
+ case kDataPar:
+ break;
+ case kTensorized:
+ break;
+ default:
+ LOG(FATAL) << "Unknown iter type" << it_attr->iter_type << " in the iter_var_attrs";
}
CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
}
}
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
- nest[i + 1].emplace_back(
- LetStmtNode::make(var, dom->min, no_op));
+ nest[i + 1].emplace_back(LetStmtNode::make(var, dom->min, no_op));
value_map[iv] = dom->min;
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(
- ForNode::make(var, 0, dom->extent,
- for_type, DeviceAPI::None, no_op));
+ ForNode::make(var, 0, dom->extent, for_type, DeviceAPI::None, no_op));
value_map[iv] = var;
} else {
Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
nest[i + 1].emplace_back(
- ForNode::make(idx, 0, dom->extent,
- for_type, DeviceAPI::None, no_op));
+ ForNode::make(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op));
PrimExpr new_value = dom->min + idx;
value_map[iv] = new_value;
- nest[i + 1].emplace_back(
- LetStmtNode::make(var, new_value, no_op));
+ nest[i + 1].emplace_back(LetStmtNode::make(var, new_value, no_op));
}
if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
- CHECK(!is_one(dom->extent))
- << "Cannot prefetch on trivial loop with extent=1";
- CHECK_EQ(it_attr->prefetch_data.size(),
- it_attr->prefetch_offset.size());
+ CHECK(!is_one(dom->extent)) << "Cannot prefetch on trivial loop with extent=1";
+ CHECK_EQ(it_attr->prefetch_data.size(), it_attr->prefetch_offset.size());
for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
- nest[i + 1].emplace_back(
- AttrStmtNode::make(it_attr->prefetch_data[j],
- tir::attr::prefetch_scope,
- it_attr->prefetch_offset[j], no_op));
+ nest[i + 1].emplace_back(AttrStmtNode::make(it_attr->prefetch_data[j],
+ tir::attr::prefetch_scope,
+ it_attr->prefetch_offset[j], no_op));
}
}
- } else if (bind_iv->thread_tag == "vthread" ||
- bind_iv->thread_tag == "cthread") {
+ } else if (bind_iv->thread_tag == "vthread" || bind_iv->thread_tag == "cthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
value_map[iv] = var;
} else {
LOG(WARNING)
- << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. "
- << "TVM assumes only threadIdx.x indicates threads inside a warp, "
- << "while threadIdx.y and threadIdx.z indicates different warps.";
+ << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. "
+ << "TVM assumes only threadIdx.x indicates threads inside a warp, "
+ << "while threadIdx.y and threadIdx.z indicates different warps.";
value_map[iv] = dom->min;
}
} else {
}
// annotate the extent of the IterVar
if (!new_loop_var) {
- nest[i + 1].emplace_back(
- AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op));
+ nest[i + 1].emplace_back(AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op));
}
}
// message passing to get offset of root iter vars.
// replacer to replace tensors
class TensorReplacer : public tir::StmtExprMutator {
public:
- explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
- : vmap_(vmap) {}
+ explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {}
PrimExpr VisitExpr_(const tir::CallNode* op) final {
if (op->call_type == tir::CallNode::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
- PrimExpr ret = tir::CallNode::make(
- op->dtype, it->second->op->name, op->args,
- op->call_type, it->second->op, it->second->value_index);
+ PrimExpr ret = tir::CallNode::make(op->dtype, it->second->op->name, op->args, op->call_type,
+ it->second->op, it->second->value_index);
found = true;
return this->VisitExpr(ret);
}
const std::unordered_map<Tensor, Tensor>& vmap_;
};
-Stmt ReplaceTensor(Stmt stmt,
- const std::unordered_map<Tensor, Tensor>& replace) {
+Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace);
Stmt ret = repl(stmt);
return repl.found ? ret : stmt;
}
-PrimExpr ReplaceTensor(PrimExpr expr,
- const std::unordered_map<Tensor, Tensor>& replace) {
+PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace);
PrimExpr ret = repl(expr);
return repl.found ? ret : expr;
}
-
-Stmt Substitute(Stmt s,
- const std::unordered_map<IterVar, PrimExpr>& value_map) {
+Stmt Substitute(Stmt s, const std::unordered_map<IterVar, PrimExpr>& value_map) {
std::unordered_map<const VarNode*, PrimExpr> init;
for (const auto& kv : value_map) {
init[kv.first->var.get()] = kv.second;
IterVarType ForTypeToIterVarType(tir::ForType for_type) {
switch (for_type) {
- case ForType::Serial:
- return kDataPar;
- case ForType::Parallel:
- return kParallelized;
- case ForType::Vectorized:
- return kVectorized;
- case ForType::Unrolled:
- return kUnrolled;
- default:
- return kDataPar;
+ case ForType::Serial:
+ return kDataPar;
+ case ForType::Parallel:
+ return kParallelized;
+ case ForType::Vectorized:
+ return kVectorized;
+ case ForType::Unrolled:
+ return kUnrolled;
+ default:
+ return kDataPar;
}
}
tir::ForType IterVarTypeToForType(IterVarType iter_type) {
switch (iter_type) {
- case kDataPar:
- return ForType::Serial;
- case kParallelized:
- return ForType::Parallel;
- case kVectorized:
- return ForType::Vectorized;
- case kUnrolled:
- return ForType::Unrolled;
- default:
- return ForType::Serial;
+ case kDataPar:
+ return ForType::Serial;
+ case kParallelized:
+ return ForType::Parallel;
+ case kVectorized:
+ return ForType::Vectorized;
+ case kUnrolled:
+ return ForType::Unrolled;
+ default:
+ return ForType::Serial;
}
}
#ifndef TVM_TE_OPERATION_OP_UTIL_H_
#define TVM_TE_OPERATION_OP_UTIL_H_
-#include <tvm/tir/expr.h>
#include <tvm/te/schedule.h>
+#include <tvm/tir/expr.h>
+
#include <unordered_map>
#include <unordered_set>
#include <vector>
-#include "../../tir/transforms/ir_util.h"
+
#include "../../tir/transforms/arg_binder.h"
+#include "../../tir/transforms/ir_util.h"
#include "../schedule/message_passing.h"
namespace tvm {
* \param p_value_map The result value of each IterVar.
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
*/
-std::vector<std::vector<Stmt> >
-MakeLoopNest(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- size_t begin_iter_pos,
- bool new_loop_var,
- const std::unordered_set<IterVar>& skip_iter,
- std::unordered_map<IterVar, PrimExpr>* p_value_map,
- bool debug_keep_trivial_loop);
+std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ size_t begin_iter_pos, bool new_loop_var,
+ const std::unordered_set<IterVar>& skip_iter,
+ std::unordered_map<IterVar, PrimExpr>* p_value_map,
+ bool debug_keep_trivial_loop);
/*!
* \brief Create a nest of if checking the predicates.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
-Stmt ReplaceTensor(Stmt stmt,
- const std::unordered_map<Tensor, Tensor>& replace);
+Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param expr The expression to be processed.
* \param replace The replacement rule.
*/
-PrimExpr ReplaceTensor(PrimExpr expr,
- const std::unordered_map<Tensor, Tensor>& replace);
+PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Substitute the variables of stmt by value map.
* \param value_map The value map.
* \return Substituted result.
*/
-Stmt Substitute(Stmt stmt,
- const std::unordered_map<IterVar, PrimExpr>& value_map);
+Stmt Substitute(Stmt stmt, const std::unordered_map<IterVar, PrimExpr>& value_map);
/*!
* \brief Converts Halide ForType to its corresponding IterVarType
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const PlaceholderOpNode*>(node.get());
- p->stream << "placeholder(" << op->name << ", " << op << ")";
-});
+ .set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const PlaceholderOpNode*>(node.get());
+ p->stream << "placeholder(" << op->name << ", " << op << ")";
+ });
TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
-int PlaceholderOpNode::num_outputs() const {
- return 1;
-}
+int PlaceholderOpNode::num_outputs() const { return 1; }
-Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
- return {};
-}
+Array<IterVar> PlaceholderOpNode::root_iter_vars() const { return {}; }
DataType PlaceholderOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return shape;
}
-Operation PlaceholderOpNode::make(std::string name,
- Array<PrimExpr> shape,
- DataType dtype) {
+Operation PlaceholderOpNode::make(std::string name, Array<PrimExpr> shape, DataType dtype) {
auto n = make_object<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
}
TVM_REGISTER_GLOBAL("te.Placeholder")
-.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
- return placeholder(shape, dtype, name);
-});
+ .set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
+ return placeholder(shape, dtype, name);
+ });
-Array<Tensor> PlaceholderOpNode::InputTensors() const {
- return {};
-}
+Array<Tensor> PlaceholderOpNode::InputTensors() const { return {}; }
-Operation PlaceholderOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
+Operation PlaceholderOpNode::ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
return self;
}
void PlaceholderOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
+ const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
-}
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {}
-void PlaceholderOpNode::GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
-}
+void PlaceholderOpNode::GatherBound(const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {}
-Stmt PlaceholderOpNode::BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const {
+Stmt PlaceholderOpNode::BuildRealize(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const {
return body;
}
-Stmt PlaceholderOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
+Stmt PlaceholderOpNode::BuildProvide(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
return Stmt();
}
} // namespace te
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
-#include "op_util.h"
+
#include "../schedule/graph.h"
+#include "op_util.h"
namespace tvm {
namespace te {
using namespace tir;
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ScanOpNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ScanOpNode*>(node.get());
- p->stream << "scan(" << op->name << ", " << op << ")";
-});
+ .set_dispatch<ScanOpNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ScanOpNode*>(node.get());
+ p->stream << "scan(" << op->name << ", " << op << ")";
+ });
TVM_REGISTER_NODE_TYPE(ScanOpNode);
-int ScanOpNode::num_outputs() const {
- return static_cast<int>(update.size());
-}
+int ScanOpNode::num_outputs() const { return static_cast<int>(update.size()); }
Array<IterVar> ScanOpNode::root_iter_vars() const {
Array<IterVar> ret{scan_axis};
for (IterVar iv : spatial_axis_) {
return ret;
}
-DataType ScanOpNode::output_dtype(size_t i) const {
- return update[i]->dtype;
-}
+DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; }
Array<PrimExpr> ScanOpNode::output_shape(size_t i) const {
CHECK_LT(i, state_placeholder.size());
return state_placeholder[i]->shape;
}
-Operation ScanOpNode::make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- IterVar axis,
- Array<Tensor> init,
- Array<Tensor> update,
- Array<Tensor> state_placeholder,
- Array<Tensor> inputs) {
+Operation ScanOpNode::make(std::string name, std::string tag, Map<std::string, ObjectRef> attrs,
+ IterVar axis, Array<Tensor> init, Array<Tensor> update,
+ Array<Tensor> state_placeholder, Array<Tensor> inputs) {
if (!attrs.defined()) {
attrs = Map<std::string, ObjectRef>();
}
CHECK_EQ(init[i]->dtype, update[i]->dtype);
CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min";
- CHECK(prove_equal(
- state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
+ CHECK(prove_equal(state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
<< "state_placeholder.shape[0] need to match"
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
- for (size_t k = 0; k < update[i].ndim(); ++k) {
- CHECK(prove_equal(
- update[i]->shape[k], state_placeholder[i]->shape[k]));
+ for (size_t k = 0; k < update[i].ndim(); ++k) {
+ CHECK(prove_equal(update[i]->shape[k], state_placeholder[i]->shape[k]));
if (k != 0) {
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k;
- n->spatial_axis_.push_back(
- IterVarNode::make(
- Range::make_by_min_extent(0, update[i]->shape[k]),
- Var(spatial_name.str()), kOpaque));
+ n->spatial_axis_.push_back(IterVarNode::make(
+ Range::make_by_min_extent(0, update[i]->shape[k]), Var(spatial_name.str()), kOpaque));
}
}
- for (size_t k = 1; k < init[i].ndim(); ++k) {
- CHECK(prove_equal(
- init[i]->shape[k], state_placeholder[i]->shape[k]));
+ for (size_t k = 1; k < init[i].ndim(); ++k) {
+ CHECK(prove_equal(init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}
n->name = std::move(name);
return Operation(n);
}
-TVM_REGISTER_GLOBAL("te.ScanOp")
-.set_body_typed(ScanOpNode::make);
-
+TVM_REGISTER_GLOBAL("te.ScanOp").set_body_typed(ScanOpNode::make);
-Array<Tensor> scan(Array<Tensor> init,
- Array<Tensor> update,
- Array<Tensor> state_placeholder,
- Array<Tensor> inputs,
- std::string name,
- std::string tag,
+Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
+ Array<Tensor> inputs, std::string name, std::string tag,
Map<std::string, ObjectRef> attrs) {
- IterVar scan_axis =
- IterVarNode::make(
- Range::make_by_min_extent(
- init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
- Var(name + ".idx"), kOrdered);
- Operation op = ScanOpNode::make(
- name, tag, attrs, scan_axis,
- init, update, state_placeholder, inputs);
+ IterVar scan_axis = IterVarNode::make(
+ Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
+ Var(name + ".idx"), kOrdered);
+ Operation op =
+ ScanOpNode::make(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
return ret;
}
-Operation ScanOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
+Operation ScanOpNode::ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_object<ScanOpNode>(*this);
for (size_t i = 0; i < n->init.size(); ++i) {
n->update.Set(i, rmap.at(n->update[i]));
}
}
- if (!n->init.same_as(init) ||
- !n->update.same_as(update)) {
+ if (!n->init.same_as(init) || !n->update.same_as(update)) {
return Operation(n);
} else {
return self;
}
}
-void ScanOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+void ScanOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
TensorDom* init_dom = nullptr;
}
// first dimension, always needed.
if (init_dom) {
- init_dom->data[0].push_back(IntSet::range(
- Range::make_by_min_extent(0, this->init[i]->shape[0])));
+ init_dom->data[0].push_back(
+ IntSet::range(Range::make_by_min_extent(0, this->init[i]->shape[0])));
}
if (update_dom) {
update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
}
}
-void ScanOpNode::GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
+void ScanOpNode::GatherBound(const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
CHECK(!out_dom_map->count(this->scan_axis));
std::vector<Tensor> output(this->num_outputs());
arith::Analyzer analyzer;
Range sdom = this->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
- (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
- sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min));
+ (*out_dom_map)[this->scan_axis] =
+ Range::make_by_min_extent(sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min));
Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
// Update for spatial axis.
size_t sp_idx = 0;
}
}
-Stmt ScanOpNode::BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const Stmt& body) const {
+Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ const Stmt& body) const {
arith::Analyzer analyzer;
CHECK_EQ(stage->op.get(), this);
Range sdom = dom_map.at(this->scan_axis);
- Range tdom = Range::make_by_min_extent(
- 0, analyzer.Simplify(sdom->extent + sdom->min));
+ Range tdom = Range::make_by_min_extent(0, analyzer.Simplify(sdom->extent + sdom->min));
Stmt ret = body;
size_t sp_idx = 0;
for (size_t i = 0; i < update.size(); ++i) {
IterVar sp_ax = this->spatial_axis_[sp_idx];
bounds.push_back(dom_map.at(sp_ax));
}
- ret = tir::RealizeNode::make(t->op, t->value_index, t->dtype,
- bounds, const_true(), ret);
+ ret = tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), ret);
}
return ret;
}
-Stmt ScanOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
+Stmt ScanOpNode::BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt provide = AttrStmtNode::make(
- stage->op, tir::attr::scan_update_scope, this->scan_axis->var,
- EvaluateNode::make(0));
- Stmt init = AttrStmtNode::make(
- stage->op, tir::attr::scan_init_scope, 0,
- EvaluateNode::make(0));
+ Stmt provide = AttrStmtNode::make(stage->op, tir::attr::scan_update_scope, this->scan_axis->var,
+ EvaluateNode::make(0));
+ Stmt init = AttrStmtNode::make(stage->op, tir::attr::scan_init_scope, 0, EvaluateNode::make(0));
size_t begin_scan = 0;
- for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
+ for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
CHECK_EQ(begin_scan, i);
begin_scan = i + 1;
}
std::unordered_map<IterVar, PrimExpr> vmap;
std::unordered_set<IterVar> empty;
- auto nest = MakeLoopNest(
- stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
+ auto nest = MakeLoopNest(stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
nest[begin_scan].push_back(init);
- nest.push_back(
- MakeIfNest(
- MakeBoundCheck(stage, dom_map, vmap, false, empty)));
+ nest.push_back(MakeIfNest(MakeBoundCheck(stage, dom_map, vmap, false, empty)));
return MergeNest(nest, provide);
}
} // namespace te
* \brief Tensor Compute Op.
* \file tensor_compute_op.cc
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
-#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
+
#include <unordered_set>
-#include "./op_util.h"
-#include "./compute_op.h"
#include "../../arith/compute_expr.h"
+#include "./compute_op.h"
+#include "./op_util.h"
namespace tvm {
namespace te {
using namespace tir;
// TensorComputeOpNode
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const TensorComputeOpNode*>(node.get());
- p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
- });
+ .set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const TensorComputeOpNode*>(node.get());
+ p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
+ });
TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
return this->intrin->buffers[this->inputs.size() + i]->dtype;
}
-Operation TensorComputeOpNode::make(std::string name,
- std::string tag,
- Array<IterVar> axis,
- Array<IterVar> reduce_axis,
- int schedulable_ndim,
- TensorIntrin intrin,
- Array<Tensor> tensors,
- Array<Region> regions,
- Array<PrimExpr> scalar_inputs) {
+Operation TensorComputeOpNode::make(std::string name, std::string tag, Array<IterVar> axis,
+ Array<IterVar> reduce_axis, int schedulable_ndim,
+ TensorIntrin intrin, Array<Tensor> tensors,
+ Array<Region> regions, Array<PrimExpr> scalar_inputs) {
auto n = make_object<TensorComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
return Operation(n);
}
-TVM_REGISTER_GLOBAL("te.TensorComputeOp")
-.set_body_typed(TensorComputeOpNode::make);
+TVM_REGISTER_GLOBAL("te.TensorComputeOp").set_body_typed(TensorComputeOpNode::make);
+Array<Tensor> TensorComputeOpNode::InputTensors() const { return inputs; }
-Array<Tensor> TensorComputeOpNode::InputTensors() const {
- return inputs;
-}
-
-Operation TensorComputeOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
+Operation TensorComputeOpNode::ReplaceInputs(const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_object<TensorComputeOpNode>(*this);
auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
if (intrin->body.same_as(n->intrin->body) &&
intrin->reduce_init.same_as(n->intrin->reduce_init) &&
- intrin->reduce_update.same_as(n->intrin->reduce_update) &&
- inputs.same_as(n->inputs)) {
+ intrin->reduce_update.same_as(n->intrin->reduce_update) && inputs.same_as(n->inputs)) {
return self;
} else {
n->intrin = TensorIntrin(intrin);
}
void TensorComputeOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
+ const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (size_t i = 0; i < this->inputs.size(); ++i) {
}
}
-size_t TensorComputeOpNode::num_schedulable_dims() const {
- return schedulable_ndim;
-}
+size_t TensorComputeOpNode::num_schedulable_dims() const { return schedulable_ndim; }
-Stmt TensorComputeOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
+Stmt TensorComputeOpNode::BuildProvide(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
// Start bind data.
}
input_bind_nest.emplace_back(AttrStmtNode::make(
bind_spec, tir::attr::buffer_bind_scope,
- CallNode::make(DataType::Handle(),
- tir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic), nop));
+ CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic),
+ nop));
}
// output binding
output_bind_nest.emplace_back(AttrStmtNode::make(
bind_spec, tir::attr::buffer_bind_scope,
- CallNode::make(DataType::Handle(),
- tir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic), nop));
+ CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic),
+ nop));
}
// Check variable remap
ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
if (this->reduce_axis.size() == 0) {
- std::vector<std::vector<Stmt> > nest(
- n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
+ std::vector<std::vector<Stmt> > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
nest.emplace_back(MakeIfNest(n.main_predicates));
CHECK_EQ(n.init_predicates.size(), 0U);
CHECK(this->intrin->body.defined())
body = tir::Substitute(body, vmap);
body = MergeNest(binder.asserts(), body);
body = te::Substitute(body, n.main_vmap);
- Stmt ret = MergeNest(nest, body);
+ Stmt ret = MergeNest(nest, body);
return ret;
} else {
// Need to split reduction
- CHECK(this->intrin->reduce_update.defined())
- << "Reduction update op is not defined";
+ CHECK(this->intrin->reduce_update.defined()) << "Reduction update op is not defined";
// Need init and update steps
CHECK_NE(this->reduce_axis.size(), 0U);
- std::vector<std::vector<Stmt> > common(
- n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
- std::vector<std::vector<Stmt> > update_nest(
- n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
+ std::vector<std::vector<Stmt> > common(n.main_nest.begin(),
+ n.main_nest.begin() + n.num_common_loop + 1);
+ std::vector<std::vector<Stmt> > update_nest(n.main_nest.begin() + n.num_common_loop + 1,
+ n.main_nest.begin() + tloc + 1);
update_nest.emplace_back(MakeIfNest(n.main_predicates));
if (this->intrin->reduce_init.defined()) {
// init nest
- std::vector<std::vector<Stmt> > init_nest(
- n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
+ std::vector<std::vector<Stmt> > init_nest(n.init_nest.begin(),
+ n.init_nest.begin() + tloc + 1);
init_nest.emplace_back(MakeIfNest(n.init_predicates));
Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
init = te::Substitute(init, n.init_vmap);
return MergeNest(common, SeqStmt::Flatten(init, update));
} else {
// When init op is not available, use body op for reset in the first iter.
- CHECK(this->intrin->body.defined())
- << "Normal body op is not defined";
- Stmt update = TransformUpdate(stage, dom_map, n,
- this->intrin->body,
- this->intrin->reduce_update);
+ CHECK(this->intrin->body.defined()) << "Normal body op is not defined";
+ Stmt update =
+ TransformUpdate(stage, dom_map, n, this->intrin->body, this->intrin->reduce_update);
update = MergeNest(output_bind_nest, update);
update = MergeNest(input_bind_nest, update);
update = tir::Substitute(update, vmap);
* \brief Logics related to tensorize, used by ComputeOpNode.
* \file tensorize.cc
*/
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/analysis.h>
-#include <tvm/runtime/registry.h>
-#include "op_util.h"
-#include "compute_op.h"
#include "../schedule/message_passing.h"
+#include "compute_op.h"
+#include "op_util.h"
namespace tvm {
namespace te {
// out_dom: the domain of root iter vars in output op
// in_region: region of each input tensor.
// return The location of the tensorized scope start.
-size_t InferTensorizeRegion(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, Range>* out_dom,
- std::unordered_map<Tensor, Array<Range> >* in_region) {
+size_t InferTensorizeRegion(const ComputeOpNode* self, const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, Range>* out_dom,
+ std::unordered_map<Tensor, Array<Range> >* in_region) {
// Get the bound of the tensorized scope.
bool found_point = false;
size_t loc_scope = 0;
// Loop over the leafs
for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = stage->leaf_iter_vars[i - 1];
- CHECK(iv->iter_type == kDataPar ||
- iv->iter_type == kCommReduce);
+ CHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce);
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range& vrange = vit->second;
if (iit != stage->iter_var_attrs.end()) {
const IterVarAttr& attr = (*iit).second;
if (!found_point) {
- CHECK(!attr->bind_thread.defined())
- << "Do not allow thread in tensorize scope";
+ CHECK(!attr->bind_thread.defined()) << "Do not allow thread in tensorize scope";
}
if (attr->iter_type == kTensorized) {
CHECK(!found_point) << "Do not allow two tensorized point";
return loc_scope;
}
-void VerifyTensorizeLoopNest(const ComputeOpNode* self,
- const Stage& stage,
- const ComputeLoopNest& n,
- size_t tloc) {
+void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage,
+ const ComputeLoopNest& n, size_t tloc) {
// Veirfication step.
std::unordered_set<const VarNode*> banned;
CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
- CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
- n.init_nest.size() == 0);
+ CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || n.init_nest.size() == 0);
auto f_push_banned = [&banned](const Stmt& s) {
if (const ForNode* op = s.as<ForNode>()) {
- banned.insert(op->loop_var.get());
+ banned.insert(op->loop_var.get());
} else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) {
if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
banned.insert(iv->var.get());
}
}
- auto fbanned = [&](const VarNode* node) {
- return banned.count(node);
- };
+ auto fbanned = [&](const VarNode* node) { return banned.count(node); };
for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
- LOG(FATAL) << "Tensorize failed, split condition "
- << pred << " relies on var defined inside tensorize scope";
+ LOG(FATAL) << "Tensorize failed, split condition " << pred
+ << " relies on var defined inside tensorize scope";
}
}
for (const PrimExpr& pred : n.init_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
- LOG(FATAL) << "Tensorize failed, split condition "
- << pred << " relies on var defined inside tensorize scope";
+ LOG(FATAL) << "Tensorize failed, split condition " << pred
+ << " relies on var defined inside tensorize scope";
}
}
}
for (size_t i = e.start; i < e.region.size(); ++i) {
args.push_back(op->args[i] - e.region[i]->min);
}
- return CallNode::make(
- op->dtype, e.tensor->op->name, args,
- op->call_type, e.tensor->op, e.tensor->value_index);
+ return CallNode::make(op->dtype, e.tensor->op->name, args, op->call_type, e.tensor->op,
+ e.tensor->value_index);
}
}
return expr;
axis.push_back(it->second);
}
}
- return ReduceNode::make(
- op->combiner, op->source, axis, op->condition, op->value_index);
+ return ReduceNode::make(op->combiner, op->source, axis, op->condition, op->value_index);
}
- void Init(const ComputeOpNode* self,
- const Stage& stage,
+ void Init(const ComputeOpNode* self, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Range>& out_dom,
- const std::unordered_map<Tensor, Array<Range> >& in_region,
- const TensorIntrin& intrin,
+ const std::unordered_map<Tensor, Array<Range> >& in_region, const TensorIntrin& intrin,
Map<Var, Range>* compute_intrin_iter_space) {
CHECK(self == stage->op.get());
CHECK(is_one(canonical_extent))
<< "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin "
- << " expected shape=" << e.tensor->shape
- << ", given region=" << e.region;
+ << " expected shape=" << e.tensor->shape << ", given region=" << e.region;
}
in_remap_[inputs[i]] = e;
}
size_t axis_start = self->axis.size() - intrin_compute->axis.size();
for (size_t i = 0; i < axis_start; ++i) {
Range r = out_dom.at(self->axis[i]);
- CHECK(is_one(r->extent))
- << "Tensorize: Output mismatch with tensor intrin "
- << " intrin-dim=" << intrin_compute->axis.size()
- << ", tensorize-dim=" << self->axis.size();
+ CHECK(is_one(r->extent)) << "Tensorize: Output mismatch with tensor intrin "
+ << " intrin-dim=" << intrin_compute->axis.size()
+ << ", tensorize-dim=" << self->axis.size();
var_remap_[self->axis[i]->var.get()] = r->min;
}
// Assume we tensorize at regin axis i [min, min + extent)
axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
for (size_t i = 0; i < axis_start; ++i) {
Range r = out_dom.at(self->reduce_axis[i]);
- CHECK(is_one(r->extent))
- << "Tensorize: Reduction mismatch with tensor intrin "
- << " intrin-dim=" << intrin_compute->reduce_axis.size()
- << ", tensorize-dim=" << self->reduce_axis.size();
+ CHECK(is_one(r->extent)) << "Tensorize: Reduction mismatch with tensor intrin "
+ << " intrin-dim=" << intrin_compute->reduce_axis.size()
+ << ", tensorize-dim=" << self->reduce_axis.size();
var_remap_[self->reduce_axis[i]->var.get()] = r->min;
}
for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
};
// Try to match tensor dataflow of the stage with the intrinsic
-Array<PrimExpr> MatchTensorizeBody(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, Range>& out_dom,
- const std::unordered_map<Tensor, Array<Range> >& in_region,
- const TensorIntrin& intrin,
- Map<Var, Range>* compute_intrin_iter_space) {
+Array<PrimExpr> MatchTensorizeBody(const ComputeOpNode* self, const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, Range>& out_dom,
+ const std::unordered_map<Tensor, Array<Range> >& in_region,
+ const TensorIntrin& intrin,
+ Map<Var, Range>* compute_intrin_iter_space) {
TensorIntrinMatcher matcher;
matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
Array<PrimExpr> ret;
return ret;
}
-void VerifyTensorizeBody(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, Range>& out_dom,
- const std::unordered_map<Tensor, Array<Range> >& in_region,
- const TensorIntrin& intrin) {
+void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, Range>& out_dom,
+ const std::unordered_map<Tensor, Array<Range> >& in_region,
+ const TensorIntrin& intrin) {
StructuralEqual expr_equal;
Map<Var, Range> compute_intrin_iter_space;
Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
- &compute_intrin_iter_space);
+ &compute_intrin_iter_space);
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
CHECK(intrin_compute) << "Only support compute intrinsic for now";
- CHECK_EQ(body.size(), intrin_compute->body.size())
- << "Tensorize failed: body size mismatch";
+ CHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch";
arith::Analyzer ana;
ana.Bind(compute_intrin_iter_space);
PrimExpr lhs = ana.Simplify(body[i]);
PrimExpr rhs = ana.Simplify(intrin_compute->body[i]);
if (lhs.dtype() != rhs.dtype()) {
- LOG(FATAL)
- << "Failed to match the data type with TensorIntrin "
- << intrin->name << "'s declaration "
- << " provided=" << lhs.dtype()
- << ", intrin=" << rhs.dtype();
+ LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name
+ << "'s declaration "
+ << " provided=" << lhs.dtype() << ", intrin=" << rhs.dtype();
}
- CHECK(expr_equal(lhs, rhs))
- << "Failed to match the compute with TensorIntrin "
- << intrin->name << "'s declaration "
- << " provided= " << lhs
- << ", intrin= " << rhs;
+ CHECK(expr_equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " << intrin->name
+ << "'s declaration "
+ << " provided= " << lhs << ", intrin= " << rhs;
}
}
-Stmt MakeTensorize(const ComputeOpNode* self,
- const Stage& stage,
+Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
std::unordered_map<IterVar, Range> out_dom;
std::unordered_map<Tensor, Array<Range> > in_region;
size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
- TensorIntrin intrin = stage->iter_var_attrs.at(
- stage->leaf_iter_vars[tloc])->tensor_intrin;
+ TensorIntrin intrin = stage->iter_var_attrs.at(stage->leaf_iter_vars[tloc])->tensor_intrin;
CHECK(intrin.defined());
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
VerifyTensorizeLoopNest(self, stage, n, tloc);
Stmt nop = EvaluateNode::make(0);
std::vector<Stmt> input_bind_nest, output_bind_nest;
Array<Tensor> inputs = self->InputTensors();
- CHECK_EQ(inputs.size(), intrin->inputs.size())
- << "Tensorize failed: input size mismatch ";
+ CHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch ";
// input binding
for (size_t i = 0; i < intrin->inputs.size(); ++i) {
Tensor tensor = inputs[i];
}
input_bind_nest.emplace_back(AttrStmtNode::make(
bind_spec, tir::attr::buffer_bind_scope,
- CallNode::make(DataType::Handle(),
- tir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic), nop));
+ CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic),
+ nop));
}
// output binding
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
Array<ObjectRef> bind_spec{buffer, tensor};
output_bind_nest.emplace_back(AttrStmtNode::make(
bind_spec, tir::attr::buffer_bind_scope,
- CallNode::make(DataType::Handle(),
- tir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic), nop));
+ CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic),
+ nop));
}
// Check variable remap
std::unordered_map<const VarNode*, PrimExpr> vmap;
IterVar iv = self->reduce_axis[i];
auto it = out_dom.find(iv);
CHECK(it != out_dom.end());
- CHECK(is_one(it->second->extent))
- << "Tensorization fail: reduction axis size do not match";
+ CHECK(is_one(it->second->extent)) << "Tensorization fail: reduction axis size do not match";
}
for (size_t i = start; i < self->reduce_axis.size(); ++i) {
IterVar iv = self->reduce_axis[i];
CHECK(it != out_dom.end());
binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0),
"tensir_intrin.reduction.min");
- binder.Bind(target->dom->extent, it->second->extent,
- "tensir_intrin.reduction.extent");
+ binder.Bind(target->dom->extent, it->second->extent, "tensir_intrin.reduction.extent");
}
if (tloc <= n.num_common_loop) {
// Do no need to split reduction
- std::vector<std::vector<Stmt> > nest(
- n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
+ std::vector<std::vector<Stmt> > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
nest.emplace_back(MakeIfNest(n.main_predicates));
CHECK_EQ(n.init_predicates.size(), 0U);
- CHECK(intrin->body.defined())
- << "Normal store op for intrin " << intrin << " is not defined";
+ CHECK(intrin->body.defined()) << "Normal store op for intrin " << intrin << " is not defined";
Stmt body = MergeNest(output_bind_nest, intrin->body);
body = MergeNest(input_bind_nest, body);
body = tir::Substitute(body, vmap);
<< "Reduction update op for intrin " << intrin << " is not defined";
// Need init and update steps
CHECK_NE(self->reduce_axis.size(), 0U);
- std::vector<std::vector<Stmt> > common(
- n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
- std::vector<std::vector<Stmt> > update_nest(
- n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
+ std::vector<std::vector<Stmt> > common(n.main_nest.begin(),
+ n.main_nest.begin() + n.num_common_loop + 1);
+ std::vector<std::vector<Stmt> > update_nest(n.main_nest.begin() + n.num_common_loop + 1,
+ n.main_nest.begin() + tloc + 1);
update_nest.emplace_back(MakeIfNest(n.main_predicates));
if (intrin->reduce_init.defined()) {
// init nest
- std::vector<std::vector<Stmt> > init_nest(
- n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
+ std::vector<std::vector<Stmt> > init_nest(n.init_nest.begin(),
+ n.init_nest.begin() + tloc + 1);
init_nest.emplace_back(MakeIfNest(n.init_predicates));
Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
init = te::Substitute(init, n.init_vmap);
return MergeNest(common, SeqStmt::Flatten(init, update));
} else {
// When init op is not available, use body op for reset in the first iter.
- CHECK(intrin->body.defined())
- << "Normal body op for intrin " << intrin << " is not defined";
- Stmt update = TransformUpdate(stage, dom_map, n,
- intrin->body,
- intrin->reduce_update);
+ CHECK(intrin->body.defined()) << "Normal body op for intrin " << intrin << " is not defined";
+ Stmt update = TransformUpdate(stage, dom_map, n, intrin->body, intrin->reduce_update);
update = MergeNest(output_bind_nest, update);
update = MergeNest(input_bind_nest, update);
update = tir::Substitute(update, vmap);
}
// Register functions for unittests
-TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- Stage stage = args[0];
- Map<IterVar, Range> dmap = args[1];
- std::unordered_map<IterVar, Range> out_dom;
- std::unordered_map<Tensor, Array<Range> > in_region;
- CHECK(stage->op.as<ComputeOpNode>());
- InferTensorizeRegion(stage->op.as<ComputeOpNode>(),
- stage,
- as_unordered_map(dmap),
- &out_dom, &in_region);
- *ret = Array<ObjectRef>{Map<IterVar, Range>(out_dom),
- Map<Tensor, Array<Range> >(in_region)};
- });
+TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion").set_body([](TVMArgs args, TVMRetValue* ret) {
+ Stage stage = args[0];
+ Map<IterVar, Range> dmap = args[1];
+ std::unordered_map<IterVar, Range> out_dom;
+ std::unordered_map<Tensor, Array<Range> > in_region;
+ CHECK(stage->op.as<ComputeOpNode>());
+ InferTensorizeRegion(stage->op.as<ComputeOpNode>(), stage, as_unordered_map(dmap), &out_dom,
+ &in_region);
+ *ret = Array<ObjectRef>{Map<IterVar, Range>(out_dom), Map<Tensor, Array<Range> >(in_region)};
+});
-TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- Stage stage = args[0];
- Map<IterVar, Range> out_dom = args[1];
- Map<Tensor, Array<Range> > in_region = args[2];
- TensorIntrin intrin = args[3];
- Map<Var, Range> vrange;
- CHECK(stage->op.as<ComputeOpNode>());
- *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
- stage,
- {{}},
- as_unordered_map(out_dom),
- as_unordered_map(in_region),
- intrin,
- &vrange);
- });
+TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody").set_body([](TVMArgs args, TVMRetValue* ret) {
+ Stage stage = args[0];
+ Map<IterVar, Range> out_dom = args[1];
+ Map<Tensor, Array<Range> > in_region = args[2];
+ TensorIntrin intrin = args[3];
+ Map<Var, Range> vrange;
+ CHECK(stage->op.as<ComputeOpNode>());
+ *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(), stage, {{}}, as_unordered_map(out_dom),
+ as_unordered_map(in_region), intrin, &vrange);
+});
} // namespace te
} // namespace tvm
* \file auto_inline_elem_wise.cc
*/
#include <tvm/runtime/registry.h>
-#include <tvm/te/schedule_pass.h>
#include <tvm/te/operation.h>
+#include <tvm/te/schedule_pass.h>
#include <tvm/tir/expr_functor.h>
namespace tvm {
Array<IterVar> axis_;
};
-
bool IsElemWise(const Operation& op) {
if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
ElemWiseDetector v = ElemWiseDetector(compute->axis);
}
}
-TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
-.set_body_typed(AutoInlineElemWise);
-
+TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise").set_body_typed(AutoInlineElemWise);
-TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective")
-.set_body_typed(AutoInlineInjective);
+TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective").set_body_typed(AutoInlineInjective);
} // namespace te
} // namespace tvm
* \brief The bound inference logic.
*/
#include <tvm/runtime/registry.h>
-#include <tvm/te/schedule_pass.h>
#include <tvm/te/operation.h>
+#include <tvm/te/schedule_pass.h>
+
#include <unordered_map>
#include <unordered_set>
+
+#include "../../runtime/thread_storage_scope.h"
#include "graph.h"
#include "message_passing.h"
-#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace te {
std::unordered_map<const Object*, Stage> op2stage_;
};
-bool NeedRelax(const IterVar& iv,
- bool found_attach,
+bool NeedRelax(const IterVar& iv, bool found_attach,
const std::unordered_map<IterVar, IterVar>& bind_map,
const runtime::StorageScope& scope) {
auto it = bind_map.find(iv);
- const std::string& tag = (
- it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
+ const std::string& tag = (it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
if (tag.length() == 0 || tag == "pipeline") {
return !found_attach;
}
// When there is warp memory
// threadIdx.x must be set to be warp index.
- if (scope.rank == StorageRank::kWarp &&
- ts.rank == 1 &&
- ts.dim_index == 0) {
+ if (scope.rank == StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) {
return true;
}
return static_cast<int>(scope.rank) <= ts.rank;
}
// infer storage scope, if not given
-StorageScope InferStorageScope(
- const Stage& stage, const GraphContext& ctx) {
+StorageScope InferStorageScope(const Stage& stage, const GraphContext& ctx) {
if (stage->scope.length() != 0) {
return StorageScope::make(stage->scope);
}
int max_rank = -1;
for (IterVar iv : ctx.attach_path.at(stage->op)) {
auto it = ctx.bind_map.find(iv);
- const std::string& tag = (
- it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
+ const std::string& tag = (it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
if (tag != "pipeline" && tag.length() != 0) {
max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
}
return s;
}
-
-void InferRootBound(const Stage& stage,
- const GraphContext& ctx,
+void InferRootBound(const Stage& stage, const GraphContext& ctx,
std::unordered_map<IterVar, Range>* rmap) {
- CHECK_NE(stage->attach_type, kInline)
- << "call schedule.normalize before scheduleops";
+ CHECK_NE(stage->attach_type, kInline) << "call schedule.normalize before scheduleops";
if (stage->attach_type == kInlinedAlready) return;
if (stage->is_output) {
// verify correctness.
- CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
- << "Output must be attached at root";
+ CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) << "Output must be attached at root";
}
if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
- for (auto iv : stage->op->root_iter_vars()) {
+ for (auto iv : stage->op->root_iter_vars()) {
CHECK(iv->dom.defined());
CHECK(!rmap->count(iv));
(*rmap)[iv] = iv->dom;
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
} else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
- CHECK(is_zero(vrange->min))
- << "InferBound requires every leaf iter var's min equals 0, "
- << " call schedule.normalize to achieve this. ";
+ CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, "
+ << " call schedule.normalize to achieve this. ";
if (ctx.bind_map.count(iv)) {
up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
} else {
found_attach = true;
}
Range vrange = rmap->at(iv);
- CHECK(is_zero(vrange->min))
- << "InferBound requires every leaf iter var's min equals 0, "
- << "call schedule.normalize to achieve this.";
+ CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, "
+ << "call schedule.normalize to achieve this.";
if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
relax_set.Set(iv->var, IntSet::range(vrange));
if (ctx.bind_map.count(iv)) {
r = iv->dom;
}
if (relax_set.size() != 0) {
- dom_map[iv->var.get()] = IntSet::interval(
- analyzer.int_set(r->min, relax_set).min(),
- analyzer.int_set(r->min + r->extent - 1, relax_set).max());
+ dom_map[iv->var.get()] =
+ IntSet::interval(analyzer.int_set(r->min, relax_set).min(),
+ analyzer.int_set(r->min + r->extent - 1, relax_set).max());
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
}
}
for (auto& p : ret) {
- ret[p.first] = Range::make_by_min_extent(
- analyzer.Simplify(p.second->min),
- analyzer.Simplify(p.second->extent));
+ ret[p.first] = Range::make_by_min_extent(analyzer.Simplify(p.second->min),
+ analyzer.Simplify(p.second->extent));
}
return Map<IterVar, Range>(ret.begin(), ret.end());
}
-TVM_REGISTER_GLOBAL("schedule.InferBound")
-.set_body_typed(InferBound);
+TVM_REGISTER_GLOBAL("schedule.InferBound").set_body_typed(InferBound);
} // namespace te
} // namespace tvm
* \file graph.cc
* \brief Utilities to get information about schedule graph.
*/
+#include "graph.h"
+
#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/te/operation.h>
-#include <utility>
-#include <unordered_set>
+
#include <unordered_map>
-#include "graph.h"
+#include <unordered_set>
+#include <utility>
namespace tvm {
namespace te {
int dim;
TensorDimKey() {}
TensorDimKey(const tir::CallNode* op, int dim)
- : f(op->func), value_index(op->value_index), dim(dim) {
- }
- TensorDimKey(const Tensor& t, int dim)
- : f(t->op), value_index(t->value_index), dim(dim) {
- }
+ : f(op->func), value_index(op->value_index), dim(dim) {}
+ TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {}
TensorDimKey(const Tensor& t, size_t dim)
- : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
- }
+ : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
inline bool operator==(const TensorDimKey& other) const {
- return f == other.f &&
- value_index == other.value_index &&
- dim == other.dim;
- }
- inline bool operator!=(const TensorDimKey& other) const {
- return !operator==(other);
+ return f == other.f && value_index == other.value_index && dim == other.dim;
}
+ inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); }
};
} // namespace te
} // namespace tvm
struct hash<::tvm::te::TensorDimKey> {
std::size_t operator()(const ::tvm::te::TensorDimKey& k) const {
size_t lhs = ::tvm::ObjectHash()(k.f);
- size_t rhs = static_cast<size_t>(k.value_index) << 16UL |
- static_cast<size_t>(k.dim);
+ size_t rhs = static_cast<size_t>(k.value_index) << 16UL | static_cast<size_t>(k.dim);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
-
namespace tvm {
namespace te {
// Do DFS visit to get the subgraph.
// Return if op is inside the subgraph.
-bool GetSubGraphByPostDFS_(
- const Operation& op,
- const std::unordered_set<const Object*>& boundary,
- bool include_bounary,
- std::unordered_map<const Object*, bool>* visited,
- Array<Operation>* result) {
+bool GetSubGraphByPostDFS_(const Operation& op, const std::unordered_set<const Object*>& boundary,
+ bool include_bounary, std::unordered_map<const Object*, bool>* visited,
+ Array<Operation>* result) {
if (visited->count(op.get())) {
return visited->at(op.get());
}
// check if we can reach boundary.
bool reach_boundary = false;
for (Tensor t : op->InputTensors()) {
- if (GetSubGraphByPostDFS_(t->op, boundary,
- include_bounary,
- visited, result)) {
+ if (GetSubGraphByPostDFS_(t->op, boundary, include_bounary, visited, result)) {
reach_boundary = true;
}
}
return reach_boundary;
}
-Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
- const Array<Tensor>& inputs,
+Array<Operation> GetSubGraph(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
bool include_inputs) {
Array<Operation> result;
std::unordered_set<const Object*> boundary;
}
std::unordered_map<const Object*, bool> visited;
for (Tensor t : outputs) {
- GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
- &visited, &result);
+ GetSubGraphByPostDFS_(t->op, boundary, include_inputs, &visited, &result);
}
return result;
}
-
-void PostDFSOrder(const Operation& op,
- const ReadGraph& g,
- std::unordered_set<Operation>* visited,
+void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set<Operation>* visited,
Array<Operation>* post_order) {
if (visited->count(op)) return;
visited->insert(op);
post_order->push_back(op);
}
-Array<Operation> PostDFSOrder(
- const Array<Operation>& roots,
- const ReadGraph& g) {
+Array<Operation> PostDFSOrder(const Array<Operation>& roots, const ReadGraph& g) {
std::unordered_set<Operation> visited;
Array<Operation> post_order;
for (Operation op : roots) {
std::unordered_set<const Object*> visited;
Array<IterVar> path;
for (Stage s = stage; s.defined();) {
- CHECK(!visited.count(s.get()))
- << "Find loop in compute_at attach group";
+ CHECK(!visited.count(s.get())) << "Find loop in compute_at attach group";
visited.insert(s.get());
Stage spec = s.GetAttachSpec();
bool start_attach;
}
if (start_attach) path.push_back(iv);
}
- CHECK(start_attach)
- << "Invalid Schedule: cannot find attach point " << attach_ivar
- << " in the schedule of " << s->op;
+ CHECK(start_attach) << "Invalid Schedule: cannot find attach point " << attach_ivar
+ << " in the schedule of " << s->op;
}
if (!ret.count(stage->op)) {
ret.Set(stage->op, path);
}
// graph of push reach relation of tensor dimensions
-using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;
+using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey>>;
ReachGraph GetReachGraph(const Array<Operation>& ops) {
ReachGraph reach;
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i);
for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
- reach[TensorDimKey(t, k)].emplace_back(
- TensorDimKey(update[i], k));
- reach[TensorDimKey(t, k)].emplace_back(
- TensorDimKey(init[i], k));
+ reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(update[i], k));
+ reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(init[i], k));
}
}
} else if (const auto* compute_op = op.as<ComputeOpNode>()) {
reach[TensorDimKey(t, i)] = {};
}
auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
- const tir::CallNode *call = n.as<tir::CallNode>();
+ const tir::CallNode* call = n.as<tir::CallNode>();
if (call != nullptr && call->func.defined()) {
if (!bset.count(call->func.get())) return;
for (size_t i = 0; i < call->args.size(); ++i) {
TensorDimKey dkey(call, static_cast<int>(i));
auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
- const VarNode *v = node.as<VarNode>();
+ const VarNode* v = node.as<VarNode>();
auto it = vmap.find(v);
if (it != vmap.end()) {
reach[it->second].push_back(dkey);
}
}
// merge exact reach
- auto f_merge_key = [&exact_reach, &fail_set](
- const TensorDimKey& dst, const TensorDimKey& src) {
+ auto f_merge_key = [&exact_reach, &fail_set](const TensorDimKey& dst, const TensorDimKey& src) {
auto sit = exact_reach.find(src);
if (sit == exact_reach.end()) return;
auto dit = exact_reach.find(dst);
}
}
} else if (const auto* compute_op = op.as<ComputeOpNode>()) {
- std::unordered_map<const Object*, std::vector<TensorDimKey> > vmap;
+ std::unordered_map<const Object*, std::vector<TensorDimKey>> vmap;
const auto& axis = compute_op->axis;
for (size_t i = 0; i < axis.size(); ++i) {
std::vector<TensorDimKey> keys;
}
vmap[axis[i]->var.get()] = std::move(keys);
}
- auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
- const ObjectRef& n) {
- const tir::CallNode *call = n.as<tir::CallNode>();
+ auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](const ObjectRef& n) {
+ const tir::CallNode* call = n.as<tir::CallNode>();
if (call != nullptr && call->func.defined()) {
for (size_t i = 0; i < call->args.size(); ++i) {
auto it = vmap.find(call->args[i].get());
TensorDimKey key(scan->update[i], k);
TensorDimKey target(scan->state_placeholder[i], k);
IterVar sp_iv = scan->spatial_axis_[sp_idx];
- if (fail_set.count(sp_iv.get()) ||
- !exact_reach.count(key) ||
+ if (fail_set.count(sp_iv.get()) || !exact_reach.count(key) ||
exact_reach.at(key) != sp_iv.get()) {
ret.Set(sp_iv, make_const(DataType::Int(32), 0));
} else {
return ret;
}
-
-TVM_REGISTER_GLOBAL("schedule.CreateReadGraph")
-.set_body_typed(CreateReadGraph);
+TVM_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph);
TVM_REGISTER_GLOBAL("schedule.PostDFSOrder")
-.set_body_typed([](const Array<Operation>& roots,
- const ReadGraph& g) {
- return PostDFSOrder(roots, g);
-});
+ .set_body_typed([](const Array<Operation>& roots, const ReadGraph& g) {
+ return PostDFSOrder(roots, g);
+ });
-TVM_REGISTER_GLOBAL("schedule.CreateAttachPath")
-.set_body_typed(CreateAttachPath);
+TVM_REGISTER_GLOBAL("schedule.CreateAttachPath").set_body_typed(CreateAttachPath);
-TVM_REGISTER_GLOBAL("schedule.ScanGetBody")
-.set_body_typed(ScanGetBody);
+TVM_REGISTER_GLOBAL("schedule.ScanGetBody").set_body_typed(ScanGetBody);
-TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis")
-.set_body_typed(ScanFixPointAnalysis);
+TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis").set_body_typed(ScanFixPointAnalysis);
} // namespace te
} // namespace tvm
#ifndef TVM_TE_SCHEDULE_GRAPH_H_
#define TVM_TE_SCHEDULE_GRAPH_H_
-#include <tvm/tir/expr.h>
-#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/tir/expr.h>
+
#include <unordered_map>
#include <unordered_set>
#include <vector>
*
* \return The subgraph.
*/
-Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
- const Array<Tensor>& inputs,
+Array<Operation> GetSubGraph(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
bool include_inputs);
/*!
* \note PostDFSOrder is a special case of Topoligical order,
* and can be used when topoligical order is needed.
*/
-Array<Operation> PostDFSOrder(
- const Array<Operation>& roots, const ReadGraph& g);
+Array<Operation> PostDFSOrder(const Array<Operation>& roots, const ReadGraph& g);
/*!
* \brief Create feedgraph for given Schedule
* \file message_passing.cc
* \brief The message passing domain.
*/
+#include "message_passing.h"
+
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
-#include "message_passing.h"
+
#include "../../arith/compute_expr.h"
namespace tvm {
using namespace tir;
-void Update(std::unordered_map<IterVar, Range>* p_state,
- const IterVar& iv,
- Range r,
+void Update(std::unordered_map<IterVar, Range>* p_state, const IterVar& iv, Range r,
arith::Analyzer* analyzer) {
auto it = p_state->find(iv);
if (it == p_state->end()) {
(*p_state)[iv] = r;
analyzer->Bind(iv->var, r);
} else {
- bool match = is_zero(it->second->min) &&
- analyzer->CanProve(r->extent - it->second->extent == 0);
- CHECK(match)
- << iv
- << " domain already inferred,"
- << " cannot prove their extents are the same "
- << it->second->extent << " vs " << r->extent;
+ bool match =
+ is_zero(it->second->min) && analyzer->CanProve(r->extent - it->second->extent == 0);
+ CHECK(match) << iv << " domain already inferred,"
+ << " cannot prove their extents are the same " << it->second->extent << " vs "
+ << r->extent;
}
}
}
}
-void PassDownDomain(const Stage& stage,
- std::unordered_map<IterVar, Range>* p_state,
- arith::Analyzer* actx,
- bool allow_missing) {
+void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_state,
+ arith::Analyzer* actx, bool allow_missing) {
auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(indexdiv(a, b));
return actx->Simplify(indexdiv(a + (b - 1), b));
};
- auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) {
+ auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(a < b)) {
return actx->Simplify(a);
}
};
if (r->factor.defined()) {
Update(p_state, r->inner,
- Range::make_by_min_extent(
- 0, resolve_min_extent_for_split(r->inner, r->factor)),
+ Range::make_by_min_extent(0, resolve_min_extent_for_split(r->inner, r->factor)),
actx);
Update(p_state, r->outer,
- Range::make_by_min_extent(
- 0, ceil_div(range_parent->extent, r->factor)), actx);
+ Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->factor)), actx);
} else {
Update(p_state, r->outer,
- Range::make_by_min_extent(
- 0, resolve_min_extent_for_split(r->outer, r->nparts)),
+ Range::make_by_min_extent(0, resolve_min_extent_for_split(r->outer, r->nparts)),
actx);
Update(p_state, r->inner,
- Range::make_by_min_extent(
- 0, ceil_div(range_parent->extent, r->nparts)), actx);
+ Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->nparts)), actx);
}
} else if (const FuseNode* r = rel.as<FuseNode>()) {
if (!state.count(r->outer) || !state.count(r->inner)) {
}
const Range& range_outer = state.at(r->outer);
const Range& range_inner = state.at(r->inner);
- state[r->fused] = Range::make_by_min_extent(
- 0, range_outer->extent * range_inner->extent);
+ state[r->fused] = Range::make_by_min_extent(0, range_outer->extent * range_inner->extent);
} else if (const RebaseNode* r = rel.as<RebaseNode>()) {
if (!state.count(r->parent)) {
CHECK(allow_missing);
continue;
}
- Update(p_state, r->rebased,
- Range::make_by_min_extent(
- 0, state.at(r->parent)->extent), actx);
+ Update(p_state, r->rebased, Range::make_by_min_extent(0, state.at(r->parent)->extent), actx);
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx);
} else {
}
}
-void PassUpIndex(const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, PrimExpr>* p_state,
- bool allow_missing) {
+void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) {
auto& state = *p_state;
for (size_t i = stage->relations.size(); i != 0; --i) {
IterVarRelation rel = stage->relations[i - 1];
}
}
-void PassDownIndex(const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, PrimExpr>* p_state,
- bool allow_missing) {
+void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) {
auto& state = *p_state;
for (IterVarRelation rel : stage->relations) {
if (const SplitNode* s = rel.as<SplitNode>()) {
}
// Domain message passing.
-void PassUpDomain(const SplitNode* s,
- const std::unordered_map<IterVar, Range>& dom_map,
- const IntSet& outer,
- const IntSet& inner,
- IntSet* parent) {
- if (dom_map.count(s->outer) &&
- dom_map.count(s->inner) &&
- dom_map.count(s->parent) &&
- outer.match_range(dom_map.at(s->outer)) &&
- inner.match_range(dom_map.at(s->inner))) {
+void PassUpDomain(const SplitNode* s, const std::unordered_map<IterVar, Range>& dom_map,
+ const IntSet& outer, const IntSet& inner, IntSet* parent) {
+ if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) &&
+ outer.match_range(dom_map.at(s->outer)) && inner.match_range(dom_map.at(s->inner))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
- *parent = arith::EvalSet(
- s->outer->var * factor + s->inner->var + parent_min,
- {{s->outer, outer}, {s->inner, inner}});
+ *parent = arith::EvalSet(s->outer->var * factor + s->inner->var + parent_min,
+ {{s->outer, outer}, {s->inner, inner}});
}
-void PassUpDomain(const FuseNode* s,
- const std::unordered_map<IterVar, Range>& dom_map,
- const IntSet& fused,
- IntSet* outer,
- IntSet* inner) {
+void PassUpDomain(const FuseNode* s, const std::unordered_map<IterVar, Range>& dom_map,
+ const IntSet& fused, IntSet* outer, IntSet* inner) {
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));
if (fused.is_single_point()) {
PrimExpr value = fused.point_value();
PrimExpr factor = dom_map.at(s->inner)->extent;
- PrimExpr v_outer = indexdiv(value, factor);
- PrimExpr v_inner = indexmod(value, factor);
+ PrimExpr v_outer = indexdiv(value, factor);
+ PrimExpr v_inner = indexmod(value, factor);
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
} else {
PrimExpr fused_extent = (fused.max() - fused.min() + 1);
PrimExpr inner_extent = dom_map.at(s->inner)->extent;
- *outer = IntSet::interval(
- outer_min + indexdiv(fused.min(), inner_extent),
- outer_min + indexdiv(fused.max(), inner_extent));
+ *outer = IntSet::interval(outer_min + indexdiv(fused.min(), inner_extent),
+ outer_min + indexdiv(fused.max(), inner_extent));
if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) &&
is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) {
// fused never spans multiple rows, make a tight bounding box
} else { // fused may span multiple rows, use full row widths
if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) ||
!is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) {
- LOG(WARNING) <<
- "fused and original axes are not aligned, this may cause redundant computations";
+ LOG(WARNING)
+ << "fused and original axes are not aligned, this may cause redundant computations";
}
*inner = IntSet::range(dom_map.at(s->inner));
}
}
}
-void PassUpDomain(const RebaseNode* s,
- const std::unordered_map<IterVar, Range>& dom_map,
- const IntSet& rebased,
- IntSet* parent) {
+void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>& dom_map,
+ const IntSet& rebased, IntSet* parent) {
CHECK(dom_map.count(s->parent));
if (rebased.match_range(dom_map.at(s->rebased))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
PrimExpr parent_min = dom_map.at(s->parent)->min;
- *parent = arith::EvalSet(s->rebased->var + parent_min,
- {{s->rebased, rebased}});
+ *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}});
}
-void PassUpDomain(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
+void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state;
for (size_t i = stage->relations.size(); i != 0; --i) {
IterVarRelation rel = stage->relations[i - 1];
if (const SplitNode* r = rel.as<SplitNode>()) {
IntSet parent;
- PassUpDomain(r, dom_map,
- state.at(r->outer), state.at(r->inner),
- &parent);
+ PassUpDomain(r, dom_map, state.at(r->outer), state.at(r->inner), &parent);
state[r->parent] = parent;
} else if (const FuseNode* r = rel.as<FuseNode>()) {
IntSet outer, inner;
- PassUpDomain(r, dom_map,
- state.at(r->fused),
- &outer, &inner);
+ PassUpDomain(r, dom_map, state.at(r->fused), &outer, &inner);
state[r->outer] = outer;
state[r->inner] = inner;
} else if (const RebaseNode* r = rel.as<RebaseNode>()) {
IntSet parent;
- PassUpDomain(r, dom_map,
- state.at(r->rebased),
- &parent);
+ PassUpDomain(r, dom_map, state.at(r->rebased), &parent);
state[r->parent] = parent;
} else if (rel.as<SingletonNode>()) {
} else {
}
// Pass up bit mask with or relation.
-void PassUpBitMaskOr(const Stage& stage,
- std::unordered_map<IterVar, int>* p_state,
+void PassUpBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
bool allow_missing) {
auto& state = *p_state;
for (size_t i = stage->relations.size(); i != 0; --i) {
}
}
-void PassDownBitMaskOr(const Stage& stage,
- std::unordered_map<IterVar, int>* p_state,
+void PassDownBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
bool allow_missing) {
auto& state = *p_state;
for (IterVarRelation rel : stage->relations) {
}
}
-
/*!
* \brief message passing to find if boundary checking on IterVar is needed.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
-void PassUpBoundCheck(const Stage& s,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, bool>* p_state,
- arith::Analyzer* analyzer) {
+void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, bool>* p_state, arith::Analyzer* analyzer) {
auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
arith::Analyzer analyzer;
if (input_1.same_as(input_2)) return true;
- return (analyzer.CanProve(input_1->min == input_2->min)
- && analyzer.CanProve(input_1->extent == input_2->extent));
+ return (analyzer.CanProve(input_1->min == input_2->min) &&
+ analyzer.CanProve(input_1->extent == input_2->extent));
}
-std::vector<PrimExpr> MakeBoundCheck(
- const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, PrimExpr>& value_map,
- bool skip_ivar_domain,
- const std::unordered_set<IterVar>& skip_iter) {
+std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, PrimExpr>& value_map,
+ bool skip_ivar_domain,
+ const std::unordered_set<IterVar>& skip_iter) {
arith::Analyzer analyzer;
std::unordered_map<IterVar, bool> bound_state;
#ifndef TVM_TE_SCHEDULE_MESSAGE_PASSING_H_
#define TVM_TE_SCHEDULE_MESSAGE_PASSING_H_
-#include <tvm/tir/expr.h>
-#include <tvm/te/schedule.h>
-#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/tir/expr.h>
+
#include <unordered_map>
#include <unordered_set>
#include <vector>
* \param analyzer Analyzer context, storing information about bounds in p_state.
* \param allow_missing Whether allow missing value.
*/
-void PassDownDomain(
- const Stage& stage,
- std::unordered_map<IterVar, Range>* p_state,
- arith::Analyzer* analyzer,
- bool allow_missing = false);
+void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_state,
+ arith::Analyzer* analyzer, bool allow_missing = false);
/*!
* \param Upward inference of index of each IterVar.
* \param p_state The index state of each IterVar.
* \param allow_missing Whether allow missing value.
*/
-void PassUpIndex(const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, PrimExpr>* p_state,
- bool allow_missing = false);
+void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing = false);
/*!
* \param Downward inference of index of each IterVar.
* \param p_state The index state of each IterVar.
* \param allow_missing Whether allow missing value.
*/
-void PassDownIndex(const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, PrimExpr>* p_state,
- bool allow_missing = false);
+void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing = false);
/*!
* \param Upward inference of domain set of each IterVar.
* \param dom_map The domain map of each iteration variable's maximum domain.
* \param p_state The index state of each IterVar.
*/
-void PassUpDomain(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
+void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state);
/*!
* \param p_state The index state of each IterVar.
* \param allow_missing Whether allow missing value.
*/
-void PassUpBitMaskOr(const Stage& stage,
- std::unordered_map<IterVar, int>* p_state,
+void PassUpBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
bool allow_missing = false);
/*!
* \param p_state The index state of each IterVar.
* \param allow_missing Whether allow missing value.
*/
-void PassDownBitMaskOr(const Stage& stage,
- std::unordered_map<IterVar, int>* p_state,
+void PassDownBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
bool allow_missing = false);
/*!
* \param skip_iter The set of variables to skip bound condition.
* \return List of predicates that we need to check.
*/
-std::vector<PrimExpr>
-MakeBoundCheck(
- const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, PrimExpr>& value_map,
- bool skip_ivar_domain,
- const std::unordered_set<IterVar>& skip_iter);
+std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, PrimExpr>& value_map,
+ bool skip_ivar_domain,
+ const std::unordered_set<IterVar>& skip_iter);
} // namespace te
} // namespace tvm
/*!
* \file operation_inline.cc
*/
+#include "operation_inline.h"
+
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
-#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
+
#include <utility>
-#include "operation_inline.h"
-#include "../../tir/transforms/ir_util.h"
+#include "../../tir/transforms/ir_util.h"
namespace tvm {
namespace te {
for (size_t i = 0; i < args_.size(); ++i) {
vmap.Set(args_[i], op->args[i]);
}
- expr = Substitute(
- EvaluateNode::make(expr), vmap).as<EvaluateNode>()->value;
+ expr = Substitute(EvaluateNode::make(expr), vmap).as<EvaluateNode>()->value;
}
return expr;
} else {
PrimExpr body_;
};
-Stmt Inline(Stmt stmt,
- Operation f,
- Array<Var> args,
- PrimExpr body) {
- CHECK_EQ(f->num_outputs(), 1)
- << "can only inline output single value operation";
+Stmt Inline(Stmt stmt, Operation f, Array<Var> args, PrimExpr body) {
+ CHECK_EQ(f->num_outputs(), 1) << "can only inline output single value operation";
Stmt ret = OperationInliner(f, args, body)(std::move(stmt));
if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret);
#ifndef TVM_TE_SCHEDULE_OPERATION_INLINE_H_
#define TVM_TE_SCHEDULE_OPERATION_INLINE_H_
-#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt.h>
#include <tvm/te/operation.h>
#include <tvm/te/tensor.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt.h>
namespace tvm {
namespace te {
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
-Stmt Inline(Stmt stmt,
- Operation op,
- Array<Var> args,
- PrimExpr body);
+Stmt Inline(Stmt stmt, Operation op, Array<Var> args, PrimExpr body);
} // namespace te
} // namespace tvm
/*!
* \file schedule_dataflow_rewrite.cc
*/
-#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
#include <tvm/tir/stmt_functor.h>
+
#include <unordered_set>
-#include "message_passing.h"
-#include "operation_inline.h"
-#include "../../tir/transforms/ir_util.h"
#include "../../arith/compute_expr.h"
+#include "../../tir/transforms/ir_util.h"
+#include "message_passing.h"
+#include "operation_inline.h"
namespace tvm {
namespace te {
// find first occurance location in leaf
-template<typename T>
+template <typename T>
size_t FindNodeRef(ArrayNode* array_node, const T& v) {
const Object* n = v.get();
for (size_t i = 0; i < array_node->data.size(); ++i) {
// The replacer of cache.
class VarReplacer : public tir::StmtExprMutator {
public:
- explicit VarReplacer(
- const std::unordered_map<const VarNode*, PrimExpr>& vsub)
- : vsub_(vsub) {}
+ explicit VarReplacer(const std::unordered_map<const VarNode*, PrimExpr>& vsub) : vsub_(vsub) {}
PrimExpr VisitExpr_(const VarNode* op) final {
auto it = vsub_.find(op);
if (it != vsub_.end()) return it->second;
tir::CommReducer MutateCommReducer(tir::CommReducer combiner) {
// Replace free variables in combiner
- auto new_identity = tir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) {
- return this->VisitExpr(e);
- });
- auto new_result = tir::UpdateArray(combiner->result, [this] (const PrimExpr& e) {
- return this->VisitExpr(e);
- });
+ auto new_identity = tir::UpdateArray(combiner->identity_element,
+ [this](const PrimExpr& e) { return this->VisitExpr(e); });
+ auto new_result = tir::UpdateArray(combiner->result,
+ [this](const PrimExpr& e) { return this->VisitExpr(e); });
if (combiner->identity_element.same_as(new_identity) &&
combiner->identity_element.same_as(new_result)) {
return combiner;
} else {
- return tir::CommReducerNode::make(
- combiner->lhs, combiner->rhs, new_result, new_identity);
+ return tir::CommReducerNode::make(combiner->lhs, combiner->rhs, new_result, new_identity);
}
}
if (op->combiner.same_as(new_combiner)) {
return new_e;
} else {
- return tir::ReduceNode::make(
- new_combiner,
- new_reduce->source,
- new_reduce->axis,
- new_reduce->condition,
- new_reduce->value_index);
+ return tir::ReduceNode::make(new_combiner, new_reduce->source, new_reduce->axis,
+ new_reduce->condition, new_reduce->value_index);
}
}
const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
};
-PrimExpr InjectPredicate(const Array<PrimExpr>& predicates,
- PrimExpr body) {
+PrimExpr InjectPredicate(const Array<PrimExpr>& predicates, PrimExpr body) {
using tir::ReduceNode;
using tir::SelectNode;
if (predicates.size() == 0) return body;
n->condition = n->condition && arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr());
return PrimExpr(n);
}
- return SelectNode::make(arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr()),
- body,
- make_zero(body.dtype()));
+ return SelectNode::make(arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr()), body,
+ make_zero(body.dtype()));
}
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
-void ReplaceDataFlow(const Array<Stage>& stages,
- std::unordered_map<Tensor, Tensor>* vmap,
+void ReplaceDataFlow(const Array<Stage>& stages, std::unordered_map<Tensor, Tensor>* vmap,
std::unordered_map<Tensor, Tensor>* rvmap) {
for (Stage s : stages) {
Operation op = s->op->ReplaceInputs(s->op, *vmap);
}
inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
- return (a->combiner.same_as(b->combiner)) &&
- (a->source.same_as(b->source)) &&
- (a->axis.same_as(b->axis)) &&
- (a->condition.same_as(b->condition));
+ return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
+ (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition));
}
-Tensor Schedule::cache_read(const Tensor& tensor,
- const std::string& scope,
+Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope,
const Array<Operation>& readers) {
(*this)->InvalidateCache();
// create identity mapping.
std::unordered_map<Tensor, Tensor> vsub;
Stage s = operator[](tensor->op);
Tensor sugar_tensor = s->op.output(tensor->value_index);
- Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
- return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
- }, os.str());
+ Tensor cache = compute(
+ sugar_tensor->shape,
+ [&sugar_tensor](const Array<Var>& i) {
+ return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
+ },
+ os.str());
vsub[sugar_tensor] = cache;
std::unordered_map<Tensor, Tensor> vmap;
for (Operation op : readers) {
Stage s = operator[](op);
Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
- CHECK(!repl_op.same_as(s->op))
- << "Cannot find " << tensor
- << " in the inputs of " << s->op;
+ CHECK(!repl_op.same_as(s->op)) << "Cannot find " << tensor << " in the inputs of " << s->op;
vmap[s->op.output(0)] = repl_op.output(0);
rvmap[repl_op.output(0)] = s->op.output(0);
s->op = repl_op;
Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
- stages->data.insert(stages->data.begin() + pos + 1,
- cache_stage);
+ stages->data.insert(stages->data.begin() + pos + 1, cache_stage);
(*this)->stage_map.Set(cache->op, cache_stage);
// Update group
cache_stage->group = op_stage->group;
return cache;
}
-template<typename OpType>
-void PrepareAxisMapping(Stage orig_stage,
- OpType* op,
- std::unordered_set<IterVar>* p_red_axis,
- Array<IterVar>* p_new_axis,
- std::unordered_map<IterVar, Range>* p_dom_map,
+template <typename OpType>
+void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_set<IterVar>* p_red_axis,
+ Array<IterVar>* p_new_axis, std::unordered_map<IterVar, Range>* p_dom_map,
std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
std::vector<PrimExpr>* p_predicates) {
std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
- CHECK_EQ(iv->iter_type, kDataPar)
- << "Can only relayout with in data parallel dimensions";
+ CHECK_EQ(iv->iter_type, kDataPar) << "Can only relayout with in data parallel dimensions";
Range dom = dom_map.at(iv);
- IterVar new_iv = IterVarNode::make(
- dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
+ IterVar new_iv = IterVarNode::make(dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
new_axis.push_back(new_iv);
if (is_one(dom->min)) {
value_map[iv] = dom->min;
skip_bound_check.insert(iv);
}
PassUpIndex(orig_stage, dom_map, &value_map, true);
- predicates = MakeBoundCheck(
- orig_stage, dom_map, value_map, true, skip_bound_check);
+ predicates = MakeBoundCheck(orig_stage, dom_map, value_map, true, skip_bound_check);
// The root axis
for (IterVar iv : op->axis) {
if (value_map.count(iv)) {
}
}
-Array<Tensor> ReplaceOriginalOp(Schedule sch,
- Stage orig_stage,
- const std::string& scope,
- Operation cache_op,
- Operation orig_new_op,
- size_t tensor_size) {
+Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::string& scope,
+ Operation cache_op, Operation orig_new_op, size_t tensor_size) {
Array<Tensor> cache_tensor_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
- stages->data.insert(stages->data.begin() + pos,
- cache_stage);
+ stages->data.insert(stages->data.begin() + pos, cache_stage);
sch->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
return cache_tensor_list;
}
-
// Cache write and relayout the data according to loop pattern
-Array<Tensor> CacheWriteWithReLayout(Schedule sch,
- const Array<Tensor>& tensor_array,
+Array<Tensor> CacheWriteWithReLayout(Schedule sch, const Array<Tensor>& tensor_array,
const std::string& scope) {
size_t tensor_size = tensor_array.size();
sch->InvalidateCache();
std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
std::vector<PrimExpr> predicates;
- PrepareAxisMapping(orig_stage, compute,
- &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
+ PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
+ &predicates);
PrimExpr body;
Array<PrimExpr> body_list;
const tir::ReduceNode* reduce_body = body.as<tir::ReduceNode>();
if (first_reduce != nullptr) {
CHECK(ReduceEqual(reduce_body, first_reduce));
- body = tir::ReduceNode::make(first_reduce->combiner,
- first_reduce->source,
- first_reduce->axis,
- first_reduce->condition,
- reduce_body->value_index);
+ body =
+ tir::ReduceNode::make(first_reduce->combiner, first_reduce->source, first_reduce->axis,
+ first_reduce->condition, reduce_body->value_index);
} else {
first_reduce = reduce_body;
}
} else {
- CHECK(first_reduce == nullptr)
- << "cannot mix reduce and other node in ONE compute bodys";
+ CHECK(first_reduce == nullptr) << "cannot mix reduce and other node in ONE compute bodys";
}
body_list.push_back(body);
}
args.push_back(value_map.at(iv));
}
}
- Operation cache_op = ComputeOpNode::make(
- compute->name + "." + scope, compute->tag, compute->attrs,
- new_axis, body_list);
+ Operation cache_op = ComputeOpNode::make(compute->name + "." + scope, compute->tag,
+ compute->attrs, new_axis, body_list);
Array<PrimExpr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
cache_expr_list.push_back(cache_tensor(args));
}
- Operation orig_new_op = ComputeOpNode::make(
- compute->name, compute->tag, compute->attrs,
- compute->axis, cache_expr_list);
- return ReplaceOriginalOp(sch, orig_stage, scope,
- cache_op, orig_new_op, tensor_size);
+ Operation orig_new_op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs,
+ compute->axis, cache_expr_list);
+ return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
}
-
// for tensor compute op
-Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
- const Array<Tensor>& tensor_array,
+Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch, const Array<Tensor>& tensor_array,
const std::string& scope) {
size_t tensor_size = tensor_array.size();
sch->InvalidateCache();
std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
std::vector<PrimExpr> predicates;
- PrepareAxisMapping(orig_stage, tensor_op,
- &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
-
+ PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
+ &predicates);
for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
IterVar iv = tensor_op->axis[i];
- IterVar new_iv = IterVarNode::make(
- iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
+ IterVar new_iv = IterVarNode::make(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
new_axis.push_back(new_iv);
}
Array<Region> new_regions;
new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
}
- Operation cache_op = TensorComputeOpNode::make(
- tensor_op->name + "." + scope, tensor_op->tag, new_axis,
- tensor_op->reduce_axis, tensor_op->schedulable_ndim,
- tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
+ Operation cache_op = TensorComputeOpNode::make(tensor_op->name + "." + scope, tensor_op->tag,
+ new_axis, tensor_op->reduce_axis,
+ tensor_op->schedulable_ndim, tensor_op->intrin,
+ tensor_op->inputs, new_regions, new_scalar_inputs);
// axis will be used in generating compute op
Array<IterVar> compute_axis = tensor_op->axis;
Tensor cache_tensor = cache_op.output(i);
cache_expr_list.push_back(cache_tensor(args));
}
- Operation orig_new_op = ComputeOpNode::make(
- tensor_op->name, tensor_op->tag, {},
- compute_axis, cache_expr_list);
- return ReplaceOriginalOp(sch, orig_stage, scope,
- cache_op, orig_new_op, tensor_size);
+ Operation orig_new_op =
+ ComputeOpNode::make(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list);
+ return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
}
-
-Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
- const std::string& scope) {
+Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array, const std::string& scope) {
(*this)->InvalidateCache();
- CHECK(tensor_array.size() > 0)
- << "size of tensor_array must be greater than 0";
+ CHECK(tensor_array.size() > 0) << "size of tensor_array must be greater than 0";
Tensor tensor = tensor_array[0];
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
<< "size of input tensor list must be same as number of stage outputs";
for (size_t i = 1; i < tensor_array.size(); i++) {
Stage tmp_stage = operator[](tensor_array[i]->op);
- CHECK(orig_stage.same_as(tmp_stage))
- << "Input tensor list must be generated by ONE computeOp";
+ CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp";
}
return CacheWriteWithReLayout(*this, tensor_array, scope);
}
-
-Tensor Schedule::cache_write(const Tensor& tensor,
- const std::string& scope) {
+Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) {
// support original compute and tensor compute both
(*this)->InvalidateCache();
if (tensor->op.as<ComputeOpNode>()) {
}
}
-
void RebaseNonZeroMinLoop(const Schedule& sch) {
std::unordered_map<IterVar, IterVar> rebase_map;
for (Stage s : sch->stages) {
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
size_t idx = FindNodeRef(leaf_vars, iv);
- auto it = s->iter_var_attrs.find(iv);
+ auto it = s->iter_var_attrs.find(iv);
// don;t need to rebase path that are binded.
- if (it != s->iter_var_attrs.end() &&
- (*it).second->bind_thread.defined()) {
+ if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
continue;
}
if (idx < leaf_vars->data.size()) {
// insert rebase
- IterVar rebased = IterVarNode::make(
- Range(), iv->var.copy_with_suffix(""), iv->iter_type);
+ IterVar rebased = IterVarNode::make(Range(), iv->var.copy_with_suffix(""), iv->iter_type);
s->relations.push_back(RebaseNode::make(iv, rebased));
if (s->iter_var_attrs.count(iv)) {
s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
{
// setup args
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
- CHECK(compute)
- << "can only inline compute op";
+ CHECK(compute) << "can only inline compute op";
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
- CHECK_EQ(compute->body.size(), 1U)
- << "can only inline compute op with 1 output";
+ CHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output";
body = compute->body[0];
}
for (size_t j = i; j < sch->stages.size(); ++j) {
for (size_t k = 1; k < new_body[j].size(); ++k) {
const tir::ReduceNode* reduce_ = new_body[j][k].as<tir::ReduceNode>();
CHECK(reduce_);
- CHECK(ReduceEqual(reduce_, reduce))
- << "The Reduce inputs of ComputeOp should "
- << "have the same attribute except value_index";
+ CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should "
+ << "have the same attribute except value_index";
}
- PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][0]),
- stage->op, args, body).as<tir::EvaluateNode>()->value;
+ PrimExpr new_value =
+ Inline(tir::EvaluateNode::make(new_body[j][0]), stage->op, args, body)
+ .as<tir::EvaluateNode>()
+ ->value;
if (!new_value.same_as(new_body[j][0])) {
changed[j] = true;
const tir::ReduceNode* r = new_value.as<tir::ReduceNode>();
}
} else {
for (size_t k = 0; k < new_body[j].size(); ++k) {
- PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][k]),
- stage->op, args, body).as<tir::EvaluateNode>()->value;
+ PrimExpr new_value =
+ Inline(tir::EvaluateNode::make(new_body[j][k]), stage->op, args, body)
+ .as<tir::EvaluateNode>()
+ ->value;
if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value);
changed[j] = true;
CHECK(compute);
Operation op = s->op;
if (changed[i]) {
- op = ComputeOpNode::make(
- compute->name, compute->tag, compute->attrs,
- compute->axis, new_body[i]);
+ op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, compute->axis,
+ new_body[i]);
}
op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) {
} else if (hybrid_changed[i]) {
const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
CHECK(hybrid);
- Operation op = HybridOpNode::make(
- hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
- hybrid->outputs, new_hybrid_body[i]);
+ Operation op = HybridOpNode::make(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
+ hybrid->outputs, new_hybrid_body[i]);
op = op->ReplaceInputs(op, repl);
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
}
// Handle reduction factor.
-Array<Tensor> Schedule::rfactor(const Tensor& tensor,
- const IterVar& axis,
- int factor_axis) {
+Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) {
(*this)->InvalidateCache();
using tir::ReduceNode;
- CHECK_EQ(axis->iter_type, kCommReduce)
- << "Can only factor reduction axis";
+ CHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis";
Stage reduce_stage = operator[](tensor->op);
const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
CHECK(compute_op) << "Can only factor ComputeOp";
std::unordered_set<IterVar> skip_bound_check;
// Verify normal axis are not touched.
for (IterVar iv : compute_op->axis) {
- CHECK(!touch_map.count(iv))
- << "Factor axis touches normal axis.";
+ CHECK(!touch_map.count(iv)) << "Factor axis touches normal axis.";
skip_bound_check.insert(iv);
}
// get analyzer.
}
}
te::PassUpIndex(reduce_stage, dom_map, &value_map, true);
- std::vector<PrimExpr> predicates = MakeBoundCheck(
- reduce_stage, dom_map, value_map, true, skip_bound_check);
+ std::vector<PrimExpr> predicates =
+ MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check);
// Get the factored op node.
- const int factor_axis_pos = \
+ const int factor_axis_pos =
factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
CHECK_LE(factor_axis_pos, compute_op->axis.size());
auto n = make_object<ComputeOpNode>();
// axis relacement.
auto iv_node = make_object<IterVarNode>();
iv_node->dom = dom_map.at(axis);
- CHECK(is_zero(iv_node->dom->min))
- << "Can only factor reduction domain starting from 0";
+ CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0";
iv_node->var = axis->var;
iv_node->iter_type = kDataPar;
}
}
VarReplacer replacer(vsub);
- Array<PrimExpr> new_source = tir::UpdateArray(reduce->source,
- [&replacer] (const PrimExpr& e) { return replacer(e); });
+ Array<PrimExpr> new_source =
+ tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); });
PrimExpr new_pred = replacer(predicate);
std::vector<PrimExpr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
- body.emplace_back(ReduceNode::make(reduce->combiner,
- new_source,
- n->reduce_axis,
- new_pred,
- idx));
+ body.emplace_back(
+ ReduceNode::make(reduce->combiner, new_source, n->reduce_axis, new_pred, idx));
}
n->body = Array<PrimExpr>(body);
// refresh relations, keep the un-touched relations.
Stage factor_stage = Stage(factor_op);
factor_stage->relations = rels;
CHECK_LT(stage_pos, stages->data.size());
- stages->data.insert(stages->data.begin() + stage_pos,
- factor_stage);
+ stages->data.insert(stages->data.begin() + stage_pos, factor_stage);
(*this)->stage_map.Set(factor_op, factor_stage);
factor_stage->group = reduce_stage->group;
if (factor_stage->group.defined()) {
++factor_stage->group->num_child_stages;
}
// Replace the old reduction.
- IterVar repl_red_axis = reduce_axis(
- dom_map.at(axis), axis->var->name_hint + ".v");
+ IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v");
Array<Tensor> factor_tensors;
Array<Tensor> old_tensors;
int size = factor_op->num_outputs();
factor_tensors.push_back(factor_op.output(idx));
old_tensors.push_back(reduce_stage->op.output(idx));
}
- Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
- [&](const Array<Var>& i) {
- Array<PrimExpr> indices;
- const int idx_size = static_cast<int>(i.size());
- for (int idx = 0; idx < idx_size; ++idx) {
- if (factor_axis_pos == idx) {
- indices.push_back(repl_red_axis->var);
+ Array<Tensor> repl_tensors = compute(
+ old_tensors[0]->shape,
+ [&](const Array<Var>& i) {
+ Array<PrimExpr> indices;
+ const int idx_size = static_cast<int>(i.size());
+ for (int idx = 0; idx < idx_size; ++idx) {
+ if (factor_axis_pos == idx) {
+ indices.push_back(repl_red_axis->var);
+ }
+ indices.push_back(i[idx]);
}
- indices.push_back(i[idx]);
- }
- if (factor_axis_pos == idx_size) {
+ if (factor_axis_pos == idx_size) {
indices.push_back(repl_red_axis->var);
- }
- Array<PrimExpr> factor_exprs;
- for (int idx = 0; idx < size; ++idx) {
- factor_exprs.push_back(factor_tensors[idx](indices));
- }
- Array<PrimExpr> reductions;
- Array<IterVar> axis = {repl_red_axis};
- PrimExpr cond = const_true();
- for (int idx = 0; idx < size; ++idx) {
- reductions.push_back(ReduceNode::make(reduce->combiner,
- factor_exprs, axis, cond, idx));
- }
- return reductions;
- }, reduce_stage->op->name + ".repl");
+ }
+ Array<PrimExpr> factor_exprs;
+ for (int idx = 0; idx < size; ++idx) {
+ factor_exprs.push_back(factor_tensors[idx](indices));
+ }
+ Array<PrimExpr> reductions;
+ Array<IterVar> axis = {repl_red_axis};
+ PrimExpr cond = const_true();
+ for (int idx = 0; idx < size; ++idx) {
+ reductions.push_back(ReduceNode::make(reduce->combiner, factor_exprs, axis, cond, idx));
+ }
+ return reductions;
+ },
+ reduce_stage->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
-#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+
#include <stack>
#include <unordered_set>
+
#include "graph.h"
namespace tvm {
namespace te {
// find first occurance location in leaf
-template<typename T>
+template <typename T>
size_t FindNodeRef(ArrayNode* array_node, const T& v) {
const Object* n = v.get();
for (size_t i = 0; i < array_node->data.size(); ++i) {
if (pos < leaf_vars->data.size()) return pos;
if (FindNodeRef(all_vars, v) < all_vars->data.size()) {
- LOG(FATAL) << "Operate on iter var " << v
- << "that has already been split";
+ LOG(FATAL) << "Operate on iter var " << v << "that has already been split";
} else {
- LOG(FATAL) << "Operate on iter var " << v
- << "that is not part of the schedule";
+ LOG(FATAL) << "Operate on iter var " << v << "that is not part of the schedule";
}
return 0;
}
-void Split(StageNode* self,
- IterVar parent,
- PrimExpr factor,
- PrimExpr nparts,
- IterVar* p_outer,
+void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, IterVar* p_outer,
IterVar* p_inner) {
// Check if split is valid.
- CHECK(parent->iter_type == kDataPar ||
- parent->iter_type == kCommReduce ||
+ CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce ||
parent->iter_type == kOrdered)
<< "Cannot split on " << IterVarType2String(parent->iter_type);
- IterVar outer = IterVarNode::make(
- Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
- IterVar inner = IterVarNode::make(
- Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
+ IterVar outer =
+ IterVarNode::make(Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
+ IterVar inner =
+ IterVarNode::make(Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
*p_outer = outer;
*p_inner = inner;
// The splits
Stage Stage::GetAttachSpec() const {
Stage attach_spec = *this;
- while (attach_spec->attach_type == kGroupRoot &&
- attach_spec->group.defined()) {
+ while (attach_spec->attach_type == kGroupRoot && attach_spec->group.defined()) {
attach_spec = attach_spec->group;
}
return attach_spec;
return *this;
}
-Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
- CHECK_NE((*this)->attach_type, kScanUpdate)
- << "Cannot specify compute_at for scan updates";
+Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
+ CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates";
// Group constraint checking.
Stage group = (*this)->group;
if (group.defined()) {
while (pg.defined() && !pg.same_as(group)) {
pg = pg->group;
}
- CHECK(pg.same_as(group))
- << "Can only assign compute_at to stages within the same group";
+ CHECK(pg.same_as(group)) << "Can only assign compute_at to stages within the same group";
}
(*this)->attach_type = kScope;
bool found = false;
for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
if (scope == parent->leaf_iter_vars[i]) {
- found = true; break;
+ found = true;
+ break;
}
}
- CHECK(found)
- << "Cannot find the axis " << scope
- << " in parent's leaf_iter_vars"
- << " parent=" << parent;
+ CHECK(found) << "Cannot find the axis " << scope << " in parent's leaf_iter_vars"
+ << " parent=" << parent;
return *this;
}
-Stage& Stage::compute_inline() { // NOLINT(*)
- CHECK_NE((*this)->attach_type, kScanUpdate)
- << "Cannot specify compute_at for scan updates";
+Stage& Stage::compute_inline() { // NOLINT(*)
+ CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates";
(*this)->attach_type = kInline;
return *this;
}
-Stage& Stage::compute_root() { // NOLINT(*)
- CHECK_NE((*this)->attach_type, kScanUpdate)
- << "Cannot specify compute_at for scan updates";
+Stage& Stage::compute_root() { // NOLINT(*)
+ CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates";
(*this)->attach_type = kGroupRoot;
return *this;
}
-Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
+Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
StageNode* self = operator->();
- CHECK(ivar->iter_type == kDataPar ||
- ivar->iter_type == kCommReduce)
+ CHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce)
<< "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
CHECK(thread_ivar->iter_type == kThreadIndex)
<< "Cannot rebase by " << IterVarType2String(ivar->iter_type)
ObjectPtr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) {
n = make_object<IterVarAttrNode>(*(*it).second.operator->());
- if (n->bind_thread.defined() &&
- !n->bind_thread.same_as(thread_ivar)) {
- LOG(WARNING) << "Axis " << ivar
- << " is already bind to another thread " << n->bind_thread;
+ if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) {
+ LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread;
}
} else {
n = make_object<IterVarAttrNode>();
StageNode* self = operator->();
CHECK(self->op.defined() && self->op.as<ScanOpNode>())
<< "env_threads is only valid for composite ops such as ScanOp";
- CHECK_EQ(self->env_threads.size(), 0U)
- << "Already set env_threads";
+ CHECK_EQ(self->env_threads.size(), 0U) << "Already set env_threads";
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
std::vector<ObjectRef> temp;
for (IterVar iv : threads) {
temp.push_back(iv);
}
- leaf_vars->data.insert(
- leaf_vars->data.begin(), temp.begin(), temp.end());
- all_vars->data.insert(
- all_vars->data.end(), temp.begin(), temp.end());
+ leaf_vars->data.insert(leaf_vars->data.begin(), temp.begin(), temp.end());
+ all_vars->data.insert(all_vars->data.end(), temp.begin(), temp.end());
self->env_threads = threads;
return *this;
}
return *this;
}
-Stage& Stage::split(
- IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
+Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer,
+ IterVar* p_inner) { // NOLINT(*)
Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
return *this;
}
-Stage& Stage::split_by_nparts(
- IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
+Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
+ IterVar* p_inner) { // NOLINT(*)
Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
return *this;
}
Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*)
StageNode* self = operator->();
- CHECK(outer->iter_type == kDataPar ||
- outer->iter_type == kCommReduce ||
+ CHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce ||
outer->iter_type == kOrdered)
<< "Cannot fuse " << IterVarType2String(outer->iter_type);
- CHECK(inner->iter_type == kDataPar ||
- inner->iter_type == kCommReduce ||
+ CHECK(inner->iter_type == kDataPar || inner->iter_type == kCommReduce ||
inner->iter_type == kOrdered)
<< "Cannot fuse " << IterVarType2String(inner->iter_type);
IterVarType iter_type = outer->iter_type;
if (inner->iter_type > iter_type) iter_type = inner->iter_type;
- std::string fused_name =
- outer->var->name_hint + "." + inner->var->name_hint + ".fused";
+ std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused";
- IterVar fused = IterVarNode::make(
- Range(), Var(fused_name, outer->var.dtype()), iter_type);
+ IterVar fused = IterVarNode::make(Range(), Var(fused_name, outer->var.dtype()), iter_type);
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
all_vars->data.push_back(fused);
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
leaf_vars->data.begin() + pos_inner + 1);
- leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
- fused);
+ leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, fused);
*p_target = fused;
return *this;
}
StageNode* self = operator->();
// special handle fuse empty array.
// insert at the outer most loop
- IterVar singleton = IterVarNode::make(
- Range::make_by_min_extent(0, 1),
- Var("singleton", DataType::Int(32)), kDataPar);
+ IterVar singleton = IterVarNode::make(Range::make_by_min_extent(0, 1),
+ Var("singleton", DataType::Int(32)), kDataPar);
self->relations.push_back(SingletonNode::make(singleton));
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
std::unordered_set<IterVar> seen_var;
StageNode* self = operator->();
for (IterVar iv : order) {
- CHECK(iv->iter_type == kDataPar ||
- iv->iter_type == kCommReduce ||
+ CHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce ||
iv->iter_type == kThreadIndex)
- << "Cannot reorder IterVar("
- << IterVarType2String(iv->iter_type) << ")";
+ << "Cannot reorder IterVar(" << IterVarType2String(iv->iter_type) << ")";
- CHECK_EQ(seen_var.count(iv), 0)
- << "Same axis can not appear more than once " << iv;
+ CHECK_EQ(seen_var.count(iv), 0) << "Same axis can not appear more than once " << iv;
seen_var.insert(iv);
}
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
return *this;
}
-Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
- PrimExpr x_factor, PrimExpr y_factor,
- IterVar* p_x_outer, IterVar* p_y_outer,
- IterVar* p_x_inner, IterVar* p_y_inner) {
+Stage& Stage::tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor,
+ IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner) {
split(x_parent, x_factor, p_x_outer, p_x_inner);
split(y_parent, y_factor, p_y_outer, p_y_inner);
reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
return *this;
}
-template<typename FUpdate>
-inline void UpdateIterVarAttr(StageNode* self,
- IterVar var,
- FUpdate fupdate,
+template <typename FUpdate>
+inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate,
bool need_leaf = true) {
if (need_leaf) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
}
inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
- UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) {
- n->iter_type = iter_type;
- });
+ UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { n->iter_type = iter_type; });
}
-Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
- CHECK(var->iter_type == kDataPar ||
- var->iter_type == kOpaque ||
- var->iter_type == kUnrolled ||
- var->iter_type == kVectorized ||
- var->iter_type == kTensorized ||
+Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
+ CHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || var->iter_type == kUnrolled ||
+ var->iter_type == kVectorized || var->iter_type == kTensorized ||
var->iter_type == kParallelized)
<< "Cannot vectorize on " << IterVarType2String(var->iter_type);
SetAttrIterType(operator->(), var, kVectorized);
return *this;
}
-Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*)
+Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*)
UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
- n->iter_type = kTensorized;
- n->tensor_intrin = f;
- });
+ n->iter_type = kTensorized;
+ n->tensor_intrin = f;
+ });
return *this;
}
-Stage& Stage::unroll(IterVar var) { // NOLINT(*)
+Stage& Stage::unroll(IterVar var) { // NOLINT(*)
SetAttrIterType(operator->(), var, kUnrolled);
return *this;
}
-Stage& Stage::parallel(IterVar var) { // NOLINT(*)
+Stage& Stage::parallel(IterVar var) { // NOLINT(*)
SetAttrIterType(operator->(), var, kParallelized);
return *this;
}
-Stage& Stage::pragma(IterVar var,
- const std::string& pragma_type,
- const PrimExpr& pragma_value) { // NOLINT(*)
+Stage& Stage::pragma(IterVar var, const std::string& pragma_type,
+ const PrimExpr& pragma_value) { // NOLINT(*)
if (pragma_type == "unroll") {
this->unroll(var);
} else if (pragma_type == "vectorize") {
this->vectorize(var);
} else {
- UpdateIterVarAttr(
- operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
- n->pragma_keys.push_back(tir::StringImmNode::make(pragma_type));
- n->pragma_values.push_back(pragma_value);
- });
+ UpdateIterVarAttr(operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
+ n->pragma_keys.push_back(tir::StringImmNode::make(pragma_type));
+ n->pragma_values.push_back(pragma_value);
+ });
}
return *this;
}
-Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) {
- StageNode *self = operator->();
+Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) {
+ StageNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
}
Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
- StageNode *self = operator->();
- UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) {
- n->dim_align_factor = factor;
- n->dim_align_offset = offset;
- }, false);
+ StageNode* self = operator->();
+ UpdateIterVarAttr(
+ self, axis,
+ [factor, offset](IterVarAttrNode* n) {
+ n->dim_align_factor = factor;
+ n->dim_align_offset = offset;
+ },
+ false);
return *this;
}
Stage& Stage::double_buffer() {
- StageNode *self = operator->();
+ StageNode* self = operator->();
CHECK(!self->is_output) << "Cannot apply double buffer on output";
self->double_buffer = true;
return *this;
Stage& Stage::opengl() {
CHECK(!is_scheduled()) << "Must be a fresh schedule";
- StageNode *self = operator->();
+ StageNode* self = operator->();
auto all_iter_vars = self->all_iter_vars; // curr version of all_iter_vars
CHECK(!all_iter_vars.empty()) << "At least one iter var";
break;
}
default: {
- LOG(ERROR) << "Invalid iter var type "
- << IterVarType2String(iter_var->iter_type);
+ LOG(ERROR) << "Invalid iter var type " << IterVarType2String(iter_var->iter_type);
break;
}
}
}
Stage CopyStage(const Stage& s) {
- ObjectPtr<StageNode> n =
- make_object<StageNode>(*s.operator->());
+ ObjectPtr<StageNode> n = make_object<StageNode>(*s.operator->());
return Stage(n);
}
for (Stage s : n->stages) {
if (s->attach_stage.defined()) {
CHECK(smap.find(s->attach_stage) != smap.end())
- << s->attach_stage << " not found in " << (*this);
+ << s->attach_stage << " not found in " << (*this);
s->attach_stage = smap.at(s->attach_stage);
}
if (s->group.defined()) {
- CHECK(smap.find(s->group) != smap.end())
- << s->group << " not found in " << (*this);
+ CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this);
s->group = smap.at(s->group);
}
}
for (Stage s : n->groups) {
if (s->attach_stage.defined()) {
CHECK(smap.find(s->attach_stage) != smap.end())
- << s->attach_stage << " not found in " << (*this);
+ << s->attach_stage << " not found in " << (*this);
s->attach_stage = smap.at(s->attach_stage);
}
if (s->group.defined()) {
- CHECK(smap.find(s->group) != smap.end())
- << s->group << " not found in " << (*this);
+ CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this);
s->group = smap.at(s->group);
}
}
Stage Schedule::operator[](const Operation& op) {
auto it = (*this)->stage_map.find(op);
CHECK(it != (*this)->stage_map.end())
- << "Cannot find Stage for operator " << op
- << " in the schedule";
+ << "Cannot find Stage for operator " << op << " in the schedule";
return (*it).second;
}
return g;
}
-Array<Tensor> RemapTensor(ScheduleNode* self,
- const Array<Tensor>& arr) {
+Array<Tensor> RemapTensor(ScheduleNode* self, const Array<Tensor>& arr) {
self->InitCache();
const auto& op2stage_cache = self->op2stage_cache_;
Array<Tensor> ret;
for (Tensor t : arr) {
if (!op2stage_cache.count(t->op.get())) {
- CHECK(self->stage_map.count(t->op))
- << "Given tensor is not in the schedule plan";
+ CHECK(self->stage_map.count(t->op)) << "Given tensor is not in the schedule plan";
t = self->stage_map[t->op]->op.output(t->value_index);
}
ret.push_back(t);
}
// Group the schedule stages.
-Stage Schedule::create_group(const Array<Tensor>& outputs,
- const Array<Tensor>& inputs,
+Stage Schedule::create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
bool include_inputs) {
ScheduleNode* self = operator->();
self->InitCache();
const auto& op2stage_cache = self->op2stage_cache_;
// Get the ops.
- Array<Operation> ops = te::GetSubGraph(
- RemapTensor(self, outputs),
- RemapTensor(self, inputs),
- include_inputs);
+ Array<Operation> ops =
+ te::GetSubGraph(RemapTensor(self, outputs), RemapTensor(self, inputs), include_inputs);
// local counter entry
// Automatically initialize to 0 during creation.
struct Entry {
// Propagate the counter statistics from by checking if subgroup
// Is full and propagate.
std::vector<Stage> stack;
- for (auto &kv : counter) {
+ for (auto& kv : counter) {
if (!kv.first.same_as(parent_group)) {
if (kv.first->num_child_stages == kv.second.count) {
stack.push_back(kv.first);
}
}
// Verification and remappig the subgroups.
- for (auto &kv : counter) {
+ for (auto& kv : counter) {
if (kv.first.same_as(parent_group)) continue;
CHECK_EQ(kv.first->num_child_stages, kv.second.count)
<< "Trying to group region that intersect with an already existed group";
return gstage;
}
-void ScheduleNode::InvalidateCache() {
- op2stage_cache_.clear();
-}
+void ScheduleNode::InvalidateCache() { op2stage_cache_.clear(); }
void ScheduleNode::InitCache() {
if (op2stage_cache_.size() == stages.size()) return;
return sch;
}
-IterVarRelation SplitNode::make(IterVar parent,
- IterVar outer,
- IterVar inner,
- PrimExpr factor,
+IterVarRelation SplitNode::make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor,
PrimExpr nparts) {
auto n = make_object<SplitNode>();
n->parent = parent;
return IterVarRelation(n);
}
-IterVarRelation FuseNode::make(
- IterVar outer, IterVar inner, IterVar fused) {
+IterVarRelation FuseNode::make(IterVar outer, IterVar inner, IterVar fused) {
auto n = make_object<FuseNode>();
n->outer = outer;
n->inner = inner;
typedef dmlc::ThreadLocalStore<TVMSpecializationThreadLocalEntry> TVMSpecializationThreadLocalStore;
void SpecializedCondition::EnterWithScope() {
- TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
+ TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get();
entry->condition_stack.push(*this);
}
void SpecializedCondition::ExitWithScope() {
- TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
+ TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get();
CHECK(!entry->condition_stack.empty());
CHECK(entry->condition_stack.top().same_as(*this));
entry->condition_stack.pop();
}
SpecializedCondition SpecializedCondition::Current() {
- TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
+ TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get();
SpecializedCondition cond;
if (entry->condition_stack.size() > 0) {
cond = entry->condition_stack.top();
class SpecializedCondition::Internal {
public:
- static void EnterScope(SpecializedCondition cond) {
- cond.EnterWithScope();
- }
+ static void EnterScope(SpecializedCondition cond) { cond.EnterWithScope(); }
- static void ExitScope(SpecializedCondition cond) {
- cond.ExitWithScope();
- }
+ static void ExitScope(SpecializedCondition cond) { cond.ExitWithScope(); }
};
TVM_REGISTER_NODE_TYPE(StageNode);
// Printer
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<StageNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const StageNode*>(node.get());
- if (op->op.defined()) {
- p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
- } else {
- p->stream << "group-stage(" << op << ")";
- }
-})
-.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IterVarAttrNode*>(node.get());
- p->stream << IterVarType2String(op->iter_type);
-})
-.set_dispatch<SplitNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SplitNode*>(node.get());
- p->stream << "split(parent=";
- p->Print(op->parent);
- p->stream << ", outer=";
- p->Print(op->outer);
- p->stream << ", inner=";
- p->Print(op->inner);
- p->stream << ')';
-})
-.set_dispatch<FuseNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const FuseNode*>(node.get());
- p->stream << "split(";
- p->stream << "outer=";
- p->Print(op->outer);
- p->stream << ", inner=";
- p->Print(op->inner);
- p->stream << ", fused=";
- p->Print(op->fused);
- p->stream << ')';
-})
-.set_dispatch<RebaseNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const RebaseNode*>(node.get());
- p->stream << "rebase(";
- p->stream << "parent=";
- p->Print(op->parent);
- p->stream << ", rebased=";
- p->Print(op->rebased);
- p->stream << ')';
-})
-.set_dispatch<SingletonNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SingletonNode*>(node.get());
- p->stream << "singleton(";
- p->Print(op->iter);
- p->stream << ')';
-})
-.set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ScheduleNode*>(node.get());
- p->stream << "schedule(" << op << ")";
-})
-.set_dispatch<SpecializedConditionNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SpecializedConditionNode*>(node.get());
- p->stream << "specialized_condition(";
- p->Print(op->clauses);
- p->stream << ')';
-});
-
+ .set_dispatch<StageNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const StageNode*>(node.get());
+ if (op->op.defined()) {
+ p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
+ } else {
+ p->stream << "group-stage(" << op << ")";
+ }
+ })
+ .set_dispatch<IterVarAttrNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const IterVarAttrNode*>(node.get());
+ p->stream << IterVarType2String(op->iter_type);
+ })
+ .set_dispatch<SplitNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const SplitNode*>(node.get());
+ p->stream << "split(parent=";
+ p->Print(op->parent);
+ p->stream << ", outer=";
+ p->Print(op->outer);
+ p->stream << ", inner=";
+ p->Print(op->inner);
+ p->stream << ')';
+ })
+ .set_dispatch<FuseNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const FuseNode*>(node.get());
+ p->stream << "split(";
+ p->stream << "outer=";
+ p->Print(op->outer);
+ p->stream << ", inner=";
+ p->Print(op->inner);
+ p->stream << ", fused=";
+ p->Print(op->fused);
+ p->stream << ')';
+ })
+ .set_dispatch<RebaseNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const RebaseNode*>(node.get());
+ p->stream << "rebase(";
+ p->stream << "parent=";
+ p->Print(op->parent);
+ p->stream << ", rebased=";
+ p->Print(op->rebased);
+ p->stream << ')';
+ })
+ .set_dispatch<SingletonNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const SingletonNode*>(node.get());
+ p->stream << "singleton(";
+ p->Print(op->iter);
+ p->stream << ')';
+ })
+ .set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ScheduleNode*>(node.get());
+ p->stream << "schedule(" << op << ")";
+ })
+ .set_dispatch<SpecializedConditionNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const SpecializedConditionNode*>(node.get());
+ p->stream << "specialized_condition(";
+ p->Print(op->clauses);
+ p->stream << ')';
+ });
-TVM_REGISTER_GLOBAL("te.CreateSchedule")
-.set_body_typed(create_schedule);
+TVM_REGISTER_GLOBAL("te.CreateSchedule").set_body_typed(create_schedule);
-TVM_REGISTER_GLOBAL("te.StageSetScope")
-.set_body_method(&Stage::set_scope);
+TVM_REGISTER_GLOBAL("te.StageSetScope").set_body_method(&Stage::set_scope);
-TVM_REGISTER_GLOBAL("te.StageBind")
-.set_body_method(&Stage::bind);
+TVM_REGISTER_GLOBAL("te.StageBind").set_body_method(&Stage::bind);
TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
-.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
- IterVar outer, inner;
- stage.split(parent, factor, &outer, &inner);
- return Array<IterVar>({outer, inner});
-});
+ .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
+ IterVar outer, inner;
+ stage.split(parent, factor, &outer, &inner);
+ return Array<IterVar>({outer, inner});
+ });
TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
-.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
- IterVar outer, inner;
- stage.split_by_nparts(parent, nparts, &outer, &inner);
- return Array<IterVar>({outer, inner});
-});
+ .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
+ IterVar outer, inner;
+ stage.split_by_nparts(parent, nparts, &outer, &inner);
+ return Array<IterVar>({outer, inner});
+ });
-TVM_REGISTER_GLOBAL("te.StageFuse")
-.set_body_typed([](Stage stage, Array<IterVar> axes) {
- IterVar fused;
- stage.fuse(axes, &fused);
- return fused;
- });
+TVM_REGISTER_GLOBAL("te.StageFuse").set_body_typed([](Stage stage, Array<IterVar> axes) {
+ IterVar fused;
+ stage.fuse(axes, &fused);
+ return fused;
+});
-TVM_REGISTER_GLOBAL("te.StageComputeAt")
-.set_body_method(&Stage::compute_at);
+TVM_REGISTER_GLOBAL("te.StageComputeAt").set_body_method(&Stage::compute_at);
-TVM_REGISTER_GLOBAL("te.StageComputeInline")
-.set_body_method(&Stage::compute_inline);
+TVM_REGISTER_GLOBAL("te.StageComputeInline").set_body_method(&Stage::compute_inline);
-TVM_REGISTER_GLOBAL("te.StageComputeRoot")
-.set_body_method(&Stage::compute_root);
+TVM_REGISTER_GLOBAL("te.StageComputeRoot").set_body_method(&Stage::compute_root);
-TVM_REGISTER_GLOBAL("te.StageReorder")
-.set_body_method(&Stage::reorder);
+TVM_REGISTER_GLOBAL("te.StageReorder").set_body_method(&Stage::reorder);
TVM_REGISTER_GLOBAL("te.StageTile")
-.set_body_typed([](
- Stage stage,
- IterVar x_parent, IterVar y_parent,
- PrimExpr x_factor, PrimExpr y_factor
-) {
- IterVar x_outer, y_outer, x_inner, y_inner;
- stage.tile(x_parent, y_parent,
- x_factor, y_factor,
- &x_outer, &y_outer,
- &x_inner, &y_inner);
- return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
- });
+ .set_body_typed([](Stage stage, IterVar x_parent, IterVar y_parent, PrimExpr x_factor,
+ PrimExpr y_factor) {
+ IterVar x_outer, y_outer, x_inner, y_inner;
+ stage.tile(x_parent, y_parent, x_factor, y_factor, &x_outer, &y_outer, &x_inner, &y_inner);
+ return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
+ });
-TVM_REGISTER_GLOBAL("te.StageEnvThreads")
-.set_body_method(&Stage::env_threads);
+TVM_REGISTER_GLOBAL("te.StageEnvThreads").set_body_method(&Stage::env_threads);
-TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
-.set_body_method(&Stage::set_store_predicate);
+TVM_REGISTER_GLOBAL("te.StageSetStorePredicate").set_body_method(&Stage::set_store_predicate);
-TVM_REGISTER_GLOBAL("te.StageUnroll")
-.set_body_method(&Stage::unroll);
+TVM_REGISTER_GLOBAL("te.StageUnroll").set_body_method(&Stage::unroll);
-TVM_REGISTER_GLOBAL("te.StageVectorize")
-.set_body_method(&Stage::vectorize);
+TVM_REGISTER_GLOBAL("te.StageVectorize").set_body_method(&Stage::vectorize);
-TVM_REGISTER_GLOBAL("te.StageTensorize")
-.set_body_method(&Stage::tensorize);
+TVM_REGISTER_GLOBAL("te.StageTensorize").set_body_method(&Stage::tensorize);
-TVM_REGISTER_GLOBAL("te.StageParallel")
-.set_body_method(&Stage::parallel);
+TVM_REGISTER_GLOBAL("te.StageParallel").set_body_method(&Stage::parallel);
-TVM_REGISTER_GLOBAL("te.StagePragma")
-.set_body_method(&Stage::pragma);
+TVM_REGISTER_GLOBAL("te.StagePragma").set_body_method(&Stage::pragma);
-TVM_REGISTER_GLOBAL("te.StagePrefetch")
-.set_body_method(&Stage::prefetch);
+TVM_REGISTER_GLOBAL("te.StagePrefetch").set_body_method(&Stage::prefetch);
-TVM_REGISTER_GLOBAL("te.StageStorageAlign")
-.set_body_method(&Stage::storage_align);
+TVM_REGISTER_GLOBAL("te.StageStorageAlign").set_body_method(&Stage::storage_align);
-TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
-.set_body_method(&Stage::double_buffer);
+TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffer);
-TVM_REGISTER_GLOBAL("te.StageOpenGL")
-.set_body_method(&Stage::opengl);
+TVM_REGISTER_GLOBAL("te.StageOpenGL").set_body_method(&Stage::opengl);
-TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
-.set_body_method(&Schedule::normalize);
+TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize);
-TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
-.set_body_method(&Schedule::create_group);
+TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group);
-TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
-.set_body_method(&Schedule::cache_read);
+TVM_REGISTER_GLOBAL("te.ScheduleCacheRead").set_body_method(&Schedule::cache_read);
-TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- if (args[1].IsObjectRef<Tensor>()) {
- *ret = args[0].operator Schedule()
- .cache_write(args[1].operator Tensor(), args[2]);
- } else {
- *ret = args[0].operator Schedule()
- .cache_write(args[1].operator Array<Tensor>(), args[2]);
- }
- });
+TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite").set_body([](TVMArgs args, TVMRetValue* ret) {
+ if (args[1].IsObjectRef<Tensor>()) {
+ *ret = args[0].operator Schedule().cache_write(args[1].operator Tensor(), args[2]);
+ } else {
+ *ret = args[0].operator Schedule().cache_write(args[1].operator Array<Tensor>(), args[2]);
+ }
+});
-TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
-.set_body_method(&Schedule::rfactor);
+TVM_REGISTER_GLOBAL("te.ScheduleRFactor").set_body_method(&Schedule::rfactor);
-TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition")
-.set_body_typed([](Array<PrimExpr> condition) {
- return SpecializedCondition(condition);
+TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition").set_body_typed([](Array<PrimExpr> condition) {
+ return SpecializedCondition(condition);
});
-TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = SpecializedCondition::Current();
+TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization").set_body([](TVMArgs args, TVMRetValue* ret) {
+ *ret = SpecializedCondition::Current();
});
TVM_REGISTER_GLOBAL("te.EnterSpecializationScope")
-.set_body_typed(SpecializedCondition::Internal::EnterScope);
+ .set_body_typed(SpecializedCondition::Internal::EnterScope);
TVM_REGISTER_GLOBAL("te.ExitSpecializationScope")
-.set_body_typed(SpecializedCondition::Internal::ExitScope);
+ .set_body_typed(SpecializedCondition::Internal::ExitScope);
} // namespace te
} // namespace tvm
* \file schedule_ops.cc
*/
#include <tvm/runtime/registry.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/analysis.h>
-#include <tvm/tir/stmt_functor.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
-#include <utility>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+
#include <unordered_map>
#include <unordered_set>
-#include "graph.h"
-#include "../operation/op_util.h"
+#include <utility>
+
#include "../../tir/transforms/ir_util.h"
+#include "../operation/op_util.h"
+#include "graph.h"
namespace tvm {
namespace te {
using namespace tir;
-Stmt MakePipeline(const Stage& s,
- const std::unordered_map<IterVar, Range>& dom_map,
- Stmt consumer,
+Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_map, Stmt consumer,
bool debug_keep_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
if (s->double_buffer) {
- producer = AttrStmtNode::make(
- s->op, tir::attr::double_buffer_scope, 1, producer);
+ producer = AttrStmtNode::make(s->op, tir::attr::double_buffer_scope, 1, producer);
}
Stmt pipeline = producer;
}
pipeline = s->op->BuildRealize(s, dom_map, pipeline);
// use attribute to mark scope of the operation.
- pipeline = AttrStmtNode::make(
- s->op, tir::attr::realize_scope,
- StringImmNode::make(s->scope),
- pipeline);
+ pipeline =
+ AttrStmtNode::make(s->op, tir::attr::realize_scope, StringImmNode::make(s->scope), pipeline);
if (s->is_opengl) {
- pipeline = AttrStmtNode::make(
- s->op, tir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline);
+ pipeline =
+ AttrStmtNode::make(s->op, tir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline);
}
return pipeline;
}
// inject the operator's realization on the stmt.
class InjectAttach : public StmtMutator {
public:
- InjectAttach(const Stage& stage,
- const Stage& attach_spec,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop)
- : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
+ InjectAttach(const Stage& stage, const Stage& attach_spec,
+ const std::unordered_map<IterVar, Range>& dom_map, bool debug_keep_trivial_loop)
+ : stage_(stage),
+ attach_spec_(attach_spec),
+ dom_map_(dom_map),
debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
Stmt VisitStmt(const Stmt& input_stmt) final {
CHECK(input_stmt.defined());
auto stmt = StmtMutator::VisitStmt(input_stmt);
const AttrStmtNode* op = stmt.as<AttrStmtNode>();
- if (op != nullptr &&
- op->attr_key == tir::attr::loop_scope) {
- if (attach_spec_->attach_type == kScope &&
- op->node == attach_spec_->attach_ivar) {
- CHECK(!found_attach)
- << "Find IterVar" << attach_spec_->attach_ivar
- << " in multiple places in the IR";
+ if (op != nullptr && op->attr_key == tir::attr::loop_scope) {
+ if (attach_spec_->attach_type == kScope && op->node == attach_spec_->attach_ivar) {
+ CHECK(!found_attach) << "Find IterVar" << attach_spec_->attach_ivar
+ << " in multiple places in the IR";
found_attach = true;
- stmt = AttrStmtNode::make(
- op->node, op->attr_key, op->value,
- MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+ stmt =
+ AttrStmtNode::make(op->node, op->attr_key, op->value,
+ MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
}
}
return stmt;
// inject the operator's realization on the stmt.
class InjectScanStep : public StmtMutator {
public:
- InjectScanStep(const Stage& stage,
- const Operation& scan_op,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool is_init,
+ InjectScanStep(const Stage& stage, const Operation& scan_op,
+ const std::unordered_map<IterVar, Range>& dom_map, bool is_init,
bool debug_keep_trivial_loop)
- : stage_(stage), scan_op_(scan_op),
- dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
+ : stage_(stage),
+ scan_op_(scan_op),
+ dom_map_(dom_map),
+ is_init_(is_init),
+ debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
Stmt VisitStmt(const Stmt& input_stmt) final {
CHECK(input_stmt.defined());
auto stmt = StmtMutator::VisitStmt(input_stmt);
// update
const AttrStmtNode* op = stmt.as<AttrStmtNode>();
- if (op != nullptr &&
- ((op->attr_key == tir::attr::scan_update_scope && !is_init_) ||
- (op->attr_key == tir::attr::scan_init_scope && is_init_))) {
+ if (op != nullptr && ((op->attr_key == tir::attr::scan_update_scope && !is_init_) ||
+ (op->attr_key == tir::attr::scan_init_scope && is_init_))) {
if (op->node.same_as(scan_op_)) {
found_attach = true;
- stmt = AttrStmtNode::make(
- op->node, op->attr_key, op->value,
- MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+ stmt =
+ AttrStmtNode::make(op->node, op->attr_key, op->value,
+ MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
}
}
return stmt;
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tir::attr::loop_scope ||
- op->attr_key == tir::attr::scan_init_scope) {
+ if (op->attr_key == tir::attr::loop_scope || op->attr_key == tir::attr::scan_init_scope) {
return this->VisitStmt(op->body);
} else if (op->attr_key == tir::attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>();
auto it = replace_op_.find(op->node.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
- Stmt ret = AttrStmtNode::make(
- it->second, op->attr_key, op->value, op->body);
+ Stmt ret = AttrStmtNode::make(it->second, op->attr_key, op->value, op->body);
return this->VisitStmt(ret);
} else {
return this->VisitStmt(op->body);
if (it != replace_op_.end()) {
if (it->second.defined()) {
return AttrStmtNode::make(
- Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
- op->attr_key, op->value, this->VisitStmt(op->body));
+ Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)}, op->attr_key,
+ op->value, this->VisitStmt(op->body));
} else {
return this->VisitStmt(op->body);
}
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
- return AttrStmtNode::make(
- it->second.output(tensor->value_index),
- op->attr_key, op->value, this->VisitStmt(op->body));
+ return AttrStmtNode::make(it->second.output(tensor->value_index), op->attr_key, op->value,
+ this->VisitStmt(op->body));
} else {
return this->VisitStmt(op->body);
}
auto it = replace_realize_.find(key);
if (it != replace_realize_.end()) {
if (it->second.defined()) {
- Stmt ret = RealizeNode::make(
- it->second->op, it->second->value_index,
- op->dtype, op->bounds, op->condition, op->body);
+ Stmt ret = RealizeNode::make(it->second->op, it->second->value_index, op->dtype, op->bounds,
+ op->condition, op->body);
return this->VisitStmt(ret);
} else {
return this->VisitStmt(op->body);
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second;
- Stmt ret = ProvideNode::make(
- dst->op, dst->value_index, op->value, op->args);
+ Stmt ret = ProvideNode::make(dst->op, dst->value_index, op->value, op->args);
return this->VisitStmt(ret);
} else {
return StmtExprMutator::VisitStmt_(op);
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second;
- PrimExpr ret = CallNode::make(
- op->dtype, dst->op->name, op->args,
- op->call_type, dst->op, dst->value_index);
+ PrimExpr ret = CallNode::make(op->dtype, dst->op->name, op->args, op->call_type, dst->op,
+ dst->value_index);
return this->VisitExpr(ret);
}
}
if (!s->op.same_as(s->origin_op)) {
for (int i = 0; i < s->op->num_outputs(); ++i) {
Tensor target = s->origin_op.output(i);
- AddReplace(s->op.output(i), target,
- target, s->origin_op);
+ AddReplace(s->op.output(i), target, target, s->origin_op);
}
}
// Specially add replacements for scan op.
}
private:
- void AddReplace(Tensor src,
- Tensor dst,
- Tensor repl_realize = Tensor(),
+ void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(),
Operation repl_op = Operation()) {
TensorKey key{src->op, src->value_index};
replace_buffer_[key] = dst;
arith::Analyzer analyzer_;
};
-Stmt ScheduleOps(
- Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
+Stmt ScheduleOps(Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
// scan init and scan updates
if (!scan) continue;
for (Tensor t : scan->init) {
if (scan_init.count(t->op)) {
- CHECK(scan_init.at(t->op).same_as(s->op))
- << "Scan init tensor can only belong to one scan";
+ CHECK(scan_init.at(t->op).same_as(s->op)) << "Scan init tensor can only belong to one scan";
} else {
scan_init[t->op] = s->op;
}
// reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1];
- CHECK_NE(s->attach_type, kInline)
- << "call schedule.normalize before scheduleops";
+ CHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops";
CHECK(s->op.defined());
// no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue;
CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
body = mu(std::move(body));
- CHECK(mu.found_attach)
- << "did not find attachment point for scan.init";
+ CHECK(mu.found_attach) << "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update
CHECK(body.defined());
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
body = mu(std::move(body));
- CHECK(mu.found_attach)
- << "did not find attachment point for scan.update";
+ CHECK(mu.found_attach) << "did not find attachment point for scan.update";
} else if (attach_spec->attach_type == kInlinedAlready) {
// do nothing
} else if (attach_spec->attach_type == kGroupRoot) {
CHECK(body.defined());
InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
body = mutator(std::move(body));
- CHECK(mutator.found_attach)
- << "did not find attachment point for " << s << " in "
- << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
- << ", body:\n"
- << body;
+ CHECK(mutator.found_attach) << "did not find attachment point for " << s << " in "
+ << attach_spec->attach_stage->op << " x "
+ << attach_spec->attach_ivar << ", body:\n"
+ << body;
}
}
SchedulePostProc post_proc;
return post_proc(std::move(body));
}
-TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+TVM_REGISTER_GLOBAL("schedule.ScheduleOps").set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 2)
*ret = ScheduleOps(args[0], args[1], false);
else
* \brief Rewrite the Stmt generated by ScheduleOps
* to accomondate tensorcore.
*/
+#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
+#include <tvm/target/target_info.h>
+#include <tvm/te/operation.h>
+#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
-#include <tvm/te/operation.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/op.h>
-#include <tvm/tir/buffer.h>
-#include <tvm/target/target_info.h>
-#include <tvm/target/target.h>
-#include <tvm/runtime/device_api.h>
+
#include <unordered_map>
+
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
namespace te {
using namespace te;
+using intrinsic::tvm_address_of;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
-using intrinsic::tvm_address_of;
struct Tile {
int m{-1};
}
}
-PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) {
+PrimExpr unpack_type_cast(const PrimExpr& input, const DataType& target_type) {
auto cast = input.as<CastNode>();
if (cast == nullptr) {
return input;
// MMAMatcher matches C = Cast(A)*Cast(B)+C,
// where A & B are fp16/int8 local buffers,
// and C is fp32/int32 local buffer.
-class MMAMatcher: public StmtVisitor {
+class MMAMatcher : public StmtVisitor {
public:
explicit MMAMatcher(Map<Tensor, Buffer> extern_buffer) {
for (auto kv : extern_buffer) {
}
}
- inline bool Matched() const {return matched_;}
+ inline bool Matched() const { return matched_; }
friend class ScheduleAnalyser;
friend class BufferAnalyser;
DataType dtype;
bool external{false};
bool released{false};
- bool same_as(const BufferInfo &bi) {
+ bool same_as(const BufferInfo& bi) {
if (this->dtype != bi.dtype) return false;
if (this->name != bi.name) return false;
if (this->external != bi.external) return false;
auto* load_c = add->a.as<CallNode>();
BufferInfo buffer_c;
- if (!check_local_buffer_(load_c, &buffer_c)
- || !buffer_c.same_as(store_buffer)
- || !(buffer_c.dtype == DataType::Float(32) ||
- buffer_c.dtype == DataType::Int(32))) {
+ if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) ||
+ !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) {
return false;
}
auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype);
auto load_a = load_a_expr.as<CallNode>();
BufferInfo buffer_a;
- if (!check_local_buffer_(load_a, &buffer_a)
- || !(buffer_a.dtype == DataType::Float(16) ||
- buffer_a.dtype == DataType::Int(8) ||
- buffer_a.dtype == DataType::UInt(8) ||
- buffer_a.dtype == DataType::Int(4) ||
- buffer_a.dtype == DataType::UInt(4) ||
- buffer_a.dtype == DataType::Int(1))) {
+ if (!check_local_buffer_(load_a, &buffer_a) ||
+ !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) ||
+ buffer_a.dtype == DataType::UInt(8) || buffer_a.dtype == DataType::Int(4) ||
+ buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) {
return false;
}
auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype);
auto load_b = load_b_expr.as<CallNode>();
BufferInfo buffer_b;
- if (!check_local_buffer_(load_b, &buffer_b)
- || !(buffer_b.dtype == DataType::Float(16) ||
- buffer_b.dtype == DataType::Int(8) ||
- buffer_b.dtype == DataType::UInt(8) ||
- buffer_b.dtype == DataType::Int(4) ||
- buffer_a.dtype == DataType::UInt(4) ||
- buffer_a.dtype == DataType::Int(1))) {
+ if (!check_local_buffer_(load_b, &buffer_b) ||
+ !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) ||
+ buffer_b.dtype == DataType::UInt(8) || buffer_b.dtype == DataType::Int(4) ||
+ buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) {
return false;
}
frag_reg_.insert(buffer_b.name);
buf_name_.insert(std::make_pair(load_a, buffer_a.name));
buf_name_.insert(std::make_pair(load_b, buffer_b.name));
- mma_sync_.insert(std::make_pair(op,
- Array<PrimExpr>{load_a_expr, load_b_expr, add->a}));
+ mma_sync_.insert(std::make_pair(op, Array<PrimExpr>{load_a_expr, load_b_expr, add->a}));
return true;
}
// ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major
class ScheduleAnalyser {
public:
- explicit ScheduleAnalyser(const MMAMatcher &mma_matcher)
- : mma_sync_(mma_matcher.mma_sync_),
- buf_name_(mma_matcher.buf_name_) {}
+ explicit ScheduleAnalyser(const MMAMatcher& mma_matcher)
+ : mma_sync_(mma_matcher.mma_sync_), buf_name_(mma_matcher.buf_name_) {}
bool MatrixIdentify(Schedule schedule) {
// TODO(minmin): handle the case where MatMul is not the output stage
}
const VarNode* axis_var[2];
const VarNode* reduce_axis_var;
- axis_var[0] = axis[axis.size()-2]->var.as<VarNode>();
- axis_var[1] = axis[axis.size()-1]->var.as<VarNode>();
+ axis_var[0] = axis[axis.size() - 2]->var.as<VarNode>();
+ axis_var[1] = axis[axis.size() - 1]->var.as<VarNode>();
reduce_axis_var = reduce_axis[0]->var.as<VarNode>();
BodyVisitor body_visitor;
matrix_major_.insert(std::make_pair(compute->name, "col_major"));
}
- for (auto &mma_sync : mma_sync_) {
- auto &operands = mma_sync.second;
+ for (auto& mma_sync : mma_sync_) {
+ auto& operands = mma_sync.second;
auto* load_a = operands[0].as<CallNode>();
auto* load_b = operands[1].as<CallNode>();
auto input0 = simplify_name(buf_name_.find(load_a)->second);
class BufferAnalyser : public StmtExprVisitor {
public:
explicit BufferAnalyser(Map<Tensor, Buffer> extern_buffer,
- const ScheduleAnalyser &schedule_analyser,
- const MMAMatcher &mma_matcher)
+ const ScheduleAnalyser& schedule_analyser, const MMAMatcher& mma_matcher)
: matrix_abc_(schedule_analyser.matrix_abc_),
matrix_major_(schedule_analyser.matrix_major_),
frag_reg_(mma_matcher.frag_reg_) {
if (op->attr_key == tir::attr::thread_extent) {
if (const IntImmNode* value = op->value.as<IntImmNode>()) {
thread_extent_.insert(
- std::make_pair(
- op->node.as<IterVarNode>()->var->name_hint,
- value->value));
+ std::make_pair(op->node.as<IterVarNode>()->var->name_hint, value->value));
}
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == tir::attr::realize_scope) {
StmtExprVisitor::VisitStmt_(op);
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
- CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << key.f;
+ CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f;
const BufferInfo& bi = it->second;
- CHECK(!bi.released)
- << "Read a buffer that is already out of scope";
+ CHECK(!bi.released) << "Read a buffer that is already out of scope";
if (matrix_abc_.count(key.GetName())) {
if (bi.shape.size() < 2) {
strides_.insert(std::make_pair(key.GetName(), strides));
if (frag_reg_.count(bi.name)) {
- PrimExpr dst = CallNode::make(bi.dtype,
- bi.name,
- op->args,
- CallNode::Halide,
- op->func,
- 0);
+ PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0);
frag_load_.insert(std::make_pair(op, dst));
auto rel_index = bi.RelIndex(op->args);
const CallNode* value = op->value.as<CallNode>();
if (value != nullptr && frag_reg_.count(value->name)) {
- PrimExpr dst = CallNode::make(bi.dtype,
- bi.name,
- op->args,
- CallNode::Halide,
- op->func,
- 0);
+ PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0);
frag_store_.insert(std::make_pair(op, dst));
}
}
if (op->call_type == CallNode::Halide) {
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
- CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << key.f;
+ CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f;
const BufferInfo& bi = it->second;
- CHECK(!bi.released)
- << "Read a buffer that is already out of scope";
+ CHECK(!bi.released) << "Read a buffer that is already out of scope";
if (matrix_abc_.count(op->name)) {
if (bi.shape.size() < 2) {
if (dim < avec.size() && avec[dim].align_factor != 0) {
PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
- stride = stride + \
- indexmod(factor + offset - indexmod(stride, factor), factor);
+ stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
stride = analyzer_.Simplify(stride);
}
rstrides.push_back(stride);
}
bool supported_warp_tile_() {
- if (warp_tile_.m == 16 &&
- warp_tile_.n == 16 &&
- warp_tile_.k == 16) {
+ if (warp_tile_.m == 16 && warp_tile_.n == 16 && warp_tile_.k == 16) {
return true;
}
- if (warp_tile_.m == 8 &&
- warp_tile_.n == 32 &&
- warp_tile_.k == 16) {
+ if (warp_tile_.m == 8 && warp_tile_.n == 32 && warp_tile_.k == 16) {
return true;
}
- if (warp_tile_.m == 32 &&
- warp_tile_.n == 8 &&
- warp_tile_.k == 16) {
+ if (warp_tile_.m == 32 && warp_tile_.n == 8 && warp_tile_.k == 16) {
return true;
}
- if (warp_tile_.m == 8 &&
- warp_tile_.n == 8 &&
- warp_tile_.k == 32) {
+ if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 32) {
return true;
}
- if (warp_tile_.m == 8 &&
- warp_tile_.n == 8 &&
- warp_tile_.k == 128) {
+ if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 128) {
return true;
}
}
std::unordered_map<TensorKey, BufferInfo> buf_map_;
- std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
+ std::unordered_map<TensorKey, std::vector<DimAlignInfo>> dim_align_;
std::unordered_map<const Object*, std::string> storage_scope_;
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
// ThreadIdxMutator does the thread index unification inside a warp
class ThreadIdxMutator : public StmtExprMutator {
public:
- explicit ThreadIdxMutator(PrimExpr warp_y): warp_y_(warp_y) {}
+ explicit ThreadIdxMutator(PrimExpr warp_y) : warp_y_(warp_y) {}
PrimExpr VisitExpr_(const VarNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
// based on tensor core intrinsics
class TensorCoreIRMutator : public StmtExprMutator {
public:
- explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser,
- const BufferAnalyser &buffer_analyser)
+ explicit TensorCoreIRMutator(const ScheduleAnalyser& schedule_analyser,
+ const BufferAnalyser& buffer_analyser)
: matrix_abc_(schedule_analyser.matrix_abc_),
- matrix_major_(schedule_analyser.matrix_major_),
- mma_sync_(schedule_analyser.mma_sync_),
- strides_(buffer_analyser.strides_),
- frag_reg_(buffer_analyser.frag_reg_),
- loop_scaling_(buffer_analyser.index_visitor.loop_scaling_),
- frag_load_(buffer_analyser.frag_load_),
- frag_store_(buffer_analyser.frag_store_),
- warp_tile_(buffer_analyser.warp_tile_),
- warp_threads_y_(buffer_analyser.warp_threads_y_) {}
+ matrix_major_(schedule_analyser.matrix_major_),
+ mma_sync_(schedule_analyser.mma_sync_),
+ strides_(buffer_analyser.strides_),
+ frag_reg_(buffer_analyser.frag_reg_),
+ loop_scaling_(buffer_analyser.index_visitor.loop_scaling_),
+ frag_load_(buffer_analyser.frag_load_),
+ frag_store_(buffer_analyser.frag_store_),
+ warp_tile_(buffer_analyser.warp_tile_),
+ warp_threads_y_(buffer_analyser.warp_threads_y_) {}
Stmt VisitStmt_(const RealizeNode* op) final {
TensorKey key{op->func, op->value_index};
for (size_t i = 0; i < op->bounds.size() - 2; ++i) {
new_bounds.push_back(op->bounds[i]);
}
- CHECK_GE(op->bounds.size(), 2)
- << "Less than 2 dimensions for matrix " << key.GetName();
- new_bounds.push_back(Range::make_by_min_extent(
- op->bounds[op->bounds.size() - 2]->min, new_extents[0]));
- new_bounds.push_back(Range::make_by_min_extent(
- op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
-
- return RealizeNode::make(op->func, op->value_index,
- op->dtype, new_bounds,
- op->condition, op->body);
+ CHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key.GetName();
+ new_bounds.push_back(
+ Range::make_by_min_extent(op->bounds[op->bounds.size() - 2]->min, new_extents[0]));
+ new_bounds.push_back(
+ Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
+
+ return RealizeNode::make(op->func, op->value_index, op->dtype, new_bounds, op->condition,
+ op->body);
}
return stmt;
}
}
auto it = matrix_abc_.find(simplify_name(node->name));
- CHECK(it != matrix_abc_.end())
- << "Cannot find matrix info for " << node->name;
+ CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name;
auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second);
Stmt body = this->VisitStmt(op->body);
- return AttrStmtNode::make(op->node,
- op->attr_key,
- matrix_abc,
- body);
+ return AttrStmtNode::make(op->node, op->attr_key, matrix_abc, body);
}
}
return stmt;
Stmt stmt = StmtExprMutator::VisitStmt_(op);
auto it = mma_sync_.find(op);
if (it != mma_sync_.end()) {
- const auto &operands = it->second;
+ const auto& operands = it->second;
PrimExpr a = operands[0];
auto ca = a.as<CallNode>();
PrimExpr b = operands[1];
ObjectPtr<BufferNode> buffer_node_b = make_object<BufferNode>();
ObjectPtr<BufferNode> buffer_node_c = make_object<BufferNode>();
- auto mma_sync_call =
- [&buffer_node_a, &buffer_node_b, &ca, &cb]
- (const Buffer &buffer) {
- Buffer buffer_a(buffer_node_a);
- Buffer buffer_b(buffer_node_b);
- if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
- return EvaluateNode::make(
- CallNode::make(DataType::Handle(),
- intrinsic::tvm_bmma_sync,
- {buffer->data, buffer->elem_offset,
- buffer_a->data, buffer_a->elem_offset,
- buffer_b->data, buffer_b->elem_offset,
- buffer->data, buffer->elem_offset},
- CallNode::Intrinsic));
- } else {
- return EvaluateNode::make(
- CallNode::make(DataType::Handle(),
- intrinsic::tvm_mma_sync,
- {buffer->data, buffer->elem_offset,
- buffer_a->data, buffer_a->elem_offset,
- buffer_b->data, buffer_b->elem_offset,
- buffer->data, buffer->elem_offset},
- CallNode::Intrinsic));
- }
- };
+ auto mma_sync_call = [&buffer_node_a, &buffer_node_b, &ca, &cb](const Buffer& buffer) {
+ Buffer buffer_a(buffer_node_a);
+ Buffer buffer_b(buffer_node_b);
+ if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
+ return EvaluateNode::make(CallNode::make(
+ DataType::Handle(), intrinsic::tvm_bmma_sync,
+ {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
+ buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
+ CallNode::Intrinsic));
+ } else {
+ return EvaluateNode::make(CallNode::make(
+ DataType::Handle(), intrinsic::tvm_mma_sync,
+ {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
+ buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
+ CallNode::Intrinsic));
+ }
+ };
- auto call_add_c =
- [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer &buffer) {
- return add_buffer_bind_scope_(cc, buffer_node_c,
- TensorKey{cc->func, cc->value_index}, mma_sync_call, cc->dtype);
- };
+ auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) {
+ return add_buffer_bind_scope_(cc, buffer_node_c, TensorKey{cc->func, cc->value_index},
+ mma_sync_call, cc->dtype);
+ };
- auto call_add_b =
- [this, &cb, &buffer_node_b, &call_add_c](const Buffer &buffer) {
- return add_buffer_bind_scope_(cb, buffer_node_b,
- TensorKey{cb->func, cb->value_index}, call_add_c, cb->dtype);
- };
+ auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) {
+ return add_buffer_bind_scope_(cb, buffer_node_b, TensorKey{cb->func, cb->value_index},
+ call_add_c, cb->dtype);
+ };
- return add_buffer_bind_scope_(ca, buffer_node_a,
- TensorKey{ca->func, ca->value_index}, call_add_b, ca->dtype);
+ return add_buffer_bind_scope_(ca, buffer_node_a, TensorKey{ca->func, ca->value_index},
+ call_add_b, ca->dtype);
}
auto it2 = frag_load_.find(op);
if (it2 != frag_load_.end()) {
PrimExpr dst = it2->second;
- if (op->value.as<FloatImmNode>() != nullptr ||
- op->value.as<IntImmNode>() != nullptr) {
+ if (op->value.as<FloatImmNode>() != nullptr || op->value.as<IntImmNode>() != nullptr) {
auto call = dst.as<CallNode>();
- auto fill_fragment_call =
- [this, &op](const Buffer &buffer) {
- return EvaluateNode::make(
- CallNode::make(DataType::Handle(),
- intrinsic::tvm_fill_fragment,
- {buffer->data,
- warp_tile_.m, warp_tile_.n, warp_tile_.k,
- buffer->elem_offset, op->value},
- CallNode::Intrinsic));
- };
+ auto fill_fragment_call = [this, &op](const Buffer& buffer) {
+ return EvaluateNode::make(CallNode::make(DataType::Handle(), intrinsic::tvm_fill_fragment,
+ {buffer->data, warp_tile_.m, warp_tile_.n,
+ warp_tile_.k, buffer->elem_offset, op->value},
+ CallNode::Intrinsic));
+ };
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
- return add_buffer_bind_scope_(call, buffer_node,
- TensorKey{call->func, call->value_index},
+ return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index},
fill_fragment_call, call->dtype);
}
const CallNode* value = op->value.as<CallNode>();
- CHECK(value != nullptr)
- << "Can only load fragment from a buffer";
+ CHECK(value != nullptr) << "Can only load fragment from a buffer";
auto it = strides_.find(value->name);
- CHECK(it != strides_.end())
- << "Cannot find stride for " << value->name;
+ CHECK(it != strides_.end()) << "Cannot find stride for " << value->name;
auto strides = it->second;
CHECK_GE(strides.size(), 2);
- PrimExpr stride = strides[strides.size()-2];
+ PrimExpr stride = strides[strides.size() - 2];
// thread index unification inside a warp
PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
PrimExpr mutated_value = thread_idx_mutator(op->value);
- PrimExpr src = CallNode::make(value->dtype,
- "&",
- {mutated_value},
- CallNode::Extern);
+ PrimExpr src = CallNode::make(value->dtype, "&", {mutated_value}, CallNode::Extern);
auto call = dst.as<CallNode>();
PrimExpr matrix_major;
auto iter2 = matrix_major_.find(simplify_name(call->name));
- CHECK(iter2 != matrix_major_.end())
- << "Can not determine matrix major for " << call->name;
+ CHECK(iter2 != matrix_major_.end()) << "Can not determine matrix major for " << call->name;
if (iter2->second == "col_major") {
matrix_major = StringImmNode::make("col_major");
} else if (iter2->second == "row_major") {
LOG(FATAL) << "invalid matrix major for " << call->name;
}
- auto load_matrix_call =
- [this, &src, &stride, &matrix_major](const Buffer &buffer) {
+ auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) {
return EvaluateNode::make(
- CallNode::make(DataType::Handle(),
- intrinsic::tvm_load_matrix_sync,
- {buffer->data,
- warp_tile_.m, warp_tile_.n, warp_tile_.k,
- buffer->elem_offset, src, stride, matrix_major},
- CallNode::Intrinsic));
+ CallNode::make(DataType::Handle(), intrinsic::tvm_load_matrix_sync,
+ {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
+ buffer->elem_offset, src, stride, matrix_major},
+ CallNode::Intrinsic));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
- return add_buffer_bind_scope_(call, buffer_node,
- TensorKey{op->func, op->value_index},
+ return add_buffer_bind_scope_(call, buffer_node, TensorKey{op->func, op->value_index},
load_matrix_call, call->dtype);
}
if (it3 != frag_store_.end()) {
TensorKey key{op->func, op->value_index};
auto it = strides_.find(key.GetName());
- CHECK(it != strides_.end())
- << "Cannot find stride for " << key.GetName();
+ CHECK(it != strides_.end()) << "Cannot find stride for " << key.GetName();
auto strides = it->second;
CHECK_GE(strides.size(), 2);
- PrimExpr stride = strides[strides.size()-2];
+ PrimExpr stride = strides[strides.size() - 2];
PrimExpr dst = it3->second;
// thread index unification inside a warp
PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
dst = thread_idx_mutator(dst);
- dst = CallNode::make(DataType::Handle(),
- "&",
- {dst},
- CallNode::Extern);
+ dst = CallNode::make(DataType::Handle(), "&", {dst}, CallNode::Extern);
auto call = op->value.as<CallNode>();
- auto store_matrix_call =
- [this, &dst, &stride](const Buffer &buffer) {
- return EvaluateNode::make(
- CallNode::make(DataType::Handle(),
- intrinsic::tvm_store_matrix_sync,
- {buffer->data,
- warp_tile_.m, warp_tile_.n, warp_tile_.k,
- buffer->elem_offset, dst, stride,
- StringImmNode::make("col_major")},
- CallNode::Intrinsic));
- };
+ auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) {
+ return EvaluateNode::make(
+ CallNode::make(DataType::Handle(), intrinsic::tvm_store_matrix_sync,
+ {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
+ buffer->elem_offset, dst, stride, StringImmNode::make("col_major")},
+ CallNode::Intrinsic));
+ };
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
- return add_buffer_bind_scope_(call, buffer_node,
- TensorKey{call->func, call->value_index},
+ return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index},
store_matrix_call, call->dtype);
}
if (it != loop_scaling_.end()) {
int scale_factor = it->second;
int scaled_extent_value = 1;
- if (const IntImmNode *ori_extent = op->extent.as<IntImmNode>()) {
+ if (const IntImmNode* ori_extent = op->extent.as<IntImmNode>()) {
int ori_extent_value = ori_extent->value;
scaled_extent_value = ori_extent_value / scale_factor;
}
PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
- stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type,
- op->device_api, op->body);
+ stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api,
+ op->body);
}
}
return stmt;
}
private:
- Array<PrimExpr> get_tile_size_(const std::string &name) {
- auto it = matrix_abc_.find(name);
- auto it2 = matrix_major_.find(name);
- CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end())
- << "Cannot find matrix info for " << name;
- PrimExpr size0 = make_const(DataType::Int(32), 16);
- PrimExpr size1 = make_const(DataType::Int(32), 16);
- if (it->second == "matrix_a" && it2->second == "col_major") {
- size0 = make_const(DataType::Int(32), warp_tile_.k);
- size1 = make_const(DataType::Int(32), warp_tile_.m);
- }
- if (it->second == "matrix_a" && it2->second == "row_major") {
- size0 = make_const(DataType::Int(32), warp_tile_.m);
- size1 = make_const(DataType::Int(32), warp_tile_.k);
- }
- if (it->second == "matrix_b" && it2->second == "row_major") {
- size0 = make_const(DataType::Int(32), warp_tile_.k);
- size1 = make_const(DataType::Int(32), warp_tile_.n);
- }
- if (it->second == "matrix_b" && it2->second == "col_major") {
- size0 = make_const(DataType::Int(32), warp_tile_.n);
- size1 = make_const(DataType::Int(32), warp_tile_.k);
- }
- if (it->second == "matrix_c") {
- size0 = make_const(DataType::Int(32), warp_tile_.n);
- size1 = make_const(DataType::Int(32), warp_tile_.m);
- }
- Array<PrimExpr> tile_size = {size0, size1};
- return tile_size;
+ Array<PrimExpr> get_tile_size_(const std::string& name) {
+ auto it = matrix_abc_.find(name);
+ auto it2 = matrix_major_.find(name);
+ CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end())
+ << "Cannot find matrix info for " << name;
+ PrimExpr size0 = make_const(DataType::Int(32), 16);
+ PrimExpr size1 = make_const(DataType::Int(32), 16);
+ if (it->second == "matrix_a" && it2->second == "col_major") {
+ size0 = make_const(DataType::Int(32), warp_tile_.k);
+ size1 = make_const(DataType::Int(32), warp_tile_.m);
+ }
+ if (it->second == "matrix_a" && it2->second == "row_major") {
+ size0 = make_const(DataType::Int(32), warp_tile_.m);
+ size1 = make_const(DataType::Int(32), warp_tile_.k);
+ }
+ if (it->second == "matrix_b" && it2->second == "row_major") {
+ size0 = make_const(DataType::Int(32), warp_tile_.k);
+ size1 = make_const(DataType::Int(32), warp_tile_.n);
+ }
+ if (it->second == "matrix_b" && it2->second == "col_major") {
+ size0 = make_const(DataType::Int(32), warp_tile_.n);
+ size1 = make_const(DataType::Int(32), warp_tile_.k);
+ }
+ if (it->second == "matrix_c") {
+ size0 = make_const(DataType::Int(32), warp_tile_.n);
+ size1 = make_const(DataType::Int(32), warp_tile_.m);
+ }
+ Array<PrimExpr> tile_size = {size0, size1};
+ return tile_size;
}
- Stmt add_buffer_bind_scope_(const CallNode* call,
- const ObjectPtr<BufferNode> &buffer_node, const TensorKey &key,
- const std::function<Stmt(const Buffer &buffer)> &call_back,
- DataType datatype) {
+ Stmt add_buffer_bind_scope_(const CallNode* call, const ObjectPtr<BufferNode>& buffer_node,
+ const TensorKey& key,
+ const std::function<Stmt(const Buffer& buffer)>& call_back,
+ DataType datatype) {
auto it = bounds_.find(key);
CHECK(it != bounds_.end());
Array<PrimExpr> min_bound;
CHECK_EQ(call->args.size(), min_bound.size());
for (size_t i = 0; i < min_bound.size(); i++) {
elem_offset = AddNode::make(
- elem_offset, MulNode::make(
- strides[i], SubNode::make(call->args[i], min_bound[i])));
+ elem_offset, MulNode::make(strides[i], SubNode::make(call->args[i], min_bound[i])));
}
auto it2 = matrix_abc_.find(simplify_name(call->name));
- CHECK(it2 != matrix_abc_.end())
- << "Cannot find matrix info for " << call->name;
+ CHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << call->name;
buffer_node->data = Var(call->name, DataType::Handle());
buffer_node->name = call->name;
buffer_node->scope = "wmma." + it2->second;
args.push_back(call->args[i]);
args.push_back(shape[i]);
}
- auto tuple = CallNode::make(DataType::Handle(),
- intrinsic::tvm_tuple,
- args,
- CallNode::Intrinsic);
+ auto tuple =
+ CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic);
Array<ObjectRef> node = {buffer, tensor};
- return AttrStmtNode::make(node,
- "buffer_bind_scope",
- tuple,
- call_back(buffer));
+ return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer));
}
std::unordered_map<std::string, std::string> matrix_abc_;
int warp_threads_y_{-1};
};
-Stmt SchedulePostProcRewriteForTensorCore(
- Stmt stmt,
- Schedule schedule,
- Map<Tensor, Buffer> extern_buffer) {
+Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
+ Map<Tensor, Buffer> extern_buffer) {
// Check if current lower target is CUDA
auto target = tvm::Target::Current(true);
if (target.defined() && target->target_name != "cuda") {
return stmt;
}
- BufferAnalyser buffer_analyser(extern_buffer,
- schedule_analyser, mma_matcher);
+ BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher);
buffer_analyser(stmt);
if (!buffer_analyser.QualifiedForTensorCore()) {
return stmt;
}
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore")
-.set_body_typed([](Stmt stmt,
- Schedule schedule,
- Map<te::Tensor, Buffer> extern_buffer) {
- return SchedulePostProcRewriteForTensorCore(
- stmt, schedule, extern_buffer);
-});
+ .set_body_typed([](Stmt stmt, Schedule schedule, Map<te::Tensor, Buffer> extern_buffer) {
+ return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer);
+ });
} // namespace te
} // namespace tvm
* - Add annotation of extern buffers using the buffer_map field
* in the PrimFunc type.
*/
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/te/operation.h>
-#include <utility>
+
#include <unordered_map>
+#include <utility>
namespace tvm {
namespace te {
class TensorToBufferMapper : public StmtExprMutator {
public:
explicit TensorToBufferMapper(std::unordered_map<Tensor, Buffer> buffer_map)
- : buffer_map_(buffer_map) {
- }
+ : buffer_map_(buffer_map) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
auto ret = StmtExprMutator::VisitStmt_(op);
Operation operation = Downcast<Operation>(op->node);
for (int i = operation->num_outputs(); i != 0; --i) {
Buffer buffer = GetOrAllocBuffer(operation.output(i - 1));
- body = AttrStmtNode::make(
- buffer, op->attr_key, op->value, body);
+ body = AttrStmtNode::make(buffer, op->attr_key, op->value, body);
}
return body;
} else if (op->attr_key == tir::attr::buffer_bind_scope) {
- Array<ObjectRef> tuple = Downcast<Array<ObjectRef> >(op->node);
+ Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node);
Tensor tensor = Downcast<Tensor>(tuple[1]);
- return AttrStmtNode::make(
- Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)},
- op->attr_key, op->value, op->body);
- } else if (op->attr_key == tir::attr::buffer_dim_align||
+ return AttrStmtNode::make(Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key,
+ op->value, op->body);
+ } else if (op->attr_key == tir::attr::buffer_dim_align ||
op->attr_key == tir::attr::prefetch_scope) {
Tensor tensor = Downcast<Tensor>(op->node);
Buffer buffer = GetOrAllocBuffer(tensor);
- return AttrStmtNode::make(
- buffer, op->attr_key, op->value, op->body);
+ return AttrStmtNode::make(buffer, op->attr_key, op->value, op->body);
} else {
return ret;
}
}
private:
- Buffer GetOrAllocBuffer(const Tensor& tensor) {
- return GetBuffer(tensor, true);
- }
+ Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); }
Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) {
auto it = buffer_map_.find(tensor);
std::unordered_map<Tensor, Buffer> buffer_map_;
};
-
-PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
- Stmt body,
+PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body,
Optional<Map<Tensor, Buffer>> extern_buffer_opt) {
std::unordered_map<Tensor, Buffer> extern_buffer;
}
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
-.set_body_typed(SchedulePostProcToPrimFunc);
+ .set_body_typed(SchedulePostProcToPrimFunc);
} // namespace te
} // namespace tvm
* \brief Verify if there was any compact buffer bound to a statement.
*/
#include <tvm/runtime/registry.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/te/tensor.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/te/tensor.h>
-#include <tvm/te/schedule_pass.h>
#include <unordered_map>
return verifier.Verify(stmt);
}
-TVM_REGISTER_GLOBAL("schedule.VerifyCompactBuffer")
-.set_body_typed(VerifyCompactBuffer);
+TVM_REGISTER_GLOBAL("schedule.VerifyCompactBuffer").set_body_typed(VerifyCompactBuffer);
} // namespace te
} // namespace tvm
* \file tensor.cc
*/
#include <tvm/runtime/registry.h>
-#include <tvm/te/tensor.h>
#include <tvm/te/operation.h>
+#include <tvm/te/tensor.h>
#include <tvm/te/tensor_intrin.h>
+
#include <memory>
namespace tvm {
namespace te {
IterVar thread_axis(Range dom, std::string tag) {
- return IterVarNode::make(
- dom, Var(tag), kThreadIndex, tag);
+ return IterVarNode::make(dom, Var(tag), kThreadIndex, tag);
}
IterVar reduce_axis(Range dom, std::string name) {
- return IterVarNode::make(
- dom, Var(name), kCommReduce);
+ return IterVarNode::make(dom, Var(name), kCommReduce);
}
-Var var(std::string name_hint, DataType t) {
- return Var(name_hint, t);
-}
+Var var(std::string name_hint, DataType t) { return Var(name_hint, t); }
// Tensor
PrimExpr Tensor::operator()(Array<Var> indices) const {
PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
using tir::CallNode;
if (ndim() != 0) {
- CHECK_EQ(ndim(), indices.size())
- << "Tensor dimension mismatch in read"
- << "ndim = " << ndim() << ", indices.size=" << indices.size();
+ CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read"
+ << "ndim = " << ndim() << ", indices.size=" << indices.size();
}
- auto n = CallNode::make(
- (*this)->dtype, (*this)->op->name, indices, CallNode::Halide,
- (*this)->op, (*this)->value_index);
+ auto n = CallNode::make((*this)->dtype, (*this)->op->name, indices, CallNode::Halide, (*this)->op,
+ (*this)->value_index);
return n;
}
return Tensor(node);
}
-Tensor TensorNode::make(Array<PrimExpr> shape,
- DataType dtype,
- Operation op,
- int value_index) {
+Tensor TensorNode::make(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {
auto n = make_object<TensorNode>();
n->shape = std::move(shape);
n->dtype = dtype;
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TensorNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* t = static_cast<const TensorNode*>(node.get());
- p->stream << "Tensor(shape=" << t->shape
- << ", op.name=" << t->op->name << ')';
- });
+ .set_dispatch<TensorNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* t = static_cast<const TensorNode*>(node.get());
+ p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')';
+ });
TVM_REGISTER_NODE_TYPE(TensorNode);
-
// TensorIntrin
-TensorIntrin TensorIntrinNode::make(std::string name,
- Operation op,
- Array<Tensor> inputs,
- Array<Buffer> buffers,
- Array<Var> scalar_params,
- Stmt body,
- Stmt reduce_init,
- Stmt reduce_update) {
+TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array<Tensor> inputs,
+ Array<Buffer> buffers, Array<Var> scalar_params, Stmt body,
+ Stmt reduce_init, Stmt reduce_update) {
auto n = make_object<TensorIntrinNode>();
n->name = std::move(name);
n->op = std::move(op);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const TensorIntrinNode*>(node.get());
- p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
- });
+ .set_dispatch<TensorIntrinNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const TensorIntrinNode*>(node.get());
+ p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
+ });
TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
-
// TensorIntrinCall
-TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
- Array<Tensor> tensors,
- Array<Region> regions,
- Array<IterVar> reduce_axis,
+TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array<Tensor> tensors,
+ Array<Region> regions, Array<IterVar> reduce_axis,
Array<PrimExpr> scalar_inputs) {
auto n = make_object<TensorIntrinCallNode>();
n->intrin = std::move(intrin);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
- p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
- });
+ .set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
+ p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
+ });
TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
-TVM_REGISTER_GLOBAL("te.Tensor")
-.set_body_typed(TensorNode::make);
+TVM_REGISTER_GLOBAL("te.Tensor").set_body_typed(TensorNode::make);
-TVM_REGISTER_GLOBAL("te.TensorIntrin")
-.set_body_typed(TensorIntrinNode::make);
+TVM_REGISTER_GLOBAL("te.TensorIntrin").set_body_typed(TensorIntrinNode::make);
-TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
-.set_body_typed(TensorIntrinCallNode::make);
+TVM_REGISTER_GLOBAL("te.TensorIntrinCall").set_body_typed(TensorIntrinCallNode::make);
-TVM_REGISTER_GLOBAL("te.TensorEqual")
-.set_body_method(&Tensor::operator==);
+TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==);
-TVM_REGISTER_GLOBAL("te.TensorHash")
-.set_body_typed([](Tensor tensor) -> int64_t {
- return static_cast<int64_t>(std::hash<Tensor>()(tensor));
- });
+TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t {
+ return static_cast<int64_t>(std::hash<Tensor>()(tensor));
+});
-TVM_REGISTER_GLOBAL("te.OpGetOutput")
-.set_body_typed([](Operation op, int64_t output) {
+TVM_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) {
return op.output(static_cast<size_t>(output));
});
-TVM_REGISTER_GLOBAL("te.OpNumOutputs")
-.set_body_method<Operation>(&OperationNode::num_outputs);
+TVM_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method<Operation>(&OperationNode::num_outputs);
-TVM_REGISTER_GLOBAL("te.OpInputTensors")
-.set_body_method<Operation>(&OperationNode::InputTensors);
+TVM_REGISTER_GLOBAL("te.OpInputTensors").set_body_method<Operation>(&OperationNode::InputTensors);
} // namespace te
} // namespace tvm
* \file tir/analysis/deep_equal.cc
* \brief Deep equality checking.
*/
-#include <tvm/node/structural_equal.h>
#include <tvm/node/reflection.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
namespace tvm {
namespace tir {
-class DeepCmpSEqualHandler :
- public SEqualReducer::Handler {
+class DeepCmpSEqualHandler : public SEqualReducer::Handler {
public:
// use direct recursion.
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, false));
}
- ObjectRef MapLhsToRhs(const ObjectRef& lhs) final {
- return ObjectRef(nullptr);
- }
+ ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); }
- void MarkGraphNode() final {
- }
+ void MarkGraphNode() final {}
private:
// reflection vtable
}
TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal")
-.set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) {
- return ExprDeepEqual()(lhs, rhs);
-});
+ .set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) {
+ return ExprDeepEqual()(lhs, rhs);
+ });
} // namespace tir
} // namespace tvm
* \file side_effect.cc
* \brief side effect analysis
*/
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
-#include <tvm/tir/analysis.h>
namespace tvm {
namespace tir {
void VisitExpr_(const CallNode* op) final {
if (!op->is_pure()) {
- has_side_effect_ = true; return;
+ has_side_effect_ = true;
+ return;
} else {
ExprVisitor::VisitExpr_(op);
}
* \file simple_analysis.cc
* \brief Implementation of simple passes
*/
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/analysis.h>
namespace tvm {
namespace tir {
class VarTouchVisitor : public ExprVisitor {
public:
- explicit VarTouchVisitor(
- std::function<bool(const VarNode*)> var_set)
- : var_set_(var_set) {}
+ explicit VarTouchVisitor(std::function<bool(const VarNode*)> var_set) : var_set_(var_set) {}
void VisitExpr(const PrimExpr& e) final {
if (use_var_) return;
ExprVisitor::VisitExpr(e);
}
- void VisitExpr_(const VarNode* op) final {
- Handle(op);
- }
+ void VisitExpr_(const VarNode* op) final { Handle(op); }
void VisitExpr_(const LoadNode* op) final {
Handle(op->buffer_var.get());
std::function<bool(const VarNode*)> var_set_;
};
-
-bool ExprUseVar(const PrimExpr& e,
- std::function<bool(const VarNode*)> var_set) {
+bool ExprUseVar(const PrimExpr& e, std::function<bool(const VarNode*)> var_set) {
VarTouchVisitor visitor(var_set);
visitor(e);
return visitor.use_var_;
class GPUCodeVerifier : public StmtVisitor {
public:
- bool Verify(Stmt stmt,
- int64_t max_local_memory_per_block,
- int64_t max_shared_memory_per_block,
- int64_t max_threads_per_block,
- int64_t max_thread_x,
- int64_t max_thread_y,
+ bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
+ int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
int64_t max_thread_z) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
}
Var var = op->node.as<IterVarNode>()->var;
- const auto *extent = op->value.as<IntImmNode>();
+ const auto* extent = op->value.as<IntImmNode>();
CHECK(extent);
// record the number of threads in a block
private:
int nest_level_{0};
- std::unordered_set<const VarNode *> visited_local_buffers_;
- std::unordered_set<const VarNode *> visited_shared_buffers_;
+ std::unordered_set<const VarNode*> visited_local_buffers_;
+ std::unordered_set<const VarNode*> visited_shared_buffers_;
std::unordered_set<std::string> visited_threads_;
size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;
}
};
-bool VerifyGPUCode(const PrimFunc& func,
- Map<std::string, PrimExpr> constraints) {
+bool VerifyGPUCode(const PrimFunc& func, Map<std::string, PrimExpr> constraints) {
GPUCodeVerifier verifier;
int64_t max_local_memory_per_block = INT64_MAX;
LOG(FATAL) << "Invalid check item: " << iter.first;
}
- return verifier.Verify(func->body,
- max_local_memory_per_block,
- max_shared_memory_per_block,
- max_threads_per_block,
- max_thread_x,
- max_thread_y,
- max_thread_z);
+ return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
+ max_threads_per_block, max_thread_x, max_thread_y, max_thread_z);
}
-
-TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code")
-.set_body_typed(VerifyGPUCode);
+TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
namespace transform {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
- CHECK(VerifyGPUCode(func, constraints))
- << "RuntimeError: GPU constraint violated"
- << func;
+ CHECK(VerifyGPUCode(func, constraints)) << "RuntimeError: GPU constraint violated" << func;
}
}
return mod;
return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode")
-.set_body_typed(VerifyGPUCode);
+TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode);
} // namespace transform
} // namespace tir
* \file verify_memory.cc
* \brief Pass to check if memory accesses are legal.
*/
-#include <tvm/tir/expr.h>
#include <tvm/ir/transform.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/target/target.h>
-#include <tvm/runtime/registry.h>
-
namespace tvm {
namespace tir {
public:
/// Special member functions
//@{
- explicit MemoryAccessVerifier(PrimFunc f, int device_type)
- : func_(f), dev_type_(device_type) {}
+ explicit MemoryAccessVerifier(PrimFunc f, int device_type) : func_(f), dev_type_(device_type) {}
virtual ~MemoryAccessVerifier() = default;
- MemoryAccessVerifier(const MemoryAccessVerifier &) = delete;
- MemoryAccessVerifier(MemoryAccessVerifier &&) = delete;
- MemoryAccessVerifier &operator=(const MemoryAccessVerifier &) = delete;
- MemoryAccessVerifier &operator=(MemoryAccessVerifier &&) = delete;
+ MemoryAccessVerifier(const MemoryAccessVerifier&) = delete;
+ MemoryAccessVerifier(MemoryAccessVerifier&&) = delete;
+ MemoryAccessVerifier& operator=(const MemoryAccessVerifier&) = delete;
+ MemoryAccessVerifier& operator=(MemoryAccessVerifier&&) = delete;
//@}
/// Interface to perform memory access verification
protected:
/// Visitor implementation
//@{
- void VisitExpr(const PrimExpr &n) final {
+ void VisitExpr(const PrimExpr& n) final {
if (Failed()) return;
StmtExprVisitor::VisitExpr(n);
}
- void VisitStmt(const Stmt &n) final {
+ void VisitStmt(const Stmt& n) final {
if (Failed()) return;
StmtExprVisitor::VisitStmt(n);
}
}
void VisitStmt_(const AttrStmtNode* op) final {
- if (!InThreadEnv() && (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::pipeline_exec_scope)) {
+ if (!InThreadEnv() &&
+ (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) {
EnterThreadEnv();
StmtExprVisitor::VisitStmt_(op);
ExitThreadEnv();
//@}
/// Check if the value of a Variable comes from function argument.
- bool IsFromFunctionArgs(const VarNode *var) const {
- const VarNode *V = var;
+ bool IsFromFunctionArgs(const VarNode* var) const {
+ const VarNode* V = var;
for (auto kv : func_->buffer_map) {
if (V == kv.second->data.get()) return true;
}
// The value is expected to come from a tvm_struct_get Call.
// Get the first argument of tvm_struct_get, and continue.
- const auto &iter = defs_.find(V);
+ const auto& iter = defs_.find(V);
if (iter == defs_.end()) return false;
- const CallNode *C = iter->second.as<const CallNode>();
+ const CallNode* C = iter->second.as<const CallNode>();
if (!C || C->name != intrinsic::tvm_struct_get) return false;
V = C->args[0].as<VarNode>();
}
}
/// Handle memory access to a Variable
- void HandleLoadStoreToVariable(const Var &var) {
+ void HandleLoadStoreToVariable(const Var& var) {
// We skip the access within thread env.
if (InThreadEnv()) return;
/// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
static bool IsGPUDevice(int dev_type) {
- return kDLGPU == dev_type || kDLOpenCL == dev_type ||
- kDLVulkan == dev_type || kDLMetal == dev_type ||
- kDLROCM == dev_type || kOpenGL == dev_type;
+ return kDLGPU == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type ||
+ kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type;
}
/// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device.
- static bool IsFPGADevice(int dev_type) {
- return kDLSDAccel == dev_type || kDLAOCL == dev_type;
- }
+ static bool IsFPGADevice(int dev_type) { return kDLSDAccel == dev_type || kDLAOCL == dev_type; }
private:
/// Status of visitor
bool in_thread_env_{false};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
//@}
- tir::PrimFunc func_{nullptr}; ///< Function to be verified.
- int dev_type_{kDLCPU}; ///< Device type
- std::unordered_map<const VarNode *, PrimExpr> defs_; ///< Variable definitions
+ tir::PrimFunc func_{nullptr}; ///< Function to be verified.
+ int dev_type_{kDLCPU}; ///< Device type
+ std::unordered_map<const VarNode*, PrimExpr> defs_; ///< Variable definitions
};
} // namespace
/// Interface of VerifyMemory pass
bool VerifyMemory(const PrimFunc& func) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
- CHECK(target.defined())
- << "LowerWarpMemory: Require the target attribute";
+ CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";
- if (func->GetAttr<Integer>(
- tvm::attr::kCallingConv,
- Integer(CallingConv::kDefault)) == CallingConv::kDefault) {
+ if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
+ CallingConv::kDefault) {
MemoryAccessVerifier v(func, target.value()->device_type);
v.Run();
return !v.Failed();
}
}
-TVM_REGISTER_GLOBAL("tir.analysis.verify_memory")
-.set_body_typed(VerifyMemory);
+TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory);
namespace transform {
Pass VerifyMemory() {
- auto pass_func = [=](IRModule mod, PassContext ctx) {
- for (auto kv : mod->functions) {
- if (auto* n = kv.second.as<PrimFuncNode>()) {
- auto func = GetRef<PrimFunc>(n);
- CHECK(VerifyMemory(func))
- << "RuntimeError: Direct host side access to device memory is detected."
- << " Did you forget to bind?\n"
- << func;
- }
- }
- return mod;
- };
+ auto pass_func =
+ [=](IRModule mod, PassContext ctx) {
+ for (auto kv : mod->functions) {
+ if (auto* n = kv.second.as<PrimFuncNode>()) {
+ auto func = GetRef<PrimFunc>(n);
+ CHECK(VerifyMemory(func))
+ << "RuntimeError: Direct host side access to device memory is detected."
+ << " Did you forget to bind?\n"
+ << func;
+ }
+ }
+ return mod;
+ };
return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory")
-.set_body_typed(VerifyMemory);
+TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory);
} // namespace transform
} // namespace tir
* \file verify_ssa.cc
*/
#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/analysis.h>
-#include <unordered_set>
+
#include <unordered_map>
+#include <unordered_set>
#include <vector>
namespace tvm {
void MarkDef(const VarNode* v, bool allow_dup = false) {
if (defined_.count(v) != 0) {
if (!allow_dup) {
- is_ssa = false; return;
+ is_ssa = false;
+ return;
}
} else {
defined_[v] = 1;
std::unordered_map<const VarNode*, int> defined_;
};
-
bool VerifySSA(const PrimFunc& func) {
IRVerifySSA visitor;
visitor.Run(func);
return visitor.is_ssa;
}
-TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa")
-.set_body_typed(VerifySSA);
-
+TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA);
namespace transform {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
- CHECK(VerifySSA(func))
- << "RuntimeError: IR is not in SSA form"
- << func;
+ CHECK(VerifySSA(func)) << "RuntimeError: IR is not in SSA form" << func;
}
}
return mod;
return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.VerifySSA")
-.set_body_typed(VerifySSA);
+TVM_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA);
} // namespace transform
/*!
* \file buffer.cc
*/
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
-#include <tvm/runtime/device_api.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/analysis.h>
-#include <tvm/arith/analyzer.h>
#include <iterator>
#include <stack>
+
#include "../../arith/compute_expr.h"
namespace tvm {
return array;
}
-Buffer decl_buffer(Array<PrimExpr> shape,
- DataType dtype,
- std::string name) {
- return BufferNode::make(
- Var(name, PointerType(PrimType(dtype))),
- dtype,
- shape,
- Array<PrimExpr>(),
- PrimExpr(),
- name,
- "",
- 0, 0,
- kDefault);
+Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, std::string name) {
+ return BufferNode::make(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array<PrimExpr>(),
+ PrimExpr(), name, "", 0, 0, kDefault);
}
// Split the given expression w.r.t the add operator
-inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr &expr) {
+inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr& expr) {
using namespace tir;
std::vector<const PrimExpr*> ret;
std::stack<const PrimExpr*> split_buffer;
return ret;
}
-
// Searches for the following types of expr:
// mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
// mod_l_expr = c
// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c)
// Currently the we will not search the add/mult combinations exhaustively
// as it will take too much computation.
-inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr &mult_expr,
- const PrimExpr &mod_l_expr,
- const PrimExpr &mod_r_expr) {
+inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr& mult_expr,
+ const PrimExpr& mod_l_expr,
+ const PrimExpr& mod_r_expr) {
using namespace tir;
const MulNode* mult_ptr = mult_expr.as<MulNode>();
if (!mult_ptr) return std::make_pair(false, PrimExpr());
return std::make_pair(false, PrimExpr());
} else if (inner_div_ptr) {
PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
- if (expr_equal(overall_mult, inner_div_ptr->b)
- && expr_equal(overall_mult, mod_r_expr)
- && expr_equal(inner_div_ptr->a, mod_l_expr)) {
+ if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) &&
+ expr_equal(inner_div_ptr->a, mod_l_expr)) {
// Found!
PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
return std::make_pair(true, ret);
inline void MergeMulModInsertElements(const std::vector<const PrimExpr*>& eles,
std::list<PrimExpr>* mult_exprs,
std::list<std::pair<PrimExpr, PrimExpr> >* mod_exprs,
- PrimExpr* no_opt_sum,
- bool* has_mult,
- bool* has_mod) {
+ PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) {
using namespace tir;
*has_mult = false;
*has_mod = false;
// The search will be performed repeatively until no pattern is found.
// Return: a pair with (false, Expr()) if cannot be optimized.
// a pair with (true, optimized_expr) if can be optimized
-inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) {
+inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) {
using namespace tir;
// 1. Prepare the lists.
// We store two lists, a list that contain all the elements that match Mul and
PrimExpr no_opt_sum;
bool has_mult;
bool has_mod;
- MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs,
- &no_opt_sum, &has_mult, &has_mod);
+ MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod);
bool find_opt = false;
std::list<std::pair<PrimExpr, PrimExpr> >::iterator search_mod_it = mod_exprs.begin();
// 2. Exhaustive Search
std::list<PrimExpr>::iterator mult_it = mult_exprs.begin();
bool inner_find_opt = false;
while (mult_it != mult_exprs.end()) {
- std::pair<bool, PrimExpr> ret = MergeMulModInner(*mult_it,
- search_mod_it->first,
- search_mod_it->second);
+ std::pair<bool, PrimExpr> ret =
+ MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second);
if (ret.first) {
inner_find_opt = true;
auto temp_mod_it = search_mod_it;
mod_exprs.erase(temp_mod_it);
mult_exprs.erase(mult_it);
std::vector<const PrimExpr*> ret_eles = ExprSplitAddition(ret.second);
- MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs,
- &no_opt_sum, &has_mult, &has_mod);
+ MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult,
+ &has_mod);
if (has_mult) {
search_mod_it = mod_exprs.begin();
} else if (has_mod && search_mod_it == mod_exprs.end()) {
no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it;
}
for (std::list<std::pair<PrimExpr, PrimExpr> >::iterator it = mod_exprs.begin();
- it != mod_exprs.end(); ++it) {
- no_opt_sum = no_opt_sum.get() ?
- no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second);
+ it != mod_exprs.end(); ++it) {
+ no_opt_sum = no_opt_sum.get() ? no_opt_sum + indexmod(it->first, it->second)
+ : indexmod(it->first, it->second);
}
return no_opt_sum;
}
PrimExpr Buffer::vload(Array<PrimExpr> begin, DataType dtype) const {
// specially handle bool, stored asDataType::Int(8)
const BufferNode* n = operator->();
- CHECK(dtype.element_of() == n->dtype.element_of() &&
- dtype.lanes() % n->dtype.lanes() == 0)
- << "Cannot load " << dtype
- << " from buffer of " << n->dtype;
+ CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0)
+ << "Cannot load " << dtype << " from buffer of " << n->dtype;
if (dtype == DataType::Bool()) {
return tir::CastNode::make(
DataType::Bool(),
- tir::LoadNode::make(
- DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)),
- const_true()));
+ tir::LoadNode::make(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)),
+ const_true()));
} else {
- return tir::LoadNode::make(
- dtype, n->data, BufferOffset(n, begin, dtype),
- const_true(dtype.lanes()));
+ return tir::LoadNode::make(dtype, n->data, BufferOffset(n, begin, dtype),
+ const_true(dtype.lanes()));
}
}
// specially handle bool, stored asDataType::Int(8)
const BufferNode* n = operator->();
DataType dtype = value.dtype();
- CHECK(dtype.element_of() == n->dtype.element_of() &&
- dtype.lanes() % n->dtype.lanes() == 0)
- << "Cannot load " << dtype
- << " from buffer of " << n->dtype;
+ CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0)
+ << "Cannot load " << dtype << " from buffer of " << n->dtype;
if (value.dtype() == DataType::Bool()) {
- return tir::StoreNode::make(n->data,
- tir::CastNode::make(DataType::Int(8), value),
- BufferOffset(n, begin, DataType::Int(8)),
- const_true());
+ return tir::StoreNode::make(n->data, tir::CastNode::make(DataType::Int(8), value),
+ BufferOffset(n, begin, DataType::Int(8)), const_true());
} else {
return tir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype),
- const_true(dtype.lanes()));
+ const_true(dtype.lanes()));
}
}
std::vector<PrimExpr> temp;
auto n = make_object<BufferNode>(*operator->());
PrimExpr acc = make_const(n->DefaultIndexType(), 1);
- for (size_t i = n->shape.size(); i != 0 ; --i) {
+ for (size_t i = n->shape.size(); i != 0; --i) {
temp.push_back(acc);
acc = acc * n->shape[i - 1];
}
// check if stride is needed.
for (size_t i = 0; i < extents.size(); ++i) {
if (!can_relax) {
- if (!is_zero(begins[i]) ||
- !is_zero(ana.Simplify(extents[i] - n->shape[i]))) {
+ if (!is_zero(begins[i]) || !is_zero(ana.Simplify(extents[i] - n->shape[i]))) {
need_stride = true;
}
}
return MakeStrideView().MakeSlice(begins, extents);
}
}
- return BufferNode::make(n->data,
- n->dtype,
- extents,
- strides,
- elem_offset,
- n->name + "_slice",
- n->scope,
- n->data_alignment,
- 0,
- n->buffer_type);
+ return BufferNode::make(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice",
+ n->scope, n->data_alignment, 0, n->buffer_type);
}
-PrimExpr Buffer::access_ptr(int access_mask,
- DataType ptr_type,
- int content_lanes,
+PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,
PrimExpr offset) const {
const BufferNode* self = operator->();
PrimExpr e_dtype;
if (content_lanes > 1) {
e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.dtype(), content_lanes);
- elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(),
- content_lanes);
+ elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes);
} else {
e_dtype = tir::TypeAnnotation(self->dtype);
}
- Array<PrimExpr> acc_args{
- e_dtype, self->data, elem_offset,
- extent, make_const(DataType::Int(32), access_mask)};
- return tir::CallNode::make(
- ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic);
+ Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
+ make_const(DataType::Int(32), access_mask)};
+ return tir::CallNode::make(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args,
+ tir::CallNode::Intrinsic);
}
-Buffer BufferNode::make(Var data,
- DataType dtype,
- Array<PrimExpr> shape,
- Array<PrimExpr> strides,
- PrimExpr elem_offset,
- std::string name,
- std::string scope,
- int data_alignment,
- int offset_factor,
- BufferType buffer_type) {
+Buffer BufferNode::make(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
+ PrimExpr elem_offset, std::string name, std::string scope,
+ int data_alignment, int offset_factor, BufferType buffer_type) {
auto n = make_object<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<BufferNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferNode*>(node.get());
- p->stream << "buffer(" << op->name << ", " << op << ")";
-});
+ .set_dispatch<BufferNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferNode*>(node.get());
+ p->stream << "buffer(" << op->name << ", " << op << ")";
+ });
TVM_REGISTER_NODE_TYPE(BufferNode);
+TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) {
+ CHECK_EQ(args.size(), 10);
+ auto buffer_type = args[9].operator std::string();
+ BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
+ *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7],
+ args[8], type);
+});
-TVM_REGISTER_GLOBAL("tir.Buffer")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- CHECK_EQ(args.size(), 10);
- auto buffer_type = args[9].operator std::string();
- BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
- *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4],
- args[5], args[6], args[7], args[8], type);
- });
-
-TVM_REGISTER_GLOBAL("tir.BufferAccessPtr")
-.set_body_method(&Buffer::access_ptr);
+TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr);
-TVM_REGISTER_GLOBAL("tir.BufferVLoad")
-.set_body_method(&Buffer::vload);
+TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload);
-TVM_REGISTER_GLOBAL("tir.BufferVStore")
-.set_body_method(&Buffer::vstore);
+TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore);
} // namespace tir
} // namespace tvm
* \file src/lang/data_layout.cc
* \brief Data Layout expression.
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/arith/analyzer.h>
#include <cctype>
namespace tvm {
namespace tir {
-using tir::Var;
using tir::IterVar;
using tir::IterVarNode;
+using tir::Var;
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode);
const LayoutAxis LayoutAxis::UPPER_CASE[] = {
- LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'),
- LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'),
- LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'),
- LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'),
- LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'),
- LayoutAxis('Z')
-};
+ LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'),
+ LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'),
+ LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'),
+ LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'),
+ LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'),
+ LayoutAxis('Z')};
const LayoutAxis LayoutAxis::LOWER_CASE[] = {
- LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'),
- LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'),
- LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'),
- LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'),
- LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'),
- LayoutAxis('z')
-};
+ LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'),
+ LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'),
+ LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'),
+ LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'),
+ LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'),
+ LayoutAxis('z')};
const LayoutAxis& LayoutAxis::Get(const char name) {
CHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z'))
- << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z.";
- return (name >= 'A' && name <= 'Z') ?
- LayoutAxis::UPPER_CASE[name-'A'] :
- LayoutAxis::LOWER_CASE[name-'a'];
+ << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z.";
+ return (name >= 'A' && name <= 'Z') ? LayoutAxis::UPPER_CASE[name - 'A']
+ : LayoutAxis::LOWER_CASE[name - 'a'];
}
const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) {
CHECK_GT(factor->value, 0);
repr << factor->value;
}
- CHECK_EQ(axis->var.get()->name_hint.size(), 1) << "Invalid layout axis "
- << axis->var.get()->name_hint;
+ CHECK_EQ(axis->var.get()->name_hint.size(), 1)
+ << "Invalid layout axis " << axis->var.get()->name_hint;
char c = axis->var.get()->name_hint[0];
CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c;
repr << axis->var.get()->name_hint;
data_ = std::move(node);
}
-Layout::Layout(const std::string& name) { // NOLINT(*)
+Layout::Layout(const std::string& name) { // NOLINT(*)
if (name == "__undef__") return;
auto node = make_object<LayoutNode>();
int32_t factor = 0;
for (char c : name) {
if (c >= 'A' && c <= 'Z') {
- CHECK_EQ(factor, 0) << "Invalid layout " << name
- << ": invalid factor size " << factor
+ CHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
<< " before dimension " << c;
std::string shape_name("_shape");
shape_name.insert(0, 1, c);
- IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)),
- Var(std::string(1, c)), tir::kDataPar);
+ IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)),
+ tir::kDataPar);
node->axes.push_back(axis);
} else if (c >= 'a' && c <= 'z') {
- CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
- << factor << " for dimension " << c;
- IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)),
- Var(std::string(1, c)), tir::kDataPar);
+ CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
+ << " for dimension " << c;
+ IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)),
+ tir::kDataPar);
node->axes.push_back(axis);
factor = 0;
} else if (c >= '0' && c <= '9') {
for (const IterVar& v : node->axes) {
char axis = v->var.get()->name_hint[0];
if (axis >= 'a' && axis <= 'z') {
- CHECK(exist_axis[axis-'a'+'A']) << "Invalid layout " << name << ": missing axis "
- << std::toupper(axis);
+ CHECK(exist_axis[axis - 'a' + 'A'])
+ << "Invalid layout " << name << ": missing axis " << std::toupper(axis);
}
}
data_ = std::move(node);
}
-Layout LayoutNode::make(const std::string& layout) {
- return Layout(layout);
-}
+Layout LayoutNode::make(const std::string& layout) { return Layout(layout); }
Layout Layout::SubLayout(size_t pos, size_t len) const {
if (!defined() || pos > ndim()) return Layout::Undef();
return Layout(new_layout);
}
-Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const {
+Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const {
if (!defined()) return Layout::Undef();
const std::string& name = operator->()->name;
const auto axes = operator->()->axes;
- CHECK(target_pos <= this->ndim()) << "Invalid split position "
- << target_pos << " for layout " << name;
+ CHECK(target_pos <= this->ndim())
+ << "Invalid split position " << target_pos << " for layout " << name;
CHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis;
CHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name;
- CHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis
- << " has already been split in " << name;
+ CHECK(!this->Contains(axis.ToSubordinate()))
+ << "Axis " << axis << " has already been split in " << name;
CHECK(factor > 0) << "Invalid split size " << factor;
Array<IterVar> new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<LayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* l = static_cast<const LayoutNode*>(node.get());
- p->stream << "Layout(" << l->name << ")";
- });
+ .set_dispatch<LayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* l = static_cast<const LayoutNode*>(node.get());
+ p->stream << "Layout(" << l->name << ")";
+ });
-inline bool GetStoreRule(Array<PrimExpr>* rule,
- const Layout& src_layout,
+inline bool GetStoreRule(Array<PrimExpr>* rule, const Layout& src_layout,
const Layout& dst_layout) {
- if (!src_layout.defined() || src_layout.name().empty() ||
- !dst_layout.defined() || dst_layout.name().empty()) {
+ if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() ||
+ dst_layout.name().empty()) {
return false;
}
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(src_index.size(), self->src_layout->axes.size())
- << "Input mismatch with layout " << self->src_layout;
+ << "Input mismatch with layout " << self->src_layout;
return TransformIndex(src_index, self->src_layout->axes, self->forward_rule);
}
-
Array<PrimExpr> BijectiveLayout::BackwardIndex(const Array<PrimExpr>& dst_index) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
- << "Output mismatch with layout " << self->dst_layout;
+ << "Output mismatch with layout " << self->dst_layout;
return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule);
}
const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImmNode>();
if (orig_shape_const) {
CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
- << "Input shape mismatch at index " << i << ". Expected "
- << orig_axis->dom->extent << ", get " << orig_shape;
+ << "Input shape mismatch at index " << i << ". Expected " << orig_axis->dom->extent
+ << ", get " << orig_shape;
}
}
bind_map[orig_axis->var.get()] = PrimExpr(0);
Array<PrimExpr> BijectiveLayout::ForwardShape(const Array<PrimExpr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
- return TransformShape(shape, self->src_layout->axes,
- self->dst_layout->axes, self->forward_rule);
+ return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->forward_rule);
}
Array<PrimExpr> BijectiveLayout::BackwardShape(const Array<PrimExpr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
- return TransformShape(shape, self->dst_layout->axes,
- self->src_layout->axes, self->backward_rule);
+ return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->backward_rule);
}
BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) {
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
- p->stream << "BijectiveLayout(" << b->src_layout.name()
- << "->" << b->dst_layout.name() << ")";
- });
+ .set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
+ p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name()
+ << ")";
+ });
-TVM_REGISTER_GLOBAL("tir.Layout")
-.set_body_typed(LayoutNode::make);
+TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed(LayoutNode::make);
-TVM_REGISTER_GLOBAL("tir.LayoutIndexOf")
-.set_body_typed([](Layout layout, std::string axis) -> int {
+TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("tir.LayoutFactorOf")
-.set_body_typed([](Layout layout, std::string axis) -> int {
- return layout.FactorOf(LayoutAxis::make(axis));
-});
+ .set_body_typed([](Layout layout, std::string axis) -> int {
+ return layout.FactorOf(LayoutAxis::make(axis));
+ });
-TVM_REGISTER_GLOBAL("tir.LayoutNdim")
-.set_body_typed([](Layout layout) -> int {
+TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int {
return layout.ndim();
});
-TVM_REGISTER_GLOBAL("tir.LayoutGetItem")
-.set_body_typed([](Layout layout, int idx) -> std::string {
+TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx];
return axis.name();
});
TVM_REGISTER_GLOBAL("tir.BijectiveLayout")
-.set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout {
- return BijectiveLayout(src_layout, dst_layout);
-});
+ .set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout {
+ return BijectiveLayout(src_layout, dst_layout);
+ });
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex")
-.set_body_method(&BijectiveLayout::ForwardIndex);
+ .set_body_method(&BijectiveLayout::ForwardIndex);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex")
-.set_body_method(&BijectiveLayout::BackwardIndex);
+ .set_body_method(&BijectiveLayout::BackwardIndex);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape")
-.set_body_method(&BijectiveLayout::ForwardShape);
+ .set_body_method(&BijectiveLayout::ForwardShape);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape")
-.set_body_method(&BijectiveLayout::BackwardShape);
+ .set_body_method(&BijectiveLayout::BackwardShape);
} // namespace tir
} // namespace tvm
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
-#include <memory>
+
#include <limits>
+#include <memory>
#include "../../support/str_escape.h"
data_ = std::move(n);
}
-
-TVM_REGISTER_GLOBAL("tir.Var")
-.set_body_typed([](std::string name_hint, runtime::TVMArgValue type) {
+TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](std::string name_hint, runtime::TVMArgValue type) {
if (type.IsObjectRef<Type>()) {
return Var(name_hint, type.operator Type());
} else {
}
});
-TVM_REGISTER_GLOBAL("tir.SizeVar")
-.set_body_typed([](std::string s, DataType t) {
- return SizeVar(s, t);
+TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](std::string s, DataType t) {
+ return SizeVar(s, t);
});
-
-IterVar IterVarNode::make(Range dom,
- Var var,
- IterVarType t,
- std::string thread_tag) {
+IterVar IterVarNode::make(Range dom, Var var, IterVarType t, std::string thread_tag) {
ObjectPtr<IterVarNode> n = make_object<IterVarNode>();
n->dom = dom;
n->var = var;
}
TVM_REGISTER_GLOBAL("tir.IterVar")
-.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
- return IterVarNode::make(
- dom, var,
- static_cast<IterVarType>(iter_type),
- thread_tag);
-});
+ .set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
+ return IterVarNode::make(dom, var, static_cast<IterVarType>(iter_type), thread_tag);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<IterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IterVarNode*>(node.get());
- p->stream << "iter_var(";
- if (op->var->name_hint.length() != 0) {
- p->stream << op->var->name_hint << ", ";
- }
- if (op->dom.defined()) {
- p->stream << op->dom;
- }
- if (op->thread_tag.length() != 0) {
- p->stream << ", " << op->thread_tag;
- }
- p->stream << ")";
- });
-
+ .set_dispatch<IterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const IterVarNode*>(node.get());
+ p->stream << "iter_var(";
+ if (op->var->name_hint.length() != 0) {
+ p->stream << op->var->name_hint << ", ";
+ }
+ if (op->dom.defined()) {
+ p->stream << op->dom;
+ }
+ if (op->thread_tag.length() != 0) {
+ p->stream << ", " << op->thread_tag;
+ }
+ p->stream << ")";
+ });
TVM_REGISTER_NODE_TYPE(IterVarNode);
return PrimExpr(node);
}
-TVM_REGISTER_GLOBAL("tir.StringImm")
-.set_body_typed(StringImmNode::make);
-
+TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed(StringImmNode::make);
PrimExpr CastNode::make(DataType t, PrimExpr value) {
CHECK(value.defined());
return PrimExpr(node);
}
-
PrimExpr AndNode::make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
return PrimExpr(node);
}
-
PrimExpr NotNode::make(PrimExpr a) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(a.dtype().is_bool());
return PrimExpr(node);
}
-
-
PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined";
CHECK(false_value.defined()) << "ValueError: true_value is undefined";
CHECK(condition.dtype().is_bool());
- CHECK(condition.dtype().lanes() == true_value.dtype().lanes() ||
- condition.dtype().lanes() == 1);
+ CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1);
CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types";
ObjectPtr<SelectNode> node = make_object<SelectNode>();
return PrimExpr(node);
}
-const char* CallNode::vectorizable_intrinsics[] = {
- "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
- "log", "sin", "cos", "pow", "tan", tir::CallNode::shift_left, tir::CallNode::shift_right,
- tir::CallNode::likely, tir::CallNode::popcount
-};
+const char* CallNode::vectorizable_intrinsics[] = {"floor",
+ "ceil",
+ "sign",
+ "trunc",
+ "fabs",
+ "round",
+ "exp",
+ "tanh",
+ "sqrt",
+ "log",
+ "sin",
+ "cos",
+ "pow",
+ "tan",
+ tir::CallNode::shift_left,
+ tir::CallNode::shift_right,
+ tir::CallNode::likely,
+ tir::CallNode::popcount};
bool CallNode::is_vectorizable() const {
size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*);
return false;
}
-PrimExpr CallNode::make(DataType dtype,
- std::string name,
- Array<PrimExpr> args,
- CallType call_type,
- FunctionRef func,
- int value_index) {
+PrimExpr CallNode::make(DataType dtype, std::string name, Array<PrimExpr> args, CallType call_type,
+ FunctionRef func, int value_index) {
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args[i].defined());
}
return PrimExpr(node);
}
-PrimExpr ShuffleNode::make(Array<PrimExpr> vectors,
- Array<PrimExpr> indices) {
+PrimExpr ShuffleNode::make(Array<PrimExpr> vectors, Array<PrimExpr> indices) {
CHECK_NE(vectors.size(), 0U);
CHECK_NE(indices.size(), 0U);
return make({vector}, {Integer(index)});
}
-CommReducer CommReducerNode::make(Array<Var> lhs,
- Array<Var> rhs,
- Array<PrimExpr> result,
+CommReducer CommReducerNode::make(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
Array<PrimExpr> identity_element) {
auto node = make_object<CommReducerNode>();
node->lhs = lhs;
value_map.Set(rhs[i], b[i]);
}
auto ret = this->result;
- ret.MutateByApply([&value_map] (const PrimExpr& e) {
- return Substitute(e, value_map);
- });
+ ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); });
return ret;
}
-TVM_REGISTER_GLOBAL("tir.CommReducer")
-.set_body_typed(CommReducerNode::make);
+TVM_REGISTER_GLOBAL("tir.CommReducer").set_body_typed(CommReducerNode::make);
TVM_REGISTER_GLOBAL("tir.CommReducerCombine")
-.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
-
+ .set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
-PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
- Array<IterVar> axis, PrimExpr condition, int value_index) {
+PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
+ PrimExpr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) {
- CHECK_EQ(axis[i]->iter_type, kCommReduce)
- << "Can only take axis created by reduce_axis";
+ CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis";
}
if (!condition.defined()) {
condition = const_true();
return PrimExpr(n);
}
-
-TVM_REGISTER_GLOBAL("tir.Reduce")
-.set_body_typed(ReduceNode::make);
-
+TVM_REGISTER_GLOBAL("tir.Reduce").set_body_typed(ReduceNode::make);
PrimExpr AnyNode::make() {
auto n = make_object<AnyNode>();
data_ = std::move(node);
}
-TVM_REGISTER_GLOBAL("tir.BufferLoad")
-.set_body_typed([](Buffer buffer, Array<PrimExpr> indices) {
+TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array<PrimExpr> indices) {
return BufferLoad(buffer, indices);
});
TVM_REGISTER_NODE_TYPE(BufferLoadNode);
-
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const StringImmNode*>(node.get());
- p->stream << '\"' << support::StrEscape(op->value) << '\"';
-});
+ .set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const StringImmNode*>(node.get());
+ p->stream << '\"' << support::StrEscape(op->value) << '\"';
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<CastNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const CastNode*>(node.get());
- p->stream << op->dtype << '(';
- p->Print(op->value);
- p->stream << ')';
- })
-.set_dispatch<VarNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const VarNode*>(node.get());
- // omit the type
- // stream << op->name << "." << op->type;
- p->stream << op->name_hint;
- })
-.set_dispatch<SizeVarNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SizeVarNode*>(node.get());
- p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
- })
-.set_dispatch<AddNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AddNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " + ";
- p->Print(op->b);
- p->stream << ')';
- })
-.set_dispatch<SubNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SubNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " - ";
- p->Print(op->b);
- p->stream << ')';
- })
-.set_dispatch<MulNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const MulNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << "*";
- p->Print(op->b);
- p->stream << ')';
- })
-.set_dispatch<DivNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const DivNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << "/";
- p->Print(op->b);
- p->stream << ')';
- })
-.set_dispatch<ModNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ModNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " % ";
- p->Print(op->b);
- p->stream << ')';
-})
-.set_dispatch<MinNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const MinNode*>(node.get());
- p->stream << "min(";
- p->Print(op->a);
- p->stream << ", ";
- p->Print(op->b);
- p->stream << ")";
-})
-.set_dispatch<MaxNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const MaxNode*>(node.get());
- p->stream << "max(";
- p->Print(op->a);
- p->stream << ", ";
- p->Print(op->b);
- p->stream << ")";
-})
-.set_dispatch<EQNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const EQNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " == ";
- p->Print(op->b);
- p->stream << ')';
-})
-.set_dispatch<NENode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const NENode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " != ";
- p->Print(op->b);
- p->stream << ')';
-})
-.set_dispatch<LTNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LTNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " < ";
- p->Print(op->b);
- p->stream << ')';
-})
-.set_dispatch<LENode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LENode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " <= ";
- p->Print(op->b);
- p->stream << ')';
-})
-.set_dispatch<GTNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const GTNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " > ";
- p->Print(op->b);
- p->stream << ')';
-})
-.set_dispatch<GENode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const GENode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " >= ";
- p->Print(op->b);
- p->stream << ')';
-});
+ .set_dispatch<CastNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const CastNode*>(node.get());
+ p->stream << op->dtype << '(';
+ p->Print(op->value);
+ p->stream << ')';
+ })
+ .set_dispatch<VarNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const VarNode*>(node.get());
+ // omit the type
+ // stream << op->name << "." << op->type;
+ p->stream << op->name_hint;
+ })
+ .set_dispatch<SizeVarNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const SizeVarNode*>(node.get());
+ p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
+ })
+ .set_dispatch<AddNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const AddNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " + ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<SubNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const SubNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " - ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<MulNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const MulNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << "*";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<DivNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const DivNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << "/";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<ModNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ModNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " % ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<MinNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const MinNode*>(node.get());
+ p->stream << "min(";
+ p->Print(op->a);
+ p->stream << ", ";
+ p->Print(op->b);
+ p->stream << ")";
+ })
+ .set_dispatch<MaxNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const MaxNode*>(node.get());
+ p->stream << "max(";
+ p->Print(op->a);
+ p->stream << ", ";
+ p->Print(op->b);
+ p->stream << ")";
+ })
+ .set_dispatch<EQNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const EQNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " == ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<NENode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const NENode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " != ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<LTNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const LTNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " < ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<LENode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const LENode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " <= ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<GTNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const GTNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " > ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+ .set_dispatch<GENode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const GENode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " >= ";
+ p->Print(op->b);
+ p->stream << ')';
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<FloorDivNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const FloorDivNode*>(node.get());
- p->stream << "floordiv(" << op->a << ", " << op->b << ")";
-});
+ .set_dispatch<FloorDivNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const FloorDivNode*>(node.get());
+ p->stream << "floordiv(" << op->a << ", " << op->b << ")";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<FloorModNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const FloorModNode*>(node.get());
- p->stream << "floormod(" << op->a << ", " << op->b << ")";
-});
+ .set_dispatch<FloorModNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const FloorModNode*>(node.get());
+ p->stream << "floormod(" << op->a << ", " << op->b << ")";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<AndNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AndNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " && ";
- p->Print(op->b);
- p->stream << ')';
-});
+ .set_dispatch<AndNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const AndNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " && ";
+ p->Print(op->b);
+ p->stream << ')';
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<OrNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const OrNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " || ";
- p->Print(op->b);
- p->stream << ')';
-});
+ .set_dispatch<OrNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const OrNode*>(node.get());
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " || ";
+ p->Print(op->b);
+ p->stream << ')';
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<NotNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const NotNode*>(node.get());
- p->stream << '!';
- p->Print(op->a);
-});
+ .set_dispatch<NotNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const NotNode*>(node.get());
+ p->stream << '!';
+ p->Print(op->a);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<SelectNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SelectNode*>(node.get());
- p->stream << "select(";
- p->Print(op->condition);
- p->stream << ", ";
- p->Print(op->true_value);
- p->stream << ", ";
- p->Print(op->false_value);
- p->stream << ")";
-});
+ .set_dispatch<SelectNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const SelectNode*>(node.get());
+ p->stream << "select(";
+ p->Print(op->condition);
+ p->stream << ", ";
+ p->Print(op->true_value);
+ p->stream << ", ";
+ p->Print(op->false_value);
+ p->stream << ")";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<LoadNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LoadNode*>(node.get());
- p->stream << op->buffer_var << "[";
- p->Print(op->index);
- p->stream << "]";
- if (!is_one(op->predicate)) {
+ .set_dispatch<LoadNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const LoadNode*>(node.get());
+ p->stream << op->buffer_var << "[";
+ p->Print(op->index);
+ p->stream << "]";
+ if (!is_one(op->predicate)) {
p->stream << " if ";
p->Print(op->predicate);
- }
-});
+ }
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<RampNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const RampNode*>(node.get());
- p->stream << "ramp(";
- p->Print(op->base);
- p->stream << ", ";
- p->Print(op->stride);
- p->stream << ", " << op->lanes << ")";
-});
+ .set_dispatch<RampNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const RampNode*>(node.get());
+ p->stream << "ramp(";
+ p->Print(op->base);
+ p->stream << ", ";
+ p->Print(op->stride);
+ p->stream << ", " << op->lanes << ")";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<BroadcastNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BroadcastNode*>(node.get());
- p->stream << "x" << op->lanes << "(";
- p->Print(op->value);
- p->stream << ")";
-});
+ .set_dispatch<BroadcastNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BroadcastNode*>(node.get());
+ p->stream << "x" << op->lanes << "(";
+ p->Print(op->value);
+ p->stream << ")";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<CallNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const CallNode*>(node.get());
- p->stream << op->name << "(";
- for (size_t i = 0; i < op->args.size(); ++i) {
- p->Print(op->args[i]);
- if (i < op->args.size() - 1) {
- p->stream << ", ";
+ .set_dispatch<CallNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const CallNode*>(node.get());
+ p->stream << op->name << "(";
+ for (size_t i = 0; i < op->args.size(); ++i) {
+ p->Print(op->args[i]);
+ if (i < op->args.size() - 1) {
+ p->stream << ", ";
+ }
}
- }
- p->stream << ")";
- });
+ p->stream << ")";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferLoadNode*>(node.get());
- p->stream << op->buffer->name << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) {
- p->stream << ", ";
+ .set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferLoadNode*>(node.get());
+ p->stream << op->buffer->name << "[";
+ for (size_t i = 0; i < op->indices.size(); ++i) {
+ p->Print(op->indices[i]);
+ if (i < op->indices.size() - 1) {
+ p->stream << ", ";
+ }
}
- }
- p->stream << "]";
- });
+ p->stream << "]";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LetNode*>(node.get());
- p->stream << "(let " << op->var << " = ";
- p->Print(op->value);
- p->stream << " in ";
- p->Print(op->body);
- p->stream << ")";
-});
+ .set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const LetNode*>(node.get());
+ p->stream << "(let " << op->var << " = ";
+ p->Print(op->value);
+ p->stream << " in ";
+ p->Print(op->body);
+ p->stream << ")";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<AnyNode>([](const ObjectRef& node, ReprPrinter* p) {
- p->stream << "?";
-});
+ .set_dispatch<AnyNode>([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ReduceNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ReduceNode*>(node.get());
- p->stream << "reduce(combiner="
- << op->combiner;
- p->stream << ", source=" << op->source;
- p->stream << ", axis=" << op->axis;
- p->stream << ", where=" << op->condition;
- p->stream << ", value_index=" << op->value_index;
- p->stream << ")";
- });
+ .set_dispatch<ReduceNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ReduceNode*>(node.get());
+ p->stream << "reduce(combiner=" << op->combiner;
+ p->stream << ", source=" << op->source;
+ p->stream << ", axis=" << op->axis;
+ p->stream << ", where=" << op->condition;
+ p->stream << ", value_index=" << op->value_index;
+ p->stream << ")";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<CommReducerNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const CommReducerNode*>(node.get());
- p->stream << "comm_reducer(result=" << op->result
- << ", lhs=" << op->lhs
- << ", rhs=" << op->rhs
- << ", identity_element=" << op->identity_element
- << ")";
- });
+ .set_dispatch<CommReducerNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const CommReducerNode*>(node.get());
+ p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs
+ << ", rhs=" << op->rhs << ", identity_element=" << op->identity_element << ")";
+ });
TVM_REGISTER_NODE_TYPE(StringImmNode);
TVM_REGISTER_NODE_TYPE(CastNode);
TVM_REGISTER_NODE_TYPE(ReduceNode);
TVM_REGISTER_NODE_TYPE(AnyNode);
+TVM_REGISTER_GLOBAL("tir.Add").set_body_typed(AddNode::make);
-TVM_REGISTER_GLOBAL("tir.Add")
-.set_body_typed(AddNode::make);
+TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed(SubNode::make);
-TVM_REGISTER_GLOBAL("tir.Sub")
-.set_body_typed(SubNode::make);
+TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed(MulNode::make);
-TVM_REGISTER_GLOBAL("tir.Mul")
-.set_body_typed(MulNode::make);
+TVM_REGISTER_GLOBAL("tir.Div").set_body_typed(DivNode::make);
-TVM_REGISTER_GLOBAL("tir.Div")
-.set_body_typed(DivNode::make);
+TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed(ModNode::make);
-TVM_REGISTER_GLOBAL("tir.Mod")
-.set_body_typed(ModNode::make);
+TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed(FloorDivNode::make);
-TVM_REGISTER_GLOBAL("tir.FloorDiv")
-.set_body_typed(FloorDivNode::make);
+TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed(FloorModNode::make);
-TVM_REGISTER_GLOBAL("tir.FloorMod")
-.set_body_typed(FloorModNode::make);
+TVM_REGISTER_GLOBAL("tir.Min").set_body_typed(MinNode::make);
-TVM_REGISTER_GLOBAL("tir.Min")
-.set_body_typed(MinNode::make);
+TVM_REGISTER_GLOBAL("tir.Max").set_body_typed(MaxNode::make);
-TVM_REGISTER_GLOBAL("tir.Max")
-.set_body_typed(MaxNode::make);
+TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed(EQNode::make);
-TVM_REGISTER_GLOBAL("tir.EQ")
-.set_body_typed(EQNode::make);
+TVM_REGISTER_GLOBAL("tir.NE").set_body_typed(NENode::make);
-TVM_REGISTER_GLOBAL("tir.NE")
-.set_body_typed(NENode::make);
+TVM_REGISTER_GLOBAL("tir.LT").set_body_typed(LTNode::make);
-TVM_REGISTER_GLOBAL("tir.LT")
-.set_body_typed(LTNode::make);
+TVM_REGISTER_GLOBAL("tir.LE").set_body_typed(LENode::make);
-TVM_REGISTER_GLOBAL("tir.LE")
-.set_body_typed(LENode::make);
+TVM_REGISTER_GLOBAL("tir.GT").set_body_typed(GTNode::make);
-TVM_REGISTER_GLOBAL("tir.GT")
-.set_body_typed(GTNode::make);
+TVM_REGISTER_GLOBAL("tir.GE").set_body_typed(GENode::make);
-TVM_REGISTER_GLOBAL("tir.GE")
-.set_body_typed(GENode::make);
+TVM_REGISTER_GLOBAL("tir.And").set_body_typed(AndNode::make);
-TVM_REGISTER_GLOBAL("tir.And")
-.set_body_typed(AndNode::make);
+TVM_REGISTER_GLOBAL("tir.Or").set_body_typed(OrNode::make);
-TVM_REGISTER_GLOBAL("tir.Or")
-.set_body_typed(OrNode::make);
+TVM_REGISTER_GLOBAL("tir.Not").set_body_typed(NotNode::make);
-TVM_REGISTER_GLOBAL("tir.Not")
-.set_body_typed(NotNode::make);
+TVM_REGISTER_GLOBAL("tir.Select").set_body_typed(SelectNode::make);
-TVM_REGISTER_GLOBAL("tir.Select")
-.set_body_typed(SelectNode::make);
+TVM_REGISTER_GLOBAL("tir.Ramp").set_body_typed(RampNode::make);
-TVM_REGISTER_GLOBAL("tir.Ramp")
-.set_body_typed(RampNode::make);
+TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed(CastNode::make);
-TVM_REGISTER_GLOBAL("tir.Cast")
-.set_body_typed(CastNode::make);
+TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed(BroadcastNode::make);
-TVM_REGISTER_GLOBAL("tir.Broadcast")
-.set_body_typed(BroadcastNode::make);
+TVM_REGISTER_GLOBAL("tir.Shuffle").set_body_typed(ShuffleNode::make);
-TVM_REGISTER_GLOBAL("tir.Shuffle")
-.set_body_typed(ShuffleNode::make);
-
-TVM_REGISTER_GLOBAL("tir.Let")
-.set_body_typed(LetNode::make);
-
-TVM_REGISTER_GLOBAL("tir.Load")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DataType t = args[0];
- if (args.size() == 3) {
- *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
- } else {
- *ret = LoadNode::make(t, args[1], args[2], args[3]);
- }
- });
+TVM_REGISTER_GLOBAL("tir.Let").set_body_typed(LetNode::make);
-TVM_REGISTER_GLOBAL("tir.Call")
-.set_body_typed([](
- DataType type, std::string name,
- Array<ObjectRef> args, int call_type,
- FunctionRef func, int value_index
-) {
- Array<PrimExpr> prim_expr_args;
- for (const auto& it : args) {
- CHECK(it->IsInstance<runtime::StringObj>() ||
- it->IsInstance<PrimExprNode>());
- if (const auto* str = it.as<runtime::StringObj>()) {
- prim_expr_args.push_back(StringImmNode::make(str->data));
- } else {
- prim_expr_args.push_back(Downcast<PrimExpr>(it));
- }
+TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) {
+ DataType t = args[0];
+ if (args.size() == 3) {
+ *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
+ } else {
+ *ret = LoadNode::make(t, args[1], args[2], args[3]);
}
- return CallNode::make(type,
- name,
- prim_expr_args,
- static_cast<CallNode::CallType>(call_type),
- func,
- value_index);
});
+TVM_REGISTER_GLOBAL("tir.Call")
+ .set_body_typed([](DataType type, std::string name, Array<ObjectRef> args, int call_type,
+ FunctionRef func, int value_index) {
+ Array<PrimExpr> prim_expr_args;
+ for (const auto& it : args) {
+ CHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>());
+ if (const auto* str = it.as<runtime::StringObj>()) {
+ prim_expr_args.push_back(StringImmNode::make(str->data));
+ } else {
+ prim_expr_args.push_back(Downcast<PrimExpr>(it));
+ }
+ }
+ return CallNode::make(type, name, prim_expr_args, static_cast<CallNode::CallType>(call_type),
+ func, value_index);
+ });
+
} // namespace tir
} // namespace tvm
* \file expr_functor.cc
*/
#include <tvm/tir/expr_functor.h>
+
#include "functor_common.h"
namespace tvm {
VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
-#define DEFINE_BINOP_VISIT_(OP) \
- void ExprVisitor::VisitExpr_(const OP* op) { \
- this->VisitExpr(op->a); \
- this->VisitExpr(op->b); \
+#define DEFINE_BINOP_VISIT_(OP) \
+ void ExprVisitor::VisitExpr_(const OP* op) { \
+ this->VisitExpr(op->a); \
+ this->VisitExpr(op->b); \
}
DEFINE_BINOP_VISIT_(AddNode);
void ExprVisitor::VisitExpr_(const ReduceNode* op) {
VisitArray(op->axis, [this](const IterVar& r) {
- this->VisitExpr(r->dom->min);
- this->VisitExpr(r->dom->extent);
- });
+ this->VisitExpr(r->dom->min);
+ this->VisitExpr(r->dom->extent);
+ });
VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitExpr(op->condition);
}
-void ExprVisitor::VisitExpr_(const CastNode* op) {
- this->VisitExpr(op->value);
-}
+void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); }
-void ExprVisitor::VisitExpr_(const NotNode* op) {
- this->VisitExpr(op->a);
-}
+void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); }
void ExprVisitor::VisitExpr_(const SelectNode* op) {
this->VisitExpr(op->condition);
VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
-void ExprVisitor::VisitExpr_(const BroadcastNode* op) {
- this->VisitExpr(op->value);
-}
+void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); }
-PrimExpr ExprMutator::VisitExpr_(const VarNode* op) {
- return GetRef<PrimExpr>(op);
-}
+PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef<PrimExpr>(op); }
PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
return this->VisitExpr_(static_cast<const VarNode*>(op));
PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value);
PrimExpr body = this->VisitExpr(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
if (args.same_as(op->args)) {
return GetRef<PrimExpr>(op);
} else {
- return CallNode::make(op->dtype,
- op->name,
- args,
- op->call_type,
- op->func,
- op->value_index);
+ return CallNode::make(op->dtype, op->name, args, op->call_type, op->func, op->value_index);
}
}
-#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
- PrimExpr ExprMutator::VisitExpr_(const OP *op) { \
- return GetRef<PrimExpr>(op); \
- }
+#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
+ PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef<PrimExpr>(op); }
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode)
-#define DEFINE_BIOP_EXPR_MUTATE_(OP) \
- PrimExpr ExprMutator::VisitExpr_(const OP* op) { \
- PrimExpr a = this->VisitExpr(op->a); \
- PrimExpr b = this->VisitExpr(op->b); \
- if (a.same_as(op->a) && \
- b.same_as(op->b)) { \
- return GetRef<PrimExpr>(op); \
- } else { \
- return OP::make(a, b); \
- } \
+#define DEFINE_BIOP_EXPR_MUTATE_(OP) \
+ PrimExpr ExprMutator::VisitExpr_(const OP* op) { \
+ PrimExpr a = this->VisitExpr(op->a); \
+ PrimExpr b = this->VisitExpr(op->b); \
+ if (a.same_as(op->a) && b.same_as(op->b)) { \
+ return GetRef<PrimExpr>(op); \
+ } else { \
+ return OP::make(a, b); \
+ } \
}
DEFINE_BIOP_EXPR_MUTATE_(AddNode);
DEFINE_BIOP_EXPR_MUTATE_(OrNode);
PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) {
- auto fitervar = [this](const IterVar& v) {
+ auto fitervar = [this](const IterVar& v) {
Range r = v->dom;
PrimExpr min = this->VisitExpr(r->min);
PrimExpr extent = this->VisitExpr(r->extent);
- if (min.same_as(r->min) &&
- extent.same_as(r->extent)) {
+ if (min.same_as(r->min) && extent.same_as(r->extent)) {
return v;
} else {
- return IterVarNode::make(
- Range::make_by_min_extent(min, extent),
- v->var, v->iter_type, v->thread_tag);
+ return IterVarNode::make(Range::make_by_min_extent(min, extent), v->var, v->iter_type,
+ v->thread_tag);
}
};
Array<IterVar> axis = MutateArray(op->axis, fitervar);
PrimExpr condition = this->VisitExpr(op->condition);
- if (axis.same_as(op->axis) &&
- source.same_as(op->source) &&
- condition.same_as(op->condition)) {
+ if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition)) {
return GetRef<PrimExpr>(op);
} else {
- return ReduceNode::make(
- op->combiner, source, axis, condition, op->value_index);
+ return ReduceNode::make(op->combiner, source, axis, condition, op->value_index);
}
}
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr true_value = this->VisitExpr(op->true_value);
PrimExpr false_value = this->VisitExpr(op->false_value);
- if (condition.same_as(op->condition) &&
- true_value.same_as(op->true_value) &&
+ if (condition.same_as(op->condition) && true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
} else {
PrimExpr ExprMutator::VisitExpr_(const RampNode* op) {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
- if (base.same_as(op->base) &&
- stride.same_as(op->stride)) {
+ if (base.same_as(op->base) && stride.same_as(op->stride)) {
return GetRef<PrimExpr>(op);
} else {
return RampNode::make(base, stride, op->lanes);
namespace tir {
// Get the function type of a PrimFunc
-PrimFunc::PrimFunc(Array<tir::Var> params,
- Stmt body,
- Type ret_type,
- Map<tir::Var, Buffer> buffer_map,
- DictAttrs attrs) {
+PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
+ Map<tir::Var, Buffer> buffer_map, DictAttrs attrs) {
// Assume void-return type for now
// TODO(tvm-team) consider type deduction from body.
if (!ret_type.defined()) {
TVM_REGISTER_NODE_TYPE(PrimFuncNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
- // TODO(tvm-team) redirect to Text printer once we have a good text format.
- auto* node = static_cast<const PrimFuncNode*>(ref.get());
- p->stream << "PrimFunc(" << node->params << ") ";
- if (node->attrs.defined()) {
- p->stream << "attrs=" << node->attrs;
- }
- p->stream << " {\n";
- p->indent += 2;
- p->Print(node->body);
- p->indent -= 2;
- p->stream << "}\n";
-});
-
+ .set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ // TODO(tvm-team) redirect to Text printer once we have a good text format.
+ auto* node = static_cast<const PrimFuncNode*>(ref.get());
+ p->stream << "PrimFunc(" << node->params << ") ";
+ if (node->attrs.defined()) {
+ p->stream << "attrs=" << node->attrs;
+ }
+ p->stream << " {\n";
+ p->indent += 2;
+ p->Print(node->body);
+ p->indent -= 2;
+ p->stream << "}\n";
+ });
TVM_REGISTER_GLOBAL("tir.PrimFunc")
-.set_body_typed([](Array<tir::Var> params,
- Stmt body,
- Type ret_type,
- Map<tir::Var, Buffer> buffer_map,
- DictAttrs attrs) {
- return PrimFunc(params, body, ret_type, buffer_map, attrs);
-});
+ .set_body_typed([](Array<tir::Var> params, Stmt body, Type ret_type,
+ Map<tir::Var, Buffer> buffer_map, DictAttrs attrs) {
+ return PrimFunc(params, body, ret_type, buffer_map, attrs);
+ });
} // namespace tir
} // namespace tvm
namespace tir {
// Implementation of Visitors
-template<typename T, typename F>
+template <typename T, typename F>
inline void VisitArray(const Array<T>& arr, F fvisit) {
for (size_t i = 0; i < arr.size(); i++) {
fvisit(arr[i]);
}
// Implementation of mutators
-template<typename T, typename F>
-inline Array<T> MutateArray(const Array<T>& arr,
- F fmutate,
- bool allow_copy_on_write = false) {
+template <typename T, typename F>
+inline Array<T> MutateArray(const Array<T>& arr, F fmutate, bool allow_copy_on_write = false) {
if (allow_copy_on_write) {
// if we allow copy on write, we can directly
// call the inplace mutate function.
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
+
#include <cmath>
// Centralized header for constant folders.
#include "../../arith/const_fold.h"
using namespace tir;
-
runtime::DataType GetRuntimeDataType(const Type& type) {
- if (auto * n = type.as<PrimTypeNode>()) {
+ if (auto* n = type.as<PrimTypeNode>()) {
return n->dtype;
} else if (type.as<PointerTypeNode>()) {
return DataType::Handle();
} else if (IsVoidType(type)) {
return DataType::Void();
} else {
- LOG(FATAL) << "Type " << type
- << " does not have a corresponding runtime::DataType";
+ LOG(FATAL) << "Type " << type << " does not have a corresponding runtime::DataType";
return DataType::Handle();
}
}
PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) {
return tir::CallNode::make(
t, tir::intrinsic::tvm_large_uint_imm,
- {make_const(DataType::UInt(32), low),
- make_const(DataType::UInt(32), high)},
+ {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)},
tir::CallNode::PureIntrinsic);
}
} else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
rhs = tir::BroadcastNode::make(rhs, ltype.lanes());
} else {
- CHECK(ltype.lanes() == rtype.lanes())
- << "Cannot match type " << ltype << " vs " << rtype;
+ CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype;
}
if (lhs.dtype() == rhs.dtype()) return;
// Only do very simple type coversion
}
namespace tir {
-template<typename ValueType>
-inline bool ConstPowerHelper(ValueType val, int *shift) {
+template <typename ValueType>
+inline bool ConstPowerHelper(ValueType val, int* shift) {
if (val <= 0) return false;
shift[0] = 0;
while (val != 0) {
PrimExpr reinterpret(const DataType& t, PrimExpr value) {
if (value.dtype() == t) return value;
- return tir::CallNode::make(
- t, tir::CallNode::reinterpret, { value }, tir::CallNode::PureIntrinsic);
+ return tir::CallNode::make(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic);
}
PrimExpr operator+(PrimExpr a, PrimExpr b) {
// negation
PrimExpr operator-(PrimExpr a) {
- using tir::IntImmNode;
using tir::FloatImmNode;
+ using tir::IntImmNode;
const IntImmNode* pa = a.as<IntImmNode>();
const FloatImmNode* fa = a.as<FloatImmNode>();
if (pa) return IntImm(a.dtype(), -pa->value);
return tir::ModNode::make(a, b);
}
-PrimExpr operator/(PrimExpr a, PrimExpr b) {
- return div(a, b);
-}
+PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); }
-PrimExpr operator%(PrimExpr a, PrimExpr b) {
- return truncmod(a, b);
-}
+PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); }
// TODO(tqchen): switch to floordiv
-PrimExpr indexdiv(PrimExpr a, PrimExpr b) {
- return floordiv(a, b);
-}
+PrimExpr indexdiv(PrimExpr a, PrimExpr b) { return floordiv(a, b); }
-PrimExpr indexmod(PrimExpr a, PrimExpr b) {
- return floormod(a, b);
-}
+PrimExpr indexmod(PrimExpr a, PrimExpr b) { return floormod(a, b); }
PrimExpr floordiv(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
PrimExpr min(PrimExpr a, PrimExpr b) {
// inf-aware simplificaiton
- using arith::is_pos_inf;
using arith::is_neg_inf;
+ using arith::is_pos_inf;
if (is_pos_inf(a)) return b;
if (is_neg_inf(a)) return a;
if (is_pos_inf(b)) return a;
PrimExpr max(PrimExpr a, PrimExpr b) {
// inf-aware simplificaiton
- using arith::is_pos_inf;
using arith::is_neg_inf;
+ using arith::is_pos_inf;
if (is_pos_inf(a)) return a;
if (is_neg_inf(a)) return b;
if (is_pos_inf(b)) return b;
return false_value;
}
}
- return tir::CallNode::make(
- true_value.dtype(),
- tir::intrinsic::tvm_if_then_else,
- {cond, true_value, false_value},
- tir::CallNode::PureIntrinsic);
+ return tir::CallNode::make(true_value.dtype(), tir::intrinsic::tvm_if_then_else,
+ {cond, true_value, false_value}, tir::CallNode::PureIntrinsic);
}
PrimExpr likely(PrimExpr cond) {
if (is_const(cond)) return cond;
- return tir::CallNode::make(cond.dtype(),
- tir::CallNode::likely,
- { cond },
- tir::CallNode::PureIntrinsic);
+ return tir::CallNode::make(cond.dtype(), tir::CallNode::likely, {cond},
+ tir::CallNode::PureIntrinsic);
}
PrimExpr operator>(PrimExpr a, PrimExpr b) {
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
- "Shift amount must be non-negative and less than " << rtype.bits()
- << " for type " << rtype;
- if (pa && pb) return IntImm(rtype, (pa->value >> pb->value));
- if (pb) {
- if (pb->value == 0) return a;
- }
- });
- return tir::CallNode::make(
- a.dtype(), tir::CallNode::shift_right, { a, b }, tir::CallNode::PureIntrinsic);
+ const DataType& rtype = a.dtype();
+ if (pb)
+ CHECK(pb->value >= 0 && pb->value < rtype.bits())
+ << "Shift amount must be non-negative and less than " << rtype.bits() << " for type "
+ << rtype;
+ if (pa && pb) return IntImm(rtype, (pa->value >> pb->value));
+ if (pb) {
+ if (pb->value == 0) return a;
+ }
+ });
+ return tir::CallNode::make(a.dtype(), tir::CallNode::shift_right, {a, b},
+ tir::CallNode::PureIntrinsic);
}
PrimExpr operator<<(PrimExpr a, PrimExpr b) {
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
- "Shift amount must be non-negative and less than " << rtype.bits()
- << " for type " << rtype;
- if (pa && pb) return IntImm(rtype, (pa->value << pb->value));
- if (pb) {
- if (pb->value == 0) return a;
- }
- });
- return tir::CallNode::make(
- a.dtype(), tir::CallNode::shift_left, { a, b }, tir::CallNode::PureIntrinsic);
+ const DataType& rtype = a.dtype();
+ if (pb)
+ CHECK(pb->value >= 0 && pb->value < rtype.bits())
+ << "Shift amount must be non-negative and less than " << rtype.bits() << " for type "
+ << rtype;
+ if (pa && pb) return IntImm(rtype, (pa->value << pb->value));
+ if (pb) {
+ if (pb->value == 0) return a;
+ }
+ });
+ return tir::CallNode::make(a.dtype(), tir::CallNode::shift_left, {a, b},
+ tir::CallNode::PureIntrinsic);
}
PrimExpr operator&(PrimExpr a, PrimExpr b) {
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm(rtype, (pa->value & pb->value));
- });
- return tir::CallNode::make(
- a.dtype(), tir::CallNode::bitwise_and, { a, b }, tir::CallNode::PureIntrinsic);
+ const DataType& rtype = a.dtype();
+ if (pa && pb) return IntImm(rtype, (pa->value & pb->value));
+ });
+ return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_and, {a, b},
+ tir::CallNode::PureIntrinsic);
}
PrimExpr operator|(PrimExpr a, PrimExpr b) {
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm(rtype, (pa->value | pb->value));
- });
- return tir::CallNode::make(
- a.dtype(), tir::CallNode::bitwise_or, { a, b }, tir::CallNode::PureIntrinsic);
+ const DataType& rtype = a.dtype();
+ if (pa && pb) return IntImm(rtype, (pa->value | pb->value));
+ });
+ return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_or, {a, b},
+ tir::CallNode::PureIntrinsic);
}
PrimExpr operator^(PrimExpr a, PrimExpr b) {
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value));
- });
- return tir::CallNode::make(
- a.dtype(), tir::CallNode::bitwise_xor, { a, b }, tir::CallNode::PureIntrinsic);
+ const DataType& rtype = a.dtype();
+ if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value));
+ });
+ return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_xor, {a, b},
+ tir::CallNode::PureIntrinsic);
}
PrimExpr operator~(PrimExpr a) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
- return tir::CallNode::make(
- a.dtype(), tir::CallNode::bitwise_not, { a }, tir::CallNode::PureIntrinsic);
+ return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_not, {a},
+ tir::CallNode::PureIntrinsic);
}
PrimExpr pow(PrimExpr x, PrimExpr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "power only applies to float";
- return tir::CallNode::make(
- x.dtype(), "pow", { x, y }, tir::CallNode::PureIntrinsic);
+ return tir::CallNode::make(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic);
}
PrimExpr abs(PrimExpr x) {
return x;
} else {
LOG(FATAL) << "Data type " << x.dtype()
- <<" not supported for absolute op. Skipping absolute op...";
+ << " not supported for absolute op. Skipping absolute op...";
return x;
}
}
}
if (x.dtype().bits() == 16) {
return tir::CallNode::make(t, tir::CallNode::isnan,
- {cast(DataType::Float(32, t.lanes()), std::move(x))},
- tir::CallNode::PureIntrinsic);
+ {cast(DataType::Float(32, t.lanes()), std::move(x))},
+ tir::CallNode::PureIntrinsic);
} else {
return tir::CallNode::make(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic);
}
} else {
- LOG(FATAL) << "Data type " << x.dtype()
- <<" not supported for isnan op. Skipping isnan op...";
+ LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op...";
return x;
}
}
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::AddNode::make(x, y);
PrimExpr identity_element = make_zero(source.dtype());
- tir::CommReducer combiner =
- tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+ tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::AndNode::make(x, y);
PrimExpr identity_element = make_const(source.dtype(), true);
- tir::CommReducer combiner =
- tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+ tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::OrNode::make(x, y);
PrimExpr identity_element = make_const(source.dtype(), false);
- tir::CommReducer combiner =
- tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+ tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::MaxNode::make(x, y);
PrimExpr identity_element = min_value(source.dtype());
- tir::CommReducer combiner =
- tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+ tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::MinNode::make(x, y);
PrimExpr identity_element = max_value(source.dtype());
- tir::CommReducer combiner =
- tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+ tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::MulNode::make(x, y);
PrimExpr identity_element = make_const(source.dtype(), 1);
- tir::CommReducer combiner =
- tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+ tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
PrimExpr fmod(PrimExpr x, PrimExpr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "fmod only applies to float";
- return tir::CallNode::make(x.dtype(), "fmod", { x, y }, tir::CallNode::PureIntrinsic);
+ return tir::CallNode::make(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic);
}
PrimExpr floor(PrimExpr x) {
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
- return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
- std::floor(fx->value)));
+ return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value)));
}
return tir::CallNode::make(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic);
}
-
// expose basic functions to node namespace
-TVM_REGISTER_GLOBAL("node._const")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- if (args[0].type_code() == kDLInt) {
- *ret = tir::make_const(args[1], args[0].operator int64_t());
- } else if (args[0].type_code() == kDLFloat) {
- *ret = tir::make_const(args[1], args[0].operator double());
- } else {
- LOG(FATAL) << "only accept int or float";
- }
- });
-
-TVM_REGISTER_GLOBAL("node.LargeUIntImm")
-.set_body_typed(LargeUIntImm);
-
-TVM_REGISTER_GLOBAL("node.String")
-.set_body_typed(tir::StringImmNode::make);
+TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) {
+ if (args[0].type_code() == kDLInt) {
+ *ret = tir::make_const(args[1], args[0].operator int64_t());
+ } else if (args[0].type_code() == kDLFloat) {
+ *ret = tir::make_const(args[1], args[0].operator double());
+ } else {
+ LOG(FATAL) << "only accept int or float";
+ }
+});
-TVM_REGISTER_GLOBAL("tir.min_value")
-.set_body_typed(min_value);
+TVM_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm);
-TVM_REGISTER_GLOBAL("tir.max_value")
-.set_body_typed(max_value);
+TVM_REGISTER_GLOBAL("node.String").set_body_typed(tir::StringImmNode::make);
-TVM_REGISTER_GLOBAL("tir.abs")
-.set_body_typed(tvm::abs);
+TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value);
-TVM_REGISTER_GLOBAL("tir.isnan")
-.set_body_typed(tvm::isnan);
+TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
-TVM_REGISTER_GLOBAL("tir.isfinite")
-.set_body_typed(tvm::isfinite);
+TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
-TVM_REGISTER_GLOBAL("tir.isinf")
-.set_body_typed(tvm::isinf);
+TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);
-TVM_REGISTER_GLOBAL("tir.floor")
-.set_body_typed(tvm::floor);
+TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
-TVM_REGISTER_GLOBAL("tir.ceil")
-.set_body_typed(tvm::ceil);
+TVM_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf);
-TVM_REGISTER_GLOBAL("tir.round")
-.set_body_typed(tvm::round);
+TVM_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor);
-TVM_REGISTER_GLOBAL("tir.nearbyint")
-.set_body_typed(tvm::nearbyint);
+TVM_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil);
-TVM_REGISTER_GLOBAL("tir.trunc")
-.set_body_typed(tvm::trunc);
+TVM_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round);
-TVM_REGISTER_GLOBAL("tir._cast")
-.set_body_typed(tvm::cast);
+TVM_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint);
+TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
+TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
// operator overloading, smarter than make
-#define REGISTER_MAKE_BINARY_OP(Node, Func) \
- TVM_REGISTER_GLOBAL("tir."#Node) \
- .set_body_typed([](PrimExpr a, PrimExpr b) { \
- return (Func(a, b)); \
+#define REGISTER_MAKE_BINARY_OP(Node, Func) \
+ TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b) { \
+ return (Func(a, b)); \
})
-#define REGISTER_MAKE_BIT_OP(Node, Func) \
- TVM_REGISTER_GLOBAL("tir."#Node) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- bool lhs_is_int = args[0].type_code() == kDLInt; \
- bool rhs_is_int = args[1].type_code() == kDLInt; \
- if (lhs_is_int) { \
- *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
- } else if (rhs_is_int) { \
- *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
- } else { \
- *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
- } \
+#define REGISTER_MAKE_BIT_OP(Node, Func) \
+ TVM_REGISTER_GLOBAL("tir." #Node).set_body([](TVMArgs args, TVMRetValue* ret) { \
+ bool lhs_is_int = args[0].type_code() == kDLInt; \
+ bool rhs_is_int = args[1].type_code() == kDLInt; \
+ if (lhs_is_int) { \
+ *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
+ } else if (rhs_is_int) { \
+ *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
+ } else { \
+ *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
+ } \
})
-
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
-REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
-REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
+REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
-.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
- return if_then_else(cond, true_value, false_value);
-});
+ .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
+ return if_then_else(cond, true_value, false_value);
+ });
} // namespace tvm
* \file tvm/tir/stmt.cc
*/
#include <tvm/runtime/registry.h>
-#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
-
+#include <tvm/tir/stmt.h>
namespace tvm {
namespace tir {
return Stmt(node);
}
-TVM_REGISTER_GLOBAL("tir.LetStmt")
-.set_body_typed(LetStmtNode::make);
+TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed(LetStmtNode::make);
-Stmt AttrStmtNode::make(ObjectRef node,
- std::string attr_key,
- PrimExpr value,
- Stmt body) {
+Stmt AttrStmtNode::make(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) {
auto n = make_object<AttrStmtNode>();
n->node = node;
n->attr_key = std::move(attr_key);
return Stmt(n);
}
-TVM_REGISTER_GLOBAL("tir.AttrStmt")
-.set_body_typed(AttrStmtNode::make);
+TVM_REGISTER_GLOBAL("tir.AttrStmt").set_body_typed(AttrStmtNode::make);
Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
CHECK(condition.defined());
- CHECK(message.dtype() == DataType::Int(32) ||
- message.as<StringImmNode>())
- << "TypeError: AssertStmt message must be an int or string:"
- << message << "\n";
+ CHECK(message.dtype() == DataType::Int(32) || message.as<StringImmNode>())
+ << "TypeError: AssertStmt message must be an int or string:" << message << "\n";
ObjectPtr<AssertStmtNode> node = make_object<AssertStmtNode>();
node->condition = std::move(condition);
}
TVM_REGISTER_GLOBAL("tir.AssertStmt")
-.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
- if (const auto* str = message.as<StringObj>()) {
- auto msg = StringImmNode::make(str->data);
- return AssertStmtNode::make(condition, msg, body);
- } else {
- return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
- }
-});
+ .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
+ if (const auto* str = message.as<StringObj>()) {
+ auto msg = StringImmNode::make(str->data);
+ return AssertStmtNode::make(condition, msg, body);
+ } else {
+ return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
+ }
+ });
-Stmt ForNode::make(Var loop_var,
- PrimExpr min,
- PrimExpr extent,
- ForType for_type,
- DeviceAPI device_api,
- Stmt body) {
+Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
+ DeviceAPI device_api, Stmt body) {
CHECK(min.defined());
CHECK(extent.defined());
CHECK(min.dtype().is_scalar());
return Stmt(node);
}
-TVM_REGISTER_GLOBAL("tir.For")
-.set_body_typed([](
- Var loop_var, PrimExpr min, PrimExpr extent,
- int for_type, int device_api, Stmt body) {
- return ForNode::make(loop_var,
- min,
- extent,
- static_cast<ForType>(for_type),
- static_cast<DeviceAPI>(device_api),
- body);
+TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent,
+ int for_type, int device_api, Stmt body) {
+ return ForNode::make(loop_var, min, extent, static_cast<ForType>(for_type),
+ static_cast<DeviceAPI>(device_api), body);
});
-
Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
CHECK(value.defined());
CHECK(index.defined());
return Stmt(node);
}
-
-TVM_REGISTER_GLOBAL("tir.Store")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- PrimExpr value = args[1];
- if (args.size() == 3) {
- *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
- } else {
- *ret = StoreNode::make(args[0], value, args[2], args[3]);
- }
- });
-
+TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) {
+ PrimExpr value = args[1];
+ if (args.size() == 3) {
+ *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
+ } else {
+ *ret = StoreNode::make(args[0], value, args[2], args[3]);
+ }
+});
Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args) {
- CHECK(value_index >=0 && value_index < func->num_outputs())
+ CHECK(value_index >= 0 && value_index < func->num_outputs())
<< "value index output function return value bound";
CHECK(value.defined()) << "Provide of undefined value\n";
return Stmt(node);
}
-TVM_REGISTER_GLOBAL("tir.Provide")
-.set_body_typed(ProvideNode::make);
+TVM_REGISTER_GLOBAL("tir.Provide").set_body_typed(ProvideNode::make);
-
-Stmt AllocateNode::make(Var buffer_var,
- DataType dtype,
- Array<PrimExpr> extents,
- PrimExpr condition,
+Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body) {
- for (size_t i = 0; i < extents.size(); ++i) {
- CHECK(extents[i].defined());
- CHECK(extents[i].dtype().is_scalar());
- }
- CHECK(body.defined());
- CHECK(condition.defined());
- CHECK(condition.dtype().is_bool());
-
- ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
- node->buffer_var = std::move(buffer_var);
- node->dtype = dtype;
- node->extents = std::move(extents);
- node->condition = std::move(condition);
- node->body = std::move(body);
- return Stmt(node);
+ for (size_t i = 0; i < extents.size(); ++i) {
+ CHECK(extents[i].defined());
+ CHECK(extents[i].dtype().is_scalar());
+ }
+ CHECK(body.defined());
+ CHECK(condition.defined());
+ CHECK(condition.dtype().is_bool());
+
+ ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
+ node->buffer_var = std::move(buffer_var);
+ node->dtype = dtype;
+ node->extents = std::move(extents);
+ node->condition = std::move(condition);
+ node->body = std::move(body);
+ return Stmt(node);
}
// overloaded, needs special handling
// has default args
TVM_REGISTER_GLOBAL("tir.Allocate")
-.set_body_typed([](
- Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
- ){
- return AllocateNode::make(buffer_var, type, extents, condition, body);
-});
+ .set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
+ Stmt body) {
+ return AllocateNode::make(buffer_var, type, extents, condition, body);
+ });
int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
int64_t result = 1;
for (size_t i = 0; i < extents.size(); ++i) {
- if (const IntImmNode *int_size = extents[i].as<IntImmNode>()) {
+ if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) {
result *= int_size->value;
if (result > std::numeric_limits<int32_t>::max()) {
return 0;
return Stmt(node);
}
-TVM_REGISTER_GLOBAL("tir.Free")
-.set_body_typed(FreeNode::make);
-
+TVM_REGISTER_GLOBAL("tir.Free").set_body_typed(FreeNode::make);
-Stmt RealizeNode::make(FunctionRef func,
- int value_index,
- DataType dtype,
- Region bounds,
- PrimExpr condition,
- Stmt body) {
+Stmt RealizeNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds,
+ PrimExpr condition, Stmt body) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
return Stmt(node);
}
-
-TVM_REGISTER_GLOBAL("tir.Realize")
-.set_body_typed(RealizeNode::make);
-
+TVM_REGISTER_GLOBAL("tir.Realize").set_body_typed(RealizeNode::make);
Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
data_ = make_object<PrefetchNode>(buffer, bounds);
}
-TVM_REGISTER_GLOBAL("tir.Prefetch")
-.set_body_typed([](Buffer buffer, Array<Range> bounds) {
+TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array<Range> bounds) {
return Prefetch(buffer, bounds);
});
-
SeqStmt::SeqStmt(Array<Stmt> seq) {
auto node = make_object<SeqStmtNode>();
node->seq = std::move(seq);
data_ = std::move(node);
}
-TVM_REGISTER_GLOBAL("tir.SeqStmt")
-.set_body_typed([](Array<Stmt> seq) {
+TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array<Stmt> seq) {
return SeqStmt(std::move(seq));
});
return Stmt(node);
}
-TVM_REGISTER_GLOBAL("tir.IfThenElse")
-.set_body_typed(IfThenElseNode::make);
-
+TVM_REGISTER_GLOBAL("tir.IfThenElse").set_body_typed(IfThenElseNode::make);
Stmt EvaluateNode::make(PrimExpr value) {
CHECK(value.defined());
return Stmt(node);
}
-TVM_REGISTER_GLOBAL("tir.Evaluate")
-.set_body_typed(EvaluateNode::make);
+TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed(EvaluateNode::make);
BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
}
TVM_REGISTER_GLOBAL("tir.BufferStore")
-.set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
- return BufferStore(buffer, value, indices);
-});
+ .set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+ return BufferStore(buffer, value, indices);
+ });
TVM_REGISTER_NODE_TYPE(BufferStoreNode);
-
-BufferRealize::BufferRealize(Buffer buffer,
- Array<Range> bounds,
- PrimExpr condition,
- Stmt body) {
- data_ = make_object<BufferRealizeNode>(
- buffer, bounds, condition, body);
+BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
+ data_ = make_object<BufferRealizeNode>(buffer, bounds, condition, body);
}
TVM_REGISTER_GLOBAL("tir.BufferRealize")
-.set_body_typed([](Buffer buffer,
- Array<Range> bounds,
- PrimExpr condition,
- Stmt body) {
- return BufferRealize(buffer, bounds, condition, body);
-});
+ .set_body_typed([](Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
+ return BufferRealize(buffer, bounds, condition, body);
+ });
TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
// Printers
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LetStmtNode*>(node.get());
- p->PrintIndent();
- p->stream << "let " << op->var << " = ";
- p->Print(op->value);
- p->stream << '\n';
- p->Print(op->body);
- });
+ .set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const LetStmtNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "let " << op->var << " = ";
+ p->Print(op->value);
+ p->stream << '\n';
+ p->Print(op->body);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AttrStmtNode*>(node.get());
- p->PrintIndent();
- p->stream << "// attr [";
- p->Print(op->node);
- p->stream << "] "
- << op->attr_key << " = ";
- p->Print(op->value);
- p->stream << '\n';
- p->Print(op->body);
- });
+ .set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const AttrStmtNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "// attr [";
+ p->Print(op->node);
+ p->stream << "] " << op->attr_key << " = ";
+ p->Print(op->value);
+ p->stream << '\n';
+ p->Print(op->body);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AssertStmtNode*>(node.get());
- p->PrintIndent();
- p->stream << "assert(";
- p->Print(op->condition);
- p->stream << ", ";
- p->Print(op->message);
- p->stream << ")\n";
- p->Print(op->body);
- });
-
-std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
+ .set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const AssertStmtNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "assert(";
+ p->Print(op->condition);
+ p->stream << ", ";
+ p->Print(op->message);
+ p->stream << ")\n";
+ p->Print(op->body);
+ });
+
+std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*)
switch (type) {
case ForType::Serial:
out << "for";
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ForNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ForNode*>(node.get());
- p->PrintIndent();
- p->stream << op->for_type << " (" << op->loop_var << ", ";
- p->Print(op->min);
- p->stream << ", ";
- p->Print(op->extent);
- p->stream << ") {\n";
-
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
-
- p->PrintIndent();
- p->stream << "}\n";
-});
+ .set_dispatch<ForNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ForNode*>(node.get());
+ p->PrintIndent();
+ p->stream << op->for_type << " (" << op->loop_var << ", ";
+ p->Print(op->min);
+ p->stream << ", ";
+ p->Print(op->extent);
+ p->stream << ") {\n";
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<StoreNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const StoreNode*>(node.get());
- p->PrintIndent();
- p->stream << op->buffer_var << "[";
- p->Print(op->index);
- p->stream << "] = ";
- p->Print(op->value);
- if (!is_one(op->predicate)) {
- p->stream << " if ";
- p->Print(op->predicate);
- }
- p->stream << '\n';
- });
+ p->indent += 2;
+ p->Print(op->body);
+ p->indent -= 2;
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ProvideNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ProvideNode*>(node.get());
- p->PrintIndent();
- p->stream << op->func->func_name() << "(";
- for (size_t i = 0; i < op->args.size(); ++i) {
- p->Print(op->args[i]);
- if (i < op->args.size() - 1) p->stream << ", ";
- }
- p->stream << ")";
- if (op->func->num_outputs() != 1) {
- p->stream << ".value[" << op->value_index << "]";
- }
- p->stream << " =";
- p->Print(op->value);
- p->stream << '\n';
- });
+ p->PrintIndent();
+ p->stream << "}\n";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferStoreNode*>(node.get());
- p->PrintIndent();
- p->stream << op->buffer->name << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) p->stream << ", ";
- }
- p->stream << "]";
- p->stream << " = ";
- p->Print(op->value);
- p->stream << '\n';
- });
+ .set_dispatch<StoreNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const StoreNode*>(node.get());
+ p->PrintIndent();
+ p->stream << op->buffer_var << "[";
+ p->Print(op->index);
+ p->stream << "] = ";
+ p->Print(op->value);
+ if (!is_one(op->predicate)) {
+ p->stream << " if ";
+ p->Print(op->predicate);
+ }
+ p->stream << '\n';
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AllocateNode*>(node.get());
- p->PrintIndent();
- p->stream << "allocate " << op->buffer_var << "[" << op->dtype;
- for (size_t i = 0; i < op->extents.size(); ++i) {
- p->stream << " * ";
- p->Print(op->extents[i]);
- }
- p->stream << "]";
- if (!is_one(op->condition)) {
- p->stream << " if ";
- p->Print(op->condition);
- }
- p->stream << "\n";
- p->Print(op->body);
- });
+ .set_dispatch<ProvideNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ProvideNode*>(node.get());
+ p->PrintIndent();
+ p->stream << op->func->func_name() << "(";
+ for (size_t i = 0; i < op->args.size(); ++i) {
+ p->Print(op->args[i]);
+ if (i < op->args.size() - 1) p->stream << ", ";
+ }
+ p->stream << ")";
+ if (op->func->num_outputs() != 1) {
+ p->stream << ".value[" << op->value_index << "]";
+ }
+ p->stream << " =";
+ p->Print(op->value);
+ p->stream << '\n';
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const FreeNode*>(node.get());
- p->PrintIndent();
- p->stream << "free " << op->buffer_var;
- p->stream << '\n';
- });
+ .set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferStoreNode*>(node.get());
+ p->PrintIndent();
+ p->stream << op->buffer->name << "[";
+ for (size_t i = 0; i < op->indices.size(); ++i) {
+ p->Print(op->indices[i]);
+ if (i < op->indices.size() - 1) p->stream << ", ";
+ }
+ p->stream << "]";
+ p->stream << " = ";
+ p->Print(op->value);
+ p->stream << '\n';
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferRealizeNode*>(node.get());
- p->PrintIndent();
- p->stream << "buffer_realize " << op->buffer->name << "(";
- for (size_t i = 0; i < op->bounds.size(); ++i) {
- p->stream << "[";
- p->Print(op->bounds[i]->min);
- p->stream << ", ";
- p->Print(op->bounds[i]->extent);
+ .set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const AllocateNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "allocate " << op->buffer_var << "[" << op->dtype;
+ for (size_t i = 0; i < op->extents.size(); ++i) {
+ p->stream << " * ";
+ p->Print(op->extents[i]);
+ }
p->stream << "]";
- if (i < op->bounds.size() - 1) p->stream << ", ";
- }
- p->stream << ")";
- if (!is_one(op->condition)) {
- p->stream << " if ";
- p->Print(op->condition);
- }
- p->stream << " {\n";
-
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
-
- p->PrintIndent();
- p->stream << "}\n";
- });
+ if (!is_one(op->condition)) {
+ p->stream << " if ";
+ p->Print(op->condition);
+ }
+ p->stream << "\n";
+ p->Print(op->body);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<RealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const RealizeNode*>(node.get());
- p->PrintIndent();
- p->stream << "realize " << op->func->func_name() << "(";
- for (size_t i = 0; i < op->bounds.size(); ++i) {
- p->stream << "[";
- p->Print(op->bounds[i]->min);
- p->stream << ", ";
- p->Print(op->bounds[i]->extent);
- p->stream << "]";
- if (i < op->bounds.size() - 1) p->stream << ", ";
- }
- p->stream << ")";
- if (op->func->num_outputs() != 1) {
- p->stream << ".value[" << op->value_index << "]";
- }
- if (!is_one(op->condition)) {
- p->stream << " if ";
- p->Print(op->condition);
- }
- p->stream << " {\n";
+ .set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const FreeNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "free " << op->buffer_var;
+ p->stream << '\n';
+ });
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferRealizeNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "buffer_realize " << op->buffer->name << "(";
+ for (size_t i = 0; i < op->bounds.size(); ++i) {
+ p->stream << "[";
+ p->Print(op->bounds[i]->min);
+ p->stream << ", ";
+ p->Print(op->bounds[i]->extent);
+ p->stream << "]";
+ if (i < op->bounds.size() - 1) p->stream << ", ";
+ }
+ p->stream << ")";
+ if (!is_one(op->condition)) {
+ p->stream << " if ";
+ p->Print(op->condition);
+ }
+ p->stream << " {\n";
- p->PrintIndent();
- p->stream << "}\n";
- });
+ p->indent += 2;
+ p->Print(op->body);
+ p->indent -= 2;
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const PrefetchNode*>(node.get());
- p->PrintIndent();
- p->stream << "prefetch " << op->buffer << "(";
- for (size_t i = 0; i < op->bounds.size(); ++i) {
- p->stream << "[";
- p->Print(op->bounds[i]->min);
- p->stream << ", ";
- p->Print(op->bounds[i]->extent);
- p->stream << "]";
- if (i < op->bounds.size() - 1) p->stream << ", ";
- }
- p->stream << ")";
- });
+ p->PrintIndent();
+ p->stream << "}\n";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SeqStmtNode*>(node.get());
- for (Stmt stmt : op->seq) {
- p->Print(stmt);
- }
- });
+ .set_dispatch<RealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const RealizeNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "realize " << op->func->func_name() << "(";
+ for (size_t i = 0; i < op->bounds.size(); ++i) {
+ p->stream << "[";
+ p->Print(op->bounds[i]->min);
+ p->stream << ", ";
+ p->Print(op->bounds[i]->extent);
+ p->stream << "]";
+ if (i < op->bounds.size() - 1) p->stream << ", ";
+ }
+ p->stream << ")";
+ if (op->func->num_outputs() != 1) {
+ p->stream << ".value[" << op->value_index << "]";
+ }
+ if (!is_one(op->condition)) {
+ p->stream << " if ";
+ p->Print(op->condition);
+ }
+ p->stream << " {\n";
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IfThenElseNode*>(node.get());
- p->PrintIndent();
- while (true) {
- p->stream << "if (" << op->condition << ") {\n";
p->indent += 2;
- p->Print(op->then_case);
+ p->Print(op->body);
p->indent -= 2;
- if (!op->else_case.defined()) {
- break;
+ p->PrintIndent();
+ p->stream << "}\n";
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const PrefetchNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "prefetch " << op->buffer << "(";
+ for (size_t i = 0; i < op->bounds.size(); ++i) {
+ p->stream << "[";
+ p->Print(op->bounds[i]->min);
+ p->stream << ", ";
+ p->Print(op->bounds[i]->extent);
+ p->stream << "]";
+ if (i < op->bounds.size() - 1) p->stream << ", ";
}
+ p->stream << ")";
+ });
- if (const IfThenElseNode *nested_if = op->else_case.as<IfThenElseNode>()) {
- p->PrintIndent();
- p->stream << "} else ";
- op = nested_if;
- } else {
- p->PrintIndent();
- p->stream << "} else {\n";
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const SeqStmtNode*>(node.get());
+ for (Stmt stmt : op->seq) {
+ p->Print(stmt);
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const IfThenElseNode*>(node.get());
+ p->PrintIndent();
+ while (true) {
+ p->stream << "if (" << op->condition << ") {\n";
p->indent += 2;
- p->Print(op->else_case);
+ p->Print(op->then_case);
p->indent -= 2;
- break;
+
+ if (!op->else_case.defined()) {
+ break;
+ }
+
+ if (const IfThenElseNode* nested_if = op->else_case.as<IfThenElseNode>()) {
+ p->PrintIndent();
+ p->stream << "} else ";
+ op = nested_if;
+ } else {
+ p->PrintIndent();
+ p->stream << "} else {\n";
+ p->indent += 2;
+ p->Print(op->else_case);
+ p->indent -= 2;
+ break;
+ }
}
- }
- p->PrintIndent();
- p->stream << "}\n";
-});
+ p->PrintIndent();
+ p->stream << "}\n";
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const EvaluateNode*>(node.get());
- p->PrintIndent();
- p->Print(op->value);
- p->stream << "\n";
- });
-
-template<typename T>
-void PrintList(const Array<T> &exprs, ReprPrinter* p) {
+ .set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const EvaluateNode*>(node.get());
+ p->PrintIndent();
+ p->Print(op->value);
+ p->stream << "\n";
+ });
+
+template <typename T>
+void PrintList(const Array<T>& exprs, ReprPrinter* p) {
for (size_t i = 0; i < exprs.size(); ++i) {
p->Print(exprs[i]);
if (i < exprs.size() - 1) {
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ShuffleNode*>(node.get());
- p->stream << "shuffle(";
- PrintList(op->vectors, p);
- p->stream << ", ";
- PrintList(op->indices, p);
- p->stream << ")";
- });
+ .set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ShuffleNode*>(node.get());
+ p->stream << "shuffle(";
+ PrintList(op->vectors, p);
+ p->stream << ", ";
+ PrintList(op->indices, p);
+ p->stream << ")";
+ });
TVM_REGISTER_NODE_TYPE(AttrStmtNode);
TVM_REGISTER_NODE_TYPE(PrefetchNode);
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h>
+
#include <functional>
+
#include "functor_common.h"
namespace tvm {
void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
VisitArray(op->bounds, [this](const Range& r) {
- this->VisitExpr(r->min);
- this->VisitExpr(r->extent);
- });
+ this->VisitExpr(r->min);
+ this->VisitExpr(r->extent);
+ });
this->VisitExpr(op->condition);
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const RealizeNode* op) {
VisitArray(op->bounds, [this](const Range& r) {
- this->VisitExpr(r->min);
- this->VisitExpr(r->extent);
- });
+ this->VisitExpr(r->min);
+ this->VisitExpr(r->extent);
+ });
this->VisitStmt(op->body);
this->VisitExpr(op->condition);
}
void StmtVisitor::VisitStmt_(const PrefetchNode* op) {
VisitArray(op->bounds, [this](const Range& r) {
- this->VisitExpr(r->min);
- this->VisitExpr(r->extent);
- });
+ this->VisitExpr(r->min);
+ this->VisitExpr(r->extent);
+ });
}
void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
- VisitArray(op->seq, [this](const Stmt& s) {
- this->VisitStmt(s);
- });
-}
-
-void StmtVisitor::VisitStmt_(const EvaluateNode* op) {
- this->VisitExpr(op->value);
+ VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
}
+void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); }
class StmtMutator::Internal {
public:
Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
Stmt body = this->VisitStmt(op->body);
- if (min.same_as(op->min) &&
- extent.same_as(op->extent) &&
- body.same_as(op->body)) {
+ if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
Stmt body = this->VisitStmt(op->body);
PrimExpr condition = this->VisitExpr(op->condition);
- if (extents.same_as(op->extents) &&
- body.same_as(op->body) &&
- condition.same_as(op->condition)) {
+ if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
if (op->else_case.defined()) {
else_case = this->VisitStmt(op->else_case);
}
- if (condition.same_as(op->condition) &&
- then_case.same_as(op->then_case) &&
+ if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
PrimExpr value = this->VisitExpr(op->value);
PrimExpr index = this->VisitExpr(op->index);
PrimExpr predicate = this->VisitExpr(op->predicate);
- if (value.same_as(op->value) &&
- index.same_as(op->index) &&
- predicate.same_as(op->predicate)) {
+ if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
PrimExpr value = this->VisitExpr(op->value);
Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
- if (value.same_as(op->value) &&
- indices.same_as(op->indices)) {
+ if (value.same_as(op->value) && indices.same_as(op->indices)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
PrimExpr condition = this->VisitExpr(op->condition);
Stmt body = this->VisitStmt(op->body);
- if (bounds.same_as(op->bounds) &&
- condition.same_as(op->condition) &&
- body.same_as(op->body)) {
+ if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
Array<PrimExpr> args = Internal::Mutate(this, op->args);
PrimExpr value = this->VisitExpr(op->value);
- if (args.same_as(op->args) &&
- value.same_as(op->value)) {
+ if (args.same_as(op->args) && value.same_as(op->value)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
Region bounds = Internal::Mutate(this, op->bounds);
Stmt body = this->VisitStmt(op->body);
PrimExpr condition = this->VisitExpr(op->condition);
- if (bounds.same_as(op->bounds) &&
- body.same_as(op->body) &&
- condition.same_as(op->condition)) {
+ if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
}
// advanced visit function for seqstmt.
-Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op,
- bool flatten_before_visit,
+Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
std::function<Stmt(const Stmt&)> fmutate) {
if (flatten_before_visit) {
// Pass 1, check if we need to flatten.
}
// function to run the visit.
auto frunvisit = [&](const SeqStmtNode* op) {
- Array<Stmt> seq =
- fmutate != nullptr ?
- MutateArray(op->seq, fmutate, allow_copy_on_write_) :
- Internal::Mutate(this, op->seq);
+ Array<Stmt> seq = fmutate != nullptr ? MutateArray(op->seq, fmutate, allow_copy_on_write_)
+ : Internal::Mutate(this, op->seq);
if (seq.same_as(op->seq)) {
return GetRef<Stmt>(op);
} else {
PrimExpr message = this->VisitExpr(op->message);
Stmt body = this->VisitStmt(op->body);
- if (condition.same_as(op->condition) &&
- message.same_as(op->message) &&
- body.same_as(op->body)) {
+ if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
}
}
-Stmt StmtMutator::VisitStmt_(const FreeNode* op) {
- return GetRef<Stmt>(op);
-}
-
+Stmt StmtMutator::VisitStmt_(const FreeNode* op) { return GetRef<Stmt>(op); }
// Implementations of IRTransform, PostOrderVisit and Substitute
-class IRApplyVisit :
- public StmtExprVisitor {
+class IRApplyVisit : public StmtExprVisitor {
public:
explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
std::unordered_set<const Object*> visited_;
};
-void PostOrderVisit(const ObjectRef& node,
- std::function<void(const ObjectRef&)> fvisit) {
+void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit) {
if (node.as<StmtNode>()) {
IRApplyVisit visitor(fvisit);
visitor(Downcast<Stmt>(node));
}
}
-class IRTransformer final :
- public StmtExprMutator {
+class IRTransformer final : public StmtExprMutator {
public:
- IRTransformer(const runtime::PackedFunc& f_preorder,
- const runtime::PackedFunc& f_postorder,
+ IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
- : f_preorder_(f_preorder),
- f_postorder_(f_postorder),
- only_enable_(only_enable) {
- }
+ : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {}
Stmt VisitStmt(const Stmt& stmt) final {
- return MutateInternal<Stmt>(stmt, [this](const Stmt& s) {
- return this->BaseVisitStmt(s);
- });
+ return MutateInternal<Stmt>(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); });
}
PrimExpr VisitExpr(const PrimExpr& expr) final {
- return MutateInternal<PrimExpr>(expr, [this](const PrimExpr& e) {
- return this->BaseVisitExpr(e);
- });
+ return MutateInternal<PrimExpr>(expr,
+ [this](const PrimExpr& e) { return this->BaseVisitExpr(e); });
}
private:
// NOTE: redirect to parent's call
// This is used to get around limitation of gcc-4.8
- Stmt BaseVisitStmt(const Stmt& s) {
- return StmtMutator::VisitStmt(s);
- }
- PrimExpr BaseVisitExpr(const PrimExpr& e) {
- return ExprMutator::VisitExpr(e);
- }
+ Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); }
+ PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); }
template <typename T, typename F>
T MutateInternal(const T& node, F fmutate) {
- if (only_enable_.size() &&
- !only_enable_.count(node->type_index())) {
+ if (only_enable_.size() && !only_enable_.count(node->type_index())) {
return fmutate(node);
}
if (f_preorder_ != nullptr) {
const std::unordered_set<uint32_t>& only_enable_;
};
-Stmt IRTransform(Stmt ir_node,
- const runtime::PackedFunc& f_preorder,
- const runtime::PackedFunc& f_postorder,
- Optional<Array<String>> only_enable) {
+Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder,
+ const runtime::PackedFunc& f_postorder, Optional<Array<String>> only_enable) {
std::unordered_set<uint32_t> only_type_index;
if (only_enable.defined()) {
for (auto s : only_enable.value()) {
class IRSubstitue : public StmtExprMutator {
public:
- explicit IRSubstitue(std::function<Optional<PrimExpr>(const Var&)> vmap)
- : vmap_(vmap) {
- }
+ explicit IRSubstitue(std::function<Optional<PrimExpr>(const Var&)> vmap) : vmap_(vmap) {}
PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();
if (auto mapped_var = vmap_(op->buffer_var)) {
- return LoadNode::make(
- op->dtype, Downcast<Var>(mapped_var.value()), op->index, op->predicate);
+ return LoadNode::make(op->dtype, Downcast<Var>(mapped_var.value()), op->index, op->predicate);
} else {
return ret;
}
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();
if (auto mapped_var = vmap_(op->buffer_var)) {
- return StoreNode::make(
- Downcast<Var>(mapped_var.value()), op->value, op->index, op->predicate);
+ return StoreNode::make(Downcast<Var>(mapped_var.value()), op->value, op->index,
+ op->predicate);
} else {
return ret;
}
std::function<Optional<PrimExpr>(const Var&)> vmap_;
};
-Stmt Substitute(Stmt stmt,
- std::function<Optional<PrimExpr>(const Var&)> vmap) {
+Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstitue(vmap)(std::move(stmt));
}
-PrimExpr Substitute(PrimExpr expr,
- std::function<Optional<PrimExpr>(const Var&)> vmap) {
+PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstitue(vmap)(std::move(expr));
}
+TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform);
-TVM_REGISTER_GLOBAL("tir.IRTransform")
-.set_body_typed(IRTransform);
-
-
-TVM_REGISTER_GLOBAL("tir.PostOrderVisit")
-.set_body_typed([](ObjectRef node, PackedFunc f) {
- tir::PostOrderVisit(node, [f](const ObjectRef& n) {
- f(n);
- });
+TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
+ tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); });
});
TVM_REGISTER_GLOBAL("tir.Substitute")
-.set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef{
- if (node->IsInstance<StmtNode>()) {
- return Substitute(Downcast<Stmt>(node), vmap);
- } else {
- return Substitute(Downcast<PrimExpr>(node), vmap);
- }
-});
+ .set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef {
+ if (node->IsInstance<StmtNode>()) {
+ return Substitute(Downcast<Stmt>(node), vmap);
+ } else {
+ return Substitute(Downcast<PrimExpr>(node), vmap);
+ }
+ });
} // namespace tir
} // namespace tvm
* \file tir/ir/transform.cc
* \brief TIR specific transformation passes.
*/
-#include <tvm/runtime/registry.h>
#include <tvm/node/repr_printer.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/transform.h>
-
namespace tvm {
namespace tir {
namespace transform {
-
/*!
* \brief Function level pass that applies transformations to all
* TIR functions within the module.
/*! \brief The pass function called on each. */
runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func;
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("pass_info", &pass_info);
- }
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
/*!
* \brief Run a function pass on given pass context.
}
// Perform Module -> Module optimizations at the PrimFunc level.
-IRModule PrimFuncPassNode::operator()(IRModule mod,
- const PassContext& pass_ctx) const {
+IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
- int opt_level,
- const std::string& name,
- const tvm::Array<runtime::String>& required) {
+ int opt_level, const std::string& name, const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return PrimFuncPass(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(PrimFuncPassNode);
TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass")
-.set_body_typed([](runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
- PassInfo pass_info) {
- return PrimFuncPass(pass_func, pass_info);
-});
+ .set_body_typed(
+ [](runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
+ PassInfo pass_info) { return PrimFuncPass(pass_func, pass_info); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
- const PassInfo info = node->Info();
- p->stream << "PrimFuncPass(" << info->name
- << ", opt_level=" << info->opt_level << ")";
-});
+ .set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
+ const PassInfo info = node->Info();
+ p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")";
+ });
} // namespace transform
} // namespace tir
/*!
* \file hoist_if_then_else.cc
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/runtime/registry.h>
+#include <queue>
#include <unordered_map>
#include <unordered_set>
-#include <queue>
+
#include "../../arith/interval_set.h"
#include "../../runtime/thread_storage_scope.h"
}
});
- PackedFunc replace_target_for = PackedFunc(
- [&](TVMArgs args, TVMRetValue *ret){
- const ObjectRef& current_for = args[0];
- if (current_for.get() == top_for_node) {
- *ret = new_if_stmt;
- }
- });
+ PackedFunc replace_target_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) {
+ const ObjectRef& current_for = args[0];
+ if (current_for.get() == top_for_node) {
+ *ret = new_if_stmt;
+ }
+ });
return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array<String>{"For"});
}
Stmt else_for;
CHECK(if_stmt.as<IfThenElseNode>());
- PackedFunc replace_then_case = PackedFunc(
- [&](TVMArgs args, TVMRetValue *ret){
- const ObjectRef& node = args[0];
- if (node == if_stmt) {
- *ret = node.as<IfThenElseNode>()->then_case;
- }
- });
+ PackedFunc replace_then_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) {
+ const ObjectRef& node = args[0];
+ if (node == if_stmt) {
+ *ret = node.as<IfThenElseNode>()->then_case;
+ }
+ });
- PackedFunc replace_else_case = PackedFunc(
- [&](TVMArgs args, TVMRetValue *ret){
- const ObjectRef& node = args[0];
- if (node == if_stmt) {
- *ret = node.as<IfThenElseNode>()->else_case;
- }
- });
+ PackedFunc replace_else_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) {
+ const ObjectRef& node = args[0];
+ if (node == if_stmt) {
+ *ret = node.as<IfThenElseNode>()->else_case;
+ }
+ });
then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array<String>{"IfThenElse"});
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
// Locate all For nodes and capture child IfThenElse nodes.
void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
- PostOrderVisit(stmt, [&](const ObjectRef& node){
+ PostOrderVisit(stmt, [&](const ObjectRef& node) {
const ForNode* for_node = node.as<ForNode>();
if (!for_node) return;
CHECK(for_node);
std::vector<Stmt> new_for_list{for_stmt};
for_tracking_map_.insert({for_stmt.get(), new_for_list});
- if (cond_var_map_[if_stmt]
- .count(for_node->loop_var.get())) {
- std::vector<Stmt> updated_for_list(for_list.begin(),
- for_list.begin() + i);
+ if (cond_var_map_[if_stmt].count(for_node->loop_var.get())) {
+ std::vector<Stmt> updated_for_list(for_list.begin(), for_list.begin() + i);
if2for_map_[if_stmt] = updated_for_list;
break;
} else {
// We keep all For nodes tracing in for_tracking_map_. When we get a
// hoisted IfThenElse, we match it with tracing For nodes to pick
// the updated one.
-size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt,
- const Stmt& if_stmt) {
+size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt) {
std::vector<Stmt> tracked_for_list = for_tracking_map_[for_stmt.get()];
size_t updated_for_idx = 0;
for (size_t i = 0; i < tracked_for_list.size(); ++i) {
- const Stmt& current_for =
- tracked_for_list.at(tracked_for_list.size() - 1 - i);
+ const Stmt& current_for = tracked_for_list.at(tracked_for_list.size() - 1 - i);
if (is_first_if(current_for, if_stmt)) {
updated_for_idx = tracked_for_list.size() - 1 - i;
break;
for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) {
const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i);
size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if);
- const Stmt& updated_for_node =
- for_tracking_map_[for_stmt.get()].at(updated_for_idx);
+ const Stmt& updated_for_node = for_tracking_map_[for_stmt.get()].at(updated_for_idx);
auto generated_for_pair = RemoveIf(updated_for_node, new_if);
const Stmt& then_for = generated_for_pair.first;
- const Stmt& else_for = generated_for_pair.second;;
+ const Stmt& else_for = generated_for_pair.second;
+
for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for;
if (else_for.get()) {
new_if = IfThenElseNode::make(new_if_node->condition, then_for, else_for);
if (i < if2for_map_[if_stmt.get()].size() - 1) {
const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1);
- const Stmt& actual_next_for =
- for_tracking_map_[original_next_for.get()].at(updated_for_idx);
+ const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(updated_for_idx);
Stmt update_for_stmt = update_for(actual_next_for, new_if);
- for_tracking_map_[original_next_for.get()].
- at(updated_for_idx) = update_for_stmt;
+ for_tracking_map_[original_next_for.get()].at(updated_for_idx) = update_for_stmt;
}
}
return new_if;
// Mutate For nodes in post order DFS manner.
Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
- PackedFunc replace_top_for = PackedFunc(
- [&](TVMArgs args, TVMRetValue *ret){
- const ObjectRef& current_for = args[0];
- const ForNode* for_node = current_for.as<ForNode>();
- if (!for_node) return;
-
- if (top_for_var_map_.count(for_node->loop_var.get())) {
- std::vector<Stmt> new_if_list;
- for (const Stmt& if_stmt :
- top_for_var_map_[for_node->loop_var.get()]) {
- new_if_list.emplace_back(HoistIf(if_stmt));
- }
+ PackedFunc replace_top_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) {
+ const ObjectRef& current_for = args[0];
+ const ForNode* for_node = current_for.as<ForNode>();
+ if (!for_node) return;
- const IfThenElseNode* next_if_node;
- const IfThenElseNode* current_if_node =
- new_if_list.back().as<IfThenElseNode>();
- Stmt new_for = Stmt();
- for (size_t i = new_if_list.size() - 1; i > 0; --i) {
- CHECK(current_if_node);
- const Stmt current_if_stmt =
- IfThenElseNode::make(current_if_node->condition,
- current_if_node->then_case,
- current_if_node->else_case);
- next_if_node = new_if_list[i - 1].as<IfThenElseNode>();
- CHECK(next_if_node);
- new_for = IfThenElseNode::make(next_if_node->condition, current_if_stmt,
- next_if_node->else_case);
- current_if_node = new_for.as<IfThenElseNode>();
- }
+ if (top_for_var_map_.count(for_node->loop_var.get())) {
+ std::vector<Stmt> new_if_list;
+ for (const Stmt& if_stmt : top_for_var_map_[for_node->loop_var.get()]) {
+ new_if_list.emplace_back(HoistIf(if_stmt));
+ }
- if (!new_for.get()) {
- const IfThenElseNode* first_if_node = new_if_list[0].as<IfThenElseNode>();
- CHECK(first_if_node);
- new_for = IfThenElseNode::make(first_if_node->condition,
- first_if_node->then_case,
- first_if_node->else_case);
- }
- *ret = new_for;
+ const IfThenElseNode* next_if_node;
+ const IfThenElseNode* current_if_node = new_if_list.back().as<IfThenElseNode>();
+ Stmt new_for = Stmt();
+ for (size_t i = new_if_list.size() - 1; i > 0; --i) {
+ CHECK(current_if_node);
+ const Stmt current_if_stmt = IfThenElseNode::make(
+ current_if_node->condition, current_if_node->then_case, current_if_node->else_case);
+ next_if_node = new_if_list[i - 1].as<IfThenElseNode>();
+ CHECK(next_if_node);
+ new_for =
+ IfThenElseNode::make(next_if_node->condition, current_if_stmt, next_if_node->else_case);
+ current_if_node = new_for.as<IfThenElseNode>();
}
- });
- return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"For"});
-}
-Stmt HoistIfThenElse(Stmt stmt) {
- return IfThenElseHoist().VisitAndMutate(stmt);
+ if (!new_for.get()) {
+ const IfThenElseNode* first_if_node = new_if_list[0].as<IfThenElseNode>();
+ CHECK(first_if_node);
+ new_for = IfThenElseNode::make(first_if_node->condition, first_if_node->then_case,
+ first_if_node->else_case);
+ }
+ *ret = new_for;
+ }
+ });
+ return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"For"});
}
+Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); }
-TVM_REGISTER_GLOBAL("testing.HoistIfThenElse")
-.set_body_typed(HoistIfThenElse);
+TVM_REGISTER_GLOBAL("testing.HoistIfThenElse").set_body_typed(HoistIfThenElse);
} // namespace tir
} // namespace tvm
* \file arg_binder.cc
* \brief Helper utility to match and bind arguments.
*/
-#include <tvm/tir/expr.h>
-#include <tvm/runtime/device_api.h>
-#include "ir_util.h"
#include "arg_binder.h"
+
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/expr.h>
+
#include "../../arith/compute_expr.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
-void BinderAddAssert(arith::Analyzer* ana,
- PrimExpr cond,
- const std::string& arg_name,
+void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg_name,
std::vector<Stmt>* asserts) {
PrimExpr scond = ana->Simplify(cond);
if (is_zero(scond)) {
- LOG(FATAL) << "Bind have an unmet assertion: "
- << cond << ", " << " on argument " << arg_name;
+ LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", "
+ << " on argument " << arg_name;
}
if (!is_one(scond)) {
std::ostringstream os;
}
}
-bool ArgBinder::Bind_(const PrimExpr& arg,
- const PrimExpr& value,
- const std::string& arg_name,
+bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name,
bool with_lets) {
CHECK_EQ(arg.dtype(), value.dtype());
if (const VarNode* v = arg.as<VarNode>()) {
return false;
}
-void ArgBinder::Bind(const PrimExpr& arg,
- const PrimExpr& value,
- const std::string& arg_name,
+void ArgBinder::Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name,
bool with_let) {
Bind_(arg, value, arg_name, with_let);
}
-void ArgBinder::BindArray(const Array<PrimExpr>& arg,
- const Array<PrimExpr>& value,
+void ArgBinder::BindArray(const Array<PrimExpr>& arg, const Array<PrimExpr>& value,
const std::string& arg_name) {
- CHECK_EQ(arg.size(), value.size())
- << "Argument " << arg_name << " array size mismatch";
+ CHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch";
for (size_t i = 0; i < arg.size(); ++i) {
std::ostringstream os;
os << arg_name << "[" << i << "]";
}
}
-void ArgBinder::BindBuffer(const Buffer& arg,
- const Buffer& value,
- const std::string& arg_name,
+void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name,
bool fuzzy_match) {
- CHECK_EQ(arg->scope, value->scope)
- << "Argument " << arg_name
- << " Buffer bind scope mismatch";
+ CHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch";
CHECK_EQ(arg->dtype, value->dtype)
- << "Argument " << arg_name
- << " Buffer bind data type mismatch";
+ << "Argument " << arg_name << " Buffer bind data type mismatch";
if (value->data_alignment % arg->data_alignment != 0) {
LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement "
<< " required_alignment=" << arg->data_alignment
PrimExpr offset = value->elem_offset;
PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
- BinderAddAssert(&analyzer_,
- truncmod(offset, factor) == zero,
- arg_name + ".elem_offset", &asserts_);
+ BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
+ &asserts_);
}
}
size_t diff = value->shape.size() - arg->shape.size();
for (size_t i = 0; i < diff; ++i) {
CHECK(is_one(analyzer_.Simplify(value->shape[i])))
- << "Argument " << arg_name << " shape mismatch"
- << arg->shape << " vs " << value->shape;
+ << "Argument " << arg_name << " shape mismatch" << arg->shape << " vs " << value->shape;
}
for (size_t i = 0; i < arg->shape.size(); ++i) {
std::ostringstream os;
return TVMStructGet(t, arr, 0, kind);
}
-void ArgBinder::BindDLTensor(const Buffer& buffer,
- const PrimExpr& device_type,
- const PrimExpr& device_id,
- const Var& handle,
+void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
+ const PrimExpr& device_id, const Var& handle,
const std::string& arg_name) {
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = EvaluateNode::make(0);
// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
- PrimExpr a_ndim = make_const(tvm_ndim_type,
- static_cast<int64_t>(buffer->shape.size()));
+ PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
- ndim_err_msg << arg_name
- << ".ndim is expected to equal "
- << buffer->shape.size();
+ ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size();
auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str());
asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
// type checks
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << dtype;
PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
- IntImm(DataType::UInt(8), dtype.code()) &&
- TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
- IntImm(DataType::UInt(8), dtype.bits()) &&
- TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
- IntImm(DataType::UInt(16), dtype.lanes()));
- if (!(dtype == DataType::Int(4) ||
- dtype == DataType::UInt(4) ||
- dtype == DataType::Int(1))) {
+ IntImm(DataType::UInt(8), dtype.code()) &&
+ TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
+ IntImm(DataType::UInt(8), dtype.bits()) &&
+ TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
+ IntImm(DataType::UInt(16), dtype.lanes()));
+ if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) {
auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str());
asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop));
Var vptr(buffer->data);
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
// mark alignment of external bufs
- init_nest_.emplace_back(AttrStmtNode::make(
- vptr, tir::attr::storage_alignment,
- IntImm(DataType::Int(32), buffer->data_alignment), nop));
+ init_nest_.emplace_back(AttrStmtNode::make(vptr, tir::attr::storage_alignment,
+ IntImm(DataType::Int(32), buffer->data_alignment),
+ nop));
}
Var v_shape(arg_name + ".shape", DataType::Handle());
init_nest_.emplace_back(LetStmtNode::make(
v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
- if (dtype == DataType::Int(4) ||
- dtype == DataType::UInt(4) ||
- dtype == DataType::Int(1)) {
+ if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) {
break;
}
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
- Bind_(buffer->shape[k],
- cast(buffer->shape[k].dtype(),
- LoadNode::make(tvm_shape_type, v_shape,
- IntImm(DataType::Int(32), k), const_true(1))),
- field_name.str(), true);
+ Bind_(
+ buffer->shape[k],
+ cast(buffer->shape[k].dtype(),
+ LoadNode::make(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))),
+ field_name.str(), true);
}
// strides field
Var v_strides(arg_name + ".strides", DataType::Handle());
def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(LetStmtNode::make(
- v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides),
- nop));
- PrimExpr is_null = CallNode::make(
- DataType::Bool(1), intrinsic::tvm_handle_is_null,
- {v_strides}, CallNode::PureIntrinsic);
+ v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop));
+ PrimExpr is_null = CallNode::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides},
+ CallNode::PureIntrinsic);
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
- PrimExpr svalue = cast(
- stype,
- LoadNode::make(tvm_shape_type, v_strides,
- IntImm(DataType::Int(32), k), const_true(1)));
+ PrimExpr svalue = cast(stype, LoadNode::make(tvm_shape_type, v_strides,
+ IntImm(DataType::Int(32), k), const_true(1)));
conds.push_back(expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
}
<< " expected to be compact array";
if (conds.size() != 0) {
auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str());
- Stmt check =
- AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
- stride_msg, EvaluateNode::make(0));
+ Stmt check = AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
+ stride_msg, EvaluateNode::make(0));
check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
}
size_t k = i - 1;
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
- PrimExpr value = cast(buffer->shape[k].dtype(),
- LoadNode::make(tvm_shape_type, v_strides,
- IntImm(DataType::Int(32), k), const_true(1)));
+ PrimExpr value = cast(
+ buffer->shape[k].dtype(),
+ LoadNode::make(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1)));
value = tvm::if_then_else(is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
field_name << v_strides->name_hint << '[' << k << ']';
Bind_(buffer->strides[k],
cast(buffer->shape[k].dtype(),
- LoadNode::make(tvm_shape_type, v_strides,
- IntImm(DataType::Int(32), k), const_true(1))),
+ LoadNode::make(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k),
+ const_true(1))),
field_name.str(), true);
}
}
if (const auto* const_offset = buffer->elem_offset.as<IntImmNode>()) {
Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
- TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset),
+ TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
if (Bind_(buffer->elem_offset,
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
- BinderAddAssert(&analyzer_,
- truncmod(offset, factor) == zero,
- arg_name + ".elem_offset", &asserts_);
+ BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
+ &asserts_);
}
}
}
// device info.
- Bind_(device_type,
- TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType),
+ Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType),
arg_name + ".device_type", true);
- Bind_(device_id,
- TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId),
+ Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId),
arg_name + ".device_id", true);
}
#ifndef TVM_TIR_TRANSFORMS_ARG_BINDER_H_
#define TVM_TIR_TRANSFORMS_ARG_BINDER_H_
-#include <tvm/tir/expr.h>
-#include <tvm/tir/buffer.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/tir/expr.h>
#include <string>
-#include <vector>
#include <unordered_map>
+#include <vector>
namespace tvm {
namespace tir {
* \param def_map A definition map that contains definition of known variables.
* ArgBinder will update this def_map when adding new definitions.
*/
- explicit ArgBinder(
- std::unordered_map<const VarNode*, PrimExpr>* def_map)
- : def_map_(def_map) {
- }
+ explicit ArgBinder(std::unordered_map<const VarNode*, PrimExpr>* def_map) : def_map_(def_map) {}
/*!
* \brief Try to bind arg to value, generate constraint if necessary.
* \param arg The argument to be binded.
* \param arg_name argument name.
* \param with_let Whether add lets during bind
*/
- void Bind(const PrimExpr& arg,
- const PrimExpr& value,
- const std::string& arg_name,
+ void Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name,
bool with_let = false);
/*!
* \brief Bind array to array
* \param value The target expression value
* \param arg_name argument name.
*/
- void BindArray(const Array<PrimExpr>& arg,
- const Array<PrimExpr>& value,
+ void BindArray(const Array<PrimExpr>& arg, const Array<PrimExpr>& value,
const std::string& arg_name);
/*!
* \brief Bind symbolic buffer to another symbolic buffer
* \param arg The argument to be binded.
* \param value The target expression value
* \param arg_name argument name.
- * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1.
+ * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as
+ * arg's higher dimensions are of 1.
*/
- void BindBuffer(const Buffer& arg,
- const Buffer& value,
- const std::string& arg_name,
+ void BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name,
bool fuzzy_match);
/*!
* \brief Bind symbolic buffer to a DLTensor handle.
* \param handle The DLTensor handle.
* \param arg_name argument name.
*/
- void BindDLTensor(const Buffer& buffer,
- const PrimExpr& device_type,
- const PrimExpr& device_id,
- const Var& handle,
- const std::string& arg_name);
+ void BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id,
+ const Var& handle, const std::string& arg_name);
/*! \return The defs generated in binding. */
- const std::vector<Var>& defs() const {
- return defs_;
- }
+ const std::vector<Var>& defs() const { return defs_; }
/*! \return The asserts generated in binding */
- const std::vector<Stmt>& asserts() const {
- return asserts_;
- }
+ const std::vector<Stmt>& asserts() const { return asserts_; }
/*!
* \brief Initialization nest generated
* This is only non-empty when BindDLTensor is called.
* Let statement is usually generated when bind to DLTensor and memory load is involved.
* \return The initialization nest generated during binding.
*/
- const std::vector<Stmt>& init_nest() const {
- return init_nest_;
- }
+ const std::vector<Stmt>& init_nest() const { return init_nest_; }
/*! \return Handle data type of the data */
- const Map<Var, PrimExpr>& def_handle_dtype() const {
- return def_handle_dtype_;
- }
+ const Map<Var, PrimExpr>& def_handle_dtype() const { return def_handle_dtype_; }
private:
// Internal bind function
- bool Bind_(const PrimExpr& arg,
- const PrimExpr& value,
- const std::string& arg_name,
+ bool Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name,
bool with_lets);
/*! \brief The definition map, can be uses to substitute */
std::unordered_map<const VarNode*, PrimExpr>* def_map_;
*/
// Instrument checkers for out of the bounds access.
-#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
-#include <vector>
+#include <tvm/tir/transform.h>
+
#include <unordered_map>
#include <utility>
+#include <vector>
namespace tvm {
namespace tir {
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::buffer_bound) {
- if (const VarNode *key = op->node.as<VarNode>()) {
+ if (const VarNode* key = op->node.as<VarNode>()) {
mem_to_shape[key] = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}
// Hashtable which maps buffer_var to shape.
- std::unordered_map<const VarNode *, PrimExpr> mem_to_shape;
+ std::unordered_map<const VarNode*, PrimExpr> mem_to_shape;
};
class BoundChecker : public StmtExprMutator {
public:
- explicit BoundChecker(
- const std::unordered_map<const VarNode *, PrimExpr> &mem_to_shape)
+ explicit BoundChecker(const std::unordered_map<const VarNode*, PrimExpr>& mem_to_shape)
: mem_to_shape_(mem_to_shape) {}
Stmt VisitStmt_(const AllocateNode* op) final {
PrimExpr condition = MakeCondition();
if (!condition.as<StringImmNode>()) {
Stmt nop = EvaluateNode::make(1);
- Stmt then_case =
- StoreNode::make(op->buffer_var, op->value, op->index, op->predicate);
- Stmt else_case =
- AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop);
+ Stmt then_case = StoreNode::make(op->buffer_var, op->value, op->index, op->predicate);
+ Stmt else_case = AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop);
Stmt body = IfThenElseNode::make(condition, then_case, else_case);
return body;
}
return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
}
- void Update(const Var& buffer_var,
- const Array<PrimExpr>& new_shape,
- const DataType& type) {
+ void Update(const Var& buffer_var, const Array<PrimExpr>& new_shape, const DataType& type) {
// Sanity check at first.
if (!new_shape.size()) {
return;
// Scalarize the shape.
PrimExpr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()),
- CastNode::make(DataType::UInt(64), new_shape[0]));
+ CastNode::make(DataType::UInt(64), new_shape[0]));
for (size_t i = 1; i < new_shape.size(); ++i) {
// Cast to unsigned to avoid integer overlow at frist.
shape = MulNode::make(shape, MulNode::make(make_const(DataType::UInt(64), type.lanes()),
- CastNode::make(DataType::UInt(64), new_shape[i])));
+ CastNode::make(DataType::UInt(64), new_shape[i])));
}
mem_to_shape_[buffer_var.get()] = shape;
}
return false;
}
- if (const RampNode *ramp_index = index.as<RampNode>()) {
- return ramp_index->base.defined() &&
- ramp_index->base.dtype().is_scalar() &&
- ramp_index->stride.defined() &&
- ramp_index->stride.dtype().is_scalar() && (ramp_index->lanes > 0);
+ if (const RampNode* ramp_index = index.as<RampNode>()) {
+ return ramp_index->base.defined() && ramp_index->base.dtype().is_scalar() &&
+ ramp_index->stride.defined() && ramp_index->stride.dtype().is_scalar() &&
+ (ramp_index->lanes > 0);
}
return true;
}
bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const {
- return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
- IndexIsValid(index) && !unsafe_rewritten_;
+ return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) &&
+ !unsafe_rewritten_;
}
void Collect(PrimExpr index, Var buffer_var) {
- store_scope_bound_collector_.push_back(
- std::make_pair(index, mem_to_shape_[buffer_var.get()]));
+ store_scope_bound_collector_.push_back(std::make_pair(index, mem_to_shape_[buffer_var.get()]));
}
PrimExpr MakeCondition() {
PrimExpr index = buffer_to_mem.first;
PrimExpr upper_bound = buffer_to_mem.second;
- if (const RampNode *ramp_index = index.as<RampNode>()) {
+ if (const RampNode* ramp_index = index.as<RampNode>()) {
// In case index is base + stride * i.
// Non inclusive range.
- index = AddNode::make(
- ramp_index->base,
- MulNode::make(ramp_index->stride, make_const(ramp_index->stride.dtype(),
- ramp_index->lanes - 1)));
+ index = AddNode::make(ramp_index->base, MulNode::make(ramp_index->stride,
+ make_const(ramp_index->stride.dtype(),
+ ramp_index->lanes - 1)));
}
// Try to simplify index and bound.
PrimExpr current_condition =
AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound));
- condition =
- !i ? current_condition : AndNode::make(condition, current_condition);
+ condition = !i ? current_condition : AndNode::make(condition, current_condition);
}
return condition;
}
// Pool which collects the pair of index and shape for specific store/load.
std::vector<std::pair<PrimExpr, PrimExpr>> store_scope_bound_collector_;
// Error message.
- const char *const error_message_ = "OUT OF THE BOUNDS";
+ const char* const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
- std::unordered_map<const VarNode *, PrimExpr> mem_to_shape_;
+ std::unordered_map<const VarNode*, PrimExpr> mem_to_shape_;
// internal analyzer
arith::Analyzer analyzer_;
};
}
TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers")
-.set_body_typed(InstrumentBoundCheckers);
+ .set_body_typed(InstrumentBoundCheckers);
} // namespace transform
*
* \file combine_context_call.cc
*/
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include <tvm/node/structural_equal.h>
-#include <tvm/node/structural_hash.h>
-#include <tvm/runtime/registry.h>
-
#include <unordered_map>
if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
CHECK_EQ(op->args.size(), 1U);
PrimExpr ctx = op->args[0];
- auto it = ctx_map_.find(ctx);
+ auto it = ctx_map_.find(ctx);
if (it != ctx_map_.end()) {
return it->second;
} else {
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::coproc_uop_scope) {
+ if (op->attr_key == attr::thread_extent || op->attr_key == attr::coproc_uop_scope) {
// Map of comparison expression to variable
std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> temp;
std::swap(temp, ctx_map_);
}
}
- Stmt Combine(Stmt stmt) {
- return BuildContext(ctx_map_, this->VisitStmt(stmt));
- }
+ Stmt Combine(Stmt stmt) { return BuildContext(ctx_map_, this->VisitStmt(stmt)); }
private:
static Stmt BuildContext(
- const std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual>& cmap,
- Stmt body) {
+ const std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual>& cmap, Stmt body) {
for (const auto& kv : cmap) {
body = LetStmtNode::make(kv.second, kv.first, body);
}
std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> ctx_map_;
};
-
namespace transform {
Pass CombineContextCall() {
return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall")
-.set_body_typed(CombineContextCall);
+TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall);
} // namespace transform
} // namespace tir
* \file coproc_sync.cc
*/
#include <tvm/runtime/registry.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
#include <unordered_map>
#include <unordered_set>
+
#include "ir_util.h"
#include "storage_access.h"
// Synchronization planning with co-processor.
class CoProcSyncPlanner : public StorageAccessVisitor {
public:
- explicit CoProcSyncPlanner(
- const std::unordered_set<const VarNode*>& touched,
- const std::string& coproc_name)
- : touched_(touched), coproc_name_(coproc_name) {
- }
+ explicit CoProcSyncPlanner(const std::unordered_set<const VarNode*>& touched,
+ const std::string& coproc_name)
+ : touched_(touched), coproc_name_(coproc_name) {}
void Plan(const Stmt& stmt) {
this->VisitStmt(stmt);
std::unordered_map<const Object*, std::vector<Stmt> > sync_;
protected:
- bool Enabled(const VarNode* buf,
- const StorageScope& scope) const final {
+ bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
return touched_.count(buf);
}
// Plan the sync
- std::vector<AccessEntry> Summarize(
- std::vector<StmtEntry> seq, const ForNode* loop) final {
+ std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
return PlanSync(seq, loop, false);
}
private:
// Plan write synchronization if write is not coherent
- std::vector<AccessEntry> PlanSync(
- std::vector<StmtEntry> seq, const ForNode* loop,
- bool force_sync_at_end) {
+ std::vector<AccessEntry> PlanSync(std::vector<StmtEntry> seq, const ForNode* loop,
+ bool force_sync_at_end) {
// detect write barriers
// access by the co-processor.
std::vector<AccessEntry> co_access;
auto find_conflict = [&](const AccessEntry& acc) {
for (const AccessEntry& x : co_access) {
if (x.buffer.same_as(acc.buffer) &&
- ((acc.type == kRead && x.type == kWrite) ||
- acc.type == kWrite)) {
+ ((acc.type == kRead && x.type == kWrite) || acc.type == kWrite)) {
return true;
}
}
bool sync_write = false;
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
- sync_write = true; break;
+ sync_write = true;
+ break;
}
if (acc.type == kSync) {
co_access.clear();
const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
- sync_at_end = true; break;
+ sync_at_end = true;
+ break;
}
}
if (sync_.count(s.stmt) || sync_at_end) break;
}
std::vector<Stmt> GetSync(std::string sync_name) {
- return {EvaluateNode::make(CallNode::make(
- DataType::Int(32),
- sync_name,
- {}, CallNode::Intrinsic))};
+ return {
+ EvaluateNode::make(CallNode::make(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))};
}
const std::unordered_set<const VarNode*>& touched_;
// Detect memory barriers when coproc read/write memory
class CoProcBarrierDetector : public StorageAccessVisitor {
public:
- explicit CoProcBarrierDetector(
- const std::unordered_set<const VarNode*>& touched,
- const std::string& coproc_name)
+ explicit CoProcBarrierDetector(const std::unordered_set<const VarNode*>& touched,
+ const std::string& coproc_name)
: touched_(touched) {
read_barrier_name_ = coproc_name + ".coproc_read_barrier";
write_barrier_name_ = coproc_name + ".coproc_write_barrier";
std::unordered_map<const Object*, std::vector<Stmt> > barrier_after_;
protected:
- bool Enabled(const VarNode* buf,
- const StorageScope& scope) const final {
+ bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
return touched_.count(buf);
}
// Plan the sync
- std::vector<AccessEntry> Summarize(
- std::vector<StmtEntry> seq, const ForNode* loop) final {
+ std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
if (read_barrier_) {
return PlanReadBarrier(seq, loop);
} else {
private:
// Plan write barrier at Read after write point.
- std::vector<AccessEntry> PlanWriteBarrier(
- std::vector<StmtEntry> seq, const ForNode* loop) {
+ std::vector<AccessEntry> PlanWriteBarrier(std::vector<StmtEntry> seq, const ForNode* loop) {
std::vector<AccessEntry> read_seq;
std::unordered_map<const VarNode*, std::vector<AccessEntry> > write_set;
auto fupdate = [&](size_t i, const AccessEntry& acc) {
- auto it = write_set.find(acc.buffer.get());
+ auto it = write_set.find(acc.buffer.get());
if (it != write_set.end()) {
CHECK_NE(i, 0U);
- barrier_after_[seq[i - 1].stmt].push_back(
- MakeBarrier(write_barrier_name_, it->second));
+ barrier_after_[seq[i - 1].stmt].push_back(MakeBarrier(write_barrier_name_, it->second));
write_set.erase(it);
}
};
fupdate(seq.size(), acc);
}
}
- for (const auto &kv : write_set) {
+ for (const auto& kv : write_set) {
read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end());
}
return read_seq;
}
- std::vector<AccessEntry> PlanReadBarrier(
- std::vector<StmtEntry> seq, const ForNode* loop) {
+ std::vector<AccessEntry> PlanReadBarrier(std::vector<StmtEntry> seq, const ForNode* loop) {
std::vector<AccessEntry> write_seq;
std::unordered_map<const VarNode*, std::vector<AccessEntry> > read_set;
auto fupdate = [&](size_t i, const AccessEntry& acc) {
- auto it = read_set.find(acc.buffer.get());
+ auto it = read_set.find(acc.buffer.get());
if (it != read_set.end()) {
CHECK_NE(i, seq.size());
- barrier_before_[seq[i].stmt].push_back(
- MakeBarrier(read_barrier_name_, it->second));
+ barrier_before_[seq[i].stmt].push_back(MakeBarrier(read_barrier_name_, it->second));
read_set.erase(it);
}
};
fupdate(0, acc);
}
}
- for (const auto &kv : read_set) {
+ for (const auto& kv : read_set) {
write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end());
}
return write_seq;
}
Range none;
Range r = arith::Union(wset).cover_range(none);
- CHECK(r.defined())
- << "Cannot deduce write range of " << wvec[0].buffer;
+ CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer;
PrimExpr min = r->min;
PrimExpr extent = r->extent;
return EvaluateNode::make(CallNode::make(
- DataType::Int(32), func,
- {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic));
+ DataType::Int(32), func, {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent},
+ CallNode::Intrinsic));
}
// Write barrier name
bool read_barrier_{false};
const std::unordered_set<const VarNode*>& touched_;
};
-
class CoProcInstDepDetector : public StmtVisitor {
public:
- explicit CoProcInstDepDetector(
- const IterVar& coproc_axis,
- const std::string& coproc_name)
+ explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name)
: coproc_axis_(coproc_axis) {
sync_push_name_ = coproc_name + ".coproc_dep_push";
sync_pop_name_ = coproc_name + ".coproc_dep_pop";
}
void VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::coproc_scope &&
- op->node.same_as(coproc_axis_)) {
+ if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) {
const IntImmNode* ctx_id = op->value.as<IntImmNode>();
CHECK(ctx_id != nullptr);
curr_state_.clear();
curr_state_.node = op;
CHECK(first_state_.node != nullptr);
// loop carry dependency
- InjectSync(last_state_, first_state_,
- &(curr_state_.exit_push),
- &(curr_state_.enter_pop));
+ InjectSync(last_state_, first_state_, &(curr_state_.exit_push), &(curr_state_.enter_pop));
curr_state_.enter_ctx = first_state_.enter_ctx;
curr_state_.exit_ctx = last_state_.exit_ctx;
}
curr_state.node = op;
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
- curr_state.enter_ctx.insert(
- first_state_.enter_ctx.begin(),
- first_state_.enter_ctx.end());
- curr_state.exit_ctx.insert(
- last_state_.exit_ctx.begin(),
- last_state_.exit_ctx.end());
+ curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end());
+ curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end());
}
first_state_.clear();
last_state_.clear();
curr_state.node = op;
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
- curr_state.enter_ctx.insert(
- first_state_.enter_ctx.begin(),
- first_state_.enter_ctx.end());
- curr_state.exit_ctx.insert(
- last_state_.exit_ctx.begin(),
- last_state_.exit_ctx.end());
+ curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end());
+ curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end());
}
}
// update in the trace.
// record the push/pop sequence that could be possibly un-matched.
// return the push/pop message at enter/exit of the Block
// after considering the existing unmatcheded events and added events
- void InjectSync(const SyncState& prev,
- const SyncState& next,
+ void InjectSync(const SyncState& prev, const SyncState& next,
std::vector<std::pair<int, int> >* prev_exit_push,
std::vector<std::pair<int, int> >* next_enter_pop) {
prev_exit_push->clear();
next_enter_pop->clear();
// quick path
- if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 &&
- prev.exit_ctx.size() == 1 && next.enter_ctx.size() == 1) {
+ if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && prev.exit_ctx.size() == 1 &&
+ next.enter_ctx.size() == 1) {
int from = *prev.exit_ctx.begin();
int to = *next.enter_ctx.begin();
if (from != to) {
// policy 1
std::vector<Stmt> prev_after, next_before;
for (const std::pair<int, int>& p : pending) {
- if (std::find(prev.exit_push.begin(),
- prev.exit_push.end(), p) ==
- prev.exit_push.end()) {
+ if (std::find(prev.exit_push.begin(), prev.exit_push.end(), p) == prev.exit_push.end()) {
vpush.push_back(p);
prev_after.emplace_back(MakePush(p.first, p.second));
}
- if (std::find(next.enter_pop.begin(),
- next.enter_pop.end(), p) ==
- next.enter_pop.end()) {
+ if (std::find(next.enter_pop.begin(), next.enter_pop.end(), p) == next.enter_pop.end()) {
vpop.push_back(p);
next_before.emplace_back(MakePop(p.first, p.second));
}
}
}
if (prev_after.size() != 0) {
- auto &v1 = insert_after_[prev.node];
+ auto& v1 = insert_after_[prev.node];
v1.insert(v1.end(), prev_after.begin(), prev_after.end());
}
if (next_before.size() != 0) {
- auto &v2 = insert_before_[next.node];
+ auto& v2 = insert_before_[next.node];
v2.insert(v2.end(), next_before.begin(), next_before.end());
}
}
void MatchFixEnterPop(const SyncState& state) {
if (state.enter_pop.size() == 0) return;
- auto &vec = insert_before_[state.node];
+ auto& vec = insert_before_[state.node];
for (const std::pair<int, int>& p : state.enter_pop) {
vec.push_back(MakePush(p.first, p.second));
}
void MatchFixExitPush(const SyncState& state) {
if (state.exit_push.size() == 0) return;
- auto &vec = insert_after_[state.node];
+ auto& vec = insert_after_[state.node];
for (const std::pair<int, int>& p : state.exit_push) {
vec.push_back(MakePop(p.first, p.second));
}
}
Stmt MakePush(int from, int to) {
- return EvaluateNode::make(CallNode::make(
- DataType::Int(32), sync_push_name_,
- {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
- CallNode::Intrinsic));
+ return EvaluateNode::make(
+ CallNode::make(DataType::Int(32), sync_push_name_,
+ {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
+ CallNode::Intrinsic));
}
Stmt MakePop(int from, int to) {
- return EvaluateNode::make(CallNode::make(
- DataType::Int(32), sync_pop_name_,
- {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
- CallNode::Intrinsic));
+ return EvaluateNode::make(
+ CallNode::make(DataType::Int(32), sync_pop_name_,
+ {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
+ CallNode::Intrinsic));
}
// sync states.
SyncState first_state_, last_state_, curr_state_;
std::string sync_push_name_, sync_pop_name_;
};
-
class CoProcSyncInserter : public StmtMutator {
public:
Stmt Insert(Stmt stmt) {
if (visitor.coproc_.size() == 0) return stmt;
std::unordered_set<const VarNode*> touched;
- for (const auto &kv : visitor.touched_) {
+ for (const auto& kv : visitor.touched_) {
if (kv.second.normal && kv.second.coproc) {
touched.insert(kv.first);
}
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
// Detect barrier
- CoProcInstDepDetector sync_detector(
- *visitor.coproc_.begin(), coproc_name);
+ CoProcInstDepDetector sync_detector(*visitor.coproc_.begin(), coproc_name);
sync_detector.Plan(stmt);
for (const auto& kv : sync_detector.insert_before_) {
auto& vec = insert_before_[kv.first];
Stmt new_stmt = StmtMutator::VisitStmt(stmt);
return SeqStmt::Flatten(
- it_before != insert_before_.end() ? it_before->second : std::vector<Stmt>(),
- new_stmt,
- it_after != insert_after_.end() ? it_after->second : std::vector<Stmt>());
+ it_before != insert_before_.end() ? it_before->second : std::vector<Stmt>(), new_stmt,
+ it_after != insert_after_.end() ? it_after->second : std::vector<Stmt>());
}
private:
std::unordered_map<const Object*, std::vector<Stmt> > insert_after_;
};
-
-Stmt CoProcSync(Stmt stmt) {
- return CoProcSyncInserter().Insert(std::move(stmt));
-}
+Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); }
namespace transform {
return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.CoProcSync")
-.set_body_typed(CoProcSync);
+TVM_REGISTER_GLOBAL("tir.transform.CoProcSync").set_body_typed(CoProcSync);
} // namespace transform
* \file decorate_device_scope.cc
*/
#include <tvm/runtime/registry.h>
-#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
Stmt DecorateDeviceScope(Stmt&& stmt) {
- Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)),
- tir::attr::device_scope,
- 0,
- stmt);
+ Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt);
return body;
}
return CreatePrimFuncPass(pass_func, 0, "tir.DecorateDeviceScope", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope")
-.set_body_typed(DecorateDeviceScope);
+TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope").set_body_typed(DecorateDeviceScope);
} // namespace transform
} // namespace tir
* \brief Replace certain copy with copy intrinsics.
* \file copy_intrin_rewrite.cc
*/
-#include <tvm/runtime/registry.h>
-#include <tvm/tir/transform.h>
-#include <tvm/arith/pattern.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/arith/pattern.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
#include "../../arith/pattern_match.h"
namespace tvm {
class CopyIntrinInjector : public StmtMutator {
public:
- CopyIntrinInjector(const std::string& pragma_key,
- const PackedFunc& flower_copy_fromto)
- : pragma_key_(attr::pragma_scope_prefix+ pragma_key),
- flower_copy_fromto_(flower_copy_fromto) {
- }
+ CopyIntrinInjector(const std::string& pragma_key, const PackedFunc& flower_copy_fromto)
+ : pragma_key_(attr::pragma_scope_prefix + pragma_key),
+ flower_copy_fromto_(flower_copy_fromto) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
storage_scope_[buf] = op->value.as<StringImmNode>()->value;
} else if (op->attr_key == pragma_key_) {
Stmt ret;
- CHECK(MatchCopyPattern(op->body, &ret))
- << "Cannot match copy pattern of " << op->body;
+ CHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body;
return ret;
}
return StmtMutator::VisitStmt_(op);
}
private:
- bool MatchCopyPattern(Stmt stmt, Stmt *out) {
+ bool MatchCopyPattern(Stmt stmt, Stmt* out) {
using namespace arith;
Stmt body = stmt;
// Expr sel_cond, sel_true_value, sel_false_value;
// match select or if
PVar<PrimExpr> sel_cond, sel_true_value, sel_false_value;
- bool has_cond =
- if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
- select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
+ bool has_cond = if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
+ select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
const CastNode* cast = store->value.as<CastNode>();
const LoadNode* load = store->value.as<LoadNode>();
for (const ForNode* op : loops) {
loop_vars.push_back(op->loop_var);
}
- Array<PrimExpr> store_strides =
- arith::DetectLinearEquation(store->index, loop_vars);
- Array<PrimExpr> load_strides =
- arith::DetectLinearEquation(load->index, loop_vars);
- if (load_strides.size() == 0 || store_strides.size() == 0) return false;
+ Array<PrimExpr> store_strides = arith::DetectLinearEquation(store->index, loop_vars);
+ Array<PrimExpr> load_strides = arith::DetectLinearEquation(load->index, loop_vars);
+ if (load_strides.size() == 0 || store_strides.size() == 0) return false;
Array<PrimExpr> dst_shape;
const size_t loop_var_size = loop_vars.size();
if (loop_var_size == 0) {
PrimExpr pad_value;
PrimExpr src_elem_offset = load_strides[loop_var_size];
if (has_cond) {
- Array<PrimExpr> clip_bound =
- arith::DetectClipBound(sel_cond.Eval(), loop_vars);
+ Array<PrimExpr> clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars);
pad_value = sel_false_value.Eval();
if (clip_bound.size() == 0) return false;
CHECK_EQ(src_shape.size(), loop_vars.size());
Array<PrimExpr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
Array<PrimExpr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
if (loop_var_size == 0) {
- src_strides.push_back(make_const(DataType::Int(32), 1));
- dst_strides.push_back(make_const(DataType::Int(32), 1));
+ src_strides.push_back(make_const(DataType::Int(32), 1));
+ dst_strides.push_back(make_const(DataType::Int(32), 1));
}
- Buffer dst = BufferNode::make(
- store->buffer_var,
- store->value.dtype(),
- dst_shape,
- dst_strides,
- store_strides[loop_var_size],
- store->buffer_var->name_hint,
- GetStorageScope(store->buffer_var.get()),
- 0, 0, kDefault);
- Buffer src = BufferNode::make(
- load->buffer_var,
- load->dtype,
- src_shape,
- src_strides,
- src_elem_offset,
- load->buffer_var->name_hint,
- GetStorageScope(load->buffer_var.get()),
- 0, 0, kDefault);
+ Buffer dst = BufferNode::make(store->buffer_var, store->value.dtype(), dst_shape, dst_strides,
+ store_strides[loop_var_size], store->buffer_var->name_hint,
+ GetStorageScope(store->buffer_var.get()), 0, 0, kDefault);
+ Buffer src = BufferNode::make(load->buffer_var, load->dtype, src_shape, src_strides,
+ src_elem_offset, load->buffer_var->name_hint,
+ GetStorageScope(load->buffer_var.get()), 0, 0, kDefault);
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
CHECK(out->defined()) << "flower function did not return correct stmt";
return true;
arith::Analyzer analyzer_;
};
-Stmt InjectCopyIntrin(Stmt stmt,
- const std::string& pragma_key,
+Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key,
const PackedFunc& flower_copy_fromto) {
return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt));
}
-
namespace transform {
-Pass InjectCopyIntrin(std::string pragma_key,
- PackedFunc flower_copy_fromto) {
+Pass InjectCopyIntrin(std::string pragma_key, PackedFunc flower_copy_fromto) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
- n->body = CopyIntrinInjector(
- pragma_key, flower_copy_fromto)(std::move(n->body));
+ n->body = CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin")
-.set_body_typed(InjectCopyIntrin);
+TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin").set_body_typed(InjectCopyIntrin);
} // namespace transform
* \file inject_double_buffer.cc
*/
#include <tvm/runtime/registry.h>
-#include <tvm/tir/transform.h>
-#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/op.h>
-#include "ir_util.h"
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
#include "../../arith/compute_expr.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
std::unordered_set<const VarNode*> touched_;
};
-
class StripDoubleBufferWrite : public StmtMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* op) final {
class DoubleBufferInjector : public StmtExprMutator {
public:
- explicit DoubleBufferInjector(int split_loop)
- : split_loop_(split_loop) {}
+ explicit DoubleBufferInjector(int split_loop) : split_loop_(split_loop) {}
Stmt Inject(Stmt stmt) {
DoubleBufferDetector detector;
Stmt VisitStmt_(const AllocateNode* op) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
- it->second.stride = arith::ComputeReduce<MulNode>(
- op->extents, PrimExpr()) * op->dtype.lanes();
+ it->second.stride =
+ arith::ComputeReduce<MulNode>(op->extents, PrimExpr()) * op->dtype.lanes();
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
Array<PrimExpr> new_extents{make_const(op->extents[0].dtype(), 2)};
}
CHECK(it->second.loop != nullptr);
auto& alloc_nest = loop_allocs_[it->second.loop];
- alloc_nest.emplace_back(AttrStmtNode::make(
- op->buffer_var, attr::storage_scope,
- StringImmNode::make(it->second.scope),
- EvaluateNode::make(0)));
- alloc_nest.emplace_back(AllocateNode::make(
- op->buffer_var, op->dtype, new_extents, op->condition,
- EvaluateNode::make(0)));
+ alloc_nest.emplace_back(AttrStmtNode::make(op->buffer_var, attr::storage_scope,
+ StringImmNode::make(it->second.scope),
+ EvaluateNode::make(0)));
+ alloc_nest.emplace_back(AllocateNode::make(op->buffer_var, op->dtype, new_extents,
+ op->condition, EvaluateNode::make(0)));
return op->body;
} else {
return StmtExprMutator::VisitStmt_(op);
<< "It is better to split with multiple of 2";
CHECK(is_zero(old_loop->min));
PrimExpr zero = old_loop->min;
- PrimExpr new_ext =
- old_loop->extent - make_const(old_loop->loop_var.dtype(), 1);
+ PrimExpr new_ext = old_loop->extent - make_const(old_loop->loop_var.dtype(), 1);
PrimExpr factor = make_const(new_ext.dtype(), split_loop_);
PrimExpr outer_ext = new_ext / factor;
PrimExpr tail_base = outer_ext * factor;
vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
loop_seq.emplace_back(Substitute(old_loop->body, vmap));
}
- Stmt loop = ForNode::make(
- outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
- SeqStmt::Flatten(loop_seq));
+ Stmt loop = ForNode::make(outer_var, zero, outer_ext, old_loop->for_type,
+ old_loop->device_api, SeqStmt::Flatten(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
PrimExpr idx = tail_base + make_const(tail_base.dtype(), i);
vmap[old_loop->loop_var.get()] = idx;
tail_seq.emplace_back(
- IfThenElseNode::make(idx < old_loop->extent,
- Substitute(tail_body, vmap)));
+ IfThenElseNode::make(idx < old_loop->extent, Substitute(tail_body, vmap)));
}
stmt = SeqStmt::Flatten(loop, tail_seq);
}
const StorageEntry& e = it->second;
CHECK(in_double_buffer_scope_);
CHECK(e.stride.defined());
- return StoreNode::make(op->buffer_var,
- op->value,
- e.switch_write_var * e.stride + op->index,
- op->predicate);
+ return StoreNode::make(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index,
+ op->predicate);
} else {
return stmt;
}
const StorageEntry& e = it->second;
CHECK(e.stride.defined());
CHECK(e.switch_read_var.defined());
- return LoadNode::make(op->dtype,
- op->buffer_var,
- e.switch_read_var * e.stride + op->index,
- op->predicate);
+ return LoadNode::make(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index,
+ op->predicate);
} else {
return expr;
}
private:
Stmt MakeProducer(const AttrStmtNode* op) {
const Var buffer = Downcast<Var>(op->node);
- CHECK_NE(loop_nest_.size(), 0U)
- << "Double buffer scope must be inside a loop";
+ CHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop";
auto it = dbuffer_info_.find(buffer.get());
if (it == dbuffer_info_.end()) {
LOG(WARNING) << "Skip double buffer scope " << op->node;
PrimExpr one = make_const(e.loop->loop_var.dtype(), 1);
PrimExpr two = make_const(e.loop->loop_var.dtype(), 2);
PrimExpr loop_shift = e.loop->loop_var + one;
- e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
- e.loop->loop_var.dtype());
+ e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", e.loop->loop_var.dtype());
e.switch_read_var = indexmod(e.loop->loop_var, two);
in_double_buffer_scope_ = true;
Stmt body = this->VisitStmt(op->body);
std::unordered_map<const VarNode*, StorageEntry> dbuffer_info_;
};
-
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) {
return DoubleBufferInjector(split_loop).Inject(stmt);
}
-
namespace transform {
Pass InjectDoubleBuffer(int split_loop) {
return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer")
-.set_body_typed(InjectDoubleBuffer);
+TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer);
} // namespace transform
* \file inject_prefetch.cc
*/
// Inject prefetch op in HalideIR
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/bound.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include <tvm/arith/bound.h>
-#include <tvm/arith/analyzer.h>
+
#include <unordered_set>
namespace tvm {
namespace tir {
-using arith::IntSet;
using arith::DomainTouched;
+using arith::IntSet;
class PrefetchInjector : public StmtMutator {
public:
}
Stmt VisitStmt_(const ForNode* op) final {
- auto &var = op->loop_var;
+ auto& var = op->loop_var;
loop_nest_.push_back(var);
if (op->for_type == ForType::Vectorized) {
vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1);
private:
std::vector<Var> loop_nest_;
- std::unordered_map<const VarNode *, IntSet> vectorized_;
+ std::unordered_map<const VarNode*, IntSet> vectorized_;
static const Range none;
};
const Range PrefetchInjector::none;
-Stmt InjectPrefetch(Stmt stmt) {
- return PrefetchInjector()(std::move(stmt));
-}
-
+Stmt InjectPrefetch(Stmt stmt) { return PrefetchInjector()(std::move(stmt)); }
namespace transform {
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch")
-.set_body_typed(InjectPrefetch);
+TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch").set_body_typed(InjectPrefetch);
} // namespace transform
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
+
#include <unordered_set>
-#include "ir_util.h"
+
#include "../../arith/compute_expr.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
// If expression is touched by var.
class ExprTouched final : public StmtExprVisitor {
public:
- explicit ExprTouched(const std::unordered_set<const VarNode*> &touched,
- bool check_write)
+ explicit ExprTouched(const std::unordered_set<const VarNode*>& touched, bool check_write)
: touched_var_(touched), check_write_(check_write) {}
void VisitExpr(const PrimExpr& n) final {
if (expr_touched_ && !check_write_) return;
StmtExprVisitor::VisitExpr(n);
}
- void VisitStmt(const Stmt& n) final {
+ void VisitStmt(const Stmt& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
StmtExprVisitor::VisitStmt(n);
}
- void VisitExpr_(const LoadNode *op) final {
+ void VisitExpr_(const LoadNode* op) final {
HandleUseVar(op->buffer_var.get());
StmtExprVisitor::VisitExpr_(op);
}
- void VisitExpr_(const VarNode *op) final {
- HandleUseVar(op);
- }
- void VisitExpr_(const CallNode *op) final {
+ void VisitExpr_(const VarNode* op) final { HandleUseVar(op); }
+ void VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
const auto* rw_mask = op->args[4].as<IntImmNode>();
const VarNode* buffer_var = op->args[1].as<VarNode>();
used_vars_.push_back(var);
}
}
- void HandleWriteVar(const VarNode* var) {
- write_vars_.push_back(var);
- }
+ void HandleWriteVar(const VarNode* var) { write_vars_.push_back(var); }
// the fields.
bool expr_touched_{false};
std::vector<const VarNode*> used_vars_;
Record(op->buffer_var.get(), tc);
this->VisitStmt(op->body);
}
- void Record(const VarNode* var,
- const ExprTouched& tc) {
+ void Record(const VarNode* var, const ExprTouched& tc) {
if (touched_var_.count(var)) return;
if (tc.expr_touched_) {
touched_var_.insert(var);
}
}
- std::unordered_set<const VarNode*>
- TouchedVar(const Stmt& stmt,
- const VarNode* var) {
+ std::unordered_set<const VarNode*> TouchedVar(const Stmt& stmt, const VarNode* var) {
touched_var_.insert(var);
this->VisitStmt(stmt);
// do a DFS to push affect around dependency.
- std::vector<const VarNode*> pending(
- touched_var_.begin(), touched_var_.end());
+ std::vector<const VarNode*> pending(touched_var_.begin(), touched_var_.end());
while (!pending.empty()) {
const VarNode* v = pending.back();
pending.pop_back();
// Whether variable is touched by the thread variable.
std::unordered_set<const VarNode*> touched_var_;
// x -> all the buffers x read from
- std::unordered_map<const VarNode*,
- std::vector<const VarNode*> > affect_;
+ std::unordered_map<const VarNode*, std::vector<const VarNode*> > affect_;
};
-
// Inject virtual thread loop
// rewrite the buffer access pattern when necessary.
class VTInjector : public StmtExprMutator {
public:
// constructor
- VTInjector(Var var,
- int num_threads,
- const std::unordered_set<const VarNode*>& touched_var,
+ VTInjector(Var var, int num_threads, const std::unordered_set<const VarNode*>& touched_var,
bool allow_share)
- : var_(var), num_threads_(num_threads),
- touched_var_(touched_var), allow_share_(allow_share) {
- }
+ : var_(var),
+ num_threads_(num_threads),
+ touched_var_(touched_var),
+ allow_share_(allow_share) {}
// Inject VTLoop when needed.
Stmt VisitStmt(const Stmt& s) final {
CHECK(!visit_touched_var_);
auto stmt = StmtExprMutator::VisitStmt(s);
if (visit_touched_var_ || trigger_base_inject_) {
- if (!vt_loop_injected_) {
+ if (!vt_loop_injected_) {
return InjectVTLoop(stmt, false);
}
visit_touched_var_ = false;
}
// Variable
PrimExpr VisitExpr_(const VarNode* op) final {
- CHECK(!alloc_remap_.count(op))
- << "Buffer address may get rewritten in virtual thread";
+ CHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread";
if (touched_var_.count(op)) {
visit_touched_var_ = true;
}
}
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
- return LoadNode::make(op->dtype, op->buffer_var,
- RewriteIndex(op->index, it->second),
- op->predicate);
+ return LoadNode::make(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second),
+ op->predicate);
} else {
return expr;
}
visit_touched_var_ = true;
PrimExpr offset = this->VisitExpr(op->args[2]);
PrimExpr extent = this->VisitExpr(op->args[3]);
- PrimExpr stride =
- it->second / make_const(offset.dtype(), dtype.lanes());
+ PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes());
offset = stride * var_ + offset;
- return CallNode::make(
- op->dtype, op->name,
- {op->args[0], op->args[1], offset, extent, op->args[4]},
- op->call_type);
+ return CallNode::make(op->dtype, op->name,
+ {op->args[0], op->args[1], offset, extent, op->args[4]}, op->call_type);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
return allow_share_ ? GetRef<PrimExpr>(op) : var_;
} else {
trigger_base_inject_ = !allow_share_;
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
- return StoreNode::make(op->buffer_var,
- op->value,
- RewriteIndex(op->index, it->second),
- op->predicate);
+ return StoreNode::make(op->buffer_var, op->value, RewriteIndex(op->index, it->second),
+ op->predicate);
} else {
return stmt;
}
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
} else if (!allow_share_ && !vt_loop_injected_ &&
- (op->attr_key == attr::coproc_uop_scope ||
- op->attr_key == attr::coproc_scope)) {
+ (op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) {
return InjectVTLoop(GetRef<Stmt>(op), true);
} else {
Stmt body = this->VisitStmt(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return AttrStmtNode::make(op->node, op->attr_key, value, body);
}
visit_touched_var_ = false;
Stmt body = this->VisitStmt(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return LetStmtNode::make(op->var, value, body);
visit_touched_var_ = false;
Stmt body = this->VisitStmt(op->body);
++max_loop_depth_;
- if (extent.same_as(op->extent) &&
- body.same_as(op->body)) {
+ if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return ForNode::make(
- op->loop_var, op->min, extent, op->for_type, op->device_api, body);
+ return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
}
}
// IfThenElse
else_case = this->VisitStmt(op->else_case);
max_loop_depth_ = std::max(temp, max_loop_depth_);
}
- if (condition.same_as(op->condition) &&
- then_case.same_as(op->then_case) &&
+ if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
// always rewrite if not allow sharing.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
- PrimExpr stride = arith::ComputeReduce<MulNode>(
- op->extents, PrimExpr()) * op->dtype.lanes();
+ PrimExpr stride = arith::ComputeReduce<MulNode>(op->extents, PrimExpr()) * op->dtype.lanes();
Array<PrimExpr> other;
other.push_back(make_const(op->extents[0].dtype(), num_threads_));
for (PrimExpr e : extents) {
// Mutate the body.
body = this->VisitStmt(op->body);
}
- if (!changed &&
- body.same_as(op->body) &&
- condition.same_as(op->condition)) {
+ if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) {
return GetRef<Stmt>(op);
} else {
- return AllocateNode::make(
- op->buffer_var, op->dtype,
- extents, condition, body);
+ return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body);
}
}
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
- return ForNode::make(idx, make_zero(idx.dtype()),
- make_const(idx.dtype(), num_threads_),
- ForType::Serial, DeviceAPI::None, stmt);
+ return ForNode::make(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_),
+ ForType::Serial, DeviceAPI::None, stmt);
}
}
std::unordered_map<const VarNode*, PrimExpr> alloc_remap_;
};
-
class VirtualThreadInjector : public StmtMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* op) final {
return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread")
-.set_body_typed(InjectVirtualThread);
+TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread);
} // namespace transform
* \file ir_util.cc
* \brief Helper functions to construct and compose IR nodes.
*/
+#include "ir_util.h"
+
#include <tvm/tir/stmt_functor.h>
-#include <utility>
-#include <unordered_set>
+
#include <unordered_map>
-#include "ir_util.h"
+#include <unordered_set>
+#include <utility>
namespace tvm {
namespace tir {
return body;
}
-
class IRConvertSSA final : public StmtExprMutator {
public:
PrimExpr VisitExpr_(const VarNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
if (scope_.count(op->buffer_var.get())) {
- return LoadNode::make(
- op->dtype, scope_[op->buffer_var.get()].back(),
- op->index, op->predicate);
+ return LoadNode::make(op->dtype, scope_[op->buffer_var.get()].back(), op->index,
+ op->predicate);
} else {
return expr;
}
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<StoreNode>();
if (scope_.count(op->buffer_var.get())) {
- return StoreNode::make(
- scope_[op->buffer_var.get()].back(), op->value,
- op->index, op->predicate);
+ return StoreNode::make(scope_[op->buffer_var.get()].back(), op->value, op->index,
+ op->predicate);
} else {
return stmt;
}
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<ForNode>();
- return ForNode::make(
- new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
+ return ForNode::make(new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<AllocateNode>();
- return AllocateNode::make(
- new_var, op->dtype, op->extents, op->condition,
- op->body);
+ return AllocateNode::make(new_var, op->dtype, op->extents, op->condition, op->body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
if (new_alloc.same_as(op->body)) return GetRef<Stmt>(op);
alloc = new_alloc.as<AllocateNode>();
CHECK(alloc);
- return AttrStmtNode::make(
- alloc->buffer_var, op->attr_key, op->value, new_alloc);
+ return AttrStmtNode::make(alloc->buffer_var, op->attr_key, op->value, new_alloc);
}
}
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
if (scope_.count(v) && scope_[v].size() != 0) {
- return AttrStmtNode::make(
- scope_[v].back(), op->attr_key, op->value, op->body);
+ return AttrStmtNode::make(scope_[v].back(), op->attr_key, op->value, op->body);
} else {
return stmt;
}
}
private:
- std::unordered_map<const VarNode*, std::vector<Var> > scope_;
+ std::unordered_map<const VarNode*, std::vector<Var>> scope_;
std::unordered_set<const VarNode*> defined_;
};
-Stmt ConvertSSA(Stmt stmt) {
- return IRConvertSSA()(std::move(stmt));
-}
+Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); }
} // namespace tir
} // namespace tvm
#ifndef TVM_TIR_TRANSFORMS_IR_UTIL_H_
#define TVM_TIR_TRANSFORMS_IR_UTIL_H_
+#include <tvm/runtime/device_api.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
-#include <tvm/runtime/device_api.h>
+
#include <vector>
namespace tvm {
* \return if update happens, return the new array, else return the
* original array
*/
-template<typename T, typename F>
+template <typename T, typename F>
inline Array<T> UpdateArray(Array<T> arr, F fupdate) {
std::vector<T> new_arr(arr.size());
bool changed = false;
* \param kind The data kind.
* \return the get expression.
*/
-inline PrimExpr TVMStructGet(
- DataType dtype, Var handle, int index,
- intrinsic::TVMStructFieldKind kind) {
- Array<PrimExpr> args ={
- handle,
- make_const(DataType::Int(32), index),
- make_const(DataType::Int(32), static_cast<int>(kind))};
+inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index,
+ intrinsic::TVMStructFieldKind kind) {
+ Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
+ make_const(DataType::Int(32), static_cast<int>(kind))};
return CallNode::make(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic);
}
return CallNode::make(
DataType::Handle(), intrinsic::tvm_address_of,
{LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
- const_true(dtype.lanes()))},
+ const_true(dtype.lanes()))},
CallNode::PureIntrinsic);
}
offset = offset * make_const(offset.dtype(), dtype.lanes());
offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
}
- return CallNode::make(
- DataType::Handle(), intrinsic::tvm_address_of,
- {LoadNode::make(dtype, handle, offset,
- const_true(dtype.lanes()))},
- CallNode::PureIntrinsic);
+ return CallNode::make(DataType::Handle(), intrinsic::tvm_address_of,
+ {LoadNode::make(dtype, handle, offset, const_true(dtype.lanes()))},
+ CallNode::PureIntrinsic);
}
/*!
* \param value The value to be set.
* \return the set stmt.
*/
-inline Stmt TVMStructSet(
- Var handle, int index,
- intrinsic::TVMStructFieldKind kind, PrimExpr value) {
- Array<PrimExpr> args ={
- handle,
- make_const(DataType::Int(32), index),
- make_const(DataType::Int(32), static_cast<int>(kind)),
- value};
+inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind kind,
+ PrimExpr value) {
+ Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
+ make_const(DataType::Int(32), static_cast<int>(kind)), value};
return EvaluateNode::make(
CallNode::make(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic));
}
*/
inline DataType APIType(DataType t) {
if (t.is_handle()) return t;
- CHECK_EQ(t.lanes(), 1)
- << "Cannot pass vector type through packed API.";
+ CHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API.";
if (t.is_uint() || t.is_int()) return DataType::Int(64);
CHECK(t.is_float());
return DataType::Float(64);
return align;
}
-
/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
* \file lift_attr_scope.cc
*/
#include <tvm/runtime/registry.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
#include "ir_util.h"
namespace tvm {
// to a few specified attr keys
class AttrScopeLifter : public StmtMutator {
public:
- explicit AttrScopeLifter(std::string attr_key)
- : attr_key_(attr_key) {}
+ explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {}
Stmt Lift(Stmt stmt) {
stmt = operator()(std::move(stmt));
if (attr_node_.defined()) {
- stmt = AttrStmtNode::make(
- attr_node_, attr_key_, attr_value_, stmt);
+ stmt = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, stmt);
}
return stmt;
}
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
if (attr_node_.defined()) {
- Stmt body = AttrStmtNode::make(
- attr_node_, attr_key_, attr_value_, op->body);
+ Stmt body = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, op->body);
// undefine them
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
- return AllocateNode::make(
- op->buffer_var, op->dtype,
- op->extents, op->condition, body);
+ return AllocateNode::make(op->buffer_var, op->dtype, op->extents, op->condition, body);
} else {
return stmt;
}
// check if all decorations are common.
for (size_t begin = 0; begin < attr_node.size();) {
size_t end = begin + 1;
- while (end < attr_node.size() &&
- attr_node[end].same_as(attr_node[begin]) &&
+ while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) &&
ValueSame(attr_value[end], attr_value[begin])) {
++end;
}
}
Stmt stmt = SeqStmt::Flatten(seq);
if (attr_node[begin].defined()) {
- stmt = AttrStmtNode::make(
- attr_node[begin], attr_key_, attr_value[begin], stmt);
+ stmt = AttrStmtNode::make(attr_node[begin], attr_key_, attr_value[begin], stmt);
}
reorg.push_back(stmt);
begin = end;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->VisitStmt(op->else_case);
- if (attr_node_.defined() &&
- attr_value_.defined() &&
- first_node.defined() &&
- first_value.defined() &&
- attr_node_.same_as(first_node) &&
+ if (attr_node_.defined() && attr_value_.defined() && first_node.defined() &&
+ first_value.defined() && attr_node_.same_as(first_node) &&
ValueSame(attr_value_, first_value)) {
- if (then_case.same_as(op->then_case) &&
- else_case.same_as(op->else_case)) {
+ if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
return IfThenElseNode::make(op->condition, then_case, else_case);
}
} else {
if (first_node.defined()) {
- then_case = AttrStmtNode::make(
- first_node, attr_key_, first_value, then_case);
+ then_case = AttrStmtNode::make(first_node, attr_key_, first_value, then_case);
}
if (attr_node_.defined()) {
- else_case = AttrStmtNode::make(
- attr_node_, attr_key_, attr_value_, else_case);
+ else_case = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, else_case);
// undefine them
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
}
- if (then_case.same_as(op->then_case) &&
- else_case.same_as(op->else_case)) {
+ if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
return IfThenElseNode::make(op->condition, then_case, else_case);
return AttrScopeLifter(attr_key).Lift(std::move(stmt));
}
-
namespace transform {
Pass LiftAttrScope(std::string attr_key) {
return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope")
-.set_body_typed(LiftAttrScope);
+TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope").set_body_typed(LiftAttrScope);
} // namespace transform
/*!
* \file loop_partition.cc
*/
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/bound.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/arith/bound.h>
+#include <tvm/tir/transform.h>
+
#include <unordered_map>
#include <unordered_set>
-#include "ir_util.h"
+
#include "../../arith/interval_set.h"
#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
-using arith::IntSet;
using arith::DeduceBound;
using arith::Intersect;
+using arith::IntSet;
using PartitionKey = std::pair<const Object*, bool>;
struct PartitionKeyHash {
class CandidateSelector final : public StmtExprVisitor {
public:
using VarIsUsed = bool;
- explicit CandidateSelector(bool split_const_loop)
- : split_const_loop_(split_const_loop) {}
+ explicit CandidateSelector(bool split_const_loop) : split_const_loop_(split_const_loop) {}
void VisitStmt_(const ForNode* op) final {
// partition const loop when sets split_const_loop_
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
- const IterVarNode *iv = op->node.as<IterVarNode>();
+ const IterVarNode* iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
class PartitionFinder : public StmtExprVisitor {
public:
explicit PartitionFinder(Var current_var,
- const std::unordered_map<const VarNode*, IntSet>& hint_map,
- const std::unordered_map<const VarNode*, IntSet>& relax_map)
- : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
- for (const auto& kv : hint_map) {
- out_vars_.insert(kv.first);
- }
- for (const auto& kv : relax_map) {
- out_vars_.insert(kv.first);
- }
- }
+ const std::unordered_map<const VarNode*, IntSet>& hint_map,
+ const std::unordered_map<const VarNode*, IntSet>& relax_map)
+ : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
+ for (const auto& kv : hint_map) {
+ out_vars_.insert(kv.first);
+ }
+ for (const auto& kv : relax_map) {
+ out_vars_.insert(kv.first);
+ }
+ }
void VisitStmt_(const ForNode* op) final {
if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
void VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(CallNode::likely)) {
PrimExpr cond = op->args[0];
- if (ExprUseVars(cond,
- std::unordered_set<const VarNode*>({current_var_.get()}))) {
+ if (ExprUseVars(cond, std::unordered_set<const VarNode*>({current_var_.get()}))) {
// For cond, find out the interval, if exists, in which we can prove that cond is
// true. Also find the interval, if exists, in which we can prove that cond is
// false.
- IntSet interval =
- DeduceBound(current_var_, cond, hint_map_, relax_map_);
+ IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.is_nothing()) {
// cond is true within interval
partitions[{cond.get(), true}] = interval;
}
PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
- IntSet interval =
- DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
+ IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
if (!interval.is_nothing()) {
// cond is false within interval
partitions[{cond.get(), false}] = interval;
class ConditionEliminator : public StmtExprMutator {
public:
explicit ConditionEliminator(const std::unordered_set<const Object*>& ps, bool cond_value = true)
- : ps_(ps), cond_value_(cond_value) {}
+ : ps_(ps), cond_value_(cond_value) {}
PrimExpr VisitExpr(const PrimExpr& e) final {
if (ps_.find(e.get()) != ps_.end()) {
bool cond_value_;
};
-
// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public StmtMutator {
public:
- explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps,
- PrimExpr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
+ explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps, PrimExpr cond)
+ : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
// likely conditions
class LoopPartitioner : public StmtMutator {
public:
- explicit LoopPartitioner(bool split_const_loop)
- : selector(CandidateSelector(split_const_loop)) {}
+ explicit LoopPartitioner(bool split_const_loop) : selector(CandidateSelector(split_const_loop)) {}
Stmt VisitAndMutate(Stmt stmt) {
selector(stmt);
Stmt VisitStmt_(const ForNode* op) final {
if (selector.candidates.count(op)) {
- Stmt s = TryPartition(op, GetRef<Stmt>(op), op->loop_var,
- op->min, op->min + op->extent - 1, op->body, false);
+ Stmt s = TryPartition(op, GetRef<Stmt>(op), op->loop_var, op->min, op->min + op->extent - 1,
+ op->body, false);
if (s.defined()) return s;
}
// normal path when loop partition fails
// normal loop variable can be put into hint map.
- hint_map_.insert({op->loop_var.get(),
- IntSet::interval(op->min, op->min + op->extent - 1)});
+ hint_map_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)});
Stmt res = StmtMutator::VisitStmt_(op);
hint_map_.erase(op->loop_var.get());
return res;
return StmtMutator::VisitStmt_(op);
}
- const IterVarNode *iv = op->node.as<IterVarNode>();
+ const IterVarNode* iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
if (selector.candidates.count(op)) {
Stmt res;
if (scope.rank == 1) {
// threadIdx should be put into relax map, in case of divergence.
- relax_map_.insert({var.get(),
- IntSet::interval(make_zero(var.dtype()), op->value - 1)});
+ relax_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)});
res = StmtMutator::VisitStmt_(op);
relax_map_.erase(var.get());
} else {
- hint_map_.insert({var.get(),
- IntSet::interval(make_zero(var.dtype()), op->value - 1)});
+ hint_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)});
res = StmtMutator::VisitStmt_(op);
hint_map_.erase(var.get());
}
}
private:
- Stmt TryPartition(const Object* op, const Stmt& stmt, Var var,
- PrimExpr min, PrimExpr max, Stmt body, bool partition_thread_scope);
+ Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max,
+ Stmt body, bool partition_thread_scope);
- std::pair<IntSet, std::unordered_set<const Object*>>
- GetIntervalAndCondset(const Partition &partitions,
- const arith::IntervalSet &for_interval,
- bool cond_value);
+ std::pair<IntSet, std::unordered_set<const Object*>> GetIntervalAndCondset(
+ const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value);
inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body);
// Returns an interval (in the first component) in which all the conditions
// given in the second component provably have value given by cond_value
-std::pair<IntSet, std::unordered_set<const Object*>>
-LoopPartitioner::GetIntervalAndCondset(const Partition &partitions,
- const arith::IntervalSet &for_interval,
- bool cond_value) {
+std::pair<IntSet, std::unordered_set<const Object*>> LoopPartitioner::GetIntervalAndCondset(
+ const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) {
Array<IntSet> sets;
std::unordered_set<const Object*> cond_set;
- for (const auto &kv : partitions) {
+ for (const auto& kv : partitions) {
if (kv.first.second == cond_value) {
arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
- arith::IntervalSet intersection = arith::Intersect(
- &analyzer_, interval, for_interval);
+ arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval);
if (!intersection->IsEmpty()) {
sets.push_back(kv.second);
cond_set.insert(kv.first.first);
* which will eventually be simplified to empty code. And because only one loop was generated
* from loop 2 we stop recursing.
*/
-Stmt LoopPartitioner::TryPartition(const Object* node,
- const Stmt& stmt,
- Var var,
- PrimExpr min,
- PrimExpr max,
- Stmt body,
- bool partition_thread_scope) {
+Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min,
+ PrimExpr max, Stmt body, bool partition_thread_scope) {
using namespace arith;
// include hint of var.
hint_map_.insert({var.get(), IntSet::interval(min, max)});
std::unordered_set<const Object*> cond_set;
// find an interval in which all conditions on var are true
std::tie(middle_interval, cond_set) =
- GetIntervalAndCondset(finder.partitions, for_interval, true);
+ GetIntervalAndCondset(finder.partitions, for_interval, true);
if (middle_interval.is_nothing()) {
// if such interval doesn't exist, find an interval in which all
// conditions on var are false
if (!analyzer_.CanProve(body_begin == min)) {
PrimExpr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
- LOG(WARNING) << "Cannot prove: " << cond
- << ", when generating the pre doubt loop";
+ LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop";
body_begin = MaxNode::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
// require the extent to be non-negative
PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
- LOG(WARNING) << "Cannot prove: " << cond
- << ", when generating the post doubt loop";
- post_doubt_begin = MinNode::make(post_doubt_begin, max+1);
+ LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop";
+ post_doubt_begin = MinNode::make(post_doubt_begin, max + 1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
- Stmt post_body =
- Substitute(body, {{Var{var}, var + post_doubt_begin}});
+ Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
}
return s;
}
-inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt body) {
- const ForNode *for_node = static_cast<const ForNode*>(node);
+inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) {
+ const ForNode* for_node = static_cast<const ForNode*>(node);
CHECK(for_node);
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) {
// If the loop extent is 1, do not create the loop anymore
class RemoveLikelyTags : public StmtExprMutator {
public:
- PrimExpr VisitExpr_(const CallNode *op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(CallNode::likely)) {
CHECK_EQ(op->args.size(), 1);
return StmtExprMutator::VisitExpr(op->args[0]);
return stmt;
}
-
namespace transform {
Pass LoopPartition(bool split_const_loop) {
return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.LoopPartition")
-.set_body_typed(LoopPartition);
+TVM_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition);
} // namespace transform
* \brief Pass for lowering custom datatypes
*/
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include <tvm/target/target.h>
-#include <tvm/runtime/registry.h>
+
#include "../../target/datatype/registry.h"
namespace tvm {
if (toBeLowered) {
auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
- return AllocateNode::make(
- allocate->buffer_var, new_allocate_type, allocate->extents,
- allocate->condition, allocate->body);
+ return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents,
+ allocate->condition, allocate->body);
}
return stmt;
}
return expr;
}
-#define DEFINE_MUTATE__(OP, NodeName) \
- inline PrimExpr VisitExpr_(const NodeName* op) final { \
- auto type_code = op->dtype.code(); \
+#define DEFINE_MUTATE__(OP, NodeName) \
+ inline PrimExpr VisitExpr_(const NodeName* op) final { \
+ auto type_code = op->dtype.code(); \
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
- PrimExpr expr = StmtExprMutator::VisitExpr_(op); \
- op = expr.as<NodeName>(); \
- if (toBeLowered) { \
- auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
- CHECK(lower) << #OP " lowering function for target " << target_ << " type " \
- << static_cast<unsigned>(type_code) << " not found"; \
- return (*lower)(expr); \
- } \
- return expr; \
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op); \
+ op = expr.as<NodeName>(); \
+ if (toBeLowered) { \
+ auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
+ CHECK(lower) << #OP " lowering function for target " << target_ << " type " \
+ << static_cast<unsigned>(type_code) << " not found"; \
+ return (*lower)(expr); \
+ } \
+ return expr; \
}
DEFINE_MUTATE__(Add, AddNode);
std::string target_;
};
-
namespace transform {
Pass LowerCustomDatatypes() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
- CHECK(target.defined())
- << "LowerCustomDatatypes: Require the target attribute";
+ CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute";
n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body));
return f;
return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes")
-.set_body_typed(LowerCustomDatatypes);
+TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes);
} // namespace transform
* \file lower_device_storage_access.cc
* \brief Lower the special device storage access.
*/
-#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
-#include <tvm/tir/buffer.h>
#include <tvm/arith/analyzer.h>
-#include <tvm/target/target_info.h>
#include <tvm/runtime/registry.h>
-#include "ir_util.h"
+#include <tvm/target/target_info.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
-using runtime::StorageScope;
using runtime::StorageRank;
+using runtime::StorageScope;
class StorageAccessInfoLower : public StmtExprMutator {
public:
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
- return LetStmtNode::make(
- op->buffer_var, info->head_address, op->body);
+ return LetStmtNode::make(op->buffer_var, info->head_address, op->body);
} else {
return op->body;
}
PrimExpr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.info.defined()) {
- return MakeTaggedAccessPtr(
- op->dtype, buffer_var, dtype, offset,
- it->second.info);
+ return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second.info);
}
CHECK(op->dtype.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}
- PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
- Var buffer_var,
- DataType dtype,
- PrimExpr offset,
+ PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, DataType dtype, PrimExpr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
- CHECK(info->head_address.defined())
- << buffer_var << " is not adddressable.";
+ CHECK(info->head_address.defined()) << buffer_var << " is not adddressable.";
return AddressOffset(buffer_var, dtype, offset);
}
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(info->unit_bits % dtype_bits, 0);
- return cast(ptr_type,
- analyzer_.Simplify(offset / make_const(
- offset.dtype(), info->unit_bits / dtype_bits)));
+ return cast(ptr_type, analyzer_.Simplify(
+ offset / make_const(offset.dtype(), info->unit_bits / dtype_bits)));
}
// The storage entry.
struct StorageEntry {
arith::Analyzer analyzer_;
};
-Stmt LowerStorageAccessInfo(Stmt stmt) {
- return StorageAccessInfoLower()(std::move(stmt));
-}
+Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); }
namespace transform {
n->body = StorageAccessInfoLower()(std::move(n->body));
return f;
};
- return CreatePrimFuncPass(
- pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
+ return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo")
-.set_body_typed(LowerDeviceStorageAccessInfo);
+ .set_body_typed(LowerDeviceStorageAccessInfo);
} // namespace transform
} // namespace tir
* Lower intrinsic calls and ops to device specific ir when possible.
* \file lower_intrin.cc
*/
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/tir/op.h>
-#include <tvm/target/target.h>
#include <unordered_set>
-#include "../../arith/pattern_match.h"
+
#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/pattern_match.h"
namespace tvm {
namespace tir {
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public:
- using IRMutatorWithAnalyzer::VisitStmt_;
using IRMutatorWithAnalyzer::VisitExpr_;
+ using IRMutatorWithAnalyzer::VisitStmt_;
IntrinInjecter(arith::Analyzer* analyzer, std::string target_name)
: IRMutatorWithAnalyzer(analyzer) {
}
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->call_type == CallNode::Intrinsic ||
- op->call_type == CallNode::PureIntrinsic) {
+ if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
PrimExpr r = ApplyPattern(op->name, GetRef<PrimExpr>(op));
if (r.defined()) return r;
}
const DataType& dtype = op->dtype;
CHECK(dtype.is_int() || dtype.is_uint());
- if (support_bitwise_op_ &&
- is_const_power_of_two_integer(op->b, &shift)) {
+ if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
// lower to right shift if possible.
return op->a >> make_const(dtype, shift);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common path, positive divisor
- if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
- analyzer_->CanProveGreaterEqual(e, 0)) {
+ if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) {
return truncdiv(op->a, op->b);
} else {
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
- return tir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1));
+ return tir::SelectNode::make(rmod >= 0, rdiv, rdiv - make_const(dtype, 1));
}
}
} else {
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
- return tir::SelectNode::make(
- (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
- rdiv, rdiv - make_const(dtype, 1));
+ return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
+ rdiv - make_const(dtype, 1));
}
}
const DataType& dtype = op->dtype;
CHECK(dtype.is_int() || dtype.is_uint());
- if (support_bitwise_op_ &&
- is_const_power_of_two_integer(op->b, &shift)) {
+ if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
// lower to masking if possible.
- int64_t mask = (
- static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
+ int64_t mask = (static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
return op->a & make_const(dtype, mask);
}
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
- return tir::SelectNode::make(
- (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
- rmod, rmod + op->b);
+ return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod,
+ rmod + op->b);
}
}
PVar<PrimExpr> x, y;
PVar<IntImm> c;
auto e = GetRef<PrimExpr>(op);
- if (max(floordiv(x, y), c).Match(e) &&
- c.Eval()->value >= 0 &&
+ if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
}
return e;
}
- PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c,
- const AddNode* op) {
+ PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) {
// emit fma instruction: a * b + c
PrimExpr lhs = SwapBroadcastCast(a);
PrimExpr rhs = SwapBroadcastCast(b);
if (fma_ != nullptr && op->dtype.is_float()) {
- PrimExpr r = (*fma_)(CallNode::make(
- op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic));
+ PrimExpr r =
+ (*fma_)(CallNode::make(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic));
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
- CHECK(target.defined())
- << "LowerIntrin: Require the target attribute";
+ CHECK(target.defined()) << "LowerIntrin: Require the target attribute";
arith::Analyzer analyzer;
- n->body =
- IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body));
+ n->body = IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin")
-.set_body_typed(LowerIntrin);
+TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin);
} // namespace transform
* Lower allreduce to device implementable ir.
* \file lower_thread_allreduce.cc
*/
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/target/target.h>
-#include <tvm/runtime/registry.h>
#include <unordered_set>
-#include "ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
class ThreadAllreduceBuilder final : public StmtExprMutator {
public:
explicit ThreadAllreduceBuilder(const TargetNode* target)
- : target_(target), warp_size_(target->thread_warp_size) {}
+ : target_(target), warp_size_(target->thread_warp_size) {}
- Stmt VisitStmt_(const AttrStmtNode *op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = StmtExprMutator::VisitStmt_(op);
return ret;
}
} else if (op->attr_key == attr::reduce_scope) {
- const CommReducerNode *combiner = op->node.as<CommReducerNode>();
+ const CommReducerNode* combiner = op->node.as<CommReducerNode>();
CHECK(combiner);
reduce_combiner_.push_back(combiner);
Stmt ret = StmtExprMutator::VisitStmt_(op);
if (it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
if (warp_allocs_.count(repl)) {
- stmt = AllocateNode::make(repl->buffer_var, repl->dtype,
- repl->extents, repl->condition, op->body);
+ stmt = AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition,
+ op->body);
stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope,
- StringImmNode::make("local"), stmt);
+ StringImmNode::make("local"), stmt);
} else {
// use volatile access to shared buffer.
- stmt = AttrStmtNode::make(
- repl->buffer_var, attr::volatile_scope, 1, op->body);
- stmt = AllocateNode::make(
- repl->buffer_var, repl->dtype,
- repl->extents, repl->condition, stmt);
- stmt = AttrStmtNode::make(
- repl->buffer_var, attr::storage_scope,
- StringImmNode::make("shared"), stmt);
+ stmt = AttrStmtNode::make(repl->buffer_var, attr::volatile_scope, 1, op->body);
+ stmt =
+ AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt);
+ stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope,
+ StringImmNode::make("shared"), stmt);
}
return stmt;
} else {
// make allreduce.
Stmt MakeAllreduce(const CallNode* call) {
CHECK(!reduce_combiner_.empty());
- const CommReducerNode *combiner = reduce_combiner_.back();
+ const CommReducerNode* combiner = reduce_combiner_.back();
size_t size = combiner->result.size();
- const IntImmNode *size_of_args = call->args[0].as<IntImmNode>();
+ const IntImmNode* size_of_args = call->args[0].as<IntImmNode>();
CHECK(size_of_args) << call->args[0]->GetTypeKey();
CHECK_EQ(size, size_of_args->value);
Array<PrimExpr> inits = combiner->identity_element;
std::vector<PrimExpr> values(size);
std::vector<DataType> types(size);
- PrimExpr cond = call->args[size+1];
+ PrimExpr cond = call->args[size + 1];
for (size_t idx = 0; idx < size; ++idx) {
- values[idx] = call->args[1+idx];
+ values[idx] = call->args[1 + idx];
if (!is_one(cond)) {
values[idx] = SelectNode::make(cond, values[idx], inits[idx]);
}
}
std::vector<const VarNode*> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
- const VarNode* buffer = call->args[2+size+idx].as<VarNode>();
+ const VarNode* buffer = call->args[2 + size + idx].as<VarNode>();
CHECK(buffer);
buffers[idx] = buffer;
}
e.scope = runtime::ThreadScope::make(iv->thread_tag);
e.iv = iv;
CHECK_LE(e.scope.rank, 1);
- CHECK_GE(e.scope.dim_index, 0)
- << "vthread do not work with cross thread reduction";
+ CHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
const auto* ptr = attr->value.as<IntImmNode>();
- CHECK(ptr)
- << "Need constant extent for reduce set " << iv;
+ CHECK(ptr) << "Need constant extent for reduce set " << iv;
e.extent = static_cast<int>(ptr->value);
if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
}
}
}
- CHECK_EQ(nmatch, reduce_set.size())
- << "Not all reduce index are presented in the context";
+ CHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context";
std::sort(vred.begin(), vred.end());
std::sort(vpar.begin(), vpar.end());
// the size of each index.
PrimExpr index(0);
for (size_t idx = 0; idx < size; ++idx) {
- shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
+ shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
PrimExpr pred = const_true(types[idx].lanes());
seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], index, pred));
// Uses a local variable to store the shuffled data.
// Later on, this allocation will be properly attached to this statement.
Var var("t" + std::to_string(idx), types[idx]);
- Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred,
- EvaluateNode::make(0));
+ Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred, EvaluateNode::make(0));
local_vars.push_back(s);
}
Var mask_var("mask", DataType::UInt(32));
{
PrimExpr pred = const_true(1);
- PrimExpr mask = CallNode::make(DataType::UInt(32),
- intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic);
+ PrimExpr mask = CallNode::make(DataType::UInt(32), intrinsic::tvm_warp_activemask, {},
+ CallNode::Intrinsic);
seq.emplace_back(StoreNode::make(mask_var, mask, index, pred));
// Push allocation with an empty body. Later this will be fixed
// when the entire body is ready.
- auto stmt = AllocateNode::make(mask_var, mask_var->dtype,
- {PrimExpr(1)}, pred, EvaluateNode::make(0));
+ auto stmt = AllocateNode::make(mask_var, mask_var->dtype, {PrimExpr(1)}, pred,
+ EvaluateNode::make(0));
local_vars.push_back(stmt);
}
Var var = shared_bufs[i];
load_remap_[buffers[i]] = LoadNode::make(types[i], var, index, pred);
Array<PrimExpr> extents{PrimExpr(1)};
- auto node = AllocateNode::make(var, types[i], extents, pred,
- EvaluateNode::make(0));
+ auto node = AllocateNode::make(var, types[i], extents, pred, EvaluateNode::make(0));
alloc_remap_[buffers[i]] = node;
warp_allocs_.insert(node.get());
}
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
PrimExpr pred = const_true(types[i].lanes());
- Var buffer_var = Downcast<Var>(call->args[2+size+i]);
+ Var buffer_var = Downcast<Var>(call->args[2 + size + i]);
stores[i] = StoreNode::make(buffer_var, values[i], 0, pred);
}
return SeqStmt::Flatten(stores);
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
- shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
+ shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
PrimExpr pred = const_true(types[idx].lanes());
- seq.emplace_back(StoreNode::make(
- shared_bufs[idx], values[idx],
- BufIndex(reduce_index, group_index, reduce_extent), pred));
+ seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx],
+ BufIndex(reduce_index, group_index, reduce_extent), pred));
}
seq.emplace_back(SyncThread("shared"));
- seq.emplace_back(MakeBufAllreduce(
- combiner, types, shared_bufs,
- reduce_index, group_index, reduce_extent, threadx_extent));
+ seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index,
+ reduce_extent, threadx_extent));
for (size_t idx = 0; idx < size; ++idx) {
CHECK(!load_remap_.count(buffers[idx]));
PrimExpr pred = const_true(types[idx].lanes());
load_remap_[buffers[idx]] = LoadNode::make(
- types[idx], shared_bufs[idx],
- BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
+ types[idx], shared_bufs[idx],
+ BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
alloc_remap_[buffers[idx]] = AllocateNode::make(
- shared_bufs[idx], types[idx],
- {PrimExpr(group_extent), PrimExpr(reduce_extent)},
- pred, EvaluateNode::make(0));
+ shared_bufs[idx], types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred,
+ EvaluateNode::make(0));
}
}
for (auto var : local_vars) {
const AllocateNode* repl = var.as<AllocateNode>();
if (repl) {
- body = AllocateNode::make(repl->buffer_var, repl->dtype,
- repl->extents, repl->condition, body);
+ body =
+ AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body);
body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope,
- StringImmNode::make("local"), body);
+ StringImmNode::make("local"), body);
}
}
}
// make allreduce.
- Stmt MakeBufAllreduce(const CommReducerNode *combiner,
- const std::vector<DataType>& types,
- const Array<Var>& shared_bufs,
- PrimExpr reduce_index,
- PrimExpr group_index,
- int reduce_extent,
- int threadx_extent) {
+ Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector<DataType>& types,
+ const Array<Var>& shared_bufs, PrimExpr reduce_index, PrimExpr group_index,
+ int reduce_extent, int threadx_extent) {
// Get next power of two
int reduce_align = 1;
while (reduce_extent > reduce_align) {
Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
b.push_back(LoadNode::make(types[i], shared_bufs[i],
- BufIndex(reduce_index + offset, group_index, reduce_extent),
- const_true()));
+ BufIndex(reduce_index + offset, group_index, reduce_extent),
+ const_true()));
a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true()));
}
Array<PrimExpr> ret = (*combiner)(a, b);
}
CHECK(threadx_extent >= 1 && warp_size_ >= 1);
// normal synchronization
- while (reduce_align > threadx_extent ||
- reduce_align > warp_size_) {
- reduce_align = reduce_align >> 1;
+ while (reduce_align > threadx_extent || reduce_align > warp_size_) {
+ reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < reduce_align;
seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
// Flatten the thread index.
// Also return a warp number,
- PrimExpr FlattenThread(const std::vector<ThreadEntry>& tvec,
- int* out_total_extent) {
+ PrimExpr FlattenThread(const std::vector<ThreadEntry>& tvec, int* out_total_extent) {
int& total_extent = *out_total_extent;
total_extent = 1;
if (tvec.size() == 0) {
}
// sync thread op.
static Stmt SyncThread(const std::string& sync) {
- return EvaluateNode::make(
- CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
- {StringImmNode::make(sync)},
- CallNode::Intrinsic));
+ return EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
+ {StringImmNode::make(sync)}, CallNode::Intrinsic));
}
// Emit warp shuffle intrinsic calls.
- PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val,
- int delta_or_lane) {
+ PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, int delta_or_lane) {
PrimExpr pred = const_true(1);
PrimExpr index(0);
PrimExpr mask = LoadNode::make(DataType::UInt(32), mask_var, index, pred);
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
- Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane),
- width, width};
+ Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width};
return CallNode::make(val.dtype(), name, args, CallNode::Intrinsic);
}
e.extent = static_cast<int>(ptr->value);
}
- return e.extent == warp_size_ &&
- e.scope.dim_index == 0 &&
- e.scope.rank == 1;
+ return e.extent == warp_size_ && e.scope.dim_index == 0 && e.scope.rank == 1;
}
// The target.
std::vector<const AttrStmtNode*> thread_extents_;
std::vector<const CommReducerNode*> reduce_combiner_;
// The load remap
- std::unordered_map<const VarNode *, PrimExpr> load_remap_;
+ std::unordered_map<const VarNode*, PrimExpr> load_remap_;
// Allocate remap
- std::unordered_map<const VarNode *, Stmt> alloc_remap_;
+ std::unordered_map<const VarNode*, Stmt> alloc_remap_;
// Allocate from warp reductions
- std::unordered_set<const void *> warp_allocs_;
+ std::unordered_set<const void*> warp_allocs_;
// Internal analyzer
arith::Analyzer analyzer_;
};
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
- CHECK(target.defined())
- << "LowerThreadAllreduce: Require the target attribute";
+ CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute";
const TargetNode* target_node = target.as<TargetNode>();
n->body = ThreadAllreduceBuilder(target_node)(n->body);
return f;
return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce")
-.set_body_typed(LowerThreadAllreduce);
+TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce);
} // namespace transform
} // namespace tir
* Lower TVM related builtin intrinsics such as packed call.
* \file tir/transforms/lower_tvm_buildin.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include <tvm/runtime/registry.h>
#include <unordered_set>
inline PrimExpr StackAlloca(std::string type, size_t num) {
Array<PrimExpr> args = {StringImmNode::make(type), ConstInt32(num)};
- return CallNode::make(
- DataType::Handle(),
- intrinsic::tvm_stack_alloca,
- args, CallNode::Intrinsic);
+ return CallNode::make(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic);
}
// Calculate the statistics of packed function.
stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->VisitStmt(stmt);
if (max_shape_stack_ != 0) {
- stmt = LetStmtNode::make(
- stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
+ stmt = LetStmtNode::make(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
}
if (max_array_stack_ != 0) {
- stmt = LetStmtNode::make(
- stack_array_, StackAlloca("array", max_array_stack_), stmt);
+ stmt = LetStmtNode::make(stack_array_, StackAlloca("array", max_array_stack_), stmt);
}
if (max_arg_stack_ != 0) {
- stmt = LetStmtNode::make(
- stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
- stmt = LetStmtNode::make(
- stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
+ stmt = LetStmtNode::make(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
+ stmt = LetStmtNode::make(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
}
return stmt;
}
}
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
- Stmt throw_last_error = EvaluateNode::make(
- CallNode::make(DataType::Int(32),
- intrinsic::tvm_throw_last_error, {},
- CallNode::Intrinsic));
+ Stmt throw_last_error = EvaluateNode::make(CallNode::make(
+ DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic));
- Stmt body = SeqStmt({
- IfThenElseNode::make(
- CallNode::make(DataType::Bool(1),
- intrinsic::tvm_handle_is_null,
- {op->buffer_var}, CallNode::PureIntrinsic),
- throw_last_error),
- op->body});
+ Stmt body = SeqStmt(
+ {IfThenElseNode::make(CallNode::make(DataType::Bool(1), intrinsic::tvm_handle_is_null,
+ {op->buffer_var}, CallNode::PureIntrinsic),
+ throw_last_error),
+ op->body});
Stmt alloca = LetStmtNode::make(
op->buffer_var,
- CallNode::make(op->buffer_var.dtype(),
- "TVMBackendAllocWorkspace",
- {cast(DataType::Int(32), device_type_),
- cast(DataType::Int(32), device_id_),
- cast(DataType::UInt(64), total_bytes),
- IntImm(DataType::Int(32), op->dtype.code()),
- IntImm(DataType::Int(32), op->dtype.bits())},
- CallNode::Extern),
+ CallNode::make(
+ op->buffer_var.dtype(), "TVMBackendAllocWorkspace",
+ {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
+ cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()),
+ IntImm(DataType::Int(32), op->dtype.bits())},
+ CallNode::Extern),
body);
- PrimExpr free_op = CallNode::make(DataType::Int(32),
- "TVMBackendFreeWorkspace",
- {cast(DataType::Int(32), device_type_),
- cast(DataType::Int(32), device_id_),
- op->buffer_var},
- CallNode::Extern);
- Stmt free_stmt = IfThenElseNode::make(
- free_op != make_zero(DataType::Int(32)), throw_last_error);
+ PrimExpr free_op = CallNode::make(DataType::Int(32), "TVMBackendFreeWorkspace",
+ {cast(DataType::Int(32), device_type_),
+ cast(DataType::Int(32), device_id_), op->buffer_var},
+ CallNode::Extern);
+ Stmt free_stmt =
+ IfThenElseNode::make(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = SeqStmt({alloca, free_stmt});
- body = AttrStmtNode::make(
- op->buffer_var, attr::storage_alignment,
- make_const(DataType::Int(32), runtime::kTempAllocaAlignment),
- body);
+ body = AttrStmtNode::make(op->buffer_var, attr::storage_alignment,
+ make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
return body;
}
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 0; i < op->args.size(); ++i) {
- prep_seq_.emplace_back(
- StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
- ConstInt32(stack_begin +i), const_true(1)));
+ prep_seq_.emplace_back(StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
+ ConstInt32(stack_begin + i), const_true(1)));
}
return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
}
run_array_stack_ += 1;
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
- prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
- prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
PrimExpr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
strides = make_zero(DataType::Handle());
}
- prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
- prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
DataType dtype = op->args[4].dtype();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
- prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
- make_const(DataType::UInt(8), dtype.bits())));
- prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
- make_const(DataType::UInt(16), dtype.lanes())));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
+ make_const(DataType::UInt(8), dtype.bits())));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
+ make_const(DataType::UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
PrimExpr byte_offset = op->args[5];
if (!is_zero(byte_offset)) {
byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
}
- prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
- cast(DataType::UInt(64), byte_offset)));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
+ cast(DataType::UInt(64), byte_offset)));
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
- prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
- cast(DataType::Int(32), device_id_)));
- prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
- cast(DataType::Int(32), device_type_)));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
+ cast(DataType::Int(32), device_id_)));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
+ cast(DataType::Int(32), device_type_)));
return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packed.
if (t != api_type) {
arg = CastNode::make(api_type, arg);
}
- prep_seq_.emplace_back(TVMStructSet(
- stack_value_, static_cast<int>(arg_stack_begin + i - 1),
- intrinsic::kTVMValueContent, arg));
+ prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast<int>(arg_stack_begin + i - 1),
+ intrinsic::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kTVMStr;
}
if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
prep_seq_.emplace_back(
- StoreNode::make(stack_tcode_,
- ConstInt32(arg_tcode),
- stack_index, const_true(1)));
+ StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
run_shape_stack_ = restore_shape_stack;
run_array_stack_ = restore_array_stack;
run_arg_stack_ = arg_stack_begin;
- Array<PrimExpr> packed_args = {
- op->args[0],
- stack_value_,
- stack_tcode_,
- ConstInt32(arg_stack_begin),
- ConstInt32(arg_stack_begin + op->args.size() - 1)
- };
- return CallNode::make(
- DataType::Int(32), intrinsic::tvm_call_packed_lowered,
- packed_args, CallNode::Intrinsic);
+ Array<PrimExpr> packed_args = {op->args[0], stack_value_, stack_tcode_,
+ ConstInt32(arg_stack_begin),
+ ConstInt32(arg_stack_begin + op->args.size() - 1)};
+ return CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args,
+ CallNode::Intrinsic);
}
- PrimExpr MakeCallTracePacked(const CallNode *op) {
+ PrimExpr MakeCallTracePacked(const CallNode* op) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
if (t != api_type) {
arg = CastNode::make(api_type, arg);
}
- prep_seq_.emplace_back(TVMStructSet(
- stack_value_, static_cast<int>(arg_stack_begin + i - 1),
- intrinsic::kTVMValueContent, arg));
+ prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast<int>(arg_stack_begin + i - 1),
+ intrinsic::kTVMValueContent, arg));
int arg_tcode = api_type.code();
CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
prep_seq_.emplace_back(
- StoreNode::make(stack_tcode_,
- ConstInt32(arg_tcode),
- stack_index, const_true(1)));
+ StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
// Update the top of the stack, so we can use more than one
// packed function's arguments with the one stack.
run_arg_stack_ = arg_stack_begin + args_size - 1;
- Array<PrimExpr> packed_args = {
- op->args[0],
- stack_value_,
- stack_tcode_,
- ConstInt32(arg_stack_begin),
- ConstInt32(arg_stack_begin + op->args.size() - 1),
- // Pass traced value.
- op->args[args_size - 1]
- };
- return CallNode::make(
- op->dtype, intrinsic::tvm_call_trace_packed_lowered,
- packed_args, CallNode::Intrinsic);
+ Array<PrimExpr> packed_args = {op->args[0], stack_value_, stack_tcode_,
+ ConstInt32(arg_stack_begin),
+ ConstInt32(arg_stack_begin + op->args.size() - 1),
+ // Pass traced value.
+ op->args[args_size - 1]};
+ return CallNode::make(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args,
+ CallNode::Intrinsic);
}
private:
return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin")
-.set_body_typed(LowerTVMBuiltin);
+TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin);
} // namespace transform
} // namespace tir
*/
// Thanks to Andrew Adams and Vinod Grover for
// explaining the concept of warp shuffle.
-#include <tvm/arith/pattern.h>
#include <tvm/arith/analyzer.h>
-
+#include <tvm/arith/pattern.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/analysis.h>
#include <tvm/tir/transform.h>
-#include <tvm/target/target.h>
-#include <tvm/runtime/registry.h>
#include <unordered_set>
-#include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h"
+#include "../../arith/pattern_match.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
// store warp_mem[m * warp_index + (width * m) * y + x]
class WarpStoreCoeffFinder : private StmtVisitor {
public:
- WarpStoreCoeffFinder(const VarNode* buffer,
- Var warp_index,
- arith::Analyzer* analyzer)
- : buffer_(buffer),
- warp_index_(warp_index),
- analyzer_(analyzer) {
- }
+ WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer)
+ : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {}
// find the warp co-efficient in the statement given the warp size
int Find(const Stmt& stmt) {
this->VisitStmt(stmt);
private:
/// Visitor implementation
- void VisitStmt_(const StoreNode *op) final {
+ void VisitStmt_(const StoreNode* op) final {
if (op->buffer_var.get() == buffer_) {
if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
}
void UpdatePattern(const PrimExpr& index) {
- Array<PrimExpr> m =
- arith::DetectLinearEquation(index, {warp_index_});
- CHECK_EQ(m.size(), 2U)
- << "LowerWarpMemory failed due to store index=" << index;
+ Array<PrimExpr> m = arith::DetectLinearEquation(index, {warp_index_});
+ CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index;
PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);
const auto* mcoeff_as_int = mcoeff.as<IntImmNode>();
CHECK(mcoeff_as_int && mcoeff_as_int->value > 0)
<< "LowerWarpMemory failed due to store index=" << index
- << ", require positive constant coefficient on warp index " << warp_index_
- << " but get " << mcoeff;
+ << ", require positive constant coefficient on warp index " << warp_index_ << " but get "
+ << mcoeff;
if (warp_coeff_ != 0) {
CHECK_EQ(warp_coeff_, mcoeff_as_int->value)
arith::Analyzer* analyzer_;
};
-
// Visitor to find the warp index
class WarpIndexFinder : private StmtVisitor {
public:
- explicit WarpIndexFinder(int warp_size)
- : warp_size_(warp_size) {
- }
+ explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {}
// find the warp co-efficient and the shuffle width in the statement
std::pair<Var, int> Find(const Stmt& stmt) {
this->VisitStmt(stmt);
private:
/// Visitor implementation
- void VisitStmt_(const AttrStmtNode *op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
auto* value_as_int = op->value.as<IntImmNode>();
- CHECK(value_as_int &&
- value_as_int->value <= warp_size_ &&
+ CHECK(value_as_int && value_as_int->value <= warp_size_ &&
warp_size_ % value_as_int->value == 0)
<< "Expect threadIdx.x 's size to be no larger than, and a factor of"
- << " warp size(" << warp_size_ << ")" << " to enable warp memory"
+ << " warp size(" << warp_size_ << ")"
+ << " to enable warp memory"
<< " but get " << op->value << " instead";
if (warp_index_.defined()) {
CHECK(warp_index_.same_as(iv))
- << "Find two instance of " << warp_index_->thread_tag
- << " in the same kernel. "
+ << "Find two instance of " << warp_index_->thread_tag << " in the same kernel. "
<< "Please create it using thread_axis once and reuse the axis "
<< "across multiple binds in the same kernel";
} else {
Stmt Rewrite(const AllocateNode* op) {
buffer_ = op->buffer_var.get();
int alloc_size = op->constant_allocation_size();
- CHECK_GT(alloc_size, 0)
- << "warp memory only support constant alloc size";
+ CHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size";
alloc_size *= op->dtype.lanes();
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
- warp_coeff_ = WarpStoreCoeffFinder(
- buffer_, warp_index_, analyzer_).Find(op->body);
+ warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
<< "Warp memory must be multiple of the extent of threadIdx.x";
warp_group_ = alloc_size / (width_ * warp_coeff_);
- return AllocateNode::make(
- op->buffer_var,
- op->dtype,
- {make_const(DataType::Int(32), alloc_size / width_)},
- op->condition,
- this->VisitStmt(op->body));
+ return AllocateNode::make(op->buffer_var, op->dtype,
+ {make_const(DataType::Int(32), alloc_size / width_)}, op->condition,
+ this->VisitStmt(op->body));
}
protected:
PrimExpr VisitExpr_(const VarNode* op) override {
- CHECK(op != buffer_)
- << "Cannot access address of warp memory directly";
+ CHECK(op != buffer_) << "Cannot access address of warp memory directly";
return StmtExprMutator::VisitExpr_(op);
}
std::tie(local_index, group) = SplitIndexByGroup(op->index);
// invariance: local index must do not contain warp id
CHECK(!ExprUseVar(local_index, warp_index_))
- << "LowerWarpMemory failed to rewrite load to shuffle for index "
- << op->index << " local_index=" << local_index;
- PrimExpr load_value = LoadNode::make(
- op->dtype, op->buffer_var, local_index, op->predicate);
- PrimExpr mask = CallNode::make(DataType::UInt(32),
- intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic);
- return CallNode::make(load_value.dtype(),
- intrinsic::tvm_warp_shuffle,
- {mask, load_value, group, width_, warp_size_},
- CallNode::Intrinsic);
+ << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index
+ << " local_index=" << local_index;
+ PrimExpr load_value = LoadNode::make(op->dtype, op->buffer_var, local_index, op->predicate);
+ PrimExpr mask = CallNode::make(DataType::UInt(32), intrinsic::tvm_warp_activemask, {},
+ CallNode::Intrinsic);
+ return CallNode::make(load_value.dtype(), intrinsic::tvm_warp_shuffle,
+ {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic);
} else {
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_);
y = y * m + x;
- PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)),
- m);
- return std::make_pair(analyzer_->canonical_simplify(y),
- analyzer_->canonical_simplify(z));
+ PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), m);
+ return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z));
}
}
arith::Analyzer* analyzer_;
};
-
// Bind bound information of variables to make analyzer more effective
// TODO(tqchen): consider a pass to inline the bound info into the expr
// so analysis can be context independent.
class BindVarBoundInfo : public StmtVisitor {
public:
- explicit BindVarBoundInfo(arith::Analyzer* analyzer)
- : analyzer_(analyzer) {}
+ explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {}
void VisitStmt_(const ForNode* op) final {
const Var& loop_var = op->loop_var;
}
void VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
if (!var_dom_.count(iv->var.get())) {
// Mutator to change the read pattern
class WarpMemoryRewriter : private StmtMutator {
public:
- explicit WarpMemoryRewriter(int warp_size)
- : warp_size_(warp_size) {
- }
+ explicit WarpMemoryRewriter(int warp_size) : warp_size_(warp_size) {}
Stmt Rewrite(Stmt stmt) {
if (warp_size_ == 1) return stmt;
warp_buffer_.insert(buf);
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
- return AttrStmtNode::make(
- op->node, op->attr_key, StringImmNode::make("local"), op->body);
+ return AttrStmtNode::make(op->node, op->attr_key, StringImmNode::make("local"), op->body);
}
}
return StmtMutator::VisitStmt_(op);
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
- CHECK(target.defined())
- << "LowerWarpMemory: Require the target attribute";
+ CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";
n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory")
-.set_body_typed(LowerWarpMemory);
+TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory);
} // namespace transform
/*!
* \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
-#include <tvm/tir/expr.h>
-#include <tvm/tir/analysis.h>
-#include <tvm/tir/transform.h>
-#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/buffer.h>
-#include <tvm/target/target.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
-#include <tvm/runtime/container.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
-#include <vector>
-#include <utility>
#include <unordered_set>
+#include <utility>
+#include <vector>
-#include "ir_util.h"
#include "arg_binder.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
EvaluateNode::make(0));
}
-PrimFunc MakePackedAPI(PrimFunc&& func,
- int num_unpacked_args) {
+PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- CHECK(global_symbol)
- << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
+ CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
- CHECK(target.defined())
- << "MakePackedAPI: Require the target attribute";
+ CHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
int target_device_type = target.value()->device_type;
std::string name_hint = global_symbol.value();
// local function definitions
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
- Array<PrimExpr> call_args{
- v_packed_args,
- IntImm(DataType::Int(32), i),
- IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
+ Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
+ IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
- PrimExpr res = CallNode::make(
- api_type, intrinsic::tvm_struct_get, call_args,
- CallNode::PureIntrinsic);
+ PrimExpr res =
+ CallNode::make(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic);
// cast to the target version.
if (api_type != t) {
res = CastNode::make(t, res);
std::ostringstream os;
os << name_hint << ": num_args should be " << num_packed_args;
- seq_init.emplace_back(
- MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
+ seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}
// Need to re-declare vars, in case some arguments also appears in the buffer.
}
if (i < num_packed_args) {
// Value loads
- seq_init.emplace_back(LetStmtNode::make(
- v_arg, f_arg_value(v_arg.dtype(), i), nop));
+ seq_init.emplace_back(LetStmtNode::make(v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
- seq_init.emplace_back(LetStmtNode::make(
- tcode, LoadNode::make(
- DataType::Int(32), v_packed_arg_type_ids,
- IntImm(DataType::Int(32), i), const_true(1)),
- nop));
+ seq_init.emplace_back(
+ LetStmtNode::make(tcode,
+ LoadNode::make(DataType::Int(32), v_packed_arg_type_ids,
+ IntImm(DataType::Int(32), i), const_true(1)),
+ nop));
DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back(
- AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
- tcode == kTVMNDArrayHandle ||
- tcode == kTVMDLTensorHandle ||
- tcode == kTVMNullptr,
+ AssertStmtNode::make(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
+ tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
tvm::tir::StringImmNode::make(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
}
for (const auto& kv : buffer_def) {
- binder.BindDLTensor(kv.second, device_type, device_id,
- kv.first, kv.first->name_hint);
+ binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint);
}
if (num_unpacked_args == 0) {
func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
}
- auto body = AttrStmtNode::make(
- make_zero(DataType::Int(32)), attr::compute_scope,
- StringImmNode::make(name_hint + "_compute_"), func_ptr->body);
+ auto body = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::compute_scope,
+ StringImmNode::make(name_hint + "_compute_"), func_ptr->body);
// Set device context
if (vmap.count(device_id.get())) {
PrimExpr node = StringImmNode::make("default");
- seq_check.push_back(AttrStmtNode::make(
- node, attr::device_context_id, device_id, nop));
- seq_check.push_back(AttrStmtNode::make(
- node, attr::device_context_type, device_type, nop));
+ seq_check.push_back(AttrStmtNode::make(node, attr::device_context_id, device_id, nop));
+ seq_check.push_back(AttrStmtNode::make(node, attr::device_context_type, device_type, nop));
if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) {
Stmt set_device = EvaluateNode::make(CallNode::make(
- DataType::Int(32), intrinsic::tvm_call_packed,
- {StringImmNode::make(runtime::symbol::tvm_set_device),
- device_type, device_id}, CallNode::Intrinsic));
+ DataType::Int(32), intrinsic::tvm_call_packed,
+ {StringImmNode::make(runtime::symbol::tvm_set_device), device_type, device_id},
+ CallNode::Intrinsic));
body = SeqStmt({set_device, body});
}
}
- func_ptr->body = MergeNest(
- {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
+ func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
func_ptr->params = args;
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
}
-
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32));
for (const auto& kv : mptr->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
- if (func->GetAttr<Integer>(
- tvm::attr::kCallingConv,
- Integer(CallingConv::kDefault)) == CallingConv::kDefault) {
+ if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
+ CallingConv::kDefault) {
auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args);
updates.push_back({kv.first, updated_func});
}
return m;
};
- return tvm::transform::CreateModulePass(
- pass_func, 0, "tir.MakePackedAPI", {});
+ return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI")
-.set_body_typed(MakePackedAPI);
+TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed(MakePackedAPI);
} // namespace transform
} // namespace tir
} // namespace tvm
* \brief narrow the datatype of indexing vars
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
-#include <tvm/runtime/registry.h>
+
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../arith/ir_visitor_with_analyzer.h"
// - Use DataTypeRewritter to rewrite the components of an indexing expression.
using arith::Analyzer;
-using arith::IRMutatorWithAnalyzer;
using arith::ConstIntBound;
+using arith::IRMutatorWithAnalyzer;
// Determine the result dtype for Var, IntImm and Cast,
// which will be stored in `vmap` eventually.
// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()`
class DataTypeVisitor final : public StmtExprVisitor {
public:
- explicit DataTypeVisitor(int target_bits)
- : bits_(target_bits), target_bits_(target_bits) {}
+ explicit DataTypeVisitor(int target_bits) : bits_(target_bits), target_bits_(target_bits) {}
void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
(bound->max_value <= ubound && bound->min_value >= lbound)) {
bits = target_bits_;
}
- int tmp = bits > bits_ ? bits : bits_;
+ int tmp = bits > bits_ ? bits : bits_;
std::swap(bits_, tmp);
StmtExprVisitor::VisitExpr(e);
std::swap(bits_, tmp);
}
void VisitStmt_(const ForNode* op) {
- analyzer_.Bind(op->loop_var,
- Range::make_by_min_extent(op->min, op->extent));
+ analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
return StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
- analyzer_.Bind(iv->var,
- Range::make_by_min_extent(0, op->value));
+ analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value));
vextent_[iv->var.as<VarNode>()] = op->value.dtype();
StmtExprVisitor::VisitStmt_(op);
} else {
class DataTypeRewriter : public StmtExprMutator {
public:
- explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {}
+ explicit DataTypeRewriter(int target_bits) : visitor_(target_bits) {}
Stmt operator()(Stmt s) {
visitor_(s);
is_index_ = true;
PrimExpr index = this->VisitExpr(op->index);
is_index_ = false;
- Stmt s = StoreNode::make(op->buffer_var,
- op->value,
- index,
- op->predicate);
+ Stmt s = StoreNode::make(op->buffer_var, op->value, index, op->predicate);
return StmtExprMutator::VisitStmt_(s.as<StoreNode>());
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<ForNode>();
- CHECK(op != nullptr)
- << "Expected type to be ForNode"
- << ", but get " << s->GetTypeKey();
+ CHECK(op != nullptr) << "Expected type to be ForNode"
+ << ", but get " << s->GetTypeKey();
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var>(e);
return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<AttrStmtNode>();
- CHECK(op != nullptr)
- << "Expected type to be AttrStmtNode"
- << ", but get " << s->GetTypeKey();
+ CHECK(op != nullptr) << "Expected type to be AttrStmtNode"
+ << ", but get " << s->GetTypeKey();
const IterVarNode* iv = op->node.as<IterVarNode>();
- CHECK(iv != nullptr)
- << "Expected type to be IterVarNode"
- << ", but get " << op->node->GetTypeKey();
+ CHECK(iv != nullptr) << "Expected type to be IterVarNode"
+ << ", but get " << op->node->GetTypeKey();
PrimExpr e = VisitExpr(iv->var);
Var var = Downcast<Var>(e);
if (ivmap_.find(iv) == ivmap_.end()) {
ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag);
}
- return AttrStmtNode::make(
- ivmap_[iv],
- op->attr_key,
- cast(var.dtype(), op->value),
- op->body);
+ return AttrStmtNode::make(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body);
}
return StmtExprMutator::VisitStmt_(op);
}
if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) {
PrimExpr e = StmtExprMutator::VisitExpr_(op);
const CastNode* new_op = e.as<CastNode>();
- CHECK(new_op != nullptr)
- << "Expected type to be CastNode"
- << ", but get " << e->GetTypeKey();
+ CHECK(new_op != nullptr) << "Expected type to be CastNode"
+ << ", but get " << e->GetTypeKey();
return CastNode::make(visitor_.vmap[op], new_op->value);
}
return StmtExprMutator::VisitExpr_(op);
bool is_index_{false};
};
-#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
- PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \
- PrimExpr a = this->VisitExpr(op->a); \
- PrimExpr b = this->VisitExpr(op->b); \
- if (a.same_as(op->a) && \
- b.same_as(op->b)) { \
- return GetRef<PrimExpr>(op); \
- } else { \
- return FUNC(a, b); \
- } \
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
+ PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \
+ PrimExpr a = this->VisitExpr(op->a); \
+ PrimExpr b = this->VisitExpr(op->b); \
+ if (a.same_as(op->a) && b.same_as(op->b)) { \
+ return GetRef<PrimExpr>(op); \
+ } else { \
+ return FUNC(a, b); \
+ } \
}
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=);
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=);
PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
PrimExpr e = StmtExprMutator::VisitExpr_(op);
op = e.as<CallNode>();
- CHECK(op != nullptr)
- << "Expected type to be CallNode"
- << ", but get " << e->GetTypeKey();
+ CHECK(op != nullptr) << "Expected type to be CallNode"
+ << ", but get " << e->GetTypeKey();
if (op->call_type == CallNode::PureIntrinsic) {
if (op->name == intrinsic::tvm_if_then_else) {
return if_then_else(op->args[0], op->args[1], op->args[2]);
return e;
}
-Stmt NarrowDataType(Stmt stmt, int target_bits) {
- return DataTypeRewriter(target_bits)(stmt);
-}
+Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); }
namespace transform {
n->body = DataTypeRewriter(target_bits)(std::move(n->body));
return f;
};
- return CreatePrimFuncPass(
- pass_func, 0, "tir.NarrowDataType", {});
+ return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType")
-.set_body_typed(NarrowDataType);
+TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType);
} // namespace transform
} // namespace tir
/*!
* \file remap_thread_axis.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include <tvm/runtime/registry.h>
-#include <unordered_map>
+#include <unordered_map>
namespace tvm {
namespace tir {
// Mutator to change the read pattern
class ThreadAxisRewriter : private StmtExprMutator {
public:
- explicit ThreadAxisRewriter(
- const std::unordered_map<std::string, IterVar>& tmap)
- : tmap_(tmap) {
- }
+ explicit ThreadAxisRewriter(const std::unordered_map<std::string, IterVar>& tmap) : tmap_(tmap) {}
- Stmt Rewrite(Stmt stmt) {
- return operator()(std::move(stmt));
- }
+ Stmt Rewrite(Stmt stmt) { return operator()(std::move(stmt)); }
private:
Stmt VisitStmt_(const AttrStmtNode* op) final {
CHECK(vmap_[v].same_as(new_iv->var));
}
Stmt body = this->VisitStmt(op->body);
- return AttrStmtNode::make(
- new_iv, op->attr_key, op->value, body);
+ return AttrStmtNode::make(new_iv, op->attr_key, op->value, body);
}
}
return StmtExprMutator::VisitStmt_(op);
std::unordered_map<const VarNode*, Var> vmap_;
};
-
PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
}
auto opt_thread_axis = f->GetAttr<Array<IterVar>>(tir::attr::kDeviceThreadAxis);
- CHECK(opt_thread_axis != nullptr)
- << "Require attribute " << tir::attr::kDeviceThreadAxis;
+ CHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis;
auto thread_axis = opt_thread_axis.value();
auto* n = f.CopyOnWrite();
return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis);
}
-
namespace transform {
Pass RemapThreadAxis(Map<runtime::String, IterVar> thread_map) {
return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis")
-.set_body_typed(RemapThreadAxis);
+TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis);
} // namespace transform
} // namespace tir
* \brief Remove no op from the stmt
*/
#include <tvm/runtime/registry.h>
-#include <tvm/tir/stmt.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
-#include <tvm/tir/transform.h>
+#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
#include <unordered_map>
namespace tvm {
auto n = CopyOnWrite(op);
size_t top = 0;
for (size_t i = 0; i < n->seq.size(); ++i) {
- if (!is_no_op(n->seq[i])) {
+ if (!is_no_op(n->seq[i])) {
n->seq.Set(top++, n->seq[i]);
}
}
}
};
-Stmt RemoveNoOp(Stmt stmt) {
- return NoOpRemover()(std::move(stmt));
-}
+Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); }
namespace transform {
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp")
-.set_body_typed(RemoveNoOp);
+TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp);
} // namespace transform
namespace tvm {
namespace tir {
-
// For now, rewrite unsafe select expression to if_then_else
// TODO(tqchen) pattern matching to support masked load
class UnsafeExprDetector : public ExprFunctor<bool(const PrimExpr& n)> {
public:
// select itself is always considered safe if condition is safe
// Because we will issue guard to make sure it is.
- bool VisitExpr_(const SelectNode* op) {
- return VisitExpr(op->condition);
- }
+ bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); }
bool VisitExpr_(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
return VisitExpr(op->args[0]);
bool VisitExpr_(const GENode* op) final { return BinaryOp(op); }
bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); }
bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); }
- bool VisitExpr_(const NotNode* op) final {
- return VisitExpr(op->a);
- }
- bool VisitExpr_(const LetNode* op) final {
- return VisitExpr(op->body) || VisitExpr(op->value);
- }
- bool VisitExpr_(const CastNode* op) final {
- return VisitExpr(op->value);
- }
- bool VisitExpr_(const BroadcastNode* op) final {
- return VisitExpr(op->value);
- }
- bool VisitExpr_(const RampNode* op) final {
- return VisitExpr(op->base) && VisitExpr(op->stride);
- }
+ bool VisitExpr_(const NotNode* op) final { return VisitExpr(op->a); }
+ bool VisitExpr_(const LetNode* op) final { return VisitExpr(op->body) || VisitExpr(op->value); }
+ bool VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
+ bool VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); }
+ bool VisitExpr_(const RampNode* op) final { return VisitExpr(op->base) && VisitExpr(op->stride); }
bool VisitExpr_(const ShuffleNode* op) final {
for (PrimExpr e : op->vectors) {
if (VisitExpr(e)) return true;
bool VisitExpr_(const StringImmNode* op) final { return false; }
private:
- template<typename T>
+ template <typename T>
bool BinaryOp(const T* op) {
return VisitExpr(op->a) || VisitExpr(op->b);
}
op = expr.as<SelectNode>();
UnsafeExprDetector unsafe;
bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
- if ((unsafe.VisitExpr(op->true_value) ||
- unsafe.VisitExpr(op->false_value)) &&
+ if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) &&
cond_is_scalar_bool) {
- return CallNode::make(
- op->dtype,
- intrinsic::tvm_if_then_else,
- {op->condition, op->true_value, op->false_value},
- CallNode::Intrinsic);
+ return CallNode::make(op->dtype, intrinsic::tvm_if_then_else,
+ {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic);
} else {
return expr;
}
}
};
-Stmt RewriteUnsafeSelect(Stmt stmt) {
- return UnsafeSelectRewriter()(std::move(stmt));
-}
+Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter()(std::move(stmt)); }
namespace transform {
return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect")
-.set_body_typed(RewriteUnsafeSelect);
+TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect);
} // namespace transform
* \file simplify.cc
* \brief Statement simplifier based on analyzer
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
-#include <tvm/tir/analysis.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/tir/op.h>
-#include <tvm/arith/analyzer.h>
#include "../../arith/ir_mutator_with_analyzer.h"
namespace tvm {
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
- explicit StmtSimplifier(Analyzer* analyzer)
- : IRMutatorWithAnalyzer(analyzer) {}
+ explicit StmtSimplifier(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {}
using Parent = IRMutatorWithAnalyzer;
using Parent::VisitStmt;
using Parent::VisitStmt_;
- PrimExpr VisitExpr(const PrimExpr& expr) final {
- return analyzer_->Simplify(expr);
- }
+ PrimExpr VisitExpr(const PrimExpr& expr) final { return analyzer_->Simplify(expr); }
- Stmt Simplify(Stmt stmt) {
- return operator()(std::move(stmt));
- }
+ Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }
Stmt VisitStmt_(const ForNode* op) final {
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
return this->VisitStmt(op->body);
}
Stmt body = this->VisitStmt(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.Simplify")
-.set_body_typed(Simplify);
+TVM_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify);
} // namespace transform
* under the License.
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
}
};
-Stmt SkipAssert(Stmt stmt) {
- return AssertSkipper()(std::move(stmt));
-}
+Stmt SkipAssert(Stmt stmt) { return AssertSkipper()(std::move(stmt)); }
namespace transform {
return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.SkipAssert")
-.set_body_typed(SkipAssert);
+TVM_REGISTER_GLOBAL("tir.transform.SkipAssert").set_body_typed(SkipAssert);
} // namespace transform
* \brief Split device function from host.
*/
#include <tvm/ir/transform.h>
-#include <tvm/tir/op.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/transform.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/target/target.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/container.h>
+#include <tvm/tir/transform.h>
#include <unordered_map>
this->HandleDef(op->var.get());
Stmt body = this->VisitStmt(op->body);
// eliminate unreferenced let
- if (use_count_.at(op->var.get()) == 0 &&
- !HasSideEffect(op->value)) {
+ if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
- if (body.same_as(op->body) &&
- value.same_as(op->value)) {
+ if (body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<Stmt>(op);
} else {
return LetStmtNode::make(op->var, value, body);
this->HandleDef(op->var.get());
PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
- if (use_count_.at(op->var.get()) == 0 &&
- !HasSideEffect(op->value)) {
+ if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
- if (body.same_as(op->body) &&
- value.same_as(op->value)) {
+ if (body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
}
void HandleDef(const VarNode* v) {
- CHECK(!def_count_.count(v))
- << "variable " << v->name_hint
- << " has already been defined, the Stmt is not SSA";
- CHECK(!use_count_.count(v))
- << "variable " << v->name_hint
- << " has been used before definition!";
+ CHECK(!def_count_.count(v)) << "variable " << v->name_hint
+ << " has already been defined, the Stmt is not SSA";
+ CHECK(!use_count_.count(v)) << "variable " << v->name_hint
+ << " has been used before definition!";
use_count_[v] = 0;
def_count_[v] = 1;
}
std::unordered_map<const VarNode*, int> def_count_;
};
-
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
VarUseDefAnalysis m;
for (Var arg : args) {
return m.undefined_;
}
-
class HostDeviceSplitter : public StmtMutator {
public:
- explicit HostDeviceSplitter(IRModule* device_mod,
- Target device_target,
- std::string name_prefix)
- : device_mod_(device_mod),
- device_target_(device_target),
- name_prefix_(name_prefix) {
- }
+ explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix)
+ : device_mod_(device_mod), device_target_(device_target), name_prefix_(name_prefix) {}
Stmt VisitStmt_(const AllocateNode* op) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::pipeline_exec_scope ||
+ if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
return SplitDeviceFunc(GetRef<Stmt>(op));
}
// Create a new version of v.
auto it = handle_data_type_.find(var.get());
if (it != handle_data_type_.end()) {
- tir::Var new_var(var->name_hint,
- PointerType(PrimType((*it).second->dtype)));
+ tir::Var new_var(var->name_hint, PointerType(PrimType((*it).second->dtype)));
params.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_);
device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
Integer(CallingConv::kDeviceKernelLaunch));
- device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
- runtime::String(kernel_symbol));
+ device_func =
+ WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol));
device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
(*device_mod_)->Add(GlobalVar(kernel_symbol), device_func);
for (PrimExpr ext : m.thread_extent_) {
call_args.push_back(ext);
}
- return EvaluateNode::make(CallNode::make(
- DataType::Int(32), intrinsic::tvm_call_packed,
- call_args, CallNode::Intrinsic));
+ return EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed,
+ call_args, CallNode::Intrinsic));
}
// target ir module
std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
};
-
PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
- CHECK(target.defined())
- << "SplitHostDevice: Require the target attribute";
+ CHECK(target.defined()) << "SplitHostDevice: Require the target attribute";
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
- HostDeviceSplitter splitter(
- device_mod,
- target.value(),
- static_cast<std::string>(global_symbol.value()));
+ HostDeviceSplitter splitter(device_mod, target.value(),
+ static_cast<std::string>(global_symbol.value()));
auto* n = func.CopyOnWrite();
n->body = splitter(std::move(n->body));
return std::move(func);
}
-
namespace transform {
Pass SplitHostDevice() {
return mod;
};
- return tvm::transform::CreateModulePass(
- pass_func, 0, "tir.SplitHostDevice", {});
+ return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice")
-.set_body_typed(SplitHostDevice);
+TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice);
} // namespace transform
} // namespace tir
/*!
* \file storage_access.cc
*/
+#include "storage_access.h"
+
#include <tvm/target/target_info.h>
+
#include <string>
#include <utility>
-#include "storage_access.h"
-#include "ir_util.h"
+
#include "../../arith/compute_expr.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::storage_scope) {
const VarNode* buf = op->node.as<VarNode>();
- storage_scope_[buf] =
- StorageScope::make(op->value.as<StringImmNode>()->value);
+ storage_scope_[buf] = StorageScope::make(op->value.as<StringImmNode>()->value);
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::double_buffer_write) {
CHECK(double_buffer_write_ == nullptr);
if (s.access.size() != 0) {
// relax the touched set to contain all ranges in the loop.
std::unordered_map<const VarNode*, arith::IntSet> relax_map;
- relax_map[op->loop_var.get()] = arith::IntSet::range(
- Range::make_by_min_extent(op->min, op->extent));
+ relax_map[op->loop_var.get()] =
+ arith::IntSet::range(Range::make_by_min_extent(op->min, op->extent));
for (AccessEntry& e : s.access) {
if (e.buffer.defined()) {
CHECK(e.touched.defined());
void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const LoadNode *l = op->args[0].as<LoadNode>();
+ const LoadNode* l = op->args[0].as<LoadNode>();
StmtExprVisitor::VisitExpr_(l);
} else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
e.threads = env_threads();
e.dtype = dtype;
e.buffer = Downcast<Var>(op->args[1]);
- e.touched = arith::IntSet::range(
- Range::make_by_min_extent(offset, extent));
+ e.touched = arith::IntSet::range(Range::make_by_min_extent(offset, extent));
e.scope = scope;
if (flag->value & 1) {
e.type = kRead;
#ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
#define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
+#include <tvm/arith/int_set.h>
#include <tvm/ir/attrs.h>
#include <tvm/tir/expr.h>
-#include <tvm/arith/int_set.h>
#include <tvm/tir/stmt_functor.h>
-#include <vector>
+
#include <unordered_map>
+#include <vector>
+
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace tir {
-using runtime::StorageScope;
using runtime::StorageRank;
+using runtime::StorageScope;
/*!
* \brief Base class of storage access analysis
*/
void VisitExpr_(const CallNode* op) final;
protected:
- StorageAccessVisitor() {
- scope_.push_back(std::vector<StmtEntry>());
- }
+ StorageAccessVisitor() { scope_.push_back(std::vector<StmtEntry>()); }
/*! \return number of conditions in the current scope. */
- int condition_counter() const {
- return condition_counter_;
- }
+ int condition_counter() const { return condition_counter_; }
/*! \return whether we are in device environment. */
- bool in_device_env() const {
- return in_device_env_;
- }
+ bool in_device_env() const { return in_device_env_; }
/*! \return environment threads */
- const Array<IterVar>& env_threads() const {
- return env_threads_;
- }
+ const Array<IterVar>& env_threads() const { return env_threads_; }
/*!
* \brief Whether we need analyze the buffer in current scope.
* \param buffer The buffer to be checked
* \param scope The scope of the buffer.
* \return Whether the analysis of buffer is enabled.
*/
- virtual bool Enabled(const VarNode* buffer,
- const StorageScope& scope) const {
- return true;
- }
+ virtual bool Enabled(const VarNode* buffer, const StorageScope& scope) const { return true; }
/*!
* \brief Summarize the sequence of operations into parent.
*
* \return The summarized sequence that represent access that
* the parent should taken care of to synchronize.
*/
- virtual std::vector<AccessEntry> Summarize(
- std::vector<StmtEntry> seq, const ForNode* loop) = 0;
+ virtual std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) = 0;
/*!
* \brief Get the scope of the buffer array.
* \return The scope of the final buffer array.
*/
// The pass definition originates from Halide pipeline.
-#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target_info.h>
+#include <tvm/te/operation.h>
+#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
-#include <tvm/te/operation.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
-#include <tvm/tir/buffer.h>
-#include <tvm/target/target_info.h>
-#include <tvm/runtime/device_api.h>
+
#include <unordered_map>
-#include "ir_util.h"
-#include "arg_binder.h"
+
#include "../../arith/compute_expr.h"
#include "../../arith/ir_visitor_with_analyzer.h"
#include "../../runtime/thread_storage_scope.h"
+#include "arg_binder.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
+using intrinsic::tvm_address_of;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
-using intrinsic::tvm_address_of;
class StorageFlattener : public StmtExprMutator {
public:
- explicit StorageFlattener(const Map<Var, Buffer>& extern_buffer_map,
- int cache_line_size,
- bool create_bound_attributes,
- IRVisitorWithAnalyzer* bound_analyzer)
- : bound_analyzer_(bound_analyzer),
- create_bound_attributes_(create_bound_attributes) {
+ explicit StorageFlattener(const Map<Var, Buffer>& extern_buffer_map, int cache_line_size,
+ bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer)
+ : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) {
for (auto kv : extern_buffer_map) {
BufferEntry e;
e.buffer = kv.second;
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<StoreNode>();
auto it = var_remap_.find(op->buffer_var.get());
- if (it != var_remap_.end() &&
- !it->second.same_as(op->buffer_var)) {
+ if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<VarNode>());
Var buf_var = Downcast<Var>(it->second);
return StoreNode::make(buf_var, op->value, op->index, op->predicate);
auto buffer = Downcast<tir::Buffer>(op->node);
Stmt body = this->VisitStmt(op->body);
auto it = buf_map_.find(buffer);
- CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << buffer;
- body = AttrStmtNode::make(
- it->second.buffer->data, op->attr_key, op->value, std::move(body));
+ CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer;
+ body = AttrStmtNode::make(it->second.buffer->data, op->attr_key, op->value, std::move(body));
return body;
} else if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
const auto& key = op->buffer;
auto it = buf_map_.find(key);
- CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << key;
+ CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second;
- CHECK(!e.released)
- << "Read a buffer that is already out of scope";
+ CHECK(!e.released) << "Read a buffer that is already out of scope";
if (is_opengl_) {
- return EvaluateNode::make(CallNode::make(
- DataType(),
- CallNode::glsl_texture_store,
- {e.buffer->data, op->value},
- CallNode::Intrinsic));
+ return EvaluateNode::make(CallNode::make(DataType(), CallNode::glsl_texture_store,
+ {e.buffer->data, op->value}, CallNode::Intrinsic));
} else {
Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
- shape_collector_.push_back(
- std::make_pair(e.buffer->data, e.buffer->shape));
+ shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape));
}
// To create bound attribute collector should has at least one item.
if (create_bound_attributes_ && shape_collector_.size()) {
for (size_t i = 0; i < shape_collector_.size(); ++i) {
- body = AttrStmtNode::make(
- shape_collector_[i].first, tir::attr::buffer_bound,
- MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
+ body = AttrStmtNode::make(shape_collector_[i].first, tir::attr::buffer_bound,
+ MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
}
}
return body;
}
// deduce current storage scope.
auto it = storage_scope_.find(op->buffer.get());
- CHECK(it != storage_scope_.end())
- << "Cannot find storage scope of " << op->buffer;
+ CHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer;
StorageScope skey;
const std::string& strkey = it->second;
if (strkey.length() == 0) {
if (curr_thread_scope_.size() != 0) {
- skey.rank = runtime::DefaultStorageRank(
- curr_thread_scope_.back().rank);
+ skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
}
} else {
skey = StorageScope::make(strkey);
strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
}
- e.buffer = BufferNode::make(
- Var(op->buffer->data->name_hint, DataType::Handle()),
- op->buffer->dtype, shape, strides, PrimExpr(),
- op->buffer->name, skey.to_string(),
- align, 0, kDefault);
+ e.buffer = BufferNode::make(Var(op->buffer->data->name_hint, DataType::Handle()),
+ op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name,
+ skey.to_string(), align, 0, kDefault);
buf_map_[key] = e;
Stmt body = this->VisitStmt(op->body);
}
if (strides.size() != 0) {
int first_dim = 0;
- ret = AllocateNode::make(
- e.buffer->data, storage_type,
- {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
- make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
+ ret = AllocateNode::make(e.buffer->data, storage_type,
+ {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
+ make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
} else {
shape = e.buffer->shape;
if (shape.size() == 0) {
shape.push_back(make_const(DataType::Int(32), 1));
}
- ret = AllocateNode::make(
- e.buffer->data, storage_type, shape,
- make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
+ ret = AllocateNode::make(e.buffer->data, storage_type, shape,
+ make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
}
- ret = AttrStmtNode::make(
- e.buffer->data, attr::storage_scope,
- StringImmNode::make(e.buffer->scope), ret);
+ ret = AttrStmtNode::make(e.buffer->data, attr::storage_scope,
+ StringImmNode::make(e.buffer->scope), ret);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound,
- MakeBound(e.buffer->dtype, e.buffer->shape), ret);
+ MakeBound(e.buffer->dtype, e.buffer->shape), ret);
}
return ret;
}
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
auto it = var_remap_.find(op->buffer_var.get());
- if (it != var_remap_.end() &&
- !it->second.same_as(op->buffer_var)) {
+ if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<VarNode>());
Var buf_var = Downcast<Var>(it->second);
return LoadNode::make(op->dtype, buf_var, op->index, op->predicate);
const auto& key = op->buffer;
auto it = buf_map_.find(key);
- CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << key;
+ CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second;
- CHECK(!e.released)
- << "Read a buffer that is already out of scope";
+ CHECK(!e.released) << "Read a buffer that is already out of scope";
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
- shape_collector_.push_back(
- std::make_pair(e.buffer->data, e.buffer->shape));
+ shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape));
}
return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype);
}
-
- Stmt VisitStmt_(const PrefetchNode *op) final {
+ Stmt VisitStmt_(const PrefetchNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<PrefetchNode>();
CHECK(op != nullptr);
const auto& key = op->buffer;
auto it = buf_map_.find(key);
- CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << key;
+ CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second;
- CHECK(!e.released)
- << "Read a buffer that is already out of scope";
+ CHECK(!e.released) << "Read a buffer that is already out of scope";
CHECK_EQ(e.buffer->shape.size(), op->bounds.size())
- << "Prefetch dim should be the same as buffer dim";
+ << "Prefetch dim should be the same as buffer dim";
- int block_size = 1,
- elem_cnt = cache_line_size_ / e.buffer->dtype.bytes();
+ int block_size = 1, elem_cnt = cache_line_size_ / e.buffer->dtype.bytes();
int starts = op->bounds.size() - 1;
for (int i = op->bounds.size() - 1; i > starts; --i) {
args.push_back(op->bounds[i]->min);
}
- auto &func_name = op->buffer->name;
- vars.push_back(Var(
- "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32)));
+ auto& func_name = op->buffer->name;
+ vars.push_back(Var("prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32)));
args.push_back(op->bounds[starts]->min + stride * vars.back());
for (int i = starts - 1; i >= 0; --i) {
- vars.push_back(Var(
- "prefetch." + func_name + "." + std::to_string(i), DataType::Int(32)));
+ vars.push_back(Var("prefetch." + func_name + "." + std::to_string(i), DataType::Int(32)));
args.push_back(vars.back() + op->bounds[i]->min);
}
for (int i = starts; i >= 0; --i) {
if (i < starts) {
- stmt = ForNode::make(
- vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
+ stmt = ForNode::make(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None,
+ stmt);
} else {
PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
- PrimExpr address = CallNode::make(
- DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
- PrimExpr prefetch = CallNode::make(
- op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
+ PrimExpr address =
+ CallNode::make(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
+ PrimExpr prefetch = CallNode::make(op->buffer->dtype, CallNode::prefetch,
+ {address, 0, 3, 1}, CallNode::Intrinsic);
stmt = EvaluateNode::make(prefetch);
PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
}
PrimExpr VisitExpr_(const CallNode* op) final {
- CHECK(op->call_type != CallNode::Halide)
- << "Cannot handle Halide calls "
- << " please run SchedulePostProcToPrimFunc first";
+ CHECK(op->call_type != CallNode::Halide) << "Cannot handle Halide calls "
+ << " please run SchedulePostProcToPrimFunc first";
return StmtExprMutator::VisitExpr_(op);
}
return Stmt();
}
-
private:
// The specific tensor data layout is not determined before
// StorageFlatten pass. We use buffer_bind_scope
// We do support a few relaxed case, such as bindingx
// region with shape [1, 1, n, m] to buffer with shape [n, m]
Stmt HandleBufferBindScope(const AttrStmtNode* op) {
- Array<ObjectRef> arr = Downcast<Array<ObjectRef> > (op->node);
+ Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
CHECK_EQ(arr.size(), 2U);
const BufferNode* buffer = arr[0].as<BufferNode>();
const BufferNode* target = arr[1].as<BufferNode>();
auto key = GetRef<Buffer>(target);
auto it = buf_map_.find(key);
- CHECK(it != buf_map_.end())
- << "Cannot find buffer of " << key;
+ CHECK(it != buf_map_.end()) << "Cannot find buffer of " << key;
const BufferEntry& be = it->second;
CHECK(!be.released);
CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
} else {
for (size_t i = 0; i < tuple->args.size(); i += 2) {
begins.push_back(tuple->args[i]);
- auto new_extent = bound_analyzer_->Simplify(tuple->args[i+1]);
+ auto new_extent = bound_analyzer_->Simplify(tuple->args[i + 1]);
extents.push_back(new_extent);
}
}
Buffer slice = be.buffer.MakeSlice(begins, extents);
if (buffer->strides.size() == 0) {
CHECK_EQ(slice->strides.size(), 0U)
- << "Trying to bind compact buffer to strided one strides="
- << slice->strides;
+ << "Trying to bind compact buffer to strided one strides=" << slice->strides;
} else {
slice = slice.MakeStrideView();
}
}
};
- bool ShapeIsValid(const Array<PrimExpr> &shape) {
+ bool ShapeIsValid(const Array<PrimExpr>& shape) {
// Zero-dimensional tensor does not need boundary check.
- if (!shape.size())
- return false;
+ if (!shape.size()) return false;
for (size_t i = 0; i < shape.size(); ++i) {
- if (!shape[i].defined() || !shape[i].dtype().is_scalar() ||
- is_negative_const(shape[i])) {
+ if (!shape[i].defined() || !shape[i].dtype().is_scalar() || is_negative_const(shape[i])) {
return false;
}
}
return true;
}
- PrimExpr MakeBound(const DataType &type, const Array<PrimExpr> &shape) {
+ PrimExpr MakeBound(const DataType& type, const Array<PrimExpr>& shape) {
// We have already checked the shape size to be greater then 0.
PrimExpr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
for (size_t i = 1; i < shape.size(); ++i) {
- bound = MulNode::make(
- bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i]));
+ bound =
+ MulNode::make(bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i]));
}
return bound;
}
// Buffer map
std::unordered_map<Buffer, BufferEntry, ObjectHash, ObjectEqual> buf_map_;
// Dimension alignment
- std::unordered_map<Buffer, std::vector<DimAlignInfo>,
- ObjectHash, ObjectEqual> dim_align_;
+ std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectHash, ObjectEqual> dim_align_;
// Storage scope
std::unordered_map<const Object*, std::string> storage_scope_;
// The current thread scope.
bool create_bound_attributes_{false};
};
-PrimFunc StorageFlatten(PrimFunc func,
- int cache_line_size,
- bool create_bound_attributes) {
+PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) {
auto fptr = func.CopyOnWrite();
IRVisitorWithAnalyzer bound_analyzer;
bound_analyzer(fptr->body);
- fptr->body = StorageFlattener(fptr->buffer_map,
- cache_line_size,
- create_bound_attributes,
+ fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes,
&bound_analyzer)(std::move(fptr->body));
return func;
}
-
namespace transform {
// TODO(tvm-team): consolidate configs to the PassContext
-Pass StorageFlatten(int cache_line_size,
- bool create_bound_attributes) {
+Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
- return StorageFlatten(
- std::move(f), cache_line_size, create_bound_attributes);
+ return StorageFlatten(std::move(f), cache_line_size, create_bound_attributes);
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten")
-.set_body_typed(StorageFlatten);
+TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten").set_body_typed(StorageFlatten);
} // namespace transform
* \brief Memory access pattern analysis and optimization.
* Re-write data access to enable memory sharing when possible.
*/
-#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/transform.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target_info.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/target/target_info.h>
+#include <tvm/tir/transform.h>
+
#include <map>
-#include <unordered_set>
#include <unordered_map>
-#include "ir_util.h"
+#include <unordered_set>
+
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
const VarNode* buf = op->buffer_var.get();
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
- CHECK_LT(it->second.level, scope_.size())
- << "Load memory in places other than store.";
+ CHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
scope_[it->second.level].touched.push_back(buf);
}
}
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
- CHECK_LT(it->second.level, scope_.size())
- << " buf=" << buf->name_hint;
+ CHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
scope_[it->second.level].touched.push_back(buf);
}
}
- template<typename T>
+ template <typename T>
void VisitNewScope(const T* op) {
scope_.push_back(StmtEntry());
StmtEntry e;
e.stmt = op;
- int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+ int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
// before scope.
linear_seq_.push_back(e);
StmtExprVisitor::VisitStmt_(op);
// after scope.
e.touched = std::move(scope_.back().touched);
scope_.pop_back();
- int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+ int64_t end_index = static_cast<int64_t>(linear_seq_.size());
CHECK_GT(end_index, begin_index);
e.scope_pair_offset = begin_index - end_index;
linear_seq_.push_back(e);
VisitNewScope(op);
} else if (op->attr_key == attr::storage_scope) {
const VarNode* buf = op->node.as<VarNode>();
- alloc_info_[buf].storage_scope =
- StorageScope::make(op->value.as<StringImmNode>()->value);
+ alloc_info_[buf].storage_scope = StorageScope::make(op->value.as<StringImmNode>()->value);
StmtExprVisitor::VisitStmt_(op);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
- void VisitStmt_(const IfThenElseNode* op) final {
- VisitNewScope(op);
- }
+ void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
- void VisitStmt_(const ForNode* op) final {
- VisitNewScope(op);
- }
+ void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
- void VisitStmt_(const AssertStmtNode* op) final {
- VisitNewScope(op);
- }
+ void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
//
class InplaceOpVerifier : public StmtExprVisitor {
public:
- bool Check(const Object* stmt,
- const VarNode* dst,
- const VarNode* src) {
+ bool Check(const Object* stmt, const VarNode* dst, const VarNode* src) {
dst_ = dst;
src_ = src;
result_ = true;
void VisitExpr_(const VarNode* op) final {
// assume all opaque access is unsafe
if (op == dst_ || op == src_) {
- result_ = false; return;
+ result_ = false;
+ return;
}
}
void VisitStmt_(const AttrStmtNode* op) final {
// always reject extern code
- if (op->attr_key == attr::extern_scope ||
- op->attr_key == attr::volatile_scope) {
- result_ = false; return;
+ if (op->attr_key == attr::extern_scope || op->attr_key == attr::volatile_scope) {
+ result_ = false;
+ return;
}
StmtExprVisitor::VisitStmt_(op);
}
const VarNode* buf = op->buffer_var.get();
// cannot read from dst_ (no reduction)
if (buf == dst_) {
- result_ = false; return;
+ result_ = false;
+ return;
}
// do not allow indirect memory load
if (mem_nest_ != 0) {
- result_ = false; return;
+ result_ = false;
+ return;
}
if (src_ == buf) {
- if (store_ == nullptr ||
- store_->value.dtype() != op->dtype ||
+ if (store_ == nullptr || store_->value.dtype() != op->dtype ||
!tir::ExprDeepEqual()(store_->index, op->index)) {
- result_ = false; return;
+ result_ = false;
+ return;
}
}
++mem_nest_;
--mem_nest_;
}
-
private:
// result of the check
bool result_{true};
for (StorageEntry* e : attach_map_.at(nullptr)) {
// CHECK_EQ(e->scope.rank, 0);
if (e->new_alloc.defined()) {
- nest.emplace_back(AttrStmtNode::make(
- e->alloc_var, attr::storage_scope,
- StringImmNode::make(e->scope.to_string()),
- EvaluateNode::make(0)));
+ nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope,
+ StringImmNode::make(e->scope.to_string()),
+ EvaluateNode::make(0)));
nest.push_back(e->new_alloc);
}
}
op = stmt.as<StoreNode>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return stmt;
- return StoreNode::make(it->second->alloc_var,
- op->value,
- RemapIndex(op->value.dtype(), op->index, it->second),
- op->predicate);
+ return StoreNode::make(it->second->alloc_var, op->value,
+ RemapIndex(op->value.dtype(), op->index, it->second), op->predicate);
}
PrimExpr VisitExpr_(const LoadNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr;
- return LoadNode::make(op->dtype,
- it->second->alloc_var,
- RemapIndex(op->dtype, op->index, it->second),
- op->predicate);
+ return LoadNode::make(op->dtype, it->second->alloc_var,
+ RemapIndex(op->dtype, op->index, it->second), op->predicate);
}
PrimExpr VisitExpr_(const VarNode* op) final {
auto it = alloc_map_.find(op);
if (se->bits_offset != 0) {
offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
}
- return CallNode::make(
- op->dtype, op->name,
- {op->args[0], se->alloc_var, offset, extent, op->args[4]},
- op->call_type);
+ return CallNode::make(op->dtype, op->name,
+ {op->args[0], se->alloc_var, offset, extent, op->args[4]},
+ op->call_type);
} else {
return StmtExprMutator::VisitExpr_(op);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
return this->VisitStmt(op->body);
- } else if (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::virtual_thread ||
+ } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread ||
attr::IsPragmaKey(op->attr_key)) {
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
- return AttrStmtNode::make(
- op->node, op->attr_key, op->value,
- MakeAttach(svec, op->body));
+ return AttrStmtNode::make(op->node, op->attr_key, op->value, MakeAttach(svec, op->body));
} else {
return StmtExprMutator::VisitStmt_(op);
}
op = stmt.as<AttrStmtNode>();
auto it = alloc_map_.find(op->node.as<VarNode>());
if (it == alloc_map_.end()) return stmt;
- return AttrStmtNode::make(
- it->second->alloc_var, op->attr_key, op->value, op->body);
+ return AttrStmtNode::make(it->second->alloc_var, op->attr_key, op->value, op->body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const ForNode* op) final {
- CHECK(op->for_type != ForType::Vectorized)
- << "VectorizeLoop before LiftStorageAlloc";
+ CHECK(op->for_type != ForType::Vectorized) << "VectorizeLoop before LiftStorageAlloc";
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
- return ForNode::make(
- op->loop_var, op->min, op->extent, op->for_type, op->device_api,
- MakeAttach(svec, op->body));
+ return ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api,
+ MakeAttach(svec, op->body));
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const AllocateNode* op) final {
- return this->VisitStmt(op->body);
- }
+ Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); }
private:
struct StorageEntry {
std::vector<const VarNode*> kill;
};
- Stmt MakeAttach(const std::vector<StorageEntry*>& svec,
- Stmt body) {
+ Stmt MakeAttach(const std::vector<StorageEntry*>& svec, Stmt body) {
std::vector<Stmt> nest;
for (StorageEntry* e : svec) {
if (e->new_alloc.defined()) {
- nest.emplace_back(AttrStmtNode::make(
- e->alloc_var, attr::storage_scope,
- StringImmNode::make(e->scope.to_string()),
- EvaluateNode::make(0)));
+ nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope,
+ StringImmNode::make(e->scope.to_string()),
+ EvaluateNode::make(0)));
nest.push_back(e->new_alloc);
}
}
attach_map_[e->attach_scope_].push_back(e);
}
// find allocation via attach map.
- for (auto &kv : attach_map_) {
+ for (auto& kv : attach_map_) {
// find the element with the most amount of bytes.
std::vector<StorageEntry*>& vec = kv.second;
// try to find merge, for tagged memory
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry* e = vec[i];
if (e->scope.tag.length() != 0) {
- CHECK_NE(e->const_nbits, 0U)
- << "Special tagged memory must be const size";
+ CHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size";
for (size_t j = 0; j < i; ++j) {
if (e->scope == vec[j]->scope) {
vec[j]->merged_children.push_back(e);
// already merged
if (e->bits_offset != 0) continue;
if (e->merged_children.size() != 0) {
- NewAllocTagMerged(e); continue;
+ NewAllocTagMerged(e);
+ continue;
}
// Get the allocation size;
e->alloc_var = e->allocs[0]->buffer_var;
if (e->allocs.size() == 1) {
// simply use the original allocation.
PrimExpr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
- make_const(DataType::Int(32), 1));
- e->new_alloc = AllocateNode::make(
- e->alloc_var, alloc_type, {sz},
- e->allocs[0]->condition, EvaluateNode::make(0));
+ make_const(DataType::Int(32), 1));
+ e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition,
+ EvaluateNode::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
// Build a merged allocation
PrimExpr combo_size;
for (const AllocateNode* op : e->allocs) {
- PrimExpr sz = arith::ComputeReduce<MulNode>(
- op->extents, make_const(DataType::Int(32), 1));
+ PrimExpr sz =
+ arith::ComputeReduce<MulNode>(op->extents, make_const(DataType::Int(32), 1));
auto nbits = op->dtype.bits() * op->dtype.lanes();
if (const auto* imm = sz.as<IntImmNode>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
- LOG(WARNING) << "The allocation requires : " << imm->value
- << " * " << nbits
+ LOG(WARNING) << "The allocation requires : " << imm->value << " * " << nbits
<< " bits, which is greater than the maximum of"
" int32. The size is cast to int64."
<< "\n";
combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = analyzer_.Simplify(combo_size);
- e->new_alloc = AllocateNode::make(
- e->alloc_var, alloc_type, {combo_size}, const_true(),
- EvaluateNode::make(0));
+ e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {combo_size}, const_true(),
+ EvaluateNode::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
// Always align to max_simd_bits
// so we can remap types by keeping this property
if (total_bits % align != 0) {
- total_bits += align - (total_bits % align);
+ total_bits += align - (total_bits % align);
}
e->alloc_var = e->allocs[0]->buffer_var;
for (StorageEntry* child : e->merged_children) {
child->alloc_var = e->alloc_var;
total_bits += child->const_nbits;
if (total_bits % align != 0) {
- total_bits += align - (total_bits % align);
+ total_bits += align - (total_bits % align);
}
}
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
- PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
- (total_bits + type_bits - 1) / type_bits);
- e->new_alloc = AllocateNode::make(
- e->alloc_var, e->elem_type, {alloc_size}, const_true(),
- EvaluateNode::make(0));
+ PrimExpr alloc_size =
+ make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits);
+ e->new_alloc = AllocateNode::make(e->alloc_var, e->elem_type, {alloc_size}, const_true(),
+ EvaluateNode::make(0));
if (info.defined()) {
CHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
visitor.Check(s.stmt, var, src)) {
uint64_t const_nbits =
static_cast<uint64_t>(ae.alloc->constant_allocation_size()) *
- ae.alloc->dtype.bits() *
- ae.alloc->dtype.lanes();
+ ae.alloc->dtype.bits() * ae.alloc->dtype.lanes();
if (src_entry->const_nbits == const_nbits && !inplace_found) {
// successfully inplace
dst_entry = src_entry;
// enter/exit new scope
if (s.stmt->IsInstance<AttrStmtNode>()) {
const auto* op = static_cast<const AttrStmtNode*>(s.stmt);
- if (op->attr_key == attr::thread_extent ||
- op->attr_key == attr::virtual_thread ||
+ if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread ||
attr::IsPragmaKey(op->attr_key)) {
PlanNewScope(op);
} else {
}
}
// Allocate new storage entry.
- StorageEntry* NewAlloc(const AllocateNode* op,
- const Object* attach_scope,
- const StorageScope& scope,
- size_t const_nbits) {
+ StorageEntry* NewAlloc(const AllocateNode* op, const Object* attach_scope,
+ const StorageScope& scope, size_t const_nbits) {
CHECK(op != nullptr);
// Re-use not successful, allocate a new buffer.
std::unique_ptr<StorageEntry> entry(new StorageEntry());
return e;
}
- StorageEntry* FindAlloc(const AllocateNode* op,
- const Object* attach_scope,
+ StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
const StorageScope& scope) {
CHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
- uint64_t const_nbits = static_cast<uint64_t>(
- op->constant_allocation_size() * op_elem_bits);
+ uint64_t const_nbits = static_cast<uint64_t>(op->constant_allocation_size() * op_elem_bits);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (scope.tag.length() == 0) {
if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
- if (const_nbits > 0 && const_nbits <= 32) {
+ if (const_nbits > 0 && const_nbits <= 32) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
}
auto end = const_free_map_.upper_bound(const_nbits * match_range);
// start looking at the buffer that is bigger than the required size first
for (auto it = mid; it != end; ++it) {
- StorageEntry *e = it->second;
+ StorageEntry* e = it->second;
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
// when not divided, no reuse, eg, float4 vs float3
// then start looking at smaller buffers.
for (auto it = mid; it != begin;) {
--it;
- StorageEntry *e = it->second;
+ StorageEntry* e = it->second;
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
if (e->elem_type != op->dtype.element_of()) continue;
}
} else {
// Simple strategy: round roubin.
- for (auto it = sym_free_list_.begin();
- it != sym_free_list_.end(); ++it) {
+ for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) {
StorageEntry* e = *it;
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
// This rules only apply if we are using non special memory
if (e->scope.tag.length() == 0) {
// Disable sharing of local memory.
- if (e->scope.rank >= StorageRank::kWarp ||
- e->allocs[0]->dtype.is_handle()) return;
+ if (e->scope.rank >= StorageRank::kWarp || e->allocs[0]->dtype.is_handle()) return;
// disable reuse of small arrays
if (e->const_nbits > 0 && e->const_nbits <= 32) return;
}
arith::Analyzer analyzer_;
};
-
// Turn alloc into vector alloc
// if all its access is the same vector type.
class VectorAllocRewriter : public StmtExprMutator {
op = stmt.as<AllocateNode>();
const auto& tvec = acc_map_[op->buffer_var.get()];
- if (tvec.size() == 1 &&
- tvec[0].element_of() == op->dtype.element_of() &&
- tvec[0].lanes() % op->dtype.lanes() == 0 &&
- tvec[0].lanes() != op->dtype.lanes()) {
+ if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() &&
+ tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != op->dtype.lanes()) {
int factor = tvec[0].lanes() / op->dtype.lanes();
Array<PrimExpr> extents = op->extents;
arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]);
if (me->base % factor == 0 && me->coeff % factor == 0) {
extents.Set(extents.size() - 1,
extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
- return AllocateNode::make(
- op->buffer_var, tvec[0], extents,
- op->condition, op->body);
+ return AllocateNode::make(op->buffer_var, tvec[0], extents, op->condition, op->body);
}
}
return stmt;
return VectorAllocRewriter()(std::move(stmt));
}
-
PrimFunc PointerValueTypeRewrite(PrimFunc f) {
auto* n = f.CopyOnWrite();
VectorAllocRewriter rewriter;
const auto& tvec = rewriter.acc_map_[var.get()];
if (tvec.size() == 1) {
- tir::Var new_var(var->name_hint,
- PointerType(PrimType(tvec[0])));
+ tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0])));
args.push_back(new_var);
remap_vars.Set(var, new_var);
// always set data type to be non vectorized so
// load/store can still work via scalarization
if (tvec.size() != 0 && !var->type_annotation.defined()) {
- tir::Var new_var(var->name_hint,
- PointerType(PrimType(tvec[0].with_lanes(1))));
+ tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1))));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
return f;
}
-
namespace transform {
Pass StorageRewrite() {
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite")
-.set_body_typed(StorageRewrite);
-
+TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite);
Pass PointerValueTypeRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
}
TVM_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite")
-.set_body_typed(PointerValueTypeRewrite);
+ .set_body_typed(PointerValueTypeRewrite);
} // namespace transform
* \brief Infer TensorCore metadata from tensor intrinsic.
* \file tensorcore_fragment.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
-#include "storage_access.h"
-#include "ir_util.h"
#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+#include "storage_access.h"
namespace tvm {
namespace tir {
std::string layout;
FragmentInfo() = default;
FragmentInfo(int _m, int _n, int _k, const std::string& _layout)
- : m(_m), n(_n), k(_k), layout(_layout) {}
+ : m(_m), n(_n), k(_k), layout(_layout) {}
};
void VisitExpr_(const CallNode* op) final {
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
class FragmentChecker : public StmtExprVisitor {
public:
- explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
+ explicit FragmentChecker(const FragmentGetter& getter) : fragment_getter(getter) {}
void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync
- if (op->is_intrinsic(intrinsic::tvm_mma_sync) ||
- op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
+ if (op->is_intrinsic(intrinsic::tvm_mma_sync) || op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
CHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_var_d = op->args[0].as<VarNode>();
const VarNode* buffer_var_a = op->args[2].as<VarNode>();
return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
}
// Fragment infomation
- const FragmentGetter &fragment_getter;
+ const FragmentGetter& fragment_getter;
};
// Store the metadata into attributes
class InferFragmenter : public StmtMutator {
public:
- explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
+ explicit InferFragmenter(const FragmentGetter& getter) : fragment_getter(getter) {}
Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer);
// Add shape attribute to all fragments
- std::string shape = std::to_string(info.m) + ", " +
- std::to_string(info.n) + ", " +
- std::to_string(info.k);
+ std::string shape =
+ std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k);
PrimExpr shape_expr = StringImmNode::make(shape);
Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
if (info.layout != "") {
// Add shape attribute to matrix_a and matrix_b
Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout,
- StringImmNode::make(info.layout), shape_attr);
+ StringImmNode::make(info.layout), shape_attr);
return layout_attr;
} else {
return shape_attr;
private:
// Fragment infomation
- const FragmentGetter &fragment_getter;
+ const FragmentGetter& fragment_getter;
};
Stmt InferFragment(Stmt stmt) {
return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.InferFragment")
-.set_body_typed(InferFragment);
+TVM_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment);
} // namespace transform
} // namespace tir
/*!
* \file thread_storage_sync.cc
*/
-#include <tvm/tir/expr.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
-#include <tvm/tir/transform.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include <tvm/runtime/registry.h>
#include <unordered_map>
#include <unordered_set>
+#include "../../runtime/thread_storage_scope.h"
#include "ir_util.h"
#include "storage_access.h"
-#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace tir {
class ThreadSyncPlanner : public StorageAccessVisitor {
public:
- explicit ThreadSyncPlanner(StorageScope sync_scope)
- : sync_scope_(sync_scope) {}
+ explicit ThreadSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {}
- // The syncs inserted before each statement
+ // The syncs inserted before each statement
std::unordered_set<const Object*> syncs_inserted_;
protected:
- bool Enabled(const VarNode* buf,
- const StorageScope& scope) const final {
+ bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
return in_device_env() && scope == sync_scope_;
}
// Plan the sync
- std::vector<AccessEntry> Summarize(
- std::vector<StmtEntry> seq, const ForNode* loop) final {
+ std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
// Unsynced reads and writes
std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes;
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, false)) {
- sync_before_stmt = true; break;
+ sync_before_stmt = true;
+ break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, false)) {
- sync_before_stmt = true; break;
+ sync_before_stmt = true;
+ break;
}
} else if (acc.type == kSync) {
- reads.clear(); writes.clear();
+ reads.clear();
+ writes.clear();
}
}
// If sync is inserted. remove the irrelevant things.
if (sync_before_stmt) {
- reads.clear(); writes.clear();
+ reads.clear();
+ writes.clear();
}
// Add the read/write of current statement
for (const AccessEntry& acc : s.access) {
} else if (acc.type == kWrite) {
writes.push_back(acc);
} else if (acc.type == kSync) {
- reads.clear(); writes.clear();
+ reads.clear();
+ writes.clear();
}
}
if (sync_before_stmt) {
- CHECK_EQ(condition_counter(), 0)
- << "Cannot insert syncs inside condition";
+ CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
syncs_inserted_.insert(s.stmt);
}
}
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
- sync_before_stmt = true; break;
+ sync_before_stmt = true;
+ break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) {
- sync_before_stmt = true; break;
+ sync_before_stmt = true;
+ break;
}
} else if (acc.type == kSync) {
- reads.clear(); writes.clear();
+ reads.clear();
+ writes.clear();
}
}
if (sync_before_stmt) {
- CHECK_EQ(condition_counter(), 0)
- << "Cannot insert syncs inside condition";
+ CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
syncs_inserted_.insert(s.stmt);
break;
}
private:
// find conflicting entry in vec.
- bool FindConflict(const std::vector<AccessEntry>& vec,
- const AccessEntry& e,
- bool loop_carry) {
+ bool FindConflict(const std::vector<AccessEntry>& vec, const AccessEntry& e, bool loop_carry) {
for (const AccessEntry& x : vec) {
if (x.buffer.same_as(e.buffer)) {
// Assumes no race between threads
// Same index value means no conflicts
// TODO(tqchen) more standard set based testing.
- if (e.touched.is_single_point() &&
- x.touched.is_single_point()) {
- if (ExprDeepEqual()(e.touched.point_value(),
- x.touched.point_value())) continue;
+ if (e.touched.is_single_point() && x.touched.is_single_point()) {
+ if (ExprDeepEqual()(e.touched.point_value(), x.touched.point_value())) continue;
}
- if (x.double_buffer_write &&
- e.type == kRead &&
- !loop_carry) continue;
+ if (x.double_buffer_write && e.type == kRead && !loop_carry) continue;
return true;
}
}
class ThreadSyncInserter : public StmtExprMutator {
public:
- ThreadSyncInserter(StorageScope sync_scope,
- const std::unordered_set<const Object*>& syncs)
+ ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set<const Object*>& syncs)
: sync_scope_(sync_scope), syncs_(syncs) {}
Stmt VisitStmt(const Stmt& stmt) final {
if (sync_scope_.rank == StorageRank::kGlobal) {
barrier = MakeGlobalBarrier();
} else {
- barrier = EvaluateNode::make(
- CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
- {StringImmNode::make(sync_scope_.to_string())},
- CallNode::Intrinsic));
+ barrier = EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
+ {StringImmNode::make(sync_scope_.to_string())},
+ CallNode::Intrinsic));
}
// Mutate after query, to avoid stmt change.
auto ret = StmtExprMutator::VisitStmt(stmt);
return ret;
} else if (op->attr_key == attr::storage_scope) {
const VarNode* buf = op->node.as<VarNode>();
- storage_scope_[buf] =
- StorageScope::make(op->value.as<StringImmNode>()->value);
+ storage_scope_[buf] = StorageScope::make(op->value.as<StringImmNode>()->value);
return StmtExprMutator::VisitStmt_(op);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
rw_stats_.clear();
- Stmt kinit = EvaluateNode::make(
- CallNode::make(
- DataType::Int(32),
- intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic));
+ Stmt kinit = EvaluateNode::make(CallNode::make(
+ DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic));
body = SeqStmt({kinit, body});
- body = AttrStmtNode::make(
- op->node, op->attr_key, op->value, body);
+ body = AttrStmtNode::make(op->node, op->attr_key, op->value, body);
return SeqStmt({prep, body});
}
Stmt MakeGlobalBarrier() {
IterVar iv = Downcast<IterVar>(attr->node);
runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag);
if (s.rank == 0) {
- num_blocks_ = (num_blocks_.defined() ?
- attr->value * num_blocks_ : attr->value);
+ num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value);
} else if (s.rank == 1) {
PrimExpr cond = iv->var == make_zero(iv->var.dtype());
is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
}
return EvaluateNode::make(
CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
- {StringImmNode::make(sync_scope_.to_string()),
- is_lead_, num_blocks_},
- CallNode::Intrinsic));
+ {StringImmNode::make(sync_scope_.to_string()), is_lead_, num_blocks_},
+ CallNode::Intrinsic));
}
// data structure.
StorageScope sync_scope_;
return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.ThreadSync")
-.set_body_typed(ThreadSync);
+TVM_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync);
} // namespace transform
} // namespace tir
* \file unroll_loop.cc
*/
// Unrolls the loop as in Halide pipeline.
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/arith/analyzer.h>
-#include <unordered_set>
+#include <tvm/tir/transform.h>
+
#include <unordered_map>
+#include <unordered_set>
#include <vector>
-#include "ir_util.h"
+
#include "../../arith/compute_expr.h"
+#include "ir_util.h"
namespace tvm {
namespace tir {
class LoopUnroller : public StmtExprMutator {
public:
- explicit LoopUnroller(int auto_max_step,
- int auto_max_depth,
- int auto_max_extent,
+ explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent,
bool explicit_unroll)
: auto_max_step_(auto_max_step),
auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent),
- explicit_unroll_(explicit_unroll) {
- }
+ explicit_unroll_(explicit_unroll) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
op = stmt.as<ForNode>();
int value = GetExtent(op);
// condition for auto unroll
- bool auto_unroll = (
- op->for_type == ForType::Serial &&
- value >= 0 &&
- normal_loop_depth_ == 0 &&
- unroll_depth_ <= auto_max_depth_);
+ bool auto_unroll = (op->for_type == ForType::Serial && value >= 0 && normal_loop_depth_ == 0 &&
+ unroll_depth_ <= auto_max_depth_);
- auto_unroll = auto_unroll && (
- value * step_count_ <= auto_max_step_||
- value <= auto_max_extent_);
+ auto_unroll =
+ auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_);
if (op->for_type == ForType::Unrolled) {
- CHECK_GE(value, 0)
- << "Cannot unroll non-constant loop";
+ CHECK_GE(value, 0) << "Cannot unroll non-constant loop";
auto_unroll = true;
}
if (auto_unroll) {
- step_count_ *= value;
+ step_count_ *= value;
unroll_depth_ += 1;
} else {
normal_loop_depth_ += 1;
} else {
if (auto_unroll) {
if (op->for_type != ForType::Unrolled) {
- return ForNode::make(
- op->loop_var, op->min, op->extent,
- ForType::Unrolled, op->device_api, op->body);
+ return ForNode::make(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api,
+ op->body);
}
}
return stmt;
int GetExtent(const ForNode* op) {
// constant folding.
PrimExpr extent = analyzer_.Simplify(op->extent);
- const IntImmNode *v1 = extent.as<IntImmNode>();
+ const IntImmNode* v1 = extent.as<IntImmNode>();
int value = -1;
// integers that do not fit in int32_t are treated as symbolic,
// as it's impossible to unroll such large loops
arith::Analyzer analyzer_;
};
-
-Stmt UnrollLoop(Stmt stmt,
- int auto_max_step,
- int auto_max_depth,
- int auto_max_extent,
+Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_max_depth, int auto_max_extent,
bool explicit_unroll) {
- Stmt ret = LoopUnroller(
- auto_max_step,
- auto_max_depth,
- auto_max_extent,
- explicit_unroll)(stmt);
+ Stmt ret = LoopUnroller(auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll)(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
namespace transform {
-Pass UnrollLoop(int auto_max_step,
- int auto_max_depth,
- int auto_max_extent,
- bool explicit_unroll) {
+Pass UnrollLoop(int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
- n->body = UnrollLoop(std::move(f->body),
- auto_max_step,
- auto_max_depth,
- auto_max_extent,
+ n->body = UnrollLoop(std::move(f->body), auto_max_step, auto_max_depth, auto_max_extent,
explicit_unroll);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop")
-.set_body_typed(UnrollLoop);
+TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop);
} // namespace transform
* \file vectorize_loop.cc
*/
// Loop vectorizer as in Halide pipeline.
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/arith/analyzer.h>
-#include <unordered_set>
+#include <tvm/tir/transform.h>
+
#include <unordered_map>
+#include <unordered_set>
#include <vector>
+
#include "../../arith/compute_expr.h"
namespace tvm {
return BroadcastNode::make(op->value, lanes);
}
}
- CHECK_EQ(e.dtype().lanes(), 1)
- << "Cannot broadcast lane=" << e.dtype().lanes()
- << " to " << lanes;
+ CHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to "
+ << lanes;
return BroadcastNode::make(e, lanes);
}
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
if (op->buffer_var.get() == buf_) {
- return LoadNode::make(op->dtype, op->buffer_var,
- op->index * var_lanes_ + var_,
- op->predicate);
+ return LoadNode::make(op->dtype, op->buffer_var, op->index * var_lanes_ + var_,
+ op->predicate);
} else {
return expr;
}
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<StoreNode>();
if (op->buffer_var.get() == buf_) {
- return StoreNode::make(op->buffer_var,
- op->value,
- op->index * var_lanes_ + var_,
- op->predicate);
+ return StoreNode::make(op->buffer_var, op->value, op->index * var_lanes_ + var_,
+ op->predicate);
} else {
return stmt;
}
class Vectorizer : public StmtExprMutator {
public:
- Vectorizer(Var var, int var_lanes)
- : var_(var), var_lanes_(var_lanes) {
+ Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
ramp_ = RampNode::make(0, 1, var_lanes);
}
}
}
- PrimExpr VisitExpr_(const AddNode* op) final {
- return AddSubVec(op);
- }
- PrimExpr VisitExpr_(const SubNode* op) final {
- return AddSubVec(op);
- }
+ PrimExpr VisitExpr_(const AddNode* op) final { return AddSubVec(op); }
+ PrimExpr VisitExpr_(const SubNode* op) final { return AddSubVec(op); }
PrimExpr VisitExpr_(const MulNode* op) final {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
- if (a.same_as(op->a) &&
- b.same_as(op->b)) {
+ if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
- return RampNode::make(
- a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
+ return RampNode::make(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
}
if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
- return RampNode::make(
- b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
+ return RampNode::make(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
}
}
return MulNode::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
return BinaryVec(op);
}
- PrimExpr VisitExpr_(const DivNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const ModNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const FloorDivNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const FloorModNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const MinNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const MaxNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const EQNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const NENode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const LTNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const LENode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const GTNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const GENode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const AndNode* op) final {
- return BinaryVec(op);
- }
- PrimExpr VisitExpr_(const OrNode* op) final {
- return BinaryVec(op);
- }
+ PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec(op); }
+ PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec(op); }
PrimExpr VisitExpr_(const RampNode* op) final {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
stride = BroadcastTo(stride, lanes);
Array<PrimExpr> elems;
for (int i = 0; i < lanes; ++i) {
- elems.push_back(
- RampNode::make(ShuffleNode::make_extract_element(base, i),
- ShuffleNode::make_extract_element(stride, i),
- op->lanes));
+ elems.push_back(RampNode::make(ShuffleNode::make_extract_element(base, i),
+ ShuffleNode::make_extract_element(stride, i), op->lanes));
}
return ShuffleNode::make_concat(elems);
}
- PrimExpr VisitExpr_(const SelectNode *op) final {
+ PrimExpr VisitExpr_(const SelectNode* op) final {
PrimExpr cond = this->VisitExpr(op->condition);
PrimExpr t = this->VisitExpr(op->true_value);
PrimExpr f = this->VisitExpr(op->false_value);
- if (cond.same_as(op->condition) &&
- t.same_as(op->true_value) &&
- f.same_as(op->false_value)) {
+ if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
} else {
- int lanes = std::max(std::max(
- cond.dtype().lanes(),
- t.dtype().lanes()), f.dtype().lanes());
+ int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes());
return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
}
- PrimExpr VisitExpr_(const CastNode *op) final {
+ PrimExpr VisitExpr_(const CastNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
if (v == var_.get()) {
return ramp_;
} else if (lets_.count(v)) {
- return lets_[v];
+ return lets_[v];
} else {
return GetRef<PrimExpr>(v);
}
}
// IfThenElse expr
- PrimExpr MutateIfThenElseExpr_(const CallNode *op) {
+ PrimExpr MutateIfThenElseExpr_(const CallNode* op) {
PrimExpr cond = this->VisitExpr(op->args[0]);
- if (cond.dtype().is_vector()) {
+ if (cond.dtype().is_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]);
- if (cond.same_as(op->args[0]) &&
- t.same_as(op->args[1]) &&
- f.same_as(op->args[2])) {
+ if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
t = BroadcastTo(t, lanes);
f = BroadcastTo(f, lanes);
- return CallNode::make(
- op->dtype.with_lanes(lanes), op->name,
- {cond, t, f}, op->call_type, op->func, op->value_index);
+ return CallNode::make(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type,
+ op->func, op->value_index);
}
}
// Call
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
- return CallNode::make(
- op->dtype, op->name, new_args, op->call_type, op->func, op->value_index);
+ return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func,
+ op->value_index);
}
} else {
int lane = 0;
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
- return CallNode::make(
- op->dtype.with_lanes(lane), op->name, new_args,
- op->call_type, op->func, op->value_index);
+ return CallNode::make(op->dtype.with_lanes(lane), op->name, new_args, op->call_type,
+ op->func, op->value_index);
}
}
}
return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes());
- return LoadNode::make(
- op->dtype.with_lanes(lanes),
- op->buffer_var,
- BroadcastTo(index, lanes),
- BroadcastTo(pred, lanes));
+ return LoadNode::make(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes),
+ BroadcastTo(pred, lanes));
}
}
// Let
return LetNode::make(v, value, this->VisitExpr(op->body));
} else {
PrimExpr body = this->VisitExpr(op->body);
- if (value.same_as(op->value) &&
- body.same_as(op->body)) {
+ if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
} else {
int lanes = std::max(value.dtype().lanes(), index.dtype().lanes());
lanes = std::max(lanes, pred.dtype().lanes());
- return StoreNode::make(op->buffer_var,
- BroadcastTo(value, lanes),
- BroadcastTo(index, lanes),
- BroadcastTo(pred, lanes));
+ return StoreNode::make(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes),
+ BroadcastTo(pred, lanes));
}
}
// For
return Scalarize(GetRef<Stmt>(op));
}
Stmt body = this->VisitStmt(op->body);
- if (extent.same_as(op->extent) &&
- body.same_as(op->body)) {
+ if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return ForNode::make(
- op->loop_var, op->min, extent,
- op->for_type, op->device_api, body);
+ return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
}
}
// IfThenElse
if (op->else_case.defined()) {
else_case = this->VisitStmt(op->else_case);
}
- if (condition.same_as(op->condition) &&
- then_case.same_as(op->then_case) &&
+ if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
// place the vector lanes in least significant dimension.
extents.push_back(var_lanes_);
// rewrite access to buffer internally.
- Stmt body = VecAllocAccess(
- op->buffer_var.get(), var_, var_lanes_)(op->body);
+ Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
- return AllocateNode::make(
- op->buffer_var, op->dtype,
- extents, condition, body);
+ return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body);
}
// scalarize the statment
Stmt Scalarize(Stmt stmt) {
if (!changed) return arr;
return Array<PrimExpr>(new_arr);
}
- template<typename T>
+ template <typename T>
PrimExpr BinaryVec(const T* op) {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
- if (a.same_as(op->a) &&
- b.same_as(op->b)) {
+ if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
- template<typename T>
+ template <typename T>
PrimExpr AddSubVec(const T* op) {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
- if (a.same_as(op->a) &&
- b.same_as(op->b)) {
+ if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (a.dtype().lanes() == 1 && b_ramp) {
return RampNode::make(
arith::Compute<T>(a, b_ramp->base),
- arith::Compute<T>(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
- b_ramp->lanes);
+ arith::Compute<T>(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
}
if (b.dtype().lanes() == 1 && a_ramp) {
- return RampNode::make(
- arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
+ return RampNode::make(arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
};
-Stmt VectorizeLoop(Stmt stmt) {
- return LoopVectorizer()(std::move(stmt));
-}
+Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); }
class VectorizeSkipper : public StmtMutator {
public:
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
if (op->for_type == ForType::Vectorized) {
- return ForNode::make(op->loop_var, op->min, op->extent,
- ForType::Serial, op->device_api,
+ return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
op->body);
} else {
- return stmt;
+ return stmt;
}
}
};
-Stmt SkipVectorize(Stmt stmt) {
- return VectorizeSkipper()(std::move(stmt));
-}
+Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); }
namespace transform {
return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop")
-.set_body_typed(VectorizeLoop);
+TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop);
} // namespace transform
TEST(Simplify, MinMax) {
tvm::arith::Analyzer ana;
auto x = tvm::te::var("x");
- auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ;
+ auto e1 = (tvm::max(x, 1) - tvm::max(x, 1));
auto e1s = ana.canonical_simplify(e1);
CHECK(tvm::tir::is_zero(e1s));
TEST(Simplify, Mul) {
tvm::arith::Analyzer ana;
auto x = tvm::te::var("x");
- auto e = (x * x) - (x * x) ;
+ auto e = (x * x) - (x * x);
auto es = ana.canonical_simplify(e);
CHECK(tvm::tir::is_zero(es));
}
auto es = ana.canonical_simplify(mod - x);
CHECK(tvm::tir::is_zero(es));
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir/attrs.h>
-#include <tvm/tir/op.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
namespace tvm {
namespace test {
double learning_rate;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") {
- TVM_ATTR_FIELD(axis)
- .set_default(10)
- .set_lower_bound(1)
- .set_upper_bound(10)
- .describe("axis field");
- TVM_ATTR_FIELD(name)
- .describe("name of the field");
+ TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe(
+ "axis field");
+ TVM_ATTR_FIELD(name).describe("name of the field");
TVM_ATTR_FIELD(expr)
.describe("expression field")
.set_default(tir::make_const(DataType::Int(32), 1));
- TVM_ATTR_FIELD(learning_rate)
- .describe("learning_rate")
- .set_default(0.1);
+ TVM_ATTR_FIELD(learning_rate).describe("learning_rate").set_default(0.1);
}
};
-}
-}
+} // namespace test
+} // namespace tvm
TEST(Attrs, Basic) {
using namespace tvm;
// Check docstring
std::ostringstream os;
n->PrintDocString(os);
- LOG(INFO) << "docstring\n"<< os.str();
+ LOG(INFO) << "docstring\n" << os.str();
CHECK(os.str().find("expr : PrimExpr, default=1") != std::string::npos);
}
-
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <topi/cuda/injective.h>
-#include <tvm/te/operation.h>
-#include <tvm/runtime/registry.h>
#include <tvm/driver/driver_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
-#include <string>
#include <cmath>
+#include <string>
TEST(BuildModule, Basic) {
using namespace tvm;
auto A = placeholder(shape, DataType::Float(32), "A");
auto B = placeholder(shape, DataType::Float(32), "B");
- auto C = compute(A->shape, [&A, &B](PrimExpr i) {
- return A[i] + B[i];
- }, "C");
+ auto C = compute(
+ A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "C");
- auto s = create_schedule({ C->op });
+ auto s = create_schedule({C->op});
auto cAxis = C->op.as<ComputeOpNode>()->axis;
IterVar bx, tx;
s[C].split(cAxis[0], 64, &bx, &tx);
- auto args = Array<Tensor>({ A, B, C });
+ auto args = Array<Tensor>({A, B, C});
std::unordered_map<Tensor, Buffer> binds;
auto config = BuildConfig::Create();
auto B = placeholder(shape, DataType::Float(32), "B");
auto C = placeholder(shape, DataType::Float(32), "C");
- auto elemwise_add = compute(A->shape, [&A, &B](PrimExpr i) {
- return A[i] + B[i];
- }, "elemwise_add");
+ auto elemwise_add = compute(
+ A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "elemwise_add");
auto copy = placeholder(shape, DataType::Float(32), "__copy");
- auto elemwise_sub = compute(C->shape, [©, &C](PrimExpr i) {
- return copy[i] - C[i];
- }, "elemwise_sub");
+ auto elemwise_sub = compute(
+ C->shape, [©, &C](PrimExpr i) { return copy[i] - C[i]; }, "elemwise_sub");
With<Target> cuda_scope(target_cuda);
auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});
-
With<Target> llvm_scope(target_llvm);
auto s2 = create_schedule({elemwise_sub->op});
std::unordered_map<Tensor, Buffer> binds;
auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config);
auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config);
- Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1},
- {target_llvm, lowered_s2}};
+ Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}};
auto module = build(inputs, Target(), config);
// Assertion for build.
"\"float32\"]]}}";
// Setup inputs.
- auto a_val =
- runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
- auto b_val =
- runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
- auto c_val =
- runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+ auto a_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+ auto b_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+ auto c_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pa = (float*)(a_val->data);
auto pb = (float*)(b_val->data);
const runtime::PackedFunc* graph_runtime =
tvm::runtime::Registry::Get("tvm.graph_runtime.create");
- runtime::Module mod = (*graph_runtime)(
- json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id);
+ runtime::Module mod =
+ (*graph_runtime)(json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id);
// test FFI for module.
auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
test_ffi(runtime::Module(mod), static_cast<int>(kTVMModuleHandle));
test_ffi(Optional<runtime::Module>(mod), static_cast<int>(kTVMModuleHandle));
-
PackedFunc set_input = mod.GetFunction("set_input", false);
PackedFunc run = mod.GetFunction("run", false);
PackedFunc get_output = mod.GetFunction("get_output", false);
}
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/container.h>
-#include <tvm/tir/op.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
#include <new>
#include <unordered_map>
public:
// Need this so that destructor of temporary objects don't interrupt our
// testing.
- TestErrorSwitch(const TestErrorSwitch& other)
- : should_fail(other.should_fail) {
+ TestErrorSwitch(const TestErrorSwitch& other) : should_fail(other.should_fail) {
const_cast<TestErrorSwitch&>(other).should_fail = false;
}
}
};
-class TestArrayObj : public Object,
- public InplaceArrayBase<TestArrayObj, TestErrorSwitch> {
+class TestArrayObj : public Object, public InplaceArrayBase<TestArrayObj, TestErrorSwitch> {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "test.TestArrayObj";
TestErrorSwitch f2{true};
TestErrorSwitch f3{false};
std::vector<TestErrorSwitch> fields{f1, f2, f3};
- auto ptr =
- make_inplace_array_object<TestArrayObj, TestErrorSwitch>(fields.size());
+ auto ptr = make_inplace_array_object<TestArrayObj, TestErrorSwitch>(fields.size());
try {
ptr->WrongInit(fields.begin(), fields.end());
} catch (...) {
// since it's not initalized.
TestErrorSwitch f2{true};
std::vector<TestErrorSwitch> fields{f1, f2};
- auto ptr =
- make_inplace_array_object<TestArrayObj, TestErrorSwitch>(fields.size());
+ auto ptr = make_inplace_array_object<TestArrayObj, TestErrorSwitch>(fields.size());
try {
ptr->Init(fields.begin(), fields.end());
} catch (...) {
using namespace tvm;
PrimExpr a = 1, b = 2;
Map<PrimExpr, PrimExpr> map1{{a, b}};
- std::unordered_map<PrimExpr, PrimExpr, ObjectHash, ObjectEqual> map2(
- map1.begin(), map1.end());
+ std::unordered_map<PrimExpr, PrimExpr, ObjectHash, ObjectEqual> map2(map1.begin(), map1.end());
CHECK(map2[a].as<IntImmNode>()->value == 2);
}
String s2 = Downcast<String>(r);
}
-
TEST(Optional, Composition) {
Optional<String> opt0(nullptr);
Optional<String> opt1 = String("xyz");
TEST(CRTMemory, Alloc) {
for (int idx = 0; idx < 65536; idx++) {
- void * a = vmalloc(1);
+ void* a = vmalloc(1);
EXPECT_EQ(vleak_size, 1);
vfree(a);
EXPECT_EQ(vleak_size, 0);
TEST(CRTMemory, Realloc) {
for (int idx = 0; idx < 65536; idx++) {
- void * a = vrealloc(0, 1);
+ void* a = vrealloc(0, 1);
EXPECT_EQ(vleak_size, 1);
- void * b = vrealloc(a, 1);
+ void* b = vrealloc(a, 1);
EXPECT_EQ(a, b);
EXPECT_EQ(vleak_size, 1);
vfree(a);
}
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
CHECK(os.str() == "max(((x + 1) + 2), 100)");
}
-
TEST(ExprNodeRef, Basic) {
using namespace tvm;
using namespace tvm::tir;
CHECK(GetRef<ObjectRef>(op).same_as(z));
}
-
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
#include <dmlc/logging.h>
#include <gtest/gtest.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/op.h>
#include <tvm/node/functor.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
TEST(IRF, Basic) {
auto z = x + 1;
NodeFunctor<int(const ObjectRef& n, int b)> f;
- f.set_dispatch<VarNode>([](const ObjectRef& n, int b) {
- return b;
- });
- f.set_dispatch<AddNode>([](const ObjectRef& n, int b) {
- return b + 2;
- });
- CHECK_EQ(f(x, 2), 2);
- CHECK_EQ(f(z, 2), 4);
+ f.set_dispatch<VarNode>([](const ObjectRef& n, int b) { return b; });
+ f.set_dispatch<AddNode>([](const ObjectRef& n, int b) { return b + 2; });
+ CHECK_EQ(f(x, 2), 2);
+ CHECK_EQ(f(z, 2), 4);
}
TEST(IRF, CountVar) {
auto z = x + 1 + y + y;
tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) {
if (n.as<VarNode>()) ++n_var;
- });
+ });
CHECK_EQ(n_var, 2);
}
-
TEST(IRF, ExprTransform) {
using namespace tvm;
using namespace tvm::tir;
Var x("x");
auto z = x + 1;
- class MyExprFunctor
- : public tir::ExprFunctor<int(const PrimExpr&, int)> {
+ class MyExprFunctor : public tir::ExprFunctor<int(const PrimExpr&, int)> {
public:
- int VisitExpr_(const VarNode* op, int b) final {
- return b;
- }
- int VisitExpr_(const IntImmNode* op, int b) final {
- return op->value;
- }
+ int VisitExpr_(const VarNode* op, int b) final { return b; }
+ int VisitExpr_(const IntImmNode* op, int b) final { return op->value; }
int VisitExpr_(const AddNode* op, int b) final {
return VisitExpr(op->a, b) + VisitExpr(op->b, b);
}
};
MyExprFunctor f;
- CHECK_EQ(f(x, 2), 2);
- CHECK_EQ(f(z, 2), 3);
+ CHECK_EQ(f(x, 2), 2);
+ CHECK_EQ(f(z, 2), 3);
try {
f(z - 1, 2);
LOG(FATAL) << "should fail";
- } catch(dmlc::Error) {
+ } catch (dmlc::Error) {
}
}
Var x("x");
auto z = x + 1;
- class MyVisitor
- : public tir::ExprFunctor<void(const PrimExpr&)>,
- public tir::StmtFunctor<void(const Stmt&)> {
+ class MyVisitor : public tir::ExprFunctor<void(const PrimExpr&)>,
+ public tir::StmtFunctor<void(const Stmt&)> {
public:
int count = 0;
// implementation
- void VisitExpr_(const VarNode* op) final {
- ++count;
- }
- void VisitExpr_(const IntImmNode* op) final {
- }
+ void VisitExpr_(const VarNode* op) final { ++count; }
+ void VisitExpr_(const IntImmNode* op) final {}
void VisitExpr_(const AddNode* op) final {
VisitExpr(op->a);
VisitExpr(op->b);
}
- void VisitStmt_(const EvaluateNode* op) final {
- VisitExpr(op->value);
- }
+ void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); }
};
MyVisitor v;
v.VisitStmt(EvaluateNode::make(z));
CHECK_EQ(v.count, 1);
}
-
TEST(IRF, StmtVisitor) {
using namespace tvm;
using namespace tvm::tir;
Var x("x");
- class MyVisitor
- : public StmtExprVisitor {
+ class MyVisitor : public StmtExprVisitor {
public:
int count = 0;
// implementation
- void VisitExpr_(const VarNode* op) final {
- ++count;
- }
+ void VisitExpr_(const VarNode* op) final { ++count; }
};
MyVisitor v;
auto fmaketest = [&]() {
using namespace tvm::tir;
Var x("x");
- class MyVisitor
- : public tir::StmtMutator,
- public tir::ExprMutator {
+ class MyVisitor : public tir::StmtMutator, public tir::ExprMutator {
public:
using StmtMutator::operator();
using ExprMutator::operator();
protected:
// implementation
- PrimExpr VisitExpr_(const AddNode* op) final {
- return op->a;
- }
- Stmt VisitStmt_(const SeqStmtNode* op) final {
- return StmtMutator::VisitSeqStmt_(op, true);
- }
- PrimExpr VisitExpr(const PrimExpr& expr) final {
- return ExprMutator::VisitExpr(expr);
- }
+ PrimExpr VisitExpr_(const AddNode* op) final { return op->a; }
+ Stmt VisitStmt_(const SeqStmtNode* op) final { return StmtMutator::VisitSeqStmt_(op, true); }
+ PrimExpr VisitExpr(const PrimExpr& expr) final { return ExprMutator::VisitExpr(expr); }
};
auto fmakealloc = [&]() {
auto z = x + 1;
}
{
- auto body = EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern));
+ auto body =
+ EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern));
auto res = v(std::move(body));
CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[0].same_as(x));
}
}
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
#include <dmlc/logging.h>
#include <gtest/gtest.h>
-#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
+#include <tvm/runtime/object.h>
namespace tvm {
namespace test {
TVM_DECLARE_FINAL_OBJECT_INFO(ObjAA, ObjA);
};
-
TVM_REGISTER_OBJECT_TYPE(ObjBase);
TVM_REGISTER_OBJECT_TYPE(ObjA);
TVM_REGISTER_OBJECT_TYPE(ObjB);
CHECK(refB.as<ObjB>() != nullptr);
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
#include <dmlc/logging.h>
#include <gtest/gtest.h>
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/container.h>
+#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include <tvm/tir/transform.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/transform.h>
TEST(PackedFunc, Basic) {
using namespace tvm;
DLTensor a;
Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
- CHECK(args.num_args == 3);
- CHECK(args.values[0].v_float64 == 1.0);
- CHECK(args.type_codes[0] == kDLFloat);
- CHECK(args.values[1].v_handle == &a);
- CHECK(args.type_codes[1] == kTVMDLTensorHandle);
- CHECK(args.values[2].v_handle == &x);
- CHECK(args.type_codes[2] == kTVMOpaqueHandle);
- *rv = Var("a");
- })(1.0, &a, handle);
+ CHECK(args.num_args == 3);
+ CHECK(args.values[0].v_float64 == 1.0);
+ CHECK(args.type_codes[0] == kDLFloat);
+ CHECK(args.values[1].v_handle == &a);
+ CHECK(args.type_codes[1] == kTVMDLTensorHandle);
+ CHECK(args.values[2].v_handle == &x);
+ CHECK(args.type_codes[2] == kTVMOpaqueHandle);
+ *rv = Var("a");
+ })(1.0, &a, handle);
CHECK(v->name_hint == "a");
}
using namespace tvm::runtime;
Var x;
Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
- CHECK(args.num_args == 1);
- CHECK(args[0].IsObjectRef<ObjectRef>());
- Var b = args[0];
- CHECK(x.same_as(b));
- *rv = b;
- })(x);
+ CHECK(args.num_args == 1);
+ CHECK(args[0].IsObjectRef<ObjectRef>());
+ Var b = args[0];
+ CHECK(x.same_as(b));
+ *rv = b;
+ })(x);
CHECK(t.same_as(x));
}
TEST(PackedFunc, NDArray) {
using namespace tvm;
using namespace tvm::runtime;
- auto x = NDArray::Empty(
- {}, String2DLDataType("float32"),
- TVMContext{kDLCPU, 0});
+ auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0});
reinterpret_cast<float*>(x->data)[0] = 10.0f;
CHECK(x.use_count() == 1);
- PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) {
- *rv = args[0];
- });
+ PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; });
NDArray ret = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
- NDArray y = args[0];
- DLTensor* ptr = args[0];
- CHECK(ptr == x.operator->());
- CHECK(x.same_as(y));
- CHECK(x.use_count() == 2);
- *rv = forward(y);
- })(x);
+ NDArray y = args[0];
+ DLTensor* ptr = args[0];
+ CHECK(ptr == x.operator->());
+ CHECK(x.same_as(y));
+ CHECK(x.use_count() == 2);
+ *rv = forward(y);
+ })(x);
CHECK(ret.use_count() == 2);
CHECK(ret.same_as(x));
}
using namespace tvm;
using namespace tvm::runtime;
PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
- CHECK(args.num_args == 1);
- std::string x = args[0];
- CHECK(x == "hello");
- String y = args[0];
- CHECK(y == "hello");
- *rv = x;
- })("hello");
+ CHECK(args.num_args == 1);
+ std::string x = args[0];
+ CHECK(x == "hello");
+ String y = args[0];
+ CHECK(y == "hello");
+ *rv = x;
+ })("hello");
PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
- CHECK(args.num_args == 1);
- runtime::String s = args[0];
- CHECK(s == "hello");
+ CHECK(args.num_args == 1);
+ runtime::String s = args[0];
+ CHECK(s == "hello");
})(runtime::String("hello"));
}
-
TEST(PackedFunc, func) {
using namespace tvm;
using namespace tvm::runtime;
- PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) {
- *rv = args[0].operator int() + 1;
- });
+ PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0].operator int() + 1; });
// function as arguments
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
- PackedFunc f = args[0];
- // TVMArgValue -> Arguments as function
- *rv = f(args[1]).operator int();
- })(addone, 1);
+ PackedFunc f = args[0];
+ // TVMArgValue -> Arguments as function
+ *rv = f(args[1]).operator int();
+ })(addone, 1);
CHECK_EQ(r0, 2);
int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
- // TVMArgValue -> TVMRetValue
- *rv = args[1];
- })(2, 100);
+ // TVMArgValue -> TVMRetValue
+ *rv = args[1];
+ })(2, 100);
CHECK_EQ(r1, 100);
int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
- // re-assignment
- *rv = args[0];
- // TVMRetValue -> Function argument
- *rv = addone(args[0].operator PackedFunc()(args[1], 1));
- })(addone, 100);
+ // re-assignment
+ *rv = args[0];
+ // TVMRetValue -> Function argument
+ *rv = addone(args[0].operator PackedFunc()(args[1], 1));
+ })(addone, 100);
CHECK_EQ(r2, 102);
}
using namespace tvm::runtime;
// automatic conversion of int to expr
PackedFunc addone([](TVMArgs args, TVMRetValue* rv) {
- PrimExpr x = args[0];
- *rv = x.as<tvm::tir::IntImmNode>()->value + 1;
+ PrimExpr x = args[0];
+ *rv = x.as<tvm::tir::IntImmNode>()->value + 1;
});
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
- PackedFunc f = args[0];
- // TVMArgValue -> Arguments as function
- *rv = f(args[1]).operator int();
- })(addone, 1);
+ PackedFunc f = args[0];
+ // TVMArgValue -> Arguments as function
+ *rv = f(args[1]).operator int();
+ })(addone, 1);
CHECK_EQ(r0, 2);
}
using namespace tvm;
using namespace tvm::runtime;
auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
- DataType x = args[0];
- *rv = x;
- });
- auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
- *rv = args[0];
- });
+ DataType x = args[0];
+ *rv = x;
+ });
+ auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; });
CHECK(get_type("int32").operator DataType() == DataType::Int(32));
CHECK(get_type("float").operator DataType() == DataType::Float(32));
CHECK(get_type2("float32x2").operator DataType() == DataType::Float(32, 2));
using BindFunc = TypedPackedFunc<Int1Func(Int2Func, int value)>;
BindFunc ftyped;
ftyped = [](Int2Func f1, int value) -> Int1Func {
- auto binded = [f1, value](int x) {
- return f1(value, x);
- };
+ auto binded = [f1, value](int x) { return f1(value, x); };
Int1Func x(binded);
return x;
};
using tvm::runtime::detail::function_signature;
TypedPackedFunc<int(float)> x;
- auto f = [](int x) -> int {
- return x + 1;
- };
+ auto f = [](int x) -> int { return x + 1; };
std::function<void(float)> y;
- static_assert(std::is_same<function_signature<decltype(x)>::FType,
- int(float)>::value, "invariant1");
- static_assert(std::is_same<function_signature<decltype(f)>::FType,
- int(int)>::value, "invariant2");
- static_assert(std::is_same<function_signature<decltype(y)>::FType,
- void(float)>::value, "invariant3");
+ static_assert(std::is_same<function_signature<decltype(x)>::FType, int(float)>::value,
+ "invariant1");
+ static_assert(std::is_same<function_signature<decltype(f)>::FType, int(int)>::value,
+ "invariant2");
+ static_assert(std::is_same<function_signature<decltype(y)>::FType, void(float)>::value,
+ "invariant3");
}
-
TEST(PackedFunc, ObjectConversion) {
using namespace tvm;
using namespace tvm::tir;
using namespace tvm::runtime;
TVMRetValue rv;
- auto x = NDArray::Empty(
- {}, String2DLDataType("float32"),
- TVMContext{kDLCPU, 0});
+ auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0});
// assign null
rv = ObjectRef();
CHECK_EQ(rv.type_code(), kTVMNullptr);
CHECK(!rv.IsObjectRef<PrimExpr>());
auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
- CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle);
- CHECK(args[0].operator NDArray().same_as(x));
- CHECK(args[0].operator ObjectRef().same_as(x));
- CHECK(args[1].operator ObjectRef().get() == nullptr);
- CHECK(args[1].operator NDArray().get() == nullptr);
- CHECK(args[1].operator Module().get() == nullptr);
- CHECK(args[1].operator Array<NDArray>().get() == nullptr);
- CHECK(!args[0].IsObjectRef<PrimExpr>());
- });
+ CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle);
+ CHECK(args[0].operator NDArray().same_as(x));
+ CHECK(args[0].operator ObjectRef().same_as(x));
+ CHECK(args[1].operator ObjectRef().get() == nullptr);
+ CHECK(args[1].operator NDArray().get() == nullptr);
+ CHECK(args[1].operator Module().get() == nullptr);
+ CHECK(args[1].operator Array<NDArray>().get() == nullptr);
+ CHECK(!args[0].IsObjectRef<PrimExpr>());
+ });
pf1(x, ObjectRef());
pf1(ObjectRef(x), NDArray());
CHECK(!rv.IsObjectRef<NDArray>());
auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
- CHECK_EQ(args[0].type_code(), kTVMModuleHandle);
- CHECK(args[0].operator Module().same_as(m));
- CHECK(args[0].operator ObjectRef().same_as(m));
- CHECK(args[1].operator ObjectRef().get() == nullptr);
- CHECK(args[1].operator NDArray().get() == nullptr);
- CHECK(args[1].operator Module().get() == nullptr);
- CHECK(!args[0].IsObjectRef<PrimExpr>());
- });
+ CHECK_EQ(args[0].type_code(), kTVMModuleHandle);
+ CHECK(args[0].operator Module().same_as(m));
+ CHECK(args[0].operator ObjectRef().same_as(m));
+ CHECK(args[1].operator ObjectRef().get() == nullptr);
+ CHECK(args[1].operator NDArray().get() == nullptr);
+ CHECK(args[1].operator Module().get() == nullptr);
+ CHECK(!args[0].IsObjectRef<PrimExpr>());
+ });
pf2(m, ObjectRef());
pf2(ObjectRef(m), Module());
}
using namespace tvm;
using namespace tvm::runtime;
{
-
auto inspect = [](TVMArgs args, TVMRetValue* rv) {
for (int i = 0; i < args.size(); ++i) {
CHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg);
}
};
- PackedFunc finspect(inspect);
+ PackedFunc finspect(inspect);
finspect(tir::Var("x"));
}
{
}
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
* under the License.
*/
+#include "../src/arith/pattern_match.h"
+
#include <gtest/gtest.h>
#include <tvm/tir/analysis.h>
-#include "../src/arith/pattern_match.h"
TEST(Pattern, Basic) {
using namespace tvm;
CHECK((px >= py && px < pz).Match(x >= y && x < z));
CHECK((!(px > py || px != py)).Match(!(x > y || x != y)));
{
- CHECK(select(px >= pz, py, py + pz).Match(
- tir::SelectNode::make((x + 1) >= 1, y, y + 1)));
+ CHECK(select(px >= pz, py, py + pz).Match(tir::SelectNode::make((x + 1) >= 1, y, y + 1)));
CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1));
}
// bit intrinsics
CHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2)))));
// select
{
- CHECK(select(px > pz, py, py + pz).Match(
- tir::SelectNode::make(x > 1, y, y + 1)));
+ CHECK(select(px > pz, py, py + pz).Match(tir::SelectNode::make(x > 1, y, y + 1)));
CHECK(is_const_int(pz.Eval(), 1));
}
- CHECK(!select(px > pz, py, py + pz).Match(
- tir::SelectNode::make(x > 2, y, y + 1)));
- CHECK(!select(px > pz, py, py).Match(
- tir::SelectNode::make(x > 2, y, y + 1)));
+ CHECK(!select(px > pz, py, py + pz).Match(tir::SelectNode::make(x > 2, y, y + 1)));
+ CHECK(!select(px > pz, py, py).Match(tir::SelectNode::make(x > 2, y, y + 1)));
{
- CHECK(select(px, py, pz).Match(
- tir::SelectNode::make(x > 2, y, y + 1)));
+ CHECK(select(px, py, pz).Match(tir::SelectNode::make(x > 2, y, y + 1)));
CHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1));
}
// if_then_else
{
- CHECK(if_then_else(px > pz, py, py + pz).Match(
- if_then_else(x > 1, y, y + 1)));
+ CHECK(if_then_else(px > pz, py, py + pz).Match(if_then_else(x > 1, y, y + 1)));
CHECK(is_const_int(pz.Eval(), 1));
}
// cast pattern
{
- CHECK(!cast(PConst<DataType>(
- DataType::Int(32)), px).Match(tir::CastNode::make(DataType::Float(64), x)));
+ CHECK(!cast(PConst<DataType>(DataType::Int(32)), px)
+ .Match(tir::CastNode::make(DataType::Float(64), x)));
CHECK(cast(pt, px).Match(tir::CastNode::make(DataType::Float(64), x)));
CHECK(pt.Eval() == DataType::Float(64));
auto zz = cast(pt, px).Eval();
- CHECK((cast(pt, px) - cast(pt, py)).Match(
- tir::CastNode::make(DataType::Float(64), x) - tir::CastNode::make(DataType::Int(64), x)));
+ CHECK((cast(pt, px) - cast(pt, py))
+ .Match(tir::CastNode::make(DataType::Float(64), x) -
+ tir::CastNode::make(DataType::Int(64), x)));
auto expr = tir::CastNode::make(DataType::Int(32), tir::CastNode::make(DataType::Float(64), x));
CHECK(!(cast(pt, cast(pt, px))).Match(expr));
}
// ramp pattern
{
- CHECK(ramp(px, PConst<PrimExpr>(1), planes).Match(
- tir::RampNode::make(x, 1, 10)));
+ CHECK(ramp(px, PConst<PrimExpr>(1), planes).Match(tir::RampNode::make(x, 1, 10)));
CHECK(planes.Eval() == 10);
- CHECK(!ramp(px, PConst<PrimExpr>(1), planes).Match(
- tir::RampNode::make(x, 2, 10)));
+ CHECK(!ramp(px, PConst<PrimExpr>(1), planes).Match(tir::RampNode::make(x, 2, 10)));
}
// broadcast pattern
{
- CHECK(broadcast(px, planes).Match(
- tir::BroadcastNode::make(x, 10)));
+ CHECK(broadcast(px, planes).Match(tir::BroadcastNode::make(x, 10)));
CHECK(planes.Eval() == 10);
- CHECK(broadcast(px * py , planes).Match(
- tir::BroadcastNode::make(x * 10, 10)));
+ CHECK(broadcast(px * py, planes).Match(tir::BroadcastNode::make(x * 10, 10)));
}
}
CHECK(!(v * c).Match((tx + 1) * 3));
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
*/
#include <gtest/gtest.h>
-#include <tvm/driver/driver_api.h>
-#include <tvm/te/operation.h>
-#include <tvm/relay/expr.h>
-#include <tvm/relay/type.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/relay/transform.h>
-#include <tvm/relay/op_strategy.h>
-#include <tvm/relay/op_attr_types.h>
#include <topi/broadcast.h>
#include <topi/generic/injective.h>
-#include <tvm/runtime/packed_func.h>
+#include <tvm/driver/driver_api.h>
#include <tvm/ir/module.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/op_strategy.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
using namespace tvm;
using namespace tvm::relay;
TVM_REGISTER_GLOBAL("test.strategy")
-.set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs,
- const Type& out_type, const Target& target) {
- FTVMCompute fcompute = [](const Attrs& attrs,
- const Array<te::Tensor>& inputs,
- const Type& out_type) -> Array<te::Tensor> {
+ .set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type,
+ const Target& target) {
+ FTVMCompute fcompute = [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) -> Array<te::Tensor> {
CHECK_EQ(inputs.size(), 2U);
return {topi::add(inputs[0], inputs[1])};
- };
- FTVMSchedule fschedule = [](const Attrs& attrs,
- const Array<te::Tensor>& outs,
- const Target& target) {
+ };
+ FTVMSchedule fschedule = [](const Attrs& attrs, const Array<te::Tensor>& outs,
+ const Target& target) {
With<Target> target_scope(target);
return topi::generic::schedule_injective(target, outs);
- };
+ };
- auto n = make_object<OpStrategyNode>();
- auto strategy = tvm::relay::OpStrategy(std::move(n));
- strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10);
- return strategy;
-});
+ auto n = make_object<OpStrategyNode>();
+ auto strategy = tvm::relay::OpStrategy(std::move(n));
+ strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10);
+ return strategy;
+ });
TVM_REGISTER_GLOBAL("relay.backend.lower_call")
-.set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs,
- const Target& target) {
- static auto fstrategy = Op::GetAttr<relay::FTVMStrategy>("FTVMStrategy");
- Op op = Downcast<Op>(call->op);
- auto out_type = call->checked_type();
- OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target);
- auto impl = strategy->specializations[0]->implementations[0];
- auto outs = impl.Compute(call->attrs, inputs, out_type);
- auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput");
- if (!f) {
- LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered";
- }
- return (*f)(outs, impl);
-});
+ .set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs,
+ const Target& target) {
+ static auto fstrategy = Op::GetAttr<relay::FTVMStrategy>("FTVMStrategy");
+ Op op = Downcast<Op>(call->op);
+ auto out_type = call->checked_type();
+ OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target);
+ auto impl = strategy->specializations[0]->implementations[0];
+ auto outs = impl.Compute(call->attrs, inputs, out_type);
+ auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput");
+ if (!f) {
+ LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered";
+ }
+ return (*f)(outs, impl);
+ });
TEST(Relay, BuildModule) {
auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
CHECK(ref_count[z.get()] == 1);
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
#include <gtest/gtest.h>
#include <tvm/node/structural_equal.h>
-#include <tvm/te/operation.h>
-#include <tvm/relay/expr.h>
-#include <tvm/relay/type.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/te/operation.h>
TEST(Relay, SelfReference) {
using namespace tvm;
auto tensor_type = relay::TensorType({}, DataType::Bool());
auto x = relay::Var("x", relay::Type());
- auto f = relay::Function(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
+ auto f = relay::Function(tvm::Array<relay::Var>{x}, x, relay::Type(), {});
CHECK(f->IsInstance<BaseFuncNode>());
auto y = relay::Var("y", tensor_type);
- auto call = relay::Call(f, Array<relay::Expr>{ y });
- auto fx = relay::Function(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
+ auto call = relay::Call(f, Array<relay::Expr>{y});
+ auto fx = relay::Function(tvm::Array<relay::Var>{y}, call, relay::Type(), {});
auto mod = IRModule::FromExpr(fx);
mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main");
- auto expected = relay::FuncType(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
+ auto expected = relay::FuncType(tvm::Array<relay::Type>{tensor_type}, tensor_type, {}, {});
CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected));
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
#include <gtest/gtest.h>
#include <topi/generic/injective.h>
-#include <tvm/node/structural_equal.h>
#include <tvm/driver/driver_api.h>
-#include <tvm/relay/expr.h>
#include <tvm/ir/module.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
-TVM_REGISTER_GLOBAL("schedule")
- .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
- *rv = topi::generic::schedule_injective(args[0], args[1]);
- });
+TVM_REGISTER_GLOBAL("schedule").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+ *rv = topi::generic::schedule_injective(args[0], args[1]);
+});
TEST(Relay, Sequential) {
using namespace tvm;
auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32));
- auto c_data =
- tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+ auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
// Create a function for optimization.
auto c = relay::Constant(c_data);
auto z2 = relay::Call(add_op, {z, z1});
// Let expression and varaible a should be dead-code eliminated.
auto z3 = relay::Let(a, c, z2);
- relay::Function func =
- relay::Function(relay::FreeVars(z3), z3, relay::Type(), {});
+ relay::Function func = relay::Function(relay::FreeVars(z3), z3, relay::Type(), {});
// Get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register");
// Run sequential passes.
tvm::Array<relay::transform::Pass> pass_seqs{
- relay::transform::InferType(),
- relay::transform::DeadCodeElimination(),
- relay::transform::EliminateCommonSubexpr(),
- relay::transform::AlterOpLayout()
- };
+ relay::transform::InferType(), relay::transform::DeadCodeElimination(),
+ relay::transform::EliminateCommonSubexpr(), relay::transform::AlterOpLayout()};
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
auto mod = IRModule::FromExpr(func);
auto pass_ctx = relay::transform::PassContext::Create();
y1 = relay::Call(add_op, {x1, y1});
auto zz = relay::Call(add_op, {y1, c1});
zz = relay::Call(add_op, {zz, zz});
- relay::Function expected_func =
- relay::Function(relay::FreeVars(zz), zz, relay::Type(), {});
+ relay::Function expected_func = relay::Function(relay::FreeVars(zz), zz, relay::Type(), {});
// Infer type for the expected function.
auto mod1 = IRModule::FromExpr(expected_func);
CHECK(!tvm::tir::HasSideEffect(A[0]));
}
-
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
Tensor A = placeholder({m, l}, DataType::Float(32), "A");
Tensor B = placeholder({n, l}, DataType::Float(32), "B");
- auto C = compute({m, n}, [&](Var i, Var j) {
- return A[i][j];
- }, "C");
+ auto C = compute(
+ {m, n}, [&](Var i, Var j) { return A[i][j]; }, "C");
Tensor::Slice x = A[n];
}
te::Tensor B = te::placeholder({n, l}, DataType::Float(32), "B");
IterVar rv = reduce_axis(Range{0, l}, "k");
- auto C = te::compute({m, n}, [&](Var i, Var j) {
- return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv});
- }, "C");
+ auto C = te::compute(
+ {m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C");
LOG(INFO) << C->op.as<te::ComputeOpNode>()->body;
}
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
* under the License.
*/
+#include <gtest/gtest.h>
+#include <tvm/runtime/c_backend_api.h>
+
#include <atomic>
#include <memory>
#include <thread>
-#include <gtest/gtest.h>
-#include <tvm/runtime/c_backend_api.h>
-
constexpr size_t N = 128;
static FTVMParallelLambda atomic_add_task_id = [](int task_id, TVMParallelGroupEnv* penv,
* under the License.
*/
-#include <tvm/te/operation.h>
-#include <topi/elemwise.h>
#include <gtest/gtest.h>
+#include <topi/elemwise.h>
+#include <tvm/te/operation.h>
namespace topi {
TEST(Tensor, Basic) {
Tensor A = placeholder({m, l}, DataType::Float(32), "A");
auto C = topi::exp(A);
}
-}
+} // namespace topi
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
* under the License.
*/
-#include <random>
-
#include <dlpack/dlpack.h>
#include <gtest/gtest.h>
+
#include <map>
+#include <random>
#include <vector>
#ifdef USE_MICRO_STANDALONE_RUNTIME
#if defined(__APPLE__) && defined(__MACH__)
#include <gtest/gtest.h>
+#include <spawn.h>
+#include <sys/wait.h>
#include <topi/generic/injective.h>
#include <tvm/driver/driver_api.h>
-#include <tvm/te/operation.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-
-#include <spawn.h>
-#include <sys/wait.h>
+#include <tvm/te/operation.h>
TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
#include <topi/detail/constant_utils.h>
#include <topi/tags.h>
-#include <string>
#include <algorithm>
+#include <string>
namespace topi {
std::string name = "T_broadcast_to",
std::string tag = kBroadcast) {
CHECK_GE(output_shape.size(), t->shape.size())
- << "Not a broadcast, output dimensionality smaller than input.\noutput: "
- << output_shape << "\nvs\ninput: " << t;
+ << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape
+ << "\nvs\ninput: " << t;
auto bh = detail::BroadcastShape(output_shape, t->shape);
CHECK_EQ(output_shape.size(), bh.common_shape.size());
for (size_t i = 0; i < output_shape.size(); ++i) {
auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
};
- return tvm::te::compute(
- tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
- l,
- name,
- tag);
+ return tvm::te::compute(tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
+ l, name, tag);
}
-#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
- inline tvm::PrimExpr Name(const tvm::PrimExpr& a, \
- const tvm::PrimExpr& b) { \
- ComputeRule; \
- } \
- inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \
- const tvm::te::Tensor& B, \
- std::string name = "T_" #Name, \
- std::string tag = kBroadcast) { \
- auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
- return detail::WithBroadcast(l, A, B, name, tag); \
- } \
- inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \
- const tvm::PrimExpr& B, \
- std::string name = "T_" #Name, \
- std::string tag = kElementWise) { \
- auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
- return tvm::te::compute(A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \
- return l(A(i), B); \
- }, name, tag); \
- } \
- inline tvm::te::Tensor Name(const tvm::PrimExpr& A, \
- const tvm::te::Tensor& B, \
- std::string name = "T_" #Name, \
- std::string tag = kElementWise) { \
- auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
- return tvm::te::compute(B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \
- return l(A, B(i)); \
- }, name, tag); \
+#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
+ inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \
+ inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \
+ std::string name = "T_" #Name, std::string tag = kBroadcast) { \
+ auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
+ return detail::WithBroadcast(l, A, B, name, tag); \
+ } \
+ inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \
+ std::string name = "T_" #Name, std::string tag = kElementWise) { \
+ auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
+ return tvm::te::compute( \
+ A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \
+ } \
+ inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \
+ std::string name = "T_" #Name, std::string tag = kElementWise) { \
+ auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
+ return tvm::te::compute( \
+ B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \
}
-
-#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \
- inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \
- const tvm::te::Tensor& B) { \
- return topi::OpName(A, B); \
- } \
- inline tvm::te::Tensor Name(const tvm::PrimExpr& A, \
- const tvm::te::Tensor& B) { \
- return topi::OpName(A, B); \
- } \
- inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \
- const tvm::PrimExpr& B) { \
- return topi::OpName(A, B); \
+#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \
+ inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \
+ return topi::OpName(A, B); \
+ } \
+ inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \
+ return topi::OpName(A, B); \
+ } \
+ inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \
+ return topi::OpName(A, B); \
}
/*!
#ifndef TOPI_CONTRIB_CUBLAS_H_
#define TOPI_CONTRIB_CUBLAS_H_
-#include <tvm/te/operation.h>
#include <topi/detail/extern.h>
+#include <tvm/te/operation.h>
namespace topi {
namespace contrib {
using namespace tvm::te;
using namespace topi::detail;
/*!
-* \brief Create an op that multiplies lhs and rhs with cuBLAS
-*
-* \param lhs The left matrix operand
-* \param rhs The right matrix operand
-* \param transa Whether to transpose lhs
-* \param transb Whether to transpose rhs
-*
-* \return The output tensor
-*/
-inline Tensor cublas_matmul(const Tensor& lhs,
- const Tensor& rhs,
- bool transa,
- bool transb) {
+ * \brief Create an op that multiplies lhs and rhs with cuBLAS
+ *
+ * \param lhs The left matrix operand
+ * \param rhs The right matrix operand
+ * \param transa Whether to transpose lhs
+ * \param transb Whether to transpose rhs
+ *
+ * \return The output tensor
+ */
+inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
auto n = transa ? lhs->shape[1] : lhs->shape[0];
auto m = transb ? rhs->shape[0] : rhs->shape[1];
return make_extern(
- { { n, m } }, { lhs->dtype }, { lhs, rhs },
- [&](Array<Buffer> ins, Array<Buffer> outs) {
- return call_packed({
- StringImmNode::make("tvm.contrib.cublas.matmul"),
- pack_buffer(ins[0]),
- pack_buffer(ins[1]),
- pack_buffer(outs[0]),
- transa,
- transb });
- }, "C", "", {})[0];
+ {{n, m}}, {lhs->dtype}, {lhs, rhs},
+ [&](Array<Buffer> ins, Array<Buffer> outs) {
+ return call_packed({StringImmNode::make("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]),
+ pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
+ },
+ "C", "", {})[0];
}
/*!
-* \brief Create an op that multiplies batch matrices
-* lhs and rhs with cuBLAS
-*
-* \param lhs The left matrix operand
-* \param rhs The right matrix operand
-* \param transa Whether to transpose lhs
-* \param transb Whether to transpose rhs
-*
-* \return The output tensor
-*/
-inline Tensor cublas_batch_matmul(const Tensor& lhs,
- const Tensor& rhs,
- bool transa,
- bool transb) {
+ * \brief Create an op that multiplies batch matrices
+ * lhs and rhs with cuBLAS
+ *
+ * \param lhs The left matrix operand
+ * \param rhs The right matrix operand
+ * \param transa Whether to transpose lhs
+ * \param transb Whether to transpose rhs
+ *
+ * \return The output tensor
+ */
+inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
auto b = lhs->shape[0];
auto n = transa ? lhs->shape[2] : lhs->shape[1];
auto m = transb ? rhs->shape[1] : rhs->shape[2];
- return make_extern(
- { { b, n, m } }, { lhs->dtype }, { lhs, rhs },
- [&](Array<Buffer> ins, Array<Buffer> outs) {
- return call_packed({
- StringImmNode::make("tvm.contrib.cublas.batch_matmul"),
- pack_buffer(ins[0]),
- pack_buffer(ins[1]),
- pack_buffer(outs[0]),
- transa,
- transb });
- }, "C", "", {})[0];
+ return make_extern({{b, n, m}}, {lhs->dtype}, {lhs, rhs},
+ [&](Array<Buffer> ins, Array<Buffer> outs) {
+ return call_packed({StringImmNode::make("tvm.contrib.cublas.batch_matmul"),
+ pack_buffer(ins[0]), pack_buffer(ins[1]),
+ pack_buffer(outs[0]), transa, transb});
+ },
+ "C", "", {})[0];
}
} // namespace contrib
#define TOPI_CONTRIB_ROCBLAS_H_
#include <tvm/te/operation.h>
+
#include "topi/detail/extern.h"
namespace topi {
using namespace tvm;
using namespace tvm::te;
/*!
-* \brief Create an op that multiplies lhs and rhs with rocBLAS
-*
-* \param lhs The left matrix operand
-* \param rhs The right matrix operand
-* \param transa Whether to transpose lhs
-* \param transb Whether to transpose rhs
-*
-* \return The output tensor
-*/
-inline Tensor rocblas_matmul(const Tensor& lhs,
- const Tensor& rhs,
- bool transa,
- bool transb) {
+ * \brief Create an op that multiplies lhs and rhs with rocBLAS
+ *
+ * \param lhs The left matrix operand
+ * \param rhs The right matrix operand
+ * \param transa Whether to transpose lhs
+ * \param transb Whether to transpose rhs
+ *
+ * \return The output tensor
+ */
+inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
auto n = transa ? lhs->shape[1] : lhs->shape[0];
auto m = transb ? rhs->shape[0] : rhs->shape[1];
return make_extern(
- { { n, m } }, { lhs->dtype }, { lhs, rhs },
- [&](Array<Buffer> ins, Array<Buffer> outs) {
- return call_packed({
- StringImmNode::make("tvm.contrib.rocblas.matmul"),
- pack_buffer(ins[0]),
- pack_buffer(ins[1]),
- pack_buffer(outs[0]),
- transa,
- transb });
- }, "C", "", {})[0];
+ {{n, m}}, {lhs->dtype}, {lhs, rhs},
+ [&](Array<Buffer> ins, Array<Buffer> outs) {
+ return call_packed({StringImmNode::make("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]),
+ pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
+ },
+ "C", "", {})[0];
}
} // namespace contrib
#ifndef TOPI_CUDA_DENSE_H_
#define TOPI_CUDA_DENSE_H_
-#include <tvm/te/operation.h>
-#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
-#include <topi/detail/array_utils.h>
-#include <topi/nn/dense.h>
#include <topi/contrib/cublas.h>
+#include <topi/detail/array_utils.h>
#include <topi/generic/extern.h>
+#include <topi/nn/dense.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule_pass.h>
namespace topi {
using namespace tvm;
namespace cuda {
/*!
-* \brief Implementation of dense for CUDA backend
-*
-* \param target The target device
-* \param data Tensor with shape [batch, in_dim]
-* \param weight Tensor with shape [out_dim, in_dim]
-* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
-* \param out_dtype Output data type. Used for mixed precision.
-*
-* \return Tensor with shape [batch, out_dim]
-*/
-inline tvm::te::Tensor dense_cuda(const Target& target,
- const tvm::te::Tensor& data,
- const tvm::te::Tensor& weight,
- const tvm::te::Tensor& bias,
- const DataType& out_dtype) {
+ * \brief Implementation of dense for CUDA backend
+ *
+ * \param target The target device
+ * \param data Tensor with shape [batch, in_dim]
+ * \param weight Tensor with shape [out_dim, in_dim]
+ * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
+ * \param out_dtype Output data type. Used for mixed precision.
+ *
+ * \return Tensor with shape [batch, out_dim]
+ */
+inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& data,
+ const tvm::te::Tensor& weight, const tvm::te::Tensor& bias,
+ const DataType& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
if (bias.defined()) {
- mm = tvm::te::compute({ batch, out_dim },
- [&](Var i, Var j) {
- return mm(i, j) + bias(j);
- }, "tensor", kBroadcast);
+ mm = tvm::te::compute(
+ {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast);
}
return mm;
}
/*!
-* \brief Create a CUDA schedule for dense
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) {
- if (target->target_name == "cuda" &&
- target->libs().count("cublas")) {
+ * \brief Create a CUDA schedule for dense
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_dense(const Target& target, const Array<Tensor>& outs) {
+ if (target->target_name == "cuda" && target->libs().count("cublas")) {
return topi::generic::schedule_extern(target, outs);
}
#ifndef TOPI_CUDA_INJECTIVE_H_
#define TOPI_CUDA_INJECTIVE_H_
+#include <topi/detail/fuse.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
-#include <topi/detail/fuse.h>
namespace topi {
using namespace tvm;
*
* \return A schedule for the given ops.
*/
-inline Schedule schedule_injective(const Target &target, const Array<Tensor>& outs) {
+inline Schedule schedule_injective(const Target& target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
#ifndef TOPI_CUDA_NORMALIZATION_H_
#define TOPI_CUDA_NORMALIZATION_H_
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
namespace topi {
using namespace tvm;
using namespace tvm::te;
namespace cuda {
/*!
-* \brief Create a CUDA schedule for LRN
-* \param outs The output tensors.
-* \return A schedule for the given ops.
-*/
+ * \brief Create a CUDA schedule for LRN
+ * \param outs The output tensors.
+ * \return A schedule for the given ops.
+ */
inline Schedule schedule_lrn(const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
#ifndef TOPI_CUDA_POOLING_H_
#define TOPI_CUDA_POOLING_H_
+#include <topi/detail/array_utils.h>
+#include <topi/detail/fuse.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
-#include <topi/detail/fuse.h>
-#include <topi/detail/array_utils.h>
namespace topi {
using namespace tvm;
namespace cuda {
/*!
-* \brief Create a CUDA schedule for pool
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
+ * \brief Create a CUDA schedule for pool
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_pool(const Target& target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
/*!
-* \brief Create a CUDA schedule for global_pool
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>& outs) {
+ * \brief Create a CUDA schedule for global_pool
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_global_pool(const Target& target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
s[out].split(i, num_thread, &by, &ty);
IterVar bx, tx;
s[out].split(c, num_thread, &bx, &tx);
- s[out].reorder({ by, bx, ty, tx });
+ s[out].reorder({by, bx, ty, tx});
s[out].bind(ty, thread_y);
s[out].bind(tx, thread_x);
s[out].bind(by, block_y);
#ifndef TOPI_CUDA_REDUCTION_H_
#define TOPI_CUDA_REDUCTION_H_
+#include <topi/detail/fuse.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
-#include <topi/detail/fuse.h>
namespace topi {
using namespace tvm;
* an index, such as argmax or argmin.
*
* \return The schedule given by sch
-*/
-Schedule ScheduleReduce(const Target& target,
- Operation op,
- Schedule sch,
+ */
+Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch,
bool is_idx_reduce = false) {
Tensor data_out;
Tensor data_in;
}
auto out_stage = sch[data_out];
- CHECK_GT(out_stage->op.as<ComputeOpNode>()->reduce_axis.size(), 0) <<
- "reduce_axis must be greater than zero";
+ CHECK_GT(out_stage->op.as<ComputeOpNode>()->reduce_axis.size(), 0)
+ << "reduce_axis must be greater than zero";
bool all_reduce;
int num_thread;
}
} else {
if (is_idx_reduce) {
- sch[temp_idx_input].compute_at(stage_real,
- stage_real->op.as<ComputeOpNode>()->axis[0]);
- sch[temp_val_input].compute_at(stage_real,
- stage_real->op.as<ComputeOpNode>()->axis[0]);
+ sch[temp_idx_input].compute_at(stage_real, stage_real->op.as<ComputeOpNode>()->axis[0]);
+ sch[temp_val_input].compute_at(stage_real, stage_real->op.as<ComputeOpNode>()->axis[0]);
}
}
}
/*!
-* \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each
-* of the op's inputs.
-*
-* \param target The target to generate a schedule for.
-* \param s The schedule we are building
-* \param op The reduce op
-*/
+ * \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each
+ * of the op's inputs.
+ *
+ * \param target The target to generate a schedule for.
+ * \param s The schedule we are building
+ * \param op The reduce op
+ */
void TraverseAfterReduce(const Target& target, Schedule s, Operation op) {
if (is_broadcast(op->tag)) {
LOG(ERROR) << "Elementwise op after reduce is not yet supported";
}
/*!
-* \brief Create a CUDA schedule for a reduce operation.
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
+ * \brief Create a CUDA schedule for a reduce operation.
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
Schedule schedule_reduce(const Target& target, Array<Tensor> outs) {
CHECK_EQ(outs.size(), 1) << "outs must have size 1";
Array<Operation> out_ops;
#ifndef TOPI_CUDA_SOFTMAX_H_
#define TOPI_CUDA_SOFTMAX_H_
+#include <topi/detail/fuse.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
-#include <topi/detail/fuse.h>
namespace topi {
using namespace tvm;
*
* \return A schedule for the given ops.
*/
-inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs) {
+inline Schedule schedule_softmax(const Target& target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
*
* \return True iff the given array contains the given item.
*/
-template<typename T>
+template <typename T>
inline bool contains(Array<T> array, T item) {
for (auto& i : array) {
if (i == item) {
#ifndef TOPI_DETAIL_BROADCAST_H_
#define TOPI_DETAIL_BROADCAST_H_
-#include <tvm/te/operation.h>
#include <topi/detail/constant_utils.h>
+#include <tvm/te/operation.h>
#include <algorithm>
#include <deque>
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else {
- CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
- << " and " << shape2[s2_size - i] << " in: "
- << tvm::Array<tvm::PrimExpr>(shape1.begin(), shape1.end())
- << " and "
+ CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
+ << shape2[s2_size - i]
+ << " in: " << tvm::Array<tvm::PrimExpr>(shape1.begin(), shape1.end()) << " and "
<< tvm::Array<tvm::PrimExpr>(shape2.begin(), shape2.end());
}
}
}
inline tvm::Array<tvm::PrimExpr> InputIndexFromBroadcast(
- const tvm::Array<tvm::tir::Var>& ovars,
- const tvm::te::Tensor& T,
- const std::deque<tvm::tir::Var>& my_vars,
- const std::deque<tvm::tir::Var>& all_vars) {
+ const tvm::Array<tvm::tir::Var>& ovars, const tvm::te::Tensor& T,
+ const std::deque<tvm::tir::Var>& my_vars, const std::deque<tvm::tir::Var>& all_vars) {
tvm::Array<tvm::PrimExpr> ivars;
CHECK_EQ(ovars.size(), all_vars.size());
// N^2, could use a map but NBD.
}
template <typename FBinaryExpr>
-inline tvm::te::Tensor WithBroadcast(FBinaryExpr op,
- const tvm::te::Tensor& A,
- const tvm::te::Tensor& B,
- const std::string& name = "tensor",
- const std::string& tag = "") {
+inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A,
+ const tvm::te::Tensor& B, const std::string& name = "tensor",
+ const std::string& tag = "") {
auto bh = BroadcastShape(A->shape, B->shape);
auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
};
- return tvm::te::compute(
- tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
- l,
- name,
- tag);
+ return tvm::te::compute(tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
+ l, name, tag);
}
} // namespace detail
#ifndef TOPI_DETAIL_CONSTANT_UTILS_H_
#define TOPI_DETAIL_CONSTANT_UTILS_H_
-#include <tvm/tir/expr.h>
#include <tvm/arith/analyzer.h>
-#include <tvm/tir/analysis.h>
#include <tvm/te/operation.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
#include <string>
#include <vector>
*
* \return true if the given expr is a constant int or uint, false otherwise.
*/
-inline bool IsConstInt(PrimExpr expr) {
- return
- expr->IsInstance<tvm::tir::IntImmNode>();
-}
+inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImmNode>(); }
/*!
* \brief Get the value of the given constant integer expression. An error
*
* \return A vector of the integer values
*/
-inline std::vector<int> GetConstIntValues(
- Array<PrimExpr> exprs, const std::string& var_name) {
+inline std::vector<int> GetConstIntValues(Array<PrimExpr> exprs, const std::string& var_name) {
std::vector<int> result;
if (!exprs.defined()) return result;
for (auto expr : exprs) {
- CHECK(IsConstInt(expr)) << "All elements of "
- << var_name << " must be constant integers";
+ CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
result.push_back(GetConstInt(expr));
}
return result;
*
* \return A vector of the int64_t values
*/
-inline std::vector<int64_t> GetConstInt64Values(
- Array<PrimExpr> exprs, const std::string& var_name) {
+inline std::vector<int64_t> GetConstInt64Values(Array<PrimExpr> exprs,
+ const std::string& var_name) {
std::vector<int64_t> result;
if (!exprs.defined()) return result;
for (auto expr : exprs) {
}
/*!
- * \brief Check weather the two expressions are equal or not, if not simplify the expressions and check again
- * \note This is stronger equality check than tvm::tir::Equal
+ * \brief Check weather the two expressions are equal or not, if not simplify the expressions and
+ * check again \note This is stronger equality check than tvm::tir::Equal
*
* \param lhs First expreesion
* \param rhs Second expreesion
bool result = expr_equal(lhs, rhs);
if (!result) {
PrimExpr zero(0);
- result = expr_equal(tvm::arith::Analyzer().Simplify(lhs-rhs), zero);
+ result = expr_equal(tvm::arith::Analyzer().Simplify(lhs - rhs), zero);
}
return result;
}
#define TOPI_DETAIL_EXTERN_H_
#include <tvm/te/operation.h>
-#include <vector>
-#include <string>
+#include <string>
+#include <vector>
namespace topi {
namespace detail {
*
* \return The Buffer object
*/
-inline Buffer DeclExternBuffer(Array<PrimExpr> shape,
- DataType dtype,
- std::string name) {
+inline Buffer DeclExternBuffer(Array<PrimExpr> shape, DataType dtype, std::string name) {
auto data = var(name, DataType::Handle());
auto elem_offset = PrimExpr();
- return BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "",
- -1, 0, kDefault);
+ return BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", -1, 0,
+ kDefault);
}
/*!
* be one output Tensor for each element of out_shapes, with dtype equal to the corresponding
* element of out_types.
*/
-inline Array<Tensor> make_extern(const Array< Array<PrimExpr> >& out_shapes,
+inline Array<Tensor> make_extern(const Array<Array<PrimExpr> >& out_shapes,
const std::vector<DataType>& out_types,
- const Array<Tensor>& inputs,
- FExtern fextern,
- std::string name,
- std::string tag,
- ::tvm::Map<std::string, ObjectRef> attrs) {
+ const Array<Tensor>& inputs, FExtern fextern, std::string name,
+ std::string tag, ::tvm::Map<std::string, ObjectRef> attrs) {
CHECK_EQ(out_shapes.size(), out_types.size())
- << "make_extern: out_shapes and out_types must have equal size";
+ << "make_extern: out_shapes and out_types must have equal size";
Array<Buffer> input_placeholders;
for (auto t : inputs) {
auto body = fextern(input_placeholders, output_placeholders);
auto body_stmt = tvm::tir::EvaluateNode::make(body);
- auto op = ExternOpNode::make(
- name, tag, attrs, inputs,
- input_placeholders, output_placeholders, body_stmt);
+ auto op = ExternOpNode::make(name, tag, attrs, inputs, input_placeholders, output_placeholders,
+ body_stmt);
Array<Tensor> outputs;
for (size_t i = 0; i < output_placeholders.size(); ++i) {
*/
inline PrimExpr pack_buffer(Buffer buf) {
CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
- auto shape = tvm::tir::CallNode::make(
- DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape,
- buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
+ auto shape =
+ tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape,
+ buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
PrimExpr strides;
if (buf->strides.size() > 0) {
- strides = tvm::tir::CallNode::make(
- DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape,
- buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
+ strides =
+ tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape,
+ buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
} else {
strides = 0;
}
- Array<PrimExpr> pack_args{
- buf->data,
- shape,
- strides,
- make_const(DataType::Int(32), static_cast<int64_t>(buf->shape.size())),
- make_const(buf->dtype, 0),
- buf->elem_offset
- };
+ Array<PrimExpr> pack_args{buf->data,
+ shape,
+ strides,
+ make_const(DataType::Int(32), static_cast<int64_t>(buf->shape.size())),
+ make_const(buf->dtype, 0),
+ buf->elem_offset};
return tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array,
- pack_args, tvm::tir::CallNode::CallType::Intrinsic);
+ pack_args, tvm::tir::CallNode::CallType::Intrinsic);
}
/*!
* \return An expression representing the invocation
*/
inline PrimExpr call_packed(Array<PrimExpr> args) {
- return tvm::tir::CallNode::make(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed,
- args, tvm::tir::CallNode::CallType::Intrinsic);
+ return tvm::tir::CallNode::make(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args,
+ tvm::tir::CallNode::CallType::Intrinsic);
}
} // namespace detail
*/
/*!
-* \file pad_utils.h
-* \brief Padding helpers
-*/
+ * \file pad_utils.h
+ * \brief Padding helpers
+ */
#ifndef TOPI_DETAIL_PAD_UTILS_H_
#define TOPI_DETAIL_PAD_UTILS_H_
-#include <vector>
+#include <tvm/te/operation.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
-#include "tvm/tir/expr.h"
-#include "tvm/tir/op.h"
+#include <vector>
namespace topi {
namespace detail {
auto pad_top = indexdiv(pad_h + 1, 2);
auto pad_left = indexdiv(pad_w + 1, 2);
- return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left };
+ return {pad_top, pad_left, pad_h - pad_top, pad_w - pad_left};
}
} // namespace detail
*/
/*!
-* \file ravel_unravel.h
-* \brief Index ravel and unraval operations
-*/
+ * \file ravel_unravel.h
+ * \brief Index ravel and unraval operations
+ */
#ifndef TOPI_DETAIL_RAVEL_UNRAVEL_H_
#define TOPI_DETAIL_RAVEL_UNRAVEL_H_
using namespace tvm::te;
/*!
-* \brief Flatten the indices to 1D
-*
-* \param indices The input coordinates
-* \param shape Shape of the tensor
-*
-* \return The index after flattening
-*/
+ * \brief Flatten the indices to 1D
+ *
+ * \param indices The input coordinates
+ * \param shape Shape of the tensor
+ *
+ * \return The index after flattening
+ */
inline PrimExpr RavelIndex(Array<PrimExpr> indices, Array<PrimExpr> shape) {
CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
CHECK_GT(indices.size(), 0) << "indices must not be empty";
}
/*!
-* \brief Convert flattened index to coordinate array
-*
-* \param idx The 1D index
-* \param shape Shape of the tensor
-*
-* \return The coordinate corresponding to the 1D index
-*/
+ * \brief Convert flattened index to coordinate array
+ *
+ * \param idx The 1D index
+ * \param shape Shape of the tensor
+ *
+ * \return The coordinate corresponding to the 1D index
+ */
inline Array<PrimExpr> UnravelIndex(PrimExpr idx, Array<PrimExpr> shape) {
std::vector<PrimExpr> indices;
#ifndef TOPI_DETAIL_TENSOR_UTILS_H_
#define TOPI_DETAIL_TENSOR_UTILS_H_
-
#include <tvm/te/operation.h>
namespace topi {
* \return The interpolated value in the given index.
*/
inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& indices,
- const PrimExpr max_y, const PrimExpr max_x) {
+ const PrimExpr max_y, const PrimExpr max_x) {
auto in_y = indices[2];
auto yf = tvm::floor(in_y);
auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));
auto C = input(indices[0], indices[1], y1, x0);
auto D = input(indices[0], indices[1], y1, x1);
- return A * ( 1 - x_lerp) * ( 1 - y_lerp) +
- B * x_lerp * (1 - y_lerp) +
- C * (1 - x_lerp) * y_lerp +
+ return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
D * x_lerp * y_lerp;
}
#ifndef TOPI_ELEMWISE_H_
#define TOPI_ELEMWISE_H_
-#include <tvm/tir/expr.h>
#include <topi/tags.h>
+#include <tvm/tir/expr.h>
+
#include <algorithm>
#include <string>
+
#include "broadcast.h"
namespace topi {
using namespace tvm::te;
// Unary intrinsic operators
-#define TOPI_DECLARE_UNARY_OP(OpName) \
- inline Tensor OpName(const Tensor& x, \
- std::string name = "T_" #OpName, \
- std::string tag = kElementWise) { \
- return compute(x->shape, [&](const Array<Var>& i) { \
- return ::tvm::OpName(x(i)); \
- }, name, tag); \
+#define TOPI_DECLARE_UNARY_OP(OpName) \
+ inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \
+ std::string tag = kElementWise) { \
+ return compute( \
+ x->shape, [&](const Array<Var>& i) { return ::tvm::OpName(x(i)); }, name, tag); \
}
TOPI_DECLARE_UNARY_OP(exp);
* \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26
*/
-inline Tensor fast_tanh_float(const Tensor& in,
- std::string name,
- std::string tag) {
+inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string tag) {
// Clamp the inputs to the range [-9, 9] since anything outside
// this range is +/-1.0f in single-precision.
auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0));
auto beta_4 = make_const(in->dtype, 1.18534705686654e-04);
auto beta_6 = make_const(in->dtype, 1.19825839466702e-06);
- return compute(x->shape,
- [&](const Array<Var>& i) {
- auto x2 = x(i) * x(i);
- auto p = x2 * alpha_13 + alpha_11;
- p = x2 * p + alpha_9;
- p = x2 * p + alpha_7;
- p = x2 * p + alpha_5;
- p = x2 * p + alpha_3;
- p = x2 * p + alpha_1;
- p = x(i) * p;
-
- auto q = x2 * beta_6 + beta_4;
- q = x2 * q + beta_2;
- q = x2 * q + beta_0;
- return p / q;
- },
- name, tag);
+ return compute(
+ x->shape,
+ [&](const Array<Var>& i) {
+ auto x2 = x(i) * x(i);
+ auto p = x2 * alpha_13 + alpha_11;
+ p = x2 * p + alpha_9;
+ p = x2 * p + alpha_7;
+ p = x2 * p + alpha_5;
+ p = x2 * p + alpha_3;
+ p = x2 * p + alpha_1;
+ p = x(i) * p;
+
+ auto q = x2 * beta_6 + beta_4;
+ q = x2 * q + beta_2;
+ q = x2 * q + beta_0;
+ return p / q;
+ },
+ name, tag);
}
/*!
-* \brief Creates an operation that returns hyperbolic tanh of a given tensor
-*
-* \param x The input tensor
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is tanh
-*/
-inline Tensor fast_tanh(const Tensor& x,
- std::string name = "T_fast_tanh",
+ * \brief Creates an operation that returns hyperbolic tanh of a given tensor
+ *
+ * \param x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is tanh
+ */
+inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
// invoke fast_tanh_float implementation
return fast_tanh_float(x, name, tag);
} else {
// fallback to default implementation
- return compute(x->shape, [&](const Array<Var>& i) {
- return ::tvm::tanh(x(i));
- }, name, tag);
+ return compute(
+ x->shape, [&](const Array<Var>& i) { return ::tvm::tanh(x(i)); }, name, tag);
}
}
/*!
-* \brief Creates an operation that returns identity of a given tensor
-*
-* \param x The input tensor
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the identity operation
-*/
-inline Tensor identity(const Tensor& x,
- std::string name = "T_identity",
+ * \brief Creates an operation that returns identity of a given tensor
+ *
+ * \param x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the identity operation
+ */
+inline Tensor identity(const Tensor& x, std::string name = "T_identity",
std::string tag = kElementWise) {
- return compute(x->shape, [&](const Array<Var>& i) {
- return x(i);
- }, name, tag);
+ return compute(
+ x->shape, [&](const Array<Var>& i) { return x(i); }, name, tag);
}
/*!
-* \brief Creates an operation that returns the negation of a given tensor
-*
-* \param x The input tensor
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the negation operation
-*/
-inline Tensor negative(const Tensor& x,
- std::string name = "T_negative",
+ * \brief Creates an operation that returns the negation of a given tensor
+ *
+ * \param x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the negation operation
+ */
+inline Tensor negative(const Tensor& x, std::string name = "T_negative",
std::string tag = kElementWise) {
- return compute(x->shape, [&](const Array<Var>& i) {
- return -x(i);
- }, name, tag);
+ return compute(
+ x->shape, [&](const Array<Var>& i) { return -x(i); }, name, tag);
}
/*!
-* \brief Creates an operation that returns the logical NOT of a given tensor
-*
-* \param x The input tensor
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the logical NOT operation
-*/
-inline Tensor logical_not(const Tensor& x,
- std::string name = "T_logical_not",
+ * \brief Creates an operation that returns the logical NOT of a given tensor
+ *
+ * \param x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the logical NOT operation
+ */
+inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not",
std::string tag = kElementWise) {
- return compute(x->shape, [&](const Array<Var>& i) {
- return !x(i);
- }, name, tag);
+ return compute(
+ x->shape, [&](const Array<Var>& i) { return !x(i); }, name, tag);
}
/*!
-* \brief Creates an operation that returns the bitwise NOT of a given tensor
-*
-* \param x The input tensor
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the bitwise NOT operation
-*/
-inline Tensor bitwise_not(const Tensor& x,
- std::string name = "T_bitwise_not",
+ * \brief Creates an operation that returns the bitwise NOT of a given tensor
+ *
+ * \param x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the bitwise NOT operation
+ */
+inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not",
std::string tag = kElementWise) {
- return compute(x->shape, [&](const Array<Var>& i) {
- return ~x(i);
- }, name, tag);
+ return compute(
+ x->shape, [&](const Array<Var>& i) { return ~x(i); }, name, tag);
}
/*!
-* \brief Returns the sign of the tensor
-*
-* \param x The input tensor
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the sign
-*/
-inline Tensor sign(const Tensor& x,
- std::string name = "T_sign",
- std::string tag = kElementWise) {
- return compute(x->shape, [&](const Array<Var>& i) {
- PrimExpr zero = make_zero(x->dtype);
- PrimExpr one = make_const(x->dtype, 1);
- PrimExpr minus_one = make_const(x->dtype, -1);
- auto s1 = tvm::tir::SelectNode::make((x(i) < zero), minus_one, zero);
- auto s2 = tvm::tir::SelectNode::make((x(i) > zero), one, s1);
- return s2;
- }, name, tag);
+ * \brief Returns the sign of the tensor
+ *
+ * \param x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the sign
+ */
+inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag = kElementWise) {
+ return compute(
+ x->shape,
+ [&](const Array<Var>& i) {
+ PrimExpr zero = make_zero(x->dtype);
+ PrimExpr one = make_const(x->dtype, 1);
+ PrimExpr minus_one = make_const(x->dtype, -1);
+ auto s1 = tvm::tir::SelectNode::make((x(i) < zero), minus_one, zero);
+ auto s2 = tvm::tir::SelectNode::make((x(i) > zero), one, s1);
+ return s2;
+ },
+ name, tag);
}
/*!
-* \brief Creates an operation that returns rsqrt of a given tensor
-*
-* \param x The input tensor
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the rsqrt operation
-*/
-inline Tensor rsqrt(const Tensor& x,
- std::string name = "tensor",
- std::string tag = kElementWise) {
- return compute(x->shape, [&](const Array<Var>& i) {
- PrimExpr one = make_const(x->dtype, 1);
- return one/tvm::sqrt(x(i));
- }, name, tag);
+ * \brief Creates an operation that returns rsqrt of a given tensor
+ *
+ * \param x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the rsqrt operation
+ */
+inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string tag = kElementWise) {
+ return compute(
+ x->shape,
+ [&](const Array<Var>& i) {
+ PrimExpr one = make_const(x->dtype, 1);
+ return one / tvm::sqrt(x(i));
+ },
+ name, tag);
}
/*!
-* \brief Creates an operation that clips each element of a tensor to
-* the interval [a_min, a_max]
-*
-* \param x The input tensor
-* \param a_min The inclusive lower bound of the interval
-* \param a_max The inclusive upper bound of the interval
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the clip operation
-*/
-inline Tensor clip(const Tensor& x,
- const PrimExpr& a_min,
- const PrimExpr& a_max,
- std::string name = "T_clip",
- std::string tag = kElementWise) {
- return compute(x->shape, [&](const Array<Var>& i) {
- auto min_val = tvm::cast(x->dtype, a_min);
- auto max_val = tvm::cast(x->dtype, a_max);
- return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*)
- }, name, tag);
+ * \brief Creates an operation that clips each element of a tensor to
+ * the interval [a_min, a_max]
+ *
+ * \param x The input tensor
+ * \param a_min The inclusive lower bound of the interval
+ * \param a_max The inclusive upper bound of the interval
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the clip operation
+ */
+inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max,
+ std::string name = "T_clip", std::string tag = kElementWise) {
+ return compute(
+ x->shape,
+ [&](const Array<Var>& i) {
+ auto min_val = tvm::cast(x->dtype, a_min);
+ auto max_val = tvm::cast(x->dtype, a_max);
+ return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*)
+ },
+ name, tag);
}
/*!
*
* \return A Tensor whose op member is the cast operation
*/
-inline Tensor cast(const Tensor& x,
- DataType type,
- std::string name = "T_cast",
+inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast",
std::string tag = kElementWise) {
- return compute(x->shape, [&](const Array<Var>& i) {
- auto expr = x(i);
- if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) {
- if (expr.dtype().lanes() == type.lanes()) {
- return expr;
- } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
- return tvm::tir::BroadcastNode::make(expr, type.lanes());
- }
- }
-
- return tvm::cast(type, x(i));
- }, name, tag);
+ return compute(
+ x->shape,
+ [&](const Array<Var>& i) {
+ auto expr = x(i);
+ if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) {
+ if (expr.dtype().lanes() == type.lanes()) {
+ return expr;
+ } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
+ return tvm::tir::BroadcastNode::make(expr, type.lanes());
+ }
+ }
+
+ return tvm::cast(type, x(i));
+ },
+ name, tag);
}
/*!
*/
inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor",
std::string tag = kElementWise) {
- return compute(x->shape,
- [&](const Array<Var>& i) {
- return tvm::tir::CallNode::make(type, "reinterpret", {x(i)},
- tvm::tir::CallNode::PureIntrinsic);
- },
- name, tag);
+ return compute(
+ x->shape,
+ [&](const Array<Var>& i) {
+ return tvm::tir::CallNode::make(type, "reinterpret", {x(i)},
+ tvm::tir::CallNode::PureIntrinsic);
+ },
+ name, tag);
}
/*!
*
* \return A Tensor whose op member is the sum operation
*/
-inline Tensor elemwise_sum(const Array<Tensor>& xs,
- std::string name = "T_elemwise_sum",
+inline Tensor elemwise_sum(const Array<Tensor>& xs, std::string name = "T_elemwise_sum",
std::string tag = kElementWise) {
CHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor.";
- return compute(xs[0]->shape, [&](const Array<Var>& i) {
- auto sum_expr = xs[0](i);
- for (size_t j = 1; j < xs.size(); j++) {
- sum_expr = sum_expr + xs[j](i);
- }
- return sum_expr;
- }, name, tag);
+ return compute(
+ xs[0]->shape,
+ [&](const Array<Var>& i) {
+ auto sum_expr = xs[0](i);
+ for (size_t j = 1; j < xs.size(); j++) {
+ sum_expr = sum_expr + xs[j](i);
+ }
+ return sum_expr;
+ },
+ name, tag);
}
/*!
-* \brief Creates an operation that fill a tensor with fill_value
-*
-* \param shape The shape of a tensor
-* \param dtype The Type of fill_value
-* \param fill_value The value to be filled
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the full operation
-*/
-inline Tensor full(const Array<PrimExpr>& shape,
- DataType dtype,
- const PrimExpr fill_value,
- std::string name = "T_full",
- std::string tag = kElementWise) {
+ * \brief Creates an operation that fill a tensor with fill_value
+ *
+ * \param shape The shape of a tensor
+ * \param dtype The Type of fill_value
+ * \param fill_value The value to be filled
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the full operation
+ */
+inline Tensor full(const Array<PrimExpr>& shape, DataType dtype, const PrimExpr fill_value,
+ std::string name = "T_full", std::string tag = kElementWise) {
PrimExpr ev = cast(dtype, fill_value);
if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << dtype;
}
- return compute(shape, [&](const Array<Var>& i) {
- return ev;
- }, name, tag);
+ return compute(
+ shape, [&](const Array<Var>& i) { return ev; }, name, tag);
}
/*!
-* \brief Creates an operation that construct a tensor with same shape as input tensor,
-* then fill a tensor with fill_value
-*
-* \param x The input tensor
-* \param fill_value The value to be filled
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op memeber is the full_like operation
-*/
-inline Tensor full_like(const Tensor& x,
- const PrimExpr fill_value,
- std::string name = "T_full_like",
- std::string tag = kElementWise) {
+ * \brief Creates an operation that construct a tensor with same shape as input tensor,
+ * then fill a tensor with fill_value
+ *
+ * \param x The input tensor
+ * \param fill_value The value to be filled
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op memeber is the full_like operation
+ */
+inline Tensor full_like(const Tensor& x, const PrimExpr fill_value,
+ std::string name = "T_full_like", std::string tag = kElementWise) {
PrimExpr ev = cast(x->dtype, fill_value);
- return compute(x->shape, [&](const Array<Var>& i) {
- return ev;
- }, name, tag);
+ return compute(
+ x->shape, [&](const Array<Var>& i) { return ev; }, name, tag);
}
/*!
* Approximation for fractional part:
* y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2))
*/
-inline Tensor fast_exp_float32(const Tensor& _x,
- std::string name,
- std::string tag) {
+inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) {
auto x_hi = make_const(DataType::Float(32), 88.3762626647950f);
auto x_lo = make_const(DataType::Float(32), -88.3762626647949f);
auto log2e = make_const(DataType::Float(32), 1.44269504088896341f);
auto one_half = make_const(DataType::Float(32), 0.5f);
auto b = make_const(DataType::Float(32), 127.0f);
- return compute(_x->shape,
- [&](const Array<Var>& i) {
- // clamp x
- auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo);
- // integer part
- auto n = ::tvm::floor(x * log2e + one_half);
- // fractional part
- auto f = x - n * ln2;
- auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f
- + p[5]) * f * f + f + one;
- // Return 2^m * exp(r).
- auto ef = tvm::reinterpret(DataType::Float(32),
- ::tvm::cast(DataType::Int(32), n + b) << 23);
- return ::tvm::max(ef * y, _x(i)); // NOLINT(*)
- },
- name, tag);
+ return compute(
+ _x->shape,
+ [&](const Array<Var>& i) {
+ // clamp x
+ auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo);
+ // integer part
+ auto n = ::tvm::floor(x * log2e + one_half);
+ // fractional part
+ auto f = x - n * ln2;
+ auto y =
+ (((((p[0] * f + p[1]) * f + p[2]) * f + p[3]) * f + p[4]) * f + p[5]) * f * f + f + one;
+ // Return 2^m * exp(r).
+ auto ef =
+ tvm::reinterpret(DataType::Float(32), ::tvm::cast(DataType::Int(32), n + b) << 23);
+ return ::tvm::max(ef * y, _x(i)); // NOLINT(*)
+ },
+ name, tag);
}
-
/*!
* \brief Fast exponential function implementation
*
* \return A Tensor whose op member is exponent operation
*
*/
-inline Tensor fast_exp(const Tensor& x,
- std::string name = "T_fast_exp",
- std::string tag = kElementWise) {
+inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp",
+ std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
auto ret = fast_exp_float32(x, name, tag);
return ret;
} else {
- return compute(x->shape, [&](const Array<Var>& i) {
- return ::tvm::exp(x(i));
- }, name, tag);
+ return compute(
+ x->shape, [&](const Array<Var>& i) { return ::tvm::exp(x(i)); }, name, tag);
}
}
* \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290
*/
-inline Tensor fast_erf_float32(const Tensor& data,
- std::string name,
- std::string tag) {
+inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) {
auto plus_4 = make_const(DataType::Float(32), 4.f);
auto minus_4 = make_const(DataType::Float(32), -4.f);
auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f);
auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f);
- return compute(data->shape, [&](const Array<Var> &i) {
- // clamp x
- auto x = tvm::max(tvm::min(data(i), plus_4), minus_4);
- auto x2 = x * x;
-
- // Evaluate the numerator polynomial p.
- auto p = x2 * alpha_13 + alpha_11;
- p = x2 * p + alpha_9;
- p = x2 * p + alpha_7;
- p = x2 * p + alpha_5;
- p = x2 * p + alpha_3;
- p = x2 * p + alpha_1;
- p = x * p;
-
- // Evaluate the denominator polynomial p.
- auto q = x2 * beta_8 + beta_6;
- q = x2 * q + beta_4;
- q = x2 * q + beta_2;
- q = x2 * q + beta_0;
-
- return p / q;
- }, name, tag);
+ return compute(
+ data->shape,
+ [&](const Array<Var>& i) {
+ // clamp x
+ auto x = tvm::max(tvm::min(data(i), plus_4), minus_4);
+ auto x2 = x * x;
+
+ // Evaluate the numerator polynomial p.
+ auto p = x2 * alpha_13 + alpha_11;
+ p = x2 * p + alpha_9;
+ p = x2 * p + alpha_7;
+ p = x2 * p + alpha_5;
+ p = x2 * p + alpha_3;
+ p = x2 * p + alpha_1;
+ p = x * p;
+
+ // Evaluate the denominator polynomial p.
+ auto q = x2 * beta_8 + beta_6;
+ q = x2 * q + beta_4;
+ q = x2 * q + beta_2;
+ q = x2 * q + beta_0;
+
+ return p / q;
+ },
+ name, tag);
}
/*!
*
* \return A Tensor whose op member is erf operation
*/
-inline Tensor fast_erf(const Tensor& x,
- std::string name = "T_fast_erf",
+inline Tensor fast_erf(const Tensor& x, std::string name = "T_fast_erf",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
auto ret = fast_erf_float32(x, name, tag);
#ifndef TOPI_GENERIC_DEFAULT_H_
#define TOPI_GENERIC_DEFAULT_H_
+#include <topi/detail/fuse.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
-#include <topi/detail/fuse.h>
namespace topi {
using namespace tvm;
namespace generic {
/*!
-* \brief Create a generic default schedule for the given output tensors.
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
+ * \brief Create a generic default schedule for the given output tensors.
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
inline Schedule default_schedule(const Target& target, Array<Tensor> outs) {
Array<Operation> out_ops;
for (auto t : outs) {
}
/*!
-* \brief Create a generic default schedule for the given output tensors, and apply
-* auto inline
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
+ * \brief Create a generic default schedule for the given output tensors, and apply
+ * auto inline
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
inline Schedule default_schedule_auto_inline(const Target& target, Array<Tensor> outs) {
Array<Operation> out_ops;
for (auto t : outs) {
#ifndef TOPI_GENERIC_EXTERN_H_
#define TOPI_GENERIC_EXTERN_H_
-#include <tvm/te/operation.h>
-#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
#include <topi/detail/fuse.h>
#include <topi/generic/injective.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule_pass.h>
namespace topi {
using namespace tvm;
namespace generic {
/*!
-* \brief Schedule an extern op followed by injective operations
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the op.
-*/
+ * \brief Schedule an extern op followed by injective operations
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the op.
+ */
inline Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
Array<Operation> out_ops;
for (auto t : outs) {
#ifndef TOPI_GENERIC_INJECTIVE_H_
#define TOPI_GENERIC_INJECTIVE_H_
+#include <topi/detail/fuse.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
-#include <topi/detail/fuse.h>
namespace topi {
using namespace tvm;
*
* \return A schedule for the given ops.
*/
-inline Schedule schedule_injective(const Target &target, const Array<Tensor>& outs) {
+inline Schedule schedule_injective(const Target& target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
#ifndef TOPI_NN_H_
#define TOPI_NN_H_
-#include <topi/tags.h>
#include <topi/detail/constant_utils.h>
+#include <topi/tags.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
-#include <tvm/te/operation.h>
#include <algorithm>
#include <string>
* \return A Tensor whose op member is the relu operation
*/
template <typename T>
-inline tvm::te::Tensor relu(const tvm::te::Tensor& t,
- T threshold = static_cast<T>(0),
- std::string name = "T_relu",
- std::string tag = kElementWise) {
+inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast<T>(0),
+ std::string name = "T_relu", std::string tag = kElementWise) {
return tvm::te::compute(
t->shape,
[&](const tvm::Array<tvm::tir::Var>& i) {
auto threshold_const = tvm::tir::make_const(t->dtype, threshold);
return tvm::max(t(i), threshold_const);
},
- name,
- tag);
+ name, tag);
}
/*!
-* \brief Creates an operation that performs a leaky rectified linear unit
-*
-* \param t The input tensor
-* \param alpha The slope for the small gradient when t < 0
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the leaky relu operation
-*/
-inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t,
- double alpha = 0.1,
- std::string name = "T_leaky_relu",
- std::string tag = kElementWise) {
+ * \brief Creates an operation that performs a leaky rectified linear unit
+ *
+ * \param t The input tensor
+ * \param alpha The slope for the small gradient when t < 0
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the leaky relu operation
+ */
+inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1,
+ std::string name = "T_leaky_relu",
+ std::string tag = kElementWise) {
return tvm::te::compute(
- t->shape,
- [&](const tvm::Array<tvm::tir::Var>& i) {
- auto value = t(i);
- auto calpha = tvm::tir::make_const(value.dtype(), alpha);
- return tvm::tir::SelectNode::make(value > 0, value, value * calpha);
- },
- name,
- tag);
+ t->shape,
+ [&](const tvm::Array<tvm::tir::Var>& i) {
+ auto value = t(i);
+ auto calpha = tvm::tir::make_const(value.dtype(), alpha);
+ return tvm::tir::SelectNode::make(value > 0, value, value * calpha);
+ },
+ name, tag);
}
/*!
*
* \return A Tensor whose op member is the parametric relu operation
*/
-inline tvm::te::Tensor prelu(const tvm::te::Tensor &x,
- const tvm::te::Tensor &slope,
- const int axis = 1,
- std::string name = "T_prelu",
- std::string tag = kBroadcast) {
- CHECK((size_t)axis < x->shape.size()) <<
- "Wrong axis (" << axis << ")value. ";
- CHECK(topi::detail::GetConstInt(slope->shape[0]) ==
- topi::detail::GetConstInt(x->shape[axis]))
- << "Wrong slope shape received.";
+inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope,
+ const int axis = 1, std::string name = "T_prelu",
+ std::string tag = kBroadcast) {
+ CHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. ";
+ CHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis]))
+ << "Wrong slope shape received.";
- return tvm::te::compute(x->shape,
- [&](const tvm::Array<tvm::tir::Var> &indices) {
- auto xval = x(indices);
- return tvm::tir::SelectNode::make(
- xval > 0,
- xval,
- xval * slope(indices[axis]));
- },
- name,
- tag);
+ return tvm::te::compute(
+ x->shape,
+ [&](const tvm::Array<tvm::tir::Var>& indices) {
+ auto xval = x(indices);
+ return tvm::tir::SelectNode::make(xval > 0, xval, xval * slope(indices[axis]));
+ },
+ name, tag);
}
/*!
*
*
*/
-inline tvm::te::Tensor pad(const tvm::te::Tensor& t,
- const tvm::Array<tvm::PrimExpr>& pad_before,
- tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
- PrimExpr pad_value = PrimExpr(),
- std::string name = "T_pad",
- std::string tag = kElementWise,
- std::string pad_mode = "constant") {
+inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array<tvm::PrimExpr>& pad_before,
+ tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
+ PrimExpr pad_value = PrimExpr(), std::string name = "T_pad",
+ std::string tag = kElementWise, std::string pad_mode = "constant") {
if (pad_after.size() < pad_before.size()) {
for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
pad_after.push_back(pad_before[i]);
tvm::Array<tvm::PrimExpr> output_shape;
tvm::Array<tvm::PrimExpr> pad_before_int32;
tvm::Array<tvm::PrimExpr> pad_after_int32;
- for (const auto &ele : pad_before) {
+ for (const auto& ele : pad_before) {
pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
}
- for (const auto &ele : pad_after) {
+ for (const auto& ele : pad_after) {
pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
}
for (size_t i = 0; i < t->shape.size(); ++i) {
sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
}
if (pad_mode == "edge") {
- pad_idx.push_back(tvm::if_then_else(
- ovars[i] < pad_before[i],
- 0,
- tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
- t->shape[i] - 1,
- ovars[i] - pad_before[i])));
+ pad_idx.push_back(
+ tvm::if_then_else(ovars[i] < pad_before[i], 0,
+ tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
+ t->shape[i] - 1, ovars[i] - pad_before[i])));
} else if (pad_mode == "reflect") {
- pad_idx.push_back(tvm::if_then_else(
- ovars[i] < pad_before[i],
- pad_before[i] - ovars[i],
- tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
- t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
- ovars[i] - pad_before[i])));
+ pad_idx.push_back(
+ tvm::if_then_else(ovars[i] < pad_before[i], pad_before[i] - ovars[i],
+ tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
+ t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
+ ovars[i] - pad_before[i])));
}
}
if (sel.size() != 0) {
if (pad_mode == "constant") {
- return tvm::if_then_else(
- detail::Map(sel, tvm::tir::AndNode::make), t(indices), pad_value);
+ return tvm::if_then_else(detail::Map(sel, tvm::tir::AndNode::make), t(indices), pad_value);
} else if (pad_mode == "edge" || pad_mode == "reflect") {
- return tvm::if_then_else(
- detail::Map(sel, tvm::tir::AndNode::make), t(indices), t(pad_idx));
+ return tvm::if_then_else(detail::Map(sel, tvm::tir::AndNode::make), t(indices), t(pad_idx));
}
}
return t(indices);
* \return A Tensor whose op member is the 2-D convolution operation (NCHW
* layout)
*/
-inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I,
- const tvm::te::Tensor& W,
- int pad_h = 0,
- int pad_w = 0,
- int stride_h = 1,
- int stride_w = 1,
- std::string name = "T_conv2d_nchw",
- std::string tag = kConv2dNCHW) {
+inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
+ int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
+ std::string name = "T_conv2d_nchw",
+ std::string tag = kConv2dNCHW) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[2];
auto pW = I->shape[3];
tvm::Array<tvm::PrimExpr> output_shape{
- I->shape[0], // B
- W->shape[0], // O
- indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
- indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
+ I->shape[0], // B
+ W->shape[0], // O
+ indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
+ indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
};
auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
- auto T = (pad_h == 0 && pad_w == 0)
- ? I
- : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
+ auto T =
+ (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
- return tvm::sum(
- T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw),
- {i, kh, kw});
+ return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw});
};
return tvm::te::compute(output_shape, l, name, tag);
}
* \return A Tensor whose op member is the 2-D convolution operation
* (HWCN layout)
*/
-inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I,
- const tvm::te::Tensor& W,
- int pad_h = 0,
- int pad_w = 0,
- int stride_h = 1,
- int stride_w = 1,
- std::string name = "T_conv2d_hwcn",
- std::string tag = kConv2dHWCN) {
+inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
+ int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
+ std::string name = "T_conv2d_hwcn",
+ std::string tag = kConv2dHWCN) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[2];
tvm::Array<tvm::PrimExpr> output_shape{
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W
- I->shape[2], // B
- W->shape[3] // O
+ I->shape[2], // B
+ W->shape[3] // O
};
auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w});
auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
- return tvm::sum(
- T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o),
- {i, kh, kw});
+ return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw});
};
return tvm::te::compute(output_shape, l, name, tag);
}
-
/*!
* \brief Creates an operation that performs a 2-D depthwise convolution with
* an NCHW-layout
* \return A Tensor whose op member is the 2-D depthwise convolution operation
* (NCHW layout)
*/
-inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I,
- const tvm::te::Tensor& W,
- int pad_h = 0,
- int pad_w = 0,
- int stride_h = 1,
- int stride_w = 1,
- std::string name = "T_depthwise_conv2d_nchw",
- std::string tag = kDepthwiseConv2dNCHW) {
+inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
+ int pad_h = 0, int pad_w = 0, int stride_h = 1,
+ int stride_w = 1,
+ std::string name = "T_depthwise_conv2d_nchw",
+ std::string tag = kDepthwiseConv2dNCHW) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[2];
auto pW = I->shape[3];
auto pCM = W->shape[1]; // channel_multiplier
tvm::Array<tvm::PrimExpr> output_shape{
- I->shape[0], // B
- W->shape[1], // O
+ I->shape[0], // B
+ W->shape[1], // O
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
};
auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
- auto T = (pad_h == 0 && pad_w == 0)
- ? I
- : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
+ auto T =
+ (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
- W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
+ W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
{i, kh, kw});
};
return tvm::te::compute(output_shape, l, name, tag);
}
-inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I,
- const tvm::te::Tensor& W,
- int pad_h = 0,
- int pad_w = 0,
- int stride_h = 1,
- int stride_w = 1,
- std::string name = "T_depthwise_conv2d_nhwc",
- std::string tag = kDepthwiseConv2dNHWC) {
+inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
+ int pad_h = 0, int pad_w = 0, int stride_h = 1,
+ int stride_w = 1,
+ std::string name = "T_depthwise_conv2d_nhwc",
+ std::string tag = kDepthwiseConv2dNHWC) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[1];
auto pW = I->shape[2];
auto pCM = W->shape[1]; // channel_multiplier
tvm::Array<tvm::PrimExpr> output_shape{
- I->shape[0], // B
+ I->shape[0], // B
indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H
- indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
- W->shape[3], // O
+ indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
+ W->shape[3], // O
};
auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
- auto T = (pad_h == 0 && pad_w == 0)
- ? I
- : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)});
+ auto T =
+ (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)});
auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) {
return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
- W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
+ W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
{kh, kw, i});
};
return tvm::te::compute(output_shape, l, name, tag);
* \return A Tensor whose op member is the 2-D groupconvolution operation
* (NCHW layout)
*/
-inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I,
- const tvm::te::Tensor& W,
- int pad_h = 0,
- int pad_w = 0,
- int stride_h = 1,
- int stride_w = 1,
- std::string name = "T_group_conv2d_ngchw",
- std::string tag = kGroupConv2d) {
+inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
+ int pad_h = 0, int pad_w = 0, int stride_h = 1,
+ int stride_w = 1,
+ std::string name = "T_group_conv2d_ngchw",
+ std::string tag = kGroupConv2d) {
CHECK_EQ(5, I->shape.size());
CHECK_EQ(5, W->shape.size());
auto pH = I->shape[2];
auto pW = I->shape[3];
tvm::Array<tvm::PrimExpr> output_shape{
- I->shape[0], // B
- I->shape[1], // G
- W->shape[2], // O
+ I->shape[0], // B
+ I->shape[1], // G
+ W->shape[2], // O
indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W
};
tvm::tir::Var o = args[2];
tvm::tir::Var h = args[3];
tvm::tir::Var w = args[4];
- return tvm::sum(
- I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
- {i, kh, kw});
+ return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
+ {i, kh, kw});
};
return tvm::te::compute(output_shape, l, name, tag);
}
#ifndef TOPI_NN_BATCH_MATMUL_H_
#define TOPI_NN_BATCH_MATMUL_H_
-#include <tvm/te/operation.h>
#include <topi/tags.h>
+#include <tvm/te/operation.h>
#include <string>
using namespace tvm::te;
/*!
-* \brief Creates an operation that calculates matrix multiplication in batch.
-*
-* \param x Tensor with shape [batch, M, K]
-* \param y Tensor with shape [batch, N, K]
-*
-* \return Tensor with shape [batch, M, N]
-*/
-inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x,
- const tvm::te::Tensor& y) {
+ * \brief Creates an operation that calculates matrix multiplication in batch.
+ *
+ * \param x Tensor with shape [batch, M, K]
+ * \param y Tensor with shape [batch, N, K]
+ *
+ * \return Tensor with shape [batch, M, N]
+ */
+inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, const tvm::te::Tensor& y) {
CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data";
CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data";
auto k = tvm::te::reduce_axis(Range(0, K), "k");
auto result = tvm::te::compute(
- { batch, M, N },
- [&](Var b, Var i, Var j) {
- return tvm::sum(x(b, i, k) * y(b, j, k), { k });
- }, "tensor", "batch_matmul");
+ {batch, M, N}, [&](Var b, Var i, Var j) { return tvm::sum(x(b, i, k) * y(b, j, k), {k}); },
+ "tensor", "batch_matmul");
return result;
}
#ifndef TOPI_NN_BIAS_ADD_H_
#define TOPI_NN_BIAS_ADD_H_
-#include <tvm/te/operation.h>
-#include <topi/tags.h>
#include <topi/broadcast.h>
+#include <topi/tags.h>
#include <topi/transform.h>
+#include <tvm/te/operation.h>
#include <string>
namespace nn {
/*!
-* \brief Creates an operation that calculates data + bias
-*
-* \param data Tensor with shape [batch, in_dim]
-* \param bias Tensor with shape [batch].
-* \param axis The axis to add the bias to.
-* \return Tensor with shape [batch, in_dim]
-*/
-inline tvm::te::Tensor bias_add(const tvm::te::Tensor& data,
- const tvm::te::Tensor& bias,
- int axis) {
+ * \brief Creates an operation that calculates data + bias
+ *
+ * \param data Tensor with shape [batch, in_dim]
+ * \param bias Tensor with shape [batch].
+ * \param axis The axis to add the bias to.
+ * \return Tensor with shape [batch, in_dim]
+ */
+inline tvm::te::Tensor bias_add(const tvm::te::Tensor& data, const tvm::te::Tensor& bias,
+ int axis) {
int data_ndim = data->shape.size();
if (axis < 0) {
axis += data_ndim;
#ifndef TOPI_NN_BNN_H_
#define TOPI_NN_BNN_H_
-#include <tvm/te/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <topi/tags.h>
#include <topi/detail/constant_utils.h>
+#include <topi/tags.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/te/operation.h>
#include <string>
using namespace tvm::te;
/*!
-* \brief Binarization and bit-packing along a certain axis.
-*
-* \param data N-D tensor, can be any layout
-* \param axis The axis along which to do binarization and bit-packing. This axis
-* must have a size equal to an integer multiple of 32.
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return Output tensor with dtype uint32
-*/
-inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data,
- int axis,
- std::string name = "PackedInput",
- std::string tag = "binarize_pack") {
+ * \brief Binarization and bit-packing along a certain axis.
+ *
+ * \param data N-D tensor, can be any layout
+ * \param axis The axis along which to do binarization and bit-packing. This axis
+ * must have a size equal to an integer multiple of 32.
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return Output tensor with dtype uint32
+ */
+inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis,
+ std::string name = "PackedInput",
+ std::string tag = "binarize_pack") {
auto ishape = data->shape;
CHECK_EQ(GetConstInt(ishape[axis]) % 32, 0)
- << "binarize_pack: axis size must be a multiple of 32";
+ << "binarize_pack: axis size must be a multiple of 32";
arith::Analyzer analyzer;
auto n = ishape.size();
Array<PrimExpr> oshape;
for (size_t i = 0; i < n; ++i) {
- oshape.push_back(i == static_cast<size_t>(axis) ?
- analyzer.Simplify(indexdiv(ishape[i], 32)) :
- ishape[i]);
+ oshape.push_back(i == static_cast<size_t>(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32))
+ : ishape[i]);
}
return tvm::te::compute(
- oshape,
- [&](const Array<Var>& indices) {
- Array<PrimExpr> start_idx;
- for (size_t i = 0; i < n; ++i) {
- start_idx.push_back(i == static_cast<size_t>(axis) ?
- indices[i] * 32 :
- static_cast<PrimExpr>(indices[i]));
- }
- auto packed = make_const(DataType::UInt(32), 0);
- for (size_t j = 0; j < 32; ++j) {
- Array<PrimExpr> idx;
+ oshape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> start_idx;
for (size_t i = 0; i < n; ++i) {
- idx.push_back(i == static_cast<size_t>(axis) ?
- start_idx[i] + static_cast<int>(j) :
- start_idx[i]);
+ start_idx.push_back(i == static_cast<size_t>(axis) ? indices[i] * 32
+ : static_cast<PrimExpr>(indices[i]));
}
- auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0);
- packed = (packed | sign);
- if (j == 31) {
- return packed;
+ auto packed = make_const(DataType::UInt(32), 0);
+ for (size_t j = 0; j < 32; ++j) {
+ Array<PrimExpr> idx;
+ for (size_t i = 0; i < n; ++i) {
+ idx.push_back(i == static_cast<size_t>(axis) ? start_idx[i] + static_cast<int>(j)
+ : start_idx[i]);
+ }
+ auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0);
+ packed = (packed | sign);
+ if (j == 31) {
+ return packed;
+ }
+ packed = packed << 1;
}
- packed = packed << 1;
- }
- return packed; // never reached, but suppress compiler warning
- }, name, tag);
+ return packed; // never reached, but suppress compiler warning
+ },
+ name, tag);
}
/*!
-* \brief Binary matrix multiplication using xor and bit-count
-*
-* \param data Tensor with shape [batch, in_dim], dtype is uint32
-* \param weight Tensor with shape [out_dim, in_dim], dtype is uint32
-*
-* \return Tensor with shape [batch, out_dim], dtype is float32
-*/
-inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data,
- const tvm::te::Tensor& weight) {
+ * \brief Binary matrix multiplication using xor and bit-count
+ *
+ * \param data Tensor with shape [batch, in_dim], dtype is uint32
+ * \param weight Tensor with shape [out_dim, in_dim], dtype is uint32
+ *
+ * \return Tensor with shape [batch, out_dim], dtype is float32
+ */
+inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) {
CHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight";
CHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data";
auto k = tvm::te::reduce_axis(Range(0, in_dim), "k");
auto matmul = tvm::te::compute(
- { batch, out_dim },
- [&](Var i, Var j) {
- return tvm::sum(popcount(data(i, k) ^ weight(j, k)), { k });
- }, "tensor", "binary_dense");
+ {batch, out_dim},
+ [&](Var i, Var j) { return tvm::sum(popcount(data(i, k) ^ weight(j, k)), {k}); }, "tensor",
+ "binary_dense");
return tvm::te::compute(
- { batch, out_dim },
- [&](Var i, Var j) {
- return 32 * in_dim - 2.0f * matmul(i, j);
- }, "tensor", kElementWise);
+ {batch, out_dim}, [&](Var i, Var j) { return 32 * in_dim - 2.0f * matmul(i, j); }, "tensor",
+ kElementWise);
}
} // namespace nn
#ifndef TOPI_NN_DENSE_H_
#define TOPI_NN_DENSE_H_
-#include <tvm/te/operation.h>
#include <topi/tags.h>
+#include <tvm/te/operation.h>
#include <string>
using namespace tvm::te;
/*!
-* \brief Creates an operation that calculates data * weight^T + bias
-*
-* \param data Tensor with shape [batch, in_dim]
-* \param weight Tensor with shape [out_dim, in_dim]
-* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
-* \param out_dtype Output data type. Used for mixed precision.
-*
-* \return Tensor with shape [batch, out_dim]
-*/
-inline tvm::te::Tensor dense(const tvm::te::Tensor& data,
- const tvm::te::Tensor& weight,
- const tvm::te::Tensor& bias,
- const DataType& out_dtype) {
+ * \brief Creates an operation that calculates data * weight^T + bias
+ *
+ * \param data Tensor with shape [batch, in_dim]
+ * \param weight Tensor with shape [out_dim, in_dim]
+ * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
+ * \param out_dtype Output data type. Used for mixed precision.
+ *
+ * \return Tensor with shape [batch, out_dim]
+ */
+inline tvm::te::Tensor dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight,
+ const tvm::te::Tensor& bias, const DataType& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
auto k = tvm::te::reduce_axis(Range(0, in_dim), "k");
auto matmul = tvm::te::compute(
- { batch, out_dim },
- [&](Var i, Var j) {
- return tvm::sum(tvm::cast(out_dtype, data(i, k)) *
- tvm::cast(out_dtype, weight(j, k)), { k });
- }, "tensor", "dense");
+ {batch, out_dim},
+ [&](Var i, Var j) {
+ return tvm::sum(tvm::cast(out_dtype, data(i, k)) * tvm::cast(out_dtype, weight(j, k)), {k});
+ },
+ "tensor", "dense");
if (bias.defined()) {
matmul = tvm::te::compute(
- { batch, out_dim },
- [&](Var i, Var j) {
- return matmul(i, j) + tvm::cast(out_dtype, bias(j));
- }, "tensor", kBroadcast);
+ {batch, out_dim},
+ [&](Var i, Var j) { return matmul(i, j) + tvm::cast(out_dtype, bias(j)); }, "tensor",
+ kBroadcast);
}
return matmul;
#ifndef TOPI_NN_DILATE_H_
#define TOPI_NN_DILATE_H_
-#include <tvm/te/operation.h>
-#include <tvm/arith/analyzer.h>
#include <topi/tags.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/te/operation.h>
#include <string>
using namespace tvm::te;
/*!
-* \brief Create a new expression of the logical and of all
-* conditions in the arguments.
-*
-* \param args The arguments to find the logical conjunction of
-*
-* \return The logical conjunction expression
-*/
+ * \brief Create a new expression of the logical and of all
+ * conditions in the arguments.
+ *
+ * \param args The arguments to find the logical conjunction of
+ *
+ * \return The logical conjunction expression
+ */
PrimExpr all(Array<PrimExpr> args) {
CHECK_GT(args.size(), 0) << "all requires at least one argument";
}
/*!
-* \brief Dilate data with zeros
-*
-* \param x The input tensor, this can have any number of
-* dimensions and any layout.
-* \param strides Dilation stride for each dimension. Stride 1
-* means no dilation.
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return The output tensor.
-*/
-inline Tensor dilate(const Tensor& x,
- Array<PrimExpr> strides,
- std::string name = "tensor",
+ * \brief Dilate data with zeros
+ *
+ * \param x The input tensor, this can have any number of
+ * dimensions and any layout.
+ * \param strides Dilation stride for each dimension. Stride 1
+ * means no dilation.
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return The output tensor.
+ */
+inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, std::string name = "tensor",
std::string tag = kInjective) {
auto n = x->shape.size();
- CHECK_EQ(n, strides.size())
- << "strides size (" << strides.size()
- << ") must match dimension of x (" << n << ")";
+ CHECK_EQ(n, strides.size()) << "strides size (" << strides.size()
+ << ") must match dimension of x (" << n << ")";
Array<PrimExpr> out_shape;
arith::Analyzer analyzer;
for (size_t i = 0; i < n; ++i) {
- out_shape.push_back(analyzer.Simplify(
- (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1)));
+ out_shape.push_back(
+ analyzer.Simplify((x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1)));
}
return tvm::te::compute(
- out_shape,
- [&](const Array<Var>& indices) {
- Array<PrimExpr> not_zero;
- Array<PrimExpr> index_tuple;
- for (size_t i = 0; i < n; ++i) {
- if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
- index_tuple.push_back(indices[i]);
- } else {
- index_tuple.push_back(indexdiv(indices[i], strides[i]));
- not_zero.push_back((indexmod(indices[i], strides[i])) == 0);
+ out_shape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> not_zero;
+ Array<PrimExpr> index_tuple;
+ for (size_t i = 0; i < n; ++i) {
+ if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
+ index_tuple.push_back(indices[i]);
+ } else {
+ index_tuple.push_back(indexdiv(indices[i], strides[i]));
+ not_zero.push_back((indexmod(indices[i], strides[i])) == 0);
+ }
+ }
+ if (not_zero.size() > 0) {
+ auto all_not_zero = all(not_zero);
+ return tvm::if_then_else(all_not_zero, x(index_tuple), make_const(x->dtype, 0));
}
- }
- if (not_zero.size() > 0) {
- auto all_not_zero = all(not_zero);
- return tvm::if_then_else(
- all_not_zero, x(index_tuple), make_const(x->dtype, 0));
- }
- return x(index_tuple);
- }, name, tag);
+ return x(index_tuple);
+ },
+ name, tag);
}
} // namespace nn
#ifndef TOPI_NN_FLATTEN_H_
#define TOPI_NN_FLATTEN_H_
-#include <tvm/te/operation.h>
-#include <topi/tags.h>
#include <topi/detail/constant_utils.h>
+#include <topi/tags.h>
+#include <tvm/te/operation.h>
#include <string>
#include <vector>
using namespace tvm::te;
/*!
-* \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions.
-* This requires the input tensor to have constant sized dimensions.
-*
-* \param x The input tensor.
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A 2-D tensor.
-*/
-inline Tensor flatten(const Tensor& x,
- std::string name = "tensor",
- std::string tag = kInjective) {
+ * \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions.
+ * This requires the input tensor to have constant sized dimensions.
+ *
+ * \param x The input tensor.
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A 2-D tensor.
+ */
+inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string tag = kInjective) {
auto ishape = x->shape;
PrimExpr dim = 1;
for (size_t i = 1; i < ishape.size(); ++i) {
dim = dim * ishape[i];
}
- Array<PrimExpr> oshape({ ishape[0], dim });
+ Array<PrimExpr> oshape({ishape[0], dim});
std::vector<PrimExpr> extra_shape;
for (size_t i = 1; i < ishape.size(); ++i) {
std::reverse(extra_shape.begin(), extra_shape.end());
return tvm::te::compute(
- oshape, [&](Var i, Var j) {
- PrimExpr idx = j;
- std::vector<PrimExpr> index;
- for (auto s : extra_shape) {
- index.push_back(indexmod(idx, s));
- idx = indexdiv(idx, s);
- }
- index.push_back(i);
- std::reverse(index.begin(), index.end());
- return x(index);
- }, name, tag);
+ oshape,
+ [&](Var i, Var j) {
+ PrimExpr idx = j;
+ std::vector<PrimExpr> index;
+ for (auto s : extra_shape) {
+ index.push_back(indexmod(idx, s));
+ idx = indexdiv(idx, s);
+ }
+ index.push_back(i);
+ std::reverse(index.begin(), index.end());
+ return x(index);
+ },
+ name, tag);
}
} // namespace nn
#ifndef TOPI_NN_LOCAL_RESPONSE_NORM_H_
#define TOPI_NN_LOCAL_RESPONSE_NORM_H_
-#include <tvm/te/operation.h>
#include <topi/tags.h>
+#include <tvm/te/operation.h>
#include <string>
using namespace tvm::te;
/*!
-* \brief Local response normalization inference operator
-*
-* \param data The input tensor. 4-D shape NCHW or NHWC
-* \param size Integer to define normalisation window size
-* \param axis Input data layout channel axis
-* \param alpha Float scaling factor
-* \param beta Exponent value
-* \param bias Offset to avoid dividing by zero
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the Local response normalization operation
-*/
-inline Tensor lrn(const Tensor& data,
- int size,
- int axis = 1,
- float alpha = 0.0001,
- float beta = 0.75,
- float bias = 2,
- std::string name = "tensor",
+ * \brief Local response normalization inference operator
+ *
+ * \param data The input tensor. 4-D shape NCHW or NHWC
+ * \param size Integer to define normalisation window size
+ * \param axis Input data layout channel axis
+ * \param alpha Float scaling factor
+ * \param beta Exponent value
+ * \param bias Offset to avoid dividing by zero
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the Local response normalization operation
+ */
+inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.0001,
+ float beta = 0.75, float bias = 2, std::string name = "tensor",
std::string tag = kBroadcast) {
CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input";
CHECK_EQ(size % 2, 1) << "size should be odd number";
CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC";
auto input_shape = data->shape;
- Array<PrimExpr> pad_before{ 0, 0, 0, 0};
- Array<PrimExpr> pad_after{ 0, 0, 0, 0};
- pad_before.Set(axis, static_cast<PrimExpr>(size/2));
- pad_after.Set(axis, static_cast<PrimExpr>(size/2));
+ Array<PrimExpr> pad_before{0, 0, 0, 0};
+ Array<PrimExpr> pad_after{0, 0, 0, 0};
+ pad_before.Set(axis, static_cast<PrimExpr>(size / 2));
+ pad_after.Set(axis, static_cast<PrimExpr>(size / 2));
auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data");
auto rxs = tvm::te::reduce_axis(Range(0, size), "rxs");
Tensor sqr_sum;
if (axis == 1) {
- sqr_sum = tvm::te::compute(input_shape,
- [&](Var i, Var l, Var j, Var k) {
- return tvm::sum(pad_data(i, l + rxs, j, k) *
- pad_data(i, l + rxs, j, k),
- {rxs});
- });
+ sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) {
+ return tvm::sum(pad_data(i, l + rxs, j, k) * pad_data(i, l + rxs, j, k), {rxs});
+ });
} else if (axis == 3) {
- sqr_sum = tvm::te::compute(input_shape,
- [&](Var i, Var l, Var j, Var k) {
- return tvm::sum(pad_data(i, l, j, k + rxs) *
- pad_data(i, l, j, k + rxs),
- {rxs});
- });
+ sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) {
+ return tvm::sum(pad_data(i, l, j, k + rxs) * pad_data(i, l, j, k + rxs), {rxs});
+ });
}
- auto sqrt_sum_up = tvm::te::compute(
- input_shape,
- [&](Var i, Var j, Var k, Var l) {
- return tvm::pow(bias +
- (div(alpha * sqr_sum(i, j, k, l), size)),
- beta);
- });
+ auto sqrt_sum_up = tvm::te::compute(input_shape, [&](Var i, Var j, Var k, Var l) {
+ return tvm::pow(bias + (div(alpha * sqr_sum(i, j, k, l), size)), beta);
+ });
return topi::divide(data, sqrt_sum_up);
}
} // namespace nn
#ifndef TOPI_NN_MAPPING_H_
#define TOPI_NN_MAPPING_H_
-#include <tvm/te/operation.h>
#include <topi/tags.h>
+#include <tvm/te/operation.h>
#include <string>
using namespace tvm::te;
/*!
-* \brief Scale and shift with NCHW order
-*
-* \param x The input tensor.
-* \param scale Scale tensor, 1-D of size channel
-* \param shift Shift tensor, 1-D of size channel
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the scale shift operation
-*/
-inline Tensor scale_shift_nchw(const Tensor& x,
- const Tensor& scale,
- const Tensor& shift,
- std::string name = "ScaleShift",
- std::string tag = kBroadcast) {
+ * \brief Scale and shift with NCHW order
+ *
+ * \param x The input tensor.
+ * \param scale Scale tensor, 1-D of size channel
+ * \param shift Shift tensor, 1-D of size channel
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the scale shift operation
+ */
+inline Tensor scale_shift_nchw(const Tensor& x, const Tensor& scale, const Tensor& shift,
+ std::string name = "ScaleShift", std::string tag = kBroadcast) {
return tvm::te::compute(
- x->shape,
- [&](Var b, Var c, Var h, Var w) {
- return x(b, c, h, w) * scale(c) + shift(w);
- }, name, tag);
+ x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(w); },
+ name, tag);
}
/*!
-* \brief Scale and shift with NHWC order
-*
-* \param x The input tensor.
-* \param scale Scale tensor, 1-D of size channel
-* \param shift Shift tensor, 1-D of size channel
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the scale shift operation
-*/
-inline Tensor scale_shift_nhwc(const Tensor& x,
- const Tensor& scale,
- const Tensor& shift,
- std::string name = "ScaleShift",
- std::string tag = kBroadcast) {
+ * \brief Scale and shift with NHWC order
+ *
+ * \param x The input tensor.
+ * \param scale Scale tensor, 1-D of size channel
+ * \param shift Shift tensor, 1-D of size channel
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the scale shift operation
+ */
+inline Tensor scale_shift_nhwc(const Tensor& x, const Tensor& scale, const Tensor& shift,
+ std::string name = "ScaleShift", std::string tag = kBroadcast) {
return tvm::te::compute(
- x->shape,
- [&](Var b, Var h, Var w, Var c) {
- return x(b, h, w, c) * scale(c) + shift(w);
- }, name, tag);
+ x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(w); },
+ name, tag);
}
} // namespace nn
kMaxPool,
};
-
/*!
-* \brief Perform pooling on height and width dimension of data.
-*
-* \param x The input tensor
-* \param kernel_size Vector of two ints: {kernel_height, kernel_width}
-* \param stride_size Vector of two ints: {stride_height, stride_width}
-* \param padding_size Vector of two ints: {padding_height, padding_width}
-* \param pool_type The type of pooling operator
-* \param ceil_mode Whether to use ceil when calculating the output size
-* \param height_axis index of the height dimension
-* \param width_axis index of the width dimension
-* \param count_include_pad Whether include padding in the calculation
-*
-* \return The output tensor in same layout order
-*/
-inline Tensor pool_impl(const Tensor& x,
- const Array<PrimExpr>& kernel_size,
- const Array<PrimExpr>& stride_size,
- const Array<PrimExpr>& padding_size,
- PoolType pool_type,
- bool ceil_mode,
- const size_t height_axis,
- const size_t width_axis,
- bool count_include_pad) {
+ * \brief Perform pooling on height and width dimension of data.
+ *
+ * \param x The input tensor
+ * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
+ * \param stride_size Vector of two ints: {stride_height, stride_width}
+ * \param padding_size Vector of two ints: {padding_height, padding_width}
+ * \param pool_type The type of pooling operator
+ * \param ceil_mode Whether to use ceil when calculating the output size
+ * \param height_axis index of the height dimension
+ * \param width_axis index of the width dimension
+ * \param count_include_pad Whether include padding in the calculation
+ *
+ * \return The output tensor in same layout order
+ */
+inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
+ PoolType pool_type, bool ceil_mode, const size_t height_axis,
+ const size_t width_axis, bool count_include_pad) {
CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
pad_after.Set(height_axis, pad_bottom);
pad_after.Set(width_axis, pad_right);
arith::Analyzer analyzer;
- auto out_height = analyzer.Simplify(
- indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1);
- auto out_width = analyzer.Simplify(
- indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1);
+ auto out_height =
+ analyzer.Simplify(indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1);
+ auto out_width =
+ analyzer.Simplify(indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1);
auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
- const int64_t *padding_h0 = as_const_int(pad_top);
- const int64_t *padding_w0 = as_const_int(pad_left);
- const int64_t *padding_h1 = as_const_int(pad_bottom);
- const int64_t *padding_w1 = as_const_int(pad_right);
+ const int64_t* padding_h0 = as_const_int(pad_top);
+ const int64_t* padding_w0 = as_const_int(pad_left);
+ const int64_t* padding_h1 = as_const_int(pad_bottom);
+ const int64_t* padding_w1 = as_const_int(pad_right);
const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
if (pool_type == kMaxPool) {
- auto temp = do_pad ? pad(
- x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
- return tvm::te::compute(out_shape, [&](const Array<Var>& output) {
- Array<PrimExpr> indices;
- for (const Var& var : output) indices.push_back(var);
- indices.Set(height_axis, output[height_axis] * stride_height + dheight);
- indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
- return tvm::max(temp(indices), { dheight, dwidth });
- }, "tensor", "pool_max");
+ auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
+ return tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& output) {
+ Array<PrimExpr> indices;
+ for (const Var& var : output) indices.push_back(var);
+ indices.Set(height_axis, output[height_axis] * stride_height + dheight);
+ indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
+ return tvm::max(temp(indices), {dheight, dwidth});
+ },
+ "tensor", "pool_max");
} else if (pool_type == kAvgPool) {
// Pad the inputs
auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
// TVM compute for summing the pooling window.
- auto pool_sum = tvm::te::compute(out_shape,
- [&](const Array<Var>& output) {
- Array<PrimExpr> indices;
- for (const Var& var : output) indices.push_back(var);
- indices.Set(height_axis, output[height_axis] * stride_height + dheight);
- indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
- return tvm::sum(temp(indices), { dheight, dwidth });
- }, "tensor", "pool_sum");
+ auto pool_sum = tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& output) {
+ Array<PrimExpr> indices;
+ for (const Var& var : output) indices.push_back(var);
+ indices.Set(height_axis, output[height_axis] * stride_height + dheight);
+ indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
+ return tvm::sum(temp(indices), {dheight, dwidth});
+ },
+ "tensor", "pool_sum");
// TVM compute for dividing the reduced window sum by kernel size.
- return tvm::te::compute(out_shape,
- [&](const Array<Var>& output) {
- Array<PrimExpr> indices;
- for (const Var& var : output) indices.push_back(var);
- if (count_include_pad) {
- return div(pool_sum(indices), (kernel_height * kernel_width));
- } else {
- PrimExpr h_start = output[height_axis] * stride_height - pad_top;
- PrimExpr w_start = output[width_axis] * stride_width - pad_left;
- PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height);
- PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width);
- h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0));
- w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0));
- PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start),
- make_const(DataType::DataType::Int(32), 1));
- return div(pool_sum(indices), divide_factor);
- }
- }, "tensor", kElementWise);
+ return tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& output) {
+ Array<PrimExpr> indices;
+ for (const Var& var : output) indices.push_back(var);
+ if (count_include_pad) {
+ return div(pool_sum(indices), (kernel_height * kernel_width));
+ } else {
+ PrimExpr h_start = output[height_axis] * stride_height - pad_top;
+ PrimExpr w_start = output[width_axis] * stride_width - pad_left;
+ PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height);
+ PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width);
+ h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0));
+ w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0));
+ PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start),
+ make_const(DataType::DataType::Int(32), 1));
+ return div(pool_sum(indices), divide_factor);
+ }
+ },
+ "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
}
}
-inline Tensor pool_grad_impl(const Tensor& out_grad,
- const Tensor& x,
- const Array<PrimExpr>& kernel_size,
- const Array<PrimExpr>& stride_size,
- const Array<PrimExpr>& padding_size,
- PoolType pool_type, bool ceil_mode,
- const size_t height_axis, const size_t width_axis,
+inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
+ const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size,
+ const Array<PrimExpr>& padding_size, PoolType pool_type,
+ bool ceil_mode, const size_t height_axis, const size_t width_axis,
bool count_include_pad) {
CHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
- auto windowh = tvm::te::reduce_axis(
- Range(0, (kernel_height + stride_height - 1) / stride_height));
- auto windoww = tvm::te::reduce_axis(
- Range(0, (kernel_width + stride_width - 1) / stride_width));
+ auto windowh =
+ tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
+ auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
auto argmax = MakeArgmaxReducer();
- auto pad_x = do_pad ? pad(
- x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
-
- auto mp_argmax =
- tvm::te::compute(
- out_shape,
- [&](const Array<Var>& inds) {
- Array<PrimExpr> window_inds{inds.begin(), inds.end()};
- window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
- window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
- auto idx = detail::RavelIndex(window_inds, ravel_shape);
- return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
- },
- "maxpool_grad_argmax", kCommReduceIdx);
+ auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
+
+ auto mp_argmax = tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& inds) {
+ Array<PrimExpr> window_inds{inds.begin(), inds.end()};
+ window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
+ window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
+ auto idx = detail::RavelIndex(window_inds, ravel_shape);
+ return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
+ },
+ "maxpool_grad_argmax", kCommReduceIdx);
auto mp_inds = mp_argmax[0];
return tvm::te::compute(
x->shape,
[&](const Array<Var>& inds) {
- Array<PrimExpr> pad_inds {inds.begin(), inds.end()};
+ Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
auto idx = detail::RavelIndex(pad_inds, ravel_shape);
- Array<PrimExpr> out_idx {inds.begin(), inds.end()};
+ Array<PrimExpr> out_idx{inds.begin(), inds.end()};
out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
(pad_inds[width_axis] - kernel_width) / stride_width + 1);
return tvm::sum(
- tvm::if_then_else(tir::AndNode::make(
- tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h,
- out_idx[width_axis] >= out_idx_lower_w),
- mp_inds(out_idx) == idx),
+ tvm::if_then_else(
+ tir::AndNode::make(tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h,
+ out_idx[width_axis] >= out_idx_lower_w),
+ mp_inds(out_idx) == idx),
out_grad(out_idx), make_const(x->dtype, 0)),
{windowh, windoww});
},
"T_pool_grad", "pool_grad_max");
} else if (pool_type == kAvgPool) {
- auto windowh = tvm::te::reduce_axis(
- Range(0, (kernel_height + stride_height - 1) / stride_height));
- auto windoww = tvm::te::reduce_axis(
- Range(0, (kernel_width + stride_width - 1) / stride_width));
+ auto windowh =
+ tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
+ auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
return tvm::te::compute(
x->shape,
[&](const Array<Var>& inds) {
out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
- PrimExpr out_idx_lower_h = tir::SelectNode::make(
- pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
- (pad_h_idx - kernel_height) / stride_height + 1);
- PrimExpr out_idx_lower_w = tir::SelectNode::make(
- pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
- (pad_w_idx - kernel_width) / stride_width + 1);
+ PrimExpr out_idx_lower_h =
+ tir::SelectNode::make(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
+ (pad_h_idx - kernel_height) / stride_height + 1);
+ PrimExpr out_idx_lower_w =
+ tir::SelectNode::make(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
+ (pad_w_idx - kernel_width) / stride_width + 1);
PrimExpr divide_factor; // number of pooled elements
if (count_include_pad) {
PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width);
h_start = tir::MaxNode::make(h_start, make_const(DataType::Int(32), 0));
w_start = tir::MaxNode::make(w_start, make_const(DataType::Int(32), 0));
- divide_factor =
- tir::MaxNode::make((h_end - h_start) * (w_end - w_start),
- make_const(DataType::Int(32), 1));
+ divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start),
+ make_const(DataType::Int(32), 1));
}
- return tvm::sum(tvm::if_then_else(
- tir::AndNode::make(
- tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h,
- out_idx[height_axis] < out_height),
- tir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w,
- out_idx[width_axis] < out_width)),
- out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
+ return tvm::sum(
+ tvm::if_then_else(
+ tir::AndNode::make(tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h,
+ out_idx[height_axis] < out_height),
+ tir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w,
+ out_idx[width_axis] < out_width)),
+ out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
{windowh, windoww});
},
"T_pool_grad", "pool_grad_avg");
}
}
-inline bool find_depth_height_width(const std::string& layout,
- int* depth_axis,
- int* height_axis,
+inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis,
int* width_axis) {
*depth_axis = -1, *height_axis = -1, *width_axis = -1;
int curr_idx = 0;
for (size_t i = 0; i < layout.size(); ++i) {
- if ((layout[i] >= 'A' && layout[i] <= 'Z') ||
- (layout[i] >= 'a' && layout[i] <= 'z')) {
+ if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) {
if (layout[i] == 'D') {
if (*depth_axis != -1) return false;
*depth_axis = curr_idx;
return true;
}
-inline bool find_height_width(const std::string& layout,
- int* height_axis,
- int* width_axis) {
+inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) {
int dummy;
- CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false);
+ CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false);
if (*height_axis != -1 && *width_axis != -1) {
return true;
}
return false;
}
-inline bool find_width(const std::string& layout,
- int* width_axis) {
+inline bool find_width(const std::string& layout, int* width_axis) {
int dummy;
- CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false);
+ CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false);
if (*width_axis != -1) {
return true;
}
}
/*!
-* \brief Perform pooling on height and width dimension of data.
-* It decides the height and width dimension according to the layout string,
-* in which 'W' and 'H' means width and height respectively.
-* Width and height dimension cannot be split.
-* For example, NCHW, NCHW16c, etc. are valid for pool,
-* while NCHW16w, NCHW16h are not.
-* See \a layout for more information of the layout string convention.
-* \param x The input tensor.
-* \param kernel_size Vector of two ints: {kernel_height, kernel_width}
-* \param stride_size Vector of two ints: {stride_height, stride_width}
-* \param padding_size Vector of two ints: {padding_height, padding_width}
-* \param pool_type The type of pooling operator
-* \param ceil_mode Whether to use ceil when calculating the output size
-* \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
-* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
-* where upper case indicates a dimension and
-* the corresponding lower case (with factor size) indicates the split dimension.
-* For example, NCHW16c can describe a 5-D tensor of
-* [batch_size, channel, height, width, channel_block].
-* (in which factor size `16` will not be used in pooling but for other operators,
-* it can be used to decide the output shape).
-* Since pooling does not care about the factor size of dimensions
-* other than `H` and `W`, one can pass `NCHWc` as well.
-* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
-*
-*
-* \return The output tensor in the same layout
-*/
-inline Tensor pool(const Tensor& x,
- const Array<PrimExpr>& kernel_size,
- const Array<PrimExpr>& stride_size,
- const Array<PrimExpr>& padding_size,
- PoolType pool_type,
- bool ceil_mode,
- const std::string& layout = "NCHW",
+ * \brief Perform pooling on height and width dimension of data.
+ * It decides the height and width dimension according to the layout string,
+ * in which 'W' and 'H' means width and height respectively.
+ * Width and height dimension cannot be split.
+ * For example, NCHW, NCHW16c, etc. are valid for pool,
+ * while NCHW16w, NCHW16h are not.
+ * See \a layout for more information of the layout string convention.
+ * \param x The input tensor.
+ * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
+ * \param stride_size Vector of two ints: {stride_height, stride_width}
+ * \param padding_size Vector of two ints: {padding_height, padding_width}
+ * \param pool_type The type of pooling operator
+ * \param ceil_mode Whether to use ceil when calculating the output size
+ * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
+ * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
+ * where upper case indicates a dimension and
+ * the corresponding lower case (with factor size) indicates the split dimension.
+ * For example, NCHW16c can describe a 5-D tensor of
+ * [batch_size, channel, height, width, channel_block].
+ * (in which factor size `16` will not be used in pooling but for other operators,
+ * it can be used to decide the output shape).
+ * Since pooling does not care about the factor size of dimensions
+ * other than `H` and `W`, one can pass `NCHWc` as well.
+ * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
+ *
+ *
+ * \return The output tensor in the same layout
+ */
+inline Tensor pool(const Tensor& x, const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
+ PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
bool count_include_pad = true) {
int height_axis = -1, width_axis = -1;
- CHECK(find_height_width(layout, &height_axis, &width_axis))
- << "Unsupported layout " << layout;
- return pool_impl(x, kernel_size, stride_size, padding_size,
- pool_type, ceil_mode, height_axis, width_axis,
- count_include_pad);
+ CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
+ return pool_impl(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, height_axis,
+ width_axis, count_include_pad);
}
/*!
height_axis, width_axis, count_include_pad);
}
-inline PrimExpr start_index(const Var& out_index,
- const PrimExpr& odim,
- const PrimExpr& idim) {
+inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
return indexdiv(out_index * idim, odim);
}
-inline PrimExpr end_index(const Var& out_index,
- const PrimExpr& odim,
- const PrimExpr& idim) {
+inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
- return tvm::tir::SelectNode::make(indexmod((out_index + 1) * idim, odim) == 0,
- tmp, tmp + 1);
+ return tvm::tir::SelectNode::make(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1);
}
/*!
-* \brief Perform adaptive pooling on N dimensional data
-*
-* \param x The input tensor
-* \param output_size int vector of size in each dimension
-* \param pool_type The type of pooling operator
-* \param axes indices of each dimension
-*
-* \return The output tensor in same layout order
-*/
-inline Tensor adaptive_pool_impl(const Tensor& x,
- const Array<PrimExpr>& output_size,
- PoolType pool_type,
- const std::vector<int>& axes) {
+ * \brief Perform adaptive pooling on N dimensional data
+ *
+ * \param x The input tensor
+ * \param output_size int vector of size in each dimension
+ * \param pool_type The type of pooling operator
+ * \param axes indices of each dimension
+ *
+ * \return The output tensor in same layout order
+ */
+inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_size,
+ PoolType pool_type, const std::vector<int>& axes) {
const auto n_dim = output_size.size();
CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
};
if (pool_type == kMaxPool) {
- return tvm::te::compute(out_shape, [&](const Array<Var>& output) {
- Array<PrimExpr> indices;
- Array<tir::IterVar> reduce_axes;
- std::tie(indices, reduce_axes) = get_iter_vars(output, true);
- return tvm::max(x(indices), reduce_axes); // NOLINT(*)
- }, "tensor", "adaptive_pool_max");
+ return tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& output) {
+ Array<PrimExpr> indices;
+ Array<tir::IterVar> reduce_axes;
+ std::tie(indices, reduce_axes) = get_iter_vars(output, true);
+ return tvm::max(x(indices), reduce_axes); // NOLINT(*)
+ },
+ "tensor", "adaptive_pool_max");
} else if (pool_type == kAvgPool) {
- auto pool_sum = tvm::te::compute(out_shape, [&](const Array<Var>& output) {
- Array<PrimExpr> indices;
- Array<tir::IterVar> reduce_axes;
- std::tie(indices, reduce_axes) = get_iter_vars(output, true);
- return tvm::sum(x(indices), reduce_axes);
- }, "tensor", "adaptive_pool_sum");
-
- return tvm::te::compute(out_shape, [&](const Array<Var>& output) {
- Array<PrimExpr> indices;
- Array<tir::IterVar> reduce_axes;
- std::tie(indices, reduce_axes) = get_iter_vars(output, false);
-
- PrimExpr divide_factor = tvm::cast(x->dtype, 1);
- for (size_t i = 0; i < n_dim; ++i) {
- divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
- }
+ auto pool_sum = tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& output) {
+ Array<PrimExpr> indices;
+ Array<tir::IterVar> reduce_axes;
+ std::tie(indices, reduce_axes) = get_iter_vars(output, true);
+ return tvm::sum(x(indices), reduce_axes);
+ },
+ "tensor", "adaptive_pool_sum");
- return div(pool_sum(indices), divide_factor);
- }, "tensor", kElementWise);
+ return tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& output) {
+ Array<PrimExpr> indices;
+ Array<tir::IterVar> reduce_axes;
+ std::tie(indices, reduce_axes) = get_iter_vars(output, false);
+
+ PrimExpr divide_factor = tvm::cast(x->dtype, 1);
+ for (size_t i = 0; i < n_dim; ++i) {
+ divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
+ }
+
+ return div(pool_sum(indices), divide_factor);
+ },
+ "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
}
/*!
-* \brief Adaptively perform pooling on height and width dimension of data.
-* The pooling kernel and stride sizes are automatically chosen for desired output sizes.
-* It decides the height and width dimension according to the layout string,
-* in which 'W' and 'H' means width and height respectively.
-* Width and height dimension cannot be split.
-* For example, NCHW, NCHW16c, etc. are valid for pool,
-* while NCHW16w, NCHW16h are not.
-* See \a layout for more information of the layout string convention.
-*
-* \param x The input tensor
-* \param output_size Vector of two ints: {output_height, output_width}
-* \param pool_type The type of pooling operator
-* \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
-* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
-* where upper case indicates a dimension and
-* the corresponding lower case (with factor size) indicates the split dimension.
-* For example, NCHW16c can describe a 5-D tensor of
-* [batch_size, channel, height, width, channel_block].
-* (in which factor size `16` will not be used in pooling but for other operators,
-* it can be used to decide the output shape).
-* Since pooling does not care about the factor size of dimensions
-* other than `H` and `W`, one can pass `NCHWc` as well.
-*
-* \return The output tensor in same layout order
-*/
-inline Tensor adaptive_pool(const Tensor& x,
- const Array<PrimExpr>& output_size,
- PoolType pool_type,
+ * \brief Adaptively perform pooling on height and width dimension of data.
+ * The pooling kernel and stride sizes are automatically chosen for desired output sizes.
+ * It decides the height and width dimension according to the layout string,
+ * in which 'W' and 'H' means width and height respectively.
+ * Width and height dimension cannot be split.
+ * For example, NCHW, NCHW16c, etc. are valid for pool,
+ * while NCHW16w, NCHW16h are not.
+ * See \a layout for more information of the layout string convention.
+ *
+ * \param x The input tensor
+ * \param output_size Vector of two ints: {output_height, output_width}
+ * \param pool_type The type of pooling operator
+ * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
+ * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
+ * where upper case indicates a dimension and
+ * the corresponding lower case (with factor size) indicates the split dimension.
+ * For example, NCHW16c can describe a 5-D tensor of
+ * [batch_size, channel, height, width, channel_block].
+ * (in which factor size `16` will not be used in pooling but for other operators,
+ * it can be used to decide the output shape).
+ * Since pooling does not care about the factor size of dimensions
+ * other than `H` and `W`, one can pass `NCHWc` as well.
+ *
+ * \return The output tensor in same layout order
+ */
+inline Tensor adaptive_pool(const Tensor& x, const Array<PrimExpr>& output_size, PoolType pool_type,
const std::string& layout = "NCHW") {
int height_axis = -1, width_axis = -1;
- CHECK(find_height_width(layout, &height_axis, &width_axis))
- << "Unsupported layout " << layout;
+ CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis});
}
/*!
-* \brief Adaptively perform pooling on three dimensional data.
-* See the two dimensional version above for details.
-* \param x The input tensor
-* \param output_size Vector of three ints: {output_depth, output_height, output_width}
-* \param pool_type The type of pooling operator
-* \param layout The input layout. The default is "NCDHW".
-*/
-inline Tensor adaptive_pool3d(const Tensor& x,
- const Array<PrimExpr>& output_size,
- PoolType pool_type,
- const std::string& layout = "NCDHW") {
+ * \brief Adaptively perform pooling on three dimensional data.
+ * See the two dimensional version above for details.
+ * \param x The input tensor
+ * \param output_size Vector of three ints: {output_depth, output_height, output_width}
+ * \param pool_type The type of pooling operator
+ * \param layout The input layout. The default is "NCDHW".
+ */
+inline Tensor adaptive_pool3d(const Tensor& x, const Array<PrimExpr>& output_size,
+ PoolType pool_type, const std::string& layout = "NCDHW") {
int depth_axis = -1, height_axis = -1, width_axis = -1;
CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
- << "Unsupported layout " << layout;
+ << "Unsupported layout " << layout;
return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
}
/*!
-* \brief Perform global pooling on height and width dimension of data.
-* It decides the height and width dimension according to the layout string,
-* in which 'W' and 'H' means width and height respectively.
-* Width and height dimension cannot be split.
-* For example, NCHW, NCHW16c, ... are valid for global_pool,
-* while NCHW16w, NCHW16h are not.
-* See \a layout for more information of the layout string convention.
-*
-* \param x The input tensor represent as layout
-* \param pool_type The type of pooling operator
-* \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear.
-* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
-* where upper case indicates a dimension and
-* the corresponding lower case (with factor size) indicates the sub-dimension.
-* For example, `NCHW16c` can describe a 5-D tensor of
-* [batch_size, channel, height, width, channel_block].
-* (in which factor size `16` will not be used in pooling but for other operators,
-* it can be used to decide the output shape).
-* Since pooling does not care about the factor size of
-* dimensions other than `H` and `W`, one can pass `NCHWc` as well.
-*
-* \return The output tensor in same layout with height and width dimension size of 1.
-* e.g., for NCHW, the output shape will be [batch, channel, 1, 1]
-*/
-inline Tensor global_pool(const Tensor& x,
- PoolType pool_type,
- const std::string& layout = "NCHW") {
+ * \brief Perform global pooling on height and width dimension of data.
+ * It decides the height and width dimension according to the layout string,
+ * in which 'W' and 'H' means width and height respectively.
+ * Width and height dimension cannot be split.
+ * For example, NCHW, NCHW16c, ... are valid for global_pool,
+ * while NCHW16w, NCHW16h are not.
+ * See \a layout for more information of the layout string convention.
+ *
+ * \param x The input tensor represent as layout
+ * \param pool_type The type of pooling operator
+ * \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear.
+ * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
+ * where upper case indicates a dimension and
+ * the corresponding lower case (with factor size) indicates the sub-dimension.
+ * For example, `NCHW16c` can describe a 5-D tensor of
+ * [batch_size, channel, height, width, channel_block].
+ * (in which factor size `16` will not be used in pooling but for other operators,
+ * it can be used to decide the output shape).
+ * Since pooling does not care about the factor size of
+ * dimensions other than `H` and `W`, one can pass `NCHWc` as well.
+ *
+ * \return The output tensor in same layout with height and width dimension size of 1.
+ * e.g., for NCHW, the output shape will be [batch, channel, 1, 1]
+ */
+inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") {
return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout);
}
/*!
-* \brief Perform pooling on N-dimension of data.
-*
-* \param x The input tensor
-* \param kernel_size Vector of N ints
-* \param stride_size Vector of N ints
-* \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ...,
-* head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN]
-* \param pool_type The type of pooling operator
-* \param ceil_mode Whether to use ceil when calculating the output size
-* \param axis Vector of indices for the N dimensions
-* \param count_include_pad Whether include padding in the calculation
-*
-* \return The output tensor in same layout order
-*/
-inline Tensor pool_impl_nd(const Tensor& x,
- const Array<PrimExpr>& kernel_size,
- const Array<PrimExpr>& stride_size,
- const Array<PrimExpr>& padding_size,
- PoolType pool_type,
- bool ceil_mode,
- const std::vector<int>& axis,
+ * \brief Perform pooling on N-dimension of data.
+ *
+ * \param x The input tensor
+ * \param kernel_size Vector of N ints
+ * \param stride_size Vector of N ints
+ * \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ...,
+ * head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN]
+ * \param pool_type The type of pooling operator
+ * \param ceil_mode Whether to use ceil when calculating the output size
+ * \param axis Vector of indices for the N dimensions
+ * \param count_include_pad Whether include padding in the calculation
+ *
+ * \return The output tensor in same layout order
+ */
+inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
+ PoolType pool_type, bool ceil_mode, const std::vector<int>& axis,
bool count_include_pad) {
int k_size = kernel_size.size();
int x_size = x->shape.size();
CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel";
CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of"
- " kernel";
+ " kernel";
CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
Array<IterVar> daxis;
stride[i] = cast(DataType::Int(32), stride_size[i]);
pad_head[i] = cast(DataType::Int(32), padding_size[i]);
pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]);
- const int64_t *padding0 = as_const_int(pad_head[i]);
- const int64_t *padding1 = as_const_int(pad_tail[i]);
+ const int64_t* padding0 = as_const_int(pad_head[i]);
+ const int64_t* padding1 = as_const_int(pad_tail[i]);
do_pad = (do_pad) ? do_pad : ((padding0 && *padding0) || (padding1 && *padding1));
if (ceil_mode) {
arith::Analyzer analyzer;
auto out_dim = analyzer.Simplify(
- indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
+ indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
out_shape.Set(ii, out_dim);
}
if (pool_type == kMaxPool) {
- auto temp = do_pad ? pad(
- x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
- return tvm::te::compute(out_shape, [&](const Array<Var>& output) {
- Array<PrimExpr> indices;
- for (const Var& var : output) indices.push_back(var);
-
- for (int i = 0; i < k_size; i++) {
- int ii = axis[i];
- indices.Set(ii, output[ii] * stride[i] + daxis[i]);
- }
+ auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
+ return tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& output) {
+ Array<PrimExpr> indices;
+ for (const Var& var : output) indices.push_back(var);
+
+ for (int i = 0; i < k_size; i++) {
+ int ii = axis[i];
+ indices.Set(ii, output[ii] * stride[i] + daxis[i]);
+ }
- return tvm::max(temp(indices), daxis);
- }, "tensor", "pool_max");
+ return tvm::max(temp(indices), daxis);
+ },
+ "tensor", "pool_max");
} else if (pool_type == kAvgPool) {
// Pad the inputs
auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
// TVM compute for summing the pooling window.
- auto pool_sum = tvm::te::compute(out_shape,
- [&](const Array<Var>& output) {
- Array<PrimExpr> indices;
- for (const Var& var : output) indices.push_back(var);
-
- for (int i = 0; i < k_size; i++) {
- int ii = axis[i];
- indices.Set(ii, output[ii] * stride[i] + daxis[i]);
- }
- return tvm::sum(temp(indices), daxis);
- }, "tensor", "pool_sum");
+ auto pool_sum = tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& output) {
+ Array<PrimExpr> indices;
+ for (const Var& var : output) indices.push_back(var);
+
+ for (int i = 0; i < k_size; i++) {
+ int ii = axis[i];
+ indices.Set(ii, output[ii] * stride[i] + daxis[i]);
+ }
+ return tvm::sum(temp(indices), daxis);
+ },
+ "tensor", "pool_sum");
// TVM compute for dividing the reduced window sum by kernel size.
- return tvm::te::compute(out_shape,
- [&](const Array<Var>& output) {
- Array<PrimExpr> indices;
- for (const Var& var : output) indices.push_back(var);
- if (count_include_pad) {
- auto kernel_size = make_const(DataType::Int(32), 1);
- for (int i = 0; i < k_size; i++) {
- kernel_size *= kernel[i];
- }
- return div(pool_sum(indices), kernel_size);
- } else {
- std::vector<PrimExpr> start(k_size);
- std::vector<PrimExpr> end(k_size);
- auto kernel_size = make_const(DataType::Int(32), 1);
- for (int i = 0; i < k_size; i++) {
- int ii = axis[i];
- start[i] = output[ii] * stride[i] - pad_head[i];
- end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]);
- start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0));
- kernel_size *= (end[i] - start[i]);
- }
-
- PrimExpr divide_factor = tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1));
- return div(pool_sum(indices), divide_factor);
- }
- }, "tensor", kElementWise);
+ return tvm::te::compute(
+ out_shape,
+ [&](const Array<Var>& output) {
+ Array<PrimExpr> indices;
+ for (const Var& var : output) indices.push_back(var);
+ if (count_include_pad) {
+ auto kernel_size = make_const(DataType::Int(32), 1);
+ for (int i = 0; i < k_size; i++) {
+ kernel_size *= kernel[i];
+ }
+ return div(pool_sum(indices), kernel_size);
+ } else {
+ std::vector<PrimExpr> start(k_size);
+ std::vector<PrimExpr> end(k_size);
+ auto kernel_size = make_const(DataType::Int(32), 1);
+ for (int i = 0; i < k_size; i++) {
+ int ii = axis[i];
+ start[i] = output[ii] * stride[i] - pad_head[i];
+ end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]);
+ start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0));
+ kernel_size *= (end[i] - start[i]);
+ }
+
+ PrimExpr divide_factor =
+ tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1));
+ return div(pool_sum(indices), divide_factor);
+ }
+ },
+ "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
}
/*!
-* \brief Perform pooling on the width dimension of data.
-* Width axis is determined by the layout string
-* in which 'W' means width.
-* Width dimension cannot be split.
-* For example, NCW, NCW16c, etc. are valid for pool,
-* while NCW16w is not.
-* See \a layout for more information of the layout string convention.
-* \param x The input tensor.
-* \param kernel_size Vector of three ints: {kernel_width}
-* \param stride_size Vector of three ints: {stride_width}
-* \param padding_size Vector of six ints: {head_pad_width, tail_pad_width}
-* \param pool_type The type of pooling operator
-* \param ceil_mode Whether to use ceil when calculating the output size
-* \param layout The input layout. Pooling supports any layout as long as 'W' appears.
-* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
-* where upper case indicates a dimension and
-* the corresponding lower case (with factor size) indicates the split dimension.
-* For example, NCW16c can describe a 4-D tensor of
-* [batch_size, channel, width, channel_block].
-* (in which factor size `16` will not be used in pooling but for other operators,
-* it can be used to decide the output shape).
-* Since pooling does not care about the factor size of dimensions
-* other than `W`, one can pass `NCWc` as well.
-* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
-*
-*
-* \return The output tensor in the same layout
-*/
-inline Tensor pool1d(const Tensor& x,
- const Array<PrimExpr>& kernel_size,
- const Array<PrimExpr>& stride_size,
- const Array<PrimExpr>& padding_size,
- PoolType pool_type,
- bool ceil_mode,
- const std::string& layout = "NCW",
+ * \brief Perform pooling on the width dimension of data.
+ * Width axis is determined by the layout string
+ * in which 'W' means width.
+ * Width dimension cannot be split.
+ * For example, NCW, NCW16c, etc. are valid for pool,
+ * while NCW16w is not.
+ * See \a layout for more information of the layout string convention.
+ * \param x The input tensor.
+ * \param kernel_size Vector of three ints: {kernel_width}
+ * \param stride_size Vector of three ints: {stride_width}
+ * \param padding_size Vector of six ints: {head_pad_width, tail_pad_width}
+ * \param pool_type The type of pooling operator
+ * \param ceil_mode Whether to use ceil when calculating the output size
+ * \param layout The input layout. Pooling supports any layout as long as 'W' appears.
+ * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
+ * where upper case indicates a dimension and
+ * the corresponding lower case (with factor size) indicates the split dimension.
+ * For example, NCW16c can describe a 4-D tensor of
+ * [batch_size, channel, width, channel_block].
+ * (in which factor size `16` will not be used in pooling but for other operators,
+ * it can be used to decide the output shape).
+ * Since pooling does not care about the factor size of dimensions
+ * other than `W`, one can pass `NCWc` as well.
+ * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
+ *
+ *
+ * \return The output tensor in the same layout
+ */
+inline Tensor pool1d(const Tensor& x, const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
+ PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW",
bool count_include_pad = true) {
int width_axis = -1;
- CHECK(find_width(layout, &width_axis))
- << "Unsupported layout " << layout;
+ CHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
std::vector<int> axis = {width_axis};
- return pool_impl_nd(x, kernel_size, stride_size, padding_size,
- pool_type, ceil_mode, axis, count_include_pad);
+ return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis,
+ count_include_pad);
}
/*!
-* \brief Perform pooling on depth, height and width dimension of data.
-* It decides the depth, height and width dimension according to the layout string,
-* in which 'D', 'W' and 'H' means depth, width and height respectively.
-* Depth, Width and height dimension cannot be split.
-* For example, NCDHW, NCDHW16c, etc. are valid for pool,
-* while NCDHW16d, NCDHW16w or NCDHW16h are not.
-* See \a layout for more information of the layout string convention.
-* \param x The input tensor.
-* \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width}
-* \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width}
-* \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width,
-* tail_pad_depth, tail_pad_height, tail_pad_width}
-* \param pool_type The type of pooling operator
-* \param ceil_mode Whether to use ceil when calculating the output size
-* \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear.
-* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
-* where upper case indicates a dimension and
-* the corresponding lower case (with factor size) indicates the split dimension.
-* For example, NCDHW16c can describe a 6-D tensor of
-* [batch_size, channel, depth, height, width, channel_block].
-* (in which factor size `16` will not be used in pooling but for other operators,
-* it can be used to decide the output shape).
-* Since pooling does not care about the factor size of dimensions
-* other than `D`, `H` and `W`, one can pass `NCDHWc` as well.
-* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
-*
-*
-* \return The output tensor in the same layout
-*/
-inline Tensor pool3d(const Tensor& x,
- const Array<PrimExpr>& kernel_size,
- const Array<PrimExpr>& stride_size,
- const Array<PrimExpr>& padding_size,
- PoolType pool_type,
- bool ceil_mode,
- const std::string& layout = "NCDHW",
+ * \brief Perform pooling on depth, height and width dimension of data.
+ * It decides the depth, height and width dimension according to the layout string,
+ * in which 'D', 'W' and 'H' means depth, width and height respectively.
+ * Depth, Width and height dimension cannot be split.
+ * For example, NCDHW, NCDHW16c, etc. are valid for pool,
+ * while NCDHW16d, NCDHW16w or NCDHW16h are not.
+ * See \a layout for more information of the layout string convention.
+ * \param x The input tensor.
+ * \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width}
+ * \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width}
+ * \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width,
+ * tail_pad_depth, tail_pad_height, tail_pad_width}
+ * \param pool_type The type of pooling operator
+ * \param ceil_mode Whether to use ceil when calculating the output size
+ * \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear.
+ * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
+ * where upper case indicates a dimension and
+ * the corresponding lower case (with factor size) indicates the split dimension.
+ * For example, NCDHW16c can describe a 6-D tensor of
+ * [batch_size, channel, depth, height, width, channel_block].
+ * (in which factor size `16` will not be used in pooling but for other operators,
+ * it can be used to decide the output shape).
+ * Since pooling does not care about the factor size of dimensions
+ * other than `D`, `H` and `W`, one can pass `NCDHWc` as well.
+ * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
+ *
+ *
+ * \return The output tensor in the same layout
+ */
+inline Tensor pool3d(const Tensor& x, const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
+ PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW",
bool count_include_pad = true) {
int depth_axis = -1, height_axis = -1, width_axis = -1;
CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
- << "Unsupported layout " << layout;
+ << "Unsupported layout " << layout;
std::vector<int> axis = {depth_axis, height_axis, width_axis};
- return pool_impl_nd(x, kernel_size, stride_size, padding_size,
- pool_type, ceil_mode, axis, count_include_pad);
+ return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis,
+ count_include_pad);
}
} // namespace nn
#ifndef TOPI_NN_SOFTMAX_H_
#define TOPI_NN_SOFTMAX_H_
-#include <tvm/te/operation.h>
#include <topi/reduction.h>
#include <topi/tags.h>
+#include <tvm/te/operation.h>
#include <algorithm>
#include <string>
using namespace tvm::te;
/*!
-* \brief Softmax activation
-*
-* \param x The input tensor. Can be any dimension
-* \param axis The channel axis along which softmax is performed
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the softmax operation
-*/
-inline Tensor softmax(const Tensor &x,
- int axis = -1,
- std::string name = "tensor",
+ * \brief Softmax activation
+ *
+ * \param x The input tensor. Can be any dimension
+ * \param axis The channel axis along which softmax is performed
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the softmax operation
+ */
+inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor",
std::string tag = "softmax_output") {
auto input_shape = x->shape;
auto ndim = input_shape.size();
tvm::Map<std::string, ObjectRef> attrs;
attrs.Set("axis", Integer(axis));
- auto insert_reduce_index = [axis, ndim](const Array<Var> &indices,
- const IterVar &reduce_index) {
+ auto insert_reduce_index = [axis, ndim](const Array<Var>& indices, const IterVar& reduce_index) {
Array<PrimExpr> eval_range;
int arg_counter = 0;
for (size_t i = 0; i < ndim; ++i) {
return eval_range;
};
- auto get_non_reduce_indices = [axis, ndim](const Array<Var> &indices) {
+ auto get_non_reduce_indices = [axis, ndim](const Array<Var>& indices) {
Array<PrimExpr> non_reduce_indices;
for (size_t i = 0; i < ndim; ++i) {
- if (static_cast<int>(i) != axis)
- non_reduce_indices.push_back(indices[i]);
+ if (static_cast<int>(i) != axis) non_reduce_indices.push_back(indices[i]);
}
return non_reduce_indices;
};
- auto _compute_max = [&](const Array<Var> &indices) {
+ auto _compute_max = [&](const Array<Var>& indices) {
auto eval_range = insert_reduce_index(indices, k1);
return topi::MaxOp(x(eval_range), {k1});
};
- auto _compute_exp = [&](const Tensor &max_elem,
- const Array<Var> &indices) {
+ auto _compute_exp = [&](const Tensor& max_elem, const Array<Var>& indices) {
auto non_reduce_indices = get_non_reduce_indices(indices);
return tvm::exp(x(indices) - max_elem(non_reduce_indices));
};
- auto _compute_expsum = [&](const Tensor &exp,
- const Array<Var> &indices) {
+ auto _compute_expsum = [&](const Tensor& exp, const Array<Var>& indices) {
auto eval_range = insert_reduce_index(indices, k2);
return tvm::sum(exp(eval_range), {k2});
};
- auto _normalize = [&](const Tensor &exp, const Tensor &expsum,
- const Array<Var> &indices) {
+ auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const Array<Var>& indices) {
auto non_reduce_indices = get_non_reduce_indices(indices);
return exp(indices) / expsum(non_reduce_indices);
};
auto max_elem = tvm::te::compute(reduced_shape, _compute_max);
- auto exp = tvm::te::compute(input_shape, [&](const Array<Var> &indices) {
- return _compute_exp(max_elem, indices);
- });
- auto expsum = tvm::te::compute(reduced_shape, [&](const Array<Var> &indices) {
- return _compute_expsum(exp, indices);
- });
- return tvm::te::compute(input_shape, [&](const Array<Var> &indices) {
- return _normalize(exp, expsum, indices);
- }, name, tag, attrs);
+ auto exp = tvm::te::compute(
+ input_shape, [&](const Array<Var>& indices) { return _compute_exp(max_elem, indices); });
+ auto expsum = tvm::te::compute(
+ reduced_shape, [&](const Array<Var>& indices) { return _compute_expsum(exp, indices); });
+ return tvm::te::compute(
+ input_shape, [&](const Array<Var>& indices) { return _normalize(exp, expsum, indices); },
+ name, tag, attrs);
}
/*!
-* \brief Log softmax activation
-*
-* \param x The input tensor. 2-D where log softmax is performed along the second dimension
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the log softmax operation
-*/
-inline Tensor log_softmax(const Tensor& x,
- std::string name = "tensor",
+ * \brief Log softmax activation
+ *
+ * \param x The input tensor. 2-D where log softmax is performed along the second dimension
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the log softmax operation
+ */
+inline Tensor log_softmax(const Tensor& x, std::string name = "tensor",
std::string tag = "log_softmax_output") {
CHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input";
PrimExpr n = x->shape[1];
auto k = tvm::te::reduce_axis(Range(0, n), "k");
- auto max_elem = tvm::te::compute(
- { m }, [&](Var i) {
- return tvm::max(x(i, k), Array<IterVar>{ k }); });
+ auto max_elem =
+ tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), Array<IterVar>{k}); });
k = tvm::te::reduce_axis(Range(0, n), "k");
- auto expsum = tvm::te::compute(
- { m }, [&](Var i) {
- return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), { k }); });
+ auto expsum =
+ tvm::te::compute({m}, [&](Var i) { return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), {k}); });
return tvm::te::compute(
- x->shape, [&](Var i, Var j) {
- return x(i, j) - max_elem(i) - tvm::log(expsum(i));
- }, name, tag);
+ x->shape, [&](Var i, Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name,
+ tag);
}
} // namespace nn
#ifndef TOPI_REDUCTION_H_
#define TOPI_REDUCTION_H_
-#include <tvm/te/operation.h>
#include <topi/broadcast.h>
+#include <topi/detail/constant_utils.h>
+#include <topi/detail/ravel_unravel.h>
#include <topi/elemwise.h>
#include <topi/tags.h>
#include <topi/transform.h>
-#include <topi/detail/ravel_unravel.h>
-#include <topi/detail/constant_utils.h>
+#include <tvm/te/operation.h>
#include <algorithm>
+#include <iterator>
#include <string>
#include <vector>
-#include <iterator>
namespace topi {
using namespace tvm;
using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis)>;
/*! \brief The operation to use for CommReduceIdx */
-using FCommReduce = std::function<
- Array<PrimExpr>(Array<PrimExpr> exprs, const Array<IterVar>& axis, PrimExpr* condition)>;
+using FCommReduce = std::function<Array<PrimExpr>(Array<PrimExpr> exprs, const Array<IterVar>& axis,
+ PrimExpr* condition)>;
/*!
-* \brief Convert a reduction axis which could be empty or have negative
-* elements into a real axis with valid dimension indices.
-*
-* \param ndim Number of dimensions in the target.
-* \param axis The axis parameter.
-*
-* \return A non-empty sorted array of valid dimension indices, with no duplicates.
-* If the input axis is empty, the result will be an axis including all dimensions.
-* If any input element is negative, it will be treated as an offset from the
-* last dimension (same as python indexing rules).
-*/
+ * \brief Convert a reduction axis which could be empty or have negative
+ * elements into a real axis with valid dimension indices.
+ *
+ * \param ndim Number of dimensions in the target.
+ * \param axis The axis parameter.
+ *
+ * \return A non-empty sorted array of valid dimension indices, with no duplicates.
+ * If the input axis is empty, the result will be an axis including all dimensions.
+ * If any input element is negative, it will be treated as an offset from the
+ * last dimension (same as python indexing rules).
+ */
inline std::vector<int> GetRealAxis(int ndim, const Array<Integer>& axis) {
std::vector<int> real_axis;
if (!axis.defined() || axis.size() == 0) {
real_axis.push_back(static_cast<int>(val));
}
std::sort(real_axis.begin(), real_axis.end());
- real_axis.resize(
- std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin());
+ real_axis.resize(std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin());
}
return real_axis;
}
Array<IterVar> reduce_axes;
for (auto i : real_axis) {
std::string name = "k" + std::to_string(i);
- reduce_axes.push_back(
- tvm::te::reduce_axis(Range(0, data->shape[i]), name));
+ reduce_axes.push_back(tvm::te::reduce_axis(Range(0, data->shape[i]), name));
}
return reduce_axes;
}
/*! \brief Calculate the target shape for a reduce op */
-inline Array<PrimExpr> MakeReduceTargetShape(const std::vector<int>& real_axis,
- const Tensor& data,
- bool keepdims,
- bool atleast1d) {
+inline Array<PrimExpr> MakeReduceTargetShape(const std::vector<int>& real_axis, const Tensor& data,
+ bool keepdims, bool atleast1d) {
auto ndim = data->shape.size();
Array<PrimExpr> target_shape;
if (keepdims) {
*
* \return The result tensor.
*/
-inline Tensor DoCommReduce(const Tensor& data,
- FReduce func,
- const Array<PrimExpr>& target_shape,
+inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array<PrimExpr>& target_shape,
const std::vector<int>& reduce_axes,
const std::vector<int>& squeeze_axes) {
auto r_axes = MakeReduceAxes(reduce_axes, data);
*
* \return The result tensor.
*/
-inline Tensor CommReduce(const Tensor& data,
- const Array<Integer>& axis,
- FReduce func,
- bool keepdims,
- bool atleast1d) {
+inline Tensor CommReduce(const Tensor& data, const Array<Integer>& axis, FReduce func,
+ bool keepdims, bool atleast1d) {
auto ndim = data->shape.size();
CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
return DoCommReduce(data, func, target_shape, real_axis,
- keepdims ? std::vector<int>() : real_axis);
+ keepdims ? std::vector<int>() : real_axis);
}
/*!
-* \brief Create an index reduction operation.
-*
-* \param data The input tensor.
-* \param axis The axes along which the reduction is performed.
-* \param func The reduction function
-* \param keepdims If this is set to true, the axes which are reduced are
-* left in the result as dimensions with size one. This enables the result
-* to broadcast correctly against the input array.
-* \param atleast1d Whether the output need to be atleast1d.
-*
-* \return The result tensor.
-*/
-inline Tensor CommReduceIdx(const Tensor& data,
- const Array<Integer>& axis,
- FCommReduce func,
- bool keepdims,
- bool atleast1d) {
+ * \brief Create an index reduction operation.
+ *
+ * \param data The input tensor.
+ * \param axis The axes along which the reduction is performed.
+ * \param func The reduction function
+ * \param keepdims If this is set to true, the axes which are reduced are
+ * left in the result as dimensions with size one. This enables the result
+ * to broadcast correctly against the input array.
+ * \param atleast1d Whether the output need to be atleast1d.
+ *
+ * \return The result tensor.
+ */
+inline Tensor CommReduceIdx(const Tensor& data, const Array<Integer>& axis, FCommReduce func,
+ bool keepdims, bool atleast1d) {
auto ndim = data->shape.size();
CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
- auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data]
- (const Array<Var>& indices) {
+ auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func,
+ &data](const Array<Var>& indices) {
Array<PrimExpr> eval_range;
Array<PrimExpr> eval_indices;
int arg_counter = 0;
ravel_shape.push_back(data->shape[i]);
}
auto idx = detail::RavelIndex(eval_indices, ravel_shape);
- return func({ idx, data(eval_range) }, reduce_axes, nullptr);
+ return func({idx, data(eval_range)}, reduce_axes, nullptr);
};
- auto temp_idx_val = tvm::te::compute(target_shape, compute,
- data->op->name + "_red_temp", kCommReduceIdx);
+ auto temp_idx_val =
+ tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduceIdx);
auto temp_idx = temp_idx_val[0];
auto temp_val = temp_idx_val[1];
return tvm::te::compute(
- target_shape,
- [&temp_idx](const Array<Var>& indices) { return temp_idx(indices); },
- data->op->name + "_red",
- kCommReduceIdx);
+ target_shape, [&temp_idx](const Array<Var>& indices) { return temp_idx(indices); },
+ data->op->name + "_red", kCommReduceIdx);
}
/*! \brief A combiner function for a reduction */
*
* \return A reducer function which creates a reduce expression over an axis.
*/
-inline FCommReduce MakeCommReducer(FCombine fcombine,
- FIdentity fidentity,
+inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity,
std::string name = "reduce") {
- return [fcombine, fidentity, name]
- (Array<PrimExpr> exprs, const Array<IterVar>& axis, PrimExpr* condition) {
+ return [fcombine, fidentity, name](Array<PrimExpr> exprs, const Array<IterVar>& axis,
+ PrimExpr* condition) {
Array<Var> lhs, rhs;
std::vector<DataType> dtypes;
Array<PrimExpr> outputs;
for (size_t i = 0; i < exprs.size(); ++i) {
outputs.push_back(
- tvm::tir::ReduceNode::make(combiner, exprs, axis, cond, static_cast<int>(i)));
+ tvm::tir::ReduceNode::make(combiner, exprs, axis, cond, static_cast<int>(i)));
}
return outputs;
};
}
/*! \brief Wrap tvm::min to ensure we get the correct overload */
-inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis) {
- return tvm::min(source, axis);
-}
+inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis) { return tvm::min(source, axis); }
/*! \brief Wrap tvm::max to ensure we get the correct overload */
inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis) {
}
/*!
-* \brief Creates an operation that sums array elements over a given axis
-*
-* \param data The input tensor
-* \param axis The axis to sum over. If axis is empty, the operation will
-* sum over all elements of the array.
-* \param keepdims If this is set to true, the axes which are reduced are
-* left in the result as dimensions with size one. This enables the result
-* to broadcast correctly against the input array.
-* \param atleast1d Whether the output need to be atleast1d.
-*
-* \return A Tensor whose op member is the sum operation
-*/
-inline Tensor sum(const Tensor& data,
- const Array<Integer>& axis,
- bool keepdims = false,
+ * \brief Creates an operation that sums array elements over a given axis
+ *
+ * \param data The input tensor
+ * \param axis The axis to sum over. If axis is empty, the operation will
+ * sum over all elements of the array.
+ * \param keepdims If this is set to true, the axes which are reduced are
+ * left in the result as dimensions with size one. This enables the result
+ * to broadcast correctly against the input array.
+ * \param atleast1d Whether the output need to be atleast1d.
+ *
+ * \return A Tensor whose op member is the sum operation
+ */
+inline Tensor sum(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, tvm::sum, keepdims, atleast1d);
}
std::vector<int> reduce_axes;
std::vector<int> squeeze_axes;
- for (int i_ax = ishape.size() - 1,
- o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) {
+ for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) {
if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) {
--o_ax;
continue;
}
/*!
-* \brief Creates an operation that computes the logical AND of elements
-* over a given axis
-*
-* \param data The input boolean tensor
-* \param axis The axes to reduce. If axis is empty, the operation will
-* perform logical AND over all elements of the array.
-* \param keepdims If this is set to true, the axes which are reduced are
-* left in the result as dimensions with size one. This enables the result
-* to broadcast correctly against the input array.
-* \param atleast1d Whether the output need to be atleast1d.
-*
-* \return A Tensor whose op member is the all operation
-*/
-inline Tensor all(const Tensor& data,
- const Array<Integer>& axis,
- bool keepdims = false,
+ * \brief Creates an operation that computes the logical AND of elements
+ * over a given axis
+ *
+ * \param data The input boolean tensor
+ * \param axis The axes to reduce. If axis is empty, the operation will
+ * perform logical AND over all elements of the array.
+ * \param keepdims If this is set to true, the axes which are reduced are
+ * left in the result as dimensions with size one. This enables the result
+ * to broadcast correctly against the input array.
+ * \param atleast1d Whether the output need to be atleast1d.
+ *
+ * \return A Tensor whose op member is the all operation
+ */
+inline Tensor all(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, tvm::all, keepdims, atleast1d);
}
/*!
-* \brief Creates an operation that computes the logical OR of elements
-* over a given axis
-*
-* \param data The input boolean tensor
-* \param axis The axes to reduce. If axis is empty, the operation will
-* perform logical OR over all elements of the array.
-* \param keepdims If this is set to true, the axes which are reduced are
-* left in the result as dimensions with size one. This enables the result
-* to broadcast correctly against the input array.
-* \param atleast1d Whether the output need to be atleast1d.
-*
-* \return A Tensor whose op member is the all operation
-*/
-inline Tensor any(const Tensor& data,
- const Array<Integer>& axis,
- bool keepdims = false,
+ * \brief Creates an operation that computes the logical OR of elements
+ * over a given axis
+ *
+ * \param data The input boolean tensor
+ * \param axis The axes to reduce. If axis is empty, the operation will
+ * perform logical OR over all elements of the array.
+ * \param keepdims If this is set to true, the axes which are reduced are
+ * left in the result as dimensions with size one. This enables the result
+ * to broadcast correctly against the input array.
+ * \param atleast1d Whether the output need to be atleast1d.
+ *
+ * \return A Tensor whose op member is the all operation
+ */
+inline Tensor any(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, tvm::any, keepdims, atleast1d);
}
/*!
-* \brief Creates an operation that finds the minimum of elements over
-* a given axis.
-*
-* \param data The input tensor
-* \param axis The axis to find the minimum over. If axis is empty, the
-* operation will find the minimum over all elements of the array.
-* \param keepdims If this is set to true, the axes which are reduced are
-* left in the result as dimensions with size one. This enables the result
-* to broadcast correctly against the input array.
-* \param atleast1d Whether the output need to be atleast1d.
-*
-* \return A Tensor whose op member is the min operation
-*/
-inline Tensor min(const Tensor& data,
- const Array<Integer>& axis,
- bool keepdims = false,
+ * \brief Creates an operation that finds the minimum of elements over
+ * a given axis.
+ *
+ * \param data The input tensor
+ * \param axis The axis to find the minimum over. If axis is empty, the
+ * operation will find the minimum over all elements of the array.
+ * \param keepdims If this is set to true, the axes which are reduced are
+ * left in the result as dimensions with size one. This enables the result
+ * to broadcast correctly against the input array.
+ * \param atleast1d Whether the output need to be atleast1d.
+ *
+ * \return A Tensor whose op member is the min operation
+ */
+inline Tensor min(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, MinOp, keepdims, atleast1d);
}
/*!
-* \brief Creates an operation that finds the maximum of elements over
-* a given axis.
-*
-* \param data The input tensor
-* \param axis The axis to find the maximum over. If axis is empty, the
-* operation will find the maximum over all elements of the array.
-* \param keepdims If this is set to true, the axes which are reduced are
-* left in the result as dimensions with size one. This enables the result
-* to broadcast correctly against the input array.
-* \param atleast1d Whether the output need to be atleast1d.
-*
-* \return A Tensor whose op member is the max operation
-*/
-inline Tensor max(const Tensor& data,
- const Array<Integer>& axis,
- bool keepdims = false,
+ * \brief Creates an operation that finds the maximum of elements over
+ * a given axis.
+ *
+ * \param data The input tensor
+ * \param axis The axis to find the maximum over. If axis is empty, the
+ * operation will find the maximum over all elements of the array.
+ * \param keepdims If this is set to true, the axes which are reduced are
+ * left in the result as dimensions with size one. This enables the result
+ * to broadcast correctly against the input array.
+ * \param atleast1d Whether the output need to be atleast1d.
+ *
+ * \return A Tensor whose op member is the max operation
+ */
+inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, MaxOp, keepdims, atleast1d);
}
/*!
-* \brief Creates an operation that finds the indices of the minimum
-* values over a given axis.
-*
-* \param data The input tensor
-* \param axis The axis along which the argmin is performed. If axis is empty,
-* the operation will find the minimum index over all elements of the array.
-* \param keepdims If this is set to true, the axes which are reduced are
-* left in the result as dimensions with size one. This enables the result
-* to broadcast correctly against the input array.
-* \param atleast1d Whether the output need to be atleast1d.
-*
-* \return A Tensor whose op member is the argmin operation
-*/
-inline Tensor argmin(const Tensor& data,
- const Array<Integer>& axis,
- bool keepdims = false,
+ * \brief Creates an operation that finds the indices of the minimum
+ * values over a given axis.
+ *
+ * \param data The input tensor
+ * \param axis The axis along which the argmin is performed. If axis is empty,
+ * the operation will find the minimum index over all elements of the array.
+ * \param keepdims If this is set to true, the axes which are reduced are
+ * left in the result as dimensions with size one. This enables the result
+ * to broadcast correctly against the input array.
+ * \param atleast1d Whether the output need to be atleast1d.
+ *
+ * \return A Tensor whose op member is the argmin operation
+ */
+inline Tensor argmin(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
auto fidentity = [](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
- result.push_back(tvm::max_value(types[1])); // val
+ result.push_back(tvm::max_value(types[1])); // val
return result;
};
auto func = MakeCommReducer(fcombine, fidentity, "argmin");
auto fidentity = [](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
- result.push_back(tvm::min_value(types[1])); // val
+ result.push_back(tvm::min_value(types[1])); // val
return result;
};
return MakeCommReducer(fcombine, fidentity, "argmax");
}
/*!
-* \brief Creates an operation that finds the indices of the maximum
-* values over a given axis.
-*
-* \param data The input tensor
-* \param axis The axis along which the argmax is performed. If axis is empty,
-* the operation will find the maximum index over all elements of the array.
-* \param keepdims If this is set to true, the axes which are reduced are
-* left in the result as dimensions with size one. This enables the result
-* to broadcast correctly against the input array.
-* \param atleast1d Whether the output need to be atleast1d.
-*
-* \return A Tensor whose op member is the argmax operation
-*/
-inline Tensor argmax(const Tensor& data,
- const Array<Integer>& axis,
- bool keepdims = false,
+ * \brief Creates an operation that finds the indices of the maximum
+ * values over a given axis.
+ *
+ * \param data The input tensor
+ * \param axis The axis along which the argmax is performed. If axis is empty,
+ * the operation will find the maximum index over all elements of the array.
+ * \param keepdims If this is set to true, the axes which are reduced are
+ * left in the result as dimensions with size one. This enables the result
+ * to broadcast correctly against the input array.
+ * \param atleast1d Whether the output need to be atleast1d.
+ *
+ * \return A Tensor whose op member is the argmax operation
+ */
+inline Tensor argmax(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
auto reducer = MakeArgmaxReducer();
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
}
/*!
-* \brief Creates product operation over given axis.
-*
-* \param data The input tensor
-* \param axis The axis to do product over. If axis is empty, the
-* operation will do the product over all elements of the array.
-* \param keepdims If this is set to true, the axes which are reduced are
-* left in the result as dimensions with size one. This enables the result
-* to broadcast correctly against the input array.
-* \param atleast1d Whether the output need to be atleast1d.
-*
-* \return A Tensor whose op member is the prod operation
-*/
-inline Tensor prod(const Tensor& data,
- const Array<Integer>& axis,
- bool keepdims = false,
+ * \brief Creates product operation over given axis.
+ *
+ * \param data The input tensor
+ * \param axis The axis to do product over. If axis is empty, the
+ * operation will do the product over all elements of the array.
+ * \param keepdims If this is set to true, the axes which are reduced are
+ * left in the result as dimensions with size one. This enables the result
+ * to broadcast correctly against the input array.
+ * \param atleast1d Whether the output need to be atleast1d.
+ *
+ * \return A Tensor whose op member is the prod operation
+ */
+inline Tensor prod(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, ProdOp, keepdims, atleast1d);
}
#ifndef TOPI_ROCM_DENSE_H_
#define TOPI_ROCM_DENSE_H_
-#include <tvm/te/operation.h>
-#include <tvm/target/generic_func.h>
#include <topi/tags.h>
-#include "topi/detail/array_utils.h"
-#include "topi/nn/dense.h"
+#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
+
#include "topi/contrib/rocblas.h"
-#include "topi/generic/extern.h"
#include "topi/cuda/dense.h"
+#include "topi/detail/array_utils.h"
+#include "topi/generic/extern.h"
+#include "topi/nn/dense.h"
namespace topi {
using namespace tvm;
namespace rocm {
/*!
-* \brief Implementation of dense for rocm backend
-*
-* \param target The target device
-* \param data Tensor with shape [batch, in_dim]
-* \param weight Tensor with shape [out_dim, in_dim]
-* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
-* \param out_dtype Output data type. Used for mixed precision.
-*
-* \return Tensor with shape [batch, out_dim]
-*/
-inline tvm::te::Tensor dense_rocm(const Target& target,
- const tvm::te::Tensor& data,
- const tvm::te::Tensor& weight,
- const tvm::te::Tensor& bias,
- const DataType& out_dtype) {
+ * \brief Implementation of dense for rocm backend
+ *
+ * \param target The target device
+ * \param data Tensor with shape [batch, in_dim]
+ * \param weight Tensor with shape [out_dim, in_dim]
+ * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
+ * \param out_dtype Output data type. Used for mixed precision.
+ *
+ * \return Tensor with shape [batch, out_dim]
+ */
+inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& data,
+ const tvm::te::Tensor& weight, const tvm::te::Tensor& bias,
+ const DataType& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
if (bias.defined()) {
- mm = tvm::te::compute({ batch, out_dim },
- [&](Var i, Var j) {
- return mm(i, j) + bias(j);
- }, "tensor", kBroadcast);
+ mm = tvm::te::compute(
+ {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast);
}
return mm;
}
/*!
-* \brief Create a rocm schedule for dense
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) {
- if (target->target_name == "rocm" &&
- target->libs().count("rocblas")) {
+ * \brief Create a rocm schedule for dense
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_dense(const Target& target, const Array<Tensor>& outs) {
+ if (target->target_name == "rocm" && target->libs().count("rocblas")) {
return topi::generic::schedule_extern(target, outs);
}
#ifndef TOPI_ROCM_INJECTIVE_H_
#define TOPI_ROCM_INJECTIVE_H_
-#include <topi/tags.h>
#include <topi/detail/fuse.h>
-#include <tvm/te/operation.h>
+#include <topi/tags.h>
#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
#include "topi/cuda/injective.h"
*
* \return A schedule for the given ops.
*/
-inline Schedule schedule_injective(const Target &target, const Array<Tensor>& outs) {
+inline Schedule schedule_injective(const Target& target, const Array<Tensor>& outs) {
return topi::cuda::schedule_injective(target, outs);
}
#ifndef TOPI_ROCM_NORMALIZATION_H_
#define TOPI_ROCM_NORMALIZATION_H_
-#include <tvm/te/operation.h>
-#include <tvm/target/generic_func.h>
#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
namespace topi {
using namespace tvm;
using namespace tvm::te;
namespace rocm {
/*!
-* \brief Create a rocm schedule for LRN
-* \param outs The output tensors.
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_lrn(const Array<Tensor>& outs) {
- return topi::cuda::schedule_lrn(outs);
-}
+ * \brief Create a rocm schedule for LRN
+ * \param outs The output tensors.
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_lrn(const Array<Tensor>& outs) { return topi::cuda::schedule_lrn(outs); }
} // namespace rocm
} // namespace topi
#ifndef TOPI_ROCM_POOLING_H_
#define TOPI_ROCM_POOLING_H_
-#include <tvm/te/operation.h>
-#include <tvm/target/generic_func.h>
-#include <topi/tags.h>
-#include <topi/detail/fuse.h>
-#include <topi/detail/array_utils.h>
#include <topi/cuda/pooling.h>
+#include <topi/detail/array_utils.h>
+#include <topi/detail/fuse.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
namespace topi {
using namespace tvm;
namespace rocm {
/*!
-* \brief Create a rocm schedule for pool
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
+ * \brief Create a rocm schedule for pool
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_pool(const Target& target, const Array<Tensor>& outs) {
return topi::cuda::schedule_pool(target, outs);
}
/*!
-* \brief Create a rocm schedule for global_pool
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>& outs) {
+ * \brief Create a rocm schedule for global_pool
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_global_pool(const Target& target, const Array<Tensor>& outs) {
return topi::cuda::schedule_global_pool(target, outs);
}
#ifndef TOPI_ROCM_REDUCTION_H_
#define TOPI_ROCM_REDUCTION_H_
-#include <topi/tags.h>
#include <topi/detail/fuse.h>
-#include <tvm/te/operation.h>
+#include <topi/tags.h>
#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
#include "topi/cuda/reduction.h"
namespace rocm {
/*!
-* \brief Create a rocm schedule for a reduce operation.
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
+ * \brief Create a rocm schedule for a reduce operation.
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
Schedule schedule_reduce(const Target& target, Array<Tensor> outs) {
return topi::cuda::schedule_reduce(target, outs);
}
#ifndef TOPI_ROCM_SOFTMAX_H_
#define TOPI_ROCM_SOFTMAX_H_
-#include <topi/tags.h>
#include <topi/detail/fuse.h>
-#include <tvm/te/operation.h>
+#include <topi/tags.h>
#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
#include "topi/cuda/softmax.h"
*
* \return A schedule for the given ops.
*/
-inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs) {
+inline Schedule schedule_softmax(const Target& target, const Array<Tensor>& outs) {
return topi::cuda::schedule_softmax(target, outs);
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
constexpr auto kGroupConv2d = "group_conv2d";
inline bool is_broadcast(std::string tag) {
- return
- tag.rfind(kElementWise, 0) == 0 ||
- tag.rfind(kBroadcast, 0) == 0;
+ return tag.rfind(kElementWise, 0) == 0 || tag.rfind(kBroadcast, 0) == 0;
}
inline bool is_injective(std::string tag) {
- return
- tag.rfind(kElementWise, 0) == 0 ||
- tag.rfind(kBroadcast, 0) == 0 ||
- tag.rfind(kInjective, 0) == 0;
+ return tag.rfind(kElementWise, 0) == 0 || tag.rfind(kBroadcast, 0) == 0 ||
+ tag.rfind(kInjective, 0) == 0;
}
} // namespace topi
#ifndef TOPI_TRANSFORM_H_
#define TOPI_TRANSFORM_H_
-#include <tvm/tir/data_layout.h>
-#include <tvm/te/operation.h>
-#include <topi/tags.h>
-#include <topi/detail/ravel_unravel.h>
#include <topi/detail/constant_utils.h>
+#include <topi/detail/ravel_unravel.h>
#include <topi/detail/tensor_utils.h>
+#include <topi/tags.h>
+#include <tvm/te/operation.h>
+#include <tvm/tir/data_layout.h>
-#include <string>
-#include <vector>
-#include <iterator>
#include <algorithm>
+#include <iterator>
#include <limits>
+#include <string>
#include <unordered_set>
+#include <vector>
namespace topi {
using namespace tvm;
using namespace topi::detail;
/*!
-* \brief Creates an operation to insert new dimensions of length 1
-*
-* \param x The input tensor
-* \param axis The index of the first new dimension (allows negative
-* indices as offsets from the last dimension)
-* \param num_newaxis The number of new dimensions to insert
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the dim expansion operation
-*/
-inline Tensor expand_dims(const Tensor& x,
- int axis,
- int num_newaxis = 1,
- std::string name = "T_expand_dims",
- std::string tag = kBroadcast) {
+ * \brief Creates an operation to insert new dimensions of length 1
+ *
+ * \param x The input tensor
+ * \param axis The index of the first new dimension (allows negative
+ * indices as offsets from the last dimension)
+ * \param num_newaxis The number of new dimensions to insert
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the dim expansion operation
+ */
+inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1,
+ std::string name = "T_expand_dims", std::string tag = kBroadcast) {
int ndim = static_cast<int>(x->shape.size());
CHECK(-ndim - 1 <= axis && axis <= ndim)
- << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
- << ", but got axis = " << axis
- << ", and data.ndim = " << ndim;
- CHECK(num_newaxis >= 0)
- << "expand_dims only accepts `num_newaxis >= 0`"
- << ", but got num_newaxis = " << num_newaxis;
+ << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
+ << ", but got axis = " << axis << ", and data.ndim = " << ndim;
+ CHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
+ << ", but got num_newaxis = " << num_newaxis;
if (axis < 0) {
// Calculate offset from last dimension
axis = ndim + axis + 1;
}
return compute(
- new_shape, [&](const Array<Var>& indices) {
- Array<PrimExpr> idx;
- for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
- idx.push_back(indices[i]);
- }
- for (size_t i = axis + num_newaxis; i < indices.size(); ++i) {
- idx.push_back(indices[i]);
- }
- return x(idx);
- }, name, tag);
+ new_shape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> idx;
+ for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
+ idx.push_back(indices[i]);
+ }
+ for (size_t i = axis + num_newaxis; i < indices.size(); ++i) {
+ idx.push_back(indices[i]);
+ }
+ return x(idx);
+ },
+ name, tag);
}
/*!
-* \brief Permute the dimensions of an array
-*
-* \param x The input tensor
-* \param axes The indices of the permutation. If this is empty,
-* the dimensions will be reversed.
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the transpose operation
-*/
-inline Tensor transpose(const Tensor& x,
- Array<Integer> axes,
- std::string name = "T_transpose",
+ * \brief Permute the dimensions of an array
+ *
+ * \param x The input tensor
+ * \param axes The indices of the permutation. If this is empty,
+ * the dimensions will be reversed.
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the transpose operation
+ */
+inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name = "T_transpose",
std::string tag = kInjective) {
if (!axes.defined() || axes.size() == 0) {
axes = Array<Integer>();
axes.Set(i, new_axis);
}
CHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
- << "axis=" << axis << " is invalid for the "
- << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
+ << "axis=" << axis << " is invalid for the " << static_cast<int>(x->shape.size())
+ << "-dimensional input tensor";
for (size_t j = 0; j < axes.size(); ++j) {
- if (i !=j) {
+ if (i != j) {
CHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose";
}
}
}
return compute(
- new_shape, [&](const Array<Var>& indices) {
- std::vector<PrimExpr> idx;
- for (size_t i = 0; i < axes.size(); ++i) {
- idx.push_back(1);
- }
- for (size_t i = 0; i < axes.size(); ++i) {
- int axis = static_cast<int>(axes[i]->value);
- idx[axis] = indices[i];
- }
- return x(idx);
- }, name, tag);
+ new_shape,
+ [&](const Array<Var>& indices) {
+ std::vector<PrimExpr> idx;
+ for (size_t i = 0; i < axes.size(); ++i) {
+ idx.push_back(1);
+ }
+ for (size_t i = 0; i < axes.size(); ++i) {
+ int axis = static_cast<int>(axes[i]->value);
+ idx[axis] = indices[i];
+ }
+ return x(idx);
+ },
+ name, tag);
}
/*!
-* \brief flip/reverse elements of an array in a particular axis
-*
-* \param x The input tensor
-* \param axis The axis along which the tensors will be reveresed
-* (allows negative indices)
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the reverse operation
-*/
-inline Tensor flip(const Tensor& x,
- int axis = 0,
- std::string name = "T_flip",
+ * \brief flip/reverse elements of an array in a particular axis
+ *
+ * \param x The input tensor
+ * \param axis The axis along which the tensors will be reveresed
+ * (allows negative indices)
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the reverse operation
+ */
+inline Tensor flip(const Tensor& x, int axis = 0, std::string name = "T_flip",
std::string tag = kInjective) {
size_t src_tensor_dim = x->shape.size();
int axis_inp = axis;
}
CHECK((0 <= axis) && (axis < static_cast<int>(x->shape.size())))
- << "axis=" << axis_inp << " is invalid for the "
- << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
+ << "axis=" << axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
+ << "-dimensional input tensor";
// Reverse the Input Tensor in the axis specified
return compute(
- x->shape, [&](const Array<Var>& indices) {
- Array<PrimExpr> real_indices;
- for (size_t i = 0; i < src_tensor_dim; ++i) {
- if (i == static_cast<size_t>(axis)) {
- real_indices.push_back(x->shape[i] - indices[i] - 1);
- } else {
- real_indices.push_back(indices[i]);
+ x->shape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> real_indices;
+ for (size_t i = 0; i < src_tensor_dim; ++i) {
+ if (i == static_cast<size_t>(axis)) {
+ real_indices.push_back(x->shape[i] - indices[i] - 1);
+ } else {
+ real_indices.push_back(indices[i]);
+ }
}
- }
- return x(real_indices);
- }, name, tag);
+ return x(real_indices);
+ },
+ name, tag);
}
/*!
-* \brief Reshape a tensor
-*
-* \param x The input tensor
-* \param newshape The new shape
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the reshape operation
-*/
-inline Tensor reshape(const Tensor& x,
- Array<PrimExpr> newshape,
- std::string name = "T_reshape",
+ * \brief Reshape a tensor
+ *
+ * \param x The input tensor
+ * \param newshape The new shape
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the reshape operation
+ */
+inline Tensor reshape(const Tensor& x, Array<PrimExpr> newshape, std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
Array<PrimExpr> target_shape;
- for (const auto &ele : newshape) {
+ for (const auto& ele : newshape) {
if (ele.as<IntImmNode>()) {
target_shape.push_back(cast(DataType::Int(32), ele));
} else {
}
if (is_empty_shape(target_shape)) {
- return compute(target_shape,
- [&](const Array<Var> &indices) { return tvm::cast(x->dtype, 0); },
- name, tag);
+ return compute(
+ target_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
} else {
return compute(
- target_shape, [&](const Array<Var>& indices) {
- return x(UnravelIndex(
- RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape),
- x_shape));
- }, name, tag);
+ target_shape,
+ [&](const Array<Var>& indices) {
+ return x(UnravelIndex(
+ RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape), x_shape));
+ },
+ name, tag);
}
}
* \return A Tensor of coordinate arrays.
*/
-inline Tensor unravel_index(const Tensor& x,
- const Tensor& shape,
- std::string name = "T_unravel",
+inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel",
std::string tag = kInjective) {
auto x_shape = x->shape;
auto shape_shape = shape->shape;
}
/*!
-* \brief Remove size 1 dimensions from the shape of a tensor.
-* The removed dimensions must have a constant size of 1.
-*
-* \param x The input tensor
-* \param axis Indices of the dimensions to remove. If this is empty,
-* all entries with a constant size of 1 will be removed.
+ * \brief Remove size 1 dimensions from the shape of a tensor.
+ * The removed dimensions must have a constant size of 1.
+ *
+ * \param x The input tensor
+ * \param axis Indices of the dimensions to remove. If this is empty,
+ * all entries with a constant size of 1 will be removed.
* \param atleast1d Whether the output need to be atleast1d.
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the squeeze operation
-*/
-inline Tensor squeeze(const Tensor& x,
- Array<Integer> axis,
- bool atleast1d = false,
- std::string name = "T_squeeze",
- std::string tag = kInjective) {
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the squeeze operation
+ */
+inline Tensor squeeze(const Tensor& x, Array<Integer> axis, bool atleast1d = false,
+ std::string name = "T_squeeze", std::string tag = kInjective) {
auto ndim = x->shape.size();
std::vector<int> axis_val;
if (!axis.defined() || axis.size() == 0) {
if (val < 0) {
val += static_cast<int>(x->shape.size());
}
- CHECK_EQ(GetConstInt(x->shape[val]), 1) <<
- "Dimension " << val << " must have size 1";
+ CHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
axis_val.push_back(val);
}
}
}
return compute(
- out_shape, [&](const Array<Var>& indices) {
- Array<PrimExpr> real_indices;
- int flag = 0;
- for (size_t i = 0; i < ndim; ++i) {
- if (axis_set.count(static_cast<int>(i)) == 0) {
- real_indices.push_back(indices[i - flag]);
- } else {
- real_indices.push_back(0);
- flag += 1;
+ out_shape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> real_indices;
+ int flag = 0;
+ for (size_t i = 0; i < ndim; ++i) {
+ if (axis_set.count(static_cast<int>(i)) == 0) {
+ real_indices.push_back(indices[i - flag]);
+ } else {
+ real_indices.push_back(0);
+ flag += 1;
+ }
}
- }
- return x(real_indices);
- }, name, tag);
+ return x(real_indices);
+ },
+ name, tag);
}
/*!
-* \brief Join a sequence of tensors along an existing axis
-*
-* \param inputs The input tensors
-* \param axis The axis along which the tensors will be joined
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the concatenate operation
-*/
-inline Tensor concatenate(const Array<Tensor>& inputs,
- int axis = 0,
- std::string name = "T_concat",
+ * \brief Join a sequence of tensors along an existing axis
+ *
+ * \param inputs The input tensors
+ * \param axis The axis along which the tensors will be joined
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the concatenate operation
+ */
+inline Tensor concatenate(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_concat",
std::string tag = kInjective) {
int ndim = static_cast<int>(inputs[0]->shape.size());
- CHECK(-ndim <= axis && axis < ndim)
- << "concatenate only accepts `axis` in [-ndim, ndim)"
- << ", but got axis = " << axis
- << ", and ndim = " << ndim;
+ CHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)"
+ << ", but got axis = " << axis << ", and ndim = " << ndim;
if (axis < 0) {
axis += ndim;
}
- CHECK_LT(axis, inputs[0]->shape.size()) <<
- "axis out of bounds";
+ CHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds";
Array<PrimExpr> axis_sizes;
for (auto t : inputs) {
}
return compute(
- out_shape, [&](const Array<Var>& indices) {
- auto ret = inputs[0](indices);
- auto ind = indices[axis];
- for (size_t i = 0; i < inputs.size() - 1; ++i) {
- ind -= axis_sizes[i];
+ out_shape,
+ [&](const Array<Var>& indices) {
+ auto ret = inputs[0](indices);
+ auto ind = indices[axis];
+ for (size_t i = 0; i < inputs.size() - 1; ++i) {
+ ind -= axis_sizes[i];
+
+ Array<PrimExpr> idx;
+ for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
+ idx.push_back(indices[i]);
+ }
+ idx.push_back(ind);
+ for (size_t i = axis + 1; i < indices.size(); ++i) {
+ idx.push_back(indices[i]);
+ }
- Array<PrimExpr> idx;
- for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
- idx.push_back(indices[i]);
+ ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret);
}
- idx.push_back(ind);
- for (size_t i = axis + 1; i < indices.size(); ++i) {
- idx.push_back(indices[i]);
- }
-
- ret = tvm::if_then_else(ind >= 0,
- inputs[i + 1](idx),
- ret);
- }
- return ret;
- }, name, tag);
+ return ret;
+ },
+ name, tag);
}
/*!
-* \brief Join a sequence of tensors along a new axis.
-*
-* \param inputs The input tensors
-* \param axis The axis along which the tensors will be stacked
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the stack operation
-*/
-inline Tensor stack(const Array<Tensor>& inputs,
- int axis = 0,
- std::string name = "T_stack",
+ * \brief Join a sequence of tensors along a new axis.
+ *
+ * \param inputs The input tensors
+ * \param axis The axis along which the tensors will be stacked
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the stack operation
+ */
+inline Tensor stack(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_stack",
std::string tag = kInjective) {
int ndim = static_cast<int>(inputs[0]->shape.size());
CHECK(-ndim - 1 <= axis && axis <= ndim)
- << "stack only accepts `axis` in [-ndim, ndim)"
- << ", but got axis = " << axis
- << ", and ndim = " << ndim;
+ << "stack only accepts `axis` in [-ndim, ndim)"
+ << ", but got axis = " << axis << ", and ndim = " << ndim;
if (axis < 0) {
axis += ndim + 1;
}
- CHECK_LT(axis, inputs[0]->shape.size() + 1) <<
- "axis out of bounds";
+ CHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds";
const int stack_size = static_cast<int>(inputs.size());
Array<PrimExpr> out_shape;
- for (size_t i = 0; i < static_cast<size_t>(axis); ++i)
- out_shape.push_back(inputs[0]->shape[i]);
+ for (size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]);
out_shape.push_back(stack_size);
for (size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
out_shape.push_back(inputs[0]->shape[i]);
return compute(
- out_shape, [&](const Array<Var>& indices) {
- Array<PrimExpr> idx;
- for (size_t i = 0; i < indices.size(); ++i)
- if (i != static_cast<size_t>(axis))
- idx.push_back(indices[i]);
- auto ind = indices[axis];
- auto ret = inputs[0](idx);
- for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
- ret = tvm::if_then_else(ind == i + 1,
- inputs[i + 1](idx),
- ret);
- }
- return ret;
- }, name, tag);
+ out_shape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> idx;
+ for (size_t i = 0; i < indices.size(); ++i)
+ if (i != static_cast<size_t>(axis)) idx.push_back(indices[i]);
+ auto ind = indices[axis];
+ auto ret = inputs[0](idx);
+ for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
+ ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret);
+ }
+ return ret;
+ },
+ name, tag);
}
/*!
-* \brief Split a tensor into multiple sub-tensors
-*
-* \param x The input tensor
-* \param split_indices The indices to split the input at. This must be in ascending
-* order.
-* \param axis The axis to split along.
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the split operation
-*/
-inline Array<Tensor> split(const Tensor& x,
- Array<Integer> split_indices,
- int axis,
- std::string name = "T_split",
- std::string tag = kInjective) {
+ * \brief Split a tensor into multiple sub-tensors
+ *
+ * \param x The input tensor
+ * \param split_indices The indices to split the input at. This must be in ascending
+ * order.
+ * \param axis The axis to split along.
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the split operation
+ */
+inline Array<Tensor> split(const Tensor& x, Array<Integer> split_indices, int axis,
+ std::string name = "T_split", std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
}
for (Integer idx : split_indices) {
int val = static_cast<int>(idx->value);
- CHECK_GT(val, begin_ids.back())
- << "split_indices must be sorted";
+ CHECK_GT(val, begin_ids.back()) << "split_indices must be sorted";
begin_ids.push_back(val);
}
- Array< Array<PrimExpr> > out_shapes;
+ Array<Array<PrimExpr> > out_shapes;
for (size_t i = 0; i < begin_ids.size(); ++i) {
int out_axis_size;
if (i == begin_ids.size() - 1) {
Array<Tensor> result;
for (size_t i = 0; i < begin_ids.size(); ++i) {
- result.push_back(
- compute(
- out_shapes[i], [&](const Array<Var>& indices) {
+ result.push_back(compute(
+ out_shapes[i],
+ [&](const Array<Var>& indices) {
auto begin = begin_ids[i];
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
}
return x(real_indices);
- }, name, tag));
+ },
+ name, tag));
}
return result;
}
/*!
-* \brief strided_slice of a tensor
-*
-* \param x The input tensor
-* \param begin The indices to begin with in the slicing
-* \param end Indicies indicating end of the slice
-* \param strides Specifies the stride values, it can be negative
-* in that case, the input tensor will be reversed in that particular axis
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the split operation
-*/
-inline Tensor strided_slice(const Tensor& x,
- const Array<Integer>& begin,
- const Array<Integer>& end,
- const Array<Integer>& strides,
- std::string name = "T_strided_slice",
+ * \brief strided_slice of a tensor
+ *
+ * \param x The input tensor
+ * \param begin The indices to begin with in the slicing
+ * \param end Indicies indicating end of the slice
+ * \param strides Specifies the stride values, it can be negative
+ * in that case, the input tensor will be reversed in that particular axis
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the split operation
+ */
+inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
+ const Array<Integer>& strides, std::string name = "T_strided_slice",
std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
// Setup the ranges.
int64_t end_i = index_canonicalization(end_vec[i]);
int interval = std::abs(end_i - begin_i);
- int slice_size = static_cast<int>((interval
- + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
+ int slice_size =
+ static_cast<int>((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
CHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
- << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
- << "] is invalid for axis=" << i;
+ << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
+ << "] is invalid for axis=" << i;
begin_expr.push_back(make_const(begin[0].dtype(), begin_i));
- strides_expr.push_back(make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()),
- stride_vec[i]));
+ strides_expr.push_back(
+ make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i]));
out_shape.push_back(slice_size);
}
return compute(
- out_shape, [&](const Array<Var>& indices) {
- Array<PrimExpr> real_indices;
- for (size_t i = 0; i < src_tensor_dim; ++i) {
- real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
- }
- return x(real_indices);
- }, name, tag);
+ out_shape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> real_indices;
+ for (size_t i = 0; i < src_tensor_dim; ++i) {
+ real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
+ }
+ return x(real_indices);
+ },
+ name, tag);
}
/*!
-* \brief Split a tensor into a number of sub-tensors
-*
-* \param x The input tensor
-* \param num_sections The number of sections to split the tensor into.
-* this must be an integer factor of the size of the axis being split.
-* \param axis The axis to split along.
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the split operation
-*/
-inline Array<Tensor> split_sections(const Tensor& x,
- int num_sections,
- int axis,
+ * \brief Split a tensor into a number of sub-tensors
+ *
+ * \param x The input tensor
+ * \param num_sections The number of sections to split the tensor into.
+ * this must be an integer factor of the size of the axis being split.
+ * \param axis The axis to split along.
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the split operation
+ */
+inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
std::string name = "T_split_sections",
std::string tag = kInjective) {
if (axis < 0) {
CHECK_GT(num_sections, 0) << "Slice count must be > 0";
CHECK_EQ(src_axis_size % num_sections, 0)
- << "num_sections must be an integer factor of the size of axis " << axis
- << " (" << src_axis_size << ")";
+ << "num_sections must be an integer factor of the size of axis " << axis << " ("
+ << src_axis_size << ")";
Array<Integer> split_indices;
auto seg_size = src_axis_size / num_sections;
}
/*!
-* \brief Take elements from an flattened input array when axis is None.
-*
-* \param a The source array.
-* \param indices The indices of the values to extract.
-* \param mode The mode of the operation.
-* \param name The name of the operation.
-* \param mode The mode of to handle out of bound indices.
-* \param tag The tag to mark the operation.
-*
-* \return A Tensor whose op member is the take operation
-*/
-inline Tensor take(const Tensor& a,
- const Tensor& indices,
- std::string mode = "clip",
- std::string name = "T_take",
- std::string tag = kInjective) {
+ * \brief Take elements from an flattened input array when axis is None.
+ *
+ * \param a The source array.
+ * \param indices The indices of the values to extract.
+ * \param mode The mode of the operation.
+ * \param name The name of the operation.
+ * \param mode The mode of to handle out of bound indices.
+ * \param tag The tag to mark the operation.
+ *
+ * \return A Tensor whose op member is the take operation
+ */
+inline Tensor take(const Tensor& a, const Tensor& indices, std::string mode = "clip",
+ std::string name = "T_take", std::string tag = kInjective) {
Array<PrimExpr> a_shape = a->shape;
Array<PrimExpr> out_shape = indices->shape;
PrimExpr a_size = 1;
if (mode == "clip") {
return compute(
- out_shape, [&](const Array<Var>& out_index) {
+ out_shape,
+ [&](const Array<Var>& out_index) {
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
return a(UnravelIndex(idx, a_shape));
- }, name, tag);
+ },
+ name, tag);
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
return compute(
- out_shape, [&](const Array<Var>& out_index) {
- return a(UnravelIndex(indices(out_index), a_shape));
- }, name, tag);
+ out_shape,
+ [&](const Array<Var>& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); },
+ name, tag);
} else { // mode == "wrap"
return compute(
- out_shape, [&](const Array<Var>& out_index) {
+ out_shape,
+ [&](const Array<Var>& out_index) {
auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
return a(UnravelIndex(idx, a_shape));
- }, name, tag);
+ },
+ name, tag);
}
}
-
/*!
-* \brief Mask the out-of-boundary elements of each sequence.
-*
-* \param data The source array.
-* \param valid_length The real length of each sequence.
-* \param mask_value The masking value.
-* \param axis The axis of the temporal dimension of the sequence
-* \param name The name of the operation.
-* \param tag The tag to mark the operation.
-*
-* \return A Tensor whose op member is the sequence_mask operation
-*/
-inline Tensor sequence_mask(const Tensor& data,
- const Tensor& valid_length,
- double mask_value,
- int axis,
- std::string name = "T_sequence_mask",
+ * \brief Mask the out-of-boundary elements of each sequence.
+ *
+ * \param data The source array.
+ * \param valid_length The real length of each sequence.
+ * \param mask_value The masking value.
+ * \param axis The axis of the temporal dimension of the sequence
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ *
+ * \return A Tensor whose op member is the sequence_mask operation
+ */
+inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value,
+ int axis, std::string name = "T_sequence_mask",
std::string tag = kInjective) {
CHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
CHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
auto batch_dim = data->shape[1 - axis];
Array<PrimExpr> out_shape = data->shape;
Tensor out = compute(
- out_shape, [&](const Array<Var>& out_index) {
+ out_shape,
+ [&](const Array<Var>& out_index) {
Array<PrimExpr> len_index;
auto tid = out_index[axis];
auto bid = out_index[1 - axis];
len_index.push_back(bid);
- PrimExpr ret = tvm::if_then_else(
- tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
- tvm::tir::make_const(data->dtype, mask_value), data(out_index));
+ PrimExpr ret =
+ tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
+ tvm::tir::make_const(data->dtype, mask_value), data(out_index));
return ret;
- }, name, tag);
+ },
+ name, tag);
return out;
}
/*!
-* \brief Take elements from an array along an axis.
-*
-* \param a The source array.
-* \param indices The indices of the values to extract.
-* \param axis The axis over which to select values. By default,
-* the flattened input array is used.
-* \param mode The mode for handling out of bound indices.
-* \param name The name of the operation.
-* \param tag The tag to mark the operation.
-*
-* \return A Tensor whose op member is the take operation
-*/
-inline Tensor take(const Tensor& a,
- const Tensor& indices,
- int axis,
- std::string mode = "clip",
- std::string name = "T_take",
- std::string tag = kInjective) {
+ * \brief Take elements from an array along an axis.
+ *
+ * \param a The source array.
+ * \param indices The indices of the values to extract.
+ * \param axis The axis over which to select values. By default,
+ * the flattened input array is used.
+ * \param mode The mode for handling out of bound indices.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ *
+ * \return A Tensor whose op member is the take operation
+ */
+inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string mode = "clip",
+ std::string name = "T_take", std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
}
}
if (mode == "clip") {
return compute(
- out_shape, [&](const Array<Var>& out_index) {
+ out_shape,
+ [&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
- for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
+ for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
- auto idx = tvm::min(tvm::max(0, indices(indices_position)),
- axis_dim - 1);
+ auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
- }, name, tag);
+ },
+ name, tag);
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
return compute(
- out_shape, [&](const Array<Var>& out_index) {
+ out_shape,
+ [&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
- for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
+ for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
real_indices.push_back(out_index[j]);
}
return a(real_indices);
- }, name, tag);
+ },
+ name, tag);
} else { // mode == "wrap"
return compute(
- out_shape, [&](const Array<Var>& out_index) {
+ out_shape,
+ [&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
- for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
+ for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
real_indices.push_back(out_index[j]);
}
return a(real_indices);
- }, name, tag);
+ },
+ name, tag);
}
}
/*!
-* \brief Return the elements, either from x or y, depending on the condition.
-*
-* \param condition The condition array.
-* \param x First array to be selected.
-* \param y Second array to be selected.
-* \param name The name of the operation.
-* \param tag The tag to mark the operation.
-*
-* \return A Tensor selected from x or y depending on condition.
-*/
-inline Tensor where(const Tensor& condition,
- const Tensor& x,
- const Tensor& y,
- std::string name = "T_where",
- std::string tag = kBroadcast) {
+ * \brief Return the elements, either from x or y, depending on the condition.
+ *
+ * \param condition The condition array.
+ * \param x First array to be selected.
+ * \param y Second array to be selected.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ *
+ * \return A Tensor selected from x or y depending on condition.
+ */
+inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
+ std::string name = "T_where", std::string tag = kBroadcast) {
CHECK_EQ(x->shape.size(), y->shape.size())
- << "x and y must have the same shape.Got different number of dimension: "
- << x->shape.size() << " vs " << y->shape.size();
- CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: "
- << x->dtype << " vs " << y->dtype;
+ << "x and y must have the same shape.Got different number of dimension: " << x->shape.size()
+ << " vs " << y->shape.size();
+ CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
+ << y->dtype;
Array<PrimExpr> oshape = x->shape;
Tensor out;
if (condition->shape.size() != 1) {
CHECK_EQ(condition->shape.size(), x->shape.size())
- << "condition array must be either have the same shape as x or to be a "
- "1-D array.Got different number of dimension: "
- << condition->shape.size() << " vs " << x->shape.size();
+ << "condition array must be either have the same shape as x or to be a "
+ "1-D array.Got different number of dimension: "
+ << condition->shape.size() << " vs " << x->shape.size();
out = compute(
- oshape, [&](const Array<Var>& indices) {
- return tvm::tir::SelectNode::make(condition(indices) != 0, x(indices), y(indices));
- }, name, tag);
+ oshape,
+ [&](const Array<Var>& indices) {
+ return tvm::tir::SelectNode::make(condition(indices) != 0, x(indices), y(indices));
+ },
+ name, tag);
} else {
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
- << "If condition is 1-D, the first dimension must be the same as x: "
- << condition->shape[0] << " vs " << x->shape[0];
+ << "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0]
+ << " vs " << x->shape[0];
out = compute(
- oshape, [&](const Array<Var>& indices) {
- Array<PrimExpr> condition_idx{indices[0]};
- return tvm::tir::SelectNode::make(condition(condition_idx) != 0,
- x(indices), y(indices));
- }, name, tag);
+ oshape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> condition_idx{indices[0]};
+ return tvm::tir::SelectNode::make(condition(condition_idx) != 0, x(indices), y(indices));
+ },
+ name, tag);
}
return out;
}
/*!
-* \brief Creates an operation to repeat elements of an array
-*
-* \param x The input tensor
-* \param repeats The number of repetitions for each element
-* \param axis The axis along which to repeat values (allows
-* negative indices as offsets from the last dimension)
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the repeat operation
-*/
-inline Tensor repeat(const Tensor& x,
- int repeats,
- int axis,
- std::string name = "T_repeat",
+ * \brief Creates an operation to repeat elements of an array
+ *
+ * \param x The input tensor
+ * \param repeats The number of repetitions for each element
+ * \param axis The axis along which to repeat values (allows
+ * negative indices as offsets from the last dimension)
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the repeat operation
+ */
+inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat",
std::string tag = kBroadcast) {
int ndim = static_cast<int>(x->shape.size());
CHECK(-ndim - 1 <= axis && axis <= ndim)
- << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
- << ", but got axis = " << axis
- << ", and data.ndim = " << ndim;
- CHECK(repeats >= 1)
- << "repeat only accepts `repeats >= 1`"
- << ", but got repeats = " << repeats;
+ << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
+ << ", but got axis = " << axis << ", and data.ndim = " << ndim;
+ CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
+ << ", but got repeats = " << repeats;
if (axis < 0) {
// Calculate offset from last dimension
axis += ndim;
}
return compute(
- new_shape, [&](const Array<Var>& indices) {
- Array<PrimExpr> idx;
- for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
- idx.push_back(indices[i]);
- }
- idx.push_back(indexdiv(indices[axis], repeats));
- for (size_t i = axis + 1; i < indices.size(); ++i) {
- idx.push_back(indices[i]);
- }
- return x(idx);
- }, name, tag);
+ new_shape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> idx;
+ for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
+ idx.push_back(indices[i]);
+ }
+ idx.push_back(indexdiv(indices[axis], repeats));
+ for (size_t i = axis + 1; i < indices.size(); ++i) {
+ idx.push_back(indices[i]);
+ }
+ return x(idx);
+ },
+ name, tag);
}
/*!
-* \brief Creates an operation to tile elements of an array
-*
-* \param x The input tensor
-* \param reps The number of times for repeating the tensor
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the tile operation
-*/
-inline Tensor tile(const Tensor& x,
- Array<Integer> reps,
- std::string name = "T_tile",
+ * \brief Creates an operation to tile elements of an array
+ *
+ * \param x The input tensor
+ * \param reps The number of times for repeating the tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the tile operation
+ */
+inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_tile",
std::string tag = kBroadcast) {
size_t ndim = x->shape.size();
size_t rdim = reps.size();
reps_shape.push_back(reps[i]);
}
} else if (ndim > rdim) {
- for (size_t i = 0; i < ndim; ++i)
- data_shape.push_back(x->shape[i]);
- for (size_t i = 0; i < (ndim - rdim); ++i)
- reps_shape.push_back(1);
- for (size_t i = 0; i < rdim; ++i)
- reps_shape.push_back(reps[i]);
+ for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
+ for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1);
+ for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
} else {
- for (size_t i = 0; i < (rdim - ndim); ++i)
- data_shape.push_back(1);
- for (size_t i = 0; i < ndim; ++i)
- data_shape.push_back(x->shape[i]);
- for (size_t i = 0; i < rdim; ++i)
- reps_shape.push_back(reps[i]);
+ for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1);
+ for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
+ for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
}
- for (size_t i = 0; i < tdim; ++i)
- new_shape.push_back(data_shape[i] * reps_shape[i]);
+ for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]);
if (is_empty_shape(new_shape)) {
- return compute(new_shape,
- [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0);},
- name, tag);
+ return compute(
+ new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
} else {
return compute(
- new_shape, [&](const Array<Var>& indices) {
- Array<PrimExpr> idx;
- if (ndim >= rdim) {
- for (size_t i = 0; i < ndim; ++i)
- idx.push_back(indexmod(indices[i], x->shape[i]));
- } else {
- for (size_t i = 0; i < ndim; ++i)
- idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
- }
- return x(idx);
- }, name, tag);
+ new_shape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> idx;
+ if (ndim >= rdim) {
+ for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i]));
+ } else {
+ for (size_t i = 0; i < ndim; ++i)
+ idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
+ }
+ return x(idx);
+ },
+ name, tag);
}
}
/*!
-* \brief Gather elements from a n-dimension array.
-*
-* \param data The source array.
-* \param indices The indices of the values to extract.
-* \param name The name of the operation.
-* \param tag The tag to mark the operation.
-*
-* \return A Tensor whose op member is the gather_nd operation
-*/
-inline Tensor gather_nd(const Tensor& data,
- const Tensor& indices,
- std::string name = "T_gather_nd",
+ * \brief Gather elements from a n-dimension array.
+ *
+ * \param data The source array.
+ * \param indices The indices of the values to extract.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ *
+ * \return A Tensor whose op member is the gather_nd operation
+ */
+inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string name = "T_gather_nd",
std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
out_shape.push_back(make_const(DataType::Int(32), 1));
}
return compute(
- out_shape, [&](const Array<Var>& out_index) {
- Array<PrimExpr> indices_position;
- indices_position.push_back(0);
- for (size_t i = 0; i < ndim_i - 1; ++i) {
- indices_position.push_back(out_index[i]);
- }
- Array<PrimExpr> real_indices;
- for (size_t i = 0; i < indices_dim0; ++i) {
- indices_position.Set(0, make_const(DataType::Int(32), i));
- if (indices->dtype.is_int()) {
- real_indices.push_back(indices(indices_position));
- } else {
- real_indices.push_back(
- tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
- }
- }
- for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
- real_indices.push_back(out_index[i]);
+ out_shape,
+ [&](const Array<Var>& out_index) {
+ Array<PrimExpr> indices_position;
+ indices_position.push_back(0);
+ for (size_t i = 0; i < ndim_i - 1; ++i) {
+ indices_position.push_back(out_index[i]);
+ }
+ Array<PrimExpr> real_indices;
+ for (size_t i = 0; i < indices_dim0; ++i) {
+ indices_position.Set(0, make_const(DataType::Int(32), i));
+ if (indices->dtype.is_int()) {
+ real_indices.push_back(indices(indices_position));
+ } else {
+ real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
}
- return data(real_indices);
- }, name, tag);
+ }
+ for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
+ real_indices.push_back(out_index[i]);
+ }
+ return data(real_indices);
+ },
+ name, tag);
}
/*!
*
* \return A Tensor whose op member is the matmul operation
*/
-inline tvm::te::Tensor matmul(const tvm::te::Tensor& A,
- const tvm::te::Tensor& B,
- bool trans_a = false,
- bool trans_b = false,
- std::string name = "T_matmul",
- std::string tag = kMatMul) {
- tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0],
- B->shape[trans_b ? 0 : 1]};
+inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B,
+ bool trans_a = false, bool trans_b = false,
+ std::string name = "T_matmul", std::string tag = kMatMul) {
+ tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]};
auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
auto l = [&](tvm::tir::Var i, tvm::tir::Var j) {
- return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]),
- {k});
+ return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
};
return tvm::te::compute(output_shape, l, name, tag);
}
*
* \return A Tensor computing the result
*/
-inline Tensor tensordot(const Tensor& A,
- const tvm::te::Tensor& B,
- int axes = 2,
- std::string name = "T_tensordot",
- std::string tag = kMatMul) {
+inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2,
+ std::string name = "T_tensordot", std::string tag = kMatMul) {
CHECK_GE(A->shape.size(), axes);
CHECK_GE(B->shape.size(), axes);
Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
- for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it)
- output_shape.push_back(*it);
+ for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
Array<IterVar> iter_vars;
for (int i = 0; i < axes; ++i)
iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i)));
- auto func =
- [&A, &B, &iter_vars, axes]
- (const Array<Var>& input_indices) {
- Array<PrimExpr> A_indices(
- input_indices.begin(),
- input_indices.begin() + (A->shape.size() - axes));
- for (auto& v : iter_vars)
- A_indices.push_back(v);
-
- Array<PrimExpr> B_indices;
- for (auto& v : iter_vars)
- B_indices.push_back(v);
-
- auto it = input_indices.begin() + (A->shape.size() - axes);
- for (; it != input_indices.end(); ++it)
- B_indices.push_back(*it);
-
- // Some passes don't like reductions with empty axis, so avoid it here
- if (iter_vars.empty())
- return A(A_indices) * B(B_indices);
- else
- return sum(A(A_indices) * B(B_indices), iter_vars);
- };
+ auto func = [&A, &B, &iter_vars, axes](const Array<Var>& input_indices) {
+ Array<PrimExpr> A_indices(input_indices.begin(),
+ input_indices.begin() + (A->shape.size() - axes));
+ for (auto& v : iter_vars) A_indices.push_back(v);
+
+ Array<PrimExpr> B_indices;
+ for (auto& v : iter_vars) B_indices.push_back(v);
+
+ auto it = input_indices.begin() + (A->shape.size() - axes);
+ for (; it != input_indices.end(); ++it) B_indices.push_back(*it);
+
+ // Some passes don't like reductions with empty axis, so avoid it here
+ if (iter_vars.empty())
+ return A(A_indices) * B(B_indices);
+ else
+ return sum(A(A_indices) * B(B_indices), iter_vars);
+ };
return compute(output_shape, func, name, tag);
}
*
* \return A Tensor computing the result
*/
-inline Tensor tensordot(const Tensor& A,
- const tvm::te::Tensor& B,
- Array<PrimExpr> A_axes,
- Array<PrimExpr> B_axes,
- std::string name = "T_tensordot",
+inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array<PrimExpr> A_axes,
+ Array<PrimExpr> B_axes, std::string name = "T_tensordot",
std::string tag = kMatMul) {
CHECK_EQ(A_axes.size(), B_axes.size());
output_shape.push_back(B->shape[i]);
Array<IterVar> iter_vars;
- for (unsigned i = 0; i < B_axes_val.size(); ++i)
- iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
-
- auto func =
- [&A, &B, &iter_vars, A_axes_val, B_axes_val]
- (const Array<Var>& input_indices) {
- int idx_input = 0;
- Array<PrimExpr> A_indices;
- for (unsigned i = 0; i < A->shape.size(); ++i) {
- auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
- if (axes_pos == A_axes_val.end())
- A_indices.push_back(input_indices[idx_input++]);
- else
- A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
- }
+ for (unsigned i = 0; i < B_axes_val.size(); ++i)
+ iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
+
+ auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array<Var>& input_indices) {
+ int idx_input = 0;
+ Array<PrimExpr> A_indices;
+ for (unsigned i = 0; i < A->shape.size(); ++i) {
+ auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
+ if (axes_pos == A_axes_val.end())
+ A_indices.push_back(input_indices[idx_input++]);
+ else
+ A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
+ }
- Array<PrimExpr> B_indices;
- for (unsigned i = 0; i < B->shape.size(); ++i) {
- auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
- if (axes_pos == B_axes_val.end())
- B_indices.push_back(input_indices[idx_input++]);
- else
- B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
- }
- return sum(A(A_indices) * B(B_indices), iter_vars);
- };
+ Array<PrimExpr> B_indices;
+ for (unsigned i = 0; i < B->shape.size(); ++i) {
+ auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
+ if (axes_pos == B_axes_val.end())
+ B_indices.push_back(input_indices[idx_input++]);
+ else
+ B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
+ }
+ return sum(A(A_indices) * B(B_indices), iter_vars);
+ };
return compute(output_shape, func, name, tag);
}
-inline Tensor arange(const PrimExpr& start,
- const PrimExpr& stop,
- const PrimExpr& step,
- DataType dtype,
- std::string name = "T_arange",
- std::string tag = kInjective) {
- PrimExpr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil(
- tvm::cast(tvm::DataType::Float(32), stop - start) / step));
+inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step,
+ DataType dtype, std::string name = "T_arange", std::string tag = kInjective) {
+ PrimExpr num_elem = tvm::cast(
+ tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
Array<PrimExpr> shape;
- return compute({num_elem}, [&](const Array<Var>& indices) {
- return tvm::cast(dtype, start + step * indices[0]);
- }, name, tag);
+ return compute(
+ {num_elem},
+ [&](const Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name,
+ tag);
}
/*!
* \param tag output tensor tag.
* \return A tensor with shape in \p dst_layout
*/
-inline Tensor layout_transform(const Tensor& src,
- const std::string& src_layout,
+inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
const std::string& dst_layout,
const std::string name = "T_layout_trans",
const std::string tag = kInjective) {
}
CHECK(src_layout_struct.defined() && dst_layout_struct.defined())
- << "cannot convert from/to undefined layout";
+ << "cannot convert from/to undefined layout";
auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct);
- CHECK(layout_converter.defined())
- << "cannot convert from " << src_layout << " to " << dst_layout;
+ CHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout;
Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
return compute(
- dst_shape, [&](const Array<Var>& dst_indices) {
- Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
- Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
- return src(src_indices);
- }, name, tag);
+ dst_shape,
+ [&](const Array<Var>& dst_indices) {
+ Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
+ Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
+ return src(src_indices);
+ },
+ name, tag);
}
/*!
* \param tag output tensor tag.
* \return Tensor of input shape.
*/
-inline Tensor shape(const Tensor& src,
- DataType dtype,
- const std::string name = "T_shape",
+inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape",
const std::string tag = kInjective) {
int ndim = static_cast<int>(src->shape.size());
Array<PrimExpr> out_shape{ndim};
- return compute(out_shape, [&](const Array<Var>& indices) {
- auto idx = indices[0];
- PrimExpr ret = 0;
- for (int i = 0; i < ndim; ++i) {
- ret = tvm::if_then_else(idx == i, src->shape[i], ret);
- }
- return tvm::cast(dtype, ret);
- }, name, tag);
+ return compute(
+ out_shape,
+ [&](const Array<Var>& indices) {
+ auto idx = indices[0];
+ PrimExpr ret = 0;
+ for (int i = 0; i < ndim; ++i) {
+ ret = tvm::if_then_else(idx == i, src->shape[i], ret);
+ }
+ return tvm::cast(dtype, ret);
+ },
+ name, tag);
}
/*!
* \param tag output tensor tag.
* \return Tensor of input shape.
*/
-inline Tensor ndarray_size(const Tensor& src,
- const DataType& dtype,
+inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
const std::string& name = "ndarray_size",
const std::string& tag = kInjective) {
int ndim = static_cast<int>(src->shape.size());
Array<PrimExpr> out_ndarray_size = {1};
- return compute(out_ndarray_size, [&](const Array<Var>& indices) {
- PrimExpr ret = 1;
- for (int i = 0; i < ndim; ++i) {
- ret *= src->shape[i];
- }
- return tvm::cast(dtype, ret);
- }, name, tag);
+ return compute(
+ out_ndarray_size,
+ [&](const Array<Var>& indices) {
+ PrimExpr ret = 1;
+ for (int i = 0; i < ndim; ++i) {
+ ret *= src->shape[i];
+ }
+ return tvm::cast(dtype, ret);
+ },
+ name, tag);
}
/*!
* \param tag output tensor tag.
* \return one-hot tensor.
*/
-inline Tensor one_hot(const Tensor& indices,
- const PrimExpr on_value,
- const PrimExpr off_value,
- int depth,
- int axis,
- const DataType& dtype,
- const std::string name = "T_one_hot",
- const std::string tag = kInjective) {
+inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
+ int depth, int axis, const DataType& dtype,
+ const std::string name = "T_one_hot", const std::string tag = kInjective) {
Array<PrimExpr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
PrimExpr on_value_cast = cast(dtype, on_value);
PrimExpr off_value_cast = cast(dtype, off_value);
- return compute(oshape, [&](const Array<Var>& iter_vars) {
- Array<Var> indices_indices;
- for (size_t i = 0; i < iter_vars.size(); i++) {
- if (static_cast<int>(i) == true_axis) {
- continue;
- }
+ return compute(
+ oshape,
+ [&](const Array<Var>& iter_vars) {
+ Array<Var> indices_indices;
+ for (size_t i = 0; i < iter_vars.size(); i++) {
+ if (static_cast<int>(i) == true_axis) {
+ continue;
+ }
- indices_indices.push_back(iter_vars[i]);
- }
+ indices_indices.push_back(iter_vars[i]);
+ }
- auto idx = iter_vars[true_axis];
- return tir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, off_value_cast);
- }, name, tag);
+ auto idx = iter_vars[true_axis];
+ return tir::SelectNode::make(indices(indices_indices) == idx, on_value_cast,
+ off_value_cast);
+ },
+ name, tag);
}
} // namespace topi
#ifndef TOPI_VISION_REORG_H_
#define TOPI_VISION_REORG_H_
-#include <tvm/te/operation.h>
#include <topi/detail/constant_utils.h>
#include <topi/reduction.h>
#include <topi/tags.h>
#include <topi/transform.h>
+#include <tvm/te/operation.h>
#include <algorithm>
#include <string>
using namespace tvm::te;
/*!
-* \brief Reorg operation
-*
-* \param data The input tensor. Can be any dimension
-* \param stride The input integer used as stride in reorg operation
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the reorg operation
-*/
-inline Tensor reorg(const Tensor &data,
- int stride = 1,
- std::string name = "tensor",
+ * \brief Reorg operation
+ *
+ * \param data The input tensor. Can be any dimension
+ * \param stride The input integer used as stride in reorg operation
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the reorg operation
+ */
+inline Tensor reorg(const Tensor& data, int stride = 1, std::string name = "tensor",
std::string tag = "reorg_output") {
auto input_shape = data->shape;
int w_in = GetConstInt(input_shape[3]);
int out_c = c_in / (stride * stride);
- auto out = tvm::te::compute(input_shape,
- [&](Var b, Var k, Var j, Var i) {
- return data(b * stride * stride,
- indexmod(k, out_c) * stride * stride,
- (j*stride + indexdiv(indexdiv(k, out_c), stride)) * stride,
- (i*stride + indexmod(indexdiv(k, out_c), stride)));
- },
- name,
- tag);
+ auto out = tvm::te::compute(
+ input_shape,
+ [&](Var b, Var k, Var j, Var i) {
+ return data(b * stride * stride, indexmod(k, out_c) * stride * stride,
+ (j * stride + indexdiv(indexdiv(k, out_c), stride)) * stride,
+ (i * stride + indexmod(indexdiv(k, out_c), stride)));
+ },
+ name, tag);
out_c = c_in * stride * stride;
int out_h = h_in / stride;
#ifndef TOPI_X86_BNN_H_
#define TOPI_X86_BNN_H_
-#include <topi/tags.h>
#include <topi/detail/fuse.h>
-#include <tvm/te/operation.h>
+#include <topi/tags.h>
#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
namespace topi {
using namespace tvm;
namespace x86 {
/*!
-* \brief Create a generic schedule for binarize_pack
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_binarize_pack(const Target &target, const Array<Tensor>& outs) {
+ * \brief Create a generic schedule for binarize_pack
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_binarize_pack(const Target& target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
/*!
-* \brief Create a generic schedule for binary_dense
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_binary_dense(const Target &target, const Array<Tensor>& outs) {
+ * \brief Create a generic schedule for binary_dense
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_binary_dense(const Target& target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
#ifndef TOPI_X86_DEFAULT_H_
#define TOPI_X86_DEFAULT_H_
-#include <topi/tags.h>
#include <topi/detail/fuse.h>
+#include <topi/tags.h>
+#include <tvm/target/generic_func.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
-#include <tvm/target/generic_func.h>
namespace topi {
using namespace tvm;
namespace x86 {
/*!
-* \brief Helper to create a default x86 schedule for the given ops.
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-* \param auto_inline Whether to apply the auto inline step.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule MakeDefaultSchedule(const Target &target,
- const Array<Tensor>& outs,
+ * \brief Helper to create a default x86 schedule for the given ops.
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ * \param auto_inline Whether to apply the auto inline step.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule MakeDefaultSchedule(const Target& target, const Array<Tensor>& outs,
bool auto_inline) {
Array<Operation> out_ops;
for (auto t : outs) {
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
- auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
+ auto fused = detail::Fuse(s[x], {n, c}); // for nhwc layout, fuse n and h
s[x].parallel(fused);
} else {
s[x].parallel(axis[0]);
}
/*!
-* \brief Create a default x86 schedule for the given ops.
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule default_schedule(const Target &target, const Array<Tensor>& outs) {
+ * \brief Create a default x86 schedule for the given ops.
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule default_schedule(const Target& target, const Array<Tensor>& outs) {
return MakeDefaultSchedule(target, outs, false);
}
/*!
-* \brief Create a default x86 schedule for the given ops, with auto inline
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule default_schedule_auto_inline(const Target &target, const Array<Tensor>& outs) {
+ * \brief Create a default x86 schedule for the given ops, with auto inline
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule default_schedule_auto_inline(const Target& target, const Array<Tensor>& outs) {
return MakeDefaultSchedule(target, outs, true);
}
#ifndef TOPI_X86_INJECTIVE_H_
#define TOPI_X86_INJECTIVE_H_
-#include <topi/tags.h>
#include <topi/detail/fuse.h>
-#include <tvm/te/operation.h>
+#include <topi/tags.h>
#include <tvm/target/generic_func.h>
+#include <tvm/te/operation.h>
namespace topi {
using namespace tvm;
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
- auto fused = detail::Fuse(sch[out], { n, c }); // for nhwc layout, fuse n and h
+ auto fused = detail::Fuse(sch[out], {n, c}); // for nhwc layout, fuse n and h
sch[out].parallel(fused);
} else {
sch[out].parallel(axis[0]);
}
/*!
-* \brief Create an x86 schedule for the given injective ops.
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the given ops.
-*/
-inline Schedule schedule_injective(const Target &target, const Array<Tensor>& outs) {
+ * \brief Create an x86 schedule for the given injective ops.
+ *
+ * \param target The target to generate a schedule for.
+ * \param outs The output tensors.
+ *
+ * \return A schedule for the given ops.
+ */
+inline Schedule schedule_injective(const Target& target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
*/
/*!
-* \brief Registration of broadcast operators
-* \file broadcast.cc
-*/
-#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/registry.h>
-
+ * \brief Registration of broadcast operators
+ * \file broadcast.cc
+ */
#include <topi/broadcast.h>
#include <topi/util.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
-#define TOPI_REGISTER_BCAST_OP(OpName, Op) \
- TVM_REGISTER_GLOBAL(OpName) \
- .set_body([](TVMArgs args, TVMRetValue *rv) { \
- bool lhs_is_tensor = args[0].IsObjectRef<tvm::te::Tensor>(); \
- bool rhs_is_tensor = args[1].IsObjectRef<tvm::te::Tensor>(); \
- if (lhs_is_tensor && rhs_is_tensor) { \
- *rv = Op(args[0].operator tvm::te::Tensor(), \
- args[1].operator tvm::te::Tensor()); \
- } else if (!lhs_is_tensor && rhs_is_tensor) { \
- *rv = Op(args[0].operator tvm::PrimExpr(), \
- args[1].operator tvm::te::Tensor()); \
- } else if (lhs_is_tensor && !rhs_is_tensor) { \
- *rv = Op(args[0].operator tvm::te::Tensor(), \
- args[1].operator tvm::PrimExpr()); \
- } else if (!lhs_is_tensor && !rhs_is_tensor) { \
- *rv = Op(args[0].operator tvm::PrimExpr(), \
- args[1].operator tvm::PrimExpr()); \
- } \
- }); \
+#define TOPI_REGISTER_BCAST_OP(OpName, Op) \
+ TVM_REGISTER_GLOBAL(OpName).set_body([](TVMArgs args, TVMRetValue* rv) { \
+ bool lhs_is_tensor = args[0].IsObjectRef<tvm::te::Tensor>(); \
+ bool rhs_is_tensor = args[1].IsObjectRef<tvm::te::Tensor>(); \
+ if (lhs_is_tensor && rhs_is_tensor) { \
+ *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::te::Tensor()); \
+ } else if (!lhs_is_tensor && rhs_is_tensor) { \
+ *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::te::Tensor()); \
+ } else if (lhs_is_tensor && !rhs_is_tensor) { \
+ *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::PrimExpr()); \
+ } else if (!lhs_is_tensor && !rhs_is_tensor) { \
+ *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::PrimExpr()); \
+ } \
+ });
TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
-TVM_REGISTER_GLOBAL("topi.broadcast_to")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = broadcast_to(args[0], args[1]);
- });
+});
} // namespace topi
*/
/*!
-* \brief Registration of elemwise operators
-* \file elemwise.cc
-*/
+ * \brief Registration of elemwise operators
+ * \file elemwise.cc
+ */
+#include <topi/elemwise.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include <topi/elemwise.h>
-
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
-TVM_REGISTER_GLOBAL("topi.acos")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.acos").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = acos(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.acosh")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.acosh").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = acosh(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.asin")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.asin").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = asin(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.asinh")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.asinh").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = asinh(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.atanh")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.atanh").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = atanh(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.exp")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = exp(args[0]);
- });
+TVM_REGISTER_GLOBAL("topi.exp").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = exp(args[0]); });
-TVM_REGISTER_GLOBAL("topi.fast_exp")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.fast_exp").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = fast_exp(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.erf")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = erf(args[0]);
- });
+TVM_REGISTER_GLOBAL("topi.erf").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = erf(args[0]); });
-TVM_REGISTER_GLOBAL("topi.fast_erf")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.fast_erf").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = fast_erf(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.tan")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = tan(args[0]);
- });
+TVM_REGISTER_GLOBAL("topi.tan").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tan(args[0]); });
-TVM_REGISTER_GLOBAL("topi.cos")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = cos(args[0]);
- });
+TVM_REGISTER_GLOBAL("topi.cos").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cos(args[0]); });
-TVM_REGISTER_GLOBAL("topi.cosh")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cosh").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = cosh(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.sin")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = sin(args[0]);
- });
+TVM_REGISTER_GLOBAL("topi.sin").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sin(args[0]); });
-TVM_REGISTER_GLOBAL("topi.sinh")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.sinh").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = sinh(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.tanh")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.tanh").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = tanh(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.fast_tanh")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.fast_tanh").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = fast_tanh(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.atan")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.atan").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = atan(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.sigmoid")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.sigmoid").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = sigmoid(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.sqrt")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.sqrt").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = sqrt(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.rsqrt")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
-*rv = rsqrt(args[0]);
- });
+TVM_REGISTER_GLOBAL("topi.rsqrt").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = rsqrt(args[0]);
+});
-TVM_REGISTER_GLOBAL("topi.log")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = log(args[0]);
- });
+TVM_REGISTER_GLOBAL("topi.log").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = log(args[0]); });
-TVM_REGISTER_GLOBAL("topi.log2")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.log2").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = log2(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.log10")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.log10").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = log10(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.identity")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.identity").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = identity(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.negative")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.negative").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = negative(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.clip")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.clip").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = clip(args[0], args[1], args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.cast")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cast").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = cast(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.reinterpret")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("topi.reinterpret").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = reinterpret(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.elemwise_sum")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.elemwise_sum").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = elemwise_sum(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.sign")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.sign").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = sign(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.full")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.full").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = full(args[0], args[1], args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.full_like")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.full_like").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = full_like(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.logical_not")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.logical_not").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = logical_not(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.bitwise_not")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.bitwise_not").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = bitwise_not(args[0]);
- });
+});
} // namespace topi
*/
/*!
-* \brief Registration of NN operators
-* \file nn.cc
-*/
-#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/registry.h>
-
+ * \brief Registration of NN operators
+ * \file nn.cc
+ */
#include <topi/nn.h>
+#include <topi/nn/batch_matmul.h>
#include <topi/nn/bias_add.h>
#include <topi/nn/bnn.h>
#include <topi/nn/dense.h>
#include <topi/nn/dilate.h>
#include <topi/nn/flatten.h>
+#include <topi/nn/local_response_norm.h>
#include <topi/nn/mapping.h>
#include <topi/nn/pooling.h>
#include <topi/nn/softmax.h>
-#include <topi/nn/local_response_norm.h>
-#include <topi/nn/batch_matmul.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
namespace topi {
using namespace tvm::runtime;
/* Ops from nn.h */
-TVM_REGISTER_GLOBAL("topi.nn.relu")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.relu").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = relu<float>(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.leaky_relu").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = leaky_relu(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.nn.prelu")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.prelu").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = prelu(args[0], args[1], args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.nn.pad")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.pad").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = pad(args[0], args[1], args[2], args[3]);
- });
+});
/* Ops from nn/dense.h */
-TVM_REGISTER_GLOBAL("topi.nn.dense")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.dense").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::dense(args[0], args[1], args[2], args[3]);
- });
+});
/* Ops from nn/bias_add.h */
-TVM_REGISTER_GLOBAL("topi.nn.bias_add")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::bias_add(args[0], args[1], args[2]);
- });
+});
/* Ops from nn/batch_matmul.h */
-TVM_REGISTER_GLOBAL("topi.nn.batch_matmul")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::batch_matmul(args[0], args[1]);
- });
+});
/* Ops from nn/dilate.h */
-TVM_REGISTER_GLOBAL("topi.nn.dilate")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::dilate(args[0], args[1]);
- });
+});
/* Ops from nn/flatten.h */
-TVM_REGISTER_GLOBAL("topi.nn.flatten")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.flatten").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::flatten(args[0]);
- });
+});
/* Ops from nn/mapping.h */
-TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::scale_shift_nchw(args[0], args[1], args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::scale_shift_nhwc(args[0], args[1], args[2]);
- });
+});
/* Ops from nn/pooling.h */
-TVM_REGISTER_GLOBAL("topi.nn.pool")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.pool").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::pool(args[0], args[1], args[2], args[3],
- static_cast<nn::PoolType>(static_cast<int>(args[4])),
- args[5], args[6], args[7]);
- });
+ static_cast<nn::PoolType>(static_cast<int>(args[4])), args[5], args[6], args[7]);
+});
-TVM_REGISTER_GLOBAL("topi.nn.pool_grad")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4],
- static_cast<nn::PoolType>(static_cast<int>(args[5])),
- args[6], args[7], args[8]);
- });
-
-TVM_REGISTER_GLOBAL("topi.nn.global_pool")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = nn::global_pool(args[0],
- static_cast<nn::PoolType>(static_cast<int>(args[1])), args[2]);
- });
-
-TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = nn::adaptive_pool(args[0], args[1],
- static_cast<nn::PoolType>(static_cast<int>(args[2])),
+ static_cast<nn::PoolType>(static_cast<int>(args[5])), args[6], args[7],
+ args[8]);
+});
+
+TVM_REGISTER_GLOBAL("topi.nn.global_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = nn::global_pool(args[0], static_cast<nn::PoolType>(static_cast<int>(args[1])), args[2]);
+});
+
+TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = nn::adaptive_pool(args[0], args[1], static_cast<nn::PoolType>(static_cast<int>(args[2])),
args[3]);
});
-TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = nn::adaptive_pool3d(args[0], args[1],
- static_cast<nn::PoolType>(static_cast<int>(args[2])),
+TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = nn::adaptive_pool3d(args[0], args[1], static_cast<nn::PoolType>(static_cast<int>(args[2])),
args[3]);
});
-TVM_REGISTER_GLOBAL("topi.nn.pool1d")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.pool1d").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::pool1d(args[0], args[1], args[2], args[3],
- static_cast<nn::PoolType>(static_cast<int>(args[4])),
- args[5], args[6], args[7]);
- });
+ static_cast<nn::PoolType>(static_cast<int>(args[4])), args[5], args[6], args[7]);
+});
-TVM_REGISTER_GLOBAL("topi.nn.pool3d")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::pool3d(args[0], args[1], args[2], args[3],
- static_cast<nn::PoolType>(static_cast<int>(args[4])),
- args[5], args[6], args[7]);
- });
+ static_cast<nn::PoolType>(static_cast<int>(args[4])), args[5], args[6], args[7]);
+});
/* Ops from nn/softmax.h */
-TVM_REGISTER_GLOBAL("topi.nn.softmax")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::softmax(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.nn.log_softmax")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.log_softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::log_softmax(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.nn.lrn")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = nn::lrn(args[0], args[1], args[2],
- static_cast<double>(args[3]),
- static_cast<double>(args[4]),
- static_cast<double>(args[5]));
- });
+TVM_REGISTER_GLOBAL("topi.nn.lrn").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = nn::lrn(args[0], args[1], args[2], static_cast<double>(args[3]),
+ static_cast<double>(args[4]), static_cast<double>(args[5]));
+});
/* Ops from nn/bnn.h */
-TVM_REGISTER_GLOBAL("topi.nn.binarize_pack")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::binarize_pack(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.nn.binary_dense")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::binary_dense(args[0], args[1]);
- });
+});
} // namespace topi
*/
/*!
-* \brief Registration of reduction operators
-* \file reduction.cc
-*/
-#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/registry.h>
-
+ * \brief Registration of reduction operators
+ * \file reduction.cc
+ */
#include <topi/reduction.h>
#include <topi/util.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
-TVM_REGISTER_GLOBAL("topi.sum")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.sum").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::sum(args[0], ArrayOrInt(args[1]), args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.min")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.min").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::min(args[0], ArrayOrInt(args[1]), args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.max")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.max").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::max(args[0], ArrayOrInt(args[1]), args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.argmin")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.argmin").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.argmax")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.argmax").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.prod")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.prod").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.all")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.all").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.any")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
- });
+});
} // namespace topi
*/
/*!
-* \brief Registration of TVM schedules
-* \file schedule.cc
-*/
+ * \brief Registration of TVM schedules
+ * \file schedule.cc
+ */
#define TOPI_REDUCE_ATLEAST1D 0
-#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/module.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/ir/expr.h>
-#include <tvm/target/generic_func.h>
-
-#include <topi/generic/default.h>
-#include <topi/generic/extern.h>
-#include <topi/generic/injective.h>
-
#include <topi/cuda/dense.h>
#include <topi/cuda/injective.h>
+#include <topi/cuda/normalization.h>
#include <topi/cuda/pooling.h>
#include <topi/cuda/reduction.h>
#include <topi/cuda/softmax.h>
-#include <topi/cuda/normalization.h>
-
-#include <topi/x86/bnn.h>
-#include <topi/x86/default.h>
-#include <topi/x86/injective.h>
-
+#include <topi/detail/tensor_utils.h>
+#include <topi/generic/default.h>
+#include <topi/generic/extern.h>
+#include <topi/generic/injective.h>
#include <topi/rocm/dense.h>
#include <topi/rocm/injective.h>
+#include <topi/rocm/normalization.h>
#include <topi/rocm/pooling.h>
#include <topi/rocm/reduction.h>
#include <topi/rocm/softmax.h>
-#include <topi/rocm/normalization.h>
-
-#include <topi/detail/tensor_utils.h>
+#include <topi/x86/bnn.h>
+#include <topi/x86/default.h>
+#include <topi/x86/injective.h>
+#include <tvm/ir/expr.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/generic_func.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
-TVM_REGISTER_GLOBAL("topi.TEST_create_target")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.TEST_create_target").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = tvm::Target::Create(args[0]);
- });
+});
/* Generic schedules */
-TVM_REGISTER_GLOBAL("topi.generic.default_schedule")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.generic.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args[2]) {
*rv = topi::generic::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::generic::default_schedule(args[0], args[1]);
}
- });
+});
-TVM_REGISTER_GLOBAL("topi.generic.schedule_extern")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.generic.schedule_extern").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::generic::schedule_extern(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.generic.schedule_injective")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.generic.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
- });
+});
TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
- });
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
+ });
/* x86 schedules */
-TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::x86::schedule_binarize_pack(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::x86::schedule_binary_dense(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.x86.default_schedule")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.x86.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args[2]) {
*rv = topi::x86::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::x86::default_schedule(args[0], args[1]);
}
- });
+});
-TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.x86.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::x86::schedule_injective(args[0], args[1]);
- });
+});
TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
- });
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
+ });
/* ROCm schedules */
-TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::rocm::schedule_dense(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::rocm::schedule_injective(args[0], args[1]);
- });
+});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]);
- });
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]);
+ });
-TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::rocm::schedule_pool(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::rocm::schedule_global_pool(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::rocm::schedule_reduce(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::rocm::schedule_softmax(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::rocm::schedule_lrn(args[0]);
- });
+});
/* CUDA schedules */
-TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::cuda::schedule_dense(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::cuda::schedule_injective(args[0], args[1]);
- });
+});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
- });
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
+ });
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::cuda::schedule_pool(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::cuda::schedule_global_pool(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::cuda::schedule_reduce(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::cuda::schedule_softmax(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::cuda::schedule_lrn(args[0]);
- });
+});
/* Utility functions */
-TVM_REGISTER_GLOBAL("topi.util.is_empty_shape")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.util.is_empty_shape").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::detail::is_empty_shape(args[0]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
- });
+});
/*! \brief Builder function for instantiating schedules. */
-using FTVMScheduleBuilder = std::function<
- tvm::te::Schedule(const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
+using FTVMScheduleBuilder = std::function<tvm::te::Schedule(
+ const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
/*!
* \brief Helper function for registering generic functions matching the
if (argNodeRef->type_index() == outs->type_index()) {
outs = args[0];
} else {
- outs = Array<Tensor> { args[0] };
+ outs = Array<Tensor>{args[0]};
}
*ret = builder(target, outs);
}
TVM_REGISTER_GENERIC_FUNC(schedule_injective)
-.set_default(WrapSchedule(topi::generic::schedule_injective))
-.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective))
-.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective));
+ .set_default(WrapSchedule(topi::generic::schedule_injective))
+ .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_injective))
+ .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_injective));
TVM_REGISTER_GENERIC_FUNC(schedule_softmax)
-.set_default(WrapSchedule(topi::generic::default_schedule))
-.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
-.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax));
+ .set_default(WrapSchedule(topi::generic::default_schedule))
+ .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
+ .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_softmax));
TVM_REGISTER_GENERIC_FUNC(schedule_dense)
-.set_default(WrapSchedule(topi::generic::default_schedule))
-.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense))
-.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense));
+ .set_default(WrapSchedule(topi::generic::default_schedule))
+ .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_dense))
+ .register_func({"rocm"}, WrapSchedule(topi::rocm::schedule_dense));
TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
-.set_default(WrapSchedule(topi::generic::default_schedule));
+ .set_default(WrapSchedule(topi::generic::default_schedule));
TVM_REGISTER_GENERIC_FUNC(schedule_pool)
-.set_default(WrapSchedule(topi::generic::default_schedule))
-.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
-.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool));
+ .set_default(WrapSchedule(topi::generic::default_schedule))
+ .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
+ .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_global_pool)
-.set_default(WrapSchedule(topi::generic::default_schedule))
-.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
-.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool));
+ .set_default(WrapSchedule(topi::generic::default_schedule))
+ .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
+ .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_global_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_reduce)
-.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
-.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline))
-.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce));
+ .set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
+ .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule_auto_inline))
+ .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_reduce));
TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack)
-.set_default(WrapSchedule(topi::generic::default_schedule))
-.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack));
+ .set_default(WrapSchedule(topi::generic::default_schedule))
+ .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binarize_pack));
TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
-.set_default(WrapSchedule(topi::generic::default_schedule))
-.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense));
+ .set_default(WrapSchedule(topi::generic::default_schedule))
+ .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binary_dense));
/*! \brief Builder function for instantiating schedules from existing schedules. */
-using FTVMScheduleFromExistingBuilder = std::function<
- tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>;
+using FTVMScheduleFromExistingBuilder =
+ std::function<tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>;
/*!
* \brief Helper function for registering generic functions matching the
* \return The wrapped schedule builder
*/
inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) {
- return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
- *ret = builder(args[0], args[1]);
- });
+ return PackedFunc(
+ [builder](TVMArgs args, TVMRetValue* ret) { *ret = builder(args[0], args[1]); });
}
TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
-.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
-.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
-.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(
- topi::cuda::schedule_injective_from_existing));
+ .set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
+ .register_func({"cpu"}, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
+ .register_func({"cuda", "gpu"},
+ WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing));
/*! \brief Builder function for instantiating dense ops. */
-using FTVMDenseOpBuilder = std::function<tvm::te::Tensor(const Target& target,
- const tvm::te::Tensor& data,
- const tvm::te::Tensor& weight,
- const tvm::te::Tensor& bias,
- const DataType& out_dtype)>;
+using FTVMDenseOpBuilder = std::function<tvm::te::Tensor(
+ const Target& target, const tvm::te::Tensor& data, const tvm::te::Tensor& weight,
+ const tvm::te::Tensor& bias, const DataType& out_dtype)>;
/*!
-* \brief Helper function for registering dense ops matching the
-* FTVMDenseOpBuilder signature. The op builder function is wrapped
-* with a PackedFunc suitable for passing to a tvm::GenericFunc.
-*
-* \param builder The op builder to wrap.
-*
-* \return The wrapped op builder
-*/
+ * \brief Helper function for registering dense ops matching the
+ * FTVMDenseOpBuilder signature. The op builder function is wrapped
+ * with a PackedFunc suitable for passing to a tvm::GenericFunc.
+ *
+ * \param builder The op builder to wrap.
+ *
+ * \return The wrapped op builder
+ */
inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::Current(false);
}
TVM_REGISTER_GENERIC_FUNC(dense)
-.set_default(WrapDenseOp([](const Target& target,
- const tvm::te::Tensor& data,
- const tvm::te::Tensor& weight,
- const tvm::te::Tensor& bias,
- const DataType& out_dtype) {
- return topi::nn::dense(data, weight, bias, out_dtype);
-}))
-.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda))
-.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm));
+ .set_default(WrapDenseOp([](const Target& target, const tvm::te::Tensor& data,
+ const tvm::te::Tensor& weight, const tvm::te::Tensor& bias,
+ const DataType& out_dtype) {
+ return topi::nn::dense(data, weight, bias, out_dtype);
+ }))
+ .register_func({"cuda", "gpu"}, WrapDenseOp(topi::cuda::dense_cuda))
+ .register_func({"rocm"}, WrapDenseOp(topi::rocm::dense_rocm));
} // namespace topi
*/
/*!
-* \brief Registration of transform operators
-* \file transform.cc
-*/
-#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/registry.h>
-
+ * \brief Registration of transform operators
+ * \file transform.cc
+ */
#include <topi/transform.h>
#include <topi/util.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
-TVM_REGISTER_GLOBAL("topi.expand_dims")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.expand_dims").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = expand_dims(args[0], args[1], args[2]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.transpose")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.transpose").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = transpose(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.flip")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.flip").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = flip(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.reshape")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.reshape").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = reshape(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.squeeze")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.squeeze").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = squeeze(args[0], ArrayOrInt(args[1]));
- });
+});
-TVM_REGISTER_GLOBAL("topi.concatenate")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.concatenate").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = concatenate(args[0], args[1]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.stack")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.stack").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = stack(args[0], args[1]);
});
-TVM_REGISTER_GLOBAL("topi.shape")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.shape").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = shape(args[0], args[1]);
});
-TVM_REGISTER_GLOBAL("topi.ndarray_size")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = ndarray_size(args[0], args[1]);
});
-TVM_REGISTER_GLOBAL("topi.split")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
*rv = split_sections(args[0], args[1], args[2]);
} else {
}
});
-TVM_REGISTER_GLOBAL("topi.layout_transform")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = layout_transform(args[0], args[1], args[2]);
});
-TVM_REGISTER_GLOBAL("topi.take")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args.size() == 3) {
std::string mode = args[2];
*rv = take(args[0], args[1], mode);
std::string mode = args[3];
*rv = take(args[0], args[1], axis, mode);
}
- });
+});
-TVM_REGISTER_GLOBAL("topi.sequence_mask")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.sequence_mask").set_body([](TVMArgs args, TVMRetValue* rv) {
double pad_val = args[2];
int axis = args[3];
*rv = sequence_mask(args[0], args[1], pad_val, axis);
});
-TVM_REGISTER_GLOBAL("topi.where")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.where").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = where(args[0], args[1], args[2]);
});
-TVM_REGISTER_GLOBAL("topi.arange")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.arange").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = arange(args[0], args[1], args[2], args[3]);
});
-TVM_REGISTER_GLOBAL("topi.repeat")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.repeat").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = repeat(args[0], args[1], args[2]);
});
-TVM_REGISTER_GLOBAL("topi.tile")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.tile").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = tile(args[0], args[1]);
});
-TVM_REGISTER_GLOBAL("topi.gather_nd")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = gather_nd(args[0], args[1]);
});
-TVM_REGISTER_GLOBAL("topi.unravel_index")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = unravel_index(args[0], args[1]);
- });
-
-TVM_REGISTER_GLOBAL("topi.matmul")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- switch ( args.size() ) {
- case 2: *rv = matmul(args[0], args[1]); break;
- case 3: *rv = matmul(args[0], args[1], args[2]); break;
- case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break;
- default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments";
- }});
-
-TVM_REGISTER_GLOBAL("topi.tensordot")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+});
+
+TVM_REGISTER_GLOBAL("topi.matmul").set_body([](TVMArgs args, TVMRetValue* rv) {
+ switch (args.size()) {
+ case 2:
+ *rv = matmul(args[0], args[1]);
+ break;
+ case 3:
+ *rv = matmul(args[0], args[1], args[2]);
+ break;
+ case 4:
+ *rv = matmul(args[0], args[1], args[2], args[3]);
+ break;
+ default:
+ CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments";
+ }
+});
+
+TVM_REGISTER_GLOBAL("topi.tensordot").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args.size() == 2) {
*rv = tensordot(args[0], args[1]);
} else if (args.size() == 3) {
Array<PrimExpr> axes = args[3];
*rv = tensordot(args[0], args[1], args[2], axes);
}
- });
+});
-TVM_REGISTER_GLOBAL("topi.strided_slice")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]);
- });
+});
-TVM_REGISTER_GLOBAL("topi.one_hot")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) {
int depth = args[3];
int axis = args[4];
DataType dtype = args[5];
*rv = one_hot(args[0], args[1], args[2], depth, axis, dtype);
- });
+});
} // namespace topi
*/
/*!
-* \brief Registration of vision operators
-* \file vision.cc
-*/
+ * \brief Registration of vision operators
+ * \file vision.cc
+ */
+#include <topi/vision/reorg.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include <topi/vision/reorg.h>
-
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
-TVM_REGISTER_GLOBAL("topi.vision.reorg")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+TVM_REGISTER_GLOBAL("topi.vision.reorg").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = vision::reorg(args[0], args[1]);
- });
+});
} // namespace topi
* \brief TVM device API for VTA
*/
-#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
+#include <tvm/runtime/registry.h>
-#include "runtime.h"
#include "../../src/runtime/workspace_pool.h"
-
+#include "runtime.h"
namespace tvm {
namespace runtime {
}
}
- void* AllocDataSpace(TVMContext ctx,
- size_t size,
- size_t alignment,
- DLDataType type_hint) final {
+ void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) final {
return VTABufferAlloc(size);
}
- void FreeDataSpace(TVMContext ctx, void* ptr) final {
- VTABufferFree(ptr);
- }
+ void FreeDataSpace(TVMContext ctx, void* ptr) final { VTABufferFree(ptr); }
- void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final {
int kind_mask = 0;
if (ctx_from.device_type != kDLCPU) {
if (ctx_to.device_type != kDLCPU) {
kind_mask |= 1;
}
- VTABufferCopy(from, from_offset,
- to, to_offset,
- size, kind_mask);
+ VTABufferCopy(from, from_offset, to, to_offset, size, kind_mask);
}
- void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
- }
+ void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {}
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
static const std::shared_ptr<VTADeviceAPI>& Global() {
- static std::shared_ptr<VTADeviceAPI> inst =
- std::make_shared<VTADeviceAPI>();
+ static std::shared_ptr<VTADeviceAPI> inst = std::make_shared<VTADeviceAPI>();
return inst;
}
};
struct VTAWorkspacePool : public WorkspacePool {
- VTAWorkspacePool() :
- WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {}
+ VTAWorkspacePool() : WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {}
};
void* VTADeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) {
- return dmlc::ThreadLocalStore<VTAWorkspacePool>::Get()
- ->AllocWorkspace(ctx, size);
+ return dmlc::ThreadLocalStore<VTAWorkspacePool>::Get()->AllocWorkspace(ctx, size);
}
void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
// Register device api with override.
static TVM_ATTRIBUTE_UNUSED auto& __register_dev__ =
-::tvm::runtime::Registry::Register("device_api.ext_dev", true)
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = VTADeviceAPI::Global().get();
- *rv = static_cast<void*>(ptr);
- });
+ ::tvm::runtime::Registry::Register("device_api.ext_dev", true)
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = VTADeviceAPI::Global().get();
+ *rv = static_cast<void*>(ptr);
+ });
} // namespace runtime
} // namespace tvm
* The runtime depends on specific instruction
* stream spec as specified in hw_spec.h
*/
-#include <vta/driver.h>
-#include <vta/hw_spec.h>
+#include "runtime.h"
+
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
+#include <vta/driver.h>
+#include <vta/hw_spec.h>
#include <algorithm>
#include <cassert>
#include <cstring>
-#include <vector>
#include <memory>
-
-#include "runtime.h"
+#include <vector>
namespace vta {
// Avoid bad configurations.
-static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8,
- "VTA_UOP_WIDTH do not match VTAUop size");
+static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8, "VTA_UOP_WIDTH do not match VTAUop size");
/*! \brief Enable coherent access of data buffers between VTA and CPU */
static const bool kBufferCoherent = VTA_COHERENT_ACCESSES;
*/
struct DataBuffer {
/*! \return Virtual address of the data. */
- void* virt_addr() const {
- return data_;
- }
+ void* virt_addr() const { return data_; }
/*! \return Physical address of the data. */
- vta_phy_addr_t phy_addr() const {
- return phy_addr_;
- }
+ vta_phy_addr_t phy_addr() const { return phy_addr_; }
/*!
* \brief Invalidate the cache of given location in data buffer.
* \param offset The offset to the data.
*/
void InvalidateCache(size_t offset, size_t size) {
if (!kBufferCoherent && kAlwaysCache) {
- VTAInvalidateCache(reinterpret_cast<char *>(data_) + offset,
- phy_addr_ + offset,
- size);
+ VTAInvalidateCache(reinterpret_cast<char*>(data_) + offset, phy_addr_ + offset, size);
}
}
/*!
*/
void FlushCache(size_t offset, size_t size) {
if (!kBufferCoherent && kAlwaysCache) {
- VTAFlushCache(reinterpret_cast<char *>(data_) + offset,
- phy_addr_ + offset,
- size);
+ VTAFlushCache(reinterpret_cast<char*>(data_) + offset, phy_addr_ + offset, size);
}
}
/*!
* \brief Performs a copy operation from host memory to buffer allocated with VTAMemAlloc.
- * \param dst The desination buffer in FPGA-accessible memory. Has to be allocated with VTAMemAlloc().
- * \param src The source buffer in host memory.
- * \param size Size of the region in Bytes.
+ * \param dst The desination buffer in FPGA-accessible memory. Has to be allocated with
+ * VTAMemAlloc(). \param src The source buffer in host memory. \param size Size of the region in
+ * Bytes.
*/
void MemCopyFromHost(void* dst, const void* src, size_t size) {
VTAMemCopyFromHost(dst, src, size);
* \param src The source buffer in FPGA-accessible memory. Has to be allocated with VTAMemAlloc().
* \param size Size of the region in Bytes.
*/
- void MemCopyToHost(void* dst, const void* src, size_t size) {
- VTAMemCopyToHost(dst, src, size);
- }
+ void MemCopyToHost(void* dst, const void* src, size_t size) { VTAMemCopyToHost(dst, src, size); }
/*!
* \brief Allocate a buffer of a given size.
* \param size The size of the buffer.
* \return The corresponding data buffer header.
*/
static DataBuffer* FromHandle(const void* buffer) {
- return const_cast<DataBuffer*>(
- reinterpret_cast<const DataBuffer*>(buffer));
+ return const_cast<DataBuffer*>(reinterpret_cast<const DataBuffer*>(buffer));
}
private:
* \param signature The pointer to signature.
* \param nbytes Number of bytes.
*/
- UopKernel(const char* signature, int nbytes)
- : signature_(signature, signature + nbytes) {
- }
+ UopKernel(const char* signature, int nbytes) : signature_(signature, signature + nbytes) {}
/*!
* \brief Verify if the signature is correct.
* \param signature Signature ptr.
return memcmp(signature, signature_.data(), nbytes) == 0;
}
/*! \return Whether the kernel is cached in SRAM. */
- bool cached() const {
- return sram_begin_ != sram_end_;
- }
+ bool cached() const { return sram_begin_ != sram_end_; }
/*! \return The length of the micro op sequence. */
- size_t size() const {
- return seq_.size();
- }
+ size_t size() const { return seq_.size(); }
/*! \return The micro-op data. */
- const VTAUop* data() const {
- return seq_.data();
- }
+ const VTAUop* data() const { return seq_.data(); }
/*! \return The loop structure. */
- const std::vector<LoopEntry>& loop() const {
- return loop_;
- }
+ const std::vector<LoopEntry>& loop() const { return loop_; }
/*!
* \brief Declare loop start.
* \param extent The loop extent.
* \param src_factor Loop factor of input index
* \param wgt_factor Loop factor of weight index.
*/
- void PushLoopBegin(uint32_t extent,
- uint32_t dst_factor,
- uint32_t src_factor,
+ void PushLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor,
uint32_t wgt_factor) {
LoopEntry le;
le.extent = extent;
/*!
* \brief Declare loop end.
*/
- void PushLoopEnd() {
- --loop_ptr_;
- }
+ void PushLoopEnd() { --loop_ptr_; }
/*!
* \brief Push micro op into kernel.
* \param mode Set to GEMM mode if set to 0, ALU mode is set to 1.
* \param use_imm Use immediate in ALU mode if set to true.
* \param imm_val Immediate value in ALU mode.
*/
- void Push(uint32_t mode,
- uint32_t reset_out,
- uint32_t dst_index,
- uint32_t src_index,
- uint32_t wgt_index,
- uint32_t opcode,
- uint32_t use_imm,
- int32_t imm_val) {
+ void Push(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index,
+ uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) {
// The loop nest structure
VerifyDep(dst_index);
VTAUop op;
uint32_t size = seq_.size();
printf("There are %u uops\n", size);
for (uint32_t i = 0; i < size; ++i) {
- printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n",
- i,
- seq_[i].dst_idx,
- seq_[i].src_idx,
+ printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n", i, seq_[i].dst_idx, seq_[i].src_idx,
seq_[i].wgt_idx);
}
printf("\n");
}
}
// The uop buffer
- template<int, bool, bool>
+ template <int, bool, bool>
friend class UopQueue;
friend class CommandQueue;
// SRAM location if begin != end
}
}
/*! \return Content of DRAM buffer. */
- char* dram_buffer() const {
- return dram_buffer_;
- }
+ char* dram_buffer() const { return dram_buffer_; }
/*! \return Physical address of DRAM. */
vta_phy_addr_t dram_phy_addr() const {
CHECK(fpga_buff_phy_);
return fpga_buff_phy_;
}
/*! \return Whether there is pending information. */
- bool pending() const {
- return sram_begin_ != sram_end_;
- }
+ bool pending() const { return sram_begin_ != sram_end_; }
/*! \brief Initialize the space of the buffer. */
void InitSpace(uint32_t elem_bytes, uint32_t max_bytes, bool coherent, bool always_cache) {
coherent_ = coherent;
always_cache_ = always_cache;
elem_bytes_ = elem_bytes;
// Allocate buffer ahead of time
- fpga_buff_ = static_cast<char*>(VTAMemAlloc(
- max_bytes, coherent_ || always_cache_));
+ fpga_buff_ = static_cast<char*>(VTAMemAlloc(max_bytes, coherent_ || always_cache_));
CHECK(fpga_buff_ != nullptr);
fpga_buff_phy_ = VTAMemGetPhyAddr(fpga_buff_);
}
/*!
* \brief Micro op buffer that manages the micro op cache.
*/
-template<int kMaxBytes, bool kCoherent, bool kAlwaysCache>
+template <int kMaxBytes, bool kCoherent, bool kAlwaysCache>
class UopQueue : public BaseQueue<VTAUop> {
public:
- void InitSpace() {
- BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache);
- }
+ void InitSpace() { BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache); }
// Push data to the queue
- template<typename FAutoSync>
+ template <typename FAutoSync>
void Push(UopKernel* kernel, FAutoSync fautosync) {
// if the micro-op is cached in VTA SRAM, skip
if (kernel->cached()) return;
cache_idx_ = 0;
BaseQueue<VTAUop>::Reset();
}
- void AutoReadBarrier() {
- ReadBarrier();
- }
+ void AutoReadBarrier() { ReadBarrier(); }
/*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */
void ReadBarrier() {
CHECK(fpga_buff_ != nullptr);
uint32_t offset = 0;
for (uint32_t i = 0; i < cache_.size(); ++i) {
uint32_t ksize = cache_[i]->size() * kElemBytes;
- VTAMemCopyFromHost(static_cast<char*>(fpga_buff_) + offset,
- cache_[i]->data(),
- ksize);
+ VTAMemCopyFromHost(static_cast<char*>(fpga_buff_) + offset, cache_[i]->data(), ksize);
// Update offset
offset += ksize;
}
// Flush if we're using a shared memory system
// and if interface is non-coherent
if (!coherent_ && always_cache_) {
- VTAFlushCache(fpga_buff_,
- fpga_buff_phy_,
- offset);
+ VTAFlushCache(fpga_buff_, fpga_buff_phy_, offset);
}
}
class UopKernelMap {
public:
// Simple hash map
- UopKernel** Get(void* signature,
- int nbytes) {
+ UopKernel** Get(void* signature, int nbytes) {
uint32_t key = 0;
CHECK(nbytes == 0 || nbytes == sizeof(int));
if (nbytes == sizeof(int)) {
std::vector<UopKernel*> kmap_;
};
-enum PipelineStage : int {
- kNoneStage = 0,
- kLoadStage = 1,
- kComputeStage = 2,
- kStoreStage = 3
-};
+enum PipelineStage : int { kNoneStage = 0, kLoadStage = 1, kComputeStage = 2, kStoreStage = 3 };
// Instruction Queue
-template<int kMaxBytes, bool kCoherent, bool kAlwaysCache>
+template <int kMaxBytes, bool kCoherent, bool kAlwaysCache>
class InsnQueue : public BaseQueue<VTAGenericInsn> {
public:
/*! \brief Initialize the space. */
std::fill(pending_pop_next_, pending_pop_next_ + 4, 0);
}
/*! \return The data pointer. */
- VTAGenericInsn* data() {
- return dram_buffer_.data();
- }
+ VTAGenericInsn* data() { return dram_buffer_.data(); }
/*! \return Number of instructions. */
- uint32_t count() {
- return dram_buffer_.size();
- }
+ uint32_t count() { return dram_buffer_.size(); }
// Insert dependency push of load
void DepPop(int from, int to) {
// NOTE: This instruction executes on queue[to]
if (GetPipelineStage(mptr) == from) {
if (from < to && !mptr->push_next_dep) {
// push(LD->C) or push(C->ST)
- mptr->push_next_dep = true; return;
+ mptr->push_next_dep = true;
+ return;
} else if (from > to && !mptr->push_prev_dep) {
// push(C->LD) or push(ST->C)
- mptr->push_prev_dep = true; return;
+ mptr->push_prev_dep = true;
+ return;
}
}
}
}
}
// Create a new instruction for a GEMM stage
- VTAGemInsn* CreateGemInsn() {
- return reinterpret_cast<VTAGemInsn*>(
- Create(kComputeStage));
- }
+ VTAGemInsn* CreateGemInsn() { return reinterpret_cast<VTAGemInsn*>(Create(kComputeStage)); }
// Create a new instruction for a ALU stage
- VTAAluInsn* CreateAluInsn() {
- return reinterpret_cast<VTAAluInsn*>(
- Create(kComputeStage));
- }
+ VTAAluInsn* CreateAluInsn() { return reinterpret_cast<VTAAluInsn*>(Create(kComputeStage)); }
// Create a new instruction for a memory stage
VTAMemInsn* CreateMemInsn(int memory_type) {
- return reinterpret_cast<VTAMemInsn*>(
- Create(GetMemPipelineStage(memory_type)));
+ return reinterpret_cast<VTAMemInsn*>(Create(GetMemPipelineStage(memory_type)));
}
// create a new instruction for a store stage
- VTAMemInsn* CreateStoreInsn() {
- return reinterpret_cast<VTAMemInsn*>(
- Create(kStoreStage));
- }
+ VTAMemInsn* CreateStoreInsn() { return reinterpret_cast<VTAMemInsn*>(Create(kStoreStage)); }
// Rewrite instruction stream to force serial execution
void RewriteForceSerial() {
int insn_count = count();
}
CommitPendingPop(kComputeStage);
} else {
- pending_pop_next_[kComputeStage] = 0;
+ pending_pop_next_[kComputeStage] = 0;
}
DepPush(kComputeStage, kLoadStage);
DepPop(kLoadStage, kComputeStage);
}
// Helper function: Get Opcode string
const char* getOpcodeString(int opcode, bool use_imm) {
- // The string name
- if (opcode == VTA_ALU_OPCODE_MIN) {
- if (use_imm) {
- return "min imm";
- } else {
- return "min";
- }
- } else if (opcode == VTA_ALU_OPCODE_MAX) {
- if (use_imm) {
- return "max imm";
- } else {
- return "max";
- }
- } else if (opcode == VTA_ALU_OPCODE_ADD) {
- if (use_imm) {
- return "add imm";
- } else {
- return "add";
- }
- } else if (opcode == VTA_ALU_OPCODE_SHR) {
- return "shr";
+ // The string name
+ if (opcode == VTA_ALU_OPCODE_MIN) {
+ if (use_imm) {
+ return "min imm";
+ } else {
+ return "min";
}
+ } else if (opcode == VTA_ALU_OPCODE_MAX) {
+ if (use_imm) {
+ return "max imm";
+ } else {
+ return "max";
+ }
+ } else if (opcode == VTA_ALU_OPCODE_ADD) {
+ if (use_imm) {
+ return "add imm";
+ } else {
+ return "add";
+ }
+ } else if (opcode == VTA_ALU_OPCODE_SHR) {
+ return "shr";
+ }
- return "unknown op";
+ return "unknown op";
}
// Dump instructions in the queue
void DumpInsn() {
printf("NOP-MEMORY-STAGE\n");
}
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
- static_cast<int>(c.mem.pop_prev_dep),
- static_cast<int>(c.mem.pop_next_dep),
- static_cast<int>(c.mem.push_prev_dep),
- static_cast<int>(c.mem.push_next_dep));
+ static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_next_dep),
+ static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_next_dep));
// Count status in queues
if (c.mem.opcode == VTA_OPCODE_STORE) {
CHECK(c.mem.pop_next_dep == false);
if (c.mem.pop_prev_dep) g2s_queue--;
if (c.mem.push_prev_dep) s2g_queue++;
} else if (c.mem.opcode == VTA_OPCODE_LOAD &&
- (c.mem.memory_type == VTA_MEM_ID_INP ||
- c.mem.memory_type == VTA_MEM_ID_WGT) ) {
+ (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) {
CHECK(c.mem.pop_prev_dep == false);
CHECK(c.mem.push_prev_dep == false);
if (c.mem.pop_next_dep) g2l_queue--;
printf("STORE:\n");
}
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
- static_cast<int>(c.mem.pop_prev_dep),
- static_cast<int>(c.mem.pop_next_dep),
- static_cast<int>(c.mem.push_prev_dep),
- static_cast<int>(c.mem.push_next_dep));
- printf("\tDRAM: 0x%08x, SRAM:0x%04x\n",
- static_cast<int>(c.mem.dram_base),
+ static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_next_dep),
+ static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_next_dep));
+ printf("\tDRAM: 0x%08x, SRAM:0x%04x\n", static_cast<int>(c.mem.dram_base),
static_cast<int>(c.mem.sram_base));
- printf("\ty: size=%d, pad=[%d, %d]\n",
- static_cast<int>(c.mem.y_size),
- static_cast<int>(c.mem.y_pad_0),
- static_cast<int>(c.mem.y_pad_1));
- printf("\tx: size=%d, stride=%d, pad=[%d, %d]\n",
- static_cast<int>(c.mem.x_size),
- static_cast<int>(c.mem.x_stride),
- static_cast<int>(c.mem.x_pad_0),
+ printf("\ty: size=%d, pad=[%d, %d]\n", static_cast<int>(c.mem.y_size),
+ static_cast<int>(c.mem.y_pad_0), static_cast<int>(c.mem.y_pad_1));
+ printf("\tx: size=%d, stride=%d, pad=[%d, %d]\n", static_cast<int>(c.mem.x_size),
+ static_cast<int>(c.mem.x_stride), static_cast<int>(c.mem.x_pad_0),
static_cast<int>(c.mem.x_pad_1));
} else if (c.mem.opcode == VTA_OPCODE_GEMM) {
// Print instruction field information
printf("GEMM\n");
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
- static_cast<int>(c.mem.pop_prev_dep),
- static_cast<int>(c.mem.pop_next_dep),
- static_cast<int>(c.mem.push_prev_dep),
- static_cast<int>(c.mem.push_next_dep));
+ static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_next_dep),
+ static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_next_dep));
printf("\treset_out: %d\n", static_cast<int>(c.gemm.reset_reg));
- printf("\trange (%d, %d)\n",
- static_cast<int>(c.gemm.uop_bgn),
+ printf("\trange (%d, %d)\n", static_cast<int>(c.gemm.uop_bgn),
static_cast<int>(c.gemm.uop_end));
printf("\touter loop - iter: %d, wgt: %d, inp: %d, acc: %d\n",
- static_cast<int>(c.gemm.iter_out),
- static_cast<int>(c.gemm.wgt_factor_out),
- static_cast<int>(c.gemm.src_factor_out),
- static_cast<int>(c.gemm.dst_factor_out));
+ static_cast<int>(c.gemm.iter_out), static_cast<int>(c.gemm.wgt_factor_out),
+ static_cast<int>(c.gemm.src_factor_out), static_cast<int>(c.gemm.dst_factor_out));
printf("\tinner loop - iter: %d, wgt: %d, inp: %d, acc: %d\n",
- static_cast<int>(c.gemm.iter_in),
- static_cast<int>(c.gemm.wgt_factor_in),
- static_cast<int>(c.gemm.src_factor_in),
- static_cast<int>(c.gemm.dst_factor_in));
+ static_cast<int>(c.gemm.iter_in), static_cast<int>(c.gemm.wgt_factor_in),
+ static_cast<int>(c.gemm.src_factor_in), static_cast<int>(c.gemm.dst_factor_in));
} else if (c.mem.opcode == VTA_OPCODE_ALU) {
// Print instruction field information
printf("ALU - %s\n", getOpcodeString(c.alu.alu_opcode, c.alu.use_imm));
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
- static_cast<int>(c.mem.pop_prev_dep),
- static_cast<int>(c.mem.pop_next_dep),
- static_cast<int>(c.mem.push_prev_dep),
- static_cast<int>(c.mem.push_next_dep));
+ static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_next_dep),
+ static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_next_dep));
printf("\treset_out: %d\n", static_cast<int>(c.alu.reset_reg));
- printf("\trange (%d, %d)\n",
- static_cast<int>(c.alu.uop_bgn),
+ printf("\trange (%d, %d)\n", static_cast<int>(c.alu.uop_bgn),
static_cast<int>(c.alu.uop_end));
- printf("\touter loop - iter: %d, dst: %d, src: %d\n",
- static_cast<int>(c.alu.iter_out),
- static_cast<int>(c.alu.dst_factor_out),
- static_cast<int>(c.alu.src_factor_out));
- printf("\tinner loop - iter: %d, dst: %d, src: %d\n",
- static_cast<int>(c.alu.iter_in),
- static_cast<int>(c.alu.dst_factor_in),
- static_cast<int>(c.alu.src_factor_in));
+ printf("\touter loop - iter: %d, dst: %d, src: %d\n", static_cast<int>(c.alu.iter_out),
+ static_cast<int>(c.alu.dst_factor_out), static_cast<int>(c.alu.src_factor_out));
+ printf("\tinner loop - iter: %d, dst: %d, src: %d\n", static_cast<int>(c.alu.iter_in),
+ static_cast<int>(c.alu.dst_factor_in), static_cast<int>(c.alu.src_factor_in));
} else if (c.mem.opcode == VTA_OPCODE_FINISH) {
printf("FINISH\n");
}
// Count status in queues
if (c.mem.opcode == VTA_OPCODE_LOAD || c.mem.opcode == VTA_OPCODE_STORE) {
if (c.mem.opcode == VTA_OPCODE_STORE) {
- CHECK(c.mem.pop_next_dep == false);
- CHECK(c.mem.push_next_dep == false);
- if (c.mem.pop_prev_dep) g2s_queue--;
- if (c.mem.push_prev_dep) s2g_queue++;
+ CHECK(c.mem.pop_next_dep == false);
+ CHECK(c.mem.push_next_dep == false);
+ if (c.mem.pop_prev_dep) g2s_queue--;
+ if (c.mem.push_prev_dep) s2g_queue++;
} else if (c.mem.opcode == VTA_OPCODE_LOAD &&
- (c.mem.memory_type == VTA_MEM_ID_INP ||
- c.mem.memory_type == VTA_MEM_ID_WGT) ) {
- CHECK(c.mem.pop_prev_dep == false);
- CHECK(c.mem.push_prev_dep == false);
- if (c.mem.pop_next_dep) g2l_queue--;
- if (c.mem.push_next_dep) l2g_queue++;
+ (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) {
+ CHECK(c.mem.pop_prev_dep == false);
+ CHECK(c.mem.push_prev_dep == false);
+ if (c.mem.pop_next_dep) g2l_queue--;
+ if (c.mem.push_next_dep) l2g_queue++;
} else {
- if (c.mem.pop_prev_dep) l2g_queue--;
- if (c.mem.push_prev_dep) g2l_queue++;
- if (c.mem.pop_next_dep) s2g_queue--;
- if (c.mem.push_next_dep) g2s_queue++;
+ if (c.mem.pop_prev_dep) l2g_queue--;
+ if (c.mem.push_prev_dep) g2l_queue++;
+ if (c.mem.pop_next_dep) s2g_queue--;
+ if (c.mem.push_next_dep) g2s_queue++;
}
- } else if (c.mem.opcode == VTA_OPCODE_GEMM ||
- c.mem.opcode == VTA_OPCODE_ALU) {
+ } else if (c.mem.opcode == VTA_OPCODE_GEMM || c.mem.opcode == VTA_OPCODE_ALU) {
// Print instruction field information
if (c.gemm.pop_prev_dep) l2g_queue--;
if (c.gemm.push_prev_dep) g2l_queue++;
// Handle the LD<->compute queue
// NOTE: pop executes on target(stage)
CHECK(stage > 0 && stage < 4);
- if (pending_pop_prev_[stage] ||
- pending_pop_next_[stage]) {
- PushNoop(stage, false, false,
- pending_pop_prev_[stage],
- pending_pop_next_[stage]);
+ if (pending_pop_prev_[stage] || pending_pop_next_[stage]) {
+ PushNoop(stage, false, false, pending_pop_prev_[stage], pending_pop_next_[stage]);
pending_pop_prev_[stage] = 0;
pending_pop_next_[stage] = 0;
}
}
return false;
}
- void AutoReadBarrier() {
- ReadBarrier();
- }
+ void AutoReadBarrier() { ReadBarrier(); }
/*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */
void ReadBarrier() {
CHECK(fpga_buff_ != nullptr);
uint32_t buff_size = dram_buffer_.size() * elem_bytes_;
CHECK(buff_size <= kMaxBytes);
// Copy contents of DRAM buffer to FPGA buff
- VTAMemCopyFromHost(fpga_buff_,
- dram_buffer_.data(),
- buff_size);
+ VTAMemCopyFromHost(fpga_buff_, dram_buffer_.data(), buff_size);
// Flush if we're using a shared memory system
// and if interface is non-coherent
if (!coherent_ && always_cache_) {
- VTAFlushCache(fpga_buff_,
- fpga_buff_phy_,
- buff_size);
+ VTAFlushCache(fpga_buff_, fpga_buff_phy_, buff_size);
}
}
// Get stage of memory and computation
static PipelineStage GetPipelineStageAll(VTAMemInsn* insn) {
- PipelineStage stage = GetPipelineStage(insn);
- if (stage != kNoneStage) return stage;
- return GetMemPipelineStage(insn->memory_type);
+ PipelineStage stage = GetPipelineStage(insn);
+ if (stage != kNoneStage) return stage;
+ return GetMemPipelineStage(insn->memory_type);
}
// Push no-op
- void PushNoop(int stage,
- bool push_prev_dep, bool push_next_dep,
- bool pop_prev_dep, bool pop_next_dep) {
+ void PushNoop(int stage, bool push_prev_dep, bool push_next_dep, bool pop_prev_dep,
+ bool pop_next_dep) {
VTAMemInsn* insn = reinterpret_cast<VTAMemInsn*>(NextInsn());
insn->opcode = (stage == kStoreStage ? VTA_OPCODE_STORE : VTA_OPCODE_LOAD);
insn->push_prev_dep = push_prev_dep;
*/
class CommandQueue {
public:
- CommandQueue() {
- this->InitSpace();
- }
+ CommandQueue() { this->InitSpace(); }
void InitSpace() {
uop_queue_.InitSpace();
insn_queue_.InitSpace();
CHECK(device_ != nullptr);
}
- ~CommandQueue() {
- VTADeviceFree(device_);
- }
+ ~CommandQueue() { VTADeviceFree(device_); }
uint32_t GetElemBytes(uint32_t memory_id) {
uint32_t elem_bytes = 0;
switch (memory_id) {
case VTA_MEM_ID_UOP:
- elem_bytes = VTA_UOP_ELEM_BYTES;
- break;
+ elem_bytes = VTA_UOP_ELEM_BYTES;
+ break;
case VTA_MEM_ID_INP:
- elem_bytes = VTA_INP_ELEM_BYTES;
- break;
+ elem_bytes = VTA_INP_ELEM_BYTES;
+ break;
case VTA_MEM_ID_WGT:
- elem_bytes = VTA_WGT_ELEM_BYTES;
- break;
+ elem_bytes = VTA_WGT_ELEM_BYTES;
+ break;
case VTA_MEM_ID_ACC:
- elem_bytes = VTA_ACC_ELEM_BYTES;
- break;
+ elem_bytes = VTA_ACC_ELEM_BYTES;
+ break;
case VTA_MEM_ID_OUT:
- elem_bytes = VTA_OUT_ELEM_BYTES;
- break;
+ elem_bytes = VTA_OUT_ELEM_BYTES;
+ break;
default:
- LOG(FATAL) << "Memory id not recognized:" << memory_id;
- break;
+ LOG(FATAL) << "Memory id not recognized:" << memory_id;
+ break;
}
/*
* elements size should not larger than VTA_PAGE_BYTES.
return elem_bytes;
}
- void LoadBuffer2D(void* src_dram_addr,
- uint32_t src_elem_offset,
- uint32_t x_size,
- uint32_t y_size,
- uint32_t x_stride,
- uint32_t x_pad_before,
- uint32_t y_pad_before,
- uint32_t x_pad_after,
- uint32_t y_pad_after,
- uint32_t dst_sram_index,
+ void LoadBuffer2D(void* src_dram_addr, uint32_t src_elem_offset, uint32_t x_size, uint32_t y_size,
+ uint32_t x_stride, uint32_t x_pad_before, uint32_t y_pad_before,
+ uint32_t x_pad_after, uint32_t y_pad_after, uint32_t dst_sram_index,
uint32_t dst_memory_type) {
VTAMemInsn* insn = insn_queue_.CreateMemInsn(dst_memory_type);
insn->opcode = VTA_OPCODE_LOAD;
this->CheckInsnOverFlow();
}
- void StoreBuffer2D(uint32_t src_sram_index,
- uint32_t src_memory_type,
- void* dst_dram_addr,
- uint32_t dst_elem_offset,
- uint32_t x_size,
- uint32_t y_size,
+ void StoreBuffer2D(uint32_t src_sram_index, uint32_t src_memory_type, void* dst_dram_addr,
+ uint32_t dst_elem_offset, uint32_t x_size, uint32_t y_size,
uint32_t x_stride) {
VTAMemInsn* insn = insn_queue_.CreateStoreInsn();
insn->opcode = VTA_OPCODE_STORE;
this->CheckInsnOverFlow();
}
- void DepPush(int from_qid, int to_qid) {
- insn_queue_.DepPush(from_qid, to_qid);
- }
+ void DepPush(int from_qid, int to_qid) { insn_queue_.DepPush(from_qid, to_qid); }
- void DepPop(int from_qid, int to_qid) {
- insn_queue_.DepPop(from_qid, to_qid);
- }
+ void DepPop(int from_qid, int to_qid) { insn_queue_.DepPop(from_qid, to_qid); }
void ReadBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) {
if (!(debug_flag_ & VTA_DEBUG_SKIP_READ_BARRIER)) {
uint32_t elem_bytes = (elem_bits + 8 - 1) / 8;
- DataBuffer::FromHandle(buffer)->FlushCache(
- elem_bytes * start, elem_bytes * extent);
+ DataBuffer::FromHandle(buffer)->FlushCache(elem_bytes * start, elem_bytes * extent);
}
}
void WriteBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) {
if (!(debug_flag_ & VTA_DEBUG_SKIP_WRITE_BARRIER)) {
uint32_t elem_bytes = (elem_bits + 8 - 1) / 8;
- DataBuffer::FromHandle(buffer)->InvalidateCache(
- elem_bytes * start, elem_bytes * extent);
+ DataBuffer::FromHandle(buffer)->InvalidateCache(elem_bytes * start, elem_bytes * extent);
}
}
insn_queue_.DumpInsn();
}
// Make sure that the last instruction is a finish instruction
- CHECK(reinterpret_cast<VTAMemInsn*>(
- insn_queue_.data())[insn_queue_.count()-1].opcode == VTA_OPCODE_FINISH);
+ CHECK(reinterpret_cast<VTAMemInsn*>(insn_queue_.data())[insn_queue_.count() - 1].opcode ==
+ VTA_OPCODE_FINISH);
// Make sure that we don't exceed contiguous physical memory limits
CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER);
- int timeout = VTADeviceRun(
- device_,
- insn_queue_.dram_phy_addr(),
- insn_queue_.count(),
- wait_cycles);
+ int timeout =
+ VTADeviceRun(device_, insn_queue_.dram_phy_addr(), insn_queue_.count(), wait_cycles);
CHECK_EQ(timeout, 0);
// Reset buffers
uop_queue_.Reset();
}
// Set debug flag
- void SetDebugFlag(int debug_flag) {
- debug_flag_ = debug_flag;
- }
+ void SetDebugFlag(int debug_flag) { debug_flag_ = debug_flag; }
- void PushGEMMOp(void** uop_handle,
- int (*finit)(void*),
- void* signature,
- int nbytes) {
+ void PushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) {
UopKernelMap** uptr = reinterpret_cast<UopKernelMap**>(uop_handle);
if (uptr[0] == nullptr) {
uptr[0] = new UopKernelMap();
this->CheckInsnOverFlow();
}
- void PushALUUop(void** uop_handle,
- int (*finit)(void*),
- void* signature,
- int nbytes) {
+ void PushALUUop(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) {
UopKernelMap** uptr = reinterpret_cast<UopKernelMap**>(uop_handle);
if (uptr[0] == nullptr) {
uptr[0] = new UopKernelMap();
}
static std::shared_ptr<CommandQueue>& ThreadLocal() {
- static std::shared_ptr<CommandQueue> inst =
- std::make_shared<CommandQueue>();
+ static std::shared_ptr<CommandQueue> inst = std::make_shared<CommandQueue>();
if (inst == nullptr) {
inst = std::make_shared<CommandQueue>();
}
return inst;
}
- static void Shutdown() {
- ThreadLocal().reset();
- }
+ static void Shutdown() { ThreadLocal().reset(); }
private:
// Push GEMM uop to the command buffer
void PushGEMMOp(UopKernel* kernel) {
- uop_queue_.Push(kernel,
- [this]() { this->AutoSync(); });
+ uop_queue_.Push(kernel, [this]() { this->AutoSync(); });
if (uop_queue_.pending()) {
VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP);
insn->opcode = VTA_OPCODE_LOAD;
insn->reset_reg = kernel->reset_out_;
insn->uop_bgn = kernel->sram_begin_;
insn->uop_end = kernel->sram_end_;
- const std::vector<UopKernel::LoopEntry> &loop = kernel->loop();
+ const std::vector<UopKernel::LoopEntry>& loop = kernel->loop();
if (loop.size() > 0) {
insn->iter_out = loop[0].extent;
insn->wgt_factor_out = loop[0].wgt_factor;
// Push ALU uop to the command buffer
void PushALUUop(UopKernel* kernel) {
- uop_queue_.Push(kernel,
- [this]() { this->AutoSync(); });
+ uop_queue_.Push(kernel, [this]() { this->AutoSync(); });
if (uop_queue_.pending()) {
VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP);
insn->opcode = VTA_OPCODE_LOAD;
insn->alu_opcode = kernel->opcode_;
insn->use_imm = kernel->use_imm_;
insn->imm = kernel->imm_val_;
- const std::vector<UopKernel::LoopEntry> &loop = kernel->loop();
+ const std::vector<UopKernel::LoopEntry>& loop = kernel->loop();
if (loop.size() == 0) {
insn->iter_out = 1;
insn->dst_factor_out = 0;
}
}
// Auto sync when instruction overflow
- void AutoSync() {
- this->Synchronize(1 << 31);
- }
+ void AutoSync() { this->Synchronize(1 << 31); }
// Internal debug flag
int debug_flag_{0};
} // namespace vta
-void* VTABufferAlloc(size_t size) {
- return vta::DataBuffer::Alloc(size);
-}
+void* VTABufferAlloc(size_t size) { return vta::DataBuffer::Alloc(size); }
-void VTABufferFree(void* buffer) {
- vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer));
-}
+void VTABufferFree(void* buffer) { vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); }
-void VTABufferCopy(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
+void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
int kind_mask) {
vta::DataBuffer* from_buffer = nullptr;
vta::DataBuffer* to_buffer = nullptr;
// This is an FPGA to host mem transfer
from_buffer->InvalidateCache(from_offset, size);
from_buffer->MemCopyToHost(static_cast<char*>(to) + to_offset,
- static_cast<const char*>(from) + from_offset,
- size);
+ static_cast<const char*>(from) + from_offset, size);
} else if (to_buffer) {
// This is a host to FPGA mem transfer
to_buffer->MemCopyFromHost(static_cast<char*>(to) + to_offset,
- static_cast<const char*>(from) + from_offset,
- size);
+ static_cast<const char*>(from) + from_offset, size);
to_buffer->FlushCache(to_offset, size);
}
}
-VTACommandHandle VTATLSCommandHandle() {
- return vta::CommandQueue::ThreadLocal().get();
-}
+VTACommandHandle VTATLSCommandHandle() { return vta::CommandQueue::ThreadLocal().get(); }
-void VTARuntimeShutdown() {
- vta::CommandQueue::Shutdown();
-}
+void VTARuntimeShutdown() { vta::CommandQueue::Shutdown(); }
void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) {
- static_cast<vta::CommandQueue*>(cmd)->
- SetDebugFlag(debug_flag);
+ static_cast<vta::CommandQueue*>(cmd)->SetDebugFlag(debug_flag);
}
void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer) {
return vta::DataBuffer::FromHandle(buffer)->virt_addr();
}
-void VTAWriteBarrier(VTACommandHandle cmd,
- void* buffer,
- uint32_t elem_bits,
- uint32_t start,
+void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start,
uint32_t extent) {
- static_cast<vta::CommandQueue*>(cmd)->
- WriteBarrier(buffer, elem_bits, start, extent);
+ static_cast<vta::CommandQueue*>(cmd)->WriteBarrier(buffer, elem_bits, start, extent);
}
-void VTAReadBarrier(VTACommandHandle cmd,
- void* buffer,
- uint32_t elem_bits,
- uint32_t start,
+void VTAReadBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start,
uint32_t extent) {
- static_cast<vta::CommandQueue*>(cmd)->
- ReadBarrier(buffer, elem_bits, start, extent);
+ static_cast<vta::CommandQueue*>(cmd)->ReadBarrier(buffer, elem_bits, start, extent);
}
-void VTALoadBuffer2D(VTACommandHandle cmd,
- void* src_dram_addr,
- uint32_t src_elem_offset,
- uint32_t x_size,
- uint32_t y_size,
- uint32_t x_stride,
- uint32_t x_pad_before,
- uint32_t y_pad_before,
- uint32_t x_pad_after,
- uint32_t y_pad_after,
- uint32_t dst_sram_index,
- uint32_t dst_memory_type) {
- static_cast<vta::CommandQueue*>(cmd)->
- LoadBuffer2D(src_dram_addr, src_elem_offset,
- x_size, y_size, x_stride,
- x_pad_before, y_pad_before,
- x_pad_after, y_pad_after,
- dst_sram_index, dst_memory_type);
+void VTALoadBuffer2D(VTACommandHandle cmd, void* src_dram_addr, uint32_t src_elem_offset,
+ uint32_t x_size, uint32_t y_size, uint32_t x_stride, uint32_t x_pad_before,
+ uint32_t y_pad_before, uint32_t x_pad_after, uint32_t y_pad_after,
+ uint32_t dst_sram_index, uint32_t dst_memory_type) {
+ static_cast<vta::CommandQueue*>(cmd)->LoadBuffer2D(
+ src_dram_addr, src_elem_offset, x_size, y_size, x_stride, x_pad_before, y_pad_before,
+ x_pad_after, y_pad_after, dst_sram_index, dst_memory_type);
}
-void VTAStoreBuffer2D(VTACommandHandle cmd,
- uint32_t src_sram_index,
- uint32_t src_memory_type,
- void* dst_dram_addr,
- uint32_t dst_elem_offset,
- uint32_t x_size,
- uint32_t y_size,
- uint32_t x_stride) {
- static_cast<vta::CommandQueue*>(cmd)->
- StoreBuffer2D(src_sram_index, src_memory_type,
- dst_dram_addr, dst_elem_offset,
- x_size, y_size, x_stride);
+void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t src_sram_index, uint32_t src_memory_type,
+ void* dst_dram_addr, uint32_t dst_elem_offset, uint32_t x_size,
+ uint32_t y_size, uint32_t x_stride) {
+ static_cast<vta::CommandQueue*>(cmd)->StoreBuffer2D(
+ src_sram_index, src_memory_type, dst_dram_addr, dst_elem_offset, x_size, y_size, x_stride);
}
-void VTAUopPush(uint32_t mode,
- uint32_t reset_out,
- uint32_t dst_index,
- uint32_t src_index,
- uint32_t wgt_index,
- uint32_t opcode,
- uint32_t use_imm,
- int32_t imm_val) {
- vta::CommandQueue::ThreadLocal()->record_kernel()
- ->Push(mode, reset_out, dst_index, src_index,
- wgt_index, opcode, use_imm, imm_val);
+void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index,
+ uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) {
+ vta::CommandQueue::ThreadLocal()->record_kernel()->Push(mode, reset_out, dst_index, src_index,
+ wgt_index, opcode, use_imm, imm_val);
}
-void VTAUopLoopBegin(uint32_t extent,
- uint32_t dst_factor,
- uint32_t src_factor,
+void VTAUopLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor,
uint32_t wgt_factor) {
- vta::CommandQueue::ThreadLocal()->record_kernel()
- ->PushLoopBegin(extent, dst_factor, src_factor, wgt_factor);
+ vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopBegin(extent, dst_factor, src_factor,
+ wgt_factor);
}
-void VTAUopLoopEnd() {
- vta::CommandQueue::ThreadLocal()->record_kernel()
- ->PushLoopEnd();
-}
+void VTAUopLoopEnd() { vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopEnd(); }
-int VTAPushGEMMOp(void** uop_handle,
- int (*finit)(void*),
- void* signature,
- int nbytes) {
- vta::CommandQueue::ThreadLocal()->
- PushGEMMOp(uop_handle, finit, signature, nbytes);
+int VTAPushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) {
+ vta::CommandQueue::ThreadLocal()->PushGEMMOp(uop_handle, finit, signature, nbytes);
return 0;
}
-int VTAPushALUOp(void** uop_handle,
- int (*finit)(void*),
- void* signature,
- int nbytes) {
- vta::CommandQueue::ThreadLocal()->
- PushALUUop(uop_handle, finit, signature, nbytes);
+int VTAPushALUOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) {
+ vta::CommandQueue::ThreadLocal()->PushALUUop(uop_handle, finit, signature, nbytes);
return 0;
}
int VTADepPush(VTACommandHandle cmd, int from_qid, int to_qid) {
- static_cast<vta::CommandQueue*>(cmd)->
- DepPush(from_qid, to_qid);
+ static_cast<vta::CommandQueue*>(cmd)->DepPush(from_qid, to_qid);
return 0;
}
int VTADepPop(VTACommandHandle cmd, int from_qid, int to_qid) {
- static_cast<vta::CommandQueue*>(cmd)->
- DepPop(from_qid, to_qid);
+ static_cast<vta::CommandQueue*>(cmd)->DepPop(from_qid, to_qid);
return 0;
}
void VTASynchronize(VTACommandHandle cmd, uint32_t wait_cycles) {
- static_cast<vta::CommandQueue*>(cmd)->
- Synchronize(wait_cycles);
+ static_cast<vta::CommandQueue*>(cmd)->Synchronize(wait_cycles);
}
* \param size Size of copy.
* \param kind_mask The memory copy kind.
*/
-TVM_DLL void VTABufferCopy(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- int kind_mask);
+TVM_DLL void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset,
+ size_t size, int kind_mask);
/*! \brief VTA command handle */
typedef void* VTACommandHandle;
* \param start The start of the region (in elements).
* \param extent The end of the region (in elements).
*/
-TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd,
- void* buffer,
- uint32_t elem_bits,
- uint32_t start,
+TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start,
uint32_t extent);
/*!
* \param start The start of the region (in elements).
* \param extent The end of the region (in elements).
*/
-TVM_DLL void VTAReadBarrier(VTACommandHandle cmd,
- void* buffer,
- uint32_t elem_bits,
- uint32_t start,
+TVM_DLL void VTAReadBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start,
uint32_t extent);
/*!
* \param dst_sram_index Destination SRAM index.
* \param dst_memory_type Destination memory type.
*/
-TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd,
- void* src_dram_addr,
- uint32_t src_elem_offset,
- uint32_t x_size,
- uint32_t y_size,
- uint32_t x_stride,
- uint32_t x_pad_before,
- uint32_t y_pad_before,
- uint32_t x_pad_after,
- uint32_t y_pad_after,
- uint32_t dst_sram_index,
+TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, void* src_dram_addr, uint32_t src_elem_offset,
+ uint32_t x_size, uint32_t y_size, uint32_t x_stride,
+ uint32_t x_pad_before, uint32_t y_pad_before, uint32_t x_pad_after,
+ uint32_t y_pad_after, uint32_t dst_sram_index,
uint32_t dst_memory_type);
/*!
* \param y_size The number of rows.
* \param x_stride The x axis stride.
*/
-TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd,
- uint32_t src_sram_index,
- uint32_t src_memory_type,
- void* dst_dram_addr,
- uint32_t dst_elem_offset,
- uint32_t x_size,
- uint32_t y_size,
+TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t src_sram_index,
+ uint32_t src_memory_type, void* dst_dram_addr,
+ uint32_t dst_elem_offset, uint32_t x_size, uint32_t y_size,
uint32_t x_stride);
/*!
* \param use_imm Use immediate in ALU mode if set to true.
* \param imm_val Immediate value in ALU mode.
*/
-TVM_DLL void VTAUopPush(uint32_t mode,
- uint32_t reset_out,
- uint32_t dst_index,
- uint32_t src_index,
- uint32_t wgt_index,
- uint32_t opcode,
- uint32_t use_imm,
- int32_t imm_val);
+TVM_DLL void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index,
+ uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val);
/*!
* \brief Mark start of a micro op loop.
* \param src_factor The input factor.
* \param wgt_factor The weight factor.
*/
-TVM_DLL void VTAUopLoopBegin(uint32_t extent,
- uint32_t dst_factor,
- uint32_t src_factor,
+TVM_DLL void VTAUopLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor,
uint32_t wgt_factor);
/*!
* \param nbytes Number of bytes to in the closure arguments.
* \return 0 if success.
*/
-TVM_DLL int VTAPushGEMMOp(void** uop_handle,
- int (*finit)(void*),
- void* signature,
- int nbytes);
+TVM_DLL int VTAPushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes);
/*!
* \brief Push ALU uop kernel into the command handle.
* \param nbytes Number of bytes to in the closure arguments.
* \return 0 if success.
*/
-TVM_DLL int VTAPushALUOp(void** uop_handle,
- int (*finit)(void*),
- void* signature,
- int nbytes);
+TVM_DLL int VTAPushALUOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes);
/*!
* \brief Push dependence token.
#define DMLC_LOG_NODATE 1
#define DMLC_LOG_FATAL_THROW 0
-
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
#include "../../src/runtime/rpc/rpc_local_session.h"
extern "C" {
* \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer
3A * \return 0 if success.
*/
-TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle,
- TVMFunctionHandle *out);
+TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out);
// --- APIs to be implemented by the frontend. ---
/*!
* \param resource_handle The handle additional resouce handle from fron-end.
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
*/
-extern int TVMWasmPackedCFunc(TVMValue* args,
- int* type_codes,
- int num_args,
- TVMRetValueHandle ret,
+extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret,
void* resource_handle);
/*!
extern void TVMWasmPackedCFuncFinalizer(void* resource_handle);
} // extern "C"
-
void* TVMWasmAllocSpace(int size) {
int num_count = (size + 7) / 8;
return new int64_t[num_count];
}
-void TVMWasmFreeSpace(void* arr) {
- delete[] static_cast<int64_t*>(arr);
-}
+void TVMWasmFreeSpace(void* arr) { delete[] static_cast<int64_t*>(arr); }
-int TVMWasmFuncCreateFromCFunc(void* resource_handle,
- TVMFunctionHandle *out) {
- return TVMFuncCreateFromCFunc(
- TVMWasmPackedCFunc, resource_handle,
- TVMWasmPackedCFuncFinalizer, out);
+int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) {
+ return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer,
+ out);
}
-
namespace tvm {
namespace runtime {
// functions in the JS runtime.
class AsyncLocalSession : public LocalSession {
public:
- AsyncLocalSession() {
- }
+ AsyncLocalSession() {}
PackedFuncHandle GetFunction(const std::string& name) final {
if (name == "runtime.RPCTimeEvaluator") {
} else if (auto* fp = tvm::runtime::Registry::Get(name)) {
// return raw handle because the remote need to explicitly manage it.
return new PackedFunc(*fp);
- } else if(auto* fp = tvm::runtime::Registry::Get("__async." + name)) {
+ } else if (auto* fp = tvm::runtime::Registry::Get("__async." + name)) {
auto* rptr = new PackedFunc(*fp);
async_func_set_.insert(rptr);
return rptr;
}
}
- void AsyncCallFunc(PackedFuncHandle func,
- const TVMValue* arg_values,
- const int* arg_type_codes,
- int num_args,
- FAsyncCallback callback) final {
+ void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes,
+ int num_args, FAsyncCallback callback) final {
auto it = async_func_set_.find(func);
if (it != async_func_set_.end()) {
PackedFunc packed_callback([callback, this](TVMArgs args, TVMRetValue*) {
int code = args[0];
TVMRetValue rv;
rv = args[1];
- this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) {
- callback(RPCCode::kReturn, encoded_args);
- });
+ this->EncodeReturn(std::move(rv),
+ [&](TVMArgs encoded_args) { callback(RPCCode::kReturn, encoded_args); });
});
TVMRetValue temp;
// special handle time evaluator.
try {
TVMArgs args(arg_values, arg_type_codes, num_args);
- PackedFunc retfunc = this->GetTimeEvaluator(
- args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
+ PackedFunc retfunc =
+ this->GetTimeEvaluator(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
TVMRetValue rv;
rv = retfunc;
this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) {
}
}
- void AsyncCopyToRemote(void* local_from,
- size_t local_from_offset,
- void* remote_to,
- size_t remote_to_offset,
- size_t nbytes,
- TVMContext remote_ctx_to,
- DLDataType type_hint,
- FAsyncCallback on_complete) final {
+ void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to,
+ size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to,
+ DLDataType type_hint, FAsyncCallback on_complete) final {
TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
try {
- this->GetDeviceAPI(remote_ctx_to)->CopyDataFromTo(
- local_from, local_from_offset,
- remote_to, remote_to_offset,
- nbytes, cpu_ctx, remote_ctx_to, type_hint, nullptr);
+ this->GetDeviceAPI(remote_ctx_to)
+ ->CopyDataFromTo(local_from, local_from_offset, remote_to, remote_to_offset, nbytes,
+ cpu_ctx, remote_ctx_to, type_hint, nullptr);
this->AsyncStreamWait(remote_ctx_to, nullptr, on_complete);
} catch (const std::runtime_error& e) {
this->SendException(on_complete, e.what());
}
}
- void AsyncCopyFromRemote(void* remote_from,
- size_t remote_from_offset,
- void* local_to,
- size_t local_to_offset,
- size_t nbytes,
- TVMContext remote_ctx_from,
- DLDataType type_hint,
- FAsyncCallback on_complete) final {
+ void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to,
+ size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from,
+ DLDataType type_hint, FAsyncCallback on_complete) final {
TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
try {
- this->GetDeviceAPI(remote_ctx_from)->CopyDataFromTo(
- remote_from, remote_from_offset,
- local_to, local_to_offset,
- nbytes, remote_ctx_from, cpu_ctx, type_hint, nullptr);
+ this->GetDeviceAPI(remote_ctx_from)
+ ->CopyDataFromTo(remote_from, remote_from_offset, local_to, local_to_offset, nbytes,
+ remote_ctx_from, cpu_ctx, type_hint, nullptr);
this->AsyncStreamWait(remote_ctx_from, nullptr, on_complete);
} catch (const std::runtime_error& e) {
this->SendException(on_complete, e.what());
}
}
- void AsyncStreamWait(TVMContext ctx,
- TVMStreamHandle stream,
- FAsyncCallback on_complete) final {
+ void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_complete) final {
if (ctx.device_type == kDLCPU) {
TVMValue value;
int32_t tcode = kTVMNullptr;
}
}
- bool IsAsync() const final {
- return true;
- }
+ bool IsAsync() const final { return true; }
private:
std::unordered_set<void*> async_func_set_;
const PackedFunc* async_wait_{nullptr};
// time evaluator
- PackedFunc GetTimeEvaluator(Optional<Module> opt_mod,
- std::string name,
- int device_type,
- int device_id,
- int number,
- int repeat,
- int min_repeat_ms) {
+ PackedFunc GetTimeEvaluator(Optional<Module> opt_mod, std::string name, int device_type,
+ int device_id, int number, int repeat, int min_repeat_ms) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
if (opt_mod.defined()) {
Module m = opt_mod.value();
std::string tkey = m->type_key();
- return WrapWasmTimeEvaluator(
- m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms);
+ return WrapWasmTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms);
} else {
auto* pf = runtime::Registry::Get(name);
CHECK(pf != nullptr) << "Cannot find " << name << " in the global function";
- return WrapWasmTimeEvaluator(
- *pf, ctx, number, repeat, min_repeat_ms);
+ return WrapWasmTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms);
}
}
// time evaluator
- PackedFunc WrapWasmTimeEvaluator(PackedFunc pf,
- TVMContext ctx,
- int number,
- int repeat,
+ PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat,
int min_repeat_ms) {
- auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](
- TVMArgs args, TVMRetValue *rv) {
+ auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) {
// the function is a async function.
PackedFunc on_complete = args[args.size() - 1];
// keep argument alive in finvoke so that they
};
auto* time_exec = runtime::Registry::Get("__async.wasm.TimeExecution");
CHECK(time_exec != nullptr) << "Cannot find wasm.GetTimer in the global function";
- (*time_exec)(TypedPackedFunc<void(int)>(finvoke),
- ctx, number, repeat, min_repeat_ms, on_complete);
+ (*time_exec)(TypedPackedFunc<void(int)>(finvoke), ctx, number, repeat, min_repeat_ms,
+ on_complete);
};
return PackedFunc(ftimer);
}
};
-TVM_REGISTER_GLOBAL("wasm.LocalSession")
-.set_body_typed([]() {
+TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() {
return CreateRPCSessionModule(std::make_shared<AsyncLocalSession>());
});
#define DMLC_LOG_NODATE 1
#define DMLC_LOG_FATAL_THROW 0
-#include <tvm/runtime/c_runtime_api.h>
#include <dmlc/logging.h>
+#include <tvm/runtime/c_runtime_api.h>
#include "src/runtime/c_runtime_api.cc"
#include "src/runtime/cpu_device_api.cc"
-#include "src/runtime/workspace_pool.cc"
+#include "src/runtime/file_util.cc"
+#include "src/runtime/graph/graph_runtime.cc"
#include "src/runtime/library_module.cc"
-#include "src/runtime/system_library.cc"
-
#include "src/runtime/module.cc"
#include "src/runtime/ndarray.cc"
#include "src/runtime/object.cc"
#include "src/runtime/registry.cc"
-#include "src/runtime/file_util.cc"
-#include "src/runtime/graph/graph_runtime.cc"
-#include "src/runtime/rpc/rpc_session.cc"
+#include "src/runtime/rpc/rpc_channel.cc"
#include "src/runtime/rpc/rpc_endpoint.cc"
#include "src/runtime/rpc/rpc_event_impl.cc"
-#include "src/runtime/rpc/rpc_channel.cc"
#include "src/runtime/rpc/rpc_local_session.cc"
#include "src/runtime/rpc/rpc_module.cc"
-
+#include "src/runtime/rpc/rpc_session.cc"
+#include "src/runtime/system_library.cc"
+#include "src/runtime/workspace_pool.cc"
// --- Implementations of backend and wasm runtime API. ---
-int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
- void* cdata,
- int num_task) {
+int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) {
TVMParallelGroupEnv env;
env.num_task = 1;
flambda(0, &env, cdata);
return 0;
}
-int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
- return 0;
-}
+int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { return 0; }
// --- Environment PackedFuncs for testing ---
-namespace tvm {
+namespace tvm {
namespace runtime {
-TVM_REGISTER_GLOBAL("testing.echo")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
+TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0];
});
-TVM_REGISTER_GLOBAL("testing.add_one")
-.set_body_typed([](int x) {
- return x + 1;
-});
+TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; });
-TVM_REGISTER_GLOBAL("testing.wrap_callback")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- PackedFunc pf = args[0];
- *ret = runtime::TypedPackedFunc<void()>([pf](){
- pf();
- });
- });
+TVM_REGISTER_GLOBAL("testing.wrap_callback").set_body([](TVMArgs args, TVMRetValue* ret) {
+ PackedFunc pf = args[0];
+ *ret = runtime::TypedPackedFunc<void()>([pf]() { pf(); });
+});
} // namespace runtime
} // namespace tvm
#include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include <tvm/runtime/device_api.h>
+
#include "../../src/runtime/meta_data.h"
-#include "../../src/runtime/workspace_pool.h"
#include "../../src/runtime/vulkan/vulkan_shader.h"
+#include "../../src/runtime/workspace_pool.h"
namespace tvm {
namespace runtime {
static WebGPUThreadEntry* ThreadLocal();
};
-
// All the implementations are redirectly to the JS side.
class WebGPUDeviceAPI : public DeviceAPI {
public:
copy_within_gpu_ = getter("deviceCopyWithinGPU");
}
- void SetDevice(TVMContext ctx) final {
- }
+ void SetDevice(TVMContext ctx) final {}
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == kExist) {
*rv = 1;
}
}
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) final {
-
double ptr_number = alloc_space_(nbytes);
return reinterpret_cast<void*>(static_cast<int64_t>(ptr_number));
}
- void FreeDataSpace(TVMContext ctx, void* ptr) final {
- return free_space_(ptr);
- }
+ void FreeDataSpace(TVMContext ctx, void* ptr) final { return free_space_(ptr); }
- void CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to, size_t to_offset, size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to, DLDataType type_hint,
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final {
if (static_cast<int>(ctx_from.device_type) == kDLWebGPU &&
static_cast<int>(ctx_to.device_type) == kDLWebGPU) {
return;
}
- void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
- LOG(FATAL) << "Not implemented";
- }
+ void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; }
void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
LOG(FATAL) << "Not implemented";
}
static const std::shared_ptr<WebGPUDeviceAPI>& Global() {
- static std::shared_ptr<WebGPUDeviceAPI> inst =
- std::make_shared<WebGPUDeviceAPI>();
+ static std::shared_ptr<WebGPUDeviceAPI> inst = std::make_shared<WebGPUDeviceAPI>();
return inst;
}
TypedPackedFunc<void(void* ptr)> free_space_;
TypedPackedFunc<void(void* from, void* to, int64_t to_offset, int64_t nbytes)> copy_to_gpu_;
TypedPackedFunc<void(void* from, int64_t from_offset, void* to, int64_t nbytes)> copy_from_gpu_;
- TypedPackedFunc<void(void* from, int64_t from_offset,
- void* to, int64_t to_offset, int64_t nbytes)> copy_within_gpu_;
+ TypedPackedFunc<void(void* from, int64_t from_offset, void* to, int64_t to_offset,
+ int64_t nbytes)>
+ copy_within_gpu_;
};
-
typedef dmlc::ThreadLocalStore<WebGPUThreadEntry> WebGPUThreadStore;
WebGPUThreadEntry::WebGPUThreadEntry()
- : pool(static_cast<DLDeviceType>(kDLWebGPU), WebGPUDeviceAPI::Global()) {
-}
-
-WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() {
- return WebGPUThreadStore::Get();
-}
+ : pool(static_cast<DLDeviceType>(kDLWebGPU), WebGPUDeviceAPI::Global()) {}
+WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore::Get(); }
class WebGPUModuleNode final : public runtime::ModuleNode {
public:
explicit WebGPUModuleNode(std::unordered_map<std::string, VulkanShader> smap,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source)
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: smap_(smap), fmap_(fmap), source_(source) {
auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUCreateShader");
CHECK(fp != nullptr);
const char* type_key() const final { return "webgpu"; }
- PackedFunc GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) final {
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
auto it = smap_.find(name);
if (it != smap_.end()) {
FunctionInfo info = fmap_.at(name);
LOG(FATAL) << "Not implemented";
}
- void SaveToBinary(dmlc::Stream* stream) final {
- LOG(FATAL) << "Not implemented";
- }
+ void SaveToBinary(dmlc::Stream* stream) final { LOG(FATAL) << "Not implemented"; }
std::string GetSource(const std::string& format) final {
// can only return source code.
TypedPackedFunc<PackedFunc(std::string finfo, TVMByteArray shader_data)> create_shader_;
};
-
Module WebGPUModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::unordered_map<std::string, VulkanShader> smap;
}
// for now webgpu is hosted via a vulkan module.
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan")
-.set_body_typed(WebGPUModuleLoadBinary);
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary);
-TVM_REGISTER_GLOBAL("device_api.webgpu")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = WebGPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});