1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "components/cast_channel/cast_transport.h"
13 #include "base/bind.h"
14 #include "base/format_macros.h"
15 #include "base/location.h"
16 #include "base/numerics/safe_conversions.h"
17 #include "base/single_thread_task_runner.h"
18 #include "base/strings/stringprintf.h"
19 #include "base/threading/thread_task_runner_handle.h"
20 #include "components/cast_channel/cast_framer.h"
21 #include "components/cast_channel/cast_message_util.h"
22 #include "components/cast_channel/logger.h"
23 #include "components/cast_channel/proto/cast_channel.pb.h"
24 #include "net/base/net_errors.h"
26 #define VLOG_WITH_CONNECTION(level) \
27 VLOG(level) << "[" << ip_endpoint_.ToString() << ", auth=SSL_VERIFIED] "
29 namespace cast_channel {
31 CastTransportImpl::CastTransportImpl(Channel* channel,
33 const net::IPEndPoint& ip_endpoint,
34 scoped_refptr<Logger> logger)
37 write_state_(WriteState::IDLE),
38 read_state_(ReadState::READ),
39 error_state_(ChannelError::NONE),
40 channel_id_(channel_id),
41 ip_endpoint_(ip_endpoint),
45 // Buffer is reused across messages to minimize unnecessary buffer
47 read_buffer_ = base::MakeRefCounted<net::GrowableIOBuffer>();
48 read_buffer_->SetCapacity(MessageFramer::MessageHeader::max_message_size());
49 framer_.reset(new MessageFramer(read_buffer_));
52 CastTransportImpl::~CastTransportImpl() {
53 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
57 bool CastTransportImpl::IsTerminalWriteState(WriteState write_state) {
58 return write_state == WriteState::WRITE_ERROR ||
59 write_state == WriteState::IDLE;
62 bool CastTransportImpl::IsTerminalReadState(ReadState read_state) {
63 return read_state == ReadState::READ_ERROR;
67 void CastTransportImpl::SetReadDelegate(std::unique_ptr<Delegate> delegate) {
68 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
70 delegate_ = std::move(delegate);
76 void CastTransportImpl::FlushWriteQueue() {
77 for (; !write_queue_.empty(); write_queue_.pop()) {
78 net::CompletionCallback& callback = write_queue_.front().callback;
79 base::ThreadTaskRunnerHandle::Get()->PostTask(
80 FROM_HERE, base::BindOnce(callback, net::ERR_FAILED));
85 void CastTransportImpl::SendMessage(const CastMessage& message,
86 const net::CompletionCallback& callback) {
87 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
88 DCHECK(IsCastMessageValid(message));
89 std::string serialized_message;
90 if (!MessageFramer::Serialize(message, &serialized_message)) {
91 base::ThreadTaskRunnerHandle::Get()->PostTask(
92 FROM_HERE, base::BindOnce(callback, net::ERR_FAILED));
95 WriteRequest write_request(message.namespace_(), serialized_message,
98 write_queue_.push(write_request);
99 if (write_state_ == WriteState::IDLE) {
100 SetWriteState(WriteState::WRITE);
101 OnWriteResult(net::OK);
105 CastTransportImpl::WriteRequest::WriteRequest(
106 const std::string& namespace_,
107 const std::string& payload,
108 const net::CompletionCallback& callback)
109 : message_namespace(namespace_), callback(callback) {
110 VLOG(2) << "WriteRequest size: " << payload.size();
111 io_buffer = base::MakeRefCounted<net::DrainableIOBuffer>(
112 base::MakeRefCounted<net::StringIOBuffer>(payload), payload.size());
115 CastTransportImpl::WriteRequest::WriteRequest(const WriteRequest& other) =
118 CastTransportImpl::WriteRequest::~WriteRequest() {}
120 void CastTransportImpl::SetReadState(ReadState read_state) {
121 if (read_state_ != read_state)
122 read_state_ = read_state;
125 void CastTransportImpl::SetWriteState(WriteState write_state) {
126 if (write_state_ != write_state)
127 write_state_ = write_state;
130 void CastTransportImpl::SetErrorState(ChannelError error_state) {
131 VLOG_WITH_CONNECTION(2) << "SetErrorState: "
132 << ::cast_channel::ChannelErrorToString(error_state);
133 error_state_ = error_state;
136 void CastTransportImpl::OnWriteResult(int result) {
137 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
138 DCHECK_NE(WriteState::IDLE, write_state_);
139 if (write_queue_.empty()) {
140 SetWriteState(WriteState::IDLE);
144 // Network operations can either finish synchronously or asynchronously.
145 // This method executes the state machine transitions in a loop so that
146 // write state transitions happen even when network operations finish
150 VLOG_WITH_CONNECTION(2)
151 << "OnWriteResult (state=" << AsInteger(write_state_) << ", "
152 << "result=" << rv << ", "
153 << "queue size=" << write_queue_.size() << ")";
155 WriteState state = write_state_;
156 write_state_ = WriteState::UNKNOWN;
158 case WriteState::WRITE:
161 case WriteState::WRITE_COMPLETE:
162 rv = DoWriteComplete(rv);
164 case WriteState::DO_CALLBACK:
165 rv = DoWriteCallback();
167 case WriteState::HANDLE_ERROR:
168 rv = DoWriteHandleError(rv);
169 DCHECK_EQ(WriteState::WRITE_ERROR, write_state_);
172 NOTREACHED() << "Unknown state in write state machine: "
174 SetWriteState(WriteState::WRITE_ERROR);
175 SetErrorState(ChannelError::UNKNOWN);
176 rv = net::ERR_FAILED;
179 } while (rv != net::ERR_IO_PENDING && !IsTerminalWriteState(write_state_));
181 if (write_state_ == WriteState::WRITE_ERROR) {
183 DCHECK_NE(ChannelError::NONE, error_state_);
184 VLOG_WITH_CONNECTION(2) << "Sending OnError().";
185 delegate_->OnError(error_state_);
189 int CastTransportImpl::DoWrite() {
190 DCHECK(!write_queue_.empty());
191 WriteRequest& request = write_queue_.front();
193 VLOG_WITH_CONNECTION(2) << "WriteData byte_count = "
194 << request.io_buffer->size() << " bytes_written "
195 << request.io_buffer->BytesConsumed();
197 SetWriteState(WriteState::WRITE_COMPLETE);
199 channel_->Write(request.io_buffer.get(), request.io_buffer->BytesRemaining(),
200 base::BindOnce(&CastTransportImpl::OnWriteResult,
201 base::Unretained(this)));
202 return net::ERR_IO_PENDING;
205 int CastTransportImpl::DoWriteComplete(int result) {
206 VLOG_WITH_CONNECTION(2) << "DoWriteComplete result=" << result;
207 DCHECK(!write_queue_.empty());
208 if (result <= 0) { // NOTE that 0 also indicates an error
209 logger_->LogSocketEventWithRv(channel_id_, ChannelEvent::SOCKET_WRITE,
211 SetErrorState(ChannelError::CAST_SOCKET_ERROR);
212 SetWriteState(WriteState::HANDLE_ERROR);
213 return result == 0 ? net::ERR_FAILED : result;
216 // Some bytes were successfully written
217 WriteRequest& request = write_queue_.front();
218 scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer;
219 io_buffer->DidConsume(result);
220 if (io_buffer->BytesRemaining() == 0) { // Message fully sent
221 SetWriteState(WriteState::DO_CALLBACK);
223 SetWriteState(WriteState::WRITE);
229 int CastTransportImpl::DoWriteCallback() {
230 VLOG_WITH_CONNECTION(2) << "DoWriteCallback";
231 DCHECK(!write_queue_.empty());
233 WriteRequest& request = write_queue_.front();
234 base::ThreadTaskRunnerHandle::Get()->PostTask(
235 FROM_HERE, base::BindOnce(request.callback, net::OK));
238 if (write_queue_.empty()) {
239 SetWriteState(WriteState::IDLE);
241 SetWriteState(WriteState::WRITE);
247 int CastTransportImpl::DoWriteHandleError(int result) {
248 VLOG_WITH_CONNECTION(2) << "DoWriteHandleError result=" << result;
249 DCHECK_NE(ChannelError::NONE, error_state_);
250 DCHECK_LT(result, 0);
251 SetWriteState(WriteState::WRITE_ERROR);
252 return net::ERR_FAILED;
255 void CastTransportImpl::Start() {
256 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
258 DCHECK_EQ(ReadState::READ, read_state_);
259 DCHECK(delegate_) << "Read delegate must be set prior to calling Start()";
262 SetReadState(ReadState::READ);
264 // Start the read state machine.
265 OnReadResult(net::OK);
268 void CastTransportImpl::OnReadResult(int result) {
269 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
270 // Network operations can either finish synchronously or asynchronously.
271 // This method executes the state machine transitions in a loop so that
272 // write state transitions happen even when network operations finish
276 VLOG_WITH_CONNECTION(2) << "OnReadResult(state=" << AsInteger(read_state_)
277 << ", result=" << rv << ")";
278 ReadState state = read_state_;
279 read_state_ = ReadState::UNKNOWN;
282 case ReadState::READ:
285 case ReadState::READ_COMPLETE:
286 rv = DoReadComplete(rv);
288 case ReadState::DO_CALLBACK:
289 rv = DoReadCallback();
291 case ReadState::HANDLE_ERROR:
292 rv = DoReadHandleError(rv);
293 DCHECK_EQ(read_state_, ReadState::READ_ERROR);
296 NOTREACHED() << "Unknown state in read state machine: "
298 SetReadState(ReadState::READ_ERROR);
299 SetErrorState(ChannelError::UNKNOWN);
300 rv = net::ERR_FAILED;
303 } while (rv != net::ERR_IO_PENDING && !IsTerminalReadState(read_state_));
305 if (IsTerminalReadState(read_state_)) {
306 DCHECK_EQ(ReadState::READ_ERROR, read_state_);
307 VLOG_WITH_CONNECTION(2) << "Sending OnError().";
308 delegate_->OnError(error_state_);
312 int CastTransportImpl::DoRead() {
313 VLOG_WITH_CONNECTION(2) << "DoRead";
314 SetReadState(ReadState::READ_COMPLETE);
316 // Determine how many bytes need to be read.
317 size_t num_bytes_to_read = framer_->BytesRequested();
318 DCHECK_GT(num_bytes_to_read, 0u);
320 // Read up to num_bytes_to_read into |current_read_buffer_|.
322 read_buffer_.get(), base::checked_cast<uint32_t>(num_bytes_to_read),
323 base::BindOnce(&CastTransportImpl::OnReadResult, base::Unretained(this)));
324 return net::ERR_IO_PENDING;
327 int CastTransportImpl::DoReadComplete(int result) {
328 VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result;
330 logger_->LogSocketEventWithRv(channel_id_, ChannelEvent::SOCKET_READ,
332 VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket.";
333 SetErrorState(ChannelError::CAST_SOCKET_ERROR);
334 SetReadState(ReadState::HANDLE_ERROR);
335 return result == 0 ? net::ERR_FAILED : result;
339 DCHECK(!current_message_);
340 ChannelError framing_error;
341 current_message_ = framer_->Ingest(result, &message_size, &framing_error);
342 if (current_message_.get() && (framing_error == ChannelError::NONE)) {
343 DCHECK_GT(message_size, static_cast<size_t>(0));
344 SetReadState(ReadState::DO_CALLBACK);
345 } else if (framing_error != ChannelError::NONE) {
346 DCHECK(!current_message_);
347 SetErrorState(ChannelError::INVALID_MESSAGE);
348 SetReadState(ReadState::HANDLE_ERROR);
350 DCHECK(!current_message_);
351 SetReadState(ReadState::READ);
356 int CastTransportImpl::DoReadCallback() {
357 VLOG_WITH_CONNECTION(2) << "DoReadCallback";
358 if (!IsCastMessageValid(*current_message_)) {
359 SetReadState(ReadState::HANDLE_ERROR);
360 SetErrorState(ChannelError::INVALID_MESSAGE);
361 return net::ERR_INVALID_RESPONSE;
363 SetReadState(ReadState::READ);
364 delegate_->OnMessage(*current_message_);
365 current_message_.reset();
369 int CastTransportImpl::DoReadHandleError(int result) {
370 VLOG_WITH_CONNECTION(2) << "DoReadHandleError";
371 DCHECK_NE(ChannelError::NONE, error_state_);
372 DCHECK_LE(result, 0);
373 SetReadState(ReadState::READ_ERROR);
374 return net::ERR_FAILED;
377 } // namespace cast_channel