1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/extensions/api/socket/tcp_socket.h"
7 #include "chrome/browser/extensions/api/api_resource.h"
8 #include "net/base/address_list.h"
9 #include "net/base/ip_endpoint.h"
10 #include "net/base/net_errors.h"
11 #include "net/base/rand_callback.h"
12 #include "net/socket/tcp_client_socket.h"
14 namespace extensions {
16 const char kTCPSocketTypeInvalidError[] =
17 "Cannot call both connect and listen on the same socket.";
18 const char kSocketListenError[] = "Could not listen on the specified port.";
20 static base::LazyInstance<ProfileKeyedAPIFactory<
21 ApiResourceManager<ResumableTCPSocket> > >
22 g_factory = LAZY_INSTANCE_INITIALIZER;
26 ProfileKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> >*
27 ApiResourceManager<ResumableTCPSocket>::GetFactoryInstance() {
28 return &g_factory.Get();
31 static base::LazyInstance<ProfileKeyedAPIFactory<
32 ApiResourceManager<ResumableTCPServerSocket> > >
33 g_server_factory = LAZY_INSTANCE_INITIALIZER;
37 ProfileKeyedAPIFactory<ApiResourceManager<ResumableTCPServerSocket> >*
38 ApiResourceManager<ResumableTCPServerSocket>::GetFactoryInstance() {
39 return &g_server_factory.Get();
42 TCPSocket::TCPSocket(const std::string& owner_extension_id)
43 : Socket(owner_extension_id),
44 socket_mode_(UNKNOWN) {
47 TCPSocket::TCPSocket(net::TCPClientSocket* tcp_client_socket,
48 const std::string& owner_extension_id,
50 : Socket(owner_extension_id),
51 socket_(tcp_client_socket),
52 socket_mode_(CLIENT) {
53 this->is_connected_ = is_connected;
56 TCPSocket::TCPSocket(net::TCPServerSocket* tcp_server_socket,
57 const std::string& owner_extension_id)
58 : Socket(owner_extension_id),
59 server_socket_(tcp_server_socket),
60 socket_mode_(SERVER) {
64 TCPSocket* TCPSocket::CreateSocketForTesting(
65 net::TCPClientSocket* tcp_client_socket,
66 const std::string& owner_extension_id,
68 return new TCPSocket(tcp_client_socket, owner_extension_id, is_connected);
72 TCPSocket* TCPSocket::CreateServerSocketForTesting(
73 net::TCPServerSocket* tcp_server_socket,
74 const std::string& owner_extension_id) {
75 return new TCPSocket(tcp_server_socket, owner_extension_id);
78 TCPSocket::~TCPSocket() {
82 void TCPSocket::Connect(const std::string& address,
84 const CompletionCallback& callback) {
85 DCHECK(!callback.is_null());
87 if (socket_mode_ == SERVER || !connect_callback_.is_null()) {
88 callback.Run(net::ERR_CONNECTION_FAILED);
91 DCHECK(!server_socket_.get());
92 socket_mode_ = CLIENT;
93 connect_callback_ = callback;
95 int result = net::ERR_CONNECTION_FAILED;
100 net::AddressList address_list;
101 if (!StringAndPortToAddressList(address, port, &address_list)) {
102 result = net::ERR_ADDRESS_INVALID;
106 socket_.reset(new net::TCPClientSocket(address_list, NULL,
107 net::NetLog::Source()));
109 connect_callback_ = callback;
110 result = socket_->Connect(base::Bind(
111 &TCPSocket::OnConnectComplete, base::Unretained(this)));
114 if (result != net::ERR_IO_PENDING)
115 OnConnectComplete(result);
118 void TCPSocket::Disconnect() {
119 is_connected_ = false;
121 socket_->Disconnect();
122 server_socket_.reset(NULL);
123 connect_callback_.Reset();
124 read_callback_.Reset();
125 accept_callback_.Reset();
126 accept_socket_.reset(NULL);
129 int TCPSocket::Bind(const std::string& address, int port) {
130 return net::ERR_FAILED;
133 void TCPSocket::Read(int count,
134 const ReadCompletionCallback& callback) {
135 DCHECK(!callback.is_null());
137 if (socket_mode_ != CLIENT) {
138 callback.Run(net::ERR_FAILED, NULL);
142 if (!read_callback_.is_null()) {
143 callback.Run(net::ERR_IO_PENDING, NULL);
148 callback.Run(net::ERR_INVALID_ARGUMENT, NULL);
152 if (!socket_.get() || !IsConnected()) {
153 callback.Run(net::ERR_SOCKET_NOT_CONNECTED, NULL);
157 read_callback_ = callback;
158 scoped_refptr<net::IOBuffer> io_buffer = new net::IOBuffer(count);
159 int result = socket_->Read(io_buffer.get(), count,
160 base::Bind(&TCPSocket::OnReadComplete, base::Unretained(this),
163 if (result != net::ERR_IO_PENDING)
164 OnReadComplete(io_buffer, result);
167 void TCPSocket::RecvFrom(int count,
168 const RecvFromCompletionCallback& callback) {
169 callback.Run(net::ERR_FAILED, NULL, NULL, 0);
172 void TCPSocket::SendTo(scoped_refptr<net::IOBuffer> io_buffer,
174 const std::string& address,
176 const CompletionCallback& callback) {
177 callback.Run(net::ERR_FAILED);
180 bool TCPSocket::SetKeepAlive(bool enable, int delay) {
183 return socket_->SetKeepAlive(enable, delay);
186 bool TCPSocket::SetNoDelay(bool no_delay) {
189 return socket_->SetNoDelay(no_delay);
192 int TCPSocket::Listen(const std::string& address, int port, int backlog,
193 std::string* error_msg) {
194 if (socket_mode_ == CLIENT) {
195 *error_msg = kTCPSocketTypeInvalidError;
196 return net::ERR_NOT_IMPLEMENTED;
198 DCHECK(!socket_.get());
199 socket_mode_ = SERVER;
201 scoped_ptr<net::IPEndPoint> bind_address(new net::IPEndPoint());
202 if (!StringAndPortToIPEndPoint(address, port, bind_address.get()))
203 return net::ERR_INVALID_ARGUMENT;
205 if (!server_socket_.get()) {
206 server_socket_.reset(new net::TCPServerSocket(NULL,
207 net::NetLog::Source()));
209 int result = server_socket_->Listen(*bind_address, backlog);
211 *error_msg = kSocketListenError;
215 void TCPSocket::Accept(const AcceptCompletionCallback &callback) {
216 if (socket_mode_ != SERVER || !server_socket_.get()) {
217 callback.Run(net::ERR_FAILED, NULL);
221 // Limits to only 1 blocked accept call.
222 if (!accept_callback_.is_null()) {
223 callback.Run(net::ERR_FAILED, NULL);
227 int result = server_socket_->Accept(&accept_socket_, base::Bind(
228 &TCPSocket::OnAccept, base::Unretained(this)));
229 if (result == net::ERR_IO_PENDING) {
230 accept_callback_ = callback;
231 } else if (result == net::OK) {
232 accept_callback_ = callback;
233 this->OnAccept(result);
235 callback.Run(result, NULL);
239 bool TCPSocket::IsConnected() {
240 RefreshConnectionStatus();
241 return is_connected_;
244 bool TCPSocket::GetPeerAddress(net::IPEndPoint* address) {
247 return !socket_->GetPeerAddress(address);
250 bool TCPSocket::GetLocalAddress(net::IPEndPoint* address) {
252 return !socket_->GetLocalAddress(address);
253 } else if (server_socket_.get()) {
254 return !server_socket_->GetLocalAddress(address);
260 Socket::SocketType TCPSocket::GetSocketType() const {
261 return Socket::TYPE_TCP;
264 int TCPSocket::WriteImpl(net::IOBuffer* io_buffer,
266 const net::CompletionCallback& callback) {
267 if (socket_mode_ != CLIENT)
268 return net::ERR_FAILED;
269 else if (!socket_.get() || !IsConnected())
270 return net::ERR_SOCKET_NOT_CONNECTED;
272 return socket_->Write(io_buffer, io_buffer_size, callback);
275 void TCPSocket::RefreshConnectionStatus() {
276 if (!is_connected_) return;
277 if (server_socket_) return;
278 if (!socket_->IsConnected()) {
283 void TCPSocket::OnConnectComplete(int result) {
284 DCHECK(!connect_callback_.is_null());
285 DCHECK(!is_connected_);
286 is_connected_ = result == net::OK;
287 connect_callback_.Run(result);
288 connect_callback_.Reset();
291 void TCPSocket::OnReadComplete(scoped_refptr<net::IOBuffer> io_buffer,
293 DCHECK(!read_callback_.is_null());
294 read_callback_.Run(result, io_buffer);
295 read_callback_.Reset();
298 void TCPSocket::OnAccept(int result) {
299 DCHECK(!accept_callback_.is_null());
300 if (result == net::OK && accept_socket_.get()) {
301 accept_callback_.Run(
302 result, static_cast<net::TCPClientSocket*>(accept_socket_.release()));
304 accept_callback_.Run(result, NULL);
306 accept_callback_.Reset();
309 ResumableTCPSocket::ResumableTCPSocket(const std::string& owner_extension_id)
310 : TCPSocket(owner_extension_id),
316 ResumableTCPSocket::ResumableTCPSocket(net::TCPClientSocket* tcp_client_socket,
317 const std::string& owner_extension_id,
319 : TCPSocket(tcp_client_socket, owner_extension_id, is_connected),
325 bool ResumableTCPSocket::IsPersistent() const {
329 ResumableTCPServerSocket::ResumableTCPServerSocket(
330 const std::string& owner_extension_id)
331 : TCPSocket(owner_extension_id),
336 bool ResumableTCPServerSocket::IsPersistent() const {
340 } // namespace extensions