namespace netlink {
-Netlink::Netlink(int protocol)
+Netlink::Netlink(int protocol) :
+ sequence(1)
{
fd = ::socket(PF_NETLINK, SOCK_RAW, protocol);
if (fd < 0) {
}
Netlink::Netlink(Netlink&& netlink) :
- fd(netlink.fd)
+ fd(netlink.fd), sequence(netlink.sequence)
{
netlink.fd = -1;
}
}
}
-void Netlink::send(const void *buf, unsigned int size)
+void Netlink::send(int type, const std::vector<char>& data)
{
struct sockaddr_nl addr;
int ret;
addr.nl_pid = 0;
addr.nl_groups = 0;
+ char buf[NLMSG_SPACE(data.size())];
+ auto *nlh = (struct nlmsghdr *)buf;
+
+ ::memset(nlh, 0, sizeof(*nlh));
+
+ nlh->nlmsg_len = sizeof(buf);
+ nlh->nlmsg_type = type;
+ nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ nlh->nlmsg_seq = sequence++;
+
+ // overflow - The sequence of message from kernel is 0,
+ // so this is not to use 0 as a sequence number.
+ if (sequence == 0)
+ sequence++;
+
+ ::memcpy(NLMSG_DATA(buf), data.data(), data.size());
+
+
do {
- ret = ::sendto(fd, buf, size, 0,
+ ret = ::sendto(fd, buf, sizeof(buf), 0,
(struct sockaddr*)&addr, sizeof(addr));
} while (ret < 0 && errno == EINTR);
if (ret < 0) {
- throw runtime::Exception("Failed to send audit packet");
+ throw runtime::Exception("Failed to send a netlink message");
+ }
+
+ if (recv(MSG_PEEK).first == NLMSG_ERROR) {
+ auto reply = recv().second;
+ auto err = (struct nlmsgerr*)reply.data();
+ if (err->error)
+ throw runtime::Exception("Netlink error: " +
+ std::to_string(err->error));
+ } else {
+ throw runtime::Exception("Missing ack of netlink message");
}
}
-void Netlink::recv(void *buf, unsigned int size, int options)
+Netlink::Message Netlink::recv(int options)
{
- struct sockaddr_nl nladdr;
- socklen_t nladdrlen = sizeof(nladdr);
int ret;
+ //To get message size
+ struct nlmsghdr nlh;
do {
- ret = ::recvfrom(fd, buf, size, options,
+ ret = ::recv(fd, &nlh, sizeof(nlh), options | MSG_PEEK);
+ } while (ret < 0 && errno == EINTR);
+
+ if (ret < 0) {
+ throw runtime::Exception("Failed to get the size of netlink message");
+ }
+
+ struct sockaddr_nl nladdr;
+ socklen_t nladdrlen = sizeof(nladdr);
+ char buf[nlh.nlmsg_len + NLMSG_HDRLEN];
+ do {
+ ret = ::recvfrom(fd, buf, sizeof(buf), options,
(struct sockaddr*)&nladdr, &nladdrlen);
} while (ret < 0 && errno == EINTR);
+ Message msg = {nlh.nlmsg_type, std::vector<char>(nlh.nlmsg_len)};
+
+ ::memcpy(msg.second.data(), NLMSG_DATA(buf), msg.second.size());
+
if (ret < 0) {
throw runtime::Exception("Failed to receive audit packet");
}
if (nladdrlen != sizeof(nladdr)) {
- throw runtime::Exception("Bad address size reading audit netlink socket");
+ throw runtime::Exception("Bad address size in netlink socket");
}
if (nladdr.nl_pid) {
- throw runtime::Exception("Spoofed packet received on audit netlink socket");
+ throw runtime::Exception("Spoofed packet received on netlink socket");
}
+
+ return msg;
}
} // namespace runtime