tools: hv: use getmntent in hv_vss_daemon
[platform/adaptation/renesas_rcar/renesas_kernel.git] / tools / hv / hv_vss_daemon.c
1 /*
2  * An implementation of the host initiated guest snapshot for Hyper-V.
3  *
4  *
5  * Copyright (C) 2013, Microsoft, Inc.
6  * Author : K. Y. Srinivasan <kys@microsoft.com>
7  *
8  * This program is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License version 2 as published
10  * by the Free Software Foundation.
11  *
12  * This program is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE, GOOD TITLE or
15  * NON INFRINGEMENT.  See the GNU General Public License for more
16  * details.
17  *
18  */
19
20
21 #include <sys/types.h>
22 #include <sys/socket.h>
23 #include <sys/poll.h>
24 #include <linux/types.h>
25 #include <stdio.h>
26 #include <mntent.h>
27 #include <stdlib.h>
28 #include <unistd.h>
29 #include <string.h>
30 #include <ctype.h>
31 #include <errno.h>
32 #include <arpa/inet.h>
33 #include <linux/connector.h>
34 #include <linux/hyperv.h>
35 #include <linux/netlink.h>
36 #include <syslog.h>
37
38 static char vss_recv_buffer[4096];
39 static char vss_send_buffer[4096];
40 static struct sockaddr_nl addr;
41
42 #ifndef SOL_NETLINK
43 #define SOL_NETLINK 270
44 #endif
45
46
47 static int vss_operate(int operation)
48 {
49         char *fs_op;
50         char cmd[512];
51         char match[] = "/dev/";
52         FILE *mounts;
53         struct mntent *ent;
54         int error = 0, root_seen = 0;
55
56         switch (operation) {
57         case VSS_OP_FREEZE:
58                 fs_op = "-f ";
59                 break;
60         case VSS_OP_THAW:
61                 fs_op = "-u ";
62                 break;
63         default:
64                 return -1;
65         }
66
67         mounts = setmntent("/proc/mounts", "r");
68         if (mounts == NULL)
69                 return -1;
70
71         while((ent = getmntent(mounts))) {
72                 if (strncmp(ent->mnt_fsname, match, strlen(match)))
73                         continue;
74                 if (strcmp(ent->mnt_dir, "/") == 0) {
75                         root_seen = 1;
76                         continue;
77                 }
78                 snprintf(cmd, sizeof(cmd), "fsfreeze %s '%s'", fs_op, ent->mnt_dir);
79                 syslog(LOG_INFO, "VSS cmd is %s\n", cmd);
80                 error |= system(cmd);
81         }
82         endmntent(mounts);
83
84         if (root_seen) {
85                 sprintf(cmd, "fsfreeze %s /", fs_op);
86                 syslog(LOG_INFO, "VSS cmd is %s\n", cmd);
87                 error |= system(cmd);
88         }
89
90         return error;
91 }
92
93 static int netlink_send(int fd, struct cn_msg *msg)
94 {
95         struct nlmsghdr *nlh;
96         unsigned int size;
97         struct msghdr message;
98         char buffer[64];
99         struct iovec iov[2];
100
101         size = NLMSG_SPACE(sizeof(struct cn_msg) + msg->len);
102
103         nlh = (struct nlmsghdr *)buffer;
104         nlh->nlmsg_seq = 0;
105         nlh->nlmsg_pid = getpid();
106         nlh->nlmsg_type = NLMSG_DONE;
107         nlh->nlmsg_len = NLMSG_LENGTH(size - sizeof(*nlh));
108         nlh->nlmsg_flags = 0;
109
110         iov[0].iov_base = nlh;
111         iov[0].iov_len = sizeof(*nlh);
112
113         iov[1].iov_base = msg;
114         iov[1].iov_len = size;
115
116         memset(&message, 0, sizeof(message));
117         message.msg_name = &addr;
118         message.msg_namelen = sizeof(addr);
119         message.msg_iov = iov;
120         message.msg_iovlen = 2;
121
122         return sendmsg(fd, &message, 0);
123 }
124
125 int main(void)
126 {
127         int fd, len, nl_group;
128         int error;
129         struct cn_msg *message;
130         struct pollfd pfd;
131         struct nlmsghdr *incoming_msg;
132         struct cn_msg   *incoming_cn_msg;
133         int     op;
134         struct hv_vss_msg *vss_msg;
135
136         if (daemon(1, 0))
137                 return 1;
138
139         openlog("Hyper-V VSS", 0, LOG_USER);
140         syslog(LOG_INFO, "VSS starting; pid is:%d", getpid());
141
142         fd = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_CONNECTOR);
143         if (fd < 0) {
144                 syslog(LOG_ERR, "netlink socket creation failed; error:%d", fd);
145                 exit(EXIT_FAILURE);
146         }
147         addr.nl_family = AF_NETLINK;
148         addr.nl_pad = 0;
149         addr.nl_pid = 0;
150         addr.nl_groups = 0;
151
152
153         error = bind(fd, (struct sockaddr *)&addr, sizeof(addr));
154         if (error < 0) {
155                 syslog(LOG_ERR, "bind failed; error:%d", error);
156                 close(fd);
157                 exit(EXIT_FAILURE);
158         }
159         nl_group = CN_VSS_IDX;
160         setsockopt(fd, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &nl_group, sizeof(nl_group));
161         /*
162          * Register ourselves with the kernel.
163          */
164         message = (struct cn_msg *)vss_send_buffer;
165         message->id.idx = CN_VSS_IDX;
166         message->id.val = CN_VSS_VAL;
167         message->ack = 0;
168         vss_msg = (struct hv_vss_msg *)message->data;
169         vss_msg->vss_hdr.operation = VSS_OP_REGISTER;
170
171         message->len = sizeof(struct hv_vss_msg);
172
173         len = netlink_send(fd, message);
174         if (len < 0) {
175                 syslog(LOG_ERR, "netlink_send failed; error:%d", len);
176                 close(fd);
177                 exit(EXIT_FAILURE);
178         }
179
180         pfd.fd = fd;
181
182         while (1) {
183                 struct sockaddr *addr_p = (struct sockaddr *) &addr;
184                 socklen_t addr_l = sizeof(addr);
185                 pfd.events = POLLIN;
186                 pfd.revents = 0;
187                 poll(&pfd, 1, -1);
188
189                 len = recvfrom(fd, vss_recv_buffer, sizeof(vss_recv_buffer), 0,
190                                 addr_p, &addr_l);
191
192                 if (len < 0) {
193                         syslog(LOG_ERR, "recvfrom failed; pid:%u error:%d %s",
194                                         addr.nl_pid, errno, strerror(errno));
195                         close(fd);
196                         return -1;
197                 }
198
199                 if (addr.nl_pid) {
200                         syslog(LOG_WARNING,
201                                 "Received packet from untrusted pid:%u",
202                                 addr.nl_pid);
203                         continue;
204                 }
205
206                 incoming_msg = (struct nlmsghdr *)vss_recv_buffer;
207
208                 if (incoming_msg->nlmsg_type != NLMSG_DONE)
209                         continue;
210
211                 incoming_cn_msg = (struct cn_msg *)NLMSG_DATA(incoming_msg);
212                 vss_msg = (struct hv_vss_msg *)incoming_cn_msg->data;
213                 op = vss_msg->vss_hdr.operation;
214                 error =  HV_S_OK;
215
216                 switch (op) {
217                 case VSS_OP_FREEZE:
218                 case VSS_OP_THAW:
219                         error = vss_operate(op);
220                         if (error)
221                                 error = HV_E_FAIL;
222                         break;
223                 default:
224                         syslog(LOG_ERR, "Illegal op:%d\n", op);
225                 }
226                 vss_msg->error = error;
227                 len = netlink_send(fd, incoming_cn_msg);
228                 if (len < 0) {
229                         syslog(LOG_ERR, "net_link send failed; error:%d", len);
230                         exit(EXIT_FAILURE);
231                 }
232         }
233
234 }