8dd8b0df337e9048fc327e6b5f55ae01afaac82b
[platform/upstream/caffeonacl.git] / src / programs / dump_network.cpp
1 // Copyright 2013 Yangqing Jia
2 //
3 // This program takes in a trained network and an input blob, and then dumps
4 // all the intermediate blobs produced by the net to individual binary
5 // files stored in protobuffer binary formats.
6 // Usage:
7 //    dump_network input_net_param trained_net_param input_blob output_prefix 0/1
8 // if input_net_param is 'none', we will directly load the network from
9 // trained_net_param. If the last argv is 1, we will do a forward-backward pass
10 // before dumping everyting, and also dump the who network.
11
12 #include <cuda_runtime.h>
13 #include <fcntl.h>
14 #include <google/protobuf/text_format.h>
15
16 #include <cstring>
17
18 #include "caffe/blob.hpp"
19 #include "caffe/common.hpp"
20 #include "caffe/net.hpp"
21 #include "caffe/filler.hpp"
22 #include "caffe/proto/caffe.pb.h"
23 #include "caffe/util/io.hpp"
24 #include "caffe/optimization/solver.hpp"
25
26 using namespace caffe;
27
28 int main(int argc, char** argv) {
29   cudaSetDevice(1);
30   Caffe::set_mode(Caffe::GPU);
31   Caffe::set_phase(Caffe::TEST);
32
33   NetParameter net_param;
34   NetParameter trained_net_param;
35
36   if (strcmp(argv[1], "none") == 0) {
37     // We directly load the net param from trained file
38     ReadProtoFromBinaryFile(argv[2], &net_param);
39   } else {
40     ReadProtoFromTextFile(argv[1], &net_param);
41   }
42   ReadProtoFromBinaryFile(argv[2], &trained_net_param);
43   
44   vector<Blob<float>* > input_vec;
45   if (strcmp(argv[3], "none") != 0) {
46     BlobProto input_blob_proto;
47     ReadProtoFromBinaryFile(argv[3], &input_blob_proto);
48     shared_ptr<Blob<float> > input_blob(new Blob<float>());
49     input_blob->FromProto(input_blob_proto);
50     input_vec.push_back(input_blob.get());
51   }
52
53   shared_ptr<Net<float> > caffe_net(new Net<float>(net_param, input_vec));
54   caffe_net->CopyTrainedLayersFrom(trained_net_param);
55
56   string output_prefix(argv[4]);
57   // Run the network without training.
58   LOG(ERROR) << "Performing Forward";
59   caffe_net->Forward(input_vec);
60   if (argc > 4 && strcmp(argv[4], "1")) {
61     LOG(ERROR) << "Performing Backward";
62     caffe_net->Backward();
63     // Dump the network
64     NetParameter output_net_param;
65     caffe_net->ToProto(&output_net_param, true);
66     WriteProtoToBinaryFile(output_net_param, output_prefix + output_net_param.name());
67   }
68   // Now, let's dump all the layers
69
70   const vector<string>& blob_names = caffe_net->blob_names();
71   const vector<shared_ptr<Blob<float> > >& blobs = caffe_net->blobs();
72   for (int blobid = 0; blobid < caffe_net->blobs().size(); ++blobid) {
73     // Serialize blob
74     LOG(ERROR) << "Dumping " << blob_names[blobid];
75     BlobProto output_blob_proto;
76     blobs[blobid]->ToProto(&output_blob_proto);
77     WriteProtoToBinaryFile(output_blob_proto, output_prefix + blob_names[blobid]);
78   }
79
80   return 0;
81 }