#include "content/browser/renderer_host/websocket_dispatcher_host.h"
+#include <algorithm>
#include <vector>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/memory/ref_counted.h"
+#include "base/memory/weak_ptr.h"
#include "content/browser/renderer_host/websocket_host.h"
#include "content/common/websocket.h"
#include "content/common/websocket_messages.h"
// This number is unlikely to occur by chance.
static const int kMagicRenderProcessId = 506116062;
+class WebSocketDispatcherHostTest;
+
// A mock of WebsocketHost which records received messages.
class MockWebSocketHost : public WebSocketHost {
public:
MockWebSocketHost(int routing_id,
WebSocketDispatcherHost* dispatcher,
- net::URLRequestContext* url_request_context)
- : WebSocketHost(routing_id, dispatcher, url_request_context) {
- }
+ net::URLRequestContext* url_request_context,
+ WebSocketDispatcherHostTest* owner);
virtual ~MockWebSocketHost() {}
- virtual bool OnMessageReceived(const IPC::Message& message) OVERRIDE{
+ virtual bool OnMessageReceived(const IPC::Message& message) OVERRIDE {
received_messages_.push_back(message);
return true;
}
+ virtual void GoAway() OVERRIDE;
+
std::vector<IPC::Message> received_messages_;
+ base::WeakPtr<WebSocketDispatcherHostTest> owner_;
};
class WebSocketDispatcherHostTest : public ::testing::Test {
public:
- WebSocketDispatcherHostTest() {
+ WebSocketDispatcherHostTest()
+ : weak_ptr_factory_(this) {
dispatcher_host_ = new WebSocketDispatcherHost(
kMagicRenderProcessId,
base::Bind(&WebSocketDispatcherHostTest::OnGetRequestContext,
base::Unretained(this)));
}
- virtual ~WebSocketDispatcherHostTest() {}
+ virtual ~WebSocketDispatcherHostTest() {
+ // We need to invalidate the issued WeakPtrs at the beginning of the
+ // destructor in order not to access destructed member variables.
+ weak_ptr_factory_.InvalidateWeakPtrs();
+ }
+
+ void GoAway(int routing_id) {
+ gone_hosts_.push_back(routing_id);
+ }
+
+ base::WeakPtr<WebSocketDispatcherHostTest> GetWeakPtr() {
+ return weak_ptr_factory_.GetWeakPtr();
+ }
protected:
scoped_refptr<WebSocketDispatcherHost> dispatcher_host_;
// Stores allocated MockWebSocketHost instances. Doesn't take ownership of
// them.
std::vector<MockWebSocketHost*> mock_hosts_;
+ std::vector<int> gone_hosts_;
+
+ base::WeakPtrFactory<WebSocketDispatcherHostTest> weak_ptr_factory_;
private:
net::URLRequestContext* OnGetRequestContext() {
WebSocketHost* CreateWebSocketHost(int routing_id) {
MockWebSocketHost* host =
- new MockWebSocketHost(routing_id, dispatcher_host_.get(), NULL);
+ new MockWebSocketHost(routing_id, dispatcher_host_.get(), NULL, this);
mock_hosts_.push_back(host);
return host;
}
};
+MockWebSocketHost::MockWebSocketHost(
+ int routing_id,
+ WebSocketDispatcherHost* dispatcher,
+ net::URLRequestContext* url_request_context,
+ WebSocketDispatcherHostTest* owner)
+ : WebSocketHost(routing_id, dispatcher, url_request_context),
+ owner_(owner->GetWeakPtr()) {}
+
+void MockWebSocketHost::GoAway() {
+ if (owner_)
+ owner_->GoAway(routing_id());
+}
+
TEST_F(WebSocketDispatcherHostTest, Construct) {
// Do nothing.
}
}
}
+TEST_F(WebSocketDispatcherHostTest, Destruct) {
+ WebSocketHostMsg_AddChannelRequest message1(
+ 123, GURL("ws://example.com/test"), std::vector<std::string>(),
+ url::Origin("http://example.com"), -1);
+ WebSocketHostMsg_AddChannelRequest message2(
+ 456, GURL("ws://example.com/test2"), std::vector<std::string>(),
+ url::Origin("http://example.com"), -1);
+
+ ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message1));
+ ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message2));
+
+ ASSERT_EQ(2u, mock_hosts_.size());
+
+ mock_hosts_.clear();
+ dispatcher_host_ = NULL;
+
+ ASSERT_EQ(2u, gone_hosts_.size());
+ // The gone_hosts_ ordering is not predictable because it depends on the
+ // hash_map ordering.
+ std::sort(gone_hosts_.begin(), gone_hosts_.end());
+ EXPECT_EQ(123, gone_hosts_[0]);
+ EXPECT_EQ(456, gone_hosts_[1]);
+}
+
} // namespace
} // namespace content