Tensor EmptyTensorFromProto(const TensorProto& tensor_proto) {
auto context = ContextFromProto(tensor_proto);
- context->SwitchToDevice(0);
+ context->SwitchToDevice();
if (NumelFromTensorProto(tensor_proto) == 0 &&
tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
// TODO: remove when serialization of dtype uninitialized tensor is removed
void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
auto tensor_proto = blob_proto.tensor();
auto context = ContextFromProto(tensor_proto);
- context->SwitchToDevice(0);
+ context->SwitchToDevice();
if (NumelFromTensorProto(tensor_proto) == 0 &&
tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
// TODO: remove after empty Tensor serialization is forbidden
auto uniq_ptr = ContextFromProto(tensor_proto);
// since CopyFromProtoAsIs accepts BaseContext*
auto context = uniq_ptr.get();
- context->SwitchToDevice(0);
+ context->SwitchToDevice();
int64_t chunkBegin = 0;
auto chunkEnd = tensor->numel();
CaffeCudaSetDevice(gpu_id_);
}
+ // void SwitchToDevice()
using BaseContext::SwitchToDevice;
inline void WaitEvent(const Event& ev) override {
if (!HasCudaGPU())
return;
CUDAContext context_outer(0); // gpu id
- context_outer.SwitchToDevice(0); // logical stream id
+ context_outer.SwitchToDevice();
if (NumCudaDevices() >= 2) {
auto before_stream = context_outer.cuda_stream();
context_different_device.SwitchToDevice(10);
// go back
- context_outer.SwitchToDevice(0); // logical stream id
+ context_outer.SwitchToDevice();
EXPECT_EQ(context_outer.cuda_stream(), before_stream);
// do nothing - infers the current device and stream
: OperatorBase(operator_def, ws), context_(operator_def.device_option()) {
// In the constructor, we switch to the device so that the child class
// constructors will run on that device.
- context_.SwitchToDevice(0);
+ context_.SwitchToDevice();
}
explicit Operator(
const c10::FunctionSchema& fn_schema,
: OperatorBase(fn_schema, inputs, outputs) {
// In the constructor, we switch to the device so that the child class
// constructors will run on that device.
- context_.SwitchToDevice(0);
+ context_.SwitchToDevice();
}
~Operator() noexcept override {}
prefetch_success_(true),
finalize_(false),
no_prefetch_(GetSingleArgument<bool>("no_prefetch", false)) {
- context_.SwitchToDevice(0);
+ context_.SwitchToDevice();
}
virtual ~PrefetchOperator() noexcept {
bool Run(int /* unused */ /*stream_id*/) override {
if (no_prefetch_) {
- context_.SwitchToDevice(0);
+ context_.SwitchToDevice();
bool result = Prefetch() && CopyPrefetched();
context_.FinishDeviceComputation();
return result;
prefetch_thread_.reset(
new std::thread([this] { this->PrefetchWorker(); }));
}
- context_.SwitchToDevice(0);
+ context_.SwitchToDevice();
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (!prefetched_)
consumer_.wait(lock);