1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "components/cast_channel/cast_message_handler.h"
10 #include "base/rand_util.h"
11 #include "base/strings/stringprintf.h"
12 #include "base/time/default_tick_clock.h"
13 #include "components/cast_channel/cast_socket_service.h"
14 #include "services/data_decoder/public/cpp/safe_json_parser.h"
15 #include "services/service_manager/public/cpp/connector.h"
17 namespace cast_channel {
21 // The max launch timeout amount for session launch requests.
22 constexpr base::TimeDelta kLaunchMaxTimeout = base::TimeDelta::FromMinutes(2);
24 void ReportParseError(const std::string& error) {
25 DVLOG(2) << "Error parsing JSON message: " << error;
30 GetAppAvailabilityRequest::GetAppAvailabilityRequest(
32 GetAppAvailabilityCallback callback,
33 const base::TickClock* clock,
34 const std::string& app_id)
35 : PendingRequest(request_id, std::move(callback), clock), app_id(app_id) {}
37 GetAppAvailabilityRequest::~GetAppAvailabilityRequest() = default;
39 VirtualConnection::VirtualConnection(int channel_id,
40 const std::string& source_id,
41 const std::string& destination_id)
42 : channel_id(channel_id),
44 destination_id(destination_id) {}
45 VirtualConnection::~VirtualConnection() = default;
47 bool VirtualConnection::operator<(const VirtualConnection& other) const {
48 return std::tie(channel_id, source_id, destination_id) <
49 std::tie(other.channel_id, other.source_id, other.destination_id);
52 InternalMessage::InternalMessage(CastMessageType type, base::Value message)
53 : type(type), message(std::move(message)) {}
54 InternalMessage::~InternalMessage() = default;
56 CastMessageHandler::CastMessageHandler(
57 CastSocketService* socket_service,
58 std::unique_ptr<service_manager::Connector> connector,
59 const base::Token& data_decoder_batch_id,
60 const std::string& user_agent,
61 const std::string& browser_version,
62 const std::string& locale)
63 : sender_id_(base::StringPrintf("sender-%d", base::RandInt(0, 1000000))),
64 connector_(std::move(connector)),
65 data_decoder_batch_id_(data_decoder_batch_id),
66 user_agent_(user_agent),
67 browser_version_(browser_version),
69 socket_service_(socket_service),
70 clock_(base::DefaultTickClock::GetInstance()),
71 weak_ptr_factory_(this) {
72 DETACH_FROM_SEQUENCE(sequence_checker_);
73 socket_service_->task_runner()->PostTask(
74 FROM_HERE, base::BindOnce(&CastSocketService::AddObserver,
75 base::Unretained(socket_service_),
76 base::Unretained(this)));
79 CastMessageHandler::~CastMessageHandler() {
80 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
81 socket_service_->RemoveObserver(this);
84 void CastMessageHandler::EnsureConnection(int channel_id,
85 const std::string& source_id,
86 const std::string& destination_id) {
87 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
88 CastSocket* socket = socket_service_->GetSocket(channel_id);
90 DVLOG(2) << "Socket not found: " << channel_id;
94 DoEnsureConnection(socket, source_id, destination_id);
97 CastMessageHandler::PendingRequests*
98 CastMessageHandler::GetOrCreatePendingRequests(int channel_id) {
99 CastMessageHandler::PendingRequests* requests = nullptr;
100 auto pending_it = pending_requests_.find(channel_id);
101 if (pending_it != pending_requests_.end()) {
102 return pending_it->second.get();
105 auto new_requests = std::make_unique<CastMessageHandler::PendingRequests>();
106 requests = new_requests.get();
107 pending_requests_.emplace(channel_id, std::move(new_requests));
111 void CastMessageHandler::RequestAppAvailability(
113 const std::string& app_id,
114 GetAppAvailabilityCallback callback) {
115 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
117 int channel_id = socket->id();
118 auto* requests = GetOrCreatePendingRequests(channel_id);
119 int request_id = NextRequestId();
121 DVLOG(2) << __func__ << ", channel_id: " << channel_id
122 << ", app_id: " << app_id << ", request_id: " << request_id;
123 if (requests->AddAppAvailabilityRequest(
124 std::make_unique<GetAppAvailabilityRequest>(
125 request_id, std::move(callback), clock_, app_id))) {
126 SendCastMessage(socket, CreateGetAppAvailabilityRequest(
127 sender_id_, request_id, app_id));
131 void CastMessageHandler::RequestReceiverStatus(int channel_id) {
132 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
134 CastSocket* socket = socket_service_->GetSocket(channel_id);
136 DVLOG(2) << __func__ << ": socket not found: " << channel_id;
140 int request_id = NextRequestId();
141 SendCastMessage(socket, CreateReceiverStatusRequest(sender_id_, request_id));
144 void CastMessageHandler::SendBroadcastMessage(
146 const std::vector<std::string>& app_ids,
147 const BroadcastRequest& request) {
148 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
150 CastSocket* socket = socket_service_->GetSocket(channel_id);
152 DVLOG(2) << __func__ << ": socket not found: " << channel_id;
156 int request_id = NextRequestId();
157 DVLOG(2) << __func__ << ", channel_id: " << channel_id
158 << ", request_id: " << request_id;
160 // Note: Even though the message is formatted like a request, we don't care
161 // about the response, as broadcasts are fire-and-forget.
162 CastMessage message =
163 CreateBroadcastRequest(sender_id_, request_id, app_ids, request);
164 SendCastMessage(socket, message);
167 void CastMessageHandler::LaunchSession(int channel_id,
168 const std::string& app_id,
169 base::TimeDelta launch_timeout,
170 LaunchSessionCallback callback) {
171 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
172 CastSocket* socket = socket_service_->GetSocket(channel_id);
174 DVLOG(2) << __func__ << ": socket not found: " << channel_id;
175 std::move(callback).Run(LaunchSessionResponse());
179 auto* requests = GetOrCreatePendingRequests(channel_id);
180 int request_id = NextRequestId();
181 // Cap the max launch timeout to avoid long-living pending requests.
182 launch_timeout = std::min(launch_timeout, kLaunchMaxTimeout);
183 DVLOG(2) << __func__ << ", channel_id: " << channel_id
184 << ", request_id: " << request_id;
185 if (requests->AddLaunchRequest(std::make_unique<LaunchSessionRequest>(
186 request_id, std::move(callback), clock_),
189 socket, CreateLaunchRequest(sender_id_, request_id, app_id, locale_));
193 void CastMessageHandler::StopSession(int channel_id,
194 const std::string& session_id,
195 ResultCallback callback) {
196 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
197 CastSocket* socket = socket_service_->GetSocket(channel_id);
199 DVLOG(2) << __func__ << ": socket not found: " << channel_id;
203 auto* requests = GetOrCreatePendingRequests(channel_id);
204 int request_id = NextRequestId();
205 DVLOG(2) << __func__ << ", channel_id: " << channel_id
206 << ", request_id: " << request_id;
207 if (requests->AddStopRequest(std::make_unique<StopSessionRequest>(
208 request_id, std::move(callback), clock_))) {
209 SendCastMessage(socket,
210 CreateStopRequest(sender_id_, request_id, session_id));
214 Result CastMessageHandler::SendAppMessage(int channel_id,
215 const CastMessage& message) {
216 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
217 DCHECK(!IsCastInternalNamespace(message.namespace_()))
218 << ": unexpected app message namespace: " << message.namespace_();
220 CastSocket* socket = socket_service_->GetSocket(channel_id);
222 DVLOG(2) << __func__ << ": socket not found: " << channel_id;
223 return Result::kFailed;
226 SendCastMessage(socket, message);
230 base::Optional<int> CastMessageHandler::SendMediaRequest(
232 const base::Value& body,
233 const std::string& source_id,
234 const std::string& destination_id) {
235 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
237 CastSocket* socket = socket_service_->GetSocket(channel_id);
239 DVLOG(2) << __func__ << ": socket not found: " << channel_id;
240 return base::nullopt;
243 int request_id = NextRequestId();
245 socket, CreateMediaRequest(body, request_id, source_id, destination_id));
249 Result CastMessageHandler::SendSetVolumeRequest(int channel_id,
250 const base::Value& body,
251 const std::string& source_id,
252 ResultCallback callback) {
253 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
255 CastSocket* socket = socket_service_->GetSocket(channel_id);
257 DVLOG(2) << __func__ << ": socket not found: " << channel_id;
258 return Result::kFailed;
261 auto* requests = GetOrCreatePendingRequests(channel_id);
262 int request_id = NextRequestId();
264 requests->AddVolumeRequest(std::make_unique<SetVolumeRequest>(
265 request_id, std::move(callback), clock_));
266 SendCastMessage(socket, CreateSetVolumeRequest(body, request_id, source_id));
270 void CastMessageHandler::AddObserver(Observer* observer) {
271 observers_.AddObserver(observer);
274 void CastMessageHandler::RemoveObserver(Observer* observer) {
275 observers_.RemoveObserver(observer);
278 void CastMessageHandler::OnError(const CastSocket& socket,
279 ChannelError error_state) {
280 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
281 int channel_id = socket.id();
283 base::EraseIf(virtual_connections_,
284 [&channel_id](const VirtualConnection& connection) {
285 return connection.channel_id == channel_id;
288 pending_requests_.erase(channel_id);
291 void CastMessageHandler::OnMessage(const CastSocket& socket,
292 const CastMessage& message) {
293 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
294 DVLOG(2) << __func__ << ", channel_id: " << socket.id()
295 << ", message: " << CastMessageToString(message);
296 if (IsCastInternalNamespace(message.namespace_())) {
297 if (message.payload_type() ==
298 cast_channel::CastMessage_PayloadType_STRING) {
299 data_decoder::SafeJsonParser::ParseBatch(
300 connector_.get(), message.payload_utf8(),
301 base::BindRepeating(&CastMessageHandler::HandleCastInternalMessage,
302 weak_ptr_factory_.GetWeakPtr(), socket.id(),
303 message.source_id(), message.destination_id()),
304 base::BindRepeating(&ReportParseError), data_decoder_batch_id_);
306 DLOG(ERROR) << "Dropping internal message with binary payload: "
307 << message.namespace_();
310 DVLOG(2) << "Got app message from cast channel with namespace: "
311 << message.namespace_();
312 for (auto& observer : observers_)
313 observer.OnAppMessage(socket.id(), message);
317 void CastMessageHandler::OnReadyStateChanged(const CastSocket& socket) {
318 if (socket.ready_state() == ReadyState::CLOSED)
319 pending_requests_.erase(socket.id());
322 void CastMessageHandler::HandleCastInternalMessage(
324 const std::string& source_id,
325 const std::string& destination_id,
326 std::unique_ptr<base::Value> payload) {
327 if (!payload->is_dict()) {
328 ReportParseError("Parsed message not a dictionary");
332 // Check if the socket still exists as it might have been removed during
334 if (!socket_service_->GetSocket(channel_id)) {
335 DVLOG(2) << __func__ << ": socket not found: " << channel_id;
339 base::Optional<int> request_id = GetRequestIdFromResponse(*payload);
341 auto requests_it = pending_requests_.find(channel_id);
342 if (requests_it != pending_requests_.end())
343 requests_it->second->HandlePendingRequest(*request_id, *payload);
346 CastMessageType type = ParseMessageTypeFromPayload(*payload);
347 if (type == CastMessageType::kOther) {
348 DVLOG(2) << "Unknown message type: " << *payload;
352 if (type == CastMessageType::kCloseConnection) {
353 // Source / destination is flipped.
354 virtual_connections_.erase(
355 VirtualConnection(channel_id, destination_id, source_id));
359 InternalMessage internal_message(type, std::move(*payload));
360 for (auto& observer : observers_)
361 observer.OnInternalMessage(channel_id, internal_message);
364 void CastMessageHandler::SendCastMessage(CastSocket* socket,
365 const CastMessage& message) {
366 // A virtual connection must be opened to the receiver before other messages
368 DoEnsureConnection(socket, message.source_id(), message.destination_id());
369 socket->transport()->SendMessage(
370 message, base::BindRepeating(&CastMessageHandler::OnMessageSent,
371 weak_ptr_factory_.GetWeakPtr()));
374 void CastMessageHandler::DoEnsureConnection(CastSocket* socket,
375 const std::string& source_id,
376 const std::string& destination_id) {
377 VirtualConnection connection(socket->id(), source_id, destination_id);
378 if (virtual_connections_.find(connection) != virtual_connections_.end())
381 DVLOG(1) << "Creating VC for channel: " << connection.channel_id
382 << ", source: " << connection.source_id
383 << ", dest: " << connection.destination_id;
384 CastMessage virtual_connection_request = CreateVirtualConnectionRequest(
385 connection.source_id, connection.destination_id,
386 connection.destination_id == kPlatformReceiverId
387 ? VirtualConnectionType::kStrong
388 : VirtualConnectionType::kInvisible,
389 user_agent_, browser_version_);
390 socket->transport()->SendMessage(
391 virtual_connection_request,
392 base::BindRepeating(&CastMessageHandler::OnMessageSent,
393 weak_ptr_factory_.GetWeakPtr()));
395 // We assume the virtual connection request will succeed; otherwise this
396 // will eventually self-correct.
397 virtual_connections_.insert(connection);
400 void CastMessageHandler::OnMessageSent(int result) {
401 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
402 DVLOG_IF(2, result < 0) << "SendMessage failed with code: " << result;
405 CastMessageHandler::PendingRequests::PendingRequests() {}
406 CastMessageHandler::PendingRequests::~PendingRequests() {
407 for (auto& request : pending_app_availability_requests_) {
408 std::move(request->callback)
409 .Run(request->app_id, GetAppAvailabilityResult::kUnknown);
412 if (pending_launch_session_request_) {
413 LaunchSessionResponse response;
414 response.result = LaunchSessionResponse::kError;
415 std::move(pending_launch_session_request_->callback)
416 .Run(std::move(response));
419 if (pending_stop_session_request_)
420 std::move(pending_stop_session_request_->callback).Run(Result::kFailed);
422 for (auto& request : pending_volume_requests_by_id_)
423 std::move(request.second->callback).Run(Result::kFailed);
426 bool CastMessageHandler::PendingRequests::AddAppAvailabilityRequest(
427 std::unique_ptr<GetAppAvailabilityRequest> request) {
428 const std::string& app_id = request->app_id;
429 int request_id = request->request_id;
430 request->timeout_timer.Start(
431 FROM_HERE, kRequestTimeout,
433 &CastMessageHandler::PendingRequests::AppAvailabilityTimedOut,
434 base::Unretained(this), request_id));
436 // Look for a request with the given app ID.
437 bool found = std::find_if(pending_app_availability_requests_.begin(),
438 pending_app_availability_requests_.end(),
439 [&app_id](const auto& old_request) {
440 return old_request->app_id == app_id;
441 }) != pending_app_availability_requests_.end();
442 pending_app_availability_requests_.emplace_back(std::move(request));
446 bool CastMessageHandler::PendingRequests::AddLaunchRequest(
447 std::unique_ptr<LaunchSessionRequest> request,
448 base::TimeDelta timeout) {
449 if (pending_launch_session_request_)
452 int request_id = request->request_id;
453 request->timeout_timer.Start(
456 &CastMessageHandler::PendingRequests::LaunchSessionTimedOut,
457 base::Unretained(this), request_id));
458 pending_launch_session_request_ = std::move(request);
462 bool CastMessageHandler::PendingRequests::AddStopRequest(
463 std::unique_ptr<StopSessionRequest> request) {
464 if (pending_stop_session_request_)
467 int request_id = request->request_id;
468 request->timeout_timer.Start(
469 FROM_HERE, kRequestTimeout,
470 base::BindOnce(&CastMessageHandler::PendingRequests::StopSessionTimedOut,
471 base::Unretained(this), request_id));
472 pending_stop_session_request_ = std::move(request);
476 void CastMessageHandler::PendingRequests::AddVolumeRequest(
477 std::unique_ptr<SetVolumeRequest> request) {
478 int request_id = request->request_id;
479 request->timeout_timer.Start(
480 FROM_HERE, kRequestTimeout,
481 base::BindOnce(&CastMessageHandler::PendingRequests::SetVolumeTimedOut,
482 base::Unretained(this), request_id));
483 pending_volume_requests_by_id_.emplace(request_id, std::move(request));
486 void CastMessageHandler::PendingRequests::HandlePendingRequest(
488 const base::Value& response) {
489 // Look up an app availability request by its |request_id|.
490 auto app_availability_it =
491 std::find_if(pending_app_availability_requests_.begin(),
492 pending_app_availability_requests_.end(),
493 [request_id](const auto& request_ptr) {
494 return request_ptr->request_id == request_id;
496 // If we found a request, process and remove all requests with the same
497 // |app_id|, which will of course include the one we just found.
498 if (app_availability_it != pending_app_availability_requests_.end()) {
499 std::string app_id = (*app_availability_it)->app_id;
500 GetAppAvailabilityResult result =
501 GetAppAvailabilityResultFromResponse(response, app_id);
502 base::EraseIf(pending_app_availability_requests_,
503 [&app_id, result](const auto& request_ptr) {
504 if (request_ptr->app_id == app_id) {
505 std::move(request_ptr->callback).Run(app_id, result);
513 if (pending_launch_session_request_ &&
514 pending_launch_session_request_->request_id == request_id) {
515 std::move(pending_launch_session_request_->callback)
516 .Run(GetLaunchSessionResponse(response));
517 pending_launch_session_request_.reset();
521 if (pending_stop_session_request_ &&
522 pending_stop_session_request_->request_id == request_id) {
523 std::move(pending_stop_session_request_->callback).Run(Result::kOk);
524 pending_stop_session_request_.reset();
528 auto volume_it = pending_volume_requests_by_id_.find(request_id);
529 if (volume_it != pending_volume_requests_by_id_.end()) {
530 std::move(volume_it->second->callback).Run(Result::kOk);
531 pending_volume_requests_by_id_.erase(volume_it);
536 void CastMessageHandler::PendingRequests::AppAvailabilityTimedOut(
538 DVLOG(1) << __func__ << ", request_id: " << request_id;
540 auto it = std::find_if(pending_app_availability_requests_.begin(),
541 pending_app_availability_requests_.end(),
542 [&request_id](const auto& request) {
543 return request->request_id == request_id;
546 CHECK(it != pending_app_availability_requests_.end());
547 std::move((*it)->callback)
548 .Run((*it)->app_id, GetAppAvailabilityResult::kUnknown);
549 pending_app_availability_requests_.erase(it);
552 void CastMessageHandler::PendingRequests::LaunchSessionTimedOut(
554 DVLOG(1) << __func__ << ", request_id: " << request_id;
555 CHECK(pending_launch_session_request_);
556 CHECK(pending_launch_session_request_->request_id == request_id);
558 LaunchSessionResponse response;
559 response.result = LaunchSessionResponse::kTimedOut;
560 std::move(pending_launch_session_request_->callback).Run(std::move(response));
561 pending_launch_session_request_.reset();
564 void CastMessageHandler::PendingRequests::StopSessionTimedOut(int request_id) {
565 DVLOG(1) << __func__ << ", request_id: " << request_id;
566 CHECK(pending_stop_session_request_);
567 CHECK(pending_stop_session_request_->request_id == request_id);
569 std::move(pending_stop_session_request_->callback).Run(Result::kFailed);
570 pending_stop_session_request_.reset();
573 void CastMessageHandler::PendingRequests::SetVolumeTimedOut(int request_id) {
574 DVLOG(1) << __func__ << ", request_id: " << request_id;
575 auto it = pending_volume_requests_by_id_.find(request_id);
576 DCHECK(it != pending_volume_requests_by_id_.end());
577 std::move(it->second->callback).Run(Result::kFailed);
578 pending_volume_requests_by_id_.erase(it);
581 } // namespace cast_channel