3 Trains a model using one or more GPUs.
5 from multiprocessing import Process
11 solver, # solver proto definition
12 snapshot, # solver snapshot to restore
13 gpus, # list of device ids
14 timing=False, # show timing info for compute and communications
16 # NCCL uses a uid to identify a session
17 uid = caffe.NCCL.new_uid()
20 caffe.log('Using devices %s' % str(gpus))
23 for rank in range(len(gpus)):
24 p = Process(target=solve,
25 args=(solver, snapshot, gpus, timing, uid, rank))
33 def time(solver, nccl):
38 for _ in range(len(solver.net.layers)):
39 fprop.append(caffe.Timer())
40 bprop.append(caffe.Timer())
41 display = solver.param.display
44 if solver.iter % display == 0:
46 for i in range(len(solver.net.layers)):
47 s += 'forw %3d %8s ' % (i, solver.net.layers[i].layer_param.name)
48 s += ': %.2f\n' % fprop[i].ms
49 for i in range(len(solver.net.layers) - 1, -1, -1):
50 s += 'back %3d %8s ' % (i, solver.net.layers[i].layer_param.name)
51 s += ': %.2f\n' % bprop[i].ms
52 s += 'solver total: %.2f\n' % total.ms
53 s += 'allreduce: %.2f\n' % allrd.ms
56 solver.net.before_forward(lambda layer: fprop[layer].start())
57 solver.net.after_forward(lambda layer: fprop[layer].stop())
58 solver.net.before_backward(lambda layer: bprop[layer].start())
59 solver.net.after_backward(lambda layer: bprop[layer].stop())
60 solver.add_callback(lambda: total.start(), lambda: (total.stop(), allrd.start()))
61 solver.add_callback(nccl)
62 solver.add_callback(lambda: '', lambda: (allrd.stop(), show_time()))
65 def solve(proto, snapshot, gpus, timing, uid, rank):
67 caffe.set_device(gpus[rank])
68 caffe.set_solver_count(len(gpus))
69 caffe.set_solver_rank(rank)
71 solver = caffe.SGDSolver(proto)
72 if snapshot and len(snapshot) != 0:
73 solver.restore(snapshot)
75 nccl = caffe.NCCL(solver, uid)
78 if timing and rank == 0:
81 solver.add_callback(nccl)
83 if solver.param.layer_wise_reduce:
84 solver.net.after_backward(nccl)
85 solver.step(solver.param.max_iter)
88 if __name__ == '__main__':
90 parser = argparse.ArgumentParser()
92 parser.add_argument("--solver", required=True, help="Solver proto definition.")
93 parser.add_argument("--snapshot", help="Solver snapshot to restore.")
94 parser.add_argument("--gpus", type=int, nargs='+', default=[0],
95 help="List of device ids.")
96 parser.add_argument("--timing", action='store_true', help="Show timing info.")
97 args = parser.parse_args()
99 train(args.solver, args.snapshot, args.gpus, args.timing)