using namespace caffe; // NOLINT(build/namespaces)
int main(int argc, char** argv) {
- if (argc < 4 || argc > 5) {
+ if (argc < 4 || argc > 6) {
LOG(ERROR) << "test_net net_proto pretrained_net_proto iterations "
- << "[CPU/GPU]";
+ << "[CPU/GPU] [Device ID]";
return 1;
}
- cudaSetDevice(0);
Caffe::set_phase(Caffe::TEST);
- if (argc == 5 && strcmp(argv[4], "GPU") == 0) {
+ if (argc >= 5 && strcmp(argv[4], "GPU") == 0) {
LOG(ERROR) << "Using GPU";
Caffe::set_mode(Caffe::GPU);
+ if (argc == 6) {
+ int device_id = atoi(argv[5]);
+ Caffe::SetDevice(device_id);
+ }
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);