Imported Upstream version 0.8.2
[platform/upstream/multipath-tools.git] / libmpathcmd / mpath_cmd.c
1 /*
2  * Copyright (C) 2015 Red Hat, Inc.
3  *
4  * This file is part of the device-mapper multipath userspace tools.
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public License
8  * as published by the Free Software Foundation; either version 2.1
9  * of the License, or (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #include <stdlib.h>
21 #include <unistd.h>
22 #include <stdio.h>
23 #include <sys/types.h>
24 #include <sys/socket.h>
25 #include <sys/un.h>
26 #include <poll.h>
27 #include <string.h>
28 #include <errno.h>
29 #include <fcntl.h>
30
31 #include "mpath_cmd.h"
32
33 /*
34  * keep reading until its all read
35  */
36 static ssize_t read_all(int fd, void *buf, size_t len, unsigned int timeout)
37 {
38         size_t total = 0;
39         ssize_t n;
40         int ret;
41         struct pollfd pfd;
42
43         while (len) {
44                 pfd.fd = fd;
45                 pfd.events = POLLIN;
46                 ret = poll(&pfd, 1, timeout);
47                 if (!ret) {
48                         errno = ETIMEDOUT;
49                         return -1;
50                 } else if (ret < 0) {
51                         if (errno == EINTR)
52                                 continue;
53                         return -1;
54                 } else if (!(pfd.revents & POLLIN))
55                         continue;
56                 n = recv(fd, buf, len, 0);
57                 if (n < 0) {
58                         if ((errno == EINTR) || (errno == EAGAIN))
59                                 continue;
60                         return -1;
61                 }
62                 if (!n)
63                         return total;
64                 buf = n + (char *)buf;
65                 len -= n;
66                 total += n;
67         }
68         return total;
69 }
70
71 /*
72  * keep writing until it's all sent
73  */
74 static size_t write_all(int fd, const void *buf, size_t len)
75 {
76         size_t total = 0;
77
78         while (len) {
79                 ssize_t n = send(fd, buf, len, MSG_NOSIGNAL);
80                 if (n < 0) {
81                         if ((errno == EINTR) || (errno == EAGAIN))
82                                 continue;
83                         return total;
84                 }
85                 if (!n)
86                         return total;
87                 buf = n + (const char *)buf;
88                 len -= n;
89                 total += n;
90         }
91         return total;
92 }
93
94 /*
95  * connect to a unix domain socket
96  */
97 int __mpath_connect(int nonblocking)
98 {
99         int fd, len;
100         struct sockaddr_un addr;
101         int flags = 0;
102
103         memset(&addr, 0, sizeof(addr));
104         addr.sun_family = AF_LOCAL;
105         addr.sun_path[0] = '\0';
106         strncpy(&addr.sun_path[1], DEFAULT_SOCKET, sizeof(addr.sun_path) - 1);
107         len = strlen(DEFAULT_SOCKET) + 1 + sizeof(sa_family_t);
108         if (len > sizeof(struct sockaddr_un))
109                 len = sizeof(struct sockaddr_un);
110
111         fd = socket(AF_LOCAL, SOCK_STREAM, 0);
112         if (fd == -1)
113                 return -1;
114
115         if (nonblocking) {
116                 flags = fcntl(fd, F_GETFL, 0);
117                 if (flags != -1)
118                         (void)fcntl(fd, F_SETFL, flags|O_NONBLOCK);
119         }
120
121         if (connect(fd, (struct sockaddr *)&addr, len) == -1) {
122                 int err = errno;
123
124                 close(fd);
125                 errno = err;
126                 return -1;
127         }
128
129         if (nonblocking && flags != -1)
130                 (void)fcntl(fd, F_SETFL, flags);
131
132         return fd;
133 }
134
135 /*
136  * connect to a unix domain socket
137  */
138 int mpath_connect(void)
139 {
140         return __mpath_connect(0);
141 }
142
143 int mpath_disconnect(int fd)
144 {
145         return close(fd);
146 }
147
148 ssize_t mpath_recv_reply_len(int fd, unsigned int timeout)
149 {
150         size_t len;
151         ssize_t ret;
152
153         ret = read_all(fd, &len, sizeof(len), timeout);
154         if (ret < 0)
155                 return ret;
156         if (ret != sizeof(len)) {
157                 errno = EIO;
158                 return -1;
159         }
160         if (len <= 0 || len >= MAX_REPLY_LEN) {
161                 errno = ERANGE;
162                 return -1;
163         }
164         return len;
165 }
166
167 int mpath_recv_reply_data(int fd, char *reply, size_t len,
168                           unsigned int timeout)
169 {
170         ssize_t ret;
171
172         ret = read_all(fd, reply, len, timeout);
173         if (ret < 0)
174                 return ret;
175         if (ret != len) {
176                 errno = EIO;
177                 return -1;
178         }
179         reply[len - 1] = '\0';
180         return 0;
181 }
182
183 int mpath_recv_reply(int fd, char **reply, unsigned int timeout)
184 {
185         int err;
186         ssize_t len;
187
188         *reply = NULL;
189         len = mpath_recv_reply_len(fd, timeout);
190         if (len <= 0)
191                 return len;
192         *reply = malloc(len);
193         if (!*reply)
194                 return -1;
195         err = mpath_recv_reply_data(fd, *reply, len, timeout);
196         if (err) {
197                 free(*reply);
198                 *reply = NULL;
199                 return -1;
200         }
201         return 0;
202 }
203
204 int mpath_send_cmd(int fd, const char *cmd)
205 {
206         size_t len;
207
208         if (cmd != NULL)
209                 len = strlen(cmd) + 1;
210         else
211                 len = 0;
212         if (write_all(fd, &len, sizeof(len)) != sizeof(len))
213                 return -1;
214         if (len && write_all(fd, cmd, len) != len)
215                 return -1;
216         return 0;
217 }
218
219 int mpath_process_cmd(int fd, const char *cmd, char **reply,
220                       unsigned int timeout)
221 {
222         if (mpath_send_cmd(fd, cmd) != 0)
223                 return -1;
224         return mpath_recv_reply(fd, reply, timeout);
225 }