1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #ifndef COMPONENTS_CAST_CHANNEL_CAST_SOCKET_H_
6 #define COMPONENTS_CAST_CHANNEL_CAST_SOCKET_H_
13 #include "base/cancelable_callback.h"
14 #include "base/gtest_prod_util.h"
15 #include "base/macros.h"
16 #include "base/memory/ref_counted.h"
17 #include "base/memory/weak_ptr.h"
18 #include "base/observer_list.h"
19 #include "base/threading/thread_checker.h"
20 #include "base/timer/timer.h"
21 #include "components/cast_channel/cast_auth_util.h"
22 #include "components/cast_channel/cast_channel_enum.h"
23 #include "components/cast_channel/cast_socket.h"
24 #include "components/cast_channel/cast_transport.h"
25 #include "net/base/completion_callback.h"
26 #include "net/base/io_buffer.h"
27 #include "net/base/ip_endpoint.h"
28 #include "net/log/net_log_source.h"
29 #include "services/network/public/mojom/network_context.mojom.h"
32 class X509Certificate;
35 namespace cast_channel {
41 // Cast device capabilities.
42 enum CastDeviceCapability : int {
49 MULTIZONE_GROUP = 1 << 5
52 // Public interface of the CastSocket class.
55 // Invoked when CastSocket opens.
56 // |socket|: raw pointer of opened socket (this pointer). Guaranteed to be
57 // valid in callback function. Do not pass |socket| around.
58 using OnOpenCallback = base::OnceCallback<void(CastSocket* socket)>;
62 virtual ~Observer() {}
64 // Invoked when an error occurs on |socket|.
65 virtual void OnError(const CastSocket& socket,
66 ChannelError error_state) = 0;
68 // Invoked when |socket| receives a message.
69 virtual void OnMessage(const CastSocket& socket,
70 const CastMessage& message) = 0;
72 virtual void OnReadyStateChanged(const CastSocket& socket);
75 virtual ~CastSocket() {}
77 // Used by BrowserContextKeyedAPIFactory.
78 static const char* service_name() { return "CastSocketImplManager"; }
80 // Connects the channel to the peer. If successful, the channel will be in
81 // READY_STATE_OPEN. DO NOT delete the CastSocket object in |callback|.
82 // Instead use Close().
83 // |callback| will be invoked with any ChannelError that occurred, or
84 // CHANNEL_ERROR_NONE if successful.
85 // If the CastSocket is destroyed while the connection is pending, |callback|
86 // will be invoked with CHANNEL_ERROR_UNKNOWN. In this case, invoking
87 // |callback| must not result in any re-entrancy behavior.
88 virtual void Connect(OnOpenCallback callback) = 0;
90 // Closes the channel if not already closed. On completion, the channel will
91 // be in READY_STATE_CLOSED.
93 // It is fine to delete this object in |callback|.
94 virtual void Close(const net::CompletionCallback& callback) = 0;
96 // The IP endpoint for the destination of the channel.
97 virtual const net::IPEndPoint& ip_endpoint() const = 0;
99 // Channel id generated by the CastChannelService.
100 virtual int id() const = 0;
102 // Sets the channel id generated by CastChannelService.
103 virtual void set_id(int id) = 0;
105 // The ready state of the channel.
106 virtual ReadyState ready_state() const = 0;
108 // Returns the last error that occurred on this channel, or
109 // CHANNEL_ERROR_NONE if no error has occurred.
110 virtual ChannelError error_state() const = 0;
112 // True when keep-alive signaling is handled for this socket.
113 virtual bool keep_alive() const = 0;
115 // Whether the channel is audio only as identified by the device
116 // certificate during channel authentication.
117 virtual bool audio_only() const = 0;
119 // Marks a socket as invalid due to an error, and sends an OnError
120 // event to |delegate_|.
121 // The OnError event receipient is responsible for closing the socket in the
122 // event of an error.
123 // Setting the error state does not close the socket if it is open.
124 virtual void SetErrorState(ChannelError error_state) = 0;
126 // Returns a pointer to the socket's message transport layer. Can be used to
127 // send and receive CastMessages over the socket.
128 virtual CastTransport* transport() const = 0;
130 // Registers |observer| with the socket to receive messages and error events.
131 virtual void AddObserver(Observer* observer) = 0;
133 // Unregisters |observer|.
134 virtual void RemoveObserver(Observer* observer) = 0;
137 // Holds parameters necessary to open a Cast channel (CastSocket) to a Cast
139 struct CastSocketOpenParams {
140 // IP endpoint of the Cast device.
141 net::IPEndPoint ip_endpoint;
143 // Connection timeout interval. If this value is not set, Cast socket will not
144 // report CONNECT_TIMEOUT error and may hang when connecting to a Cast device.
145 base::TimeDelta connect_timeout;
147 // Amount of idle time to wait before disconnecting. Cast socket will ping
148 // Cast device periodically at |ping_interval| to check liveness. If it does
149 // not receive response in |liveness_timeout|, it reports PING_TIMEOUT error.
150 // |liveness_timeout| should always be larger than or equal to
152 // If this value is not set, there is not periodic ping and Cast socket is
153 // always assumed alive.
154 base::TimeDelta liveness_timeout;
156 // Amount of idle time to wait before pinging the Cast device. See comments
157 // for |liveness_timeout|.
158 base::TimeDelta ping_interval;
160 // A bit vector representing the capabilities of the sink. The values are
161 // defined in components/cast_channel/cast_socket.h.
162 uint64_t device_capabilities;
164 CastSocketOpenParams(const net::IPEndPoint& ip_endpoint,
165 base::TimeDelta connect_timeout);
166 CastSocketOpenParams(const net::IPEndPoint& ip_endpoint,
167 base::TimeDelta connect_timeout,
168 base::TimeDelta liveness_timeout,
169 base::TimeDelta ping_interval,
170 uint64_t device_capabilities);
173 // This class implements a channel between Chrome and a Cast device using a TCP
174 // socket with SSL. The channel may authenticate that the receiver is a genuine
175 // Cast device. All CastSocketImpl objects must be used only on the IO thread.
177 // NOTE: Not called "CastChannel" to reduce confusion with the generated API
179 class CastSocketImpl : public CastSocket {
181 using NetworkContextGetter =
182 base::RepeatingCallback<network::mojom::NetworkContext*()>;
183 CastSocketImpl(NetworkContextGetter network_context_getter,
184 const CastSocketOpenParams& open_params,
185 const scoped_refptr<Logger>& logger);
187 CastSocketImpl(NetworkContextGetter network_context_getter,
188 const CastSocketOpenParams& open_params,
189 const scoped_refptr<Logger>& logger,
190 const AuthContext& auth_context);
192 // Ensures that the socket is closed.
193 ~CastSocketImpl() override;
195 // CastSocket interface.
196 void Connect(OnOpenCallback callback) override;
197 CastTransport* transport() const override;
198 void Close(const net::CompletionCallback& callback) override;
199 const net::IPEndPoint& ip_endpoint() const override;
200 int id() const override;
201 void set_id(int channel_id) override;
202 ReadyState ready_state() const override;
203 ChannelError error_state() const override;
204 bool keep_alive() const override;
205 bool audio_only() const override;
206 void AddObserver(Observer* observer) override;
207 void RemoveObserver(Observer* observer) override;
209 static net::NetworkTrafficAnnotationTag GetNetworkTrafficAnnotationTag();
212 // CastTransport::Delegate methods for receiving handshake messages.
213 class AuthTransportDelegate : public CastTransport::Delegate {
215 explicit AuthTransportDelegate(CastSocketImpl* socket);
217 // Gets the error state of the channel.
218 // Returns CHANNEL_ERROR_NONE if no errors are present.
219 ChannelError error_state() const;
221 // Gets recorded error details.
222 LastError last_error() const;
224 // CastTransport::Delegate interface.
225 void OnError(ChannelError error_state) override;
226 void OnMessage(const CastMessage& message) override;
227 void Start() override;
230 CastSocketImpl* socket_;
231 ChannelError error_state_;
232 LastError last_error_;
235 // CastTransport::Delegate methods to receive normal messages and errors.
236 class CastSocketMessageDelegate : public CastTransport::Delegate {
238 CastSocketMessageDelegate(CastSocketImpl* socket);
239 ~CastSocketMessageDelegate() override;
241 // CastTransport::Delegate implementation.
242 void OnError(ChannelError error_state) override;
243 void OnMessage(const CastMessage& message) override;
244 void Start() override;
247 CastSocketImpl* const socket_;
248 DISALLOW_COPY_AND_ASSIGN(CastSocketMessageDelegate);
251 // Replaces the internally-constructed transport object with one provided
252 // by the caller (e.g. a mock).
253 void SetTransportForTesting(std::unique_ptr<CastTransport> transport);
255 void SetPeerCertForTesting(scoped_refptr<net::X509Certificate> peer_cert);
257 // Verifies whether the socket complies with cast channel policy.
258 // Audio only channel policy mandates that a device declaring a video out
259 // capability must not have a certificate with audio only policy.
260 bool VerifyChannelPolicy(const AuthResult& result);
265 FRIEND_TEST_ALL_PREFIXES(MockCastSocketTest, TestObservers);
266 friend class AuthTransportDelegate;
268 void SetErrorState(ChannelError error_state) override;
270 // Frees resources and cancels pending callbacks. |ready_state_| will be set
271 // READY_STATE_CLOSED on completion. A no-op if |ready_state_| is already
272 // READY_STATE_CLOSED.
273 void CloseInternal();
275 // Verifies whether the challenge reply received from the peer is valid:
276 // 1. Signature in the reply is valid.
277 // 2. Certificate is rooted to a trusted CA.
278 virtual bool VerifyChallengeReply();
280 // Invoked by a cancelable closure when connection setup time
281 // exceeds the interval specified at |connect_timeout|.
282 void OnConnectTimeout();
284 /////////////////////////////////////////////////////////////////////////////
285 // Following methods work together to implement the following flow:
286 // 1. Create a new TCP socket and connect to it
287 // 2. Create a new SSL socket and try connecting to it
288 // 3. If connection fails due to invalid cert authority, then extract the
289 // peer certificate from the error.
290 // 4. Whitelist the peer certificate and try #1 and #2 again.
291 // 5. If SSL socket is connected successfully, and if protocol is casts://
292 // then issue an auth challenge request.
293 // 6. Validate the auth challenge response.
295 // Main method that performs connection state transitions.
296 void DoConnectLoop(int result);
297 // Each of the below Do* method is executed in the corresponding
298 // connection state. For example when connection state is TCP_CONNECT
299 // DoTcpConnect is called, and so on.
301 int DoTcpConnectComplete(int result);
303 int DoSslConnectComplete(int result);
304 int DoAuthChallengeSend();
305 int DoAuthChallengeSendComplete(int result);
306 int DoAuthChallengeReplyComplete(int result);
308 // Callback from network::mojom::NetworkContext::CreateTCPConnectedSocket.
309 void OnConnect(int result,
310 const base::Optional<net::IPEndPoint>& local_addr,
311 const base::Optional<net::IPEndPoint>& peer_addr,
312 mojo::ScopedDataPipeConsumerHandle receive_stream,
313 mojo::ScopedDataPipeProducerHandle send_stream);
314 void OnUpgradeToTLS(int result,
315 mojo::ScopedDataPipeConsumerHandle receive_stream,
316 mojo::ScopedDataPipeProducerHandle send_stream,
317 const base::Optional<net::SSLInfo>& ssl_info);
318 /////////////////////////////////////////////////////////////////////////////
320 // Resets the cancellable callback used for async invocations of
322 void ResetConnectLoopCallback();
324 // Posts a task to invoke |connect_loop_callback_| with |result| on the
325 // current message loop.
326 void PostTaskToStartConnectLoop(int result);
328 // Runs the external connection callback and resets it.
329 void DoConnectCallback();
331 virtual base::OneShotTimer* GetTimer();
333 void SetConnectState(ConnectionState connect_state);
334 void SetReadyState(ReadyState ready_state);
336 THREAD_CHECKER(thread_checker_);
338 // The id of the channel.
341 // Cast socket related settings.
342 CastSocketOpenParams open_params_;
344 // Shared logging object, used to log CastSocket events for diagnostics.
345 scoped_refptr<Logger> logger_;
347 NetworkContextGetter network_context_getter_;
349 // Owned ptr to the underlying TCP socket.
350 network::mojom::TCPConnectedSocketPtr tcp_socket_;
352 // Owned ptr to the underlying SSL socket.
353 network::mojom::TLSClientSocketPtr socket_;
355 // Helper class to write to the SSL socket.
356 std::unique_ptr<MojoDataPump> mojo_data_pump_;
358 // Certificate of the peer. This field may be empty if the peer
359 // certificate is not yet fetched.
360 scoped_refptr<net::X509Certificate> peer_cert_;
362 // The challenge context for the current connection.
363 const AuthContext auth_context_;
365 // Reply received from the receiver to a challenge request.
366 std::unique_ptr<CastMessage> challenge_reply_;
368 // Callbacks invoked when the socket is connected or fails to connect.
369 std::vector<OnOpenCallback> connect_callbacks_;
371 // Callback invoked by |connect_timeout_timer_| to cancel the connection.
372 base::CancelableClosure connect_timeout_callback_;
374 // Timer invoked when the connection has timed out.
375 std::unique_ptr<base::OneShotTimer> connect_timeout_timer_;
377 // Set when a timeout is triggered and the connection process has
381 // Whether the channel is audio only as identified by the device
382 // certificate during channel authentication.
385 // Connection flow state machine state.
386 ConnectionState connect_state_;
388 // Write flow state machine state.
389 WriteState write_state_;
391 // Read flow state machine state.
392 ReadState read_state_;
394 // The last error encountered by the channel.
395 ChannelError error_state_;
397 // The current status of the channel.
398 ReadyState ready_state_;
400 // Callback which, when invoked, will re-enter the connection state machine.
401 // Oustanding callbacks will be cancelled when |this| is destroyed.
402 // The callback signature is based on net::CompletionCallback, which passes
403 // operation result codes as byte counts in the success case, or as
404 // net::Error enum values for error cases.
405 base::CancelableCallback<void(int)> connect_loop_callback_;
407 // Cast message formatting and parsing layer.
408 std::unique_ptr<CastTransport> transport_;
410 // Caller's message read and error handling delegate.
411 std::unique_ptr<CastTransport::Delegate> delegate_;
413 // Raw pointer to the auth handshake delegate. Used to get detailed error
415 AuthTransportDelegate* auth_delegate_;
417 // List of socket observers.
418 base::ObserverList<Observer>::Unchecked observers_;
420 base::WeakPtrFactory<CastSocketImpl> weak_factory_;
422 DISALLOW_COPY_AND_ASSIGN(CastSocketImpl);
424 } // namespace cast_channel
426 #endif // COMPONENTS_CAST_CHANNEL_CAST_SOCKET_H_