1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "components/cast_channel/cast_socket_service.h"
7 #include "base/memory/ptr_util.h"
8 #include "base/task/post_task.h"
9 #include "components/cast_channel/cast_socket.h"
10 #include "components/cast_channel/logger.h"
11 #include "content/public/browser/browser_task_traits.h"
12 #include "content/public/browser/browser_thread.h"
14 using content::BrowserThread;
16 namespace cast_channel {
18 int CastSocketService::last_channel_id_ = 0;
20 CastSocketService::CastSocketService()
21 : logger_(new Logger()),
22 // IO thread's task runner is used because of:
23 // (1) ChromeURLRequestContextGetter::GetURLRequestContext, which is
24 // called by CastMediaSinkServiceImpl, must run on IO thread. (2) Parts of
25 // CastChannel extension API functions run on IO thread.
26 task_runner_(base::CreateSingleThreadTaskRunnerWithTraits(
27 {content::BrowserThread::IO})) {}
29 // This is a leaky singleton and the dtor won't be called.
30 CastSocketService::~CastSocketService() = default;
33 CastSocketService* CastSocketService::GetInstance() {
34 static CastSocketService* instance = new CastSocketService();
38 scoped_refptr<Logger> CastSocketService::GetLogger() {
42 CastSocket* CastSocketService::AddSocket(std::unique_ptr<CastSocket> socket) {
43 DCHECK(task_runner_->BelongsToCurrentThread());
45 int id = ++last_channel_id_;
48 auto* socket_ptr = socket.get();
49 sockets_.insert(std::make_pair(id, std::move(socket)));
53 std::unique_ptr<CastSocket> CastSocketService::RemoveSocket(int channel_id) {
54 DCHECK(task_runner_->BelongsToCurrentThread());
55 DCHECK(channel_id > 0);
56 auto socket_it = sockets_.find(channel_id);
58 std::unique_ptr<CastSocket> socket;
59 if (socket_it != sockets_.end()) {
60 socket = std::move(socket_it->second);
61 sockets_.erase(socket_it);
66 CastSocket* CastSocketService::GetSocket(int channel_id) const {
67 DCHECK(task_runner_->BelongsToCurrentThread());
68 DCHECK(channel_id > 0);
69 const auto& socket_it = sockets_.find(channel_id);
70 return socket_it == sockets_.end() ? nullptr : socket_it->second.get();
73 CastSocket* CastSocketService::GetSocket(
74 const net::IPEndPoint& ip_endpoint) const {
75 DCHECK(task_runner_->BelongsToCurrentThread());
76 auto it = std::find_if(
77 sockets_.begin(), sockets_.end(),
79 const std::pair<const int, std::unique_ptr<CastSocket>>& pair) {
80 return pair.second->ip_endpoint() == ip_endpoint;
82 return it == sockets_.end() ? nullptr : it->second.get();
85 void CastSocketService::OpenSocket(NetworkContextGetter network_context_getter,
86 const CastSocketOpenParams& open_params,
87 CastSocket::OnOpenCallback open_cb) {
88 DCHECK(task_runner_->BelongsToCurrentThread());
90 const net::IPEndPoint& ip_endpoint = open_params.ip_endpoint;
91 auto* socket = GetSocket(ip_endpoint);
93 // If cast socket does not exist.
94 if (socket_for_test_) {
95 socket = AddSocket(std::move(socket_for_test_));
97 socket = new CastSocketImpl(network_context_getter, open_params, logger_);
98 AddSocket(base::WrapUnique(socket));
102 for (auto& observer : observers_)
103 socket->AddObserver(&observer);
105 socket->Connect(std::move(open_cb));
108 void CastSocketService::AddObserver(CastSocket::Observer* observer) {
109 DCHECK(task_runner_->BelongsToCurrentThread());
111 if (observers_.HasObserver(observer))
114 observers_.AddObserver(observer);
115 for (auto& socket_it : sockets_)
116 socket_it.second->AddObserver(observer);
119 void CastSocketService::RemoveObserver(CastSocket::Observer* observer) {
120 DCHECK(task_runner_->BelongsToCurrentThread());
123 for (auto& socket_it : sockets_)
124 socket_it.second->RemoveObserver(observer);
125 observers_.RemoveObserver(observer);
128 void CastSocketService::SetSocketForTest(
129 std::unique_ptr<cast_channel::CastSocket> socket_for_test) {
130 socket_for_test_ = std::move(socket_for_test);
133 } // namespace cast_channel