From b2eb98f6c371cca2689908d66d9771a26c69f90a Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Tue, 15 Jan 2019 16:25:28 -0800 Subject: [PATCH] Remove cuda from autograd profiler (#15898) Summary: This puts stubs in the autograd profiler for the use of cuda APIs allowing the cuda parts of libtorch to be linked separately from the CPU parts. This also edits the buck build. Previous: For GPU builds: _C -> csrc -> caffe2 For CPU builds: _C -> csrc-cpu -> caffe2 Now: GPU: _C -> libtorch_cuda -> (libtorch -> caffe2, for CPU) Pull Request resolved: https://github.com/pytorch/pytorch/pull/15898 Reviewed By: ailzhang Differential Revision: D13617991 Pulled By: zdevito fbshipit-source-id: 6d84a50bb356a54b4217f93219902755601b00e1 --- tools/build_variables.py | 138 +++++++++++++++------ torch/CMakeLists.txt | 10 +- .../api/include/torch/nn/parallel/data_parallel.h | 2 + torch/csrc/autograd/functions/comm.cpp | 12 -- torch/csrc/autograd/functions/comm.h | 4 - torch/csrc/autograd/profiler.cpp | 83 ++++++------- torch/csrc/autograd/profiler.h | 88 +++++++------ torch/csrc/autograd/profiler_cuda.cpp | 65 ++++++++++ torch/csrc/cuda/comm.cpp | 4 - torch/csrc/cuda/comm.h | 4 - torch/csrc/cuda/cuda_check.h | 3 - 11 files changed, 253 insertions(+), 160 deletions(-) create mode 100644 torch/csrc/autograd/profiler_cuda.cpp diff --git a/tools/build_variables.py b/tools/build_variables.py index b6511c3..a24e795 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -2,6 +2,7 @@ # torch/csrc/{autgrad,jit}/generated. In fbcode, this distinction is # not currently relevant so they are combined into one list. from __future__ import absolute_import, division, print_function, unicode_literals +load("@bazel_skylib//lib:new_sets.bzl", "sets") GENERATED_CPP = [ @@ -22,8 +23,8 @@ GENERATED_CPP = [ "python_variable_methods.cpp", ] -# copied from https://github.com/pytorch/pytorch/blob/master/tools/cpp_build/libtorch/CMakeLists.txt -torch_sources_no_python_default = [ +# copied from https://github.com/pytorch/pytorch/blob/master/tools/cpp_build/torch/CMakeLists.txt +libtorch_sources = [ ":generate-code=Functions.cpp", ":generate-code=register_aten_ops_0.cpp", ":generate-code=register_aten_ops_1.cpp", @@ -39,7 +40,6 @@ torch_sources_no_python_default = [ "torch/csrc/autograd/function.cpp", "torch/csrc/autograd/functions/accumulate_grad.cpp", "torch/csrc/autograd/functions/basic_ops.cpp", - "torch/csrc/autograd/functions/comm.cpp", "torch/csrc/autograd/functions/tensor.cpp", "torch/csrc/autograd/functions/utils.cpp", "torch/csrc/autograd/grad_mode.cpp", @@ -110,39 +110,39 @@ torch_sources_no_python_default = [ "torch/csrc/jit/fuser/interface.cpp", ] +libtorch_cuda_sources = [ + "torch/csrc/cuda/comm.cpp", + "torch/csrc/cuda/nccl.cpp", + "torch/csrc/jit/fuser/cuda/fused_kernel.cpp", + "torch/csrc/autograd/profiler_cuda.cpp", + "torch/csrc/autograd/functions/comm.cpp" +] -def torch_vars(): - r = {} - # We start torch_sources with all cpp files, and exclude some. - # This is a much better approach than listing all of them manually because - # the number of excluded files is small and doesn"t change very frequently - r["torch_sources"] = ( - native.glob( - ["torch/csrc/**/*.cpp"], - exclude=[ - # remove anything that has "generic" in it"s path - "torch/csrc/**/generic/**/*.cpp", - # distributed only uses Module.cpp - # so remove all other files and just include that - "torch/csrc/distributed/**/*.cpp", - ], - ) - + [ - "torch/csrc/distributed/Module.cpp", - "torch/csrc/distributed/c10d/init.cpp", - "torch/csrc/distributed/c10d/ddp.cpp", - ] - + [":generate-code=" + x for x in GENERATED_CPP] - ) - - r["torch_sources_no_python"] = ( - torch_sources_no_python_default - + ["torch/csrc/cuda/comm.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/jit/fuser/cuda/fused_kernel.cpp"] - ) - r["torch_sources_no_python_cpu"] = torch_sources_no_python_default +def add_torch_libs(): + r = {} + # We start torch_python_sources with all cpp files, and exclude some + # including the files already contained in the torch and cuda bindings + globbed_sources = (native.glob( + ["torch/csrc/**/*.cpp"], + exclude=[ + # remove anything that has "generic" in it"s path + "torch/csrc/**/generic/**/*.cpp", + # distributed only uses Module.cpp + # so remove all other files and just include that + "torch/csrc/distributed/**/*.cpp", + ], + ) + [ + "torch/csrc/distributed/Module.cpp", + "torch/csrc/distributed/c10d/init.cpp", + "torch/csrc/distributed/c10d/ddp.cpp", + ] + [":generate-code=" + x for x in GENERATED_CPP]) + libtorch_python_sources = sets.to_list(sets.difference( + sets.make(globbed_sources), + sets.make(libtorch_sources + libtorch_cuda_sources), + )) - r["torch_csrc_flags"] = { + common_flags = { "compiler_flags": [ "-D_THP_CORE", "-DUSE_C10D", @@ -180,13 +180,69 @@ def torch_vars(): ], } - r["torch_csrc_flags_cpu"] = dict(r["torch_csrc_flags"]) + cpp_library( + name="libtorch", + srcs=libtorch_sources, + link_whole=True, + deps=[ + ":generated-autograd-headers", + ":generated-autograd-headers-bare", + ":generated-jit-headers", + "//caffe2/aten:ATen-cpu", + "//caffe2/caffe2:caffe2_cpu", + "//caffe2/torch/lib/libshm:libshm", + ], + external_deps=[ + ("nanopb", None, "protobuf-nanopb"), + ("protobuf", None), + ], + **common_flags + ) + + cpp_library( + name="libtorch_cuda", + srcs=libtorch_cuda_sources, + link_whole=True, + propagated_pp_flags=[ + "-DUSE_CUDA", + ], + deps=[ + ":generated-autograd-headers", + ":generated-autograd-headers-bare", + ":generated-jit-headers", + ":libtorch", + "//caffe2/aten:ATen", + "//caffe2/aten:generated-aten-headers-cuda", + "//caffe2/caffe2:caffe2_cpu", + "//caffe2/torch/lib/libshm:libshm", + ], + external_deps=[ + ("cudnn", "7.1.2", "cudnn-lazy"), + ("nccl", "2.1.15", "nccl-lazy"), + ("cuda", None, "nvToolsExt-lazy"), + ("cuda", None, "nvrtc-lazy"), + ("cuda", None, "nvrtc-builtins-lazy"), + ], + **common_flags + ) + + cpp_python_extension( + name="_C", + srcs=libtorch_python_sources, + base_module="torch", + deps=[ + ":libtorch_cuda", + ":thnn", + ":torch-lib-headers", + "//caffe2/torch/lib/THD:THD", + "//caffe2/torch/lib/c10d:c10d", + "//caffe2/torch/lib/libshm:libshm", + ], + external_deps=[ + ("numpy", None, "cpp"), + ("pybind11", None), + ], + **common_flags + ) - r["torch_csrc_flags_cpu"]["preprocessor_flags"] = [ - "-Icaffe2", - "-Icaffe2/torch/csrc/api/include", - "-Icaffe2/torch/csrc", - "-Icaffe2/torch/csrc/nn", - "-Icaffe2/torch/lib", - ] return r diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 72f3431..02d6abc 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -129,7 +129,6 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/autograd/function.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/accumulate_grad.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/basic_ops.cpp - ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/tensor.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/utils.cpp ${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp @@ -144,7 +143,6 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/autograd/saved_variable.cpp ${TORCH_SRC_DIR}/csrc/autograd/variable.cpp ${TORCH_SRC_DIR}/csrc/autograd/VariableTypeManual.cpp - ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp ${TORCH_SRC_DIR}/csrc/jit/autodiff.cpp ${TORCH_SRC_DIR}/csrc/jit/export.cpp ${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_0.cpp @@ -227,6 +225,14 @@ else () endif() endif () +if (USE_CUDA) + list(APPEND TORCH_SRCS + ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp + ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp + ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp + ) +endif() + if (NOT NO_API) list(APPEND TORCH_SRCS ${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp diff --git a/torch/csrc/api/include/torch/nn/parallel/data_parallel.h b/torch/csrc/api/include/torch/nn/parallel/data_parallel.h index ea23398..ebdb174 100644 --- a/torch/csrc/api/include/torch/nn/parallel/data_parallel.h +++ b/torch/csrc/api/include/torch/nn/parallel/data_parallel.h @@ -6,7 +6,9 @@ #include #include +#ifdef USE_CUDA #include +#endif #include #include diff --git a/torch/csrc/autograd/functions/comm.cpp b/torch/csrc/autograd/functions/comm.cpp index e604859..3ae4faf 100644 --- a/torch/csrc/autograd/functions/comm.cpp +++ b/torch/csrc/autograd/functions/comm.cpp @@ -1,5 +1,3 @@ -#ifdef USE_CUDA - #include #include @@ -31,7 +29,6 @@ Scatter::Scatter( unsqueeze_scalars_(unsqueeze_scalars) {} variable_list Scatter::apply(variable_list&& inputs) { -#ifdef USE_CUDA AT_ASSERT(inputs.size() == 1); auto& input = inputs.front(); @@ -63,16 +60,12 @@ variable_list Scatter::apply(variable_list&& inputs) { set_history(variables, grad_fn); return variables; -#else - AT_ERROR("Scatter is only supported in CUDA environments"); -#endif } Gather::Gather(const at::Device& destination_device, int64_t dim) : destination_device_(destination_device), dim_(dim) {} variable_list Gather::apply(variable_list&& inputs) { -#ifdef USE_CUDA bool all_are_zero_dim = true; for (const auto& input : inputs) { AT_CHECK( @@ -125,12 +118,7 @@ variable_list Gather::apply(variable_list&& inputs) { auto variable = torch::cuda::gather(tensors, dim_, destination_index); set_history(variable, grad_fn); return {variable}; -#else - AT_ERROR("Gather is only supported in CUDA environments"); -#endif } } // namespace autograd } // namespace torch - -#endif diff --git a/torch/csrc/autograd/functions/comm.h b/torch/csrc/autograd/functions/comm.h index e806546..552c085 100644 --- a/torch/csrc/autograd/functions/comm.h +++ b/torch/csrc/autograd/functions/comm.h @@ -1,7 +1,5 @@ #pragma once -#ifdef USE_CUDA - #include #include #include @@ -44,5 +42,3 @@ struct TORCH_API Gather : public Function { } // namespace autograd } // namespace torch - -#endif diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler.cpp index 25b2780..9a6550c 100644 --- a/torch/csrc/autograd/profiler.cpp +++ b/torch/csrc/autograd/profiler.cpp @@ -1,14 +1,20 @@ #include #include -#ifdef USE_CUDA -#include -#endif - #include namespace torch { namespace autograd { namespace profiler { +CUDAStubs default_stubs; +constexpr CUDAStubs* default_stubs_addr = &default_stubs; +// constant initialization, so it is guarenteed to be initialized before +// static initialization calls which may invoke registerCUDAMethods +static CUDAStubs* cuda_stubs = default_stubs_addr; + +TORCH_API void registerCUDAMethods(CUDAStubs* stubs) { + cuda_stubs = stubs; +} + ProfilerState state = ProfilerState::Disabled; uint16_t next_thread_id = 0; std::mutex all_event_lists_mutex; @@ -31,12 +37,7 @@ void mark(std::string name, bool include_cuda /* = true */) { return; } if (state == ProfilerState::NVTX) { -#ifdef USE_CUDA - nvtxMarkA(name.c_str()); -#else - throw std::logic_error( - "mark called with NVTX tracing, but compiled without CUDA"); -#endif + cuda_stubs->nvtxMarkA(name.c_str()); } else { getEventList().record( EventKind::Mark, @@ -56,18 +57,13 @@ void pushRangeImpl(T name, const char* msg="", int64_t sequence_nr=-1) { return; } if (state == ProfilerState::NVTX) { -#ifdef USE_CUDA if(sequence_nr >= 0) { std::stringstream s; s << name << msg << sequence_nr; - nvtxRangePushA(s.str().c_str()); + cuda_stubs->nvtxRangePushA(s.str().c_str()); } else { - nvtxRangePushA(c_str(name)); + cuda_stubs->nvtxRangePushA(c_str(name)); } -#else - throw std::logic_error( - "pushRange called with NVTX tracing, but compiled without CUDA"); -#endif } else { getEventList().record( EventKind::PushRange, @@ -86,12 +82,7 @@ void popRange() { return; } if (state == ProfilerState::NVTX) { -#ifdef USE_CUDA - nvtxRangePop(); -#else - throw std::logic_error( - "popRange called with NVTX tracing, but compiled without CUDA"); -#endif + cuda_stubs->nvtxRangePop(); } else { getEventList().record( EventKind::PopRange, @@ -103,7 +94,7 @@ void popRange() { RecordFunction::RecordFunction(Function* fn) { // typeid(*fn).name() would avoid an additional string allocation. - // However, typeid(*fn).name() would cause nvtx annotations for all user-defined + // However, typeid(*fn).name() would cause nvtx annotations for all user-defined // (Python-side) custom autograd function backward() methods to have the same name, // because they route through the same C++ side class. // fn->name() ensures that nvtx annotations for custom function backward() methods @@ -124,48 +115,32 @@ RecordFunction::RecordFunction(const char* name, int64_t current_sequence_nr) pushRangeImpl(name, ", seq=", current_sequence_nr); } -#ifdef USE_CUDA -static void onEachDevice(std::function op) { - at::cuda::OptionalCUDAGuard device_guard; - int count; - TORCH_CUDA_CHECK(cudaGetDeviceCount(&count)); - for(int i = 0; i < count; i++) { - device_guard.set_index(i); - op(i); - } -} -#endif - void enableProfiler(ProfilerState new_state) { AT_ASSERT(new_state != ProfilerState::Disabled); -#ifndef USE_CUDA - if (new_state == ProfilerState::NVTX) + if (new_state == ProfilerState::NVTX && !cuda_stubs->enabled()) throw std::runtime_error("Can't use NVTX profiler - PyTorch was compiled without CUDA"); -#endif if (state != ProfilerState::Disabled && new_state != state) { throw std::runtime_error("can't change kind of profiling (e.g. NVTX to CPU) while profiler is running"); } state = new_state; -#ifdef USE_CUDA if(state == ProfilerState::CUDA) { // event recording appears to have some startup overhead, so we need to // to generate some dummy events first before recording syncrhonization events for(int i = 0; i < 5; i++) { - onEachDevice([](int d) { + cuda_stubs->onEachDevice([](int d) { mark("__cuda_startup"); - cudaDeviceSynchronize(); + cuda_stubs->synchronize(); }); } // cuda events must be on the same device, so we need a start event recorded // for each gpu. we then use this event to synchronize time on the GPU // with the CPU clock. - onEachDevice([](int d) { + cuda_stubs->onEachDevice([](int d) { mark("__cuda_start_event"); }); } -#endif mark("__start_profile", false); } @@ -197,4 +172,24 @@ thread_event_lists disableProfiler() { } } +void Event::record(bool record_cuda) { + if (record_cuda) { + cuda_stubs->record(&device_, &event, &cpu_ns_); + return; + } + cpu_ns_ = getTime(); +} + +double Event::cuda_elapsed_us(const Event & e) { + if(!e.has_cuda() || !has_cuda()) { + throw std::logic_error("Events were not recorded for CUDA"); + } + if(e.device() != device()) { + throw std::logic_error("Events are not on the same device"); + } + return cuda_stubs->elapsed(event, e.event); +} + +CUDAStubs::~CUDAStubs() = default; + }}} diff --git a/torch/csrc/autograd/profiler.h b/torch/csrc/autograd/profiler.h index 0750cdf..e262f08 100644 --- a/torch/csrc/autograd/profiler.h +++ b/torch/csrc/autograd/profiler.h @@ -1,8 +1,5 @@ #pragma once -#ifdef USE_CUDA -#include -#endif #include #include #include @@ -16,21 +13,54 @@ #include #include #include -#include -#ifdef USE_CUDA -#include -#include -#endif #ifndef _WIN32 #include #endif +typedef struct CUevent_st* CUDAEventStub; + namespace torch { namespace autograd { struct Function; namespace profiler { +struct TORCH_API CUDAStubs { + virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) { + fail(); + } + virtual float elapsed(CUDAEventStub event, CUDAEventStub event2) { + fail(); + return 0.f; + } + virtual void nvtxMarkA(const char* name) { + fail(); + } + virtual void nvtxRangePushA(const char* name) { + fail(); + } + virtual void nvtxRangePop() { + fail(); + } + virtual bool enabled() { + return false; + } + virtual void onEachDevice(std::function op) { + fail(); + } + virtual void synchronize() { + fail(); + } + virtual ~CUDAStubs(); + +private: + void fail() { + AT_ERROR("CUDA used in profiler but not enabled."); + } +}; + +TORCH_API void registerCUDAMethods(CUDAStubs* stubs); + constexpr inline size_t ceilToMultiple(size_t a, size_t b) { return ((a + b - 1) / b) * b; } @@ -64,7 +94,7 @@ enum class EventKind : uint16_t { PopRange }; -struct Event final { +struct TORCH_API Event final { Event(EventKind kind, std::string name, uint16_t thread_id, bool record_cuda) : owned_name_(new std::string(std::move(name))) , name_ptr_(owned_name_->c_str()) @@ -75,19 +105,7 @@ struct Event final { , kind_(kind) , thread_id_(thread_id) { record(record_cuda); } - void record(bool record_cuda) { -#ifdef USE_CUDA - if (record_cuda) { - TORCH_CUDA_CHECK(cudaGetDevice(&device_)); - TORCH_CUDA_CHECK(cudaEventCreate(&event)); - auto stream = at::cuda::getCurrentCUDAStream(); - cpu_ns_ = getTime(); - TORCH_CUDA_CHECK(cudaEventRecord(event, stream)); - return; - } -#endif - cpu_ns_ = getTime(); - } + void record(bool record_cuda); std::string kind() const { switch(kind_) { case EventKind::Mark: return "mark"; @@ -105,29 +123,9 @@ struct Event final { double cpu_elapsed_us(const Event & e) { return (e.cpu_ns_ - cpu_ns_)/(1000.0); } - double cuda_elapsed_us(const Event & e) { -#ifdef USE_CUDA - if(!e.has_cuda() || !has_cuda()) { - throw std::logic_error("Events were not recorded for CUDA"); - } - if(e.device() != device()) { - throw std::logic_error("Events are not on the same device"); - } - TORCH_CUDA_CHECK(cudaEventSynchronize(event)); - TORCH_CUDA_CHECK(cudaEventSynchronize(e.event)); - float ms; - TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event, e.event)); - return ms*1000.0; -#else - throw std::logic_error("CUDA not enabled"); -#endif - } + double cuda_elapsed_us(const Event & e); bool has_cuda() const { -#ifdef USE_CUDA return event != nullptr; -#else - return false; -#endif } int device() const { return device_; @@ -142,9 +140,7 @@ private: EventKind kind_; uint16_t thread_id_; int device_ = -1; -#ifdef USE_CUDA - cudaEvent_t event = nullptr; -#endif + struct CUevent_st* event = nullptr; }; // a linked-list of fixed sized vectors, to avoid diff --git a/torch/csrc/autograd/profiler_cuda.cpp b/torch/csrc/autograd/profiler_cuda.cpp new file mode 100644 index 0000000..cbea26d --- /dev/null +++ b/torch/csrc/autograd/profiler_cuda.cpp @@ -0,0 +1,65 @@ +#include +#include +#include +#include + +#include + +namespace torch { namespace autograd { namespace profiler { + +namespace { + +struct CUDAMethods : public CUDAStubs { + void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) override { + TORCH_CUDA_CHECK(cudaGetDevice(device)); + TORCH_CUDA_CHECK(cudaEventCreate(event)); + auto stream = at::cuda::getCurrentCUDAStream(); + *cpu_ns = getTime(); + TORCH_CUDA_CHECK(cudaEventRecord(*event, stream)); + } + float elapsed(CUDAEventStub event, CUDAEventStub event2) override { + TORCH_CUDA_CHECK(cudaEventSynchronize(event)); + TORCH_CUDA_CHECK(cudaEventSynchronize(event2)); + float ms; + TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event, event2)); + return ms*1000.0; + } + void nvtxMarkA(const char* name) override { + ::nvtxMark(name); + } + void nvtxRangePushA(const char* name) override { + ::nvtxRangePushA(name); + } + void nvtxRangePop() override { + ::nvtxRangePop(); + } + void onEachDevice(std::function op) override { + at::cuda::OptionalCUDAGuard device_guard; + int count; + TORCH_CUDA_CHECK(cudaGetDeviceCount(&count)); + for(int i = 0; i < count; i++) { + device_guard.set_index(i); + op(i); + } + } + void synchronize() override { + cudaDeviceSynchronize(); + } + bool enabled() override { + return true; + } + +}; + +struct RegisterCUDAMethods { + RegisterCUDAMethods() { + static CUDAMethods methods; + registerCUDAMethods(&methods); + } +}; +RegisterCUDAMethods reg; + +} // namespaces +} // namespace profiler +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index c6bc582..d4045b8 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -1,7 +1,5 @@ #include -#ifdef USE_CUDA - #include #include @@ -244,5 +242,3 @@ at::Tensor gather( return result; } }} // namespace torch::cuda - -#endif diff --git a/torch/csrc/cuda/comm.h b/torch/csrc/cuda/comm.h index 960eddf..8d7671d 100644 --- a/torch/csrc/cuda/comm.h +++ b/torch/csrc/cuda/comm.h @@ -1,7 +1,5 @@ #pragma once -#ifdef USE_CUDA - #include #include #include @@ -30,5 +28,3 @@ at::Tensor gather( int64_t dim, c10::optional destination_index); }} - -#endif diff --git a/torch/csrc/cuda/cuda_check.h b/torch/csrc/cuda/cuda_check.h index 327e90d..b000b1a 100644 --- a/torch/csrc/cuda/cuda_check.h +++ b/torch/csrc/cuda/cuda_check.h @@ -1,6 +1,5 @@ #pragma once -#ifdef USE_CUDA #include #include #include @@ -37,5 +36,3 @@ static inline void cudaCheck(cudaError_t result, const char * file, int line) { #define TORCH_CUDA_CHECK(result) ::torch::cudaCheck(result,__FILE__,__LINE__); } - -#endif -- 2.7.4