3 * Copyright 2009, Google, Inc.
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
8 * 1. Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright notice,
11 * this list of conditions and the following disclaimer in the documentation
12 * and/or other materials provided with the distribution.
13 * 3. The name of the author may not be used to endorse or promote products
14 * derived from this software without specific prior written permission.
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 #ifndef WEBRTC_P2P_BASE_FAKESESSION_H_
29 #define WEBRTC_P2P_BASE_FAKESESSION_H_
35 #include "webrtc/p2p/base/session.h"
36 #include "webrtc/p2p/base/transport.h"
37 #include "webrtc/p2p/base/transportchannel.h"
38 #include "webrtc/p2p/base/transportchannelimpl.h"
39 #include "webrtc/base/buffer.h"
40 #include "webrtc/base/fakesslidentity.h"
41 #include "webrtc/base/messagequeue.h"
42 #include "webrtc/base/sigslot.h"
43 #include "webrtc/base/sslfingerprint.h"
49 struct PacketMessageData : public rtc::MessageData {
50 PacketMessageData(const char* data, size_t len) : packet(data, len) {
55 // Fake transport channel class, which can be passed to anything that needs a
56 // transport channel. Can be informed of another FakeTransportChannel via
58 class FakeTransportChannel : public TransportChannelImpl,
59 public rtc::MessageHandler {
61 explicit FakeTransportChannel(Transport* transport,
62 const std::string& content_name,
64 : TransportChannelImpl(content_name, component),
65 transport_(transport),
71 role_(ICEROLE_UNKNOWN),
73 ice_proto_(ICEPROTO_HYBRID),
74 remote_ice_mode_(ICEMODE_FULL),
75 dtls_fingerprint_("", NULL, 0),
76 ssl_role_(rtc::SSL_CLIENT),
77 connection_count_(0) {
79 ~FakeTransportChannel() {
83 uint64 IceTiebreaker() const { return tiebreaker_; }
84 TransportProtocol protocol() const { return ice_proto_; }
85 IceMode remote_ice_mode() const { return remote_ice_mode_; }
86 const std::string& ice_ufrag() const { return ice_ufrag_; }
87 const std::string& ice_pwd() const { return ice_pwd_; }
88 const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; }
89 const std::string& remote_ice_pwd() const { return remote_ice_pwd_; }
90 const rtc::SSLFingerprint& dtls_fingerprint() const {
91 return dtls_fingerprint_;
94 void SetAsync(bool async) {
98 virtual Transport* GetTransport() {
102 virtual void SetIceRole(IceRole role) { role_ = role; }
103 virtual IceRole GetIceRole() const { return role_; }
104 virtual size_t GetConnectionCount() const { return connection_count_; }
105 virtual void SetIceTiebreaker(uint64 tiebreaker) { tiebreaker_ = tiebreaker; }
106 virtual bool GetIceProtocolType(IceProtocolType* type) const {
110 virtual void SetIceProtocolType(IceProtocolType type) { ice_proto_ = type; }
111 virtual void SetIceCredentials(const std::string& ice_ufrag,
112 const std::string& ice_pwd) {
113 ice_ufrag_ = ice_ufrag;
116 virtual void SetRemoteIceCredentials(const std::string& ice_ufrag,
117 const std::string& ice_pwd) {
118 remote_ice_ufrag_ = ice_ufrag;
119 remote_ice_pwd_ = ice_pwd;
122 virtual void SetRemoteIceMode(IceMode mode) { remote_ice_mode_ = mode; }
123 virtual bool SetRemoteFingerprint(const std::string& alg, const uint8* digest,
125 dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len);
128 virtual bool SetSslRole(rtc::SSLRole role) {
132 virtual bool GetSslRole(rtc::SSLRole* role) const {
137 virtual void Connect() {
138 if (state_ == STATE_INIT) {
139 state_ = STATE_CONNECTING;
142 virtual void Reset() {
143 if (state_ != STATE_INIT) {
146 dest_->state_ = STATE_INIT;
153 void SetWritable(bool writable) {
154 set_writable(writable);
157 void SetDestination(FakeTransportChannel* dest) {
158 if (state_ == STATE_CONNECTING && dest) {
159 // This simulates the delivery of candidates.
162 if (identity_ && dest_->identity_) {
164 dest_->do_dtls_ = true;
165 NegotiateSrtpCiphers();
167 state_ = STATE_CONNECTED;
168 dest_->state_ = STATE_CONNECTED;
170 dest_->set_writable(true);
171 } else if (state_ == STATE_CONNECTED && !dest) {
172 // Simulates loss of connectivity, by asymmetrically forgetting dest_.
174 state_ = STATE_CONNECTING;
179 void SetConnectionCount(size_t connection_count) {
180 size_t old_connection_count = connection_count_;
181 connection_count_ = connection_count;
182 if (connection_count_ < old_connection_count)
183 SignalConnectionRemoved(this);
186 virtual int SendPacket(const char* data, size_t len,
187 const rtc::PacketOptions& options, int flags) {
188 if (state_ != STATE_CONNECTED) {
192 if (flags != PF_SRTP_BYPASS && flags != 0) {
196 PacketMessageData* packet = new PacketMessageData(data, len);
198 rtc::Thread::Current()->Post(this, 0, packet);
200 rtc::Thread::Current()->Send(this, 0, packet);
202 return static_cast<int>(len);
204 virtual int SetOption(rtc::Socket::Option opt, int value) {
207 virtual int GetError() {
211 virtual void OnSignalingReady() {
213 virtual void OnCandidate(const Candidate& candidate) {
216 virtual void OnMessage(rtc::Message* msg) {
217 PacketMessageData* data = static_cast<PacketMessageData*>(
219 dest_->SignalReadPacket(dest_, data->packet.data(),
220 data->packet.length(),
221 rtc::CreatePacketTime(0), 0);
225 bool SetLocalIdentity(rtc::SSLIdentity* identity) {
226 identity_ = identity;
231 void SetRemoteCertificate(rtc::FakeSSLCertificate* cert) {
235 virtual bool IsDtlsActive() const {
239 virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) {
240 srtp_ciphers_ = ciphers;
244 virtual bool GetSrtpCipher(std::string* cipher) {
245 if (!chosen_srtp_cipher_.empty()) {
246 *cipher = chosen_srtp_cipher_;
252 virtual bool GetLocalIdentity(rtc::SSLIdentity** identity) const {
256 *identity = identity_->GetReference();
260 virtual bool GetRemoteCertificate(rtc::SSLCertificate** cert) const {
264 *cert = remote_cert_->GetReference();
268 virtual bool ExportKeyingMaterial(const std::string& label,
269 const uint8* context,
274 if (!chosen_srtp_cipher_.empty()) {
275 memset(result, 0xff, result_len);
282 virtual void NegotiateSrtpCiphers() {
283 for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin();
284 it1 != srtp_ciphers_.end(); ++it1) {
285 for (std::vector<std::string>::const_iterator it2 =
286 dest_->srtp_ciphers_.begin();
287 it2 != dest_->srtp_ciphers_.end(); ++it2) {
289 chosen_srtp_cipher_ = *it1;
290 dest_->chosen_srtp_cipher_ = *it2;
297 virtual bool GetStats(ConnectionInfos* infos) OVERRIDE {
300 infos->push_back(info);
305 enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED };
306 Transport* transport_;
307 FakeTransportChannel* dest_;
310 rtc::SSLIdentity* identity_;
311 rtc::FakeSSLCertificate* remote_cert_;
313 std::vector<std::string> srtp_ciphers_;
314 std::string chosen_srtp_cipher_;
317 IceProtocolType ice_proto_;
318 std::string ice_ufrag_;
319 std::string ice_pwd_;
320 std::string remote_ice_ufrag_;
321 std::string remote_ice_pwd_;
322 IceMode remote_ice_mode_;
323 rtc::SSLFingerprint dtls_fingerprint_;
324 rtc::SSLRole ssl_role_;
325 size_t connection_count_;
328 // Fake transport class, which can be passed to anything that needs a Transport.
329 // Can be informed of another FakeTransport via SetDestination (low-tech way
330 // of doing candidates)
331 class FakeTransport : public Transport {
333 typedef std::map<int, FakeTransportChannel*> ChannelMap;
334 FakeTransport(rtc::Thread* signaling_thread,
335 rtc::Thread* worker_thread,
336 const std::string& content_name,
337 PortAllocator* alllocator = NULL)
338 : Transport(signaling_thread, worker_thread,
339 content_name, "test_type", NULL),
345 DestroyAllChannels();
348 const ChannelMap& channels() const { return channels_; }
350 void SetAsync(bool async) { async_ = async; }
351 void SetDestination(FakeTransport* dest) {
353 for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
355 it->second->SetLocalIdentity(identity_);
356 SetChannelDestination(it->first, it->second);
360 void SetWritable(bool writable) {
361 for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
363 it->second->SetWritable(writable);
367 void set_identity(rtc::SSLIdentity* identity) {
368 identity_ = identity;
371 using Transport::local_description;
372 using Transport::remote_description;
375 virtual TransportChannelImpl* CreateTransportChannel(int component) {
376 if (channels_.find(component) != channels_.end()) {
379 FakeTransportChannel* channel =
380 new FakeTransportChannel(this, content_name(), component);
381 channel->SetAsync(async_);
382 SetChannelDestination(component, channel);
383 channels_[component] = channel;
386 virtual void DestroyTransportChannel(TransportChannelImpl* channel) {
387 channels_.erase(channel->component());
390 virtual void SetIdentity_w(rtc::SSLIdentity* identity) {
391 identity_ = identity;
393 virtual bool GetIdentity_w(rtc::SSLIdentity** identity) {
397 *identity = identity_->GetReference();
402 FakeTransportChannel* GetFakeChannel(int component) {
403 ChannelMap::iterator it = channels_.find(component);
404 return (it != channels_.end()) ? it->second : NULL;
406 void SetChannelDestination(int component,
407 FakeTransportChannel* channel) {
408 FakeTransportChannel* dest_channel = NULL;
410 dest_channel = dest_->GetFakeChannel(component);
412 dest_channel->SetLocalIdentity(dest_->identity_);
415 channel->SetDestination(dest_channel);
418 // Note, this is distinct from the Channel map owned by Transport.
419 // This map just tracks the FakeTransportChannels created by this class.
420 ChannelMap channels_;
421 FakeTransport* dest_;
423 rtc::SSLIdentity* identity_;
426 // Fake session class, which can be passed into a BaseChannel object for
427 // test purposes. Can be connected to other FakeSessions via Connect().
428 class FakeSession : public BaseSession {
430 explicit FakeSession()
431 : BaseSession(rtc::Thread::Current(),
432 rtc::Thread::Current(),
434 fail_create_channel_(false) {
436 explicit FakeSession(bool initiator)
437 : BaseSession(rtc::Thread::Current(),
438 rtc::Thread::Current(),
439 NULL, "", "", initiator),
440 fail_create_channel_(false) {
442 FakeSession(rtc::Thread* worker_thread, bool initiator)
443 : BaseSession(rtc::Thread::Current(),
445 NULL, "", "", initiator),
446 fail_create_channel_(false) {
449 FakeTransport* GetTransport(const std::string& content_name) {
450 return static_cast<FakeTransport*>(
451 BaseSession::GetTransport(content_name));
454 void Connect(FakeSession* dest) {
455 // Simulate the exchange of candidates.
456 CompleteNegotiation();
457 dest->CompleteNegotiation();
458 for (TransportMap::const_iterator it = transport_proxies().begin();
459 it != transport_proxies().end(); ++it) {
460 static_cast<FakeTransport*>(it->second->impl())->SetDestination(
461 dest->GetTransport(it->first));
465 virtual TransportChannel* CreateChannel(
466 const std::string& content_name,
467 const std::string& channel_name,
469 if (fail_create_channel_) {
472 return BaseSession::CreateChannel(content_name, channel_name, component);
475 void set_fail_channel_creation(bool fail_channel_creation) {
476 fail_create_channel_ = fail_channel_creation;
479 // TODO: Hoist this into Session when we re-work the Session code.
480 void set_ssl_identity(rtc::SSLIdentity* identity) {
481 for (TransportMap::const_iterator it = transport_proxies().begin();
482 it != transport_proxies().end(); ++it) {
483 // We know that we have a FakeTransport*
485 static_cast<FakeTransport*>(it->second->impl())->set_identity
491 virtual Transport* CreateTransport(const std::string& content_name) {
492 return new FakeTransport(signaling_thread(), worker_thread(), content_name);
495 void CompleteNegotiation() {
496 for (TransportMap::const_iterator it = transport_proxies().begin();
497 it != transport_proxies().end(); ++it) {
498 it->second->CompleteNegotiation();
499 it->second->ConnectChannels();
504 bool fail_create_channel_;
507 } // namespace cricket
509 #endif // WEBRTC_P2P_BASE_FAKESESSION_H_