Imported Upstream version 0.8.0
[platform/upstream/multipath-tools.git] / libmpathcmd / mpath_cmd.c
1 /*
2  * Copyright (C) 2015 Red Hat, Inc.
3  *
4  * This file is part of the device-mapper multipath userspace tools.
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public License
8  * as published by the Free Software Foundation; either version 2.1
9  * of the License, or (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #include <stdlib.h>
21 #include <unistd.h>
22 #include <stdio.h>
23 #include <sys/types.h>
24 #include <sys/socket.h>
25 #include <sys/un.h>
26 #include <poll.h>
27 #include <string.h>
28 #include <errno.h>
29
30 #include "mpath_cmd.h"
31
32 /*
33  * keep reading until its all read
34  */
35 static ssize_t read_all(int fd, void *buf, size_t len, unsigned int timeout)
36 {
37         size_t total = 0;
38         ssize_t n;
39         int ret;
40         struct pollfd pfd;
41
42         while (len) {
43                 pfd.fd = fd;
44                 pfd.events = POLLIN;
45                 ret = poll(&pfd, 1, timeout);
46                 if (!ret) {
47                         errno = ETIMEDOUT;
48                         return -1;
49                 } else if (ret < 0) {
50                         if (errno == EINTR)
51                                 continue;
52                         return -1;
53                 } else if (!(pfd.revents & POLLIN))
54                         continue;
55                 n = recv(fd, buf, len, 0);
56                 if (n < 0) {
57                         if ((errno == EINTR) || (errno == EAGAIN))
58                                 continue;
59                         return -1;
60                 }
61                 if (!n)
62                         return total;
63                 buf = n + (char *)buf;
64                 len -= n;
65                 total += n;
66         }
67         return total;
68 }
69
70 /*
71  * keep writing until it's all sent
72  */
73 static size_t write_all(int fd, const void *buf, size_t len)
74 {
75         size_t total = 0;
76
77         while (len) {
78                 ssize_t n = send(fd, buf, len, MSG_NOSIGNAL);
79                 if (n < 0) {
80                         if ((errno == EINTR) || (errno == EAGAIN))
81                                 continue;
82                         return total;
83                 }
84                 if (!n)
85                         return total;
86                 buf = n + (const char *)buf;
87                 len -= n;
88                 total += n;
89         }
90         return total;
91 }
92
93 /*
94  * connect to a unix domain socket
95  */
96 int mpath_connect(void)
97 {
98         int fd, len;
99         struct sockaddr_un addr;
100
101         memset(&addr, 0, sizeof(addr));
102         addr.sun_family = AF_LOCAL;
103         addr.sun_path[0] = '\0';
104         len = strlen(DEFAULT_SOCKET) + 1 + sizeof(sa_family_t);
105         strncpy(&addr.sun_path[1], DEFAULT_SOCKET, len);
106
107         fd = socket(AF_LOCAL, SOCK_STREAM, 0);
108         if (fd == -1)
109                 return -1;
110
111         if (connect(fd, (struct sockaddr *)&addr, len) == -1) {
112                 close(fd);
113                 return -1;
114         }
115
116         return fd;
117 }
118
119 int mpath_disconnect(int fd)
120 {
121         return close(fd);
122 }
123
124 ssize_t mpath_recv_reply_len(int fd, unsigned int timeout)
125 {
126         size_t len;
127         ssize_t ret;
128
129         ret = read_all(fd, &len, sizeof(len), timeout);
130         if (ret < 0)
131                 return ret;
132         if (ret != sizeof(len)) {
133                 errno = EIO;
134                 return -1;
135         }
136         if (len <= 0 || len >= MAX_REPLY_LEN) {
137                 errno = ERANGE;
138                 return -1;
139         }
140         return len;
141 }
142
143 int mpath_recv_reply_data(int fd, char *reply, size_t len,
144                           unsigned int timeout)
145 {
146         ssize_t ret;
147
148         ret = read_all(fd, reply, len, timeout);
149         if (ret < 0)
150                 return ret;
151         if (ret != len) {
152                 errno = EIO;
153                 return -1;
154         }
155         reply[len - 1] = '\0';
156         return 0;
157 }
158
159 int mpath_recv_reply(int fd, char **reply, unsigned int timeout)
160 {
161         int err;
162         ssize_t len;
163
164         *reply = NULL;
165         len = mpath_recv_reply_len(fd, timeout);
166         if (len <= 0)
167                 return len;
168         *reply = malloc(len);
169         if (!*reply)
170                 return -1;
171         err = mpath_recv_reply_data(fd, *reply, len, timeout);
172         if (err) {
173                 free(*reply);
174                 *reply = NULL;
175                 return -1;
176         }
177         return 0;
178 }
179
180 int mpath_send_cmd(int fd, const char *cmd)
181 {
182         size_t len;
183
184         if (cmd != NULL)
185                 len = strlen(cmd) + 1;
186         else
187                 len = 0;
188         if (write_all(fd, &len, sizeof(len)) != sizeof(len))
189                 return -1;
190         if (len && write_all(fd, cmd, len) != len)
191                 return -1;
192         return 0;
193 }
194
195 int mpath_process_cmd(int fd, const char *cmd, char **reply,
196                       unsigned int timeout)
197 {
198         if (mpath_send_cmd(fd, cmd) != 0)
199                 return -1;
200         return mpath_recv_reply(fd, reply, timeout);
201 }