add device id arg to test_net (fix #232)
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 26 Apr 2014 22:22:42 +0000 (15:22 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 26 Apr 2014 22:22:42 +0000 (15:22 -0700)
tools/test_net.cpp

index 559fa73..69c774a 100644 (file)
 using namespace caffe;  // NOLINT(build/namespaces)
 
 int main(int argc, char** argv) {
-  if (argc < 4 || argc > 5) {
+  if (argc < 4 || argc > 6) {
     LOG(ERROR) << "test_net net_proto pretrained_net_proto iterations "
-        << "[CPU/GPU]";
+        << "[CPU/GPU] [Device ID]";
     return 1;
   }
 
-  cudaSetDevice(0);
   Caffe::set_phase(Caffe::TEST);
 
-  if (argc == 5 && strcmp(argv[4], "GPU") == 0) {
+  if (argc >= 5 && strcmp(argv[4], "GPU") == 0) {
     LOG(ERROR) << "Using GPU";
     Caffe::set_mode(Caffe::GPU);
+    if (argc == 6) {
+      int device_id = atoi(argv[5]);
+      Caffe::SetDevice(device_id);
+    }
   } else {
     LOG(ERROR) << "Using CPU";
     Caffe::set_mode(Caffe::CPU);