Python Multi-GPU
[platform/upstream/caffeonacl.git] / python / train.py
1 #!/usr/bin/env python
2 """
3 Trains a model using one or more GPUs.
4 """
5 from multiprocessing import Process
6
7 import caffe
8
9
10 def train(
11         solver,  # solver proto definition
12         snapshot,  # solver snapshot to restore
13         gpus,  # list of device ids
14         timing=False,  # show timing info for compute and communications
15 ):
16     # NCCL uses a uid to identify a session
17     uid = caffe.NCCL.new_uid()
18
19     caffe.init_log()
20     caffe.log('Using devices %s' % str(gpus))
21
22     procs = []
23     for rank in range(len(gpus)):
24         p = Process(target=solve,
25                     args=(solver, snapshot, gpus, timing, uid, rank))
26         p.daemon = True
27         p.start()
28         procs.append(p)
29     for p in procs:
30         p.join()
31
32
33 def time(solver, nccl):
34     fprop = []
35     bprop = []
36     total = caffe.Timer()
37     allrd = caffe.Timer()
38     for _ in range(len(solver.net.layers)):
39         fprop.append(caffe.Timer())
40         bprop.append(caffe.Timer())
41     display = solver.param.display
42
43     def show_time():
44         if solver.iter % display == 0:
45             s = '\n'
46             for i in range(len(solver.net.layers)):
47                 s += 'forw %3d %8s ' % (i, solver.net.layers[i].layer_param.name)
48                 s += ': %.2f\n' % fprop[i].ms
49             for i in range(len(solver.net.layers) - 1, -1, -1):
50                 s += 'back %3d %8s ' % (i, solver.net.layers[i].layer_param.name)
51                 s += ': %.2f\n' % bprop[i].ms
52             s += 'solver total: %.2f\n' % total.ms
53             s += 'allreduce: %.2f\n' % allrd.ms
54             caffe.log(s)
55
56     solver.net.before_forward(lambda layer: fprop[layer].start())
57     solver.net.after_forward(lambda layer: fprop[layer].stop())
58     solver.net.before_backward(lambda layer: bprop[layer].start())
59     solver.net.after_backward(lambda layer: bprop[layer].stop())
60     solver.add_callback(lambda: total.start(), lambda: (total.stop(), allrd.start()))
61     solver.add_callback(nccl)
62     solver.add_callback(lambda: '', lambda: (allrd.stop(), show_time()))
63
64
65 def solve(proto, snapshot, gpus, timing, uid, rank):
66     caffe.set_mode_gpu()
67     caffe.set_device(gpus[rank])
68     caffe.set_solver_count(len(gpus))
69     caffe.set_solver_rank(rank)
70
71     solver = caffe.SGDSolver(proto)
72     if snapshot and len(snapshot) != 0:
73         solver.restore(snapshot)
74
75     nccl = caffe.NCCL(solver, uid)
76     nccl.bcast()
77
78     if timing and rank == 0:
79         time(solver, nccl)
80     else:
81         solver.add_callback(nccl)
82
83     if solver.param.layer_wise_reduce:
84         solver.net.after_backward(nccl)
85     solver.step(solver.param.max_iter)
86
87
88 if __name__ == '__main__':
89     import argparse
90     parser = argparse.ArgumentParser()
91
92     parser.add_argument("--solver", required=True, help="Solver proto definition.")
93     parser.add_argument("--snapshot", help="Solver snapshot to restore.")
94     parser.add_argument("--gpus", type=int, nargs='+', default=[0],
95                         help="List of device ids.")
96     parser.add_argument("--timing", action='store_true', help="Show timing info.")
97     args = parser.parse_args()
98
99     train(args.solver, args.snapshot, args.gpus, args.timing)