blob: 551729ac542b735bb353b4643cbc1812d22a87c1 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/filter/zstd_source_stream.h"
#include <algorithm>
#include <unordered_map>
#include <utility>
#define ZSTD_STATIC_LINKING_ONLY
#include "base/bits.h"
#include "base/check_op.h"
#include "base/metrics/histogram_macros.h"
#include "base/numerics/safe_conversions.h"
#include "net/base/io_buffer.h"
#include "third_party/zstd/src/lib/zstd.h"
#include "third_party/zstd/src/lib/zstd_errors.h"
namespace net {
namespace {
const char kZstd[] = "ZSTD";
struct FreeContextDeleter {
inline void operator()(ZSTD_DCtx* ptr) const { ZSTD_freeDCtx(ptr); }
};
// ZstdSourceStream applies Zstd content decoding to a data stream.
// Zstd format speciication: https://datatracker.ietf.org/doc/html/rfc8878
class ZstdSourceStream : public FilterSourceStream {
public:
explicit ZstdSourceStream(std::unique_ptr<SourceStream> upstream,
scoped_refptr<IOBuffer> dictionary = nullptr,
size_t dictionary_size = 0u)
: FilterSourceStream(SourceStream::TYPE_ZSTD, std::move(upstream)),
dictionary_(std::move(dictionary)),
dictionary_size_(dictionary_size) {
ZSTD_customMem custom_mem = {&customMalloc, &customFree, this};
dctx_.reset(ZSTD_createDCtx_advanced(custom_mem));
CHECK(dctx_);
// Following RFC 8878 recommendation (see section 3.1.1.1.2 Window
// Descriptor) of using a maximum 8MB memory buffer to decompress frames
// to '... protect decoders from unreasonable memory requirements'.
int window_log_max = 23;
if (dictionary_) {
// For shared dictionary case, allow using larger window size (Log2Ceiling
// of `dictionary_size`). It is safe because we have the size limit per
// shared dictionary and the total dictionary size limit.
window_log_max =
std::max(base::bits::Log2Ceiling(
base::checked_cast<uint32_t>(dictionary_size_)),
window_log_max);
}
ZSTD_DCtx_setParameter(dctx_.get(), ZSTD_d_windowLogMax, window_log_max);
if (dictionary_) {
size_t result = ZSTD_DCtx_loadDictionary_advanced(
dctx_.get(), reinterpret_cast<const void*>(dictionary_->data()),
dictionary_size_, ZSTD_dlm_byRef, ZSTD_dct_rawContent);
DCHECK(!ZSTD_isError(result));
}
}
ZstdSourceStream(const ZstdSourceStream&) = delete;
ZstdSourceStream& operator=(const ZstdSourceStream&) = delete;
~ZstdSourceStream() override {
if (ZSTD_isError(decoding_result_)) {
ZSTD_ErrorCode error_code = ZSTD_getErrorCode(decoding_result_);
UMA_HISTOGRAM_ENUMERATION(
"Net.ZstdFilter.ErrorCode", static_cast<int>(error_code),
static_cast<int>(ZSTD_ErrorCode::ZSTD_error_maxCode));
}
UMA_HISTOGRAM_ENUMERATION("Net.ZstdFilter.Status", decoding_status_);
if (decoding_status_ == ZstdDecodingStatus::kEndOfFrame) {
// CompressionRatio is undefined when there is no output produced.
if (produced_bytes_ != 0) {
UMA_HISTOGRAM_PERCENTAGE(
"Net.ZstdFilter.CompressionRatio",
static_cast<int>((consumed_bytes_ * 100) / produced_bytes_));
}
}
UMA_HISTOGRAM_MEMORY_KB("Net.ZstdFilter.MaxMemoryUsage",
(max_allocated_ / 1024));
}
private:
static void* customMalloc(void* opaque, size_t size) {
return reinterpret_cast<ZstdSourceStream*>(opaque)->customMalloc(size);
}
void* customMalloc(size_t size) {
void* address = malloc(size);
CHECK(address);
malloc_sizes_.insert(std::make_pair(address, size));
total_allocated_ += size;
if (total_allocated_ > max_allocated_) {
max_allocated_ = total_allocated_;
}
return address;
}
static void customFree(void* opaque, void* address) {
return reinterpret_cast<ZstdSourceStream*>(opaque)->customFree(address);
}
void customFree(void* address) {
free(address);
auto it = malloc_sizes_.find(address);
CHECK(it != malloc_sizes_.end());
const size_t size = it->second;
total_allocated_ -= size;
malloc_sizes_.erase(it);
}
// SourceStream implementation
std::string GetTypeAsString() const override { return kZstd; }
base::expected<size_t, Error> FilterData(IOBuffer* output_buffer,
size_t output_buffer_size,
IOBuffer* input_buffer,
size_t input_buffer_size,
size_t* consumed_bytes,
bool upstream_end_reached) override {
CHECK(dctx_);
ZSTD_inBuffer input = {input_buffer->data(), input_buffer_size, 0};
ZSTD_outBuffer output = {output_buffer->data(), output_buffer_size, 0};
const size_t result = ZSTD_decompressStream(dctx_.get(), &output, &input);
decoding_result_ = result;
produced_bytes_ += output.pos;
consumed_bytes_ += input.pos;
*consumed_bytes = input.pos;
if (ZSTD_isError(result)) {
decoding_status_ = ZstdDecodingStatus::kDecodingError;
return base::unexpected(ERR_CONTENT_DECODING_FAILED);
} else if (input.pos < input.size) {
// Given a valid frame, zstd won't consume the last byte of the frame
// until it has flushed all of the decompressed data of the frame.
// Therefore, instead of checking if the return code is 0, we can
// just check if input.pos < input.size.
return output.pos;
} else {
CHECK_EQ(input.pos, input.size);
if (result != 0u) {
// The return value from ZSTD_decompressStream did not end on a frame,
// but we reached the end of the file. We assume this is an error, and
// the input was truncated.
if (upstream_end_reached) {
decoding_status_ = ZstdDecodingStatus::kDecodingError;
}
} else {
CHECK_EQ(result, 0u);
CHECK_LE(output.pos, output.size);
// Finished decoding a frame.
decoding_status_ = ZstdDecodingStatus::kEndOfFrame;
}
return output.pos;
}
}
size_t total_allocated_ = 0;
size_t max_allocated_ = 0;
std::unordered_map<void*, size_t> malloc_sizes_;
const scoped_refptr<IOBuffer> dictionary_;
const size_t dictionary_size_;
std::unique_ptr<ZSTD_DCtx, FreeContextDeleter> dctx_;
ZstdDecodingStatus decoding_status_ = ZstdDecodingStatus::kDecodingInProgress;
size_t decoding_result_ = 0;
size_t consumed_bytes_ = 0;
size_t produced_bytes_ = 0;
};
} // namespace
std::unique_ptr<FilterSourceStream> CreateZstdSourceStream(
std::unique_ptr<SourceStream> previous) {
return std::make_unique<ZstdSourceStream>(std::move(previous));
}
std::unique_ptr<FilterSourceStream> CreateZstdSourceStreamWithDictionary(
std::unique_ptr<SourceStream> previous,
scoped_refptr<IOBuffer> dictionary,
size_t dictionary_size) {
return std::make_unique<ZstdSourceStream>(
std::move(previous), std::move(dictionary), dictionary_size);
}
} // namespace net