blob: faa0b3e5e8c9fcf35c88289a49e5e64fda2475b6 [file] [log] [blame]
// Copyright 2013 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/websockets/websocket_handshake_stream_create_helper.h"
#include <string>
#include <utility>
#include <vector>
#include "base/containers/span.h"
#include "base/functional/callback.h"
#include "base/memory/scoped_refptr.h"
#include "base/notreached.h"
#include "base/strings/string_piece.h"
#include "base/task/single_thread_task_runner.h"
#include "base/time/default_tick_clock.h"
#include "base/time/time.h"
#include "net/base/auth.h"
#include "net/base/completion_once_callback.h"
#include "net/base/host_port_pair.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/base/load_flags.h"
#include "net/base/net_errors.h"
#include "net/base/network_anonymization_key.h"
#include "net/base/network_handle.h"
#include "net/base/privacy_mode.h"
#include "net/base/proxy_server.h"
#include "net/base/request_priority.h"
#include "net/base/test_completion_callback.h"
#include "net/cert/cert_verify_result.h"
#include "net/dns/public/host_resolver_results.h"
#include "net/dns/public/secure_dns_policy.h"
#include "net/http/http_request_info.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_response_info.h"
#include "net/http/transport_security_state.h"
#include "net/log/net_log.h"
#include "net/log/net_log_with_source.h"
#include "net/quic/address_utils.h"
#include "net/quic/crypto/proof_verifier_chromium.h"
#include "net/quic/mock_crypto_client_stream_factory.h"
#include "net/quic/mock_quic_data.h"
#include "net/quic/quic_chromium_alarm_factory.h"
#include "net/quic/quic_chromium_connection_helper.h"
#include "net/quic/quic_chromium_packet_reader.h"
#include "net/quic/quic_chromium_packet_writer.h"
#include "net/quic/quic_context.h"
#include "net/quic/quic_http_utils.h"
#include "net/quic/quic_server_info.h"
#include "net/quic/quic_session_key.h"
#include "net/quic/quic_test_packet_maker.h"
#include "net/quic/test_quic_crypto_client_config_handle.h"
#include "net/quic/test_task_runner.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool.h"
#include "net/socket/connect_job.h"
#include "net/socket/socket_tag.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/websocket_endpoint_lock_manager.h"
#include "net/spdy/spdy_session_key.h"
#include "net/spdy/spdy_test_util_common.h"
#include "net/ssl/ssl_config_service_defaults.h"
#include "net/ssl/ssl_info.h"
#include "net/test/cert_test_util.h"
#include "net/test/gtest_util.h"
#include "net/test/test_data_directory.h"
#include "net/test/test_with_task_environment.h"
#include "net/third_party/quiche/src/quiche/common/platform/api/quiche_flags.h"
#include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_client_config.h"
#include "net/third_party/quiche/src/quiche/quic/core/qpack/qpack_decoder.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_connection.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_connection_id.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_error_codes.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_time.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_types.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_utils.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
#include "net/third_party/quiche/src/quiche/quic/platform/api/quic_socket_address.h"
#include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h"
#include "net/third_party/quiche/src/quiche/quic/test_tools/mock_clock.h"
#include "net/third_party/quiche/src/quiche/quic/test_tools/mock_connection_id_generator.h"
#include "net/third_party/quiche/src/quiche/quic/test_tools/mock_random.h"
#include "net/third_party/quiche/src/quiche/quic/test_tools/qpack/qpack_test_utils.h"
#include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h"
#include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h"
#include "net/third_party/quiche/src/quiche/spdy/core/spdy_protocol.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "net/websockets/websocket_basic_handshake_stream.h"
#include "net/websockets/websocket_event_interface.h"
#include "net/websockets/websocket_stream.h"
#include "net/websockets/websocket_test_util.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "url/gurl.h"
#include "url/origin.h"
#include "url/scheme_host_port.h"
#include "url/url_constants.h"
namespace net {
class HttpNetworkSession;
class URLRequest;
class WebSocketHttp2HandshakeStream;
class WebSocketHttp3HandshakeStream;
class X509Certificate;
struct WebSocketHandshakeRequestInfo;
struct WebSocketHandshakeResponseInfo;
} // namespace net
using ::net::test::IsError;
using ::net::test::IsOk;
using ::testing::StrictMock;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::_;
namespace net {
namespace {
enum HandshakeStreamType {
BASIC_HANDSHAKE_STREAM,
HTTP2_HANDSHAKE_STREAM,
HTTP3_HANDSHAKE_STREAM
};
// This class encapsulates the details of creating a mock ClientSocketHandle.
class MockClientSocketHandleFactory {
public:
MockClientSocketHandleFactory()
: common_connect_job_params_(
socket_factory_maker_.factory(),
/*host_resolver=*/nullptr,
/*http_auth_cache=*/nullptr,
/*http_auth_handler_factory=*/nullptr,
/*spdy_session_pool=*/nullptr,
/*quic_supported_versions=*/nullptr,
/*quic_stream_factory=*/nullptr,
/*proxy_delegate=*/nullptr,
/*http_user_agent_settings=*/nullptr,
/*ssl_client_context=*/nullptr,
/*socket_performance_watcher_factory=*/nullptr,
/*network_quality_estimator=*/nullptr,
/*net_log=*/nullptr,
/*websocket_endpoint_lock_manager=*/nullptr,
/*http_server_properties=*/nullptr,
/*alpn_protos=*/nullptr,
/*application_settings=*/nullptr,
/*ignore_certificate_errors=*/nullptr),
pool_(1, 1, &common_connect_job_params_) {}
MockClientSocketHandleFactory(const MockClientSocketHandleFactory&) = delete;
MockClientSocketHandleFactory& operator=(
const MockClientSocketHandleFactory&) = delete;
// The created socket expects |expect_written| to be written to the socket,
// and will respond with |return_to_read|. The test will fail if the expected
// text is not written, or if all the bytes are not read.
std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
const std::string& expect_written,
const std::string& return_to_read) {
socket_factory_maker_.SetExpectations(expect_written, return_to_read);
auto socket_handle = std::make_unique<ClientSocketHandle>();
socket_handle->Init(
ClientSocketPool::GroupId(
url::SchemeHostPort(url::kHttpScheme, "a", 80),
PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(),
SecureDnsPolicy::kAllow),
scoped_refptr<ClientSocketPool::SocketParams>(),
absl::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
return socket_handle;
}
private:
WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
const CommonConnectJobParams common_connect_job_params_;
MockTransportClientSocketPool pool_;
};
class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
public:
~TestConnectDelegate() override = default;
void OnCreateRequest(URLRequest* request) override {}
void OnSuccess(
std::unique_ptr<WebSocketStream> stream,
std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
void OnFailure(const std::string& failure_message,
int net_error,
absl::optional<int> response_code) override {}
void OnStartOpeningHandshake(
std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
void OnSSLCertificateError(
std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
ssl_error_callbacks,
int net_error,
const SSLInfo& ssl_info,
bool fatal) override {}
int OnAuthRequired(const AuthChallengeInfo& auth_info,
scoped_refptr<HttpResponseHeaders> response_headers,
const IPEndPoint& host_port_pair,
base::OnceCallback<void(const AuthCredentials*)> callback,
absl::optional<AuthCredentials>* credentials) override {
*credentials = absl::nullopt;
return OK;
}
};
class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
public:
~MockWebSocketStreamRequestAPI() override = default;
MOCK_METHOD1(OnBasicHandshakeStreamCreated,
void(WebSocketBasicHandshakeStream* handshake_stream));
MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
void(WebSocketHttp2HandshakeStream* handshake_stream));
MOCK_METHOD1(OnHttp3HandshakeStreamCreated,
void(WebSocketHttp3HandshakeStream* handshake_stream));
MOCK_METHOD3(OnFailure,
void(const std::string& message,
int net_error,
absl::optional<int> response_code));
};
class WebSocketHandshakeStreamCreateHelperTest
: public TestWithParam<HandshakeStreamType>,
public WithTaskEnvironment {
protected:
WebSocketHandshakeStreamCreateHelperTest()
: quic_version_(quic::HandshakeProtocol::PROTOCOL_TLS1_3,
quic::QuicTransportVersion::QUIC_VERSION_IETF_RFC_V1),
mock_quic_data_(quic_version_) {}
std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
const std::vector<std::string>& sub_protocols,
const WebSocketExtraHeaders& extra_request_headers,
const WebSocketExtraHeaders& extra_response_headers) {
const char kPath[] = "/";
const char kOrigin[] = "http://origin.example.org";
const GURL url("wss://www.example.org/");
NetLogWithSource net_log;
WebSocketHandshakeStreamCreateHelper create_helper(
&connect_delegate_, sub_protocols, &stream_request_);
switch (GetParam()) {
case BASIC_HANDSHAKE_STREAM:
EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
break;
case HTTP2_HANDSHAKE_STREAM:
EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
break;
case HTTP3_HANDSHAKE_STREAM:
EXPECT_CALL(stream_request_, OnHttp3HandshakeStreamCreated(_)).Times(1);
break;
default:
NOTREACHED();
}
EXPECT_CALL(stream_request_, OnFailure(_, _, _)).Times(0);
HttpRequestInfo request_info;
request_info.url = url;
request_info.method = "GET";
request_info.load_flags = LOAD_DISABLE_CACHE;
request_info.traffic_annotation =
MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
auto headers = WebSocketCommonTestHeaders();
switch (GetParam()) {
case BASIC_HANDSHAKE_STREAM: {
std::unique_ptr<ClientSocketHandle> socket_handle =
socket_handle_factory_.CreateClientSocketHandle(
WebSocketStandardRequest(kPath, "www.example.org",
url::Origin::Create(GURL(kOrigin)),
/*send_additional_request_headers=*/{},
extra_request_headers),
WebSocketStandardResponse(
WebSocketExtraHeadersToString(extra_response_headers)));
std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
create_helper.CreateBasicStream(std::move(socket_handle), false,
&websocket_endpoint_lock_manager_);
// If in future the implementation type returned by CreateBasicStream()
// changes, this static_cast will be wrong. However, in that case the
// test will fail and AddressSanitizer should identify the issue.
static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
handshake->RegisterRequest(&request_info);
int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
CompletionOnceCallback());
EXPECT_THAT(rv, IsOk());
HttpResponseInfo response;
TestCompletionCallback request_callback;
rv = handshake->SendRequest(headers, &response,
request_callback.callback());
EXPECT_THAT(rv, IsOk());
TestCompletionCallback response_callback;
rv = handshake->ReadResponseHeaders(response_callback.callback());
EXPECT_THAT(rv, IsOk());
EXPECT_EQ(101, response.headers->response_code());
EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
return handshake->Upgrade();
}
case HTTP2_HANDSHAKE_STREAM: {
SpdyTestUtil spdy_util;
spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
kPath, "www.example.org", kOrigin, extra_request_headers);
spdy::SpdySerializedFrame request_headers(
spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
DEFAULT_PRIORITY, false));
MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
spdy::Http2HeaderBlock response_header_block =
WebSocketHttp2Response(extra_response_headers);
spdy::SpdySerializedFrame response_headers(
spdy_util.ConstructSpdyResponseHeaders(
1, std::move(response_header_block), false));
MockRead reads[] = {CreateMockRead(response_headers, 1),
MockRead(ASYNC, 0, 2)};
SequencedSocketData data(reads, writes);
SSLSocketDataProvider ssl(ASYNC, OK);
ssl.ssl_info.cert =
ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
SpdySessionDependencies session_deps;
session_deps.socket_factory->AddSocketDataProvider(&data);
session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
std::unique_ptr<HttpNetworkSession> http_network_session =
SpdySessionDependencies::SpdyCreateSession(&session_deps);
const SpdySessionKey key(
HostPortPair::FromURL(url), ProxyChain::Direct(),
PRIVACY_MODE_DISABLED, SpdySessionKey::IsProxySession::kFalse,
SocketTag(), NetworkAnonymizationKey(), SecureDnsPolicy::kAllow);
base::WeakPtr<SpdySession> spdy_session =
CreateSpdySession(http_network_session.get(), key, net_log);
std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
create_helper.CreateHttp2Stream(spdy_session, {} /* dns_aliases */);
handshake->RegisterRequest(&request_info);
int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY,
NetLogWithSource(),
CompletionOnceCallback());
EXPECT_THAT(rv, IsOk());
HttpResponseInfo response;
TestCompletionCallback request_callback;
rv = handshake->SendRequest(headers, &response,
request_callback.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = request_callback.WaitForResult();
EXPECT_THAT(rv, IsOk());
TestCompletionCallback response_callback;
rv = handshake->ReadResponseHeaders(response_callback.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = response_callback.WaitForResult();
EXPECT_THAT(rv, IsOk());
EXPECT_EQ(200, response.headers->response_code());
return handshake->Upgrade();
}
case HTTP3_HANDSHAKE_STREAM: {
const quic::QuicStreamId client_data_stream_id(
quic::QuicUtils::GetFirstBidirectionalStreamId(
quic_version_.transport_version, quic::Perspective::IS_CLIENT));
quic::QuicCryptoClientConfig crypto_config(
quic::test::crypto_test_utils::ProofVerifierForTesting());
const quic::QuicConnectionId connection_id(
quic::test::TestConnectionId(2));
test::QuicTestPacketMaker client_maker(
quic_version_, connection_id, &clock_, "mail.example.org",
quic::Perspective::IS_CLIENT,
/*client_headers_include_h2_stream_dependency_=*/false);
test::QuicTestPacketMaker server_maker(
quic_version_, connection_id, &clock_, "mail.example.org",
quic::Perspective::IS_SERVER,
/*client_headers_include_h2_stream_dependency_=*/false);
IPEndPoint peer_addr(IPAddress(192, 0, 2, 23), 443);
quic::test::MockConnectionIdGenerator connection_id_generator;
testing::StrictMock<quic::test::MockQuicConnectionVisitor> visitor;
ProofVerifyDetailsChromium verify_details;
MockCryptoClientStreamFactory crypto_client_stream_factory;
TransportSecurityState transport_security_state;
SSLConfigServiceDefaults ssl_config_service;
FLAGS_quic_enable_http3_grease_randomness = false;
clock_.AdvanceTime(quic::QuicTime::Delta::FromMilliseconds(20));
quic::QuicEnableVersion(quic_version_);
quic::test::MockRandom random_generator{0};
spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
kPath, "www.example.org", kOrigin, extra_request_headers);
int packet_number = 1;
mock_quic_data_.AddWrite(
SYNCHRONOUS,
client_maker.MakeInitialSettingsPacket(packet_number++));
mock_quic_data_.AddWrite(
ASYNC,
client_maker.MakeRequestHeadersPacket(
packet_number++, client_data_stream_id,
/*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
std::move(request_header_block), nullptr));
spdy::Http2HeaderBlock response_header_block =
WebSocketHttp2Response(extra_response_headers);
mock_quic_data_.AddRead(
ASYNC, server_maker.MakeResponseHeadersPacket(
/*packet_number=*/1, client_data_stream_id,
/*fin=*/false, std::move(response_header_block),
/*spdy_headers_frame_length=*/nullptr));
mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
mock_quic_data_.AddWrite(SYNCHRONOUS,
client_maker.MakeAckAndRstPacket(
packet_number++, client_data_stream_id,
quic::QUIC_STREAM_CANCELLED, 1, 0,
/*include_stop_sending_if_v99=*/true));
auto socket = std::make_unique<MockUDPClientSocket>(
mock_quic_data_.InitializeAndGetSequencedSocketData(),
NetLog::Get());
socket->Connect(peer_addr);
scoped_refptr<test::TestTaskRunner> runner =
base::MakeRefCounted<test::TestTaskRunner>(&clock_);
auto helper = std::make_unique<QuicChromiumConnectionHelper>(
&clock_, &random_generator);
auto alarm_factory =
std::make_unique<QuicChromiumAlarmFactory>(runner.get(), &clock_);
// Ownership of 'writer' is passed to 'QuicConnection'.
QuicChromiumPacketWriter* writer = new QuicChromiumPacketWriter(
socket.get(),
base::SingleThreadTaskRunner::GetCurrentDefault().get());
quic::QuicConnection* connection = new quic::QuicConnection(
connection_id, quic::QuicSocketAddress(),
net::ToQuicSocketAddress(peer_addr), helper.get(),
alarm_factory.get(), writer, true /* owns_writer */,
quic::Perspective::IS_CLIENT,
quic::test::SupportedVersions(quic_version_),
connection_id_generator);
connection->set_visitor(&visitor);
// Load a certificate that is valid for *.example.org
scoped_refptr<X509Certificate> test_cert(
ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem"));
EXPECT_TRUE(test_cert.get());
verify_details.cert_verify_result.verified_cert = test_cert;
verify_details.cert_verify_result.is_issued_by_known_root = true;
crypto_client_stream_factory.AddProofVerifyDetails(&verify_details);
base::TimeTicks dns_end = base::TimeTicks::Now();
base::TimeTicks dns_start = dns_end - base::Milliseconds(1);
session_ = std::make_unique<QuicChromiumClientSession>(
connection, std::move(socket),
/*stream_factory=*/nullptr, &crypto_client_stream_factory, &clock_,
&transport_security_state, &ssl_config_service,
/*server_info=*/nullptr,
QuicSessionKey("mail.example.org", 80, PRIVACY_MODE_DISABLED,
SocketTag(), NetworkAnonymizationKey(),
SecureDnsPolicy::kAllow,
/*require_dns_https_alpn=*/false),
/*require_confirmation=*/false,
/*migrate_session_early_v2=*/false,
/*migrate_session_on_network_change_v2=*/false,
/*default_network=*/handles::kInvalidNetworkHandle,
quic::QuicTime::Delta::FromMilliseconds(
kDefaultRetransmittableOnWireTimeout.InMilliseconds()),
/*migrate_idle_session=*/true, /*allow_port_migration=*/false,
kDefaultIdleSessionMigrationPeriod,
/*multi_port_probing_interval=*/0, kMaxTimeOnNonDefaultNetwork,
kMaxMigrationsToNonDefaultNetworkOnWriteError,
kMaxMigrationsToNonDefaultNetworkOnPathDegrading,
kQuicYieldAfterPacketsRead,
quic::QuicTime::Delta::FromMilliseconds(
kQuicYieldAfterDurationMilliseconds),
/*cert_verify_flags=*/0, quic::test::DefaultQuicConfig(),
std::make_unique<TestQuicCryptoClientConfigHandle>(&crypto_config),
dns_start, dns_end, base::DefaultTickClock::GetInstance(),
base::SingleThreadTaskRunner::GetCurrentDefault().get(),
/*socket_performance_watcher=*/nullptr,
HostResolverEndpointResult(), NetLog::Get());
session_->Initialize();
// Blackhole QPACK decoder stream instead of constructing mock writes.
session_->qpack_decoder()->set_qpack_stream_sender_delegate(
&noop_qpack_stream_sender_delegate_);
TestCompletionCallback callback;
EXPECT_THAT(session_->CryptoConnect(callback.callback()), IsOk());
EXPECT_TRUE(session_->OneRttKeysAvailable());
std::unique_ptr<QuicChromiumClientSession::Handle> session_handle =
session_->CreateHandle(
url::SchemeHostPort(url::kHttpsScheme, "mail.example.org", 80));
std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
create_helper.CreateHttp3Stream(std::move(session_handle),
{} /* dns_aliases */);
handshake->RegisterRequest(&request_info);
int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
CompletionOnceCallback());
EXPECT_THAT(rv, IsOk());
HttpResponseInfo response;
TestCompletionCallback request_callback;
rv = handshake->SendRequest(headers, &response,
request_callback.callback());
EXPECT_THAT(rv, IsOk());
session_->StartReading();
TestCompletionCallback response_callback;
rv = handshake->ReadResponseHeaders(response_callback.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = response_callback.WaitForResult();
EXPECT_THAT(rv, IsOk());
EXPECT_EQ(200, response.headers->response_code());
return handshake->Upgrade();
}
default:
NOTREACHED();
return nullptr;
}
}
private:
MockClientSocketHandleFactory socket_handle_factory_;
TestConnectDelegate connect_delegate_;
StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
// For HTTP3_HANDSHAKE_STREAM
quic::ParsedQuicVersion quic_version_;
quic::MockClock clock_;
std::unique_ptr<QuicChromiumClientSession> session_;
test::MockQuicData mock_quic_data_;
quic::test::NoopQpackStreamSenderDelegate noop_qpack_stream_sender_delegate_;
};
INSTANTIATE_TEST_SUITE_P(All,
WebSocketHandshakeStreamCreateHelperTest,
Values(BASIC_HANDSHAKE_STREAM,
HTTP2_HANDSHAKE_STREAM,
HTTP3_HANDSHAKE_STREAM));
// Confirm that the basic case works as expected.
TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
std::unique_ptr<WebSocketStream> stream =
CreateAndInitializeStream({}, {}, {});
EXPECT_EQ("", stream->GetExtensions());
EXPECT_EQ("", stream->GetSubProtocol());
}
// Verify that the sub-protocols are passed through.
TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
std::vector<std::string> sub_protocols;
sub_protocols.push_back("chat");
sub_protocols.push_back("superchat");
std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
{{"Sec-WebSocket-Protocol", "superchat"}});
EXPECT_EQ("superchat", stream->GetSubProtocol());
}
// Verify that extension name is available. Bad extension names are tested in
// websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
{}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
EXPECT_EQ("permessage-deflate", stream->GetExtensions());
}
// Verify that extension parameters are available. Bad parameters are tested in
// websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
{}, {},
{{"Sec-WebSocket-Extensions",
"permessage-deflate;"
" client_max_window_bits=14; server_max_window_bits=14;"
" server_no_context_takeover; client_no_context_takeover"}});
EXPECT_EQ(
"permessage-deflate;"
" client_max_window_bits=14; server_max_window_bits=14;"
" server_no_context_takeover; client_no_context_takeover",
stream->GetExtensions());
}
} // namespace
} // namespace net