blob: b9cf896f85698af730ea00f6fe00da206ec30446 [file] [log] [blame]
// Copyright 2016 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/filter/filter_source_stream.h"
#include <utility>
#include "base/check_op.h"
#include "base/containers/fixed_flat_map.h"
#include "base/functional/bind.h"
#include "base/metrics/histogram_macros.h"
#include "base/notreached.h"
#include "base/numerics/safe_conversions.h"
#include "base/strings/string_util.h"
#include "components/miracle_parameter/common/public/miracle_parameter.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
namespace net {
namespace {
constexpr char kDeflate[] = "deflate";
constexpr char kGZip[] = "gzip";
constexpr char kXGZip[] = "x-gzip";
constexpr char kBrotli[] = "br";
constexpr char kZstd[] = "zstd";
BASE_FEATURE(kBufferSizeForFilterSourceStreamFeature,
"BufferSizeForFilterSourceStreamFeature",
base::FEATURE_ENABLED_BY_DEFAULT);
MIRACLE_PARAMETER_FOR_INT(GetBufferSizeForFilterSourceStream,
kBufferSizeForFilterSourceStreamFeature,
"BufferSizeForFilterSourceStream",
32 * 1024)
} // namespace
FilterSourceStream::FilterSourceStream(SourceType type,
std::unique_ptr<SourceStream> upstream)
: SourceStream(type), upstream_(std::move(upstream)) {
DCHECK(upstream_);
}
FilterSourceStream::~FilterSourceStream() = default;
int FilterSourceStream::Read(IOBuffer* read_buffer,
int read_buffer_size,
CompletionOnceCallback callback) {
DCHECK_EQ(STATE_NONE, next_state_);
DCHECK(read_buffer);
DCHECK_LT(0, read_buffer_size);
// Allocate a BlockBuffer during first Read().
if (!input_buffer_) {
input_buffer_ = base::MakeRefCounted<IOBufferWithSize>(
GetBufferSizeForFilterSourceStream());
// This is first Read(), start with reading data from |upstream_|.
next_state_ = STATE_READ_DATA;
} else {
// Otherwise start with filtering data, which will tell us whether this
// stream needs input data.
next_state_ = STATE_FILTER_DATA;
}
output_buffer_ = read_buffer;
output_buffer_size_ = base::checked_cast<size_t>(read_buffer_size);
int rv = DoLoop(OK);
if (rv == ERR_IO_PENDING)
callback_ = std::move(callback);
return rv;
}
std::string FilterSourceStream::Description() const {
std::string next_type_string = upstream_->Description();
if (next_type_string.empty())
return GetTypeAsString();
return next_type_string + "," + GetTypeAsString();
}
bool FilterSourceStream::MayHaveMoreBytes() const {
return !upstream_end_reached_;
}
FilterSourceStream::SourceType FilterSourceStream::ParseEncodingType(
const std::string& encoding) {
std::string lower_encoding = base::ToLowerASCII(encoding);
static constexpr auto kEncodingMap =
base::MakeFixedFlatMapSorted<base::StringPiece, SourceType>({
{"", TYPE_NONE},
{kBrotli, TYPE_BROTLI},
{kDeflate, TYPE_DEFLATE},
{kGZip, TYPE_GZIP},
{kXGZip, TYPE_GZIP},
{kZstd, TYPE_ZSTD},
});
auto* encoding_type = kEncodingMap.find(lower_encoding);
if (encoding_type == kEncodingMap.end()) {
return TYPE_UNKNOWN;
}
return encoding_type->second;
}
int FilterSourceStream::DoLoop(int result) {
DCHECK_NE(STATE_NONE, next_state_);
int rv = result;
do {
State state = next_state_;
next_state_ = STATE_NONE;
switch (state) {
case STATE_READ_DATA:
rv = DoReadData();
break;
case STATE_READ_DATA_COMPLETE:
rv = DoReadDataComplete(rv);
break;
case STATE_FILTER_DATA:
DCHECK_LE(0, rv);
rv = DoFilterData();
break;
default:
NOTREACHED() << "bad state: " << state;
rv = ERR_UNEXPECTED;
break;
}
} while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
return rv;
}
int FilterSourceStream::DoReadData() {
// Read more data means subclasses have consumed all input or this is the
// first read in which case the |drainable_input_buffer_| is not initialized.
DCHECK(drainable_input_buffer_ == nullptr ||
0 == drainable_input_buffer_->BytesRemaining());
next_state_ = STATE_READ_DATA_COMPLETE;
// Use base::Unretained here is safe because |this| owns |upstream_|.
int rv =
upstream_->Read(input_buffer_.get(), GetBufferSizeForFilterSourceStream(),
base::BindOnce(&FilterSourceStream::OnIOComplete,
base::Unretained(this)));
return rv;
}
int FilterSourceStream::DoReadDataComplete(int result) {
DCHECK_NE(ERR_IO_PENDING, result);
if (result >= OK) {
drainable_input_buffer_ =
base::MakeRefCounted<DrainableIOBuffer>(input_buffer_, result);
next_state_ = STATE_FILTER_DATA;
}
if (result <= OK)
upstream_end_reached_ = true;
return result;
}
int FilterSourceStream::DoFilterData() {
DCHECK(output_buffer_);
DCHECK(drainable_input_buffer_);
size_t consumed_bytes = 0;
base::expected<size_t, Error> bytes_output = FilterData(
output_buffer_.get(), output_buffer_size_, drainable_input_buffer_.get(),
drainable_input_buffer_->BytesRemaining(), &consumed_bytes,
upstream_end_reached_);
const auto bytes_remaining =
base::checked_cast<size_t>(drainable_input_buffer_->BytesRemaining());
if (bytes_output.has_value() && bytes_output.value() == 0) {
DCHECK_EQ(consumed_bytes, bytes_remaining);
} else {
DCHECK_LE(consumed_bytes, bytes_remaining);
}
// FilterData() is not allowed to return ERR_IO_PENDING.
if (!bytes_output.has_value())
DCHECK_NE(ERR_IO_PENDING, bytes_output.error());
if (consumed_bytes > 0)
drainable_input_buffer_->DidConsume(consumed_bytes);
// Received data or encountered an error.
if (!bytes_output.has_value()) {
CHECK_LT(bytes_output.error(), 0);
return bytes_output.error();
}
if (bytes_output.value() != 0)
return base::checked_cast<int>(bytes_output.value());
// If no data is returned, continue reading if |this| needs more input.
if (NeedMoreData()) {
DCHECK_EQ(0, drainable_input_buffer_->BytesRemaining());
next_state_ = STATE_READ_DATA;
}
return 0;
}
void FilterSourceStream::OnIOComplete(int result) {
DCHECK_EQ(STATE_READ_DATA_COMPLETE, next_state_);
int rv = DoLoop(result);
if (rv == ERR_IO_PENDING)
return;
output_buffer_ = nullptr;
output_buffer_size_ = 0;
std::move(callback_).Run(rv);
}
bool FilterSourceStream::NeedMoreData() const {
return !upstream_end_reached_;
}
} // namespace net