3 * Copyright 2009, Google, Inc.
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
8 * 1. Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright notice,
11 * this list of conditions and the following disclaimer in the documentation
12 * and/or other materials provided with the distribution.
13 * 3. The name of the author may not be used to endorse or promote products
14 * derived from this software without specific prior written permission.
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 #ifndef TALK_P2P_BASE_FAKESESSION_H_
29 #define TALK_P2P_BASE_FAKESESSION_H_
35 #include "talk/base/buffer.h"
36 #include "talk/base/fakesslidentity.h"
37 #include "talk/base/sigslot.h"
38 #include "talk/base/sslfingerprint.h"
39 #include "talk/base/messagequeue.h"
40 #include "talk/p2p/base/session.h"
41 #include "talk/p2p/base/transport.h"
42 #include "talk/p2p/base/transportchannel.h"
43 #include "talk/p2p/base/transportchannelimpl.h"
49 struct PacketMessageData : public talk_base::MessageData {
50 PacketMessageData(const char* data, size_t len) : packet(data, len) {
52 talk_base::Buffer packet;
55 // Fake transport channel class, which can be passed to anything that needs a
56 // transport channel. Can be informed of another FakeTransportChannel via
58 class FakeTransportChannel : public TransportChannelImpl,
59 public talk_base::MessageHandler {
61 explicit FakeTransportChannel(Transport* transport,
62 const std::string& content_name,
64 : TransportChannelImpl(content_name, component),
65 transport_(transport),
71 role_(ICEROLE_UNKNOWN),
73 ice_proto_(ICEPROTO_HYBRID),
74 remote_ice_mode_(ICEMODE_FULL),
75 dtls_fingerprint_("", NULL, 0),
76 ssl_role_(talk_base::SSL_CLIENT) {
78 ~FakeTransportChannel() {
82 uint64 IceTiebreaker() const { return tiebreaker_; }
83 TransportProtocol protocol() const { return ice_proto_; }
84 IceMode remote_ice_mode() const { return remote_ice_mode_; }
85 const std::string& ice_ufrag() const { return ice_ufrag_; }
86 const std::string& ice_pwd() const { return ice_pwd_; }
87 const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; }
88 const std::string& remote_ice_pwd() const { return remote_ice_pwd_; }
89 const talk_base::SSLFingerprint& dtls_fingerprint() const {
90 return dtls_fingerprint_;
93 void SetAsync(bool async) {
97 virtual Transport* GetTransport() {
101 virtual void SetIceRole(IceRole role) { role_ = role; }
102 virtual IceRole GetIceRole() const { return role_; }
103 virtual void SetIceTiebreaker(uint64 tiebreaker) { tiebreaker_ = tiebreaker; }
104 virtual void SetIceProtocolType(IceProtocolType type) { ice_proto_ = type; }
105 virtual void SetIceCredentials(const std::string& ice_ufrag,
106 const std::string& ice_pwd) {
107 ice_ufrag_ = ice_ufrag;
110 virtual void SetRemoteIceCredentials(const std::string& ice_ufrag,
111 const std::string& ice_pwd) {
112 remote_ice_ufrag_ = ice_ufrag;
113 remote_ice_pwd_ = ice_pwd;
116 virtual void SetRemoteIceMode(IceMode mode) { remote_ice_mode_ = mode; }
117 virtual bool SetRemoteFingerprint(const std::string& alg, const uint8* digest,
119 dtls_fingerprint_ = talk_base::SSLFingerprint(alg, digest, digest_len);
122 virtual bool SetSslRole(talk_base::SSLRole role) {
126 virtual bool GetSslRole(talk_base::SSLRole* role) const {
131 virtual void Connect() {
132 if (state_ == STATE_INIT) {
133 state_ = STATE_CONNECTING;
136 virtual void Reset() {
137 if (state_ != STATE_INIT) {
140 dest_->state_ = STATE_INIT;
147 void SetWritable(bool writable) {
148 set_writable(writable);
151 void SetDestination(FakeTransportChannel* dest) {
152 if (state_ == STATE_CONNECTING && dest) {
153 // This simulates the delivery of candidates.
156 if (identity_ && dest_->identity_) {
158 dest_->do_dtls_ = true;
159 NegotiateSrtpCiphers();
161 state_ = STATE_CONNECTED;
162 dest_->state_ = STATE_CONNECTED;
164 dest_->set_writable(true);
165 } else if (state_ == STATE_CONNECTED && !dest) {
166 // Simulates loss of connectivity, by asymmetrically forgetting dest_.
168 state_ = STATE_CONNECTING;
173 virtual int SendPacket(const char* data, size_t len,
174 talk_base::DiffServCodePoint dscp, int flags) {
175 if (state_ != STATE_CONNECTED) {
179 if (flags != PF_SRTP_BYPASS && flags != 0) {
183 PacketMessageData* packet = new PacketMessageData(data, len);
185 talk_base::Thread::Current()->Post(this, 0, packet);
187 talk_base::Thread::Current()->Send(this, 0, packet);
189 return static_cast<int>(len);
191 virtual int SetOption(talk_base::Socket::Option opt, int value) {
194 virtual int GetError() {
198 virtual void OnSignalingReady() {
200 virtual void OnCandidate(const Candidate& candidate) {
203 virtual void OnMessage(talk_base::Message* msg) {
204 PacketMessageData* data = static_cast<PacketMessageData*>(
206 dest_->SignalReadPacket(dest_, data->packet.data(),
207 data->packet.length(), 0);
211 bool SetLocalIdentity(talk_base::SSLIdentity* identity) {
212 identity_ = identity;
217 void SetRemoteCertificate(talk_base::FakeSSLCertificate* cert) {
221 virtual bool IsDtlsActive() const {
225 virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) {
226 srtp_ciphers_ = ciphers;
230 virtual bool GetSrtpCipher(std::string* cipher) {
231 if (!chosen_srtp_cipher_.empty()) {
232 *cipher = chosen_srtp_cipher_;
238 virtual bool GetLocalIdentity(talk_base::SSLIdentity** identity) const {
242 *identity = identity_->GetReference();
246 virtual bool GetRemoteCertificate(talk_base::SSLCertificate** cert) const {
250 *cert = remote_cert_->GetReference();
254 virtual bool ExportKeyingMaterial(const std::string& label,
255 const uint8* context,
260 if (!chosen_srtp_cipher_.empty()) {
261 memset(result, 0xff, result_len);
268 virtual void NegotiateSrtpCiphers() {
269 for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin();
270 it1 != srtp_ciphers_.end(); ++it1) {
271 for (std::vector<std::string>::const_iterator it2 =
272 dest_->srtp_ciphers_.begin();
273 it2 != dest_->srtp_ciphers_.end(); ++it2) {
275 chosen_srtp_cipher_ = *it1;
276 dest_->chosen_srtp_cipher_ = *it2;
283 virtual bool GetStats(ConnectionInfos* infos) OVERRIDE {
286 infos->push_back(info);
291 enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED };
292 Transport* transport_;
293 FakeTransportChannel* dest_;
296 talk_base::SSLIdentity* identity_;
297 talk_base::FakeSSLCertificate* remote_cert_;
299 std::vector<std::string> srtp_ciphers_;
300 std::string chosen_srtp_cipher_;
303 IceProtocolType ice_proto_;
304 std::string ice_ufrag_;
305 std::string ice_pwd_;
306 std::string remote_ice_ufrag_;
307 std::string remote_ice_pwd_;
308 IceMode remote_ice_mode_;
309 talk_base::SSLFingerprint dtls_fingerprint_;
310 talk_base::SSLRole ssl_role_;
313 // Fake transport class, which can be passed to anything that needs a Transport.
314 // Can be informed of another FakeTransport via SetDestination (low-tech way
315 // of doing candidates)
316 class FakeTransport : public Transport {
318 typedef std::map<int, FakeTransportChannel*> ChannelMap;
319 FakeTransport(talk_base::Thread* signaling_thread,
320 talk_base::Thread* worker_thread,
321 const std::string& content_name,
322 PortAllocator* alllocator = NULL)
323 : Transport(signaling_thread, worker_thread,
324 content_name, "test_type", NULL),
330 DestroyAllChannels();
333 const ChannelMap& channels() const { return channels_; }
335 void SetAsync(bool async) { async_ = async; }
336 void SetDestination(FakeTransport* dest) {
338 for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
340 it->second->SetLocalIdentity(identity_);
341 SetChannelDestination(it->first, it->second);
345 void SetWritable(bool writable) {
346 for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
348 it->second->SetWritable(writable);
352 void set_identity(talk_base::SSLIdentity* identity) {
353 identity_ = identity;
356 using Transport::local_description;
357 using Transport::remote_description;
360 virtual TransportChannelImpl* CreateTransportChannel(int component) {
361 if (channels_.find(component) != channels_.end()) {
364 FakeTransportChannel* channel =
365 new FakeTransportChannel(this, content_name(), component);
366 channel->SetAsync(async_);
367 SetChannelDestination(component, channel);
368 channels_[component] = channel;
371 virtual void DestroyTransportChannel(TransportChannelImpl* channel) {
372 channels_.erase(channel->component());
375 virtual void SetIdentity_w(talk_base::SSLIdentity* identity) {
376 identity_ = identity;
378 virtual bool GetIdentity_w(talk_base::SSLIdentity** identity) {
382 *identity = identity_->GetReference();
387 FakeTransportChannel* GetFakeChannel(int component) {
388 ChannelMap::iterator it = channels_.find(component);
389 return (it != channels_.end()) ? it->second : NULL;
391 void SetChannelDestination(int component,
392 FakeTransportChannel* channel) {
393 FakeTransportChannel* dest_channel = NULL;
395 dest_channel = dest_->GetFakeChannel(component);
397 dest_channel->SetLocalIdentity(dest_->identity_);
400 channel->SetDestination(dest_channel);
403 // Note, this is distinct from the Channel map owned by Transport.
404 // This map just tracks the FakeTransportChannels created by this class.
405 ChannelMap channels_;
406 FakeTransport* dest_;
408 talk_base::SSLIdentity* identity_;
411 // Fake session class, which can be passed into a BaseChannel object for
412 // test purposes. Can be connected to other FakeSessions via Connect().
413 class FakeSession : public BaseSession {
415 explicit FakeSession()
416 : BaseSession(talk_base::Thread::Current(),
417 talk_base::Thread::Current(),
419 fail_create_channel_(false) {
421 explicit FakeSession(bool initiator)
422 : BaseSession(talk_base::Thread::Current(),
423 talk_base::Thread::Current(),
424 NULL, "", "", initiator),
425 fail_create_channel_(false) {
427 FakeSession(talk_base::Thread* worker_thread, bool initiator)
428 : BaseSession(talk_base::Thread::Current(),
430 NULL, "", "", initiator),
431 fail_create_channel_(false) {
434 FakeTransport* GetTransport(const std::string& content_name) {
435 return static_cast<FakeTransport*>(
436 BaseSession::GetTransport(content_name));
439 void Connect(FakeSession* dest) {
440 // Simulate the exchange of candidates.
441 CompleteNegotiation();
442 dest->CompleteNegotiation();
443 for (TransportMap::const_iterator it = transport_proxies().begin();
444 it != transport_proxies().end(); ++it) {
445 static_cast<FakeTransport*>(it->second->impl())->SetDestination(
446 dest->GetTransport(it->first));
450 virtual TransportChannel* CreateChannel(
451 const std::string& content_name,
452 const std::string& channel_name,
454 if (fail_create_channel_) {
457 return BaseSession::CreateChannel(content_name, channel_name, component);
460 void set_fail_channel_creation(bool fail_channel_creation) {
461 fail_create_channel_ = fail_channel_creation;
464 // TODO: Hoist this into Session when we re-work the Session code.
465 void set_ssl_identity(talk_base::SSLIdentity* identity) {
466 for (TransportMap::const_iterator it = transport_proxies().begin();
467 it != transport_proxies().end(); ++it) {
468 // We know that we have a FakeTransport*
470 static_cast<FakeTransport*>(it->second->impl())->set_identity
476 virtual Transport* CreateTransport(const std::string& content_name) {
477 return new FakeTransport(signaling_thread(), worker_thread(), content_name);
480 void CompleteNegotiation() {
481 for (TransportMap::const_iterator it = transport_proxies().begin();
482 it != transport_proxies().end(); ++it) {
483 it->second->CompleteNegotiation();
484 it->second->ConnectChannels();
489 bool fail_create_channel_;
492 } // namespace cricket
494 #endif // TALK_P2P_BASE_FAKESESSION_H_